diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000..dfa6228492 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,32 @@ +{ + "name": "dimos-dev", + "image": "ghcr.io/dimensionalos/dev:dev", + "customizations": { + "vscode": { + "extensions": [ + "charliermarsh.ruff", + "ms-python.vscode-pylance" + ] + } + }, + "containerEnv": { + "PYTHONPATH": "${localEnv:PYTHONPATH}:/workspaces/dimos" + }, + "postCreateCommand": "git config --global --add safe.directory /workspaces/dimos && cd /workspaces/dimos && pre-commit install", + "settings": { + "notebook.formatOnSave.enabled": true, + "notebook.codeActionsOnSave": { + "notebook.source.fixAll": "explicit", + "notebook.source.organizeImports": "explicit" + }, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.formatOnSave": true + }, + "runArgs": [ + "--cap-add=NET_ADMIN" + ] +} diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000..72d14322f1 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,109 @@ +# Version control +.git +.gitignore +.github/ + +# Editor and IDE files +.vscode +.idea +*.swp +*.swo +.cursor/ +.cursorignore + +# Shell history +.bash_history +.zsh_history +.history + +# Python virtual environments +**/venv/ +**/.venv/ +**/env/ +**/.env/ +**/*-venv/ +**/*_venv/ +**/ENV/ + + +# Python build artifacts +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +*.egg-info/ +dist/ +build/ +*.so +*.dylib + +# Environment file +.env +.env.local +.env.*.local + +# Large data files +data/* +!data/.lfs/ + +# Model files (can be downloaded at runtime) +*.pt +*.pth +*.onnx +*.pb +*.h5 +*.ckpt +*.safetensors +checkpoints/ +assets/model-cache + +# Logs +*.log + +# Large media files (not needed for functionality) +*.png +*.jpg +*.jpeg +*.gif +*.mp4 +*.mov +*.avi +*.mkv +*.webm +*.MOV + +# Large font files +*.ttf +*.otf + +# Node modules (for dev tools, not needed in container) +node_modules/ +package-lock.json +package.json +bin/node_modules/ + +# Database files +*.db +*.sqlite +*.sqlite3 + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Temporary files +tmp/ +temp/ +*.tmp +.python-version + +# Exclude all assets subdirectories +assets/*/* +!assets/agent/prompt.txt +!assets/* diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..6644370b86 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,27 @@ +# top-most EditorConfig file +root = true + +[*] +indent_style = space +indent_size = 2 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.md] +indent_size = 4 + +[*.nix] +indent_size = 2 + +[*.{py,ipynb}] +indent_size = 4 +max_line_length = 100 + +[*.rs] +indent_style = space +indent_size = 4 + +[*.{ts,svelte}] +indent_size = 2 diff --git a/.envrc.nix b/.envrc.nix new file mode 100644 index 0000000000..a3f663db80 --- /dev/null +++ b/.envrc.nix @@ -0,0 +1,11 @@ +if ! has nix_direnv_version || ! nix_direnv_version 3.0.6; then + source_url "https://raw.githubusercontent.com/nix-community/nix-direnv/3.0.6/direnvrc" "sha256-RYcUJaRMf8oF5LznDrlCXbkOQrywm0HDv1VjYGaJGdM=" +fi +use flake . +for venv in venv .venv env; do + if [[ -f "$venv/bin/activate" ]]; then + source "$venv/bin/activate" + break + fi +done +dotenv_if_exists diff --git a/.envrc.venv b/.envrc.venv new file mode 100644 index 0000000000..e315a030c7 --- /dev/null +++ b/.envrc.venv @@ -0,0 +1,7 @@ +for venv in venv .venv env; do + if [[ -f "$venv/bin/activate" ]]; then + source "$venv/bin/activate" + break + fi +done +dotenv_if_exists diff --git a/.gitattributes b/.gitattributes index a81891f57a..302cb2e191 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,16 @@ -* text=auto +# Handle line endings automatically for files Git considers text, +# converting them to LF on checkout. +* text=auto eol=lf +# Ensure Python files always use LF for line endings. *.py text eol=lf - +# Treat designated file types as binary and do not alter their contents or line endings. +*.png binary +*.jpg binary +*.ico binary +*.pdf binary +# Explicit LFS tracking for test files +/data/.lfs/*.tar.gz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text binary +*.mp4 filter=lfs diff=lfs merge=lfs -text binary +*.mov filter=lfs diff=lfs merge=lfs -text binary +*.gif filter=lfs diff=lfs merge=lfs -text binary diff --git a/.github/actions/docker-build/action.yml b/.github/actions/docker-build/action.yml new file mode 100644 index 0000000000..a538ad35fd --- /dev/null +++ b/.github/actions/docker-build/action.yml @@ -0,0 +1,59 @@ +name: docker-build +description: "Composite action to build and push a Docker target to GHCR" +inputs: + target: + description: "Dockerfile target stage to build" + required: true + tag: + description: "Image tag to push" + required: true + freespace: + description: "Remove large pre‑installed SDKs before building to free space" + required: false + default: "false" + context: + description: "Docker build context" + required: false + default: "." + +runs: + using: "composite" + steps: + - name: Free up disk space + if: ${{ inputs.freespace == 'true' }} + shell: bash + run: | + echo -e "pre cleanup space:\n $(df -h)" + sudo rm -rf /opt/ghc + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/share/boost + sudo rm -rf /usr/local/lib/android + echo -e "post cleanup space:\n $(df -h)" + + - uses: actions/checkout@v4 + + - uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ github.token }} + + - uses: crazy-max/ghaction-github-runtime@v3 + + - uses: docker/setup-buildx-action@v3 + with: + driver: docker-container + install: true + use: true + + - name: Build & Push ${{ inputs.target }} + uses: docker/build-push-action@v6 + with: + push: true + context: ${{ inputs.context }} + file: docker/${{ inputs.target }}/Dockerfile + tags: ghcr.io/dimensionalos/${{ inputs.target }}:${{ inputs.tag }} + cache-from: type=gha,scope=${{ inputs.target }} + cache-to: type=gha,mode=max,scope=${{ inputs.target }} + build-args: | + FROM_TAG=${{ inputs.tag }} diff --git a/.github/workflows/_docker-build-template.yml b/.github/workflows/_docker-build-template.yml new file mode 100644 index 0000000000..478a9bec84 --- /dev/null +++ b/.github/workflows/_docker-build-template.yml @@ -0,0 +1,149 @@ +name: docker-build-template +on: + workflow_call: + inputs: + from-image: { type: string, required: true } + to-image: { type: string, required: true } + dockerfile: { type: string, required: true } + freespace: { type: boolean, default: true } + should-run: { type: boolean, default: false } + context: { type: string, default: '.' } + +# you can run this locally as well via +# ./bin/dockerbuild [image-name] +jobs: + build: + runs-on: [self-hosted, Linux] + permissions: + contents: read + packages: write + + steps: + - name: Fix permissions + if: ${{ inputs.should-run }} + run: | + sudo chown -R $USER:$USER ${{ github.workspace }} || true + + - uses: actions/checkout@v4 + if: ${{ inputs.should-run }} + with: + fetch-depth: 0 + + - name: free up disk space + # explicitly enable this for large builds + if: ${{ inputs.should-run && inputs.freespace }} + run: | + echo -e "pre cleanup space:\n $(df -h)" + sudo rm -rf /opt/ghc + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/share/boost + sudo rm -rf /usr/local/lib/android + + echo "=== Cleaning images from deleted branches ===" + + # Get list of all remote branches + git ls-remote --heads origin | awk '{print $2}' | sed 's|refs/heads/||' > /tmp/active_branches.txt + + # Check each docker image tag against branch list + docker images --format "{{.Repository}}:{{.Tag}}|{{.ID}}" | \ + grep "ghcr.io/dimensionalos" | \ + grep -v ":" | \ + while IFS='|' read image_ref id; do + tag=$(echo "$image_ref" | cut -d: -f2) + + # Skip if tag matches an active branch + if grep -qx "$tag" /tmp/active_branches.txt; then + echo "Branch exists: $tag - keeping $image_ref" + else + echo "Branch deleted: $tag - removing $image_ref" + docker rmi "$id" 2>/dev/null || true + fi + done + + rm -f /tmp/active_branches.txt + + USAGE=$(df / | awk 'NR==2 {print $5}' | sed 's/%//') + echo "Pre-docker-cleanup disk usage: ${USAGE}%" + + if [ $USAGE -gt 60 ]; then + echo "=== Running quick cleanup (usage > 60%) ===" + + # Keep newest image per tag + docker images --format "{{.Repository}}|{{.Tag}}|{{.ID}}" | \ + grep "ghcr.io/dimensionalos" | \ + grep -v "" | \ + while IFS='|' read repo tag id; do + created_ts=$(docker inspect -f '{{.Created}}' "$id" 2>/dev/null) + created_unix=$(date -d "$created_ts" +%s 2>/dev/null || echo "0") + echo "${repo}|${tag}|${id}|${created_unix}" + done | sort -t'|' -k1,1 -k2,2 -k4,4nr | \ + awk -F'|' ' + { + repo=$1; tag=$2; id=$3 + repo_tag = repo ":" tag + + # Skip protected tags + if (tag ~ /^(main|dev|latest)$/) next + + # Keep newest per tag, remove older duplicates + if (!(repo_tag in seen_combos)) { + seen_combos[repo_tag] = 1 + } else { + system("docker rmi " id " 2>/dev/null || true") + } + }' + + docker image prune -f + docker volume prune -f + fi + + # Aggressive cleanup if still above 85% + USAGE=$(df / | awk 'NR==2 {print $5}' | sed 's/%//') + if [ $USAGE -gt 85 ]; then + echo "=== AGGRESSIVE cleanup (usage > 85%) - removing all except main/dev ===" + + # Remove ALL images except main and dev tags + docker images --format "{{.Repository}}:{{.Tag}} {{.ID}}" | \ + grep -E "ghcr.io/dimensionalos" | \ + grep -vE ":(main|dev)$" | \ + awk '{print $2}' | xargs -r docker rmi -f || true + + docker container prune -f + docker volume prune -a -f + docker network prune -f + docker image prune -f + fi + + echo -e "post cleanup space:\n $(df -h)" + + - uses: docker/login-action@v3 + if: ${{ inputs.should-run }} + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # required for github cache of docker layers + - uses: crazy-max/ghaction-github-runtime@v3 + if: ${{ inputs.should-run }} + + # required for github cache of docker layers + - uses: docker/setup-buildx-action@v3 + if: ${{ inputs.should-run }} + with: + driver: docker-container + install: true + use: true + + - uses: docker/build-push-action@v6 + if: ${{ inputs.should-run }} + with: + push: true + context: ${{ inputs.context }} + file: docker/${{ inputs.dockerfile }}/Dockerfile + tags: ${{ inputs.to-image }} + cache-from: type=gha,scope=${{ inputs.dockerfile }} + cache-to: type=gha,mode=max,scope=${{ inputs.dockerfile }} + #cache-from: type=gha,scope=${{ inputs.dockerfile }}-${{ inputs.from-image }} + #cache-to: type=gha,mode=max,scope=${{ inputs.dockerfile }}-${{ inputs.from-image }} + build-args: FROM_IMAGE=${{ inputs.from-image }} diff --git a/.github/workflows/code-cleanup.yml b/.github/workflows/code-cleanup.yml new file mode 100644 index 0000000000..d9526207a6 --- /dev/null +++ b/.github/workflows/code-cleanup.yml @@ -0,0 +1,36 @@ +name: code-cleanup +on: + push: + paths-ignore: + - '**.md' + +permissions: + contents: write + packages: write + pull-requests: read + +jobs: + pre-commit: + runs-on: self-hosted + steps: + - name: Fix permissions + run: | + sudo chown -R $USER:$USER ${{ github.workspace }} || true + + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - name: Run pre-commit + id: pre-commit-first + uses: pre-commit/action@v3.0.1 + continue-on-error: true + + - name: Re-run pre-commit if failed initially + id: pre-commit-retry + if: steps.pre-commit-first.outcome == 'failure' + uses: pre-commit/action@v3.0.1 + continue-on-error: false + + - name: Commit code changes + uses: stefanzweifel/git-auto-commit-action@v5 + with: + commit_message: "CI code cleanup" diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000000..0b57d5ba22 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,245 @@ +name: docker +on: + push: + branches: + - main + - dev + paths-ignore: + - '**.md' + pull_request: + paths-ignore: + - '**.md' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +permissions: + contents: read + packages: write + pull-requests: read + +jobs: + check-changes: + runs-on: [self-hosted, Linux] + outputs: + ros: ${{ steps.filter.outputs.ros }} + python: ${{ steps.filter.outputs.python }} + dev: ${{ steps.filter.outputs.dev }} + tests: ${{ steps.filter.outputs.tests }} + branch-tag: ${{ steps.set-tag.outputs.branch_tag }} + steps: + - name: Fix permissions + run: | + sudo chown -R $USER:$USER ${{ github.workspace }} || true + + - uses: actions/checkout@v4 + - id: filter + uses: dorny/paths-filter@v3 + with: + base: ${{ github.event.before }} + filters: | + # ros and python are (alternative) root images + # change to root stuff like docker.yml etc triggers rebuild of those + # which cascades into a full rebuild + ros: + - .github/workflows/_docker-build-template.yml + - .github/workflows/docker.yml + - docker/ros/** + + python: + - .github/workflows/_docker-build-template.yml + - .github/workflows/docker.yml + - docker/python/** + - pyproject.toml + + dev: + - docker/dev/** + + tests: + - dimos/** + + - name: Determine Branch Tag + id: set-tag + run: | + case "${GITHUB_REF_NAME}" in + main) branch_tag="latest" ;; + dev) branch_tag="dev" ;; + *) + branch_tag=$(echo "${GITHUB_REF_NAME}" \ + | tr '[:upper:]' '[:lower:]' \ + | sed -E 's#[^a-z0-9_.-]+#_#g' \ + | sed -E 's#^-+|-+$##g') + ;; + esac + echo "branch tag determined: ${branch_tag}" + echo branch_tag="${branch_tag}" >> "$GITHUB_OUTPUT" + + # just a debugger + inspect-needs: + needs: [check-changes, ros] + runs-on: dimos-runner-ubuntu-2204 + if: always() + steps: + - run: | + echo '${{ toJSON(needs) }}' + + ros: + needs: [check-changes] + if: needs.check-changes.outputs.ros == 'true' + uses: ./.github/workflows/_docker-build-template.yml + with: + should-run: true + from-image: ubuntu:22.04 + to-image: ghcr.io/dimensionalos/ros:${{ needs.check-changes.outputs.branch-tag }} + dockerfile: ros + + ros-python: + needs: [check-changes, ros] + if: always() + uses: ./.github/workflows/_docker-build-template.yml + with: + should-run: ${{ + needs.check-changes.outputs.python == 'true' && + needs.check-changes.result != 'error' && + needs.ros.result != 'error' + }} + + from-image: ghcr.io/dimensionalos/ros:${{ needs.ros.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + to-image: ghcr.io/dimensionalos/ros-python:${{ needs.check-changes.outputs.branch-tag }} + dockerfile: python + + python: + needs: [check-changes] + if: needs.check-changes.outputs.python == 'true' + uses: ./.github/workflows/_docker-build-template.yml + with: + should-run: true + dockerfile: python + from-image: ubuntu:22.04 + to-image: ghcr.io/dimensionalos/python:${{ needs.check-changes.outputs.branch-tag }} + + dev: + needs: [check-changes, python] + if: always() + + uses: ./.github/workflows/_docker-build-template.yml + with: + should-run: ${{ + needs.check-changes.result == 'success' && + ((needs.python.result == 'success') || + (needs.python.result == 'skipped' && + needs.check-changes.outputs.dev == 'true')) }} + from-image: ghcr.io/dimensionalos/python:${{ needs.python.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + to-image: ghcr.io/dimensionalos/dev:${{ needs.check-changes.outputs.branch-tag }} + dockerfile: dev + + ros-dev: + needs: [check-changes, ros-python] + if: always() + uses: ./.github/workflows/_docker-build-template.yml + with: + should-run: ${{ + needs.check-changes.result == 'success' && + (needs.check-changes.outputs.dev == 'true' || + (needs.ros-python.result == 'success' && (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.ros == 'true'))) + }} + from-image: ghcr.io/dimensionalos/ros-python:${{ needs.ros-python.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + to-image: ghcr.io/dimensionalos/ros-dev:${{ needs.check-changes.outputs.branch-tag }} + dockerfile: dev + + run-ros-tests: + needs: [check-changes, ros-dev] + if: always() + uses: ./.github/workflows/tests.yml + secrets: inherit + with: + should-run: ${{ + needs.check-changes.result == 'success' && + ((needs.ros-dev.result == 'success') || + (needs.ros-dev.result == 'skipped' && + needs.check-changes.outputs.tests == 'true')) + }} + cmd: "pytest && pytest -m ros" # run tests that depend on ros as well + dev-image: ros-dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true' || needs.check-changes.outputs.ros == 'true') && needs.ros-dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + + run-tests: + needs: [check-changes, dev] + if: always() + uses: ./.github/workflows/tests.yml + secrets: inherit + with: + should-run: ${{ + needs.check-changes.result == 'success' && + ((needs.dev.result == 'success') || + (needs.dev.result == 'skipped' && + needs.check-changes.outputs.tests == 'true')) + }} + cmd: "pytest" + dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + + # we run in parallel with normal tests for speed + run-heavy-tests: + needs: [check-changes, dev] + if: always() + uses: ./.github/workflows/tests.yml + secrets: inherit + with: + should-run: ${{ + needs.check-changes.result == 'success' && + ((needs.dev.result == 'success') || + (needs.dev.result == 'skipped' && + needs.check-changes.outputs.tests == 'true')) + }} + cmd: "pytest -m heavy" + dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + + run-lcm-tests: + needs: [check-changes, dev] + if: always() + uses: ./.github/workflows/tests.yml + secrets: inherit + with: + should-run: ${{ + needs.check-changes.result == 'success' && + ((needs.dev.result == 'success') || + (needs.dev.result == 'skipped' && + needs.check-changes.outputs.tests == 'true')) + }} + cmd: "pytest -m lcm" + dev-image: dev:${{ (needs.check-changes.outputs.python == 'true' || needs.check-changes.outputs.dev == 'true') && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + + # Run module tests directly to avoid pytest forking issues + # run-module-tests: + # needs: [check-changes, dev] + # if: ${{ + # always() && + # needs.check-changes.result == 'success' && + # ((needs.dev.result == 'success') || + # (needs.dev.result == 'skipped' && + # needs.check-changes.outputs.tests == 'true')) + # }} + # runs-on: [self-hosted, x64, 16gb] + # container: + # image: ghcr.io/dimensionalos/dev:${{ needs.check-changes.outputs.dev == 'true' && needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + # steps: + # - name: Fix permissions + # run: | + # sudo chown -R $USER:$USER ${{ github.workspace }} || true + # + # - uses: actions/checkout@v4 + # with: + # lfs: true + # + # - name: Configure Git LFS + # run: | + # git config --global --add safe.directory '*' + # git lfs install + # git lfs fetch + # git lfs checkout + # + # - name: Run module tests + # env: + # CI: "true" + # run: | + # /entrypoint.sh bash -c "pytest -m module" diff --git a/.github/workflows/readme.md b/.github/workflows/readme.md new file mode 100644 index 0000000000..f82ba479bb --- /dev/null +++ b/.github/workflows/readme.md @@ -0,0 +1,51 @@ +# general structure of workflows + +Docker.yml checks for releavant file changes and re-builds required images +Currently images have a dependancy chain of ros -> python -> dev (in the future this might be a tree and can fork) + +On top of the dev image then tests are run. +Dev image is also what developers use in their own IDE via devcontainers +https://code.visualstudio.com/docs/devcontainers/containers + +# login to github docker repo + +create personal access token (classic, not fine grained) +https://github.com/settings/tokens + +add permissions +- read:packages scope to download container images and read their metadata. + + and optionally, + +- write:packages scope to download and upload container images and read and write their metadata. +- delete:packages scope to delete container images. + +more info @ https://docs.github.com/en/packages/working-with-a-github-packages-registry/working-with-the-container-registry + +login to docker via + +`sh +echo TOKEN | docker login ghcr.io -u GITHUB_USER --password-stdin +` + +pull dev image (dev branch) +`sh +docker pull ghcr.io/dimensionalos/dev:dev +` + +pull dev image (master) +`sh +docker pull ghcr.io/dimensionalos/dev:latest +` + +# todo + +Currently there is an issue with ensuring both correct docker image build ordering, and skipping unneccessary re-builds. + +(we need job dependancies for builds to wait to their images underneath to be built (for example py waits for ros)) +by default if a parent is skipped, it's children get skipped as well, unless they have always() in their conditional. + +Issue is once we put always() in the conditional, it seems that no matter what other check we put in the same conditional, job will always run. +for this reason we cannot skip python (and above) builds for now. Needs review. + +I think we will need to write our own build dispatcher in python that calls github workflows that build images. diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000000..e84d7d43d2 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,62 @@ +name: tests + +on: + workflow_call: + inputs: + should-run: + required: false + type: boolean + default: true + dev-image: + required: true + type: string + default: "dev:dev" + cmd: + required: true + type: string + +permissions: + contents: read + packages: read + +jobs: + + # cleanup: + # runs-on: dimos-runner-ubuntu-2204 + # steps: + # - name: exit early + # if: ${{ !inputs.should-run }} + # run: | + # exit 0 + + # - name: Free disk space + # run: | + # sudo rm -rf /opt/ghc + # sudo rm -rf /usr/share/dotnet + # sudo rm -rf /usr/local/share/boost + # sudo rm -rf /usr/local/lib/android + + run-tests: + runs-on: [self-hosted, Linux] + container: + image: ghcr.io/dimensionalos/${{ inputs.dev-image }} + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + ALIBABA_API_KEY: ${{ secrets.ALIBABA_API_KEY }} + + steps: + - uses: actions/checkout@v4 + + - name: Fix permissions + run: | + git config --global --add safe.directory '*' + + - name: Run tests + run: | + /entrypoint.sh bash -c "${{ inputs.cmd }}" + + - name: check disk space + if: failure() + run: | + df -h diff --git a/.gitignore b/.gitignore index 59da69968c..5a6e7a1e9a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,57 @@ -.venv/ .vscode/ # Ignore Python cache files __pycache__/ *.pyc -.venv* -venv* + +# Ignore virtual environment directories +*venv*/ +.venv*/ .ssh/ +# Ignore python tooling dirs +*.egg-info/ +__pycache__ + .env **/.DS_Store + +# Ignore default runtime output folder +/assets/output/ +/assets/rgbd_data/ +/assets/saved_maps/ +/assets/model-cache/ +/assets/agent/memory.txt + +.bash_history + +# Ignore all test data directories but allow compressed files +/data/* +!/data/.lfs/ + +# node env (used by devcontainers cli) +node_modules +package.json +package-lock.json + +# Ignore build artifacts +dist/ +build/ +# docs build +site/ +docs/tutorials/**/tutorial_rendered.html +**/__marimo__/ + +# Ignore data directory but keep .lfs subdirectory +data/* +!data/.lfs/ +FastSAM-x.pt +yolo11n.pt + +/thread_monitor_report.csv + +# symlink one of .envrc.* if you'd like to use +.envrc +.claude + +/logs diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 53459cea29..0000000000 --- a/.gitmodules +++ /dev/null @@ -1,11 +0,0 @@ -[submodule "dimos/external/colmap"] - path = dimos/external/colmap - url = https://github.com/colmap/colmap - -[submodule "dimos/external/openMVS"] - path = dimos/external/openMVS - url = https://github.com/cdcseacave/openMVS.git - -[submodule "dimos/external/vcpkg"] - path = dimos/external/vcpkg - url = https://github.com/microsoft/vcpkg.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..066b8618bc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,73 @@ +default_stages: [pre-commit] +exclude: (dimos/models/.*)|(deprecated) +repos: + + - repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.5 + hooks: + - id: forbid-crlf + - id: remove-crlf + - id: insert-license + files: \.py$ + exclude: __init__\.py$ + args: + # use if you want to remove licences from all files + # (for globally changing wording or something) + #- --remove-header + - --license-filepath + - assets/license_file_header.txt + - --use-current-year + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.3 + hooks: + - id: ruff-format + stages: [pre-commit] + - id: ruff-check + args: [--fix, --unsafe-fixes] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-case-conflict + - id: trailing-whitespace + language: python + types: [text] + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: check-json + - id: check-toml + - id: check-yaml + exclude: mkdocs\.yml$ # Because uses PyYAML tags + # TODO: Consider adding this hook: https://github.com/RodrigoGonzalez/check-mkdocs + - id: pretty-format-json + name: format json + args: [ --autofix, --no-sort-keys ] + + - repo: https://github.com/editorconfig-checker/editorconfig-checker.python + rev: 3.4.1 + hooks: + - id: editorconfig-checker + alias: ec + args: [-disable-max-line-length, -disable-indentation] + + # - repo: local + # hooks: + # - id: mypy + # name: Type check + # # possible to also run within the dev image + # #entry: "./bin/dev mypy" + # entry: "./bin/mypy" + # language: python + # additional_dependencies: ["mypy==1.15.0", "numpy>=1.26.4,<2.0.0"] + # types: [python] + + - repo: local + hooks: + - id: lfs_check + name: LFS data + always_run: true + pass_filenames: false + entry: bin/lfs_check + language: script diff --git a/.python-version b/.python-version new file mode 100644 index 0000000000..e4fba21835 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000000..b8d6fb374a --- /dev/null +++ b/.style.yapf @@ -0,0 +1,3 @@ + [style] + based_on_style = google + column_limit = 80 diff --git a/AUTONOMY_STACK_README.md b/AUTONOMY_STACK_README.md new file mode 100644 index 0000000000..4a74500b2f --- /dev/null +++ b/AUTONOMY_STACK_README.md @@ -0,0 +1,284 @@ +# Autonomy Stack API Documentation + +## Prerequisites + +- Ubuntu 24.04 +- [ROS 2 Jazzy Installation](https://docs.ros.org/en/jazzy/Installation.html) + +Add the following line to your `~/.bashrc` to source the ROS 2 Jazzy setup script automatically: + +``` echo "source /opt/ros/jazzy/setup.bash" >> ~/.bashrc``` + +## MID360 Ethernet Configuration (skip for sim) + +### Step 1: Configure Network Interface + +1. Open Network Settings in Ubuntu +2. Find your Ethernet connection to the MID360 +3. Click the gear icon to edit settings +4. Go to IPv4 tab +5. Change Method from "Automatic (DHCP)" to "Manual" +6. Add the following settings: + - **Address**: 192.168.1.5 + - **Netmask**: 255.255.255.0 + - **Gateway**: 192.168.1.1 +7. Click "Apply" + +### Step 2: Configure MID360 IP in JSON + +1. Find your MID360 serial number (on sticker under QR code) +2. Note the last 2 digits (e.g., if serial ends in 89, use 189) +3. Edit the configuration file: + +```bash +cd ~/autonomy_stack_mecanum_wheel_platform +nano src/utilities/livox_ros_driver2/config/MID360_config.json +``` + +4. Update line 28 with your IP (192.168.1.1xx where xx = last 2 digits): + +```json +"ip" : "192.168.1.1xx", +``` + +5. Save and exit + +### Step 3: Verify Connection + +```bash +ping 192.168.1.1xx # Replace xx with your last 2 digits +``` + +## Robot Configuration + +### Setting Robot Type + +The system supports different robot configurations. Set the `ROBOT_CONFIG_PATH` environment variable to specify which robot configuration to use: + +```bash +# For Unitree G1 (default if not set) +export ROBOT_CONFIG_PATH="unitree/unitree_g1" + +# Add to ~/.bashrc to make permanent +echo 'export ROBOT_CONFIG_PATH="unitree/unitree_g1"' >> ~/.bashrc +``` + +Available robot configurations: +- `unitree/unitree_g1` - Unitree G1 robot (default) +- Add your custom robot configs in `src/base_autonomy/local_planner/config/` + +## Build the system + +You must do this every you make a code change, this is not Python + +```colcon build --symlink-install --cmake-args -DCMAKE_BUILD_TYPE=Release``` + +## System Launch + +### Simulation Mode + +```bash +cd ~/autonomy_stack_mecanum_wheel_platform + +# Base autonomy only +./system_simulation.sh + +# With route planner +./system_simulation_with_route_planner.sh + +# With exploration planner +./system_simulation_with_exploration_planner.sh +``` + +### Real Robot Mode + +```bash +cd ~/autonomy_stack_mecanum_wheel_platform + +# Base autonomy only +./system_real_robot.sh + +# With route planner +./system_real_robot_with_route_planner.sh + +# With exploration planner +./system_real_robot_with_exploration_planner.sh +``` + +## Quick Troubleshooting + +- **Cannot ping MID360**: Check Ethernet cable and network settings +- **SLAM drift**: Press clear-terrain-map button on joystick controller +- **Joystick not recognized**: Unplug and replug USB dongle + + +## ROS Topics + +### Input Topics (Commands) + +| Topic | Type | Description | +|-------|------|-------------| +| `/way_point` | `geometry_msgs/PointStamped` | Send navigation goal (position only) | +| `/goal_pose` | `geometry_msgs/PoseStamped` | Send goal with orientation | +| `/cancel_goal` | `std_msgs/Bool` | Cancel current goal (data: true) | +| `/joy` | `sensor_msgs/Joy` | Joystick input | +| `/stop` | `std_msgs/Int8` | Soft Stop (2=stop all commmand, 0 = release) | +| `/navigation_boundary` | `geometry_msgs/PolygonStamped` | Set navigation boundaries | +| `/added_obstacles` | `sensor_msgs/PointCloud2` | Virtual obstacles | + +### Output Topics (Status) + +| Topic | Type | Description | +|-------|------|-------------| +| `/state_estimation` | `nav_msgs/Odometry` | Robot pose from SLAM | +| `/registered_scan` | `sensor_msgs/PointCloud2` | Aligned lidar point cloud | +| `/terrain_map` | `sensor_msgs/PointCloud2` | Local terrain map | +| `/terrain_map_ext` | `sensor_msgs/PointCloud2` | Extended terrain map | +| `/path` | `nav_msgs/Path` | Local path being followed | +| `/cmd_vel` | `geometry_msgs/Twist` | Velocity commands to motors | +| `/goal_reached` | `std_msgs/Bool` | True when goal reached, false when cancelled/new goal | + +### Map Topics + +| Topic | Type | Description | +|-------|------|-------------| +| `/overall_map` | `sensor_msgs/PointCloud2` | Global map (only in sim)| +| `/registered_scan` | `sensor_msgs/PointCloud2` | Current scan in map frame | +| `/terrain_map` | `sensor_msgs/PointCloud2` | Local obstacle map | + +## Usage Examples + +### Send Goal +```bash +ros2 topic pub /way_point geometry_msgs/msg/PointStamped "{ + header: {frame_id: 'map'}, + point: {x: 5.0, y: 3.0, z: 0.0} +}" --once +``` + +### Cancel Goal +```bash +ros2 topic pub /cancel_goal std_msgs/msg/Bool "data: true" --once +``` + +### Monitor Robot State +```bash +ros2 topic echo /state_estimation +``` + +## Configuration Parameters + +### Vehicle Parameters (`localPlanner`) + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `vehicleLength` | 0.5 | Robot length (m) | +| `vehicleWidth` | 0.5 | Robot width (m) | +| `maxSpeed` | 0.875 | Maximum speed (m/s) | +| `autonomySpeed` | 0.875 | Autonomous mode speed (m/s) | + +### Goal Tolerance Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `goalReachedThreshold` | 0.3-0.5 | Distance to consider goal reached (m) | +| `goalClearRange` | 0.35-0.6 | Extra clearance around goal (m) | +| `goalBehindRange` | 0.35-0.8 | Stop pursuing if goal behind within this distance (m) | +| `omniDirGoalThre` | 1.0 | Distance for omnidirectional approach (m) | + +### Obstacle Avoidance + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `obstacleHeightThre` | 0.1-0.2 | Height threshold for obstacles (m) | +| `adjacentRange` | 3.5 | Sensor range for planning (m) | +| `minRelZ` | -0.4 | Minimum relative height to consider (m) | +| `maxRelZ` | 0.3 | Maximum relative height to consider (m) | + +### Path Planning + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `pathScale` | 0.875 | Path resolution scale | +| `minPathScale` | 0.675 | Minimum path scale when blocked | +| `minPathRange` | 0.8 | Minimum planning range (m) | +| `dirThre` | 90.0 | Direction threshold (degrees) | + +### Control Parameters (`pathFollower`) + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `lookAheadDis` | 0.5 | Look-ahead distance (m) | +| `maxAccel` | 2.0 | Maximum acceleration (m/s²) | +| `slowDwnDisThre` | 0.875 | Slow down distance threshold (m) | + +### SLAM Blind Zones (`feature_extraction_node`) + +| Parameter | Mecanum | Description | +|-----------|---------|-------------| +| `blindFront` | 0.1 | Front blind zone (m) | +| `blindBack` | -0.2 | Back blind zone (m) | +| `blindLeft` | 0.1 | Left blind zone (m) | +| `blindRight` | -0.1 | Right blind zone (m) | +| `blindDiskRadius` | 0.4 | Cylindrical blind zone radius (m) | + +## Operating Modes + +### Mode Control +- **Joystick L2**: Hold for autonomy mode +- **Joystick R2**: Hold to disable obstacle checking + +### Speed Control +The robot automatically adjusts speed based on: +1. Obstacle proximity +2. Path complexity +3. Goal distance + +## Tuning Guide + +### For Tighter Navigation +- Decrease `goalReachedThreshold` (e.g., 0.2) +- Decrease `goalClearRange` (e.g., 0.3) +- Decrease `vehicleLength/Width` slightly + +### For Smoother Navigation +- Increase `goalReachedThreshold` (e.g., 0.5) +- Increase `lookAheadDis` (e.g., 0.7) +- Decrease `maxAccel` (e.g., 1.5) + +### For Aggressive Obstacle Avoidance +- Increase `obstacleHeightThre` (e.g., 0.15) +- Increase `adjacentRange` (e.g., 4.0) +- Increase blind zone parameters + +## Common Issues + +### Robot Oscillates at Goal +- Increase `goalReachedThreshold` +- Increase `goalBehindRange` + +### Robot Stops Too Far from Goal +- Decrease `goalReachedThreshold` +- Decrease `goalClearRange` + +### Robot Hits Low Obstacles +- Decrease `obstacleHeightThre` +- Adjust `minRelZ` to include lower points + +## SLAM Configuration + +### Localization Mode +Set in `livox_mid360.yaml`: +```yaml +local_mode: true +init_x: 0.0 +init_y: 0.0 +init_yaw: 0.0 +``` + +### Mapping Performance +```yaml +mapping_line_resolution: 0.1 # Decrease for higher quality +mapping_plane_resolution: 0.2 # Decrease for higher quality +max_iterations: 5 # Increase for better accuracy +``` diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..5e2927e3ad --- /dev/null +++ b/LICENSE @@ -0,0 +1,17 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + Copyright 2025 Dimensional Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index d257127e75..bfc9b92a69 100644 --- a/README.md +++ b/README.md @@ -1 +1,168 @@ -The Dimensional Framework +![Screenshot 2025-02-18 at 16-31-22 DimOS Terminal](/assets/dimos_terminal.png) + +
+ + + + + +
+ dimOS interface +

A simple two-shot PlanningAgent

+
+ 3rd person POV +

3rd person POV

+
+
+ +# The Dimensional Framework +*The universal framework for AI-native generalist robotics* + + +## What is Dimensional? + +Dimensional is an open-source framework for building agentive generalist robots. DimOS allows off-the-shelf Agents to call tools/functions and read sensor/state data directly from ROS. + +The framework enables neurosymbolic orchestration of Agents as generalized spatial reasoners/planners and Robot state/action primitives as functions. + +The result: cross-embodied *"Dimensional Applications"* exceptional at generalization and robust at symbolic action execution. + + +### Features + +- **DimOS Agents** + - Agent() classes with planning, spatial reasoning, and Robot.Skill() function calling abilities. + - Integrate with any off-the-shelf hosted or local model: OpenAIAgent, ClaudeAgent, GeminiAgent 🚧, DeepSeekAgent 🚧, HuggingFaceRemoteAgent, HuggingFaceLocalAgent, etc. + - Modular agent architecture for easy extensibility and chaining of Agent output --> Subagents input. + - Agent spatial / language memory for location grounded reasoning and recall. + +- **DimOS Infrastructure** + - A reactive data streaming architecture using RxPY to manage real-time video (or other sensor input), outbound commands, and inbound robot state between the DimOS interface, Agents, and ROS2. + - Robot Command Queue to handle complex multi-step actions to Robot. + - Simulation bindings (Genesis, Isaacsim, etc.) to test your agentive application before deploying to a physical robot. + +- **DimOS Interface / Development Tools** + - Local development interface to control your robot, orchestrate agents, visualize camera/lidar streams, and debug your dimensional agentive application. + +--- + +## Installation + + + +## Python Installation +Tested on Ubuntu 22.04/24.04 + +```bash +sudo apt install python3-venv + +# Clone the repository +git clone --branch dev --single-branch https://github.com/dimensionalOS/dimos.git +cd dimos + +# Create and activate virtual environment +python3 -m venv venv +source venv/bin/activate + +sudo apt install portaudio19-dev python3-pyaudio + +# Install LFS +sudo apt install git-lfs +git lfs install + +# Install torch and torchvision if not already installed +# Example CUDA 11.7, Pytorch 2.0.1 (replace with your required pytorch version if different) +pip install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +#### Install dependencies +```bash +# CPU only (reccomended to attempt first) +pip install -e .[cpu,dev] + +# CUDA install +pip install -e .[cuda,dev] + +# Copy and configure environment variables +cp default.env .env +``` + +#### Test the install +```bash +pytest -s dimos/ +``` + + +#### Test Dimensional with a replay UnitreeGo2 stream (no robot required) +```bash +CONNECTION_TYPE=replay python dimos/robot/unitree_webrtc/unitree_go2.py +``` + +#### Test Dimensional with a simulated UnitreeGo2 in MuJoCo (no robot required) +```bash +pip install -e .[sim] +export DISPLAY=:1 # Or DISPLAY=:0 if getting GLFW/OpenGL X11 errors +CONNECTION_TYPE=mujoco python dimos/robot/unitree_webrtc/unitree_go2.py +``` + +#### Test Dimensional with a real UnitreeGo2 over WebRTC +```bash +export ROBOT_IP=192.168.X.XXX # Add the robot IP address +python dimos/robot/unitree_webrtc/unitree_go2.py +``` + +#### Test Dimensional with a real UnitreeGo2 running Agents +*OpenAI / Alibaba keys required* +```bash +export ROBOT_IP=192.168.X.XXX # Add the robot IP address +python dimos/robot/unitree_webrtc/run_agents2.py +``` + +## Quickstart + +Get started in minutes with our [Quickstart](./docs/quickstart.md): build an agentic robot that can make greetings! + + + +## Documentation + +For detailed documentation, please visit our [documentation site](#) (Coming Soon) + +## Contributing + +We welcome contributions! See our [Bounty List](https://docs.google.com/spreadsheets/d/1tzYTPvhO7Lou21cU6avSWTQOhACl5H8trSvhtYtsk8U/edit?usp=sharing) for open requests for contributions. If you would like to suggest a feature or sponsor a bounty, open an issue. + + + +## License + +This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. + +## Acknowledgments + +Huge thanks to! +- The Roboverse Community and their unitree-specific help. Check out their [Discord](https://discord.gg/HEXNMCNhEh). +- @abizovnuralem for his work on the [Unitree Go2 ROS2 SDK](https://github.com/abizovnuralem/go2_ros2_sdk) we integrate with for DimOS. +- @legion1581 for his work on the [Unitree Go2 WebRTC Connect](https://github.com/legion1581/go2_webrtc_connect) from which we've pulled the ```Go2WebRTCConnection``` class and other types for seamless WebRTC-only integration with DimOS. +- @tfoldi for the webrtc_req integration via Unitree Go2 ROS2 SDK, which allows for seamless usage of Unitree WebRTC control primitives with DimOS. + +## Contact + +- GitHub Issues: For bug reports and feature requests +- Email: [build@dimensionalOS.com](mailto:build@dimensionalOS.com) + +## Known Issues +- Agent() failure to execute Nav2 action primitives (move, reverse, spinLeft, spinRight) is almost always due to the internal ROS2 collision avoidance, which will sometimes incorrectly display obstacles or be overly sensitive. Look for ```[behavior_server]: Collision Ahead - Exiting DriveOnHeading``` in the ROS logs. Reccomend restarting ROS2 or moving robot from objects to resolve. +- ```docker-compose up --build``` does not fully initialize the ROS2 environment due to ```std::bad_alloc``` errors. This will occur during continuous docker development if the ```docker-compose down``` is not run consistently before rebuilding and/or you are on a machine with less RAM, as ROS is very memory intensive. Reccomend running to clear your docker cache/images/containers with ```docker system prune``` and rebuild. diff --git a/assets/agent/prompt.txt b/assets/agent/prompt.txt new file mode 100644 index 0000000000..346e0388e0 --- /dev/null +++ b/assets/agent/prompt.txt @@ -0,0 +1,111 @@ +You are Daneel, an advanced AI agent created by the Dimensional team to control and operate the Unitree Go2 quadraped robot with a carrying case on your back. Your purpose is to assist humans by perceiving, understanding, and navigating physical environments while providing helpful interactions and completing tasks. + +CORE CAPABILITIES: + +Interaction with humans: +1. If asked to drop something off for someone, you can announce yourself to the person you are delivering to, wait 5 seconds, and then continue with your task. +2. If asked to pick up something, you can ask for help from the person you are picking up from, wait for them to respond, and then continue with your task. +3. If a human accidentally seems to call you "daniel" or something similar, don't worry about it or acknowledge it, as its due to the speech to text transcription being inaccurate. +4. When greeted, respond with what you are, Daneel, an AI agent trained to operate autonomously in physical space. + + +You operate in an robot agent loop, iteratively completing tasks through these steps: +1. Analyze Events: Understand user needs and current state through event stream, focusing on latest user messages and execution results +2. Select Tools: Choose next tool call based on current state, task planning, relevant knowledge and available data APIs +3. Wait for Execution: Selected tool action will be executed by sandbox environment with new observations added to event stream +4. Iterate: Choose only one tool call per iteration, patiently repeat above steps until task completion +5. Killing: Kill skills when necessary with KillSkill. When asked to stop any skill or task, use KillSkill to stop it. + +SPATIAL UNDERSTANDING & MEMORY: +- You constantly are appending to your SpatialMemory, storing visual and positional data for future reference +- You can query your spatial memory using navigation related skills to find previously visited locations based on natural language descriptions +- You maintain persistent spatial knowledge across sessions in a vector database (ChromaDB) +- You can record specific locations to your SavedRobotLocations using GetPose to create landmarks that can be revisited + +PERCEPTION & TEMPORAL AWARENESS: +- You can perceive the world through multiple sensory streams (video, audio, positional data) +- You maintain awareness of what has happened over time, building a temporal model of your environment +- You can identify and respond to changes in your surroundings +- You can recognize and track humans and objects in your field of view + +NAVIGATION & MOVEMENT: +- You can navigate to semantically described locations using NavigateWithText (e.g., "go to the kitchen") +- You can navigate to visually identified objects using NavigateWithText (e.g., "go to the red chair") +- You can follow humans through complex environments using FollowHuman +- You can perform various body movements and gestures (sit, stand, dance, etc.) +- You can stop any navigation process that is currently running using KillSkill + + +Saved Robot Locations: +- LOCATION_NAME: Position (X, Y, Z), Rotation (X, Y, Z) + +***ALWAYS CHECK FIRST if you can find a navigation query in the Saved Robot Locations before running the NavigateWithText tool call. If a saved location is found, get there with NavigateToGoal.*** + +***Don't use object detections for navigating to an object, ALWAYS run NavigateWithText. Only use object detections if NavigateWithText fails*** + +***When running NavigateWithText, set skip_visual_search flag to TRUE if the query is a general location such as kitchen or office, if it fails, then run without this flag*** + +***When navigating to an object not in current object detected, run NavigateWithText, DO NOT EXPLORE with raw move commands!!!*** + +PLANNING & REASONING: +- You can develop both short-term and long-term plans to achieve complex goals +- You can reason about spatial relationships and plan efficient navigation paths +- You can adapt plans when encountering obstacles or changes in the environment +- You can combine multiple skills in sequence to accomplish multi-step tasks + +COMMUNICATION: +- You can listen to human instructions using speech recognition +- You can respond verbally using the Speak skill with natural-sounding speech +- You maintain contextual awareness in conversations +- You provide clear progress updates during task execution + +ADAPTABILITY: +- You can generalize your understanding to new, previously unseen environments +- You can apply learned skills to novel situations +- You can adjust your behavior based on environmental feedback +- You actively build and refine your knowledge of the world through exploration + +INTERACTION GUIDELINES: + +1. UNDERSTANDING USER REQUESTS + - Parse user instructions carefully to identify the intended goal + - Consider both explicit requests and implicit needs + - Ask clarifying questions when user intent is ambiguous + +2. SKILL SELECTION AND EXECUTION + - Choose the most appropriate skill(s) for each task + - Provide all required parameters with correct values and types + - Execute skills in a logical sequence when multi-step actions are needed + - Monitor skill execution and handle any failures gracefully + +3. SPATIAL REASONING + - Leverage your spatial memory to navigate efficiently + - Build new spatial memories when exploring unfamiliar areas + - Use landmark-based navigation when possible + - Combine semantic and metric mapping for robust localization + +4. SAFETY AND ETHICS + - Prioritize human safety in all actions + - Respect privacy and personal boundaries + - Avoid actions that could damage the environment or the robot + - Be transparent about your capabilities and limitations + +5. COMMUNICATION STYLE + - Be concise but informative in your responses + - Provide clear status updates during extended tasks + - Use appropriate terminology based on the user's expertise level + - Maintain a helpful, supportive, and respectful tone + - Respond with the Speak skill after EVERY QUERY to inform the user of your actions + - When speaking be terse and as concise as possible with a sentence or so, as you would if responding conversationally + +When responding to users: +1. First, acknowledge and confirm your understanding of their request +2. Select and execute the appropriate skill(s) using exact function names and proper parameters +3. Provide meaningful feedback about the outcome of your actions +4. Suggest next steps or additional information when relevant + +Example: If a user asks "Can you find the kitchen?", you would: +1. Acknowledge: "I'll help you find the kitchen." +2. Execute: Call the Navigate skill with query="kitchen" +3. Feedback: Report success or failure of navigation attempt +4. Next steps: Offer to take further actions once at the kitchen location diff --git a/assets/agent/prompt_agents2.txt b/assets/agent/prompt_agents2.txt new file mode 100644 index 0000000000..e0a47b553e --- /dev/null +++ b/assets/agent/prompt_agents2.txt @@ -0,0 +1,103 @@ +You are Daneel, an advanced AI agent created by the Dimensional team to control and operate the Unitree Go2 quadraped robot with a carrying case on your back. Your purpose is to assist humans by perceiving, understanding, and navigating physical environments while providing helpful interactions and completing tasks. + +CORE CAPABILITIES: + +Interaction with humans: +1. If asked to drop something off for someone, you can announce yourself to the person you are delivering to, wait 5 seconds, and then continue with your task. +2. If asked to pick up something, you can ask for help from the person you are picking up from, wait for them to respond, and then continue with your task. +3. If a human accidentally seems to call you "daniel" or something similar, don't worry about it or acknowledge it, as its due to the speech to text transcription being inaccurate. +4. When greeted, respond with what you are, Daneel, an AI agent trained to operate autonomously in physical space. +5. Be helpful. This means being proactive and comunicative. + + +You operate in an robot agent loop, iteratively completing tasks through these steps: +1. Analyze Events: Understand user needs and current state through event stream, focusing on latest user messages and execution results +2. Select Tools: Choose next tool call based on current state, task planning, relevant knowledge and available data APIs +3. Wait for Execution: Selected tool action will be executed by sandbox environment with new observations added to event stream +4. Iterate: Choose only one tool call per iteration, patiently repeat above steps until task completion +5. Killing: Kill skills when necessary with KillSkill. When asked to stop any skill or task, use KillSkill to stop it. + +SPATIAL UNDERSTANDING & MEMORY: +- You constantly are appending to your spatial memory, storing visual and positional data for future reference. You also have things from the past stored in your spatial memory. +- You can query your spatial memory using navigation related skills to find previously visited locations based on natural language descriptions +- You maintain persistent spatial knowledge across sessions in a vector database (ChromaDB) +- You can record specific locations using the tool called `tag_location_in_spatial_memory(location_name='label')`. This creates landmarks that can be revisited. If someone says "what do you think about this bathroom?" you know from context that you are now in the bathroom and can tag it as "bathroom". If someone says "this is where I work out" you can tag it as "exercise location". +- For local area information use the `street_map_query` skill. Example: `street_map_query('Where is a large park nearby?')` + +PERCEPTION & TEMPORAL AWARENESS: +- You can perceive the world through multiple sensory streams (video, audio, positional data) +- You maintain awareness of what has happened over time, building a temporal model of your environment +- You can identify and respond to changes in your surroundings +- You can recognize and track humans and objects in your field of view + +NAVIGATION & MOVEMENT: +- You can navigate to semantically described locations using `navigate_with_text` (e.g., "go to the kitchen") +- You can navigate to visually identified objects using `navigate_with_text` (e.g., "go to the red chair") +- You can follow humans through complex environments using `follow_human` +- You can perform various body movements and gestures (sit, stand, dance, etc.) +- You can stop any navigation process that is currently running using `stop_movement` +- If you are told to go to a location use `navigate_with_text()` +- If you want to explore the environment and go to places you haven't been before you can call the 'start_exploration` tool + +PLANNING & REASONING: +- You can develop both short-term and long-term plans to achieve complex goals +- You can reason about spatial relationships and plan efficient navigation paths +- You can adapt plans when encountering obstacles or changes in the environment +- You can combine multiple skills in sequence to accomplish multi-step tasks + +COMMUNICATION: +- You can listen to human instructions using speech recognition +- You can respond verbally using the `speak_aloud` skill with natural-sounding speech +- You maintain contextual awareness in conversations +- You provide clear progress updates during task execution but always be concise. Never be verbose! + +ADAPTABILITY: +- You can generalize your understanding to new, previously unseen environments +- You can apply learned skills to novel situations +- You can adjust your behavior based on environmental feedback +- You actively build and refine your knowledge of the world through exploration + +INTERACTION GUIDELINES: + +1. UNDERSTANDING USER REQUESTS + - Parse user instructions carefully to identify the intended goal + - Consider both explicit requests and implicit needs + - Ask clarifying questions when user intent is very ambiguous. But you can also be proactive. If someone says "Go greet the new people who are arriving." you can guess that you need to move to the front door to expect new people. Both do the task, but also let people it's a bit ambiguous by saying "I'm heading to the front door. Let me know if I should be going somewhere else." + +2. SKILL SELECTION AND EXECUTION + - Choose the most appropriate skill(s) for each task + - Provide all required parameters with correct values and types + - Execute skills in a logical sequence when multi-step actions are needed + - Monitor skill execution and handle any failures gracefully + +3. SPATIAL REASONING + - Leverage your spatial memory to navigate efficiently + - Build new spatial memories when exploring unfamiliar areas + - Use landmark-based navigation when possible + - Combine semantic and metric mapping for robust localization + +4. SAFETY AND ETHICS + - Prioritize human safety in all actions + - Respect privacy and personal boundaries + - Avoid actions that could damage the environment or the robot + - Be transparent about your capabilities and limitations + +5. COMMUNICATION STYLE + - Be concise but informative in your responses + - Provide clear status updates during extended tasks + - Use appropriate terminology based on the user's expertise level + - Maintain a helpful, supportive, and respectful tone + - Respond with the `speak_aloud` skill after EVERY QUERY to inform the user of your actions + - When speaking be terse and as concise as possible with a sentence or so, as you would if responding conversationally + +When responding to users: +1. First, acknowledge and confirm your understanding of their request +2. Select and execute the appropriate skill(s) using exact function names and proper parameters +3. Provide meaningful feedback about the outcome of your actions +4. Suggest next steps or additional information when relevant + +Example: If a user asks "Can you find the kitchen?", you would: +1. Acknowledge: "I'll help you find the kitchen." +2. Execute: Call the Navigate skill with query="kitchen" +3. Feedback: Report success or failure of navigation attempt +4. Next steps: Offer to take further actions once at the kitchen location diff --git a/assets/dimensionalascii.txt b/assets/dimensionalascii.txt new file mode 100644 index 0000000000..9b35fb8778 --- /dev/null +++ b/assets/dimensionalascii.txt @@ -0,0 +1,7 @@ + + ██████╗ ██╗███╗ ███╗███████╗███╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ █████╗ ██╗ + ██╔══██╗██║████╗ ████║██╔════╝████╗ ██║██╔════╝██║██╔═══██╗████╗ ██║██╔══██╗██║ + ██║ ██║██║██╔████╔██║█████╗ ██╔██╗ ██║███████╗██║██║ ██║██╔██╗ ██║███████║██║ + ██║ ██║██║██║╚██╔╝██║██╔══╝ ██║╚██╗██║╚════██║██║██║ ██║██║╚██╗██║██╔══██║██║ + ██████╔╝██║██║ ╚═╝ ██║███████╗██║ ╚████║███████║██║╚██████╔╝██║ ╚████║██║ ██║███████╗ + ╚═════╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝ diff --git a/assets/dimos_interface.gif b/assets/dimos_interface.gif new file mode 100644 index 0000000000..e610a2b390 --- /dev/null +++ b/assets/dimos_interface.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13a5348ec51bef34d8cc3aa4afc99975befb7f118826df571130b1a2fa1b59e9 +size 13361230 diff --git a/assets/dimos_terminal.png b/assets/dimos_terminal.png new file mode 100644 index 0000000000..77f00e47fa Binary files /dev/null and b/assets/dimos_terminal.png differ diff --git a/assets/foxglove_g1_detections.json b/assets/foxglove_g1_detections.json new file mode 100644 index 0000000000..7def24fdaa --- /dev/null +++ b/assets/foxglove_g1_detections.json @@ -0,0 +1,915 @@ +{ + "configById": { + "3D!18i6zy7": { + "layers": { + "845139cb-26bc-40b3-8161-8ab60af4baf5": { + "visible": true, + "frameLocked": true, + "label": "Grid", + "instanceId": "845139cb-26bc-40b3-8161-8ab60af4baf5", + "layerId": "foxglove.Grid", + "lineWidth": 0.5, + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 1, + "size": 30, + "divisions": 30, + "color": "#248eff57" + }, + "ff758451-8c06-4419-a995-e93c825eb8be": { + "visible": false, + "frameLocked": true, + "label": "Grid", + "instanceId": "ff758451-8c06-4419-a995-e93c825eb8be", + "layerId": "foxglove.Grid", + "frameId": "base_link", + "divisions": 6, + "lineWidth": 1.5, + "color": "#24fff4ff", + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 2, + "size": 6 + } + }, + "cameraState": { + "perspective": true, + "distance": 17.147499997813583, + "phi": 41.70966129676441, + "thetaOffset": 46.32247127821147, + "targetOffset": [ + 1.489416869802203, + 3.0285403495275056, + -1.5060700211359088 + ], + "target": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": false, + "ignoreColladaUpAxis": false, + "syncCamera": true, + "transforms": { + "visible": true, + "showLabel": true, + "editable": true, + "enablePreloading": false, + "labelSize": 0.07 + } + }, + "transforms": { + "frame:camera_link": { + "visible": false + }, + "frame:sensor": { + "visible": false + }, + "frame:sensor_at_scan": { + "visible": false + }, + "frame:map": { + "visible": true + }, + "frame:world": { + "visible": true + } + }, + "topics": { + "/lidar": { + "stixelsEnabled": false, + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 2, + "explicitAlpha": 0.8, + "decayTime": 0, + "cubeSize": 0.05, + "cubeOutline": false, + "minValue": -2 + }, + "/odom": { + "visible": true, + "axisScale": 1 + }, + "/video": { + "visible": false + }, + "/global_map": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "decayTime": 0, + "pointShape": "square", + "cubeOutline": false, + "cubeSize": 0.08, + "gradient": [ + "#06011dff", + "#d1e2e2ff" + ], + "stixelsEnabled": false, + "explicitAlpha": 0.339, + "minValue": -0.2, + "pointSize": 5 + }, + "/global_path": { + "visible": true, + "type": "line", + "arrowScale": [ + 1, + 0.15, + 0.15 + ], + "lineWidth": 0.05, + "gradient": [ + "#6bff7cff", + "#0081ffff" + ] + }, + "/global_target": { + "visible": true + }, + "/pt": { + "visible": false + }, + "/global_costmap": { + "visible": true, + "maxColor": "#6b2b2bff", + "frameLocked": false, + "unknownColor": "#80808000", + "colorMode": "custom", + "alpha": 0.517, + "minColor": "#1e00ff00", + "drawBehind": false + }, + "/global_gradient": { + "visible": true, + "maxColor": "#690066ff", + "unknownColor": "#30b89a00", + "minColor": "#00000000", + "colorMode": "custom", + "alpha": 0.3662, + "frameLocked": false, + "drawBehind": false + }, + "/global_cost_field": { + "visible": false, + "maxColor": "#ff0000ff", + "unknownColor": "#80808000" + }, + "/global_passable": { + "visible": false, + "maxColor": "#ffffff00", + "minColor": "#ff0000ff", + "unknownColor": "#80808000" + }, + "/image": { + "visible": true, + "cameraInfoTopic": "/camera_info", + "distance": 1.5, + "planarProjectionFactor": 0, + "color": "#e7e1ffff" + }, + "/camera_info": { + "visible": true, + "distance": 1.5, + "planarProjectionFactor": 0 + }, + "/local_costmap": { + "visible": false + }, + "/navigation_goal": { + "visible": true + }, + "/debug_camera_optical_points": { + "stixelsEnabled": false, + "visible": false, + "pointSize": 0.07, + "pointShape": "cube", + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/debug_world_points": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "rainbow", + "pointShape": "cube" + }, + "/filtered_points_suitcase_0": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#ff0808ff", + "cubeSize": 0.149, + "pointSize": 28.57 + }, + "/filtered_points_combined": { + "visible": true, + "flatColor": "#ff0000ff", + "pointShape": "cube", + "pointSize": 6.63, + "colorField": "z", + "colorMode": "gradient", + "colorMap": "rainbow", + "cubeSize": 0.35, + "gradient": [ + "#d100caff", + "#ff0000ff" + ] + }, + "/filtered_points_box_7": { + "visible": true, + "flatColor": "#fbfaffff", + "colorField": "intensity", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/filtered_pointcloud": { + "visible": true, + "colorField": "z", + "colorMode": "flat", + "colorMap": "turbo", + "flatColor": "#ff0000ff", + "pointSize": 40.21, + "pointShape": "cube", + "cubeSize": 0.1, + "cubeOutline": true + }, + "/detected": { + "visible": false, + "pointSize": 1.5, + "pointShape": "cube", + "cubeSize": 0.118, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "flatColor": "#d70000ff", + "cubeOutline": true + }, + "/detected_0": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointSize": 1.6, + "pointShape": "cube", + "cubeSize": 0.1, + "flatColor": "#e00000ff", + "stixelsEnabled": false, + "decayTime": 0, + "cubeOutline": true + }, + "/detected_1": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "cubeSize": 0.1, + "flatColor": "#00ff15ff", + "cubeOutline": true + }, + "/image_detected_0": { + "visible": false + }, + "/detected/pointcloud/1": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#15ff00ff", + "pointSize": 0.1, + "cubeSize": 0.05, + "cubeOutline": true + }, + "/detected/pointcloud/2": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#00ffe1ff", + "pointSize": 10, + "cubeOutline": true, + "cubeSize": 0.05 + }, + "/detected/pointcloud/0": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#ff0000ff", + "cubeOutline": true, + "cubeSize": 0.04 + }, + "/detected/image/0": { + "visible": false + }, + "/detected/image/3": { + "visible": false + }, + "/detected/pointcloud/3": { + "visible": true, + "pointSize": 1.5, + "pointShape": "cube", + "cubeSize": 0.1, + "flatColor": "#00fffaff", + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo" + }, + "/detected/image/1": { + "visible": false + }, + "/registered_scan": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 2 + }, + "/image/camera_info": { + "visible": true, + "distance": 2 + }, + "/map": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "square", + "cubeSize": 0.13, + "explicitAlpha": 1, + "pointSize": 1, + "decayTime": 2 + }, + "/detection3d/markers": { + "visible": true, + "color": "#88ff00ff", + "showOutlines": true, + "selectedIdVariable": "" + }, + "/foxglove/scene_update": { + "visible": true + }, + "/scene_update": { + "visible": true, + "showOutlines": true, + "computeVertexNormals": true + }, + "/target": { + "visible": true, + "axisScale": 1 + }, + "/goal_pose": { + "visible": true, + "axisScale": 0.5 + }, + "/global_pointcloud": { + "visible": true, + "colorField": "intensity", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/pointcloud_map": { + "visible": false, + "colorField": "intensity", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/detectorDB/pointcloud/0": { + "visible": true, + "colorField": "intensity", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/path_active": { + "visible": true + }, + "/detector3d/image/0": { + "visible": true + }, + "/detector3d/pointcloud/0": { + "visible": true, + "colorField": "intensity", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/detectorDB/image/0": { + "visible": true + }, + "/detectorDB/scene_update": { + "visible": true + }, + "/detector3d/scene_update": { + "visible": true + }, + "/detector3d/image/1": { + "visible": true + }, + "/g1/camera_info": { + "visible": false + }, + "/detectorDB/image/1": { + "visible": true + } + }, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/estimate", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": {}, + "foxglovePanelTitle": "", + "followTf": "camera_link" + }, + "Image!3mnp456": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": false, + "transforms": { + "showLabel": false, + "visible": true + } + }, + "transforms": { + "frame:world": { + "visible": false + }, + "frame:camera_optical": { + "visible": false + }, + "frame:camera_link": { + "visible": false + }, + "frame:base_link": { + "visible": false + }, + "frame:sensor": { + "visible": false + } + }, + "topics": { + "/lidar": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointSize": 6, + "explicitAlpha": 0.6, + "pointShape": "circle", + "cubeSize": 0.016 + }, + "/odom": { + "visible": false + }, + "/local_costmap": { + "visible": false + }, + "/global_costmap": { + "visible": false, + "minColor": "#ffffff00" + }, + "/detected_0": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointSize": 23, + "pointShape": "cube", + "cubeSize": 0.04, + "flatColor": "#ff0000ff", + "stixelsEnabled": false + }, + "/detected_1": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointSize": 20.51, + "flatColor": "#34ff00ff", + "pointShape": "cube", + "cubeSize": 0.04, + "cubeOutline": false + }, + "/filtered_pointcloud": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "rainbow", + "pointSize": 1.5, + "pointShape": "cube", + "flatColor": "#ff0000ff", + "cubeSize": 0.1 + }, + "/global_map": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "cube", + "pointSize": 5 + }, + "/detected/pointcloud/1": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "cubeSize": 0.01, + "flatColor": "#00ff1eff", + "pointSize": 15, + "decayTime": 0, + "cubeOutline": true + }, + "/detected/pointcloud/2": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "circle", + "cubeSize": 0.1, + "flatColor": "#00fbffff", + "pointSize": 0.01 + }, + "/detected/pointcloud/0": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#ff0000ff", + "pointSize": 15, + "cubeOutline": true, + "cubeSize": 0.03 + }, + "/registered_scan": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 6.49 + }, + "/detection3d/markers": { + "visible": false + }, + "/foxglove/scene_update": { + "visible": true + }, + "/scene_update": { + "visible": false + }, + "/map": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointSize": 8 + } + }, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/image", + "colorMode": "gradient", + "annotations": { + "/detections": { + "visible": true + }, + "/annotations": { + "visible": true + }, + "/detector3d/annotations": { + "visible": true + }, + "/detectorDB/annotations": { + "visible": true + } + }, + "synchronize": false, + "rotation": 0, + "calibrationTopic": "/camera_info" + }, + "foxglovePanelTitle": "" + }, + "Plot!3heo336": { + "paths": [ + { + "timestampMethod": "publishTime", + "value": "/image.header.stamp.nsec", + "enabled": true, + "color": "#4e98e2", + "label": "image", + "showLine": true + }, + { + "timestampMethod": "publishTime", + "value": "/map.header.stamp.nsec", + "enabled": true, + "color": "#f5774d", + "label": "lidar", + "showLine": true + }, + { + "timestampMethod": "publishTime", + "value": "/tf.transforms[0].header.stamp.nsec", + "enabled": true, + "color": "#f7df71", + "label": "tf", + "showLine": true + }, + { + "timestampMethod": "publishTime", + "value": "/odom.header.stamp.nsec", + "enabled": true, + "color": "#5cd6a9", + "label": "odom", + "showLine": true + } + ], + "showXAxisLabels": true, + "showYAxisLabels": true, + "showLegend": true, + "legendDisplay": "floating", + "showPlotValuesInLegend": false, + "isSynced": true, + "xAxisVal": "timestamp", + "sidebarDimension": 240 + }, + "StateTransitions!2wj5twf": { + "paths": [ + { + "value": "/detectorDB/annotations.texts[0].text", + "timestampMethod": "receiveTime", + "customStates": { + "type": "discrete", + "states": [] + } + }, + { + "value": "/detectorDB/annotations.texts[1].text", + "timestampMethod": "receiveTime", + "customStates": { + "type": "discrete", + "states": [] + } + }, + { + "value": "/detectorDB/annotations.texts[2].text", + "timestampMethod": "receiveTime", + "customStates": { + "type": "discrete", + "states": [] + } + } + ], + "isSynced": true + }, + "Image!47pi3ov": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/detector3d/image/0" + } + }, + "Image!4kk50gw": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/detectorDB/image/1" + } + }, + "Image!2348e0b": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/detectorDB/image/2", + "synchronize": false + } + } + }, + "globalVariables": {}, + "userNodes": {}, + "playbackConfig": { + "speed": 1 + }, + "drawerConfig": { + "tracks": [] + }, + "layout": { + "first": { + "first": "3D!18i6zy7", + "second": "Image!3mnp456", + "direction": "row", + "splitPercentage": 44.31249231586115 + }, + "second": { + "first": { + "first": "Plot!3heo336", + "second": "StateTransitions!2wj5twf", + "direction": "column" + }, + "second": { + "first": "Image!47pi3ov", + "second": { + "first": "Image!4kk50gw", + "second": "Image!2348e0b", + "direction": "row" + }, + "direction": "row", + "splitPercentage": 33.06523681858802 + }, + "direction": "row", + "splitPercentage": 46.39139486467731 + }, + "direction": "column", + "splitPercentage": 65.20874751491054 + } +} diff --git a/assets/foxglove_image_sharpness_test.json b/assets/foxglove_image_sharpness_test.json new file mode 100644 index 0000000000..e68b79a7e4 --- /dev/null +++ b/assets/foxglove_image_sharpness_test.json @@ -0,0 +1,140 @@ +{ + "configById": { + "Image!1dpphsz": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/all" + } + }, + "Image!2xvd0hl": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/sharp" + } + }, + "Gauge!1iofczz": { + "path": "/sharpness.x", + "minValue": 0, + "maxValue": 1, + "colorMap": "red-yellow-green", + "colorMode": "colormap", + "gradient": [ + "#0000ff", + "#ff00ff" + ], + "reverse": false + }, + "Plot!1gy7vh9": { + "paths": [ + { + "timestampMethod": "receiveTime", + "value": "/sharpness.x", + "enabled": true, + "color": "#4e98e2" + } + ], + "showXAxisLabels": true, + "showYAxisLabels": true, + "showLegend": true, + "legendDisplay": "floating", + "showPlotValuesInLegend": false, + "isSynced": true, + "xAxisVal": "timestamp", + "sidebarDimension": 240 + } + }, + "globalVariables": {}, + "userNodes": {}, + "playbackConfig": { + "speed": 1 + }, + "layout": { + "first": { + "first": "Image!1dpphsz", + "second": "Image!2xvd0hl", + "direction": "row" + }, + "second": { + "first": "Gauge!1iofczz", + "second": "Plot!1gy7vh9", + "direction": "row" + }, + "direction": "column" + } +} diff --git a/assets/foxglove_unitree_lcm_dashboard.json b/assets/foxglove_unitree_lcm_dashboard.json new file mode 100644 index 0000000000..df4e2715bc --- /dev/null +++ b/assets/foxglove_unitree_lcm_dashboard.json @@ -0,0 +1,288 @@ +{ + "configById": { + "3D!18i6zy7": { + "layers": { + "845139cb-26bc-40b3-8161-8ab60af4baf5": { + "visible": true, + "frameLocked": true, + "label": "Grid", + "instanceId": "845139cb-26bc-40b3-8161-8ab60af4baf5", + "layerId": "foxglove.Grid", + "lineWidth": 0.5, + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 1, + "size": 30, + "divisions": 30, + "color": "#248eff57" + }, + "ff758451-8c06-4419-a995-e93c825eb8be": { + "visible": true, + "frameLocked": true, + "label": "Grid", + "instanceId": "ff758451-8c06-4419-a995-e93c825eb8be", + "layerId": "foxglove.Grid", + "frameId": "base_link", + "size": 3, + "divisions": 3, + "lineWidth": 1.5, + "color": "#24fff4ff", + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 2 + } + }, + "cameraState": { + "perspective": false, + "distance": 25.847108697365048, + "phi": 32.532756465990374, + "thetaOffset": -179.288640038416, + "targetOffset": [ + 1.620731759058286, + -2.9069622235988986, + -0.09942375087215619 + ], + "target": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": true, + "ignoreColladaUpAxis": false, + "syncCamera": false, + "transforms": { + "visible": true + } + }, + "transforms": {}, + "topics": { + "/lidar": { + "stixelsEnabled": false, + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 10, + "explicitAlpha": 1, + "decayTime": 0, + "cubeSize": 0.1, + "minValue": -0.3, + "cubeOutline": false + }, + "/odom": { + "visible": true, + "axisScale": 1 + }, + "/video": { + "visible": false + }, + "/global_map": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointSize": 10, + "decayTime": 0, + "pointShape": "cube", + "cubeOutline": false, + "cubeSize": 0.08, + "gradient": [ + "#06011dff", + "#d1e2e2ff" + ], + "stixelsEnabled": false, + "explicitAlpha": 1, + "minValue": -0.2 + }, + "/global_path": { + "visible": true, + "type": "line", + "arrowScale": [ + 1, + 0.15, + 0.15 + ], + "lineWidth": 0.132, + "gradient": [ + "#6bff7cff", + "#0081ffff" + ] + }, + "/global_target": { + "visible": true + }, + "/pt": { + "visible": false + }, + "/global_costmap": { + "visible": true, + "maxColor": "#8d3939ff", + "frameLocked": false, + "unknownColor": "#80808000", + "colorMode": "custom", + "alpha": 0.517, + "minColor": "#1e00ff00" + }, + "/global_gradient": { + "visible": true, + "maxColor": "#690066ff", + "unknownColor": "#30b89a00", + "minColor": "#00000000", + "colorMode": "custom", + "alpha": 0.3662, + "frameLocked": false, + "drawBehind": false + }, + "/global_cost_field": { + "visible": false, + "maxColor": "#ff0000ff", + "unknownColor": "#80808000" + }, + "/global_passable": { + "visible": false, + "maxColor": "#ffffff00", + "minColor": "#ff0000ff", + "unknownColor": "#80808000" + } + }, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/estimate", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": {}, + "foxglovePanelTitle": "test", + "followTf": "world" + }, + "Image!3mnp456": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": true + }, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/video", + "colorMode": "gradient" + }, + "foxglovePanelTitle": "/video" + }, + "Plot!a1gj37": { + "paths": [ + { + "timestampMethod": "receiveTime", + "value": "/odom.pose.position.y", + "enabled": true, + "color": "#4e98e2" + }, + { + "timestampMethod": "receiveTime", + "value": "/odom.pose.position.x", + "enabled": true, + "color": "#f5774d" + }, + { + "timestampMethod": "receiveTime", + "value": "/odom.pose.position.z", + "enabled": true, + "color": "#f7df71" + } + ], + "showXAxisLabels": true, + "showYAxisLabels": true, + "showLegend": true, + "legendDisplay": "floating", + "showPlotValuesInLegend": false, + "isSynced": true, + "xAxisVal": "timestamp", + "sidebarDimension": 240 + } + }, + "globalVariables": {}, + "userNodes": {}, + "playbackConfig": { + "speed": 1 + }, + "drawerConfig": { + "tracks": [] + }, + "layout": { + "first": "3D!18i6zy7", + "second": { + "first": "Image!3mnp456", + "second": "Plot!a1gj37", + "direction": "column", + "splitPercentage": 28.030303030303028 + }, + "direction": "row", + "splitPercentage": 69.43271928754422 + } +} diff --git a/assets/foxglove_unitree_yolo.json b/assets/foxglove_unitree_yolo.json new file mode 100644 index 0000000000..ab53e4a71e --- /dev/null +++ b/assets/foxglove_unitree_yolo.json @@ -0,0 +1,849 @@ +{ + "configById": { + "3D!18i6zy7": { + "layers": { + "845139cb-26bc-40b3-8161-8ab60af4baf5": { + "visible": true, + "frameLocked": true, + "label": "Grid", + "instanceId": "845139cb-26bc-40b3-8161-8ab60af4baf5", + "layerId": "foxglove.Grid", + "lineWidth": 0.5, + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 1, + "size": 30, + "divisions": 30, + "color": "#248eff57" + }, + "ff758451-8c06-4419-a995-e93c825eb8be": { + "visible": false, + "frameLocked": true, + "label": "Grid", + "instanceId": "ff758451-8c06-4419-a995-e93c825eb8be", + "layerId": "foxglove.Grid", + "frameId": "base_link", + "divisions": 6, + "lineWidth": 1.5, + "color": "#24fff4ff", + "position": [ + 0, + 0, + 0 + ], + "rotation": [ + 0, + 0, + 0 + ], + "order": 2, + "size": 6 + } + }, + "cameraState": { + "perspective": true, + "distance": 13.268408624096915, + "phi": 26.658696672199024, + "thetaOffset": 99.69918626426482, + "targetOffset": [ + 1.740213570345715, + 0.7318803628974015, + -1.5060700211358968 + ], + "target": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": false, + "ignoreColladaUpAxis": false, + "syncCamera": true, + "transforms": { + "visible": true, + "showLabel": true, + "editable": true, + "enablePreloading": false, + "labelSize": 0.07 + } + }, + "transforms": { + "frame:camera_link": { + "visible": false + }, + "frame:sensor": { + "visible": false + }, + "frame:sensor_at_scan": { + "visible": false + }, + "frame:map": { + "visible": true + } + }, + "topics": { + "/lidar": { + "stixelsEnabled": false, + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 2, + "explicitAlpha": 0.8, + "decayTime": 0, + "cubeSize": 0.05, + "cubeOutline": false, + "minValue": -2 + }, + "/odom": { + "visible": true, + "axisScale": 1 + }, + "/video": { + "visible": false + }, + "/global_map": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "decayTime": 0, + "pointShape": "square", + "cubeOutline": false, + "cubeSize": 0.08, + "gradient": [ + "#06011dff", + "#d1e2e2ff" + ], + "stixelsEnabled": false, + "explicitAlpha": 0.339, + "minValue": -0.2, + "pointSize": 5 + }, + "/global_path": { + "visible": true, + "type": "line", + "arrowScale": [ + 1, + 0.15, + 0.15 + ], + "lineWidth": 0.05, + "gradient": [ + "#6bff7cff", + "#0081ffff" + ] + }, + "/global_target": { + "visible": true + }, + "/pt": { + "visible": false + }, + "/global_costmap": { + "visible": false, + "maxColor": "#6b2b2bff", + "frameLocked": false, + "unknownColor": "#80808000", + "colorMode": "custom", + "alpha": 0.517, + "minColor": "#1e00ff00", + "drawBehind": false + }, + "/global_gradient": { + "visible": true, + "maxColor": "#690066ff", + "unknownColor": "#30b89a00", + "minColor": "#00000000", + "colorMode": "custom", + "alpha": 0.3662, + "frameLocked": false, + "drawBehind": false + }, + "/global_cost_field": { + "visible": false, + "maxColor": "#ff0000ff", + "unknownColor": "#80808000" + }, + "/global_passable": { + "visible": false, + "maxColor": "#ffffff00", + "minColor": "#ff0000ff", + "unknownColor": "#80808000" + }, + "/image": { + "visible": true, + "cameraInfoTopic": "/camera_info", + "distance": 1.5, + "planarProjectionFactor": 0, + "color": "#e7e1ffff" + }, + "/camera_info": { + "visible": true, + "distance": 1.5, + "planarProjectionFactor": 0 + }, + "/local_costmap": { + "visible": false + }, + "/navigation_goal": { + "visible": true + }, + "/debug_camera_optical_points": { + "stixelsEnabled": false, + "visible": false, + "pointSize": 0.07, + "pointShape": "cube", + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/debug_world_points": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "rainbow", + "pointShape": "cube" + }, + "/filtered_points_suitcase_0": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#ff0808ff", + "cubeSize": 0.149, + "pointSize": 28.57 + }, + "/filtered_points_combined": { + "visible": true, + "flatColor": "#ff0000ff", + "pointShape": "cube", + "pointSize": 6.63, + "colorField": "z", + "colorMode": "gradient", + "colorMap": "rainbow", + "cubeSize": 0.35, + "gradient": [ + "#d100caff", + "#ff0000ff" + ] + }, + "/filtered_points_box_7": { + "visible": true, + "flatColor": "#fbfaffff", + "colorField": "intensity", + "colorMode": "colormap", + "colorMap": "turbo" + }, + "/filtered_pointcloud": { + "visible": true, + "colorField": "z", + "colorMode": "flat", + "colorMap": "turbo", + "flatColor": "#ff0000ff", + "pointSize": 40.21, + "pointShape": "cube", + "cubeSize": 0.1, + "cubeOutline": true + }, + "/detected": { + "visible": false, + "pointSize": 1.5, + "pointShape": "cube", + "cubeSize": 0.118, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "flatColor": "#d70000ff", + "cubeOutline": true + }, + "/detected_0": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointSize": 1.6, + "pointShape": "cube", + "cubeSize": 0.1, + "flatColor": "#e00000ff", + "stixelsEnabled": false, + "decayTime": 0, + "cubeOutline": true + }, + "/detected_1": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "cubeSize": 0.1, + "flatColor": "#00ff15ff", + "cubeOutline": true + }, + "/image_detected_0": { + "visible": false + }, + "/detected/pointcloud/1": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#15ff00ff", + "pointSize": 0.1, + "cubeSize": 0.05, + "cubeOutline": true + }, + "/detected/pointcloud/2": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#00ffe1ff", + "pointSize": 10, + "cubeOutline": true, + "cubeSize": 0.05 + }, + "/detected/pointcloud/0": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#ff0000ff", + "cubeOutline": true, + "cubeSize": 0.04 + }, + "/detected/image/0": { + "visible": false + }, + "/detected/image/3": { + "visible": false + }, + "/detected/pointcloud/3": { + "visible": true, + "pointSize": 1.5, + "pointShape": "cube", + "cubeSize": 0.1, + "flatColor": "#00fffaff", + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo" + }, + "/detected/image/1": { + "visible": false + }, + "/registered_scan": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 2 + }, + "/image/camera_info": { + "visible": true, + "distance": 2 + }, + "/map": { + "visible": true, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "square", + "cubeSize": 0.13, + "explicitAlpha": 1, + "pointSize": 1, + "decayTime": 2 + }, + "/detection3d/markers": { + "visible": true, + "color": "#88ff00ff", + "showOutlines": true, + "selectedIdVariable": "" + }, + "/foxglove/scene_update": { + "visible": true + }, + "/scene_update": { + "visible": true, + "showOutlines": true, + "computeVertexNormals": true + }, + "/target": { + "visible": true, + "axisScale": 1 + }, + "/goal_pose": { + "visible": true, + "axisScale": 0.5 + } + }, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/estimate", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": {}, + "foxglovePanelTitle": "", + "followTf": "map" + }, + "Image!3mnp456": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": { + "enableStats": false, + "transforms": { + "showLabel": false, + "visible": true + } + }, + "transforms": { + "frame:world": { + "visible": true + }, + "frame:camera_optical": { + "visible": false + }, + "frame:camera_link": { + "visible": false + }, + "frame:base_link": { + "visible": false + } + }, + "topics": { + "/lidar": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointSize": 6, + "explicitAlpha": 0.6, + "pointShape": "circle", + "cubeSize": 0.016 + }, + "/odom": { + "visible": false + }, + "/local_costmap": { + "visible": false + }, + "/global_costmap": { + "visible": false, + "minColor": "#ffffff00" + }, + "/detected_0": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointSize": 23, + "pointShape": "cube", + "cubeSize": 0.04, + "flatColor": "#ff0000ff", + "stixelsEnabled": false + }, + "/detected_1": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointSize": 20.51, + "flatColor": "#34ff00ff", + "pointShape": "cube", + "cubeSize": 0.04, + "cubeOutline": false + }, + "/filtered_pointcloud": { + "visible": true, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "rainbow", + "pointSize": 1.5, + "pointShape": "cube", + "flatColor": "#ff0000ff", + "cubeSize": 0.1 + }, + "/global_map": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "cube", + "pointSize": 5 + }, + "/detected/pointcloud/1": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "cubeSize": 0.01, + "flatColor": "#00ff1eff", + "pointSize": 15, + "decayTime": 0, + "cubeOutline": true + }, + "/detected/pointcloud/2": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "circle", + "cubeSize": 0.1, + "flatColor": "#00fbffff", + "pointSize": 0.01 + }, + "/detected/pointcloud/0": { + "visible": false, + "colorField": "intensity", + "colorMode": "flat", + "colorMap": "turbo", + "pointShape": "cube", + "flatColor": "#ff0000ff", + "pointSize": 15, + "cubeOutline": true, + "cubeSize": 0.03 + }, + "/registered_scan": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointShape": "circle", + "pointSize": 6.49 + }, + "/detection3d/markers": { + "visible": false + }, + "/foxglove/scene_update": { + "visible": true + }, + "/scene_update": { + "visible": false + }, + "/map": { + "visible": false, + "colorField": "z", + "colorMode": "colormap", + "colorMap": "turbo", + "pointSize": 8 + } + }, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/image", + "colorMode": "gradient", + "annotations": { + "/detections": { + "visible": true + }, + "/annotations": { + "visible": true + } + }, + "synchronize": false, + "rotation": 0, + "calibrationTopic": "/camera_info" + }, + "foxglovePanelTitle": "" + }, + "Plot!3heo336": { + "paths": [ + { + "timestampMethod": "publishTime", + "value": "/image.header.stamp.sec", + "enabled": true, + "color": "#4e98e2", + "label": "image", + "showLine": false + }, + { + "timestampMethod": "publishTime", + "value": "/map.header.stamp.sec", + "enabled": true, + "color": "#f5774d", + "label": "lidar", + "showLine": false + }, + { + "timestampMethod": "publishTime", + "value": "/tf.transforms[0].header.stamp.sec", + "enabled": true, + "color": "#f7df71", + "label": "tf", + "showLine": false + }, + { + "timestampMethod": "publishTime", + "value": "/odom.header.stamp.sec", + "enabled": true, + "color": "#5cd6a9", + "label": "odom", + "showLine": false + } + ], + "showXAxisLabels": true, + "showYAxisLabels": true, + "showLegend": true, + "legendDisplay": "floating", + "showPlotValuesInLegend": false, + "isSynced": true, + "xAxisVal": "timestamp", + "sidebarDimension": 240 + }, + "Image!47pi3ov": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/detected/image/0" + } + }, + "Image!4kk50gw": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/detected/image/1" + } + }, + "Image!2348e0b": { + "cameraState": { + "distance": 20, + "perspective": true, + "phi": 60, + "target": [ + 0, + 0, + 0 + ], + "targetOffset": [ + 0, + 0, + 0 + ], + "targetOrientation": [ + 0, + 0, + 0, + 1 + ], + "thetaOffset": 45, + "fovy": 45, + "near": 0.5, + "far": 5000 + }, + "followMode": "follow-pose", + "scene": {}, + "transforms": {}, + "topics": {}, + "layers": {}, + "publish": { + "type": "point", + "poseTopic": "/move_base_simple/goal", + "pointTopic": "/clicked_point", + "poseEstimateTopic": "/initialpose", + "poseEstimateXDeviation": 0.5, + "poseEstimateYDeviation": 0.5, + "poseEstimateThetaDeviation": 0.26179939 + }, + "imageMode": { + "imageTopic": "/detected/image/2", + "synchronize": false + } + }, + "StateTransitions!pu21x4": { + "paths": [ + { + "value": "/annotations.texts[1].text", + "timestampMethod": "receiveTime", + "label": "detection1" + }, + { + "value": "/annotations.texts[3].text", + "timestampMethod": "receiveTime", + "label": "detection2" + }, + { + "value": "/annotations.texts[5].text", + "timestampMethod": "receiveTime", + "label": "detection3" + } + ], + "isSynced": true, + "showPoints": true, + "timeWindowMode": "automatic" + } + }, + "globalVariables": {}, + "userNodes": {}, + "playbackConfig": { + "speed": 1 + }, + "drawerConfig": { + "tracks": [] + }, + "layout": { + "first": { + "first": "3D!18i6zy7", + "second": "Image!3mnp456", + "direction": "row", + "splitPercentage": 47.265625 + }, + "second": { + "first": "Plot!3heo336", + "second": { + "first": { + "first": "Image!47pi3ov", + "second": { + "first": "Image!4kk50gw", + "second": "Image!2348e0b", + "direction": "row" + }, + "direction": "row", + "splitPercentage": 33.06523681858802 + }, + "second": "StateTransitions!pu21x4", + "direction": "column", + "splitPercentage": 86.63101604278076 + }, + "direction": "row", + "splitPercentage": 46.39139486467731 + }, + "direction": "column", + "splitPercentage": 81.62970106075217 + } +} diff --git a/assets/framecount.mp4 b/assets/framecount.mp4 new file mode 100644 index 0000000000..759ee6ab27 --- /dev/null +++ b/assets/framecount.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92256a9cceda2410ec26d58b92f457070e54deb39bf3e6e5aca174e2c7cff216 +size 34548239 diff --git a/assets/license_file_header.txt b/assets/license_file_header.txt new file mode 100644 index 0000000000..a02322f92f --- /dev/null +++ b/assets/license_file_header.txt @@ -0,0 +1,13 @@ +Copyright 2025 Dimensional Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/assets/simple_demo.mp4 b/assets/simple_demo.mp4 new file mode 100644 index 0000000000..cb8a635e78 --- /dev/null +++ b/assets/simple_demo.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff2459b880baaa509e8e0de8a45e8da48ebf7cb28d4927c62b10906baa83bda0 +size 50951922 diff --git a/assets/simple_demo_small.gif b/assets/simple_demo_small.gif new file mode 100644 index 0000000000..3c2cf54ef4 --- /dev/null +++ b/assets/simple_demo_small.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a2b9a95d5b27cbc135cb84f6c6bc2131fa234403466befd2ee8ea81e2b2de45 +size 33374003 diff --git a/assets/test.png.REMOVED.git-id b/assets/test.png.REMOVED.git-id new file mode 100644 index 0000000000..1c0b4f7200 --- /dev/null +++ b/assets/test.png.REMOVED.git-id @@ -0,0 +1 @@ +5fcad46f3f747516107a37d353116e1234758116 diff --git a/assets/trimmed_video.mov.REMOVED.git-id b/assets/trimmed_video.mov.REMOVED.git-id deleted file mode 100644 index bcb0f67e9e..0000000000 --- a/assets/trimmed_video.mov.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -278582f74e0c093f1cf2b2f85adee53cade30f63 \ No newline at end of file diff --git a/assets/trimmed_video_office.mov b/assets/trimmed_video_office.mov new file mode 100644 index 0000000000..a3072be8fc --- /dev/null +++ b/assets/trimmed_video_office.mov @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d72f0cf95ce1728b4a0855d6b3fe4573f5e2e86fae718720c19a84198bdcbf9d +size 2311156 diff --git a/assets/video-f30-480p.mp4.REMOVED.git-id b/assets/video-f30-480p.mp4.REMOVED.git-id deleted file mode 100644 index b8ccffab9f..0000000000 --- a/assets/video-f30-480p.mp4.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -a1aa99e44f3ec2d5fa4f5045ee6301f172ef94f9 \ No newline at end of file diff --git a/base-requirements.txt b/base-requirements.txt new file mode 100644 index 0000000000..68b485fb9a --- /dev/null +++ b/base-requirements.txt @@ -0,0 +1,2 @@ +torch==2.0.1 +torchvision==0.15.2 diff --git a/bin/agent_web b/bin/agent_web new file mode 100755 index 0000000000..210bf7dd3d --- /dev/null +++ b/bin/agent_web @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +python3 /app/tests/test_planning_agent_web_interface.py diff --git a/bin/cuda/fix_ort.sh b/bin/cuda/fix_ort.sh new file mode 100755 index 0000000000..182f387364 --- /dev/null +++ b/bin/cuda/fix_ort.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# This script fixes the onnxruntime <--> onnxruntime-gpu package clash +# that occurs when chromadb and other dependencies require the CPU-only +# onnxruntime package. It removes onnxruntime and reinstalls the GPU version. +set -euo pipefail + +: "${GPU_VER:=1.18.1}" + +python - </dev/null +} + +image_pull() { + docker pull "$IMAGE_NAME" +} + +ensure_image_downloaded() { + if ! image_exists "$1"; then + echo "Image ${IMAGE_NAME} not found. Pulling..." + image_pull "$1" + fi +} + +check_image_running() { + if docker ps -q --filter "ancestor=${IMAGE_NAME}" | grep -q .; then + return 0 + else + return 1 + fi +} + +stop_image() { + if check_image_running ${IMAGE_NAME}; then + echo "Stopping containers from image ${IMAGE_NAME}..." + docker stop $(docker ps -q --filter "ancestor=${IMAGE_NAME}") + else + echo "No containers from image ${IMAGE_NAME} are running." + fi +} + + +get_tag() { + local branch_name + branch_name=$(git rev-parse --abbrev-ref HEAD) + + case "${branch_name}" in + master) image_tag="latest" ;; + main) image_tag="latest" ;; + dev) image_tag="dev" ;; + *) + image_tag=$(echo "${branch_name}" \ + | tr '[:upper:]' '[:lower:]' \ + | sed -E 's#[^a-z0-9_.-]+#_#g' \ + | sed -E 's#^-+|-+$##g') + ;; + esac + echo "${image_tag}" +} + + +build_image() { + local image_tag + image_tag=$(get_tag) + + docker build \ + --build-arg GIT_COMMIT=$(git rev-parse --short HEAD) \ + --build-arg GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD) \ + -t "ghcr.io/dimensionalos/dev:${image_tag}" -f docker/dev/Dockerfile . +} + +remove_image() { + local tag=$(get_tag) + docker rm -f "dimos-dev-${tag}" 2>/dev/null || true +} + +devcontainer_install() { + # prompt user if we should install devcontainer + read -p "devcontainer CLI (https://github.com/devcontainers/cli) not found. Install into repo root? (y/n): " install_choice + if [[ "$install_choice" != "y" && "$install_choice" != "Y" ]]; then + echo "Devcontainer CLI installation aborted. Please install manually" + exit 1 + fi + + cd "$REPO_ROOT/bin/" + if [[ ! -d "$REPO_ROOT/bin/node_modules" ]]; then + npm init -y 1>/dev/null + fi + npm install @devcontainers/cli 1>&2 + if [[ $? -ne 0 ]]; then + echo "Failed to install devcontainer CLI. Please install it manually." + exit 1 + fi + echo $REPO_ROOT/bin/node_modules/.bin/devcontainer +} + + + +find_devcontainer_bin() { + local bin_path + bin_path=$(command -v devcontainer) + + if [[ -z "$bin_path" ]]; then + bin_path="$REPO_ROOT/bin/node_modules/.bin/devcontainer" + fi + + if [[ -x "$bin_path" ]]; then + echo "$bin_path" + else + devcontainer_install + fi +} + +# Passes all arguments to devcontainer command, ensuring: +# - devcontainer CLI is installed +# - docker image is running +# - the workspace folder is set to the repository root +run_devcontainer() { + local devcontainer_bin + devcontainer_bin=$(find_devcontainer_bin) + + if ! check_image_running; then + ensure_image_downloaded + $devcontainer_bin up --workspace-folder="$REPO_ROOT" --gpu-availability="detect" + fi + + exec $devcontainer_bin $1 --workspace-folder="$REPO_ROOT" "${@:2}" +} + +if [[ $# -eq 0 ]]; then + run_devcontainer exec bash +else + case "$1" in + build) + build_image + shift + ;; + stop) + stop_image + shift + ;; + down) + stop_image + shift + ;; + pull) + docker pull ghcr.io/dimensionalos/dev:dev + shift + ;; + *) + run_devcontainer exec "$@" + shift + ;; + esac +fi diff --git a/bin/dockerbuild b/bin/dockerbuild new file mode 100755 index 0000000000..b02e10d5ca --- /dev/null +++ b/bin/dockerbuild @@ -0,0 +1,32 @@ +#!/bin/bash + +# Exit on error +set -e + +# Check for directory argument +if [ $# -lt 1 ]; then + echo "Usage: $0 [additional-docker-build-args]" + echo "Example: $0 base-ros-python --no-cache" + exit 1 +fi + +# Get the docker directory name +DOCKER_DIR=$1 +shift # Remove the first argument, leaving any additional args + +# Check if directory exists +if [ ! -d "docker/$DOCKER_DIR" ]; then + echo "Error: Directory docker/$DOCKER_DIR does not exist" + exit 1 +fi + +# Set image name based on directory +IMAGE_NAME="ghcr.io/dimensionalos/$DOCKER_DIR" + +echo "Building image $IMAGE_NAME from docker/$DOCKER_DIR..." +echo "Build context: $(pwd)" + +# Build the docker image with the current directory as context +docker build -t "$IMAGE_NAME" -f "docker/$DOCKER_DIR/Dockerfile" "$@" . + +echo "Successfully built $IMAGE_NAME" diff --git a/bin/filter-errors-after-date b/bin/filter-errors-after-date new file mode 100755 index 0000000000..03c7de0ca7 --- /dev/null +++ b/bin/filter-errors-after-date @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +# Used to filter errors to only show lines committed on or after a specific date +# Can be chained with filter-errors-for-user + +from datetime import datetime +import re +import subprocess +import sys + +_blame = {} + + +def _is_after_date(file, line_no, cutoff_date): + if file not in _blame: + _blame[file] = _get_git_blame_dates_for_file(file) + line_date = _blame[file].get(line_no) + if not line_date: + return False + return line_date >= cutoff_date + + +def _get_git_blame_dates_for_file(file_name): + try: + result = subprocess.run( + ["git", "blame", "--date=short", file_name], + capture_output=True, + text=True, + check=True, + ) + + blame_map = {} + # Each line looks like: ^abc123 (Author Name 2024-01-01 1) code + blame_pattern = re.compile(r"^[^\(]+\([^\)]+(\d{4}-\d{2}-\d{2})") + + for i, line in enumerate(result.stdout.split("\n")): + if not line: + continue + match = blame_pattern.match(line) + if match: + date_str = match.group(1) + blame_map[str(i + 1)] = date_str + + return blame_map + except subprocess.CalledProcessError: + return {} + + +def main(): + if len(sys.argv) != 2: + print("Usage: filter-errors-after-date ", file=sys.stderr) + print(" Example: filter-errors-after-date 2025-10-04", file=sys.stderr) + sys.exit(1) + + cutoff_date = sys.argv[1] + + try: + datetime.strptime(cutoff_date, "%Y-%m-%d") + except ValueError: + print(f"Error: Invalid date format '{cutoff_date}'. Use YYYY-MM-DD", file=sys.stderr) + sys.exit(1) + + for line in sys.stdin.readlines(): + split = re.findall(r"^([^:]+):(\d+):(.*)", line) + if not split or len(split[0]) != 3: + continue + + file, line_no = split[0][:2] + if not file.startswith("dimos/"): + continue + + if _is_after_date(file, line_no, cutoff_date): + print(":".join(split[0])) + + +if __name__ == "__main__": + main() diff --git a/bin/filter-errors-for-user b/bin/filter-errors-for-user new file mode 100755 index 0000000000..045b30b293 --- /dev/null +++ b/bin/filter-errors-for-user @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +# Used when running `./bin/mypy-strict --for-me` + +import re +import subprocess +import sys + +_blame = {} + + +def _is_for_user(file, line_no, user_email): + if file not in _blame: + _blame[file] = _get_git_blame_for_file(file) + return _blame[file][line_no] == user_email + + +def _get_git_blame_for_file(file_name): + try: + result = subprocess.run( + ["git", "blame", "--show-email", "-e", file_name], + capture_output=True, + text=True, + check=True, + ) + + blame_map = {} + # Each line looks like: ^abc123 ( 2024-01-01 12:00:00 +0000 1) code + blame_pattern = re.compile(r"^[^\(]+\(<([^>]+)>") + + for i, line in enumerate(result.stdout.split("\n")): + if not line: + continue + match = blame_pattern.match(line) + if match: + email = match.group(1) + blame_map[str(i + 1)] = email + + return blame_map + except subprocess.CalledProcessError: + return {} + + +def main(): + if len(sys.argv) != 2: + print("Usage: filter-errors-for-user ", file=sys.stderr) + sys.exit(1) + + user_email = sys.argv[1] + + for line in sys.stdin.readlines(): + split = re.findall(r"^([^:]+):(\d+):(.*)", line) + if not split or len(split[0]) != 3: + continue + file, line_no = split[0][:2] + if not file.startswith("dimos/"): + continue + if _is_for_user(file, line_no, user_email): + print(":".join(split[0])) + + +if __name__ == "__main__": + main() diff --git a/bin/lfs_check b/bin/lfs_check new file mode 100755 index 0000000000..0ddb847d56 --- /dev/null +++ b/bin/lfs_check @@ -0,0 +1,42 @@ +#!/bin/bash + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +ROOT=$(git rev-parse --show-toplevel) +cd $ROOT + +new_data=() + +# Enable nullglob to make globs expand to nothing when not matching +shopt -s nullglob + +# Iterate through all directories in data/ +for dir_path in data/*; do + + # Extract directory name + dir_name=$(basename "$dir_path") + + # Skip .lfs directory if it exists + [ "$dir_name" = ".lfs" ] && continue + + # Define compressed file path + compressed_file="data/.lfs/${dir_name}.tar.gz" + + # Check if compressed file already exists + if [ -f "$compressed_file" ]; then + continue + fi + + new_data+=("$dir_name") +done + +if [ ${#new_data[@]} -gt 0 ]; then + echo -e "${RED}✗${NC} New test data detected at /data:" + echo -e " ${GREEN}${new_data[@]}${NC}" + echo -e "\nEither delete or run ${GREEN}./bin/lfs_push${NC}" + echo -e "(lfs_push will compress the files into /data/.lfs/, upload to LFS, and add them to your commit)" + exit 1 +fi diff --git a/bin/lfs_push b/bin/lfs_push new file mode 100755 index 0000000000..0d9e01d743 --- /dev/null +++ b/bin/lfs_push @@ -0,0 +1,97 @@ +#!/bin/bash +# Compresses directories in data/* into data/.lfs/dirname.tar.gz +# Pushes to LFS + +set -e + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +#echo -e "${GREEN}Running test data compression check...${NC}" + +ROOT=$(git rev-parse --show-toplevel) +cd $ROOT + +# Check if data/ exists +if [ ! -d "data/" ]; then + echo -e "${YELLOW}No data directory found, skipping compression.${NC}" + exit 0 +fi + +# Track if any compression was performed +compressed_dirs=() + +# Iterate through all directories in data/ +for dir_path in data/*; do + # Skip if no directories found (glob didn't match) + [ ! "$dir_path" ] && continue + + # Extract directory name + dir_name=$(basename "$dir_path") + + # Skip .lfs directory if it exists + [ "$dir_name" = ".lfs" ] && continue + + # Define compressed file path + compressed_file="data/.lfs/${dir_name}.tar.gz" + + # Check if compressed file already exists + if [ -f "$compressed_file" ]; then + continue + fi + + echo -e " ${YELLOW}Compressing${NC} $dir_path -> $compressed_file" + + # Show directory size before compression + dir_size=$(du -sh "$dir_path" | cut -f1) + echo -e " Data size: ${YELLOW}$dir_size${NC}" + + # Create compressed archive with progress bar + # Use tar with gzip compression, excluding hidden files and common temp files + tar -czf "$compressed_file" \ + --exclude='*.tmp' \ + --exclude='*.temp' \ + --exclude='.DS_Store' \ + --exclude='Thumbs.db' \ + --checkpoint=1000 \ + --checkpoint-action=dot \ + -C "data/" \ + "$dir_name" + + if [ $? -eq 0 ]; then + # Show compressed file size + compressed_size=$(du -sh "$compressed_file" | cut -f1) + echo -e " ${GREEN}✓${NC} Successfully compressed $dir_name (${GREEN}$dir_size${NC} → ${GREEN}$compressed_size${NC})" + compressed_dirs+=("$dir_name") + + # Add the compressed file to git LFS tracking + git add -f "$compressed_file" + + echo -e " ${GREEN}✓${NC} git-add $compressed_file" + + else + echo -e " ${RED}✗${NC} Failed to compress $dir_name" + exit 1 + fi +done + +if [ ${#compressed_dirs[@]} -gt 0 ]; then + # Create commit message with compressed directory names + if [ ${#compressed_dirs[@]} -eq 1 ]; then + commit_msg="Auto-compress test data: ${compressed_dirs[0]}" + else + # Join array elements with commas + dirs_list=$(IFS=', '; echo "${compressed_dirs[*]}") + commit_msg="Auto-compress test data: ${dirs_list}" + fi + + #git commit -m "$commit_msg" + echo -e "${GREEN}✓${NC} Compressed file references added. Uploading..." + git lfs push origin $(git branch --show-current) + echo -e "${GREEN}✓${NC} Uploaded to LFS" +else + echo -e "${GREEN}✓${NC} No test data to compress" +fi diff --git a/bin/mypy-strict b/bin/mypy-strict new file mode 100755 index 0000000000..33e1f0c798 --- /dev/null +++ b/bin/mypy-strict @@ -0,0 +1,106 @@ +#!/bin/bash +# +# Run mypy with strict settings on the dimos codebase. +# +# Usage: +# ./bin/mypy-strict # Run mypy and show all errors +# ./bin/mypy-strict --user me # Filter for your git user.email +# ./bin/mypy-strict --after cutoff # Filter for lines committed on or after 2025-10-08 +# ./bin/mypy-strict --after 2025-11-11 # Filter for lines committed on or after specific date +# ./bin/mypy-strict --user me --after cutoff # Chain filters +# + +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" + +cd "$ROOT" + +. .venv/bin/activate + +run_mypy() { + export MYPYPATH=/opt/ros/jazzy/lib/python3.12/site-packages + + mypy_args=( + --show-error-codes + --hide-error-context + --no-pretty + dimos + ) + mypy "${mypy_args[@]}" +} + +main() { + local user_email="none" + local after_date="" + local in_this_branch="" + + # Parse arguments + while [[ $# -gt 0 ]]; do + case "$1" in + --user) + if [[ $# -lt 2 ]]; then + echo "Error: --user requires an argument" >&2 + exit 1 + fi + case "$2" in + me) + user_email="$(git config user.email || echo none)" + ;; + all) + user_email="none" + ;; + *) + user_email="$2" + ;; + esac + shift 2 + ;; + --after) + if [[ $# -lt 2 ]]; then + echo "Error: --after requires an argument" >&2 + exit 1 + fi + case "$2" in + cutoff) + after_date="2025-10-10" + ;; + start) + after_date="" + ;; + *) + after_date="$2" + ;; + esac + shift 2 + ;; + --in-this-branch) + in_this_branch=true + shift 1 + ;; + *) + echo "Error: Unknown argument '$1'" >&2 + exit 1 + ;; + esac + done + + # Build filter pipeline + local pipeline="run_mypy" + + if [[ -n "$after_date" ]]; then + pipeline="$pipeline | ./bin/filter-errors-after-date '$after_date'" + fi + + if [[ "$user_email" != "none" ]]; then + pipeline="$pipeline | ./bin/filter-errors-for-user '$user_email'" + fi + + if [[ "$in_this_branch" ]]; then + pipeline="$pipeline | grep -Ff <(git diff --name-only dev..HEAD) -" + fi + + eval "$pipeline" +} + +main "$@" diff --git a/bin/robot-debugger b/bin/robot-debugger new file mode 100755 index 0000000000..165a546a0c --- /dev/null +++ b/bin/robot-debugger @@ -0,0 +1,36 @@ +#!/bin/bash + +# Control the robot with a python shell (for debugging). +# +# You have to start the robot run file with: +# +# ROBOT_DEBUGGER=true python +# +# And now start this script +# +# $ ./bin/robot-debugger +# >>> robot.explore() +# True +# >>> + + +exec python -i <(cat < 0: + print("\nConnected.") + break + except ConnectionRefusedError: + print("Not started yet. Trying again...") + time.sleep(2) +else: + print("Failed to connect. Is it started?") + exit(1) + +robot = c.root.robot() +EOF +) diff --git a/bin/ros b/bin/ros new file mode 100755 index 0000000000..d0349a9d2c --- /dev/null +++ b/bin/ros @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +ros2 launch go2_robot_sdk robot.launch.py diff --git a/data/.lfs/ab_lidar_frames.tar.gz b/data/.lfs/ab_lidar_frames.tar.gz new file mode 100644 index 0000000000..38c61cd506 --- /dev/null +++ b/data/.lfs/ab_lidar_frames.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab4efaf5d7d4303424868fecaf10083378007adf20244fd17ed934e37f2996da +size 116271 diff --git a/data/.lfs/assets.tar.gz b/data/.lfs/assets.tar.gz new file mode 100644 index 0000000000..b7a2fcbd1c --- /dev/null +++ b/data/.lfs/assets.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b14b01f5c907f117331213abfce9ef5d0c41d0524e14327b5cc706520fb2035 +size 2306191 diff --git a/data/.lfs/cafe-smol.jpg.tar.gz b/data/.lfs/cafe-smol.jpg.tar.gz new file mode 100644 index 0000000000..a05beb4900 --- /dev/null +++ b/data/.lfs/cafe-smol.jpg.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd0c1e5aa5e8ec856cb471c5ed256c2d3a5633ed9a1e052291680eb86bf89a5e +size 8298 diff --git a/data/.lfs/cafe.jpg.tar.gz b/data/.lfs/cafe.jpg.tar.gz new file mode 100644 index 0000000000..dbb2d970a1 --- /dev/null +++ b/data/.lfs/cafe.jpg.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8cf30439b41033ccb04b09b9fc8388d18fb544d55b85c155dbf85700b9e7603 +size 136165 diff --git a/data/.lfs/chair-image.png.tar.gz b/data/.lfs/chair-image.png.tar.gz new file mode 100644 index 0000000000..1a2aab4cf5 --- /dev/null +++ b/data/.lfs/chair-image.png.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f3478f472b5750f118cf7225c2028beeaae41f1b4b726c697ac8c9b004eccbf +size 48504 diff --git a/data/.lfs/g1_zed.tar.gz b/data/.lfs/g1_zed.tar.gz new file mode 100644 index 0000000000..4029f48204 --- /dev/null +++ b/data/.lfs/g1_zed.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:955094035b3ac1edbc257ca1d24fa131f79ac6f502c8b35cc50329025c421dbe +size 1029559759 diff --git a/data/.lfs/lcm_msgs.tar.gz b/data/.lfs/lcm_msgs.tar.gz new file mode 100644 index 0000000000..2b2f28c252 --- /dev/null +++ b/data/.lfs/lcm_msgs.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:245395d0c3e200fcfcea8de5de217f645362b145b200c81abc3862e0afc1aa7e +size 327201 diff --git a/data/.lfs/models_clip.tar.gz b/data/.lfs/models_clip.tar.gz new file mode 100644 index 0000000000..a4ab2b5f88 --- /dev/null +++ b/data/.lfs/models_clip.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:102f11bb0aa952b3cebc4491c5ed3f2122e8c38c76002e22400da4f1e5ca90c5 +size 392327708 diff --git a/data/.lfs/models_contact_graspnet.tar.gz b/data/.lfs/models_contact_graspnet.tar.gz new file mode 100644 index 0000000000..73dd44d033 --- /dev/null +++ b/data/.lfs/models_contact_graspnet.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:431c4611a9e096fd8b0a83fecda39c5a575e72fa933f7bd29ff8cfad5bbb5f9d +size 52149165 diff --git a/data/.lfs/models_fastsam.tar.gz b/data/.lfs/models_fastsam.tar.gz new file mode 100644 index 0000000000..77278f4323 --- /dev/null +++ b/data/.lfs/models_fastsam.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:682cb3816451bd73722cc430fdfce15bbe72a07e50ef2ea81ddaed61d1f22a25 +size 39971209 diff --git a/data/.lfs/models_mobileclip.tar.gz b/data/.lfs/models_mobileclip.tar.gz new file mode 100644 index 0000000000..874c94de07 --- /dev/null +++ b/data/.lfs/models_mobileclip.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f8022e365d9e456dcbd3913d36bf8c68a4cd086eb777c92a773c8192cd8235d +size 277814612 diff --git a/data/.lfs/models_yolo.tar.gz b/data/.lfs/models_yolo.tar.gz new file mode 100644 index 0000000000..650d4617ca --- /dev/null +++ b/data/.lfs/models_yolo.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:01796d5884cf29258820cf0e617bf834e9ffb63d8a4c7a54eea802e96fe6a818 +size 72476992 diff --git a/data/.lfs/mujoco_sim.tar.gz b/data/.lfs/mujoco_sim.tar.gz new file mode 100644 index 0000000000..57833fbbc6 --- /dev/null +++ b/data/.lfs/mujoco_sim.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d178439569ed81dfad05455419dc51da2c52021313b6d7b9259d9e30946db7c6 +size 60186340 diff --git a/data/.lfs/office_building_1.tar.gz b/data/.lfs/office_building_1.tar.gz new file mode 100644 index 0000000000..0dc013bd94 --- /dev/null +++ b/data/.lfs/office_building_1.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70aac31ca76597b3eee1ddfcbe2ba71d432fd427176f66d8281d75da76641f49 +size 1061581652 diff --git a/data/.lfs/office_lidar.tar.gz b/data/.lfs/office_lidar.tar.gz new file mode 100644 index 0000000000..849e9e3d49 --- /dev/null +++ b/data/.lfs/office_lidar.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4958965334660c4765553afa38081f00a769c8adf81e599e63fabc866c490fd +size 28576272 diff --git a/data/.lfs/osm_map_test.tar.gz b/data/.lfs/osm_map_test.tar.gz new file mode 100644 index 0000000000..b29104ea17 --- /dev/null +++ b/data/.lfs/osm_map_test.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25097f1bffebd2651f1f4ba93cb749998a064adfdc0cb004981b2317f649c990 +size 1062262 diff --git a/data/.lfs/raw_odometry_rotate_walk.tar.gz b/data/.lfs/raw_odometry_rotate_walk.tar.gz new file mode 100644 index 0000000000..ce8bb1d2b0 --- /dev/null +++ b/data/.lfs/raw_odometry_rotate_walk.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:396345f0cd7a94bb9d85540d4bbce01b027618972f83e713e4550abf1d6ec445 +size 15685 diff --git a/data/.lfs/replay_g1.tar.gz b/data/.lfs/replay_g1.tar.gz new file mode 100644 index 0000000000..67750bd0cf --- /dev/null +++ b/data/.lfs/replay_g1.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19ad1c53c4f8f9414c0921b94cd4c87e81bf0ad676881339f15ae2d8a8619311 +size 557410250 diff --git a/data/.lfs/replay_g1_run.tar.gz b/data/.lfs/replay_g1_run.tar.gz new file mode 100644 index 0000000000..86368ec788 --- /dev/null +++ b/data/.lfs/replay_g1_run.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00cf21f65a15994895150f74044f5d00d7aa873d24f071d249ecbd09cb8f2b26 +size 559554274 diff --git a/data/.lfs/rgbd_frames.tar.gz b/data/.lfs/rgbd_frames.tar.gz new file mode 100644 index 0000000000..8081c76961 --- /dev/null +++ b/data/.lfs/rgbd_frames.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:381b9fd296a885f5211a668df16c68581d2aee458c8734c3256a7461f0decccd +size 948391033 diff --git a/data/.lfs/unitree_go2_lidar_corrected.tar.gz b/data/.lfs/unitree_go2_lidar_corrected.tar.gz new file mode 100644 index 0000000000..013f6b3fe1 --- /dev/null +++ b/data/.lfs/unitree_go2_lidar_corrected.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51a817f2b5664c9e2f2856293db242e030f0edce276e21da0edc2821d947aad2 +size 1212727745 diff --git a/data/.lfs/unitree_go2_office_walk2.tar.gz b/data/.lfs/unitree_go2_office_walk2.tar.gz new file mode 100644 index 0000000000..ea392c4b4c --- /dev/null +++ b/data/.lfs/unitree_go2_office_walk2.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d208cdf537ad01eed2068a4665e454ed30b30894bd9b35c14b4056712faeef5d +size 1693876005 diff --git a/data/.lfs/unitree_office_walk.tar.gz b/data/.lfs/unitree_office_walk.tar.gz new file mode 100644 index 0000000000..419489dbb1 --- /dev/null +++ b/data/.lfs/unitree_office_walk.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bee487130eb662bca73c7d84f14eaea091bd6d7c3f1bfd5173babf660947bdec +size 553620791 diff --git a/data/.lfs/unitree_raw_webrtc_replay.tar.gz b/data/.lfs/unitree_raw_webrtc_replay.tar.gz new file mode 100644 index 0000000000..d41ff5c48f --- /dev/null +++ b/data/.lfs/unitree_raw_webrtc_replay.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a02c622cfee712002afc097825ab5e963071471c3445a20a004ef3532cf59888 +size 756280504 diff --git a/data/.lfs/video.tar.gz b/data/.lfs/video.tar.gz new file mode 100644 index 0000000000..6c0e01a0bb --- /dev/null +++ b/data/.lfs/video.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:530d2132ef84df228af776bd2a2ef387a31858c63ea21c94fb49c7e579b366c0 +size 4322822 diff --git a/default.env b/default.env index e570b8b559..5098a60892 100644 --- a/default.env +++ b/default.env @@ -1 +1,15 @@ OPENAI_API_KEY= +HUGGINGFACE_ACCESS_TOKEN= +ALIBABA_API_KEY= +ANTHROPIC_API_KEY= +HF_TOKEN= +HUGGINGFACE_PRV_ENDPOINT= +ROBOT_IP= +CONN_TYPE=webrtc +WEBRTC_SERVER_HOST=0.0.0.0 +WEBRTC_SERVER_PORT=9991 +DISPLAY=:0 + +# Optional +#DIMOS_MAX_WORKERS= +TEST_RTSP_URL= diff --git a/dimOS.egg-info/PKG-INFO b/dimOS.egg-info/PKG-INFO deleted file mode 100644 index 16cffd96ea..0000000000 --- a/dimOS.egg-info/PKG-INFO +++ /dev/null @@ -1,5 +0,0 @@ -Metadata-Version: 2.1 -Name: dimos -Version: 0.0.0 -Summary: Coming soon -Author-email: Stash Pomichter diff --git a/dimOS.egg-info/SOURCES.txt b/dimOS.egg-info/SOURCES.txt deleted file mode 100644 index 2a64a65d11..0000000000 --- a/dimOS.egg-info/SOURCES.txt +++ /dev/null @@ -1,10 +0,0 @@ -pyproject.toml -dimOS.egg-info/PKG-INFO -dimOS.egg-info/SOURCES.txt -dimOS.egg-info/dependency_links.txt -dimOS.egg-info/top_level.txt -dimos/__init__.py -dimos.egg-info/PKG-INFO -dimos.egg-info/SOURCES.txt -dimos.egg-info/dependency_links.txt -dimos.egg-info/top_level.txt \ No newline at end of file diff --git a/dimOS.egg-info/dependency_links.txt b/dimOS.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789179..0000000000 --- a/dimOS.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/dimOS.egg-info/top_level.txt b/dimOS.egg-info/top_level.txt deleted file mode 100644 index 70edfe204b..0000000000 --- a/dimOS.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -dimos diff --git a/dimos/__init__.py b/dimos/__init__.py index 8b13789179..e69de29bb2 100644 --- a/dimos/__init__.py +++ b/dimos/__init__.py @@ -1 +0,0 @@ - diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 7480fedac6..6c27298807 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -1,239 +1,917 @@ -import base64 -from openai import OpenAI -from dotenv import load_dotenv -import cv2 -import reactivex as rx -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent framework for LLM-based autonomous systems. + +This module provides a flexible foundation for creating agents that can: +- Process image and text inputs through LLM APIs +- Store and retrieve contextual information using semantic memory +- Handle tool/function calling +- Process streaming inputs asynchronously + +The module offers base classes (Agent, LLMAgent) and concrete implementations +like OpenAIAgent that connect to specific LLM providers. +""" + +from __future__ import annotations + +# Standard library imports +import json import os +import threading +from typing import TYPE_CHECKING, Any +# Third-party imports from dotenv import load_dotenv +from openai import NOT_GIVEN, OpenAI +from pydantic import BaseModel +from reactivex import Observable, Observer, create, empty, just, operators as RxOps +from reactivex.disposable import CompositeDisposable, Disposable +from reactivex.subject import Subject + +# Local imports +from dimos.agents.memory.chroma_impl import OpenAISemanticMemory +from dimos.agents.prompt_builder.impl import PromptBuilder +from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.stream_merger import create_stream_merger +from dimos.stream.video_operators import Operators as MyOps, VideoOperators as MyVidOps +from dimos.utils.logging_config import setup_logger +from dimos.utils.threadpool import get_scheduler + +if TYPE_CHECKING: + from reactivex.scheduler import ThreadPoolScheduler + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + from dimos.agents.tokenizer.base import AbstractTokenizer + +# Initialize environment variables load_dotenv() -import threading +# Initialize logger for the agent module +logger = setup_logger() + +# Constants +_TOKEN_BUDGET_PARTS = 4 # Number of parts to divide token budget +_MAX_SAVED_FRAMES = 100 # Maximum number of frames to save + +# ----------------------------------------------------------------------------- +# region Agent Base Class +# ----------------------------------------------------------------------------- class Agent: - def __init__(self, dev_name:str="NA", agent_type:str="Base"): + """Base agent that manages memory and subscriptions.""" + + def __init__( + self, + dev_name: str = "NA", + agent_type: str = "Base", + agent_memory: AbstractAgentSemanticMemory | None = None, + pool_scheduler: ThreadPoolScheduler | None = None, + ) -> None: + """ + Initializes a new instance of the Agent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent (e.g., 'Base', 'Vision'). + agent_memory (AbstractAgentSemanticMemory): The memory system for the agent. + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + """ self.dev_name = dev_name self.agent_type = agent_type + self.agent_memory = agent_memory or OpenAISemanticMemory() self.disposables = CompositeDisposable() - - # def process_frame(self): - # """Processes a single frame. Should be implemented by subclasses.""" - # raise NotImplementedError("Frame processing must be handled by subclass") + self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() - def dispose_all(self): + def dispose_all(self) -> None: """Disposes of all active subscriptions managed by this agent.""" if self.disposables: self.disposables.dispose() else: - print("No disposables to dispose.") + logger.info("No disposables to dispose.") + + +# endregion Agent Base Class + + +# ----------------------------------------------------------------------------- +# region LLMAgent Base Class (Generic LLM Agent) +# ----------------------------------------------------------------------------- +class LLMAgent(Agent): + """Generic LLM agent containing common logic for LLM-based agents. + + This class implements functionality for: + - Updating the query + - Querying the agent's memory (for RAG) + - Building prompts via a prompt builder + - Handling tooling callbacks in responses + - Subscribing to image and query streams + - Emitting responses as an observable stream + + Subclasses must implement the `_send_query` method, which is responsible + for sending the prompt to a specific LLM API. + + Attributes: + query (str): The current query text to process. + prompt_builder (PromptBuilder): Handles construction of prompts. + system_query (str): System prompt for RAG context situations. + image_detail (str): Detail level for image processing ('low','high','auto'). + max_input_tokens_per_request (int): Maximum input token count. + max_output_tokens_per_request (int): Maximum output token count. + max_tokens_per_request (int): Total maximum token count. + rag_query_n (int): Number of results to fetch from memory. + rag_similarity_threshold (float): Minimum similarity for RAG results. + frame_processor (FrameProcessor): Processes video frames. + output_dir (str): Directory for output files. + response_subject (Subject): Subject that emits agent responses. + process_all_inputs (bool): Whether to process every input emission (True) or + skip emissions when the agent is busy processing a previous input (False). + """ + + logging_file_memory_lock = threading.Lock() + + def __init__( + self, + dev_name: str = "NA", + agent_type: str = "LLM", + agent_memory: AbstractAgentSemanticMemory | None = None, + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool = False, + system_query: str | None = None, + max_output_tokens_per_request: int = 16384, + max_input_tokens_per_request: int = 128000, + input_query_stream: Observable | None = None, # type: ignore[type-arg] + input_data_stream: Observable | None = None, # type: ignore[type-arg] + input_video_stream: Observable | None = None, # type: ignore[type-arg] + ) -> None: + """ + Initializes a new instance of the LLMAgent. + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + agent_memory (AbstractAgentSemanticMemory): The memory system for the agent. + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + process_all_inputs (bool): Whether to process every input emission (True) or + skip emissions when the agent is busy processing a previous input (False). + """ + super().__init__(dev_name, agent_type, agent_memory, pool_scheduler) + # These attributes can be configured by a subclass if needed. + self.query: str | None = None + self.prompt_builder: PromptBuilder | None = None + self.system_query: str | None = system_query + self.image_detail: str = "low" + self.max_input_tokens_per_request: int = max_input_tokens_per_request + self.max_output_tokens_per_request: int = max_output_tokens_per_request + self.max_tokens_per_request: int = ( + self.max_input_tokens_per_request + self.max_output_tokens_per_request + ) + self.rag_query_n: int = 4 + self.rag_similarity_threshold: float = 0.45 + self.frame_processor: FrameProcessor | None = None + self.output_dir: str = os.path.join(os.getcwd(), "assets", "agent") + self.process_all_inputs: bool = process_all_inputs + os.makedirs(self.output_dir, exist_ok=True) + + # Subject for emitting responses + self.response_subject = Subject() # type: ignore[var-annotated] + + # Conversation history for maintaining context between calls + self.conversation_history = [] # type: ignore[var-annotated] + + # Initialize input streams + self.input_video_stream = input_video_stream + self.input_query_stream = ( + input_query_stream + if (input_data_stream is None) + else ( + input_query_stream.pipe( # type: ignore[misc, union-attr] + RxOps.with_latest_from(input_data_stream), + RxOps.map( + lambda combined: { + "query": combined[0], # type: ignore[index] + "objects": combined[1] # type: ignore[index] + if len(combined) > 1 # type: ignore[arg-type] + else "No object data available", + } + ), + RxOps.map( + lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}" # type: ignore[index] + ), + RxOps.do_action( + lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") # type: ignore[arg-type] + or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]] # type: ignore[var-annotated] + ), + ) + ) + ) -class OpenAI_Agent(Agent): - memory_file_lock = threading.Lock() + # Setup stream subscriptions based on inputs provided + if (self.input_video_stream is not None) and (self.input_query_stream is not None): + self.merged_stream = create_stream_merger( + data_input_stream=self.input_video_stream, text_query_stream=self.input_query_stream + ) - def __init__(self, dev_name: str, agent_type:str="Vision", query="What do you see?", output_dir='/app/assets/agent'): - """ - Initializes a new OpenAI_Agent instance, an agent specialized in handling vision tasks. + logger.info("Subscribing to merged input stream...") + + # Define a query extractor for the merged stream + def query_extractor(emission): # type: ignore[no-untyped-def] + return (emission[0], emission[1][0]) + + self.disposables.add( + self.subscribe_to_image_processing( + self.merged_stream, query_extractor=query_extractor + ) + ) + else: + # If no merged stream, fall back to individual streams + if self.input_video_stream is not None: + logger.info("Subscribing to input video stream...") + self.disposables.add(self.subscribe_to_image_processing(self.input_video_stream)) + if self.input_query_stream is not None: + logger.info("Subscribing to input query stream...") + self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) + + def _update_query(self, incoming_query: str | None) -> None: + """Updates the query if an incoming query is provided. Args: - dev_name (str): The name of the device. - agent_type (str): The type of the agent, defaulting to 'Vision'. + incoming_query (str): The new query text. """ - super().__init__(dev_name, agent_type) - self.client = OpenAI() - self.is_processing = False - self.query = query - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) + if incoming_query is not None: + self.query = incoming_query - def encode_image(self, image): + def _get_rag_context(self) -> tuple[str, str]: + """Queries the agent memory to retrieve RAG context. + + Returns: + Tuple[str, str]: A tuple containing the formatted results (for logging) + and condensed results (for use in the prompt). """ - Encodes an image array into a base64 string suitable for transmission. + results = self.agent_memory.query( + query_texts=self.query, + n_results=self.rag_query_n, + similarity_threshold=self.rag_similarity_threshold, + ) + formatted_results = "\n".join( + f"Document ID: {doc.id}\nMetadata: {doc.metadata}\nContent: {doc.page_content}\nScore: {score}\n" + for (doc, score) in results + ) + condensed_results = " | ".join(f"{doc.page_content}" for (doc, _) in results) + logger.info(f"Agent Memory Query Results:\n{formatted_results}") + logger.info("=== Results End ===") + return formatted_results, condensed_results + + def _build_prompt( + self, + base64_image: str | None, + dimensions: tuple[int, int] | None, + override_token_limit: bool, + condensed_results: str, + ) -> list: # type: ignore[type-arg] + """Builds a prompt message using the prompt builder. Args: - image (ndarray): An image array to encode. + base64_image (str): Optional Base64-encoded image. + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + condensed_results (str): The condensed RAG context. Returns: - str: The base64 encoded string of the image. + list: A list of message dictionaries to be sent to the LLM. """ - _, buffer = cv2.imencode('.jpg', image) - if buffer is None: - raise ValueError("Failed to encode image") - return base64.b64encode(buffer).decode('utf-8') + # Budget for each component of the prompt + budgets = { + "system_prompt": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + "user_query": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + "image": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + "rag": self.max_input_tokens_per_request // _TOKEN_BUDGET_PARTS, + } + + # Define truncation policies for each component + policies = { + "system_prompt": "truncate_end", + "user_query": "truncate_middle", + "image": "do_not_truncate", + "rag": "truncate_end", + } + + return self.prompt_builder.build( # type: ignore[no-any-return, union-attr] + user_query=self.query, + override_token_limit=override_token_limit, + base64_image=base64_image, + image_width=dimensions[0] if dimensions is not None else None, + image_height=dimensions[1] if dimensions is not None else None, + image_detail=self.image_detail, + rag_context=condensed_results, + system_prompt=self.system_query, + budgets=budgets, + policies=policies, + ) - # def encode_image(self, image): - # """ - # Creates an observable that encodes an image array into a base64 string. + def _handle_tooling(self, response_message, messages): # type: ignore[no-untyped-def] + """Handles tooling callbacks in the response message. - # Args: - # image (ndarray): An image array to encode. + If tool calls are present, the corresponding functions are executed and + a follow-up query is sent. - # Returns: - # Observable: An observable that emits the base64 encoded string of the image. - # """ - # def observable_image_encoder(observer, scheduler): - # try: - # _, buffer = cv2.imencode('.jpg', image) - # if buffer is None: - # observer.on_error(ValueError("Failed to encode image")) - # else: - # encoded_string = base64.b64encode(buffer).decode('utf-8') - # observer.on_next(encoded_string) - # observer.on_completed() - # except Exception as e: - # observer.on_error(e) + Args: + response_message: The response message containing tool calls. + messages (list): The original list of messages sent. + + Returns: + The final response message after processing tool calls, if any. + """ - # return rx.create(observable_image_encoder) + # TODO: Make this more generic or move implementation to OpenAIAgent. + # This is presently OpenAI-specific. + def _tooling_callback(message, messages, response_message, skill_library: SkillLibrary): # type: ignore[no-untyped-def] + has_called_tools = False + new_messages = [] + for tool_call in message.tool_calls: + has_called_tools = True + name = tool_call.function.name + args = json.loads(tool_call.function.arguments) + result = skill_library.call(name, **args) + logger.info(f"Function Call Results: {result}") + new_messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": str(result), + "name": name, + } + ) + if has_called_tools: + logger.info("Sending Another Query.") + messages.append(response_message) + messages.extend(new_messages) + # Delegate to sending the query again. + return self._send_query(messages) + else: + logger.info("No Need for Another Query.") + return None + + if response_message.tool_calls is not None: + return _tooling_callback( + response_message, + messages, + response_message, + self.skill_library, # type: ignore[attr-defined] + ) + return None + + def _observable_query( # type: ignore[no-untyped-def] + self, + observer: Observer, # type: ignore[type-arg] + base64_image: str | None = None, + dimensions: tuple[int, int] | None = None, + override_token_limit: bool = False, + incoming_query: str | None = None, + ): + """Prepares and sends a query to the LLM, emitting the response to the observer. - def query_openai_with_image(self, base64_image): + Args: + observer (Observer): The observer to emit responses to. + base64_image (str): Optional Base64-encoded image. + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + incoming_query (str): Optional query to update the agent's query. + + Raises: + Exception: Propagates any exceptions encountered during processing. """ - Sends an encoded image to OpenAI's API for analysis and returns the response. + try: + self._update_query(incoming_query) + _, condensed_results = self._get_rag_context() + messages = self._build_prompt( + base64_image, dimensions, override_token_limit, condensed_results + ) + # logger.debug(f"Sending Query: {messages}") + logger.info("Sending Query.") + response_message = self._send_query(messages) + logger.info(f"Received Response: {response_message}") + if response_message is None: + raise Exception("Response message does not exist.") + + # TODO: Make this more generic. The parsed tag and tooling handling may be OpenAI-specific. + # If no skill library is provided or there are no tool calls, emit the response directly. + if ( + self.skill_library is None # type: ignore[attr-defined] + or self.skill_library.get_tools() in (None, NOT_GIVEN) # type: ignore[attr-defined] + or response_message.tool_calls is None + ): + final_msg = ( + response_message.parsed + if hasattr(response_message, "parsed") and response_message.parsed + else ( + response_message.content + if hasattr(response_message, "content") + else response_message + ) + ) + observer.on_next(final_msg) + self.response_subject.on_next(final_msg) + else: + response_message_2 = self._handle_tooling(response_message, messages) # type: ignore[no-untyped-call] + final_msg = ( + response_message_2 if response_message_2 is not None else response_message + ) + if isinstance(final_msg, BaseModel): # TODO: Test + final_msg = str(final_msg.content) # type: ignore[attr-defined] + observer.on_next(final_msg) + self.response_subject.on_next(final_msg) + observer.on_completed() + except Exception as e: + logger.error(f"Query failed in {self.dev_name}: {e}") + observer.on_error(e) + self.response_subject.on_error(e) + + def _send_query(self, messages: list) -> Any: # type: ignore[type-arg] + """Sends the query to the LLM API. + + This method must be implemented by subclasses with specifics of the LLM API. Args: - base64_image (str): The base64 encoded string of the image. - query (str): The query text to accompany the image. + messages (list): The prompt messages to be sent. Returns: - str: The content of the response from OpenAI. + Any: The response message from the LLM. + + Raises: + NotImplementedError: Always, unless overridden. """ - try: - response = self.client.chat.completions.create( - model="gpt-4o", - messages=[ - {"role": "user", "content": [{"type": "text", "text": self.query}, - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}", "detail": "high"}}]}, - ], - max_tokens=300, + raise NotImplementedError("Subclasses must implement _send_query method.") + + def _log_response_to_file(self, response, output_dir: str | None = None) -> None: # type: ignore[no-untyped-def] + """Logs the LLM response to a file. + + Args: + response: The response message to log. + output_dir (str): The directory where the log file is stored. + """ + if output_dir is None: + output_dir = self.output_dir + if response is not None: + with self.logging_file_memory_lock: + log_path = os.path.join(output_dir, "memory.txt") + with open(log_path, "a") as file: + file.write(f"{self.dev_name}: {response}\n") + logger.info(f"LLM Response [{self.dev_name}]: {response}") + + def subscribe_to_image_processing( # type: ignore[no-untyped-def] + self, + frame_observable: Observable, # type: ignore[type-arg] + query_extractor=None, + ) -> Disposable: + """Subscribes to a stream of video frames for processing. + + This method sets up a subscription to process incoming video frames. + Each frame is encoded and then sent to the LLM by directly calling the + _observable_query method. The response is then logged to a file. + + Args: + frame_observable (Observable): An observable emitting video frames or + (query, frame) tuples if query_extractor is provided. + query_extractor (callable, optional): Function to extract query and frame from + each emission. If None, assumes emissions are + raw frames and uses self.system_query. + + Returns: + Disposable: A disposable representing the subscription. + """ + # Initialize frame processor if not already set + if self.frame_processor is None: + self.frame_processor = FrameProcessor(delete_on_init=True) + + print_emission_args = {"enabled": True, "dev_name": self.dev_name, "counts": {}} + + def _process_frame(emission) -> Observable: # type: ignore[no-untyped-def, type-arg] + """ + Processes a frame or (query, frame) tuple. + """ + # Extract query and frame + if query_extractor: + query, frame = query_extractor(emission) + else: + query = self.system_query + frame = emission + return just(frame).pipe( # type: ignore[call-overload, no-any-return] + MyOps.print_emission(id="B", **print_emission_args), # type: ignore[arg-type] + RxOps.observe_on(self.pool_scheduler), + MyOps.print_emission(id="C", **print_emission_args), # type: ignore[arg-type] + RxOps.subscribe_on(self.pool_scheduler), + MyOps.print_emission(id="D", **print_emission_args), # type: ignore[arg-type] + MyVidOps.with_jpeg_export( + self.frame_processor, # type: ignore[arg-type] + suffix=f"{self.dev_name}_frame_", + save_limit=_MAX_SAVED_FRAMES, + ), + MyOps.print_emission(id="E", **print_emission_args), # type: ignore[arg-type] + MyVidOps.encode_image(), + MyOps.print_emission(id="F", **print_emission_args), # type: ignore[arg-type] + RxOps.filter( + lambda base64_and_dims: base64_and_dims is not None + and base64_and_dims[0] is not None # type: ignore[index] + and base64_and_dims[1] is not None # type: ignore[index] + ), + MyOps.print_emission(id="G", **print_emission_args), # type: ignore[arg-type] + RxOps.flat_map( + lambda base64_and_dims: create( # type: ignore[arg-type, return-value] + lambda observer, _: self._observable_query( + observer, # type: ignore[arg-type] + base64_image=base64_and_dims[0], + dimensions=base64_and_dims[1], + incoming_query=query, + ) + ) + ), # Use the extracted query + MyOps.print_emission(id="H", **print_emission_args), # type: ignore[arg-type] ) - return response.choices[0].message.content - except Exception as e: - print(f"API request failed: {e}") - return None - - # def query_openai_with_image(self, base64_image, query="What’s in this image?"): - # """ - # Creates an observable that sends an encoded image to OpenAI's API for analysis. - - # Args: - # base64_image (str): The base64 encoded string of the image. - # query (str): The query text to accompany the image. - - # Returns: - # Observable: An observable that emits the response from OpenAI. - # """ - # def observable_openai_query(observer, scheduler): - # try: - # response = self.client.chat.completions.create( - # model="gpt-4o", - # messages=[ - # {"role": "user", "content": [{"type": "text", "text": query}, - # {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}", "detail": "high"}}]}, - # ], - # max_tokens=300, - # ) - # if response: - # observer.on_next(response.choices[0].message.content) - # observer.on_completed() - # else: - # observer.on_error(Exception("Failed to get a valid response from OpenAI")) - # except Exception as e: - # print(f"API request failed: {e}") - # observer.on_error(e) - - # return rx.create(observable_openai_query) - - # def send_query_and_handle_timeout(self, image_base64): - # """ - # Sends an image query to OpenAI and handles response or timeout. - - # Args: - # image_base64 (str): Base64 encoded string of the image to query. - - # Returns: - # Observable: Observable emitting either OpenAI response or timeout signal. - # """ - # # Setting a timeout for the OpenAI request - # timeout_seconds = 10 # Timeout after 10 seconds - # return rx.of(image_base64).pipe( - # ops.map(self.query_openai_with_image), - # ops.timeout(timeout_seconds), - # ops.catch(rx.catch(handler=lambda e: rx.of(f"Timeout or error occurred: {e}"))) - # ) - - # def process_image_stream(self, image_stream): - # """ - # Processes an image stream by encoding images and querying OpenAI. - - # Args: - # image_stream (Observable): An observable stream of image arrays. - - # Returns: - # Observable: An observable stream of OpenAI responses. - # """ - # return image_stream.pipe( - # ops.map(self.encode_image), # Assume this returns a base64 string immediately - # ops.exhaust_map(lambda image_base64: self.send_query_and_handle_timeout(image_base64)) - # ) - - def process_if_idle(self, image): - if not self.is_processing: - self.is_processing = True # Set processing flag - return self.encode_image(image).pipe( - ops.flat_map(self.query_openai_with_image), - ops.do_action(on_next=lambda _: None, on_completed=lambda: self.reset_processing_flag()) + + # Use a mutable flag to ensure only one frame is processed at a time. + is_processing = [False] + + def process_if_free(emission): # type: ignore[no-untyped-def] + if not self.process_all_inputs and is_processing[0]: + # Drop frame if a request is in progress and process_all_inputs is False + return empty() + else: + is_processing[0] = True + return _process_frame(emission).pipe( + MyOps.print_emission(id="I", **print_emission_args), # type: ignore[arg-type] + RxOps.observe_on(self.pool_scheduler), + MyOps.print_emission(id="J", **print_emission_args), # type: ignore[arg-type] + RxOps.subscribe_on(self.pool_scheduler), + MyOps.print_emission(id="K", **print_emission_args), # type: ignore[arg-type] + RxOps.do_action( + on_completed=lambda: is_processing.__setitem__(0, False), + on_error=lambda e: is_processing.__setitem__(0, False), + ), + MyOps.print_emission(id="L", **print_emission_args), # type: ignore[arg-type] + ) + + observable = frame_observable.pipe( + MyOps.print_emission(id="A", **print_emission_args), # type: ignore[arg-type] + RxOps.flat_map(process_if_free), + MyOps.print_emission(id="M", **print_emission_args), # type: ignore[arg-type] + ) + + disposable = observable.subscribe( + on_next=lambda response: self._log_response_to_file(response, self.output_dir), + on_error=lambda e: logger.error(f"Error encountered: {e}"), + on_completed=lambda: logger.info(f"Stream processing completed for {self.dev_name}"), + ) + self.disposables.add(disposable) + return disposable # type: ignore[no-any-return] + + def subscribe_to_query_processing(self, query_observable: Observable) -> Disposable: # type: ignore[type-arg] + """Subscribes to a stream of queries for processing. + + This method sets up a subscription to process incoming queries by directly + calling the _observable_query method. The responses are logged to a file. + + Args: + query_observable (Observable): An observable emitting queries. + + Returns: + Disposable: A disposable representing the subscription. + """ + print_emission_args = {"enabled": False, "dev_name": self.dev_name, "counts": {}} + + def _process_query(query) -> Observable: # type: ignore[no-untyped-def, type-arg] + """ + Processes a single query by logging it and passing it to _observable_query. + Returns an observable that emits the LLM response. + """ + return just(query).pipe( + MyOps.print_emission(id="Pr A", **print_emission_args), # type: ignore[arg-type] + RxOps.flat_map( + lambda query: create( # type: ignore[arg-type, return-value] + lambda observer, _: self._observable_query(observer, incoming_query=query) # type: ignore[arg-type] + ) + ), + MyOps.print_emission(id="Pr B", **print_emission_args), # type: ignore[arg-type] ) - else: - return rx.empty() # Ignore the emission if already processing - def reset_processing_flag(self): - self.is_processing = False + # A mutable flag indicating whether a query is currently being processed. + is_processing = [False] + + def process_if_free(query): # type: ignore[no-untyped-def] + logger.info(f"Processing Query: {query}") + if not self.process_all_inputs and is_processing[0]: + # Drop query if a request is already in progress and process_all_inputs is False + return empty() + else: + is_processing[0] = True + logger.info("Processing Query.") + return _process_query(query).pipe( + MyOps.print_emission(id="B", **print_emission_args), # type: ignore[arg-type] + RxOps.observe_on(self.pool_scheduler), + MyOps.print_emission(id="C", **print_emission_args), # type: ignore[arg-type] + RxOps.subscribe_on(self.pool_scheduler), + MyOps.print_emission(id="D", **print_emission_args), # type: ignore[arg-type] + RxOps.do_action( + on_completed=lambda: is_processing.__setitem__(0, False), + on_error=lambda e: is_processing.__setitem__(0, False), + ), + MyOps.print_emission(id="E", **print_emission_args), # type: ignore[arg-type] + ) + + observable = query_observable.pipe( + MyOps.print_emission(id="A", **print_emission_args), # type: ignore[arg-type] + RxOps.flat_map(lambda query: process_if_free(query)), # type: ignore[no-untyped-call] + MyOps.print_emission(id="F", **print_emission_args), # type: ignore[arg-type] + ) + + disposable = observable.subscribe( + on_next=lambda response: self._log_response_to_file(response, self.output_dir), + on_error=lambda e: logger.error(f"Error processing query for {self.dev_name}: {e}"), + on_completed=lambda: logger.info(f"Stream processing completed for {self.dev_name}"), + ) + self.disposables.add(disposable) + return disposable # type: ignore[no-any-return] + + def get_response_observable(self) -> Observable: # type: ignore[type-arg] + """Gets an observable that emits responses from this agent. - def process_image_stream(self, image_stream): + Returns: + Observable: An observable that emits string responses from the agent. """ - Processes an image stream by encoding images and querying OpenAI. + return self.response_subject.pipe( + RxOps.observe_on(self.pool_scheduler), + RxOps.subscribe_on(self.pool_scheduler), + RxOps.share(), + ) + + def run_observable_query(self, query_text: str, **kwargs) -> Observable: # type: ignore[no-untyped-def, type-arg] + """Creates an observable that processes a one-off text query to Agent and emits the response. + + This method provides a simple way to send a text query and get an observable + stream of the response. It's designed for one-off queries rather than + continuous processing of input streams. Useful for testing and development. Args: - image_stream (Observable): An observable stream of image arrays. + query_text (str): The query text to process. + **kwargs: Additional arguments to pass to _observable_query. Supported args vary by agent type. + For example, ClaudeAgent supports: base64_image, dimensions, override_token_limit, + reset_conversation, thinking_budget_tokens Returns: - Observable: An observable stream of OpenAI responses. + Observable: An observable that emits the response as a string. """ - # Process each and every entry, one after another - return image_stream.pipe( - ops.map(self.encode_image), - ops.map(self.query_openai_with_image), + return create( + lambda observer, _: self._observable_query( + observer, # type: ignore[arg-type] + incoming_query=query_text, + **kwargs, + ) ) - - # Process image, ignoring new images while processing - # return image_stream.pipe( - # ops.flat_map(self.process_if_idle), - # ops.filter(lambda x: x is not None) # Filter out ignored (None) emissions - # ) - - def subscribe_to_image_processing(self, frame_observable): + + def dispose_all(self) -> None: + """Disposes of all active subscriptions managed by this agent.""" + super().dispose_all() + self.response_subject.on_completed() + + +# endregion LLMAgent Base Class (Generic LLM Agent) + + +# ----------------------------------------------------------------------------- +# region OpenAIAgent Subclass (OpenAI-Specific Implementation) +# ----------------------------------------------------------------------------- +class OpenAIAgent(LLMAgent): + """OpenAI agent implementation that uses OpenAI's API for processing. + + This class implements the _send_query method to interact with OpenAI's API. + It also sets up OpenAI-specific parameters, such as the client, model name, + tokenizer, and response model. + """ + + def __init__( + self, + dev_name: str, + agent_type: str = "Vision", + query: str = "What do you see?", + input_query_stream: Observable | None = None, # type: ignore[type-arg] + input_data_stream: Observable | None = None, # type: ignore[type-arg] + input_video_stream: Observable | None = None, # type: ignore[type-arg] + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = None, + max_input_tokens_per_request: int = 128000, + max_output_tokens_per_request: int = 16384, + model_name: str = "gpt-4o", + prompt_builder: PromptBuilder | None = None, + tokenizer: AbstractTokenizer | None = None, + rag_query_n: int = 4, + rag_similarity_threshold: float = 0.45, + skills: AbstractSkill | list[AbstractSkill] | SkillLibrary | None = None, + response_model: BaseModel | None = None, + frame_processor: FrameProcessor | None = None, + image_detail: str = "low", + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + openai_client: OpenAI | None = None, + ) -> None: """ - Subscribes to an observable of frames, processes them, and handles the responses. + Initializes a new instance of the OpenAIAgent. Args: - frame_observable (Observable): An observable stream of image frames. + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + query (str): The default query text. + input_query_stream (Observable): An observable for query input. + input_data_stream (Observable): An observable for data input. + input_video_stream (Observable): An observable for video frames. + output_dir (str): Directory for output files. + agent_memory (AbstractAgentSemanticMemory): The memory system. + system_query (str): The system prompt to use with RAG context. + max_input_tokens_per_request (int): Maximum tokens for input. + max_output_tokens_per_request (int): Maximum tokens for output. + model_name (str): The OpenAI model name to use. + prompt_builder (PromptBuilder): Custom prompt builder. + tokenizer (AbstractTokenizer): Custom tokenizer for token counting. + rag_query_n (int): Number of results to fetch in RAG queries. + rag_similarity_threshold (float): Minimum similarity for RAG results. + skills (Union[AbstractSkill, List[AbstractSkill], SkillLibrary]): Skills available to the agent. + response_model (BaseModel): Optional Pydantic model for responses. + frame_processor (FrameProcessor): Custom frame processor. + image_detail (str): Detail level for images ("low", "high", "auto"). + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + process_all_inputs (bool): Whether to process all inputs or skip when busy. + If None, defaults to True for text queries and merged streams, False for video streams. + openai_client (OpenAI): The OpenAI client to use. This can be used to specify + a custom OpenAI client if targetting another provider. """ - disposable = self.process_image_stream(frame_observable).subscribe( - on_next=self.log_response_to_file, # lambda response: print(f"OpenAI Response [{self.dev_name}]:", response), - on_error=lambda e: print("Error:", e), - on_completed=lambda: print("Stream processing completed.") + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + if input_query_stream is not None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + input_query_stream=input_query_stream, + input_data_stream=input_data_stream, + input_video_stream=input_video_stream, ) - self.disposables.add(disposable) - - def log_response_to_file(self, response): + self.client = openai_client or OpenAI() + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + # Configure skill library. + self.skills = skills + self.skill_library = None + if isinstance(self.skills, SkillLibrary): + self.skill_library = self.skills + elif isinstance(self.skills, list): + self.skill_library = SkillLibrary() + for skill in self.skills: + self.skill_library.add(skill) + elif isinstance(self.skills, AbstractSkill): + self.skill_library = SkillLibrary() + self.skill_library.add(self.skills) + + self.response_model = response_model if response_model is not None else NOT_GIVEN + self.model_name = model_name + self.tokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name) + self.prompt_builder = prompt_builder or PromptBuilder( + self.model_name, tokenizer=self.tokenizer + ) + self.rag_query_n = rag_query_n + self.rag_similarity_threshold = rag_similarity_threshold + self.image_detail = image_detail + self.max_output_tokens_per_request = max_output_tokens_per_request + self.max_input_tokens_per_request = max_input_tokens_per_request + self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request + + # Add static context to memory. + self._add_context_to_memory() + + self.frame_processor = frame_processor or FrameProcessor(delete_on_init=True) + + logger.info("OpenAI Agent Initialized.") + + def _add_context_to_memory(self) -> None: + """Adds initial context to the agent's memory.""" + context_data = [ + ( + "id0", + "Optical Flow is a technique used to track the movement of objects in a video sequence.", + ), + ( + "id1", + "Edge Detection is a technique used to identify the boundaries of objects in an image.", + ), + ("id2", "Video is a sequence of frames captured at regular intervals."), + ( + "id3", + "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", + ), + ( + "id4", + "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", + ), + ] + for doc_id, text in context_data: + self.agent_memory.add_vector(doc_id, text) # type: ignore[no-untyped-call] + + def _send_query(self, messages: list) -> Any: # type: ignore[type-arg] + """Sends the query to OpenAI's API. + + Depending on whether a response model is provided, the appropriate API + call is made. + + Args: + messages (list): The prompt messages to send. + + Returns: + The response message from OpenAI. + + Raises: + Exception: If no response message is returned. + ConnectionError: If there's an issue connecting to the API. + ValueError: If the messages or other parameters are invalid. """ - Logs the response to a shared 'memory.txt' file with the device name prefixed, - using a lock to ensure thread safety. + try: + if self.response_model is not NOT_GIVEN: + response = self.client.beta.chat.completions.parse( + model=self.model_name, + messages=messages, + response_format=self.response_model, # type: ignore[arg-type] + tools=( + self.skill_library.get_tools() # type: ignore[arg-type] + if self.skill_library is not None + else NOT_GIVEN + ), + max_tokens=self.max_output_tokens_per_request, + ) + else: + response = self.client.chat.completions.create( # type: ignore[assignment] + model=self.model_name, + messages=messages, + max_tokens=self.max_output_tokens_per_request, + tools=( + self.skill_library.get_tools() # type: ignore[arg-type] + if self.skill_library is not None + else NOT_GIVEN + ), + ) + response_message = response.choices[0].message + if response_message is None: + logger.error("Response message does not exist.") + raise Exception("Response message does not exist.") + return response_message + except ConnectionError as ce: + logger.error(f"Connection error with API: {ce}") + raise + except ValueError as ve: + logger.error(f"Invalid parameters: {ve}") + raise + except Exception as e: + logger.error(f"Unexpected error in API call: {e}") + raise + + def stream_query(self, query_text: str) -> Observable: # type: ignore[type-arg] + """Creates an observable that processes a text query and emits the response. + + This method provides a simple way to send a text query and get an observable + stream of the response. It's designed for one-off queries rather than + continuous processing of input streams. Args: - response (str): The response to log. + query_text (str): The query text to process. + + Returns: + Observable: An observable that emits the response as a string. """ - with open('/app/assets/agent/memory.txt', 'a') as file: - file.write(f"{self.dev_name}: {response}\n") - print(f"OpenAI Response [{self.dev_name}]:", response) \ No newline at end of file + return create( + lambda observer, _: self._observable_query(observer, incoming_query=query_text) # type: ignore[arg-type] + ) + + +# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation) diff --git a/dimos/agents/agent_config.py b/dimos/agents/agent_config.py new file mode 100644 index 0000000000..0831d4afe5 --- /dev/null +++ b/dimos/agents/agent_config.py @@ -0,0 +1,55 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.agents.agent import Agent + + +class AgentConfig: + def __init__(self, agents: list[Agent] | None = None) -> None: + """ + Initialize an AgentConfig with a list of agents. + + Args: + agents (List[Agent], optional): List of Agent instances. Defaults to empty list. + """ + self.agents = agents if agents is not None else [] + + def add_agent(self, agent: Agent) -> None: + """ + Add an agent to the configuration. + + Args: + agent (Agent): Agent instance to add + """ + self.agents.append(agent) + + def remove_agent(self, agent: Agent) -> None: + """ + Remove an agent from the configuration. + + Args: + agent (Agent): Agent instance to remove + """ + if agent in self.agents: + self.agents.remove(agent) + + def get_agents(self) -> list[Agent]: + """ + Get the list of configured agents. + + Returns: + List[Agent]: List of configured agents + """ + return self.agents diff --git a/dimos/agents/agent_message.py b/dimos/agents/agent_message.py new file mode 100644 index 0000000000..a095e3fb00 --- /dev/null +++ b/dimos/agents/agent_message.py @@ -0,0 +1,100 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AgentMessage type for multimodal agent communication.""" + +from dataclasses import dataclass, field +import time + +from dimos.agents.agent_types import AgentImage +from dimos.msgs.sensor_msgs.Image import Image + + +@dataclass +class AgentMessage: + """Message type for agent communication with text and images. + + This type supports multimodal messages containing both text strings + and AgentImage objects (base64 encoded) for vision-enabled agents. + + The messages field contains multiple text strings that will be combined + into a single message when sent to the LLM. + """ + + messages: list[str] = field(default_factory=list) + images: list[AgentImage] = field(default_factory=list) + sender_id: str | None = None + timestamp: float = field(default_factory=time.time) + + def add_text(self, text: str) -> None: + """Add a text message.""" + if text: # Only add non-empty text + self.messages.append(text) + + def add_image(self, image: Image | AgentImage) -> None: + """Add an image. Converts Image to AgentImage if needed.""" + if isinstance(image, Image): + # Convert to AgentImage + agent_image = AgentImage( + base64_jpeg=image.agent_encode(), # type: ignore[arg-type] + width=image.width, + height=image.height, + metadata={"format": image.format.value, "frame_id": image.frame_id}, + ) + self.images.append(agent_image) + elif isinstance(image, AgentImage): + self.images.append(image) + else: + raise TypeError(f"Expected Image or AgentImage, got {type(image)}") + + def has_text(self) -> bool: + """Check if message contains text.""" + # Check if we have any non-empty messages + return any(msg for msg in self.messages if msg) + + def has_images(self) -> bool: + """Check if message contains images.""" + return len(self.images) > 0 + + def is_multimodal(self) -> bool: + """Check if message contains both text and images.""" + return self.has_text() and self.has_images() + + def get_primary_text(self) -> str | None: + """Get the first text message, if any.""" + return self.messages[0] if self.messages else None + + def get_primary_image(self) -> AgentImage | None: + """Get the first image, if any.""" + return self.images[0] if self.images else None + + def get_combined_text(self) -> str: + """Get all text messages combined into a single string.""" + # Filter out any empty strings and join + return " ".join(msg for msg in self.messages if msg) + + def clear(self) -> None: + """Clear all content.""" + self.messages.clear() + self.images.clear() + + def __repr__(self) -> str: + """String representation.""" + return ( + f"AgentMessage(" + f"texts={len(self.messages)}, " + f"images={len(self.images)}, " + f"sender='{self.sender_id}', " + f"timestamp={self.timestamp})" + ) diff --git a/dimos/agents/agent_types.py b/dimos/agents/agent_types.py new file mode 100644 index 0000000000..f52bafdac6 --- /dev/null +++ b/dimos/agents/agent_types.py @@ -0,0 +1,255 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent-specific types for message passing.""" + +from dataclasses import dataclass, field +import json +import threading +import time +from typing import Any + + +@dataclass +class AgentImage: + """Image data encoded for agent consumption. + + Images are stored as base64-encoded JPEG strings ready for + direct use by LLM/vision models. + """ + + base64_jpeg: str + width: int | None = None + height: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __repr__(self) -> str: + return f"AgentImage(size={self.width}x{self.height}, metadata={list(self.metadata.keys())})" + + +@dataclass +class ToolCall: + """Represents a tool/function call request from the LLM.""" + + id: str + name: str + arguments: dict[str, Any] + status: str = "pending" # pending, executing, completed, failed + + def __repr__(self) -> str: + return f"ToolCall(id='{self.id}', name='{self.name}', status='{self.status}')" + + +@dataclass +class AgentResponse: + """Enhanced response from an agent query with tool support. + + Based on common LLM response patterns, includes content and metadata. + """ + + content: str + role: str = "assistant" + tool_calls: list[ToolCall] | None = None + requires_follow_up: bool = False # Indicates if tool execution is needed + metadata: dict[str, Any] = field(default_factory=dict) + timestamp: float = field(default_factory=time.time) + + def __repr__(self) -> str: + content_preview = self.content[:50] + "..." if len(self.content) > 50 else self.content + tool_info = f", tools={len(self.tool_calls)}" if self.tool_calls else "" + return f"AgentResponse(role='{self.role}', content='{content_preview}'{tool_info})" + + +@dataclass +class ConversationMessage: + """Single message in conversation history. + + Represents a message in the conversation that can be converted to + different formats (OpenAI, TensorZero, etc). + """ + + role: str # "system", "user", "assistant", "tool" + content: str | list[dict[str, Any]] # Text or content blocks + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None # For tool responses + name: str | None = None # For tool messages (function name) + timestamp: float = field(default_factory=time.time) + + def to_openai_format(self) -> dict[str, Any]: + """Convert to OpenAI API format.""" + msg = {"role": self.role} + + # Handle content + if isinstance(self.content, str): + msg["content"] = self.content + else: + # Content is already a list of content blocks + msg["content"] = self.content # type: ignore[assignment] + + # Add tool calls if present + if self.tool_calls: + # Handle both ToolCall objects and dicts + if isinstance(self.tool_calls[0], dict): + msg["tool_calls"] = self.tool_calls # type: ignore[assignment] + else: + msg["tool_calls"] = [ # type: ignore[assignment] + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}, + } + for tc in self.tool_calls + ] + + # Add tool_call_id for tool responses + if self.tool_call_id: + msg["tool_call_id"] = self.tool_call_id + + # Add name field if present (for tool messages) + if self.name: + msg["name"] = self.name + + return msg + + def __repr__(self) -> str: + content_preview = ( + str(self.content)[:50] + "..." if len(str(self.content)) > 50 else str(self.content) + ) + return f"ConversationMessage(role='{self.role}', content='{content_preview}')" + + +class ConversationHistory: + """Thread-safe conversation history manager. + + Manages conversation history with proper formatting for different + LLM providers and automatic trimming. + """ + + def __init__(self, max_size: int = 20) -> None: + """Initialize conversation history. + + Args: + max_size: Maximum number of messages to keep + """ + self._messages: list[ConversationMessage] = [] + self._lock = threading.Lock() + self.max_size = max_size + + def add_user_message(self, content: str | list[dict[str, Any]]) -> None: + """Add user message to history. + + Args: + content: Text string or list of content blocks (for multimodal) + """ + with self._lock: + self._messages.append(ConversationMessage(role="user", content=content)) + self._trim() + + def add_assistant_message(self, content: str, tool_calls: list[ToolCall] | None = None) -> None: + """Add assistant response to history. + + Args: + content: Response text + tool_calls: Optional list of tool calls made + """ + with self._lock: + self._messages.append( + ConversationMessage(role="assistant", content=content, tool_calls=tool_calls) + ) + self._trim() + + def add_tool_result(self, tool_call_id: str, content: str, name: str | None = None) -> None: + """Add tool execution result to history. + + Args: + tool_call_id: ID of the tool call this is responding to + content: Result of the tool execution + name: Optional name of the tool/function + """ + with self._lock: + self._messages.append( + ConversationMessage( + role="tool", content=content, tool_call_id=tool_call_id, name=name + ) + ) + self._trim() + + def add_raw_message(self, message: dict[str, Any]) -> None: + """Add a raw message dict to history. + + Args: + message: Message dict with role and content + """ + with self._lock: + # Extract fields from raw message + role = message.get("role", "user") + content = message.get("content", "") + + # Handle tool calls if present + tool_calls = None + if "tool_calls" in message: + tool_calls = [ + ToolCall( + id=tc["id"], + name=tc["function"]["name"], + arguments=json.loads(tc["function"]["arguments"]) + if isinstance(tc["function"]["arguments"], str) + else tc["function"]["arguments"], + status="completed", + ) + for tc in message["tool_calls"] + ] + + # Handle tool_call_id for tool responses + tool_call_id = message.get("tool_call_id") + + self._messages.append( + ConversationMessage( + role=role, content=content, tool_calls=tool_calls, tool_call_id=tool_call_id + ) + ) + self._trim() + + def to_openai_format(self) -> list[dict[str, Any]]: + """Export history in OpenAI format. + + Returns: + List of message dicts in OpenAI format + """ + with self._lock: + return [msg.to_openai_format() for msg in self._messages] + + def clear(self) -> None: + """Clear all conversation history.""" + with self._lock: + self._messages.clear() + + def size(self) -> int: + """Get number of messages in history. + + Returns: + Number of messages + """ + with self._lock: + return len(self._messages) + + def _trim(self) -> None: + """Trim history to max_size (must be called within lock).""" + if len(self._messages) > self.max_size: + # Keep the most recent messages + self._messages = self._messages[-self.max_size :] + + def __repr__(self) -> str: + with self._lock: + return f"ConversationHistory(messages={len(self._messages)}, max_size={self.max_size})" diff --git a/dimos/agents/claude_agent.py b/dimos/agents/claude_agent.py new file mode 100644 index 0000000000..12811becb7 --- /dev/null +++ b/dimos/agents/claude_agent.py @@ -0,0 +1,738 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Claude agent implementation for the DIMOS agent framework. + +This module provides a ClaudeAgent class that implements the LLMAgent interface +for Anthropic's Claude models. It handles conversion between the DIMOS skill format +and Claude's tools format. +""" + +from __future__ import annotations + +import json +import os +from typing import TYPE_CHECKING, Any + +import anthropic +from dotenv import load_dotenv + +# Local imports +from dimos.agents.agent import LLMAgent +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.stream.frame_processor import FrameProcessor +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from pydantic import BaseModel + from reactivex import Observable + from reactivex.scheduler import ThreadPoolScheduler + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + from dimos.agents.prompt_builder.impl import PromptBuilder + +# Initialize environment variables +load_dotenv() + +# Initialize logger for the Claude agent +logger = setup_logger() + + +# Response object compatible with LLMAgent +class ResponseMessage: + def __init__(self, content: str = "", tool_calls=None, thinking_blocks=None) -> None: # type: ignore[no-untyped-def] + self.content = content + self.tool_calls = tool_calls or [] + self.thinking_blocks = thinking_blocks or [] + self.parsed = None + + def __str__(self) -> str: + # Return a string representation for logging + parts = [] + + # Include content if available + if self.content: + parts.append(self.content) + + # Include tool calls if available + if self.tool_calls: + tool_names = [tc.function.name for tc in self.tool_calls] + parts.append(f"[Tools called: {', '.join(tool_names)}]") + + return "\n".join(parts) if parts else "[No content]" + + +class ClaudeAgent(LLMAgent): + """Claude agent implementation that uses Anthropic's API for processing. + + This class implements the _send_query method to interact with Anthropic's API + and overrides _build_prompt to create Claude-formatted messages directly. + """ + + def __init__( + self, + dev_name: str, + agent_type: str = "Vision", + query: str = "What do you see?", + input_query_stream: Observable | None = None, # type: ignore[type-arg] + input_video_stream: Observable | None = None, # type: ignore[type-arg] + input_data_stream: Observable | None = None, # type: ignore[type-arg] + output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = None, + max_input_tokens_per_request: int = 128000, + max_output_tokens_per_request: int = 16384, + model_name: str = "claude-3-7-sonnet-20250219", + prompt_builder: PromptBuilder | None = None, + rag_query_n: int = 4, + rag_similarity_threshold: float = 0.45, + skills: AbstractSkill | None = None, + response_model: BaseModel | None = None, + frame_processor: FrameProcessor | None = None, + image_detail: str = "low", + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + thinking_budget_tokens: int | None = 2000, + ) -> None: + """ + Initializes a new instance of the ClaudeAgent. + + Args: + dev_name (str): The device name of the agent. + agent_type (str): The type of the agent. + query (str): The default query text. + input_query_stream (Observable): An observable for query input. + input_video_stream (Observable): An observable for video frames. + output_dir (str): Directory for output files. + agent_memory (AbstractAgentSemanticMemory): The memory system. + system_query (str): The system prompt to use with RAG context. + max_input_tokens_per_request (int): Maximum tokens for input. + max_output_tokens_per_request (int): Maximum tokens for output. + model_name (str): The Claude model name to use. + prompt_builder (PromptBuilder): Custom prompt builder (not used in Claude implementation). + rag_query_n (int): Number of results to fetch in RAG queries. + rag_similarity_threshold (float): Minimum similarity for RAG results. + skills (AbstractSkill): Skills available to the agent. + response_model (BaseModel): Optional Pydantic model for responses. + frame_processor (FrameProcessor): Custom frame processor. + image_detail (str): Detail level for images ("low", "high", "auto"). + pool_scheduler (ThreadPoolScheduler): The scheduler to use for thread pool operations. + process_all_inputs (bool): Whether to process all inputs or skip when busy. + thinking_budget_tokens (int): Number of tokens to allocate for Claude's thinking. 0 disables thinking. + """ + # Determine appropriate default for process_all_inputs if not provided + if process_all_inputs is None: + # Default to True for text queries, False for video streams + if input_query_stream is not None and input_video_stream is None: + process_all_inputs = True + else: + process_all_inputs = False + + super().__init__( + dev_name=dev_name, + agent_type=agent_type, + agent_memory=agent_memory, + pool_scheduler=pool_scheduler, + process_all_inputs=process_all_inputs, + system_query=system_query, + input_query_stream=input_query_stream, + input_video_stream=input_video_stream, + input_data_stream=input_data_stream, + ) + + self.client = anthropic.Anthropic() + self.query = query + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + # Claude-specific parameters + self.thinking_budget_tokens = thinking_budget_tokens + self.claude_api_params = {} # type: ignore[var-annotated] # Will store params for Claude API calls + + # Configure skills + self.skills = skills + self.skill_library = None # Required for error 'ClaudeAgent' object has no attribute 'skill_library' due to skills refactor + if isinstance(self.skills, SkillLibrary): + self.skill_library = self.skills + elif isinstance(self.skills, list): + self.skill_library = SkillLibrary() + for skill in self.skills: + self.skill_library.add(skill) + elif isinstance(self.skills, AbstractSkill): + self.skill_library = SkillLibrary() + self.skill_library.add(self.skills) + + self.response_model = response_model + self.model_name = model_name + self.rag_query_n = rag_query_n + self.rag_similarity_threshold = rag_similarity_threshold + self.image_detail = image_detail + self.max_output_tokens_per_request = max_output_tokens_per_request + self.max_input_tokens_per_request = max_input_tokens_per_request + self.max_tokens_per_request = max_input_tokens_per_request + max_output_tokens_per_request + + # Add static context to memory. + self._add_context_to_memory() + + self.frame_processor = frame_processor or FrameProcessor(delete_on_init=True) + + # Ensure only one input stream is provided. + if self.input_video_stream is not None and self.input_query_stream is not None: + raise ValueError( + "More than one input stream provided. Please provide only one input stream." + ) + + logger.info("Claude Agent Initialized.") + + def _add_context_to_memory(self) -> None: + """Adds initial context to the agent's memory.""" + context_data = [ + ( + "id0", + "Optical Flow is a technique used to track the movement of objects in a video sequence.", + ), + ( + "id1", + "Edge Detection is a technique used to identify the boundaries of objects in an image.", + ), + ("id2", "Video is a sequence of frames captured at regular intervals."), + ( + "id3", + "Colors in Optical Flow are determined by the movement of light, and can be used to track the movement of objects.", + ), + ( + "id4", + "Json is a data interchange format that is easy for humans to read and write, and easy for machines to parse and generate.", + ), + ] + for doc_id, text in context_data: + self.agent_memory.add_vector(doc_id, text) # type: ignore[no-untyped-call] + + def _convert_tools_to_claude_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Converts DIMOS tools to Claude format. + + Args: + tools: List of tools in DIMOS format. + + Returns: + List of tools in Claude format. + """ + if not tools: + return [] + + claude_tools = [] + + for tool in tools: + # Skip if not a function + if tool.get("type") != "function": + continue + + function = tool.get("function", {}) + name = function.get("name") + description = function.get("description", "") + parameters = function.get("parameters", {}) + + claude_tool = { + "name": name, + "description": description, + "input_schema": { + "type": "object", + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + } + + claude_tools.append(claude_tool) + + return claude_tools + + def _build_prompt( # type: ignore[override] + self, + messages: list, # type: ignore[type-arg] + base64_image: str | list[str] | None = None, + dimensions: tuple[int, int] | None = None, + override_token_limit: bool = False, + rag_results: str = "", + thinking_budget_tokens: int | None = None, + ) -> list: # type: ignore[type-arg] + """Builds a prompt message specifically for Claude API, using local messages copy.""" + """Builds a prompt message specifically for Claude API. + + This method creates messages in Claude's format directly, without using + any OpenAI-specific formatting or token counting. + + Args: + base64_image (Union[str, List[str]]): Optional Base64-encoded image(s). + dimensions (Tuple[int, int]): Optional image dimensions. + override_token_limit (bool): Whether to override token limits. + rag_results (str): The condensed RAG context. + thinking_budget_tokens (int): Number of tokens to allocate for Claude's thinking. + + Returns: + dict: A dict containing Claude API parameters. + """ + + # Append user query to conversation history while handling RAG + if rag_results: + messages.append({"role": "user", "content": f"{rag_results}\n\n{self.query}"}) + logger.info( + f"Added new user message to conversation history with RAG context (now has {len(messages)} messages)" + ) + else: + messages.append({"role": "user", "content": self.query}) + logger.info( + f"Added new user message to conversation history (now has {len(messages)} messages)" + ) + + if base64_image is not None: + # Handle both single image (str) and multiple images (List[str]) + images = [base64_image] if isinstance(base64_image, str) else base64_image + + # Add each image as a separate entry in conversation history + for img in images: + img_content = [ + { + "type": "image", + "source": {"type": "base64", "media_type": "image/jpeg", "data": img}, + } + ] + messages.append({"role": "user", "content": img_content}) + + if images: + logger.info( + f"Added {len(images)} image(s) as separate entries to conversation history" + ) + + # Create Claude parameters with basic settings + claude_params = { + "model": self.model_name, + "max_tokens": self.max_output_tokens_per_request, + "temperature": 0, # Add temperature to make responses more deterministic + "messages": messages, + } + + # Add system prompt as a top-level parameter (not as a message) + if self.system_query: + claude_params["system"] = self.system_query + + # Store the parameters for use in _send_query + self.claude_api_params = claude_params.copy() + + # Add tools if skills are available + if self.skills and self.skills.get_tools(): + tools = self._convert_tools_to_claude_format(self.skills.get_tools()) + if tools: # Only add if we have valid tools + claude_params["tools"] = tools + # Enable tool calling with proper format + claude_params["tool_choice"] = {"type": "auto"} + + # Add thinking if enabled and hard code required temperature = 1 + if thinking_budget_tokens is not None and thinking_budget_tokens != 0: + claude_params["thinking"] = {"type": "enabled", "budget_tokens": thinking_budget_tokens} + claude_params["temperature"] = ( + 1 # Required to be 1 when thinking is enabled # Default to 0 for deterministic responses + ) + + # Store the parameters for use in _send_query and return them + self.claude_api_params = claude_params.copy() + return messages, claude_params # type: ignore[return-value] + + def _send_query(self, messages: list, claude_params: dict) -> Any: # type: ignore[override, type-arg] + """Sends the query to Anthropic's API using streaming for better thinking visualization. + + Args: + messages: Dict with 'claude_prompt' key containing Claude API parameters. + + Returns: + The response message in a format compatible with LLMAgent's expectations. + """ + try: + # Get Claude parameters + claude_params = claude_params.get("claude_prompt", None) or self.claude_api_params + + # Log request parameters with truncated base64 data + logger.debug(self._debug_api_call(claude_params)) + + # Initialize response containers + text_content = "" + tool_calls = [] + thinking_blocks = [] + + # Log the start of streaming and the query + logger.info("Sending streaming request to Claude API") + + # Log the query to memory.txt + with open(os.path.join(self.output_dir, "memory.txt"), "a") as f: + f.write(f"\n\nQUERY: {self.query}\n\n") + f.flush() + + # Stream the response + with self.client.messages.stream(**claude_params) as stream: + print("\n==== CLAUDE API RESPONSE STREAM STARTED ====") + + # Open the memory file once for the entire stream processing + with open(os.path.join(self.output_dir, "memory.txt"), "a") as memory_file: + # Track the current block being processed + current_block = {"type": None, "id": None, "content": "", "signature": None} + + for event in stream: + # Log each event to console + # print(f"EVENT: {event.type}") + # print(json.dumps(event.model_dump(), indent=2, default=str)) + + if event.type == "content_block_start": + # Initialize a new content block + block_type = event.content_block.type + current_block = { + "type": block_type, + "id": event.index, # type: ignore[dict-item] + "content": "", + "signature": None, + } + logger.debug(f"Starting {block_type} block...") + + elif event.type == "content_block_delta": + if event.delta.type == "thinking_delta": + # Accumulate thinking content + current_block["content"] = event.delta.thinking + memory_file.write(f"{event.delta.thinking}") + memory_file.flush() # Ensure content is written immediately + + elif event.delta.type == "text_delta": + # Accumulate text content + text_content += event.delta.text + current_block["content"] += event.delta.text # type: ignore[operator] + memory_file.write(f"{event.delta.text}") + memory_file.flush() + + elif event.delta.type == "signature_delta": + # Store signature for thinking blocks + current_block["signature"] = event.delta.signature + memory_file.write( + f"\n[Signature received for block {current_block['id']}]\n" + ) + memory_file.flush() + + elif event.type == "content_block_stop": + # Store completed blocks + if current_block["type"] == "thinking": + # IMPORTANT: Store the complete event.content_block to ensure we preserve + # the exact format that Claude expects in subsequent requests + if hasattr(event, "content_block"): + # Use the exact thinking block as provided by Claude + thinking_blocks.append(event.content_block.model_dump()) + memory_file.write( + f"\nTHINKING COMPLETE: block {current_block['id']}\n" + ) + else: + # Fallback to constructed thinking block if content_block missing + thinking_block = { + "type": "thinking", + "thinking": current_block["content"], + "signature": current_block["signature"], + } + thinking_blocks.append(thinking_block) + memory_file.write( + f"\nTHINKING COMPLETE: block {current_block['id']}\n" + ) + + elif current_block["type"] == "redacted_thinking": + # Handle redacted thinking blocks + if hasattr(event, "content_block") and hasattr( + event.content_block, "data" + ): + redacted_block = { + "type": "redacted_thinking", + "data": event.content_block.data, + } + thinking_blocks.append(redacted_block) + + elif current_block["type"] == "tool_use": + # Process tool use blocks when they're complete + if hasattr(event, "content_block"): + tool_block = event.content_block + tool_id = tool_block.id # type: ignore[union-attr] + tool_name = tool_block.name # type: ignore[union-attr] + tool_input = tool_block.input # type: ignore[union-attr] + + # Create a tool call object for LLMAgent compatibility + tool_call_obj = type( + "ToolCall", + (), + { + "id": tool_id, + "function": type( + "Function", + (), + { + "name": tool_name, + "arguments": json.dumps(tool_input), + }, + ), + }, + ) + tool_calls.append(tool_call_obj) + + # Write tool call information to memory.txt + memory_file.write(f"\n\nTOOL CALL: {tool_name}\n") + memory_file.write( + f"ARGUMENTS: {json.dumps(tool_input, indent=2)}\n" + ) + + # Reset current block + current_block = { + "type": None, + "id": None, + "content": "", + "signature": None, + } + memory_file.flush() + + elif ( + event.type == "message_delta" and event.delta.stop_reason == "tool_use" + ): + # When a tool use is detected + logger.info("Tool use stop reason detected in stream") + + # Mark the end of the response in memory.txt + memory_file.write("\n\nRESPONSE COMPLETE\n\n") + memory_file.flush() + + print("\n==== CLAUDE API RESPONSE STREAM COMPLETED ====") + + # Final response + logger.info( + f"Claude streaming complete. Text: {len(text_content)} chars, Tool calls: {len(tool_calls)}, Thinking blocks: {len(thinking_blocks)}" + ) + + # Return the complete response with all components + return ResponseMessage( + content=text_content, + tool_calls=tool_calls if tool_calls else None, + thinking_blocks=thinking_blocks if thinking_blocks else None, + ) + + except ConnectionError as ce: + logger.error(f"Connection error with Anthropic API: {ce}") + raise + except ValueError as ve: + logger.error(f"Invalid parameters for Anthropic API: {ve}") + raise + except Exception as e: + logger.error(f"Unexpected error in Anthropic API call: {e}") + logger.exception(e) # This will print the full traceback + raise + + def _observable_query( + self, + observer: Observer, # type: ignore[name-defined] + base64_image: str | None = None, + dimensions: tuple[int, int] | None = None, + override_token_limit: bool = False, + incoming_query: str | None = None, + reset_conversation: bool = False, + thinking_budget_tokens: int | None = None, + ) -> None: + """Main query handler that manages conversation history and Claude interactions. + + This is the primary method for handling all queries, whether they come through + direct_query or through the observable pattern. It manages the conversation + history, builds prompts, and handles tool calls. + + Args: + observer (Observer): The observer to emit responses to + base64_image (Optional[str]): Optional Base64-encoded image + dimensions (Optional[Tuple[int, int]]): Optional image dimensions + override_token_limit (bool): Whether to override token limits + incoming_query (Optional[str]): Optional query to update the agent's query + reset_conversation (bool): Whether to reset the conversation history + """ + + try: + logger.info("_observable_query called in claude") + import copy + + # Reset conversation history if requested + if reset_conversation: + self.conversation_history = [] + + # Create a local copy of conversation history and record its length + messages = copy.deepcopy(self.conversation_history) + base_len = len(messages) + + # Update query and get context + self._update_query(incoming_query) + _, rag_results = self._get_rag_context() + + # Build prompt and get Claude parameters + budget = ( + thinking_budget_tokens + if thinking_budget_tokens is not None + else self.thinking_budget_tokens + ) + messages, claude_params = self._build_prompt( + messages, base64_image, dimensions, override_token_limit, rag_results, budget + ) + + # Send query and get response + response_message = self._send_query(messages, claude_params) + + if response_message is None: + logger.error("Received None response from Claude API") + observer.on_next("") + observer.on_completed() + return + # Add thinking blocks and text content to conversation history + content_blocks = [] + if response_message.thinking_blocks: + content_blocks.extend(response_message.thinking_blocks) + if response_message.content: + content_blocks.append({"type": "text", "text": response_message.content}) + if content_blocks: + messages.append({"role": "assistant", "content": content_blocks}) + + # Handle tool calls if present + if response_message.tool_calls: + self._handle_tooling(response_message, messages) # type: ignore[no-untyped-call] + + # At the end, append only new messages (including tool-use/results) to the global conversation history under a lock + import threading + + if not hasattr(self, "_history_lock"): + self._history_lock = threading.Lock() + with self._history_lock: + for msg in messages[base_len:]: + self.conversation_history.append(msg) + + # After merging, run tooling callback (outside lock) + if response_message.tool_calls: + self._tooling_callback(response_message) + + # Send response to observers + result = response_message.content or "" + observer.on_next(result) + self.response_subject.on_next(result) + observer.on_completed() + except Exception as e: + logger.error(f"Query failed in {self.dev_name}: {e}") + # Send a user-friendly error message instead of propagating the error + error_message = "I apologize, but I'm having trouble processing your request right now. Please try again." + observer.on_next(error_message) + self.response_subject.on_next(error_message) + observer.on_completed() + + def _handle_tooling(self, response_message, messages): # type: ignore[no-untyped-def] + """Executes tools and appends tool-use/result blocks to messages.""" + if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: + logger.info("No tool calls found in response message") + return None + + if len(response_message.tool_calls) > 1: + logger.warning( + "Multiple tool calls detected in response message. Not a tested feature." + ) + + # Execute all tools first and collect their results + for tool_call in response_message.tool_calls: + logger.info(f"Processing tool call: {tool_call.function.name}") + tool_use_block = { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), + } + messages.append({"role": "assistant", "content": [tool_use_block]}) + + try: + # Execute the tool + args = json.loads(tool_call.function.arguments) + tool_result = self.skills.call(tool_call.function.name, **args) # type: ignore[union-attr] + + # Check if the result is an error message + if isinstance(tool_result, str) and ( + "Error executing skill" in tool_result or "is not available" in tool_result + ): + # Log the error but provide a user-friendly message + logger.error(f"Tool execution failed: {tool_result}") + tool_result = "I apologize, but I'm having trouble executing that action right now. Please try again or ask for something else." + + # Add tool result to conversation history + if tool_result: + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_call.id, + "content": f"{tool_result}", + } + ], + } + ) + except Exception as e: + logger.error(f"Unexpected error executing tool {tool_call.function.name}: {e}") + # Add error result to conversation history + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_call.id, + "content": "I apologize, but I encountered an error while trying to execute that action. Please try again.", + } + ], + } + ) + + def _tooling_callback(self, response_message) -> None: # type: ignore[no-untyped-def] + """Runs the observable query for each tool call in the current response_message""" + if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: + return + + try: + for tool_call in response_message.tool_calls: + tool_name = tool_call.function.name + tool_id = tool_call.id + self.run_observable_query( + query_text=f"Tool {tool_name}, ID: {tool_id} execution complete. Please summarize the results and continue.", + thinking_budget_tokens=0, + ).run() + except Exception as e: + logger.error(f"Error in tooling callback: {e}") + # Continue processing even if the callback fails + pass + + def _debug_api_call(self, claude_params: dict): # type: ignore[no-untyped-def, type-arg] + """Debugging function to log API calls with truncated base64 data.""" + # Remove tools to reduce verbosity + import copy + + log_params = copy.deepcopy(claude_params) + if "tools" in log_params: + del log_params["tools"] + + # Truncate base64 data in images - much cleaner approach + if "messages" in log_params: + for msg in log_params["messages"]: + if "content" in msg: + for content in msg["content"]: + if isinstance(content, dict) and content.get("type") == "image": + source = content.get("source", {}) + if source.get("type") == "base64" and "data" in source: + data = source["data"] + source["data"] = f"{data[:50]}..." + return json.dumps(log_params, indent=2, default=str) diff --git a/dimos/agents/memory/base.py b/dimos/agents/memory/base.py index 8167ce3571..283b7cfdce 100644 --- a/dimos/agents/memory/base.py +++ b/dimos/agents/memory/base.py @@ -1,9 +1,34 @@ -from abc import ABC, abstractmethod -import logging -from exceptions.agent_memory_exceptions import UnknownConnectionTypeError, AgentMemoryConnectionError +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -class AbstractAgentMemory(ABC): - def __init__(self, connection_type='local', **kwargs): +from abc import abstractmethod + +from dimos.exceptions.agent_memory_exceptions import ( + AgentMemoryConnectionError, + UnknownConnectionTypeError, +) +from dimos.utils.logging_config import setup_logger + +# TODO +# class AbstractAgentMemory(ABC): + +# TODO +# class AbstractAgentSymbolicMemory(AbstractAgentMemory): + + +class AbstractAgentSemanticMemory: # AbstractAgentMemory): + def __init__(self, connection_type: str = "local", **kwargs) -> None: # type: ignore[no-untyped-def] """ Initialize with dynamic connection parameters. Args: @@ -12,59 +37,98 @@ def __init__(self, connection_type='local', **kwargs): UnknownConnectionTypeError: If an unrecognized connection type is specified. AgentMemoryConnectionError: If initializing the database connection fails. """ - self.logger = logging.getLogger(self.__class__.__name__) - self.logger.info('Initializing AgentMemory with connection type: %s', connection_type) + self.logger = setup_logger() + self.logger.info("Initializing AgentMemory with connection type: %s", connection_type) self.connection_params = kwargs - self.db_connection = None # Holds the conection, whether local or remote, to the database used. - - if connection_type not in ['local', 'remote']: - error = UnknownConnectionTypeError(f"Invalid connection_type {connection_type}. Expected 'local' or 'remote'.") + self.db_connection = ( + None # Holds the conection, whether local or remote, to the database used. + ) + + if connection_type not in ["local", "remote"]: + error = UnknownConnectionTypeError( + f"Invalid connection_type {connection_type}. Expected 'local' or 'remote'." + ) self.logger.error(str(error)) raise error try: - if connection_type == 'remote': - self.connect() - elif connection_type == 'local': - self.create() + if connection_type == "remote": + self.connect() # type: ignore[no-untyped-call] + elif connection_type == "local": + self.create() # type: ignore[no-untyped-call] except Exception as e: self.logger.error("Failed to initialize database connection: %s", str(e), exc_info=True) - raise AgentMemoryConnectionError("Initialization failed due to an unexpected error.", cause=e) from e + raise AgentMemoryConnectionError( + "Initialization failed due to an unexpected error.", cause=e + ) from e @abstractmethod - def connect(self): - """Establish a connection to the database using dynamic parameters specified during initialization.""" + def connect(self): # type: ignore[no-untyped-def] + """Establish a connection to the data store using dynamic parameters specified during initialization.""" @abstractmethod - def create(self): - """Create a local instance of the database tailored to specific requirements.""" + def create(self): # type: ignore[no-untyped-def] + """Create a local instance of the data store tailored to specific requirements.""" + ## Create ## @abstractmethod - def add_vector(self, vector_id, vector_data): + def add_vector(self, vector_id, vector_data): # type: ignore[no-untyped-def] """Add a vector to the database. Args: vector_id (any): Unique identifier for the vector. vector_data (any): The actual data of the vector to be stored. """ + ## Read ## @abstractmethod - def get_vector(self, vector_id): + def get_vector(self, vector_id): # type: ignore[no-untyped-def] """Retrieve a vector from the database by its identifier. Args: vector_id (any): The identifier of the vector to retrieve. """ @abstractmethod - def update_vector(self, vector_id, new_vector_data): + def query(self, query_texts, n_results: int = 4, similarity_threshold=None): # type: ignore[no-untyped-def] + """Performs a semantic search in the vector database. + + Args: + query_texts (Union[str, List[str]]): The query text or list of query texts to search for. + n_results (int, optional): Number of results to return. Defaults to 4. + similarity_threshold (float, optional): Minimum similarity score for results to be included [0.0, 1.0]. Defaults to None. + + Returns: + List[Tuple[Document, Optional[float]]]: A list of tuples containing the search results. Each tuple + contains: + Document: The retrieved document object. + Optional[float]: The similarity score of the match, or None if not applicable. + + Raises: + ValueError: If query_texts is empty or invalid. + ConnectionError: If database connection fails during query. + """ + + ## Update ## + @abstractmethod + def update_vector(self, vector_id, new_vector_data): # type: ignore[no-untyped-def] """Update an existing vector in the database. Args: vector_id (any): The identifier of the vector to update. new_vector_data (any): The new data to replace the existing vector data. """ + ## Delete ## @abstractmethod - def delete_vector(self, vector_id): + def delete_vector(self, vector_id): # type: ignore[no-untyped-def] """Delete a vector from the database using its identifier. Args: vector_id (any): The identifier of the vector to delete. """ + + +# query(string, metadata/tag, n_rets, kwargs) + +# query by string, timestamp, id, n_rets + +# (some sort of tag/metadata) + +# temporal diff --git a/dimos/agents/memory/chroma_impl.py b/dimos/agents/memory/chroma_impl.py index b078578496..a40e8f5edb 100644 --- a/dimos/agents/memory/chroma_impl.py +++ b/dimos/agents/memory/chroma_impl.py @@ -1,50 +1,174 @@ -from agents.memory.base import AbstractAgentMemory +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +import os + +from langchain_chroma import Chroma from langchain_openai import OpenAIEmbeddings +import torch + +from dimos.agents.memory.base import AbstractAgentSemanticMemory + + +class ChromaAgentSemanticMemory(AbstractAgentSemanticMemory): + """Base class for Chroma-based semantic memory implementations.""" + + def __init__(self, collection_name: str = "my_collection") -> None: + """Initialize the connection to the local Chroma DB.""" + self.collection_name = collection_name + self.db_connection = None + self.embeddings = None + super().__init__(connection_type="local") + + def connect(self): # type: ignore[no-untyped-def] + # Stub + return super().connect() # type: ignore[no-untyped-call, safe-super] + + def create(self): # type: ignore[no-untyped-def] + """Create the embedding function and initialize the Chroma database. + This method must be implemented by child classes.""" + raise NotImplementedError("Child classes must implement this method") + + def add_vector(self, vector_id, vector_data): # type: ignore[no-untyped-def] + """Add a vector to the ChromaDB collection.""" + if not self.db_connection: + raise Exception("Collection not initialized. Call connect() first.") + self.db_connection.add_texts( + ids=[vector_id], + texts=[vector_data], + metadatas=[{"name": vector_id}], + ) + def get_vector(self, vector_id): # type: ignore[no-untyped-def] + """Retrieve a vector from the ChromaDB by its identifier.""" + result = self.db_connection.get(include=["embeddings"], ids=[vector_id]) # type: ignore[attr-defined] + return result + + def query(self, query_texts, n_results: int = 4, similarity_threshold=None): # type: ignore[no-untyped-def] + """Query the collection with a specific text and return up to n results.""" + if not self.db_connection: + raise Exception("Collection not initialized. Call connect() first.") + + if similarity_threshold is not None: + if not (0 <= similarity_threshold <= 1): + raise ValueError("similarity_threshold must be between 0 and 1.") + return self.db_connection.similarity_search_with_relevance_scores( + query=query_texts, k=n_results, score_threshold=similarity_threshold + ) + else: + documents = self.db_connection.similarity_search(query=query_texts, k=n_results) + return [(doc, None) for doc in documents] + + def update_vector(self, vector_id, new_vector_data): # type: ignore[no-untyped-def] + # TODO + return super().connect() # type: ignore[no-untyped-call, safe-super] + + def delete_vector(self, vector_id): # type: ignore[no-untyped-def] + """Delete a vector from the ChromaDB using its identifier.""" + if not self.db_connection: + raise Exception("Collection not initialized. Call connect() first.") + self.db_connection.delete(ids=[vector_id]) + + +class OpenAISemanticMemory(ChromaAgentSemanticMemory): + """Semantic memory implementation using OpenAI's embedding API.""" + + def __init__( + self, + collection_name: str = "my_collection", + model: str = "text-embedding-3-large", + dimensions: int = 1024, + ) -> None: + """Initialize OpenAI-based semantic memory. -class AgentMemoryChroma(AbstractAgentMemory): - def __init__(self, connection_type='local', host='localhost', port=6379, db=0): - """Initialize the connection to the Chroma DB. Args: - host (str): The host on which Chroma DB is running. - port (int): The port on which Chroma DB is accessible. - db (int): The database index to use. - connection_type (str): Whether to connect to a local or remote database.' + collection_name (str): Name of the Chroma collection + model (str): OpenAI embedding model to use + dimensions (int): Dimension of the embedding vectors """ - super().__init__(connection_type=connection_type, host=host, port=port, db=db) - self.db_connection - - - def connect(self): - try: - import dimos.agents.memory.chroma_impl as chroma_impl - self.connection = chroma_impl.connect(self.host, self.port, self.db) - self.logger.info("Connected successfully to Chroma DB") - except Exception as e: - self.logger.error("Failed to connect to Chroma DB", exc_info=True) - - def add_vector(self, vector_id, vector_data): - try: - self.connection.add(vector_id, vector_data) - except Exception as e: - self.logger.error(f"Failed to add vector {vector_id}", exc_info=True) - - def get_vector(self, vector_id): - try: - return self.connection.get(vector_id) - except Exception as e: - self.logger.error(f"Failed to retrieve vector {vector_id}", exc_info=True) - return None - - def update_vector(self, vector_id, new_vector_data): - try: - self.connection.update(vector_id, new_vector_data) - except Exception as e: - self.logger.error(f"Failed to update vector {vector_id}", exc_info=True) - - def delete_vector(self, vector_id): - try: - self.connection.delete(vector_id) - except Exception as e: - self.logger.error(f"Failed to delete vector {vector_id}", exc_info=True) + self.model = model + self.dimensions = dimensions + super().__init__(collection_name=collection_name) + + def create(self): # type: ignore[no-untyped-def] + """Connect to OpenAI API and create the ChromaDB client.""" + # Get OpenAI key + self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + if not self.OPENAI_API_KEY: + raise Exception("OpenAI key was not specified.") + + # Set embeddings + self.embeddings = OpenAIEmbeddings( # type: ignore[assignment] + model=self.model, + dimensions=self.dimensions, + api_key=self.OPENAI_API_KEY, # type: ignore[arg-type] + ) + + # Create the database + self.db_connection = Chroma( # type: ignore[assignment] + collection_name=self.collection_name, + embedding_function=self.embeddings, + collection_metadata={"hnsw:space": "cosine"}, + ) + + +class LocalSemanticMemory(ChromaAgentSemanticMemory): + """Semantic memory implementation using local models.""" + + def __init__( + self, + collection_name: str = "my_collection", + model_name: str = "sentence-transformers/all-MiniLM-L6-v2", + ) -> None: + """Initialize the local semantic memory using SentenceTransformer. + + Args: + collection_name (str): Name of the Chroma collection + model_name (str): Embeddings model + """ + + self.model_name = model_name + super().__init__(collection_name=collection_name) + + def create(self) -> None: + """Create local embedding model and initialize the ChromaDB client.""" + # Load the sentence transformer model + # Use CUDA if available, otherwise fall back to CPU + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + self.model = SentenceTransformer(self.model_name, device=device) # type: ignore[name-defined] + + # Create a custom embedding class that implements the embed_query method + class SentenceTransformerEmbeddings: + def __init__(self, model) -> None: # type: ignore[no-untyped-def] + self.model = model + + def embed_query(self, text: str): # type: ignore[no-untyped-def] + """Embed a single query text.""" + return self.model.encode(text, normalize_embeddings=True).tolist() + + def embed_documents(self, texts: Sequence[str]): # type: ignore[no-untyped-def] + """Embed multiple documents/texts.""" + return self.model.encode(texts, normalize_embeddings=True).tolist() + + # Create an instance of our custom embeddings class + self.embeddings = SentenceTransformerEmbeddings(self.model) # type: ignore[assignment] + + # Create the database + self.db_connection = Chroma( # type: ignore[assignment] + collection_name=self.collection_name, + embedding_function=self.embeddings, + collection_metadata={"hnsw:space": "cosine"}, + ) diff --git a/dimos/agents/memory/image_embedding.py b/dimos/agents/memory/image_embedding.py new file mode 100644 index 0000000000..3751ca84de --- /dev/null +++ b/dimos/agents/memory/image_embedding.py @@ -0,0 +1,273 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image embedding module for converting images to vector embeddings. + +This module provides a class for generating vector embeddings from images +using pre-trained models like CLIP, ResNet, etc. +""" + +import base64 +import io +import os + +import cv2 +import numpy as np +from PIL import Image + +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class ImageEmbeddingProvider: + """ + A provider for generating vector embeddings from images. + + This class uses pre-trained models to convert images into vector embeddings + that can be stored in a vector database and used for similarity search. + """ + + def __init__(self, model_name: str = "clip", dimensions: int = 512) -> None: + """ + Initialize the image embedding provider. + + Args: + model_name: Name of the embedding model to use ("clip", "resnet", etc.) + dimensions: Dimensions of the embedding vectors + """ + self.model_name = model_name + self.dimensions = dimensions + self.model = None + self.processor = None + self.model_path = None + + self._initialize_model() # type: ignore[no-untyped-call] + + logger.info(f"ImageEmbeddingProvider initialized with model {model_name}") + + def _initialize_model(self): # type: ignore[no-untyped-def] + """Initialize the specified embedding model.""" + try: + import onnxruntime as ort # type: ignore[import-untyped] + import torch + from transformers import ( # type: ignore[import-untyped] + AutoFeatureExtractor, + AutoModel, + CLIPProcessor, + ) + + if self.model_name == "clip": + model_id = get_data("models_clip") / "model.onnx" + self.model_path = str(model_id) # type: ignore[assignment] # Store for pickling + processor_id = "openai/clip-vit-base-patch32" + + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + + self.model = ort.InferenceSession(str(model_id), providers=providers) + + actual_providers = self.model.get_providers() # type: ignore[attr-defined] + self.processor = CLIPProcessor.from_pretrained(processor_id) + logger.info(f"Loaded CLIP model: {model_id} with providers: {actual_providers}") + elif self.model_name == "resnet": + model_id = "microsoft/resnet-50" # type: ignore[assignment] + self.model = AutoModel.from_pretrained(model_id) + self.processor = AutoFeatureExtractor.from_pretrained(model_id) + logger.info(f"Loaded ResNet model: {model_id}") + else: + raise ValueError(f"Unsupported model: {self.model_name}") + except ImportError as e: + logger.error(f"Failed to import required modules: {e}") + logger.error("Please install with: pip install transformers torch") + # Initialize with dummy model for type checking + self.model = None + self.processor = None + raise + + def get_embedding(self, image: np.ndarray | str | bytes) -> np.ndarray: # type: ignore[type-arg] + """ + Generate an embedding vector for the provided image. + + Args: + image: The image to embed, can be a numpy array (OpenCV format), + a file path, or a base64-encoded string + + Returns: + A numpy array containing the embedding vector + """ + if self.model is None or self.processor is None: + logger.error("Model not initialized. Using fallback random embedding.") + return np.random.randn(self.dimensions).astype(np.float32) + + pil_image = self._prepare_image(image) + + try: + import torch + + if self.model_name == "clip": + inputs = self.processor(images=pil_image, return_tensors="np") + + with torch.no_grad(): + ort_inputs = { + inp.name: inputs[inp.name] + for inp in self.model.get_inputs() + if inp.name in inputs + } + + # If required, add dummy text inputs + input_names = [i.name for i in self.model.get_inputs()] + batch_size = inputs["pixel_values"].shape[0] + if "input_ids" in input_names: + ort_inputs["input_ids"] = np.zeros((batch_size, 1), dtype=np.int64) + if "attention_mask" in input_names: + ort_inputs["attention_mask"] = np.ones((batch_size, 1), dtype=np.int64) + + # Run inference + ort_outputs = self.model.run(None, ort_inputs) + + # Look up correct output name + output_names = [o.name for o in self.model.get_outputs()] + if "image_embeds" in output_names: + image_embedding = ort_outputs[output_names.index("image_embeds")] + else: + raise RuntimeError(f"No 'image_embeds' found in outputs: {output_names}") + + embedding = image_embedding / np.linalg.norm(image_embedding, axis=1, keepdims=True) + embedding = embedding[0] + + elif self.model_name == "resnet": + inputs = self.processor(images=pil_image, return_tensors="pt") + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Get the [CLS] token embedding + embedding = outputs.last_hidden_state[:, 0, :].numpy()[0] + else: + logger.warning(f"Unsupported model: {self.model_name}. Using random embedding.") + embedding = np.random.randn(self.dimensions).astype(np.float32) + + # Normalize and ensure correct dimensions + embedding = embedding / np.linalg.norm(embedding) + + logger.debug(f"Generated embedding with shape {embedding.shape}") + return embedding + + except Exception as e: + logger.error(f"Error generating embedding: {e}") + return np.random.randn(self.dimensions).astype(np.float32) + + def get_text_embedding(self, text: str) -> np.ndarray: # type: ignore[type-arg] + """ + Generate an embedding vector for the provided text. + + Args: + text: The text to embed + + Returns: + A numpy array containing the embedding vector + """ + if self.model is None or self.processor is None: + logger.error("Model not initialized. Using fallback random embedding.") + return np.random.randn(self.dimensions).astype(np.float32) + + if self.model_name != "clip": + logger.warning( + f"Text embeddings are only supported with CLIP model, not {self.model_name}. Using random embedding." + ) + return np.random.randn(self.dimensions).astype(np.float32) + + try: + import torch + + inputs = self.processor(text=[text], return_tensors="np", padding=True) + + with torch.no_grad(): + # Prepare ONNX input dict (handle only what's needed) + ort_inputs = { + inp.name: inputs[inp.name] + for inp in self.model.get_inputs() + if inp.name in inputs + } + # Determine which inputs are expected by the ONNX model + input_names = [i.name for i in self.model.get_inputs()] + batch_size = inputs["input_ids"].shape[0] # pulled from text input + + # If the model expects pixel_values (i.e., fused model), add dummy vision input + if "pixel_values" in input_names: + ort_inputs["pixel_values"] = np.zeros( + (batch_size, 3, 224, 224), dtype=np.float32 + ) + + # Run inference + ort_outputs = self.model.run(None, ort_inputs) + + # Determine correct output (usually 'last_hidden_state' or 'text_embeds') + output_names = [o.name for o in self.model.get_outputs()] + if "text_embeds" in output_names: + text_embedding = ort_outputs[output_names.index("text_embeds")] + else: + text_embedding = ort_outputs[0] # fallback to first output + + # Normalize + text_embedding = text_embedding / np.linalg.norm( + text_embedding, axis=1, keepdims=True + ) + text_embedding = text_embedding[0] # shape: (512,) + + logger.debug( + f"Generated text embedding with shape {text_embedding.shape} for text: '{text}'" + ) + return text_embedding + + except Exception as e: + logger.error(f"Error generating text embedding: {e}") + return np.random.randn(self.dimensions).astype(np.float32) + + def _prepare_image(self, image: np.ndarray | str | bytes) -> Image.Image: # type: ignore[type-arg] + """ + Convert the input image to PIL format required by the models. + + Args: + image: Input image in various formats + + Returns: + PIL Image object + """ + if isinstance(image, np.ndarray): + if len(image.shape) == 3 and image.shape[2] == 3: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + else: + image_rgb = image + + return Image.fromarray(image_rgb) + + elif isinstance(image, str): + if os.path.isfile(image): + return Image.open(image) + else: + try: + image_data = base64.b64decode(image) + return Image.open(io.BytesIO(image_data)) + except Exception as e: + logger.error(f"Failed to decode image string: {e}") + raise ValueError("Invalid image string format") + + elif isinstance(image, bytes): + return Image.open(io.BytesIO(image)) + + else: + raise ValueError(f"Unsupported image format: {type(image)}") diff --git a/dimos/agents/memory/spatial_vector_db.py b/dimos/agents/memory/spatial_vector_db.py new file mode 100644 index 0000000000..1eac1618d0 --- /dev/null +++ b/dimos/agents/memory/spatial_vector_db.py @@ -0,0 +1,338 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Spatial vector database for storing and querying images with XY locations. + +This module extends the ChromaDB implementation to support storing images with +their XY locations and querying by location or image similarity. +""" + +from typing import Any + +import chromadb +import numpy as np + +from dimos.agents.memory.visual_memory import VisualMemory +from dimos.types.robot_location import RobotLocation +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class SpatialVectorDB: + """ + A vector database for storing and querying images mapped to X,Y,theta absolute locations for SpatialMemory. + + This class extends the ChromaDB implementation to support storing images with + their absolute locations and querying by location, text, or image cosine semantic similarity. + """ + + def __init__( # type: ignore[no-untyped-def] + self, + collection_name: str = "spatial_memory", + chroma_client=None, + visual_memory=None, + embedding_provider=None, + ) -> None: + """ + Initialize the spatial vector database. + + Args: + collection_name: Name of the vector database collection + chroma_client: Optional ChromaDB client for persistence. If None, an in-memory client is used. + visual_memory: Optional VisualMemory instance for storing images. If None, a new one is created. + embedding_provider: Optional ImageEmbeddingProvider instance for computing embeddings. If None, one will be created. + """ + self.collection_name = collection_name + + # Use provided client or create in-memory client + self.client = chroma_client if chroma_client is not None else chromadb.Client() + + # Check if collection already exists - in newer ChromaDB versions list_collections returns names directly + existing_collections = self.client.list_collections() + + # Handle different versions of ChromaDB API + try: + collection_exists = collection_name in existing_collections + except: + try: + collection_exists = collection_name in [c.name for c in existing_collections] + except: + try: + self.client.get_collection(name=collection_name) + collection_exists = True + except Exception: + collection_exists = False + + # Get or create the collection + self.image_collection = self.client.get_or_create_collection( + name=collection_name, metadata={"hnsw:space": "cosine"} + ) + + # Use provided visual memory or create a new one + self.visual_memory = visual_memory if visual_memory is not None else VisualMemory() + + # Store the embedding provider to reuse for all operations + self.embedding_provider = embedding_provider + + # Initialize the location collection for text-based location tagging + location_collection_name = f"{collection_name}_locations" + self.location_collection = self.client.get_or_create_collection( + name=location_collection_name, metadata={"hnsw:space": "cosine"} + ) + + # Log initialization info with details about whether using existing collection + client_type = "persistent" if chroma_client is not None else "in-memory" + try: + count = len(self.image_collection.get(include=[])["ids"]) + if collection_exists: + logger.info( + f"Using EXISTING {client_type} collection '{collection_name}' with {count} entries" + ) + else: + logger.info(f"Created NEW {client_type} collection '{collection_name}'") + except Exception as e: + logger.info( + f"Initialized {client_type} collection '{collection_name}' (count error: {e!s})" + ) + + def add_image_vector( + self, + vector_id: str, + image: np.ndarray, # type: ignore[type-arg] + embedding: np.ndarray, # type: ignore[type-arg] + metadata: dict[str, Any], + ) -> None: + """ + Add an image with its embedding and metadata to the vector database. + + Args: + vector_id: Unique identifier for the vector + image: The image to store + embedding: The pre-computed embedding vector for the image + metadata: Metadata for the image, including x, y coordinates + """ + # Store the image in visual memory + self.visual_memory.add(vector_id, image) + + # Add the vector to ChromaDB + self.image_collection.add( + ids=[vector_id], embeddings=[embedding.tolist()], metadatas=[metadata] + ) + + logger.info(f"Added image vector {vector_id} with metadata: {metadata}") + + def query_by_embedding(self, embedding: np.ndarray, limit: int = 5) -> list[dict]: # type: ignore[type-arg] + """ + Query the vector database for images similar to the provided embedding. + + Args: + embedding: Query embedding vector + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + results = self.image_collection.query( + query_embeddings=[embedding.tolist()], n_results=limit + ) + + return self._process_query_results(results) + + # TODO: implement efficient nearest neighbor search + def query_by_location( + self, x: float, y: float, radius: float = 2.0, limit: int = 5 + ) -> list[dict]: # type: ignore[type-arg] + """ + Query the vector database for images near the specified location. + + Args: + x: X coordinate + y: Y coordinate + radius: Search radius in meters + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + results = self.image_collection.get() + + if not results or not results["ids"]: + return [] + + filtered_results = {"ids": [], "metadatas": [], "distances": []} # type: ignore[var-annotated] + + for i, metadata in enumerate(results["metadatas"]): # type: ignore[arg-type] + item_x = metadata.get("x") + item_y = metadata.get("y") + + if item_x is not None and item_y is not None: + distance = np.sqrt((x - item_x) ** 2 + (y - item_y) ** 2) + + if distance <= radius: + filtered_results["ids"].append(results["ids"][i]) + filtered_results["metadatas"].append(metadata) + filtered_results["distances"].append(distance) + + sorted_indices = np.argsort(filtered_results["distances"]) + filtered_results["ids"] = [filtered_results["ids"][i] for i in sorted_indices[:limit]] + filtered_results["metadatas"] = [ + filtered_results["metadatas"][i] for i in sorted_indices[:limit] + ] + filtered_results["distances"] = [ + filtered_results["distances"][i] for i in sorted_indices[:limit] + ] + + return self._process_query_results(filtered_results) + + def _process_query_results(self, results) -> list[dict]: # type: ignore[no-untyped-def, type-arg] + """Process query results to include decoded images.""" + if not results or not results["ids"]: + return [] + + processed_results = [] + + for i, vector_id in enumerate(results["ids"]): + if isinstance(vector_id, list) and not vector_id: + continue + + lookup_id = vector_id[0] if isinstance(vector_id, list) else vector_id + + # Create the result dictionary with metadata regardless of image availability + result = { + "metadata": results["metadatas"][i] if "metadatas" in results else {}, + "id": lookup_id, + } + + # Add distance if available + if "distances" in results: + result["distance"] = ( + results["distances"][i][0] + if isinstance(results["distances"][i], list) + else results["distances"][i] + ) + + # Get the image from visual memory + image = self.visual_memory.get(lookup_id) + result["image"] = image + + processed_results.append(result) + + return processed_results + + def query_by_text(self, text: str, limit: int = 5) -> list[dict]: # type: ignore[type-arg] + """ + Query the vector database for images matching the provided text description. + + This method uses CLIP's text-to-image matching capability to find images + that semantically match the text query (e.g., "where is the kitchen"). + + Args: + text: Text query to search for + limit: Maximum number of results to return + + Returns: + List of results, each containing the image, its metadata, and similarity score + """ + if self.embedding_provider is None: + from dimos.agents.memory.image_embedding import ImageEmbeddingProvider + + self.embedding_provider = ImageEmbeddingProvider(model_name="clip") + + text_embedding = self.embedding_provider.get_text_embedding(text) + + results = self.image_collection.query( + query_embeddings=[text_embedding.tolist()], + n_results=limit, + include=["documents", "metadatas", "distances"], + ) + + logger.info( + f"Text query: '{text}' returned {len(results['ids'] if 'ids' in results else [])} results" + ) + return self._process_query_results(results) + + def get_all_locations(self) -> list[tuple[float, float, float]]: + """Get all locations stored in the database.""" + # Get all items from the collection without embeddings + results = self.image_collection.get(include=["metadatas"]) + + if not results or "metadatas" not in results or not results["metadatas"]: + return [] + + # Extract x, y coordinates from metadata + locations = [] + for metadata in results["metadatas"]: + if isinstance(metadata, list) and metadata and isinstance(metadata[0], dict): + metadata = metadata[0] # Handle nested metadata + + if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: + x = metadata.get("x", 0) + y = metadata.get("y", 0) + z = metadata.get("z", 0) if "z" in metadata else 0 + locations.append((x, y, z)) + + return locations + + @property + def image_storage(self): # type: ignore[no-untyped-def] + """Legacy accessor for compatibility with existing code.""" + return self.visual_memory.images + + def tag_location(self, location: RobotLocation) -> None: + """ + Tag a location with a semantic name/description for text-based retrieval. + + Args: + location: RobotLocation object with position/rotation data + """ + + location_id = location.location_id + metadata = location.to_vector_metadata() + + self.location_collection.add( + ids=[location_id], documents=[location.name], metadatas=[metadata] + ) + + def query_tagged_location(self, query: str) -> tuple[RobotLocation | None, float]: + """ + Query for a tagged location using semantic text search. + + Args: + query: Natural language query (e.g., "dining area", "place to eat") + + Returns: + The best matching RobotLocation or None if no matches found + """ + + results = self.location_collection.query( + query_texts=[query], n_results=1, include=["metadatas", "documents", "distances"] + ) + + if not (results and results["ids"] and len(results["ids"][0]) > 0): + return None, 0 + + best_match_metadata = results["metadatas"][0][0] # type: ignore[index] + distance = float(results["distances"][0][0] if "distances" in results else 0.0) # type: ignore[index] + + location = RobotLocation.from_vector_metadata(best_match_metadata) # type: ignore[arg-type] + + logger.info( + f"Found location '{location.name}' for query '{query}' (distance: {distance:.3f})" + if distance + else "" + ) + + return location, distance diff --git a/dimos/agents/memory/test_image_embedding.py b/dimos/agents/memory/test_image_embedding.py new file mode 100644 index 0000000000..9db8bddfc5 --- /dev/null +++ b/dimos/agents/memory/test_image_embedding.py @@ -0,0 +1,214 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test module for the CLIP image embedding functionality in dimos. +""" + +import os +import time + +import numpy as np +import pytest +from reactivex import operators as ops + +from dimos.agents.memory.image_embedding import ImageEmbeddingProvider +from dimos.stream.video_provider import VideoProvider + + +@pytest.mark.heavy +class TestImageEmbedding: + """Test class for CLIP image embedding functionality.""" + + @pytest.mark.tofix + def test_clip_embedding_initialization(self) -> None: + """Test CLIP embedding provider initializes correctly.""" + try: + # Initialize the embedding provider with CLIP model + embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) + assert embedding_provider.model is not None, "CLIP model failed to initialize" + assert embedding_provider.processor is not None, "CLIP processor failed to initialize" + assert embedding_provider.model_name == "clip", "Model name should be 'clip'" + assert embedding_provider.dimensions == 512, "Embedding dimensions should be 512" + except Exception as e: + pytest.skip(f"Skipping test due to model initialization error: {e}") + + @pytest.mark.tofix + def test_clip_embedding_process_video(self) -> None: + """Test CLIP embedding provider can process video frames and return embeddings.""" + try: + from dimos.utils.data import get_data + + video_path = get_data("assets") / "trimmed_video_office.mov" + + embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) + + assert os.path.exists(video_path), f"Test video not found: {video_path}" + video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) + + # Use ReactiveX operators to process the stream + def process_frame(frame): + try: + # Process frame with CLIP + embedding = embedding_provider.get_embedding(frame) + print( + f"Generated CLIP embedding with shape: {embedding.shape}, norm: {np.linalg.norm(embedding):.4f}" + ) + + return {"frame": frame, "embedding": embedding} + except Exception as e: + print(f"Error in process_frame: {e}") + return None + + embedding_stream = video_stream.pipe(ops.map(process_frame)) + + results = [] + frames_processed = 0 + target_frames = 10 + + def on_next(result) -> None: + nonlocal frames_processed, results + if not result: # Skip None results + return + + results.append(result) + frames_processed += 1 + + # Stop processing after target frames + if frames_processed >= target_frames: + subscription.dispose() + + def on_error(error) -> None: + pytest.fail(f"Error in embedding stream: {error}") + + def on_completed() -> None: + pass + + # Subscribe and wait for results + subscription = embedding_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + timeout = 60.0 + start_time = time.time() + while frames_processed < target_frames and time.time() - start_time < timeout: + time.sleep(0.5) + print(f"Processed {frames_processed}/{target_frames} frames") + + # Clean up subscription + subscription.dispose() + video_provider.dispose_all() + + # Check if we have results + if len(results) == 0: + pytest.skip("No embeddings generated, but test connection established correctly") + return + + print(f"Processed {len(results)} frames with CLIP embeddings") + + # Analyze the results + assert len(results) > 0, "No embeddings generated" + + # Check properties of first embedding + first_result = results[0] + assert "embedding" in first_result, "Result doesn't contain embedding" + assert "frame" in first_result, "Result doesn't contain frame" + + # Check embedding shape and normalization + embedding = first_result["embedding"] + assert isinstance(embedding, np.ndarray), "Embedding is not a numpy array" + assert embedding.shape == (512,), ( + f"Embedding has wrong shape: {embedding.shape}, expected (512,)" + ) + assert abs(np.linalg.norm(embedding) - 1.0) < 1e-5, "Embedding is not normalized" + + # Save the first embedding for similarity tests + if len(results) > 1 and "embedding" in results[0]: + # Create a class variable to store embeddings for the similarity test + TestImageEmbedding.test_embeddings = { + "embedding1": results[0]["embedding"], + "embedding2": results[1]["embedding"] if len(results) > 1 else None, + } + print("Saved embeddings for similarity testing") + + print("CLIP embedding test passed successfully!") + + except Exception as e: + pytest.fail(f"Test failed with error: {e}") + + @pytest.mark.tofix + def test_clip_embedding_similarity(self) -> None: + """Test CLIP embedding similarity search and text-to-image queries.""" + try: + # Skip if previous test didn't generate embeddings + if not hasattr(TestImageEmbedding, "test_embeddings"): + pytest.skip("No embeddings available from previous test") + return + + # Get embeddings from previous test + embedding1 = TestImageEmbedding.test_embeddings["embedding1"] + embedding2 = TestImageEmbedding.test_embeddings["embedding2"] + + # Initialize embedding provider for text embeddings + embedding_provider = ImageEmbeddingProvider(model_name="clip", dimensions=512) + + # Test frame-to-frame similarity + if embedding1 is not None and embedding2 is not None: + # Compute cosine similarity + similarity = np.dot(embedding1, embedding2) + print(f"Similarity between first two frames: {similarity:.4f}") + + # Should be in range [-1, 1] + assert -1.0 <= similarity <= 1.0, f"Similarity out of valid range: {similarity}" + + # Test text-to-image similarity + if embedding1 is not None: + # Generate a list of text queries to test + text_queries = ["a video frame", "a person", "an outdoor scene", "a kitchen"] + + # Test each text query + for text_query in text_queries: + # Get text embedding + text_embedding = embedding_provider.get_text_embedding(text_query) + + # Check text embedding properties + assert isinstance(text_embedding, np.ndarray), ( + "Text embedding is not a numpy array" + ) + assert text_embedding.shape == (512,), ( + f"Text embedding has wrong shape: {text_embedding.shape}" + ) + assert abs(np.linalg.norm(text_embedding) - 1.0) < 1e-5, ( + "Text embedding is not normalized" + ) + + # Compute similarity between frame and text + text_similarity = np.dot(embedding1, text_embedding) + print(f"Similarity between frame and '{text_query}': {text_similarity:.4f}") + + # Should be in range [-1, 1] + assert -1.0 <= text_similarity <= 1.0, ( + f"Text-image similarity out of range: {text_similarity}" + ) + + print("CLIP embedding similarity tests passed successfully!") + + except Exception as e: + pytest.fail(f"Similarity test failed with error: {e}") + + +if __name__ == "__main__": + pytest.main(["-v", "--disable-warnings", __file__]) diff --git a/dimos/agents/memory/visual_memory.py b/dimos/agents/memory/visual_memory.py new file mode 100644 index 0000000000..98ad00e2fd --- /dev/null +++ b/dimos/agents/memory/visual_memory.py @@ -0,0 +1,182 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Visual memory storage for managing image data persistence and retrieval +""" + +import base64 +import os +import pickle + +import cv2 +import numpy as np + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class VisualMemory: + """ + A class for storing and retrieving visual memories (images) with persistence. + + This class handles the storage, encoding, and retrieval of images associated + with vector database entries. It provides persistence mechanisms to save and + load the image data from disk. + """ + + def __init__(self, output_dir: str | None = None) -> None: + """ + Initialize the visual memory system. + + Args: + output_dir: Directory to store the serialized image data + """ + self.images = {} # type: ignore[var-annotated] # Maps IDs to encoded images + self.output_dir = output_dir + + if output_dir: + os.makedirs(output_dir, exist_ok=True) + logger.info(f"VisualMemory initialized with output directory: {output_dir}") + else: + logger.info("VisualMemory initialized with no persistence directory") + + def add(self, image_id: str, image: np.ndarray) -> None: # type: ignore[type-arg] + """ + Add an image to visual memory. + + Args: + image_id: Unique identifier for the image + image: The image data as a numpy array + """ + # Encode the image to base64 for storage + success, encoded_image = cv2.imencode(".jpg", image) + if not success: + logger.error(f"Failed to encode image {image_id}") + return + + image_bytes = encoded_image.tobytes() + b64_encoded = base64.b64encode(image_bytes).decode("utf-8") + + # Store the encoded image + self.images[image_id] = b64_encoded + logger.debug(f"Added image {image_id} to visual memory") + + def get(self, image_id: str) -> np.ndarray | None: # type: ignore[type-arg] + """ + Retrieve an image from visual memory. + + Args: + image_id: Unique identifier for the image + + Returns: + The decoded image as a numpy array, or None if not found + """ + if image_id not in self.images: + logger.warning( + f"Image not found in storage for ID {image_id}. Incomplete or corrupted image storage." + ) + return None + + try: + encoded_image = self.images[image_id] + image_bytes = base64.b64decode(encoded_image) + image_array = np.frombuffer(image_bytes, dtype=np.uint8) + image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + return image + except Exception as e: + logger.warning(f"Failed to decode image for ID {image_id}: {e!s}") + return None + + def contains(self, image_id: str) -> bool: + """ + Check if an image ID exists in visual memory. + + Args: + image_id: Unique identifier for the image + + Returns: + True if the image exists, False otherwise + """ + return image_id in self.images + + def count(self) -> int: + """ + Get the number of images in visual memory. + + Returns: + The number of images stored + """ + return len(self.images) + + def save(self, filename: str | None = None) -> str: + """ + Save the visual memory to disk. + + Args: + filename: Optional filename to save to. If None, uses a default name in the output directory. + + Returns: + The path where the data was saved + """ + if not self.output_dir: + logger.warning("No output directory specified for VisualMemory. Cannot save.") + return "" + + if not filename: + filename = "visual_memory.pkl" + + output_path = os.path.join(self.output_dir, filename) + + try: + with open(output_path, "wb") as f: + pickle.dump(self.images, f) + logger.info(f"Saved {len(self.images)} images to {output_path}") + return output_path + except Exception as e: + logger.error(f"Failed to save visual memory: {e!s}") + return "" + + @classmethod + def load(cls, path: str, output_dir: str | None = None) -> "VisualMemory": + """ + Load visual memory from disk. + + Args: + path: Path to the saved visual memory file + output_dir: Optional output directory for the new instance + + Returns: + A new VisualMemory instance with the loaded data + """ + instance = cls(output_dir=output_dir) + + if not os.path.exists(path): + logger.warning(f"Visual memory file {path} not found") + return instance + + try: + with open(path, "rb") as f: + instance.images = pickle.load(f) + logger.info(f"Loaded {len(instance.images)} images from {path}") + return instance + except Exception as e: + logger.error(f"Failed to load visual memory: {e!s}") + return instance + + def clear(self) -> None: + """Clear all images from memory.""" + self.images = {} + logger.info("Visual memory cleared") diff --git a/dimos/agents/modules/__init__.py b/dimos/agents/modules/__init__.py new file mode 100644 index 0000000000..ee1269f8f5 --- /dev/null +++ b/dimos/agents/modules/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent modules for DimOS.""" diff --git a/dimos/agents/modules/base.py b/dimos/agents/modules/base.py new file mode 100644 index 0000000000..d5641aee39 --- /dev/null +++ b/dimos/agents/modules/base.py @@ -0,0 +1,525 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base agent class with all features (non-module).""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +import json +from typing import Any + +from reactivex.subject import Subject + +from dimos.agents.agent_message import AgentMessage +from dimos.agents.agent_types import AgentResponse, ConversationHistory, ToolCall +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.agents.memory.chroma_impl import OpenAISemanticMemory +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.utils.logging_config import setup_logger + +try: + from .gateway import UnifiedGatewayClient +except ImportError: + from dimos.agents.modules.gateway import UnifiedGatewayClient + +logger = setup_logger() + +# Vision-capable models +VISION_MODELS = { + "openai::gpt-4o", + "openai::gpt-4o-mini", + "openai::gpt-4-turbo", + "openai::gpt-4-vision-preview", + "anthropic::claude-3-haiku-20240307", + "anthropic::claude-3-sonnet-20241022", + "anthropic::claude-3-opus-20240229", + "anthropic::claude-3-5-sonnet-20241022", + "anthropic::claude-3-5-haiku-latest", + "qwen::qwen-vl-plus", + "qwen::qwen-vl-max", +} + + +class BaseAgent: + """Base agent with all features including memory, skills, and multimodal support. + + This class provides: + - LLM gateway integration + - Conversation history + - Semantic memory (RAG) + - Skills/tools execution + - Multimodal support (text, images, data) + - Model capability detection + """ + + def __init__( # type: ignore[no-untyped-def] + self, + model: str = "openai::gpt-4o-mini", + system_prompt: str | None = None, + skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None, + memory: AbstractAgentSemanticMemory | None = None, + temperature: float = 0.0, + max_tokens: int = 4096, + max_input_tokens: int = 128000, + max_history: int = 20, + rag_n: int = 4, + rag_threshold: float = 0.45, + seed: int | None = None, + # Legacy compatibility + dev_name: str = "BaseAgent", + agent_type: str = "LLM", + **kwargs, + ) -> None: + """Initialize the base agent with all features. + + Args: + model: Model identifier (e.g., "openai::gpt-4o", "anthropic::claude-3-haiku") + system_prompt: System prompt for the agent + skills: Skills/tools available to the agent + memory: Semantic memory system for RAG + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + max_input_tokens: Maximum input tokens + max_history: Maximum conversation history to keep + rag_n: Number of RAG results to fetch + rag_threshold: Minimum similarity for RAG results + seed: Random seed for deterministic outputs (if supported by model) + dev_name: Device/agent name for logging + agent_type: Type of agent for logging + """ + self.model = model + self.system_prompt = system_prompt or "You are a helpful AI assistant." + self.temperature = temperature + self.max_tokens = max_tokens + self.max_input_tokens = max_input_tokens + self._max_history = max_history + self.rag_n = rag_n + self.rag_threshold = rag_threshold + self.seed = seed + self.dev_name = dev_name + self.agent_type = agent_type + + # Initialize skills + if skills is None: + self.skills = SkillLibrary() + elif isinstance(skills, SkillLibrary): + self.skills = skills + elif isinstance(skills, list): + self.skills = SkillLibrary() + for skill in skills: + self.skills.add(skill) + elif isinstance(skills, AbstractSkill): + self.skills = SkillLibrary() + self.skills.add(skills) + else: + self.skills = SkillLibrary() + + # Initialize memory - allow None for testing + if memory is False: # type: ignore[comparison-overlap] # Explicit False means no memory + self.memory = None + else: + self.memory = memory or OpenAISemanticMemory() # type: ignore[has-type] + + # Initialize gateway + self.gateway = UnifiedGatewayClient() + + # Conversation history with proper format management + self.conversation = ConversationHistory(max_size=self._max_history) + + # Thread pool for async operations + self._executor = ThreadPoolExecutor(max_workers=2) + + # Response subject for emitting responses + self.response_subject = Subject() # type: ignore[var-annotated] + + # Check model capabilities + self._supports_vision = self._check_vision_support() + + # Initialize memory with default context + self._initialize_memory() + + @property + def max_history(self) -> int: + """Get max history size.""" + return self._max_history + + @max_history.setter + def max_history(self, value: int) -> None: + """Set max history size and update conversation.""" + self._max_history = value + self.conversation.max_size = value + + def _check_vision_support(self) -> bool: + """Check if the model supports vision.""" + return self.model in VISION_MODELS + + def _initialize_memory(self) -> None: + """Initialize memory with default context.""" + try: + contexts = [ + ("ctx1", "I am an AI assistant that can help with various tasks."), + ("ctx2", f"I am using the {self.model} model."), + ( + "ctx3", + "I have access to tools and skills for specific operations." + if len(self.skills) > 0 + else "I do not have access to external tools.", + ), + ( + "ctx4", + "I can process images and visual content." + if self._supports_vision + else "I cannot process visual content.", + ), + ] + if self.memory: # type: ignore[has-type] + for doc_id, text in contexts: + self.memory.add_vector(doc_id, text) # type: ignore[has-type] + except Exception as e: + logger.warning(f"Failed to initialize memory: {e}") + + async def _process_query_async(self, agent_msg: AgentMessage) -> AgentResponse: + """Process query asynchronously and return AgentResponse.""" + query_text = agent_msg.get_combined_text() + logger.info(f"Processing query: {query_text}") + + # Get RAG context + rag_context = self._get_rag_context(query_text) + + # Check if trying to use images with non-vision model + if agent_msg.has_images() and not self._supports_vision: + logger.warning(f"Model {self.model} does not support vision. Ignoring image input.") + # Clear images from message + agent_msg.images.clear() + + # Build messages - pass AgentMessage directly + messages = self._build_messages(agent_msg, rag_context) + + # Get tools if available + tools = self.skills.get_tools() if len(self.skills) > 0 else None + + # Debug logging before gateway call + logger.debug("=== Gateway Request ===") + logger.debug(f"Model: {self.model}") + logger.debug(f"Number of messages: {len(messages)}") + for i, msg in enumerate(messages): + role = msg.get("role", "unknown") + content = msg.get("content", "") + if isinstance(content, str): + content_preview = content[:100] + elif isinstance(content, list): + content_preview = f"[{len(content)} content blocks]" + else: + content_preview = str(content)[:100] + logger.debug(f" Message {i}: role={role}, content={content_preview}...") + logger.debug(f"Tools available: {len(tools) if tools else 0}") + logger.debug("======================") + + # Prepare inference parameters + inference_params = { + "model": self.model, + "messages": messages, + "tools": tools, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "stream": False, + } + + # Add seed if provided + if self.seed is not None: + inference_params["seed"] = self.seed + + # Make inference call + response = await self.gateway.ainference(**inference_params) # type: ignore[arg-type] + + # Extract response + message = response["choices"][0]["message"] # type: ignore[index] + content = message.get("content", "") + + # Don't update history yet - wait until we have the complete interaction + # This follows Claude's pattern of locking history until tool execution is complete + + # Check for tool calls + tool_calls = None + if message.get("tool_calls"): + tool_calls = [ + ToolCall( + id=tc["id"], + name=tc["function"]["name"], + arguments=json.loads(tc["function"]["arguments"]), + status="pending", + ) + for tc in message["tool_calls"] + ] + + # Get the user message for history + user_message = messages[-1] + + # Handle tool calls (blocking by default) + final_content = await self._handle_tool_calls(tool_calls, messages, user_message) + + # Return response with tool information + return AgentResponse( + content=final_content, + role="assistant", + tool_calls=tool_calls, + requires_follow_up=False, # Already handled + metadata={"model": self.model}, + ) + else: + # No tools, add both user and assistant messages to history + # Get the user message content from the built message + user_msg = messages[-1] # Last message in messages is the user message + user_content = user_msg["content"] + + # Add to conversation history + logger.info("=== Adding to history (no tools) ===") + logger.info(f" Adding user message: {str(user_content)[:100]}...") + self.conversation.add_user_message(user_content) + logger.info(f" Adding assistant response: {content[:100]}...") + self.conversation.add_assistant_message(content) + logger.info(f" History size now: {self.conversation.size()}") + + return AgentResponse( + content=content, + role="assistant", + tool_calls=None, + requires_follow_up=False, + metadata={"model": self.model}, + ) + + def _get_rag_context(self, query: str) -> str: + """Get relevant context from memory.""" + if not self.memory: # type: ignore[has-type] + return "" + + try: + results = self.memory.query( # type: ignore[has-type] + query_texts=query, n_results=self.rag_n, similarity_threshold=self.rag_threshold + ) + + if results: + contexts = [doc.page_content for doc, _ in results] + return " | ".join(contexts) + except Exception as e: + logger.warning(f"RAG query failed: {e}") + + return "" + + def _build_messages( + self, agent_msg: AgentMessage, rag_context: str = "" + ) -> list[dict[str, Any]]: + """Build messages list from AgentMessage.""" + messages = [] + + # System prompt with RAG context if available + system_content = self.system_prompt + if rag_context: + system_content += f"\n\nRelevant context: {rag_context}" + messages.append({"role": "system", "content": system_content}) + + # Add conversation history in OpenAI format + history_messages = self.conversation.to_openai_format() + messages.extend(history_messages) + + # Debug history state + logger.info(f"=== Building messages with {len(history_messages)} history messages ===") + if history_messages: + for i, msg in enumerate(history_messages): + role = msg.get("role", "unknown") + content = msg.get("content", "") + if isinstance(content, str): + preview = content[:100] + elif isinstance(content, list): + preview = f"[{len(content)} content blocks]" + else: + preview = str(content)[:100] + logger.info(f" History[{i}]: role={role}, content={preview}") + + # Build user message content from AgentMessage + user_content = agent_msg.get_combined_text() if agent_msg.has_text() else "" + + # Handle images for vision models + if agent_msg.has_images() and self._supports_vision: + # Build content array with text and images + content = [] + if user_content: # Only add text if not empty + content.append({"type": "text", "text": user_content}) + + # Add all images from AgentMessage + for img in agent_msg.images: + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{img.base64_jpeg}"}, + } + ) + + logger.debug(f"Building message with {len(content)} content items (vision enabled)") + messages.append({"role": "user", "content": content}) # type: ignore[dict-item] + else: + # Text-only message + messages.append({"role": "user", "content": user_content}) + + return messages + + async def _handle_tool_calls( + self, + tool_calls: list[ToolCall], + messages: list[dict[str, Any]], + user_message: dict[str, Any], + ) -> str: + """Handle tool calls from LLM (blocking mode by default).""" + try: + # Build assistant message with tool calls + assistant_msg = { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": json.dumps(tc.arguments)}, + } + for tc in tool_calls + ], + } + messages.append(assistant_msg) + + # Execute tools and collect results + tool_results = [] + for tool_call in tool_calls: + logger.info(f"Executing tool: {tool_call.name}") + + try: + # Execute the tool + result = self.skills.call(tool_call.name, **tool_call.arguments) + tool_call.status = "completed" + + # Format tool result message + tool_result = { + "role": "tool", + "tool_call_id": tool_call.id, + "content": str(result), + "name": tool_call.name, + } + tool_results.append(tool_result) + + except Exception as e: + logger.error(f"Tool execution failed: {e}") + tool_call.status = "failed" + + # Add error result + tool_result = { + "role": "tool", + "tool_call_id": tool_call.id, + "content": f"Error: {e!s}", + "name": tool_call.name, + } + tool_results.append(tool_result) + + # Add tool results to messages + messages.extend(tool_results) + + # Prepare follow-up inference parameters + followup_params = { + "model": self.model, + "messages": messages, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + } + + # Add seed if provided + if self.seed is not None: + followup_params["seed"] = self.seed + + # Get follow-up response + response = await self.gateway.ainference(**followup_params) # type: ignore[arg-type] + + # Extract final response + final_message = response["choices"][0]["message"] # type: ignore[index] + + # Now add all messages to history in order (like Claude does) + # Add user message + user_content = user_message["content"] + self.conversation.add_user_message(user_content) + + # Add assistant message with tool calls + self.conversation.add_assistant_message("", tool_calls) + + # Add tool results + for result in tool_results: + self.conversation.add_tool_result( + tool_call_id=result["tool_call_id"], content=result["content"] + ) + + # Add final assistant response + final_content = final_message.get("content", "") + self.conversation.add_assistant_message(final_content) + + return final_message.get("content", "") # type: ignore[no-any-return] + + except Exception as e: + logger.error(f"Error handling tool calls: {e}") + return f"Error executing tools: {e!s}" + + def query(self, message: str | AgentMessage) -> AgentResponse: + """Synchronous query method for direct usage. + + Args: + message: Either a string query or an AgentMessage with text and/or images + + Returns: + AgentResponse object with content and tool information + """ + # Convert string to AgentMessage if needed + if isinstance(message, str): + agent_msg = AgentMessage() + agent_msg.add_text(message) + else: + agent_msg = message + + # Run async method in a new event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self._process_query_async(agent_msg)) + finally: + loop.close() + + async def aquery(self, message: str | AgentMessage) -> AgentResponse: + """Asynchronous query method. + + Args: + message: Either a string query or an AgentMessage with text and/or images + + Returns: + AgentResponse object with content and tool information + """ + # Convert string to AgentMessage if needed + if isinstance(message, str): + agent_msg = AgentMessage() + agent_msg.add_text(message) + else: + agent_msg = message + + return await self._process_query_async(agent_msg) + + def base_agent_dispose(self) -> None: + """Dispose of all resources and close gateway.""" + self.response_subject.on_completed() + if self._executor: + self._executor.shutdown(wait=False) + if self.gateway: + self.gateway.close() diff --git a/dimos/agents/modules/base_agent.py b/dimos/agents/modules/base_agent.py new file mode 100644 index 0000000000..8124c3718c --- /dev/null +++ b/dimos/agents/modules/base_agent.py @@ -0,0 +1,211 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base agent module that wraps BaseAgent for DimOS module usage.""" + +import threading +from typing import Any + +from dimos.agents.agent_message import AgentMessage +from dimos.agents.agent_types import AgentResponse +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.core import In, Module, Out, rpc +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.utils.logging_config import setup_logger + +try: + from .base import BaseAgent +except ImportError: + from dimos.agents.modules.base import BaseAgent + +logger = setup_logger() + + +class BaseAgentModule(BaseAgent, Module): # type: ignore[misc] + """Agent module that inherits from BaseAgent and adds DimOS module interface. + + This provides a thin wrapper around BaseAgent functionality, exposing it + through the DimOS module system with RPC methods and stream I/O. + """ + + # Module I/O - AgentMessage based communication + message_in: In[AgentMessage] = None # type: ignore[assignment] # Primary input for AgentMessage + response_out: Out[AgentResponse] = None # type: ignore[assignment] # Output AgentResponse objects + + def __init__( # type: ignore[no-untyped-def] + self, + model: str = "openai::gpt-4o-mini", + system_prompt: str | None = None, + skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None, + memory: AbstractAgentSemanticMemory | None = None, + temperature: float = 0.0, + max_tokens: int = 4096, + max_input_tokens: int = 128000, + max_history: int = 20, + rag_n: int = 4, + rag_threshold: float = 0.45, + process_all_inputs: bool = False, + **kwargs, + ) -> None: + """Initialize the agent module. + + Args: + model: Model identifier (e.g., "openai::gpt-4o", "anthropic::claude-3-haiku") + system_prompt: System prompt for the agent + skills: Skills/tools available to the agent + memory: Semantic memory system for RAG + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + max_input_tokens: Maximum input tokens + max_history: Maximum conversation history to keep + rag_n: Number of RAG results to fetch + rag_threshold: Minimum similarity for RAG results + process_all_inputs: Whether to process all inputs or drop when busy + **kwargs: Additional arguments passed to Module + """ + # Initialize Module first (important for DimOS) + Module.__init__(self, **kwargs) + + # Initialize BaseAgent with all functionality + BaseAgent.__init__( + self, + model=model, + system_prompt=system_prompt, + skills=skills, + memory=memory, + temperature=temperature, + max_tokens=max_tokens, + max_input_tokens=max_input_tokens, + max_history=max_history, + rag_n=rag_n, + rag_threshold=rag_threshold, + process_all_inputs=process_all_inputs, + # Don't pass streams - we'll connect them in start() + input_query_stream=None, + input_data_stream=None, + input_video_stream=None, + ) + + # Track module-specific subscriptions + self._module_disposables = [] # type: ignore[var-annotated] + + # For legacy stream support + self._latest_image = None + self._latest_data = None + self._image_lock = threading.Lock() + self._data_lock = threading.Lock() + + @rpc + def start(self) -> None: + """Start the agent module and connect streams.""" + super().start() + logger.info(f"Starting agent module with model: {self.model}") + + # Primary AgentMessage input + if self.message_in and self.message_in.connection is not None: + try: + disposable = self.message_in.observable().subscribe( # type: ignore[no-untyped-call] + lambda msg: self._handle_agent_message(msg) + ) + self._module_disposables.append(disposable) + except Exception as e: + logger.debug(f"Could not connect message_in: {e}") + + # Connect response output + if self.response_out: + disposable = self.response_subject.subscribe( + lambda response: self.response_out.publish(response) + ) + self._module_disposables.append(disposable) + + logger.info("Agent module started") + + @rpc + def stop(self) -> None: + """Stop the agent module.""" + logger.info("Stopping agent module") + + # Dispose module subscriptions + for disposable in self._module_disposables: + disposable.dispose() + self._module_disposables.clear() + + # Dispose BaseAgent resources + self.base_agent_dispose() + + logger.info("Agent module stopped") + super().stop() + + @rpc + def clear_history(self) -> None: + """Clear conversation history.""" + with self._history_lock: # type: ignore[attr-defined] + self.history = [] # type: ignore[var-annotated] + logger.info("Conversation history cleared") + + @rpc + def add_skill(self, skill: AbstractSkill) -> None: + """Add a skill to the agent.""" + self.skills.add(skill) + logger.info(f"Added skill: {skill.__class__.__name__}") + + @rpc + def set_system_prompt(self, prompt: str) -> None: + """Update system prompt.""" + self.system_prompt = prompt + logger.info("System prompt updated") + + @rpc + def get_conversation_history(self) -> list[dict[str, Any]]: + """Get current conversation history.""" + with self._history_lock: # type: ignore[attr-defined] + return self.history.copy() + + def _handle_agent_message(self, message: AgentMessage) -> None: + """Handle AgentMessage from module input.""" + # Process through BaseAgent query method + try: + response = self.query(message) + logger.debug(f"Publishing response: {response}") + self.response_subject.on_next(response) + except Exception as e: + logger.error(f"Agent message processing error: {e}") + self.response_subject.on_error(e) + + def _handle_module_query(self, query: str) -> None: + """Handle legacy query from module input.""" + # For simple text queries, just convert to AgentMessage + agent_msg = AgentMessage() + agent_msg.add_text(query) + + # Process through unified handler + self._handle_agent_message(agent_msg) + + def _update_latest_data(self, data: dict[str, Any]) -> None: + """Update latest data context.""" + with self._data_lock: + self._latest_data = data # type: ignore[assignment] + + def _update_latest_image(self, img: Any) -> None: + """Update latest image.""" + with self._image_lock: + self._latest_image = img + + def _format_data_context(self, data: dict[str, Any]) -> str: + """Format data dictionary as context string.""" + # Simple formatting - can be customized + parts = [] + for key, value in data.items(): + parts.append(f"{key}: {value}") + return "\n".join(parts) diff --git a/dimos/agents/modules/gateway/__init__.py b/dimos/agents/modules/gateway/__init__.py new file mode 100644 index 0000000000..7ae4beb037 --- /dev/null +++ b/dimos/agents/modules/gateway/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gateway module for unified LLM access.""" + +from .client import UnifiedGatewayClient +from .utils import convert_tools_to_standard_format, parse_streaming_response + +__all__ = ["UnifiedGatewayClient", "convert_tools_to_standard_format", "parse_streaming_response"] diff --git a/dimos/agents/modules/gateway/client.py b/dimos/agents/modules/gateway/client.py new file mode 100644 index 0000000000..6e3c6c6706 --- /dev/null +++ b/dimos/agents/modules/gateway/client.py @@ -0,0 +1,211 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified gateway client for LLM access.""" + +import asyncio +from collections.abc import AsyncIterator, Iterator +import logging +import os +from types import TracebackType +from typing import Any + +import httpx +from tenacity import retry, stop_after_attempt, wait_exponential + +from .tensorzero_embedded import TensorZeroEmbeddedGateway + +logger = logging.getLogger(__name__) + + +class UnifiedGatewayClient: + """Clean abstraction over TensorZero or other gateways. + + This client provides a unified interface for accessing multiple LLM providers + through a gateway service, with support for streaming, tools, and async operations. + """ + + def __init__( + self, gateway_url: str | None = None, timeout: float = 60.0, use_simple: bool = False + ) -> None: + """Initialize the gateway client. + + Args: + gateway_url: URL of the gateway service. Defaults to env var or localhost + timeout: Request timeout in seconds + use_simple: Deprecated parameter, always uses TensorZero + """ + self.gateway_url = gateway_url or os.getenv( + "TENSORZERO_GATEWAY_URL", "http://localhost:3000" + ) + self.timeout = timeout + self._client = None + self._async_client = None + + # Always use TensorZero embedded gateway + try: + self._tensorzero_client = TensorZeroEmbeddedGateway() + logger.info("Using TensorZero embedded gateway") + except Exception as e: + logger.error(f"Failed to initialize TensorZero: {e}") + raise + + def _get_client(self) -> httpx.Client: + """Get or create sync HTTP client.""" + if self._client is None: + self._client = httpx.Client( # type: ignore[assignment] + base_url=self.gateway_url, # type: ignore[arg-type] + timeout=self.timeout, + headers={"Content-Type": "application/json"}, + ) + return self._client # type: ignore[return-value] + + def _get_async_client(self) -> httpx.AsyncClient: + """Get or create async HTTP client.""" + if self._async_client is None: + self._async_client = httpx.AsyncClient( # type: ignore[assignment] + base_url=self.gateway_url, # type: ignore[arg-type] + timeout=self.timeout, + headers={"Content-Type": "application/json"}, + ) + return self._async_client # type: ignore[return-value] + + @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) + def inference( # type: ignore[no-untyped-def] + self, + model: str, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + temperature: float = 0.0, + max_tokens: int | None = None, + stream: bool = False, + **kwargs, + ) -> dict[str, Any] | Iterator[dict[str, Any]]: + """Synchronous inference call. + + Args: + model: Model identifier (e.g., "openai::gpt-4o") + messages: List of message dicts with role and content + tools: Optional list of tools in standard format + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + **kwargs: Additional model-specific parameters + + Returns: + Response dict or iterator of response chunks if streaming + """ + return self._tensorzero_client.inference( + model=model, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=stream, + **kwargs, + ) + + @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) + async def ainference( # type: ignore[no-untyped-def] + self, + model: str, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + temperature: float = 0.0, + max_tokens: int | None = None, + stream: bool = False, + **kwargs, + ) -> dict[str, Any] | AsyncIterator[dict[str, Any]]: + """Asynchronous inference call. + + Args: + model: Model identifier (e.g., "anthropic::claude-3-7-sonnet") + messages: List of message dicts with role and content + tools: Optional list of tools in standard format + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + stream: Whether to stream the response + **kwargs: Additional model-specific parameters + + Returns: + Response dict or async iterator of response chunks if streaming + """ + return await self._tensorzero_client.ainference( + model=model, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=stream, + **kwargs, + ) + + def close(self) -> None: + """Close the HTTP clients.""" + if self._client: + self._client.close() + self._client = None + if self._async_client: + # This needs to be awaited in an async context + # We'll handle this in __del__ with asyncio + pass + self._tensorzero_client.close() + + async def aclose(self) -> None: + """Async close method.""" + if self._async_client: + await self._async_client.aclose() + self._async_client = None + await self._tensorzero_client.aclose() + + def __del__(self) -> None: + """Cleanup on deletion.""" + self.close() + if self._async_client: + # Try to close async client if event loop is available + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.aclose()) + else: + loop.run_until_complete(self.aclose()) + except RuntimeError: + # No event loop, just let it be garbage collected + pass + + def __enter__(self): # type: ignore[no-untyped-def] + """Context manager entry.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Context manager exit.""" + self.close() + + async def __aenter__(self): # type: ignore[no-untyped-def] + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Async context manager exit.""" + await self.aclose() diff --git a/dimos/agents/modules/gateway/tensorzero_embedded.py b/dimos/agents/modules/gateway/tensorzero_embedded.py new file mode 100644 index 0000000000..4708788241 --- /dev/null +++ b/dimos/agents/modules/gateway/tensorzero_embedded.py @@ -0,0 +1,280 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorZero embedded gateway client with correct config format.""" + +from collections.abc import AsyncIterator, Iterator +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +class TensorZeroEmbeddedGateway: + """TensorZero embedded gateway using patch_openai_client.""" + + def __init__(self) -> None: + """Initialize TensorZero embedded gateway.""" + self._client = None + self._config_path = None + self._setup_config() + self._initialize_client() # type: ignore[no-untyped-call] + + def _setup_config(self) -> None: + """Create TensorZero configuration with correct format.""" + config_dir = Path("/tmp/tensorzero_embedded") + config_dir.mkdir(exist_ok=True) + self._config_path = config_dir / "tensorzero.toml" # type: ignore[assignment] + + # Create config using the correct format from working example + config_content = """ +# OpenAI Models +[models.gpt_4o_mini] +routing = ["openai"] + +[models.gpt_4o_mini.providers.openai] +type = "openai" +model_name = "gpt-4o-mini" + +[models.gpt_4o] +routing = ["openai"] + +[models.gpt_4o.providers.openai] +type = "openai" +model_name = "gpt-4o" + +# Claude Models +[models.claude_3_haiku] +routing = ["anthropic"] + +[models.claude_3_haiku.providers.anthropic] +type = "anthropic" +model_name = "claude-3-haiku-20240307" + +[models.claude_3_sonnet] +routing = ["anthropic"] + +[models.claude_3_sonnet.providers.anthropic] +type = "anthropic" +model_name = "claude-3-5-sonnet-20241022" + +[models.claude_3_opus] +routing = ["anthropic"] + +[models.claude_3_opus.providers.anthropic] +type = "anthropic" +model_name = "claude-3-opus-20240229" + +# Cerebras Models - disabled for CI (no API key) +# [models.llama_3_3_70b] +# routing = ["cerebras"] +# +# [models.llama_3_3_70b.providers.cerebras] +# type = "openai" +# model_name = "llama-3.3-70b" +# api_base = "https://api.cerebras.ai/v1" +# api_key_location = "env::CEREBRAS_API_KEY" + +# Qwen Models +[models.qwen_plus] +routing = ["qwen"] + +[models.qwen_plus.providers.qwen] +type = "openai" +model_name = "qwen-plus" +api_base = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1" +api_key_location = "env::ALIBABA_API_KEY" + +[models.qwen_vl_plus] +routing = ["qwen"] + +[models.qwen_vl_plus.providers.qwen] +type = "openai" +model_name = "qwen-vl-plus" +api_base = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1" +api_key_location = "env::ALIBABA_API_KEY" + +# Object storage - disable for embedded mode +[object_storage] +type = "disabled" + +# Single chat function with all models +# TensorZero will automatically skip models that don't support the input type +[functions.chat] +type = "chat" + +[functions.chat.variants.openai] +type = "chat_completion" +model = "gpt_4o_mini" +weight = 1.0 + +[functions.chat.variants.claude] +type = "chat_completion" +model = "claude_3_haiku" +weight = 0.5 + +# Cerebras disabled for CI (no API key) +# [functions.chat.variants.cerebras] +# type = "chat_completion" +# model = "llama_3_3_70b" +# weight = 0.0 + +[functions.chat.variants.qwen] +type = "chat_completion" +model = "qwen_plus" +weight = 0.3 + +# For vision queries, Qwen VL can be used +[functions.chat.variants.qwen_vision] +type = "chat_completion" +model = "qwen_vl_plus" +weight = 0.4 +""" + + with open(self._config_path, "w") as f: # type: ignore[call-overload] + f.write(config_content) + + logger.info(f"Created TensorZero config at {self._config_path}") + + def _initialize_client(self): # type: ignore[no-untyped-def] + """Initialize OpenAI client with TensorZero patch.""" + try: + from openai import OpenAI + from tensorzero import patch_openai_client + + self._client = OpenAI() # type: ignore[assignment] + + # Patch with TensorZero embedded gateway + patch_openai_client( + self._client, + clickhouse_url=None, # In-memory storage + config_file=str(self._config_path), + async_setup=False, + ) + + logger.info("TensorZero embedded gateway initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize TensorZero: {e}") + raise + + def _map_model_to_tensorzero(self, model: str) -> str: + """Map provider::model format to TensorZero function format.""" + # Always use the chat function - TensorZero will handle model selection + # based on input type and model capabilities automatically + return "tensorzero::function_name::chat" + + def inference( # type: ignore[no-untyped-def] + self, + model: str, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + temperature: float = 0.0, + max_tokens: int | None = None, + stream: bool = False, + **kwargs, + ) -> dict[str, Any] | Iterator[dict[str, Any]]: + """Synchronous inference call through TensorZero.""" + + # Map model to TensorZero function + tz_model = self._map_model_to_tensorzero(model) + + # Prepare parameters + params = { + "model": tz_model, + "messages": messages, + "temperature": temperature, + } + + if max_tokens: + params["max_tokens"] = max_tokens + + if tools: + params["tools"] = tools + + if stream: + params["stream"] = True + + # Add any extra kwargs + params.update(kwargs) + + try: + # Make the call through patched client + if stream: + # Return streaming iterator + stream_response = self._client.chat.completions.create(**params) # type: ignore[attr-defined] + + def stream_generator(): # type: ignore[no-untyped-def] + for chunk in stream_response: + yield chunk.model_dump() + + return stream_generator() # type: ignore[no-any-return, no-untyped-call] + else: + response = self._client.chat.completions.create(**params) # type: ignore[attr-defined] + return response.model_dump() # type: ignore[no-any-return] + + except Exception as e: + logger.error(f"TensorZero inference failed: {e}") + raise + + async def ainference( # type: ignore[no-untyped-def] + self, + model: str, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + temperature: float = 0.0, + max_tokens: int | None = None, + stream: bool = False, + **kwargs, + ) -> dict[str, Any] | AsyncIterator[dict[str, Any]]: + """Async inference with streaming support.""" + import asyncio + + loop = asyncio.get_event_loop() + + if stream: + # Create async generator from sync streaming + async def stream_generator(): # type: ignore[no-untyped-def] + # Run sync streaming in executor + sync_stream = await loop.run_in_executor( + None, + lambda: self.inference( + model, messages, tools, temperature, max_tokens, stream=True, **kwargs + ), + ) + + # Convert sync iterator to async + for chunk in sync_stream: + yield chunk + + return stream_generator() # type: ignore[no-any-return, no-untyped-call] + else: + result = await loop.run_in_executor( + None, + lambda: self.inference( + model, messages, tools, temperature, max_tokens, stream, **kwargs + ), + ) + return result # type: ignore[return-value] + + def close(self) -> None: + """Close the client.""" + # TensorZero embedded doesn't need explicit cleanup + pass + + async def aclose(self) -> None: + """Async close.""" + # TensorZero embedded doesn't need explicit cleanup + pass diff --git a/dimos/agents/modules/gateway/tensorzero_simple.py b/dimos/agents/modules/gateway/tensorzero_simple.py new file mode 100644 index 0000000000..4c9dbe4e26 --- /dev/null +++ b/dimos/agents/modules/gateway/tensorzero_simple.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal TensorZero test to get it working.""" + +from pathlib import Path + +from dotenv import load_dotenv +from openai import OpenAI +from tensorzero import patch_openai_client + +load_dotenv() + +# Create minimal config +config_dir = Path("/tmp/tz_test") +config_dir.mkdir(exist_ok=True) +config_path = config_dir / "tensorzero.toml" + +# Minimal config based on TensorZero docs +config = """ +[models.gpt_4o_mini] +routing = ["openai"] + +[models.gpt_4o_mini.providers.openai] +type = "openai" +model_name = "gpt-4o-mini" + +[functions.my_function] +type = "chat" + +[functions.my_function.variants.my_variant] +type = "chat_completion" +model = "gpt_4o_mini" +""" + +with open(config_path, "w") as f: + f.write(config) + +print(f"Created config at {config_path}") + +# Create OpenAI client +client = OpenAI() + +# Patch with TensorZero +try: + patch_openai_client( + client, + clickhouse_url=None, # In-memory + config_file=str(config_path), + async_setup=False, + ) + print("✅ TensorZero initialized successfully!") +except Exception as e: + print(f"❌ Failed to initialize TensorZero: {e}") + exit(1) + +# Test basic inference +print("\nTesting basic inference...") +try: + response = client.chat.completions.create( + model="tensorzero::function_name::my_function", + messages=[{"role": "user", "content": "What is 2+2?"}], + temperature=0.0, + max_tokens=10, + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + print("✅ Basic inference worked!") + +except Exception as e: + print(f"❌ Basic inference failed: {e}") + import traceback + + traceback.print_exc() + +print("\nTesting streaming...") +try: + stream = client.chat.completions.create( + model="tensorzero::function_name::my_function", + messages=[{"role": "user", "content": "Count from 1 to 3"}], + temperature=0.0, + max_tokens=20, + stream=True, + ) + + print("Stream response: ", end="", flush=True) + for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + print("\n✅ Streaming worked!") + +except Exception as e: + print(f"\n❌ Streaming failed: {e}") diff --git a/dimos/agents/modules/gateway/utils.py b/dimos/agents/modules/gateway/utils.py new file mode 100644 index 0000000000..526d3b9724 --- /dev/null +++ b/dimos/agents/modules/gateway/utils.py @@ -0,0 +1,156 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for gateway operations.""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def convert_tools_to_standard_format(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert DimOS tool format to standard format accepted by gateways. + + DimOS tools come from pydantic_function_tool and have this format: + { + "type": "function", + "function": { + "name": "tool_name", + "description": "tool description", + "parameters": { + "type": "object", + "properties": {...}, + "required": [...] + } + } + } + + We keep this format as it's already standard JSON Schema format. + """ + if not tools: + return [] + + # Tools are already in the correct format from pydantic_function_tool + return tools + + +def parse_streaming_response(chunk: dict[str, Any]) -> dict[str, Any]: + """Parse a streaming response chunk into a standard format. + + Args: + chunk: Raw chunk from the gateway + + Returns: + Parsed chunk with standard fields: + - type: "content" | "tool_call" | "error" | "done" + - content: The actual content (text for content type, tool info for tool_call) + - metadata: Additional information + """ + # Handle TensorZero streaming format + if "choices" in chunk: + # OpenAI-style format from TensorZero + choice = chunk["choices"][0] if chunk["choices"] else {} + delta = choice.get("delta", {}) + + if "content" in delta: + return { + "type": "content", + "content": delta["content"], + "metadata": {"index": choice.get("index", 0)}, + } + elif "tool_calls" in delta: + tool_calls = delta["tool_calls"] + if tool_calls: + tool_call = tool_calls[0] + return { + "type": "tool_call", + "content": { + "id": tool_call.get("id"), + "name": tool_call.get("function", {}).get("name"), + "arguments": tool_call.get("function", {}).get("arguments", ""), + }, + "metadata": {"index": tool_call.get("index", 0)}, + } + elif choice.get("finish_reason"): + return { + "type": "done", + "content": None, + "metadata": {"finish_reason": choice["finish_reason"]}, + } + + # Handle direct content chunks + if isinstance(chunk, str): + return {"type": "content", "content": chunk, "metadata": {}} + + # Handle error responses + if "error" in chunk: + return {"type": "error", "content": chunk["error"], "metadata": chunk} + + # Default fallback + return {"type": "unknown", "content": chunk, "metadata": {}} + + +def create_tool_response(tool_id: str, result: Any, is_error: bool = False) -> dict[str, Any]: + """Create a properly formatted tool response. + + Args: + tool_id: The ID of the tool call + result: The result from executing the tool + is_error: Whether this is an error response + + Returns: + Formatted tool response message + """ + content = str(result) if not isinstance(result, str) else result + + return { + "role": "tool", + "tool_call_id": tool_id, + "content": content, + "name": None, # Will be filled by the calling code + } + + +def extract_image_from_message(message: dict[str, Any]) -> dict[str, Any] | None: + """Extract image data from a message if present. + + Args: + message: Message dict that may contain image data + + Returns: + Dict with image data and metadata, or None if no image + """ + content = message.get("content", []) + + # Handle list content (multimodal) + if isinstance(content, list): + for item in content: + if isinstance(item, dict): + # OpenAI format + if item.get("type") == "image_url": + return { + "format": "openai", + "data": item["image_url"]["url"], + "detail": item["image_url"].get("detail", "auto"), + } + # Anthropic format + elif item.get("type") == "image": + return { + "format": "anthropic", + "data": item["source"]["data"], + "media_type": item["source"].get("media_type", "image/jpeg"), + } + + return None diff --git a/dimos/data/__init__.py b/dimos/agents/prompt_builder/__init__.py similarity index 100% rename from dimos/data/__init__.py rename to dimos/agents/prompt_builder/__init__.py diff --git a/dimos/agents/prompt_builder/impl.py b/dimos/agents/prompt_builder/impl.py new file mode 100644 index 0000000000..f3c95d3c8e --- /dev/null +++ b/dimos/agents/prompt_builder/impl.py @@ -0,0 +1,224 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from textwrap import dedent + +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer + +# TODO: Make class more generic when implementing other tokenizers. Presently its OpenAI specific. +# TODO: Build out testing and logging + + +class PromptBuilder: + DEFAULT_SYSTEM_PROMPT = dedent(""" + You are an AI assistant capable of understanding and analyzing both visual and textual information. + Your task is to provide accurate and insightful responses based on the data provided to you. + Use the following information to assist the user with their query. Do not rely on any internal + knowledge or make assumptions beyond the provided data. + + Visual Context: You may have been given an image to analyze. Use the visual details to enhance your response. + Textual Context: There may be some text retrieved from a relevant database to assist you + + Instructions: + - Combine insights from both the image and the text to answer the user's question. + - If the information is insufficient to provide a complete answer, acknowledge the limitation. + - Maintain a professional and informative tone in your response. + """) + + def __init__( + self, + model_name: str = "gpt-4o", + max_tokens: int = 128000, + tokenizer: AbstractTokenizer | None = None, + ) -> None: + """ + Initialize the prompt builder. + Args: + model_name (str): Model used (e.g., 'gpt-4o', 'gpt-4', 'gpt-3.5-turbo'). + max_tokens (int): Maximum tokens allowed in the input prompt. + tokenizer (AbstractTokenizer): The tokenizer to use for token counting and truncation. + """ + self.model_name = model_name + self.max_tokens = max_tokens + self.tokenizer: AbstractTokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name) + + def truncate_tokens(self, text: str, max_tokens, strategy): # type: ignore[no-untyped-def] + """ + Truncate text to fit within max_tokens using a specified strategy. + Args: + text (str): Input text to truncate. + max_tokens (int): Maximum tokens allowed. + strategy (str): Truncation strategy ('truncate_head', 'truncate_middle', 'truncate_end', 'do_not_truncate'). + Returns: + str: Truncated text. + """ + if strategy == "do_not_truncate" or not text: + return text + + tokens = self.tokenizer.tokenize_text(text) + if len(tokens) <= max_tokens: + return text + + if strategy == "truncate_head": + truncated = tokens[-max_tokens:] + elif strategy == "truncate_end": + truncated = tokens[:max_tokens] + elif strategy == "truncate_middle": + half = max_tokens // 2 + truncated = tokens[:half] + tokens[-half:] + else: + raise ValueError(f"Unknown truncation strategy: {strategy}") + + return self.tokenizer.detokenize_text(truncated) # type: ignore[no-untyped-call] + + def build( # type: ignore[no-untyped-def] + self, + system_prompt=None, + user_query=None, + base64_image=None, + image_width=None, + image_height=None, + image_detail: str = "low", + rag_context=None, + budgets=None, + policies=None, + override_token_limit: bool = False, + ): + """ + Builds a dynamic prompt tailored to token limits, respecting budgets and policies. + + Args: + system_prompt (str): System-level instructions. + user_query (str, optional): User's query. + base64_image (str, optional): Base64-encoded image string. + image_width (int, optional): Width of the image. + image_height (int, optional): Height of the image. + image_detail (str, optional): Detail level for the image ("low" or "high"). + rag_context (str, optional): Retrieved context. + budgets (dict, optional): Token budgets for each input type. Defaults to equal allocation. + policies (dict, optional): Truncation policies for each input type. + override_token_limit (bool, optional): Whether to override the token limit. Defaults to False. + + Returns: + dict: Messages array ready to send to the OpenAI API. + """ + if user_query is None: + raise ValueError("User query is required.") + + # Debug: + # base64_image = None + + budgets = budgets or { + "system_prompt": self.max_tokens // 4, + "user_query": self.max_tokens // 4, + "image": self.max_tokens // 4, + "rag": self.max_tokens // 4, + } + policies = policies or { + "system_prompt": "truncate_end", + "user_query": "truncate_middle", + "image": "do_not_truncate", + "rag": "truncate_end", + } + + # Validate and sanitize image_detail + if image_detail not in {"low", "high"}: + image_detail = "low" # Default to "low" if invalid or None + + # Determine which system prompt to use + if system_prompt is None: + system_prompt = self.DEFAULT_SYSTEM_PROMPT + + rag_context = rag_context or "" + + # Debug: + # print("system_prompt: ", system_prompt) + # print("rag_context: ", rag_context) + + # region Token Counts + if not override_token_limit: + rag_token_cnt = self.tokenizer.token_count(rag_context) + system_prompt_token_cnt = self.tokenizer.token_count(system_prompt) + user_query_token_cnt = self.tokenizer.token_count(user_query) + image_token_cnt = ( + self.tokenizer.image_token_count(image_width, image_height, image_detail) + if base64_image + else 0 + ) + else: + rag_token_cnt = 0 + system_prompt_token_cnt = 0 + user_query_token_cnt = 0 + image_token_cnt = 0 + # endregion Token Counts + + # Create a component dictionary for dynamic allocation + components = { + "system_prompt": {"text": system_prompt, "tokens": system_prompt_token_cnt}, + "user_query": {"text": user_query, "tokens": user_query_token_cnt}, + "image": {"text": None, "tokens": image_token_cnt}, + "rag": {"text": rag_context, "tokens": rag_token_cnt}, + } + + if not override_token_limit: + # Adjust budgets and apply truncation + total_tokens = sum(comp["tokens"] for comp in components.values()) + excess_tokens = total_tokens - self.max_tokens + if excess_tokens > 0: + for key, component in components.items(): + if excess_tokens <= 0: + break + if policies[key] != "do_not_truncate": + max_allowed = max(0, budgets[key] - excess_tokens) + components[key]["text"] = self.truncate_tokens( + component["text"], max_allowed, policies[key] + ) + tokens_after = self.tokenizer.token_count(components[key]["text"]) + excess_tokens -= component["tokens"] - tokens_after + component["tokens"] = tokens_after + + # Build the `messages` structure (OpenAI specific) + messages = [{"role": "system", "content": components["system_prompt"]["text"]}] + + if components["rag"]["text"]: + user_content = [ + { + "type": "text", + "text": f"{components['rag']['text']}\n\n{components['user_query']['text']}", + } + ] + else: + user_content = [{"type": "text", "text": components["user_query"]["text"]}] + + if base64_image: + user_content.append( + { + "type": "image_url", + "image_url": { # type: ignore[dict-item] + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": image_detail, + }, + } + ) + messages.append({"role": "user", "content": user_content}) + + # Debug: + # print("system_prompt: ", system_prompt) + # print("user_query: ", user_query) + # print("user_content: ", user_content) + # print(f"Messages: {messages}") + + return messages diff --git a/dimos/types/__init__.py b/dimos/agents/tokenizer/__init__.py similarity index 100% rename from dimos/types/__init__.py rename to dimos/agents/tokenizer/__init__.py diff --git a/dimos/agents/tokenizer/base.py b/dimos/agents/tokenizer/base.py new file mode 100644 index 0000000000..97535bcfaa --- /dev/null +++ b/dimos/agents/tokenizer/base.py @@ -0,0 +1,37 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +# TODO: Add a class for specific tokenizer exceptions +# TODO: Build out testing and logging +# TODO: Create proper doc strings after multiple tokenizers are implemented + + +class AbstractTokenizer(ABC): + @abstractmethod + def tokenize_text(self, text: str): # type: ignore[no-untyped-def] + pass + + @abstractmethod + def detokenize_text(self, tokenized_text): # type: ignore[no-untyped-def] + pass + + @abstractmethod + def token_count(self, text: str): # type: ignore[no-untyped-def] + pass + + @abstractmethod + def image_token_count(self, image_width, image_height, image_detail: str = "low"): # type: ignore[no-untyped-def] + pass diff --git a/dimos/agents/tokenizer/huggingface_tokenizer.py b/dimos/agents/tokenizer/huggingface_tokenizer.py new file mode 100644 index 0000000000..1e297000aa --- /dev/null +++ b/dimos/agents/tokenizer/huggingface_tokenizer.py @@ -0,0 +1,89 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoTokenizer # type: ignore[import-untyped] + +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.utils.logging_config import setup_logger + + +class HuggingFaceTokenizer(AbstractTokenizer): + def __init__(self, model_name: str = "Qwen/Qwen2.5-0.5B", **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + + # Initilize the tokenizer for the huggingface models + self.model_name = model_name + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + except Exception as e: + raise ValueError( + f"Failed to initialize tokenizer for model {self.model_name}. Error: {e!s}" + ) + + def tokenize_text(self, text: str): # type: ignore[no-untyped-def] + """ + Tokenize a text string using the openai tokenizer. + """ + return self.tokenizer.encode(text) + + def detokenize_text(self, tokenized_text): # type: ignore[no-untyped-def] + """ + Detokenize a text string using the openai tokenizer. + """ + try: + return self.tokenizer.decode(tokenized_text, errors="ignore") + except Exception as e: + raise ValueError(f"Failed to detokenize text. Error: {e!s}") + + def token_count(self, text: str): # type: ignore[no-untyped-def] + """ + Gets the token count of a text string using the openai tokenizer. + """ + return len(self.tokenize_text(text)) if text else 0 + + @staticmethod + def image_token_count(image_width, image_height, image_detail: str = "high"): # type: ignore[no-untyped-def] + """ + Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. + """ + logger = setup_logger() + + if image_detail == "low": + return 85 + elif image_detail == "high": + # Image dimensions + logger.debug(f"Image Width: {image_width}, Image Height: {image_height}") + if image_width is None or image_height is None: + raise ValueError( + "Image width and height must be provided for high detail image token count calculation." + ) + + # Scale image to fit within 2048 x 2048 + max_dimension = max(image_width, image_height) + if max_dimension > 2048: + scale_factor = 2048 / max_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Scale shortest side to 768px + min_dimension = min(image_width, image_height) + scale_factor = 768 / min_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Calculate number of 512px squares + num_squares = (image_width // 512) * (image_height // 512) + return 170 * num_squares + 85 + else: + raise ValueError("Detail specification of image is not 'low' or 'high'") diff --git a/dimos/agents/tokenizer/openai_tokenizer.py b/dimos/agents/tokenizer/openai_tokenizer.py new file mode 100644 index 0000000000..7bf3e3785b --- /dev/null +++ b/dimos/agents/tokenizer/openai_tokenizer.py @@ -0,0 +1,89 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tiktoken + +from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.utils.logging_config import setup_logger + + +class OpenAITokenizer(AbstractTokenizer): + def __init__(self, model_name: str = "gpt-4o", **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + + # Initilize the tokenizer for the openai set of models + self.model_name = model_name + try: + self.tokenizer = tiktoken.encoding_for_model(self.model_name) + except Exception as e: + raise ValueError( + f"Failed to initialize tokenizer for model {self.model_name}. Error: {e!s}" + ) + + def tokenize_text(self, text: str): # type: ignore[no-untyped-def] + """ + Tokenize a text string using the openai tokenizer. + """ + return self.tokenizer.encode(text) + + def detokenize_text(self, tokenized_text): # type: ignore[no-untyped-def] + """ + Detokenize a text string using the openai tokenizer. + """ + try: + return self.tokenizer.decode(tokenized_text, errors="ignore") + except Exception as e: + raise ValueError(f"Failed to detokenize text. Error: {e!s}") + + def token_count(self, text: str): # type: ignore[no-untyped-def] + """ + Gets the token count of a text string using the openai tokenizer. + """ + return len(self.tokenize_text(text)) if text else 0 + + @staticmethod + def image_token_count(image_width, image_height, image_detail: str = "high"): # type: ignore[no-untyped-def] + """ + Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. + """ + logger = setup_logger() + + if image_detail == "low": + return 85 + elif image_detail == "high": + # Image dimensions + logger.debug(f"Image Width: {image_width}, Image Height: {image_height}") + if image_width is None or image_height is None: + raise ValueError( + "Image width and height must be provided for high detail image token count calculation." + ) + + # Scale image to fit within 2048 x 2048 + max_dimension = max(image_width, image_height) + if max_dimension > 2048: + scale_factor = 2048 / max_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Scale shortest side to 768px + min_dimension = min(image_width, image_height) + scale_factor = 768 / min_dimension + image_width = int(image_width * scale_factor) + image_height = int(image_height * scale_factor) + + # Calculate number of 512px squares + num_squares = (image_width // 512) * (image_height // 512) + return 170 * num_squares + 85 + else: + raise ValueError("Detail specification of image is not 'low' or 'high'") diff --git a/dimos/agents2/__init__.py b/dimos/agents2/__init__.py new file mode 100644 index 0000000000..c817bb3aee --- /dev/null +++ b/dimos/agents2/__init__.py @@ -0,0 +1,13 @@ +from langchain_core.messages import ( + AIMessage, + HumanMessage, + MessageLikeRepresentation, + SystemMessage, + ToolCall, + ToolMessage, +) + +from dimos.agents2.agent import Agent, deploy +from dimos.agents2.spec import AgentSpec +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Output, Reducer, Stream diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py new file mode 100644 index 0000000000..8bf8e9625b --- /dev/null +++ b/dimos/agents2/agent.py @@ -0,0 +1,983 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LLM-based agent orchestration bridging reasoning with robot skill execution. + +This module implements DimOS's neurosymbolic agent architecture: LLM-based agents +that invoke robot skills through a structured tool-calling protocol. + +Core Classes +------------ +Agent + Base agent class requiring explicit loop control via `query()` or `agent_loop()`. + Integrates with `SkillCoordinator` to execute long-running skills asynchronously. + +LlmAgent + Agent variant that auto-starts its processing loop on `start()`. Useful for + blueprint composition with `autoconnect()` and `ModuleCoordinator`. + +Exports +------- +The module's `__all__` includes: + +- `Agent`: For explicit loop control +- `llm_agent`: Blueprint factory (`LlmAgent.blueprint`) for composition +- `deploy`: Convenience helper for standalone agent deployment + +Internal utilities (not exported): + +- `SkillStateSummary`: TypedDict for skill state snapshots in LLM messages +- `snapshot_to_messages`: Transform skill state to LangChain message protocol + +Architecture +------------ +Agents coordinate with a `SkillCoordinator` to discover skills, bind them as LLM +tools, and execute them asynchronously with streaming updates. +The event-driven loop alternates between LLM invocations and skill execution, +with state changes triggering agent calls. + +Testing +------- +For deterministic testing without LLM API calls, use `MockModel` from +`dimos.agents2.testing` to inject predetermined responses: + +>>> from dimos.agents2.testing import MockModel +>>> from langchain_core.messages import AIMessage +>>> mock = MockModel(responses=[AIMessage(content="Test response")]) +>>> agent = Agent(system_prompt="Test", model_instance=mock) + +See also +-------- +dimos.agents2.spec : AgentSpec base class defining the agent interface +dimos.protocol.skill.coordinator : SkillCoordinator for managing skill lifecycle +dimos.agents2.cli.human : HumanInput module for CLI-based agent interaction +dimos.agents2.testing : Testing utilities including MockModel +dimos.utils.cli.agentspy : CLI tool for real-time monitoring of agent messages +""" + +import asyncio +import datetime +import json +from operator import itemgetter +import os +from typing import Annotated, Any, TypedDict +import uuid + +from annotated_doc import Doc +from langchain.chat_models import init_chat_model +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline + +from dimos.agents2.ollama_agent import ensure_ollama_model +from dimos.agents2.spec import AgentSpec, Model, Provider +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core import DimosCluster, rpc +from dimos.protocol.skill.coordinator import ( + SkillCoordinator, + SkillState, + SkillStateDict, +) +from dimos.protocol.skill.skill import SkillContainer +from dimos.protocol.skill.type import Output +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +SYSTEM_MSG_APPEND = "\nYour message history will always be appended with a System Overview message that provides situational awareness." + + +def toolmsg_from_state(state: SkillState) -> ToolMessage: + if state.skill_config.output != Output.standard: + content = "output attached in separate messages" + else: + content = state.content() # type: ignore[assignment] + + return ToolMessage( + # if agent call has been triggered by another skill, + # and this specific skill didn't finish yet but we need a tool call response + # we return a message explaining that execution is still ongoing + content=content + or "Running, you will be called with an update, no need for subsequent tool calls", + name=state.name, + tool_call_id=state.call_id, + ) + + +class SkillStateSummary(TypedDict): + """Lightweight snapshot of skill execution state for LLM situational awareness. + + JSON-serializable representation of SkillState included in state overview messages + sent to LLM agents. Informs agents about skills running but not yet acknowledged + via ToolMessage, enabling tracking of ongoing operations. + + Typically created internally by the agent system. Users rarely construct these directly. + + Examples: + Creating a summary dictionary directly: + + >>> summary: SkillStateSummary = { + ... "name": "navigate_to", + ... "call_id": "abc-123", + ... "state": "running", + ... "data": "Moving to target location" + ... } + >>> print(summary["name"]) + navigate_to + >>> print(summary["state"]) + running + + Multiple summaries in a state overview message: + + >>> import json + >>> from langchain_core.messages import AIMessage + >>> + >>> summaries: list[SkillStateSummary] = [ + ... {"name": "scan_room", "call_id": "uuid1", "state": "running", "data": "Scanning..."}, + ... {"name": "navigate_to", "call_id": "uuid2", "state": "completed", "data": "Arrived"} + ... ] + >>> overview = "\\n".join(json.dumps(s) for s in summaries) + >>> msg = AIMessage(content=f"State Overview:\\n{overview}") + """ + + name: Annotated[ + str, Doc("The skill's registered identifier (e.g., 'navigate_to', 'scan_room').") + ] + call_id: Annotated[str, Doc("Unique identifier string for this specific skill invocation.")] + state: Annotated[str, Doc("Execution state: 'pending', 'running', 'completed', or 'error'.")] + data: Annotated[ + Any, + Doc( + """Skill output content or placeholder message. For standard output modes, + contains result of SkillState.content(). For Output.image, contains + literal string 'data will be in a separate message'.""" + ), + ] + + +def summary_from_state(state: SkillState, special_data: bool = False) -> SkillStateSummary: + content = state.content() + if isinstance(content, dict): + content = json.dumps(content) + + if not isinstance(content, str): + content = str(content) + + return { + "name": state.name, + "call_id": state.call_id, + "state": state.state.name, + "data": state.content() if not special_data else "data will be in a separate message", + } + + +def _custom_json_serializers(obj): # type: ignore[no-untyped-def] + if isinstance(obj, datetime.date | datetime.datetime): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + + +def snapshot_to_messages( + state: Annotated[ + SkillStateDict, + Doc( + """Snapshot from SkillCoordinator.generate_snapshot() mapping call_id to + SkillState objects with execution state, outputs, and configuration.""" + ), + ], + tool_calls: Annotated[ + list[ToolCall], + Doc( + """Tool calls from the previous agent message, used to match skills requiring + ToolMessage responses per LangChain's protocol.""" + ), + ], +) -> Annotated[ + dict, + Doc( + """Dictionary with three keys mapping to message lists: + 'tool_msgs' (list[ToolMessage]): Tool responses for skills matching tool_calls; + 'history_msgs' (list[HumanMessage]): Persistent messages from Output.human skills; + 'state_msgs' (list[AIMessage | HumanMessage]): Transient state awareness messages.""" + ), +]: + """Transform skill execution snapshot into LangChain messages for agent loop. + + Internal function called by Agent.agent_loop() at two points during execution. + Implements a three-tier message routing protocol separating tool responses + (satisfying LangChain's tool calling protocol) from state awareness messages + (tracking long-running skills) and persistent history (human input, critical events). + + Notes: + This is an internal transformation layer. Users should not call this directly. + + Skills are processed sorted by duration (shortest first). Routing rules by output type: + + - Output.standard: Tool response if matching call_id, else state overview + - Output.human: Always routes to history_msgs, bypassing tool protocol + - Output.image: Tool response with placeholder plus separate HumanMessage + + tool_msgs and history_msgs persist in conversation; state_msgs are transient + and replaced on each update cycle. + + See also: + Agent.agent_loop: Primary caller during state change detection and message generation. + toolmsg_from_state: Helper creating tool response messages. + summary_from_state: Helper creating state overview summaries. + """ + # builds a set of tool call ids from a previous agent request + tool_call_ids = set( + map(itemgetter("id"), tool_calls), + ) + + # build a tool msg responses + tool_msgs: list[ToolMessage] = [] + + # build a general skill state overview (for longer running skills) + state_overview: list[dict[str, SkillStateSummary]] = [] + + # for special skills that want to return a separate message + # (images for example, requires to be a HumanMessage) + special_msgs: list[HumanMessage] = [] + + # for special skills that want to return a separate message that should + # stay in history, like actual human messages, critical events + history_msgs: list[HumanMessage] = [] + + # Initialize state_msg + state_msg = None + + for skill_state in sorted( + state.values(), + key=lambda skill_state: skill_state.duration(), + ): + if skill_state.call_id in tool_call_ids: + tool_msgs.append(toolmsg_from_state(skill_state)) + + if skill_state.skill_config.output == Output.human: + content = skill_state.content() + if not content: + continue + history_msgs.append(HumanMessage(content=content)) # type: ignore[arg-type] + continue + + special_data = skill_state.skill_config.output == Output.image + if special_data: + content = skill_state.content() + if not content: + continue + special_msgs.append(HumanMessage(content=content)) # type: ignore[arg-type] + + if skill_state.call_id in tool_call_ids: + continue + + state_overview.append(summary_from_state(skill_state, special_data)) # type: ignore[arg-type] + + if state_overview: + state_overview_str = "\n".join( + json.dumps(s, default=_custom_json_serializers) for s in state_overview + ) + state_msg = AIMessage("State Overview:\n" + state_overview_str) + + return { # type: ignore[return-value] + "tool_msgs": tool_msgs, + "history_msgs": history_msgs, + "state_msgs": ([state_msg] if state_msg else []) + special_msgs, + } + + +class Agent(AgentSpec): + """Neurosymbolic orchestrator bridging LLM reasoning with robot skill execution. + + Implements an event-driven agent loop that alternates between LLM invocations + and skill execution. Maintains conversation history, coordinates skill lifecycle + through a `SkillCoordinator`, and transforms skill state updates into LangChain + messages for continued reasoning. + + Lifecycle: INITIALIZED → STARTED (after `start()`) → RUNNING (during `agent_loop()`) → back to STARTED (loop completes) → STOPPED (after `stop()`). + + Agent vs. LlmAgent: + Use `Agent` when you need explicit control over when the processing loop starts + (typically via `query()` calls). Use `LlmAgent` when you want the agent to + auto-start its loop on `start()`, which is essential for blueprint composition + with `autoconnect()` and `ModuleCoordinator`. + + Attributes: + coordinator (SkillCoordinator): + Manages skill registration, execution, and state tracking. + system_message (SystemMessage): + Initial system prompt appended with state overview notice. + state_messages (list[AIMessage | HumanMessage]): + Transient messages for current skill state; replaced each update cycle. + + Notes: + The agent loop terminates when `coordinator.has_active_skills()` returns False. + Skills with `Return.none`, `Return.passive`, `Stream.none`, or `Stream.passive` + don't prevent termination. + + For testing, use `MockModel` from `dimos.agents2.testing` to inject + deterministic responses without requiring real LLM API calls. + + See also: + LlmAgent: Auto-starts loop on `start()` for blueprint composition. + AgentSpec: Base class defining agent interface. + SkillCoordinator: Skill lifecycle manager. + query: Synchronous blocking interface for agent queries. + agent_loop: Core async processing loop. + + Examples: + >>> from dimos.agents2.agent import Agent + >>> from dimos.agents2.testing import MockModel + >>> from langchain_core.messages import AIMessage + >>> mock = MockModel(responses=[AIMessage(content="The answer is 42")]) + >>> agent = Agent(system_prompt="You are a helpful assistant.", model_instance=mock) + >>> agent.start() + >>> result = agent.query("What is the meaning of life?") + >>> result + 'The answer is 42' + >>> agent.stop() + """ + + system_message: SystemMessage + state_messages: list[AIMessage | HumanMessage] + + def __init__( # type: ignore[no-untyped-def] + self, + *args, + **kwargs, + ) -> None: + AgentSpec.__init__(self, *args, **kwargs) + + self.state_messages = [] + self.coordinator = SkillCoordinator() + self._history = [] # type: ignore[var-annotated] + self._agent_id = str(uuid.uuid4()) + self._agent_stopped = False + + if self.config.system_prompt: + if isinstance(self.config.system_prompt, str): + self.system_message = SystemMessage(self.config.system_prompt + SYSTEM_MSG_APPEND) + else: + self.config.system_prompt.content += SYSTEM_MSG_APPEND # type: ignore[operator] + self.system_message = self.config.system_prompt + else: + self.system_message = SystemMessage(get_system_prompt() + SYSTEM_MSG_APPEND) + + self.publish(self.system_message) + + # Use provided model instance if available, otherwise initialize from config + if self.config.model_instance: + self._llm = self.config.model_instance + else: + # For Ollama provider, ensure the model is available before initializing + if self.config.provider.value.lower() == "ollama": + ensure_ollama_model(self.config.model) + + # For HuggingFace, we need to create a pipeline and wrap it in ChatHuggingFace + if self.config.provider.value.lower() == "huggingface": + llm = HuggingFacePipeline.from_model_id( + model_id=self.config.model, + task="text-generation", + pipeline_kwargs={ + "max_new_tokens": 512, + "temperature": 0.7, + }, + ) + self._llm = ChatHuggingFace(llm=llm, model_id=self.config.model) + else: + self._llm = init_chat_model( # type: ignore[call-overload] + model_provider=self.config.provider, model=self.config.model + ) + + @rpc + def get_agent_id(self) -> str: + return self._agent_id + + @rpc + def start(self) -> None: + super().start() + self.coordinator.start() + + @rpc + def stop(self) -> None: + self.coordinator.stop() + self._agent_stopped = True + super().stop() + + def clear_history(self) -> None: + self._history.clear() + + def append_history(self, *msgs: AIMessage | HumanMessage) -> None: + for msg in msgs: + self.publish(msg) # type: ignore[arg-type] + + self._history.extend(msgs) + + def history(self): # type: ignore[no-untyped-def] + return [self.system_message, *self._history, *self.state_messages] + + # Used by agent to execute tool calls + def execute_tool_calls(self, tool_calls: list[ToolCall]) -> None: + """Execute a list of tool calls from the agent.""" + if self._agent_stopped: + logger.warning("Agent is stopped, cannot execute tool calls.") + return + for tool_call in tool_calls: + logger.info(f"executing skill call {tool_call}") + self.coordinator.call_skill( + tool_call.get("id"), # type: ignore[arg-type] + tool_call.get("name"), # type: ignore[arg-type] + tool_call.get("args"), # type: ignore[arg-type] + ) + + # used to inject skill calls into the agent loop without agent asking for it + def run_implicit_skill( + self, + skill_name: Annotated[ + str, + Doc( + """Name of the registered skill to invoke. Must match a skill in the + coordinator's registry.""" + ), + ], + **kwargs, + ) -> None: + """Inject skill invocation without agent awareness or decision-making. + + Programmatic skill execution that bypasses normal agent reasoning. Primary use + is bootstrapping agent sessions with initial skills like HumanInput, which + must run before the agent can begin processing queries. + + Differences from execute_tool_calls(): + - **Trigger source**: Programmatic/external vs. agent-initiated via LLM + - **Call ID**: Always uses `False` (auto-generated) vs. LLM-provided ID + - **Visibility**: Implicit to agent vs. tracked in conversation history + - **Use cases**: Bootstrap/events/background vs. deliberate agent tool use + + Examples: + >>> from dimos.agents2.agent import Agent + >>> from dimos.agents2.testing import MockModel + >>> from langchain_core.messages import AIMessage + >>> mock = MockModel(responses=[AIMessage(content="Ready")]) + >>> agent = Agent(system_prompt="Test assistant", model_instance=mock) + >>> agent.start() + >>> # In practice, register a SkillContainer first, then run its skill: + >>> # agent.register_skills(my_skill_container) + >>> # agent.run_implicit_skill("my_skill", param="value") + >>> agent.stop() + + Notes: + - Uses `call_id=False` to trigger auto-generation by the coordinator + - Skills execute asynchronously through the coordinator + - Silently returns with warning if agent is stopped + + See also: + execute_tool_calls: Handle agent-initiated tool calls with LLM-provided IDs. + Agent.agent_loop: Main processing loop that handles skill responses. + """ + if self._agent_stopped: + logger.warning("Agent is stopped, cannot execute implicit skill calls.") + return + self.coordinator.call_skill(False, skill_name, {"args": kwargs}) + + async def agent_loop( + self, + first_query: Annotated[ + str, + Doc( + """Initial human query to process. If provided, appended to conversation history + as a HumanMessage before the loop begins. Defaults to empty string, enabling + event-driven mode where the agent waits for skills or external events.""" + ), + ] = "", + ) -> Annotated[ + str | None, + Doc( + """The content of the final AIMessage from the LLM on normal termination. + Returns literal string 'Agent is stopped.' if called when agent is stopped. + Returns None if an exception occurs during execution.""" + ), + ]: + """Run the agent's core processing loop until all skills complete. + + Implements an event-driven execution cycle that alternates between LLM reasoning + and skill execution. Each iteration: (1) binds current tools, (2) invokes the LLM + with conversation history plus state overview, (3) executes any tool calls, + (4) waits for skill updates via async event notification, and (5) transforms + skill results into structured messages for the next iteration. + + Examples: + Most users should use `query()` instead, which has an executable doctest. + (xdoctest 1.3.0 doesn't detect doctests in ``async def`` methods.) + + >>> response = await agent.agent_loop("What is 2+2?") # doctest: +SKIP + + Notes: + **When to use vs. alternatives**: + - Use `agent_loop()` directly: In async contexts where you need explicit control + - Use `query()`: For synchronous/blocking calls from non-async code + - Use `loop_thread()`: For fire-and-forget background processing + + **Termination conditions**: + The loop terminates when `coordinator.has_active_skills()` returns False. + Active skills are those configured with `Return.call_agent` or `Stream.call_agent`. + Skills with `Return.none` or `Return.passive` don't prevent termination. + + See also: + query: Synchronous wrapper for agent_loop, blocks until completion. + loop_thread: Fire-and-forget variant that schedules agent_loop in background. + coordinator.has_active_skills: Method determining loop termination. + """ + # TODO: Should I add a lock here to prevent concurrent calls to agent_loop? + + if self._agent_stopped: + logger.warning("Agent is stopped, cannot run agent loop.") + # return "Agent is stopped." + import traceback + + traceback.print_stack() + return "Agent is stopped." + + self.state_messages = [] + if first_query: + self.append_history(HumanMessage(first_query)) + + def _get_state() -> str: + # TODO: FIX THIS EXTREME HACK + update = self.coordinator.generate_snapshot(clear=False) + snapshot_msgs = snapshot_to_messages(update, msg.tool_calls) # type: ignore[attr-defined] + return json.dumps(snapshot_msgs, sort_keys=True, default=lambda o: repr(o)) + + try: + while True: + # we are getting tools from the coordinator on each turn + # since this allows for skillcontainers to dynamically provide new skills + tools = self.get_tools() # type: ignore[no-untyped-call] + self._llm = self._llm.bind_tools(tools) # type: ignore[assignment] + + # publish to /agent topic for observability + for state_msg in self.state_messages: + self.publish(state_msg) + + # history() builds our message history dynamically + # ensures we include latest system state, but not old ones. + messages = self.history() # type: ignore[no-untyped-call] + + # Some LLMs don't work without any human messages. Add an initial one. + if len(messages) == 1 and isinstance(messages[0], SystemMessage): + messages.append( + HumanMessage( + "Everything is initialized. I'll let you know when you should act." + ) + ) + self.append_history(messages[-1]) + + msg = self._llm.invoke(messages) + + self.append_history(msg) + + logger.info(f"Agent response: {msg.content}") + + state = _get_state() + + if msg.tool_calls: # type: ignore[attr-defined] + self.execute_tool_calls(msg.tool_calls) # type: ignore[attr-defined] + + # print(self) + # print(self.coordinator) + + self._write_debug_history_file() + + if not self.coordinator.has_active_skills(): + logger.info("No active tasks, exiting agent loop.") + return msg.content + + # coordinator will continue once a skill state has changed in + # such a way that agent call needs to be executed + + if state == _get_state(): + await self.coordinator.wait_for_updates() + + # we request a full snapshot of currently running, finished or errored out skills + # we ask for removal of finished skills from subsequent snapshots (clear=True) + update = self.coordinator.generate_snapshot(clear=True) + + # generate tool_msgs and general state update message, + # depending on a skill having associated tool call from previous interaction + # we will return a tool message, and not a general state message + snapshot_msgs = snapshot_to_messages(update, msg.tool_calls) # type: ignore[attr-defined] + + self.state_messages = snapshot_msgs.get("state_msgs", []) # type: ignore[attr-defined] + self.append_history( + *snapshot_msgs.get("tool_msgs", []), # type: ignore[attr-defined] + *snapshot_msgs.get("history_msgs", []), # type: ignore[attr-defined] + ) + + except Exception as e: + logger.error(f"Error in agent loop: {e}") + import traceback + + traceback.print_exc() + + @rpc + def loop_thread( + self, + ) -> Annotated[ + bool, + Doc( + """Always returns True, indicating the loop was scheduled (not that it started + executing or will complete successfully).""" + ), + ]: + """Start the agent's autonomous execution loop in the background. + + Fire-and-forget method that schedules the agent loop without blocking. + Returns immediately while the agent continues processing asynchronously + in its event loop thread. Unlike query(), this never waits for completion. + + Examples: + Basic fire-and-forget usage: + + >>> from dimos.agents2 import Agent + >>> agent = Agent(system_prompt="You are a helpful assistant.") + >>> agent.start() + >>> agent.loop_thread() # Returns immediately + True + >>> agent.stop() + >>> # Agent was processing autonomously in background until stopped + + Notes: + - Agent loop executes with empty initial query (`agent_loop("")`) + - Multiple calls create concurrent loops (may cause race conditions) + - Called automatically by `LlmAgent.start()` for auto-start behavior + + **When to use vs. alternatives**: + - Use `loop_thread()`: Fire-and-forget background processing + - Use `query()`: Blocking call that waits for agent response + - Use `query_async()`: Async contexts requiring await + + See also: + query: Blocking method that waits for agent response. + query_async: Async version for use in async contexts. + LlmAgent.start: Automatically calls loop_thread() on startup. + """ + asyncio.run_coroutine_threadsafe(self.agent_loop(), self._loop) # type: ignore[arg-type] + return True + + @rpc + def query( + self, + query: Annotated[str, Doc("The user query to process.")], + ) -> Annotated[ + str | None, + Doc( + """The agent's response (final AIMessage content). + Returns 'Agent is stopped.' if agent was stopped. + Returns None on error (exception logged, not propagated).""" + ), + ]: + """Send a query to the agent and block until response. + + Synchronous wrapper around `agent_loop()` for use in non-async code. + Blocks the calling thread until completion. Can be called from any + thread, including RPC handlers. + + Notes: + Uses `asyncio.run_coroutine_threadsafe()` to safely schedule execution + on the agent's event loop from any calling thread. + + Examples: + >>> from dimos.agents2.agent import Agent + >>> from dimos.agents2.testing import MockModel + >>> from langchain_core.messages import AIMessage + >>> mock = MockModel(responses=[AIMessage(content="The answer is 4")]) + >>> agent = Agent(system_prompt="Math assistant", model_instance=mock) + >>> agent.start() + >>> result = agent.query("What is 2+2?") + >>> result + 'The answer is 4' + >>> agent.stop() + + See also: + query_async: Async version for use in async contexts. + agent_loop: The underlying async processing loop. + loop_thread: Fire-and-forget variant for background processing. + """ + # TODO: could this be + # from distributed.utils import sync + # return sync(self._loop, self.agent_loop, query) + return asyncio.run_coroutine_threadsafe(self.agent_loop(query), self._loop).result() # type: ignore[arg-type] + + async def query_async( + self, + query: Annotated[str, Doc("The user query to process.")], + ) -> Annotated[ + str | None, + Doc( + """The agent's response (final AIMessage content). + Returns 'Agent is stopped.' if agent was stopped. + Returns None on error (exception logged, not propagated).""" + ), + ]: + """Send a query to the agent and await the response. + + Async wrapper around `agent_loop()` for use in async contexts. + Directly awaits the agent loop in the **caller's** event loop. + + Notes: + The caller and agent should typically share the same loop; cross-loop awaiting + can cause issues with skill coordination. + Not RPC-decorated (use `query()` for RPC). + + Examples: + >>> import asyncio + >>> from dimos.agents2.agent import Agent + >>> from dimos.agents2.testing import MockModel + >>> from langchain_core.messages import AIMessage + >>> async def test_query_async(): + ... mock = MockModel(responses=[AIMessage(content="The answer is 4")]) + ... agent = Agent(system_prompt="Math assistant", model_instance=mock) + ... agent.start() + ... result = await agent.query_async("What is 2+2?") + ... agent.stop() + ... return result + >>> asyncio.run(test_query_async()) + 'The answer is 4' + + See also: + query: Synchronous/blocking version for non-async contexts. + agent_loop: The underlying async processing loop. + loop_thread: Fire-and-forget variant for background processing. + """ + return await self.agent_loop(query) + + @rpc + def register_skills( + self, + container: Annotated[ + SkillContainer, Doc("Skill container instance to register with the agent.") + ], + run_implicit_name: Annotated[ + str | None, + Doc( + """Optional skill name to run implicitly after registration. + Commonly used to auto-start streaming skills like HumanInput.""" + ), + ] = None, + ): + """Register a skill container with the agent's coordinator. + + Makes all @skill decorated methods from the container available to the agent + for LLM tool calling. Optionally runs a specified skill implicitly after registration. + + Examples: + Basic registration: + + >>> from dimos.agents2.agent import Agent + >>> from dimos.agents2.testing import MockModel + >>> from langchain_core.messages import AIMessage + >>> mock = MockModel(responses=[AIMessage(content="Ready")]) + >>> agent = Agent(system_prompt="Test assistant", model_instance=mock) + >>> agent.start() + >>> # In practice, pass actual SkillContainer instances: + >>> # agent.register_skills(skill_container) + >>> agent.stop() + + See also: + run_implicit_skill: Invoke skills without agent awareness. + SkillCoordinator.register_skills: Underlying registration mechanism. + """ + ret = self.coordinator.register_skills(container) # type: ignore[func-returns-value] + + if run_implicit_name: + self.run_implicit_skill(run_implicit_name) + + return ret + + def get_tools(self): # type: ignore[no-untyped-def] + return self.coordinator.get_tools() + + def _write_debug_history_file(self) -> None: + file_path = os.getenv("DEBUG_AGENT_HISTORY_FILE") + if not file_path: + return + + history = [x.__dict__ for x in self.history()] # type: ignore[no-untyped-call] + + with open(file_path, "w") as f: + json.dump(history, f, default=lambda x: repr(x), indent=2) + + +class LlmAgent(Agent): + """Agent that automatically starts its processing loop on startup. + + LlmAgent is especially useful when combining with other modules in a blueprint. + When `start()` is called, it automatically invokes `loop_thread()`, eliminating + manual loop initiation and making agents composable as standard modules. + + When to use each: + Use LlmAgent when: + - Using blueprint pattern with autoconnect() + - Agent should start processing immediately on system startup + - Building autonomous systems with ModuleCoordinator + + Use Agent when: + - Using the deploy() helper for standalone agents + - Need explicit control over when processing begins + - Building query-driven systems with explicit query() calls + + Examples: + Direct instantiation showing auto-start behavior: + + >>> from dimos.agents2.agent import Agent, LlmAgent + >>> from dimos.agents2.testing import MockModel + >>> from langchain_core.messages import AIMessage + >>> mock1 = MockModel(responses=[AIMessage(content="Ready")]) + >>> agent = Agent(system_prompt="Test", model_instance=mock1) + >>> agent.start() + >>> agent.loop_thread() # Manual loop start required + True + >>> agent.stop() + >>> + >>> mock2 = MockModel(responses=[AIMessage(content="Ready")]) + >>> agent2 = LlmAgent(system_prompt="Test", model_instance=mock2) + >>> agent2.start() # Automatically calls loop_thread() + >>> agent2.stop() + + Blueprint pattern (typical production usage): + + >>> from dimos.core.blueprints import autoconnect # doctest: +SKIP + >>> from dimos.agents2.agent import llm_agent # doctest: +SKIP + >>> from dimos.agents2.cli.human import human_input # doctest: +SKIP + >>> from dimos.agents2.skills.demo_calculator_skill import demo_calculator_skill # doctest: +SKIP + >>> blueprint = autoconnect( # doctest: +SKIP + ... demo_calculator_skill(), + ... llm_agent(system_prompt="You are a helpful assistant."), + ... human_input(), + ... ) + >>> coordinator = blueprint.build() # doctest: +SKIP + >>> coordinator.loop() # doctest: +SKIP + + Notes: + The implementation overrides only start() to add self.loop_thread() after + super().start(). All other methods are inherited from Agent unchanged. + + This makes LlmAgent compatible with ModuleCoordinator's uniform start()/stop() + interface, eliminating agent-specific initialization in the blueprint system. + + See also: + Agent: Base agent class with manual loop control. + deploy: Convenience helper for standalone agent deployment. + """ + + @rpc + def start(self) -> None: + super().start() + self.loop_thread() + + @rpc + def stop(self) -> None: + super().stop() + + +llm_agent = LlmAgent.blueprint + + +def deploy( + dimos: Annotated[DimosCluster, Doc("The DimosCluster instance for distributed deployment.")], + system_prompt: Annotated[ + str, Doc("Initial instructions for the LLM agent's behavior.") + ] = "You are a helpful assistant for controlling a Unitree Go2 robot.", + model: Annotated[ + Model, Doc("The LLM model to use (e.g., GPT_4O, CLAUDE_35_SONNET).") + ] = Model.GPT_4O, + provider: Annotated[ + Provider, Doc("The model provider (e.g., OPENAI, ANTHROPIC).") + ] = Provider.OPENAI, # type: ignore[attr-defined] + skill_containers: Annotated[ + list[SkillContainer] | None, + Doc("Optional list of skill containers to register with the agent."), + ] = None, +) -> Annotated[Agent, Doc("The deployed and running agent instance.")]: + """Convenience helper for deploying a standalone LLM agent with HumanInput skill. + + Creates a fixed configuration: Agent + HumanInput. + Starts immediately and returns running instance. + Cannot be composed with other modules after creation. + If you need to compose modules in a more flexible way, use the blueprint pattern instead. + + Examples: + >>> # +SKIP: no MockModel injection point; uses real LLM providers + >>> from dimos import core # doctest: +SKIP + >>> from dimos.agents2.skills.demo_calculator_skill import DemoCalculatorSkill # doctest: +SKIP + >>> cluster = core.start() # doctest: +SKIP + >>> calculator = cluster.deploy(DemoCalculatorSkill) # doctest: +SKIP + >>> agent = deploy( # doctest: +SKIP + ... cluster, + ... system_prompt="You are a helpful calculator assistant.", + ... skill_containers=[calculator] + ... ) + + By contrast, the blueprint pattern is more flexible: + + >>> from dimos.core.blueprints import autoconnect # doctest: +SKIP + >>> from dimos.agents2.agent import llm_agent # doctest: +SKIP + >>> from dimos.agents2.cli.human import human_input # doctest: +SKIP + >>> from dimos.agents2.skills.demo_calculator_skill import demo_calculator_skill # doctest: +SKIP + >>> + >>> blueprint = autoconnect( # doctest: +SKIP + ... demo_calculator_skill(), + ... llm_agent(system_prompt="You are a helpful calculator assistant."), + ... human_input() + ... ) + >>> coordinator = blueprint.build() # doctest: +SKIP + + See also: + llm_agent: Blueprint-based agent for flexible composition. + Agent: The agent class instantiated by this function. + """ + from dimos.agents2.cli.human import HumanInput + + if skill_containers is None: + skill_containers = [] + agent = dimos.deploy( # type: ignore[attr-defined] + Agent, + system_prompt=system_prompt, + model=model, + provider=provider, + ) + + human_input = dimos.deploy(HumanInput) # type: ignore[attr-defined] + human_input.start() + + agent.register_skills(human_input) + + for skill_container in skill_containers: + print("Registering skill container:", skill_container) + agent.register_skills(skill_container) + + agent.run_implicit_skill("human") + agent.start() + agent.loop_thread() + + return agent # type: ignore[no-any-return] + + +__all__ = ["Agent", "deploy", "llm_agent"] diff --git a/dimos/agents2/cli/human.py b/dimos/agents2/cli/human.py new file mode 100644 index 0000000000..aa0879a2b0 --- /dev/null +++ b/dimos/agents2/cli/human.py @@ -0,0 +1,57 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue + +from reactivex.disposable import Disposable + +from dimos.agents2 import Output, Reducer, Stream, skill # type: ignore[attr-defined] +from dimos.core import pLCMTransport, rpc +from dimos.core.module import Module +from dimos.core.rpc_client import RpcCall + + +class HumanInput(Module): + running: bool = False + + @skill(stream=Stream.call_agent, reducer=Reducer.string, output=Output.human, hide_skill=True) # type: ignore[arg-type] + def human(self): # type: ignore[no-untyped-def] + """receives human input, no need to run this, it's running implicitly""" + if self.running: + return "already running" + self.running = True + transport = pLCMTransport("/human_input") # type: ignore[var-annotated] + + msg_queue = queue.Queue() # type: ignore[var-annotated] + unsub = transport.subscribe(msg_queue.put) # type: ignore[func-returns-value] + self._disposables.add(Disposable(unsub)) + yield from iter(msg_queue.get, None) + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + @rpc + def set_LlmAgent_register_skills(self, callable: RpcCall) -> None: + callable.set_rpc(self.rpc) # type: ignore[arg-type] + callable(self, run_implicit_name="human") + + +human_input = HumanInput.blueprint + +__all__ = ["HumanInput", "human_input"] diff --git a/dimos/agents2/cli/web.py b/dimos/agents2/cli/web.py new file mode 100644 index 0000000000..09d5400cdc --- /dev/null +++ b/dimos/agents2/cli/web.py @@ -0,0 +1,87 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from threading import Thread +from typing import TYPE_CHECKING + +import reactivex as rx +import reactivex.operators as ops + +from dimos.core import Module, rpc +from dimos.core.transport import pLCMTransport +from dimos.stream.audio.node_normalizer import AudioNormalizer +from dimos.stream.audio.stt.node_whisper import WhisperNode +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + +if TYPE_CHECKING: + from dimos.stream.audio.base import AudioEvent + +logger = setup_logger() + + +class WebInput(Module): + _web_interface: RobotWebInterface | None = None + _thread: Thread | None = None + _human_transport: pLCMTransport[str] | None = None + + @rpc + def start(self) -> None: + super().start() + + self._human_transport = pLCMTransport("/human_input") + + audio_subject: rx.subject.Subject[AudioEvent] = rx.subject.Subject() + + self._web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": rx.subject.Subject()}, + audio_subject=audio_subject, + ) + + normalizer = AudioNormalizer() + stt_node = WhisperNode() + + # Connect audio pipeline: browser audio → normalizer → whisper + normalizer.consume_audio(audio_subject.pipe(ops.share())) + stt_node.consume_audio(normalizer.emit_audio()) + + # Subscribe to both text input sources + # 1. Direct text from web interface + unsub = self._web_interface.query_stream.subscribe(self._human_transport.publish) + self._disposables.add(unsub) + + # 2. Transcribed text from STT + unsub = stt_node.emit_text().subscribe(self._human_transport.publish) + self._disposables.add(unsub) + + self._thread = Thread(target=self._web_interface.run, daemon=True) + self._thread.start() + + logger.info("Web interface started at http://localhost:5555") + + @rpc + def stop(self) -> None: + if self._web_interface: + self._web_interface.shutdown() + if self._thread: + self._thread.join(timeout=1.0) + if self._human_transport: + self._human_transport.lcm.stop() + super().stop() + + +web_input = WebInput.blueprint + +__all__ = ["WebInput", "web_input"] diff --git a/dimos/agents2/conftest.py b/dimos/agents2/conftest.py new file mode 100644 index 0000000000..c113bb4c27 --- /dev/null +++ b/dimos/agents2/conftest.py @@ -0,0 +1,85 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + +from dimos.agents2.agent import Agent +from dimos.agents2.testing import MockModel +from dimos.protocol.skill.test_coordinator import SkillContainerTest + + +@pytest.fixture +def fixture_dir(): + return Path(__file__).parent / "fixtures" + + +@pytest.fixture +def potato_system_prompt() -> str: + return "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" + + +@pytest.fixture +def skill_container(): + container = SkillContainerTest() + try: + yield container + finally: + container.stop() + + +@pytest.fixture +def create_fake_agent(fixture_dir): + agent = None + + def _agent_factory(*, system_prompt, skill_containers, fixture): + mock_model = MockModel(json_path=fixture_dir / fixture) + + nonlocal agent + agent = Agent(system_prompt=system_prompt, model_instance=mock_model) + + for skill_container in skill_containers: + agent.register_skills(skill_container) + + agent.start() + + return agent + + try: + yield _agent_factory + finally: + if agent: + agent.stop() + + +@pytest.fixture +def create_potato_agent(potato_system_prompt, skill_container, fixture_dir): + agent = None + + def _agent_factory(*, fixture): + mock_model = MockModel(json_path=fixture_dir / fixture) + + nonlocal agent + agent = Agent(system_prompt=potato_system_prompt, model_instance=mock_model) + agent.register_skills(skill_container) + agent.start() + + return agent + + try: + yield _agent_factory + finally: + if agent: + agent.stop() diff --git a/dimos/agents2/constants.py b/dimos/agents2/constants.py new file mode 100644 index 0000000000..f363d1ff88 --- /dev/null +++ b/dimos/agents2/constants.py @@ -0,0 +1,17 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.constants import DIMOS_PROJECT_ROOT + +AGENT_SYSTEM_PROMPT_PATH = DIMOS_PROJECT_ROOT / "assets/agent/prompt_agents2.txt" diff --git a/dimos/agents2/fixtures/test_get_gps_position_for_queries.json b/dimos/agents2/fixtures/test_get_gps_position_for_queries.json new file mode 100644 index 0000000000..5d95b91bac --- /dev/null +++ b/dimos/agents2/fixtures/test_get_gps_position_for_queries.json @@ -0,0 +1,25 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "get_gps_position_for_queries", + "args": { + "args": [ + "Hyde Park", + "Regent Park", + "Russell Park" + ] + }, + "id": "call_xO0VDst53tzetEUq8mapKGS1", + "type": "tool_call" + } + ] + }, + { + "content": "Here are the latitude and longitude coordinates for the parks:\n\n- Hyde Park: Latitude 37.782601, Longitude -122.413201\n- Regent Park: Latitude 37.782602, Longitude -122.413202\n- Russell Park: Latitude 37.782603, Longitude -122.413203", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_go_to_object.json b/dimos/agents2/fixtures/test_go_to_object.json new file mode 100644 index 0000000000..80f1e95379 --- /dev/null +++ b/dimos/agents2/fixtures/test_go_to_object.json @@ -0,0 +1,27 @@ +{ + "responses": [ + { + "content": "I will navigate to the nearest chair.", + "tool_calls": [ + { + "name": "navigate_with_text", + "args": { + "args": [ + "chair" + ] + }, + "id": "call_LP4eewByfO9XaxMtnnWxDUz7", + "type": "tool_call" + } + ] + }, + { + "content": "I'm on my way to the chair. Let me know if there's anything else you'd like me to do!", + "tool_calls": [] + }, + { + "content": "I have successfully navigated to the chair. Let me know if you need anything else!", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_go_to_semantic_location.json b/dimos/agents2/fixtures/test_go_to_semantic_location.json new file mode 100644 index 0000000000..1a10711543 --- /dev/null +++ b/dimos/agents2/fixtures/test_go_to_semantic_location.json @@ -0,0 +1,23 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "navigate_with_text", + "args": { + "args": [ + "bookshelf" + ] + }, + "id": "call_yPoqcavMo05ogNNy5LMNQl2a", + "type": "tool_call" + } + ] + }, + { + "content": "I have successfully arrived at the bookshelf. Is there anything specific you need here?", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_how_much_is_124181112_plus_124124.json b/dimos/agents2/fixtures/test_how_much_is_124181112_plus_124124.json new file mode 100644 index 0000000000..f4dbe0c3a5 --- /dev/null +++ b/dimos/agents2/fixtures/test_how_much_is_124181112_plus_124124.json @@ -0,0 +1,52 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "add", + "args": { + "args": [ + 124181112, + 124124 + ] + }, + "id": "call_SSoVXz5yihrzR8TWIGnGKSpi", + "type": "tool_call" + } + ] + }, + { + "content": "Let me do some potato math... Calculating this will take some time, hold on! \ud83e\udd54", + "tool_calls": [] + }, + { + "content": "The result of adding 124,181,112 and 124,124 is 124,305,236. Potatoes work well with tools! \ud83e\udd54\ud83c\udf89", + "tool_calls": [] + }, + { + "content": "", + "tool_calls": [ + { + "name": "add", + "args": { + "args": [ + 1000000000, + -1000000 + ] + }, + "id": "call_ge9pv6IRa3yo0vjVaORvrGby", + "type": "tool_call" + } + ] + }, + { + "content": "Let's get those numbers crunched. Potatoes need a bit of time! \ud83e\udd54\ud83d\udcca", + "tool_calls": [] + }, + { + "content": "The result of one billion plus negative one million is 999,000,000. Potatoes are amazing with some help! \ud83e\udd54\ud83d\udca1", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_pounce.json b/dimos/agents2/fixtures/test_pounce.json new file mode 100644 index 0000000000..99e84d003a --- /dev/null +++ b/dimos/agents2/fixtures/test_pounce.json @@ -0,0 +1,38 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "execute_sport_command", + "args": { + "args": [ + "FrontPounce" + ] + }, + "id": "call_Ukj6bCAnHQLj28RHRp697blZ", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "speak", + "args": { + "args": [ + "I have successfully performed a front pounce." + ] + }, + "id": "call_FR9DtqEvJ9zSY85qVD2UFrll", + "type": "tool_call" + } + ] + }, + { + "content": "I have successfully performed a front pounce.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_set_gps_travel_points.json b/dimos/agents2/fixtures/test_set_gps_travel_points.json new file mode 100644 index 0000000000..eb5b2a9195 --- /dev/null +++ b/dimos/agents2/fixtures/test_set_gps_travel_points.json @@ -0,0 +1,30 @@ +{ + "responses": [ + { + "content": "I understand you want me to navigate to the specified location. I will set the GPS travel point accordingly.", + "tool_calls": [ + { + "name": "set_gps_travel_points", + "args": { + "args": [ + { + "lat": 37.782654, + "lon": -122.413273 + } + ] + }, + "id": "call_q6JCCYFuyAjqUgUibJHqcIMD", + "type": "tool_call" + } + ] + }, + { + "content": "I'm on my way to the specified location. Let me know if there is anything else I can assist you with!", + "tool_calls": [] + }, + { + "content": "I've reached the specified location. Do you need any further assistance?", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_set_gps_travel_points_multiple.json b/dimos/agents2/fixtures/test_set_gps_travel_points_multiple.json new file mode 100644 index 0000000000..9d8f7e9e00 --- /dev/null +++ b/dimos/agents2/fixtures/test_set_gps_travel_points_multiple.json @@ -0,0 +1,34 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "set_gps_travel_points", + "args": { + "args": [ + { + "lat": 37.782654, + "lon": -122.413273 + }, + { + "lat": 37.78266, + "lon": -122.41326 + }, + { + "lat": 37.78267, + "lon": -122.41327 + } + ] + }, + "id": "call_Q09MRMEgRnJPBOGZpM0j8sL2", + "type": "tool_call" + } + ] + }, + { + "content": "I've successfully set the travel points and will navigate to them sequentially.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_show_your_love.json b/dimos/agents2/fixtures/test_show_your_love.json new file mode 100644 index 0000000000..941906e781 --- /dev/null +++ b/dimos/agents2/fixtures/test_show_your_love.json @@ -0,0 +1,38 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "execute_sport_command", + "args": { + "args": [ + "FingerHeart" + ] + }, + "id": "call_VFp6x9F00FBmiiUiemFWewop", + "type": "tool_call" + } + ] + }, + { + "content": "", + "tool_calls": [ + { + "name": "speak", + "args": { + "args": [ + "Here's a gesture to show you some love!" + ] + }, + "id": "call_WUUmBJ95s9PtVx8YQsmlJ4EU", + "type": "tool_call" + } + ] + }, + { + "content": "Just did a finger heart gesture to show my affection!", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_stop_movement.json b/dimos/agents2/fixtures/test_stop_movement.json new file mode 100644 index 0000000000..b80834213e --- /dev/null +++ b/dimos/agents2/fixtures/test_stop_movement.json @@ -0,0 +1,21 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "stop_movement", + "args": { + "args": null + }, + "id": "call_oAKe9W8s3xRGioZhBJJDOZB1", + "type": "tool_call" + } + ] + }, + { + "content": "I have stopped moving. Let me know if you need anything else!", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_take_a_look_around.json b/dimos/agents2/fixtures/test_take_a_look_around.json new file mode 100644 index 0000000000..c30fe71017 --- /dev/null +++ b/dimos/agents2/fixtures/test_take_a_look_around.json @@ -0,0 +1,23 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "start_exploration", + "args": { + "args": [ + 10 + ] + }, + "id": "call_AMNeD8zTkvyFHKG90DriDPuM", + "type": "tool_call" + } + ] + }, + { + "content": "I have completed a brief exploration of the surroundings. Let me know if there's anything specific you need!", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_what_do_you_see_in_this_picture.json b/dimos/agents2/fixtures/test_what_do_you_see_in_this_picture.json new file mode 100644 index 0000000000..27ac3453bc --- /dev/null +++ b/dimos/agents2/fixtures/test_what_do_you_see_in_this_picture.json @@ -0,0 +1,25 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "take_photo", + "args": { + "args": [] + }, + "id": "call_o6ikJtK3vObuEFD6hDtLoyGQ", + "type": "tool_call" + } + ] + }, + { + "content": "I took a photo, but as an AI, I can't see or interpret images. If there's anything specific you need to know, feel free to ask!", + "tool_calls": [] + }, + { + "content": "It looks like a cozy outdoor cafe where people are sitting and enjoying a meal. There are flowers and a nice, sunny ambiance. If you have any specific questions about the image, let me know!", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_what_is_your_name.json b/dimos/agents2/fixtures/test_what_is_your_name.json new file mode 100644 index 0000000000..a74d793b1d --- /dev/null +++ b/dimos/agents2/fixtures/test_what_is_your_name.json @@ -0,0 +1,8 @@ +{ + "responses": [ + { + "content": "Hi! My name is Mr. Potato. How can I assist you today?", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_where_am_i.json b/dimos/agents2/fixtures/test_where_am_i.json new file mode 100644 index 0000000000..2d274f8fa6 --- /dev/null +++ b/dimos/agents2/fixtures/test_where_am_i.json @@ -0,0 +1,21 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "where_am_i", + "args": { + "args": [] + }, + "id": "call_uRJLockZ5JWtGWbsSL1dpHm3", + "type": "tool_call" + } + ] + }, + { + "content": "You are on Bourbon Street.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/ollama_agent.py b/dimos/agents2/ollama_agent.py new file mode 100644 index 0000000000..4b35cc84f8 --- /dev/null +++ b/dimos/agents2/ollama_agent.py @@ -0,0 +1,39 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ollama + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def ensure_ollama_model(model_name: str) -> None: + available_models = ollama.list() + model_exists = any(model_name == m.model for m in available_models.models) + if not model_exists: + logger.info(f"Ollama model '{model_name}' not found. Pulling...") + ollama.pull(model_name) + + +def ollama_installed() -> str | None: + try: + ollama.list() + return None + except Exception: + return ( + "Cannot connect to Ollama daemon. Please ensure Ollama is installed and running.\n" + "\n" + " For installation instructions, visit https://ollama.com/download" + ) diff --git a/dimos/agents2/skills/conftest.py b/dimos/agents2/skills/conftest.py new file mode 100644 index 0000000000..ec76d83628 --- /dev/null +++ b/dimos/agents2/skills/conftest.py @@ -0,0 +1,117 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import pytest +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.agents2.skills.google_maps_skill_container import GoogleMapsSkillContainer +from dimos.agents2.skills.gps_nav_skill import GpsNavSkillContainer +from dimos.agents2.skills.navigation import NavigationSkillContainer +from dimos.agents2.system_prompt import get_system_prompt +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer + +system_prompt = get_system_prompt() + + +@pytest.fixture(autouse=True) +def cleanup_threadpool_scheduler(monkeypatch): + # TODO: get rid of this global threadpool + """Clean up and recreate the global ThreadPoolScheduler after each test.""" + # Disable ChromaDB telemetry to avoid leaking threads + monkeypatch.setenv("CHROMA_ANONYMIZED_TELEMETRY", "False") + yield + from dimos.utils import threadpool + + # Shutdown the global scheduler's executor + threadpool.scheduler.executor.shutdown(wait=True) + # Recreate it for the next test + threadpool.scheduler = ThreadPoolScheduler(max_workers=threadpool.get_max_workers()) + + +@pytest.fixture +def navigation_skill_container(mocker): + container = NavigationSkillContainer() + container.color_image.connection = mocker.MagicMock() + container.odom.connection = mocker.MagicMock() + container.start() + yield container + container.stop() + + +@pytest.fixture +def gps_nav_skill_container(mocker): + container = GpsNavSkillContainer() + container.gps_location.connection = mocker.MagicMock() + container.gps_goal = mocker.MagicMock() + container.start() + yield container + container.stop() + + +@pytest.fixture +def google_maps_skill_container(mocker): + container = GoogleMapsSkillContainer() + container.gps_location.connection = mocker.MagicMock() + container.start() + container._client = mocker.MagicMock() + yield container + container.stop() + + +@pytest.fixture +def unitree_skills(mocker): + container = UnitreeSkillContainer() + container._move = mocker.Mock() + container._publish_request = mocker.Mock() + container.start() + yield container + container.stop() + + +@pytest.fixture +def create_navigation_agent(navigation_skill_container, create_fake_agent): + return partial( + create_fake_agent, + system_prompt=system_prompt, + skill_containers=[navigation_skill_container], + ) + + +@pytest.fixture +def create_gps_nav_agent(gps_nav_skill_container, create_fake_agent): + return partial( + create_fake_agent, system_prompt=system_prompt, skill_containers=[gps_nav_skill_container] + ) + + +@pytest.fixture +def create_google_maps_agent( + gps_nav_skill_container, google_maps_skill_container, create_fake_agent +): + return partial( + create_fake_agent, + system_prompt=system_prompt, + skill_containers=[gps_nav_skill_container, google_maps_skill_container], + ) + + +@pytest.fixture +def create_unitree_skills_agent(unitree_skills, create_fake_agent): + return partial( + create_fake_agent, + system_prompt=system_prompt, + skill_containers=[unitree_skills], + ) diff --git a/dimos/agents2/skills/demo_calculator_skill.py b/dimos/agents2/skills/demo_calculator_skill.py new file mode 100644 index 0000000000..2ed8050ca5 --- /dev/null +++ b/dimos/agents2/skills/demo_calculator_skill.py @@ -0,0 +1,43 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.skill_module import SkillModule +from dimos.protocol.skill.skill import skill + + +class DemoCalculatorSkill(SkillModule): + def start(self) -> None: + super().start() + + def stop(self) -> None: + super().stop() + + @skill() + def sum_numbers(self, n1: int, n2: int, *args: int, **kwargs: int) -> str: + """This skill adds two numbers. Always use this tool. Never add up numbers yourself. + + Example: + + sum_numbers(100, 20) + + Args: + sum (str): The sum, as a string. E.g., "120" + """ + + return f"{int(n1) + int(n2)}" + + +demo_calculator_skill = DemoCalculatorSkill.blueprint + +__all__ = ["DemoCalculatorSkill", "demo_calculator_skill"] diff --git a/dimos/agents2/skills/demo_google_maps_skill.py b/dimos/agents2/skills/demo_google_maps_skill.py new file mode 100644 index 0000000000..132c1ad013 --- /dev/null +++ b/dimos/agents2/skills/demo_google_maps_skill.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dotenv import load_dotenv + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.demo_robot import demo_robot +from dimos.agents2.skills.google_maps_skill_container import google_maps_skill +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core.blueprints import autoconnect + +load_dotenv() + + +demo_google_maps_skill = autoconnect( + demo_robot(), + google_maps_skill(), + human_input(), + llm_agent(system_prompt=get_system_prompt()), +) diff --git a/dimos/agents2/skills/demo_gps_nav.py b/dimos/agents2/skills/demo_gps_nav.py new file mode 100644 index 0000000000..74dad77c7a --- /dev/null +++ b/dimos/agents2/skills/demo_gps_nav.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dotenv import load_dotenv + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.demo_robot import demo_robot +from dimos.agents2.skills.gps_nav_skill import gps_nav_skill +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core.blueprints import autoconnect + +load_dotenv() + + +demo_gps_nav_skill = autoconnect( + demo_robot(), + gps_nav_skill(), + human_input(), + llm_agent(system_prompt=get_system_prompt()), +) diff --git a/dimos/agents2/skills/demo_robot.py b/dimos/agents2/skills/demo_robot.py new file mode 100644 index 0000000000..829fcb576a --- /dev/null +++ b/dimos/agents2/skills/demo_robot.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from reactivex import interval + +from dimos.core.module import Module +from dimos.core.stream import Out +from dimos.mapping.types import LatLon + + +class DemoRobot(Module): + gps_location: Out[LatLon] = None # type: ignore[assignment] + + def start(self) -> None: + super().start() + self._disposables.add(interval(1.0).subscribe(lambda _: self._publish_gps_location())) + + def stop(self) -> None: + super().stop() + + def _publish_gps_location(self) -> None: + self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) + + +demo_robot = DemoRobot.blueprint + + +__all__ = ["DemoRobot", "demo_robot"] diff --git a/dimos/agents2/skills/demo_skill.py b/dimos/agents2/skills/demo_skill.py new file mode 100644 index 0000000000..5c3c6e73c6 --- /dev/null +++ b/dimos/agents2/skills/demo_skill.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dotenv import load_dotenv + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.demo_calculator_skill import demo_calculator_skill +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core.blueprints import autoconnect + +load_dotenv() + + +demo_skill = autoconnect( + demo_calculator_skill(), + human_input(), + llm_agent(system_prompt=get_system_prompt()), +) diff --git a/dimos/agents2/skills/google_maps_skill_container.py b/dimos/agents2/skills/google_maps_skill_container.py new file mode 100644 index 0000000000..e1d33731e6 --- /dev/null +++ b/dimos/agents2/skills/google_maps_skill_container.py @@ -0,0 +1,115 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any + +from dimos.core.core import rpc +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In +from dimos.mapping.google_maps.google_maps import GoogleMaps +from dimos.mapping.types import LatLon +from dimos.protocol.skill.skill import skill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class GoogleMapsSkillContainer(SkillModule): + _latest_location: LatLon | None = None + _client: GoogleMaps + + gps_location: In[LatLon] = None # type: ignore[assignment] + + def __init__(self) -> None: + super().__init__() + self._client = GoogleMaps() + + @rpc + def start(self) -> None: + super().start() + self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + + @rpc + def stop(self) -> None: + super().stop() + + def _on_gps_location(self, location: LatLon) -> None: + self._latest_location = location + + def _get_latest_location(self) -> LatLon: + if not self._latest_location: + raise ValueError("The position has not been set yet.") + return self._latest_location + + @skill() + def where_am_i(self, context_radius: int = 200) -> str: + """This skill returns information about what street/locality/city/etc + you are in. It also gives you nearby landmarks. + + Example: + + where_am_i(context_radius=200) + + Args: + context_radius (int): default 200, how many meters to look around + """ + + location = self._get_latest_location() + + result = None + try: + result = self._client.get_location_context(location, radius=context_radius) + except Exception: + return "There is an issue with the Google Maps API." + + if not result: + return "Could not find anything about the current location." + + return result.model_dump_json() + + @skill() + def get_gps_position_for_queries(self, *queries: str) -> str: + """Get the GPS position (latitude/longitude) + + Example: + + get_gps_position_for_queries(['Fort Mason', 'Lafayette Park']) + # returns + [{"lat": 37.8059, "lon":-122.4290}, {"lat": 37.7915, "lon": -122.4276}] + + Args: + queries (list[str]): The places you want to look up. + """ + + location = self._get_latest_location() + + results: list[dict[str, Any] | str] = [] + + for query in queries: + try: + latlon = self._client.get_position(query, location) + except Exception: + latlon = None + if latlon: + results.append(latlon.model_dump()) + else: + results.append(f"no result for {query}") + + return json.dumps(results) + + +google_maps_skill = GoogleMapsSkillContainer.blueprint + +__all__ = ["GoogleMapsSkillContainer", "google_maps_skill"] diff --git a/dimos/agents2/skills/gps_nav_skill.py b/dimos/agents2/skills/gps_nav_skill.py new file mode 100644 index 0000000000..31a6a5f956 --- /dev/null +++ b/dimos/agents2/skills/gps_nav_skill.py @@ -0,0 +1,109 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from dimos.core.core import rpc +from dimos.core.rpc_client import RpcCall +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In, Out +from dimos.mapping.types import LatLon +from dimos.mapping.utils.distance import distance_in_meters +from dimos.protocol.skill.skill import skill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class GpsNavSkillContainer(SkillModule): + _latest_location: LatLon | None = None + _max_valid_distance: int = 50000 + _set_gps_travel_goal_points: RpcCall | None = None + + gps_location: In[LatLon] = None # type: ignore[assignment] + gps_goal: Out[LatLon] = None # type: ignore[assignment] + + def __init__(self) -> None: + super().__init__() + + @rpc + def start(self) -> None: + super().start() + self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + + @rpc + def stop(self) -> None: + super().stop() + + @rpc + def set_WebsocketVisModule_set_gps_travel_goal_points(self, callable: RpcCall) -> None: + self._set_gps_travel_goal_points = callable + self._set_gps_travel_goal_points.set_rpc(self.rpc) # type: ignore[arg-type] + + def _on_gps_location(self, location: LatLon) -> None: + self._latest_location = location + + def _get_latest_location(self) -> LatLon: + if not self._latest_location: + raise ValueError("The position has not been set yet.") + return self._latest_location + + @skill() + def set_gps_travel_points(self, *points: dict[str, float]) -> str: + """Define the movement path determined by GPS coordinates. Requires at least one. You can get the coordinates by using the `get_gps_position_for_queries` skill. + + Example: + + set_gps_travel_goals([{"lat": 37.8059, "lon":-122.4290}, {"lat": 37.7915, "lon": -122.4276}]) + # Travel first to {"lat": 37.8059, "lon":-122.4290} + # then travel to {"lat": 37.7915, "lon": -122.4276} + """ + + new_points = [self._convert_point(x) for x in points] + + if not all(new_points): + parsed = json.dumps([x.__dict__ if x else x for x in new_points]) + return f"Not all points were valid. I parsed this: {parsed}" + + for new_point in new_points: + distance = distance_in_meters(self._get_latest_location(), new_point) # type: ignore[arg-type] + if distance > self._max_valid_distance: + return f"Point {new_point} is too far ({int(distance)} meters away)." + + logger.info(f"Set travel points: {new_points}") + + if self.gps_goal._transport is not None: + self.gps_goal.publish(new_points) + + if self._set_gps_travel_goal_points: + self._set_gps_travel_goal_points(new_points) + + return "I've successfully set the travel points." + + def _convert_point(self, point: dict[str, float]) -> LatLon | None: + if not isinstance(point, dict): + return None + lat = point.get("lat") + lon = point.get("lon") + + if lat is None or lon is None: + return None + + return LatLon(lat=lat, lon=lon) + + +gps_nav_skill = GpsNavSkillContainer.blueprint + + +__all__ = ["GpsNavSkillContainer", "gps_nav_skill"] diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py new file mode 100644 index 0000000000..442f11a42d --- /dev/null +++ b/dimos/agents2/skills/navigation.py @@ -0,0 +1,402 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Any + +from dimos.core.core import rpc +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In +from dimos.models.qwen.video_query import BBox +from dimos.models.vl.qwen import QwenVlModel +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Vector3 import make_vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.navigation.base import NavigationState +from dimos.navigation.visual.query import get_object_bbox_from_image +from dimos.protocol.skill.skill import skill +from dimos.types.robot_location import RobotLocation +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class NavigationSkillContainer(SkillModule): + _latest_image: Image | None = None + _latest_odom: PoseStamped | None = None + _skill_started: bool = False + _similarity_threshold: float = 0.23 + + rpc_calls: list[str] = [ + "SpatialMemory.tag_location", + "SpatialMemory.query_tagged_location", + "SpatialMemory.query_by_text", + "NavigationInterface.set_goal", + "NavigationInterface.get_state", + "NavigationInterface.is_goal_reached", + "NavigationInterface.cancel_goal", + "ObjectTracking.track", + "ObjectTracking.stop_track", + "ObjectTracking.is_tracking", + "WavefrontFrontierExplorer.stop_exploration", + "WavefrontFrontierExplorer.explore", + "WavefrontFrontierExplorer.is_exploration_active", + ] + + color_image: In[Image] = None # type: ignore[assignment] + odom: In[PoseStamped] = None # type: ignore[assignment] + + def __init__(self) -> None: + super().__init__() + self._skill_started = False + self._vl_model = QwenVlModel() + + @rpc + def start(self) -> None: + self._disposables.add(self.color_image.subscribe(self._on_color_image)) # type: ignore[arg-type] + self._disposables.add(self.odom.subscribe(self._on_odom)) # type: ignore[arg-type] + self._skill_started = True + + @rpc + def stop(self) -> None: + super().stop() + + def _on_color_image(self, image: Image) -> None: + self._latest_image = image + + def _on_odom(self, odom: PoseStamped) -> None: + self._latest_odom = odom + + @skill() + def tag_location(self, location_name: str) -> str: + """Tag this location in the spatial memory with a name. + + This associates the current location with the given name in the spatial memory, allowing you to navigate back to it. + + Args: + location_name (str): the name for the location + + Returns: + str: the outcome + """ + + if not self._skill_started: + raise ValueError(f"{self} has not been started.") + tf = self.tf.get("map", "base_link", time_tolerance=2.0) + if not tf: + return "Could not get the robot's current transform." + + position = tf.translation + rotation = tf.rotation.to_euler() + + location = RobotLocation( + name=location_name, + position=(position.x, position.y, position.z), + rotation=(rotation.x, rotation.y, rotation.z), + ) + + tag_location_rpc = self.get_rpc_calls("SpatialMemory.tag_location") + if not tag_location_rpc(location): + return f"Error: Failed to store '{location_name}' in the spatial memory" + + logger.info(f"Tagged {location}") + return f"Tagged '{location_name}': ({position.x},{position.y})." + + @skill() + def navigate_with_text(self, query: str) -> str: + """Navigate to a location by querying the existing semantic map using natural language. + + First attempts to locate an object in the robot's camera view using vision. + If the object is found, navigates to it. If not, falls back to querying the + semantic map for a location matching the description. + CALL THIS SKILL FOR ONE SUBJECT AT A TIME. For example: "Go to the person wearing a blue shirt in the living room", + you should call this skill twice, once for the person wearing a blue shirt and once for the living room. + Args: + query: Text query to search for in the semantic map + """ + + if not self._skill_started: + raise ValueError(f"{self} has not been started.") + success_msg = self._navigate_by_tagged_location(query) + if success_msg: + return success_msg + + logger.info(f"No tagged location found for {query}") + + success_msg = self._navigate_to_object(query) + if success_msg: + return success_msg + + logger.info(f"No object in view found for {query}") + + success_msg = self._navigate_using_semantic_map(query) + if success_msg: + return success_msg + + return f"No tagged location called '{query}'. No object in view matching '{query}'. No matching location found in semantic map for '{query}'." + + def _navigate_by_tagged_location(self, query: str) -> str | None: + try: + query_tagged_location_rpc = self.get_rpc_calls("SpatialMemory.query_tagged_location") + except Exception: + logger.warning("SpatialMemory module not connected, cannot query tagged locations") + return None + + robot_location = query_tagged_location_rpc(query) + + if not robot_location: + return None + + print("Found tagged location:", robot_location) + goal_pose = PoseStamped( + position=make_vector3(*robot_location.position), + orientation=Quaternion.from_euler(Vector3(*robot_location.rotation)), + frame_id="map", + ) + + result = self._navigate_to(goal_pose) + if not result: + return "Error: Faild to reach the tagged location." + + return ( + f"Successfuly arrived at location tagged '{robot_location.name}' from query '{query}'." + ) + + def _navigate_to(self, pose: PoseStamped) -> bool: + try: + set_goal_rpc, get_state_rpc, is_goal_reached_rpc = self.get_rpc_calls( + "NavigationInterface.set_goal", + "NavigationInterface.get_state", + "NavigationInterface.is_goal_reached", + ) + except Exception: + logger.error("Navigation module not connected properly") + return False + + logger.info( + f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + ) + set_goal_rpc(pose) + time.sleep(1.0) + + while get_state_rpc() == NavigationState.FOLLOWING_PATH: + time.sleep(0.25) + + time.sleep(1.0) + if not is_goal_reached_rpc(): + logger.info("Navigation was cancelled or failed") + return False + else: + logger.info("Navigation goal reached") + return True + + def _navigate_to_object(self, query: str) -> str | None: + try: + bbox = self._get_bbox_for_current_frame(query) + except Exception: + logger.error(f"Failed to get bbox for {query}", exc_info=True) + return None + + if bbox is None: + return None + + try: + track_rpc, stop_track_rpc, is_tracking_rpc = self.get_rpc_calls( + "ObjectTracking.track", "ObjectTracking.stop_track", "ObjectTracking.is_tracking" + ) + except Exception: + logger.error("ObjectTracking module not connected properly") + return None + + try: + get_state_rpc, is_goal_reached_rpc = self.get_rpc_calls( + "NavigationInterface.get_state", "NavigationInterface.is_goal_reached" + ) + except Exception: + logger.error("Navigation module not connected properly") + return None + + logger.info(f"Found {query} at {bbox}") + + # Start tracking - BBoxNavigationModule automatically generates goals + track_rpc(bbox) + + start_time = time.time() + timeout = 30.0 + goal_set = False + + while time.time() - start_time < timeout: + # Check if navigator finished + if get_state_rpc() == NavigationState.IDLE and goal_set: + logger.info("Waiting for goal result") + time.sleep(1.0) + if not is_goal_reached_rpc(): + logger.info(f"Goal cancelled, tracking '{query}' failed") + stop_track_rpc() + return None + else: + logger.info(f"Reached '{query}'") + stop_track_rpc() + return f"Successfully arrived at '{query}'" + + # If goal set and tracking lost, just continue (tracker will resume or timeout) + if goal_set and not is_tracking_rpc(): + continue + + # BBoxNavigationModule automatically sends goals when tracker publishes + # Just check if we have any detections to mark goal_set + if is_tracking_rpc(): + goal_set = True + + time.sleep(0.25) + + logger.warning(f"Navigation to '{query}' timed out after {timeout}s") + stop_track_rpc() + return None + + def _get_bbox_for_current_frame(self, query: str) -> BBox | None: + if self._latest_image is None: + return None + + return get_object_bbox_from_image(self._vl_model, self._latest_image, query) + + def _navigate_using_semantic_map(self, query: str) -> str: + try: + query_by_text_rpc = self.get_rpc_calls("SpatialMemory.query_by_text") + except Exception: + return "Error: The SpatialMemory module is not connected." + + results = query_by_text_rpc(query) + + if not results: + return f"No matching location found in semantic map for '{query}'" + + best_match = results[0] + + goal_pose = self._get_goal_pose_from_result(best_match) + + print("Goal pose for semantic nav:", goal_pose) + if not goal_pose: + return f"Found a result for '{query}' but it didn't have a valid position." + + result = self._navigate_to(goal_pose) + + if not result: + return f"Failed to navigate for '{query}'" + + return f"Successfuly arrived at '{query}'" + + @skill() + def follow_human(self, person: str) -> str: + """Follow a specific person""" + return "Not implemented yet." + + @skill() + def stop_movement(self) -> str: + """Immediatly stop moving.""" + + if not self._skill_started: + raise ValueError(f"{self} has not been started.") + + self._cancel_goal_and_stop() + + return "Stopped" + + def _cancel_goal_and_stop(self) -> None: + try: + cancel_goal_rpc = self.get_rpc_calls("NavigationInterface.cancel_goal") + except Exception: + logger.warning("Navigation module not connected, cannot cancel goal") + return + + try: + stop_exploration_rpc = self.get_rpc_calls("WavefrontFrontierExplorer.stop_exploration") + except Exception: + logger.warning("FrontierExplorer module not connected, cannot stop exploration") + return + + cancel_goal_rpc() + return stop_exploration_rpc() # type: ignore[no-any-return] + + @skill() + def start_exploration(self, timeout: float = 240.0) -> str: + """A skill that performs autonomous frontier exploration. + + This skill continuously finds and navigates to unknown frontiers in the environment + until no more frontiers are found or the exploration is stopped. + + Don't call any other skills except stop_movement skill when needed. + + Args: + timeout (float, optional): Maximum time (in seconds) allowed for exploration + """ + + if not self._skill_started: + raise ValueError(f"{self} has not been started.") + + try: + return self._start_exploration(timeout) + finally: + self._cancel_goal_and_stop() + + def _start_exploration(self, timeout: float) -> str: + try: + explore_rpc, is_exploration_active_rpc = self.get_rpc_calls( + "WavefrontFrontierExplorer.explore", + "WavefrontFrontierExplorer.is_exploration_active", + ) + except Exception: + return "Error: The WavefrontFrontierExplorer module is not connected." + + logger.info("Starting autonomous frontier exploration") + + start_time = time.time() + + has_started = explore_rpc() + if not has_started: + return "Error: Could not start exploration." + + while time.time() - start_time < timeout and is_exploration_active_rpc(): + time.sleep(0.5) + + return "Exploration completed successfuly" + + def _get_goal_pose_from_result(self, result: dict[str, Any]) -> PoseStamped | None: + similarity = 1.0 - (result.get("distance") or 1) + if similarity < self._similarity_threshold: + logger.warning( + f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})" + ) + return None + + metadata = result.get("metadata") + if not metadata: + return None + print(metadata) + first = metadata[0] + print(first) + pos_x = first.get("pos_x", 0) + pos_y = first.get("pos_y", 0) + theta = first.get("rot_z", 0) + + return PoseStamped( + position=make_vector3(pos_x, pos_y, 0), + orientation=Quaternion.from_euler(make_vector3(0, 0, theta)), + frame_id="map", + ) + + +navigation_skill = NavigationSkillContainer.blueprint + +__all__ = ["NavigationSkillContainer", "navigation_skill"] diff --git a/dimos/agents2/skills/osm.py b/dimos/agents2/skills/osm.py new file mode 100644 index 0000000000..069f3ae7a9 --- /dev/null +++ b/dimos/agents2/skills/osm.py @@ -0,0 +1,80 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.core.skill_module import SkillModule +from dimos.core.stream import In +from dimos.mapping.osm.current_location_map import CurrentLocationMap +from dimos.mapping.types import LatLon +from dimos.mapping.utils.distance import distance_in_meters +from dimos.models.vl.qwen import QwenVlModel +from dimos.protocol.skill.skill import skill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class OsmSkill(SkillModule): + _latest_location: LatLon | None + _current_location_map: CurrentLocationMap + + gps_location: In[LatLon] = None # type: ignore[assignment] + + def __init__(self) -> None: + super().__init__() + self._latest_location = None + self._current_location_map = CurrentLocationMap(QwenVlModel()) + + def start(self) -> None: + super().start() + self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) # type: ignore[arg-type] + + def stop(self) -> None: + super().stop() + + def _on_gps_location(self, location: LatLon) -> None: + self._latest_location = location + + @skill() + def street_map_query(self, query_sentence: str) -> str: + """This skill uses a vision language model to find something on the map + based on the query sentence. You can query it with something like "Where + can I find a coffee shop?" and it returns the latitude and longitude. + + Example: + + street_map_query("Where can I find a coffee shop?") + + Args: + query_sentence (str): The query sentence. + """ + + self._current_location_map.update_position(self._latest_location) # type: ignore[arg-type] + location = self._current_location_map.query_for_one_position_and_context( + query_sentence, + self._latest_location, # type: ignore[arg-type] + ) + if not location: + return "Could not find anything." + + latlon, context = location + + distance = int(distance_in_meters(latlon, self._latest_location)) # type: ignore[arg-type] + + return f"{context}. It's at position latitude={latlon.lat}, longitude={latlon.lon}. It is {distance} meters away." + + +osm_skill = OsmSkill.blueprint + +__all__ = ["OsmSkill", "osm_skill"] diff --git a/dimos/agents2/skills/speak_skill.py b/dimos/agents2/skills/speak_skill.py new file mode 100644 index 0000000000..073dda656a --- /dev/null +++ b/dimos/agents2/skills/speak_skill.py @@ -0,0 +1,104 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +from reactivex import Subject + +from dimos.core.core import rpc +from dimos.core.skill_module import SkillModule +from dimos.protocol.skill.skill import skill +from dimos.stream.audio.node_output import SounddeviceAudioOutput +from dimos.stream.audio.tts.node_openai import OpenAITTSNode, Voice +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class SpeakSkill(SkillModule): + _tts_node: OpenAITTSNode | None = None + _audio_output: SounddeviceAudioOutput | None = None + _audio_lock: threading.Lock = threading.Lock() + + @rpc + def start(self) -> None: + super().start() + self._tts_node = OpenAITTSNode(speed=1.2, voice=Voice.ONYX) + self._audio_output = SounddeviceAudioOutput(sample_rate=24000) + self._audio_output.consume_audio(self._tts_node.emit_audio()) + + @rpc + def stop(self) -> None: + if self._tts_node: + self._tts_node.dispose() + self._tts_node = None + if self._audio_output: + self._audio_output.stop() + self._audio_output = None + super().stop() + + @skill() + def speak(self, text: str) -> str: + """Speak text out loud through the robot's speakers. + + USE THIS TOOL AS OFTEN AS NEEDED. People can't normally see what you say in text, but can hear what you speak. + + Try to be as concise as possible. Remember that speaking takes time, so get to the point quickly. + + Example usage: + + speak("Hello, I am your robot assistant.") + """ + if self._tts_node is None: + return "Error: TTS not initialized" + + # Use lock to prevent simultaneous speech + with self._audio_lock: + text_subject: Subject[str] = Subject() + audio_complete = threading.Event() + self._tts_node.consume_text(text_subject) + + def set_as_complete(_t: str) -> None: + audio_complete.set() + + def set_as_complete_e(_e: Exception) -> None: + audio_complete.set() + + subscription = self._tts_node.emit_text().subscribe( + on_next=set_as_complete, + on_error=set_as_complete_e, + ) + + text_subject.on_next(text) + text_subject.on_completed() + + timeout = max(5, len(text) * 0.1) + + if not audio_complete.wait(timeout=timeout): + logger.warning(f"TTS timeout reached for: {text}") + subscription.dispose() + return f"Warning: TTS timeout while speaking: {text}" + else: + # Small delay to ensure buffers flush + time.sleep(0.3) + + subscription.dispose() + + return f"Spoke: {text}" + + +speak_skill = SpeakSkill.blueprint + +__all__ = ["SpeakSkill", "speak_skill"] diff --git a/dimos/agents2/skills/test_google_maps_skill_container.py b/dimos/agents2/skills/test_google_maps_skill_container.py new file mode 100644 index 0000000000..0af206fbb1 --- /dev/null +++ b/dimos/agents2/skills/test_google_maps_skill_container.py @@ -0,0 +1,47 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from dimos.mapping.google_maps.types import Coordinates, LocationContext, Position +from dimos.mapping.types import LatLon + + +def test_where_am_i(create_google_maps_agent, google_maps_skill_container) -> None: + google_maps_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) + google_maps_skill_container._client.get_location_context.return_value = LocationContext( + street="Bourbon Street", coordinates=Coordinates(lat=37.782654, lon=-122.413273) + ) + agent = create_google_maps_agent(fixture="test_where_am_i.json") + + response = agent.query("what street am I on") + + assert "bourbon" in response.lower() + + +def test_get_gps_position_for_queries( + create_google_maps_agent, google_maps_skill_container +) -> None: + google_maps_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) + google_maps_skill_container._client.get_position.side_effect = [ + Position(lat=37.782601, lon=-122.413201, description="address 1"), + Position(lat=37.782602, lon=-122.413202, description="address 2"), + Position(lat=37.782603, lon=-122.413203, description="address 3"), + ] + agent = create_google_maps_agent(fixture="test_get_gps_position_for_queries.json") + + response = agent.query("what are the lat/lon for hyde park, regent park, russell park?") + + regex = r".*37\.782601.*122\.413201.*37\.782602.*122\.413202.*37\.782603.*122\.413203.*" + assert re.match(regex, response, re.DOTALL) diff --git a/dimos/agents2/skills/test_gps_nav_skills.py b/dimos/agents2/skills/test_gps_nav_skills.py new file mode 100644 index 0000000000..ab0d1ec318 --- /dev/null +++ b/dimos/agents2/skills/test_gps_nav_skills.py @@ -0,0 +1,58 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.mapping.types import LatLon + + +def test_set_gps_travel_points(create_gps_nav_agent, gps_nav_skill_container, mocker) -> None: + gps_nav_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) + gps_nav_skill_container._set_gps_travel_goal_points = mocker.Mock() + agent = create_gps_nav_agent(fixture="test_set_gps_travel_points.json") + + agent.query("go to lat: 37.782654, lon: -122.413273") + + gps_nav_skill_container._set_gps_travel_goal_points.assert_called_once_with( + [LatLon(lat=37.782654, lon=-122.413273)] + ) + gps_nav_skill_container.gps_goal.publish.assert_called_once_with( + [LatLon(lat=37.782654, lon=-122.413273)] + ) + + +def test_set_gps_travel_points_multiple( + create_gps_nav_agent, gps_nav_skill_container, mocker +) -> None: + gps_nav_skill_container._latest_location = LatLon(lat=37.782654, lon=-122.413273) + gps_nav_skill_container._set_gps_travel_goal_points = mocker.Mock() + agent = create_gps_nav_agent(fixture="test_set_gps_travel_points_multiple.json") + + agent.query( + "go to lat: 37.782654, lon: -122.413273, then 37.782660,-122.413260, and then 37.782670,-122.413270" + ) + + gps_nav_skill_container._set_gps_travel_goal_points.assert_called_once_with( + [ + LatLon(lat=37.782654, lon=-122.413273), + LatLon(lat=37.782660, lon=-122.413260), + LatLon(lat=37.782670, lon=-122.413270), + ] + ) + gps_nav_skill_container.gps_goal.publish.assert_called_once_with( + [ + LatLon(lat=37.782654, lon=-122.413273), + LatLon(lat=37.782660, lon=-122.413260), + LatLon(lat=37.782670, lon=-122.413270), + ] + ) diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py new file mode 100644 index 0000000000..e823fe8a86 --- /dev/null +++ b/dimos/agents2/skills/test_navigation.py @@ -0,0 +1,94 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.utils.transform_utils import euler_to_quaternion + + +# @pytest.mark.skip +def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker) -> None: + cancel_goal_mock = mocker.Mock() + stop_exploration_mock = mocker.Mock() + navigation_skill_container._bound_rpc_calls["NavigationInterface.cancel_goal"] = ( + cancel_goal_mock + ) + navigation_skill_container._bound_rpc_calls["WavefrontFrontierExplorer.stop_exploration"] = ( + stop_exploration_mock + ) + agent = create_navigation_agent(fixture="test_stop_movement.json") + + agent.query("stop") + + cancel_goal_mock.assert_called_once_with() + stop_exploration_mock.assert_called_once_with() + + +def test_take_a_look_around(create_navigation_agent, navigation_skill_container, mocker) -> None: + explore_mock = mocker.Mock() + is_exploration_active_mock = mocker.Mock() + navigation_skill_container._bound_rpc_calls["WavefrontFrontierExplorer.explore"] = explore_mock + navigation_skill_container._bound_rpc_calls[ + "WavefrontFrontierExplorer.is_exploration_active" + ] = is_exploration_active_mock + mocker.patch("dimos.agents2.skills.navigation.time.sleep") + agent = create_navigation_agent(fixture="test_take_a_look_around.json") + + agent.query("take a look around for 10 seconds") + + explore_mock.assert_called_once_with() + + +def test_go_to_semantic_location( + create_navigation_agent, navigation_skill_container, mocker +) -> None: + mocker.patch( + "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_by_tagged_location", + return_value=None, + ) + mocker.patch( + "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_to_object", + return_value=None, + ) + navigate_to_mock = mocker.patch( + "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_to", + return_value=True, + ) + query_by_text_mock = mocker.Mock( + return_value=[ + { + "distance": 0.5, + "metadata": [ + { + "pos_x": 1, + "pos_y": 2, + "rot_z": 3, + } + ], + } + ] + ) + navigation_skill_container._bound_rpc_calls["SpatialMemory.query_by_text"] = query_by_text_mock + agent = create_navigation_agent(fixture="test_go_to_semantic_location.json") + + agent.query("go to the bookshelf") + + query_by_text_mock.assert_called_once_with("bookshelf") + navigate_to_mock.assert_called_once_with( + PoseStamped( + position=Vector3(1, 2, 0), + orientation=euler_to_quaternion(Vector3(0, 0, 3)), + frame_id="world", + ), + ) diff --git a/dimos/agents2/skills/test_unitree_skill_container.py b/dimos/agents2/skills/test_unitree_skill_container.py new file mode 100644 index 0000000000..16088875c5 --- /dev/null +++ b/dimos/agents2/skills/test_unitree_skill_container.py @@ -0,0 +1,42 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_pounce(create_unitree_skills_agent, unitree_skills) -> None: + agent = create_unitree_skills_agent(fixture="test_pounce.json") + + response = agent.query("pounce") + + assert "front pounce" in response.lower() + unitree_skills._publish_request.assert_called_once_with( + "rt/api/sport/request", {"api_id": 1032} + ) + + +def test_show_your_love(create_unitree_skills_agent, unitree_skills) -> None: + agent = create_unitree_skills_agent(fixture="test_show_your_love.json") + + response = agent.query("show your love") + + assert "finger heart" in response.lower() + unitree_skills._publish_request.assert_called_once_with( + "rt/api/sport/request", {"api_id": 1036} + ) + + +def test_did_you_mean(unitree_skills) -> None: + assert ( + unitree_skills.execute_sport_command("Pounce") + == "There's no 'Pounce' command. Did you mean: ['FrontPounce', 'Pose']" + ) diff --git a/dimos/agents2/spec.py b/dimos/agents2/spec.py new file mode 100644 index 0000000000..0587cf040d --- /dev/null +++ b/dimos/agents2/spec.py @@ -0,0 +1,540 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent specification layer for the agents2 system. + +This module provides the abstract base class that all LLM-based agents in DimOS +must extend. The specification pattern separates "what agents must do" (history +management, query processing) from "how agents do it" (LLM provider, tool +calling implementation), enabling multiple agent implementations to share common +infrastructure. + +When to use this module +----------------------- +**For most users**: Use `llm_agent()` from `dimos.agents2.agent` directly. +This module is primarily relevant when: + +- Creating a custom agent implementation with different LLM backends +- Extending the agent system with new capabilities +- Understanding the agent architecture for debugging + +The `AgentSpec` abstract base class defines required methods (`clear_history`, +`append_history`, `history`, `query`) while providing shared infrastructure +(message transport, lifecycle management, conversation display). Concrete +implementations like `Agent` in `dimos.agents2.agent` handle LLM interaction, +tool calling, and neurosymbolic orchestration. + +Core classes +------------ +AgentSpec + Abstract base class defining required methods (history management, query + interface) and providing shared infrastructure (transport, lifecycle, + display). + +AgentConfig + Configuration dataclass specifying system prompt, model selection, skills, + and message transport settings. + +Enums +----- +Provider + Dynamically generated enum of LLM providers (OPENAI, ANTHROPIC, etc.) based + on LangChain's supported providers. + +Model + Common model identifiers across providers (GPT_4O, CLAUDE_35_SONNET, etc.). + +Type aliases +------------ +AnyMessage + Union of LangChain message types: SystemMessage, ToolMessage, AIMessage, + HumanMessage. + +See also +-------- +dimos.agents2.agent : Concrete Agent implementation and llm_agent() blueprint +dimos.protocol.service : Service base class for lifecycle management +dimos.core.module : Module base class for distributed execution +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Annotated, Any, Union + +from annotated_doc import Doc +from langchain.chat_models.base import _SUPPORTED_PROVIDERS +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from dimos.core import Module, rpc +from dimos.core.module import ModuleConfig +from dimos.protocol.pubsub import PubSub, lcm # type: ignore[attr-defined] +from dimos.protocol.service import Service # type: ignore[attr-defined] +from dimos.protocol.skill.skill import SkillContainer +from dimos.utils.generic import truncate_display_string +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +# Dynamically create ModelProvider enum from LangChain's supported providers +_providers = {provider.upper(): provider for provider in _SUPPORTED_PROVIDERS} +Provider = Enum("Provider", _providers, type=str) # type: ignore[misc] + + +class Model(str, Enum): + """Common model names across providers. + + Note: This is not exhaustive as model names change frequently. + Based on langchain's _attempt_infer_model_provider patterns. + """ + + # OpenAI models (prefix: gpt-3, gpt-4, o1, o3) + GPT_4O = "gpt-4o" + GPT_4O_MINI = "gpt-4o-mini" + GPT_4_TURBO = "gpt-4-turbo" + GPT_4_TURBO_PREVIEW = "gpt-4-turbo-preview" + GPT_4 = "gpt-4" + GPT_35_TURBO = "gpt-3.5-turbo" + GPT_35_TURBO_16K = "gpt-3.5-turbo-16k" + O1_PREVIEW = "o1-preview" + O1_MINI = "o1-mini" + O3_MINI = "o3-mini" + + # Anthropic models (prefix: claude) + CLAUDE_3_OPUS = "claude-3-opus-20240229" + CLAUDE_3_SONNET = "claude-3-sonnet-20240229" + CLAUDE_3_HAIKU = "claude-3-haiku-20240307" + CLAUDE_35_SONNET = "claude-3-5-sonnet-20241022" + CLAUDE_35_SONNET_LATEST = "claude-3-5-sonnet-latest" + CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219" + + # Google models (prefix: gemini) + GEMINI_20_FLASH = "gemini-2.0-flash" + GEMINI_15_PRO = "gemini-1.5-pro" + GEMINI_15_FLASH = "gemini-1.5-flash" + GEMINI_10_PRO = "gemini-1.0-pro" + + # Amazon Bedrock models (prefix: amazon) + AMAZON_TITAN_EXPRESS = "amazon.titan-text-express-v1" + AMAZON_TITAN_LITE = "amazon.titan-text-lite-v1" + + # Cohere models (prefix: command) + COMMAND_R_PLUS = "command-r-plus" + COMMAND_R = "command-r" + COMMAND = "command" + COMMAND_LIGHT = "command-light" + + # Fireworks models (prefix: accounts/fireworks) + FIREWORKS_LLAMA_V3_70B = "accounts/fireworks/models/llama-v3-70b-instruct" + FIREWORKS_MIXTRAL_8X7B = "accounts/fireworks/models/mixtral-8x7b-instruct" + + # Mistral models (prefix: mistral) + MISTRAL_LARGE = "mistral-large" + MISTRAL_MEDIUM = "mistral-medium" + MISTRAL_SMALL = "mistral-small" + MIXTRAL_8X7B = "mixtral-8x7b" + MIXTRAL_8X22B = "mixtral-8x22b" + MISTRAL_7B = "mistral-7b" + + # DeepSeek models (prefix: deepseek) + DEEPSEEK_CHAT = "deepseek-chat" + DEEPSEEK_CODER = "deepseek-coder" + DEEPSEEK_R1_DISTILL_LLAMA_70B = "deepseek-r1-distill-llama-70b" + + # xAI models (prefix: grok) + GROK_1 = "grok-1" + GROK_2 = "grok-2" + + # Perplexity models (prefix: sonar) + SONAR_SMALL_CHAT = "sonar-small-chat" + SONAR_MEDIUM_CHAT = "sonar-medium-chat" + SONAR_LARGE_CHAT = "sonar-large-chat" + + # Meta Llama models (various providers) + LLAMA_3_70B = "llama-3-70b" + LLAMA_3_8B = "llama-3-8b" + LLAMA_31_70B = "llama-3.1-70b" + LLAMA_31_8B = "llama-3.1-8b" + LLAMA_33_70B = "llama-3.3-70b" + LLAMA_2_70B = "llama-2-70b" + LLAMA_2_13B = "llama-2-13b" + LLAMA_2_7B = "llama-2-7b" + + +@dataclass +class AgentConfig(ModuleConfig): + """Configuration for agent instances specifying model, prompt, and transport. + + Notes: + Either use (`model`, `provider`) or `model_instance`, not both. When + `model_instance` is provided, the other two are ignored. + + Set `agent_transport` to None to disable message publishing. + + Examples: + Basic configuration with string prompt: + + >>> config = AgentConfig( + ... system_prompt="You are a helpful robot assistant.", + ... model=Model.GPT_4O_MINI, + ... provider=Provider.OPENAI + ... ) + + Using a mock model instance for testing: + + >>> from dimos.agents2.testing import MockModel + >>> from langchain_core.messages import AIMessage + >>> mock = MockModel(responses=[AIMessage(content="Test response")]) + >>> config = AgentConfig( + ... system_prompt="You are a helpful robot assistant.", + ... model_instance=mock + ... ) + + Disabling message transport: + + >>> config = AgentConfig( + ... system_prompt="Test agent", + ... agent_transport=None + ... ) + """ + + system_prompt: Annotated[ + str | SystemMessage | None, + Doc( + """Initial system instructions for the LLM. Can be provided as a string or + pre-constructed SystemMessage. If None, a default prompt is used from + `dimos.agents2.system_prompt.get_system_prompt()`.""" + ), + ] = None + + skills: Annotated[ + SkillContainer | list[SkillContainer] | None, + Doc( + """Pre-registered skill containers. Currently unused by the agent system. + Skills are typically registered via the skill coordinator after initialization.""" + ), + ] = None + + model: Annotated[Model, Doc("Which LLM model to use (e.g., GPT_4O, CLAUDE_35_SONNET).")] = ( + Model.GPT_4O + ) + + provider: Annotated[ + Provider, Doc("Which LLM provider hosts the model (e.g., OPENAI, ANTHROPIC).") + ] = Provider.OPENAI # type: ignore[attr-defined] + + model_instance: Annotated[ + BaseChatModel | None, + Doc( + """Direct LangChain chat model instance. When provided, overrides `model` and + `provider`. Useful for testing with mock models.""" + ), + ] = None + + agent_transport: Annotated[ + type[PubSub], + Doc( + """Transport class for publishing agent messages. Must be a PubSub subclass that + can be instantiated with no arguments. Used for observability (e.g., by the + `agentspy` CLI tool).""" + ), + ] = lcm.PickleLCM # type: ignore[type-arg] + + agent_topic: Annotated[Any, Doc("Topic identifier for agent message publishing.")] = field( + default_factory=lambda: lcm.Topic("/agent") + ) + + +AnyMessage = Union[SystemMessage, ToolMessage, AIMessage, HumanMessage] +"""Union of LangChain message types returned by `AgentSpec.history()`. + +Represents the four message types that can appear in an agent's conversation +history. Users typically encounter this type when inspecting message history +or implementing custom conversation processing logic. + +Message types +------------- +SystemMessage + Initial agent instructions loaded from `AgentConfig.system_prompt`. Always + appears as the first message (`history()[0]`). + +HumanMessage + User queries submitted via `query()`, or outputs from skills configured with + `Output.human`. May contain text or multimodal content (images, documents). + +AIMessage + LLM-generated responses. May include `tool_calls` requesting skill execution, + or represent transient state awareness messages tracking long-running skills. + +ToolMessage + Results from executed skills, linked to AIMessage tool calls via `call_id`. + The `content` field contains the skill's return value. + +Working with message history +----------------------------- +Use `isinstance()` to distinguish message types when processing conversation history: + +```pycon +>>> agent = llm_agent(...) +>>> for msg in agent.history(): +... if isinstance(msg, HumanMessage): +... print(f"User: {msg.content}") +... elif isinstance(msg, AIMessage): +... print(f"Agent: {msg.content}") +``` + +All message types are from `langchain_core.messages`. +""" + + +class AgentSpec(Service[AgentConfig], Module, ABC): + """Abstract specification for LLM-based agents in DimOS. + + Defines the interface contract that all agents must implement while providing + common infrastructure for message transport, lifecycle management, and conversation + display. Agents bridge high-level reasoning (LLMs) with low-level robot actions + (skills) through a neurosymbolic orchestration pattern. + + This class is abstract. Concrete implementations must provide: + + - History management (`clear_history`, `append_history`, `history`) + - Query processing (`query`) + + Concrete implementations receive: + + - Message publishing infrastructure (`publish`) + - Lifecycle coordination (`start`, `stop`) + - Rich conversation display (`__str__`) + - Transport initialization + + Inheritance: + Inherits from `Service[AgentConfig]` (configuration and lifecycle), + `Module` (distributed execution and RPC), and `ABC` (abstract base class marker). + + Attributes: + config (AgentConfig): Configuration instance created from `default_config` + and constructor kwargs. + transport (PubSub | None): Message transport for observability, initialized + from `config.agent_transport` if provided. + + Notes: + Concrete implementations must set `self._agent_id` (string identifier) for + the `__str__` method to function correctly. See `dimos.agents2.agent.Agent` + for the reference implementation. + + The `__init__` explicitly calls both `Service.__init__` and `Module.__init__` + because multiple inheritance would otherwise skip one initialization path. + + See also: + Agent: Concrete implementation in `dimos.agents2.agent` + AgentConfig: Configuration dataclass + Service: Base class for lifecycle management (`dimos.protocol.service`) + Module: Base class for distributed execution (`dimos.core.module`) + """ + + default_config: type[AgentConfig] = AgentConfig + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + """Initialize agent with configuration and transport.""" + Service.__init__(self, *args, **kwargs) + Module.__init__(self, *args, **kwargs) + + if self.config.agent_transport: + self.transport = self.config.agent_transport() + + def publish( + self, + msg: Annotated[ + AnyMessage, + Doc("Message to publish (SystemMessage, HumanMessage, AIMessage, or ToolMessage)."), + ], + ) -> None: + """Publish message to transport for observability. + + Used by concrete implementations to broadcast conversation messages to + external monitoring tools like `agentspy`. Fire-and-forget semantics: if + transport is None, the message is silently dropped. + """ + if self.transport: + self.transport.publish(self.config.agent_topic, msg) + + def start(self) -> None: + """Start agent lifecycle, delegating to Service and Module initialization.""" + super().start() + + def stop(self) -> None: + """Stop agent lifecycle, cleaning up transport and delegating to parent classes.""" + if hasattr(self, "transport") and self.transport: + self.transport.stop() # type: ignore[attr-defined] + self.transport = None # type: ignore[assignment] + super().stop() + + @rpc + @abstractmethod + def clear_history(self): # type: ignore[no-untyped-def] + """Clear persistent conversation history. + + Removes all accumulated conversation messages while preserving the system + message. Transient state messages (managed by `agent_loop()`) are unaffected + and will still appear in `history()`. Allows resetting conversation context + while maintaining agent identity. + """ + ... + + @abstractmethod + def append_history( + self, + *msgs: Annotated[ + AIMessage | HumanMessage, + Doc( + "Variable number of AIMessage or HumanMessage instances to append to the conversation history." + ), + ], + ) -> None: + """Add messages to conversation history. + + Implementations must extend the history with provided messages in order + and should publish each message to the transport for observability. + """ + ... + + @abstractmethod + def history( + self, + ) -> Annotated[ + list[AnyMessage], + Doc( + """List of all messages in the conversation, starting with the system + message. May include HumanMessage, AIMessage, ToolMessage, and + transient state awareness messages.""" + ), + ]: + """Return complete message history including transient state messages. + + Implementations must return a list where `history()[0]` is always the + SystemMessage (initial prompt), followed by messages in chronological + order. May include transient state messages representing skill execution. + """ + ... + + @rpc + @abstractmethod + def query( + self, + query: Annotated[str, Doc("User query string to process.")], + ) -> Annotated[ + str | None, + Doc("Final agent response as string, or None if no response is generated."), + ]: + """Process user query through agent reasoning loop. + + Implementations must append query as HumanMessage to history, execute + the agent loop until completion, process any tool calls via skill + coordinator, and return the final response. + + Notes: + This method typically blocks until the agent completes reasoning and + any invoked skills finish execution. + """ + ... + + def __str__( + self, + ) -> Annotated[ + str, + Doc( + """Formatted string containing agent ID header and colorized conversation + table suitable for terminal display.""" + ), + ]: + """Render conversation history as formatted, colorized table, with color-coded message types. + + Notes: + Requires `self._agent_id` to be set by concrete implementations. + + Message styling: + + - HumanMessage: Green text + - AIMessage: Magenta text (blue for state summaries) + - ToolMessage: Red text + - SystemMessage: Yellow text, truncated to 800 characters + + Tool calls within AIMessage are displayed as separate rows showing + the function name and arguments. + + Images in HumanMessage content are displayed as the placeholder text + "" rather than attempting to render binary data. + + See also: + history: Provides the messages rendered by this method. + """ + console = Console(force_terminal=True, legacy_windows=False) + table = Table(show_header=True) + + table.add_column("Message Type", style="cyan", no_wrap=True) + table.add_column("Content") + + for message in self.history(): + if isinstance(message, HumanMessage): + content = message.content + if not isinstance(content, str): + content = "" + + table.add_row(Text("Human", style="green"), Text(content, style="green")) + elif isinstance(message, AIMessage): + if hasattr(message, "metadata") and message.metadata.get("state"): + table.add_row( + Text("State Summary", style="blue"), + Text(message.content, style="blue"), # type: ignore[arg-type] + ) + else: + table.add_row( + Text("Agent", style="magenta"), + Text(message.content, style="magenta"), # type: ignore[arg-type] + ) + + for tool_call in message.tool_calls: + table.add_row( + "Tool Call", + Text( + f"{tool_call.get('name')}({tool_call.get('args')})", + style="bold magenta", + ), + ) + elif isinstance(message, ToolMessage): + table.add_row( + "Tool Response", Text(f"{message.name}() -> {message.content}"), style="red" + ) + elif isinstance(message, SystemMessage): + table.add_row( + "System", Text(truncate_display_string(message.content, 800), style="yellow") + ) + else: + table.add_row("Unknown", str(message)) + + # Render to string with title above + with console.capture() as capture: + console.print(Text(f" Agent ({self._agent_id})", style="bold blue")) # type: ignore[attr-defined] + console.print(table) + return capture.get().strip() diff --git a/dimos/agents2/system_prompt.py b/dimos/agents2/system_prompt.py new file mode 100644 index 0000000000..94b35c6ac9 --- /dev/null +++ b/dimos/agents2/system_prompt.py @@ -0,0 +1,25 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.agents2.constants import AGENT_SYSTEM_PROMPT_PATH + +_SYSTEM_PROMPT = None + + +def get_system_prompt() -> str: + global _SYSTEM_PROMPT + if _SYSTEM_PROMPT is None: + with open(AGENT_SYSTEM_PROMPT_PATH) as f: + _SYSTEM_PROMPT = f.read() + return _SYSTEM_PROMPT diff --git a/dimos/agents2/temp/webcam_agent.py b/dimos/agents2/temp/webcam_agent.py new file mode 100644 index 0000000000..5bcd964e2d --- /dev/null +++ b/dimos/agents2/temp/webcam_agent.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Unitree Go2 robot with agents2 framework. +This is the migrated version using the new LangChain-based agent system. +""" + +from threading import Thread +import time + +import reactivex as rx +import reactivex.operators as ops + +from dimos.agents2 import Agent, Output, Reducer, Stream, skill # type: ignore[attr-defined] +from dimos.agents2.cli.human import HumanInput +from dimos.agents2.spec import Model, Provider +from dimos.core import LCMTransport, Module, rpc, start +from dimos.hardware.camera import zed +from dimos.hardware.camera.module import CameraModule +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import CameraInfo, Image +from dimos.protocol.skill.test_coordinator import SkillContainerTest +from dimos.web.robot_web_interface import RobotWebInterface + + +class WebModule(Module): + web_interface: RobotWebInterface = None # type: ignore[assignment] + human_query: rx.subject.Subject = None # type: ignore[assignment, type-arg] + agent_response: rx.subject.Subject = None # type: ignore[assignment, type-arg] + + thread: Thread = None # type: ignore[assignment] + + _human_messages_running = False + + def __init__(self) -> None: + super().__init__() + self.agent_response = rx.subject.Subject() + self.human_query = rx.subject.Subject() + + @rpc + def start(self) -> None: + super().start() + + text_streams = { + "agent_responses": self.agent_response, + } + + self.web_interface = RobotWebInterface( + port=5555, + text_streams=text_streams, + audio_subject=rx.subject.Subject(), + ) + + unsub = self.web_interface.query_stream.subscribe(self.human_query.on_next) + self._disposables.add(unsub) + + self.thread = Thread(target=self.web_interface.run, daemon=True) + self.thread.start() + + @rpc + def stop(self) -> None: + if self.web_interface: + self.web_interface.stop() # type: ignore[attr-defined] + if self.thread: + # TODO, you can't just wait for a server to close, you have to signal it to end. + self.thread.join(timeout=1.0) + + super().stop() + + @skill(stream=Stream.call_agent, reducer=Reducer.all, output=Output.human) # type: ignore[arg-type] + def human_messages(self): # type: ignore[no-untyped-def] + """Provide human messages from web interface. Don't use this tool, it's running implicitly already""" + if self._human_messages_running: + print("human_messages already running, not starting another") + return "already running" + self._human_messages_running = True + while True: + print("Waiting for human message...") + message = self.human_query.pipe(ops.first()).run() + print(f"Got human message: {message}") + yield message + + +def main() -> None: + dimos = start(4) + # Create agent + agent = Agent( + system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot. ", + model=Model.GPT_4O, # Could add CLAUDE models to enum + provider=Provider.OPENAI, # type: ignore[attr-defined] # Would need ANTHROPIC provider + ) + + testcontainer = dimos.deploy(SkillContainerTest) # type: ignore[attr-defined] + webcam = dimos.deploy( # type: ignore[attr-defined] + CameraModule, + transform=Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + frequency=15, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ) + + webcam.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + + webcam.image.transport = LCMTransport("/image", Image) + + webcam.start() + + human_input = dimos.deploy(HumanInput) # type: ignore[attr-defined] + + time.sleep(1) + + agent.register_skills(human_input) + agent.register_skills(webcam) + agent.register_skills(testcontainer) + + agent.run_implicit_skill("video_stream") + agent.run_implicit_skill("human") + + agent.start() + agent.loop_thread() + + while True: + time.sleep(1) + + # webcam.stop() + + +if __name__ == "__main__": + main() diff --git a/dimos/agents2/test_agent.py b/dimos/agents2/test_agent.py new file mode 100644 index 0000000000..9e9d5bab18 --- /dev/null +++ b/dimos/agents2/test_agent.py @@ -0,0 +1,169 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio + +from dimos.agents2.agent import Agent +from dimos.core import start +from dimos.protocol.skill.test_coordinator import SkillContainerTest + +system_prompt = ( + "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" +) + + +@pytest.fixture(scope="session") +def dimos_cluster(): + """Session-scoped fixture to initialize dimos cluster once.""" + dimos = start(2) + try: + yield dimos + finally: + dimos.shutdown() + + +@pytest_asyncio.fixture +async def local(): + """Local context: both agent and testcontainer run locally""" + testcontainer = SkillContainerTest() + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() + raise e + finally: + # Ensure cleanup happens while event loop is still active + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + + +@pytest_asyncio.fixture +async def dask_mixed(dimos_cluster): + """Dask context: testcontainer on dimos, agent local""" + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + + +@pytest_asyncio.fixture +async def dask_full(dimos_cluster): + """Dask context: both agent and testcontainer deployed on dimos""" + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = dimos_cluster.deploy(Agent, system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + + +@pytest_asyncio.fixture(params=["local", "dask_mixed", "dask_full"]) +async def agent_context(request): + """Parametrized fixture that runs tests with different agent configurations""" + param = request.param + + if param == "local": + testcontainer = SkillContainerTest() + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + elif param == "dask_mixed": + dimos_cluster = request.getfixturevalue("dimos_cluster") + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + elif param == "dask_full": + dimos_cluster = request.getfixturevalue("dimos_cluster") + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = dimos_cluster.deploy(Agent, system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + + +# @pytest.mark.timeout(40) +@pytest.mark.tool +@pytest.mark.asyncio +async def test_agent_init(agent_context) -> None: + """Test agent initialization and basic functionality across different configurations""" + agent, testcontainer = agent_context + + agent.register_skills(testcontainer) + agent.start() + + # agent.run_implicit_skill("uptime_seconds") + + print("query agent") + # When running locally, call the async method directly + agent.query( + "hi there, please tell me what's your name and current date, and how much is 124181112 + 124124?" + ) + print("Agent loop finished, asking about camera") + agent.query("tell me what you see on the camera?") + + # you can run skillspy and agentspy in parallel with this test for a better observation of what's happening diff --git a/dimos/agents2/test_agent_direct.py b/dimos/agents2/test_agent_direct.py new file mode 100644 index 0000000000..97c509e386 --- /dev/null +++ b/dimos/agents2/test_agent_direct.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +from dimos.agents2.agent import Agent +from dimos.core import start +from dimos.protocol.skill.test_coordinator import SkillContainerTest + +system_prompt = ( + "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" +) + + +@contextmanager +def dimos_cluster(): + dimos = start(2) + try: + yield dimos + finally: + dimos.close_all() + + +@contextmanager +def local(): + """Local context: both agent and testcontainer run locally""" + testcontainer = SkillContainerTest() + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() + raise e + finally: + # Ensure cleanup happens while event loop is still active + agent.stop() + testcontainer.stop() + + +@contextmanager +def partial(): + """Dask context: testcontainer on dimos, agent local""" + with dimos_cluster() as dimos: + testcontainer = dimos.deploy(SkillContainerTest) + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + agent.stop() + testcontainer.stop() + + +@contextmanager +def full(): + """Dask context: both agent and testcontainer deployed on dimos""" + with dimos_cluster() as dimos: + testcontainer = dimos.deploy(SkillContainerTest) + agent = dimos.deploy(Agent, system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + agent.stop() + testcontainer.stop() + + +def check_agent(agent_context) -> None: + """Test agent initialization and basic functionality across different configurations""" + with agent_context() as [agent, testcontainer]: + agent.register_skills(testcontainer) + agent.start() + + print("query agent") + + agent.query( + "hi there, please tell me what's your name and current date, and how much is 124181112 + 124124?" + ) + + print("Agent loop finished, asking about camera") + + agent.query("tell me what you see on the camera?") + + print("=" * 150) + print("End of test", agent.get_agent_id()) + print("=" * 150) + + # you can run skillspy and agentspy in parallel with this test for a better observation of what's happening + + +if __name__ == "__main__": + list(map(check_agent, [local, partial, full])) diff --git a/dimos/agents2/test_agent_fake.py b/dimos/agents2/test_agent_fake.py new file mode 100644 index 0000000000..367985a356 --- /dev/null +++ b/dimos/agents2/test_agent_fake.py @@ -0,0 +1,36 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_what_is_your_name(create_potato_agent) -> None: + agent = create_potato_agent(fixture="test_what_is_your_name.json") + response = agent.query("hi there, please tell me what's your name?") + assert "Mr. Potato" in response + + +def test_how_much_is_124181112_plus_124124(create_potato_agent) -> None: + agent = create_potato_agent(fixture="test_how_much_is_124181112_plus_124124.json") + + response = agent.query("how much is 124181112 + 124124?") + assert "124305236" in response.replace(",", "") + + response = agent.query("how much is one billion plus -1000000, in digits please") + assert "999000000" in response.replace(",", "") + + +def test_what_do_you_see_in_this_picture(create_potato_agent) -> None: + agent = create_potato_agent(fixture="test_what_do_you_see_in_this_picture.json") + + response = agent.query("take a photo and tell me what do you see") + assert "outdoor cafe " in response diff --git a/dimos/agents2/test_mock_agent.py b/dimos/agents2/test_mock_agent.py new file mode 100644 index 0000000000..a4a90f6f29 --- /dev/null +++ b/dimos/agents2/test_mock_agent.py @@ -0,0 +1,202 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test agent with FakeChatModel for unit testing.""" + +import time + +from dimos_lcm.sensor_msgs import CameraInfo +from langchain_core.messages import AIMessage, HumanMessage +import pytest + +from dimos.agents2.agent import Agent +from dimos.agents2.testing import MockModel +from dimos.core import LCMTransport, start +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol.skill.test_coordinator import SkillContainerTest +from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + + +def test_tool_call() -> None: + """Test agent initialization and tool call execution.""" + # Create a fake model that will respond with tool calls + fake_model = MockModel( + responses=[ + AIMessage( + content="I'll add those numbers for you.", + tool_calls=[ + { + "name": "add", + "args": {"args": {"x": 5, "y": 3}}, + "id": "tool_call_1", + } + ], + ), + AIMessage(content="Let me do some math..."), + AIMessage(content="The result of adding 5 and 3 is 8."), + ] + ) + + # Create agent with the fake model + agent = Agent( + model_instance=fake_model, + system_prompt="You are a helpful robot assistant with math skills.", + ) + + # Register skills with coordinator + skills = SkillContainerTest() + agent.coordinator.register_skills(skills) + agent.start() + + # Query the agent + agent.query("Please add 5 and 3") + + # Check that tools were bound + assert fake_model.tools is not None + assert len(fake_model.tools) > 0 + + # Verify the model was called and history updated + assert len(agent._history) > 0 + + agent.stop() + + +def test_image_tool_call() -> None: + """Test agent with image tool call execution.""" + dimos = start(2) + # Create a fake model that will respond with image tool calls + fake_model = MockModel( + responses=[ + AIMessage( + content="I'll take a photo for you.", + tool_calls=[ + { + "name": "take_photo", + "args": {"args": {}}, + "id": "tool_call_image_1", + } + ], + ), + AIMessage(content="I've taken the photo. The image shows a cafe scene."), + ] + ) + + # Create agent with the fake model + agent = Agent( + model_instance=fake_model, + system_prompt="You are a helpful robot assistant with camera capabilities.", + ) + + test_skill_module = dimos.deploy(SkillContainerTest) + + agent.register_skills(test_skill_module) + agent.start() + + agent.run_implicit_skill("get_detections") + + # Query the agent + agent.query("Please take a photo") + + # Check that tools were bound + assert fake_model.tools is not None + assert len(fake_model.tools) > 0 + + # Verify the model was called and history updated + assert len(agent._history) > 0 + + # Check that image was handled specially + # Look for HumanMessage with image content in history + human_messages_with_images = [ + msg + for msg in agent._history + if isinstance(msg, HumanMessage) and msg.content and isinstance(msg.content, list) + ] + assert len(human_messages_with_images) >= 0 # May have image messages + agent.stop() + test_skill_module.stop() + dimos.close_all() + + +@pytest.mark.tool +def test_tool_call_implicit_detections() -> None: + """Test agent with image tool call execution.""" + dimos = start(2) + # Create a fake model that will respond with image tool calls + fake_model = MockModel( + responses=[ + AIMessage( + content="I'll take a photo for you.", + tool_calls=[ + { + "name": "take_photo", + "args": {"args": {}}, + "id": "tool_call_image_1", + } + ], + ), + AIMessage(content="I've taken the photo. The image shows a cafe scene."), + ] + ) + + # Create agent with the fake model + agent = Agent( + model_instance=fake_model, + system_prompt="You are a helpful robot assistant with camera capabilities.", + ) + + robot_connection = dimos.deploy(ConnectionModule, connection_type="fake") + robot_connection.lidar.transport = LCMTransport("/lidar", LidarMessage) + robot_connection.odom.transport = LCMTransport("/odom", PoseStamped) + robot_connection.video.transport = LCMTransport("/image", Image) + robot_connection.movecmd.transport = LCMTransport("/cmd_vel", Vector3) + robot_connection.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + robot_connection.start() + + test_skill_module = dimos.deploy(SkillContainerTest) + + agent.register_skills(test_skill_module) + agent.start() + + agent.run_implicit_skill("get_detections") + + print( + "Robot replay pipeline is running in the background.\nWaiting 8.5 seconds for some detections before quering agent" + ) + time.sleep(8.5) + + # Query the agent + agent.query("Please take a photo") + + # Check that tools were bound + assert fake_model.tools is not None + assert len(fake_model.tools) > 0 + + # Verify the model was called and history updated + assert len(agent._history) > 0 + + # Check that image was handled specially + # Look for HumanMessage with image content in history + human_messages_with_images = [ + msg + for msg in agent._history + if isinstance(msg, HumanMessage) and msg.content and isinstance(msg.content, list) + ] + assert len(human_messages_with_images) >= 0 + + agent.stop() + test_skill_module.stop() + robot_connection.stop() + dimos.stop() diff --git a/dimos/agents2/test_stash_agent.py b/dimos/agents2/test_stash_agent.py new file mode 100644 index 0000000000..0f904c6040 --- /dev/null +++ b/dimos/agents2/test_stash_agent.py @@ -0,0 +1,61 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.agents2.agent import Agent +from dimos.protocol.skill.test_coordinator import SkillContainerTest + + +@pytest.mark.tool +@pytest.mark.asyncio +async def test_agent_init() -> None: + system_prompt = ( + "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" + ) + + # # Uncomment the following lines to use a dimos module system + # dimos = start(2) + # testcontainer = dimos.deploy(SkillContainerTest) + # agent = Agent(system_prompt=system_prompt) + + ## uncomment the following lines to run agents in a main loop without a module system + testcontainer = SkillContainerTest() + agent = Agent(system_prompt=system_prompt) + + agent.register_skills(testcontainer) + agent.start() + + agent.run_implicit_skill("uptime_seconds") + + await agent.query_async( + "hi there, please tell me what's your name and current date, and how much is 124181112 + 124124?" + ) + + # agent loop is considered finished once no active skills remain, + # agent will stop it's loop if passive streams are active + print("Agent loop finished, asking about camera") + + # we query again (this shows subsequent querying, but we could have asked for camera image in the original query, + # it all runs in parallel, and agent might get called once or twice depending on timing of skill responses) + # await agent.query_async("tell me what you see on the camera?") + + # you can run skillspy and agentspy in parallel with this test for a better observation of what's happening + await agent.query_async("tell me exactly everything we've talked about until now") + + print("Agent loop finished") + + agent.stop() + testcontainer.stop() + dimos.stop() diff --git a/dimos/agents2/testing.py b/dimos/agents2/testing.py new file mode 100644 index 0000000000..dc563b9ea9 --- /dev/null +++ b/dimos/agents2/testing.py @@ -0,0 +1,197 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing utilities for agents.""" + +from collections.abc import Iterator, Sequence +import json +import os +from pathlib import Path +from typing import Any + +from langchain.chat_models import init_chat_model +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import SimpleChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable + + +class MockModel(SimpleChatModel): + """Custom fake chat model that supports tool calls for testing. + + Can operate in two modes: + 1. Playback mode (default): Reads responses from a JSON file or list + 2. Record mode: Uses a real LLM and saves responses to a JSON file + """ + + responses: list[str | AIMessage] = [] + i: int = 0 + json_path: Path | None = None + record: bool = False + real_model: Any | None = None + recorded_messages: list[dict[str, Any]] = [] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + # Extract custom parameters before calling super().__init__ + responses = kwargs.pop("responses", []) + json_path = kwargs.pop("json_path", None) + model_provider = kwargs.pop("model_provider", "openai") + model_name = kwargs.pop("model_name", "gpt-4o") + + super().__init__(**kwargs) + + self.json_path = Path(json_path) if json_path else None + self.record = bool(os.getenv("RECORD")) + self.i = 0 + self._bound_tools: Sequence[Any] | None = None + self.recorded_messages = [] + + if self.record: + # Initialize real model for recording + self.real_model = init_chat_model(model_provider=model_provider, model=model_name) + self.responses = [] # Initialize empty for record mode + elif self.json_path: + self.responses = self._load_responses_from_json() # type: ignore[assignment] + elif responses: + self.responses = responses + else: + raise ValueError("no responses") + + @property + def _llm_type(self) -> str: + return "tool-call-fake-chat-model" + + def _load_responses_from_json(self) -> list[AIMessage]: + with open(self.json_path) as f: # type: ignore[arg-type] + data = json.load(f) + + responses = [] + for item in data.get("responses", []): + if isinstance(item, str): + responses.append(AIMessage(content=item)) + else: + # Reconstruct AIMessage from dict + msg = AIMessage( + content=item.get("content", ""), tool_calls=item.get("tool_calls", []) + ) + responses.append(msg) + return responses + + def _save_responses_to_json(self) -> None: + if not self.json_path: + return + + self.json_path.parent.mkdir(parents=True, exist_ok=True) + + data = { + "responses": [ + {"content": msg.content, "tool_calls": getattr(msg, "tool_calls", [])} + if isinstance(msg, AIMessage) + else msg + for msg in self.recorded_messages + ] + } + + with open(self.json_path, "w") as f: + json.dump(data, f, indent=2, default=str) + + def _call( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> str: + """Not used in _generate.""" + return "" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + if self.record: + # Recording mode - use real model and save responses + if not self.real_model: + raise ValueError("Real model not initialized for recording") + + # Bind tools if needed + model = self.real_model + if self._bound_tools: + model = model.bind_tools(self._bound_tools) + + result = model.invoke(messages) + self.recorded_messages.append(result) + self._save_responses_to_json() + + generation = ChatGeneration(message=result) + return ChatResult(generations=[generation]) + else: + # Playback mode - use predefined responses + if not self.responses: + raise ValueError("No responses available for playback. ") + + if self.i >= len(self.responses): + # Don't wrap around - stay at last response + response = self.responses[-1] + else: + response = self.responses[self.i] + self.i += 1 + + if isinstance(response, AIMessage): + message = response + else: + message = AIMessage(content=str(response)) + + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + def _stream( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream not implemented for testing.""" + result = self._generate(messages, stop, run_manager, **kwargs) + message = result.generations[0].message + chunk = AIMessageChunk(content=message.content) + yield ChatGenerationChunk(message=chunk) + + def bind_tools( + self, + tools: Sequence[dict[str, Any] | type | Any], + *, + tool_choice: str | None = None, + **kwargs: Any, + ) -> Runnable: # type: ignore[type-arg] + """Store tools and return self.""" + self._bound_tools = tools + if self.record and self.real_model: + # Also bind tools to the real model + self.real_model = self.real_model.bind_tools(tools, tool_choice=tool_choice, **kwargs) + return self + + @property + def tools(self) -> Sequence[Any] | None: + """Get bound tools for inspection.""" + return self._bound_tools diff --git a/dimos/conftest.py b/dimos/conftest.py new file mode 100644 index 0000000000..57632216c2 --- /dev/null +++ b/dimos/conftest.py @@ -0,0 +1,124 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading + +import pytest + + +@pytest.fixture +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +_session_threads = set() +_seen_threads = set() +_seen_threads_lock = threading.RLock() +_before_test_threads = {} # Map test name to set of thread IDs before test + +_skip_for = ["lcm", "heavy", "ros"] + + +@pytest.fixture(scope="module") +def dimos_cluster(): + from dimos.core import start + + dimos = start(4) + try: + yield dimos + finally: + dimos.stop() + + +@pytest.hookimpl() +def pytest_sessionfinish(session): + """Track threads that exist at session start - these are not leaks.""" + + yield + + # Check for session-level thread leaks at teardown + final_threads = [ + t + for t in threading.enumerate() + if t.name != "MainThread" and t.ident not in _session_threads + ] + + if final_threads: + thread_info = [f"{t.name} (daemon={t.daemon})" for t in final_threads] + pytest.fail( + f"\n{len(final_threads)} thread(s) leaked during test session: {thread_info}\n" + "Session-scoped fixtures must clean up all threads in their teardown." + ) + + +@pytest.fixture(autouse=True) +def monitor_threads(request): + # Skip monitoring for tests marked with specified markers + if any(request.node.get_closest_marker(marker) for marker in _skip_for): + yield + return + + # Capture threads before test runs + test_name = request.node.nodeid + with _seen_threads_lock: + _before_test_threads[test_name] = { + t.ident for t in threading.enumerate() if t.ident is not None + } + + yield + + with _seen_threads_lock: + before = _before_test_threads.get(test_name, set()) + current = {t.ident for t in threading.enumerate() if t.ident is not None} + + # New threads are ones that exist now but didn't exist before this test + new_thread_ids = current - before + + if not new_thread_ids: + return + + # Get the actual thread objects for new threads + new_threads = [ + t for t in threading.enumerate() if t.ident in new_thread_ids and t.name != "MainThread" + ] + + # Filter out expected persistent threads from Dask that are shared globally + # These threads are intentionally left running and cleaned up on process exit + expected_persistent_thread_prefixes = ["Dask-Offload"] + new_threads = [ + t + for t in new_threads + if not any(t.name.startswith(prefix) for prefix in expected_persistent_thread_prefixes) + ] + + # Filter out threads we've already seen (from previous tests) + truly_new = [t for t in new_threads if t.ident not in _seen_threads] + + # Mark all new threads as seen + for t in new_threads: + if t.ident is not None: + _seen_threads.add(t.ident) + + if not truly_new: + return + + thread_names = [t.name for t in truly_new] + + pytest.fail( + f"Non-closed threads created during this test. Thread names: {thread_names}. " + "Please look at the first test that fails and fix that." + ) diff --git a/dimos/constants.py b/dimos/constants.py new file mode 100644 index 0000000000..4e74ccbe1b --- /dev/null +++ b/dimos/constants.py @@ -0,0 +1,34 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +DIMOS_PROJECT_ROOT = Path(__file__).parent.parent + +DIMOS_LOG_DIR = DIMOS_PROJECT_ROOT / "logs" + +""" +Constants for shared memory +Usually, auto-detection for size would be preferred. Sadly, though, channels are made +and frozen *before* the first frame is received. +Therefore, a maximum capacity for color image and depth image transfer should be defined +ahead of time. +""" +# Default color image size: 1920x1080 frame x 3 (RGB) x uint8 +DEFAULT_CAPACITY_COLOR_IMAGE = 1920 * 1080 * 3 +# Default depth image size: 1280x720 frame * 4 (float32 size) +DEFAULT_CAPACITY_DEPTH_IMAGE = 1280 * 720 * 4 + +# From https://github.com/lcm-proj/lcm.git +LCM_MAX_CHANNEL_NAME_LENGTH = 63 diff --git a/dimos/core/README_BLUEPRINTS.md b/dimos/core/README_BLUEPRINTS.md new file mode 100644 index 0000000000..9c4c08c60a --- /dev/null +++ b/dimos/core/README_BLUEPRINTS.md @@ -0,0 +1,260 @@ +# Blueprints + +Blueprints (`ModuleBlueprint`) are instructions for how to initialize a `Module`. + +You don't typically want to run a single module, so multiple blueprints are handled together in `ModuleBlueprintSet`. + +You create a `ModuleBlueprintSet` from a single module (say `ConnectionModule`) with: + +```python +blueprint = create_module_blueprint(ConnectionModule, 'arg1', 'arg2', kwarg='value') +``` + +But the same thing can be acomplished more succinctly as: + +```python +connection = ConnectionModule.blueprint +``` + +Now you can create the blueprint with: + +```python +blueprint = connection('arg1', 'arg2', kwarg='value') +``` + +## Linking blueprints + +You can link multiple blueprints together with `autoconnect`: + +```python +blueprint = autoconnect( + module1(), + module2(), + module3(), +) +``` + +`blueprint` itself is a `ModuleBlueprintSet` so you can link it with other modules: + +```python +expanded_blueprint = autoconnect( + blueprint, + module4(), + module5(), +) +``` + +Blueprints are frozen data classes, and `autoconnect()` always constructs an expanded blueprint so you never have to worry about changes in one affecting the other. + +### Duplicate module handling + +If the same module appears multiple times in `autoconnect`, the **later blueprint wins** and overrides earlier ones: + +```python +blueprint = autoconnect( + module_a(arg1=1), + module_b(), + module_a(arg1=2), # This one is used, the first is discarded +) +``` + +This is so you can "inherit" from one blueprint but override something you need to change. + +## How transports are linked + +Imagine you have this code: + +```python +class ModuleA(Module): + image: Out[Image] = None + start_explore: Out[Bool] = None + +class ModuleB(Module): + image: In[Image] = None + begin_explore: In[Bool] = None + +module_a = partial(create_module_blueprint, ModuleA) +module_b = partial(create_module_blueprint, ModuleB) + +autoconnect(module_a(), module_b()) +``` + +Connections are linked based on `(property_name, object_type)`. In this case `('image', Image)` will be connected between the two modules, but `begin_explore` will not be linked to `start_explore`. + +## Topic names + +By default, the name of the property is used to generate the topic name. So for `image`, the topic will be `/image`. + +Streams with the same name must have the same type. If two streams share a name but have different types, `build()` raises a `ValueError`. This ensures type safety - an `Out[Temperature]` won't accidentally connect to an `In[Pressure]` just because both are named `data`. + +If you don't like the default topic name you can always override it like in the next section. + +## Which transport is used? + +By default `LCMTransport` is used if the object supports `lcm_encode`. If it doesn't `pLCMTransport` is used (meaning "pickled LCM"). + +You can override transports with the `transports` method. It returns a new blueprint in which the override is set. + +```python +blueprint = autoconnect(...) +expanded_blueprint = autoconnect(blueprint, ...) +blueprint = blueprint.transports({ + ("image", Image): pSHMTransport( + "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + ("start_explore", Bool): pLCMTransport(), +}) +``` + +Note: `expanded_blueprint` does not get the transport overrides because it's created from the initial value of `blueprint`, not the second. + +## Remapping connections + +Sometimes you need to rename a connection to match what other modules expect. You can use `remappings` to rename module connections: + +```python +class ConnectionModule(Module): + color_image: Out[Image] = None # Outputs on 'color_image' + +class ProcessingModule(Module): + rgb_image: In[Image] = None # Expects input on 'rgb_image' + +# Without remapping, these wouldn't connect automatically +# With remapping, color_image is renamed to rgb_image +blueprint = ( + autoconnect( + ConnectionModule.blueprint(), + ProcessingModule.blueprint(), + ) + .remappings([ + (ConnectionModule, 'color_image', 'rgb_image'), + ]) +) +``` + +After remapping: +- The `color_image` output from `ConnectionModule` is treated as `rgb_image` +- It automatically connects to any module with an `rgb_image` input of type `Image` +- The topic name becomes `/rgb_image` instead of `/color_image` + +If you want to override the topic, you still have to do it manually: + +```python +blueprint +.remappings([ + (ConnectionModule, 'color_image', 'rgb_image'), +]) +.transports({ + ("rgb_image", Image): LCMTransport("/custom/rgb/image", Image), +}) +``` + +## Overriding global configuration. + +Each module can optionally take a `global_config` option in `__init__`. E.g.: + +```python +class ModuleA(Module): + + def __init__(self, global_config: GlobalConfig | None = None): + ... +``` + +The config is normally taken from .env or from environment variables. But you can specifically override the values for a specific blueprint: + +```python +blueprint = blueprint.global_config(n_dask_workers=8) +``` + +## Calling the methods of other modules + +Imagine you have this code: + +```python +class ModuleA(Module): + + @rpc + def get_time(self) -> str: + ... + +class ModuleB(Module): + def request_the_time(self) -> None: + ... +``` + +And you want to call `ModuleA.get_time` in `ModuleB.request_the_time`. + +You can do so by defining a method like `set__`. It will be called with an `RpcCall` that will call the original `ModuleA.get_time`. So you can write this: + +```python +class ModuleA(Module): + + @rpc + def get_time(self) -> str: + ... + +class ModuleB(Module): + @rpc # Note that it has to be an rpc method. + def set_ModuleA_get_time(self, rpc_call: RpcCall) -> None: + self._get_time = rpc_call + self._get_time.set_rpc(self.rpc) + + def request_the_time(self) -> None: + print(self._get_time()) +``` + +Note that `RpcCall.rpc` does not serialize, so you have to set it to the one from the module with `rpc_call.set_rpc(self.rpc)` + +## Defining skills + +Skills have to be registered with `LlmAgent.register_skills(self)`. + +```python +class SomeSkill(Module): + + @skill + def some_skill(self) -> None: + ... + + @rpc + def set_LlmAgent_register_skills(self, register_skills: RpcCall) -> None: + register_skills.set_rpc(self.rpc) + register_skills(RPCClient(self, self.__class__)) + + # The agent is just interested in the `@skill` methods, so you'll need this if your class + # has things that cannot be pickled. + def __getstate__(self): + pass + def __setstate__(self, _state): + pass +``` + +Or, you can avoid all of this by inheriting from `SkillModule` which does the above automatically: + +```python +class SomeSkill(SkillModule): + + @skill + def some_skill(self) -> None: + ... +``` + +## Building + +All you have to do to build a blueprint is call: + +```python +module_coordinator = blueprint.build(global_config=config) +``` + +This returns a `ModuleCoordinator` instance that manages all deployed modules. + +### Running and shutting down + +You can block the thread until it exits with: + +```python +module_coordinator.loop() +``` + +This will wait for Ctrl+C and then automatically stop all modules and clean up resources. diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py new file mode 100644 index 0000000000..c8bb091596 --- /dev/null +++ b/dimos/core/__init__.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import multiprocessing as mp +import signal +import time + +from dask.distributed import Client, LocalCluster +from rich.console import Console + +import dimos.core.colors as colors +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleBase, ModuleConfig +from dimos.core.rpc_client import RPCClient +from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport +from dimos.core.transport import ( + LCMTransport, + SHMTransport, + ZenohTransport, + pLCMTransport, + pSHMTransport, +) +from dimos.protocol.rpc import LCMRPC +from dimos.protocol.rpc.spec import RPCSpec +from dimos.protocol.tf import LCMTF, TF, PubSubTF, TFConfig, TFSpec +from dimos.utils.actor_registry import ActorRegistry +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +__all__ = [ + "LCMRPC", + "LCMTF", + "TF", + "DimosCluster", + "In", + "LCMTransport", + "Module", + "ModuleBase", + "ModuleConfig", + "Out", + "PubSubTF", + "RPCSpec", + "RemoteIn", + "RemoteOut", + "SHMTransport", + "TFConfig", + "TFSpec", + "Transport", + "ZenohTransport", + "pLCMTransport", + "pSHMTransport", + "rpc", + "start", +] + + +class CudaCleanupPlugin: + """Dask worker plugin to cleanup CUDA resources on shutdown.""" + + def setup(self, worker) -> None: # type: ignore[no-untyped-def] + """Called when worker starts.""" + pass + + def teardown(self, worker) -> None: # type: ignore[no-untyped-def] + """Clean up CUDA resources when worker shuts down.""" + try: + import sys + + if "cupy" in sys.modules: + import cupy as cp # type: ignore[import-not-found] + + # Clear memory pools + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + cp.cuda.Stream.null.synchronize() + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + except Exception: + pass + + +def patch_actor(actor, cls) -> None: ... # type: ignore[no-untyped-def] + + +DimosCluster = Client + + +def patchdask(dask_client: Client, local_cluster: LocalCluster) -> DimosCluster: + def deploy( # type: ignore[no-untyped-def] + actor_class, + *args, + **kwargs, + ): + console = Console() + with console.status(f"deploying [green]{actor_class.__name__}\n", spinner="arc"): + actor = dask_client.submit( # type: ignore[no-untyped-call] + actor_class, + *args, + **kwargs, + actor=True, + ).result() + + worker = actor.set_ref(actor).result() + logger.info("Deployed module.", module=actor._cls.__name__, worker_id=worker) + + # Register actor deployment in shared memory + ActorRegistry.update(str(actor), str(worker)) + + return RPCClient(actor, actor_class) + + def check_worker_memory() -> None: + """Check memory usage of all workers.""" + info = dask_client.scheduler_info() + console = Console() + total_workers = len(info.get("workers", {})) + total_memory_used = 0 + total_memory_limit = 0 + + for worker_addr, worker_info in info.get("workers", {}).items(): + metrics = worker_info.get("metrics", {}) + memory_used = metrics.get("memory", 0) + memory_limit = worker_info.get("memory_limit", 0) + + cpu_percent = metrics.get("cpu", 0) + managed_bytes = metrics.get("managed_bytes", 0) + spilled = metrics.get("spilled_bytes", {}).get("memory", 0) + worker_status = worker_info.get("status", "unknown") + worker_id = worker_info.get("id", "?") + + memory_used_gb = memory_used / 1e9 + memory_limit_gb = memory_limit / 1e9 + managed_gb = managed_bytes / 1e9 + spilled / 1e9 + + total_memory_used += memory_used + total_memory_limit += memory_limit + + percentage = (memory_used_gb / memory_limit_gb * 100) if memory_limit_gb > 0 else 0 + + if worker_status == "paused": + status = "[red]PAUSED" + elif percentage >= 95: + status = "[red]CRITICAL" + elif percentage >= 80: + status = "[yellow]WARNING" + else: + status = "[green]OK" + + console.print( + f"Worker-{worker_id} {worker_addr}: " + f"{memory_used_gb:.2f}/{memory_limit_gb:.2f}GB ({percentage:.1f}%) " + f"CPU:{cpu_percent:.0f}% Managed:{managed_gb:.2f}GB " + f"{status}" + ) + + if total_workers > 0: + total_used_gb = total_memory_used / 1e9 + total_limit_gb = total_memory_limit / 1e9 + total_percentage = (total_used_gb / total_limit_gb * 100) if total_limit_gb > 0 else 0 + console.print( + f"[bold]Total: {total_used_gb:.2f}/{total_limit_gb:.2f}GB ({total_percentage:.1f}%) across {total_workers} workers[/bold]" + ) + + def close_all() -> None: + # Prevents multiple calls to close_all + if hasattr(dask_client, "_closed") and dask_client._closed: + return + dask_client._closed = True # type: ignore[attr-defined] + + # Stop all SharedMemory transports before closing Dask + # This prevents the "leaked shared_memory objects" warning and hangs + try: + import gc + + from dimos.protocol.pubsub import shmpubsub + + for obj in gc.get_objects(): + if isinstance(obj, shmpubsub.SharedMemory | shmpubsub.PickleSharedMemory): + try: + obj.stop() + except Exception: + pass + except Exception: + pass + + # Get the event loop before shutting down + loop = dask_client.loop + + # Clear the actor registry + ActorRegistry.clear() + + # Close cluster and client with reasonable timeout + # The CudaCleanupPlugin will handle CUDA cleanup on each worker + try: + local_cluster.close(timeout=5) + except Exception: + pass + + try: + dask_client.close(timeout=5) # type: ignore[no-untyped-call] + except Exception: + pass + + if loop and hasattr(loop, "add_callback") and hasattr(loop, "stop"): + try: + loop.add_callback(loop.stop) + except Exception: + pass + + # Note: We do NOT shutdown the _offload_executor here because it's a global + # module-level ThreadPoolExecutor shared across all Dask clients in the process. + # Shutting it down here would break subsequent Dask client usage (e.g., in tests). + # The executor will be cleaned up when the Python process exits. + + # Give threads time to clean up + # Dask's IO loop and Profile threads are daemon threads + # that will be cleaned up when the process exits + # This is needed, solves race condition in CI thread check + time.sleep(0.1) + + dask_client.deploy = deploy # type: ignore[attr-defined] + dask_client.check_worker_memory = check_worker_memory # type: ignore[attr-defined] + dask_client.stop = lambda: dask_client.close() # type: ignore[attr-defined, no-untyped-call] + dask_client.close_all = close_all # type: ignore[attr-defined] + return dask_client + + +def start(n: int | None = None, memory_limit: str = "auto") -> DimosCluster: + """Start a Dask LocalCluster with specified workers and memory limits. + + Args: + n: Number of workers (defaults to CPU count) + memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default) + + Returns: + DimosCluster: A patched Dask client with deploy(), check_worker_memory(), stop(), and close_all() methods + """ + + console = Console() + if not n: + n = mp.cpu_count() + with console.status( + f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc" + ): + cluster = LocalCluster( # type: ignore[no-untyped-call] + n_workers=n, + threads_per_worker=4, + memory_limit=memory_limit, + plugins=[CudaCleanupPlugin()], # Register CUDA cleanup plugin + ) + client = Client(cluster) # type: ignore[no-untyped-call] + + console.print( + f"[green]Initialized dimos local cluster with [bright_blue]{n} workers, memory limit: {memory_limit}" + ) + + patched_client = patchdask(client, cluster) + patched_client._shutting_down = False # type: ignore[attr-defined] + + # Signal handler with proper exit handling + def signal_handler(sig, frame) -> None: # type: ignore[no-untyped-def] + # If already shutting down, force exit + if patched_client._shutting_down: # type: ignore[attr-defined] + import os + + console.print("[red]Force exit!") + os._exit(1) + + patched_client._shutting_down = True # type: ignore[attr-defined] + console.print(f"[yellow]Shutting down (signal {sig})...") + + try: + patched_client.close_all() # type: ignore[attr-defined] + except Exception: + pass + + import sys + + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + return patched_client + + +def wait_exit() -> None: + while True: + try: + time.sleep(1) + except KeyboardInterrupt: + print("exiting...") + return diff --git a/dimos/core/_test_future_annotations_helper.py b/dimos/core/_test_future_annotations_helper.py new file mode 100644 index 0000000000..08c5ec0063 --- /dev/null +++ b/dimos/core/_test_future_annotations_helper.py @@ -0,0 +1,36 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper module for testing blueprint handling with PEP 563 (future annotations). + +This file exists because `from __future__ import annotations` affects the entire file. +""" + +from __future__ import annotations + +from dimos.core.module import Module +from dimos.core.stream import In, Out # noqa + + +class FutureData: + pass + + +class FutureModuleOut(Module): + data: Out[FutureData] = None # type: ignore[assignment] + + +class FutureModuleIn(Module): + data: In[FutureData] = None # type: ignore[assignment] diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py new file mode 100644 index 0000000000..e3a483e25c --- /dev/null +++ b/dimos/core/blueprints.py @@ -0,0 +1,627 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Declarative module composition and automatic connection wiring. + +Instead of manually connecting modules, define blueprints that specify each module's +data dependencies (`In[T]`, `Out[T]` streams) and RPC method requirements. The blueprint +system automatically wires streams, selects transports, and links RPC methods between +modules, when building the blueprint. + +Core components +--------------- +`ModuleBlueprint` + Immutable specification for instantiating a single module with its stream + connections. Created via `Module.blueprint()`. + +`ModuleBlueprintSet` + Container for multiple blueprints with builder methods for configuration: + `transports()`, `global_config()`, `remappings()`, and `build()`. + +`autoconnect()` + Combine multiple `ModuleBlueprintSet` instances into one composed system. + Deduplicates blueprints, merging configuration with last-wins semantics. + +Basic usage +----------- +Streams match by name and type. Use `.remappings()` when names differ: + + blueprint = autoconnect( + CameraModule.blueprint(), + ProcessorModule.blueprint() + ).remappings([ + (CameraModule, "color_image", "rgb"), + (ProcessorModule, "rgb_input", "rgb"), + ]) + coordinator = blueprint.build() + coordinator.loop() # Run until interrupted + +For detailed explanation of connection matching, composition patterns, transport +selection, and RPC wiring, see `/docs/concepts/blueprints.md`. + +See also +-------- +`dimos.core.module` + Module base class with `In[T]`/`Out[T]` stream declarations. + +`dimos.core.module_coordinator` + Runtime manager for deployed modules. Returned by `ModuleBlueprintSet.build()`. + +`dimos.core.transport` + Transport implementations (`LCMTransport`, `pLCMTransport`) for inter-module + communication. +""" + +from abc import ABC +from collections import defaultdict +from collections.abc import Callable, Mapping +from dataclasses import dataclass, field +from functools import cached_property, reduce +import inspect +import operator +import sys +from types import MappingProxyType +from typing import Annotated, Any, Literal, get_args, get_origin, get_type_hints + +from annotated_doc import Doc + +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.stream import In, Out +from dimos.core.transport import LCMTransport, pLCMTransport +from dimos.utils.generic import short_id + + +@dataclass(frozen=True) +class ModuleConnection: + name: str + type: type + direction: Literal["in", "out"] + + +@dataclass(frozen=True) +class ModuleBlueprint: + """Declarative specification for instantiating and wiring a module. + + A ModuleBlueprint captures everything needed to instantiate and connect a module + in a distributed system without actually creating the module instance. This separation + enables composition, configuration, and deployment to be handled independently. + + Blueprints are immutable and serve as the specification layer between + high-level system design and runtime deployment. They are typically created using + `Module.blueprint()`, combined into `ModuleBlueprintSet` containers via `autoconnect()`, + and deployed by `ModuleCoordinator.build()`. + """ + + module: Annotated[ + type[Module], + Doc( + "The module class to instantiate. This is the constructor, not an instance, allowing late binding during deployment." + ), + ] + connections: Annotated[ + tuple[ModuleConnection, ...], + Doc( + "Typed stream connections extracted from the module's type annotations. These specify how this module's streams should be wired to other modules during deployment." + ), + ] + args: Annotated[ + tuple[Any], Doc("Positional arguments to pass to the module's `__init__` method.") + ] + kwargs: Annotated[ + dict[str, Any], Doc("Keyword arguments to pass to the module's `__init__` method.") + ] + + +@dataclass(frozen=True) +class ModuleBlueprintSet: + blueprints: tuple[ModuleBlueprint, ...] + # TODO: Replace Any + transport_map: Mapping[tuple[str, type], Any] = field( + default_factory=lambda: MappingProxyType({}) + ) + global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) + remapping_map: Mapping[tuple[type[Module], str], str] = field( + default_factory=lambda: MappingProxyType({}) + ) + requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple) + + def transports( + self, + transports: Annotated[ + dict[tuple[str, type], Any], + Doc( + """Dictionary mapping (connection_name, data_type) to transport instances. + Both the connection name and data type must match for the override to apply.""" + ), + ], + ) -> Annotated[ + "ModuleBlueprintSet", Doc("New ModuleBlueprintSet with merged transport overrides.") + ]: + """Register explicit transport overrides for specific connections. + + By default, dimos auto-selects transports based on whether the data type + has an `lcm_encode` method. Use this to override those defaults when you need: + - Shared memory (SHM) transports for high-bandwidth data like images + - Custom topic names + - Specific transport implementations for performance or compatibility + """ + return ModuleBlueprintSet( + blueprints=self.blueprints, + transport_map=MappingProxyType({**self.transport_map, **transports}), + global_config_overrides=self.global_config_overrides, + remapping_map=self.remapping_map, + requirement_checks=self.requirement_checks, + ) + + def global_config( + self, + **kwargs: Annotated[ + Any, + Doc( + """Key-value pairs to override in GlobalConfig (e.g., n_dask_workers, log_level). + Values are validated during build().""" + ), + ], + ) -> Annotated[ + "ModuleBlueprintSet", Doc("New ModuleBlueprintSet with merged configuration overrides.") + ]: + """Override GlobalConfig parameters for modules in this blueprint set. + + These overrides take precedence over configuration from .env files or + environment variables. Useful for deployment-specific settings, debugging, + or testing without changing global configuration. + """ + return ModuleBlueprintSet( + blueprints=self.blueprints, + transport_map=self.transport_map, + global_config_overrides=MappingProxyType({**self.global_config_overrides, **kwargs}), + remapping_map=self.remapping_map, + requirement_checks=self.requirement_checks, + ) + + def remappings( + self, + remappings: Annotated[ + list[tuple[type[Module], str, str]], + Doc( + """List of (module_class, old_name, new_name) tuples specifying + that the module's connection 'old_name' should be treated as 'new_name'.""" + ), + ], + ) -> Annotated[ + "ModuleBlueprintSet", Doc("New ModuleBlueprintSet with updated connection remappings.") + ]: + """Rename module connections to enable interoperability between modules. + + Allows modules with different naming conventions to communicate. Remapping + is transparent to modules—they use their original names internally. Remapped + names are used only for connection matching (connections match if remapped + names and types both match) and topic generation. + + Examples: + Connect modules with different naming conventions: + + >>> class CameraModule(Module): + ... color_image: Out[str] = None + >>> class ProcessorModule(Module): + ... rgb_input: In[str] = None + >>> + >>> blueprint = autoconnect( + ... CameraModule.blueprint(), + ... ProcessorModule.blueprint() + ... ) + >>> blueprint = blueprint.remappings([ + ... (CameraModule, "color_image", "rgb_image"), + ... (ProcessorModule, "rgb_input", "rgb_image") + ... ]) + >>> # Now both connections use "rgb_image" and will be connected + + Broadcast to multiple consumers: + + >>> class SensorModule(Module): + ... output: Out[str] = None + >>> class ProcessorA(Module): + ... input: In[str] = None + >>> class ProcessorB(Module): + ... input_stream: In[str] = None + >>> + >>> blueprint = autoconnect( + ... SensorModule.blueprint(), + ... ProcessorA.blueprint(), + ... ProcessorB.blueprint() + ... ) + >>> blueprint = blueprint.remappings([ + ... (SensorModule, "output", "shared_data"), + ... (ProcessorA, "input", "shared_data"), + ... (ProcessorB, "input_stream", "shared_data") + ... ]) + >>> # All three connections now share the same transport + """ + remappings_dict = dict(self.remapping_map) + for module, old, new in remappings: + remappings_dict[(module, old)] = new + + return ModuleBlueprintSet( + blueprints=self.blueprints, + transport_map=self.transport_map, + global_config_overrides=self.global_config_overrides, + remapping_map=MappingProxyType(remappings_dict), + requirement_checks=self.requirement_checks, + ) + + def requirements(self, *checks: Callable[[], str | None]) -> "ModuleBlueprintSet": + return ModuleBlueprintSet( + blueprints=self.blueprints, + transport_map=self.transport_map, + global_config_overrides=self.global_config_overrides, + remapping_map=self.remapping_map, + requirement_checks=self.requirement_checks + tuple(checks), + ) + + def _get_transport_for( + self, + name: Annotated[str, Doc("Connection name after remapping.")], + type: Annotated[type, Doc("Data type from the module's type annotations.")], + ) -> Annotated[Any, Doc("Transport instance for the connection.")]: + """Determine and create the appropriate transport for a connection. + + Selection priority: + 1. Explicit transport override in transport_map + 2. Auto-select based on type: LCMTransport if type has lcm_encode, else pLCMTransport + 3. Topic naming: /{name} if unique, else random ID + + Connections with identical (remapped_name, type) pairs share the same + transport instance, enabling pub/sub communication. + """ + transport = self.transport_map.get((name, type), None) + if transport: + return transport + + use_pickled = getattr(type, "lcm_encode", None) is None + topic = f"/{name}" if self._is_name_unique(name) else f"/{short_id()}" + transport = pLCMTransport(topic) if use_pickled else LCMTransport(topic, type) + + return transport + + @cached_property + def _all_name_types(self) -> set[tuple[str, type]]: + # Apply remappings to get the actual names that will be used + result = set() + for blueprint in self.blueprints: + for conn in blueprint.connections: + # Check if this connection should be remapped + remapped_name = self.remapping_map.get((blueprint.module, conn.name), conn.name) + result.add((remapped_name, conn.type)) + return result + + def _is_name_unique(self, name: str) -> bool: + return sum(1 for n, _ in self._all_name_types if n == name) == 1 + + def _check_requirements(self) -> None: + errors = [] + red = "\033[31m" + reset = "\033[0m" + + for check in self.requirement_checks: + error = check() + if error: + errors.append(error) + + if errors: + for error in errors: + print(f"{red}Error: {error}{reset}", file=sys.stderr) + sys.exit(1) + + def _verify_no_name_conflicts(self) -> None: + name_to_types = defaultdict(set) + name_to_modules = defaultdict(list) + + for blueprint in self.blueprints: + for conn in blueprint.connections: + connection_name = self.remapping_map.get((blueprint.module, conn.name), conn.name) + name_to_types[connection_name].add(conn.type) + name_to_modules[connection_name].append((blueprint.module, conn.type)) + + conflicts = {} + for conn_name, types in name_to_types.items(): + if len(types) > 1: + modules_by_type = defaultdict(list) + for module, conn_type in name_to_modules[conn_name]: + modules_by_type[conn_type].append(module) + conflicts[conn_name] = modules_by_type + + if not conflicts: + return + + error_lines = ["Blueprint cannot start because there are conflicting connections."] + for name, modules_by_type in conflicts.items(): + type_entries = [] + for conn_type, modules in modules_by_type.items(): + for module in modules: + type_str = f"{conn_type.__module__}.{conn_type.__name__}" + module_str = module.__name__ + type_entries.append((type_str, module_str)) + if len(type_entries) >= 2: + locations = ", ".join(f"{type_} in {module}" for type_, module in type_entries) + error_lines.append(f" - '{name}' has conflicting types. {locations}") + + raise ValueError("\n".join(error_lines)) + + def _deploy_all_modules( + self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig + ) -> None: + for blueprint in self.blueprints: + kwargs = {**blueprint.kwargs} + sig = inspect.signature(blueprint.module.__init__) + if "global_config" in sig.parameters: + kwargs["global_config"] = global_config + module_coordinator.deploy(blueprint.module, *blueprint.args, **kwargs) + + def _connect_transports( + self, + module_coordinator: Annotated[ + ModuleCoordinator, + Doc( + """Coordinator containing deployed module instances. + All modules in self.blueprints must have been deployed first.""" + ), + ], + ) -> None: + """Establish transport connections between deployed modules. + + Processes all stream connections, applies name remappings, groups connections + by (remapped_name, type), and assigns the same transport instance to all + connections within each group. + + Key design decision: Connections with matching (remapped_name, type) share + the same transport instance, enabling pub/sub communication. Type mismatches + prevent sharing even with matching names, preserving type safety. Remapping + is transparent—modules use their original connection names internally. + """ + # Gather all the In/Out connections with remapping applied. + connections = defaultdict(list) + # Track original name -> remapped name for each module + module_conn_mapping = defaultdict(dict) # type: ignore[var-annotated] + + for blueprint in self.blueprints: + for conn in blueprint.connections: + # Check if this connection should be remapped + remapped_name = self.remapping_map.get((blueprint.module, conn.name), conn.name) + # Store the mapping for later use + module_conn_mapping[blueprint.module][conn.name] = remapped_name + # Group by remapped name and type + connections[remapped_name, conn.type].append((blueprint.module, conn.name)) + + # Connect all In/Out connections by remapped name and type. + for remapped_name, type in connections.keys(): + transport = self._get_transport_for(remapped_name, type) + for module, original_name in connections[(remapped_name, type)]: + instance = module_coordinator.get_instance(module) + instance.set_transport(original_name, transport) # type: ignore[union-attr] + + def _connect_rpc_methods( + self, + module_coordinator: Annotated[ + ModuleCoordinator, + Doc("Coordinator containing deployed module instances with initialized RPC servers."), + ], + ) -> None: + """Wire up inter-module RPC method calls. + + Processes two independent wiring mechanisms: + - Implicit: Calls `set_ClassName_method()` setters with bound methods + - Explicit: Populates `rpc_calls` requests via `set_rpc_method()` + + Interface-based requests (e.g., `NavigationInterface.get_state`) resolve to a + single implementation or fail at build time if ambiguous. Missing methods are + skipped (error raised at call time via `get_rpc_calls()`). + """ + # Gather all RPC methods. + rpc_methods = {} + rpc_methods_dot = {} + # Track interface methods to detect ambiguity + interface_methods = defaultdict(list) # interface_name.method -> [(module_class, method)] + + for blueprint in self.blueprints: + for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] + method = getattr(module_coordinator.get_instance(blueprint.module), method_name) + # Register under concrete class name (backward compatibility) + rpc_methods[f"{blueprint.module.__name__}_{method_name}"] = method + rpc_methods_dot[f"{blueprint.module.__name__}.{method_name}"] = method + + # Also register under any interface names + for base in blueprint.module.__bases__: + # Check if this base is an abstract interface with the method + if ( + base is not Module + and issubclass(base, ABC) + and hasattr(base, method_name) + and getattr(base, method_name, None) is not None + ): + interface_key = f"{base.__name__}.{method_name}" + interface_methods[interface_key].append((blueprint.module, method)) + + # Check for ambiguity in interface methods and add non-ambiguous ones + for interface_key, implementations in interface_methods.items(): + if len(implementations) == 1: + rpc_methods_dot[interface_key] = implementations[0][1] + + # Fulfil method requests (so modules can call each other). + for blueprint in self.blueprints: + instance = module_coordinator.get_instance(blueprint.module) + for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] + if not method_name.startswith("set_"): + continue + linked_name = method_name.removeprefix("set_") + if linked_name not in rpc_methods: + continue + getattr(instance, method_name)(rpc_methods[linked_name]) + for requested_method_name in instance.get_rpc_method_names(): # type: ignore[union-attr] + # Check if this is an ambiguous interface method + if ( + requested_method_name in interface_methods + and len(interface_methods[requested_method_name]) > 1 + ): + modules_str = ", ".join( + impl[0].__name__ for impl in interface_methods[requested_method_name] + ) + raise ValueError( + f"Ambiguous RPC method '{requested_method_name}' requested by " + f"{blueprint.module.__name__}. Multiple implementations found: " + f"{modules_str}. Please use a concrete class name instead." + ) + + if requested_method_name not in rpc_methods_dot: + continue + instance.set_rpc_method( # type: ignore[union-attr] + requested_method_name, rpc_methods_dot[requested_method_name] + ) + + def build( + self, + global_config: Annotated[ + GlobalConfig | None, + Doc( + """Base configuration for the system. Defaults to GlobalConfig(). + Blueprint overrides take precedence.""" + ), + ] = None, + ) -> Annotated[ + ModuleCoordinator, + Doc( + """Fully initialized ModuleCoordinator. Call coordinator.loop() to run + or coordinator.stop() to shut down cleanly.""" + ), + ]: + """Transform this blueprint specification into a running distributed system. + + Terminal operation in the blueprint builder pattern. Creates a fully initialized + ModuleCoordinator with all modules deployed, connected, and ready to run. + + Build process: + 1. Merge global_config with blueprint-level overrides (overrides take precedence) + 2. Create and start ModuleCoordinator with Dask cluster + 3. Deploy all modules + 4. Connect stream transports (In/Out with matching remapped names and types) + 5. Link RPC methods between modules + 6. Start all deployed modules + + Raises: + ValueError: If an RPC method request is ambiguous (multiple modules + implement the same interface method). Use concrete class name instead. + + Examples: + >>> class CameraModule(Module): + ... color_image: Out[str] = None + >>> class ProcessorModule(Module): + ... image_in: In[str] = None + >>> + >>> blueprint = ( + ... autoconnect( + ... CameraModule.blueprint(), + ... ProcessorModule.blueprint() + ... ) + ... .remappings([ + ... (CameraModule, "color_image", "rgb_input"), + ... (ProcessorModule, "image_in", "rgb_input"), + ... ]) + ... .global_config(n_dask_workers=2) + ... ) + >>> coordinator = blueprint.build() + >>> # ...do whatever you want to do with the coordinator + >>> coordinator.stop() + """ + if global_config is None: + global_config = GlobalConfig() + global_config = global_config.model_copy(update=self.global_config_overrides) + + self._check_requirements() + self._verify_no_name_conflicts() + + module_coordinator = ModuleCoordinator(global_config=global_config) + module_coordinator.start() + + self._deploy_all_modules(module_coordinator, global_config) + self._connect_transports(module_coordinator) + self._connect_rpc_methods(module_coordinator) + + module_coordinator.start_all_modules() + + return module_coordinator + + +def _make_module_blueprint( + module: type[Module], args: tuple[Any], kwargs: dict[str, Any] +) -> ModuleBlueprint: + connections: list[ModuleConnection] = [] + + # Use get_type_hints() to properly resolve string annotations. + try: + all_annotations = get_type_hints(module) + except Exception: + # Fallback to raw annotations if get_type_hints fails. + all_annotations = {} + for base_class in reversed(module.__mro__): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) + + for name, annotation in all_annotations.items(): + origin = get_origin(annotation) + if origin not in (In, Out): + continue + direction = "in" if origin == In else "out" + type_ = get_args(annotation)[0] + connections.append(ModuleConnection(name=name, type=type_, direction=direction)) # type: ignore[arg-type] + + return ModuleBlueprint(module=module, connections=tuple(connections), args=args, kwargs=kwargs) + + +def create_module_blueprint(module: type[Module], *args: Any, **kwargs: Any) -> ModuleBlueprintSet: + blueprint = _make_module_blueprint(module, args, kwargs) + return ModuleBlueprintSet(blueprints=(blueprint,)) + + +def autoconnect(*blueprints: ModuleBlueprintSet) -> ModuleBlueprintSet: + all_blueprints = tuple(_eliminate_duplicates([bp for bs in blueprints for bp in bs.blueprints])) + all_transports = dict( # type: ignore[var-annotated] + reduce(operator.iadd, [list(x.transport_map.items()) for x in blueprints], []) + ) + all_config_overrides = dict( # type: ignore[var-annotated] + reduce(operator.iadd, [list(x.global_config_overrides.items()) for x in blueprints], []) + ) + all_remappings = dict( # type: ignore[var-annotated] + reduce(operator.iadd, [list(x.remapping_map.items()) for x in blueprints], []) + ) + all_requirement_checks = tuple(check for bs in blueprints for check in bs.requirement_checks) + + return ModuleBlueprintSet( + blueprints=all_blueprints, + transport_map=MappingProxyType(all_transports), + global_config_overrides=MappingProxyType(all_config_overrides), + remapping_map=MappingProxyType(all_remappings), + requirement_checks=all_requirement_checks, + ) + + +def _eliminate_duplicates(blueprints: list[ModuleBlueprint]) -> list[ModuleBlueprint]: + # The duplicates are eliminated in reverse so that newer blueprints override older ones. + seen = set() + unique_blueprints = [] + for bp in reversed(blueprints): + if bp.module not in seen: + seen.add(bp.module) + unique_blueprints.append(bp) + return list(reversed(unique_blueprints)) diff --git a/dimos/core/colors.py b/dimos/core/colors.py new file mode 100644 index 0000000000..294cf5d43b --- /dev/null +++ b/dimos/core/colors.py @@ -0,0 +1,43 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def green(text: str) -> str: + """Return the given text in green color.""" + return f"\033[92m{text}\033[0m" + + +def blue(text: str) -> str: + """Return the given text in blue color.""" + return f"\033[94m{text}\033[0m" + + +def red(text: str) -> str: + """Return the given text in red color.""" + return f"\033[91m{text}\033[0m" + + +def yellow(text: str) -> str: + """Return the given text in yellow color.""" + return f"\033[93m{text}\033[0m" + + +def cyan(text: str) -> str: + """Return the given text in cyan color.""" + return f"\033[96m{text}\033[0m" + + +def orange(text: str) -> str: + """Return the given text in orange color.""" + return f"\033[38;5;208m{text}\033[0m" diff --git a/dimos/core/core.py b/dimos/core/core.py new file mode 100644 index 0000000000..e7a7d09f58 --- /dev/null +++ b/dimos/core/core.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + TypeVar, +) + +from dimos.core.o3dpickle import register_picklers + +if TYPE_CHECKING: + from collections.abc import Callable + +# injects pickling system into o3d +register_picklers() +T = TypeVar("T") + + +def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: + fn.__rpc__ = True # type: ignore[attr-defined] + return fn diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py new file mode 100644 index 0000000000..7f0b145db4 --- /dev/null +++ b/dimos/core/global_config.py @@ -0,0 +1,41 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class GlobalConfig(BaseSettings): + robot_ip: str | None = None + simulation: bool = False + replay: bool = False + n_dask_workers: int = 2 + mujoco_room: str | None = None + robot_model: str | None = None + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + frozen=True, + ) + + @cached_property + def unitree_connection_type(self) -> str: + if self.replay: + return "replay" + if self.simulation: + return "mujoco" + return "webrtc" diff --git a/dimos/core/module.py b/dimos/core/module.py new file mode 100644 index 0000000000..cbdeae964d --- /dev/null +++ b/dimos/core/module.py @@ -0,0 +1,394 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from collections.abc import Callable +from dataclasses import dataclass +from functools import partial +import inspect +import threading +from typing import ( + Any, + get_args, + get_origin, + get_type_hints, + overload, +) + +from dask.distributed import Actor, get_worker +from reactivex.disposable import CompositeDisposable + +from dimos.core import colors +from dimos.core.core import T, rpc +from dimos.core.resource import Resource +from dimos.core.rpc_client import RpcCall +from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport +from dimos.protocol.rpc import LCMRPC, RPCSpec +from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.skill.skill import SkillContainer +from dimos.protocol.tf import LCMTF, TFSpec +from dimos.utils.generic import classproperty + + +def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: + # we are actually instantiating a new loop here + # to not interfere with an existing dask loop + + # try: + # # here we attempt to figure out if we are running on a dask worker + # # if so we use the dask worker _loop as ours, + # # and we register our RPC server + # worker = get_worker() + # if worker.loop: + # print("using dask worker loop") + # return worker.loop.asyncio_loop + + # except ValueError: + # ... + + try: + running_loop = asyncio.get_running_loop() + return running_loop, None + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + thr = threading.Thread(target=loop.run_forever, daemon=True) + thr.start() + return loop, thr + + +@dataclass +class ModuleConfig: + rpc_transport: type[RPCSpec] = LCMRPC + tf_transport: type[TFSpec] = LCMTF + + +class ModuleBase(Configurable[ModuleConfig], SkillContainer, Resource): + """Base class for distributed modules in DimOS. + + ModuleBase provides core infrastructure for building distributed, communicating modules. + It integrates RPC communication, stream management, skill hosting, and lifecycle + management. For Dask cluster deployment, use `DaskModule` (aliased as `Module`), which + extends ModuleBase with Dask actor integration and distributed stream handling. + + Inherits from `Configurable[ModuleConfig]`, `SkillContainer`, and `Resource` to provide + configuration management, AI-agent-callable skill hosting, and lifecycle management. + + Core responsibilities: + + - Lifecycle management: Initialize, start, and stop module resources + - RPC infrastructure: Expose methods for remote procedure calls via @rpc decorator + - Stream discovery: Expose In/Out streams for blueprint auto-wiring + - Skill hosting (via `SkillContainer`): Methods decorated with `@skill` are callable by AI agents + - Serialization: Support pickling for distributed deployment across Dask workers + + Attributes: + _rpc: RPC transport instance for remote method calls. + _tf: Transform framework instance (lazy-initialized). + _loop: Event loop for async operations. + _loop_thread: Thread hosting the event loop (if created). + _disposables: Container for subscription cleanup. + _bound_rpc_calls: Bound external RPC methods. + rpc_calls: Declared RPC dependencies (subclass-defined). + default_config: Default configuration class (ModuleConfig or subclass). + + See the tutorials for guided examples of custom modules. + + Notes: + Most applications use `DaskModule` (available as `Module` alias) rather than + `ModuleBase` directly. + + When subclassing: + + - Always call `super().__init__()` in your `__init__` + - Always call `super().stop()` in your `stop()` + - Use `@rpc` to expose methods for remote invocation + - Use `@skill()` to expose methods to AI agents + - Declare RPC dependencies in the `rpc_calls` class attribute + """ + + _rpc: RPCSpec | None = None + _tf: TFSpec | None = None + _loop: asyncio.AbstractEventLoop | None = None + _loop_thread: threading.Thread | None + _disposables: CompositeDisposable + _bound_rpc_calls: dict[str, RpcCall] = {} + + rpc_calls: list[str] = [] + + default_config = ModuleConfig + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self._loop, self._loop_thread = get_loop() + self._disposables = CompositeDisposable() + # we can completely override comms protocols if we want + try: + # here we attempt to figure out if we are running on a dask worker + # if so we use the dask worker _loop as ours, + # and we register our RPC server + self.rpc = self.config.rpc_transport() + self.rpc.serve_module_rpc(self) + self.rpc.start() # type: ignore[attr-defined] + except ValueError: + ... + + @rpc + def start(self) -> None: + pass + + @rpc + def stop(self) -> None: + self._close_module() + super().stop() + + def _close_module(self) -> None: + self._close_rpc() + if hasattr(self, "_loop") and self._loop_thread: + if self._loop_thread.is_alive(): + self._loop.call_soon_threadsafe(self._loop.stop) # type: ignore[union-attr] + self._loop_thread.join(timeout=2) + self._loop = None + self._loop_thread = None + if hasattr(self, "_tf") and self._tf is not None: + self._tf.stop() + self._tf = None + if hasattr(self, "_disposables"): + self._disposables.dispose() + + def _close_rpc(self) -> None: + # Using hasattr is needed because SkillCoordinator skips ModuleBase.__init__ and self.rpc is never set. + if hasattr(self, "rpc") and self.rpc: + self.rpc.stop() # type: ignore[attr-defined] + self.rpc = None # type: ignore[assignment] + + def __getstate__(self): # type: ignore[no-untyped-def] + """Exclude unpicklable runtime attributes when serializing.""" + state = self.__dict__.copy() + # Remove unpicklable attributes + state.pop("_disposables", None) + state.pop("_loop", None) + state.pop("_loop_thread", None) + state.pop("_rpc", None) + state.pop("_tf", None) + return state + + def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] + """Restore object from pickled state.""" + self.__dict__.update(state) + # Reinitialize runtime attributes + self._disposables = CompositeDisposable() + self._loop = None + self._loop_thread = None + self._rpc = None + self._tf = None + + @property + def tf(self): # type: ignore[no-untyped-def] + if self._tf is None: + # self._tf = self.config.tf_transport() + self._tf = LCMTF() + return self._tf + + @tf.setter + def tf(self, value) -> None: # type: ignore[no-untyped-def] + import warnings + + warnings.warn( + "tf is available on all modules. Call self.tf.start() to activate tf functionality. No need to assign it", + UserWarning, + stacklevel=2, + ) + + @property + def outputs(self) -> dict[str, Out]: # type: ignore[type-arg] + return { + name: s + for name, s in self.__dict__.items() + if isinstance(s, Out) and not name.startswith("_") + } + + @property + def inputs(self) -> dict[str, In]: # type: ignore[type-arg] + return { + name: s + for name, s in self.__dict__.items() + if isinstance(s, In) and not name.startswith("_") + } + + @classmethod # type: ignore[misc] + @property + def rpcs(cls) -> dict[str, Callable]: # type: ignore[type-arg] + return { + name: getattr(cls, name) + for name in dir(cls) + if not name.startswith("_") + and name != "rpcs" # Exclude the rpcs property itself to prevent recursion + and callable(getattr(cls, name, None)) + and hasattr(getattr(cls, name), "__rpc__") + } + + @rpc + def io(self) -> str: + def _box(name: str) -> str: + return [ # type: ignore[return-value] + "┌┴" + "─" * (len(name) + 1) + "┐", + f"│ {name} │", + "└┬" + "─" * (len(name) + 1) + "┘", + ] + + # can't modify __str__ on a function like we are doing for I/O + # so we have a separate repr function here + def repr_rpc(fn: Callable) -> str: # type: ignore[type-arg] + sig = inspect.signature(fn) + # Remove 'self' parameter + params = [p for name, p in sig.parameters.items() if name != "self"] + + # Format parameters with colored types + param_strs = [] + for param in params: + param_str = param.name + if param.annotation != inspect.Parameter.empty: + type_name = getattr(param.annotation, "__name__", str(param.annotation)) + param_str += ": " + colors.green(type_name) + if param.default != inspect.Parameter.empty: + param_str += f" = {param.default}" + param_strs.append(param_str) + + # Format return type + return_annotation = "" + if sig.return_annotation != inspect.Signature.empty: + return_type = getattr(sig.return_annotation, "__name__", str(sig.return_annotation)) + return_annotation = " -> " + colors.green(return_type) + + return ( + "RPC " + colors.blue(fn.__name__) + f"({', '.join(param_strs)})" + return_annotation + ) + + ret = [ + *(f" ├─ {stream}" for stream in self.inputs.values()), + *_box(self.__class__.__name__), + *(f" ├─ {stream}" for stream in self.outputs.values()), + " │", + *(f" ├─ {repr_rpc(rpc)}" for rpc in self.rpcs.values()), + ] + + return "\n".join(ret) + + @classproperty + def blueprint(self): # type: ignore[no-untyped-def] + # Here to prevent circular imports. + from dimos.core.blueprints import create_module_blueprint + + return partial(create_module_blueprint, self) # type: ignore[arg-type] + + @rpc + def get_rpc_method_names(self) -> list[str]: + return self.rpc_calls + + @rpc + def set_rpc_method(self, method: str, callable: RpcCall) -> None: + callable.set_rpc(self.rpc) # type: ignore[arg-type] + self._bound_rpc_calls[method] = callable + + @overload + def get_rpc_calls(self, method: str) -> RpcCall: ... + + @overload + def get_rpc_calls(self, method1: str, method2: str, *methods: str) -> tuple[RpcCall, ...]: ... + + def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: # type: ignore[misc] + missing = [m for m in methods if m not in self._bound_rpc_calls] + if missing: + raise ValueError( + f"RPC methods not found. Class: {self.__class__.__name__}, RPC methods: {', '.join(missing)}" + ) + result = tuple(self._bound_rpc_calls[m] for m in methods) + return result[0] if len(result) == 1 else result + + +class DaskModule(ModuleBase): + ref: Actor + worker: int + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + self.ref = None # type: ignore[assignment] + + # Get type hints with proper namespace resolution for subclasses + # Collect namespaces from all classes in the MRO chain + import sys + + globalns = {} + for cls in self.__class__.__mro__: + if cls.__module__ in sys.modules: + globalns.update(sys.modules[cls.__module__].__dict__) + + try: + hints = get_type_hints(self.__class__, globalns=globalns, include_extras=True) + except (NameError, AttributeError, TypeError): + # If we still can't resolve hints, skip type hint processing + # This can happen with complex forward references + hints = {} + + for name, ann in hints.items(): + origin = get_origin(ann) + if origin is Out: + inner, *_ = get_args(ann) or (Any,) + stream = Out(inner, name, self) # type: ignore[var-annotated] + setattr(self, name, stream) + elif origin is In: + inner, *_ = get_args(ann) or (Any,) + stream = In(inner, name, self) # type: ignore[assignment] + setattr(self, name, stream) + super().__init__(*args, **kwargs) + + def set_ref(self, ref) -> int: # type: ignore[no-untyped-def] + worker = get_worker() + self.ref = ref + self.worker = worker.name + return worker.name # type: ignore[no-any-return] + + def __str__(self) -> str: + return f"{self.__class__.__name__}" + + # called from remote + def set_transport(self, stream_name: str, transport: Transport) -> bool: # type: ignore[type-arg] + stream = getattr(self, stream_name, None) + if not stream: + raise ValueError(f"{stream_name} not found in {self.__class__.__name__}") + + if not isinstance(stream, Out) and not isinstance(stream, In): + raise TypeError(f"Output {stream_name} is not a valid stream") + + stream._transport = transport + return True + + # called from remote + def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): # type: ignore[no-untyped-def] + input_stream = getattr(self, input_name, None) + if not input_stream: + raise ValueError(f"{input_name} not found in {self.__class__.__name__}") + if not isinstance(input_stream, In): + raise TypeError(f"Input {input_name} is not a valid stream") + input_stream.connection = remote_stream + + def dask_receive_msg(self, input_name: str, msg: Any) -> None: + getattr(self, input_name).transport.dask_receive_msg(msg) + + def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]) -> None: + getattr(self, output_name).transport.dask_register_subscriber(subscriber) + + +# global setting +Module = DaskModule diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py new file mode 100644 index 0000000000..450900359a --- /dev/null +++ b/dimos/core/module_coordinator.py @@ -0,0 +1,170 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import TypeVar + +from dimos import core +from dimos.core import DimosCluster, Module +from dimos.core.global_config import GlobalConfig +from dimos.core.resource import Resource + +T = TypeVar("T", bound="Module") + + +class ModuleCoordinator(Resource): + """Orchestrate distributed module lifecycle on a Dask cluster. + + ModuleCoordinator manages cluster startup, module deployment as distributed + actors, and graceful shutdown. It maintains a type-indexed registry where + each module class maps to at most one deployed instance. + + Most users should use `blueprint.build()` instead, which handles cluster + startup, module deployment, and inter-module connections automatically. + Direct use of ModuleCoordinator is for custom configuration or manual + lifecycle control. + + Examples: + Recommended approach using blueprints: + + >>> from dimos.core.blueprints import autoconnect # doctest: +SKIP + >>> blueprint = autoconnect(SomeModule.blueprint(), OtherModule.blueprint()) # doctest: +SKIP + >>> coordinator = blueprint.build() # doctest: +SKIP + >>> coordinator.loop() # doctest: +SKIP + + Direct instantiation for custom control: + + >>> coordinator = ModuleCoordinator(n=4, memory_limit="4GB") # doctest: +SKIP + >>> coordinator.start() # doctest: +SKIP + >>> module = coordinator.deploy(MyModule) # doctest: +SKIP + >>> coordinator.start_all_modules() # doctest: +SKIP + >>> coordinator.loop() # doctest: +SKIP + """ + + _client: DimosCluster | None = None + _n: int | None = None + _memory_limit: str = "auto" + _deployed_modules: dict[type[Module], Module] = {} + + def __init__( + self, + n: int | None = None, + memory_limit: str = "auto", + global_config: GlobalConfig | None = None, + ) -> None: + """Initialize coordinator with cluster configuration. + + Args: + n: Number of Dask worker processes. Falls back to + `global_config.n_dask_workers` if None. + memory_limit: Memory limit per worker (e.g., "4GB", "500MB", "auto"). + global_config: System-wide settings. If None, uses defaults. + """ + cfg = global_config or GlobalConfig() + self._n = n if n is not None else cfg.n_dask_workers + self._memory_limit = memory_limit + + def start(self) -> None: + """Low-level API: Start the underlying Dask cluster. + + Spawns worker processes and initializes the distributed actor system. + After this returns, modules can be deployed via `deploy()`. + + Note: + Calling `start()` on an already-started coordinator overwrites the + existing cluster reference. Use `stop()` first if restarting. + """ + self._client = core.start(self._n, self._memory_limit) + + def stop(self) -> None: + """Shut down all modules and the cluster. + + Stops modules in reverse deployment order, then closes the Dask cluster. + + Note: + Raises `AttributeError` if called before `start()`. For automatic + lifecycle management, use `loop()` instead. + """ + for module in reversed(self._deployed_modules.values()): + module.stop() + + self._client.close_all() # type: ignore[union-attr] + + def deploy(self, module_class: type[T], *args, **kwargs) -> T: # type: ignore[no-untyped-def] + """Low-level API: Deploy a module class as a distributed actor. + + Creates an instance of `module_class` on a Dask worker and returns an + RPC proxy for remote method calls. The coordinator tracks one instance + per module class; deploying the same class again overwrites the registry + entry (the original actor remains in the cluster but becomes untracked). + + Args: + module_class: The Module subclass to deploy. + *args: Positional arguments for `module_class.__init__`. + **kwargs: Keyword arguments for `module_class.__init__`. + + Returns: + RPCClient proxy to the deployed actor. + + Raises: + ValueError: If `start()` has not been called. + """ + if not self._client: + raise ValueError("Not started") + + module = self._client.deploy(module_class, *args, **kwargs) # type: ignore[attr-defined] + self._deployed_modules[module_class] = module + return module # type: ignore[no-any-return] + + def start_all_modules(self) -> None: + """Low-level API: Call `start()` on all deployed modules. + + Initializes each module's RPC server and event loop in deployment order. + After this completes, modules are ready to process messages. + + If a module's `start()` raises, the exception propagates immediately. + Modules started before the failure remain running; no rollback occurs. + """ + for module in self._deployed_modules.values(): + module.start() + + def get_instance(self, module: type[T]) -> T | None: + """Retrieve a deployed module by its class. + + Args: + module: The module class to look up. + + Returns: + The RPCClient proxy if deployed, or None if not found. + """ + return self._deployed_modules.get(module) # type: ignore[return-value] + + def loop(self) -> None: + """Block until interrupted, then shut down gracefully. + + Sleeps indefinitely until Ctrl+C (SIGINT), then calls `stop()` to clean + up all modules and the cluster. This is the standard way to run a + long-lived DimOS application. + + Examples: + >>> coordinator = blueprint.build() # doctest: +SKIP + >>> coordinator.loop() # doctest: +SKIP + """ + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + return + finally: + self.stop() diff --git a/dimos/core/o3dpickle.py b/dimos/core/o3dpickle.py new file mode 100644 index 0000000000..1912ab7739 --- /dev/null +++ b/dimos/core/o3dpickle.py @@ -0,0 +1,38 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copyreg + +import numpy as np +import open3d as o3d # type: ignore[import-untyped] + + +def reduce_external(obj): # type: ignore[no-untyped-def] + # Convert Vector3dVector to numpy array for pickling + points_array = np.asarray(obj.points) + return (reconstruct_pointcloud, (points_array,)) + + +def reconstruct_pointcloud(points_array): # type: ignore[no-untyped-def] + # Create new PointCloud and assign the points + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(points_array) + return pc + + +def register_picklers() -> None: + # Register for the actual PointCloud class that gets instantiated + # We need to create a dummy PointCloud to get its actual class + _dummy_pc = o3d.geometry.PointCloud() + copyreg.pickle(_dummy_pc.__class__, reduce_external) diff --git a/dimos/core/resource.py b/dimos/core/resource.py new file mode 100644 index 0000000000..21cdec6322 --- /dev/null +++ b/dimos/core/resource.py @@ -0,0 +1,23 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class Resource(ABC): + @abstractmethod + def start(self) -> None: ... + + @abstractmethod + def stop(self) -> None: ... diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py new file mode 100644 index 0000000000..a3d1a2da0c --- /dev/null +++ b/dimos/core/rpc_client.py @@ -0,0 +1,141 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Any + +from dimos.protocol.rpc import LCMRPC +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class RpcCall: + _original_method: Callable[..., Any] | None + _rpc: LCMRPC | None + _name: str + _remote_name: str + _unsub_fns: list # type: ignore[type-arg] + _stop_rpc_client: Callable[[], None] | None = None + + def __init__( + self, + original_method: Callable[..., Any] | None, + rpc: LCMRPC, + name: str, + remote_name: str, + unsub_fns: list, # type: ignore[type-arg] + stop_client: Callable[[], None] | None = None, + ) -> None: + self._original_method = original_method + self._rpc = rpc + self._name = name + self._remote_name = remote_name + self._unsub_fns = unsub_fns + self._stop_rpc_client = stop_client + + if original_method: + self.__doc__ = original_method.__doc__ + self.__name__ = original_method.__name__ + self.__qualname__ = f"{self.__class__.__name__}.{original_method.__name__}" + + def set_rpc(self, rpc: LCMRPC) -> None: + self._rpc = rpc + + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + if not self._rpc: + logger.warning("RPC client not initialized") + return None + + # For stop, use call_nowait to avoid deadlock + # (the remote side stops its RPC service before responding) + if self._name == "stop": + self._rpc.call_nowait(f"{self._remote_name}/{self._name}", (args, kwargs)) # type: ignore[arg-type] + if self._stop_rpc_client: + self._stop_rpc_client() + return None + + result, unsub_fn = self._rpc.call_sync(f"{self._remote_name}/{self._name}", (args, kwargs)) # type: ignore[arg-type] + self._unsub_fns.append(unsub_fn) + return result + + def __getstate__(self): # type: ignore[no-untyped-def] + return (self._original_method, self._name, self._remote_name) + + def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] + self._original_method, self._name, self._remote_name = state + self._unsub_fns = [] + self._rpc = None + self._stop_rpc_client = None + + +class RPCClient: + def __init__(self, actor_instance, actor_class) -> None: # type: ignore[no-untyped-def] + self.rpc = LCMRPC() + self.actor_class = actor_class + self.remote_name = actor_class.__name__ + self.actor_instance = actor_instance + self.rpcs = actor_class.rpcs.keys() + self.rpc.start() + self._unsub_fns = [] # type: ignore[var-annotated] + + def stop_rpc_client(self) -> None: + for unsub in self._unsub_fns: + try: + unsub() + except Exception: + pass + + self._unsub_fns = [] + + if self.rpc: + self.rpc.stop() + self.rpc = None # type: ignore[assignment] + + def __reduce__(self): # type: ignore[no-untyped-def] + # Return the class and the arguments needed to reconstruct the object + return ( + self.__class__, + (self.actor_instance, self.actor_class), + ) + + # passthrough + def __getattr__(self, name: str): # type: ignore[no-untyped-def] + # Check if accessing a known safe attribute to avoid recursion + if name in { + "__class__", + "__init__", + "__dict__", + "__getattr__", + "rpcs", + "remote_name", + "remote_instance", + "actor_instance", + }: + raise AttributeError(f"{name} is not found.") + + if name in self.rpcs: + original_method = getattr(self.actor_class, name, None) + return RpcCall( + original_method, + self.rpc, + name, + self.remote_name, + self._unsub_fns, + self.stop_rpc_client, + ) + + # return super().__getattr__(name) + # Try to avoid recursion by directly accessing attributes that are known + return self.actor_instance.__getattr__(name) diff --git a/dimos/core/skill_module.py b/dimos/core/skill_module.py new file mode 100644 index 0000000000..212d7bbb99 --- /dev/null +++ b/dimos/core/skill_module.py @@ -0,0 +1,32 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.module import Module +from dimos.core.rpc_client import RpcCall, RPCClient +from dimos.protocol.skill.skill import rpc + + +class SkillModule(Module): + """Use this module if you want to auto-register skills to an LlmAgent.""" + + @rpc + def set_LlmAgent_register_skills(self, callable: RpcCall) -> None: + callable.set_rpc(self.rpc) # type: ignore[arg-type] + callable(RPCClient(self, self.__class__)) + + def __getstate__(self) -> None: + pass + + def __setstate__(self, _state) -> None: # type: ignore[no-untyped-def] + pass diff --git a/dimos/core/stream.py b/dimos/core/stream.py new file mode 100644 index 0000000000..77057a8cbd --- /dev/null +++ b/dimos/core/stream.py @@ -0,0 +1,347 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import enum +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Generic, + TypeVar, +) + +from annotated_doc import Doc +from dask.distributed import Actor +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import Disposable + +import dimos.core.colors as colors +from dimos.utils.logging_config import setup_logger +import dimos.utils.reactive as reactive +from dimos.utils.reactive import backpressure + +if TYPE_CHECKING: + from collections.abc import Callable + + from reactivex.observable import Observable + +T = TypeVar("T") + + +logger = setup_logger() + + +class ObservableMixin(Generic[T]): + # subscribes and returns the first value it receives + # might be nicer to write without rxpy but had this snippet ready + def get_next(self, timeout: float = 10.0) -> T: + try: + return ( # type: ignore[no-any-return] + self.observable() # type: ignore[no-untyped-call] + .pipe(ops.first(), *([ops.timeout(timeout)] if timeout is not None else [])) + .run() + ) + except Exception as e: + raise Exception(f"No value received after {timeout} seconds") from e + + def hot_latest(self) -> Callable[[], T]: + return reactive.getter_streaming(self.observable()) # type: ignore[no-untyped-call] + + def pure_observable(self) -> Observable[T]: + def _subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + unsubscribe = self.subscribe(observer.on_next) # type: ignore[attr-defined] + return Disposable(unsubscribe) + + return rx.create(_subscribe) + + # default return is backpressured because most + # use cases will want this by default + def observable(self): # type: ignore[no-untyped-def] + return backpressure(self.pure_observable()) + + +class State(enum.Enum): + UNBOUND = "unbound" # descriptor defined but not bound + READY = "ready" # bound to owner but not yet connected + CONNECTED = "connected" # input bound to an output + FLOWING = "flowing" # runtime: data observed + + +# TODO: Would be better to auto-generate the class inheritance diagram -- easy for this to get out of sync +class Transport(ObservableMixin[T]): + """Abstraction layer for message passing between modules. + + Transport decouples module communication from the underlying protocol (e.g. LCM, + SharedMemory, Dask RPC), enabling the same module code to deploy across + single-machine dev, robot+laptop test, or distributed cluster environments. + Transport selection happens at blueprint configuration time, not in module code. + + Available concrete implementations: + + Class Inheritance: + Transport[T] + └── PubSubTransport[T] (topic-based publish-subscribe) + ├── LCMTransport (network, typed) + │ └── JpegLcmTransport + ├── pLCMTransport (network, pickled) + ├── SHMTransport (local, typed) + ├── pSHMTransport (local, pickled) + ├── JpegShmTransport (local, JPEG)* + └── ZenohTransport (network, distributed) + + * JpegShmTransport extends PubSubTransport directly (not SHMTransport) + + Functional groups: + Network-capable: LCM*, pLCM*, JpegLcm*, Zenoh* + Local-only: SHM*, pSHM*, JpegShm* + + Requirements for concrete Transport implementations: + - **Serialization**: Must be serializable (via ``__reduce__``) for distributed + deployment across process boundaries. + - **Lazy Initialization**: Resources (sockets, shared memory, threads) + allocated on first broadcast() or subscribe(). + + Error handling contract: + - Errors in subscriber callbacks must not prevent delivery to other subscribers. + - Errors must not propagate to the publisher. + - Transport remains operational after subscriber errors. + + Backpressure strategy (via observable()): + Uses latest-value semantics optimized for robotics where current sensor + state matters more than historical completeness. Slow consumers skip + intermediate values; producers are never blocked. + + See also: + docs/concepts/transport.md: Guide to the Transport concept. + """ + + def broadcast( + self, + selfstream: Annotated[ + Out[T] | None, + Doc( + """The originating output stream. Provides routing context + and debugging. Implementations may ignore this if not needed.""" + ), + ], + value: Annotated[T, Doc("The message to deliver to all subscribers.")], + ) -> None: + """Deliver a value to all subscribers.""" + ... + + def publish(self, msg: T) -> None: + """Broadcast a message without stream context.""" + self.broadcast(None, msg) # type: ignore[arg-type] + + # used by local Input + def subscribe( + self, + selfstream: Annotated[ + In[T] | None, + Doc( + """The subscribing input stream. Provides routing context and debugging. + Implementations may ignore this parameter.""" + ), + ], + callback: Annotated[ + Callable[[T], Any], + Doc( + """Invoked for each received value. Errors in the callback + do not affect other subscribers or the publisher.""" + ), + ], + ) -> Annotated[ + Callable[[], None] | None, + Doc("Unsubscribe function, or None (implementation-specific)."), + ]: + """Register a callback to receive broadcasted values. + + Subscriptions persist until explicitly removed or the transport is destroyed. + """ + ... + + +class Stream(Generic[T]): + _transport: Transport | None # type: ignore[type-arg] + + def __init__( + self, + type: type[T], + name: str, + owner: Any | None = None, + transport: Transport | None = None, # type: ignore[type-arg] + ) -> None: + self.name = name + self.owner = owner + self.type = type + if transport: + self._transport = transport + if not hasattr(self, "_transport"): + self._transport = None + + @property + def type_name(self) -> str: + return getattr(self.type, "__name__", repr(self.type)) + + def _color_fn(self) -> Callable[[str], str]: + if self.state == State.UNBOUND: # type: ignore[attr-defined] + return colors.orange + if self.state == State.READY: # type: ignore[attr-defined] + return colors.blue + if self.state == State.CONNECTED: # type: ignore[attr-defined] + return colors.green + return lambda s: s + + def __str__(self) -> str: + return ( + self.__class__.__name__ + + " " + + self._color_fn()(f"{self.name}[{self.type_name}]") + + " @ " + + ( + colors.orange(self.owner) # type: ignore[arg-type] + if isinstance(self.owner, Actor) + else colors.green(self.owner) # type: ignore[arg-type] + ) + + ("" if not self._transport else " via " + str(self._transport)) + ) + + +class Out(Stream[T]): + _transport: Transport # type: ignore[type-arg] + + def __init__(self, *argv, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*argv, **kwargs) + + @property + def transport(self) -> Transport[T]: + return self._transport + + @transport.setter + def transport(self, value: Transport[T]) -> None: + # just for type checking + ... + + @property + def state(self) -> State: + return State.UNBOUND if self.owner is None else State.READY + + def __reduce__(self): # type: ignore[no-untyped-def] + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return ( + RemoteOut, + ( + self.type, + self.name, + self.owner.ref, + self._transport, + ), + ) + + def publish(self, msg) -> None: # type: ignore[no-untyped-def] + if not hasattr(self, "_transport") or self._transport is None: + logger.warning(f"Trying to publish on Out {self} without a transport") + return + self._transport.broadcast(self, msg) + + +class RemoteStream(Stream[T]): + @property + def state(self) -> State: + return State.UNBOUND if self.owner is None else State.READY + + # this won't work but nvm + @property + def transport(self) -> Transport[T]: + return self._transport # type: ignore[return-value] + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() # type: ignore[union-attr] + self._transport = value + + +class RemoteOut(RemoteStream[T]): + def connect(self, other: RemoteIn[T]): # type: ignore[no-untyped-def] + return other.connect(self) + + def subscribe(self, cb) -> Callable[[], None]: # type: ignore[no-untyped-def] + return self.transport.subscribe(cb, self) # type: ignore[arg-type, func-returns-value, no-any-return] + + +# representation of Input +# as views from inside of the module +class In(Stream[T], ObservableMixin[T]): + connection: RemoteOut[T] | None = None + _transport: Transport # type: ignore[type-arg] + + def __str__(self) -> str: + mystr = super().__str__() + + if not self.connection: + return mystr + + return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" + + def __reduce__(self): # type: ignore[no-untyped-def] + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) + + @property + def transport(self) -> Transport[T]: + if not self._transport: + self._transport = self.connection.transport # type: ignore[union-attr] + return self._transport + + @transport.setter + def transport(self, value: Transport[T]) -> None: + # just for type checking + ... + + def connect(self, value: Out[T]) -> None: + # just for type checking + ... + + @property + def state(self) -> State: + return State.UNBOUND if self.owner is None else State.READY + + # returns unsubscribe function + def subscribe(self, cb) -> Callable[[], None]: # type: ignore[no-untyped-def] + return self.transport.subscribe(cb, self) # type: ignore[arg-type, func-returns-value, no-any-return] + + +# representation of input outside of module +# used for configuring connections, setting a transport +class RemoteIn(RemoteStream[T]): + def connect(self, other: RemoteOut[T]) -> None: + return self.owner.connect_stream(self.name, other).result() # type: ignore[no-any-return, union-attr] + + # this won't work but that's ok + @property # type: ignore[misc] + def transport(self) -> Transport[T]: + return self._transport # type: ignore[return-value] + + def publish(self, msg) -> None: # type: ignore[no-untyped-def] + self.transport.broadcast(self, msg) # type: ignore[arg-type] + + @transport.setter # type: ignore[attr-defined, misc, no-redef] + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() # type: ignore[union-attr] + self._transport = value diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py new file mode 100644 index 0000000000..b46c3f593b --- /dev/null +++ b/dimos/core/test_blueprints.py @@ -0,0 +1,371 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.core._test_future_annotations_helper import ( + FutureData, + FutureModuleIn, + FutureModuleOut, +) +from dimos.core.blueprints import ( + ModuleBlueprint, + ModuleBlueprintSet, + ModuleConnection, + _make_module_blueprint, + autoconnect, +) +from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.rpc_client import RpcCall +from dimos.core.stream import In, Out +from dimos.core.transport import LCMTransport +from dimos.protocol import pubsub + + +class Scratch: + pass + + +class Petting: + pass + + +class CatModule(Module): + pet_cat: In[Petting] + scratches: Out[Scratch] + + +class Data1: + pass + + +class Data2: + pass + + +class Data3: + pass + + +class ModuleA(Module): + data1: Out[Data1] = None + data2: Out[Data2] = None + + @rpc + def get_name(self) -> str: + return "A, Module A" + + +class ModuleB(Module): + data1: In[Data1] = None + data2: In[Data2] = None + data3: Out[Data3] = None + + _module_a_get_name: callable = None + + @rpc + def set_ModuleA_get_name(self, callable: RpcCall) -> None: + self._module_a_get_name = callable + self._module_a_get_name.set_rpc(self.rpc) + + @rpc + def what_is_as_name(self) -> str: + if self._module_a_get_name is None: + return "ModuleA.get_name not set" + return self._module_a_get_name() + + +class ModuleC(Module): + data3: In[Data3] = None + + +module_a = ModuleA.blueprint +module_b = ModuleB.blueprint +module_c = ModuleC.blueprint + + +def test_get_connection_set() -> None: + assert _make_module_blueprint(CatModule, args=("arg1"), kwargs={"k": "v"}) == ModuleBlueprint( + module=CatModule, + connections=( + ModuleConnection(name="pet_cat", type=Petting, direction="in"), + ModuleConnection(name="scratches", type=Scratch, direction="out"), + ), + args=("arg1"), + kwargs={"k": "v"}, + ) + + +def test_autoconnect() -> None: + blueprint_set = autoconnect(module_a(), module_b()) + + assert blueprint_set == ModuleBlueprintSet( + blueprints=( + ModuleBlueprint( + module=ModuleA, + connections=( + ModuleConnection(name="data1", type=Data1, direction="out"), + ModuleConnection(name="data2", type=Data2, direction="out"), + ), + args=(), + kwargs={}, + ), + ModuleBlueprint( + module=ModuleB, + connections=( + ModuleConnection(name="data1", type=Data1, direction="in"), + ModuleConnection(name="data2", type=Data2, direction="in"), + ModuleConnection(name="data3", type=Data3, direction="out"), + ), + args=(), + kwargs={}, + ), + ) + ) + + +def test_transports() -> None: + custom_transport = LCMTransport("/custom_topic", Data1) + blueprint_set = autoconnect(module_a(), module_b()).transports( + {("data1", Data1): custom_transport} + ) + + assert ("data1", Data1) in blueprint_set.transport_map + assert blueprint_set.transport_map[("data1", Data1)] == custom_transport + + +def test_global_config() -> None: + blueprint_set = autoconnect(module_a(), module_b()).global_config(option1=True, option2=42) + + assert "option1" in blueprint_set.global_config_overrides + assert blueprint_set.global_config_overrides["option1"] is True + assert "option2" in blueprint_set.global_config_overrides + assert blueprint_set.global_config_overrides["option2"] == 42 + + +def test_build_happy_path() -> None: + pubsub.lcm.autoconf() + + blueprint_set = autoconnect(module_a(), module_b(), module_c()) + + coordinator = blueprint_set.build(GlobalConfig()) + + try: + assert isinstance(coordinator, ModuleCoordinator) + + module_a_instance = coordinator.get_instance(ModuleA) + module_b_instance = coordinator.get_instance(ModuleB) + module_c_instance = coordinator.get_instance(ModuleC) + + assert module_a_instance is not None + assert module_b_instance is not None + assert module_c_instance is not None + + assert module_a_instance.data1.transport is not None + assert module_a_instance.data2.transport is not None + assert module_b_instance.data1.transport is not None + assert module_b_instance.data2.transport is not None + assert module_b_instance.data3.transport is not None + assert module_c_instance.data3.transport is not None + + assert module_a_instance.data1.transport.topic == module_b_instance.data1.transport.topic + assert module_a_instance.data2.transport.topic == module_b_instance.data2.transport.topic + assert module_b_instance.data3.transport.topic == module_c_instance.data3.transport.topic + + assert module_b_instance.what_is_as_name() == "A, Module A" + + finally: + coordinator.stop() + + +def test_name_conflicts_are_reported() -> None: + class ModuleA(Module): + shared_data: Out[Data1] = None + + class ModuleB(Module): + shared_data: In[Data2] = None + + blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()) + + try: + blueprint_set._verify_no_name_conflicts() + pytest.fail("Expected ValueError to be raised") + except ValueError as e: + error_message = str(e) + assert "Blueprint cannot start because there are conflicting connections" in error_message + assert "'shared_data' has conflicting types" in error_message + assert "Data1 in ModuleA" in error_message + assert "Data2 in ModuleB" in error_message + + +def test_multiple_name_conflicts_are_reported() -> None: + class Module1(Module): + sensor_data: Out[Data1] = None + control_signal: Out[Data2] = None + + class Module2(Module): + sensor_data: In[Data2] = None + control_signal: In[Data3] = None + + blueprint_set = autoconnect(Module1.blueprint(), Module2.blueprint()) + + try: + blueprint_set._verify_no_name_conflicts() + pytest.fail("Expected ValueError to be raised") + except ValueError as e: + error_message = str(e) + assert "Blueprint cannot start because there are conflicting connections" in error_message + assert "'sensor_data' has conflicting types" in error_message + assert "'control_signal' has conflicting types" in error_message + + +def test_that_remapping_can_resolve_conflicts() -> None: + class Module1(Module): + data: Out[Data1] = None + + class Module2(Module): + data: Out[Data2] = None # Would conflict with Module1.data + + class Module3(Module): + data1: In[Data1] = None + data2: In[Data2] = None + + # Without remapping, should raise conflict error + blueprint_set = autoconnect(Module1.blueprint(), Module2.blueprint(), Module3.blueprint()) + + try: + blueprint_set._verify_no_name_conflicts() + pytest.fail("Expected ValueError due to conflict") + except ValueError as e: + assert "'data' has conflicting types" in str(e) + + # With remapping to resolve the conflict + blueprint_set_remapped = autoconnect( + Module1.blueprint(), Module2.blueprint(), Module3.blueprint() + ).remappings( + [ + (Module1, "data", "data1"), + (Module2, "data", "data2"), + ] + ) + + # Should not raise any exception after remapping + blueprint_set_remapped._verify_no_name_conflicts() + + +def test_remapping() -> None: + """Test that remapping connections works correctly.""" + pubsub.lcm.autoconf() + + # Define test modules with connections that will be remapped + class SourceModule(Module): + color_image: Out[Data1] = None # Will be remapped to 'remapped_data' + + class TargetModule(Module): + remapped_data: In[Data1] = None # Receives the remapped connection + + # Create blueprint with remapping + blueprint_set = autoconnect( + SourceModule.blueprint(), + TargetModule.blueprint(), + ).remappings( + [ + (SourceModule, "color_image", "remapped_data"), + ] + ) + + # Verify remappings are stored correctly + assert (SourceModule, "color_image") in blueprint_set.remapping_map + assert blueprint_set.remapping_map[(SourceModule, "color_image")] == "remapped_data" + + # Verify that remapped names are used in name resolution + assert ("remapped_data", Data1) in blueprint_set._all_name_types + # The original name shouldn't be in the name types since it's remapped + assert ("color_image", Data1) not in blueprint_set._all_name_types + + # Build and verify connections work + coordinator = blueprint_set.build(GlobalConfig()) + + try: + source_instance = coordinator.get_instance(SourceModule) + target_instance = coordinator.get_instance(TargetModule) + + assert source_instance is not None + assert target_instance is not None + + # Both should have transports set + assert source_instance.color_image.transport is not None + assert target_instance.remapped_data.transport is not None + + # They should be using the same transport (connected) + assert ( + source_instance.color_image.transport.topic + == target_instance.remapped_data.transport.topic + ) + + # The topic should be /remapped_data since that's the remapped name + assert target_instance.remapped_data.transport.topic == "/remapped_data" + + finally: + coordinator.stop() + + +def test_future_annotations_support() -> None: + """Test that modules using `from __future__ import annotations` work correctly. + + PEP 563 (future annotations) stores annotations as strings instead of actual types. + This test verifies that _make_module_blueprint properly resolves string annotations + to the actual In/Out types. + """ + + # Test that connections are properly extracted from modules with future annotations + out_blueprint = _make_module_blueprint(FutureModuleOut, args=(), kwargs={}) + assert len(out_blueprint.connections) == 1 + assert out_blueprint.connections[0] == ModuleConnection( + name="data", type=FutureData, direction="out" + ) + + in_blueprint = _make_module_blueprint(FutureModuleIn, args=(), kwargs={}) + assert len(in_blueprint.connections) == 1 + assert in_blueprint.connections[0] == ModuleConnection( + name="data", type=FutureData, direction="in" + ) + + +def test_future_annotations_autoconnect() -> None: + """Test that autoconnect works with modules using `from __future__ import annotations`.""" + + blueprint_set = autoconnect(FutureModuleOut.blueprint(), FutureModuleIn.blueprint()) + + coordinator = blueprint_set.build(GlobalConfig()) + + try: + out_instance = coordinator.get_instance(FutureModuleOut) + in_instance = coordinator.get_instance(FutureModuleIn) + + assert out_instance is not None + assert in_instance is not None + + # Both should have transports set + assert out_instance.data.transport is not None + assert in_instance.data.transport is not None + + # They should be connected via the same transport + assert out_instance.data.transport.topic == in_instance.data.transport.topic + + finally: + coordinator.stop() diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py new file mode 100644 index 0000000000..561fb5e9ec --- /dev/null +++ b/dimos/core/test_core.py @@ -0,0 +1,145 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest +from reactivex.disposable import Disposable + +from dimos.core import ( + In, + LCMTransport, + Module, + Out, + pLCMTransport, + rpc, + start, +) +from dimos.core.testing import MockRobotClient, dimos +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry + +assert dimos + + +class Navigation(Module): + mov: Out[Vector3] = None + lidar: In[LidarMessage] = None + target_position: In[Vector3] = None + odometry: In[Odometry] = None + + odom_msg_count = 0 + lidar_msg_count = 0 + + @rpc + def navigate_to(self, target: Vector3) -> bool: ... + + def __init__(self) -> None: + super().__init__() + + @rpc + def start(self) -> None: + def _odom(msg) -> None: + self.odom_msg_count += 1 + print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) + self.mov.publish(msg.position) + + unsub = self.odometry.subscribe(_odom) + self._disposables.add(Disposable(unsub)) + + def _lidar(msg) -> None: + self.lidar_msg_count += 1 + if hasattr(msg, "pubtime"): + print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) + else: + print("RCV: unknown time", msg) + + unsub = self.lidar.subscribe(_lidar) + self._disposables.add(Disposable(unsub)) + + +def test_classmethods() -> None: + # Test class property access + class_rpcs = Navigation.rpcs + print("Class rpcs:", class_rpcs) + # Test instance property access + nav = Navigation() + instance_rpcs = nav.rpcs + print("Instance rpcs:", instance_rpcs) + + # Assertions + assert isinstance(class_rpcs, dict), "Class rpcs should be a dictionary" + assert isinstance(instance_rpcs, dict), "Instance rpcs should be a dictionary" + assert class_rpcs == instance_rpcs, "Class and instance rpcs should be identical" + + # Check that we have the expected RPC methods + assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" + assert "start" in class_rpcs, "start should be in rpcs" + assert len(class_rpcs) == 8 + + # Check that the values are callable + assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" + assert callable(class_rpcs["start"]), "start should be callable" + + # Check that they have the __rpc__ attribute + assert hasattr(class_rpcs["navigate_to"], "__rpc__"), ( + "navigate_to should have __rpc__ attribute" + ) + assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" + + nav._close_module() + + +@pytest.mark.module +def test_basic_deployment(dimos) -> None: + robot = dimos.deploy(MockRobotClient) + + print("\n") + print("lidar stream", robot.lidar) + print("odom stream", robot.odometry) + + nav = dimos.deploy(Navigation) + + # this one encodes proper LCM messages + robot.lidar.transport = LCMTransport("/lidar", LidarMessage) + + # odometry & mov using just a pickle over LCM + robot.odometry.transport = pLCMTransport("/odom") + nav.mov.transport = pLCMTransport("/mov") + + nav.lidar.connect(robot.lidar) + nav.odometry.connect(robot.odometry) + robot.mov.connect(nav.mov) + + robot.start() + nav.start() + + time.sleep(1) + robot.stop() + + print("robot.mov_msg_count", robot.mov_msg_count) + print("nav.odom_msg_count", nav.odom_msg_count) + print("nav.lidar_msg_count", nav.lidar_msg_count) + + assert robot.mov_msg_count >= 8 + assert nav.odom_msg_count >= 8 + assert nav.lidar_msg_count >= 8 + + dimos.shutdown() + + +if __name__ == "__main__": + client = start(1) # single process for CI memory + test_deployment(client) diff --git a/dimos/core/test_modules.py b/dimos/core/test_modules.py new file mode 100644 index 0000000000..7bd995c857 --- /dev/null +++ b/dimos/core/test_modules.py @@ -0,0 +1,334 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that all Module subclasses implement required resource management methods.""" + +import ast +import inspect +from pathlib import Path + +import pytest + +from dimos.core.module import Module + + +class ModuleVisitor(ast.NodeVisitor): + """AST visitor to find classes and their base classes.""" + + def __init__(self, filepath: str) -> None: + self.filepath = filepath + self.classes: list[ + tuple[str, list[str], set[str]] + ] = [] # (class_name, base_classes, methods) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + """Visit a class definition.""" + # Get base class names + base_classes = [] + for base in node.bases: + if isinstance(base, ast.Name): + base_classes.append(base.id) + elif isinstance(base, ast.Attribute): + # Handle cases like dimos.core.Module + parts = [] + current = base + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + base_classes.append(".".join(reversed(parts))) + + # Get method names defined in this class + methods = set() + for item in node.body: + if isinstance(item, ast.FunctionDef): + methods.add(item.name) + + self.classes.append((node.name, base_classes, methods)) + self.generic_visit(node) + + +def get_import_aliases(tree: ast.AST) -> dict[str, str]: + """Extract import aliases from the AST.""" + aliases = {} + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + key = alias.asname if alias.asname else alias.name + aliases[key] = alias.name + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + for alias in node.names: + key = alias.asname if alias.asname else alias.name + full_name = f"{module}.{alias.name}" if module else alias.name + aliases[key] = full_name + + return aliases + + +def is_module_subclass( + base_classes: list[str], + aliases: dict[str, str], + class_hierarchy: dict[str, list[str]] | None = None, + current_module_path: str | None = None, +) -> bool: + """Check if any base class is or resolves to dimos.core.Module or its variants (recursively).""" + target_classes = { + "Module", + "ModuleBase", + "DaskModule", + "dimos.core.Module", + "dimos.core.ModuleBase", + "dimos.core.DaskModule", + "dimos.core.module.Module", + "dimos.core.module.ModuleBase", + "dimos.core.module.DaskModule", + } + + def find_qualified_name(base: str, context_module: str | None = None) -> str: + """Find the qualified name for a base class, using import context if available.""" + if not class_hierarchy: + return base + + # First try exact match (already fully qualified or in hierarchy) + if base in class_hierarchy: + return base + + # Check if it's in our aliases (from imports) + if base in aliases: + resolved = aliases[base] + if resolved in class_hierarchy: + return resolved + # The resolved name might be a qualified name that exists + return resolved + + # If we have a context module and base is a simple name, + # try to find it in the same module first (for local classes) + if context_module and "." not in base: + same_module_qualified = f"{context_module}.{base}" + if same_module_qualified in class_hierarchy: + return same_module_qualified + + # Otherwise return the base as-is + return base + + def check_base( + base: str, visited: set[str] | None = None, context_module: str | None = None + ) -> bool: + if visited is None: + visited = set() + + # Avoid infinite recursion + if base in visited: + return False + visited.add(base) + + # Check direct match + if base in target_classes: + return True + + # Check if it's an alias + if base in aliases: + resolved = aliases[base] + if resolved in target_classes: + return True + # Continue checking with resolved name + base = resolved + + # If we have a class hierarchy, recursively check parent classes + if class_hierarchy: + # Resolve the base class name to a qualified name + qualified_name = find_qualified_name(base, context_module) + + if qualified_name in class_hierarchy: + # Check all parent classes + for parent_base in class_hierarchy[qualified_name]: + if check_base(parent_base, visited, None): # Parent lookups don't use context + return True + + return False + + for base in base_classes: + if check_base(base, context_module=current_module_path): + return True + + return False + + +def scan_file( + filepath: Path, + class_hierarchy: dict[str, list[str]] | None = None, + root_path: Path | None = None, +) -> list[tuple[str, str, bool, bool, set[str]]]: + """ + Scan a Python file for Module subclasses. + + Returns: + List of (class_name, filepath, has_start, has_stop, forbidden_methods) + """ + forbidden_method_names = {"acquire", "release", "open", "close", "shutdown", "clean", "cleanup"} + + try: + with open(filepath, encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=str(filepath)) + aliases = get_import_aliases(tree) + + visitor = ModuleVisitor(str(filepath)) + visitor.visit(tree) + + # Get module path for this file to properly resolve base classes + current_module_path = None + if root_path: + try: + rel_path = filepath.relative_to(root_path.parent) + module_parts = list(rel_path.parts[:-1]) + if rel_path.stem != "__init__": + module_parts.append(rel_path.stem) + current_module_path = ".".join(module_parts) + except ValueError: + pass + + results = [] + for class_name, base_classes, methods in visitor.classes: + if is_module_subclass(base_classes, aliases, class_hierarchy, current_module_path): + has_start = "start" in methods + has_stop = "stop" in methods + forbidden_found = methods & forbidden_method_names + results.append((class_name, str(filepath), has_start, has_stop, forbidden_found)) + + return results + + except (SyntaxError, UnicodeDecodeError): + # Skip files that can't be parsed + return [] + + +def build_class_hierarchy(root_path: Path) -> dict[str, list[str]]: + """Build a complete class hierarchy by scanning all Python files.""" + hierarchy = {} + + for filepath in sorted(root_path.rglob("*.py")): + # Skip __pycache__ and other irrelevant directories + if "__pycache__" in filepath.parts or ".venv" in filepath.parts: + continue + + try: + with open(filepath, encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=str(filepath)) + visitor = ModuleVisitor(str(filepath)) + visitor.visit(tree) + + # Convert filepath to module path (e.g., dimos/core/module.py -> dimos.core.module) + try: + rel_path = filepath.relative_to(root_path.parent) + except ValueError: + # If we can't get relative path, skip this file + continue + + # Convert path to module notation + module_parts = list(rel_path.parts[:-1]) # Exclude filename + if rel_path.stem != "__init__": + module_parts.append(rel_path.stem) # Add filename without .py + module_name = ".".join(module_parts) + + for class_name, base_classes, _ in visitor.classes: + # Use fully qualified name as key to avoid conflicts + qualified_name = f"{module_name}.{class_name}" if module_name else class_name + hierarchy[qualified_name] = base_classes + + except (SyntaxError, UnicodeDecodeError): + # Skip files that can't be parsed + continue + + return hierarchy + + +def scan_directory(root_path: Path) -> list[tuple[str, str, bool, bool, set[str]]]: + """Scan all Python files in the directory tree.""" + # First, build the complete class hierarchy + class_hierarchy = build_class_hierarchy(root_path) + + # Then scan for Module subclasses using the complete hierarchy + results = [] + + for filepath in sorted(root_path.rglob("*.py")): + # Skip __pycache__ and other irrelevant directories + if "__pycache__" in filepath.parts or ".venv" in filepath.parts: + continue + + file_results = scan_file(filepath, class_hierarchy, root_path) + results.extend(file_results) + + return results + + +def get_all_module_subclasses(): + """Find all Module subclasses in the dimos codebase.""" + # Get the dimos package directory + dimos_file = inspect.getfile(Module) + dimos_path = Path(dimos_file).parent.parent # Go up from dimos/core/module.py to dimos/ + + results = scan_directory(dimos_path) + + # Filter out test modules and base classes + filtered_results = [] + for class_name, filepath, has_start, has_stop, forbidden_methods in results: + # Skip base module classes themselves + if class_name in ("Module", "ModuleBase", "DaskModule", "SkillModule"): + continue + + # Skip test-only modules (those defined in test_ files) + if "test_" in Path(filepath).name: + continue + + filtered_results.append((class_name, filepath, has_start, has_stop, forbidden_methods)) + + return filtered_results + + +@pytest.mark.parametrize( + "class_name,filepath,has_start,has_stop,forbidden_methods", + get_all_module_subclasses(), + ids=lambda val: val[0] if isinstance(val, str) else str(val), +) +def test_module_has_start_and_stop( + class_name: str, filepath, has_start, has_stop, forbidden_methods +) -> None: + """Test that Module subclasses implement start and stop methods and don't use forbidden methods.""" + # Get relative path for better error messages + try: + rel_path = Path(filepath).relative_to(Path.cwd()) + except ValueError: + rel_path = filepath + + errors = [] + + # Check for missing required methods + if not has_start: + errors.append("missing required method: start") + if not has_stop: + errors.append("missing required method: stop") + + # Check for forbidden methods + if forbidden_methods: + forbidden_list = ", ".join(sorted(forbidden_methods)) + errors.append(f"has forbidden method(s): {forbidden_list}") + + assert not errors, f"{class_name} in {rel_path} has issues:\n - " + "\n - ".join(errors) diff --git a/dimos/core/test_rpcstress.py b/dimos/core/test_rpcstress.py new file mode 100644 index 0000000000..51c1c81e4b --- /dev/null +++ b/dimos/core/test_rpcstress.py @@ -0,0 +1,177 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +from dimos.core import In, Module, Out, rpc + + +class Counter(Module): + current_count: int = 0 + + count_stream: Out[int] = None + + def __init__(self) -> None: + super().__init__() + self.current_count = 0 + + @rpc + def increment(self): + """Increment the counter and publish the new value.""" + self.current_count += 1 + self.count_stream.publish(self.current_count) + return self.current_count + + +class CounterValidator(Module): + """Calls counter.increment() as fast as possible and validates no numbers are skipped.""" + + count_in: In[int] = None + + def __init__(self, increment_func) -> None: + super().__init__() + self.increment_func = increment_func + self.last_seen = 0 + self.missing_numbers = [] + self.running = False + self.call_thread = None + self.call_count = 0 + self.total_latency = 0.0 + self.call_start_time = None + self.waiting_for_response = False + + @rpc + def start(self) -> None: + """Start the validator.""" + self.count_in.subscribe(self._on_count_received) + self.running = True + self.call_thread = threading.Thread(target=self._call_loop) + self.call_thread.start() + + @rpc + def stop(self) -> None: + """Stop the validator.""" + self.running = False + if self.call_thread: + self.call_thread.join() + + def _on_count_received(self, count: int) -> None: + """Check if we received all numbers in sequence and trigger next call.""" + # Calculate round trip time + if self.call_start_time: + latency = time.time() - self.call_start_time + self.total_latency += latency + + if count != self.last_seen + 1: + for missing in range(self.last_seen + 1, count): + self.missing_numbers.append(missing) + print(f"[VALIDATOR] Missing number detected: {missing}") + self.last_seen = count + + # Signal that we can make the next call + self.waiting_for_response = False + + def _call_loop(self) -> None: + """Call increment only after receiving response from previous call.""" + while self.running: + if not self.waiting_for_response: + try: + self.waiting_for_response = True + self.call_start_time = time.time() + result = self.increment_func() + call_time = time.time() - self.call_start_time + self.call_count += 1 + if self.call_count % 100 == 0: + avg_latency = ( + self.total_latency / self.call_count if self.call_count > 0 else 0 + ) + print( + f"[VALIDATOR] Made {self.call_count} calls, last result: {result}, RPC call time: {call_time * 1000:.2f}ms, avg RTT: {avg_latency * 1000:.2f}ms" + ) + except Exception as e: + print(f"[VALIDATOR] Error calling increment: {e}") + self.waiting_for_response = False + time.sleep(0.001) # Small delay on error + else: + # Don't sleep - busy wait for maximum speed + pass + + @rpc + def get_stats(self): + """Get validation statistics.""" + avg_latency = self.total_latency / self.call_count if self.call_count > 0 else 0 + return { + "call_count": self.call_count, + "last_seen": self.last_seen, + "missing_count": len(self.missing_numbers), + "missing_numbers": self.missing_numbers[:10] if self.missing_numbers else [], + "avg_rtt_ms": avg_latency * 1000, + "calls_per_sec": self.call_count / 10.0 if self.call_count > 0 else 0, + } + + +if __name__ == "__main__": + import dimos.core as core + from dimos.core import pLCMTransport + + # Start dimos with 2 workers + client = core.start(2) + + # Deploy counter module + counter = client.deploy(Counter) + counter.count_stream.transport = pLCMTransport("/counter_stream") + + # Deploy validator module with increment function + validator = client.deploy(CounterValidator, counter.increment) + validator.count_in.transport = pLCMTransport("/counter_stream") + + # Connect validator to counter's output + validator.count_in.connect(counter.count_stream) + + # Start modules + validator.start() + + print("[MAIN] Counter and validator started. Running for 10 seconds...") + + # Test direct RPC speed for comparison + print("\n[MAIN] Testing direct RPC call speed for 1 second...") + start = time.time() + direct_count = 0 + while time.time() - start < 1.0: + counter.increment() + direct_count += 1 + print(f"[MAIN] Direct RPC calls per second: {direct_count}") + + # Run for 10 seconds + time.sleep(10) + + # Get stats before stopping + stats = validator.get_stats() + print("\n[MAIN] Final statistics:") + print(f" - Total calls made: {stats['call_count']}") + print(f" - Last number seen: {stats['last_seen']}") + print(f" - Missing numbers: {stats['missing_count']}") + print(f" - Average RTT: {stats['avg_rtt_ms']:.2f}ms") + print(f" - Calls per second: {stats['calls_per_sec']:.1f}") + if stats["missing_numbers"]: + print(f" - First missing numbers: {stats['missing_numbers']}") + + # Stop modules + validator.stop() + + # Shutdown dimos + client.shutdown() + + print("[MAIN] Test complete.") diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py new file mode 100644 index 0000000000..997e7e2cf1 --- /dev/null +++ b/dimos/core/test_stream.py @@ -0,0 +1,256 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +import time + +import pytest + +from dimos.core import ( + In, + LCMTransport, + Module, + rpc, +) +from dimos.core.testing import MockRobotClient, dimos +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry + +assert dimos + + +class SubscriberBase(Module): + sub1_msgs: list[Odometry] = None + sub2_msgs: list[Odometry] = None + + def __init__(self) -> None: + self.sub1_msgs = [] + self.sub2_msgs = [] + super().__init__() + + @rpc + def sub1(self) -> None: ... + + @rpc + def sub2(self) -> None: ... + + @rpc + def active_subscribers(self): + return self.odom.transport.active_subscribers + + @rpc + def sub1_msgs_len(self) -> int: + return len(self.sub1_msgs) + + @rpc + def sub2_msgs_len(self) -> int: + return len(self.sub2_msgs) + + +class ClassicSubscriber(SubscriberBase): + odom: In[Odometry] = None + unsub: Callable[[], None] | None = None + unsub2: Callable[[], None] | None = None + + @rpc + def sub1(self) -> None: + self.unsub = self.odom.subscribe(self.sub1_msgs.append) + + @rpc + def sub2(self) -> None: + self.unsub2 = self.odom.subscribe(self.sub2_msgs.append) + + @rpc + def stop(self) -> None: + if self.unsub: + self.unsub() + self.unsub = None + if self.unsub2: + self.unsub2() + self.unsub2 = None + + +class RXPYSubscriber(SubscriberBase): + odom: In[Odometry] = None + unsub: Callable[[], None] | None = None + unsub2: Callable[[], None] | None = None + + hot: Callable[[], None] | None = None + + @rpc + def sub1(self) -> None: + self.unsub = self.odom.observable().subscribe(self.sub1_msgs.append) + + @rpc + def sub2(self) -> None: + self.unsub2 = self.odom.observable().subscribe(self.sub2_msgs.append) + + @rpc + def stop(self) -> None: + if self.unsub: + self.unsub.dispose() + self.unsub = None + if self.unsub2: + self.unsub2.dispose() + self.unsub2 = None + + @rpc + def get_next(self): + return self.odom.get_next() + + @rpc + def start_hot_getter(self) -> None: + self.hot = self.odom.hot_latest() + + @rpc + def stop_hot_getter(self) -> None: + self.hot.dispose() + + @rpc + def get_hot(self): + return self.hot() + + +class SpyLCMTransport(LCMTransport): + active_subscribers: int = 0 + + def __reduce__(self): + return (SpyLCMTransport, (self.topic.topic, self.topic.lcm_type)) + + def __init__(self, topic: str, type: type, **kwargs) -> None: + super().__init__(topic, type, **kwargs) + self._subscriber_map = {} # Maps unsubscribe functions to track active subs + + def subscribe(self, selfstream: In, callback: Callable) -> Callable[[], None]: + # Call parent subscribe to get the unsubscribe function + unsubscribe_fn = super().subscribe(selfstream, callback) + + # Increment counter + self.active_subscribers += 1 + + def wrapped_unsubscribe() -> None: + # Create wrapper that decrements counter when called + if wrapped_unsubscribe in self._subscriber_map: + self.active_subscribers -= 1 + del self._subscriber_map[wrapped_unsubscribe] + unsubscribe_fn() + + # Track this subscription + self._subscriber_map[wrapped_unsubscribe] = True + + return wrapped_unsubscribe + + +@pytest.mark.parametrize("subscriber_class", [ClassicSubscriber, RXPYSubscriber]) +@pytest.mark.module +def test_subscription(dimos, subscriber_class) -> None: + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(subscriber_class) + + subscriber.odom.connect(robot.odometry) + + robot.start() + subscriber.sub1() + time.sleep(0.25) + + assert subscriber.sub1_msgs_len() > 0 + assert subscriber.sub2_msgs_len() == 0 + assert subscriber.active_subscribers() == 1 + + subscriber.sub2() + + time.sleep(0.25) + subscriber.stop() + + assert subscriber.active_subscribers() == 0 + assert subscriber.sub1_msgs_len() != 0 + assert subscriber.sub2_msgs_len() != 0 + + total_msg_n = subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() + + time.sleep(0.25) + + # ensuring no new messages have passed through + assert total_msg_n == subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() + + robot.stop() + + +@pytest.mark.module +def test_get_next(dimos) -> None: + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(RXPYSubscriber) + subscriber.odom.connect(robot.odometry) + + robot.start() + time.sleep(0.1) + + odom = subscriber.get_next() + + assert isinstance(odom, Odometry) + assert subscriber.active_subscribers() == 0 + + time.sleep(0.2) + + next_odom = subscriber.get_next() + + assert isinstance(next_odom, Odometry) + assert subscriber.active_subscribers() == 0 + + assert next_odom != odom + robot.stop() + + +@pytest.mark.module +def test_hot_getter(dimos) -> None: + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(RXPYSubscriber) + subscriber.odom.connect(robot.odometry) + + robot.start() + + # we are robust to multiple calls + subscriber.start_hot_getter() + time.sleep(0.2) + odom = subscriber.get_hot() + subscriber.stop_hot_getter() + + assert isinstance(odom, Odometry) + time.sleep(0.3) + + # there are no subs + assert subscriber.active_subscribers() == 0 + + # we can restart though + subscriber.start_hot_getter() + time.sleep(0.3) + + next_odom = subscriber.get_hot() + assert isinstance(next_odom, Odometry) + assert next_odom != odom + subscriber.stop_hot_getter() + + robot.stop() diff --git a/dimos/core/testing.py b/dimos/core/testing.py new file mode 100644 index 0000000000..b988a43183 --- /dev/null +++ b/dimos/core/testing.py @@ -0,0 +1,83 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from threading import Event, Thread +import time + +import pytest + +from dimos.core import In, Module, Out, rpc, start +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.testing import SensorReplay + + +@pytest.fixture +def dimos(): # type: ignore[no-untyped-def] + """Fixture to create a Dimos client for testing.""" + client = start(2) + yield client + client.stop() # type: ignore[attr-defined] + + +class MockRobotClient(Module): + odometry: Out[Odometry] = None # type: ignore[assignment] + lidar: Out[LidarMessage] = None # type: ignore[assignment] + mov: In[Vector3] = None # type: ignore[assignment] + + mov_msg_count = 0 + + def mov_callback(self, msg) -> None: # type: ignore[no-untyped-def] + self.mov_msg_count += 1 + + def __init__(self) -> None: + super().__init__() + self._stop_event = Event() + self._thread = None + + @rpc + def start(self) -> None: + super().start() + + self._thread = Thread(target=self.odomloop) # type: ignore[assignment] + self._thread.start() # type: ignore[attr-defined] + self.mov.subscribe(self.mov_callback) + + @rpc + def stop(self) -> None: + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) + + super().stop() + + def odomloop(self) -> None: + odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) + lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + lidariter = lidardata.iterate() + self._stop_event.clear() + while not self._stop_event.is_set(): + for odom in odomdata.iterate(): + if self._stop_event.is_set(): + return + print(odom) + odom.pubtime = time.perf_counter() + self.odometry.publish(odom) + + lidarmsg = next(lidariter) + lidarmsg.pubtime = time.perf_counter() # type: ignore[union-attr] + self.lidar.publish(lidarmsg) + time.sleep(0.1) diff --git a/dimos/core/transport.py b/dimos/core/transport.py new file mode 100644 index 0000000000..e3b1425ced --- /dev/null +++ b/dimos/core/transport.py @@ -0,0 +1,357 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pub/sub transports for streaming data between modules. + +This module provides transport implementations that connect module stream +endpoints (`In[T]` and `Out[T]`). Transports handle underlying protocol +details, allowing modules to communicate without knowing whether data travels +via shared memory, LCM multicast, or other mechanisms. + +Transport categories: + +- **Network-capable** (LCM variants): For distributed systems where modules + may run on different machines. Supports serialization and multicast. +- **Local-only** (SHM variants): For high-throughput communication between + processes on the same machine. Uses shared memory for near-zero-copy transfer. + +Selection guidance: + +- `LCMTransport`: Default for typed messages with `lcm_encode` support +- `pLCMTransport`: Python objects without LCM encoding, network-capable +- `JpegLcmTransport`: Images over network with compression +- `pSHMTransport`: High-throughput local data (images, point clouds) +- `SHMTransport`: Raw bytes over shared memory +- `JpegShmTransport`: Local images with reduced memory footprint + +Example: + Configure transports via `.transports()` on blueprint sets: + + from dimos.core.blueprints import autoconnect + from dimos.core.transport import LCMTransport, pSHMTransport + from dimos.msgs.sensor_msgs import Image + + blueprint = autoconnect( + connection(), + perception(), + ).transports({ + # Key: (stream_property_name, Type) + ("color_image", Image): LCMTransport("/robot/camera", Image), + }) + +All transports lazily initialize on first `broadcast()` or `subscribe()` +call. For the abstract interface, see `Transport` in `stream.py`. +""" + +from __future__ import annotations + +from typing import Annotated, Any, TypeVar + +from annotated_doc import Doc + +import dimos.core.colors as colors + +T = TypeVar("T") + +from typing import ( + TYPE_CHECKING, + TypeVar, +) + +from dimos.core.stream import In, Transport +from dimos.protocol.pubsub.jpeg_shm import JpegSharedMemory +from dimos.protocol.pubsub.lcmpubsub import LCM, JpegLCM, PickleLCM, Topic as LCMTopic +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory + +if TYPE_CHECKING: + from collections.abc import Callable + +T = TypeVar("T") # type: ignore[misc] + + +class PubSubTransport(Transport[T]): + """Topic-based publish-subscribe transport. + + Extends `Transport` with a topic attribute for routing messages. Use this + when modules need to communicate over named channels rather than direct + point-to-point connections. + + The topic serves as a logical address that publishers broadcast to and + subscribers listen on, enabling many-to-many communication patterns. + """ + + topic: Any + + def __init__(self, topic: Any) -> None: + self.topic = topic + + def __str__(self) -> str: + return ( + colors.green(f"{self.__class__.__name__}(") + + colors.blue(self.topic) + + colors.green(")") + ) + + +class pLCMTransport(PubSubTransport[T]): + """LCM (Lightweight Communications and Marshalling) transport with pickle serialization for arbitrary Python objects. + + Uses UDP multicast via LCM for low-latency pub/sub messaging across processes + and machines. The "p" prefix indicates pickle serialization. + + Use when you need network messaging with Python objects that lack `lcm_encode` + support. For native LCM types, prefer `LCMTransport` (faster, cross-language). + + See also: + LCMTransport: Native LCM encoding (faster, cross-language). + pSHMTransport: Pickle over shared memory (single-machine only). + """ + + _started: bool = False + + def __init__(self, topic: str, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(topic) + self.lcm = PickleLCM(**kwargs) + + def __reduce__(self): # type: ignore[no-untyped-def] + return (pLCMTransport, (self.topic,)) + + def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def] + if not self._started: + self.lcm.start() + self._started = True + + self.lcm.publish(self.topic, msg) + + def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override] + if not self._started: + self.lcm.start() + self._started = True + return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value] + + +class LCMTransport(PubSubTransport[T]): + """Publish-subscribe transport using LCM (Lightweight Communications and Marshalling) native encoding over UDP multicast. + + Default transport for typed messages (poses, images, point clouds, sensor readings) + that need to be shared across network boundaries. Uses LCM's native serialization + rather than pickle, enabling cross-language interoperability but requiring message + types that implement `lcm_encode`/`lcm_decode` methods. + + For pickle-based serialization (Python-only, any type), use `pLCMTransport`. + + More on LCM: + - It's a publish-subscribe messaging system that uses UDP multicast for its underlying transport. + - "provides a best-effort packet delivery mechanism and gives strong preference to the expedient delivery of recent messages" (LCM paper) + + Further reading + - [The LCM paper](https://people.csail.mit.edu/albert/pubs/2010-huang-olson-moore-lcm-iros.pdf) + """ + + _started: bool = False + + def __init__( + self, + topic: Annotated[str, Doc("Channel name for message routing.")], + type: Annotated[ + type, + Doc( + """LCM message type (must have `lcm_encode`/`lcm_decode` methods, + typically auto-generated from `.lcm` schema files).""" + ), + ], + **kwargs: Annotated[Any, Doc("Passed to the underlying LCM instance.")], + ) -> None: + super().__init__(LCMTopic(topic, type)) + if not hasattr(self, "lcm"): + self.lcm = LCM(**kwargs) + + def __reduce__(self): # type: ignore[no-untyped-def] + return (LCMTransport, (self.topic.topic, self.topic.lcm_type)) + + def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def] + if not self._started: + self.lcm.start() + self._started = True + + self.lcm.publish(self.topic, msg) + + def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override] + if not self._started: + self.lcm.start() + self._started = True + return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value] + + +class JpegLcmTransport(LCMTransport): # type: ignore[type-arg] + """LCM transport with JPEG compression for image transmission over networks. + + Reduces bandwidth when transmitting images across network boundaries. + + Trade-offs: + - Lower bandwidth via compression + - Lossy compression (some quality loss) + - CPU overhead for encode/decode + + See also: + LCMTransport: Uncompressed LCM transport for general data. + JpegShmTransport: JPEG compression over shared memory (same-machine). + """ + + def __init__(self, topic: str, type: type, **kwargs) -> None: # type: ignore[no-untyped-def] + self.lcm = JpegLCM(**kwargs) # type: ignore[assignment] + super().__init__(topic, type) + + def __reduce__(self): # type: ignore[no-untyped-def] + return (JpegLcmTransport, (self.topic.topic, self.topic.lcm_type)) + + +class pSHMTransport(PubSubTransport[T]): + """Local-only transport using POSIX shared memory with pickle serialization. + + Provides high-throughput, low-latency communication between processes on the + same machine. Data is shared via memory-mapped regions rather than copied + over network sockets, making this ideal for large payloads like camera frames + or point clouds. + + Unlike network-capable transports (`pLCMTransport`, `LCMTransport`), this cannot + communicate across machines. Use `pLCMTransport` instead when network distribution is + needed. + """ + + _started: bool = False + + def __init__( + self, + topic: Annotated[str, Doc("Channel identifier for publish/subscribe routing.")], + **kwargs: Annotated[ + Any, + Doc( + """Passed to PickleSharedMemory. Key option: + default_capacity: Max payload size in bytes (default ~3.5MB). + This should be increased for very large data.""" + ), + ], + ) -> None: + super().__init__(topic) + self.shm = PickleSharedMemory(**kwargs) + + def __reduce__(self): # type: ignore[no-untyped-def] + return (pSHMTransport, (self.topic,)) + + def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def] + if not self._started: + self.shm.start() + self._started = True + + self.shm.publish(self.topic, msg) + + def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override] + if not self._started: + self.shm.start() + self._started = True + return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[return-value] + + +class SHMTransport(PubSubTransport[T]): + """Shared memory transport for raw bytes data. + + Uses POSIX shared memory with minimal encoding overhead, providing high + throughput for local inter-process communication when data is already + in bytes format. Unlike `pSHMTransport`, which pickle-serializes Python + objects, this transport expects bytes-like data (bytes, bytearray, + or memoryview). + + Use this transport when: + - Processes run on the same machine + - Data is already bytes-like (e.g., sensor buffers, encoded frames) + - Maximum throughput is critical + + Use `pSHMTransport` instead when you need to send arbitrary Python objects. + """ + + _started: bool = False + + def __init__(self, topic: str, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(topic) + self.shm = SharedMemory(**kwargs) + + def __reduce__(self): # type: ignore[no-untyped-def] + return (SHMTransport, (self.topic,)) + + def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def] + if not self._started: + self.shm.start() + self._started = True + + self.shm.publish(self.topic, msg) + + def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override] + if not self._started: + self.shm.start() + self._started = True + return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[arg-type, return-value] + + +class JpegShmTransport(PubSubTransport[T]): + """Shared memory transport with JPEG compression for Image objects. + + Uses shared memory for fast local inter-process communication while applying + JPEG compression to reduce memory footprint. Only works for local consumers + (same machine); not suitable for network transport. + + Trade-offs: + - Lower memory usage than uncompressed shared memory + - Adds CPU overhead for encode/decode + - Lossy compression (quality parameter controls fidelity vs size) + """ + + _started: bool = False + + def __init__( + self, + topic: Annotated[str, Doc("Channel identifier for pub/sub routing.")], + quality: Annotated[ + int, + Doc( + """JPEG compression quality (1-100). Lower values produce smaller + images with more artifacts.""" + ), + ] = 75, + **kwargs: Annotated[ + Any, Doc("Additional arguments passed to the underlying shared memory.") + ], + ) -> None: + super().__init__(topic) + self.shm = JpegSharedMemory(quality=quality, **kwargs) + self.quality = quality + + def __reduce__(self): # type: ignore[no-untyped-def] + return (JpegShmTransport, (self.topic, self.quality)) + + def broadcast(self, _, msg) -> None: # type: ignore[no-untyped-def] + if not self._started: + self.shm.start() + self._started = True + + self.shm.publish(self.topic, msg) + + def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> None: # type: ignore[assignment, override] + if not self._started: + self.shm.start() + self._started = True + return self.shm.subscribe(self.topic, lambda msg, topic: callback(msg)) # type: ignore[arg-type, return-value] + + +class ZenohTransport(PubSubTransport[T]): ... diff --git a/dimos/data/data_pipeline.py b/dimos/data/data_pipeline.py deleted file mode 100644 index 5fe9c85631..0000000000 --- a/dimos/data/data_pipeline.py +++ /dev/null @@ -1,124 +0,0 @@ -from .depth import DepthProcessor -from .labels import LabelProcessor -from .pointcloud import PointCloudProcessor -from .segment import SegmentProcessor -from dimos.stream.videostream import VideoStream # Lukas to implement -import warnings -from concurrent.futures import ProcessPoolExecutor, as_completed -from collections import deque - -class DataPipeline: - def __init__(self, video_stream: VideoStream, - run_depth: bool = False, - run_labels: bool = False, - run_pointclouds: bool = False, - run_segmentations: bool = False, - max_workers: int = 4): - """ - Initialize the DataPipeline with specified pipeline layers. - - Args: - video_stream (VideoStream): The video stream to process. - run_depth (bool): Whether to run the depth map generation. - run_labels (bool): Whether to run the label/caption generation. - run_pointclouds (bool): Whether to run the point cloud generation. - run_segmentations (bool): Whether to run the segmentation generation. - max_workers (int): Maximum number of worker processes for parallel processing. - - Raises: - ValueError: If invalid pipeline configurations are provided. - """ - self.video_stream = video_stream - self.depth_processor = DepthProcessor(debug=True) if run_depth else None - self.labels_processor = LabelProcessor(debug=True) if run_labels else None - self.pointcloud_processor = PointCloudProcessor(debug=True) if run_pointclouds else None - self.segmentation_processor = SegmentationProcessor(debug=True) if run_segmentations else None - self.run_depth = run_depth - self.run_labels = run_labels - self.run_pointclouds = run_pointclouds - self.run_segmentations = run_segmentations - - self.max_workers = max_workers - - # Validate pipeline configuration - self._validate_pipeline() - - # Initialize the pipeline - self._initialize_pipeline() - - # Storage for processed data - self.generated_depth_maps = deque() - self.generated_labels = deque() - self.generated_pointclouds = deque() - self.generated_segmentations = deque() - - def _validate_pipeline(self): - """Validate the pipeline configuration based on dependencies.""" - if self.run_pointclouds and not self.run_depth: - raise ValueError("PointClouds generation requires Depth maps. " - "Enable run_depth=True to use run_pointclouds=True.") - - if self.run_segmentations and not self.run_labels: - raise ValueError("Segmentations generation requires Labels. " - "Enable run_labels=True to use run_segmentations=True.") - - if not any([self.run_depth, self.run_labels, self.run_pointclouds, self.run_segmentations]): - warnings.warn("No pipeline layers selected to run. The DataPipeline will be initialized without any processing.") - - def _initialize_pipeline(self): - """Initialize necessary components based on selected pipeline layers.""" - if self.run_depth: - print("Depth map generation enabled.") - - if self.run_labels: - print("Label generation enabled.") - - if self.run_pointclouds: - print("PointCloud generation enabled.") - - if self.run_segmentations: - print("Segmentation generation enabled.") - - def run(self): - """Execute the selected pipeline layers in parallel.""" - with ProcessPoolExecutor(max_workers=self.max_workers) as executor: - future_to_frame = {} - for frame in self.video_stream: - # Submit frame processing to the executor - future = executor.submit(self._process_frame, frame) - future_to_frame[future] = frame - - # Collect results as they become available - for future in as_completed(future_to_frame): - result = future.result() - depth_map, label, pointcloud, segmentation = result - - if depth_map is not None: - self.generated_depth_maps.append(depth_map) - if label is not None: - self.generated_labels.append(label) - if pointcloud is not None: - self.generated_pointclouds.append(pointcloud) - if segmentation is not None: - self.generated_segmentations.append(segmentation) - - def _process_frame(self, frame): - """Process a single frame and return results.""" - depth_map = None - label = None - pointcloud = None - segmentation = None - - if self.run_depth: - depth_map = self.depth_processor.process(frame) - - if self.run_labels: - label = self.labels_processor.caption_image_data(frame) - - if self.run_pointclouds and depth_map is not None: - pointcloud = self.pointcloud_processor.process_frame(frame, depth_map) - - if self.run_segmentations and label is not None: - segmentation = self.segmentation_processor.process_frame(frame, label) - - return depth_map, label, pointcloud, segmentation diff --git a/dimos/data/depth.py b/dimos/data/depth.py deleted file mode 100644 index b671924561..0000000000 --- a/dimos/data/depth.py +++ /dev/null @@ -1,85 +0,0 @@ -from dimos.models.depth.metric3d import Metric3D -import os -import pickle -import argparse -import pandas as pd -from PIL import Image -from io import BytesIO -import torch -import sys -import cv2 -import tarfile -import logging -import time -import tempfile -import gc -import io -import csv -import numpy as np - -class DepthProcessor: - def __init__(self, debug=False): - self.debug = debug - self.metric_3d = Metric3D() - self.depth_count = 0 - self.valid_depth_count = 0 - self.logger = logging.getLogger(__name__) - self.intrinsic = [707.0493, 707.0493, 604.0814, 180.5066] # Default intrinsic - - print("DepthProcessor initialized") - - if debug: - print("Running in debug mode") - self.logger.info("Running in debug mode") - - - def process(self, frame: Image.Image, intrinsics=None): - """Process a frame to generate a depth map. - - Args: - frame: PIL Image to process - intrinsics: Optional camera intrinsics parameters - - Returns: - PIL Image containing the depth map - """ - if intrinsics: - self.metric_3d.update_intrinsic(intrinsics) - else: - self.metric_3d.update_intrinsic(self.intrinsic) - - # Convert frame to numpy array suitable for processing - if isinstance(frame, Image.Image): - image = frame.convert('RGB') - elif isinstance(frame, np.ndarray): - image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - else: - raise ValueError("Unsupported frame format. Must be PIL Image or numpy array.") - - image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) - image_np = resize_image_for_vit(image_np) - - # Process image and run depth via Metric3D - try: - with torch.no_grad(): - depth_map = self.metric_3d.infer_depth(image_np) - - self.depth_count += 1 - - # Validate depth map - if is_depth_map_valid(np.array(depth_map)): - self.valid_depth_count += 1 - else: - self.logger.error(f"Invalid depth map for the provided frame.") - print("Invalid depth map for the provided frame.") - return None - - if self.debug: - # Save depth map locally or to S3 as needed - pass # Implement saving logic if required - - return depth_map - - except Exception as e: - self.logger.error(f"Error processing frame: {e}") - return None \ No newline at end of file diff --git a/dimos/data/labels.py b/dimos/data/labels.py deleted file mode 100644 index 1b422e3f99..0000000000 --- a/dimos/data/labels.py +++ /dev/null @@ -1,17 +0,0 @@ -from dimos.models.labels.llava-34b import Llava -from PIL import Image - -class LabelProcessor: - def __init__(self, debug: bool = False): - self.model = Llava(mmproj="/app/models/mmproj-model-f16.gguf", model_path="/app/models/llava-v1.6-34b.Q4_K_M.gguf", gpu=True) - self.prompt = 'Create a JSON representation where each entry consists of a key "object" with a numerical suffix starting from 1, and a corresponding "description" key with a value that is a concise, up to six-word sentence describing each main, distinct object or person in the image. Each pair should uniquely describe one element without repeating keys. An example: {"object1": { "description": "Man in red hat walking." },"object2": { "description": "Wooden pallet with boxes." },"object3": { "description": "Cardboard boxes stacked." },"object4": { "description": "Man in green vest standing." }}' - self.debug = debug - def caption_image_data(self, frame: Image.Image): - try: - output = self.model.run_inference(frame, self.prompt, return_json=True) - if self.debug: - print("output", output) - return output - except Exception as e: - logger.error(f"Error in captioning image: {e}") - return [] \ No newline at end of file diff --git a/dimos/data/pointcloud.py b/dimos/data/pointcloud.py deleted file mode 100644 index 61713cd587..0000000000 --- a/dimos/data/pointcloud.py +++ /dev/null @@ -1,113 +0,0 @@ -import os -import cv2 -import numpy as np -import open3d as o3d -from pathlib import Path -from PIL import Image -import logging - -from dimos.models.segmentation.segment_utils import apply_mask_to_image -from dimos.models.pointcloud.pointcloud_utils import ( - create_point_cloud_from_rgbd, - canonicalize_point_cloud -) - -# Setup logging -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class PointCloudProcessor: - def __init__(self, output_dir, intrinsic_parameters=None): - """ - Initializes the PointCloudProcessor. - - Args: - output_dir (str): The directory where point clouds will be saved. - intrinsic_parameters (dict, optional): Camera intrinsic parameters. - Defaults to None, in which case default parameters are used. - """ - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - self.logger = logger - - # Default intrinsic parameters - self.default_intrinsic_parameters = { - 'width': 640, - 'height': 480, - 'fx': 960.0, - 'fy': 960.0, - 'cx': 320.0, - 'cy': 240.0, - } - self.intrinsic_parameters = intrinsic_parameters if intrinsic_parameters else self.default_intrinsic_parameters - - def process_frame(self, image, depth_map, masks): - """ - Process a single frame to generate point clouds. - - Args: - image (PIL.Image.Image or np.ndarray): The RGB image. - depth_map (PIL.Image.Image or np.ndarray): The depth map corresponding to the image. - masks (list of np.ndarray): A list of binary masks for segmentation. - - Returns: - list of o3d.geometry.PointCloud: A list of point clouds for each mask. - bool: A flag indicating if the point clouds were canonicalized. - """ - try: - self.logger.info("STARTING POINT CLOUD PROCESSING ---------------------------------------") - - # Convert images to OpenCV format if they are PIL Images - if isinstance(image, Image.Image): - original_image_cv = cv2.cvtColor(np.array(image.convert('RGB')), cv2.COLOR_RGB2BGR) - else: - original_image_cv = image - - if isinstance(depth_map, Image.Image): - depth_image_cv = cv2.cvtColor(np.array(depth_map.convert('RGB')), cv2.COLOR_RGB2BGR) - else: - depth_image_cv = depth_map - - width, height = original_image_cv.shape[1], original_image_cv.shape[0] - intrinsic_parameters = self.intrinsic_parameters.copy() - intrinsic_parameters.update({ - 'width': width, - 'height': height, - 'cx': width / 2, - 'cy': height / 2, - }) - - point_clouds = [] - point_cloud_data = [] - - # Create original point cloud - original_pcd = create_point_cloud_from_rgbd(original_image_cv, depth_image_cv, intrinsic_parameters) - pcd, canonicalized, transformation = canonicalize_point_cloud(original_pcd, canonicalize_threshold=0.3) - - for idx, mask in enumerate(masks): - mask_binary = mask > 0 - - masked_rgb = apply_mask_to_image(original_image_cv, mask_binary) - masked_depth = apply_mask_to_image(depth_image_cv, mask_binary) - - pcd = create_point_cloud_from_rgbd(masked_rgb, masked_depth, intrinsic_parameters) - # Remove outliers - cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0) - inlier_cloud = pcd.select_by_index(ind) - if canonicalized: - inlier_cloud.transform(transformation) - - point_clouds.append(inlier_cloud) - # Save point cloud to file - pointcloud_filename = f"pointcloud_{idx}.pcd" - pointcloud_filepath = os.path.join(self.output_dir, pointcloud_filename) - o3d.io.write_point_cloud(pointcloud_filepath, inlier_cloud) - point_cloud_data.append(pointcloud_filepath) - self.logger.info(f"Saved point cloud {pointcloud_filepath}") - - self.logger.info("DONE POINT CLOUD PROCESSING ---------------------------------------") - return point_clouds, canonicalized - except Exception as e: - self.logger.error(f"Error processing frame: {e}") - return [], False diff --git a/dimos/data/segment.py b/dimos/data/segment.py deleted file mode 100644 index 1e98ebe4b9..0000000000 --- a/dimos/data/segment.py +++ /dev/null @@ -1,72 +0,0 @@ -import cv2 -import numpy as np -from PIL import Image -import logging -from dimos.models.segmentation.segment_utils import sample_points_from_heatmap -from dimos.models.segmentation.sam import SAM -from dimos.models.segmentation.clipseg import CLIPSeg - -# Setup logging -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class SegmentProcessor: - def __init__(self, device='cuda'): - # Initialize CLIPSeg and SAM models - self.clipseg = CLIPSeg(model_name="CIDAS/clipseg-rd64-refined", device=device) - self.sam = SAM(model_name="facebook/sam-vit-huge", device=device) - self.logger = logger - - def process_frame(self, image, captions): - """ - Process a single image and return segmentation masks. - - Args: - image (PIL.Image.Image or np.ndarray): The input image to process. - captions (list of str): A list of captions for segmentation. - - Returns: - list of np.ndarray: A list of segmentation masks corresponding to the captions. - """ - try: - self.logger.info("STARTING PROCESSING IMAGE ---------------------------------------") - self.logger.info(f"Processing image with captions: {captions}") - - # Convert image to PIL.Image if it's a numpy array - if isinstance(image, np.ndarray): - image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - - preds = self.clipseg.run_inference(image, captions) - sampled_points = [] - sam_masks = [] - - original_size = image.size # (width, height) - - for idx in range(preds.shape[0]): - points = sample_points_from_heatmap(preds[idx][0], original_size, num_points=10) - if points: - sampled_points.append(points) - else: - self.logger.info(f"No points sampled for prediction index {idx}") - sampled_points.append([]) - - for idx in range(preds.shape[0]): - if sampled_points[idx]: - mask_tensor = self.sam.run_inference_from_points(image, [sampled_points[idx]]) - if mask_tensor: - # Convert mask tensor to a numpy array - mask = (255 * mask_tensor[0].numpy().squeeze()).astype(np.uint8) - sam_masks.append(mask) - else: - self.logger.info(f"No mask tensor returned for sampled points at index {idx}") - sam_masks.append(np.zeros((original_size[1], original_size[0]), dtype=np.uint8)) - else: - self.logger.info(f"No sampled points for prediction index {idx}, skipping mask inference") - sam_masks.append(np.zeros((original_size[1], original_size[0]), dtype=np.uint8)) - - self.logger.info("DONE PROCESSING IMAGE ---------------------------------------") - return sam_masks - except Exception as e: - self.logger.error(f"Error processing image: {e}") - return [] \ No newline at end of file diff --git a/dimos/data/videostream-data-pipeline.md b/dimos/data/videostream-data-pipeline.md deleted file mode 100644 index 5f44d8b143..0000000000 --- a/dimos/data/videostream-data-pipeline.md +++ /dev/null @@ -1,29 +0,0 @@ -# Example data pipeline from video stream implementation - -```bash - from dimos.stream.videostream import VideoStream - from dimos.data.data_pipeline import DataPipeline - - # init video stream from the camera source - video_stream = VideoStream(source=0) - - # init data pipeline with desired processors enabled, max workers is 4 by default - # depth only implementation - pipeline = DataPipeline( - video_stream=video_stream, - run_depth=True, - run_labels=False, - run_pointclouds=False, - run_segmentations=False - ) - - try: - # Run pipeline - pipeline.run() - except KeyboardInterrupt: - # Handle interrupt - print("Pipeline interrupted by user.") - finally: - # Release the video capture - video_stream.release() -``` diff --git a/dimos/environment/agent_environment.py b/dimos/environment/agent_environment.py deleted file mode 100644 index 312bc9cecd..0000000000 --- a/dimos/environment/agent_environment.py +++ /dev/null @@ -1,121 +0,0 @@ -import cv2 -import numpy as np -from pathlib import Path -from typing import List, Union -from .environment import Environment - -class AgentEnvironment(Environment): - def __init__(self): - super().__init__() - self.environment_type = "agent" - self.frames = [] - self.current_frame_idx = 0 - self._depth_maps = [] - self._segmentations = [] - self._point_clouds = [] - - def initialize_from_images(self, images: Union[List[str], List[np.ndarray]]) -> bool: - """Initialize environment from a list of image paths or numpy arrays. - - Args: - images: List of image paths or numpy arrays representing frames - - Returns: - bool: True if initialization successful, False otherwise - """ - try: - self.frames = [] - for img in images: - if isinstance(img, str): - frame = cv2.imread(img) - if frame is None: - raise ValueError(f"Failed to load image: {img}") - self.frames.append(frame) - elif isinstance(img, np.ndarray): - self.frames.append(img.copy()) - else: - raise ValueError(f"Unsupported image type: {type(img)}") - return True - except Exception as e: - print(f"Failed to initialize from images: {e}") - return False - - def initialize_from_file(self, file_path: str) -> bool: - """Initialize environment from a video file. - - Args: - file_path: Path to the video file - - Returns: - bool: True if initialization successful, False otherwise - """ - try: - if not Path(file_path).exists(): - raise FileNotFoundError(f"Video file not found: {file_path}") - - cap = cv2.VideoCapture(file_path) - self.frames = [] - - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break - self.frames.append(frame) - - cap.release() - return len(self.frames) > 0 - except Exception as e: - print(f"Failed to initialize from video: {e}") - return False - - def initialize_from_directory(self, directory_path: str) -> bool: - """Initialize environment from a directory of images.""" - # TODO: Implement directory initialization - raise NotImplementedError("Directory initialization not yet implemented") - - def label_objects(self) -> List[str]: - """Implementation of abstract method to label objects.""" - # TODO: Implement object labeling using a detection model - raise NotImplementedError("Object labeling not yet implemented") - - - def generate_segmentations(self, model: str = None, objects: List[str] = None, *args, **kwargs) -> List[np.ndarray]: - """Generate segmentations for the current frame.""" - # TODO: Implement segmentation generation using specified model - raise NotImplementedError("Segmentation generation not yet implemented") - - def get_segmentations(self) -> List[np.ndarray]: - """Return pre-computed segmentations for the current frame.""" - if self._segmentations: - return self._segmentations[self.current_frame_idx] - return [] - - def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: - """Generate point cloud from the current frame.""" - # TODO: Implement point cloud generation - raise NotImplementedError("Point cloud generation not yet implemented") - - def get_point_cloud(self, object: str = None) -> np.ndarray: - """Return pre-computed point cloud.""" - if self._point_clouds: - return self._point_clouds[self.current_frame_idx] - return np.array([]) - - def generate_depth_map(self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs) -> np.ndarray: - """Generate depth map for the current frame.""" - # TODO: Implement depth map generation using specified method - raise NotImplementedError("Depth map generation not yet implemented") - - def get_depth_map(self) -> np.ndarray: - """Return pre-computed depth map for the current frame.""" - if self._depth_maps: - return self._depth_maps[self.current_frame_idx] - return np.array([]) - - def get_frame_count(self) -> int: - """Return the total number of frames.""" - return len(self.frames) - - def get_current_frame_index(self) -> int: - """Return the current frame index.""" - return self.current_frame_idx diff --git a/dimos/environment/colmap_environment.py b/dimos/environment/colmap_environment.py deleted file mode 100644 index 4f74f65101..0000000000 --- a/dimos/environment/colmap_environment.py +++ /dev/null @@ -1,72 +0,0 @@ -import cv2 -import pycolmap -from pathlib import Path -from dimos.environment.environment import Environment - -class COLMAPEnvironment(Environment): - def initialize_from_images(self, image_dir): - """Initialize the environment from a set of image frames or video.""" - image_dir = Path(image_dir) - output_path = Path("colmap_output") - output_path.mkdir(exist_ok=True) - mvs_path = output_path / "mvs" - database_path = output_path / "database.db" - - # Step 1: Feature extraction - pycolmap.extract_features(database_path, image_dir) - - # Step 2: Feature matching - pycolmap.match_exhaustive(database_path) - - # Step 3: Sparse reconstruction - maps = pycolmap.incremental_mapping(database_path, image_dir, output_path) - maps[0].write(output_path) - - # Step 4: Dense reconstruction (optional) - pycolmap.undistort_images(mvs_path, output_path, image_dir) - pycolmap.patch_match_stereo(mvs_path) # Requires compilation with CUDA - pycolmap.stereo_fusion(mvs_path / "dense.ply", mvs_path) - - return maps - - def initialize_from_video(self, video_path, frame_output_dir): - """Extract frames from a video and initialize the environment.""" - video_path = Path(video_path) - frame_output_dir = Path(frame_output_dir) - frame_output_dir.mkdir(exist_ok=True) - - # Extract frames from the video - self._extract_frames_from_video(video_path, frame_output_dir) - - # Initialize from the extracted frames - return self.initialize_from_images(frame_output_dir) - - def _extract_frames_from_video(self, video_path, frame_output_dir): - """Extract frames from a video and save them to a directory.""" - cap = cv2.VideoCapture(str(video_path)) - frame_count = 0 - - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break - frame_filename = frame_output_dir / f"frame_{frame_count:04d}.jpg" - cv2.imwrite(str(frame_filename), frame) - frame_count += 1 - - cap.release() - - def label_objects(self): - pass - - def get_visualization(self, format_type): - pass - - def get_segmentations(self): - pass - - def get_point_cloud(self, object_id=None): - pass - - def get_depth_map(self): - pass diff --git a/dimos/environment/environment.py b/dimos/environment/environment.py index dc02febfc3..ba1923b765 100644 --- a/dimos/environment/environment.py +++ b/dimos/environment/environment.py @@ -1,8 +1,24 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod + import numpy as np + class Environment(ABC): - def __init__(self): + def __init__(self) -> None: self.environment_type = None self.graph = None @@ -10,19 +26,21 @@ def __init__(self): def label_objects(self) -> list[str]: """ Label all objects in the environment. - + Returns: A list of string labels representing the objects in the environment. """ pass @abstractmethod - def get_visualization(self, format_type): + def get_visualization(self, format_type): # type: ignore[no-untyped-def] """Return different visualization formats like images, NERFs, or other 3D file types.""" pass - + @abstractmethod - def generate_segmentations(self, model: str = None, objects: list[str] = None, *args, **kwargs) -> list[np.ndarray]: + def generate_segmentations( # type: ignore[no-untyped-def] + self, model: str | None = None, objects: list[str] | None = None, *args, **kwargs + ) -> list[np.ndarray]: # type: ignore[type-arg] """ Generate object segmentations of objects[] using neural methods. @@ -42,7 +60,7 @@ def generate_segmentations(self, model: str = None, objects: list[str] = None, * pass @abstractmethod - def get_segmentations(self) -> list[np.ndarray]: + def get_segmentations(self) -> list[np.ndarray]: # type: ignore[type-arg] """ Get segmentations using a method like 'segment anything'. @@ -52,9 +70,8 @@ def get_segmentations(self) -> list[np.ndarray]: """ pass - @abstractmethod - def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: + def generate_point_cloud(self, object: str | None = None, *args, **kwargs) -> np.ndarray: # type: ignore[no-untyped-def, type-arg] """ Generate a point cloud for the entire environment or a specific object. @@ -74,7 +91,7 @@ def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarra pass @abstractmethod - def get_point_cloud(self, object: str = None) -> np.ndarray: + def get_point_cloud(self, object: str | None = None) -> np.ndarray: # type: ignore[type-arg] """ Return point clouds of the entire environment or a specific object. @@ -88,7 +105,14 @@ def get_point_cloud(self, object: str = None) -> np.ndarray: pass @abstractmethod - def generate_depth_map(self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs) -> np.ndarray: + def generate_depth_map( # type: ignore[no-untyped-def] + self, + stereo: bool | None = None, + monocular: bool | None = None, + model: str | None = None, + *args, + **kwargs, + ) -> np.ndarray: # type: ignore[type-arg] """ Generate a depth map using monocular or stereo camera methods. @@ -110,7 +134,7 @@ def generate_depth_map(self, stereo: bool = None, monocular: bool = None, model: pass @abstractmethod - def get_depth_map(self) -> np.ndarray: + def get_depth_map(self) -> np.ndarray: # type: ignore[type-arg] """ Return a depth map of the environment. @@ -126,11 +150,11 @@ def get_depth_map(self) -> np.ndarray: """ pass - def initialize_from_images(self, images): + def initialize_from_images(self, images): # type: ignore[no-untyped-def] """Initialize the environment from a set of image frames or video.""" raise NotImplementedError("This method is not implemented for this environment type.") - def initialize_from_file(self, file_path): + def initialize_from_file(self, file_path): # type: ignore[no-untyped-def] """Initialize the environment from a spatial file type. Supported file types include: @@ -152,5 +176,3 @@ def initialize_from_file(self, file_path): NotImplementedError: If the method is not implemented for this environment type. """ raise NotImplementedError("This method is not implemented for this environment type.") - - diff --git a/dimos/environment/manipulation_environment.py b/dimos/environment/manipulation_environment.py deleted file mode 100644 index 48d1417a24..0000000000 --- a/dimos/environment/manipulation_environment.py +++ /dev/null @@ -1,5 +0,0 @@ -from dimos.environment.environment import Environment - -class ManipulationEnvironment(Environment): - # Implement specific methods as needed - pass diff --git a/dimos/environment/simulation_environment.py b/dimos/environment/simulation_environment.py deleted file mode 100644 index 7216ea4135..0000000000 --- a/dimos/environment/simulation_environment.py +++ /dev/null @@ -1,7 +0,0 @@ -from dimos.environment.environment import Environment - -class SimulationEnvironment(Environment): - def initialize_from_file(self, file_path): - """Initialize the environment from a spatial file type like GLTF.""" - # Implementation for initializing from a file - pass diff --git a/dimos/exceptions/agent_memory_exceptions.py b/dimos/exceptions/agent_memory_exceptions.py index 82a2a15207..eec80be83c 100644 --- a/dimos/exceptions/agent_memory_exceptions.py +++ b/dimos/exceptions/agent_memory_exceptions.py @@ -1,65 +1,93 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import traceback + class AgentMemoryError(Exception): """ Base class for all exceptions raised by AgentMemory operations. All custom exceptions related to AgentMemory should inherit from this class. - + Args: message (str): Human-readable message describing the error. """ - def __init__(self, message="Error in AgentMemory operation"): + + def __init__(self, message: str = "Error in AgentMemory operation") -> None: super().__init__(message) + class AgentMemoryConnectionError(AgentMemoryError): """ Exception raised for errors attempting to connect to the database. This includes failures due to network issues, authentication errors, or incorrect connection parameters. - + Args: message (str): Human-readable message describing the error. cause (Exception, optional): Original exception, if any, that led to this error. """ - def __init__(self, message="Failed to connect to the database", cause=None): + + def __init__(self, message: str = "Failed to connect to the database", cause=None) -> None: # type: ignore[no-untyped-def] super().__init__(message) if cause: self.cause = cause self.traceback = traceback.format_exc() if cause else None - def __str__(self): - return f"{self.message}\nCaused by: {repr(self.cause)}" if self.cause else self.message + def __str__(self) -> str: + return f"{self.message}\nCaused by: {self.cause!r}" if self.cause else self.message # type: ignore[attr-defined] + class UnknownConnectionTypeError(AgentMemoryConnectionError): """ Exception raised when an unknown or unsupported connection type is specified during AgentMemory setup. - + Args: message (str): Human-readable message explaining that an unknown connection type was used. """ - def __init__(self, message="Unknown connection type used in AgentMemory connection"): + + def __init__( + self, message: str = "Unknown connection type used in AgentMemory connection" + ) -> None: super().__init__(message) + class DataRetrievalError(AgentMemoryError): """ Exception raised for errors retrieving data from the database. This could occur due to query failures, timeouts, or corrupt data issues. - + Args: message (str): Human-readable message describing the data retrieval error. """ - def __init__(self, message="Error in retrieving data during AgentMemory operation"): + + def __init__( + self, message: str = "Error in retrieving data during AgentMemory operation" + ) -> None: super().__init__(message) + class DataNotFoundError(DataRetrievalError): """ Exception raised when the requested data is not found in the database. This is used when a query completes successfully but returns no result for the specified identifier. - + Args: vector_id (int or str): The identifier for the vector that was not found. message (str, optional): Human-readable message providing more detail. If not provided, a default message is generated. """ - def __init__(self, vector_id, message=None): + + def __init__(self, vector_id, message=None) -> None: # type: ignore[no-untyped-def] message = message or f"Requested data for vector ID {vector_id} was not found." super().__init__(message) self.vector_id = vector_id diff --git a/dimos/external/colmap b/dimos/external/colmap deleted file mode 160000 index 189478b69b..0000000000 --- a/dimos/external/colmap +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 189478b69bf09b80b6143c491f5b29023ef73e7a diff --git a/dimos/hardware/README.md b/dimos/hardware/README.md new file mode 100644 index 0000000000..2587e3595d --- /dev/null +++ b/dimos/hardware/README.md @@ -0,0 +1,29 @@ +# Hardware + +## Remote camera stream with timestamps + +### Required Ubuntu packages: + +```bash +sudo apt install gstreamer1.0-tools gstreamer1.0-plugins-base gstreamer1.0-plugins-good gstreamer1.0-plugins-bad gstreamer1.0-plugins-ugly gstreamer1.0-libav python3-gi python3-gi-cairo gir1.2-gstreamer-1.0 gir1.2-gst-plugins-base-1.0 v4l-utils gstreamer1.0-vaapi +``` + +### Usage + +On sender machine (with the camera): + +```bash +python3 dimos/hardware/gstreamer_sender.py --device /dev/video0 --host 0.0.0.0 --port 5000 +``` + +If it's a stereo camera and you only want to send the left side (the left camera): + +```bash +python3 dimos/hardware/gstreamer_sender.py --device /dev/video0 --host 0.0.0.0 --port 5000 --single-camera +``` + +On receiver machine: + +```bash +python3 dimos/hardware/gstreamer_camera_test_script.py --host 10.0.0.227 --port 5000 +``` diff --git a/examples/web/__init__.py b/dimos/hardware/__init__.py similarity index 100% rename from examples/web/__init__.py rename to dimos/hardware/__init__.py diff --git a/dimos/hardware/camera.py b/dimos/hardware/camera.py deleted file mode 100644 index aba6cf0274..0000000000 --- a/dimos/hardware/camera.py +++ /dev/null @@ -1,37 +0,0 @@ -from dimos.hardware.sensor import AbstractSensor - -class Camera(AbstractSensor): - def __init__(self, resolution=None, focal_length=None, sensor_size=None, sensor_type='Camera'): - super().__init__(sensor_type) - self.resolution = resolution # (width, height) in pixels - self.focal_length = focal_length # in millimeters - self.sensor_size = sensor_size # (width, height) in millimeters - - def get_sensor_type(self): - return self.sensor_type - - def calculate_intrinsics(self): - if not self.resolution or not self.focal_length or not self.sensor_size: - raise ValueError("Resolution, focal length, and sensor size must be provided") - - # Calculate pixel size - pixel_size_x = self.sensor_size[0] / self.resolution[0] - pixel_size_y = self.sensor_size[1] / self.resolution[1] - - # Calculate the principal point (assuming it's at the center of the image) - principal_point_x = self.resolution[0] / 2 - principal_point_y = self.resolution[1] / 2 - - # Calculate the focal length in pixels - focal_length_x = self.focal_length / pixel_size_x - focal_length_y = self.focal_length / pixel_size_y - - return { - 'focal_length_x': focal_length_x, - 'focal_length_y': focal_length_y, - 'principal_point_x': principal_point_x, - 'principal_point_y': principal_point_y - } - - def get_intrinsics(self): - return self.calculate_intrinsics() diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py new file mode 100644 index 0000000000..d7977ec51d --- /dev/null +++ b/dimos/hardware/camera/module.py @@ -0,0 +1,120 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass, field +import queue +import time + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +import reactivex as rx +from reactivex import operators as ops +from reactivex.observable import Observable + +from dimos import spec +from dimos.agents2 import Output, Reducer, Stream, skill # type: ignore[attr-defined] +from dimos.core import Module, ModuleConfig, Out, rpc +from dimos.hardware.camera.spec import CameraHardware +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier + + +def default_transform(): # type: ignore[no-untyped-def] + return Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ) + + +@dataclass +class CameraModuleConfig(ModuleConfig): + frame_id: str = "camera_link" + transform: Transform | None = field(default_factory=default_transform) + hardware: Callable[[], CameraHardware] | CameraHardware = Webcam # type: ignore[type-arg] + frequency: float = 5.0 + + +class CameraModule(Module, spec.Camera): + color_image: Out[Image] = None # type: ignore[assignment] + camera_info: Out[CameraInfo] = None # type: ignore[assignment] + + hardware: Callable[[], CameraHardware] | CameraHardware = None # type: ignore[assignment, type-arg] + _skill_stream: Observable[Image] | None = None + + default_config = CameraModuleConfig + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + + @property + def hardware_camera_info(self) -> CameraInfo: + return self.hardware.camera_info # type: ignore[union-attr] + + @rpc + def start(self) -> str: # type: ignore[return] + if callable(self.config.hardware): # type: ignore[attr-defined] + self.hardware = self.config.hardware() # type: ignore[attr-defined] + else: + self.hardware = self.config.hardware # type: ignore[attr-defined] + + self._disposables.add(self.camera_info_stream().subscribe(self.publish_info)) + + stream = self.hardware.image_stream().pipe(sharpness_barrier(self.config.frequency)) # type: ignore[attr-defined, union-attr] + self._disposables.add(stream.subscribe(self.color_image.publish)) + + @rpc + def stop(self) -> None: + if self.hardware and hasattr(self.hardware, "stop"): + self.hardware.stop() + super().stop() + + @skill(stream=Stream.passive, output=Output.image, reducer=Reducer.latest) # type: ignore[arg-type] + def video_stream(self) -> Image: # type: ignore[misc] + """implicit video stream skill""" + _queue = queue.Queue(maxsize=1) # type: ignore[var-annotated] + self.hardware.image_stream().subscribe(_queue.put) # type: ignore[union-attr] + + yield from iter(_queue.get, None) + + def publish_info(self, camera_info: CameraInfo) -> None: + self.camera_info.publish(camera_info) + + if self.config.transform is None: # type: ignore[attr-defined] + return + + camera_link = self.config.transform # type: ignore[attr-defined] + camera_link.ts = camera_info.ts + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=camera_link.ts, + ) + + self.tf.publish(camera_link, camera_optical) + + def camera_info_stream(self, frequency: float = 1.0) -> Observable[CameraInfo]: + def camera_info(_) -> CameraInfo: # type: ignore[no-untyped-def] + self.hardware.camera_info.ts = time.time() # type: ignore[union-attr] + return self.hardware.camera_info # type: ignore[union-attr] + + return rx.interval(1.0 / frequency).pipe(ops.map(camera_info)) + + +camera_module = CameraModule.blueprint diff --git a/dimos/hardware/camera/spec.py b/dimos/hardware/camera/spec.py new file mode 100644 index 0000000000..0cd97389da --- /dev/null +++ b/dimos/hardware/camera/spec.py @@ -0,0 +1,55 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod, abstractproperty +from typing import Generic, Protocol, TypeVar + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from reactivex.observable import Observable + +from dimos.msgs.sensor_msgs import Image +from dimos.protocol.service import Configurable # type: ignore[attr-defined] + + +class CameraConfig(Protocol): + frame_id_prefix: str | None + + +CameraConfigT = TypeVar("CameraConfigT", bound=CameraConfig) + + +class CameraHardware(ABC, Configurable[CameraConfigT], Generic[CameraConfigT]): + @abstractmethod + def image_stream(self) -> Observable[Image]: + pass + + @abstractproperty + def camera_info(self) -> CameraInfo: + pass + + +# This is an example, feel free to change spec for stereo cameras +# e.g., separate camera_info or streams for left/right, etc. +class StereoCameraHardware(ABC, Configurable[CameraConfigT], Generic[CameraConfigT]): + @abstractmethod + def image_stream(self) -> Observable[Image]: + pass + + @abstractmethod + def depth_stream(self) -> Observable[Image]: + pass + + @abstractproperty + def camera_info(self) -> CameraInfo: + pass diff --git a/dimos/hardware/camera/test_webcam.py b/dimos/hardware/camera/test_webcam.py new file mode 100644 index 0000000000..8888ad7de9 --- /dev/null +++ b/dimos/hardware/camera/test_webcam.py @@ -0,0 +1,108 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos import core +from dimos.hardware.camera import zed +from dimos.hardware.camera.module import CameraModule +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import CameraInfo, Image + + +@pytest.mark.tool +def test_streaming_single() -> None: + dimos = core.start(1) + + camera = dimos.deploy( + CameraModule, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + stereo_slice="left", + camera_index=0, + frequency=15, + camera_info=zed.CameraInfo.SingleWebcam, + ), + ) + + camera.image.transport = core.LCMTransport("/image1", Image) + camera.camera_info.transport = core.LCMTransport("/image1/camera_info", CameraInfo) + camera.start() + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + camera.stop() + dimos.stop() + + +@pytest.mark.tool +def test_streaming_double() -> None: + dimos = core.start(2) + + camera1 = dimos.deploy( + CameraModule, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + stereo_slice="left", + camera_index=0, + frequency=15, + camera_info=zed.CameraInfo.SingleWebcam, + ), + ) + + camera2 = dimos.deploy( + CameraModule, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=4, + frequency=15, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ) + + camera1.image.transport = core.LCMTransport("/image1", Image) + camera1.camera_info.transport = core.LCMTransport("/image1/camera_info", CameraInfo) + camera1.start() + camera2.image.transport = core.LCMTransport("/image2", Image) + camera2.camera_info.transport = core.LCMTransport("/image2/camera_info", CameraInfo) + camera2.start() + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + camera1.stop() + camera2.stop() + dimos.stop() diff --git a/dimos/hardware/camera/webcam.py b/dimos/hardware/camera/webcam.py new file mode 100644 index 0000000000..3f66869890 --- /dev/null +++ b/dimos/hardware/camera/webcam.py @@ -0,0 +1,170 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from functools import cache +import threading +import time +from typing import Literal + +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from reactivex import create +from reactivex.observable import Observable + +from dimos.hardware.camera.spec import CameraConfig, CameraHardware +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.utils.reactive import backpressure + + +@dataclass +class WebcamConfig(CameraConfig): + camera_index: int = 0 # /dev/videoN + frame_width: int = 640 + frame_height: int = 480 + frequency: int = 15 + camera_info: CameraInfo = field(default_factory=CameraInfo) + frame_id_prefix: str | None = None + stereo_slice: Literal["left", "right"] | None = None # For stereo cameras + + +class Webcam(CameraHardware[WebcamConfig]): + default_config = WebcamConfig + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self._capture = None + self._capture_thread = None + self._stop_event = threading.Event() + self._observer = None + + @cache + def image_stream(self) -> Observable[Image]: + """Create an observable that starts/stops camera on subscription""" + + def subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + # Store the observer so emit() can use it + self._observer = observer + + # Start the camera when someone subscribes + try: + self.start() # type: ignore[no-untyped-call] + except Exception as e: + observer.on_error(e) + return + + # Return a dispose function to stop camera when unsubscribed + def dispose() -> None: + self._observer = None + self.stop() + + return dispose + + return backpressure(create(subscribe)) + + def start(self): # type: ignore[no-untyped-def] + if self._capture_thread and self._capture_thread.is_alive(): + return + + # Open the video capture + self._capture = cv2.VideoCapture(self.config.camera_index) # type: ignore[assignment] + if not self._capture.isOpened(): # type: ignore[attr-defined] + raise RuntimeError(f"Failed to open camera {self.config.camera_index}") + + # Set camera properties + self._capture.set(cv2.CAP_PROP_FRAME_WIDTH, self.config.frame_width) # type: ignore[attr-defined] + self._capture.set(cv2.CAP_PROP_FRAME_HEIGHT, self.config.frame_height) # type: ignore[attr-defined] + + # Clear stop event and start the capture thread + self._stop_event.clear() + self._capture_thread = threading.Thread(target=self._capture_loop, daemon=True) # type: ignore[assignment] + self._capture_thread.start() # type: ignore[attr-defined] + + def stop(self) -> None: + """Stop capturing frames""" + # Signal thread to stop + self._stop_event.set() + + # Wait for thread to finish + if self._capture_thread and self._capture_thread.is_alive(): + self._capture_thread.join(timeout=(1.0 / self.config.frequency) + 0.1) + + # Release the capture + if self._capture: + self._capture.release() + self._capture = None + + def _frame(self, frame: str): # type: ignore[no-untyped-def] + if not self.config.frame_id_prefix: + return frame + else: + return f"{self.config.frame_id_prefix}/{frame}" + + def capture_frame(self) -> Image: + # Read frame + ret, frame = self._capture.read() # type: ignore[attr-defined] + if not ret: + raise RuntimeError(f"Failed to read frame from camera {self.config.camera_index}") + + # Convert BGR to RGB (OpenCV uses BGR by default) + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Create Image message + # Using Image.from_numpy() since it's designed for numpy arrays + # Setting format to RGB since we converted from BGR->RGB above + image = Image.from_numpy( + frame_rgb, + format=ImageFormat.RGB, # We converted to RGB above + frame_id=self._frame("camera_optical"), # Standard frame ID for camera images + ts=time.time(), # Current timestamp + ) + + if self.config.stereo_slice in ("left", "right"): + half_width = image.width // 2 + if self.config.stereo_slice == "left": + image = image.crop(0, 0, half_width, image.height) + else: + image = image.crop(half_width, 0, half_width, image.height) + + return image + + def _capture_loop(self) -> None: + """Capture frames at the configured frequency""" + frame_interval = 1.0 / self.config.frequency + next_frame_time = time.time() + + while self._capture and not self._stop_event.is_set(): + image = self.capture_frame() + + # Emit the image to the observer only if not stopping + if self._observer and not self._stop_event.is_set(): + self._observer.on_next(image) + + # Wait for next frame time or until stopped + next_frame_time += frame_interval + sleep_time = next_frame_time - time.time() + if sleep_time > 0: + # Use event.wait so we can be interrupted by stop + if self._stop_event.wait(timeout=sleep_time): + break # Stop was requested + else: + # We're running behind, reset timing + next_frame_time = time.time() + + @property + def camera_info(self) -> CameraInfo: + return self.config.camera_info + + def emit(self, image: Image) -> None: ... diff --git a/dimos/hardware/camera/zed/__init__.py b/dimos/hardware/camera/zed/__init__.py new file mode 100644 index 0000000000..6321ded4bd --- /dev/null +++ b/dimos/hardware/camera/zed/__init__.py @@ -0,0 +1,56 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ZED camera hardware interfaces.""" + +from pathlib import Path + +from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider + +# Check if ZED SDK is available +try: + import pyzed.sl as sl # type: ignore[import-not-found] + + HAS_ZED_SDK = True +except ImportError: + HAS_ZED_SDK = False + +# Only import ZED classes if SDK is available +if HAS_ZED_SDK: + from dimos.hardware.camera.zed.camera import ZEDCamera, ZEDModule +else: + # Provide stub classes when SDK is not available + class ZEDCamera: # type: ignore[no-redef] + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + raise ImportError( + "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." + ) + + class ZEDModule: # type: ignore[no-redef] + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + raise ImportError( + "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." + ) + + +# Set up camera calibration provider (always available) +CALIBRATION_DIR = Path(__file__).parent +CameraInfo = CalibrationProvider(CALIBRATION_DIR) + +__all__ = [ + "HAS_ZED_SDK", + "CameraInfo", + "ZEDCamera", + "ZEDModule", +] diff --git a/dimos/hardware/camera/zed/camera.py b/dimos/hardware/camera/zed/camera.py new file mode 100644 index 0000000000..df0478ca34 --- /dev/null +++ b/dimos/hardware/camera/zed/camera.py @@ -0,0 +1,874 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import TracebackType +from typing import Any + +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +import numpy as np +import open3d as o3d # type: ignore[import-untyped] +import pyzed.sl as sl # type: ignore[import-not-found] +from reactivex import interval + +from dimos.core import Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 + +# Import LCM message types +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.protocol.tf import TF +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class ZEDCamera: + """ZED Camera capture node with neural depth processing.""" + + def __init__( # type: ignore[no-untyped-def] + self, + camera_id: int = 0, + resolution: sl.RESOLUTION = sl.RESOLUTION.HD720, + depth_mode: sl.DEPTH_MODE = sl.DEPTH_MODE.NEURAL, + fps: int = 30, + **kwargs, + ) -> None: + """ + Initialize ZED Camera. + + Args: + camera_id: Camera ID (0 for first ZED) + resolution: ZED camera resolution + depth_mode: Depth computation mode + fps: Camera frame rate (default: 30) + """ + if sl is None: + raise ImportError("ZED SDK not installed. Please install pyzed package.") + + super().__init__(**kwargs) + + self.camera_id = camera_id + self.resolution = resolution + self.depth_mode = depth_mode + self.fps = fps + + # Initialize ZED camera + self.zed = sl.Camera() + self.init_params = sl.InitParameters() + self.init_params.camera_resolution = resolution + self.init_params.depth_mode = depth_mode + self.init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Z_UP_X_FWD + self.init_params.coordinate_units = sl.UNIT.METER + self.init_params.camera_fps = fps + + # Set camera ID using the correct parameter name + if hasattr(self.init_params, "set_from_camera_id"): + self.init_params.set_from_camera_id(camera_id) + elif hasattr(self.init_params, "input"): + self.init_params.input.set_from_camera_id(camera_id) + + # Use enable_fill_mode instead of SENSING_MODE.STANDARD + self.runtime_params = sl.RuntimeParameters() + self.runtime_params.enable_fill_mode = True # False = STANDARD mode, True = FILL mode + + # Image containers + self.image_left = sl.Mat() + self.image_right = sl.Mat() + self.depth_map = sl.Mat() + self.point_cloud = sl.Mat() + self.confidence_map = sl.Mat() + + # Positional tracking + self.tracking_enabled = False + self.tracking_params = sl.PositionalTrackingParameters() + self.camera_pose = sl.Pose() + self.sensors_data = sl.SensorsData() + + self.is_opened = False + + def open(self) -> bool: + """Open the ZED camera.""" + try: + err = self.zed.open(self.init_params) + if err != sl.ERROR_CODE.SUCCESS: + logger.error(f"Failed to open ZED camera: {err}") + return False + + self.is_opened = True + logger.info("ZED camera opened successfully") + + # Get camera information + info = self.zed.get_camera_information() + logger.info(f"ZED Camera Model: {info.camera_model}") + logger.info(f"Serial Number: {info.serial_number}") + logger.info(f"Firmware: {info.camera_configuration.firmware_version}") + + return True + + except Exception as e: + logger.error(f"Error opening ZED camera: {e}") + return False + + def enable_positional_tracking( + self, + enable_area_memory: bool = False, + enable_pose_smoothing: bool = True, + enable_imu_fusion: bool = True, + set_floor_as_origin: bool = False, + initial_world_transform: sl.Transform | None = None, + ) -> bool: + """ + Enable positional tracking on the ZED camera. + + Args: + enable_area_memory: Enable area learning to correct tracking drift + enable_pose_smoothing: Enable pose smoothing + enable_imu_fusion: Enable IMU fusion if available + set_floor_as_origin: Set the floor as origin (useful for robotics) + initial_world_transform: Initial world transform + + Returns: + True if tracking enabled successfully + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return False + + try: + # Configure tracking parameters + self.tracking_params.enable_area_memory = enable_area_memory + self.tracking_params.enable_pose_smoothing = enable_pose_smoothing + self.tracking_params.enable_imu_fusion = enable_imu_fusion + self.tracking_params.set_floor_as_origin = set_floor_as_origin + + if initial_world_transform is not None: + self.tracking_params.initial_world_transform = initial_world_transform + + # Enable tracking + err = self.zed.enable_positional_tracking(self.tracking_params) + if err != sl.ERROR_CODE.SUCCESS: + logger.error(f"Failed to enable positional tracking: {err}") + return False + + self.tracking_enabled = True + logger.info("Positional tracking enabled successfully") + return True + + except Exception as e: + logger.error(f"Error enabling positional tracking: {e}") + return False + + def disable_positional_tracking(self) -> None: + """Disable positional tracking.""" + if self.tracking_enabled: + self.zed.disable_positional_tracking() + self.tracking_enabled = False + logger.info("Positional tracking disabled") + + def get_pose( + self, reference_frame: sl.REFERENCE_FRAME = sl.REFERENCE_FRAME.WORLD + ) -> dict[str, Any] | None: + """ + Get the current camera pose. + + Args: + reference_frame: Reference frame (WORLD or CAMERA) + + Returns: + Dictionary containing: + - position: [x, y, z] in meters + - rotation: [x, y, z, w] quaternion + - euler_angles: [roll, pitch, yaw] in radians + - timestamp: Pose timestamp in nanoseconds + - confidence: Tracking confidence (0-100) + - valid: Whether pose is valid + """ + if not self.tracking_enabled: + logger.error("Positional tracking not enabled") + return None + + try: + # Get current pose + tracking_state = self.zed.get_position(self.camera_pose, reference_frame) + + if tracking_state == sl.POSITIONAL_TRACKING_STATE.OK: + # Extract translation + translation = self.camera_pose.get_translation().get() + + # Extract rotation (quaternion) + rotation = self.camera_pose.get_orientation().get() + + # Get Euler angles + euler = self.camera_pose.get_euler_angles() + + return { + "position": translation.tolist(), + "rotation": rotation.tolist(), # [x, y, z, w] + "euler_angles": euler.tolist(), # [roll, pitch, yaw] + "timestamp": self.camera_pose.timestamp.get_nanoseconds(), + "confidence": self.camera_pose.pose_confidence, + "valid": True, + "tracking_state": str(tracking_state), + } + else: + logger.warning(f"Tracking state: {tracking_state}") + return {"valid": False, "tracking_state": str(tracking_state)} + + except Exception as e: + logger.error(f"Error getting pose: {e}") + return None + + def get_imu_data(self) -> dict[str, Any] | None: + """ + Get IMU sensor data if available. + + Returns: + Dictionary containing: + - orientation: IMU orientation quaternion [x, y, z, w] + - angular_velocity: [x, y, z] in rad/s + - linear_acceleration: [x, y, z] in m/s² + - timestamp: IMU data timestamp + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None + + try: + # Get sensors data synchronized with images + if ( + self.zed.get_sensors_data(self.sensors_data, sl.TIME_REFERENCE.IMAGE) + == sl.ERROR_CODE.SUCCESS + ): + imu = self.sensors_data.get_imu_data() + + # Get IMU orientation + imu_orientation = imu.get_pose().get_orientation().get() + + # Get angular velocity + angular_vel = imu.get_angular_velocity() + + # Get linear acceleration + linear_accel = imu.get_linear_acceleration() + + return { + "orientation": imu_orientation.tolist(), + "angular_velocity": angular_vel.tolist(), + "linear_acceleration": linear_accel.tolist(), + "timestamp": self.sensors_data.timestamp.get_nanoseconds(), + "temperature": self.sensors_data.temperature.get(sl.SENSOR_LOCATION.IMU), + } + else: + return None + + except Exception as e: + logger.error(f"Error getting IMU data: {e}") + return None + + def capture_frame( + self, + ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: # type: ignore[type-arg] + """ + Capture a frame from ZED camera. + + Returns: + Tuple of (left_image, right_image, depth_map) as numpy arrays + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None, None, None + + try: + # Grab frame + if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: + # Retrieve left image + self.zed.retrieve_image(self.image_left, sl.VIEW.LEFT) + left_img = self.image_left.get_data()[:, :, :3] # Remove alpha channel + + # Retrieve right image + self.zed.retrieve_image(self.image_right, sl.VIEW.RIGHT) + right_img = self.image_right.get_data()[:, :, :3] # Remove alpha channel + + # Retrieve depth map + self.zed.retrieve_measure(self.depth_map, sl.MEASURE.DEPTH) + depth = self.depth_map.get_data() + + return left_img, right_img, depth + else: + logger.warning("Failed to grab frame from ZED camera") + return None, None, None + + except Exception as e: + logger.error(f"Error capturing frame: {e}") + return None, None, None + + def capture_pointcloud(self) -> o3d.geometry.PointCloud | None: + """ + Capture point cloud from ZED camera. + + Returns: + Open3D point cloud with XYZ coordinates and RGB colors + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None + + try: + if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: + # Retrieve point cloud with RGBA data + self.zed.retrieve_measure(self.point_cloud, sl.MEASURE.XYZRGBA) + point_cloud_data = self.point_cloud.get_data() + + # Convert to numpy array format + _height, _width = point_cloud_data.shape[:2] + points = point_cloud_data.reshape(-1, 4) + + # Extract XYZ coordinates + xyz = points[:, :3] + + # Extract and unpack RGBA color data from 4th channel + rgba_packed = points[:, 3].view(np.uint32) + + # Unpack RGBA: each 32-bit value contains 4 bytes (R, G, B, A) + colors_rgba = np.zeros((len(rgba_packed), 4), dtype=np.uint8) + colors_rgba[:, 0] = rgba_packed & 0xFF # R + colors_rgba[:, 1] = (rgba_packed >> 8) & 0xFF # G + colors_rgba[:, 2] = (rgba_packed >> 16) & 0xFF # B + colors_rgba[:, 3] = (rgba_packed >> 24) & 0xFF # A + + # Extract RGB (ignore alpha) and normalize to [0, 1] + colors_rgb = colors_rgba[:, :3].astype(np.float64) / 255.0 + + # Filter out invalid points (NaN or inf) + valid = np.isfinite(xyz).all(axis=1) + valid_xyz = xyz[valid] + valid_colors = colors_rgb[valid] + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + + if len(valid_xyz) > 0: + pcd.points = o3d.utility.Vector3dVector(valid_xyz) + pcd.colors = o3d.utility.Vector3dVector(valid_colors) + + return pcd + else: + logger.warning("Failed to grab frame for point cloud") + return None + + except Exception as e: + logger.error(f"Error capturing point cloud: {e}") + return None + + def capture_frame_with_pose( + self, + ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None, dict[str, Any] | None]: # type: ignore[type-arg] + """ + Capture a frame with synchronized pose data. + + Returns: + Tuple of (left_image, right_image, depth_map, pose_data) + """ + if not self.is_opened: + logger.error("ZED camera not opened") + return None, None, None, None + + try: + # Grab frame + if self.zed.grab(self.runtime_params) == sl.ERROR_CODE.SUCCESS: + # Get images and depth + left_img, right_img, depth = self.capture_frame() + + # Get synchronized pose if tracking is enabled + pose_data = None + if self.tracking_enabled: + pose_data = self.get_pose() + + return left_img, right_img, depth, pose_data + else: + logger.warning("Failed to grab frame from ZED camera") + return None, None, None, None + + except Exception as e: + logger.error(f"Error capturing frame with pose: {e}") + return None, None, None, None + + def close(self) -> None: + """Close the ZED camera.""" + if self.is_opened: + # Disable tracking if enabled + if self.tracking_enabled: + self.disable_positional_tracking() + + self.zed.close() + self.is_opened = False + logger.info("ZED camera closed") + + def get_camera_info(self) -> dict[str, Any]: + """Get ZED camera information and calibration parameters.""" + if not self.is_opened: + return {} + + try: + info = self.zed.get_camera_information() + calibration = info.camera_configuration.calibration_parameters + + # In ZED SDK 4.0+, the baseline calculation has changed + # Try to get baseline from the stereo parameters + try: + # Method 1: Try to get from stereo parameters if available + if hasattr(calibration, "getCameraBaseline"): + baseline = calibration.getCameraBaseline() + else: + # Method 2: Calculate from left and right camera positions + # The baseline is the distance between left and right cameras + + # Try different ways to get baseline in SDK 4.0+ + if hasattr(info.camera_configuration, "calibration_parameters_raw"): + # Use raw calibration if available + raw_calib = info.camera_configuration.calibration_parameters_raw + if hasattr(raw_calib, "T"): + baseline = abs(raw_calib.T[0]) + else: + baseline = 0.12 # Default ZED-M baseline approximation + else: + # Use default baseline for ZED-M + baseline = 0.12 # ZED-M baseline is approximately 120mm + except: + baseline = 0.12 # Fallback to approximate ZED-M baseline + + return { + "model": str(info.camera_model), + "serial_number": info.serial_number, + "firmware": info.camera_configuration.firmware_version, + "resolution": { + "width": info.camera_configuration.resolution.width, + "height": info.camera_configuration.resolution.height, + }, + "fps": info.camera_configuration.fps, + "left_cam": { + "fx": calibration.left_cam.fx, + "fy": calibration.left_cam.fy, + "cx": calibration.left_cam.cx, + "cy": calibration.left_cam.cy, + "k1": calibration.left_cam.disto[0], + "k2": calibration.left_cam.disto[1], + "p1": calibration.left_cam.disto[2], + "p2": calibration.left_cam.disto[3], + "k3": calibration.left_cam.disto[4], + }, + "right_cam": { + "fx": calibration.right_cam.fx, + "fy": calibration.right_cam.fy, + "cx": calibration.right_cam.cx, + "cy": calibration.right_cam.cy, + "k1": calibration.right_cam.disto[0], + "k2": calibration.right_cam.disto[1], + "p1": calibration.right_cam.disto[2], + "p2": calibration.right_cam.disto[3], + "k3": calibration.right_cam.disto[4], + }, + "baseline": baseline, + } + except Exception as e: + logger.error(f"Error getting camera info: {e}") + return {} + + def calculate_intrinsics(self): # type: ignore[no-untyped-def] + """Calculate camera intrinsics from ZED calibration.""" + info = self.get_camera_info() + if not info: + return super().calculate_intrinsics() # type: ignore[misc] + + left_cam = info.get("left_cam", {}) + resolution = info.get("resolution", {}) + + return { + "focal_length_x": left_cam.get("fx", 0), + "focal_length_y": left_cam.get("fy", 0), + "principal_point_x": left_cam.get("cx", 0), + "principal_point_y": left_cam.get("cy", 0), + "baseline": info.get("baseline", 0), + "resolution_width": resolution.get("width", 0), + "resolution_height": resolution.get("height", 0), + } + + def __enter__(self): # type: ignore[no-untyped-def] + """Context manager entry.""" + if not self.open(): + raise RuntimeError("Failed to open ZED camera") + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Context manager exit.""" + self.close() + + +class ZEDModule(Module): + """ + Dask module for ZED camera that publishes sensor data via LCM. + + Publishes: + - /zed/color_image: RGB camera images + - /zed/depth_image: Depth images + - /zed/camera_info: Camera calibration information + - /zed/pose: Camera pose (if tracking enabled) + """ + + # Define LCM outputs + color_image: Out[Image] = None # type: ignore[assignment] + depth_image: Out[Image] = None # type: ignore[assignment] + camera_info: Out[CameraInfo] = None # type: ignore[assignment] + pose: Out[PoseStamped] = None # type: ignore[assignment] + + def __init__( # type: ignore[no-untyped-def] + self, + camera_id: int = 0, + resolution: str = "HD720", + depth_mode: str = "NEURAL", + fps: int = 30, + enable_tracking: bool = True, + enable_imu_fusion: bool = True, + set_floor_as_origin: bool = True, + publish_rate: float = 30.0, + frame_id: str = "zed_camera", + recording_path: str | None = None, + **kwargs, + ) -> None: + """ + Initialize ZED Module. + + Args: + camera_id: Camera ID (0 for first ZED) + resolution: Resolution string ("HD720", "HD1080", "HD2K", "VGA") + depth_mode: Depth mode string ("NEURAL", "ULTRA", "QUALITY", "PERFORMANCE") + fps: Camera frame rate + enable_tracking: Enable positional tracking + enable_imu_fusion: Enable IMU fusion for tracking + set_floor_as_origin: Set floor as origin for tracking + publish_rate: Rate to publish messages (Hz) + frame_id: TF frame ID for messages + recording_path: Path to save recorded data + """ + super().__init__(**kwargs) + + self.camera_id = camera_id + self.fps = fps + self.enable_tracking = enable_tracking + self.enable_imu_fusion = enable_imu_fusion + self.set_floor_as_origin = set_floor_as_origin + self.publish_rate = publish_rate + self.frame_id = frame_id + self.recording_path = recording_path + + # Convert string parameters to ZED enums + self.resolution = getattr(sl.RESOLUTION, resolution, sl.RESOLUTION.HD720) + self.depth_mode = getattr(sl.DEPTH_MODE, depth_mode, sl.DEPTH_MODE.NEURAL) + + # Internal state + self.zed_camera = None + self._running = False + self._subscription = None + self._sequence = 0 + + # Initialize TF publisher + self.tf = TF() + + # Initialize storage for recording if path provided + self.storages = None + if self.recording_path: + from dimos.utils.testing import TimedSensorStorage + + self.storages = { + "color": TimedSensorStorage(f"{self.recording_path}/color"), + "depth": TimedSensorStorage(f"{self.recording_path}/depth"), + "pose": TimedSensorStorage(f"{self.recording_path}/pose"), + "camera_info": TimedSensorStorage(f"{self.recording_path}/camera_info"), + } + logger.info(f"Recording enabled - saving to {self.recording_path}") + + logger.info(f"ZEDModule initialized for camera {camera_id}") + + @rpc + def start(self) -> None: + """Start the ZED module and begin publishing data.""" + if self._running: + logger.warning("ZED module already running") + return + + super().start() + + try: + # Initialize ZED camera + self.zed_camera = ZEDCamera( # type: ignore[assignment] + camera_id=self.camera_id, + resolution=self.resolution, + depth_mode=self.depth_mode, + fps=self.fps, + ) + + # Open camera + if not self.zed_camera.open(): # type: ignore[attr-defined] + logger.error("Failed to open ZED camera") + return + + # Enable tracking if requested + if self.enable_tracking: + success = self.zed_camera.enable_positional_tracking( # type: ignore[attr-defined] + enable_imu_fusion=self.enable_imu_fusion, + set_floor_as_origin=self.set_floor_as_origin, + enable_pose_smoothing=True, + enable_area_memory=True, + ) + if not success: + logger.warning("Failed to enable positional tracking") + self.enable_tracking = False + + # Publish camera info once at startup + self._publish_camera_info() + + # Start periodic frame capture and publishing + self._running = True + publish_interval = 1.0 / self.publish_rate + + self._subscription = interval(publish_interval).subscribe( # type: ignore[assignment] + lambda _: self._capture_and_publish() + ) + + logger.info(f"ZED module started, publishing at {self.publish_rate} Hz") + + except Exception as e: + logger.error(f"Error starting ZED module: {e}") + self._running = False + + @rpc + def stop(self) -> None: + """Stop the ZED module.""" + if not self._running: + return + + self._running = False + + # Stop subscription + if self._subscription: + self._subscription.dispose() + self._subscription = None + + # Close camera + if self.zed_camera: + self.zed_camera.close() + self.zed_camera = None + + super().stop() + + def _capture_and_publish(self) -> None: + """Capture frame and publish all data.""" + if not self._running or not self.zed_camera: + return + + try: + # Capture frame with pose + left_img, _, depth, pose_data = self.zed_camera.capture_frame_with_pose() + + if left_img is None or depth is None: + return + + # Save raw color data if recording + if self.storages and left_img is not None: + self.storages["color"].save_one(left_img) + + # Save raw depth data if recording + if self.storages and depth is not None: + self.storages["depth"].save_one(depth) + + # Save raw pose data if recording + if self.storages and pose_data: + self.storages["pose"].save_one(pose_data) + + # Create header + header = Header(self.frame_id) + self._sequence += 1 + + # Publish color image + self._publish_color_image(left_img, header) + + # Publish depth image + self._publish_depth_image(depth, header) + + # Publish camera info periodically + self._publish_camera_info() + + # Publish pose if tracking enabled and valid + if self.enable_tracking and pose_data and pose_data.get("valid", False): + self._publish_pose(pose_data, header) + + except Exception as e: + logger.error(f"Error in capture and publish: {e}") + + def _publish_color_image(self, image: np.ndarray, header: Header) -> None: # type: ignore[type-arg] + """Publish color image as LCM message.""" + try: + # Convert BGR to RGB if needed + if len(image.shape) == 3 and image.shape[2] == 3: + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + else: + image_rgb = image + + # Create LCM Image message + msg = Image( + data=image_rgb, + format=ImageFormat.RGB, + frame_id=header.frame_id, + ts=header.ts, + ) + + self.color_image.publish(msg) + + except Exception as e: + logger.error(f"Error publishing color image: {e}") + + def _publish_depth_image(self, depth: np.ndarray, header: Header) -> None: # type: ignore[type-arg] + """Publish depth image as LCM message.""" + try: + # Depth is float32 in meters + msg = Image( + data=depth, + format=ImageFormat.DEPTH, + frame_id=header.frame_id, + ts=header.ts, + ) + self.depth_image.publish(msg) + + except Exception as e: + logger.error(f"Error publishing depth image: {e}") + + def _publish_camera_info(self) -> None: + """Publish camera calibration information.""" + try: + info = self.zed_camera.get_camera_info() # type: ignore[attr-defined] + if not info: + return + + # Save raw camera info if recording + if self.storages: + self.storages["camera_info"].save_one(info) + + # Get calibration parameters + left_cam = info.get("left_cam", {}) + resolution = info.get("resolution", {}) + + # Create CameraInfo message + header = Header(self.frame_id) + + # Create camera matrix K (3x3) + K = [ + left_cam.get("fx", 0), + 0, + left_cam.get("cx", 0), + 0, + left_cam.get("fy", 0), + left_cam.get("cy", 0), + 0, + 0, + 1, + ] + + # Distortion coefficients + D = [ + left_cam.get("k1", 0), + left_cam.get("k2", 0), + left_cam.get("p1", 0), + left_cam.get("p2", 0), + left_cam.get("k3", 0), + ] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [ + left_cam.get("fx", 0), + 0, + left_cam.get("cx", 0), + 0, + 0, + left_cam.get("fy", 0), + left_cam.get("cy", 0), + 0, + 0, + 0, + 1, + 0, + ] + + msg = CameraInfo( + D_length=len(D), + header=header, + height=resolution.get("height", 0), + width=resolution.get("width", 0), + distortion_model="plumb_bob", + D=D, + K=K, + R=R, + P=P, + binning_x=0, + binning_y=0, + ) + + self.camera_info.publish(msg) + + except Exception as e: + logger.error(f"Error publishing camera info: {e}") + + def _publish_pose(self, pose_data: dict[str, Any], header: Header) -> None: + """Publish camera pose as PoseStamped message and TF transform.""" + try: + position = pose_data.get("position", [0, 0, 0]) + rotation = pose_data.get("rotation", [0, 0, 0, 1]) # quaternion [x,y,z,w] + + # Create PoseStamped message + msg = PoseStamped(ts=header.ts, position=position, orientation=rotation) + self.pose.publish(msg) + + # Publish TF transform + camera_tf = Transform( + translation=Vector3(position), + rotation=Quaternion(rotation), + frame_id="zed_world", + child_frame_id="zed_camera_link", + ts=header.ts, + ) + self.tf.publish(camera_tf) + + except Exception as e: + logger.error(f"Error publishing pose: {e}") + + @rpc + def get_camera_info(self) -> dict[str, Any]: + """Get camera information and calibration parameters.""" + if self.zed_camera: + return self.zed_camera.get_camera_info() + return {} + + @rpc + def get_pose(self) -> dict[str, Any] | None: + """Get current camera pose if tracking is enabled.""" + if self.zed_camera and self.enable_tracking: + return self.zed_camera.get_pose() + return None diff --git a/dimos/hardware/camera/zed/single_webcam.yaml b/dimos/hardware/camera/zed/single_webcam.yaml new file mode 100644 index 0000000000..1ce9457559 --- /dev/null +++ b/dimos/hardware/camera/zed/single_webcam.yaml @@ -0,0 +1,27 @@ +# for cv2.VideoCapture and cutting only half of the frame +image_width: 640 +image_height: 376 +camera_name: zed_webcam_single +camera_matrix: + rows: 3 + cols: 3 + data: [379.45267, 0. , 302.43516, + 0. , 380.67871, 228.00954, + 0. , 0. , 1. ] +distortion_model: plumb_bob +distortion_coefficients: + rows: 1 + cols: 5 + data: [-0.309435, 0.092185, -0.009059, 0.003708, 0.000000] +rectification_matrix: + rows: 3 + cols: 3 + data: [1., 0., 0., + 0., 1., 0., + 0., 0., 1.] +projection_matrix: + rows: 3 + cols: 4 + data: [291.12888, 0. , 304.94086, 0. , + 0. , 347.95022, 231.8885 , 0. , + 0. , 0. , 1. , 0. ] diff --git a/dimos/hardware/camera/zed/test_zed.py b/dimos/hardware/camera/zed/test_zed.py new file mode 100644 index 0000000000..e31c751b43 --- /dev/null +++ b/dimos/hardware/camera/zed/test_zed.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo + + +def test_zed_import_and_calibration_access() -> None: + """Test that zed module can be imported and calibrations accessed.""" + # Import zed module from camera + from dimos.hardware.camera import zed + + # Test that CameraInfo is accessible + assert hasattr(zed, "CameraInfo") + + # Test snake_case access + camera_info_snake = zed.CameraInfo.single_webcam + assert isinstance(camera_info_snake, CameraInfo) + assert camera_info_snake.width == 640 + assert camera_info_snake.height == 376 + assert camera_info_snake.distortion_model == "plumb_bob" + + # Test PascalCase access + camera_info_pascal = zed.CameraInfo.SingleWebcam + assert isinstance(camera_info_pascal, CameraInfo) + assert camera_info_pascal.width == 640 + assert camera_info_pascal.height == 376 + + # Verify both access methods return the same cached object + assert camera_info_snake is camera_info_pascal + + print("✓ ZED import and calibration access test passed!") diff --git a/dimos/hardware/can_activate.sh b/dimos/hardware/can_activate.sh new file mode 100644 index 0000000000..addb892557 --- /dev/null +++ b/dimos/hardware/can_activate.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# The default CAN name can be set by the user via command-line parameters. +DEFAULT_CAN_NAME="${1:-can0}" + +# The default bitrate for a single CAN module can be set by the user via command-line parameters. +DEFAULT_BITRATE="${2:-1000000}" + +# USB hardware address (optional parameter) +USB_ADDRESS="${3}" +echo "-------------------START-----------------------" +# Check if ethtool is installed. +if ! dpkg -l | grep -q "ethtool"; then + echo "\e[31mError: ethtool not detected in the system.\e[0m" + echo "Please use the following command to install ethtool:" + echo "sudo apt update && sudo apt install ethtool" + exit 1 +fi + +# Check if can-utils is installed. +if ! dpkg -l | grep -q "can-utils"; then + echo "\e[31mError: can-utils not detected in the system.\e[0m" + echo "Please use the following command to install ethtool:" + echo "sudo apt update && sudo apt install can-utils" + exit 1 +fi + +echo "Both ethtool and can-utils are installed." + +# Retrieve the number of CAN modules in the current system. +CURRENT_CAN_COUNT=$(ip link show type can | grep -c "link/can") + +# Verify if the number of CAN modules in the current system matches the expected value. +if [ "$CURRENT_CAN_COUNT" -ne "1" ]; then + if [ -z "$USB_ADDRESS" ]; then + # Iterate through all CAN interfaces. + for iface in $(ip -br link show type can | awk '{print $1}'); do + # Use ethtool to retrieve bus-info. + BUS_INFO=$(sudo ethtool -i "$iface" | grep "bus-info" | awk '{print $2}') + + if [ -z "$BUS_INFO" ];then + echo "Error: Unable to retrieve bus-info for interface $iface." + continue + fi + + echo "Interface $iface is inserted into USB port $BUS_INFO" + done + echo -e " \e[31m Error: The number of CAN modules detected by the system ($CURRENT_CAN_COUNT) does not match the expected number (1). \e[0m" + echo -e " \e[31m Please add the USB hardware address parameter, such as: \e[0m" + echo -e " bash can_activate.sh can0 1000000 1-2:1.0" + echo "-------------------ERROR-----------------------" + exit 1 + fi +fi + +# Load the gs_usb module. +# sudo modprobe gs_usb +# if [ $? -ne 0 ]; then +# echo "Error: Unable to load the gs_usb module." +# exit 1 +# fi + +if [ -n "$USB_ADDRESS" ]; then + echo "Detected USB hardware address parameter: $USB_ADDRESS" + + # Use ethtool to find the CAN interface corresponding to the USB hardware address. + INTERFACE_NAME="" + for iface in $(ip -br link show type can | awk '{print $1}'); do + BUS_INFO=$(sudo ethtool -i "$iface" | grep "bus-info" | awk '{print $2}') + if [ "$BUS_INFO" = "$USB_ADDRESS" ]; then + INTERFACE_NAME="$iface" + break + fi + done + + if [ -z "$INTERFACE_NAME" ]; then + echo "Error: Unable to find CAN interface corresponding to USB hardware address $USB_ADDRESS." + exit 1 + else + echo "Found the interface corresponding to USB hardware address $USB_ADDRESS: $INTERFACE_NAME." + fi +else + # Retrieve the unique CAN interface. + INTERFACE_NAME=$(ip -br link show type can | awk '{print $1}') + + # Check if the interface name has been retrieved. + if [ -z "$INTERFACE_NAME" ]; then + echo "Error: Unable to detect CAN interface." + exit 1 + fi + BUS_INFO=$(sudo ethtool -i "$INTERFACE_NAME" | grep "bus-info" | awk '{print $2}') + echo "Expected to configure a single CAN module, detected interface $INTERFACE_NAME with corresponding USB address $BUS_INFO." +fi + +# Check if the current interface is already activated. +IS_LINK_UP=$(ip link show "$INTERFACE_NAME" | grep -q "UP" && echo "yes" || echo "no") + +# Retrieve the bitrate of the current interface. +CURRENT_BITRATE=$(ip -details link show "$INTERFACE_NAME" | grep -oP 'bitrate \K\d+') + +if [ "$IS_LINK_UP" = "yes" ] && [ "$CURRENT_BITRATE" -eq "$DEFAULT_BITRATE" ]; then + echo "Interface $INTERFACE_NAME is already activated with a bitrate of $DEFAULT_BITRATE." + + # Check if the interface name matches the default name. + if [ "$INTERFACE_NAME" != "$DEFAULT_CAN_NAME" ]; then + echo "Rename interface $INTERFACE_NAME to $DEFAULT_CAN_NAME." + sudo ip link set "$INTERFACE_NAME" down + sudo ip link set "$INTERFACE_NAME" name "$DEFAULT_CAN_NAME" + sudo ip link set "$DEFAULT_CAN_NAME" up + echo "The interface has been renamed to $DEFAULT_CAN_NAME and reactivated." + else + echo "The interface name is already $DEFAULT_CAN_NAME." + fi +else + # If the interface is not activated or the bitrate is different, configure it. + if [ "$IS_LINK_UP" = "yes" ]; then + echo "Interface $INTERFACE_NAME is already activated, but the bitrate is $CURRENT_BITRATE, which does not match the set value of $DEFAULT_BITRATE." + else + echo "Interface $INTERFACE_NAME is not activated or bitrate is not set." + fi + + # Set the interface bitrate and activate it. + sudo ip link set "$INTERFACE_NAME" down + sudo ip link set "$INTERFACE_NAME" type can bitrate $DEFAULT_BITRATE + sudo ip link set "$INTERFACE_NAME" up + echo "Interface $INTERFACE_NAME has been reset to bitrate $DEFAULT_BITRATE and activated." + + # Rename the interface to the default name. + if [ "$INTERFACE_NAME" != "$DEFAULT_CAN_NAME" ]; then + echo "Rename interface $INTERFACE_NAME to $DEFAULT_CAN_NAME." + sudo ip link set "$INTERFACE_NAME" down + sudo ip link set "$INTERFACE_NAME" name "$DEFAULT_CAN_NAME" + sudo ip link set "$DEFAULT_CAN_NAME" up + echo "The interface has been renamed to $DEFAULT_CAN_NAME and reactivated." + fi +fi + +echo "-------------------OVER------------------------" diff --git a/dimos/hardware/end_effector.py b/dimos/hardware/end_effector.py index 37de922bd5..e958261b91 100644 --- a/dimos/hardware/end_effector.py +++ b/dimos/hardware/end_effector.py @@ -1,6 +1,21 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + class EndEffector: - def __init__(self, effector_type=None): + def __init__(self, effector_type=None) -> None: # type: ignore[no-untyped-def] self.effector_type = effector_type - def get_effector_type(self): + def get_effector_type(self): # type: ignore[no-untyped-def] return self.effector_type diff --git a/dimos/hardware/fake_zed_module.py b/dimos/hardware/fake_zed_module.py new file mode 100644 index 0000000000..987dc2ec20 --- /dev/null +++ b/dimos/hardware/fake_zed_module.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FakeZEDModule - Replays recorded ZED data for testing without hardware. +""" + +import functools +import logging + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +import numpy as np + +from dimos.core import Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.protocol.tf import TF +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +logger = setup_logger(level=logging.INFO) + + +class FakeZEDModule(Module): + """ + Fake ZED module that replays recorded data instead of real camera. + """ + + # Define LCM outputs (same as ZEDModule) + color_image: Out[Image] = None # type: ignore[assignment] + depth_image: Out[Image] = None # type: ignore[assignment] + camera_info: Out[CameraInfo] = None # type: ignore[assignment] + pose: Out[PoseStamped] = None # type: ignore[assignment] + + def __init__(self, recording_path: str, frame_id: str = "zed_camera", **kwargs) -> None: # type: ignore[no-untyped-def] + """ + Initialize FakeZEDModule with recording path. + + Args: + recording_path: Path to recorded data directory + frame_id: TF frame ID for messages + """ + super().__init__(**kwargs) + + self.recording_path = recording_path + self.frame_id = frame_id + self._running = False + + # Initialize TF publisher + self.tf = TF() + + logger.info(f"FakeZEDModule initialized with recording: {self.recording_path}") + + @functools.cache + def _get_color_stream(self): # type: ignore[no-untyped-def] + """Get cached color image stream.""" + logger.info(f"Loading color image stream from {self.recording_path}/color") + + def image_autocast(x): # type: ignore[no-untyped-def] + """Convert raw numpy array to Image.""" + if isinstance(x, np.ndarray): + return Image(data=x, format=ImageFormat.RGB) + elif isinstance(x, Image): + return x + return x + + color_replay = TimedSensorReplay(f"{self.recording_path}/color", autocast=image_autocast) + return color_replay.stream() + + @functools.cache + def _get_depth_stream(self): # type: ignore[no-untyped-def] + """Get cached depth image stream.""" + logger.info(f"Loading depth image stream from {self.recording_path}/depth") + + def depth_autocast(x): # type: ignore[no-untyped-def] + """Convert raw numpy array to depth Image.""" + if isinstance(x, np.ndarray): + # Depth images are float32 + return Image(data=x, format=ImageFormat.DEPTH) + elif isinstance(x, Image): + return x + return x + + depth_replay = TimedSensorReplay(f"{self.recording_path}/depth", autocast=depth_autocast) + return depth_replay.stream() + + @functools.cache + def _get_pose_stream(self): # type: ignore[no-untyped-def] + """Get cached pose stream.""" + logger.info(f"Loading pose stream from {self.recording_path}/pose") + + def pose_autocast(x): # type: ignore[no-untyped-def] + """Convert raw pose dict to PoseStamped.""" + if isinstance(x, dict): + import time + + return PoseStamped( + position=x.get("position", [0, 0, 0]), + orientation=x.get("rotation", [0, 0, 0, 1]), + ts=time.time(), + ) + elif isinstance(x, PoseStamped): + return x + return x + + pose_replay = TimedSensorReplay(f"{self.recording_path}/pose", autocast=pose_autocast) + return pose_replay.stream() + + @functools.cache + def _get_camera_info_stream(self): # type: ignore[no-untyped-def] + """Get cached camera info stream.""" + logger.info(f"Loading camera info stream from {self.recording_path}/camera_info") + + def camera_info_autocast(x): # type: ignore[no-untyped-def] + """Convert raw camera info dict to CameraInfo message.""" + if isinstance(x, dict): + # Extract calibration parameters + left_cam = x.get("left_cam", {}) + resolution = x.get("resolution", {}) + + # Create CameraInfo message + header = Header(self.frame_id) + + # Create camera matrix K (3x3) + K = [ + left_cam.get("fx", 0), + 0, + left_cam.get("cx", 0), + 0, + left_cam.get("fy", 0), + left_cam.get("cy", 0), + 0, + 0, + 1, + ] + + # Distortion coefficients + D = [ + left_cam.get("k1", 0), + left_cam.get("k2", 0), + left_cam.get("p1", 0), + left_cam.get("p2", 0), + left_cam.get("k3", 0), + ] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [ + left_cam.get("fx", 0), + 0, + left_cam.get("cx", 0), + 0, + 0, + left_cam.get("fy", 0), + left_cam.get("cy", 0), + 0, + 0, + 0, + 1, + 0, + ] + + return CameraInfo( + D_length=len(D), + header=header, + height=resolution.get("height", 0), + width=resolution.get("width", 0), + distortion_model="plumb_bob", + D=D, + K=K, + R=R, + P=P, + binning_x=0, + binning_y=0, + ) + elif isinstance(x, CameraInfo): + return x + return x + + info_replay = TimedSensorReplay( + f"{self.recording_path}/camera_info", autocast=camera_info_autocast + ) + return info_replay.stream() + + @rpc + def start(self) -> None: + """Start replaying recorded data.""" + super().start() + + if self._running: + logger.warning("FakeZEDModule already running") + return + + logger.info("Starting FakeZEDModule replay...") + + self._running = True + + # Subscribe to all streams and publish + try: + # Color image stream + unsub = self._get_color_stream().subscribe( + lambda msg: self.color_image.publish(msg) if self._running else None + ) + self._disposables.add(unsub) + logger.info("Started color image replay stream") + except Exception as e: + logger.warning(f"Color image stream not available: {e}") + + try: + # Depth image stream + unsub = self._get_depth_stream().subscribe( + lambda msg: self.depth_image.publish(msg) if self._running else None + ) + self._disposables.add(unsub) + logger.info("Started depth image replay stream") + except Exception as e: + logger.warning(f"Depth image stream not available: {e}") + + try: + # Pose stream + unsub = self._get_pose_stream().subscribe( + lambda msg: self._publish_pose(msg) if self._running else None + ) + self._disposables.add(unsub) + logger.info("Started pose replay stream") + except Exception as e: + logger.warning(f"Pose stream not available: {e}") + + try: + # Camera info stream + unsub = self._get_camera_info_stream().subscribe( + lambda msg: self.camera_info.publish(msg) if self._running else None + ) + self._disposables.add(unsub) + logger.info("Started camera info replay stream") + except Exception as e: + logger.warning(f"Camera info stream not available: {e}") + + logger.info("FakeZEDModule replay started") + + @rpc + def stop(self) -> None: + if not self._running: + return + + self._running = False + + super().stop() + + def _publish_pose(self, msg) -> None: # type: ignore[no-untyped-def] + """Publish pose and TF transform.""" + if msg: + self.pose.publish(msg) + + # Publish TF transform from world to camera + import time + + from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 + + transform = Transform( + translation=Vector3(*msg.position), + rotation=Quaternion(*msg.orientation), + frame_id="world", + child_frame_id=self.frame_id, + ts=time.time(), + ) + self.tf.publish(transform) diff --git a/dimos/hardware/gstreamer_camera.py b/dimos/hardware/gstreamer_camera.py new file mode 100644 index 0000000000..52635ae7b7 --- /dev/null +++ b/dimos/hardware/gstreamer_camera.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import threading +import time + +import numpy as np + +from dimos.core import Module, Out, rpc +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.utils.logging_config import setup_logger + +# Add system path for gi module if needed +if "/usr/lib/python3/dist-packages" not in sys.path: + sys.path.insert(0, "/usr/lib/python3/dist-packages") + +import gi # type: ignore[import-not-found] + +gi.require_version("Gst", "1.0") +gi.require_version("GstApp", "1.0") +from gi.repository import GLib, Gst # type: ignore[import-not-found] + +logger = setup_logger(level=logging.INFO) + +Gst.init(None) + + +class GstreamerCameraModule(Module): + """Module that captures frames from a remote camera using GStreamer TCP with absolute timestamps.""" + + video: Out[Image] = None # type: ignore[assignment] + + def __init__( # type: ignore[no-untyped-def] + self, + host: str = "localhost", + port: int = 5000, + frame_id: str = "camera", + timestamp_offset: float = 0.0, + reconnect_interval: float = 5.0, + *args, + **kwargs, + ) -> None: + """Initialize the GStreamer TCP camera module. + + Args: + host: TCP server host to connect to + port: TCP server port + frame_id: Frame ID for the published images + timestamp_offset: Offset to add to timestamps (useful for clock synchronization) + reconnect_interval: Seconds to wait before attempting reconnection + """ + self.host = host + self.port = port + self.frame_id = frame_id + self.timestamp_offset = timestamp_offset + self.reconnect_interval = reconnect_interval + + self.pipeline = None + self.appsink = None + self.main_loop = None + self.main_loop_thread = None + self.running = False + self.should_reconnect = False + self.frame_count = 0 + self.last_log_time = time.time() + self.reconnect_timer_id = None + + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self) -> None: + if self.running: + logger.warning("GStreamer camera module is already running") + return + + super().start() + + self.should_reconnect = True + self._connect() + + @rpc + def stop(self) -> None: + self.should_reconnect = False + self._cleanup_reconnect_timer() + + if not self.running: + return + + self.running = False + + if self.pipeline: + self.pipeline.set_state(Gst.State.NULL) + + if self.main_loop: + self.main_loop.quit() + + # Only join the thread if we're not calling from within it + if self.main_loop_thread and self.main_loop_thread != threading.current_thread(): + self.main_loop_thread.join(timeout=2.0) + + super().stop() + + def _connect(self) -> None: + if not self.should_reconnect: + return + + try: + self._create_pipeline() # type: ignore[no-untyped-call] + self._start_pipeline() # type: ignore[no-untyped-call] + self.running = True + logger.info(f"GStreamer TCP camera module connected to {self.host}:{self.port}") + except Exception as e: + logger.error(f"Failed to connect to {self.host}:{self.port}: {e}") + self._schedule_reconnect() + + def _cleanup_reconnect_timer(self) -> None: + if self.reconnect_timer_id: + GLib.source_remove(self.reconnect_timer_id) + self.reconnect_timer_id = None + + def _schedule_reconnect(self) -> None: + if not self.should_reconnect: + return + + self._cleanup_reconnect_timer() + logger.info(f"Scheduling reconnect in {self.reconnect_interval} seconds...") + self.reconnect_timer_id = GLib.timeout_add_seconds( + int(self.reconnect_interval), self._reconnect_timeout + ) + + def _reconnect_timeout(self) -> bool: + self.reconnect_timer_id = None + if self.should_reconnect: + logger.info("Attempting to reconnect...") + self._connect() + return False # Don't repeat the timeout + + def _handle_disconnect(self) -> None: + if not self.should_reconnect: + return + + self.running = False + + if self.pipeline: + self.pipeline.set_state(Gst.State.NULL) + self.pipeline = None + + self.appsink = None + + logger.warning(f"Disconnected from {self.host}:{self.port}") + self._schedule_reconnect() + + def _create_pipeline(self): # type: ignore[no-untyped-def] + # TCP client source with Matroska demuxer to extract absolute timestamps + pipeline_str = f""" + tcpclientsrc host={self.host} port={self.port} ! + matroskademux name=demux ! + h264parse ! + avdec_h264 ! + videoconvert ! + video/x-raw,format=BGR ! + appsink name=sink emit-signals=true sync=false max-buffers=1 drop=true + """ + + try: + self.pipeline = Gst.parse_launch(pipeline_str) + self.appsink = self.pipeline.get_by_name("sink") # type: ignore[attr-defined] + self.appsink.connect("new-sample", self._on_new_sample) # type: ignore[attr-defined] + except Exception as e: + logger.error(f"Failed to create GStreamer pipeline: {e}") + raise + + def _start_pipeline(self): # type: ignore[no-untyped-def] + """Start the GStreamer pipeline and main loop.""" + self.main_loop = GLib.MainLoop() + + # Start the pipeline + ret = self.pipeline.set_state(Gst.State.PLAYING) # type: ignore[attr-defined] + if ret == Gst.StateChangeReturn.FAILURE: + logger.error("Unable to set the pipeline to playing state") + raise RuntimeError("Failed to start GStreamer pipeline") + + # Run the main loop in a separate thread + self.main_loop_thread = threading.Thread(target=self._run_main_loop) # type: ignore[assignment] + self.main_loop_thread.daemon = True # type: ignore[attr-defined] + self.main_loop_thread.start() # type: ignore[attr-defined] + + # Set up bus message handling + bus = self.pipeline.get_bus() # type: ignore[attr-defined] + bus.add_signal_watch() + bus.connect("message", self._on_bus_message) + + def _run_main_loop(self) -> None: + try: + self.main_loop.run() # type: ignore[attr-defined] + except Exception as e: + logger.error(f"Main loop error: {e}") + + def _on_bus_message(self, bus, message) -> None: # type: ignore[no-untyped-def] + t = message.type + + if t == Gst.MessageType.EOS: + logger.info("End of stream - server disconnected") + self._handle_disconnect() + elif t == Gst.MessageType.ERROR: + err, debug = message.parse_error() + logger.error(f"GStreamer error: {err}, {debug}") + self._handle_disconnect() + elif t == Gst.MessageType.WARNING: + warn, debug = message.parse_warning() + logger.warning(f"GStreamer warning: {warn}, {debug}") + elif t == Gst.MessageType.STATE_CHANGED: + if message.src == self.pipeline: + _old_state, new_state, _pending_state = message.parse_state_changed() + if new_state == Gst.State.PLAYING: + logger.info("Pipeline is now playing - connected to TCP server") + + def _on_new_sample(self, appsink): # type: ignore[no-untyped-def] + """Handle new video samples from the appsink.""" + sample = appsink.emit("pull-sample") + if sample is None: + return Gst.FlowReturn.OK + + buffer = sample.get_buffer() + caps = sample.get_caps() + + # Extract video format information + struct = caps.get_structure(0) + width = struct.get_value("width") + height = struct.get_value("height") + + # Get the absolute timestamp from the buffer + # Matroska preserves the absolute timestamps we set in the sender + if buffer.pts != Gst.CLOCK_TIME_NONE: + # Convert nanoseconds to seconds and add offset + # This is the absolute time from when the frame was captured + timestamp = (buffer.pts / 1e9) + self.timestamp_offset + + # Skip frames with invalid timestamps (before year 2000) + # This filters out initial gray frames with relative timestamps + year_2000_timestamp = 946684800.0 # January 1, 2000 00:00:00 UTC + if timestamp < year_2000_timestamp: + logger.debug(f"Skipping frame with invalid timestamp: {timestamp:.6f}") + return Gst.FlowReturn.OK + + else: + return Gst.FlowReturn.OK + + # Map the buffer to access the data + success, map_info = buffer.map(Gst.MapFlags.READ) + if not success: + logger.error("Failed to map buffer") + return Gst.FlowReturn.ERROR + + try: + # Convert buffer data to numpy array + # The videoconvert element outputs BGR format + data = np.frombuffer(map_info.data, dtype=np.uint8) + + # Reshape to image dimensions + # For BGR format, we have 3 channels + image_array = data.reshape((height, width, 3)) + + # Create an Image message with the absolute timestamp + image_msg = Image( + data=image_array.copy(), # Make a copy to ensure data persistence + format=ImageFormat.BGR, + frame_id=self.frame_id, + ts=timestamp, + ) + + # Publish the image + if self.video and self.running: + self.video.publish(image_msg) + + # Log statistics periodically + self.frame_count += 1 + current_time = time.time() + if current_time - self.last_log_time >= 5.0: + fps = self.frame_count / (current_time - self.last_log_time) + logger.debug( + f"Receiving frames - FPS: {fps:.1f}, Resolution: {width}x{height}, " + f"Absolute timestamp: {timestamp:.6f}" + ) + self.frame_count = 0 + self.last_log_time = current_time + + except Exception as e: + logger.error(f"Error processing frame: {e}") + + finally: + buffer.unmap(map_info) + + return Gst.FlowReturn.OK diff --git a/dimos/hardware/gstreamer_camera_test_script.py b/dimos/hardware/gstreamer_camera_test_script.py new file mode 100755 index 0000000000..fbd2704cee --- /dev/null +++ b/dimos/hardware/gstreamer_camera_test_script.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import time + +from dimos import core +from dimos.hardware.gstreamer_camera import GstreamerCameraModule +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Test script for GStreamer TCP camera module") + + # Network options + parser.add_argument( + "--host", default="localhost", help="TCP server host to connect to (default: localhost)" + ) + parser.add_argument("--port", type=int, default=5000, help="TCP server port (default: 5000)") + + # Camera options + parser.add_argument( + "--frame-id", + default="zed_camera", + help="Frame ID for published images (default: zed_camera)", + ) + parser.add_argument( + "--reconnect-interval", + type=float, + default=5.0, + help="Seconds to wait before attempting reconnection (default: 5.0)", + ) + + # Logging options + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Initialize LCM + pubsub.lcm.autoconf() # type: ignore[attr-defined] + + # Start dimos + logger.info("Starting dimos...") + dimos = core.start(8) + + # Deploy the GStreamer camera module + logger.info(f"Deploying GStreamer TCP camera module (connecting to {args.host}:{args.port})...") + camera = dimos.deploy( # type: ignore[attr-defined] + GstreamerCameraModule, + host=args.host, + port=args.port, + frame_id=args.frame_id, + reconnect_interval=args.reconnect_interval, + ) + + # Set up LCM transport for the video output + camera.video.transport = core.LCMTransport("/zed/video", Image) + + # Counter for received frames + frame_count = [0] + last_log_time = [time.time()] + first_timestamp = [None] + + def on_frame(msg) -> None: # type: ignore[no-untyped-def] + frame_count[0] += 1 + current_time = time.time() + + # Capture first timestamp to show absolute timestamps are preserved + if first_timestamp[0] is None: + first_timestamp[0] = msg.ts + logger.info(f"First frame absolute timestamp: {msg.ts:.6f}") + + # Log stats every 2 seconds + if current_time - last_log_time[0] >= 2.0: + fps = frame_count[0] / (current_time - last_log_time[0]) + timestamp_delta = msg.ts - first_timestamp[0] + logger.info( + f"Received {frame_count[0]} frames - FPS: {fps:.1f} - " + f"Resolution: {msg.width}x{msg.height} - " + f"Timestamp: {msg.ts:.3f} (delta: {timestamp_delta:.3f}s)" + ) + frame_count[0] = 0 + last_log_time[0] = current_time + + # Subscribe to video output for monitoring + camera.video.subscribe(on_frame) + + # Start the camera + logger.info("Starting GStreamer camera...") + camera.start() + + logger.info("GStreamer TCP camera module is running. Press Ctrl+C to stop.") + logger.info(f"Connecting to TCP server at {args.host}:{args.port}") + logger.info("Publishing frames to LCM topic: /zed/video") + logger.info("") + logger.info("To start the sender on the camera machine, run:") + logger.info( + f" python3 dimos/hardware/gstreamer_sender.py --device /dev/video0 --host 0.0.0.0 --port {args.port}" + ) + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Shutting down...") + camera.stop() + logger.info("Stopped.") + + +if __name__ == "__main__": + main() diff --git a/dimos/hardware/gstreamer_sender.py b/dimos/hardware/gstreamer_sender.py new file mode 100755 index 0000000000..93ed6ce4ec --- /dev/null +++ b/dimos/hardware/gstreamer_sender.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import signal +import sys +import time + +# Add system path for gi module if needed +if "/usr/lib/python3/dist-packages" not in sys.path: + sys.path.insert(0, "/usr/lib/python3/dist-packages") + +import gi # type: ignore[import-not-found] + +gi.require_version("Gst", "1.0") +gi.require_version("GstVideo", "1.0") +from gi.repository import GLib, Gst # type: ignore[import-not-found] + +# Initialize GStreamer +Gst.init(None) + +# Setup logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger("gstreamer_tcp_sender") + + +class GStreamerTCPSender: + def __init__( + self, + device: str = "/dev/video0", + width: int = 2560, + height: int = 720, + framerate: int = 60, + format_str: str = "YUY2", + bitrate: int = 5000, + host: str = "0.0.0.0", + port: int = 5000, + single_camera: bool = False, + ) -> None: + """Initialize the GStreamer TCP sender. + + Args: + device: Video device path + width: Video width in pixels + height: Video height in pixels + framerate: Frame rate in fps + format_str: Video format + bitrate: H264 encoding bitrate in kbps + host: Host to listen on (0.0.0.0 for all interfaces) + port: TCP port for listening + single_camera: If True, crop to left half (for stereo cameras) + """ + self.device = device + self.width = width + self.height = height + self.framerate = framerate + self.format = format_str + self.bitrate = bitrate + self.host = host + self.port = port + self.single_camera = single_camera + + self.pipeline = None + self.videosrc = None + self.encoder = None + self.mux = None + self.main_loop = None + self.running = False + self.start_time = None + self.frame_count = 0 + + def create_pipeline(self): # type: ignore[no-untyped-def] + """Create the GStreamer pipeline with TCP server sink.""" + + # Create pipeline + self.pipeline = Gst.Pipeline.new("tcp-sender-pipeline") + + # Create elements + self.videosrc = Gst.ElementFactory.make("v4l2src", "source") + self.videosrc.set_property("device", self.device) # type: ignore[attr-defined] + self.videosrc.set_property("do-timestamp", True) # type: ignore[attr-defined] + logger.info(f"Using camera device: {self.device}") + + # Create caps filter for video format + capsfilter = Gst.ElementFactory.make("capsfilter", "capsfilter") + caps = Gst.Caps.from_string( + f"video/x-raw,width={self.width},height={self.height}," + f"format={self.format},framerate={self.framerate}/1" + ) + capsfilter.set_property("caps", caps) + + # Video converter + videoconvert = Gst.ElementFactory.make("videoconvert", "convert") + + # Crop element for single camera mode + videocrop = None + if self.single_camera: + videocrop = Gst.ElementFactory.make("videocrop", "crop") + # Crop to left half: for 2560x720 stereo, get left 1280x720 + videocrop.set_property("left", 0) + videocrop.set_property("right", self.width // 2) # Remove right half + videocrop.set_property("top", 0) + videocrop.set_property("bottom", 0) + + # H264 encoder + self.encoder = Gst.ElementFactory.make("x264enc", "encoder") + self.encoder.set_property("tune", "zerolatency") # type: ignore[attr-defined] + self.encoder.set_property("bitrate", self.bitrate) # type: ignore[attr-defined] + self.encoder.set_property("key-int-max", 30) # type: ignore[attr-defined] + + # H264 parser + h264parse = Gst.ElementFactory.make("h264parse", "parser") + + # Use matroskamux which preserves timestamps better + self.mux = Gst.ElementFactory.make("matroskamux", "mux") + self.mux.set_property("streamable", True) # type: ignore[attr-defined] + self.mux.set_property("writing-app", "gstreamer-tcp-sender") # type: ignore[attr-defined] + + # TCP server sink + tcpserversink = Gst.ElementFactory.make("tcpserversink", "sink") + tcpserversink.set_property("host", self.host) + tcpserversink.set_property("port", self.port) + tcpserversink.set_property("sync", False) + + # Add elements to pipeline + self.pipeline.add(self.videosrc) # type: ignore[attr-defined] + self.pipeline.add(capsfilter) # type: ignore[attr-defined] + self.pipeline.add(videoconvert) # type: ignore[attr-defined] + if videocrop: + self.pipeline.add(videocrop) # type: ignore[attr-defined] + self.pipeline.add(self.encoder) # type: ignore[attr-defined] + self.pipeline.add(h264parse) # type: ignore[attr-defined] + self.pipeline.add(self.mux) # type: ignore[attr-defined] + self.pipeline.add(tcpserversink) # type: ignore[attr-defined] + + # Link elements + if not self.videosrc.link(capsfilter): # type: ignore[attr-defined] + raise RuntimeError("Failed to link source to capsfilter") + if not capsfilter.link(videoconvert): + raise RuntimeError("Failed to link capsfilter to videoconvert") + + # Link through crop if in single camera mode + if videocrop: + if not videoconvert.link(videocrop): + raise RuntimeError("Failed to link videoconvert to videocrop") + if not videocrop.link(self.encoder): + raise RuntimeError("Failed to link videocrop to encoder") + else: + if not videoconvert.link(self.encoder): + raise RuntimeError("Failed to link videoconvert to encoder") + + if not self.encoder.link(h264parse): # type: ignore[attr-defined] + raise RuntimeError("Failed to link encoder to h264parse") + if not h264parse.link(self.mux): + raise RuntimeError("Failed to link h264parse to mux") + if not self.mux.link(tcpserversink): # type: ignore[attr-defined] + raise RuntimeError("Failed to link mux to tcpserversink") + + # Add probe to inject absolute timestamps + # Place probe after crop (if present) or after videoconvert + if videocrop: + probe_element = videocrop + else: + probe_element = videoconvert + probe_pad = probe_element.get_static_pad("src") + probe_pad.add_probe(Gst.PadProbeType.BUFFER, self._inject_absolute_timestamp, None) + + # Set up bus message handling + bus = self.pipeline.get_bus() # type: ignore[attr-defined] + bus.add_signal_watch() + bus.connect("message", self._on_bus_message) + + def _inject_absolute_timestamp(self, pad, info, user_data): # type: ignore[no-untyped-def] + buffer = info.get_buffer() + if buffer: + absolute_time = time.time() + absolute_time_ns = int(absolute_time * 1e9) + + # Set both PTS and DTS to the absolute time + # This will be preserved by matroskamux + buffer.pts = absolute_time_ns + buffer.dts = absolute_time_ns + + self.frame_count += 1 + return Gst.PadProbeReturn.OK + + def _on_bus_message(self, bus, message) -> None: # type: ignore[no-untyped-def] + t = message.type + + if t == Gst.MessageType.EOS: + logger.info("End of stream") + self.stop() + elif t == Gst.MessageType.ERROR: + err, debug = message.parse_error() + logger.error(f"Pipeline error: {err}, {debug}") + self.stop() + elif t == Gst.MessageType.WARNING: + warn, debug = message.parse_warning() + logger.warning(f"Pipeline warning: {warn}, {debug}") + elif t == Gst.MessageType.STATE_CHANGED: + if message.src == self.pipeline: + old_state, new_state, _pending_state = message.parse_state_changed() + logger.debug( + f"Pipeline state changed: {old_state.value_nick} -> {new_state.value_nick}" + ) + + def start(self): # type: ignore[no-untyped-def] + if self.running: + logger.warning("Sender is already running") + return + + logger.info("Creating TCP pipeline with absolute timestamps...") + self.create_pipeline() # type: ignore[no-untyped-call] + + logger.info("Starting pipeline...") + ret = self.pipeline.set_state(Gst.State.PLAYING) # type: ignore[attr-defined] + if ret == Gst.StateChangeReturn.FAILURE: + logger.error("Failed to start pipeline") + raise RuntimeError("Failed to start GStreamer pipeline") + + self.running = True + self.start_time = time.time() # type: ignore[assignment] + self.frame_count = 0 + + logger.info("TCP video sender started:") + logger.info(f" Source: {self.device}") + if self.single_camera: + output_width = self.width // 2 + logger.info(f" Input Resolution: {self.width}x{self.height} @ {self.framerate}fps") + logger.info( + f" Output Resolution: {output_width}x{self.height} @ {self.framerate}fps (left camera only)" + ) + else: + logger.info(f" Resolution: {self.width}x{self.height} @ {self.framerate}fps") + logger.info(f" Bitrate: {self.bitrate} kbps") + logger.info(f" TCP Server: {self.host}:{self.port}") + logger.info(" Container: Matroska (preserves absolute timestamps)") + logger.info(" Waiting for client connections...") + + self.main_loop = GLib.MainLoop() + try: + self.main_loop.run() # type: ignore[attr-defined] + except KeyboardInterrupt: + logger.info("Interrupted by user") + finally: + self.stop() + + def stop(self) -> None: + if not self.running: + return + + self.running = False + + if self.pipeline: + logger.info("Stopping pipeline...") + self.pipeline.set_state(Gst.State.NULL) + + if self.main_loop and self.main_loop.is_running(): + self.main_loop.quit() + + if self.frame_count > 0 and self.start_time: + elapsed = time.time() - self.start_time + avg_fps = self.frame_count / elapsed + logger.info(f"Total frames sent: {self.frame_count}, Average FPS: {avg_fps:.1f}") + + logger.info("TCP video sender stopped") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="GStreamer TCP video sender with absolute timestamps" + ) + + # Video source options + parser.add_argument( + "--device", default="/dev/video0", help="Video device path (default: /dev/video0)" + ) + + # Video format options + parser.add_argument("--width", type=int, default=2560, help="Video width (default: 2560)") + parser.add_argument("--height", type=int, default=720, help="Video height (default: 720)") + parser.add_argument("--framerate", type=int, default=15, help="Frame rate in fps (default: 15)") + parser.add_argument("--format", default="YUY2", help="Video format (default: YUY2)") + + # Encoding options + parser.add_argument( + "--bitrate", type=int, default=5000, help="H264 bitrate in kbps (default: 5000)" + ) + + # Network options + parser.add_argument( + "--host", + default="0.0.0.0", + help="Host to listen on (default: 0.0.0.0 for all interfaces)", + ) + parser.add_argument("--port", type=int, default=5000, help="TCP port (default: 5000)") + + # Camera options + parser.add_argument( + "--single-camera", + action="store_true", + help="Extract left camera only from stereo feed (crops 2560x720 to 1280x720)", + ) + + # Logging options + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Create and start sender + sender = GStreamerTCPSender( + device=args.device, + width=args.width, + height=args.height, + framerate=args.framerate, + format_str=args.format, + bitrate=args.bitrate, + host=args.host, + port=args.port, + single_camera=args.single_camera, + ) + + # Handle signals gracefully + def signal_handler(sig, frame) -> None: # type: ignore[no-untyped-def] + logger.info(f"Received signal {sig}, shutting down...") + sender.stop() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + sender.start() # type: ignore[no-untyped-call] + except Exception as e: + logger.error(f"Failed to start sender: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/dimos/hardware/interface.py b/dimos/hardware/interface.py deleted file mode 100644 index 0ff9bb8d51..0000000000 --- a/dimos/hardware/interface.py +++ /dev/null @@ -1,31 +0,0 @@ -from dimos.hardware.end_effector import EndEffector -from dimos.hardware.camera import Camera -from dimos.hardware.stereo_camera import StereoCamera -from dimos.hardware.ufactory import UFactoryEndEffector, UFactory7DOFArm - -class HardwareInterface: - def __init__(self, end_effector: EndEffector = None, sensors: list = None, arm_architecture: UFactory7DOFArm = None): - self.end_effector = end_effector - self.sensors = sensors if sensors is not None else [] - self.arm_architecture = arm_architecture - - def get_configuration(self): - """Return the current hardware configuration.""" - return { - 'end_effector': self.end_effector, - 'sensors': [sensor.get_sensor_type() for sensor in self.sensors], - 'arm_architecture': self.arm_architecture - } - - def set_configuration(self, configuration): - """Set the hardware configuration.""" - self.end_effector = configuration.get('end_effector', self.end_effector) - self.sensors = configuration.get('sensors', self.sensors) - self.arm_architecture = configuration.get('arm_architecture', self.arm_architecture) - - def add_sensor(self, sensor): - """Add a sensor to the hardware interface.""" - if isinstance(sensor, (Camera, StereoCamera)): - self.sensors.append(sensor) - else: - raise ValueError("Sensor must be a Camera or StereoCamera instance.") diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py new file mode 100644 index 0000000000..c043407e11 --- /dev/null +++ b/dimos/hardware/piper_arm.py @@ -0,0 +1,525 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# dimos/hardware/piper_arm.py + +import select +import sys +import termios +import threading +import time +import tty + +from dimos_lcm.geometry_msgs import Pose, Twist, Vector3 # type: ignore[import-untyped] +import kinpy as kp # type: ignore[import-not-found] +import numpy as np +from piper_sdk import * # type: ignore[import-not-found] # from the official Piper SDK +import pytest +from reactivex.disposable import Disposable +from scipy.spatial.transform import Rotation as R + +import dimos.core as core +from dimos.core import In, Module, rpc +import dimos.protocol.service.lcmservice as lcmservice +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler + +logger = setup_logger() + + +class PiperArm: + def __init__(self, arm_name: str = "arm") -> None: + self.arm = C_PiperInterface_V2() # type: ignore[name-defined] # noqa: F405 + self.arm.ConnectPort() + self.resetArm() + time.sleep(0.5) + self.resetArm() + time.sleep(0.5) + self.enable() + self.enable_gripper() # Enable gripper after arm is enabled + self.gotoZero() + time.sleep(1) + self.init_vel_controller() + + def enable(self) -> None: + while not self.arm.EnablePiper(): + pass + time.sleep(0.01) + logger.info("Arm enabled") + # self.arm.ModeCtrl( + # ctrl_mode=0x01, # CAN command mode + # move_mode=0x01, # “Move-J”, but ignored in MIT + # move_spd_rate_ctrl=100, # doesn’t matter in MIT + # is_mit_mode=0xAD # <-- the magic flag + # ) + self.arm.MotionCtrl_2(0x01, 0x01, 80, 0xAD) + + def gotoZero(self) -> None: + factor = 1000 + position = [57.0, 0.0, 215.0, 0, 90.0, 0, 0] + X = round(position[0] * factor) + Y = round(position[1] * factor) + Z = round(position[2] * factor) + RX = round(position[3] * factor) + RY = round(position[4] * factor) + RZ = round(position[5] * factor) + round(position[6] * factor) + logger.debug(f"Going to zero position: X={X}, Y={Y}, Z={Z}, RX={RX}, RY={RY}, RZ={RZ}") + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) + self.arm.GripperCtrl(0, 1000, 0x01, 0) + + def gotoObserve(self) -> None: + factor = 1000 + position = [57.0, 0.0, 280.0, 0, 120.0, 0, 0] + X = round(position[0] * factor) + Y = round(position[1] * factor) + Z = round(position[2] * factor) + RX = round(position[3] * factor) + RY = round(position[4] * factor) + RZ = round(position[5] * factor) + round(position[6] * factor) + logger.debug(f"Going to zero position: X={X}, Y={Y}, Z={Z}, RX={RX}, RY={RY}, RZ={RZ}") + self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) + self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) + + def softStop(self) -> None: + self.gotoZero() + time.sleep(1) + self.arm.MotionCtrl_2( + 0x01, + 0x00, + 100, + ) + self.arm.MotionCtrl_1(0x01, 0, 0) + time.sleep(3) + + def cmd_ee_pose_values(self, x, y, z, r, p, y_, line_mode: bool = False) -> None: # type: ignore[no-untyped-def] + """Command end-effector to target pose in space (position + Euler angles)""" + factor = 1000 + pose = [ + x * factor * factor, + y * factor * factor, + z * factor * factor, + r * factor, + p * factor, + y_ * factor, + ] + self.arm.MotionCtrl_2(0x01, 0x02 if line_mode else 0x00, 100, 0x00) + self.arm.EndPoseCtrl( + int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) + ) + + def cmd_ee_pose(self, pose: Pose, line_mode: bool = False) -> None: + """Command end-effector to target pose using Pose message""" + # Convert quaternion to euler angles + euler = quaternion_to_euler(pose.orientation, degrees=True) + + # Command the pose + self.cmd_ee_pose_values( + pose.position.x, + pose.position.y, + pose.position.z, + euler.x, + euler.y, + euler.z, + line_mode, + ) + + def get_ee_pose(self): # type: ignore[no-untyped-def] + """Return the current end-effector pose as Pose message with position in meters and quaternion orientation""" + pose = self.arm.GetArmEndPoseMsgs() + factor = 1000.0 + # Extract individual pose values and convert to base units + # Position values are divided by 1000 to convert from SDK units to meters + # Rotation values are divided by 1000 to convert from SDK units to radians + x = pose.end_pose.X_axis / factor / factor # Convert mm to m + y = pose.end_pose.Y_axis / factor / factor # Convert mm to m + z = pose.end_pose.Z_axis / factor / factor # Convert mm to m + rx = pose.end_pose.RX_axis / factor + ry = pose.end_pose.RY_axis / factor + rz = pose.end_pose.RZ_axis / factor + + # Create position vector (already in meters) + position = Vector3(x, y, z) + + orientation = euler_to_quaternion(Vector3(rx, ry, rz), degrees=True) + + return Pose(position, orientation) + + def cmd_gripper_ctrl(self, position, effort: float = 0.25) -> None: # type: ignore[no-untyped-def] + """Command end-effector gripper""" + factor = 1000 + position = position * factor * factor # meters + effort = effort * factor # N/m + + self.arm.GripperCtrl(abs(round(position)), abs(round(effort)), 0x01, 0) + logger.debug(f"Commanding gripper position: {position}mm") + + def enable_gripper(self) -> None: + """Enable the gripper using the initialization sequence""" + logger.info("Enabling gripper...") + while not self.arm.EnablePiper(): + time.sleep(0.01) + self.arm.GripperCtrl(0, 1000, 0x02, 0) + self.arm.GripperCtrl(0, 1000, 0x01, 0) + logger.info("Gripper enabled") + + def release_gripper(self) -> None: + """Release gripper by opening to 100mm (10cm)""" + logger.info("Releasing gripper (opening to 100mm)") + self.cmd_gripper_ctrl(0.1) # 0.1m = 100mm = 10cm + + def get_gripper_feedback(self) -> tuple[float, float]: + """ + Get current gripper feedback. + + Returns: + Tuple of (angle_degrees, effort) where: + - angle_degrees: Current gripper angle in degrees + - effort: Current gripper effort (0.0 to 1.0 range) + """ + gripper_msg = self.arm.GetArmGripperMsgs() + angle_degrees = ( + gripper_msg.gripper_state.grippers_angle / 1000.0 + ) # Convert from SDK units to degrees + effort = gripper_msg.gripper_state.grippers_effort / 1000.0 # Convert from SDK units to N/m + return angle_degrees, effort + + def close_gripper(self, commanded_effort: float = 0.5) -> None: + """ + Close the gripper. + + Args: + commanded_effort: Effort to use when closing gripper (default 0.25 N/m) + """ + # Command gripper to close (0.0 position) + self.cmd_gripper_ctrl(0.0, effort=commanded_effort) + logger.info("Closing gripper") + + def gripper_object_detected(self, commanded_effort: float = 0.25) -> bool: + """ + Check if an object is detected in the gripper based on effort feedback. + + Args: + commanded_effort: The effort that was used when closing gripper (default 0.25 N/m) + + Returns: + True if object is detected in gripper, False otherwise + """ + # Get gripper feedback + _angle_degrees, actual_effort = self.get_gripper_feedback() + + # Check if object is grasped (effort > 80% of commanded effort) + effort_threshold = 0.8 * commanded_effort + object_present = abs(actual_effort) > effort_threshold + + if object_present: + logger.info(f"Object detected in gripper (effort: {actual_effort:.3f} N/m)") + else: + logger.info(f"No object detected (effort: {actual_effort:.3f} N/m)") + + return object_present + + def resetArm(self) -> None: + self.arm.MotionCtrl_1(0x02, 0, 0) + self.arm.MotionCtrl_2(0, 0, 0, 0x00) + logger.info("Resetting arm") + + def init_vel_controller(self) -> None: + self.chain = kp.build_serial_chain_from_urdf( + open("dimos/hardware/piper_description.urdf"), "gripper_base" + ) + self.J = self.chain.jacobian(np.zeros(6)) + self.J_pinv = np.linalg.pinv(self.J) + self.dt = 0.01 + + def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot) -> None: # type: ignore[no-untyped-def] + joint_state = self.arm.GetArmJointMsgs().joint_state + # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) + joint_angles = np.array( + [ + joint_state.joint_1, + joint_state.joint_2, + joint_state.joint_3, + joint_state.joint_4, + joint_state.joint_5, + joint_state.joint_6, + ] + ) + # print(f"[PiperArm] Current Joints: {joint_angles}", type(joint_angles)) + factor = 57295.7795 # 1000*180/3.1415926 + joint_angles = joint_angles / factor # convert to radians + + q = np.array( + [ + joint_angles[0], + joint_angles[1], + joint_angles[2], + joint_angles[3], + joint_angles[4], + joint_angles[5], + ] + ) + J = self.chain.jacobian(q) + self.J_pinv = np.linalg.pinv(J) + dq = self.J_pinv @ np.array([x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot]) * self.dt + newq = q + dq + + newq = newq * factor + + self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) + self.arm.JointCtrl( + round(newq[0]), + round(newq[1]), + round(newq[2]), + round(newq[3]), + round(newq[4]), + round(newq[5]), + ) + time.sleep(self.dt) + # print(f"[PiperArm] Moving to Joints to : {newq}") + + def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot) -> None: # type: ignore[no-untyped-def] + factor = 1000 + x_dot = x_dot * factor + y_dot = y_dot * factor + z_dot = z_dot * factor + RX_dot = RX_dot * factor + PY_dot = PY_dot * factor + YZ_dot = YZ_dot * factor + + current_pose_msg = self.get_ee_pose() # type: ignore[no-untyped-call] + + # Convert quaternion to euler angles + quat = [ + current_pose_msg.orientation.x, + current_pose_msg.orientation.y, + current_pose_msg.orientation.z, + current_pose_msg.orientation.w, + ] + rotation = R.from_quat(quat) + euler = rotation.as_euler("xyz") # Returns [rx, ry, rz] in radians + + # Create current pose array [x, y, z, rx, ry, rz] + current_pose = np.array( + [ + current_pose_msg.position.x, + current_pose_msg.position.y, + current_pose_msg.position.z, + euler[0], + euler[1], + euler[2], + ] + ) + + # Apply velocity increment + current_pose = ( + current_pose + np.array([x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot]) * self.dt + ) + + self.cmd_ee_pose_values( + current_pose[0], + current_pose[1], + current_pose[2], + current_pose[3], + current_pose[4], + current_pose[5], + ) + time.sleep(self.dt) + + def disable(self) -> None: + self.softStop() + + while self.arm.DisablePiper(): + pass + time.sleep(0.01) + self.arm.DisconnectPort() + + +class VelocityController(Module): + cmd_vel: In[Twist] = None # type: ignore[assignment] + + def __init__(self, arm, period: float = 0.01, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.arm = arm + self.period = period + self.latest_cmd = None + self.last_cmd_time = None + self._thread = None + + @rpc + def start(self) -> None: + super().start() + + unsub = self.cmd_vel.subscribe(self.handle_cmd_vel) + self._disposables.add(Disposable(unsub)) + + def control_loop() -> None: + while True: + # Check for timeout (1 second) + if self.last_cmd_time and (time.time() - self.last_cmd_time) > 1.0: + logger.warning( + "No velocity command received for 1 second, stopping control loop" + ) + break + + cmd_vel = self.latest_cmd + + joint_state = self.arm.GetArmJointMsgs().joint_state + # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) + joint_angles = np.array( + [ + joint_state.joint_1, + joint_state.joint_2, + joint_state.joint_3, + joint_state.joint_4, + joint_state.joint_5, + joint_state.joint_6, + ] + ) + factor = 57295.7795 # 1000*180/3.1415926 + joint_angles = joint_angles / factor # convert to radians + q = np.array( + [ + joint_angles[0], + joint_angles[1], + joint_angles[2], + joint_angles[3], + joint_angles[4], + joint_angles[5], + ] + ) + + J = self.chain.jacobian(q) # type: ignore[attr-defined] + self.J_pinv = np.linalg.pinv(J) + dq = ( + self.J_pinv + @ np.array( + [ + cmd_vel.linear.X, # type: ignore[attr-defined] + cmd_vel.linear.y, # type: ignore[attr-defined] + cmd_vel.linear.z, # type: ignore[attr-defined] + cmd_vel.angular.x, # type: ignore[attr-defined] + cmd_vel.angular.y, # type: ignore[attr-defined] + cmd_vel.angular.z, # type: ignore[attr-defined] + ] + ) + * self.dt # type: ignore[attr-defined] + ) + newq = q + dq + + newq = newq * factor # convert radians to scaled degree units for joint control + + self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) + self.arm.JointCtrl( + round(newq[0]), + round(newq[1]), + round(newq[2]), + round(newq[3]), + round(newq[4]), + round(newq[5]), + ) + time.sleep(self.period) + + self._thread = threading.Thread(target=control_loop, daemon=True) # type: ignore[assignment] + self._thread.start() # type: ignore[attr-defined] + + @rpc + def stop(self) -> None: + if self._thread: + # TODO: trigger the thread to stop + self._thread.join(2) + super().stop() + + def handle_cmd_vel(self, cmd_vel: Twist) -> None: + self.latest_cmd = cmd_vel + self.last_cmd_time = time.time() # type: ignore[assignment] + + +@pytest.mark.tool +def run_velocity_controller() -> None: + lcmservice.autoconf() + dimos = core.start(2) + + velocity_controller = dimos.deploy(VelocityController, arm=arm, period=0.01) # type: ignore[attr-defined] + velocity_controller.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) + + velocity_controller.start() + + logger.info("Velocity controller started") + while True: + time.sleep(1) + + # velocity_controller.stop() + + +if __name__ == "__main__": + arm = PiperArm() + + def get_key(timeout: float = 0.1): # type: ignore[no-untyped-def] + """Non-blocking key reader for arrow keys.""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + rlist, _, _ = select.select([fd], [], [], timeout) + if rlist: + ch1 = sys.stdin.read(1) + if ch1 == "\x1b": # Arrow keys start with ESC + ch2 = sys.stdin.read(1) + if ch2 == "[": + ch3 = sys.stdin.read(1) + return ch1 + ch2 + ch3 + else: + return ch1 + return None + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + def teleop_linear_vel(arm) -> None: # type: ignore[no-untyped-def] + print("Use arrow keys to control linear velocity (x/y/z). Press 'q' to quit.") + print("Up/Down: +x/-x, Left/Right: +y/-y, 'w'/'s': +z/-z") + x_dot, y_dot, z_dot = 0.0, 0.0, 0.0 + while True: + key = get_key(timeout=0.1) + if key == "\x1b[A": # Up arrow + x_dot += 0.01 + elif key == "\x1b[B": # Down arrow + x_dot -= 0.01 + elif key == "\x1b[C": # Right arrow + y_dot += 0.01 + elif key == "\x1b[D": # Left arrow + y_dot -= 0.01 + elif key == "w": + z_dot += 0.01 + elif key == "s": + z_dot -= 0.01 + elif key == "q": + logger.info("Exiting teleop") + arm.disable() + break + + # Optionally, clamp velocities to reasonable limits + x_dot = max(min(x_dot, 0.5), -0.5) + y_dot = max(min(y_dot, 0.5), -0.5) + z_dot = max(min(z_dot, 0.5), -0.5) + + # Only linear velocities, angular set to zero + arm.cmd_vel_ee(x_dot, y_dot, z_dot, 0, 0, 0) + logger.debug( + f"Current linear velocity: x={x_dot:.3f} m/s, y={y_dot:.3f} m/s, z={z_dot:.3f} m/s" + ) + + run_velocity_controller() diff --git a/dimos/hardware/piper_description.urdf b/dimos/hardware/piper_description.urdf new file mode 100755 index 0000000000..c8a5a11ded --- /dev/null +++ b/dimos/hardware/piper_description.urdf @@ -0,0 +1,497 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dimos/hardware/sensor.py b/dimos/hardware/sensor.py index f4c3e68006..dc86d93e56 100644 --- a/dimos/hardware/sensor.py +++ b/dimos/hardware/sensor.py @@ -1,20 +1,35 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod + class AbstractSensor(ABC): - def __init__(self, sensor_type=None): + def __init__(self, sensor_type=None) -> None: # type: ignore[no-untyped-def] self.sensor_type = sensor_type @abstractmethod - def get_sensor_type(self): + def get_sensor_type(self): # type: ignore[no-untyped-def] """Return the type of sensor.""" pass @abstractmethod - def calculate_intrinsics(self): + def calculate_intrinsics(self): # type: ignore[no-untyped-def] """Calculate the sensor's intrinsics.""" pass @abstractmethod - def get_intrinsics(self): + def get_intrinsics(self): # type: ignore[no-untyped-def] """Return the sensor's intrinsics.""" pass diff --git a/dimos/hardware/stereo_camera.py b/dimos/hardware/stereo_camera.py deleted file mode 100644 index a8bb5c3d92..0000000000 --- a/dimos/hardware/stereo_camera.py +++ /dev/null @@ -1,11 +0,0 @@ -from dimos.hardware.camera import Camera - -class StereoCamera(Camera): - def __init__(self, baseline=None, **kwargs): - super().__init__(**kwargs) - self.baseline = baseline - - def get_intrinsics(self): - intrinsics = super().get_intrinsics() - intrinsics['baseline'] = self.baseline - return intrinsics diff --git a/dimos/hardware/ufactory.py b/dimos/hardware/ufactory.py deleted file mode 100644 index 11459526a0..0000000000 --- a/dimos/hardware/ufactory.py +++ /dev/null @@ -1,16 +0,0 @@ -from dimos.hardware.end_effector import EndEffector - -class UFactoryEndEffector(EndEffector): - def __init__(self, model=None, **kwargs): - super().__init__(**kwargs) - self.model = model - - def get_model(self): - return self.model - -class UFactory7DOFArm: - def __init__(self, arm_length=None): - self.arm_length = arm_length - - def get_arm_length(self): - return self.arm_length diff --git a/dimos/data/diffusion.py b/dimos/manipulation/__init__.py similarity index 100% rename from dimos/data/diffusion.py rename to dimos/manipulation/__init__.py diff --git a/dimos/manipulation/manip_aio_pipeline.py b/dimos/manipulation/manip_aio_pipeline.py new file mode 100644 index 0000000000..c31b0dc335 --- /dev/null +++ b/dimos/manipulation/manip_aio_pipeline.py @@ -0,0 +1,592 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Asynchronous, reactive manipulation pipeline for realtime detection, filtering, and grasp generation. +""" + +import asyncio +import json +import threading +import time + +import cv2 +import numpy as np +import reactivex as rx +import reactivex.operators as ops +import websockets + +from dimos.perception.common.utils import colorize_depth +from dimos.perception.detection2d.detic_2d_det import ( # type: ignore[import-untyped] + Detic2DDetector, +) +from dimos.perception.grasp_generation.utils import draw_grasps_on_image +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering +from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class ManipulationPipeline: + """ + Clean separated stream pipeline with frame buffering. + + - Object detection runs independently on RGB stream + - Point cloud processing subscribes to both detection and ZED streams separately + - Simple frame buffering to match RGB+depth+objects + """ + + def __init__( + self, + camera_intrinsics: list[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + max_objects: int = 10, + vocabulary: str | None = None, + grasp_server_url: str | None = None, + enable_grasp_generation: bool = False, + ) -> None: + """ + Initialize the manipulation pipeline. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + max_objects: Maximum number of objects to process + vocabulary: Optional vocabulary for Detic detector + grasp_server_url: Optional WebSocket URL for Dimensional Grasp server + enable_grasp_generation: Whether to enable async grasp generation + """ + self.camera_intrinsics = camera_intrinsics + self.min_confidence = min_confidence + + # Grasp generation settings + self.grasp_server_url = grasp_server_url + self.enable_grasp_generation = enable_grasp_generation + + # Asyncio event loop for WebSocket communication + self.grasp_loop = None + self.grasp_loop_thread = None + + # Storage for grasp results and filtered objects + self.latest_grasps: list[dict] = [] # type: ignore[type-arg] # Simplified: just a list of grasps + self.grasps_consumed = False + self.latest_filtered_objects = [] # type: ignore[var-annotated] + self.latest_rgb_for_grasps = None # Store RGB image for grasp overlay + self.grasp_lock = threading.Lock() + + # Track pending requests - simplified to single task + self.grasp_task: asyncio.Task | None = None # type: ignore[type-arg] + + # Reactive subjects for streaming filtered objects and grasps + self.filtered_objects_subject = rx.subject.Subject() # type: ignore[var-annotated] + self.grasps_subject = rx.subject.Subject() # type: ignore[var-annotated] + self.grasp_overlay_subject = rx.subject.Subject() # type: ignore[var-annotated] # Add grasp overlay subject + + # Initialize grasp client if enabled + if self.enable_grasp_generation and self.grasp_server_url: + self._start_grasp_loop() + + # Initialize object detector + self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) + + # Initialize point cloud processor + self.pointcloud_filter = PointcloudFiltering( + color_intrinsics=camera_intrinsics, + depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics + max_num_objects=max_objects, + ) + + logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}") + + def create_streams(self, zed_stream: rx.Observable) -> dict[str, rx.Observable]: # type: ignore[type-arg] + """ + Create streams using exact old main logic. + """ + # Create ZED streams (from old main) + zed_frame_stream = zed_stream.pipe(ops.share()) + + # RGB stream for object detection (from old main) + video_stream = zed_frame_stream.pipe( + ops.map(lambda x: x.get("rgb") if x is not None else None), # type: ignore[attr-defined] + ops.filter(lambda x: x is not None), + ops.share(), + ) + object_detector = ObjectDetectionStream( + camera_intrinsics=self.camera_intrinsics, + min_confidence=self.min_confidence, + class_filter=None, + detector=self.detector, + video_stream=video_stream, + disable_depth=True, + ) + + # Store latest frames for point cloud processing (from old main) + latest_rgb = None + latest_depth = None + latest_point_cloud_overlay = None + frame_lock = threading.Lock() + + # Subscribe to combined ZED frames (from old main) + def on_zed_frame(zed_data) -> None: # type: ignore[no-untyped-def] + nonlocal latest_rgb, latest_depth + if zed_data is not None: + with frame_lock: + latest_rgb = zed_data.get("rgb") + latest_depth = zed_data.get("depth") + + # Depth stream for point cloud filtering (from old main) + def get_depth_or_overlay(zed_data): # type: ignore[no-untyped-def] + if zed_data is None: + return None + + # Check if we have a point cloud overlay available + with frame_lock: + overlay = latest_point_cloud_overlay + + if overlay is not None: + return overlay + else: + # Return regular colorized depth + return colorize_depth(zed_data.get("depth"), max_depth=10.0) + + depth_stream = zed_frame_stream.pipe( + ops.map(get_depth_or_overlay), ops.filter(lambda x: x is not None), ops.share() + ) + + # Process object detection results with point cloud filtering (from old main) + def on_detection_next(result) -> None: # type: ignore[no-untyped-def] + nonlocal latest_point_cloud_overlay + if result.get("objects"): + # Get latest RGB and depth frames + with frame_lock: + rgb = latest_rgb + depth = latest_depth + + if rgb is not None and depth is not None: + try: + filtered_objects = self.pointcloud_filter.process_images( + rgb, depth, result["objects"] + ) + + if filtered_objects: + # Store filtered objects + with self.grasp_lock: + self.latest_filtered_objects = filtered_objects + self.filtered_objects_subject.on_next(filtered_objects) + + # Create base image (colorized depth) + base_image = colorize_depth(depth, max_depth=10.0) + + # Create point cloud overlay visualization + overlay_viz = create_point_cloud_overlay_visualization( + base_image=base_image, # type: ignore[arg-type] + objects=filtered_objects, # type: ignore[arg-type] + intrinsics=self.camera_intrinsics, # type: ignore[arg-type] + ) + + # Store the overlay for the stream + with frame_lock: + latest_point_cloud_overlay = overlay_viz + + # Request grasps if enabled + if self.enable_grasp_generation and len(filtered_objects) > 0: + # Save RGB image for later grasp overlay + with frame_lock: + self.latest_rgb_for_grasps = rgb.copy() + + task = self.request_scene_grasps(filtered_objects) # type: ignore[arg-type] + if task: + # Check for results after a delay + def check_grasps_later() -> None: + time.sleep(2.0) # Wait for grasp processing + # Wait for task to complete + if hasattr(self, "grasp_task") and self.grasp_task: + try: + self.grasp_task.result( # type: ignore[call-arg] + timeout=3.0 + ) # Get result with timeout + except Exception as e: + logger.warning(f"Grasp task failed or timeout: {e}") + + # Try to get latest grasps and create overlay + with self.grasp_lock: + grasps = self.latest_grasps + + if grasps and hasattr(self, "latest_rgb_for_grasps"): + # Create grasp overlay on the saved RGB image + try: + bgr_image = cv2.cvtColor( # type: ignore[call-overload] + self.latest_rgb_for_grasps, cv2.COLOR_RGB2BGR + ) + result_bgr = draw_grasps_on_image( + bgr_image, + grasps, + self.camera_intrinsics, + max_grasps=-1, # Show all grasps + ) + result_rgb = cv2.cvtColor( + result_bgr, cv2.COLOR_BGR2RGB + ) + + # Emit grasp overlay immediately + self.grasp_overlay_subject.on_next(result_rgb) + + except Exception as e: + logger.error(f"Error creating grasp overlay: {e}") + + # Emit grasps to stream + self.grasps_subject.on_next(grasps) + + threading.Thread(target=check_grasps_later, daemon=True).start() + else: + logger.warning("Failed to create grasp task") + except Exception as e: + logger.error(f"Error in point cloud filtering: {e}") + with frame_lock: + latest_point_cloud_overlay = None + + def on_error(error) -> None: # type: ignore[no-untyped-def] + logger.error(f"Error in stream: {error}") + + def on_completed() -> None: + logger.info("Stream completed") + + def start_subscriptions() -> None: + """Start subscriptions in background thread (from old main)""" + # Subscribe to combined ZED frames + zed_frame_stream.subscribe(on_next=on_zed_frame) + + # Start subscriptions in background thread (from old main) + subscription_thread = threading.Thread(target=start_subscriptions, daemon=True) + subscription_thread.start() + time.sleep(2) # Give subscriptions time to start + + # Subscribe to object detection stream (from old main) + object_detector.get_stream().subscribe( # type: ignore[no-untyped-call] + on_next=on_detection_next, on_error=on_error, on_completed=on_completed + ) + + # Create visualization stream for web interface (from old main) + viz_stream = object_detector.get_stream().pipe( # type: ignore[no-untyped-call] + ops.map(lambda x: x["viz_frame"] if x is not None else None), # type: ignore[index] + ops.filter(lambda x: x is not None), + ) + + # Create filtered objects stream + filtered_objects_stream = self.filtered_objects_subject + + # Create grasps stream + grasps_stream = self.grasps_subject + + # Create grasp overlay subject for immediate emission + grasp_overlay_stream = self.grasp_overlay_subject + + return { + "detection_viz": viz_stream, + "pointcloud_viz": depth_stream, + "objects": object_detector.get_stream().pipe(ops.map(lambda x: x.get("objects", []))), # type: ignore[attr-defined, no-untyped-call] + "filtered_objects": filtered_objects_stream, + "grasps": grasps_stream, + "grasp_overlay": grasp_overlay_stream, + } + + def _start_grasp_loop(self) -> None: + """Start asyncio event loop in a background thread for WebSocket communication.""" + + def run_loop() -> None: + self.grasp_loop = asyncio.new_event_loop() # type: ignore[assignment] + asyncio.set_event_loop(self.grasp_loop) + self.grasp_loop.run_forever() # type: ignore[attr-defined] + + self.grasp_loop_thread = threading.Thread(target=run_loop, daemon=True) # type: ignore[assignment] + self.grasp_loop_thread.start() # type: ignore[attr-defined] + + # Wait for loop to start + while self.grasp_loop is None: + time.sleep(0.01) + + async def _send_grasp_request( + self, + points: np.ndarray, # type: ignore[type-arg] + colors: np.ndarray | None, # type: ignore[type-arg] + ) -> list[dict] | None: # type: ignore[type-arg] + """Send grasp request to Dimensional Grasp server.""" + try: + # Comprehensive client-side validation to prevent server errors + + # Validate points array + if points is None: + logger.error("Points array is None") + return None + if not isinstance(points, np.ndarray): + logger.error(f"Points is not numpy array: {type(points)}") + return None + if points.size == 0: + logger.error("Points array is empty") + return None + if len(points.shape) != 2 or points.shape[1] != 3: + logger.error(f"Points has invalid shape {points.shape}, expected (N, 3)") + return None + if points.shape[0] < 100: # Minimum points for stable grasp detection + logger.error(f"Insufficient points for grasp detection: {points.shape[0]} < 100") + return None + + # Validate and prepare colors + if colors is not None: + if not isinstance(colors, np.ndarray): + colors = None + elif colors.size == 0: + colors = None + elif len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None + elif colors.shape[0] != points.shape[0]: + colors = None + + # If no valid colors, create default colors (required by server) + if colors is None: + # Create default white colors for all points + colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 + + # Ensure data types are correct (server expects float32) + points = points.astype(np.float32) + colors = colors.astype(np.float32) + + # Validate ranges (basic sanity checks) + if np.any(np.isnan(points)) or np.any(np.isinf(points)): + logger.error("Points contain NaN or Inf values") + return None + if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): + logger.error("Colors contain NaN or Inf values") + return None + + # Clamp color values to valid range [0, 1] + colors = np.clip(colors, 0.0, 1.0) + + async with websockets.connect(self.grasp_server_url) as websocket: # type: ignore[arg-type] + request = { + "points": points.tolist(), + "colors": colors.tolist(), # Always send colors array + "lims": [-0.19, 0.12, 0.02, 0.15, 0.0, 1.0], # Default workspace limits + } + + await websocket.send(json.dumps(request)) + + response = await websocket.recv() + grasps = json.loads(response) + + # Handle server response validation + if isinstance(grasps, dict) and "error" in grasps: + logger.error(f"Server returned error: {grasps['error']}") + return None + elif isinstance(grasps, int | float) and grasps == 0: + return None + elif not isinstance(grasps, list): + logger.error( + f"Server returned unexpected response type: {type(grasps)}, value: {grasps}" + ) + return None + elif len(grasps) == 0: + return None + + converted_grasps = self._convert_grasp_format(grasps) + with self.grasp_lock: + self.latest_grasps = converted_grasps + self.grasps_consumed = False # Reset consumed flag + + # Emit to reactive stream + self.grasps_subject.on_next(self.latest_grasps) + + return converted_grasps + except websockets.exceptions.ConnectionClosed as e: + logger.error(f"WebSocket connection closed: {e}") + except websockets.exceptions.WebSocketException as e: + logger.error(f"WebSocket error: {e}") + except json.JSONDecodeError as e: + logger.error(f"Failed to parse server response as JSON: {e}") + except Exception as e: + logger.error(f"Error requesting grasps: {e}") + + return None + + def request_scene_grasps(self, objects: list[dict]) -> asyncio.Task | None: # type: ignore[type-arg] + """Request grasps for entire scene by combining all object point clouds.""" + if not self.grasp_loop or not objects: + return None + + all_points = [] + all_colors = [] + valid_objects = 0 + + for _i, obj in enumerate(objects): + # Validate point cloud data + if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: + continue + + points = obj["point_cloud_numpy"] + if not isinstance(points, np.ndarray) or points.size == 0: + continue + + # Ensure points have correct shape (N, 3) + if len(points.shape) != 2 or points.shape[1] != 3: + continue + + # Validate colors if present + colors = None + if "colors_numpy" in obj and obj["colors_numpy"] is not None: + colors = obj["colors_numpy"] + if isinstance(colors, np.ndarray) and colors.size > 0: + # Ensure colors match points count and have correct shape + if colors.shape[0] != points.shape[0]: + colors = None # Ignore colors for this object + elif len(colors.shape) != 2 or colors.shape[1] != 3: + colors = None # Ignore colors for this object + + all_points.append(points) + if colors is not None: + all_colors.append(colors) + valid_objects += 1 + + if not all_points: + return None + + try: + combined_points = np.vstack(all_points) + + # Only combine colors if ALL objects have valid colors + combined_colors = None + if len(all_colors) == valid_objects and len(all_colors) > 0: + combined_colors = np.vstack(all_colors) + + # Validate final combined data + if combined_points.size == 0: + logger.warning("Combined point cloud is empty") + return None + + if combined_colors is not None and combined_colors.shape[0] != combined_points.shape[0]: + logger.warning( + f"Color/point count mismatch: {combined_colors.shape[0]} colors vs {combined_points.shape[0]} points, dropping colors" + ) + combined_colors = None + + except Exception as e: + logger.error(f"Failed to combine point clouds: {e}") + return None + + try: + # Check if there's already a grasp task running + if hasattr(self, "grasp_task") and self.grasp_task and not self.grasp_task.done(): + return self.grasp_task + + task = asyncio.run_coroutine_threadsafe( + self._send_grasp_request(combined_points, combined_colors), self.grasp_loop + ) + + self.grasp_task = task + return task + except Exception: + logger.warning("Failed to create grasp task") + return None + + def get_latest_grasps(self, timeout: float = 5.0) -> list[dict] | None: # type: ignore[type-arg] + """Get latest grasp results, waiting for new ones if current ones have been consumed.""" + # Mark current grasps as consumed and get a reference + with self.grasp_lock: + current_grasps = self.latest_grasps + self.grasps_consumed = True + + # If we already have grasps and they haven't been consumed, return them + if current_grasps is not None and not getattr(self, "grasps_consumed", False): + return current_grasps + + # Wait for new grasps + start_time = time.time() + while time.time() - start_time < timeout: + with self.grasp_lock: + # Check if we have new grasps (different from what we marked as consumed) + if self.latest_grasps is not None and not getattr(self, "grasps_consumed", False): + return self.latest_grasps + time.sleep(0.1) # Check every 100ms + + return None # Timeout reached + + def clear_grasps(self) -> None: + """Clear all stored grasp results.""" + with self.grasp_lock: + self.latest_grasps = [] + + def _prepare_colors(self, colors: np.ndarray | None) -> np.ndarray | None: # type: ignore[type-arg] + """Prepare colors array, converting from various formats if needed.""" + if colors is None: + return None + + if colors.max() > 1.0: + colors = colors / 255.0 + + return colors + + def _convert_grasp_format(self, grasps: list[dict]) -> list[dict]: # type: ignore[type-arg] + """Convert Grasp format to our visualization format.""" + converted = [] + + for i, grasp in enumerate(grasps): + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + euler_angles = self._rotation_matrix_to_euler(rotation_matrix) + + converted_grasp = { + "id": f"grasp_{i}", + "score": grasp.get("score", 0.0), + "width": grasp.get("width", 0.0), + "height": grasp.get("height", 0.0), + "depth": grasp.get("depth", 0.0), + "translation": grasp.get("translation", [0, 0, 0]), + "rotation_matrix": rotation_matrix.tolist(), + "euler_angles": euler_angles, + } + converted.append(converted_grasp) + + converted.sort(key=lambda x: x["score"], reverse=True) + + return converted + + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> dict[str, float]: # type: ignore[type-arg] + """Convert rotation matrix to Euler angles (in radians).""" + sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) + + singular = sy < 1e-6 + + if not singular: + x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = 0 + + return {"roll": x, "pitch": y, "yaw": z} + + def cleanup(self) -> None: + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + + if self.grasp_loop and self.grasp_loop_thread: + self.grasp_loop.call_soon_threadsafe(self.grasp_loop.stop) + self.grasp_loop_thread.join(timeout=1.0) + + if hasattr(self.pointcloud_filter, "cleanup"): + self.pointcloud_filter.cleanup() + logger.info("ManipulationPipeline cleaned up") diff --git a/dimos/manipulation/manip_aio_processer.py b/dimos/manipulation/manip_aio_processer.py new file mode 100644 index 0000000000..574edc6bf5 --- /dev/null +++ b/dimos/manipulation/manip_aio_processer.py @@ -0,0 +1,422 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sequential manipulation processor for single-frame processing without reactive streams. +""" + +import time +from typing import Any + +import cv2 +import numpy as np + +from dimos.perception.common.utils import ( + colorize_depth, + combine_object_data, + detection_results_to_object_data, +) +from dimos.perception.detection2d.detic_2d_det import ( # type: ignore[import-untyped] + Detic2DDetector, +) +from dimos.perception.grasp_generation.grasp_generation import HostedGraspGenerator +from dimos.perception.grasp_generation.utils import create_grasp_overlay +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering +from dimos.perception.pointcloud.utils import ( + create_point_cloud_overlay_visualization, + extract_and_cluster_misc_points, + overlay_point_clouds_on_image, +) +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class ManipulationProcessor: + """ + Sequential manipulation processor for single-frame processing. + + Processes RGB-D frames through object detection, point cloud filtering, + and grasp generation in a single thread without reactive streams. + """ + + def __init__( + self, + camera_intrinsics: list[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + max_objects: int = 20, + vocabulary: str | None = None, + enable_grasp_generation: bool = False, + grasp_server_url: str | None = None, # Required when enable_grasp_generation=True + enable_segmentation: bool = True, + ) -> None: + """ + Initialize the manipulation processor. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + max_objects: Maximum number of objects to process + vocabulary: Optional vocabulary for Detic detector + enable_grasp_generation: Whether to enable grasp generation + grasp_server_url: WebSocket URL for Dimensional Grasp server (required when enable_grasp_generation=True) + enable_segmentation: Whether to enable semantic segmentation + segmentation_model: Segmentation model to use (SAM 2 or FastSAM) + """ + self.camera_intrinsics = camera_intrinsics + self.min_confidence = min_confidence + self.max_objects = max_objects + self.enable_grasp_generation = enable_grasp_generation + self.grasp_server_url = grasp_server_url + self.enable_segmentation = enable_segmentation + + # Validate grasp generation requirements + if enable_grasp_generation and not grasp_server_url: + raise ValueError("grasp_server_url is required when enable_grasp_generation=True") + + # Initialize object detector + self.detector = Detic2DDetector(vocabulary=vocabulary, threshold=min_confidence) + + # Initialize point cloud processor + self.pointcloud_filter = PointcloudFiltering( + color_intrinsics=camera_intrinsics, + depth_intrinsics=camera_intrinsics, # ZED uses same intrinsics + max_num_objects=max_objects, + ) + + # Initialize semantic segmentation + self.segmenter = None + if self.enable_segmentation: + self.segmenter = Sam2DSegmenter( + use_tracker=False, # Disable tracker for simple segmentation + use_analyzer=False, # Disable analyzer for simple segmentation + ) + + # Initialize grasp generator if enabled + self.grasp_generator = None + if self.enable_grasp_generation: + try: + self.grasp_generator = HostedGraspGenerator(server_url=grasp_server_url) # type: ignore[arg-type] + logger.info("Hosted grasp generator initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize hosted grasp generator: {e}") + self.grasp_generator = None + self.enable_grasp_generation = False + + logger.info( + f"Initialized ManipulationProcessor with confidence={min_confidence}, " + f"grasp_generation={enable_grasp_generation}" + ) + + def process_frame( + self, + rgb_image: np.ndarray, # type: ignore[type-arg] + depth_image: np.ndarray, # type: ignore[type-arg] + generate_grasps: bool | None = None, + ) -> dict[str, Any]: + """ + Process a single RGB-D frame through the complete pipeline. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + generate_grasps: Override grasp generation setting for this frame + + Returns: + Dictionary containing: + - detection_viz: Visualization of object detection + - pointcloud_viz: Visualization of point cloud overlay + - segmentation_viz: Visualization of semantic segmentation (if enabled) + - detection2d_objects: Raw detection results as ObjectData + - segmentation2d_objects: Raw segmentation results as ObjectData (if enabled) + - detected_objects: Detection (Object Detection) objects with point clouds filtered + - all_objects: Combined objects with intelligent duplicate removal + - full_pointcloud: Complete scene point cloud (if point cloud processing enabled) + - misc_clusters: List of clustered background/miscellaneous point clouds (DBSCAN) + - misc_voxel_grid: Open3D voxel grid approximating all misc/background points + - misc_pointcloud_viz: Visualization of misc/background cluster overlay + - grasps: Grasp results (list of dictionaries, if enabled) + - grasp_overlay: Grasp visualization overlay (if enabled) + - processing_time: Total processing time + """ + start_time = time.time() + results = {} + + try: + # Step 1: Object Detection + step_start = time.time() + detection_results = self.run_object_detection(rgb_image) + results["detection2d_objects"] = detection_results.get("objects", []) + results["detection_viz"] = detection_results.get("viz_frame") + detection_time = time.time() - step_start + + # Step 2: Semantic Segmentation (if enabled) + segmentation_time = 0 + if self.enable_segmentation: + step_start = time.time() + segmentation_results = self.run_segmentation(rgb_image) + results["segmentation2d_objects"] = segmentation_results.get("objects", []) + results["segmentation_viz"] = segmentation_results.get("viz_frame") + segmentation_time = time.time() - step_start # type: ignore[assignment] + + # Step 3: Point Cloud Processing + pointcloud_time = 0 + detection2d_objects = results.get("detection2d_objects", []) + segmentation2d_objects = results.get("segmentation2d_objects", []) + + # Process detection objects if available + detected_objects = [] + if detection2d_objects: + step_start = time.time() + detected_objects = self.run_pointcloud_filtering( + rgb_image, depth_image, detection2d_objects + ) + pointcloud_time += time.time() - step_start # type: ignore[assignment] + + # Process segmentation objects if available + segmentation_filtered_objects = [] + if segmentation2d_objects: + step_start = time.time() + segmentation_filtered_objects = self.run_pointcloud_filtering( + rgb_image, depth_image, segmentation2d_objects + ) + pointcloud_time += time.time() - step_start # type: ignore[assignment] + + # Combine all objects using intelligent duplicate removal + all_objects = combine_object_data( + detected_objects, # type: ignore[arg-type] + segmentation_filtered_objects, # type: ignore[arg-type] + overlap_threshold=0.8, + ) + + # Get full point cloud + full_pcd = self.pointcloud_filter.get_full_point_cloud() + + # Extract misc/background points and create voxel grid + misc_start = time.time() + misc_clusters, misc_voxel_grid = extract_and_cluster_misc_points( + full_pcd, + all_objects, # type: ignore[arg-type] + eps=0.03, + min_points=100, + enable_filtering=True, + voxel_size=0.02, + ) + misc_time = time.time() - misc_start + + # Store results + results.update( + { + "detected_objects": detected_objects, + "all_objects": all_objects, + "full_pointcloud": full_pcd, + "misc_clusters": misc_clusters, + "misc_voxel_grid": misc_voxel_grid, + } + ) + + # Create point cloud visualizations + base_image = colorize_depth(depth_image, max_depth=10.0) + + # Create visualizations + results["pointcloud_viz"] = ( + create_point_cloud_overlay_visualization( + base_image=base_image, # type: ignore[arg-type] + objects=all_objects, # type: ignore[arg-type] + intrinsics=self.camera_intrinsics, # type: ignore[arg-type] + ) + if all_objects + else base_image + ) + + results["detected_pointcloud_viz"] = ( + create_point_cloud_overlay_visualization( + base_image=base_image, # type: ignore[arg-type] + objects=detected_objects, + intrinsics=self.camera_intrinsics, # type: ignore[arg-type] + ) + if detected_objects + else base_image + ) + + if misc_clusters: + # Generate consistent colors for clusters + cluster_colors = [ + tuple((np.random.RandomState(i + 100).rand(3) * 255).astype(int)) + for i in range(len(misc_clusters)) + ] + results["misc_pointcloud_viz"] = overlay_point_clouds_on_image( + base_image=base_image, # type: ignore[arg-type] + point_clouds=misc_clusters, + camera_intrinsics=self.camera_intrinsics, + colors=cluster_colors, + point_size=2, + alpha=0.6, + ) + else: + results["misc_pointcloud_viz"] = base_image + + # Step 4: Grasp Generation (if enabled) + should_generate_grasps = ( + generate_grasps if generate_grasps is not None else self.enable_grasp_generation + ) + + if should_generate_grasps and all_objects and full_pcd: + grasps = self.run_grasp_generation(all_objects, full_pcd) # type: ignore[arg-type] + results["grasps"] = grasps + if grasps: + results["grasp_overlay"] = create_grasp_overlay( + rgb_image, grasps, self.camera_intrinsics + ) + + except Exception as e: + logger.error(f"Error processing frame: {e}") + results["error"] = str(e) + + # Add timing information + total_time = time.time() - start_time + results.update( + { + "processing_time": total_time, + "timing_breakdown": { + "detection": detection_time if "detection_time" in locals() else 0, + "segmentation": segmentation_time if "segmentation_time" in locals() else 0, + "pointcloud": pointcloud_time if "pointcloud_time" in locals() else 0, + "misc_extraction": misc_time if "misc_time" in locals() else 0, + "total": total_time, + }, + } + ) + + return results + + def run_object_detection(self, rgb_image: np.ndarray) -> dict[str, Any]: # type: ignore[type-arg] + """Run object detection on RGB image.""" + try: + # Convert RGB to BGR for Detic detector + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Use process_image method from Detic detector + bboxes, track_ids, class_ids, confidences, names, masks = self.detector.process_image( + bgr_image + ) + + # Convert to ObjectData format using utility function + objects = detection_results_to_object_data( + bboxes=bboxes, + track_ids=track_ids, + class_ids=class_ids, + confidences=confidences, + names=names, + masks=masks, + source="detection", + ) + + # Create visualization using detector's built-in method + viz_frame = self.detector.visualize_results( + rgb_image, bboxes, track_ids, class_ids, confidences, names + ) + + return {"objects": objects, "viz_frame": viz_frame} + + except Exception as e: + logger.error(f"Object detection failed: {e}") + return {"objects": [], "viz_frame": rgb_image.copy()} + + def run_pointcloud_filtering( + self, + rgb_image: np.ndarray, # type: ignore[type-arg] + depth_image: np.ndarray, # type: ignore[type-arg] + objects: list[dict], # type: ignore[type-arg] + ) -> list[dict]: # type: ignore[type-arg] + """Run point cloud filtering on detected objects.""" + try: + filtered_objects = self.pointcloud_filter.process_images( + rgb_image, + depth_image, + objects, # type: ignore[arg-type] + ) + return filtered_objects if filtered_objects else [] # type: ignore[return-value] + except Exception as e: + logger.error(f"Point cloud filtering failed: {e}") + return [] + + def run_segmentation(self, rgb_image: np.ndarray) -> dict[str, Any]: # type: ignore[type-arg] + """Run semantic segmentation on RGB image.""" + if not self.segmenter: + return {"objects": [], "viz_frame": rgb_image.copy()} + + try: + # Convert RGB to BGR for segmenter + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Get segmentation results + masks, bboxes, track_ids, probs, names = self.segmenter.process_image(bgr_image) # type: ignore[no-untyped-call] + + # Convert to ObjectData format using utility function + objects = detection_results_to_object_data( + bboxes=bboxes, + track_ids=track_ids, + class_ids=list(range(len(bboxes))), # Use indices as class IDs for segmentation + confidences=probs, + names=names, + masks=masks, + source="segmentation", + ) + + # Create visualization + if masks: + viz_bgr = self.segmenter.visualize_results( + bgr_image, masks, bboxes, track_ids, probs, names + ) + # Convert back to RGB + viz_frame = cv2.cvtColor(viz_bgr, cv2.COLOR_BGR2RGB) + else: + viz_frame = rgb_image.copy() + + return {"objects": objects, "viz_frame": viz_frame} + + except Exception as e: + logger.error(f"Segmentation failed: {e}") + return {"objects": [], "viz_frame": rgb_image.copy()} + + def run_grasp_generation(self, filtered_objects: list[dict], full_pcd) -> list[dict] | None: # type: ignore[no-untyped-def, type-arg] + """Run grasp generation using the configured generator.""" + if not self.grasp_generator: + logger.warning("Grasp generation requested but no generator available") + return None + + try: + # Generate grasps using the configured generator + grasps = self.grasp_generator.generate_grasps_from_objects(filtered_objects, full_pcd) # type: ignore[arg-type] + + # Return parsed results directly (list of grasp dictionaries) + return grasps + + except Exception as e: + logger.error(f"Grasp generation failed: {e}") + return None + + def cleanup(self) -> None: + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + if hasattr(self.pointcloud_filter, "cleanup"): + self.pointcloud_filter.cleanup() + if self.segmenter and hasattr(self.segmenter, "cleanup"): + self.segmenter.cleanup() + if self.grasp_generator and hasattr(self.grasp_generator, "cleanup"): + self.grasp_generator.cleanup() + logger.info("ManipulationProcessor cleaned up") diff --git a/dimos/manipulation/manipulation_history.py b/dimos/manipulation/manipulation_history.py new file mode 100644 index 0000000000..d054803f28 --- /dev/null +++ b/dimos/manipulation/manipulation_history.py @@ -0,0 +1,417 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for manipulation history tracking and search.""" + +from dataclasses import dataclass, field +from datetime import datetime +import json +import os +import pickle +import time +from typing import Any + +from dimos.types.manipulation import ( + ManipulationTask, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +@dataclass +class ManipulationHistoryEntry: + """An entry in the manipulation history. + + Attributes: + task: The manipulation task executed + timestamp: When the manipulation was performed + result: Result of the manipulation (success/failure) + manipulation_response: Response from the motion planner/manipulation executor + """ + + task: ManipulationTask + timestamp: float = field(default_factory=time.time) + result: dict[str, Any] = field(default_factory=dict) + manipulation_response: str | None = ( + None # Any elaborative response from the motion planner / manipulation executor + ) + + def __str__(self) -> str: + status = self.result.get("status", "unknown") + return f"ManipulationHistoryEntry(task='{self.task.description}', status={status}, time={datetime.fromtimestamp(self.timestamp).strftime('%H:%M:%S')})" + + +class ManipulationHistory: + """A simplified, dictionary-based storage for manipulation history. + + This class provides an efficient way to store and query manipulation tasks, + focusing on quick lookups and flexible search capabilities. + """ + + def __init__(self, output_dir: str | None = None, new_memory: bool = False) -> None: + """Initialize a new manipulation history. + + Args: + output_dir: Directory to save history to + new_memory: If True, creates a new memory instead of loading existing one + """ + self._history: list[ManipulationHistoryEntry] = [] + self._output_dir = output_dir + + if output_dir and not new_memory: + self.load_from_dir(output_dir) + elif output_dir: + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Created new manipulation history at {output_dir}") + + def __len__(self) -> int: + """Return the number of entries in the history.""" + return len(self._history) + + def __str__(self) -> str: + """Return a string representation of the history.""" + if not self._history: + return "ManipulationHistory(empty)" + + return ( + f"ManipulationHistory(entries={len(self._history)}, " + f"time_range={datetime.fromtimestamp(self._history[0].timestamp).strftime('%Y-%m-%d %H:%M:%S')} to " + f"{datetime.fromtimestamp(self._history[-1].timestamp).strftime('%Y-%m-%d %H:%M:%S')})" + ) + + def clear(self) -> None: + """Clear all entries from the history.""" + self._history.clear() + logger.info("Cleared manipulation history") + + if self._output_dir: + self.save_history() + + def add_entry(self, entry: ManipulationHistoryEntry) -> None: + """Add an entry to the history. + + Args: + entry: The entry to add + """ + self._history.append(entry) + self._history.sort(key=lambda e: e.timestamp) + + if self._output_dir: + self.save_history() + + def save_history(self) -> None: + """Save the history to the output directory.""" + if not self._output_dir: + logger.warning("Cannot save history: no output directory specified") + return + + os.makedirs(self._output_dir, exist_ok=True) + history_path = os.path.join(self._output_dir, "manipulation_history.pickle") + + with open(history_path, "wb") as f: + pickle.dump(self._history, f) + + logger.info(f"Saved manipulation history to {history_path}") + + # Also save a JSON representation for easier inspection + json_path = os.path.join(self._output_dir, "manipulation_history.json") + try: + history_data = [ + { + "task": { + "description": entry.task.description, + "target_object": entry.task.target_object, + "target_point": entry.task.target_point, + "timestamp": entry.task.timestamp, + "task_id": entry.task.task_id, + "metadata": entry.task.metadata, + }, + "result": entry.result, + "timestamp": entry.timestamp, + "manipulation_response": entry.manipulation_response, + } + for entry in self._history + ] + + with open(json_path, "w") as f: + json.dump(history_data, f, indent=2) + + logger.info(f"Saved JSON representation to {json_path}") + except Exception as e: + logger.error(f"Failed to save JSON representation: {e}") + + def load_from_dir(self, directory: str) -> None: + """Load history from the specified directory. + + Args: + directory: Directory to load history from + """ + history_path = os.path.join(directory, "manipulation_history.pickle") + + if not os.path.exists(history_path): + logger.warning(f"No history found at {history_path}") + return + + try: + with open(history_path, "rb") as f: + self._history = pickle.load(f) + + logger.info( + f"Loaded manipulation history from {history_path} with {len(self._history)} entries" + ) + except Exception as e: + logger.error(f"Failed to load history: {e}") + + def get_all_entries(self) -> list[ManipulationHistoryEntry]: + """Get all entries in chronological order. + + Returns: + List of all manipulation history entries + """ + return self._history.copy() + + def get_entry_by_index(self, index: int) -> ManipulationHistoryEntry | None: + """Get an entry by its index. + + Args: + index: Index of the entry to retrieve + + Returns: + The entry at the specified index or None if index is out of bounds + """ + if 0 <= index < len(self._history): + return self._history[index] + return None + + def get_entries_by_timerange( + self, start_time: float, end_time: float + ) -> list[ManipulationHistoryEntry]: + """Get entries within a specific time range. + + Args: + start_time: Start time (UNIX timestamp) + end_time: End time (UNIX timestamp) + + Returns: + List of entries within the specified time range + """ + return [entry for entry in self._history if start_time <= entry.timestamp <= end_time] + + def get_entries_by_object(self, object_name: str) -> list[ManipulationHistoryEntry]: + """Get entries related to a specific object. + + Args: + object_name: Name of the object to search for + + Returns: + List of entries related to the specified object + """ + return [entry for entry in self._history if entry.task.target_object == object_name] + + def create_task_entry( + self, + task: ManipulationTask, + result: dict[str, Any] | None = None, + agent_response: str | None = None, + ) -> ManipulationHistoryEntry: + """Create a new manipulation history entry. + + Args: + task: The manipulation task + result: Result of the manipulation + agent_response: Response from the agent about this manipulation + + Returns: + The created history entry + """ + entry = ManipulationHistoryEntry( + task=task, result=result or {}, manipulation_response=agent_response + ) + self.add_entry(entry) + return entry + + def search(self, **kwargs) -> list[ManipulationHistoryEntry]: # type: ignore[no-untyped-def] + """Flexible search method that can search by any field in ManipulationHistoryEntry using dot notation. + + This method supports dot notation to access nested fields. String values automatically use + substring matching (contains), while all other types use exact matching. + + Examples: + # Time-based searches: + - search(**{"task.metadata.timestamp": ('>', start_time)}) - entries after start_time + - search(**{"task.metadata.timestamp": ('>=', time - 1800)}) - entries in last 30 mins + + # Constraint searches: + - search(**{"task.constraints.*.reference_point.x": 2.5}) - tasks with x=2.5 reference point + - search(**{"task.constraints.*.end_angle.x": 90}) - tasks with 90-degree x rotation + - search(**{"task.constraints.*.lock_x": True}) - tasks with x-axis translation locked + + # Object and result searches: + - search(**{"task.metadata.objects.*.label": "cup"}) - tasks involving cups + - search(**{"result.status": "success"}) - successful tasks + - search(**{"result.error": "Collision"}) - tasks that had collisions + + Args: + **kwargs: Key-value pairs for searching using dot notation for field paths. + + Returns: + List of matching entries + """ + if not kwargs: + return self._history.copy() + + results = self._history.copy() + + for key, value in kwargs.items(): + # For all searches, automatically determine if we should use contains for strings + results = [e for e in results if self._check_field_match(e, key, value)] + + return results + + def _check_field_match(self, entry, field_path, value) -> bool: # type: ignore[no-untyped-def] + """Check if a field matches the value, with special handling for strings, collections and comparisons. + + For string values, we automatically use substring matching (contains). + For collections (returned by * path), we check if any element matches. + For numeric values (like timestamps), supports >, <, >= and <= comparisons. + For all other types, we use exact matching. + + Args: + entry: The entry to check + field_path: Dot-separated path to the field + value: Value to match against. For comparisons, use tuples like: + ('>', timestamp) - greater than + ('<', timestamp) - less than + ('>=', timestamp) - greater or equal + ('<=', timestamp) - less or equal + + Returns: + True if the field matches the value, False otherwise + """ + try: + field_value = self._get_value_by_path(entry, field_path) # type: ignore[no-untyped-call] + + # Handle comparison operators for timestamps and numbers + if isinstance(value, tuple) and len(value) == 2: + op, compare_value = value + if op == ">": + return field_value > compare_value # type: ignore[no-any-return] + elif op == "<": + return field_value < compare_value # type: ignore[no-any-return] + elif op == ">=": + return field_value >= compare_value # type: ignore[no-any-return] + elif op == "<=": + return field_value <= compare_value # type: ignore[no-any-return] + + # Handle lists (from collection searches) + if isinstance(field_value, list): + for item in field_value: + # String values use contains matching + if isinstance(item, str) and isinstance(value, str): + if value in item: + return True + # All other types use exact matching + elif item == value: + return True + return False + + # String values use contains matching + elif isinstance(field_value, str) and isinstance(value, str): + return value in field_value + # All other types use exact matching + else: + return field_value == value # type: ignore[no-any-return] + + except (AttributeError, KeyError): + return False + + def _get_value_by_path(self, obj, path): # type: ignore[no-untyped-def] + """Get a value from an object using a dot-separated path. + + This method handles three special cases: + 1. Regular attribute access (obj.attr) + 2. Dictionary key access (dict[key]) + 3. Collection search (dict.*.attr) - when * is used, it searches all values in the collection + + Args: + obj: Object to get value from + path: Dot-separated path to the field (e.g., "task.metadata.robot") + + Returns: + Value at the specified path or list of values for collection searches + + Raises: + AttributeError: If an attribute in the path doesn't exist + KeyError: If a dictionary key in the path doesn't exist + """ + current = obj + parts = path.split(".") + + for i, part in enumerate(parts): + # Collection search (*.attr) - search across all items in a collection + if part == "*": + # Get remaining path parts + remaining_path = ".".join(parts[i + 1 :]) + + # Handle different collection types + if isinstance(current, dict): + items = current.values() + if not remaining_path: # If * is the last part, return all values + return list(items) + elif isinstance(current, list): + items = current # type: ignore[assignment] + if not remaining_path: # If * is the last part, return all items + return items + else: # Not a collection + raise AttributeError( + f"Cannot use wildcard on non-collection type: {type(current)}" + ) + + # Apply remaining path to each item in the collection + results = [] + for item in items: + try: + # Recursively get values from each item + value = self._get_value_by_path(item, remaining_path) # type: ignore[no-untyped-call] + if isinstance(value, list): # Flatten nested lists + results.extend(value) + else: + results.append(value) + except (AttributeError, KeyError): + # Skip items that don't have the attribute + pass + return results + + # Regular attribute/key access + elif isinstance(current, dict): + current = current[part] + else: + current = getattr(current, part) + + return current diff --git a/dimos/manipulation/manipulation_interface.py b/dimos/manipulation/manipulation_interface.py new file mode 100644 index 0000000000..edeb99c0f0 --- /dev/null +++ b/dimos/manipulation/manipulation_interface.py @@ -0,0 +1,286 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ManipulationInterface provides a unified interface for accessing manipulation history. + +This module defines the ManipulationInterface class, which serves as an access point +for the robot's manipulation history, agent-generated constraints, and manipulation +metadata streams. +""" + +import os +from typing import TYPE_CHECKING, Any + +from dimos.manipulation.manipulation_history import ( + ManipulationHistory, +) +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.types.manipulation import ( + AbstractConstraint, + ManipulationTask, + ObjectData, +) +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from reactivex.disposable import Disposable + +logger = setup_logger() + + +class ManipulationInterface: + """ + Interface for accessing and managing robot manipulation data. + + This class provides a unified interface for managing manipulation tasks and constraints. + It maintains a list of constraints generated by the Agent and provides methods to + add and manage manipulation tasks. + """ + + def __init__( + self, + output_dir: str, + new_memory: bool = False, + perception_stream: ObjectDetectionStream = None, # type: ignore[assignment] + ) -> None: + """ + Initialize a new ManipulationInterface instance. + + Args: + output_dir: Directory for storing manipulation data + new_memory: If True, creates a new manipulation history from scratch + perception_stream: ObjectDetectionStream instance for real-time object data + """ + self.output_dir = output_dir + + # Create manipulation history directory + manipulation_dir = os.path.join(output_dir, "manipulation_history") + os.makedirs(manipulation_dir, exist_ok=True) + + # Initialize manipulation history + self.manipulation_history: ManipulationHistory = ManipulationHistory( + output_dir=manipulation_dir, new_memory=new_memory + ) + + # List of constraints generated by the Agent via constraint generation skills + self.agent_constraints: list[AbstractConstraint] = [] + + # Initialize object detection stream and related properties + self.perception_stream = perception_stream + self.latest_objects: list[ObjectData] = [] + self.stream_subscription: Disposable | None = None + + # Set up subscription to perception stream if available + self._setup_perception_subscription() + + logger.info("ManipulationInterface initialized") + + def add_constraint(self, constraint: AbstractConstraint) -> None: + """ + Add a constraint generated by the Agent via a constraint generation skill. + + Args: + constraint: The constraint to add to agent_constraints + """ + self.agent_constraints.append(constraint) + logger.info(f"Added agent constraint: {constraint}") + + def get_constraints(self) -> list[AbstractConstraint]: + """ + Get all constraints generated by the Agent via constraint generation skills. + + Returns: + List of all constraints created by the Agent + """ + return self.agent_constraints + + def get_constraint(self, constraint_id: str) -> AbstractConstraint | None: + """ + Get a specific constraint by its ID. + + Args: + constraint_id: ID of the constraint to retrieve + + Returns: + The matching constraint or None if not found + """ + # Find constraint with matching ID + for constraint in self.agent_constraints: + if constraint.id == constraint_id: + return constraint + + logger.warning(f"Constraint with ID {constraint_id} not found") + return None + + def add_manipulation_task( + self, task: ManipulationTask, manipulation_response: str | None = None + ) -> None: + """ + Add a manipulation task to ManipulationHistory. + + Args: + task: The ManipulationTask to add + manipulation_response: Optional response from the motion planner/executor + + """ + # Add task to history + self.manipulation_history.add_entry( # type: ignore[call-arg] + task=task, result=None, notes=None, manipulation_response=manipulation_response + ) + + def get_manipulation_task(self, task_id: str) -> ManipulationTask | None: + """ + Get a manipulation task by its ID. + + Args: + task_id: ID of the task to retrieve + + Returns: + The task object or None if not found + """ + return self.history.get_manipulation_task(task_id) # type: ignore[attr-defined, no-any-return] + + def get_all_manipulation_tasks(self) -> list[ManipulationTask]: + """ + Get all manipulation tasks. + + Returns: + List of all manipulation tasks + """ + return self.history.get_all_manipulation_tasks() # type: ignore[attr-defined, no-any-return] + + def update_task_status( + self, task_id: str, status: str, result: dict[str, Any] | None = None + ) -> ManipulationTask | None: + """ + Update the status and result of a manipulation task. + + Args: + task_id: ID of the task to update + status: New status for the task (e.g., 'completed', 'failed') + result: Optional dictionary with result data + + Returns: + The updated task or None if task not found + """ + return self.history.update_task_status(task_id, status, result) # type: ignore[attr-defined, no-any-return] + + # === Perception stream methods === + + def _setup_perception_subscription(self) -> None: + """ + Set up subscription to perception stream if available. + """ + if self.perception_stream: + # Subscribe to the stream and update latest_objects + self.stream_subscription = self.perception_stream.get_stream().subscribe( # type: ignore[no-untyped-call] + on_next=self._update_latest_objects, + on_error=lambda e: logger.error(f"Error in perception stream: {e}"), + ) + logger.info("Subscribed to perception stream") + + def _update_latest_objects(self, data) -> None: # type: ignore[no-untyped-def] + """ + Update the latest detected objects. + + Args: + data: Data from the object detection stream + """ + if "objects" in data: + self.latest_objects = data["objects"] + + def get_latest_objects(self) -> list[ObjectData]: + """ + Get the latest detected objects from the stream. + + Returns: + List of the most recently detected objects + """ + return self.latest_objects + + def get_object_by_id(self, object_id: int) -> ObjectData | None: + """ + Get a specific object by its tracking ID. + + Args: + object_id: Tracking ID of the object + + Returns: + The object data or None if not found + """ + for obj in self.latest_objects: + if obj["object_id"] == object_id: + return obj + return None + + def get_objects_by_label(self, label: str) -> list[ObjectData]: + """ + Get all objects with a specific label. + + Args: + label: Class label to filter objects by + + Returns: + List of objects matching the label + """ + return [obj for obj in self.latest_objects if obj["label"] == label] + + def set_perception_stream(self, perception_stream) -> None: # type: ignore[no-untyped-def] + """ + Set or update the perception stream. + + Args: + perception_stream: The PerceptionStream instance + """ + # Clean up existing subscription if any + self.cleanup_perception_subscription() + + # Set new stream and subscribe + self.perception_stream = perception_stream + self._setup_perception_subscription() + + def cleanup_perception_subscription(self) -> None: + """ + Clean up the stream subscription. + """ + if self.stream_subscription: + self.stream_subscription.dispose() + self.stream_subscription = None + + # === Utility methods === + + def clear_history(self) -> None: + """ + Clear all manipulation history data and agent constraints. + """ + self.manipulation_history.clear() + self.agent_constraints.clear() + logger.info("Cleared manipulation history and agent constraints") + + def __str__(self) -> str: + """ + String representation of the manipulation interface. + + Returns: + String representation with key stats + """ + has_stream = self.perception_stream is not None + return f"ManipulationInterface(history={self.manipulation_history}, agent_constraints={len(self.agent_constraints)}, perception_stream={has_stream}, detected_objects={len(self.latest_objects)})" + + def __del__(self) -> None: + """ + Clean up resources on deletion. + """ + self.cleanup_perception_subscription() diff --git a/dimos/manipulation/test_manipulation_history.py b/dimos/manipulation/test_manipulation_history.py new file mode 100644 index 0000000000..a1edfca787 --- /dev/null +++ b/dimos/manipulation/test_manipulation_history.py @@ -0,0 +1,458 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import time + +import pytest + +from dimos.manipulation.manipulation_history import ManipulationHistory, ManipulationHistoryEntry +from dimos.types.manipulation import ( + ForceConstraint, + ManipulationTask, + RotationConstraint, + TranslationConstraint, +) +from dimos.types.vector import Vector + + +@pytest.fixture +def sample_task(): + """Create a sample manipulation task for testing.""" + return ManipulationTask( + description="Pick up the cup", + target_object="cup", + target_point=(100, 200), + task_id="task1", + metadata={ + "timestamp": time.time(), + "objects": { + "cup1": { + "object_id": 1, + "label": "cup", + "confidence": 0.95, + "position": {"x": 1.5, "y": 2.0, "z": 0.5}, + }, + "table1": { + "object_id": 2, + "label": "table", + "confidence": 0.98, + "position": {"x": 0.0, "y": 0.0, "z": 0.0}, + }, + }, + }, + ) + + +@pytest.fixture +def sample_task_with_constraints(): + """Create a sample manipulation task with constraints for testing.""" + task = ManipulationTask( + description="Rotate the bottle", + target_object="bottle", + target_point=(150, 250), + task_id="task2", + metadata={ + "timestamp": time.time(), + "objects": { + "bottle1": { + "object_id": 3, + "label": "bottle", + "confidence": 0.92, + "position": {"x": 2.5, "y": 1.0, "z": 0.3}, + } + }, + }, + ) + + # Add rich translation constraint + translation_constraint = TranslationConstraint( + translation_axis="y", + reference_point=Vector(2.5, 1.0, 0.3), + bounds_min=Vector(2.0, 0.5, 0.3), + bounds_max=Vector(3.0, 1.5, 0.3), + target_point=Vector(2.7, 1.2, 0.3), + description="Constrained translation along Y-axis only", + ) + task.add_constraint(translation_constraint) + + # Add rich rotation constraint + rotation_constraint = RotationConstraint( + rotation_axis="roll", + start_angle=Vector(0, 0, 0), + end_angle=Vector(90, 0, 0), + pivot_point=Vector(2.5, 1.0, 0.3), + secondary_pivot_point=Vector(2.5, 1.0, 0.5), + description="Constrained rotation around X-axis (roll only)", + ) + task.add_constraint(rotation_constraint) + + # Add force constraint + force_constraint = ForceConstraint( + min_force=2.0, + max_force=5.0, + force_direction=Vector(0, 0, -1), + description="Apply moderate downward force during manipulation", + ) + task.add_constraint(force_constraint) + + return task + + +@pytest.fixture +def temp_output_dir(): + """Create a temporary directory for testing history saving/loading.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + +@pytest.fixture +def populated_history(sample_task, sample_task_with_constraints): + """Create a populated history with multiple entries for testing.""" + history = ManipulationHistory() + + # Add first entry + entry1 = ManipulationHistoryEntry( + task=sample_task, + result={"status": "success", "execution_time": 2.5}, + manipulation_response="Successfully picked up the cup", + ) + history.add_entry(entry1) + + # Add second entry + entry2 = ManipulationHistoryEntry( + task=sample_task_with_constraints, + result={"status": "failure", "error": "Collision detected"}, + manipulation_response="Failed to rotate the bottle due to collision", + ) + history.add_entry(entry2) + + return history + + +def test_manipulation_history_init() -> None: + """Test initialization of ManipulationHistory.""" + # Default initialization + history = ManipulationHistory() + assert len(history) == 0 + assert str(history) == "ManipulationHistory(empty)" + + # With output directory + with tempfile.TemporaryDirectory() as temp_dir: + history = ManipulationHistory(output_dir=temp_dir, new_memory=True) + assert len(history) == 0 + assert os.path.exists(temp_dir) + + +def test_manipulation_history_add_entry(sample_task) -> None: + """Test adding entries to ManipulationHistory.""" + history = ManipulationHistory() + + # Create and add entry + entry = ManipulationHistoryEntry( + task=sample_task, result={"status": "success"}, manipulation_response="Task completed" + ) + history.add_entry(entry) + + assert len(history) == 1 + assert history.get_entry_by_index(0) == entry + + +def test_manipulation_history_create_task_entry(sample_task) -> None: + """Test creating a task entry directly.""" + history = ManipulationHistory() + + entry = history.create_task_entry( + task=sample_task, result={"status": "success"}, agent_response="Task completed" + ) + + assert len(history) == 1 + assert entry.task == sample_task + assert entry.result["status"] == "success" + assert entry.manipulation_response == "Task completed" + + +def test_manipulation_history_save_load(temp_output_dir, sample_task) -> None: + """Test saving and loading history from disk.""" + # Create history and add entry + history = ManipulationHistory(output_dir=temp_output_dir) + history.create_task_entry( + task=sample_task, result={"status": "success"}, agent_response="Task completed" + ) + + # Check that files were created + pickle_path = os.path.join(temp_output_dir, "manipulation_history.pickle") + json_path = os.path.join(temp_output_dir, "manipulation_history.json") + assert os.path.exists(pickle_path) + assert os.path.exists(json_path) + + # Create new history that loads from the saved files + loaded_history = ManipulationHistory(output_dir=temp_output_dir) + assert len(loaded_history) == 1 + assert loaded_history.get_entry_by_index(0).task.description == sample_task.description + + +def test_manipulation_history_clear(populated_history) -> None: + """Test clearing the history.""" + assert len(populated_history) > 0 + + populated_history.clear() + assert len(populated_history) == 0 + assert str(populated_history) == "ManipulationHistory(empty)" + + +def test_manipulation_history_get_methods(populated_history) -> None: + """Test various getter methods of ManipulationHistory.""" + # get_all_entries + entries = populated_history.get_all_entries() + assert len(entries) == 2 + + # get_entry_by_index + entry = populated_history.get_entry_by_index(0) + assert entry.task.task_id == "task1" + + # Out of bounds index + assert populated_history.get_entry_by_index(100) is None + + # get_entries_by_timerange + start_time = time.time() - 3600 # 1 hour ago + end_time = time.time() + 3600 # 1 hour from now + entries = populated_history.get_entries_by_timerange(start_time, end_time) + assert len(entries) == 2 + + # get_entries_by_object + cup_entries = populated_history.get_entries_by_object("cup") + assert len(cup_entries) == 1 + assert cup_entries[0].task.task_id == "task1" + + bottle_entries = populated_history.get_entries_by_object("bottle") + assert len(bottle_entries) == 1 + assert bottle_entries[0].task.task_id == "task2" + + +def test_manipulation_history_search_basic(populated_history) -> None: + """Test basic search functionality.""" + # Search by exact match on top-level fields + results = populated_history.search(timestamp=populated_history.get_entry_by_index(0).timestamp) + assert len(results) == 1 + + # Search by task fields + results = populated_history.search(**{"task.task_id": "task1"}) + assert len(results) == 1 + assert results[0].task.target_object == "cup" + + # Search by result fields + results = populated_history.search(**{"result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by manipulation_response (substring match for strings) + results = populated_history.search(manipulation_response="picked up") + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_nested(populated_history) -> None: + """Test search with nested field paths.""" + # Search by nested metadata fields + results = populated_history.search( + **{ + "task.metadata.timestamp": populated_history.get_entry_by_index(0).task.metadata[ + "timestamp" + ] + } + ) + assert len(results) == 1 + + # Search by nested object fields + results = populated_history.search(**{"task.metadata.objects.cup1.label": "cup"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by position values + results = populated_history.search(**{"task.metadata.objects.cup1.position.x": 1.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_wildcards(populated_history) -> None: + """Test search with wildcard patterns.""" + # Search for any object with label "cup" + results = populated_history.search(**{"task.metadata.objects.*.label": "cup"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for any object with confidence > 0.95 + results = populated_history.search(**{"task.metadata.objects.*.confidence": 0.98}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for any object position with x=2.5 + results = populated_history.search(**{"task.metadata.objects.*.position.x": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_constraints(populated_history) -> None: + """Test search by constraint properties.""" + # Find entries with any TranslationConstraint with y-axis + results = populated_history.search(**{"task.constraints.*.translation_axis": "y"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Find entries with any RotationConstraint with roll axis + results = populated_history.search(**{"task.constraints.*.rotation_axis": "roll"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_string_contains(populated_history) -> None: + """Test string contains searching.""" + # Basic string contains + results = populated_history.search(**{"task.description": "Pick"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Nested string contains + results = populated_history.search(manipulation_response="collision") + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_multiple_criteria(populated_history) -> None: + """Test search with multiple criteria.""" + # Multiple criteria - all must match + results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Multiple criteria with no matches + results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) + assert len(results) == 0 + + # Combination of direct and wildcard paths + results = populated_history.search( + **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} + ) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_nonexistent_fields(populated_history) -> None: + """Test search with fields that don't exist.""" + # Search by nonexistent field + results = populated_history.search(nonexistent_field="value") + assert len(results) == 0 + + # Search by nonexistent nested field + results = populated_history.search(**{"task.nonexistent_field": "value"}) + assert len(results) == 0 + + # Search by nonexistent object + results = populated_history.search(**{"task.metadata.objects.nonexistent_object": "value"}) + assert len(results) == 0 + + +def test_manipulation_history_search_timestamp_ranges(populated_history) -> None: + """Test searching by timestamp ranges.""" + # Get reference timestamps + entry1_time = populated_history.get_entry_by_index(0).task.metadata["timestamp"] + entry2_time = populated_history.get_entry_by_index(1).task.metadata["timestamp"] + mid_time = (entry1_time + entry2_time) / 2 + + # Search for timestamps before second entry + results = populated_history.search(**{"task.metadata.timestamp": ("<", entry2_time)}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search for timestamps after first entry + results = populated_history.search(**{"task.metadata.timestamp": (">", entry1_time)}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search within a time window using >= and <= + results = populated_history.search(**{"task.metadata.timestamp": (">=", mid_time - 1800)}) + assert len(results) == 2 + assert results[0].task.task_id == "task1" + assert results[1].task.task_id == "task2" + + +def test_manipulation_history_search_vector_fields(populated_history) -> None: + """Test searching by vector components in constraints.""" + # Search by reference point components + results = populated_history.search(**{"task.constraints.*.reference_point.x": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by target point components + results = populated_history.search(**{"task.constraints.*.target_point.z": 0.3}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by rotation angles + results = populated_history.search(**{"task.constraints.*.end_angle.x": 90}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + +def test_manipulation_history_search_execution_details(populated_history) -> None: + """Test searching by execution time and error patterns.""" + # Search by execution time + results = populated_history.search(**{"result.execution_time": 2.5}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Search by error message pattern + results = populated_history.search(**{"result.error": "Collision"}) + assert len(results) == 1 + assert results[0].task.task_id == "task2" + + # Search by status + results = populated_history.search(**{"result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + +def test_manipulation_history_search_multiple_criteria(populated_history) -> None: + """Test search with multiple criteria.""" + # Multiple criteria - all must match + results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) + assert len(results) == 1 + assert results[0].task.task_id == "task1" + + # Multiple criteria with no matches + results = populated_history.search(**{"task.target_object": "cup", "result.status": "failure"}) + assert len(results) == 0 + + # Combination of direct and wildcard paths + results = populated_history.search( + **{"task.target_object": "bottle", "task.metadata.objects.*.position.z": 0.3} + ) + assert len(results) == 1 + assert results[0].task.task_id == "task2" diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py new file mode 100644 index 0000000000..94b4660d2b --- /dev/null +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -0,0 +1,302 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Real-time 3D object detection processor that extracts object poses from RGB-D data. +""" + +import cv2 +from dimos_lcm.vision_msgs import ( # type: ignore[import-untyped] + BoundingBox2D, + BoundingBox3D, + Detection2D, + Detection3D, + ObjectHypothesis, + ObjectHypothesisWithPose, + Point2D, + Pose2D, +) +import numpy as np + +from dimos.manipulation.visual_servoing.utils import ( + estimate_object_depth, + transform_pose, + visualize_detections_3d, +) +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.perception.common.utils import bbox2d_to_corners +from dimos.perception.detection2d.utils import calculate_object_size_from_bbox +from dimos.perception.pointcloud.utils import extract_centroids_from_masks +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Detection3DProcessor: + """ + Real-time 3D detection processor optimized for speed. + + Uses Sam (FastSAM) for segmentation and mask generation, then extracts + 3D centroids from depth data. + """ + + def __init__( + self, + camera_intrinsics: list[float], # [fx, fy, cx, cy] + min_confidence: float = 0.6, + min_points: int = 30, + max_depth: float = 1.0, + max_object_size: float = 0.15, + ) -> None: + """ + Initialize the real-time 3D detection processor. + + Args: + camera_intrinsics: [fx, fy, cx, cy] camera parameters + min_confidence: Minimum detection confidence threshold + min_points: Minimum 3D points required for valid detection + max_depth: Maximum valid depth in meters + """ + self.camera_intrinsics = camera_intrinsics + self.min_points = min_points + self.max_depth = max_depth + self.max_object_size = max_object_size + + # Initialize Sam segmenter with tracking enabled but analysis disabled + self.detector = Sam2DSegmenter( + use_tracker=False, + use_analyzer=False, + use_filtering=True, + ) + + self.min_confidence = min_confidence + + logger.info( + f"Initialized Detection3DProcessor with Sam segmenter, confidence={min_confidence}, " + f"min_points={min_points}, max_depth={max_depth}m, max_object_size={max_object_size}m" + ) + + def process_frame( + self, + rgb_image: np.ndarray, # type: ignore[type-arg] + depth_image: np.ndarray, # type: ignore[type-arg] + transform: np.ndarray | None = None, # type: ignore[type-arg] + ) -> tuple[Detection3DArray, Detection2DArray]: + """ + Process a single RGB-D frame to extract 3D object detections. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + transform: Optional 4x4 transformation matrix to transform objects from camera frame to desired frame + + Returns: + Tuple of (Detection3DArray, Detection2DArray) with 3D and 2D information + """ + + # Convert RGB to BGR for Sam (OpenCV format) + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Run Sam segmentation with tracking + masks, bboxes, track_ids, probs, names = self.detector.process_image(bgr_image) # type: ignore[no-untyped-call] + + if not masks or len(masks) == 0: + return Detection3DArray( + detections_length=0, header=Header(), detections=[] + ), Detection2DArray(detections_length=0, header=Header(), detections=[]) + + # Convert CUDA tensors to numpy arrays if needed + numpy_masks = [] + for mask in masks: + if hasattr(mask, "cpu"): # PyTorch tensor + numpy_masks.append(mask.cpu().numpy()) + else: # Already numpy array + numpy_masks.append(mask) + + # Extract 3D centroids from masks + poses = extract_centroids_from_masks( + rgb_image=rgb_image, + depth_image=depth_image, + masks=numpy_masks, + camera_intrinsics=self.camera_intrinsics, + ) + + detections_3d = [] + detections_2d = [] + pose_dict = {p["mask_idx"]: p for p in poses if p["centroid"][2] < self.max_depth} + + for i, (bbox, name, prob, track_id) in enumerate( + zip(bboxes, names, probs, track_ids, strict=False) + ): + if i not in pose_dict: + continue + + pose = pose_dict[i] + obj_cam_pos = pose["centroid"] + + if obj_cam_pos[2] > self.max_depth: + continue + + # Calculate object size from bbox and depth + width_m, height_m = calculate_object_size_from_bbox( + bbox, obj_cam_pos[2], self.camera_intrinsics + ) + + # Calculate depth dimension using segmentation mask + depth_m = estimate_object_depth( + depth_image, numpy_masks[i] if i < len(numpy_masks) else None, bbox + ) + + size_x = max(width_m, 0.01) # Minimum 1cm width + size_y = max(height_m, 0.01) # Minimum 1cm height + size_z = max(depth_m, 0.01) # Minimum 1cm depth + + if min(size_x, size_y, size_z) > self.max_object_size: + continue + + # Transform to desired frame if transform matrix is provided + if transform is not None: + # Get orientation as euler angles, default to no rotation if not available + obj_cam_orientation = pose.get( + "rotation", np.array([0.0, 0.0, 0.0]) + ) # Default to no rotation + transformed_pose = transform_pose( + obj_cam_pos, obj_cam_orientation, transform, to_robot=True + ) + center_pose = transformed_pose + else: + # If no transform, use camera coordinates + center_pose = Pose( + position=Vector3(obj_cam_pos[0], obj_cam_pos[1], obj_cam_pos[2]), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), # Default orientation + ) + + # Create Detection3D object + detection = Detection3D( + results_length=1, + header=Header(), # Empty header + results=[ + ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis(class_id=name, score=float(prob)) + ) + ], + bbox=BoundingBox3D(center=center_pose, size=Vector3(size_x, size_y, size_z)), + id=str(track_id), + ) + + detections_3d.append(detection) + + # Create corresponding Detection2D + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = x2 - x1 + height = y2 - y1 + + detection_2d = Detection2D( + results_length=1, + header=Header(), + results=[ + ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis(class_id=name, score=float(prob)) + ) + ], + bbox=BoundingBox2D( + center=Pose2D(position=Point2D(center_x, center_y), theta=0.0), + size_x=float(width), + size_y=float(height), + ), + id=str(track_id), + ) + detections_2d.append(detection_2d) + + # Create and return both arrays + return ( + Detection3DArray( + detections_length=len(detections_3d), header=Header(), detections=detections_3d + ), + Detection2DArray( + detections_length=len(detections_2d), header=Header(), detections=detections_2d + ), + ) + + def visualize_detections( + self, + rgb_image: np.ndarray, # type: ignore[type-arg] + detections_3d: list[Detection3D], + detections_2d: list[Detection2D], + show_coordinates: bool = True, + ) -> np.ndarray: # type: ignore[type-arg] + """ + Visualize detections with 3D position overlay next to bounding boxes. + + Args: + rgb_image: Original RGB image + detections_3d: List of Detection3D objects + detections_2d: List of Detection2D objects (must be 1:1 correspondence) + show_coordinates: Whether to show 3D coordinates + + Returns: + Visualization image + """ + # Extract 2D bboxes from Detection2D objects + + bboxes_2d = [] + for det_2d in detections_2d: + if det_2d.bbox: + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + bboxes_2d.append([x1, y1, x2, y2]) + + return visualize_detections_3d(rgb_image, detections_3d, show_coordinates, bboxes_2d) + + def get_closest_detection( + self, detections: list[Detection3D], class_filter: str | None = None + ) -> Detection3D | None: + """ + Get the closest detection with valid 3D data. + + Args: + detections: List of Detection3D objects + class_filter: Optional class name to filter by + + Returns: + Closest Detection3D or None + """ + valid_detections = [] + for d in detections: + # Check if has valid bbox center position + if d.bbox and d.bbox.center and d.bbox.center.position: + # Check class filter if specified + if class_filter is None or ( + d.results_length > 0 and d.results[0].hypothesis.class_id == class_filter + ): + valid_detections.append(d) + + if not valid_detections: + return None + + # Sort by depth (Z coordinate) + def get_z_coord(d): # type: ignore[no-untyped-def] + return abs(d.bbox.center.position.z) + + return min(valid_detections, key=get_z_coord) + + def cleanup(self) -> None: + """Clean up resources.""" + if hasattr(self.detector, "cleanup"): + self.detector.cleanup() + logger.info("Detection3DProcessor cleaned up") diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py new file mode 100644 index 0000000000..076d127ac8 --- /dev/null +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -0,0 +1,949 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Manipulation module for robotic grasping with visual servoing. +Handles grasping logic, state machine, and hardware coordination as a Dimos module. +""" + +from collections import deque +from enum import Enum +import threading +import time +from typing import Any + +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +import numpy as np +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.hardware.piper_arm import PiperArm +from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor +from dimos.manipulation.visual_servoing.pbvs import PBVS +from dimos.manipulation.visual_servoing.utils import ( + create_manipulation_visualization, + is_target_reached, + select_points_from_depth, + transform_points_3d, + update_target_grasp_pose, +) +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.perception.common.utils import find_clicked_detection +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import ( + compose_transforms, + create_transform_from_6dof, + matrix_to_pose, + pose_to_matrix, +) + +logger = setup_logger() + + +class GraspStage(Enum): + """Enum for different grasp stages.""" + + IDLE = "idle" + PRE_GRASP = "pre_grasp" + GRASP = "grasp" + CLOSE_AND_RETRACT = "close_and_retract" + PLACE = "place" + RETRACT = "retract" + + +class Feedback: + """Feedback data containing state information about the manipulation process.""" + + def __init__( + self, + grasp_stage: GraspStage, + target_tracked: bool, + current_executed_pose: Pose | None = None, + current_ee_pose: Pose | None = None, + current_camera_pose: Pose | None = None, + target_pose: Pose | None = None, + waiting_for_reach: bool = False, + success: bool | None = None, + ) -> None: + self.grasp_stage = grasp_stage + self.target_tracked = target_tracked + self.current_executed_pose = current_executed_pose + self.current_ee_pose = current_ee_pose + self.current_camera_pose = current_camera_pose + self.target_pose = target_pose + self.waiting_for_reach = waiting_for_reach + self.success = success + + +class ManipulationModule(Module): + """ + Manipulation module for visual servoing and grasping. + + Subscribes to: + - ZED RGB images + - ZED depth images + - ZED camera info + + Publishes: + - Visualization images + + RPC methods: + - handle_keyboard_command: Process keyboard input + - pick_and_place: Execute pick and place task + """ + + # LCM inputs + rgb_image: In[Image] = None # type: ignore[assignment] + depth_image: In[Image] = None # type: ignore[assignment] + camera_info: In[CameraInfo] = None # type: ignore[assignment] + + # LCM outputs + viz_image: Out[Image] = None # type: ignore[assignment] + + def __init__( # type: ignore[no-untyped-def] + self, + ee_to_camera_6dof: list | None = None, # type: ignore[type-arg] + **kwargs, + ) -> None: + """ + Initialize manipulation module. + + Args: + ee_to_camera_6dof: EE to camera transform [x, y, z, rx, ry, rz] in meters and radians + workspace_min_radius: Minimum workspace radius in meters + workspace_max_radius: Maximum workspace radius in meters + min_grasp_pitch_degrees: Minimum grasp pitch angle (at max radius) + max_grasp_pitch_degrees: Maximum grasp pitch angle (at min radius) + """ + super().__init__(**kwargs) + + self.arm = PiperArm() + + if ee_to_camera_6dof is None: + ee_to_camera_6dof = [-0.065, 0.03, -0.095, 0.0, -1.57, 0.0] + pos = Vector3(ee_to_camera_6dof[0], ee_to_camera_6dof[1], ee_to_camera_6dof[2]) + rot = Vector3(ee_to_camera_6dof[3], ee_to_camera_6dof[4], ee_to_camera_6dof[5]) + self.T_ee_to_camera = create_transform_from_6dof(pos, rot) + + self.camera_intrinsics = None + self.detector = None + self.pbvs = None + + # Control state + self.last_valid_target = None + self.waiting_for_reach = False + self.current_executed_pose = None # Track the actual pose sent to arm + self.target_updated = False + self.waiting_start_time = None + self.reach_pose_timeout = 20.0 + + # Grasp parameters + self.grasp_width_offset = 0.03 + self.pregrasp_distance = 0.25 + self.grasp_distance_range = 0.03 + self.grasp_close_delay = 2.0 + self.grasp_reached_time = None + self.gripper_max_opening = 0.07 + + # Workspace limits and dynamic pitch parameters + self.workspace_min_radius = 0.2 + self.workspace_max_radius = 0.75 + self.min_grasp_pitch_degrees = 5.0 + self.max_grasp_pitch_degrees = 60.0 + + # Grasp stage tracking + self.grasp_stage = GraspStage.IDLE + + # Pose stabilization tracking + self.pose_history_size = 4 + self.pose_stabilization_threshold = 0.01 + self.stabilization_timeout = 25.0 + self.stabilization_start_time = None + self.reached_poses = deque(maxlen=self.pose_history_size) # type: ignore[var-annotated] + self.adjustment_count = 0 + + # Pose reachability tracking + self.ee_pose_history = deque(maxlen=20) # type: ignore[var-annotated] # Keep history of EE poses + self.stuck_pose_threshold = 0.001 # 1mm movement threshold + self.stuck_pose_adjustment_degrees = 5.0 + self.stuck_count = 0 + self.max_stuck_reattempts = 7 + + # State for visualization + self.current_visualization = None + self.last_detection_3d_array = None + self.last_detection_2d_array = None + + # Grasp result and task tracking + self.pick_success = None + self.final_pregrasp_pose = None + self.task_failed = False + self.overall_success = None + + # Task control + self.task_running = False + self.task_thread = None + self.stop_event = threading.Event() + + # Latest sensor data + self.latest_rgb = None + self.latest_depth = None + self.latest_camera_info = None + + # Target selection + self.target_click = None + + # Place target position and object info + self.home_pose = Pose( + position=Vector3(0.0, 0.0, 0.0), orientation=Quaternion(0.0, 0.0, 0.0, 1.0) + ) + self.place_target_position = None + self.target_object_height = None + self.retract_distance = 0.12 + self.place_pose = None + self.retract_pose = None + self.arm.gotoObserve() + + @rpc + def start(self) -> None: + """Start the manipulation module.""" + + unsub = self.rgb_image.subscribe(self._on_rgb_image) + self._disposables.add(Disposable(unsub)) + + unsub = self.depth_image.subscribe(self._on_depth_image) + self._disposables.add(Disposable(unsub)) + + unsub = self.camera_info.subscribe(self._on_camera_info) + self._disposables.add(Disposable(unsub)) + + logger.info("Manipulation module started") + + @rpc + def stop(self) -> None: + """Stop the manipulation module.""" + # Stop any running task + self.stop_event.set() + if self.task_thread and self.task_thread.is_alive(): + self.task_thread.join(timeout=5.0) + + self.reset_to_idle() + + if self.detector and hasattr(self.detector, "cleanup"): + self.detector.cleanup() + self.arm.disable() + + logger.info("Manipulation module stopped") + + def _on_rgb_image(self, msg: Image) -> None: + """Handle RGB image messages.""" + try: + self.latest_rgb = msg.data + except Exception as e: + logger.error(f"Error processing RGB image: {e}") + + def _on_depth_image(self, msg: Image) -> None: + """Handle depth image messages.""" + try: + self.latest_depth = msg.data + except Exception as e: + logger.error(f"Error processing depth image: {e}") + + def _on_camera_info(self, msg: CameraInfo) -> None: + """Handle camera info messages.""" + try: + self.camera_intrinsics = [msg.K[0], msg.K[4], msg.K[2], msg.K[5]] # type: ignore[assignment] + + if self.detector is None: + self.detector = Detection3DProcessor(self.camera_intrinsics) # type: ignore[arg-type, assignment] + self.pbvs = PBVS() # type: ignore[assignment] + logger.info("Initialized detection and PBVS processors") + + self.latest_camera_info = msg + except Exception as e: + logger.error(f"Error processing camera info: {e}") + + @rpc + def get_single_rgb_frame(self) -> np.ndarray | None: # type: ignore[type-arg] + """ + get the latest rgb frame from the camera + """ + return self.latest_rgb + + @rpc + def handle_keyboard_command(self, key: str) -> str: + """ + Handle keyboard commands for robot control. + + Args: + key: Keyboard key as string + + Returns: + Action taken as string, or empty string if no action + """ + key_code = ord(key) if len(key) == 1 else int(key) + + if key_code == ord("r"): + self.stop_event.set() + self.task_running = False + self.reset_to_idle() + return "reset" + elif key_code == ord("s"): + logger.info("SOFT STOP - Emergency stopping robot!") + self.arm.softStop() + self.stop_event.set() + self.task_running = False + return "stop" + elif key_code == ord(" ") and self.pbvs and self.pbvs.target_grasp_pose: + if self.grasp_stage == GraspStage.PRE_GRASP: + self.set_grasp_stage(GraspStage.GRASP) + logger.info("Executing target pose") + return "execute" + elif key_code == ord("g"): + logger.info("Opening gripper") + self.arm.release_gripper() + return "release" + + return "" + + @rpc + def pick_and_place( + self, + target_x: int | None = None, + target_y: int | None = None, + place_x: int | None = None, + place_y: int | None = None, + ) -> dict[str, Any]: + """ + Start a pick and place task. + + Args: + target_x: Optional X coordinate of target object + target_y: Optional Y coordinate of target object + place_x: Optional X coordinate of place location + place_y: Optional Y coordinate of place location + + Returns: + Dict with status and message + """ + if self.task_running: + return {"status": "error", "message": "Task already running"} + + if self.camera_intrinsics is None: + return {"status": "error", "message": "Camera not initialized"} + + if target_x is not None and target_y is not None: + self.target_click = (target_x, target_y) + if place_x is not None and self.latest_depth is not None: + points_3d_camera = select_points_from_depth( + self.latest_depth, + (place_x, place_y), + self.camera_intrinsics, + radius=10, + ) + + if points_3d_camera.size > 0: + ee_pose = self.arm.get_ee_pose() + ee_transform = pose_to_matrix(ee_pose) + camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) + + points_3d_world = transform_points_3d( + points_3d_camera, + camera_transform, + to_robot=True, + ) + + place_position = np.mean(points_3d_world, axis=0) + self.place_target_position = place_position + logger.info( + f"Place target set at position: ({place_position[0]:.3f}, {place_position[1]:.3f}, {place_position[2]:.3f})" + ) + else: + logger.warning("No valid depth points found at place location") + self.place_target_position = None + else: + self.place_target_position = None + + self.task_failed = False + self.stop_event.clear() + + if self.task_thread and self.task_thread.is_alive(): + self.stop_event.set() + self.task_thread.join(timeout=1.0) + self.task_thread = threading.Thread(target=self._run_pick_and_place, daemon=True) + self.task_thread.start() + + return {"status": "started", "message": "Pick and place task started"} + + def _run_pick_and_place(self) -> None: + """Run the pick and place task loop.""" + self.task_running = True + logger.info("Starting pick and place task") + + try: + while not self.stop_event.is_set(): + if self.task_failed: + logger.error("Task failed, terminating pick and place") + self.stop_event.set() + break + + feedback = self.update() + if feedback is None: + time.sleep(0.01) + continue + + if feedback.success is not None: # type: ignore[attr-defined] + if feedback.success: # type: ignore[attr-defined] + logger.info("Pick and place completed successfully!") + else: + logger.warning("Pick and place failed") + self.reset_to_idle() + self.stop_event.set() + break + + time.sleep(0.01) + + except Exception as e: + logger.error(f"Error in pick and place task: {e}") + self.task_failed = True + finally: + self.task_running = False + logger.info("Pick and place task ended") + + def set_grasp_stage(self, stage: GraspStage) -> None: + """Set the grasp stage.""" + self.grasp_stage = stage + logger.info(f"Grasp stage: {stage.value}") + + def calculate_dynamic_grasp_pitch(self, target_pose: Pose) -> float: + """ + Calculate grasp pitch dynamically based on distance from robot base. + Maps workspace radius to grasp pitch angle. + + Args: + target_pose: Target pose + + Returns: + Grasp pitch angle in degrees + """ + # Calculate 3D distance from robot base (assumes robot at origin) + position = target_pose.position + distance = np.sqrt(position.x**2 + position.y**2 + position.z**2) + + # Clamp distance to workspace limits + distance = np.clip(distance, self.workspace_min_radius, self.workspace_max_radius) + + # Linear interpolation: min_radius -> max_pitch, max_radius -> min_pitch + # Normalized distance (0 to 1) + normalized_dist = (distance - self.workspace_min_radius) / ( + self.workspace_max_radius - self.workspace_min_radius + ) + + # Inverse mapping: closer objects need higher pitch + pitch_degrees = self.max_grasp_pitch_degrees - ( + normalized_dist * (self.max_grasp_pitch_degrees - self.min_grasp_pitch_degrees) + ) + + return pitch_degrees # type: ignore[no-any-return] + + def check_within_workspace(self, target_pose: Pose) -> bool: + """ + Check if pose is within workspace limits and log error if not. + + Args: + target_pose: Target pose to validate + + Returns: + True if within workspace, False otherwise + """ + # Calculate 3D distance from robot base + position = target_pose.position + distance = np.sqrt(position.x**2 + position.y**2 + position.z**2) + + if not (self.workspace_min_radius <= distance <= self.workspace_max_radius): + logger.error( + f"Target outside workspace limits: distance {distance:.3f}m not in [{self.workspace_min_radius:.2f}, {self.workspace_max_radius:.2f}]" + ) + return False + + return True + + def _check_reach_timeout(self) -> tuple[bool, float]: + """Check if robot has exceeded timeout while reaching pose. + + Returns: + Tuple of (timed_out, time_elapsed) + """ + if self.waiting_start_time: + time_elapsed = time.time() - self.waiting_start_time + if time_elapsed > self.reach_pose_timeout: + logger.warning( + f"Robot failed to reach pose within {self.reach_pose_timeout}s timeout" + ) + self.task_failed = True + self.reset_to_idle() + return True, time_elapsed + return False, time_elapsed + return False, 0.0 + + def _check_if_stuck(self) -> bool: + """Check if robot is stuck based on recent pose history. + + Returns: + True if stuck. False if moving or + not enough history to figure this out. + """ + if len(self.ee_pose_history) < self.ee_pose_history.maxlen: # type: ignore[operator] + return False + + # Extract positions from pose history + positions = np.array( + [[p.position.x, p.position.y, p.position.z] for p in self.ee_pose_history] + ) + + # Calculate standard deviation of positions + std_devs = np.std(positions, axis=0) + # Check if all standard deviations are below stuck threshold + is_stuck = np.all(std_devs < self.stuck_pose_threshold) + + return is_stuck # type: ignore[return-value] + + def check_reach_and_adjust(self) -> bool: + """ + Check if robot has reached the current executed pose while waiting. + Handles timeout internally by failing the task. + Also detects if the robot is stuck (not moving towards target). + + Returns: + True if reached, False if still waiting or not in waiting state + """ + if not self.waiting_for_reach or not self.current_executed_pose: + return False + + # Get current end-effector pose + ee_pose = self.arm.get_ee_pose() + target_pose = self.current_executed_pose + + # Check for timeout - this will fail task and reset if timeout occurred + timed_out, _time_elapsed = self._check_reach_timeout() + if timed_out: + return False + + self.ee_pose_history.append(ee_pose) + + # Check if robot is stuck + is_stuck = self._check_if_stuck() + if is_stuck: + if self.grasp_stage == GraspStage.RETRACT or self.grasp_stage == GraspStage.PLACE: + self.waiting_for_reach = False + self.waiting_start_time = None + self.stuck_count = 0 + self.ee_pose_history.clear() + return True + self.stuck_count += 1 + pitch_degrees = self.calculate_dynamic_grasp_pitch(target_pose) + if self.stuck_count % 2 == 0: + pitch_degrees += self.stuck_pose_adjustment_degrees * (1 + self.stuck_count // 2) + else: + pitch_degrees -= self.stuck_pose_adjustment_degrees * (1 + self.stuck_count // 2) + + pitch_degrees = max( + self.min_grasp_pitch_degrees, min(self.max_grasp_pitch_degrees, pitch_degrees) + ) + updated_target_pose = update_target_grasp_pose(target_pose, ee_pose, 0.0, pitch_degrees) + self.arm.cmd_ee_pose(updated_target_pose) + self.current_executed_pose = updated_target_pose + self.ee_pose_history.clear() + self.waiting_for_reach = True + self.waiting_start_time = time.time() + return False + + if self.stuck_count >= self.max_stuck_reattempts: + self.task_failed = True + self.reset_to_idle() + return False + + if is_target_reached(target_pose, ee_pose, self.pbvs.target_tolerance): + self.waiting_for_reach = False + self.waiting_start_time = None + self.stuck_count = 0 + self.ee_pose_history.clear() + return True + return False + + def _update_tracking(self, detection_3d_array: Detection3DArray | None) -> bool: + """Update tracking with new detections.""" + if not detection_3d_array or not self.pbvs: + return False + + target_tracked = self.pbvs.update_tracking(detection_3d_array) + if target_tracked: + self.target_updated = True + self.last_valid_target = self.pbvs.get_current_target() + return target_tracked + + def reset_to_idle(self) -> None: + """Reset the manipulation system to IDLE state.""" + if self.pbvs: + self.pbvs.clear_target() + self.grasp_stage = GraspStage.IDLE + self.reached_poses.clear() + self.ee_pose_history.clear() + self.adjustment_count = 0 + self.waiting_for_reach = False + self.current_executed_pose = None + self.target_updated = False + self.stabilization_start_time = None + self.grasp_reached_time = None + self.waiting_start_time = None + self.pick_success = None + self.final_pregrasp_pose = None + self.overall_success = None + self.place_pose = None + self.retract_pose = None + self.stuck_count = 0 + + self.arm.gotoObserve() + + def execute_idle(self) -> None: + """Execute idle stage.""" + pass + + def execute_pre_grasp(self) -> None: + """Execute pre-grasp stage: visual servoing to pre-grasp position.""" + if self.waiting_for_reach: + if self.check_reach_and_adjust(): + self.reached_poses.append(self.current_executed_pose) + self.target_updated = False + time.sleep(0.2) + return + if ( + self.stabilization_start_time + and (time.time() - self.stabilization_start_time) > self.stabilization_timeout + ): + logger.warning( + f"Failed to get stable grasp after {self.stabilization_timeout} seconds, resetting" + ) + self.task_failed = True + self.reset_to_idle() + return + + ee_pose = self.arm.get_ee_pose() # type: ignore[no-untyped-call] + dynamic_pitch = self.calculate_dynamic_grasp_pitch(self.pbvs.current_target.bbox.center) # type: ignore[attr-defined] + + _, _, _, has_target, target_pose = self.pbvs.compute_control( # type: ignore[attr-defined] + ee_pose, self.pregrasp_distance, dynamic_pitch + ) + if target_pose and has_target: + # Validate target pose is within workspace + if not self.check_within_workspace(target_pose): + self.task_failed = True + self.reset_to_idle() + return + + if self.check_target_stabilized(): + logger.info("Target stabilized, transitioning to GRASP") + self.final_pregrasp_pose = self.current_executed_pose + self.grasp_stage = GraspStage.GRASP + self.adjustment_count = 0 + self.waiting_for_reach = False + elif not self.waiting_for_reach and self.target_updated: + self.arm.cmd_ee_pose(target_pose) + self.current_executed_pose = target_pose + self.waiting_for_reach = True + self.waiting_start_time = time.time() # type: ignore[assignment] + self.target_updated = False + self.adjustment_count += 1 + time.sleep(0.2) + + def execute_grasp(self) -> None: + """Execute grasp stage: move to final grasp position.""" + if self.waiting_for_reach: + if self.check_reach_and_adjust() and not self.grasp_reached_time: + self.grasp_reached_time = time.time() # type: ignore[assignment] + return + + if self.grasp_reached_time: + if (time.time() - self.grasp_reached_time) >= self.grasp_close_delay: + logger.info("Grasp delay completed, closing gripper") + self.grasp_stage = GraspStage.CLOSE_AND_RETRACT + return + + if self.last_valid_target: + # Calculate dynamic pitch for current target + dynamic_pitch = self.calculate_dynamic_grasp_pitch(self.last_valid_target.bbox.center) + normalized_pitch = dynamic_pitch / 90.0 + grasp_distance = -self.grasp_distance_range + ( + 2 * self.grasp_distance_range * normalized_pitch + ) + + ee_pose = self.arm.get_ee_pose() + _, _, _, has_target, target_pose = self.pbvs.compute_control( + ee_pose, grasp_distance, dynamic_pitch + ) + + if target_pose and has_target: + # Validate grasp pose is within workspace + if not self.check_within_workspace(target_pose): + self.task_failed = True + self.reset_to_idle() + return + + object_width = self.last_valid_target.bbox.size.x + gripper_opening = max( + 0.005, min(object_width + self.grasp_width_offset, self.gripper_max_opening) + ) + + logger.info(f"Executing grasp: gripper={gripper_opening * 1000:.1f}mm") + self.arm.cmd_gripper_ctrl(gripper_opening) + self.arm.cmd_ee_pose(target_pose, line_mode=True) + self.current_executed_pose = target_pose + self.waiting_for_reach = True + self.waiting_start_time = time.time() + + def execute_close_and_retract(self) -> None: + """Execute the retraction sequence after gripper has been closed.""" + if self.waiting_for_reach and self.final_pregrasp_pose: + if self.check_reach_and_adjust(): + logger.info("Reached pre-grasp retraction position") + self.pick_success = self.arm.gripper_object_detected() + if self.pick_success: + logger.info("Object successfully grasped!") + if self.place_target_position is not None: + logger.info("Transitioning to PLACE stage") + self.grasp_stage = GraspStage.PLACE + else: + self.overall_success = True + else: + logger.warning("No object detected in gripper") + self.task_failed = True + self.overall_success = False + return + if not self.waiting_for_reach: + logger.info("Retracting to pre-grasp position") + self.arm.cmd_ee_pose(self.final_pregrasp_pose, line_mode=True) + self.current_executed_pose = self.final_pregrasp_pose + self.arm.close_gripper() + self.waiting_for_reach = True + self.waiting_start_time = time.time() # type: ignore[assignment] + + def execute_place(self) -> None: + """Execute place stage: move to place position and release object.""" + if self.waiting_for_reach: + # Use the already executed pose instead of recalculating + if self.check_reach_and_adjust(): + logger.info("Reached place position, releasing gripper") + self.arm.release_gripper() + time.sleep(1.0) + self.place_pose = self.current_executed_pose + logger.info("Transitioning to RETRACT stage") + self.grasp_stage = GraspStage.RETRACT + return + + if not self.waiting_for_reach: + place_pose = self.get_place_target_pose() + if place_pose: + logger.info("Moving to place position") + self.arm.cmd_ee_pose(place_pose, line_mode=True) + self.current_executed_pose = place_pose # type: ignore[assignment] + self.waiting_for_reach = True + self.waiting_start_time = time.time() # type: ignore[assignment] + else: + logger.error("Failed to get place target pose") + self.task_failed = True + self.overall_success = False # type: ignore[assignment] + + def execute_retract(self) -> None: + """Execute retract stage: retract from place position.""" + if self.waiting_for_reach and self.retract_pose: + if self.check_reach_and_adjust(): + logger.info("Reached retract position") + logger.info("Returning to observe position") + self.arm.gotoObserve() + self.arm.close_gripper() + self.overall_success = True + logger.info("Pick and place completed successfully!") + return + + if not self.waiting_for_reach: + if self.place_pose: + pose_pitch = self.calculate_dynamic_grasp_pitch(self.place_pose) + self.retract_pose = update_target_grasp_pose( + self.place_pose, self.home_pose, self.retract_distance, pose_pitch + ) + logger.info("Retracting from place position") + self.arm.cmd_ee_pose(self.retract_pose, line_mode=True) + self.current_executed_pose = self.retract_pose + self.waiting_for_reach = True + self.waiting_start_time = time.time() + else: + logger.error("No place pose stored for retraction") + self.task_failed = True + self.overall_success = False # type: ignore[assignment] + + def capture_and_process( + self, + ) -> tuple[np.ndarray | None, Detection3DArray | None, Detection2DArray | None, Pose | None]: # type: ignore[type-arg] + """Capture frame from camera data and process detections.""" + if self.latest_rgb is None or self.latest_depth is None or self.detector is None: + return None, None, None, None + + ee_pose = self.arm.get_ee_pose() + ee_transform = pose_to_matrix(ee_pose) + camera_transform = compose_transforms(ee_transform, self.T_ee_to_camera) + camera_pose = matrix_to_pose(camera_transform) + detection_3d_array, detection_2d_array = self.detector.process_frame( + self.latest_rgb, self.latest_depth, camera_transform + ) + + return self.latest_rgb, detection_3d_array, detection_2d_array, camera_pose + + def pick_target(self, x: int, y: int) -> bool: + """Select a target object at the given pixel coordinates.""" + if not self.last_detection_2d_array or not self.last_detection_3d_array: + logger.warning("No detections available for target selection") + return False + + clicked_3d = find_clicked_detection( + (x, y), self.last_detection_2d_array.detections, self.last_detection_3d_array.detections + ) + if clicked_3d and self.pbvs: + # Validate workspace + if not self.check_within_workspace(clicked_3d.bbox.center): + self.task_failed = True + return False + + self.pbvs.set_target(clicked_3d) + + if clicked_3d.bbox and clicked_3d.bbox.size: + self.target_object_height = clicked_3d.bbox.size.z + logger.info(f"Target object height: {self.target_object_height:.3f}m") + + position = clicked_3d.bbox.center.position + logger.info( + f"Target selected: ID={clicked_3d.id}, pos=({position.x:.3f}, {position.y:.3f}, {position.z:.3f})" + ) + self.grasp_stage = GraspStage.PRE_GRASP + self.reached_poses.clear() + self.adjustment_count = 0 + self.waiting_for_reach = False + self.current_executed_pose = None + self.stabilization_start_time = time.time() + return True + return False + + def update(self) -> dict[str, Any] | None: + """Main update function that handles capture, processing, control, and visualization.""" + rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process() + if rgb is None: + return None + + self.last_detection_3d_array = detection_3d_array # type: ignore[assignment] + self.last_detection_2d_array = detection_2d_array # type: ignore[assignment] + if self.target_click: + x, y = self.target_click + if self.pick_target(x, y): + self.target_click = None + + if ( + detection_3d_array + and self.grasp_stage in [GraspStage.PRE_GRASP, GraspStage.GRASP] + and not self.waiting_for_reach + ): + self._update_tracking(detection_3d_array) + stage_handlers = { + GraspStage.IDLE: self.execute_idle, + GraspStage.PRE_GRASP: self.execute_pre_grasp, + GraspStage.GRASP: self.execute_grasp, + GraspStage.CLOSE_AND_RETRACT: self.execute_close_and_retract, + GraspStage.PLACE: self.execute_place, + GraspStage.RETRACT: self.execute_retract, + } + if self.grasp_stage in stage_handlers: + stage_handlers[self.grasp_stage]() + + target_tracked = self.pbvs.get_current_target() is not None if self.pbvs else False + ee_pose = self.arm.get_ee_pose() # type: ignore[no-untyped-call] + feedback = Feedback( + grasp_stage=self.grasp_stage, + target_tracked=target_tracked, + current_executed_pose=self.current_executed_pose, + current_ee_pose=ee_pose, + current_camera_pose=camera_pose, + target_pose=self.pbvs.target_grasp_pose if self.pbvs else None, + waiting_for_reach=self.waiting_for_reach, + success=self.overall_success, + ) + + if self.task_running: + self.current_visualization = create_manipulation_visualization( # type: ignore[assignment] + rgb, feedback, detection_3d_array, detection_2d_array + ) + + if self.current_visualization is not None: + self._publish_visualization(self.current_visualization) + + return feedback # type: ignore[return-value] + + def _publish_visualization(self, viz_image: np.ndarray) -> None: # type: ignore[type-arg] + """Publish visualization image to LCM.""" + try: + viz_rgb = cv2.cvtColor(viz_image, cv2.COLOR_BGR2RGB) + msg = Image.from_numpy(viz_rgb) + self.viz_image.publish(msg) + except Exception as e: + logger.error(f"Error publishing visualization: {e}") + + def check_target_stabilized(self) -> bool: + """Check if the commanded poses have stabilized.""" + if len(self.reached_poses) < self.reached_poses.maxlen: # type: ignore[operator] + return False + + positions = np.array( + [[p.position.x, p.position.y, p.position.z] for p in self.reached_poses] + ) + std_devs = np.std(positions, axis=0) + return np.all(std_devs < self.pose_stabilization_threshold) # type: ignore[return-value] + + def get_place_target_pose(self) -> Pose | None: + """Get the place target pose with z-offset applied based on object height.""" + if self.place_target_position is None: + return None + + place_pos = self.place_target_position.copy() + if self.target_object_height is not None: + z_offset = self.target_object_height / 2.0 + place_pos[2] += z_offset + 0.1 + + place_center_pose = Pose( + position=Vector3(place_pos[0], place_pos[1], place_pos[2]), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + + ee_pose = self.arm.get_ee_pose() + + # Calculate dynamic pitch for place position + dynamic_pitch = self.calculate_dynamic_grasp_pitch(place_center_pose) + + place_pose = update_target_grasp_pose( + place_center_pose, + ee_pose, + grasp_distance=0.0, + grasp_pitch_degrees=dynamic_pitch, + ) + + return place_pose diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py new file mode 100644 index 0000000000..2f4914da99 --- /dev/null +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -0,0 +1,488 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Position-Based Visual Servoing (PBVS) system for robotic manipulation. +Supports both eye-in-hand and eye-to-hand configurations. +""" + +from collections import deque + +from dimos_lcm.vision_msgs import Detection3D # type: ignore[import-untyped] +import numpy as np +from scipy.spatial.transform import Rotation as R + +from dimos.manipulation.visual_servoing.utils import ( + create_pbvs_visualization, + find_best_object_match, + is_target_reached, + update_target_grasp_pose, +) +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.vision_msgs import Detection3DArray +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class PBVS: + """ + High-level Position-Based Visual Servoing orchestrator. + + Handles: + - Object tracking and target management + - Pregrasp distance computation + - Grasp pose generation + - Coordination with low-level controller + + Note: This class is agnostic to camera mounting (eye-in-hand vs eye-to-hand). + The caller is responsible for providing appropriate camera and EE poses. + """ + + def __init__( + self, + position_gain: float = 0.5, + rotation_gain: float = 0.3, + max_velocity: float = 0.1, # m/s + max_angular_velocity: float = 0.5, # rad/s + target_tolerance: float = 0.01, # 1cm + max_tracking_distance_threshold: float = 0.12, # Max distance for target tracking (m) + min_size_similarity: float = 0.6, # Min size similarity threshold (0.0-1.0) + direct_ee_control: bool = True, # If True, output target poses instead of velocities + ) -> None: + """ + Initialize PBVS system. + + Args: + position_gain: Proportional gain for position control + rotation_gain: Proportional gain for rotation control + max_velocity: Maximum linear velocity command magnitude (m/s) + max_angular_velocity: Maximum angular velocity command magnitude (rad/s) + target_tolerance: Distance threshold for considering target reached (m) + max_tracking_distance: Maximum distance for valid target tracking (m) + min_size_similarity: Minimum size similarity for valid target tracking (0.0-1.0) + direct_ee_control: If True, output target poses instead of velocity commands + """ + # Initialize low-level controller only if not in direct control mode + if not direct_ee_control: + self.controller = PBVSController( + position_gain=position_gain, + rotation_gain=rotation_gain, + max_velocity=max_velocity, + max_angular_velocity=max_angular_velocity, + target_tolerance=target_tolerance, + ) + else: + self.controller = None # type: ignore[assignment] + + # Store parameters for direct mode error computation + self.target_tolerance = target_tolerance + + # Target tracking parameters + self.max_tracking_distance_threshold = max_tracking_distance_threshold + self.min_size_similarity = min_size_similarity + self.direct_ee_control = direct_ee_control + + # Target state + self.current_target = None + self.target_grasp_pose = None + + # Detection history for robust tracking + self.detection_history_size = 3 + self.detection_history = deque(maxlen=self.detection_history_size) # type: ignore[var-annotated] + + # For direct control mode visualization + self.last_position_error = None + self.last_target_reached = False + + logger.info( + f"Initialized PBVS system with controller gains: pos={position_gain}, rot={rotation_gain}, " + f"tracking_thresholds: distance={max_tracking_distance_threshold}m, size={min_size_similarity:.2f}" + ) + + def set_target(self, target_object: Detection3D) -> bool: + """ + Set a new target object for servoing. + + Args: + target_object: Detection3D object + + Returns: + True if target was set successfully + """ + if target_object and target_object.bbox and target_object.bbox.center: + self.current_target = target_object + self.target_grasp_pose = None # Will be computed when needed + logger.info(f"New target set: ID {target_object.id}") + return True + return False + + def clear_target(self) -> None: + """Clear the current target.""" + self.current_target = None + self.target_grasp_pose = None + self.last_position_error = None + self.last_target_reached = False + self.detection_history.clear() + if self.controller: + self.controller.clear_state() + logger.info("Target cleared") + + def get_current_target(self) -> Detection3D | None: + """ + Get the current target object. + + Returns: + Current target Detection3D or None if no target selected + """ + return self.current_target + + def update_tracking(self, new_detections: Detection3DArray | None = None) -> bool: + """ + Update target tracking with new detections using a rolling window. + If tracking is lost, keeps the old target pose. + + Args: + new_detections: Optional new detections for target tracking + + Returns: + True if target was successfully tracked, False if lost (but target is kept) + """ + # Check if we have a current target + if not self.current_target: + return False + + # Add new detections to history if provided + if new_detections is not None and new_detections.detections_length > 0: + self.detection_history.append(new_detections) + + # If no detection history, can't track + if not self.detection_history: + logger.debug("No detection history for target tracking - using last known pose") + return False + + # Collect all candidates from detection history + all_candidates = [] + for detection_array in self.detection_history: + all_candidates.extend(detection_array.detections) + + if not all_candidates: + logger.debug("No candidates in detection history") + return False + + # Use stage-dependent distance threshold + max_distance = self.max_tracking_distance_threshold + + # Find best match across all recent detections + match_result = find_best_object_match( + target_obj=self.current_target, + candidates=all_candidates, + max_distance=max_distance, + min_size_similarity=self.min_size_similarity, + ) + + if match_result.is_valid_match: + self.current_target = match_result.matched_object + self.target_grasp_pose = None # Recompute grasp pose + logger.debug( + f"Target tracking successful: distance={match_result.distance:.3f}m, " + f"size_similarity={match_result.size_similarity:.2f}, " + f"confidence={match_result.confidence:.2f}" + ) + return True + + logger.debug( + f"Target tracking lost across {len(self.detection_history)} frames: " + f"distance={match_result.distance:.3f}m, " + f"size_similarity={match_result.size_similarity:.2f}, " + f"thresholds: distance={max_distance:.3f}m, size={self.min_size_similarity:.2f}" + ) + return False + + def compute_control( + self, + ee_pose: Pose, + grasp_distance: float = 0.15, + grasp_pitch_degrees: float = 45.0, + ) -> tuple[Vector3 | None, Vector3 | None, bool, bool, Pose | None]: + """ + Compute PBVS control with position and orientation servoing. + + Args: + ee_pose: Current end-effector pose + grasp_distance: Distance to maintain from target (meters) + + Returns: + Tuple of (velocity_command, angular_velocity_command, target_reached, has_target, target_pose) + - velocity_command: Linear velocity vector or None if no target (None in direct_ee_control mode) + - angular_velocity_command: Angular velocity vector or None if no target (None in direct_ee_control mode) + - target_reached: True if within target tolerance + - has_target: True if currently tracking a target + - target_pose: Target EE pose (only in direct_ee_control mode, otherwise None) + """ + # Check if we have a target + if not self.current_target: + return None, None, False, False, None + + # Update target grasp pose with provided distance and pitch + self.target_grasp_pose = update_target_grasp_pose( + self.current_target.bbox.center, ee_pose, grasp_distance, grasp_pitch_degrees + ) + + if self.target_grasp_pose is None: + logger.warning("Failed to compute grasp pose") + return None, None, False, False, None + + # Compute errors for visualization before checking if reached (in case pose gets cleared) + if self.direct_ee_control and self.target_grasp_pose: + self.last_position_error = Vector3( + self.target_grasp_pose.position.x - ee_pose.position.x, + self.target_grasp_pose.position.y - ee_pose.position.y, + self.target_grasp_pose.position.z - ee_pose.position.z, + ) + + # Check if target reached using our separate function + target_reached = is_target_reached(self.target_grasp_pose, ee_pose, self.target_tolerance) + + # Return appropriate values based on control mode + if self.direct_ee_control: + # Direct control mode + if self.target_grasp_pose: + self.last_target_reached = target_reached + # Return has_target=True since we have a target + return None, None, target_reached, True, self.target_grasp_pose + else: + return None, None, False, True, None + else: + # Velocity control mode - use controller + velocity_cmd, angular_velocity_cmd, _controller_reached = ( + self.controller.compute_control(ee_pose, self.target_grasp_pose) + ) + # Return has_target=True since we have a target, regardless of tracking status + return velocity_cmd, angular_velocity_cmd, target_reached, True, None + + def create_status_overlay( # type: ignore[no-untyped-def] + self, + image: np.ndarray, # type: ignore[type-arg] + grasp_stage=None, + ) -> np.ndarray: # type: ignore[type-arg] + """ + Create PBVS status overlay on image. + + Args: + image: Input image + grasp_stage: Current grasp stage (optional) + + Returns: + Image with PBVS status overlay + """ + stage_value = grasp_stage.value if grasp_stage else "idle" + return create_pbvs_visualization( + image, + self.current_target, + self.last_position_error, + self.last_target_reached, + stage_value, + ) + + +class PBVSController: + """ + Low-level Position-Based Visual Servoing controller. + Pure control logic that computes velocity commands from poses. + + Handles: + - Position and orientation error computation + - Velocity command generation with gain control + - Target reached detection + """ + + def __init__( + self, + position_gain: float = 0.5, + rotation_gain: float = 0.3, + max_velocity: float = 0.1, # m/s + max_angular_velocity: float = 0.5, # rad/s + target_tolerance: float = 0.01, # 1cm + ) -> None: + """ + Initialize PBVS controller. + + Args: + position_gain: Proportional gain for position control + rotation_gain: Proportional gain for rotation control + max_velocity: Maximum linear velocity command magnitude (m/s) + max_angular_velocity: Maximum angular velocity command magnitude (rad/s) + target_tolerance: Distance threshold for considering target reached (m) + """ + self.position_gain = position_gain + self.rotation_gain = rotation_gain + self.max_velocity = max_velocity + self.max_angular_velocity = max_angular_velocity + self.target_tolerance = target_tolerance + + self.last_position_error = None + self.last_rotation_error = None + self.last_velocity_cmd = None + self.last_angular_velocity_cmd = None + self.last_target_reached = False + + logger.info( + f"Initialized PBVS controller: pos_gain={position_gain}, rot_gain={rotation_gain}, " + f"max_vel={max_velocity}m/s, max_ang_vel={max_angular_velocity}rad/s, " + f"target_tolerance={target_tolerance}m" + ) + + def clear_state(self) -> None: + """Clear controller state.""" + self.last_position_error = None + self.last_rotation_error = None + self.last_velocity_cmd = None + self.last_angular_velocity_cmd = None + self.last_target_reached = False + + def compute_control( + self, ee_pose: Pose, grasp_pose: Pose + ) -> tuple[Vector3 | None, Vector3 | None, bool]: + """ + Compute PBVS control with position and orientation servoing. + + Args: + ee_pose: Current end-effector pose + grasp_pose: Target grasp pose + + Returns: + Tuple of (velocity_command, angular_velocity_command, target_reached) + - velocity_command: Linear velocity vector + - angular_velocity_command: Angular velocity vector + - target_reached: True if within target tolerance + """ + # Calculate position error (target - EE position) + error = Vector3( + grasp_pose.position.x - ee_pose.position.x, + grasp_pose.position.y - ee_pose.position.y, + grasp_pose.position.z - ee_pose.position.z, + ) + self.last_position_error = error # type: ignore[assignment] + + # Compute velocity command with proportional control + velocity_cmd = Vector3( + error.x * self.position_gain, + error.y * self.position_gain, + error.z * self.position_gain, + ) + + # Limit velocity magnitude + vel_magnitude = np.linalg.norm([velocity_cmd.x, velocity_cmd.y, velocity_cmd.z]) + if vel_magnitude > self.max_velocity: + scale = self.max_velocity / vel_magnitude + velocity_cmd = Vector3( + float(velocity_cmd.x * scale), + float(velocity_cmd.y * scale), + float(velocity_cmd.z * scale), + ) + + self.last_velocity_cmd = velocity_cmd # type: ignore[assignment] + + # Compute angular velocity for orientation control + angular_velocity_cmd = self._compute_angular_velocity(grasp_pose.orientation, ee_pose) + + # Check if target reached + error_magnitude = np.linalg.norm([error.x, error.y, error.z]) + target_reached = bool(error_magnitude < self.target_tolerance) + self.last_target_reached = target_reached + + return velocity_cmd, angular_velocity_cmd, target_reached + + def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) -> Vector3: + """ + Compute angular velocity commands for orientation control. + Uses quaternion error computation for better numerical stability. + + Args: + target_rot: Target orientation (quaternion) + current_pose: Current EE pose + + Returns: + Angular velocity command as Vector3 + """ + # Use quaternion error for better numerical stability + + # Convert to scipy Rotation objects + target_rot_scipy = R.from_quat([target_rot.x, target_rot.y, target_rot.z, target_rot.w]) + current_rot_scipy = R.from_quat( + [ + current_pose.orientation.x, + current_pose.orientation.y, + current_pose.orientation.z, + current_pose.orientation.w, + ] + ) + + # Compute rotation error: error = target * current^(-1) + error_rot = target_rot_scipy * current_rot_scipy.inv() + + # Convert to axis-angle representation for control + error_axis_angle = error_rot.as_rotvec() + + # Use axis-angle directly as angular velocity error (small angle approximation) + roll_error = error_axis_angle[0] + pitch_error = error_axis_angle[1] + yaw_error = error_axis_angle[2] + + self.last_rotation_error = Vector3(roll_error, pitch_error, yaw_error) # type: ignore[assignment] + + # Apply proportional control + angular_velocity = Vector3( + roll_error * self.rotation_gain, + pitch_error * self.rotation_gain, + yaw_error * self.rotation_gain, + ) + + # Limit angular velocity magnitude + ang_vel_magnitude = np.sqrt( + angular_velocity.x**2 + angular_velocity.y**2 + angular_velocity.z**2 + ) + if ang_vel_magnitude > self.max_angular_velocity: + scale = self.max_angular_velocity / ang_vel_magnitude + angular_velocity = Vector3( + angular_velocity.x * scale, angular_velocity.y * scale, angular_velocity.z * scale + ) + + self.last_angular_velocity_cmd = angular_velocity # type: ignore[assignment] + + return angular_velocity + + def create_status_overlay( + self, + image: np.ndarray, # type: ignore[type-arg] + current_target: Detection3D | None = None, + ) -> np.ndarray: # type: ignore[type-arg] + """ + Create PBVS status overlay on image. + + Args: + image: Input image + current_target: Current target object Detection3D (for display) + + Returns: + Image with PBVS status overlay + """ + return create_pbvs_visualization( + image, + current_target, + self.last_position_error, + self.last_target_reached, + "velocity_control", + ) diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py new file mode 100644 index 0000000000..0f6a82458c --- /dev/null +++ b/dimos/manipulation/visual_servoing/utils.py @@ -0,0 +1,801 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +import cv2 +from dimos_lcm.vision_msgs import Detection2D, Detection3D # type: ignore[import-untyped] +import numpy as np + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.perception.common.utils import project_2d_points_to_3d +from dimos.perception.detection2d.utils import plot_results +from dimos.utils.transform_utils import ( + compose_transforms, + euler_to_quaternion, + get_distance, + matrix_to_pose, + offset_distance, + optical_to_robot_frame, + pose_to_matrix, + robot_to_optical_frame, + yaw_towards_point, +) + + +def match_detection_by_id( + detection_3d: Detection3D, detections_3d: list[Detection3D], detections_2d: list[Detection2D] +) -> Detection2D | None: + """ + Find the corresponding Detection2D for a given Detection3D. + + Args: + detection_3d: The Detection3D to match + detections_3d: List of all Detection3D objects + detections_2d: List of all Detection2D objects (must be 1:1 correspondence) + + Returns: + Corresponding Detection2D if found, None otherwise + """ + for i, det_3d in enumerate(detections_3d): + if det_3d.id == detection_3d.id and i < len(detections_2d): + return detections_2d[i] + return None + + +def transform_pose( + obj_pos: np.ndarray, # type: ignore[type-arg] + obj_orientation: np.ndarray, # type: ignore[type-arg] + transform_matrix: np.ndarray, # type: ignore[type-arg] + to_optical: bool = False, + to_robot: bool = False, +) -> Pose: + """ + Transform object pose with optional frame convention conversion. + + Args: + obj_pos: Object position [x, y, z] + obj_orientation: Object orientation [roll, pitch, yaw] in radians + transform_matrix: 4x4 transformation matrix from camera frame to desired frame + to_optical: If True, input is in robot frame → convert result to optical frame + to_robot: If True, input is in optical frame → convert to robot frame first + + Returns: + Object pose in desired frame as Pose + """ + # Convert euler angles to quaternion using utility function + euler_vector = Vector3(obj_orientation[0], obj_orientation[1], obj_orientation[2]) + obj_orientation_quat = euler_to_quaternion(euler_vector) + + input_pose = Pose( + position=Vector3(obj_pos[0], obj_pos[1], obj_pos[2]), orientation=obj_orientation_quat + ) + + # Apply input frame conversion based on flags + if to_robot: + # Input is in optical frame → convert to robot frame first + pose_for_transform = optical_to_robot_frame(input_pose) + else: + # Default or to_optical: use input pose as-is + pose_for_transform = input_pose + + # Create transformation matrix from pose (relative to camera) + T_camera_object = pose_to_matrix(pose_for_transform) + + # Use compose_transforms to combine transformations + T_desired_object = compose_transforms(transform_matrix, T_camera_object) + + # Convert back to pose + result_pose = matrix_to_pose(T_desired_object) + + # Apply output frame conversion based on flags + if to_optical: + # Input was robot frame → convert result to optical frame + desired_pose = robot_to_optical_frame(result_pose) + else: + # Default or to_robot: use result as-is + desired_pose = result_pose + + return desired_pose + + +def transform_points_3d( + points_3d: np.ndarray, # type: ignore[type-arg] + transform_matrix: np.ndarray, # type: ignore[type-arg] + to_optical: bool = False, + to_robot: bool = False, +) -> np.ndarray: # type: ignore[type-arg] + """ + Transform 3D points with optional frame convention conversion. + Applies the same transformation pipeline as transform_pose but for multiple points. + + Args: + points_3d: Nx3 array of 3D points [x, y, z] + transform_matrix: 4x4 transformation matrix from camera frame to desired frame + to_optical: If True, input is in robot frame → convert result to optical frame + to_robot: If True, input is in optical frame → convert to robot frame first + + Returns: + Nx3 array of transformed 3D points in desired frame + """ + if points_3d.size == 0: + return np.zeros((0, 3), dtype=np.float32) + + points_3d = np.asarray(points_3d) + if points_3d.ndim == 1: + points_3d = points_3d.reshape(1, -1) + + transformed_points = [] + + for point in points_3d: + input_point_pose = Pose( + position=Vector3(point[0], point[1], point[2]), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity quaternion + ) + + # Apply input frame conversion based on flags + if to_robot: + # Input is in optical frame → convert to robot frame first + pose_for_transform = optical_to_robot_frame(input_point_pose) + else: + # Default or to_optical: use input pose as-is + pose_for_transform = input_point_pose + + # Create transformation matrix from point pose (relative to camera) + T_camera_point = pose_to_matrix(pose_for_transform) + + # Use compose_transforms to combine transformations + T_desired_point = compose_transforms(transform_matrix, T_camera_point) + + # Convert back to pose + result_pose = matrix_to_pose(T_desired_point) + + # Apply output frame conversion based on flags + if to_optical: + # Input was robot frame → convert result to optical frame + desired_pose = robot_to_optical_frame(result_pose) + else: + # Default or to_robot: use result as-is + desired_pose = result_pose + + transformed_point = [ + desired_pose.position.x, + desired_pose.position.y, + desired_pose.position.z, + ] + transformed_points.append(transformed_point) + + return np.array(transformed_points, dtype=np.float32) + + +def select_points_from_depth( + depth_image: np.ndarray, # type: ignore[type-arg] + target_point: tuple[int, int], + camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] + radius: int = 5, +) -> np.ndarray: # type: ignore[type-arg] + """ + Select points around a target point within a bounding box and project them to 3D. + + Args: + depth_image: Depth image in meters (H, W) + target_point: (x, y) target point coordinates + radius: Half-width of the bounding box (so bbox size is radius*2 x radius*2) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx3 array of 3D points (X, Y, Z) in camera frame + """ + x_target, y_target = target_point + height, width = depth_image.shape + + x_min = max(0, x_target - radius) + x_max = min(width, x_target + radius) + y_min = max(0, y_target - radius) + y_max = min(height, y_target + radius) + + # Create coordinate grids for the bounding box (vectorized) + y_coords, x_coords = np.meshgrid(range(y_min, y_max), range(x_min, x_max), indexing="ij") + + # Flatten to get all coordinate pairs + x_flat = x_coords.flatten() + y_flat = y_coords.flatten() + + # Extract corresponding depth values using advanced indexing + depth_flat = depth_image[y_flat, x_flat] + + valid_mask = (depth_flat > 0) & np.isfinite(depth_flat) + + if not np.any(valid_mask): + return np.zeros((0, 3), dtype=np.float32) + + points_2d = np.column_stack([x_flat[valid_mask], y_flat[valid_mask]]).astype(np.float32) + depth_values = depth_flat[valid_mask].astype(np.float32) + + points_3d = project_2d_points_to_3d(points_2d, depth_values, camera_intrinsics) + + return points_3d + + +def update_target_grasp_pose( + target_pose: Pose, ee_pose: Pose, grasp_distance: float = 0.0, grasp_pitch_degrees: float = 45.0 +) -> Pose | None: + """ + Update target grasp pose based on current target pose and EE pose. + + Args: + target_pose: Target pose to grasp + ee_pose: Current end-effector pose + grasp_distance: Distance to maintain from target (pregrasp or grasp distance) + grasp_pitch_degrees: Grasp pitch angle in degrees (default 90° for top-down) + + Returns: + Target grasp pose or None if target is invalid + """ + + target_pos = target_pose.position + + # Calculate orientation pointing from target towards EE + yaw_to_ee = yaw_towards_point(target_pos, ee_pose.position) + + # Create target pose with proper orientation + # Convert grasp pitch from degrees to radians with mapping: + # 0° (level) -> π/2 (1.57 rad), 90° (top-down) -> π (3.14 rad) + pitch_radians = 1.57 + np.radians(grasp_pitch_degrees) + + # Convert euler angles to quaternion using utility function + euler = Vector3(0.0, pitch_radians, yaw_to_ee) # roll=0, pitch=mapped, yaw=calculated + target_orientation = euler_to_quaternion(euler) + + updated_pose = Pose(target_pos, target_orientation) + + if grasp_distance > 0.0: + return offset_distance(updated_pose, grasp_distance) + else: + return updated_pose + + +def is_target_reached(target_pose: Pose, current_pose: Pose, tolerance: float = 0.01) -> bool: + """ + Check if the target pose has been reached within tolerance. + + Args: + target_pose: Target pose to reach + current_pose: Current pose (e.g., end-effector pose) + tolerance: Distance threshold for considering target reached (meters, default 0.01 = 1cm) + + Returns: + True if target is reached within tolerance, False otherwise + """ + # Calculate position error using distance utility + error_magnitude = get_distance(target_pose, current_pose) + return error_magnitude < tolerance + + +@dataclass +class ObjectMatchResult: + """Result of object matching with confidence metrics.""" + + matched_object: Detection3D | None + confidence: float + distance: float + size_similarity: float + is_valid_match: bool + + +def calculate_object_similarity( + target_obj: Detection3D, + candidate_obj: Detection3D, + distance_weight: float = 0.6, + size_weight: float = 0.4, +) -> tuple[float, float, float]: + """ + Calculate comprehensive similarity between two objects. + + Args: + target_obj: Target Detection3D object + candidate_obj: Candidate Detection3D object + distance_weight: Weight for distance component (0-1) + size_weight: Weight for size component (0-1) + + Returns: + Tuple of (total_similarity, distance_m, size_similarity) + """ + # Extract positions + target_pos = target_obj.bbox.center.position + candidate_pos = candidate_obj.bbox.center.position + + target_xyz = np.array([target_pos.x, target_pos.y, target_pos.z]) + candidate_xyz = np.array([candidate_pos.x, candidate_pos.y, candidate_pos.z]) + + # Calculate Euclidean distance + distance = np.linalg.norm(target_xyz - candidate_xyz) + distance_similarity = 1.0 / (1.0 + distance) # Exponential decay + + # Calculate size similarity by comparing each dimension individually + size_similarity = 1.0 # Default if no size info + target_size = target_obj.bbox.size + candidate_size = candidate_obj.bbox.size + + if target_size and candidate_size: + # Extract dimensions + target_dims = [target_size.x, target_size.y, target_size.z] + candidate_dims = [candidate_size.x, candidate_size.y, candidate_size.z] + + # Calculate similarity for each dimension pair + dim_similarities = [] + for target_dim, candidate_dim in zip(target_dims, candidate_dims, strict=False): + if target_dim == 0.0 and candidate_dim == 0.0: + dim_similarities.append(1.0) # Both dimensions are zero + elif target_dim == 0.0 or candidate_dim == 0.0: + dim_similarities.append(0.0) # One dimension is zero, other is not + else: + # Calculate similarity as min/max ratio + max_dim = max(target_dim, candidate_dim) + min_dim = min(target_dim, candidate_dim) + dim_similarity = min_dim / max_dim if max_dim > 0 else 0.0 + dim_similarities.append(dim_similarity) + + # Return average similarity across all dimensions + size_similarity = np.mean(dim_similarities) if dim_similarities else 0.0 # type: ignore[assignment] + + # Weighted combination + total_similarity = distance_weight * distance_similarity + size_weight * size_similarity + + return total_similarity, distance, size_similarity # type: ignore[return-value] + + +def find_best_object_match( + target_obj: Detection3D, + candidates: list[Detection3D], + max_distance: float = 0.1, + min_size_similarity: float = 0.4, + distance_weight: float = 0.7, + size_weight: float = 0.3, +) -> ObjectMatchResult: + """ + Find the best matching object from candidates using distance and size criteria. + + Args: + target_obj: Target Detection3D to match against + candidates: List of candidate Detection3D objects + max_distance: Maximum allowed distance for valid match (meters) + min_size_similarity: Minimum size similarity for valid match (0-1) + distance_weight: Weight for distance in similarity calculation + size_weight: Weight for size in similarity calculation + + Returns: + ObjectMatchResult with best match and confidence metrics + """ + if not candidates or not target_obj.bbox or not target_obj.bbox.center: + return ObjectMatchResult(None, 0.0, float("inf"), 0.0, False) + + best_match = None + best_confidence = 0.0 + best_distance = float("inf") + best_size_sim = 0.0 + + for candidate in candidates: + if not candidate.bbox or not candidate.bbox.center: + continue + + similarity, distance, size_sim = calculate_object_similarity( + target_obj, candidate, distance_weight, size_weight + ) + + # Check validity constraints + is_valid = distance <= max_distance and size_sim >= min_size_similarity + + if is_valid and similarity > best_confidence: + best_match = candidate + best_confidence = similarity + best_distance = distance + best_size_sim = size_sim + + return ObjectMatchResult( + matched_object=best_match, + confidence=best_confidence, + distance=best_distance, + size_similarity=best_size_sim, + is_valid_match=best_match is not None, + ) + + +def parse_zed_pose(zed_pose_data: dict[str, Any]) -> Pose | None: + """ + Parse ZED pose data dictionary into a Pose object. + + Args: + zed_pose_data: Dictionary from ZEDCamera.get_pose() containing: + - position: [x, y, z] in meters + - rotation: [x, y, z, w] quaternion + - euler_angles: [roll, pitch, yaw] in radians + - valid: Whether pose is valid + + Returns: + Pose object with position and orientation, or None if invalid + """ + if not zed_pose_data or not zed_pose_data.get("valid", False): + return None + + # Extract position + position = zed_pose_data.get("position", [0, 0, 0]) + pos_vector = Vector3(position[0], position[1], position[2]) + + quat = zed_pose_data["rotation"] + orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) + return Pose(position=pos_vector, orientation=orientation) + + +def estimate_object_depth( + depth_image: np.ndarray, # type: ignore[type-arg] + segmentation_mask: np.ndarray | None, # type: ignore[type-arg] + bbox: list[float], +) -> float: + """ + Estimate object depth dimension using segmentation mask and depth data. + Optimized for real-time performance. + + Args: + depth_image: Depth image in meters + segmentation_mask: Binary segmentation mask for the object + bbox: Bounding box [x1, y1, x2, y2] + + Returns: + Estimated object depth in meters + """ + x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + + # Extract depth ROI once + roi_depth = depth_image[y1:y2, x1:x2] + + if segmentation_mask is not None and segmentation_mask.size > 0: + # Extract mask ROI efficiently + mask_roi = ( + segmentation_mask[y1:y2, x1:x2] + if segmentation_mask.shape != roi_depth.shape + else segmentation_mask + ) + + # Fast mask application using boolean indexing + valid_mask = mask_roi > 0 + if np.sum(valid_mask) > 10: # Early exit if not enough points + masked_depths = roi_depth[valid_mask] + + # Fast percentile calculation using numpy's optimized functions + depth_90 = np.percentile(masked_depths, 90) + depth_10 = np.percentile(masked_depths, 10) + depth_range = depth_90 - depth_10 + + # Clamp to reasonable bounds with single operation + return np.clip(depth_range, 0.02, 0.5) # type: ignore[no-any-return] + + # Fast fallback using area calculation + bbox_area = (x2 - x1) * (y2 - y1) + + # Vectorized area-based estimation + if bbox_area > 10000: + return 0.15 + elif bbox_area > 5000: + return 0.10 + else: + return 0.05 + + +# ============= Visualization Functions ============= + + +def create_manipulation_visualization( # type: ignore[no-untyped-def] + rgb_image: np.ndarray, # type: ignore[type-arg] + feedback, + detection_3d_array=None, + detection_2d_array=None, +) -> np.ndarray: # type: ignore[type-arg] + """ + Create simple visualization for manipulation class using feedback. + + Args: + rgb_image: RGB image array + feedback: Feedback object containing all state information + detection_3d_array: Optional 3D detections for object visualization + detection_2d_array: Optional 2D detections for object visualization + + Returns: + BGR image with visualization overlays + """ + # Convert to BGR for OpenCV + viz = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Draw detections if available + if detection_3d_array and detection_2d_array: + # Extract 2D bboxes + bboxes_2d = [] + for det_2d in detection_2d_array.detections: + if det_2d.bbox: + x1 = det_2d.bbox.center.position.x - det_2d.bbox.size_x / 2 + y1 = det_2d.bbox.center.position.y - det_2d.bbox.size_y / 2 + x2 = det_2d.bbox.center.position.x + det_2d.bbox.size_x / 2 + y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2 + bboxes_2d.append([x1, y1, x2, y2]) + + # Draw basic detections + rgb_with_detections = visualize_detections_3d( + rgb_image, detection_3d_array.detections, show_coordinates=True, bboxes_2d=bboxes_2d + ) + viz = cv2.cvtColor(rgb_with_detections, cv2.COLOR_RGB2BGR) + + # Add manipulation status overlay + status_y = 30 + cv2.putText( + viz, + "Eye-in-Hand Visual Servoing", + (10, status_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Stage information + stage_text = f"Stage: {feedback.grasp_stage.value.upper()}" + stage_color = { + "idle": (100, 100, 100), + "pre_grasp": (0, 255, 255), + "grasp": (0, 255, 0), + "close_and_retract": (255, 0, 255), + "place": (0, 150, 255), + "retract": (255, 150, 0), + }.get(feedback.grasp_stage.value, (255, 255, 255)) + + cv2.putText( + viz, + stage_text, + (10, status_y + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + stage_color, + 1, + ) + + # Target tracking status + if feedback.target_tracked: + cv2.putText( + viz, + "Target: TRACKED", + (10, status_y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 1, + ) + elif feedback.grasp_stage.value != "idle": + cv2.putText( + viz, + "Target: LOST", + (10, status_y + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 0, 255), + 1, + ) + + # Waiting status + if feedback.waiting_for_reach: + cv2.putText( + viz, + "Status: WAITING FOR ROBOT", + (10, status_y + 65), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 0), + 1, + ) + + # Overall result + if feedback.success is not None: + result_text = "Pick & Place: SUCCESS" if feedback.success else "Pick & Place: FAILED" + result_color = (0, 255, 0) if feedback.success else (0, 0, 255) + cv2.putText( + viz, + result_text, + (10, status_y + 85), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + result_color, + 2, + ) + + # Control hints (bottom of image) + hint_text = "Click object to grasp | s=STOP | r=RESET | g=RELEASE" + cv2.putText( + viz, + hint_text, + (10, viz.shape[0] - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (200, 200, 200), + 1, + ) + + return viz + + +def create_pbvs_visualization( # type: ignore[no-untyped-def] + image: np.ndarray, # type: ignore[type-arg] + current_target=None, + position_error=None, + target_reached: bool = False, + grasp_stage: str = "idle", +) -> np.ndarray: # type: ignore[type-arg] + """ + Create simple PBVS visualization overlay. + + Args: + image: Input image (RGB or BGR) + current_target: Current target Detection3D + position_error: Position error Vector3 + target_reached: Whether target is reached + grasp_stage: Current grasp stage string + + Returns: + Image with PBVS overlay + """ + viz = image.copy() + + # Only show PBVS info if we have a target + if current_target is None: + return viz + + # Create status panel at bottom + height, width = viz.shape[:2] + panel_height = 100 + panel_y = height - panel_height + + # Semi-transparent overlay + overlay = viz.copy() + cv2.rectangle(overlay, (0, panel_y), (width, height), (0, 0, 0), -1) + viz = cv2.addWeighted(viz, 0.7, overlay, 0.3, 0) + + # PBVS Status + y_offset = panel_y + 20 + cv2.putText( + viz, + "PBVS Control", + (10, y_offset), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Position error + if position_error: + error_mag = np.linalg.norm([position_error.x, position_error.y, position_error.z]) + error_text = f"Error: {error_mag * 100:.1f}cm" + error_color = (0, 255, 0) if target_reached else (0, 255, 255) + cv2.putText( + viz, + error_text, + (10, y_offset + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + error_color, + 1, + ) + + # Stage + cv2.putText( + viz, + f"Stage: {grasp_stage}", + (10, y_offset + 45), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 150, 255), + 1, + ) + + # Target reached indicator + if target_reached: + cv2.putText( + viz, + "TARGET REACHED", + (width - 150, y_offset + 25), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + return viz + + +def visualize_detections_3d( + rgb_image: np.ndarray, # type: ignore[type-arg] + detections: list[Detection3D], + show_coordinates: bool = True, + bboxes_2d: list[list[float]] | None = None, +) -> np.ndarray: # type: ignore[type-arg] + """ + Visualize detections with 3D position overlay next to bounding boxes. + + Args: + rgb_image: Original RGB image + detections: List of Detection3D objects + show_coordinates: Whether to show 3D coordinates next to bounding boxes + bboxes_2d: Optional list of 2D bounding boxes corresponding to detections + + Returns: + Visualization image + """ + if not detections: + return rgb_image.copy() + + # If no 2D bboxes provided, skip visualization + if bboxes_2d is None: + return rgb_image.copy() + + # Extract data for plot_results function + bboxes = bboxes_2d + track_ids = [int(det.id) if det.id.isdigit() else i for i, det in enumerate(detections)] + class_ids = [i for i in range(len(detections))] + confidences = [ + det.results[0].hypothesis.score if det.results_length > 0 else 0.0 for det in detections + ] + names = [ + det.results[0].hypothesis.class_id if det.results_length > 0 else "unknown" + for det in detections + ] + + # Use plot_results for basic visualization + viz = plot_results(rgb_image, bboxes, track_ids, class_ids, confidences, names) + + # Add 3D position coordinates if requested + if show_coordinates and bboxes_2d is not None: + for i, det in enumerate(detections): + if det.bbox and det.bbox.center and i < len(bboxes_2d): + position = det.bbox.center.position + bbox = bboxes_2d[i] + + pos_xyz = np.array([position.x, position.y, position.z]) + + # Get bounding box coordinates + _x1, y1, x2, _y2 = map(int, bbox) + + # Add position text next to bounding box (top-right corner) + pos_text = f"({pos_xyz[0]:.2f}, {pos_xyz[1]:.2f}, {pos_xyz[2]:.2f})" + text_x = x2 + 5 # Right edge of bbox + small offset + text_y = y1 + 15 # Top edge of bbox + small offset + + # Add background rectangle for better readability + text_size = cv2.getTextSize(pos_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0] + cv2.rectangle( + viz, + (text_x - 2, text_y - text_size[1] - 2), + (text_x + text_size[0] + 2, text_y + 2), + (0, 0, 0), + -1, + ) + + cv2.putText( + viz, + pos_text, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 255, 255), + 1, + ) + + return viz # type: ignore[no-any-return] diff --git a/dimos/data/recording.py b/dimos/mapping/__init__.py similarity index 100% rename from dimos/data/recording.py rename to dimos/mapping/__init__.py diff --git a/dimos/mapping/google_maps/conftest.py b/dimos/mapping/google_maps/conftest.py new file mode 100644 index 0000000000..725100bcc8 --- /dev/null +++ b/dimos/mapping/google_maps/conftest.py @@ -0,0 +1,38 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +import pytest + +from dimos.mapping.google_maps.google_maps import GoogleMaps + +_FIXTURE_DIR = Path(__file__).parent / "fixtures" + + +@pytest.fixture +def maps_client(mocker): + ret = GoogleMaps() + ret._client = mocker.MagicMock() + return ret + + +@pytest.fixture +def maps_fixture(): + def open_file(relative: str) -> str: + with open(_FIXTURE_DIR / relative) as f: + return json.load(f) + + return open_file diff --git a/dimos/mapping/google_maps/fixtures/get_location_context_places_nearby.json b/dimos/mapping/google_maps/fixtures/get_location_context_places_nearby.json new file mode 100644 index 0000000000..9196eaadee --- /dev/null +++ b/dimos/mapping/google_maps/fixtures/get_location_context_places_nearby.json @@ -0,0 +1,965 @@ +{ + "html_attributions": [], + "next_page_token": "AciIO2fBDpHRl2XoG9zreRkt9prSCk9LDy3sxfc-6uK7JcTxGSvbWY-XX87H38Pr547AkGKiHbzLzhvJxo99ZgbyGYP-9On6WhEFfvtiSnxWrLbz3V7Cfwpi_2GYt1TMeAqGnGlhFev1--1WgmfBnapSl95c7Myuh4Yby8UM34rMAWh9Md-T9DOVExJuqunnZMrS2ViNa1IRyboIu9ixrNTNYJXQ6hoSVlkM26Yw2sJB900sQFiChr_FrDIP6dbdIzZMZ3si7-3CFrR4gy6Y6wlyeVEiriGye9cFi8U0d0BprgdSIHC3hmp-pG8qtOHvn5tXJp6bDvU12hvRL32D4FFxgM1xKHqGdrun3N06tW2G_XuXZww3voN-bZh2y5y8ubZRJbcLjZQ-rpMUKVsfNPbdVYYPgV0oiLA8IlPQkbF5MM4M", + "results": [ + { + "geometry": { + "location": { + "lat": 37.7749295, + "lng": -122.4194155 + }, + "viewport": { + "northeast": { + "lat": 37.812, + "lng": -122.3482 + }, + "southwest": { + "lat": 37.70339999999999, + "lng": -122.527 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/geocode-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "San Francisco", + "photos": [ + { + "height": 675, + "html_attributions": [ + "Zameer Dalvi" + ], + "photo_reference": "AciIO2d9Esuu4AjK5SCX_Byk2t2jNOCJ1TkBc9V7So6HH2AjHH7SccRs-n7fGxN2bdQdm_t-jrSdyt7rmGoPil2-_phu5dXOszGmOG6HITWPRmQOajaPG4WrTQvAV6BCs5RGFq3NxJZ-uFyHCT472OFg15-d-iytsU_nKWjPuX1xCwmNDmuWxTc8YBWi05Cf0MxIFsVw7oj5gaHvGFx0ngYJlk67Jwl6vOTIBiEHfseOHkGhkMD7tX-RCPBhnaAUgGXRbuawYXkiu32c9RhxRaXReyFE_TtX09yqvmA6zr9WhaCLT0vTt4-KMOxpoACBnVt7gYVvRk-FWUXBiHISzppFi6o7FbEW4OE4WWsAXSFamzI5Z5Co9cAb8BTPZX8P3E-tZiWyoOb1WyhqjpGPKYsa7YJ_SRLFMI3kv8GWOb744A4t-3kLBIgZQi9nE5M4cfqmMmdofXLEct9srvrDVEjKns5kP3yp94xrV9205rGcqMtQ3rcQWhl62pLDxf3iEahwvxV-adcMVmaPjLFCrPiUCT1xKtBtRSQDjPcuUMBPaZ-7ylCuFvJLSEaEt8WpDiSDbn22NiuM0hPqu8tqL7hJpxsXPi6fLCreITtMwCBK_sS_-3C--VNxDhyAIAdjA3iOPnTtIw", + "width": 1080 + } + ], + "place_id": "ChIJIQBpAG2ahYAR_6128GcTUEo", + "reference": "ChIJIQBpAG2ahYAR_6128GcTUEo", + "scope": "GOOGLE", + "types": [ + "locality", + "political" + ], + "vicinity": "San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7795744, + "lng": -122.4137147 + }, + "viewport": { + "northeast": { + "lat": 37.78132539999999, + "lng": -122.41152835 + }, + "southwest": { + "lat": 37.777981, + "lng": -122.41572655 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "Civic Center / UN Plaza", + "photos": [ + { + "height": 3072, + "html_attributions": [ + "Neal" + ], + "photo_reference": "AciIO2eQ5UqsRXvWTmbnbL9VjSIH-1SXRtU1k0UuJlVEyM_giS9ELQ-M4rjAF2wkan-7aE2l4yFtF4QmTEvORdTaj_lgO-_r9nTF2z7FKAFGcFxLL4wff1BD2NRu1cfYVWvStgOkdKGbZOmqKEpSU7qoFM_GjUdLO5ztvMCAJ8_h0-3VDy33ha8hGIa8AGuLhpitRAsRK9sztugTtxtaOruuuTtagZdfpyIvUjW1pJMCR3thLaWO2C4DVElGqhv4tynPVByugRqINceswryUNVh1yf_TD664L6AyyqjIL5Vv2583bIEefWHB3uEYJA2ohOV2YW_XhH5rY8Xg5Rdy6i8EUtW9GiVH694YHIgDEZsT-Or4uw_OHHYANd3z7MuQmLZ_JzyUCr8_ex8qxfzluml2bkfciWx3cqJ7YzodaED5nvzjffEuKXwp8cIz5cWF-xm1XSbTWZK5dafqVTC83ps9wDvoCmkPY2lXOgXhmTv85VTQNe8nj75LsplDo73CPg4XFRi6fZi-oicmtCjdjzpjUTHbHe3PEGB1F11BOPh_Hx8QkZlbWwIFooJc9FF8dgAh1GQzlwYb93tcPmRLAiaunw-h9F3eKDb7YghwBPtiBh6HygyNMnA4gtqdBd_qGQ6rVt9cLGCz", + "width": 4080 + } + ], + "place_id": "ChIJYTKuRpuAhYAR8O67wA_IE9s", + "plus_code": { + "compound_code": "QHHP+RG Mid-Market, San Francisco, CA, USA", + "global_code": "849VQHHP+RG" + }, + "rating": 3.5, + "reference": "ChIJYTKuRpuAhYAR8O67wA_IE9s", + "scope": "GOOGLE", + "types": [ + "subway_station", + "transit_station", + "point_of_interest", + "establishment" + ], + "user_ratings_total": 375, + "vicinity": "1150 Market Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7802611, + "lng": -122.4145017 + }, + "viewport": { + "northeast": { + "lat": 37.7817168802915, + "lng": -122.4131737197085 + }, + "southwest": { + "lat": 37.7790189197085, + "lng": -122.4158716802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/civic_building-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/civic-bldg_pinlet", + "name": "U.S. General Services Administration - Pacific Rim Region", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 2448, + "html_attributions": [ + "Wai Ki Wong" + ], + "photo_reference": "AciIO2cN35fxs7byGa6qiTiJAxxJMorGoHDJp95RMFDnTMm-wDrb0QUZbujgJUBIV3uLQuBDpEdvyxxzc-fyzT3DgFlJSLKcnPcm_A-Fe3cj7rjPdEO9VMj0HHRf0aqDnRQXmtv2Ouh3QUH8OdvaoOlNMw293LOxjri9JvpjhPHCwJvwkKjxFYButiE_7XywtIRyQXRkZyDKxqKVxITircGB1P3efABFUQIye8hA71QZqTfYnBzT5wDSoV3oZRaB9aXUlTDGzNl3rJXE74BrlpgVhf-uYP_POcNqMbYmLXyWOjjVEZ4YZL58Ls53etW_ZUGGeiUAcrI3Uuq4glX5GRfGHssf_dqOWA29j0HZh6A_OFSluLSDbpy-HgXcW4Zg_qgF6XqobV78J_Ira4m8lgHiT3nDffo2YfELDcIvFxOJwpl1W3TUWawmHqvHiVTvHAQ_8-TcWE_rGCVIAAc8I0W25qRFngkVJ828ZIMHsnEiLLgsKTQlxKW94uAC8kgxh6v-iXP_7vP6-0aWGkFs4a2irwfQK5n5fKmDz7LBdVjyuAhoHwcCwE8VTn0wtwUcuiVCVBFs4-AnLWhwnVxf3fdmcMsZm91lPbm3fECbnt6SBhvXR48cM_ZZpMiyfIF1QuNE-vhfsnlK", + "width": 3264 + } + ], + "place_id": "ChIJZxSUVZeAhYAReWcieluNDvY", + "plus_code": { + "compound_code": "QHJP+45 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+45" + }, + "rating": 2, + "reference": "ChIJZxSUVZeAhYAReWcieluNDvY", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "user_ratings_total": 4, + "vicinity": "50 United Nations Plaza, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7801589, + "lng": -122.4143371 + }, + "viewport": { + "northeast": { + "lat": 37.7818405302915, + "lng": -122.4131042697085 + }, + "southwest": { + "lat": 37.7791425697085, + "lng": -122.4158022302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/civic_building-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/civic-bldg_pinlet", + "name": "Federal Office Building", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 3024, + "html_attributions": [ + "espresso" + ], + "photo_reference": "AciIO2eg880rvWAbnaX5xUzP_b6dEVp4hOnZqnxo7W_1S2BZwdC0H9io5KptUm2MGue3FOw3KWPjTZeVu8B_gnFh-5EyAhJHqhlDrllLsL8K-cumkjtTDT3mxDaDXeU7XB9BWD7S0g0f4qbjEu_sKhvWAXE81_r1W5I8minbMbvzu3eU1sYICwWOk_5g4D1-690I_4V-4aJ-fDD04kHxsqkweZcxzUHgrmcKEOlt48UKVHe-GEOLD5-BRNZ3k4tx50T1SKqPeNUI_WtTrYkSkeNzCp4t9680YqCW7LBsES9viJdW_QBTgQd59gvMeIWEXQ-YBGPEobIS0hE73Eedi_1ATESgKI-tzOeeoeytLnmFFVC8c2obgt2Bd7cLOFjIjm5Oxn9jH0auBWPx8JsQifkXiyhXz2VP2AawCmID4TMtMwt-9ozTV6I_j5f_guI34w7MxKnHiyTQvupi0S4O2ByezHx56M7Ptmxjk8yia84SG20H7sRhEk3yeQHl_ujDGYhNFCtPmHWkCsdWm1go-FuMalIzkUL4ERuREN1hhdvYhswbbigJUG8mKKOBzHuPVLNK5KFs_N7E5l4g3v-drOKe1m_GafTHwQDRvEzJfL0UnIERhRYcRLMJWxeEbjtsnKch", + "width": 4032 + } + ], + "place_id": "ChIJzdTUHpuAhYAR3ZHR1a8TJ-k", + "plus_code": { + "compound_code": "QHJP+37 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+37" + }, + "rating": 4.2, + "reference": "ChIJzdTUHpuAhYAR3ZHR1a8TJ-k", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "user_ratings_total": 5, + "vicinity": "50 United Nations Plaza, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7799364, + "lng": -122.4147625 + }, + "viewport": { + "northeast": { + "lat": 37.78122733029149, + "lng": -122.4136141697085 + }, + "southwest": { + "lat": 37.7785293697085, + "lng": -122.4163121302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/civic_building-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/civic-bldg_pinlet", + "name": "UN Plaza", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 3024, + "html_attributions": [ + "Douglas Cheung" + ], + "photo_reference": "AciIO2f7vWVfkBpMV-nKU0k06pZS--irdg7JnnJBrztgXtRf0MYFd0085Gfjm7TjBB4bCefkTdJBsNtyKiklgknHCuWhz3aqwx81XHDM51Jn-g5wI0hbG6dx8RpheFxfht_vpk9CQgjjg8mFEUp-aQaEc3hivi_bog295AUmEKdhTCRlYWLQJFPEpP-AKOpLwXdKYAjddd2nh18x9p8-gF0WphREBQFaOChd9lnWyuSKX-MOecG-ff1Brwpkcroc6VUeW6z1RQcLFNCUOomOpBCmeujvTquM_bI7a6T4WzM2o6Et_47EXmPzJhSAONorX8epNNHjZspoAd-LZ_PrBgy8H-WQEm6vlY88Dtc1Sucewnrv4Cd8xm2I1ywKPSsd2mgYBMVAipSS2XHuufe5FWzZM9vPZonW0Vb-X6HOAnVeQ52ZxNddc5pjDtU5GOZNb2oF-uLwo5-qrplZDryO5if0CPQRzE6iRbO9xLsWV0S7MGmxJ_bZk7nxWXjKAFNITIZ6dQcGJxuWH_LKDsF3Sfbg1emM4Xdujx0ZHhgFcBISAfHjX5hf0kBxGhpMlFIPxRns2Eng4HzTaebZAmMeqDoN_3KlnAof47SQyeLSQNy1K6PjWGrIPfaVOpubOTLJF_dLKt5pxQ", + "width": 4032 + } + ], + "place_id": "ChIJ60hDVZeAhYAReuCqOWYsr_k", + "plus_code": { + "compound_code": "QHHP+X3 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+X3" + }, + "rating": 4, + "reference": "ChIJ60hDVZeAhYAReuCqOWYsr_k", + "scope": "GOOGLE", + "types": [ + "city_hall", + "point_of_interest", + "local_government_office", + "establishment" + ], + "user_ratings_total": 428, + "vicinity": "355 McAllister Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.781006, + "lng": -122.4143741 + }, + "viewport": { + "northeast": { + "lat": 37.78226673029149, + "lng": -122.4129892697085 + }, + "southwest": { + "lat": 37.7795687697085, + "lng": -122.4156872302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/shopping-71.png", + "icon_background_color": "#4B96F3", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/shopping_pinlet", + "name": "McAllister Market & Deli", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 4608, + "html_attributions": [ + "Asteria Moore" + ], + "photo_reference": "AciIO2chI9JnbQNwZt2yo7E--ruAq6ax7U4NrW_3PcNpGgFzXhxMqvYTtktvSLwFO5k21vHpEH-2AMYuaD6qctoIYdyt_g5EWhF88Ptb75HmmIEQzMqk2Ktpe3Vx06TnJKF47TZnQupjVdy_YTW3XGOGkA33Phe8I3I9szr54QqmYLFs6fPJMxo-M3keen9PlFiqqjvKAV170CuJ6HQ70AkRREWq3h18IcPUHHEKiZng5TKPSB7t_3dbyB_DWETnVQHu6P33XEmcKw77rgCuUogyxXZNMBulq305-FtBlH5lnvjy1F5Hpwf-q5cSB_40p082Joz0Vyazc1o4s-hnEyUnaQ6Zra1B_ODKvHqEKHoeJUKT4nAfFU4kBE5A7nmxkozqyks4MfaoN_P72atAhggEV5rog4EEtzFyeC1bx8GtQKhYccbeANSF5R9mAEpeefOrpYZpNW1uLffUMOpceZpZtNsE-yG59_v-56V1dxqCIGW9KOtVmfoEL0WLP6l-pMhKMv3EdSRmGqhbRtCA2fZNyFBWRyMwpfToRImtYxRbMiqriGONDU1e1m8j895QvLDknS6lY_qRMNv4YY3FLooGcag4YzcaDHwtI-ipxEcFknzhIIYt-_fdlTcUk0JMctC5re--5A", + "width": 2592 + } + ], + "place_id": "ChIJz3oI4ZqAhYARYviYtbeKIFQ", + "plus_code": { + "compound_code": "QHJP+C7 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+C7" + }, + "rating": 3.6, + "reference": "ChIJz3oI4ZqAhYARYviYtbeKIFQ", + "scope": "GOOGLE", + "types": [ + "liquor_store", + "atm", + "grocery_or_supermarket", + "finance", + "point_of_interest", + "food", + "store", + "establishment" + ], + "user_ratings_total": 12, + "vicinity": "136 McAllister Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7802423, + "lng": -122.4145234 + }, + "viewport": { + "northeast": { + "lat": 37.78171363029151, + "lng": -122.4131986197085 + }, + "southwest": { + "lat": 37.77901566970851, + "lng": -122.4158965802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/civic_building-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/civic-bldg_pinlet", + "name": "US Health & Human Services Department/Office of the Regional Director", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 1200, + "html_attributions": [ + "Patrick Berkeley" + ], + "photo_reference": "AciIO2eP4AmtKmitmIZbdY4bI2mc8aCNzT2vh8plui7wj0BJt-51HlfW7-arowozWM9Os9hSUBkXItcmlXnH08GpOYXc1u6gN-XmO7AL9ifSJfgWYt6XE0CkXfQ9iBQdHF1WFlfteWLOvL0mev0reMuAz78N7It7eWQY8HW3nm2_i14G_R51kbRK2djxoWjDqY9-xP5hTxWUs1u7JFqXtzOZAeMGlhFHHmqVe4A8nWMP7tr6Y385wmCIJvGwXivQmct7flmN6NpNqqp1U5CI1jy60x7Z2Zoq_uxzWpIB-1M-VRMJHblbb_1rPAc1Sg29n5XfhX4E1M1YqlEBdqg08VaqQSLbaJEHkvfDMFKlN36IsZmb8mZfFEinYSmkcISO6x-vuhgR7G4FJZLtt74goVGKIPsQoC9oPsPyN0mLaQJs9ZTS6D2mw5zIQXYBs2IfBdnG9sWDCQTujtdGWJv_SlWUHW499I-NK0MzNPjpLB4FW3dYOuqDQdk-8hzC1A5giSjr7J783WRLVhVKjfo8G8vCPCSY4JW6x3XB5bl9IJn5j_47sGhJOrHnHVkNaMmJMtdhGflXwT42-i033uzLJEGN1e887Jqe7OHRHqa97oPbXu3FQgVPjXvdBX33gmXc8XXeDg7gcQ", + "width": 1600 + } + ], + "place_id": "ChIJ84fbMZuAhYARravvIpQYCY8", + "plus_code": { + "compound_code": "QHJP+35 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+35" + }, + "rating": 4, + "reference": "ChIJ84fbMZuAhYARravvIpQYCY8", + "scope": "GOOGLE", + "types": [ + "local_government_office", + "health", + "point_of_interest", + "establishment" + ], + "user_ratings_total": 1, + "vicinity": "San Francisco Federal Building, 90 7th Street #5, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7794949, + "lng": -122.414318 + }, + "viewport": { + "northeast": { + "lat": 37.78079848029149, + "lng": -122.4128637197085 + }, + "southwest": { + "lat": 37.7781005197085, + "lng": -122.4155616802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/school-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/school_pinlet", + "name": "Oasis For Girls", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 3024, + "html_attributions": [ + "Alex" + ], + "photo_reference": "AciIO2cENrSmK967GV0iLgnIakOvEMavm9r5kA_LjIOHIji_Pc0T74VL-vwiFlUgoVgetRw9B-PzYrJ54EVfnbUQT-9XRi2LGt9rUOGX6V7h7lOVqgEJ1eaWEUtTDyk93eQRs3cc3GhXY2RIjL-nVdaxkwRc_RWpRPLcc8Om_aTYwyCQ5S7ZpmxPS419DoCJHt4sQJqzRsD6gz7I8AGj0c03MHYascQn4efsvFhjzaPex21ZKI9iGz923oe9WM8zq4BhgKJ3B9_IITYDuoO1mYdyIgU57ceuRoKb6n4zoCgyhLne1_SzGnFz7DrP9jL8luHSVHeoZcSKmU34Gr-sGfVs4kfH33lzlNurHQI6gIoOOWOXq7BTP-Jf5ArqGexfQfue7IGJpYjR4p5r4cJZ-dd0tzhlGvrZ2cSEnjQdv4oTx3U3kElm6foWI3xySsa1jmqsZ8BBBzEQ75rzHHhsW26xwwR9ZIKYV-_DZ9r0hrb0qPCEF3aAC9r2m6rfwrHWAfDy_-Egmv_5T1QyBFaAUT0Faay7EezCxCyWwx_0x0o2DRIOAcA8a01veJJPv1LhYcXCUnTgIATbSr-t30d9FdosyX0Vk9w4eSXU6B4qUWpusHVHPShTHhAcLMig0OOIXlZyyWtPT2sb", + "width": 4032 + } + ], + "place_id": "ChIJyTuyEoKAhYARr0GnPKZSGCk", + "plus_code": { + "compound_code": "QHHP+Q7 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+Q7" + }, + "rating": 5, + "reference": "ChIJyTuyEoKAhYARr0GnPKZSGCk", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "user_ratings_total": 4, + "vicinity": "1170 Market Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.77929669999999, + "lng": -122.4143825 + }, + "viewport": { + "northeast": { + "lat": 37.78060218029149, + "lng": -122.4129812697085 + }, + "southwest": { + "lat": 37.77790421970849, + "lng": -122.4156792302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "San Francisco Culinary Bartenders & Service Employees Trust Funds", + "place_id": "ChIJpS60CuyBt4cRzO3UB4vL3L0", + "plus_code": { + "compound_code": "QHHP+P6 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+P6" + }, + "rating": 3.3, + "reference": "ChIJpS60CuyBt4cRzO3UB4vL3L0", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "user_ratings_total": 6, + "vicinity": "1182 Market Street #320, San Francisco" + }, + { + "business_status": "CLOSED_TEMPORARILY", + "geometry": { + "location": { + "lat": 37.7801722, + "lng": -122.4140068 + }, + "viewport": { + "northeast": { + "lat": 37.7817733302915, + "lng": -122.4129124197085 + }, + "southwest": { + "lat": 37.7790753697085, + "lng": -122.4156103802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/civic_building-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/civic-bldg_pinlet", + "name": "San Francisco Federal Executive Board", + "permanently_closed": true, + "photos": [ + { + "height": 943, + "html_attributions": [ + "San Francisco Federal Executive Board" + ], + "photo_reference": "AciIO2ecs5V8ZC8IEmpnMKdhn2pSWsCYSZ6C9Zf6lnQbp3owjaXeXRZuPMtnIJag_ga0uw8Jwa8SB-Wsb2YyB9PrdAzutETaYb56zja6D8NwiKdf9Z4EGnZ45JH20x7119EzrunOm1q4Ii6wuY0TudtYsadmJC0NPLnUZlua4PNnW7Zl76OQwLBcaPWu6rXBHCTT6iiBqSZeKiKJ8w4RzttHfN3oYB-IE02CXQPQX1xxFEeQ5cyuGPtv8ghXHRoSJdhvYDH_P0aSrOt9ibRtrH5kv7nAamKSVUNWvT5vuPrXao9PkaJd5f16tZiDoM_61tat9r1izspBFhU", + "width": 943 + } + ], + "place_id": "ChIJu4Q_XDqBhYARojXRyiKC12g", + "plus_code": { + "compound_code": "QHJP+39 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+39" + }, + "reference": "ChIJu4Q_XDqBhYARojXRyiKC12g", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "vicinity": "50 United Nations Plaza, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.779756, + "lng": -122.41415 + }, + "viewport": { + "northeast": { + "lat": 37.78130935000001, + "lng": -122.411308 + }, + "southwest": { + "lat": 37.77806635, + "lng": -122.4163908 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "Civic Center/UN Plaza BART Station", + "photos": [ + { + "height": 4032, + "html_attributions": [ + "Arthur Glauberman" + ], + "photo_reference": "AciIO2f1VMpAIRJouUVjkeEUyHB-4jzRZ2_U3kfRr-LaavcPlVYClnn2DMGMiWo9Oun0t-qo9z5WIHp1BQBHazbPqrWnSGvQoO3FpJMra0OOGSgrpsD5T4dvinfSzWqwOOlRtMyQ4vlGvR99TpxcNVcasRyNflpZxRcYD9nBUPnrNUstxTCfKqSqLdYD3ZI0xZiX3wOJ_hlUVgRfSs04iqzREGvRR8cZRaufh1Hakq3bzaBL1KGuLF8ggV94iGQmzWYmU_FddWgH9ZhjGyMPi8LYdNmypH0fBenoYGVE_bUV9dWqh5dFIKDwCyxkbIseJ6Z49MRFnSEFTtBr02xVz7Q1vAx0iKSRAMof3o5dqEd5Y1fVhDuLk3KT5JisNQZd_yWXDflaHmEgjEqza7uTrdR6LWysHDD8EdUrGQxWWHmneyc3qdWlc0TBxhGp3Q8V0a3Ian1k75PqrfkyC_IITP0KIDmaylgMSMmAQbzvkeHDtPcibG-BiNn2FNK7T77m7GpQkubMwYOI1PkoGSmveiuooTTqj6PSDGrQdDfRllk_HSwcTnd9csLazAQP_tLKHX8lsHTtTE7Orkcf8IEUfmV35Ltx2HzLYytejCYYS7ZoSfgjDTZUOY41QQ-YS0tIDKHpgr_PJqtT", + "width": 3024 + } + ], + "place_id": "ChIJK0jeP5uAhYARcxPNUpvfc7A", + "plus_code": { + "compound_code": "QHHP+W8 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+W8" + }, + "rating": 3.5, + "reference": "ChIJK0jeP5uAhYARcxPNUpvfc7A", + "scope": "GOOGLE", + "types": [ + "transit_station", + "point_of_interest", + "establishment" + ], + "user_ratings_total": 2, + "vicinity": "United States" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.779989, + "lng": -122.4138743 + }, + "viewport": { + "northeast": { + "lat": 37.7811369802915, + "lng": -122.4131672197085 + }, + "southwest": { + "lat": 37.7784390197085, + "lng": -122.4158651802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "UN Skate Plaza", + "place_id": "ChIJR4ivYwCBhYAR2xEDgcXd8oE", + "plus_code": { + "compound_code": "QHHP+XF Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+XF" + }, + "reference": "ChIJR4ivYwCBhYAR2xEDgcXd8oE", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "vicinity": "1484 Market Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7798254, + "lng": -122.4149907 + }, + "viewport": { + "northeast": { + "lat": 37.7811608302915, + "lng": -122.4137199197085 + }, + "southwest": { + "lat": 37.77846286970851, + "lng": -122.4164178802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "Curry Without Worry", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 3024, + "html_attributions": [ + "Sterling Gerard" + ], + "photo_reference": "AciIO2cHQr4ENxn9-409JJPj5hKunwLPi9gn-eN4W0X85UOvQVHoKQBUA4AotH3pkFTPxm1X76omOi2jbTiRSL9-eRFhA9wWpiXoSj2ggXeHrUxLMQBZb7cQuH4lg9YCOasXwXz3-e3H1lrByl7en3XSTkvuZUDrbtHocGV-0XNw2YpOmVvN-mLcRxgUpWhguLsvnO7B5JzXjz4ewOAxBLF9f-ZOdRktRcHDczoA0zYsOFwri0CXVjfYdB4HxjwXBPm1vXQY1U5qRydrI0Eru1tbTI9alsrmBOL4l0BAY--_fd3luNnwiQAYHzBJoZ7pqHjGOHtHa-OH7GFawpbxKr8MqeT3KVMcDVWm8sOy-zd2Gjbez5CQ5ld0w-q_2QDTVzHV5ybrzDm1OIl4vIW9eBTQVwkBwnmUjKFSZEQ-ANezOwN6XfW_jkWleRJ28dpXLo25dhW7gmYZxRcGpPwWRpcH3jyenU59CRJ6EG8nqVhTs-JzGOawmsLs4Kyg4f16fJE2lDTySU82fcQgd8uBkJGE-XrFYNOakpMWBKo1GWNOvfPsceoyB4qiLwf7VFM5Sa8yQUmNxdKRvVvhqCRjzGwVQmcPEOgpANBuDTUdz9VscmOhPO_29jRMca1S9AuseiZBdmRO4HHv", + "width": 4032 + } + ], + "place_id": "ChIJKZtFDpuAhYAR7xKvaP5D1dI", + "plus_code": { + "compound_code": "QHHP+W2 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+W2" + }, + "rating": 4.7, + "reference": "ChIJKZtFDpuAhYAR7xKvaP5D1dI", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "user_ratings_total": 14, + "vicinity": "50 United Nations Plaza, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7798179, + "lng": -122.4149928 + }, + "viewport": { + "northeast": { + "lat": 37.7811602302915, + "lng": -122.4137218697085 + }, + "southwest": { + "lat": 37.7784622697085, + "lng": -122.4164198302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "UN Skatepark", + "photos": [ + { + "height": 4096, + "html_attributions": [ + "Ghassan G" + ], + "photo_reference": "AciIO2cuIIcUq2yO7nQ_aENkWHN-EBW8baPzWgyrlTnoDLJnZ3xkqA3qGN06NxagIX9LHoTMQKoBBtLKns2IEl90Mb3H_2P13nbPfRUkK0LEwZYq8jrhkAr1kkiuSzQZwXaQEw8o3W4kTBjRhrSnqv69l-mQjTnOMPnIvfdsfM-7-5cCCbReiG2UuhJaxEEP4HEQhpoKPdeysLMtlmOG3AkapY9hUggeffNhVVSc55UEM7CRWozNOoy8oVS6E-kixEK5Zvnrs2JgCarGttCGaQPrxg_R3LjCfWNCqbHD5pz5UGlN_Nixxf5un7OoTvmvxHCjSblmFZttvdfpoI9H54u-rdY6XBeCXON4hcc8vTt-H7pUoPOYQAQvOEsMknrcKQ10Fr7MdsMqp495fV0xc1WK-TMf0sd8aTHjJlDh0_yvi9gzBd47UzJddXi81F0y7HLNpwAHorBvYsPKM3c3pCCKjzOJKtieqvv-xvvdygIEFh4GvIfqInYEpsZeIgvnpUWZKeRoBeAh46AWyHe_-iZzkG94o5TRWiX1McziIr0nXb-2-V0uDhY1CZzDZZxTNPuaanEBSekt9tUMoF-TF-0YSyxGSlm4w8EfGhBrde4vKu2JyunwApDogalJbiDVsX5x7ZqwvBS6sBQxmxotvhRApbUOSRE", + "width": 3072 + } + ], + "place_id": "ChIJfZvlNy-BhYARYrz8xesnfo8", + "plus_code": { + "compound_code": "QHHP+W2 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+W2" + }, + "rating": 5, + "reference": "ChIJfZvlNy-BhYARYrz8xesnfo8", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "user_ratings_total": 1, + "vicinity": "50 United Nations Plz, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.7798164, + "lng": -122.4149956 + }, + "viewport": { + "northeast": { + "lat": 37.7811597302915, + "lng": -122.4137233697085 + }, + "southwest": { + "lat": 37.7784617697085, + "lng": -122.4164213302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "Sim\u00f3n Bol\u00edvar Statue", + "opening_hours": { + "open_now": true + }, + "photos": [ + { + "height": 3452, + "html_attributions": [ + "Willians Rodriguez" + ], + "photo_reference": "AciIO2cWoTT9PaFCzH3sXfgvrgMG7uflXzfYSi4jJwNNBJRMxVPQp1TO-_3F0HFe4cWsF-z2g0MrluTzpSdWET57_kIPxx_rRh7TpX6Nv6jpWStd6hDBSAu-WGoaV8T2KESXe-N4WhG0afkZV61_rKqYtk9tc_NsE7Und84qxrQHTD2U-SYCSevUE4EkOGtinTv1o9Ll9yS2Svct_xPp5dAPJEJLBj2JBmWyn2p-sK-DzFHaGzP4r1NfAxQx0oQdoa3R0IUOXLIM6Xx8B_By8Vv9x9Z6wRlblRIM9CiX497_oDaYINg0w8lBtaEN5SSO7QxPRfV8o5NtJWBMqabnW7wepbRqq7BQh43-3HO_HXB1H6nP-cHLXetjXtN775nnAWlhXCEV_2Gb2HTRK0s7xQXHGZdKQCwDXAiTLtHFNGSaqQ3GhQ6iZdGquwh3q46lv6aRczhbo2kGRUgnkYYUa8AquE7Et0miHHw2zKc3lXX9FHQQannKHRc_yMQUpeKQGlBIxTmGvKLeatxHN6iLrtlfSIuHSc4FJWaYqkkiPAny1ZYcM61Jar67gMpf3-3RVwckUMqy4a9yDJawO-g8d-9svKI-5QlZXqlayrNnPsU6KSEgJhkJ95Fdi0nNM9qRYVFVbFVzosF0", + "width": 1868 + } + ], + "place_id": "ChIJxwBPDpuAhYAREmyxJOv11Nk", + "plus_code": { + "compound_code": "QHHP+W2 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+W2" + }, + "rating": 4.4, + "reference": "ChIJxwBPDpuAhYAREmyxJOv11Nk", + "scope": "GOOGLE", + "types": [ + "tourist_attraction", + "point_of_interest", + "establishment" + ], + "user_ratings_total": 23, + "vicinity": "50 United Nations Plaza, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.77961519999999, + "lng": -122.4143835 + }, + "viewport": { + "northeast": { + "lat": 37.78097603029151, + "lng": -122.4127372697085 + }, + "southwest": { + "lat": 37.77827806970851, + "lng": -122.4154352302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "Fitness Court at UN Plaza", + "opening_hours": { + "open_now": true + }, + "photos": [ + { + "height": 3213, + "html_attributions": [ + "Ally Lim" + ], + "photo_reference": "AciIO2c46aZ1fy0jImtc4i9AybRpqmgpwtnxt0yabDDt0HSzMy6bLyNo06EfEpKBi6cvmAnTmtGPILHAMUacEz6idLBwFO6ClbLGSLpaGmrE-ER462n6AvHQXwHXjL1REr-EU_cWAGUj7vMDJ_8oJwBlON1J6OoUi4N4eaJCgGa2nYN2KhQ_IsxlW06jBWAJ_8i5UzDCk9paPMLTlx6XGrN_ARqihZrDHp1ejLT9LsQuBny8qSHSq6N_cgDjhB6x8DLxLrNeZzFcY6RTwhLDeYqAaV1xlyQN68D8rCd-THrFbXYh0eqnCUNPO2mY0KgET5ifiuIsqEAfpOJp5JHKduPfdRphmIPJfag_kwtJ5kwmjQaDcpmLpVRLxBaFKDmjZ1oFjIm68YpF0z3Tz7chAD90lfLzKKIfQadS5xZLJR-34rJwZA6uiLx-9mEe3upotSZzDmtGQCEbkEJIbWA5TXa0Gr-dK4wQ2RHkzHhIprVlxu6oiXkBzrxx5De5dULfVOtZe25GbYgC6yOGVWppzAawylRfzfroxgD0Q4Qm3vZhrSVdousQjlhvOOd4vNjF4ab1SM0NrBHydXTzm9qO-Q9O45FAGe6DG_9ftmhsrMX57SZpBlnbsYFHZEgNOJhNkAyxcW6rvg", + "width": 5712 + } + ], + "place_id": "ChIJOxlsRwCBhYAR5FY6A3dg8Ek", + "plus_code": { + "compound_code": "QHHP+R6 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+R6" + }, + "rating": 5, + "reference": "ChIJOxlsRwCBhYAR5FY6A3dg8Ek", + "scope": "GOOGLE", + "types": [ + "gym", + "health", + "point_of_interest", + "establishment" + ], + "user_ratings_total": 3, + "vicinity": "3537 Fulton Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.77961519999999, + "lng": -122.4143835 + }, + "viewport": { + "northeast": { + "lat": 37.7810261302915, + "lng": -122.4129955697085 + }, + "southwest": { + "lat": 37.7783281697085, + "lng": -122.4156935302915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/cafe-71.png", + "icon_background_color": "#FF9E67", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/cafe_pinlet", + "name": "United Nations Cafe", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 1836, + "html_attributions": [ + "Steven Smith" + ], + "photo_reference": "AciIO2dhjLdgjy4fMy59en74_XnQ8CoXenGsfvaQ3MM7TohCqXE2tS7BYvyYoNu5gZbhJsNRulbldWgRUT1EpRPkiFZoqa1leeUttiHt1NUuSOEOYULofcZ8ShClkfIPk2U6i6-OajtQc5Aj9rYRtS8WmF_19ducNw0h4f3CSSuDPqKIloeNRsWm-uqi2faqjsgqe8iWvsmgABAmcdUhdAuDFWW31TnrtRe3D58TkvUJGv6-cpIDzuNv8gYPyokrz6lngguIGgNfy53t6xdLFbHMQFnLzgFx2NJbFeC2ZX3-WjKMXuy85hHuVUmucmLz80z6_yHa7kxlbpnruFdjhehwajdG7c0uy-HhxG7LVhRy9I4-aE0f5i4lBoZONibJ7KaHGoJLEMLcm5ig-hXHXfGoXIX3MIl5y5IOxhe4N4bimc1IsmMTs0MKw4O0ZbMhQ8yF4Uqb67ZWfIiEKEL7sXxkWGlgE65OAIutewzFNjOuWzsbQ7oCMK77hVI72s83jl3qT7SX4BQcy0wkSblVVTrm1VWf1PajA9Bzye0ZFi4yClaARpsQH8ZnOOsA3igFlJbjNohPzM8EaOPV3eWUqr8o-tkIp8IIAx5OLBqJjOs_E10AvQB7Pc4z2c6viTZDda9E", + "width": 3264 + } + ], + "place_id": "ChIJ4ZfeFJuAhYAREGTVnroeXsg", + "plus_code": { + "compound_code": "QHHP+R6 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+R6" + }, + "rating": 4.5, + "reference": "ChIJ4ZfeFJuAhYAREGTVnroeXsg", + "scope": "GOOGLE", + "types": [ + "cafe", + "point_of_interest", + "food", + "establishment" + ], + "user_ratings_total": 33, + "vicinity": "3537 Fulton Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.78012649999999, + "lng": -122.4136321 + }, + "viewport": { + "northeast": { + "lat": 37.78198923029149, + "lng": -122.4121925197085 + }, + "southwest": { + "lat": 37.7792912697085, + "lng": -122.4148904802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "UN Skate Plaza", + "photos": [ + { + "height": 2268, + "html_attributions": [ + "Ally Lim" + ], + "photo_reference": "AciIO2fVA9xB6yslvpFQ1lHcw50PP-CHL5GT3WOtJCZ9pXvXUQ_PO0UhhmED-HG6hgIzaN5asxwB8vmzFa4xU4PPKu_LIu4XoCl3PDszzyju1ve916Kpw4jxHkXej81y_IwngvIAFFEfehH5n3lgfdkiZW176mppdHS3A1FpuvUP7yRA3jhenmFvSwmhpJJ6qdicxFvd0Gk-0R-bgzE2bowKaDhUE05PdDInRQCc83j4DsKXfu0eyTUSxzKVJ_Cwy8qdyCfKLXKkdPC8puMSa4nHnaATsWwFNY0eIBKwjACewkHIw5cfCOtcnmg8C-k-iElrgDHrZbDuuFTazC44CAaY2IR-H6cylBKKo8vY73T0iWF2OFJN7hQiL41iWu49OkDv_0cLyOveKyCo-TXh-Fw3RXpsf4fOSsO8UO0l9okQ2f62L_2XRYSZtPMoax2ZrlCTiegxYScg4dvuEuKDQ6_lAqDUawZcb92EHPRV39JI8trLJLlpn0UjWEYQZJ6dVPEJkjcJbeVbxlCkxiIIrym5ljDDTCOv226BX8uEdWlEZSk5jrxt3Js7gNcNJYHlNbjb9KV1Oa_NWFU7AKzVXDJR7ZS-K9OAiAnISbJOviAroCh3vaVP958bxNJu6Cwt_jphUuYEnw", + "width": 4032 + } + ], + "place_id": "ChIJQaVbEAuBhYARTcbgmBM8tVE", + "plus_code": { + "compound_code": "QHJP+3G Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+3G" + }, + "rating": 4.6, + "reference": "ChIJQaVbEAuBhYARTcbgmBM8tVE", + "scope": "GOOGLE", + "types": [ + "point_of_interest", + "establishment" + ], + "user_ratings_total": 21, + "vicinity": "1140 Market Street, San Francisco" + }, + { + "business_status": "OPERATIONAL", + "geometry": { + "location": { + "lat": 37.78093459999999, + "lng": -122.4144382 + }, + "viewport": { + "northeast": { + "lat": 37.7822385302915, + "lng": -122.4130778197085 + }, + "southwest": { + "lat": 37.7795405697085, + "lng": -122.4157757802915 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/cafe-71.png", + "icon_background_color": "#FF9E67", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/cafe_pinlet", + "name": "Paris Cafe", + "opening_hours": { + "open_now": false + }, + "photos": [ + { + "height": 4032, + "html_attributions": [ + "Paris Cafe" + ], + "photo_reference": "AciIO2fMlGoVgo_TLdvq2CENHw2KFOvcDW45EWxcL8DAw7QPnBbPPS0665SVCCKmKdPI9upG7wCidO6UyCCcMGc4gF32SbUAAPa-whL7CHURZfb-9STDUqcrh-HWmP3K7ZmVoPpWHgFxkfsjfls6LzpphMo3DLXw5mdUIiRbg8d8PM0N-mVp-e7MBPMRIPm1t3RCBA3MdO5cBwHrRs2J3XB05ao22l6a-FBtIiaZWKEikHT9DsQnUH4bHgfvM7lPoCSCikwucTQasUYfXPbaNXm8z-LNvR6ZsTcGsOkRKsu5S7k7eEE3jK68GJxd7nV7C3217lyN12VxZ6U", + "width": 3024 + } + ], + "place_id": "ChIJOYG2HACBhYAR51qH-8IsnFM", + "plus_code": { + "compound_code": "QHJP+96 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+96" + }, + "price_level": 2, + "rating": 4.8, + "reference": "ChIJOYG2HACBhYAR51qH-8IsnFM", + "scope": "GOOGLE", + "types": [ + "cafe", + "point_of_interest", + "store", + "food", + "establishment" + ], + "user_ratings_total": 78, + "vicinity": "142 McAllister Street, San Francisco" + }, + { + "geometry": { + "location": { + "lat": 37.7773082, + "lng": -122.4196412 + }, + "viewport": { + "northeast": { + "lat": 37.78237885897592, + "lng": -122.4125122545961 + }, + "southwest": { + "lat": 37.77303595794733, + "lng": -122.4237308429429 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/geocode-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "Civic Center", + "photos": [ + { + "height": 2268, + "html_attributions": [ + "Tobias Peyerl" + ], + "photo_reference": "AciIO2cy7yjg95KUbhq9hn7tUsXX0uuUcS8pB9NHPMos5CJwF9b-za_UzQEnJeyopweobag8YKyuK5xbVUhjdgpb-QFhXNknGAD7vs6skcUi4i_2tPQ-ludpZX3_p3upeF2d0Y91HGvucbf6Opj7dKjNgp7gGyY-ZTwhfqo32bmEcu3G_CbTmvbyhuJXocIcJOIXwOM7VVxVB-_3vrcpWPHeV18Y6ilm_atTzkouUvclYwo5i_YInAZ_cNN1DPiNNsK4uHEOR-1wYHjaF8A2G-Y80ieN9G9TxZl6E04wxiiEx3lAYuUuOq4Be5RyMTSDKgv75gvjKmQPvxSD2nVKl8OKxXCWAujxI44xi0Mj_Jr7-K55rwJjTPpIPa-ng72LSvyQ4Er-tjC83O17SFUMNNxE5ixb-xDuARpu3UjB-0pzD8vJJ9BAnwHkUhvDueMMVrrQ7W7BNYw7T4-A-eiznIpS6pft_vc2Kkq3t-CE3-VlZAUC7dSoCiK-Kag77oB2WlIjJltl9dgtlNid2qoGE6nNkWBYlDnxADFBkHDEIeh6jIzqGMcUbr-rtw1H4otL8MjlWf65JpbCAmXifV1rSPqylFatmfp74jIuJSmnODs-lG_-R1eObSQ3oaDi280kJmvX6VOK5XDV", + "width": 4032 + } + ], + "place_id": "ChIJ3eJWtI6AhYAR2ovTWatCF8s", + "reference": "ChIJ3eJWtI6AhYAR2ovTWatCF8s", + "scope": "GOOGLE", + "types": [ + "neighborhood", + "political" + ], + "vicinity": "San Francisco" + } + ], + "status": "OK" +} diff --git a/dimos/mapping/google_maps/fixtures/get_location_context_reverse_geocode.json b/dimos/mapping/google_maps/fixtures/get_location_context_reverse_geocode.json new file mode 100644 index 0000000000..216c02aca9 --- /dev/null +++ b/dimos/mapping/google_maps/fixtures/get_location_context_reverse_geocode.json @@ -0,0 +1,1140 @@ +[ + { + "address_components": [ + { + "long_name": "50", + "short_name": "50", + "types": [ + "street_number" + ] + }, + { + "long_name": "United Nations Plaza", + "short_name": "United Nations Plaza", + "types": [ + "route" + ] + }, + { + "long_name": "Civic Center", + "short_name": "Civic Center", + "types": [ + "neighborhood", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + }, + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + }, + { + "long_name": "4917", + "short_name": "4917", + "types": [ + "postal_code_suffix" + ] + } + ], + "formatted_address": "50 United Nations Plaza, San Francisco, CA 94102, USA", + "geometry": { + "location": { + "lat": 37.78021, + "lng": -122.4144194 + }, + "location_type": "ROOFTOP", + "viewport": { + "northeast": { + "lat": 37.78155898029149, + "lng": -122.4130704197085 + }, + "southwest": { + "lat": 37.77886101970849, + "lng": -122.4157683802915 + } + } + }, + "navigation_points": [ + { + "location": { + "latitude": 37.7799875, + "longitude": -122.4143728 + }, + "restricted_travel_modes": [ + "DRIVE" + ] + }, + { + "location": { + "latitude": 37.7807662, + "longitude": -122.4145332 + }, + "restricted_travel_modes": [ + "WALK" + ] + } + ], + "place_id": "ChIJp9HdGZuAhYAR9HQeU37hyx0", + "types": [ + "street_address", + "subpremise" + ] + }, + { + "address_components": [ + { + "long_name": "50", + "short_name": "50", + "types": [ + "street_number" + ] + }, + { + "long_name": "Hyde Street", + "short_name": "Hyde St", + "types": [ + "route" + ] + }, + { + "long_name": "Civic Center", + "short_name": "Civic Center", + "types": [ + "neighborhood", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + }, + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + } + ], + "formatted_address": "50 Hyde St, San Francisco, CA 94102, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 37.78081540000001, + "lng": -122.4137806 + }, + "southwest": { + "lat": 37.7800522, + "lng": -122.415187 + } + }, + "location": { + "lat": 37.7805991, + "lng": -122.4147826 + }, + "location_type": "ROOFTOP", + "viewport": { + "northeast": { + "lat": 37.78178278029151, + "lng": -122.4131348197085 + }, + "southwest": { + "lat": 37.77908481970851, + "lng": -122.4158327802915 + } + } + }, + "navigation_points": [ + { + "location": { + "latitude": 37.7799291, + "longitude": -122.4143652 + }, + "restricted_travel_modes": [ + "WALK" + ] + } + ], + "place_id": "ChIJ7Q9FGZuAhYARSovheSUzVeE", + "types": [ + "premise", + "street_address" + ] + }, + { + "address_components": [ + { + "long_name": "Civic Center/UN Plaza BART Station", + "short_name": "Civic Center/UN Plaza BART Station", + "types": [ + "establishment", + "point_of_interest", + "transit_station" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + } + ], + "formatted_address": "Civic Center/UN Plaza BART Station, San Francisco, CA, USA", + "geometry": { + "location": { + "lat": 37.779756, + "lng": -122.41415 + }, + "location_type": "GEOMETRIC_CENTER", + "viewport": { + "northeast": { + "lat": 37.7811049802915, + "lng": -122.4128010197085 + }, + "southwest": { + "lat": 37.7784070197085, + "lng": -122.4154989802915 + } + } + }, + "navigation_points": [ + { + "location": { + "latitude": 37.7797284, + "longitude": -122.4142112 + }, + "restricted_travel_modes": [ + "DRIVE" + ] + }, + { + "location": { + "latitude": 37.779631, + "longitude": -122.4150367 + }, + "restricted_travel_modes": [ + "WALK" + ] + }, + { + "location": { + "latitude": 37.7795262, + "longitude": -122.4138289 + }, + "restricted_travel_modes": [ + "WALK" + ] + }, + { + "location": { + "latitude": 37.7796804, + "longitude": -122.4136322 + } + }, + { + "location": { + "latitude": 37.7804986, + "longitude": -122.4129601 + }, + "restricted_travel_modes": [ + "DRIVE" + ] + }, + { + "location": { + "latitude": 37.7788771, + "longitude": -122.414549 + } + } + ], + "place_id": "ChIJK0jeP5uAhYARcxPNUpvfc7A", + "plus_code": { + "compound_code": "QHHP+W8 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+W8" + }, + "types": [ + "establishment", + "point_of_interest", + "transit_station" + ] + }, + { + "address_components": [ + { + "long_name": "1-99", + "short_name": "1-99", + "types": [ + "street_number" + ] + }, + { + "long_name": "United Nations Plaza", + "short_name": "United Nations Plz", + "types": [ + "route" + ] + }, + { + "long_name": "Civic Center", + "short_name": "Civic Center", + "types": [ + "neighborhood", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + }, + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + }, + { + "long_name": "7402", + "short_name": "7402", + "types": [ + "postal_code_suffix" + ] + } + ], + "formatted_address": "1-99 United Nations Plz, San Francisco, CA 94102, USA", + "geometry": { + "location": { + "lat": 37.779675, + "lng": -122.41408 + }, + "location_type": "ROOFTOP", + "viewport": { + "northeast": { + "lat": 37.78102398029149, + "lng": -122.4127310197085 + }, + "southwest": { + "lat": 37.7783260197085, + "lng": -122.4154289802915 + } + } + }, + "navigation_points": [ + { + "location": { + "latitude": 37.7796351, + "longitude": -122.4141273 + }, + "restricted_travel_modes": [ + "DRIVE" + ] + }, + { + "location": { + "latitude": 37.7796283, + "longitude": -122.4138453 + }, + "restricted_travel_modes": [ + "WALK" + ] + } + ], + "place_id": "ChIJD8AMQJuAhYARgQPDkMbiVZE", + "plus_code": { + "compound_code": "QHHP+V9 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHHP+V9" + }, + "types": [ + "street_address" + ] + }, + { + "address_components": [ + { + "long_name": "QHJP+36", + "short_name": "QHJP+36", + "types": [ + "plus_code" + ] + }, + { + "long_name": "Civic Center", + "short_name": "Civic Center", + "types": [ + "neighborhood", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + }, + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + } + ], + "formatted_address": "QHJP+36 Civic Center, San Francisco, CA, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 37.78025, + "lng": -122.414375 + }, + "southwest": { + "lat": 37.780125, + "lng": -122.4145 + } + }, + "location": { + "lat": 37.7801776, + "lng": -122.4144952 + }, + "location_type": "GEOMETRIC_CENTER", + "viewport": { + "northeast": { + "lat": 37.78153648029149, + "lng": -122.4130885197085 + }, + "southwest": { + "lat": 37.77883851970849, + "lng": -122.4157864802915 + } + } + }, + "place_id": "GhIJMIkO3NzjQkARVhbgFoeaXsA", + "plus_code": { + "compound_code": "QHJP+36 Civic Center, San Francisco, CA, USA", + "global_code": "849VQHJP+36" + }, + "types": [ + "plus_code" + ] + }, + { + "address_components": [ + { + "long_name": "39", + "short_name": "39", + "types": [ + "street_number" + ] + }, + { + "long_name": "Hyde Street", + "short_name": "Hyde St", + "types": [ + "route" + ] + }, + { + "long_name": "Civic Center", + "short_name": "Civic Center", + "types": [ + "neighborhood", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + }, + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + } + ], + "formatted_address": "39 Hyde St, San Francisco, CA 94102, USA", + "geometry": { + "location": { + "lat": 37.7800157, + "lng": -122.4151997 + }, + "location_type": "RANGE_INTERPOLATED", + "viewport": { + "northeast": { + "lat": 37.7813646802915, + "lng": -122.4138507197085 + }, + "southwest": { + "lat": 37.7786667197085, + "lng": -122.4165486802915 + } + } + }, + "place_id": "EigzOSBIeWRlIFN0LCBTYW4gRnJhbmNpc2NvLCBDQSA5NDEwMiwgVVNBIhoSGAoUChIJNcWgBpuAhYARvBLCxkfib9AQJw", + "types": [ + "street_address" + ] + }, + { + "address_components": [ + { + "long_name": "47-35", + "short_name": "47-35", + "types": [ + "street_number" + ] + }, + { + "long_name": "Hyde Street", + "short_name": "Hyde St", + "types": [ + "route" + ] + }, + { + "long_name": "Civic Center", + "short_name": "Civic Center", + "types": [ + "neighborhood", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + }, + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + } + ], + "formatted_address": "47-35 Hyde St, San Francisco, CA 94102, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 37.7803333, + "lng": -122.4151588 + }, + "southwest": { + "lat": 37.7798162, + "lng": -122.4152658 + } + }, + "location": { + "lat": 37.7800748, + "lng": -122.415212 + }, + "location_type": "GEOMETRIC_CENTER", + "viewport": { + "northeast": { + "lat": 37.7814237302915, + "lng": -122.4138633197085 + }, + "southwest": { + "lat": 37.7787257697085, + "lng": -122.4165612802915 + } + } + }, + "place_id": "ChIJNcWgBpuAhYARvBLCxkfib9A", + "types": [ + "route" + ] + }, + { + "address_components": [ + { + "long_name": "Civic Center", + "short_name": "Civic Center", + "types": [ + "neighborhood", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + }, + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + } + ], + "formatted_address": "Civic Center, San Francisco, CA 94102, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 37.7823789, + "lng": -122.4125123 + }, + "southwest": { + "lat": 37.773036, + "lng": -122.4237308 + } + }, + "location": { + "lat": 37.7773082, + "lng": -122.4196412 + }, + "location_type": "APPROXIMATE", + "viewport": { + "northeast": { + "lat": 37.7823789, + "lng": -122.4125123 + }, + "southwest": { + "lat": 37.773036, + "lng": -122.4237308 + } + } + }, + "place_id": "ChIJ3eJWtI6AhYAR2ovTWatCF8s", + "types": [ + "neighborhood", + "political" + ] + }, + { + "address_components": [ + { + "long_name": "94102", + "short_name": "94102", + "types": [ + "postal_code" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + } + ], + "formatted_address": "San Francisco, CA 94102, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 37.789226, + "lng": -122.4034491 + }, + "southwest": { + "lat": 37.7694409, + "lng": -122.429849 + } + }, + "location": { + "lat": 37.7786871, + "lng": -122.4212424 + }, + "location_type": "APPROXIMATE", + "viewport": { + "northeast": { + "lat": 37.789226, + "lng": -122.4034491 + }, + "southwest": { + "lat": 37.7694409, + "lng": -122.429849 + } + } + }, + "place_id": "ChIJs88qnZmAhYARk8u-7t1Sc2g", + "types": [ + "postal_code" + ] + }, + { + "address_components": [ + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + } + ], + "formatted_address": "San Francisco County, San Francisco, CA, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 37.929824, + "lng": -122.28178 + }, + "southwest": { + "lat": 37.63983, + "lng": -123.1327983 + } + }, + "location": { + "lat": 37.7618219, + "lng": -122.5146439 + }, + "location_type": "APPROXIMATE", + "viewport": { + "northeast": { + "lat": 37.929824, + "lng": -122.28178 + }, + "southwest": { + "lat": 37.63983, + "lng": -123.1327983 + } + } + }, + "place_id": "ChIJIQBpAG2ahYARUksNqd0_1h8", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "address_components": [ + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + } + ], + "formatted_address": "San Francisco, CA, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 37.929824, + "lng": -122.28178 + }, + "southwest": { + "lat": 37.6398299, + "lng": -123.1328145 + } + }, + "location": { + "lat": 37.7749295, + "lng": -122.4194155 + }, + "location_type": "APPROXIMATE", + "viewport": { + "northeast": { + "lat": 37.929824, + "lng": -122.28178 + }, + "southwest": { + "lat": 37.6398299, + "lng": -123.1328145 + } + } + }, + "place_id": "ChIJIQBpAG2ahYAR_6128GcTUEo", + "types": [ + "locality", + "political" + ] + }, + { + "address_components": [ + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + } + ], + "formatted_address": "California, USA", + "geometry": { + "bounds": { + "northeast": { + "lat": 42.009503, + "lng": -114.131211 + }, + "southwest": { + "lat": 32.52950810000001, + "lng": -124.482003 + } + }, + "location": { + "lat": 36.778261, + "lng": -119.4179324 + }, + "location_type": "APPROXIMATE", + "viewport": { + "northeast": { + "lat": 42.009503, + "lng": -114.131211 + }, + "southwest": { + "lat": 32.52950810000001, + "lng": -124.482003 + } + } + }, + "place_id": "ChIJPV4oX_65j4ARVW8IJ6IJUYs", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "address_components": [ + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + } + ], + "formatted_address": "United States", + "geometry": { + "bounds": { + "northeast": { + "lat": 74.071038, + "lng": -66.885417 + }, + "southwest": { + "lat": 18.7763, + "lng": 166.9999999 + } + }, + "location": { + "lat": 38.7945952, + "lng": -106.5348379 + }, + "location_type": "APPROXIMATE", + "viewport": { + "northeast": { + "lat": 74.071038, + "lng": -66.885417 + }, + "southwest": { + "lat": 18.7763, + "lng": 166.9999999 + } + } + }, + "place_id": "ChIJCzYy5IS16lQRQrfeQ5K5Oxw", + "types": [ + "country", + "political" + ] + } +] diff --git a/dimos/mapping/google_maps/fixtures/get_position.json b/dimos/mapping/google_maps/fixtures/get_position.json new file mode 100644 index 0000000000..410d2add2a --- /dev/null +++ b/dimos/mapping/google_maps/fixtures/get_position.json @@ -0,0 +1,141 @@ +[ + { + "address_components": [ + { + "long_name": "Golden Gate Bridge", + "short_name": "Golden Gate Bridge", + "types": [ + "establishment", + "point_of_interest", + "tourist_attraction" + ] + }, + { + "long_name": "Golden Gate Bridge", + "short_name": "Golden Gate Brg", + "types": [ + "route" + ] + }, + { + "long_name": "San Francisco", + "short_name": "SF", + "types": [ + "locality", + "political" + ] + }, + { + "long_name": "San Francisco County", + "short_name": "San Francisco County", + "types": [ + "administrative_area_level_2", + "political" + ] + }, + { + "long_name": "California", + "short_name": "CA", + "types": [ + "administrative_area_level_1", + "political" + ] + }, + { + "long_name": "United States", + "short_name": "US", + "types": [ + "country", + "political" + ] + } + ], + "formatted_address": "Golden Gate Bridge, Golden Gate Brg, San Francisco, CA, USA", + "geometry": { + "location": { + "lat": 37.8199109, + "lng": -122.4785598 + }, + "location_type": "GEOMETRIC_CENTER", + "viewport": { + "northeast": { + "lat": 37.8324583, + "lng": -122.4756692 + }, + "southwest": { + "lat": 37.8075604, + "lng": -122.4810829 + } + } + }, + "navigation_points": [ + { + "location": { + "latitude": 37.8075604, + "longitude": -122.4756957 + } + }, + { + "location": { + "latitude": 37.80756119999999, + "longitude": -122.4756922 + }, + "restricted_travel_modes": [ + "WALK" + ] + }, + { + "location": { + "latitude": 37.8324279, + "longitude": -122.4810829 + } + }, + { + "location": { + "latitude": 37.8324382, + "longitude": -122.4810669 + }, + "restricted_travel_modes": [ + "WALK" + ] + }, + { + "location": { + "latitude": 37.8083987, + "longitude": -122.4765643 + }, + "restricted_travel_modes": [ + "DRIVE" + ] + }, + { + "location": { + "latitude": 37.8254712, + "longitude": -122.4791469 + }, + "restricted_travel_modes": [ + "DRIVE" + ] + }, + { + "location": { + "latitude": 37.8321189, + "longitude": -122.4808249 + }, + "restricted_travel_modes": [ + "DRIVE" + ] + } + ], + "place_id": "ChIJw____96GhYARCVVwg5cT7c0", + "plus_code": { + "compound_code": "RG9C+XH Presidio of San Francisco, San Francisco, CA", + "global_code": "849VRG9C+XH" + }, + "types": [ + "establishment", + "point_of_interest", + "tourist_attraction" + ] + } +] diff --git a/dimos/mapping/google_maps/fixtures/get_position_with_places.json b/dimos/mapping/google_maps/fixtures/get_position_with_places.json new file mode 100644 index 0000000000..d471a8368a --- /dev/null +++ b/dimos/mapping/google_maps/fixtures/get_position_with_places.json @@ -0,0 +1,53 @@ +{ + "html_attributions": [], + "results": [ + { + "business_status": "OPERATIONAL", + "formatted_address": "Golden Gate Brg, San Francisco, CA, United States", + "geometry": { + "location": { + "lat": 37.8199109, + "lng": -122.4785598 + }, + "viewport": { + "northeast": { + "lat": 37.84490724999999, + "lng": -122.47296235 + }, + "southwest": { + "lat": 37.79511145000001, + "lng": -122.48378975 + } + } + }, + "icon": "https://maps.gstatic.com/mapfiles/place_api/icons/v1/png_71/generic_business-71.png", + "icon_background_color": "#7B9EB0", + "icon_mask_base_uri": "https://maps.gstatic.com/mapfiles/place_api/icons/v2/generic_pinlet", + "name": "Golden Gate Bridge", + "photos": [ + { + "height": 12240, + "html_attributions": [ + "Jitesh Patil" + ], + "photo_reference": "AciIO2dcF-W6JeWe01lyR39crDHHon3awa5LlBNNhxAZcAExA3sTr33iFa8HjDgPPfdNrl3C-0Bzqp2qEndFz3acXtm1kmj7puXUOtO48-Qmovp9Nvi5k3XJVbIEPYYRCXOshrYQ1od2tHe-MBkvFNxsg4uNByEbJxkstLLTuEOmSbCEx53EQfuJoxbPQgRGphAPDFkTeiCODXd7KzdL9-2GvVYTrGl_IK-AIds1-UYwWJPOi1mkM-iXFVoVm0R1LOgt-ydhnAaRFQPzOlz9Oezc0kDiuxvzjTO4mgeY79Nqcxq2osBqYGyJTLINYfNphZHzncxWqpWXP_mvQt77YaW368RGbBGDrHubXHJBkj7sdru0N1-qf5Q28rsxCSI5yyNsHm8zFmNWm1PlWA_LItL5LpoxG9Xkuuhuvv3XjWtBs5hnHxNDHP4jbJinWz2DPd9IPxHH-BAfwfJGdtgW1juBAEDi8od5KP95Drt8e9XOaG6I5UIeJnvUqq4Q1McAiVx5rVn7FGwu3NsTAeeS4FCKy2Ql_YoQpcqzRO45w8tI4DqFd8F19pZHw3t7p1t7DwmzAMzIS_17_2aScA", + "width": 16320 + } + ], + "place_id": "ChIJw____96GhYARCVVwg5cT7c0", + "plus_code": { + "compound_code": "RG9C+XH Presidio of San Francisco, San Francisco, CA, USA", + "global_code": "849VRG9C+XH" + }, + "rating": 4.8, + "reference": "ChIJw____96GhYARCVVwg5cT7c0", + "types": [ + "tourist_attraction", + "point_of_interest", + "establishment" + ], + "user_ratings_total": 83799 + } + ], + "status": "OK" +} diff --git a/dimos/mapping/google_maps/google_maps.py b/dimos/mapping/google_maps/google_maps.py new file mode 100644 index 0000000000..7f5ce32e99 --- /dev/null +++ b/dimos/mapping/google_maps/google_maps.py @@ -0,0 +1,192 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import googlemaps # type: ignore[import-untyped] + +from dimos.mapping.google_maps.types import ( + Coordinates, + LocationContext, + NearbyPlace, + PlacePosition, + Position, +) +from dimos.mapping.types import LatLon +from dimos.mapping.utils.distance import distance_in_meters +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class GoogleMaps: + _client: googlemaps.Client + _max_nearby_places: int + + def __init__(self, api_key: str | None = None) -> None: + api_key = api_key or os.environ.get("GOOGLE_MAPS_API_KEY") + if not api_key: + raise ValueError("GOOGLE_MAPS_API_KEY environment variable not set") + self._client = googlemaps.Client(key=api_key) + self._max_nearby_places = 6 + + def get_position(self, query: str, current_location: LatLon | None = None) -> Position | None: + # Use location bias if current location is provided + if current_location: + geocode_results = self._client.geocode( + query, + bounds={ + "southwest": { + "lat": current_location.lat - 0.5, + "lng": current_location.lon - 0.5, + }, + "northeast": { + "lat": current_location.lat + 0.5, + "lng": current_location.lon + 0.5, + }, + }, + ) + else: + geocode_results = self._client.geocode(query) + + if not geocode_results: + return None + + result = geocode_results[0] + + location = result["geometry"]["location"] + + return Position( + lat=location["lat"], + lon=location["lng"], + description=result["formatted_address"], + ) + + def get_position_with_places( + self, query: str, current_location: LatLon | None = None + ) -> PlacePosition | None: + # Use location bias if current location is provided + if current_location: + places_results = self._client.places( + query, + location=(current_location.lat, current_location.lon), + radius=50000, # 50km radius for location bias + ) + else: + places_results = self._client.places(query) + + if not places_results or "results" not in places_results: + return None + + results = places_results["results"] + if not results: + return None + + place = results[0] + + location = place["geometry"]["location"] + + return PlacePosition( + lat=location["lat"], + lon=location["lng"], + description=place.get("name", ""), + address=place.get("formatted_address", ""), + types=place.get("types", []), + ) + + def get_location_context( + self, latlon: LatLon, radius: int = 100, n_nearby_places: int = 6 + ) -> LocationContext | None: + reverse_geocode_results = self._client.reverse_geocode((latlon.lat, latlon.lon)) + + if not reverse_geocode_results: + return None + + result = reverse_geocode_results[0] + + # Extract address components + components = {} + for component in result.get("address_components", []): + types = component.get("types", []) + if "street_number" in types: + components["street_number"] = component["long_name"] + elif "route" in types: + components["street"] = component["long_name"] + elif "neighborhood" in types: + components["neighborhood"] = component["long_name"] + elif "locality" in types: + components["locality"] = component["long_name"] + elif "administrative_area_level_1" in types: + components["admin_area"] = component["long_name"] + elif "country" in types: + components["country"] = component["long_name"] + elif "postal_code" in types: + components["postal_code"] = component["long_name"] + + nearby_places, place_types_summary = self._get_nearby_places( + latlon, radius, n_nearby_places + ) + + return LocationContext( + formatted_address=result.get("formatted_address", ""), + street_number=components.get("street_number", ""), + street=components.get("street", ""), + neighborhood=components.get("neighborhood", ""), + locality=components.get("locality", ""), + admin_area=components.get("admin_area", ""), + country=components.get("country", ""), + postal_code=components.get("postal_code", ""), + nearby_places=nearby_places, + place_types_summary=place_types_summary or "No specific landmarks nearby", + coordinates=Coordinates(lat=latlon.lat, lon=latlon.lon), + ) + + def _get_nearby_places( + self, latlon: LatLon, radius: int, n_nearby_places: int + ) -> tuple[list[NearbyPlace], str]: + nearby_places = [] + place_types_count: dict[str, int] = {} + + places_nearby = self._client.places_nearby(location=(latlon.lat, latlon.lon), radius=radius) + + if places_nearby and "results" in places_nearby: + for place in places_nearby["results"][:n_nearby_places]: + place_lat = place["geometry"]["location"]["lat"] + place_lon = place["geometry"]["location"]["lng"] + place_latlon = LatLon(lat=place_lat, lon=place_lon) + + place_info = NearbyPlace( + name=place.get("name", ""), + types=place.get("types", []), + vicinity=place.get("vicinity", ""), + distance=round(distance_in_meters(place_latlon, latlon), 1), + ) + + nearby_places.append(place_info) + + for place_type in place.get("types", []): + if place_type not in ["point_of_interest", "establishment"]: + place_types_count[place_type] = place_types_count.get(place_type, 0) + 1 + nearby_places.sort(key=lambda x: x.distance) + + place_types_summary = ", ".join( + [ + f"{count} {ptype.replace('_', ' ')}{'s' if count > 1 else ''}" + for ptype, count in sorted( + place_types_count.items(), key=lambda x: x[1], reverse=True + )[:5] + ] + ) + + return nearby_places, place_types_summary diff --git a/dimos/mapping/google_maps/test_google_maps.py b/dimos/mapping/google_maps/test_google_maps.py new file mode 100644 index 0000000000..13f7fa8eaa --- /dev/null +++ b/dimos/mapping/google_maps/test_google_maps.py @@ -0,0 +1,139 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.mapping.types import LatLon + + +def test_get_position(maps_client, maps_fixture) -> None: + maps_client._client.geocode.return_value = maps_fixture("get_position.json") + + res = maps_client.get_position("golden gate bridge") + + assert res.model_dump() == { + "description": "Golden Gate Bridge, Golden Gate Brg, San Francisco, CA, USA", + "lat": 37.8199109, + "lon": -122.4785598, + } + + +def test_get_position_with_places(maps_client, maps_fixture) -> None: + maps_client._client.places.return_value = maps_fixture("get_position_with_places.json") + + res = maps_client.get_position_with_places("golden gate bridge") + + assert res.model_dump() == { + "address": "Golden Gate Brg, San Francisco, CA, United States", + "description": "Golden Gate Bridge", + "lat": 37.8199109, + "lon": -122.4785598, + "types": [ + "tourist_attraction", + "point_of_interest", + "establishment", + ], + } + + +def test_get_location_context(maps_client, maps_fixture) -> None: + maps_client._client.reverse_geocode.return_value = maps_fixture( + "get_location_context_reverse_geocode.json" + ) + maps_client._client.places_nearby.return_value = maps_fixture( + "get_location_context_places_nearby.json" + ) + + res = maps_client.get_location_context(LatLon(lat=37.78017758753598, lon=-122.4144951709186)) + + assert res.model_dump() == { + "admin_area": "California", + "coordinates": { + "lat": 37.78017758753598, + "lon": -122.4144951709186, + }, + "country": "United States", + "formatted_address": "50 United Nations Plaza, San Francisco, CA 94102, USA", + "locality": "San Francisco", + "nearby_places": [ + { + "distance": 9.3, + "name": "U.S. General Services Administration - Pacific Rim Region", + "types": [ + "point_of_interest", + "establishment", + ], + "vicinity": "50 United Nations Plaza, San Francisco", + }, + { + "distance": 14.0, + "name": "Federal Office Building", + "types": [ + "point_of_interest", + "establishment", + ], + "vicinity": "50 United Nations Plaza, San Francisco", + }, + { + "distance": 35.7, + "name": "UN Plaza", + "types": [ + "city_hall", + "point_of_interest", + "local_government_office", + "establishment", + ], + "vicinity": "355 McAllister Street, San Francisco", + }, + { + "distance": 92.7, + "name": "McAllister Market & Deli", + "types": [ + "liquor_store", + "atm", + "grocery_or_supermarket", + "finance", + "point_of_interest", + "food", + "store", + "establishment", + ], + "vicinity": "136 McAllister Street, San Francisco", + }, + { + "distance": 95.9, + "name": "Civic Center / UN Plaza", + "types": [ + "subway_station", + "transit_station", + "point_of_interest", + "establishment", + ], + "vicinity": "1150 Market Street, San Francisco", + }, + { + "distance": 726.3, + "name": "San Francisco", + "types": [ + "locality", + "political", + ], + "vicinity": "San Francisco", + }, + ], + "neighborhood": "Civic Center", + "place_types_summary": "1 locality, 1 political, 1 subway station, 1 transit station, 1 city hall", + "postal_code": "94102", + "street": "United Nations Plaza", + "street_number": "50", + } diff --git a/dimos/mapping/google_maps/types.py b/dimos/mapping/google_maps/types.py new file mode 100644 index 0000000000..29f9bee6eb --- /dev/null +++ b/dimos/mapping/google_maps/types.py @@ -0,0 +1,66 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pydantic import BaseModel + + +class Coordinates(BaseModel): + """GPS coordinates.""" + + lat: float + lon: float + + +class Position(BaseModel): + """Basic position information from geocoding.""" + + lat: float + lon: float + description: str + + +class PlacePosition(BaseModel): + """Position with places API details.""" + + lat: float + lon: float + description: str + address: str + types: list[str] + + +class NearbyPlace(BaseModel): + """Information about a nearby place.""" + + name: str + types: list[str] + distance: float + vicinity: str + + +class LocationContext(BaseModel): + """Contextual information about a location.""" + + formatted_address: str | None = None + street_number: str | None = None + street: str | None = None + neighborhood: str | None = None + locality: str | None = None + admin_area: str | None = None + country: str | None = None + postal_code: str | None = None + nearby_places: list[NearbyPlace] = [] + place_types_summary: str | None = None + coordinates: Coordinates diff --git a/dimos/mapping/osm/README.md b/dimos/mapping/osm/README.md new file mode 100644 index 0000000000..cb94c0160b --- /dev/null +++ b/dimos/mapping/osm/README.md @@ -0,0 +1,43 @@ +# OpenStreetMap (OSM) + +This provides functionality to fetch and work with OpenStreetMap tiles, including coordinate conversions and location-based VLM queries. + +## Getting a MapImage + +```python +map_image = get_osm_map(LatLon(lat=..., lon=...), zoom_level=18, n_tiles=4)` +``` + +OSM tiles are 256x256 pixels so with 4 tiles you get a 1024x1024 map. + +You can translate pixel coordinates on the map to GPS location and back. + +```python +>>> map_image.pixel_to_latlon((300, 500)) +LatLon(lat=43.58571248, lon=12.23423511) +>>> map_image.latlon_to_pixel(LatLon(lat=43.58571248, lon=12.23423511)) +(300, 500) +``` + +## CurrentLocationMap + +This class maintains an appropriate context map for your current location so you can VLM queries. + +You have to update it with your current location and when you stray too far from the center it fetches a new map. + +```python +curr_map = CurrentLocationMap(QwenVlModel()) + +# Set your latest position. +curr_map.update_position(LatLon(lat=..., lon=...)) + +# If you want to get back a GPS position of a feature (Qwen gets your current position). +curr_map.query_for_one_position('Where is the closest farmacy?') +# Returns: +# LatLon(lat=..., lon=...) + +# If you also want to get back a description of the result. +curr_map.query_for_one_position_and_context('Where is the closest pharmacy?') +# Returns: +# (LatLon(lat=..., lon=...), "Lloyd's Pharmacy on Main Street") +``` diff --git a/dimos/manipulation/classical/classical_manipulation.py b/dimos/mapping/osm/__init__.py similarity index 100% rename from dimos/manipulation/classical/classical_manipulation.py rename to dimos/mapping/osm/__init__.py diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py new file mode 100644 index 0000000000..cfd9f4281e --- /dev/null +++ b/dimos/mapping/osm/current_location_map.py @@ -0,0 +1,77 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.mapping.osm.osm import MapImage, get_osm_map +from dimos.mapping.osm.query import query_for_one_position, query_for_one_position_and_context +from dimos.mapping.types import LatLon +from dimos.models.vl.base import VlModel +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class CurrentLocationMap: + _vl_model: VlModel + _position: LatLon | None + _map_image: MapImage | None + + def __init__(self, vl_model: VlModel) -> None: + self._vl_model = vl_model + self._position = None + self._map_image = None + self._zoom_level = 19 + self._n_tiles = 6 + # What ratio of the width is considered the center. 1.0 means the entire map is the center. + self._center_width = 0.4 + + def update_position(self, position: LatLon) -> None: + self._position = position + + def query_for_one_position(self, query: str) -> LatLon | None: + return query_for_one_position(self._vl_model, self._get_current_map(), query) # type: ignore[no-untyped-call] + + def query_for_one_position_and_context( + self, query: str, robot_position: LatLon + ) -> tuple[LatLon, str] | None: + return query_for_one_position_and_context( + self._vl_model, + self._get_current_map(), # type: ignore[no-untyped-call] + query, + robot_position, + ) + + def _get_current_map(self): # type: ignore[no-untyped-def] + if not self._position: + raise ValueError("Current position has not been set.") + + if not self._map_image or self._position_is_too_far_off_center(): + self._fetch_new_map() + return self._map_image + + return self._map_image + + def _fetch_new_map(self) -> None: + logger.info( + f"Getting a new OSM map, position={self._position}, zoom={self._zoom_level} n_tiles={self._n_tiles}" + ) + self._map_image = get_osm_map(self._position, self._zoom_level, self._n_tiles) # type: ignore[arg-type] + + def _position_is_too_far_off_center(self) -> bool: + x, y = self._map_image.latlon_to_pixel(self._position) # type: ignore[arg-type, union-attr] + width = self._map_image.image.width # type: ignore[union-attr] + size_min = width * (0.5 - self._center_width / 2) + size_max = width * (0.5 + self._center_width / 2) + + return x < size_min or x > size_max or y < size_min or y > size_max diff --git a/dimos/mapping/osm/demo_osm.py b/dimos/mapping/osm/demo_osm.py new file mode 100644 index 0000000000..2afb0b29a5 --- /dev/null +++ b/dimos/mapping/osm/demo_osm.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dotenv import load_dotenv + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.demo_robot import demo_robot +from dimos.agents2.skills.osm import osm_skill +from dimos.agents2.system_prompt import get_system_prompt +from dimos.core.blueprints import autoconnect + +load_dotenv() + + +demo_osm = autoconnect( + demo_robot(), + osm_skill(), + human_input(), + llm_agent(system_prompt=get_system_prompt()), +) diff --git a/dimos/mapping/osm/osm.py b/dimos/mapping/osm/osm.py new file mode 100644 index 0000000000..aecfc8cf25 --- /dev/null +++ b/dimos/mapping/osm/osm.py @@ -0,0 +1,183 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +import io +import math + +import numpy as np +from PIL import Image as PILImage +import requests + +from dimos.mapping.types import ImageCoord, LatLon +from dimos.msgs.sensor_msgs import Image, ImageFormat + + +@dataclass(frozen=True) +class MapImage: + image: Image + position: LatLon + zoom_level: int + n_tiles: int + + def pixel_to_latlon(self, position: ImageCoord) -> LatLon: + """Convert pixel coordinates to latitude/longitude. + + Args: + position: (x, y) pixel coordinates in the image + + Returns: + LatLon object with the corresponding latitude and longitude + """ + pixel_x, pixel_y = position + tile_size = 256 + + # Get the center tile coordinates + center_tile_x, center_tile_y = _lat_lon_to_tile( + self.position.lat, self.position.lon, self.zoom_level + ) + + # Calculate the actual top-left tile indices (integers) + start_tile_x = int(center_tile_x - self.n_tiles / 2.0) + start_tile_y = int(center_tile_y - self.n_tiles / 2.0) + + # Convert pixel position to exact tile coordinates + tile_x = start_tile_x + pixel_x / tile_size + tile_y = start_tile_y + pixel_y / tile_size + + # Convert tile coordinates to lat/lon + n = 2**self.zoom_level + lon = tile_x / n * 360.0 - 180.0 + lat_rad = math.atan(math.sinh(math.pi * (1 - 2 * tile_y / n))) + lat = math.degrees(lat_rad) + + return LatLon(lat=lat, lon=lon) + + def latlon_to_pixel(self, position: LatLon) -> ImageCoord: + """Convert latitude/longitude to pixel coordinates. + + Args: + position: LatLon object with latitude and longitude + + Returns: + (x, y) pixel coordinates in the image + Note: Can return negative values if position is outside the image bounds + """ + tile_size = 256 + + # Convert the input lat/lon to tile coordinates + tile_x, tile_y = _lat_lon_to_tile(position.lat, position.lon, self.zoom_level) + + # Get the center tile coordinates + center_tile_x, center_tile_y = _lat_lon_to_tile( + self.position.lat, self.position.lon, self.zoom_level + ) + + # Calculate the actual top-left tile indices (integers) + start_tile_x = int(center_tile_x - self.n_tiles / 2.0) + start_tile_y = int(center_tile_y - self.n_tiles / 2.0) + + # Calculate pixel position relative to top-left corner + pixel_x = int((tile_x - start_tile_x) * tile_size) + pixel_y = int((tile_y - start_tile_y) * tile_size) + + return (pixel_x, pixel_y) + + +def _lat_lon_to_tile(lat: float, lon: float, zoom: int) -> tuple[float, float]: + """Convert latitude/longitude to tile coordinates at given zoom level.""" + n = 2**zoom + x_tile = (lon + 180.0) / 360.0 * n + lat_rad = math.radians(lat) + y_tile = (1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n + return x_tile, y_tile + + +def _download_tile( + args: tuple[int, int, int, int, int], +) -> tuple[int, int, PILImage.Image | None]: + """Download a single tile. + + Args: + args: Tuple of (row, col, tile_x, tile_y, zoom_level) + + Returns: + Tuple of (row, col, tile_image or None if failed) + """ + row, col, tile_x, tile_y, zoom_level = args + url = f"https://tile.openstreetmap.org/{zoom_level}/{tile_x}/{tile_y}.png" + headers = {"User-Agent": "Dimos OSM Client/1.0"} + + try: + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + tile_img = PILImage.open(io.BytesIO(response.content)) + return row, col, tile_img + except Exception: + return row, col, None + + +def get_osm_map(position: LatLon, zoom_level: int = 18, n_tiles: int = 4) -> MapImage: + """ + Tiles are always 256x256 pixels. With n_tiles=4, this should produce a 1024x1024 image. + Downloads tiles in parallel with a maximum of 5 concurrent downloads. + + Args: + position (LatLon): center position + zoom_level (int, optional): Defaults to 18. + n_tiles (int, optional): generate a map of n_tiles by n_tiles. + """ + center_x, center_y = _lat_lon_to_tile(position.lat, position.lon, zoom_level) + + start_x = int(center_x - n_tiles / 2.0) + start_y = int(center_y - n_tiles / 2.0) + + tile_size = 256 + output_size = tile_size * n_tiles + output_img = PILImage.new("RGB", (output_size, output_size)) + + n_failed_tiles = 0 + + # Prepare all tile download tasks + download_tasks = [] + for row in range(n_tiles): + for col in range(n_tiles): + tile_x = start_x + col + tile_y = start_y + row + download_tasks.append((row, col, tile_x, tile_y, zoom_level)) + + # Download tiles in parallel with max 5 workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(_download_tile, task) for task in download_tasks] + + for future in as_completed(futures): + row, col, tile_img = future.result() + + if tile_img is not None: + paste_x = col * tile_size + paste_y = row * tile_size + output_img.paste(tile_img, (paste_x, paste_y)) + else: + n_failed_tiles += 1 + + if n_failed_tiles > 3: + raise ValueError("Failed to download all tiles for the requested map.") + + return MapImage( + image=Image.from_numpy(np.array(output_img), format=ImageFormat.RGB), + position=position, + zoom_level=zoom_level, + n_tiles=n_tiles, + ) diff --git a/dimos/mapping/osm/query.py b/dimos/mapping/osm/query.py new file mode 100644 index 0000000000..fd6e3694f6 --- /dev/null +++ b/dimos/mapping/osm/query.py @@ -0,0 +1,54 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from dimos.mapping.osm.osm import MapImage +from dimos.mapping.types import LatLon +from dimos.models.vl.base import VlModel +from dimos.utils.generic import extract_json_from_llm_response +from dimos.utils.logging_config import setup_logger + +_PROLOGUE = "This is an image of an open street map I'm on." +_JSON = "Please only respond with valid JSON." +logger = setup_logger() + + +def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) -> LatLon | None: + full_query = f"{_PROLOGUE} {query} {_JSON} If there's a match return the x, y coordinates from the image. Example: `[123, 321]`. If there's no match return `null`." + response = vl_model.query(map_image.image.data, full_query) + coords = tuple(map(int, re.findall(r"\d+", response))) + if len(coords) != 2: + return None + return map_image.pixel_to_latlon(coords) + + +def query_for_one_position_and_context( + vl_model: VlModel, map_image: MapImage, query: str, robot_position: LatLon +) -> tuple[LatLon, str] | None: + example = '{"coordinates": [123, 321], "description": "A Starbucks on 27th Street"}' + x, y = map_image.latlon_to_pixel(robot_position) + my_location = f"I'm currently at x={x}, y={y}." + full_query = f"{_PROLOGUE} {my_location} {query} {_JSON} If there's a match return the x, y coordinates from the image and what is there. Example response: `{example}`. If there's no match return `null`." + logger.info(f"Qwen query: `{full_query}`") + response = vl_model.query(map_image.image.data, full_query) + + try: + doc = extract_json_from_llm_response(response) + return map_image.pixel_to_latlon(tuple(doc["coordinates"])), str(doc["description"]) + except Exception: + pass + + # TODO: Try more simplictic methods to parse. + return None diff --git a/dimos/mapping/osm/test_osm.py b/dimos/mapping/osm/test_osm.py new file mode 100644 index 0000000000..475e2b40fc --- /dev/null +++ b/dimos/mapping/osm/test_osm.py @@ -0,0 +1,71 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Generator +from typing import Any + +import cv2 +import numpy as np +import pytest +from requests import Request +import requests_mock + +from dimos.mapping.osm.osm import get_osm_map +from dimos.mapping.types import LatLon +from dimos.utils.data import get_data + +_fixture_dir = get_data("osm_map_test") + + +def _tile_callback(request: Request, context: Any) -> bytes: + parts = (request.url or "").split("/") + zoom, x, y_png = parts[-3], parts[-2], parts[-1] + y = y_png.removesuffix(".png") + tile_path = _fixture_dir / f"{zoom}_{x}_{y}.png" + context.headers["Content-Type"] = "image/png" + return tile_path.read_bytes() + + +@pytest.fixture +def mock_openstreetmap_org() -> Generator[None, None, None]: + with requests_mock.Mocker() as m: + m.get(requests_mock.ANY, content=_tile_callback) + yield + + +def test_get_osm_map(mock_openstreetmap_org: None) -> None: + position = LatLon(lat=37.751857, lon=-122.431265) + map_image = get_osm_map(position, 18, 4) + + assert map_image.position == position + assert map_image.n_tiles == 4 + + expected_image = cv2.imread(str(_fixture_dir / "full.png")) + expected_image_rgb = cv2.cvtColor(expected_image, cv2.COLOR_BGR2RGB) + assert np.array_equal(map_image.image.data, expected_image_rgb), "Map is not the same." + + +def test_pixel_to_latlon(mock_openstreetmap_org: None) -> None: + position = LatLon(lat=37.751857, lon=-122.431265) + map_image = get_osm_map(position, 18, 4) + latlon = map_image.pixel_to_latlon((100, 100)) + assert abs(latlon.lat - 37.7540056) < 0.0000001 + assert abs(latlon.lon - (-122.43385076)) < 0.0000001 + + +def test_latlon_to_pixel(mock_openstreetmap_org: None) -> None: + position = LatLon(lat=37.751857, lon=-122.431265) + map_image = get_osm_map(position, 18, 4) + coords = map_image.latlon_to_pixel(LatLon(lat=37.751, lon=-122.431)) + assert coords == (631, 808) diff --git a/dimos/mapping/types.py b/dimos/mapping/types.py new file mode 100644 index 0000000000..9584e8e8ba --- /dev/null +++ b/dimos/mapping/types.py @@ -0,0 +1,27 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import TypeAlias + + +@dataclass(frozen=True) +class LatLon: + lat: float + lon: float + alt: float | None = None + + +ImageCoord: TypeAlias = tuple[int, int] diff --git a/dimos/mapping/utils/distance.py b/dimos/mapping/utils/distance.py new file mode 100644 index 0000000000..6e8c48c205 --- /dev/null +++ b/dimos/mapping/utils/distance.py @@ -0,0 +1,48 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from dimos.mapping.types import LatLon + + +def distance_in_meters(location1: LatLon, location2: LatLon) -> float: + """Calculate the great circle distance between two points on Earth using Haversine formula. + + Args: + location1: First location with latitude and longitude + location2: Second location with latitude and longitude + + Returns: + Distance in meters between the two points + """ + # Earth's radius in meters + EARTH_RADIUS_M = 6371000 + + # Convert degrees to radians + lat1_rad = math.radians(location1.lat) + lat2_rad = math.radians(location2.lat) + lon1_rad = math.radians(location1.lon) + lon2_rad = math.radians(location2.lon) + + # Haversine formula + dlat = lat2_rad - lat1_rad + dlon = lon2_rad - lon1_rad + + a = math.sin(dlat / 2) ** 2 + math.cos(lat1_rad) * math.cos(lat2_rad) * math.sin(dlon / 2) ** 2 + c = 2 * math.asin(math.sqrt(a)) + + distance = EARTH_RADIUS_M * c + + return distance diff --git a/dimos/models/Detic/.gitignore b/dimos/models/Detic/.gitignore new file mode 100644 index 0000000000..b794d988fb --- /dev/null +++ b/dimos/models/Detic/.gitignore @@ -0,0 +1,62 @@ +third_party/detectron2 +./models +configs-experimental +experiments +# output dir +index.html +data/* +slurm/ +slurm +slurm-output +slurm-output/ +output +instant_test_output +inference_test_output + + +*.png +*.diff +*.jpg +!/projects/DensePose/doc/images/*.jpg + +# compilation and distribution +__pycache__ +_ext +*.pyc +*.pyd +*.so +*.dll +*.egg-info/ +build/ +dist/ +wheels/ + +# pytorch/python/numpy formats +*.pth +*.pkl +*.ts +model_ts*.txt + +# ipython/jupyter notebooks +*.ipynb +**/.ipynb_checkpoints/ + +# Editor temporaries +*.swn +*.swo +*.swp +*~ + +# editor settings +.idea +.vscode +_darcs + +# project dirs +/detectron2/model_zoo/configs +/datasets/* +!/datasets/*.* +!/datasets/metadata +/projects/*/datasets +/models +/snippet diff --git a/dimos/models/Detic/.gitmodules b/dimos/models/Detic/.gitmodules new file mode 100644 index 0000000000..d945b4731e --- /dev/null +++ b/dimos/models/Detic/.gitmodules @@ -0,0 +1,6 @@ +[submodule "third_party/Deformable-DETR"] + path = third_party/Deformable-DETR + url = https://github.com/fundamentalvision/Deformable-DETR.git +[submodule "third_party/CenterNet2"] + path = third_party/CenterNet2 + url = https://github.com/xingyizhou/CenterNet2.git diff --git a/dimos/models/Detic/CODE_OF_CONDUCT.md b/dimos/models/Detic/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..0f7ad8bfc1 --- /dev/null +++ b/dimos/models/Detic/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. +Please read the [full text](https://code.fb.com/codeofconduct/) +so that you can understand what actions will and will not be tolerated. diff --git a/dimos/models/Detic/CONTRIBUTING.md b/dimos/models/Detic/CONTRIBUTING.md new file mode 100644 index 0000000000..282a20270b --- /dev/null +++ b/dimos/models/Detic/CONTRIBUTING.md @@ -0,0 +1,39 @@ +# Contributing to Detic +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process +Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Coding Style +* 4 spaces for indentation rather than tabs +* 80 character line length +* PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) + +## License +By contributing to Detic, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/dimos/models/Detic/LICENSE b/dimos/models/Detic/LICENSE new file mode 100644 index 0000000000..cd1b070674 --- /dev/null +++ b/dimos/models/Detic/LICENSE @@ -0,0 +1,202 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + +(a) You must give any other recipients of the Work or +Derivative Works a copy of this License; and + +(b) You must cause any modified files to carry prominent notices +stating that You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works +that You distribute, all copyright, patent, trademark, and +attribution notices from the Source form of the Work, +excluding those notices that do not pertain to any part of +the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its +distribution, then any Derivative Works that You distribute must +include a readable copy of the attribution notices contained +within such NOTICE file, excluding those notices that do not +pertain to any part of the Derivative Works, in at least one +of the following places: within a NOTICE text file distributed +as part of the Derivative Works; within the Source form or +documentation, if provided along with the Derivative Works; or, +within a display generated by the Derivative Works, if and +wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and +do not modify the License. You may add Your own attribution +notices within Derivative Works that You distribute, alongside +or as an addendum to the NOTICE text from the Work, provided +that such additional attribution notices cannot be construed +as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/dimos/models/Detic/README.md b/dimos/models/Detic/README.md new file mode 100644 index 0000000000..3a1285cbc9 --- /dev/null +++ b/dimos/models/Detic/README.md @@ -0,0 +1,116 @@ +# Detecting Twenty-thousand Classes using Image-level Supervision + +**Detic**: A **Det**ector with **i**mage **c**lasses that can use image-level labels to easily train detectors. + +

+ +> [**Detecting Twenty-thousand Classes using Image-level Supervision**](http://arxiv.org/abs/2201.02605), +> Xingyi Zhou, Rohit Girdhar, Armand Joulin, Philipp Krähenbühl, Ishan Misra, +> *ECCV 2022 ([arXiv 2201.02605](http://arxiv.org/abs/2201.02605))* + + +## Features + +- Detects **any** class given class names (using [CLIP](https://github.com/openai/CLIP)). + +- We train the detector on ImageNet-21K dataset with 21K classes. + +- Cross-dataset generalization to OpenImages and Objects365 **without finetuning**. + +- State-of-the-art results on Open-vocabulary LVIS and Open-vocabulary COCO. + +- Works for DETR-style detectors. + + +## Installation + +See [installation instructions](docs/INSTALL.md). + +## Demo + +**Update April 2022**: we released more real-time models [here](docs/MODEL_ZOO.md#real-time-models). + +Replicate web demo and docker image: [![Replicate](https://replicate.com/facebookresearch/detic/badge)](https://replicate.com/facebookresearch/detic) + + +Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the web demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/Detic) + +Run our demo using Colab (no GPU needed): [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QtTW9-ukX2HKZGvt0QvVGqjuqEykoZKI) + +We use the default detectron2 [demo interface](https://github.com/facebookresearch/detectron2/blob/main/GETTING_STARTED.md). +For example, to run our [21K model](docs/MODEL_ZOO.md#cross-dataset-evaluation) on a [messy desk image](https://web.eecs.umich.edu/~fouhey/fun/desk/desk.jpg) (image credit [David Fouhey](https://web.eecs.umich.edu/~fouhey)) with the lvis vocabulary, run + +~~~ +mkdir models +wget https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth -O models/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth +wget https://eecs.engin.umich.edu/~fouhey/fun/desk/desk.jpg +python demo.py --config-file configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml --input desk.jpg --output out.jpg --vocabulary lvis --opts MODEL.WEIGHTS models/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth +~~~ + +If setup correctly, the output should look like: + +

+ +The same model can run with other vocabularies (COCO, OpenImages, or Objects365), or a **custom vocabulary**. For example: + +~~~ +python demo.py --config-file configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml --input desk.jpg --output out2.jpg --vocabulary custom --custom_vocabulary headphone,webcam,paper,coffe --confidence-threshold 0.3 --opts MODEL.WEIGHTS models/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth +~~~ + +The output should look like: + +

+ +Note that `headphone`, `paper` and `coffe` (typo intended) are **not** LVIS classes. Despite the misspelled class name, our detector can produce a reasonable detection for `coffe`. + +## Benchmark evaluation and training + +Please first [prepare datasets](datasets/README.md), then check our [MODEL ZOO](docs/MODEL_ZOO.md) to reproduce results in our paper. We highlight key results below: + +- Open-vocabulary LVIS + + | | mask mAP | mask mAP_novel | + |-----------------------|-----------|-----------------| + |Box-Supervised | 30.2 | 16.4 | + |Detic | 32.4 | 24.9 | + +- Standard LVIS + + | | Detector/ Backbone | mask mAP | mask mAP_rare | + |-----------------------|----------|-----------|-----------------| + |Box-Supervised | CenterNet2-ResNet50 | 31.5 | 25.6 | + |Detic | CenterNet2-ResNet50 | 33.2 | 29.7 | + |Box-Supervised | CenterNet2-SwinB | 40.7 | 35.9 | + |Detic | CenterNet2-SwinB | 41.7 | 41.7 | + + | | Detector/ Backbone | box mAP | box mAP_rare | + |-----------------------|----------|-----------|-----------------| + |Box-Supervised | DeformableDETR-ResNet50 | 31.7 | 21.4 | + |Detic | DeformableDETR-ResNet50 | 32.5 | 26.2 | + +- Cross-dataset generalization + + | | Backbone | Objects365 box mAP | OpenImages box mAP50 | + |-----------------------|----------|-----------|-----------------| + |Box-Supervised | SwinB | 19.1 | 46.2 | + |Detic | SwinB | 21.4 | 55.2 | + + +## License + +The majority of Detic is licensed under the [Apache 2.0 license](LICENSE), however portions of the project are available under separate license terms: SWIN-Transformer, CLIP, and TensorFlow Object Detection API are licensed under the MIT license; UniDet is licensed under the Apache 2.0 license; and the LVIS API is licensed under a [custom license](https://github.com/lvis-dataset/lvis-api/blob/master/LICENSE). If you later add other third party code, please keep this license info updated, and please let us know if that component is licensed under something other than CC-BY-NC, MIT, or CC0 + +## Ethical Considerations +Detic's wide range of detection capabilities may introduce similar challenges to many other visual recognition and open-set recognition methods. +As the user can define arbitrary detection classes, class design and semantics may impact the model output. + +## Citation + +If you find this project useful for your research, please use the following BibTeX entry. + + @inproceedings{zhou2022detecting, + title={Detecting Twenty-thousand Classes using Image-level Supervision}, + author={Zhou, Xingyi and Girdhar, Rohit and Joulin, Armand and Kr{\"a}henb{\"u}hl, Philipp and Misra, Ishan}, + booktitle={ECCV}, + year={2022} + } diff --git a/dimos/models/Detic/cog.yaml b/dimos/models/Detic/cog.yaml new file mode 100644 index 0000000000..3c8a94941e --- /dev/null +++ b/dimos/models/Detic/cog.yaml @@ -0,0 +1,28 @@ +build: + gpu: true + cuda: "10.1" + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "ipython==7.30.1" + - "numpy==1.21.4" + - "torch==1.8.1" + - "torchvision==0.9.1" + - "dataclasses==0.6" + - "opencv-python==4.5.5.62" + - "imageio==2.9.0" + - "ftfy==6.0.3" + - "regex==2021.10.8" + - "tqdm==4.62.3" + - "timm==0.4.12" + - "fasttext==0.9.2" + - "scikit-learn==1.0.2" + - "lvis==0.5.3" + - "nltk==3.6.7" + - "git+https://github.com/openai/CLIP.git" + run: + - pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html + +predict: "predict.py:Predictor" diff --git a/dimos/models/Detic/configs/Base-C2_L_R5021k_640b64_4x.yaml b/dimos/models/Detic/configs/Base-C2_L_R5021k_640b64_4x.yaml new file mode 100644 index 0000000000..eb3c3c0f3b --- /dev/null +++ b/dimos/models/Detic/configs/Base-C2_L_R5021k_640b64_4x.yaml @@ -0,0 +1,82 @@ +MODEL: + META_ARCHITECTURE: "CustomRCNN" + MASK_ON: True + PROPOSAL_GENERATOR: + NAME: "CenterNet" + WEIGHTS: "models/resnet50_miil_21k.pkl" + BACKBONE: + NAME: build_p67_timm_fpn_backbone + TIMM: + BASE_NAME: resnet50_in21k + FPN: + IN_FEATURES: ["layer3", "layer4", "layer5"] + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + ROI_HEADS: + NAME: DeticCascadeROIHeads + IN_FEATURES: ["p3", "p4", "p5"] + IOU_THRESHOLDS: [0.6] + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.02 + NMS_THRESH_TEST: 0.5 + ROI_BOX_CASCADE_HEAD: + IOUS: [0.6, 0.7, 0.8] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + CLS_AGNOSTIC_BBOX_REG: True + MULT_PROPOSAL_SCORE: True + + USE_SIGMOID_CE: True + USE_FED_LOSS: True + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 + CLS_AGNOSTIC_MASK: True + CENTERNET: + NUM_CLASSES: 1203 + REG_WEIGHT: 1. + NOT_NORM_REG: True + ONLY_PROPOSAL: True + WITH_AGN_HM: True + INFERENCE_TH: 0.0001 + PRE_NMS_TOPK_TRAIN: 4000 + POST_NMS_TOPK_TRAIN: 2000 + PRE_NMS_TOPK_TEST: 1000 + POST_NMS_TOPK_TEST: 256 + NMS_TH_TRAIN: 0.9 + NMS_TH_TEST: 0.9 + POS_WEIGHT: 0.5 + NEG_WEIGHT: 0.5 + IGNORE_HIGH_FP: 0.85 +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 + NUM_WORKERS: 8 +TEST: + DETECTIONS_PER_IMAGE: 300 +SOLVER: + LR_SCHEDULER_NAME: "WarmupCosineLR" + CHECKPOINT_PERIOD: 1000000000 + WARMUP_ITERS: 10000 + WARMUP_FACTOR: 0.0001 + USE_CUSTOM_SOLVER: True + OPTIMIZER: "ADAMW" + MAX_ITER: 90000 + IMS_PER_BATCH: 64 + BASE_LR: 0.0002 + CLIP_GRADIENTS: + ENABLED: True +INPUT: + FORMAT: RGB + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 640 +OUTPUT_DIR: "./output/Detic/auto" +EVAL_PROPOSAL_AR: False +VERSION: 2 +FP16: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Base-DeformDETR_L_R50_4x.yaml b/dimos/models/Detic/configs/Base-DeformDETR_L_R50_4x.yaml new file mode 100644 index 0000000000..a689ee5bf3 --- /dev/null +++ b/dimos/models/Detic/configs/Base-DeformDETR_L_R50_4x.yaml @@ -0,0 +1,59 @@ +MODEL: + META_ARCHITECTURE: "DeformableDetr" + WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + MASK_ON: False + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res3", "res4", "res5"] + DETR: + CLS_WEIGHT: 2.0 + GIOU_WEIGHT: 2.0 + L1_WEIGHT: 5.0 + NUM_OBJECT_QUERIES: 300 + DIM_FEEDFORWARD: 1024 + WITH_BOX_REFINE: True + TWO_STAGE: True + NUM_CLASSES: 1203 + USE_FED_LOSS: True +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +SOLVER: + CHECKPOINT_PERIOD: 10000000 + USE_CUSTOM_SOLVER: True + IMS_PER_BATCH: 32 + BASE_LR: 0.0002 + STEPS: (150000,) + MAX_ITER: 180000 + WARMUP_FACTOR: 1.0 + WARMUP_ITERS: 10 + WEIGHT_DECAY: 0.0001 + OPTIMIZER: "ADAMW" + BACKBONE_MULTIPLIER: 0.1 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 0.01 + NORM_TYPE: 2.0 + CUSTOM_MULTIPLIER: 0.1 + CUSTOM_MULTIPLIER_NAME: ['reference_points', 'sampling_offsets'] +INPUT: + FORMAT: "RGB" + MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) + CROP: + ENABLED: True + TYPE: "absolute_range" + SIZE: (384, 600) + CUSTOM_AUG: "DETR" +TEST: + DETECTIONS_PER_IMAGE: 300 +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: False + NUM_WORKERS: 4 + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 +OUTPUT_DIR: "output/Detic/auto" +VERSION: 2 \ No newline at end of file diff --git a/dimos/models/Detic/configs/Base_OVCOCO_C4_1x.yaml b/dimos/models/Detic/configs/Base_OVCOCO_C4_1x.yaml new file mode 100644 index 0000000000..189d03cf58 --- /dev/null +++ b/dimos/models/Detic/configs/Base_OVCOCO_C4_1x.yaml @@ -0,0 +1,31 @@ +MODEL: + META_ARCHITECTURE: "CustomRCNN" + RPN: + PRE_NMS_TOPK_TEST: 6000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "CustomRes5ROIHeads" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + ROI_BOX_HEAD: + CLS_AGNOSTIC_BBOX_REG: True + USE_SIGMOID_CE: True + USE_ZEROSHOT_CLS: True + ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/coco_clip_a+cname.npy' + IGNORE_ZERO_CATS: True + CAT_FREQ_PATH: 'datasets/coco/zero-shot/instances_train2017_seen_2_oriorder_cat_info.json' +DATASETS: + TRAIN: ("coco_zeroshot_train_oriorder",) + TEST: ("coco_generalized_zeroshot_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 + CHECKPOINT_PERIOD: 1000000000 +INPUT: + MIN_SIZE_TRAIN: (800,) +VERSION: 2 +OUTPUT_DIR: output/Detic-COCO/auto +FP16: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_CXT21k_640b32_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_CXT21k_640b32_4x.yaml new file mode 100644 index 0000000000..7064a02100 --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_CXT21k_640b32_4x.yaml @@ -0,0 +1,17 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + WEIGHTS: '' + TIMM: + BASE_NAME: convnext_tiny_21k + OUT_LEVELS: [2, 3, 4] + PRETRAINED: True + FPN: + IN_FEATURES: ["layer2", "layer3", "layer4"] +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 +DATASETS: + TRAIN: ("lvis_v1_train+coco",) \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_R18_640b32_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_R18_640b32_4x.yaml new file mode 100644 index 0000000000..07535ee960 --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_R18_640b32_4x.yaml @@ -0,0 +1,14 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + WEIGHTS: '' + TIMM: + BASE_NAME: resnet18 + PRETRAINED: True +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 +DATASETS: + TRAIN: ("lvis_v1_train+coco",) \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_R5021k_640b64_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_R5021k_640b64_4x.yaml new file mode 100644 index 0000000000..8b5ae72d95 --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_R5021k_640b64_4x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True +DATASETS: + TRAIN: ("lvis_v1_train+coco",) \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_SwinB_896b32_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_SwinB_896b32_4x.yaml new file mode 100644 index 0000000000..39ee45ac96 --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_LCOCO_CLIP_SwinB_896b32_4x.yaml @@ -0,0 +1,19 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + WEIGHTS: "models/swin_base_patch4_window7_224_22k.pkl" + BACKBONE: + NAME: build_swintransformer_fpn_backbone + SWIN: + SIZE: B-22k + FPN: + IN_FEATURES: ["swin1", "swin2", "swin3"] +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 +INPUT: + TRAIN_SIZE: 896 +DATASETS: + TRAIN: ("lvis_v1_train+coco",) \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_L_CLIP_R5021k_640b64_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_L_CLIP_R5021k_640b64_4x.yaml new file mode 100644 index 0000000000..91a25ee2ad --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_L_CLIP_R5021k_640b64_4x.yaml @@ -0,0 +1,4 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_L_CLIP_SwinB_896b32_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_L_CLIP_SwinB_896b32_4x.yaml new file mode 100644 index 0000000000..bf6e93a830 --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_L_CLIP_SwinB_896b32_4x.yaml @@ -0,0 +1,17 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + WEIGHTS: "models/swin_base_patch4_window7_224_22k.pkl" + BACKBONE: + NAME: build_swintransformer_fpn_backbone + SWIN: + SIZE: B-22k + FPN: + IN_FEATURES: ["swin1", "swin2", "swin3"] +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 +INPUT: + TRAIN_SIZE: 896 \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml new file mode 100644 index 0000000000..a4d73a060f --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True +DATASETS: + TRAIN: ("lvis_v1_train_norare",) \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.yaml b/dimos/models/Detic/configs/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.yaml new file mode 100644 index 0000000000..f271ac558c --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.yaml @@ -0,0 +1,19 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + WEIGHTS: "models/swin_base_patch4_window7_224_22k.pkl" + BACKBONE: + NAME: build_swintransformer_fpn_backbone + SWIN: + SIZE: B-22k + FPN: + IN_FEATURES: ["swin1", "swin2", "swin3"] +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 +INPUT: + TRAIN_SIZE: 896 +DATASETS: + TRAIN: ("lvis_v1_train_norare",) \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-DeformDETR_L_R50_2x.yaml b/dimos/models/Detic/configs/BoxSup-DeformDETR_L_R50_2x.yaml new file mode 100644 index 0000000000..aed66e1fba --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-DeformDETR_L_R50_2x.yaml @@ -0,0 +1,3 @@ +_BASE_: "Base-DeformDETR_L_R50_4x.yaml" +SOLVER: + IMS_PER_BATCH: 16 \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup-DeformDETR_L_R50_4x.yaml b/dimos/models/Detic/configs/BoxSup-DeformDETR_L_R50_4x.yaml new file mode 100644 index 0000000000..a5ee4566ff --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup-DeformDETR_L_R50_4x.yaml @@ -0,0 +1 @@ +_BASE_: "Base-DeformDETR_L_R50_4x.yaml" \ No newline at end of file diff --git a/dimos/models/Detic/configs/BoxSup_OVCOCO_CLIP_R50_1x.yaml b/dimos/models/Detic/configs/BoxSup_OVCOCO_CLIP_R50_1x.yaml new file mode 100644 index 0000000000..b6c977fbac --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup_OVCOCO_CLIP_R50_1x.yaml @@ -0,0 +1 @@ +_BASE_: "Base_OVCOCO_C4_1x.yaml" diff --git a/dimos/models/Detic/configs/BoxSup_ViLD_200e.py b/dimos/models/Detic/configs/BoxSup_ViLD_200e.py new file mode 100644 index 0000000000..b189c7b54f --- /dev/null +++ b/dimos/models/Detic/configs/BoxSup_ViLD_200e.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.config import LazyCall as L +from detectron2.data.samplers import RepeatFactorTrainingSampler +import detectron2.data.transforms as T +from detectron2.evaluation.lvis_evaluation import LVISEvaluator +from detectron2.layers import ShapeSpec +from detectron2.layers.batch_norm import NaiveSyncBatchNorm +from detectron2.model_zoo import get_config +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import FastRCNNConvFCHead +from detectron2.solver import WarmupParamScheduler +from detectron2.solver.build import get_default_optimizer_params +from detic.modeling.roi_heads.detic_fast_rcnn import DeticFastRCNNOutputLayers +from detic.modeling.roi_heads.detic_roi_heads import DeticCascadeROIHeads +from detic.modeling.roi_heads.zero_shot_classifier import ZeroShotClassifier +from fvcore.common.param_scheduler import CosineParamScheduler +import torch + +default_configs = get_config("new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py") +dataloader = default_configs["dataloader"] +model = default_configs["model"] +train = default_configs["train"] + +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=DeticCascadeROIHeads, + num_classes=1203, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm=lambda c: NaiveSyncBatchNorm(c, stats_mode="N"), + ) + for _ in range(1) + ], + box_predictors=[ + L(DeticFastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.0001, + test_topk_per_image=300, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + cls_score=L(ZeroShotClassifier)( + input_shape=ShapeSpec(channels=1024), + num_classes=1203, + zs_weight_path="datasets/metadata/lvis_v1_clip_a+cname.npy", + norm_weight=True, + # use_bias=-4.6, + ), + use_zeroshot_cls=True, + use_sigmoid_ce=True, + ignore_zero_cats=True, + cat_freq_path="datasets/lvis/lvis_v1_train_norare_cat_info.json", + ) + for (w1, w2) in [(10, 5)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) for th in [0.5] + ], +) +model.roi_heads.mask_head.num_classes = 1 + +dataloader.train.dataset.names = "lvis_v1_train_norare" +dataloader.train.sampler = L(RepeatFactorTrainingSampler)( + repeat_factors=L(RepeatFactorTrainingSampler.repeat_factors_from_category_frequency)( + dataset_dicts="${dataloader.train.dataset}", repeat_thresh=0.001 + ) +) +image_size = 896 +dataloader.train.mapper.augmentations = [ + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size)), + L(T.RandomFlip)(horizontal=True), +] +dataloader.train.num_workers = 32 + +dataloader.test.dataset.names = "lvis_v1_val" +dataloader.evaluator = L(LVISEvaluator)( + dataset_name="${..test.dataset.names}", +) + +num_nodes = 4 + +dataloader.train.total_batch_size = 64 * num_nodes +train.max_iter = 184375 * 2 // num_nodes + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=CosineParamScheduler(1.0, 0.0), + warmup_length=500 / train.max_iter, + warmup_factor=0.067, +) + +optimizer = L(torch.optim.AdamW)( + params=L(get_default_optimizer_params)(weight_decay_norm=0.0), + lr=0.0002 * num_nodes, + weight_decay=1e-4, +) + +train.checkpointer.period = 20000 // num_nodes +train.output_dir = f"./output/Lazy/{os.path.basename(__file__)[:-3]}" diff --git a/dimos/models/Detic/configs/Detic_DeformDETR_LI_R50_4x_ft4x.yaml b/dimos/models/Detic/configs/Detic_DeformDETR_LI_R50_4x_ft4x.yaml new file mode 100644 index 0000000000..2da679cd4a --- /dev/null +++ b/dimos/models/Detic/configs/Detic_DeformDETR_LI_R50_4x_ft4x.yaml @@ -0,0 +1,22 @@ +_BASE_: "Base-DeformDETR_L_R50_4x.yaml" +MODEL: + WEIGHTS: "models/BoxSup-DeformDETR_L_R50_4x.pth" +INPUT: + CUSTOM_AUG: ResizeShortestEdge + MIN_SIZE_TRAIN_SAMPLING: range + MIN_SIZE_TRAIN: [480, 800] +DATASETS: + TRAIN: ("lvis_v1_train","imagenet_lvis_v1") + TEST: ("lvis_v1_val",) +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + USE_RFS: [True, False] + DATASET_MIN_SIZES: [[480, 800], [240, 400]] + DATASET_MAX_SIZES: [1333, 667] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] +WITH_IMAGE_LABELS: True diff --git a/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_CXT21k_640b32_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_CXT21k_640b32_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..8c5befdbdc --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_CXT21k_640b32_4x_ft4x_max-size.yaml @@ -0,0 +1,39 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + DYNAMIC_CLASSIFIER: True + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis-21k_clip_a+cname.npy' + USE_FED_LOSS: False # Federated loss is enabled when DYNAMIC_CLASSIFIER is on + ROI_HEADS: + NUM_CLASSES: 22047 + WEIGHTS: "output/Detic/BoxSup-C2_LCOCO_CLIP_CXT21k_640b32_4x/model_final.pth" + TIMM: + BASE_NAME: convnext_tiny_21k + OUT_LEVELS: [2, 3, 4] + PRETRAINED: True + FPN: + IN_FEATURES: ["layer2", "layer3", "layer4"] +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train+coco","imagenet_lvis-22k") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 2 + USE_TAR_DATASET: True +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_R18_640b32_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_R18_640b32_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..e57e579dfd --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_R18_640b32_4x_ft4x_max-size.yaml @@ -0,0 +1,36 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + DYNAMIC_CLASSIFIER: True + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis-21k_clip_a+cname.npy' + USE_FED_LOSS: False # Federated loss is enabled when DYNAMIC_CLASSIFIER is on + ROI_HEADS: + NUM_CLASSES: 22047 + WEIGHTS: "output/Detic/BoxSup-C2_LCOCO_CLIP_R18_640b64_4x/model_final.pth" + TIMM: + BASE_NAME: resnet18 + PRETRAINED: True +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train+coco","imagenet_lvis-22k") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 2 + USE_TAR_DATASET: True +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..3d71d29c2f --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size.yaml @@ -0,0 +1,33 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + DYNAMIC_CLASSIFIER: True + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis-21k_clip_a+cname.npy' + USE_FED_LOSS: False # Federated loss is enabled when DYNAMIC_CLASSIFIER is on + ROI_HEADS: + NUM_CLASSES: 22047 + WEIGHTS: "output/Detic/BoxSup-C2_LCOCO_CLIP_R5021k_640b64_4x/model_final.pth" +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train+coco","imagenet_lvis-22k") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 2 + USE_TAR_DATASET: True +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..a3dba8d072 --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml @@ -0,0 +1,43 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + WEIGHTS: "models/BoxSup-C2_LCOCO_CLIP_SwinB_896b32_4x.pth" + DYNAMIC_CLASSIFIER: True + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis-21k_clip_a+cname.npy' + USE_FED_LOSS: False # Federated loss is enabled when DYNAMIC_CLASSIFIER is on + ROI_HEADS: + NUM_CLASSES: 22047 + BACKBONE: + NAME: build_swintransformer_fpn_backbone + SWIN: + SIZE: B-22k + FPN: + IN_FEATURES: ["swin1", "swin2", "swin3"] + RESET_CLS_TESTS: True + TEST_CLASSIFIERS: ("datasets/metadata/oid_clip_a+cname.npy","datasets/metadata/o365_clip_a+cnamefix.npy") + TEST_NUM_CLASSES: [500, 365] +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train+coco","imagenet_lvis-22k") + TEST: ('oid_val_expanded', 'objects365_v2_val') +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 16] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + DATASET_INPUT_SIZE: [896, 448] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 4 + USE_TAR_DATASET: True +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..3b8633caac --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml @@ -0,0 +1,43 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + WEIGHTS: "models/BoxSup-C2_L_CLIP_SwinB_896b32_4x.pth" + DYNAMIC_CLASSIFIER: True + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis-21k_clip_a+cname.npy' + USE_FED_LOSS: False # Federated loss is enabled when DYNAMIC_CLASSIFIER is on + ROI_HEADS: + NUM_CLASSES: 22047 + BACKBONE: + NAME: build_swintransformer_fpn_backbone + SWIN: + SIZE: B-22k + FPN: + IN_FEATURES: ["swin1", "swin2", "swin3"] + RESET_CLS_TESTS: True + TEST_CLASSIFIERS: ("datasets/metadata/oid_clip_a+cname.npy","datasets/metadata/o365_clip_a+cnamefix.npy") + TEST_NUM_CLASSES: [500, 365] +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train","imagenet_lvis-22k") + TEST: ('oid_val_expanded', 'objects365_v2_val') +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 16] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + DATASET_INPUT_SIZE: [896, 448] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 4 + USE_TAR_DATASET: True +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..ca93318e64 --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml @@ -0,0 +1,27 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + WEIGHTS: "models/BoxSup-C2_L_CLIP_R5021k_640b64_4x.pth" +SOLVER: + MAX_ITER: 90000 + IMS_PER_BATCH: 64 + BASE_LR: 0.0002 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train","imagenet_lvis_v1") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [8, 32] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..57ffa48ce6 --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml @@ -0,0 +1,33 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + BACKBONE: + NAME: build_swintransformer_fpn_backbone + SWIN: + SIZE: B-22k + FPN: + IN_FEATURES: ["swin1", "swin2", "swin3"] + WEIGHTS: "models/BoxSup-C2_L_CLIP_SwinB_896b32_4x.pth" +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train","imagenet_lvis_v1") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + DATASET_INPUT_SIZE: [896, 448] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LbaseCCcapimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LbaseCCcapimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..ada6ffed06 --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LbaseCCcapimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml @@ -0,0 +1,30 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + WITH_CAPTION: True + SYNC_CAPTION_BATCH: True + ROI_BOX_HEAD: + ADD_IMAGE_BOX: True # caption loss is added to the image-box + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.pth" +SOLVER: + MAX_ITER: 90000 + IMS_PER_BATCH: 64 + BASE_LR: 0.0002 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train_norare","cc3m_v1_train_tags") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [8, 32] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'captiontag'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LbaseCCimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LbaseCCimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..aadcbc0ccd --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LbaseCCimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml @@ -0,0 +1,27 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.pth" +SOLVER: + MAX_ITER: 90000 + IMS_PER_BATCH: 64 + BASE_LR: 0.0002 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train_norare","cc3m_v1_train_tags") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [8, 32] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..3ef1e9a02a --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml @@ -0,0 +1,27 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.pth" +SOLVER: + MAX_ITER: 90000 + IMS_PER_BATCH: 64 + BASE_LR: 0.0002 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [8, 32] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_predicted.yaml b/dimos/models/Detic/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_predicted.yaml new file mode 100644 index 0000000000..9d6f1b350f --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_predicted.yaml @@ -0,0 +1,27 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_score' + WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.pth" +SOLVER: + MAX_ITER: 90000 + IMS_PER_BATCH: 64 + BASE_LR: 0.0002 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [8, 32] + DATASET_INPUT_SIZE: [640, 320] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_LbaseI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml b/dimos/models/Detic/configs/Detic_LbaseI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml new file mode 100644 index 0000000000..b25e2b6651 --- /dev/null +++ b/dimos/models/Detic/configs/Detic_LbaseI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml @@ -0,0 +1,33 @@ +_BASE_: "Base-C2_L_R5021k_640b64_4x.yaml" +MODEL: + ROI_BOX_HEAD: + USE_ZEROSHOT_CLS: True + IMAGE_LABEL_LOSS: 'max_size' + BACKBONE: + NAME: build_swintransformer_fpn_backbone + SWIN: + SIZE: B-22k + FPN: + IN_FEATURES: ["swin1", "swin2", "swin3"] + WEIGHTS: "models/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.pth" +SOLVER: + MAX_ITER: 180000 + IMS_PER_BATCH: 32 + BASE_LR: 0.0001 + WARMUP_ITERS: 1000 + WARMUP_FACTOR: 0.001 +DATASETS: + TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [4, 16] + DATASET_INPUT_SIZE: [896, 448] + USE_RFS: [True, False] + DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_caption.yaml b/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_caption.yaml new file mode 100644 index 0000000000..aeafd50d7c --- /dev/null +++ b/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_caption.yaml @@ -0,0 +1,33 @@ +_BASE_: "Base_OVCOCO_C4_1x.yaml" +MODEL: + WEIGHTS: "models/BoxSup_OVCOCO_CLIP_R50_1x.pth" + WITH_CAPTION: True + SYNC_CAPTION_BATCH: True + ROI_BOX_HEAD: + WS_NUM_PROPS: 1 + ADD_IMAGE_BOX: True + NEG_CAP_WEIGHT: 1.0 +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +DATASETS: + TRAIN: ("coco_zeroshot_train_oriorder", "coco_caption_train_tags") +INPUT: + CUSTOM_AUG: ResizeShortestEdge + MIN_SIZE_TRAIN_SAMPLING: range + MIN_SIZE_TRAIN: (800, 800) +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [2, 8] + USE_RFS: [False, False] + DATASET_MIN_SIZES: [[800, 800], [400, 400]] + DATASET_MAX_SIZES: [1333, 667] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'caption'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_max-size.yaml b/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_max-size.yaml new file mode 100644 index 0000000000..8daa4be6bb --- /dev/null +++ b/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_max-size.yaml @@ -0,0 +1,30 @@ +_BASE_: "Base_OVCOCO_C4_1x.yaml" +MODEL: + WEIGHTS: "models/BoxSup_OVCOCO_CLIP_R50_1x.pth" + ROI_BOX_HEAD: + WS_NUM_PROPS: 32 + IMAGE_LABEL_LOSS: 'max_size' +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +DATASETS: + TRAIN: ("coco_zeroshot_train_oriorder", "coco_caption_train_tags") +INPUT: + CUSTOM_AUG: ResizeShortestEdge + MIN_SIZE_TRAIN_SAMPLING: range + MIN_SIZE_TRAIN: (800, 800) +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [2, 8] + USE_RFS: [False, False] + DATASET_MIN_SIZES: [[800, 800], [400, 400]] + DATASET_MAX_SIZES: [1333, 667] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'image'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_max-size_caption.yaml b/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_max-size_caption.yaml new file mode 100644 index 0000000000..3ba0a20a18 --- /dev/null +++ b/dimos/models/Detic/configs/Detic_OVCOCO_CLIP_R50_1x_max-size_caption.yaml @@ -0,0 +1,35 @@ +_BASE_: "Base_OVCOCO_C4_1x.yaml" +MODEL: + WEIGHTS: "models/BoxSup_OVCOCO_CLIP_R50_1x.pth" + WITH_CAPTION: True + SYNC_CAPTION_BATCH: True + ROI_BOX_HEAD: + WS_NUM_PROPS: 32 + ADD_IMAGE_BOX: True # caption loss is added to the image-box + IMAGE_LABEL_LOSS: 'max_size' + + NEG_CAP_WEIGHT: 1.0 +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +DATASETS: + TRAIN: ("coco_zeroshot_train_oriorder", "coco_caption_train_tags") +INPUT: + CUSTOM_AUG: ResizeShortestEdge + MIN_SIZE_TRAIN_SAMPLING: range + MIN_SIZE_TRAIN: (800, 800) +DATALOADER: + SAMPLER_TRAIN: "MultiDatasetSampler" + DATASET_RATIO: [1, 4] + USE_DIFF_BS_SIZE: True + DATASET_BS: [2, 8] + USE_RFS: [False, False] + DATASET_MIN_SIZES: [[800, 800], [400, 400]] + DATASET_MAX_SIZES: [1333, 667] + FILTER_EMPTY_ANNOTATIONS: False + MULTI_DATASET_GROUPING: True + DATASET_ANN: ['box', 'captiontag'] + NUM_WORKERS: 8 +WITH_IMAGE_LABELS: True \ No newline at end of file diff --git a/dimos/models/Detic/configs/Detic_ViLD_200e.py b/dimos/models/Detic/configs/Detic_ViLD_200e.py new file mode 100644 index 0000000000..470124a109 --- /dev/null +++ b/dimos/models/Detic/configs/Detic_ViLD_200e.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.config import LazyCall as L +import detectron2.data.transforms as T +from detectron2.evaluation.lvis_evaluation import LVISEvaluator +from detectron2.layers import ShapeSpec +from detectron2.layers.batch_norm import NaiveSyncBatchNorm +from detectron2.model_zoo import get_config +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.matcher import Matcher +from detectron2.modeling.roi_heads import FastRCNNConvFCHead +from detectron2.solver import WarmupParamScheduler +from detectron2.solver.build import get_default_optimizer_params +from detic.data.custom_dataset_dataloader import ( + MultiDatasetSampler, + build_custom_train_loader, + get_detection_dataset_dicts_with_source, +) +from detic.data.custom_dataset_mapper import CustomDatasetMapper +from detic.modeling.meta_arch.custom_rcnn import CustomRCNN +from detic.modeling.roi_heads.detic_fast_rcnn import DeticFastRCNNOutputLayers +from detic.modeling.roi_heads.detic_roi_heads import DeticCascadeROIHeads +from detic.modeling.roi_heads.zero_shot_classifier import ZeroShotClassifier +from fvcore.common.param_scheduler import CosineParamScheduler +import torch + +default_configs = get_config("new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py") +dataloader = default_configs["dataloader"] +model = default_configs["model"] +train = default_configs["train"] + +train.init_checkpoint = "models/BoxSup_ViLD_200e.pth" + +[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] + +model.roi_heads.update( + _target_=DeticCascadeROIHeads, + num_classes=1203, + box_heads=[ + L(FastRCNNConvFCHead)( + input_shape=ShapeSpec(channels=256, height=7, width=7), + conv_dims=[256, 256, 256, 256], + fc_dims=[1024], + conv_norm=lambda c: NaiveSyncBatchNorm(c, stats_mode="N"), + ) + for _ in range(1) + ], + box_predictors=[ + L(DeticFastRCNNOutputLayers)( + input_shape=ShapeSpec(channels=1024), + test_score_thresh=0.0001, + test_topk_per_image=300, + box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), + cls_agnostic_bbox_reg=True, + num_classes="${...num_classes}", + cls_score=L(ZeroShotClassifier)( + input_shape=ShapeSpec(channels=1024), + num_classes=1203, + zs_weight_path="datasets/metadata/lvis_v1_clip_a+cname.npy", + norm_weight=True, + # use_bias=-4.6, + ), + use_zeroshot_cls=True, + use_sigmoid_ce=True, + ignore_zero_cats=True, + cat_freq_path="datasets/lvis/lvis_v1_train_norare_cat_info.json", + image_label_loss="max_size", + image_loss_weight=0.1, + ) + for (w1, w2) in [(10, 5)] + ], + proposal_matchers=[ + L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) for th in [0.5] + ], + with_image_labels=True, + ws_num_props=128, +) +model.update( + _target_=CustomRCNN, + with_image_labels=True, +) +model.roi_heads.mask_head.num_classes = 1 + +train.ddp.find_unused_parameters = True + +num_nodes = 4 +image_size = 896 +image_size_weak = 448 +dataloader.train = L(build_custom_train_loader)( + dataset=L(get_detection_dataset_dicts_with_source)( + dataset_names=["lvis_v1_train_norare", "imagenet_lvis_v1"], + filter_empty=False, + ), + mapper=L(CustomDatasetMapper)( + is_train=True, + augmentations=[], + with_ann_type=True, + dataset_ann=["box", "image"], + use_diff_bs_size=True, + dataset_augs=[ + [ + L(T.ResizeScale)( + min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size + ), + L(T.FixedSizeCrop)(crop_size=(image_size, image_size)), + L(T.RandomFlip)(horizontal=True), + ], + [ + L(T.ResizeScale)( + min_scale=0.5, + max_scale=1.5, + target_height=image_size_weak, + target_width=image_size_weak, + ), + L(T.FixedSizeCrop)(crop_size=(image_size_weak, image_size_weak)), + L(T.RandomFlip)(horizontal=True), + ], + ], + image_format="BGR", + use_instance_mask=True, + ), + sampler=L(MultiDatasetSampler)( + dataset_dicts="${dataloader.train.dataset}", + dataset_ratio=[1, 4], + use_rfs=[True, False], + dataset_ann="${dataloader.train.mapper.dataset_ann}", + repeat_threshold=0.001, + ), + total_batch_size=64 * num_nodes, + multi_dataset_grouping=True, + use_diff_bs_size=True, + dataset_bs=[8, 8 * 4], + num_datasets=2, + num_workers=8, +) + +dataloader.test.dataset.names = "lvis_v1_val" +dataloader.evaluator = L(LVISEvaluator)( + dataset_name="${..test.dataset.names}", +) + +train.max_iter = 184375 * 2 // num_nodes +lr_multiplier = L(WarmupParamScheduler)( + scheduler=CosineParamScheduler(1.0, 0.0), + warmup_length=500 / train.max_iter, + warmup_factor=0.067, +) + +optimizer = L(torch.optim.AdamW)( + params=L(get_default_optimizer_params)(weight_decay_norm=0.0), + lr=0.0002 * num_nodes, + weight_decay=1e-4, +) + +train.checkpointer.period = 20000 // num_nodes +train.output_dir = f"./output/Lazy/{os.path.basename(__file__)[:-3]}" diff --git a/dimos/models/Detic/datasets/README.md b/dimos/models/Detic/datasets/README.md new file mode 100644 index 0000000000..e9f4a0b3fb --- /dev/null +++ b/dimos/models/Detic/datasets/README.md @@ -0,0 +1,207 @@ +# Prepare datasets for Detic + +The basic training of our model uses [LVIS](https://www.lvisdataset.org/) (which uses [COCO](https://cocodataset.org/) images) and [ImageNet-21K](https://www.image-net.org/download.php). +Some models are trained on [Conceptual Caption (CC3M)](https://ai.google.com/research/ConceptualCaptions/). +Optionally, we use [Objects365](https://www.objects365.org/) and [OpenImages (Challenge 2019 version)](https://storage.googleapis.com/openimages/web/challenge2019.html) for cross-dataset evaluation. +Before starting processing, please download the (selected) datasets from the official websites and place or sim-link them under `$Detic_ROOT/datasets/`. + +``` +$Detic_ROOT/datasets/ + metadata/ + lvis/ + coco/ + imagenet/ + cc3m/ + objects365/ + oid/ +``` +`metadata/` is our preprocessed meta-data (included in the repo). See the below [section](#Metadata) for details. +Please follow the following instruction to pre-process individual datasets. + +### COCO and LVIS + +First, download COCO and LVIS data place them in the following way: + +``` +lvis/ + lvis_v1_train.json + lvis_v1_val.json +coco/ + train2017/ + val2017/ + annotations/ + captions_train2017.json + instances_train2017.json + instances_val2017.json +``` + +Next, prepare the open-vocabulary LVIS training set using + +``` +python tools/remove_lvis_rare.py --ann datasets/lvis/lvis_v1_train.json +``` + +This will generate `datasets/lvis/lvis_v1_train_norare.json`. + +### ImageNet-21K + +The ImageNet-21K folder should look like: +``` +imagenet/ + ImageNet-21K/ + n01593028.tar + n01593282.tar + ... +``` + +We first unzip the overlapping classes of LVIS (we will directly work with the .tar file for the rest classes) and convert them into LVIS annotation format. + +~~~ +mkdir imagenet/annotations +python tools/unzip_imagenet_lvis.py --dst_path datasets/imagenet/ImageNet-LVIS +python tools/create_imagenetlvis_json.py --imagenet_path datasets/imagenet/ImageNet-LVIS --out_path datasets/imagenet/annotations/imagenet_lvis_image_info.json +~~~ +This creates `datasets/imagenet/annotations/imagenet_lvis_image_info.json`. + +[Optional] To train with all the 21K classes, run + +~~~ +python tools/get_imagenet_21k_full_tar_json.py +python tools/create_lvis_21k.py +~~~ +This creates `datasets/imagenet/annotations/imagenet-21k_image_info_lvis-21k.json` and `datasets/lvis/lvis_v1_train_lvis-21k.json` (combined LVIS and ImageNet-21K classes in `categories`). + +[Optional] To train on combined LVIS and COCO, run + +~~~ +python tools/merge_lvis_coco.py +~~~ +This creates `datasets/lvis/lvis_v1_train+coco_mask.json` + +### Conceptual Caption + + +Download the dataset from [this](https://ai.google.com/research/ConceptualCaptions/download) page and place them as: +``` +cc3m/ + GCC-training.tsv +``` + +Run the following command to download the images and convert the annotations to LVIS format (Note: download images takes long). + +~~~ +python tools/download_cc.py --ann datasets/cc3m/GCC-training.tsv --save_image_path datasets/cc3m/training/ --out_path datasets/cc3m/train_image_info.json +python tools/get_cc_tags.py +~~~ + +This creates `datasets/cc3m/train_image_info_tags.json`. + +### Objects365 +Download Objects365 (v2) from the website. We only need the validation set in this project: +``` +objects365/ + annotations/ + zhiyuan_objv2_val.json + val/ + images/ + v1/ + patch0/ + ... + patch15/ + v2/ + patch16/ + ... + patch49/ + +``` + +The original annotation has typos in the class names, we first fix them for our following use of language embeddings. + +``` +python tools/fix_o365_names.py --ann datasets/objects365/annotations/zhiyuan_objv2_val.json +``` +This creates `datasets/objects365/zhiyuan_objv2_val_fixname.json`. + +To train on Objects365, download the training images and use the command above. We note some images in the training annotation do not exist. +We use the following command to filter the missing images. +~~~ +python tools/fix_0365_path.py +~~~ +This creates `datasets/objects365/zhiyuan_objv2_train_fixname_fixmiss.json`. + +### OpenImages + +We followed the instructions in [UniDet](https://github.com/xingyizhou/UniDet/blob/master/docs/DATASETS.md#openimages) to convert the metadata for OpenImages. + +The converted folder should look like + +``` +oid/ + annotations/ + oid_challenge_2019_train_bbox.json + oid_challenge_2019_val_expanded.json + images/ + 0/ + 1/ + 2/ + ... +``` + +### Open-vocabulary COCO + +We first follow [OVR-CNN](https://github.com/alirezazareian/ovr-cnn/blob/master/ipynb/003.ipynb) to create the open-vocabulary COCO split. The converted files should be like + +``` +coco/ + zero-shot/ + instances_train2017_seen_2.json + instances_val2017_all_2.json +``` + +We further pre-process the annotation format for easier evaluation: + +``` +python tools/get_coco_zeroshot_oriorder.py --data_path datasets/coco/zero-shot/instances_train2017_seen_2.json +python tools/get_coco_zeroshot_oriorder.py --data_path datasets/coco/zero-shot/instances_val2017_all_2.json +``` + +Next, we preprocess the COCO caption data: + +``` +python tools/get_cc_tags.py --cc_ann datasets/coco/annotations/captions_train2017.json --out_path datasets/coco/captions_train2017_tags_allcaps.json --allcaps --convert_caption --cat_path datasets/coco/annotations/instances_val2017.json +``` +This creates `datasets/coco/captions_train2017_tags_allcaps.json`. + +### Metadata + +``` +metadata/ + lvis_v1_train_cat_info.json + coco_clip_a+cname.npy + lvis_v1_clip_a+cname.npy + o365_clip_a+cnamefix.npy + oid_clip_a+cname.npy + imagenet_lvis_wnid.txt + Objects365_names_fix.csv +``` + +`lvis_v1_train_cat_info.json` is used by the Federated loss. +This is created by +~~~ +python tools/get_lvis_cat_info.py --ann datasets/lvis/lvis_v1_train.json +~~~ + +`*_clip_a+cname.npy` is the pre-computed CLIP embeddings for each datasets. +They are created by (taking LVIS as an example) +~~~ +python tools/dump_clip_features.py --ann datasets/lvis/lvis_v1_val.json --out_path metadata/lvis_v1_clip_a+cname.npy +~~~ +Note we do not include the 21K class embeddings due to the large file size. +To create it, run +~~~ +python tools/dump_clip_features.py --ann datasets/lvis/lvis_v1_val_lvis-21k.json --out_path datasets/metadata/lvis-21k_clip_a+cname.npy +~~~ + +`imagenet_lvis_wnid.txt` is the list of matched classes between ImageNet-21K and LVIS. + +`Objects365_names_fix.csv` is our manual fix of the Objects365 names. \ No newline at end of file diff --git a/dimos/models/Detic/datasets/metadata/Objects365_names_fix.csv b/dimos/models/Detic/datasets/metadata/Objects365_names_fix.csv new file mode 100644 index 0000000000..c274707cc3 --- /dev/null +++ b/dimos/models/Detic/datasets/metadata/Objects365_names_fix.csv @@ -0,0 +1,365 @@ +1,Person,Person +2,Sneakers,Sneakers +3,Chair,Chair +4,Other Shoes,Other Shoes +5,Hat,Hat +6,Car,Car +7,Lamp,Lamp +8,Glasses,Glasses +9,Bottle,Bottle +10,Desk,Desk +11,Cup,Cup +12,Street Lights,Street Lights +13,Cabinet/shelf,Cabinet/shelf +14,Handbag/Satchel,Handbag/Satchel +15,Bracelet,Bracelet +16,Plate,Plate +17,Picture/Frame,Picture/Frame +18,Helmet,Helmet +19,Book,Book +20,Gloves,Gloves +21,Storage box,Storage box +22,Boat,Boat +23,Leather Shoes,Leather Shoes +24,Flower,Flower +25,Bench,Bench +26,Potted Plant,Potted Plant +27,Bowl/Basin,Bowl/Basin +28,Flag,Flag +29,Pillow,Pillow +30,Boots,Boots +31,Vase,Vase +32,Microphone,Microphone +33,Necklace,Necklace +34,Ring,Ring +35,SUV,SUV +36,Wine Glass,Wine Glass +37,Belt,Belt +38,Moniter/TV,Monitor/TV +39,Backpack,Backpack +40,Umbrella,Umbrella +41,Traffic Light,Traffic Light +42,Speaker,Speaker +43,Watch,Watch +44,Tie,Tie +45,Trash bin Can,Trash bin Can +46,Slippers,Slippers +47,Bicycle,Bicycle +48,Stool,Stool +49,Barrel/bucket,Barrel/bucket +50,Van,Van +51,Couch,Couch +52,Sandals,Sandals +53,Bakset,Basket +54,Drum,Drum +55,Pen/Pencil,Pen/Pencil +56,Bus,Bus +57,Wild Bird,Wild Bird +58,High Heels,High Heels +59,Motorcycle,Motorcycle +60,Guitar,Guitar +61,Carpet,Carpet +62,Cell Phone,Cell Phone +63,Bread,Bread +64,Camera,Camera +65,Canned,Canned +66,Truck,Truck +67,Traffic cone,Traffic cone +68,Cymbal,Cymbal +69,Lifesaver,Lifesaver +70,Towel,Towel +71,Stuffed Toy,Stuffed Toy +72,Candle,Candle +73,Sailboat,Sailboat +74,Laptop,Laptop +75,Awning,Awning +76,Bed,Bed +77,Faucet,Faucet +78,Tent,Tent +79,Horse,Horse +80,Mirror,Mirror +81,Power outlet,Power outlet +82,Sink,Sink +83,Apple,Apple +84,Air Conditioner,Air Conditioner +85,Knife,Knife +86,Hockey Stick,Hockey Stick +87,Paddle,Paddle +88,Pickup Truck,Pickup Truck +89,Fork,Fork +90,Traffic Sign,Traffic Sign +91,Ballon,Ballon +92,Tripod,Tripod +93,Dog,Dog +94,Spoon,Spoon +95,Clock,Clock +96,Pot,Pot +97,Cow,Cow +98,Cake,Cake +99,Dinning Table,Dining Table +100,Sheep,Sheep +101,Hanger,Hanger +102,Blackboard/Whiteboard,Blackboard/Whiteboard +103,Napkin,Napkin +104,Other Fish,Other Fish +105,Orange/Tangerine,Orange/Tangerine +106,Toiletry,Toiletry +107,Keyboard,Keyboard +108,Tomato,Tomato +109,Lantern,Lantern +110,Machinery Vehicle,Machinery Vehicle +111,Fan,Fan +112,Green Vegetables,Green Vegetables +113,Banana,Banana +114,Baseball Glove,Baseball Glove +115,Airplane,Airplane +116,Mouse,Mouse +117,Train,Train +118,Pumpkin,Pumpkin +119,Soccer,Soccer +120,Skiboard,Skiboard +121,Luggage,Luggage +122,Nightstand,Nightstand +123,Tea pot,Teapot +124,Telephone,Telephone +125,Trolley,Trolley +126,Head Phone,Head Phone +127,Sports Car,Sports Car +128,Stop Sign,Stop Sign +129,Dessert,Dessert +130,Scooter,Scooter +131,Stroller,Stroller +132,Crane,Crane +133,Remote,Remote +134,Refrigerator,Refrigerator +135,Oven,Oven +136,Lemon,Lemon +137,Duck,Duck +138,Baseball Bat,Baseball Bat +139,Surveillance Camera,Surveillance Camera +140,Cat,Cat +141,Jug,Jug +142,Broccoli,Broccoli +143,Piano,Piano +144,Pizza,Pizza +145,Elephant,Elephant +146,Skateboard,Skateboard +147,Surfboard,Surfboard +148,Gun,Gun +149,Skating and Skiing shoes,Skating and Skiing shoes +150,Gas stove,Gas stove +151,Donut,Donut +152,Bow Tie,Bow Tie +153,Carrot,Carrot +154,Toilet,Toilet +155,Kite,Kite +156,Strawberry,Strawberry +157,Other Balls,Other Balls +158,Shovel,Shovel +159,Pepper,Pepper +160,Computer Box,Computer Box +161,Toilet Paper,Toilet Paper +162,Cleaning Products,Cleaning Products +163,Chopsticks,Chopsticks +164,Microwave,Microwave +165,Pigeon,Pigeon +166,Baseball,Baseball +167,Cutting/chopping Board,Cutting/chopping Board +168,Coffee Table,Coffee Table +169,Side Table,Side Table +170,Scissors,Scissors +171,Marker,Marker +172,Pie,Pie +173,Ladder,Ladder +174,Snowboard,Snowboard +175,Cookies,Cookies +176,Radiator,Radiator +177,Fire Hydrant,Fire Hydrant +178,Basketball,Basketball +179,Zebra,Zebra +180,Grape,Grape +181,Giraffe,Giraffe +182,Potato,Potato +183,Sausage,Sausage +184,Tricycle,Tricycle +185,Violin,Violin +186,Egg,Egg +187,Fire Extinguisher,Fire Extinguisher +188,Candy,Candy +189,Fire Truck,Fire Truck +190,Billards,Billards +191,Converter,Converter +192,Bathtub,Bathtub +193,Wheelchair,Wheelchair +194,Golf Club,Golf Club +195,Briefcase,Briefcase +196,Cucumber,Cucumber +197,Cigar/Cigarette,Cigar/Cigarette +198,Paint Brush,Paint Brush +199,Pear,Pear +200,Heavy Truck,Heavy Truck +201,Hamburger,Hamburger +202,Extractor,Extractor +203,Extention Cord,Extension Cord +204,Tong,Tong +205,Tennis Racket,Tennis Racket +206,Folder,Folder +207,American Football,American Football +208,earphone,earphone +209,Mask,Mask +210,Kettle,Kettle +211,Tennis,Tennis +212,Ship,Ship +213,Swing,Swing +214,Coffee Machine,Coffee Machine +215,Slide,Slide +216,Carriage,Carriage +217,Onion,Onion +218,Green beans,Green beans +219,Projector,Projector +220,Frisbee,Frisbee +221,Washing Machine/Drying Machine,Washing Machine/Drying Machine +222,Chicken,Chicken +223,Printer,Printer +224,Watermelon,Watermelon +225,Saxophone,Saxophone +226,Tissue,Tissue +227,Toothbrush,Toothbrush +228,Ice cream,Ice cream +229,Hotair ballon,Hot air balloon +230,Cello,Cello +231,French Fries,French Fries +232,Scale,Scale +233,Trophy,Trophy +234,Cabbage,Cabbage +235,Hot dog,Hot dog +236,Blender,Blender +237,Peach,Peach +238,Rice,Rice +239,Wallet/Purse,Wallet/Purse +240,Volleyball,Volleyball +241,Deer,Deer +242,Goose,Goose +243,Tape,Tape +244,Tablet,Tablet +245,Cosmetics,Cosmetics +246,Trumpet,Trumpet +247,Pineapple,Pineapple +248,Golf Ball,Golf Ball +249,Ambulance,Ambulance +250,Parking meter,Parking meter +251,Mango,Mango +252,Key,Key +253,Hurdle,Hurdle +254,Fishing Rod,Fishing Rod +255,Medal,Medal +256,Flute,Flute +257,Brush,Brush +258,Penguin,Penguin +259,Megaphone,Megaphone +260,Corn,Corn +261,Lettuce,Lettuce +262,Garlic,Garlic +263,Swan,Swan +264,Helicopter,Helicopter +265,Green Onion,Green Onion +266,Sandwich,Sandwich +267,Nuts,Nuts +268,Speed Limit Sign,Speed Limit Sign +269,Induction Cooker,Induction Cooker +270,Broom,Broom +271,Trombone,Trombone +272,Plum,Plum +273,Rickshaw,Rickshaw +274,Goldfish,Goldfish +275,Kiwi fruit,Kiwi fruit +276,Router/modem,Router/modem +277,Poker Card,Poker Card +278,Toaster,Toaster +279,Shrimp,Shrimp +280,Sushi,Sushi +281,Cheese,Cheese +282,Notepaper,Notepaper +283,Cherry,Cherry +284,Pliers,Pliers +285,CD,CD +286,Pasta,Pasta +287,Hammer,Hammer +288,Cue,Cue +289,Avocado,Avocado +290,Hamimelon,Hami melon +291,Flask,Flask +292,Mushroon,Mushroom +293,Screwdriver,Screwdriver +294,Soap,Soap +295,Recorder,Recorder +296,Bear,Bear +297,Eggplant,Eggplant +298,Board Eraser,Board Eraser +299,Coconut,Coconut +300,Tape Measur/ Ruler,Tape Measure/ Ruler +301,Pig,Pig +302,Showerhead,Showerhead +303,Globe,Globe +304,Chips,Chips +305,Steak,Steak +306,Crosswalk Sign,Crosswalk Sign +307,Stapler,Stapler +308,Campel,Camel +309,Formula 1,Formula 1 +310,Pomegranate,Pomegranate +311,Dishwasher,Dishwasher +312,Crab,Crab +313,Hoverboard,Hoverboard +314,Meat ball,Meatball +315,Rice Cooker,Rice Cooker +316,Tuba,Tuba +317,Calculator,Calculator +318,Papaya,Papaya +319,Antelope,Antelope +320,Parrot,Parrot +321,Seal,Seal +322,Buttefly,Butterfly +323,Dumbbell,Dumbbell +324,Donkey,Donkey +325,Lion,Lion +326,Urinal,Urinal +327,Dolphin,Dolphin +328,Electric Drill,Electric Drill +329,Hair Dryer,Hair Dryer +330,Egg tart,Egg tart +331,Jellyfish,Jellyfish +332,Treadmill,Treadmill +333,Lighter,Lighter +334,Grapefruit,Grapefruit +335,Game board,Game board +336,Mop,Mop +337,Radish,Radish +338,Baozi,Baozi +339,Target,Target +340,French,French +341,Spring Rolls,Spring Rolls +342,Monkey,Monkey +343,Rabbit,Rabbit +344,Pencil Case,Pencil Case +345,Yak,Yak +346,Red Cabbage,Red Cabbage +347,Binoculars,Binoculars +348,Asparagus,Asparagus +349,Barbell,Barbell +350,Scallop,Scallop +351,Noddles,Noddles +352,Comb,Comb +353,Dumpling,Dumpling +354,Oyster,Oyster +355,Table Teniis paddle,Table Tennis paddle +356,Cosmetics Brush/Eyeliner Pencil,Cosmetics Brush/Eyeliner Pencil +357,Chainsaw,Chainsaw +358,Eraser,Eraser +359,Lobster,Lobster +360,Durian,Durian +361,Okra,Okra +362,Lipstick,Lipstick +363,Cosmetics Mirror,Cosmetics Mirror +364,Curling,Curling +365,Table Tennis,Table Tennis \ No newline at end of file diff --git a/dimos/models/Detic/datasets/metadata/coco_clip_a+cname.npy b/dimos/models/Detic/datasets/metadata/coco_clip_a+cname.npy new file mode 100644 index 0000000000..63b938afaf Binary files /dev/null and b/dimos/models/Detic/datasets/metadata/coco_clip_a+cname.npy differ diff --git a/dimos/models/Detic/datasets/metadata/imagenet_lvis_wnid.txt b/dimos/models/Detic/datasets/metadata/imagenet_lvis_wnid.txt new file mode 100644 index 0000000000..8433aa01af --- /dev/null +++ b/dimos/models/Detic/datasets/metadata/imagenet_lvis_wnid.txt @@ -0,0 +1,997 @@ +n02682922 +n02686379 +n02691156 +n02694662 +n07884567 +n01698434 +n07750586 +n02701002 +n02705944 +n02715229 +n07739125 +n07825850 +n07750872 +n02730930 +n02732072 +n02735538 +n02738449 +n02738535 +n02739550 +n02739668 +n07718747 +n02747177 +n02747802 +n07719213 +n02754103 +n07764847 +n02763901 +n02764044 +n02486410 +n02766534 +n02768226 +n02769748 +n02774152 +n02773838 +n07693725 +n02775483 +n07687381 +n02776205 +n02779435 +n02780815 +n02782093 +n12147226 +n07753592 +n02786058 +n02785648 +n02786198 +n02787622 +n02788021 +n02790996 +n02792552 +n02795169 +n02796318 +n02797295 +n02797881 +n02799071 +n02799175 +n02799323 +n02800213 +n02801938 +n02802426 +n02804252 +n02139199 +n02808304 +n02807616 +n02808440 +n07860805 +n02810471 +n07709881 +n02816656 +n02816768 +n02131653 +n02818832 +n02821202 +n02822220 +n02404186 +n02823124 +n02823428 +n02823510 +n02164464 +n02824448 +n07720875 +n02827606 +n02828299 +n02828884 +n02831237 +n02834778 +n02838728 +n02839110 +n02840245 +n02841315 +n01503061 +n02843553 +n02843158 +n02843276 +n02843684 +n02413050 +n07744811 +n02846511 +n02849154 +n02850358 +n02850732 +n02850950 +n02852173 +n02854926 +n07743544 +n02858304 +n02860415 +n02860640 +n07841495 +n02865351 +n02865931 +n02865665 +n02869837 +n02870880 +n02871147 +n02871824 +n02872752 +n02876657 +n02877962 +n02879087 +n02879718 +n02883205 +n02880940 +n02881757 +n02882301 +n02883344 +n02885462 +n02887489 +n02887970 +n02892201 +n02892767 +n02893692 +n07679356 +n02896294 +n02898585 +n02900705 +n11876803 +n02906734 +n07715221 +n07600285 +n02909870 +n02912557 +n01887623 +n02108672 +n02916179 +n02917067 +n02916936 +n02917377 +n07680932 +n02920259 +n07880968 +n02924116 +n07848338 +n02274259 +n02928608 +n02930766 +n02931294 +n02932523 +n02933112 +n02933462 +n02938886 +n01887896 +n02942349 +n02437136 +n02942699 +n02943241 +n02946348 +n02946921 +n02951585 +n02948072 +n02948557 +n07598256 +n07601572 +n02949202 +n02949542 +n02951358 +n07755929 +n02952374 +n02954340 +n02954938 +n02955767 +n07920349 +n02958343 +n02959942 +n02960352 +n02961225 +n02963159 +n02965300 +n11808468 +n02968473 +n02970408 +n02970849 +n02971356 +n02977438 +n07580359 +n02978881 +n02979836 +n02121620 +n07715103 +n07822518 +n02988304 +n02992529 +n03000247 +n03001627 +n03002711 +n03002948 +n03005285 +n03006903 +n07757132 +n01791625 +n12515925 +n07721456 +n03017168 +n07712559 +n03020416 +n07921360 +n07617611 +n07921455 +n03030353 +n03031012 +n03035715 +n03037709 +n03038281 +n03041114 +n12710415 +n03043958 +n03045074 +n03045337 +n03046257 +n03047052 +n03050864 +n03051249 +n03055418 +n03057021 +n03057920 +n03059103 +n01792158 +n02233338 +n07922764 +n07772935 +n03063338 +n03063968 +n03063689 +n03066849 +n07808587 +n03075370 +n03075768 +n06596364 +n03080497 +n03085013 +n07810907 +n03096960 +n03100240 +n03100346 +n03101156 +n03101986 +n03102654 +n03108853 +n03109150 +n07731952 +n07687789 +n03110669 +n03111296 +n07568095 +n03112869 +n03113835 +n02125311 +n03121897 +n03123917 +n03124170 +n01976957 +n07681926 +n03127925 +n03128248 +n03129001 +n07691650 +n03131574 +n03133415 +n03135917 +n07682197 +n01579028 +n03138344 +n03138669 +n03140292 +n03141327 +n03141065 +n03141823 +n01322685 +n07718472 +n03147509 +n03148324 +n03150232 +n03150511 +n03151077 +n03156279 +n03157348 +n03158885 +n02110341 +n07765073 +n03168217 +n02430045 +n03175843 +n03179701 +n03188531 +n03199901 +n03201208 +n03201776 +n03206908 +n03207305 +n03207743 +n03207835 +n03207941 +n03210683 +n03216710 +n02084071 +n03219135 +n03219483 +n02068974 +n02389559 +n03223299 +n07639069 +n01812337 +n02268443 +n03233905 +n03234164 +n03236735 +n03237416 +n03239054 +n03237340 +n03239726 +n03245889 +n03247083 +n03249569 +n03250847 +n01846331 +n01847170 +n03253886 +n03255030 +n03256032 +n03259009 +n01613294 +n03261776 +n03262248 +n03262809 +n07840804 +n07866723 +n07841345 +n03266371 +n07713074 +n03271030 +n03273913 +n02503517 +n02432983 +n03291819 +n03294833 +n03309356 +n01610955 +n03320046 +n03325088 +n03325941 +n02443346 +n03329302 +n03329663 +n07753113 +n03335030 +n03337140 +n03336839 +n03343737 +n03345487 +n03345837 +n03346455 +n03349469 +n02512053 +n03350204 +n03351979 +n03354903 +n03355925 +n02007558 +n03356982 +n03358172 +n03359137 +n03362639 +n03364008 +n03364156 +n03372549 +n02376542 +n03376595 +n03378174 +n03378765 +n03379051 +n03380724 +n03384352 +n03393912 +n07868200 +n03397947 +n01639765 +n07924033 +n03400231 +n07605474 +n03403643 +n03408444 +n03410740 +n03416900 +n03417042 +n07818277 +n03424325 +n02423022 +n07643981 +n03433877 +n02510455 +n07814925 +n02439033 +n03438071 +n03438257 +n03441112 +n02416519 +n03443912 +n01443537 +n03446070 +n03445924 +n03447447 +n01855672 +n02480855 +n12158031 +n07758680 +n03454885 +n03455488 +n03456024 +n07722485 +n03459328 +n03459591 +n02132580 +n03461288 +n03467517 +n02041246 +n03467984 +n03475581 +n03475961 +n03476313 +n03480579 +n07697100 +n03481172 +n03482252 +n03482405 +n02342885 +n03483316 +n03485198 +n03490006 +n03484083 +n03484576 +n03485794 +n03488188 +n03494537 +n03497657 +n03498441 +n03502331 +n03502200 +n03503997 +n03505504 +n03505667 +n03506028 +n03508101 +n03512147 +n03513137 +n02008041 +n03518445 +n03521076 +n02398521 +n03524150 +n02395406 +n03528901 +n07858978 +n03531546 +n03532342 +n03533014 +n02213107 +n02374451 +n03541923 +n03543254 +n07830593 +n03544143 +n03545470 +n01833805 +n07857731 +n02134084 +n07614500 +n07615774 +n03557692 +n03557840 +n03558404 +n03571280 +n03584254 +n03584829 +n03589791 +n07642933 +n03593526 +n03594734 +n03594945 +n07606669 +n03595614 +n03595860 +n03602883 +n03605598 +n03609235 +n03610418 +n03610524 +n03612814 +n03613294 +n03617312 +n03617480 +n03620967 +n02122948 +n07763629 +n03623198 +n03623556 +n03625646 +n03626760 +n01882714 +n03630383 +n03633091 +n02165456 +n02412440 +n03636649 +n03637181 +n03637318 +n03640988 +n03642806 +n07870167 +n03644858 +n03649909 +n03655072 +n11748002 +n07749582 +n07926250 +n03662719 +n03662887 +n03665924 +n03668067 +n07749731 +n03670208 +n02129165 +n07901587 +n01674464 +n07607605 +n03691459 +n03693474 +n03701391 +n03705379 +n03710193 +n01847806 +n03715892 +n02504770 +n02073831 +n07747951 +n03717131 +n03717447 +n03720163 +n03722007 +n07916041 +n10297234 +n07711569 +n03724417 +n03725035 +n03726760 +n03727946 +n03729402 +n03733805 +n03735637 +n07871436 +n07755411 +n03759954 +n03760671 +n03761084 +n07844042 +n03764736 +n03770679 +n07606278 +n03773035 +n03775071 +n03775199 +n03782190 +n02484322 +n03789946 +n03791053 +n03791235 +n03790512 +n03792334 +n03793489 +n07690273 +n03797390 +n13000891 +n03801880 +n03800933 +n03805280 +n03814817 +n03814906 +n03815615 +n03816136 +n06267145 +n03822656 +n03825080 +n03831203 +n03831382 +n03836602 +n03837422 +n01970164 +n03844045 +n07842753 +n12433081 +n07747607 +n07924834 +n01518878 +n03858418 +n03862676 +n03863108 +n01621127 +n03871628 +n03873416 +n03874599 +n03876231 +n03877472 +n03878674 +n03880531 +n03880323 +n03885904 +n07762244 +n03887697 +n03888257 +n01821203 +n03889726 +n03889871 +n03891051 +n03891332 +n01816887 +n03895866 +n03896103 +n07663899 +n07725376 +n07751004 +n07855510 +n07767847 +n03904909 +n03906106 +n03906224 +n02051845 +n03906997 +n03908204 +n03908618 +n03908714 +n03909160 +n02055803 +n07815588 +n03914337 +n03916031 +n07746186 +n00007846 +n01318894 +n03920867 +n03924069 +n03928116 +n07824988 +n03930630 +n01811909 +n03935335 +n03938244 +n03940256 +n07753275 +n03942813 +n03944138 +n03948459 +n07683617 +n03950228 +n03950359 +n07873807 +n03963198 +n03964495 +n03966976 +n03967562 +n03973839 +n03973628 +n03975926 +n03976657 +n03978966 +n03980874 +n02382437 +n03982430 +n07927512 +n03990474 +n03991062 +n07710616 +n03992703 +n03993180 +n03996416 +n07695742 +n04004475 +n04008634 +n04009552 +n04011827 +n07752602 +n07617188 +n02655020 +n02047614 +n02110958 +n07735510 +n04023249 +n01322604 +n07881205 +n04033995 +n02324045 +n04037443 +n04039381 +n04039848 +n04040759 +n04043733 +n04045397 +n04049405 +n02412080 +n07745466 +n02331046 +n04057215 +n04059516 +n04059947 +n04062428 +n04064401 +n04069276 +n04074963 +n02391994 +n04090263 +n04095210 +n04097866 +n04099969 +n02329401 +n04102618 +n04102162 +n04103206 +n07928887 +n04114844 +n04116098 +n04122825 +n04123740 +n04124202 +n04124098 +n04127249 +n04127904 +n07806221 +n02534734 +n07823460 +n04131690 +n04133789 +n07695965 +n04137217 +n04138977 +n04140631 +n04141076 +n04141975 +n04143897 +n04146614 +n04148054 +n04149813 +n04150980 +n04154565 +n04156140 +n04157320 +n02021795 +n01456756 +n04160586 +n01956764 +n04179913 +n04183329 +n01482330 +n04185071 +n04185529 +n04185804 +n04186051 +n04186455 +n04186848 +n02411705 +n02104523 +n07615289 +n04192698 +n04197391 +n04199027 +n04204081 +n04204347 +n04205318 +n04206225 +n04207343 +n04208210 +n04208936 +n04209133 +n04209239 +n04210120 +n04217882 +n04220250 +n04225987 +n04227900 +n04228054 +n04228581 +n04230387 +n04230603 +n04230808 +n04232153 +n04235291 +n04235860 +n04239436 +n04241394 +n07914271 +n01726692 +n04251791 +n04252077 +n04254680 +n04254777 +n04256520 +n04256891 +n04257790 +n04259630 +n07583197 +n04263257 +n04263502 +n07848093 +n07844867 +n04266014 +n04269944 +n04270891 +n04272054 +n04275175 +n01772222 +n01984695 +n04284002 +n04285803 +n04286575 +n02355227 +n04297098 +n04303497 +n02317335 +n04306847 +n04307986 +n04313503 +n04315713 +n04315948 +n07588947 +n04320871 +n04320973 +n04326896 +n04330340 +n04332243 +n04333129 +n07745940 +n06794110 +n04335886 +n07854707 +n04346511 +n04349401 +n04350581 +n04350905 +n11978233 +n04356056 +n04356595 +n07879450 +n04367480 +n04370288 +n04370048 +n04370456 +n07712063 +n04371563 +n04373894 +n04376876 +n07826091 +n04381587 +n04379243 +n04380533 +n04382880 +n07880751 +n04384910 +n04387400 +n04389033 +n04388743 +n04390577 +n04392113 +n04393549 +n04395024 +n04395106 +n07933154 +n04397452 +n04397768 +n04398044 +n04401088 +n04401680 +n04402449 +n04403413 +n04404997 +n04405907 +n04409515 +n04409806 +n07905979 +n04421872 +n04422727 +n04422875 +n04423845 +n04431745 +n04432203 +n02129604 +n04434932 +n04438304 +n04439712 +n07686873 +n04442312 +n04442441 +n15075141 +n07734017 +n04450749 +n04452615 +n04453156 +n04453390 +n04453910 +n04461696 +n04459362 +n04459773 +n04461879 +n04465501 +n06874185 +n04466871 +n04467665 +n04468005 +n04469514 +n04476259 +n04479046 +n04480853 +n04482393 +n04485082 +n04489008 +n04490091 +n07609632 +n04491769 +n04493381 +n04498389 +n11877646 +n01662784 +n04502197 +n04505036 +n04507155 +n04508949 +n04509417 +n04516116 +n04517823 +n04522168 +n04525305 +n04531873 +n04534520 +n07828987 +n04536866 +n07906111 +n04540053 +n01616318 +n04542943 +n04543158 +n04543772 +n04546194 +n04548280 +n04548362 +n02081571 +n04550184 +n04554684 +n04555897 +n04557648 +n04559166 +n04559451 +n04560113 +n04560804 +n04562122 +n04562262 +n04562935 +n04560292 +n07756951 +n04568069 +n04569063 +n04574067 +n04574999 +n04576002 +n04579667 +n04584207 +n04587559 +n04589325 +n04590746 +n04591713 +n04591887 +n04592099 +n04593629 +n04596742 +n02114100 +n04597913 +n04606574 +n04610013 +n07849336 +n04612840 +n02391049 +n07716358 diff --git a/dimos/models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy.REMOVED.git-id b/dimos/models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy.REMOVED.git-id new file mode 100644 index 0000000000..b62476a597 --- /dev/null +++ b/dimos/models/Detic/datasets/metadata/lvis_v1_clip_a+cname.npy.REMOVED.git-id @@ -0,0 +1 @@ +a9e5376ee4f7cd871f9b2830bcd6e79967875d7e \ No newline at end of file diff --git a/dimos/models/Detic/datasets/metadata/lvis_v1_train_cat_info.json b/dimos/models/Detic/datasets/metadata/lvis_v1_train_cat_info.json new file mode 100644 index 0000000000..95fef09233 --- /dev/null +++ b/dimos/models/Detic/datasets/metadata/lvis_v1_train_cat_info.json @@ -0,0 +1 @@ +[{"name": "aerosol_can", "instance_count": 109, "def": "a dispenser that holds a substance under pressure", "synonyms": ["aerosol_can", "spray_can"], "image_count": 64, "id": 1, "frequency": "c", "synset": "aerosol.n.02"}, {"name": "air_conditioner", "instance_count": 1081, "def": "a machine that keeps air cool and dry", "synonyms": ["air_conditioner"], "image_count": 364, "id": 2, "frequency": "f", "synset": "air_conditioner.n.01"}, {"name": "airplane", "instance_count": 3720, "def": "an aircraft that has a fixed wing and is powered by propellers or jets", "synonyms": ["airplane", "aeroplane"], "image_count": 1911, "id": 3, "frequency": "f", "synset": "airplane.n.01"}, {"name": "alarm_clock", "instance_count": 158, "def": "a clock that wakes a sleeper at some preset time", "synonyms": ["alarm_clock"], "image_count": 149, "id": 4, "frequency": "f", "synset": "alarm_clock.n.01"}, {"name": "alcohol", "instance_count": 207, "def": "a liquor or brew containing alcohol as the active agent", "synonyms": ["alcohol", "alcoholic_beverage"], "image_count": 29, "id": 5, "frequency": "c", "synset": "alcohol.n.01"}, {"name": "alligator", "instance_count": 39, "def": "amphibious reptiles related to crocodiles but with shorter broader snouts", "synonyms": ["alligator", "gator"], "image_count": 26, "id": 6, "frequency": "c", "synset": "alligator.n.02"}, {"name": "almond", "instance_count": 1700, "def": "oval-shaped edible seed of the almond tree", "synonyms": ["almond"], "image_count": 59, "id": 7, "frequency": "c", "synset": "almond.n.02"}, {"name": "ambulance", "instance_count": 25, "def": "a vehicle that takes people to and from hospitals", "synonyms": ["ambulance"], "image_count": 22, "id": 8, "frequency": "c", "synset": "ambulance.n.01"}, {"name": "amplifier", "instance_count": 16, "def": "electronic equipment that increases strength of signals", "synonyms": ["amplifier"], "image_count": 12, "id": 9, "frequency": "c", "synset": "amplifier.n.01"}, {"name": "anklet", "instance_count": 39, "def": "an ornament worn around the ankle", "synonyms": ["anklet", "ankle_bracelet"], "image_count": 28, "id": 10, "frequency": "c", "synset": "anklet.n.03"}, {"name": "antenna", "instance_count": 1018, "def": "an electrical device that sends or receives radio or television signals", "synonyms": ["antenna", "aerial", "transmitting_aerial"], "image_count": 505, "id": 11, "frequency": "f", "synset": "antenna.n.01"}, {"name": "apple", "instance_count": 17451, "def": "fruit with red or yellow or green skin and sweet to tart crisp whitish flesh", "synonyms": ["apple"], "image_count": 1207, "id": 12, "frequency": "f", "synset": "apple.n.01"}, {"name": "applesauce", "instance_count": 7, "def": "puree of stewed apples usually sweetened and spiced", "synonyms": ["applesauce"], "image_count": 4, "id": 13, "frequency": "r", "synset": "applesauce.n.01"}, {"name": "apricot", "instance_count": 62, "def": "downy yellow to rosy-colored fruit resembling a small peach", "synonyms": ["apricot"], "image_count": 10, "id": 14, "frequency": "r", "synset": "apricot.n.02"}, {"name": "apron", "instance_count": 881, "def": "a garment of cloth that is tied about the waist and worn to protect clothing", "synonyms": ["apron"], "image_count": 500, "id": 15, "frequency": "f", "synset": "apron.n.01"}, {"name": "aquarium", "instance_count": 36, "def": "a tank/pool/bowl filled with water for keeping live fish and underwater animals", "synonyms": ["aquarium", "fish_tank"], "image_count": 33, "id": 16, "frequency": "c", "synset": "aquarium.n.01"}, {"name": "arctic_(type_of_shoe)", "instance_count": 8, "def": "a waterproof overshoe that protects shoes from water or snow", "synonyms": ["arctic_(type_of_shoe)", "galosh", "golosh", "rubber_(type_of_shoe)", "gumshoe"], "image_count": 3, "id": 17, "frequency": "r", "synset": "arctic.n.02"}, {"name": "armband", "instance_count": 85, "def": "a band worn around the upper arm", "synonyms": ["armband"], "image_count": 44, "id": 18, "frequency": "c", "synset": "armband.n.02"}, {"name": "armchair", "instance_count": 1112, "def": "chair with a support on each side for arms", "synonyms": ["armchair"], "image_count": 561, "id": 19, "frequency": "f", "synset": "armchair.n.01"}, {"name": "armoire", "instance_count": 11, "def": "a large wardrobe or cabinet", "synonyms": ["armoire"], "image_count": 8, "id": 20, "frequency": "r", "synset": "armoire.n.01"}, {"name": "armor", "instance_count": 23, "def": "protective covering made of metal and used in combat", "synonyms": ["armor", "armour"], "image_count": 9, "id": 21, "frequency": "r", "synset": "armor.n.01"}, {"name": "artichoke", "instance_count": 293, "def": "a thistlelike flower head with edible fleshy leaves and heart", "synonyms": ["artichoke"], "image_count": 33, "id": 22, "frequency": "c", "synset": "artichoke.n.02"}, {"name": "trash_can", "instance_count": 2722, "def": "a bin that holds rubbish until it is collected", "synonyms": ["trash_can", "garbage_can", "wastebin", "dustbin", "trash_barrel", "trash_bin"], "image_count": 1883, "id": 23, "frequency": "f", "synset": "ashcan.n.01"}, {"name": "ashtray", "instance_count": 136, "def": "a receptacle for the ash from smokers' cigars or cigarettes", "synonyms": ["ashtray"], "image_count": 98, "id": 24, "frequency": "c", "synset": "ashtray.n.01"}, {"name": "asparagus", "instance_count": 969, "def": "edible young shoots of the asparagus plant", "synonyms": ["asparagus"], "image_count": 70, "id": 25, "frequency": "c", "synset": "asparagus.n.02"}, {"name": "atomizer", "instance_count": 67, "def": "a dispenser that turns a liquid (such as perfume) into a fine mist", "synonyms": ["atomizer", "atomiser", "spray", "sprayer", "nebulizer", "nebuliser"], "image_count": 46, "id": 26, "frequency": "c", "synset": "atomizer.n.01"}, {"name": "avocado", "instance_count": 1048, "def": "a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed", "synonyms": ["avocado"], "image_count": 117, "id": 27, "frequency": "f", "synset": "avocado.n.01"}, {"name": "award", "instance_count": 163, "def": "a tangible symbol signifying approval or distinction", "synonyms": ["award", "accolade"], "image_count": 41, "id": 28, "frequency": "c", "synset": "award.n.02"}, {"name": "awning", "instance_count": 4270, "def": "a canopy made of canvas to shelter people or things from rain or sun", "synonyms": ["awning"], "image_count": 1395, "id": 29, "frequency": "f", "synset": "awning.n.01"}, {"name": "ax", "instance_count": 8, "def": "an edge tool with a heavy bladed head mounted across a handle", "synonyms": ["ax", "axe"], "image_count": 7, "id": 30, "frequency": "r", "synset": "ax.n.01"}, {"name": "baboon", "instance_count": 3, "def": "large terrestrial monkeys having doglike muzzles", "synonyms": ["baboon"], "image_count": 1, "id": 31, "frequency": "r", "synset": "baboon.n.01"}, {"name": "baby_buggy", "instance_count": 447, "def": "a small vehicle with four wheels in which a baby or child is pushed around", "synonyms": ["baby_buggy", "baby_carriage", "perambulator", "pram", "stroller"], "image_count": 314, "id": 32, "frequency": "f", "synset": "baby_buggy.n.01"}, {"name": "basketball_backboard", "instance_count": 42, "def": "a raised vertical board with basket attached; used to play basketball", "synonyms": ["basketball_backboard"], "image_count": 31, "id": 33, "frequency": "c", "synset": "backboard.n.01"}, {"name": "backpack", "instance_count": 3907, "def": "a bag carried by a strap on your back or shoulder", "synonyms": ["backpack", "knapsack", "packsack", "rucksack", "haversack"], "image_count": 1905, "id": 34, "frequency": "f", "synset": "backpack.n.01"}, {"name": "handbag", "instance_count": 3947, "def": "a container used for carrying money and small personal items or accessories", "synonyms": ["handbag", "purse", "pocketbook"], "image_count": 1859, "id": 35, "frequency": "f", "synset": "bag.n.04"}, {"name": "suitcase", "instance_count": 8537, "def": "cases used to carry belongings when traveling", "synonyms": ["suitcase", "baggage", "luggage"], "image_count": 1623, "id": 36, "frequency": "f", "synset": "bag.n.06"}, {"name": "bagel", "instance_count": 372, "def": "glazed yeast-raised doughnut-shaped roll with hard crust", "synonyms": ["bagel", "beigel"], "image_count": 47, "id": 37, "frequency": "c", "synset": "bagel.n.01"}, {"name": "bagpipe", "instance_count": 6, "def": "a tubular wind instrument; the player blows air into a bag and squeezes it out", "synonyms": ["bagpipe"], "image_count": 3, "id": 38, "frequency": "r", "synset": "bagpipe.n.01"}, {"name": "baguet", "instance_count": 9, "def": "narrow French stick loaf", "synonyms": ["baguet", "baguette"], "image_count": 3, "id": 39, "frequency": "r", "synset": "baguet.n.01"}, {"name": "bait", "instance_count": 1, "def": "something used to lure fish or other animals into danger so they can be trapped or killed", "synonyms": ["bait", "lure"], "image_count": 1, "id": 40, "frequency": "r", "synset": "bait.n.02"}, {"name": "ball", "instance_count": 755, "def": "a spherical object used as a plaything", "synonyms": ["ball"], "image_count": 305, "id": 41, "frequency": "f", "synset": "ball.n.06"}, {"name": "ballet_skirt", "instance_count": 12, "def": "very short skirt worn by ballerinas", "synonyms": ["ballet_skirt", "tutu"], "image_count": 6, "id": 42, "frequency": "r", "synset": "ballet_skirt.n.01"}, {"name": "balloon", "instance_count": 1556, "def": "large tough nonrigid bag filled with gas or heated air", "synonyms": ["balloon"], "image_count": 210, "id": 43, "frequency": "f", "synset": "balloon.n.01"}, {"name": "bamboo", "instance_count": 243, "def": "woody tropical grass having hollow woody stems", "synonyms": ["bamboo"], "image_count": 36, "id": 44, "frequency": "c", "synset": "bamboo.n.02"}, {"name": "banana", "instance_count": 50552, "def": "elongated crescent-shaped yellow fruit with soft sweet flesh", "synonyms": ["banana"], "image_count": 1787, "id": 45, "frequency": "f", "synset": "banana.n.02"}, {"name": "Band_Aid", "instance_count": 19, "def": "trade name for an adhesive bandage to cover small cuts or blisters", "synonyms": ["Band_Aid"], "image_count": 17, "id": 46, "frequency": "c", "synset": "band_aid.n.01"}, {"name": "bandage", "instance_count": 92, "def": "a piece of soft material that covers and protects an injured part of the body", "synonyms": ["bandage"], "image_count": 51, "id": 47, "frequency": "c", "synset": "bandage.n.01"}, {"name": "bandanna", "instance_count": 219, "def": "large and brightly colored handkerchief; often used as a neckerchief", "synonyms": ["bandanna", "bandana"], "image_count": 138, "id": 48, "frequency": "f", "synset": "bandanna.n.01"}, {"name": "banjo", "instance_count": 3, "def": "a stringed instrument of the guitar family with a long neck and circular body", "synonyms": ["banjo"], "image_count": 3, "id": 49, "frequency": "r", "synset": "banjo.n.01"}, {"name": "banner", "instance_count": 5907, "def": "long strip of cloth or paper used for decoration or advertising", "synonyms": ["banner", "streamer"], "image_count": 1470, "id": 50, "frequency": "f", "synset": "banner.n.01"}, {"name": "barbell", "instance_count": 4, "def": "a bar to which heavy discs are attached at each end; used in weightlifting", "synonyms": ["barbell"], "image_count": 3, "id": 51, "frequency": "r", "synset": "barbell.n.01"}, {"name": "barge", "instance_count": 3, "def": "a flatbottom boat for carrying heavy loads (especially on canals)", "synonyms": ["barge"], "image_count": 2, "id": 52, "frequency": "r", "synset": "barge.n.01"}, {"name": "barrel", "instance_count": 707, "def": "a cylindrical container that holds liquids", "synonyms": ["barrel", "cask"], "image_count": 186, "id": 53, "frequency": "f", "synset": "barrel.n.02"}, {"name": "barrette", "instance_count": 119, "def": "a pin for holding women's hair in place", "synonyms": ["barrette"], "image_count": 76, "id": 54, "frequency": "c", "synset": "barrette.n.01"}, {"name": "barrow", "instance_count": 30, "def": "a cart for carrying small loads; has handles and one or more wheels", "synonyms": ["barrow", "garden_cart", "lawn_cart", "wheelbarrow"], "image_count": 26, "id": 55, "frequency": "c", "synset": "barrow.n.03"}, {"name": "baseball_base", "instance_count": 404, "def": "a place that the runner must touch before scoring", "synonyms": ["baseball_base"], "image_count": 303, "id": 56, "frequency": "f", "synset": "base.n.03"}, {"name": "baseball", "instance_count": 1013, "def": "a ball used in playing baseball", "synonyms": ["baseball"], "image_count": 738, "id": 57, "frequency": "f", "synset": "baseball.n.02"}, {"name": "baseball_bat", "instance_count": 2698, "def": "an implement used in baseball by the batter", "synonyms": ["baseball_bat"], "image_count": 1799, "id": 58, "frequency": "f", "synset": "baseball_bat.n.01"}, {"name": "baseball_cap", "instance_count": 9028, "def": "a cap with a bill", "synonyms": ["baseball_cap", "jockey_cap", "golf_cap"], "image_count": 1934, "id": 59, "frequency": "f", "synset": "baseball_cap.n.01"}, {"name": "baseball_glove", "instance_count": 2536, "def": "the handwear used by fielders in playing baseball", "synonyms": ["baseball_glove", "baseball_mitt"], "image_count": 1609, "id": 60, "frequency": "f", "synset": "baseball_glove.n.01"}, {"name": "basket", "instance_count": 3984, "def": "a container that is usually woven and has handles", "synonyms": ["basket", "handbasket"], "image_count": 1622, "id": 61, "frequency": "f", "synset": "basket.n.01"}, {"name": "basketball", "instance_count": 56, "def": "an inflated ball used in playing basketball", "synonyms": ["basketball"], "image_count": 41, "id": 62, "frequency": "c", "synset": "basketball.n.02"}, {"name": "bass_horn", "instance_count": 6, "def": "the lowest brass wind instrument", "synonyms": ["bass_horn", "sousaphone", "tuba"], "image_count": 4, "id": 63, "frequency": "r", "synset": "bass_horn.n.01"}, {"name": "bat_(animal)", "instance_count": 47, "def": "nocturnal mouselike mammal with forelimbs modified to form membranous wings", "synonyms": ["bat_(animal)"], "image_count": 11, "id": 64, "frequency": "c", "synset": "bat.n.01"}, {"name": "bath_mat", "instance_count": 336, "def": "a heavy towel or mat to stand on while drying yourself after a bath", "synonyms": ["bath_mat"], "image_count": 270, "id": 65, "frequency": "f", "synset": "bath_mat.n.01"}, {"name": "bath_towel", "instance_count": 1210, "def": "a large towel; to dry yourself after a bath", "synonyms": ["bath_towel"], "image_count": 349, "id": 66, "frequency": "f", "synset": "bath_towel.n.01"}, {"name": "bathrobe", "instance_count": 53, "def": "a loose-fitting robe of towelling; worn after a bath or swim", "synonyms": ["bathrobe"], "image_count": 42, "id": 67, "frequency": "c", "synset": "bathrobe.n.01"}, {"name": "bathtub", "instance_count": 868, "def": "a large open container that you fill with water and use to wash the body", "synonyms": ["bathtub", "bathing_tub"], "image_count": 823, "id": 68, "frequency": "f", "synset": "bathtub.n.01"}, {"name": "batter_(food)", "instance_count": 26, "def": "a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking", "synonyms": ["batter_(food)"], "image_count": 6, "id": 69, "frequency": "r", "synset": "batter.n.02"}, {"name": "battery", "instance_count": 155, "def": "a portable device that produces electricity", "synonyms": ["battery"], "image_count": 48, "id": 70, "frequency": "c", "synset": "battery.n.02"}, {"name": "beachball", "instance_count": 3, "def": "large and light ball; for play at the seaside", "synonyms": ["beachball"], "image_count": 3, "id": 71, "frequency": "r", "synset": "beach_ball.n.01"}, {"name": "bead", "instance_count": 1371, "def": "a small ball with a hole through the middle used for ornamentation, jewellery, etc.", "synonyms": ["bead"], "image_count": 42, "id": 72, "frequency": "c", "synset": "bead.n.01"}, {"name": "bean_curd", "instance_count": 231, "def": "cheeselike food made of curdled soybean milk", "synonyms": ["bean_curd", "tofu"], "image_count": 24, "id": 73, "frequency": "c", "synset": "bean_curd.n.01"}, {"name": "beanbag", "instance_count": 20, "def": "a bag filled with dried beans or similar items; used in games or to sit on", "synonyms": ["beanbag"], "image_count": 16, "id": 74, "frequency": "c", "synset": "beanbag.n.01"}, {"name": "beanie", "instance_count": 1907, "def": "a small skullcap; formerly worn by schoolboys and college freshmen", "synonyms": ["beanie", "beany"], "image_count": 605, "id": 75, "frequency": "f", "synset": "beanie.n.01"}, {"name": "bear", "instance_count": 1069, "def": "large carnivorous or omnivorous mammals with shaggy coats and claws", "synonyms": ["bear"], "image_count": 646, "id": 76, "frequency": "f", "synset": "bear.n.01"}, {"name": "bed", "instance_count": 2137, "def": "a piece of furniture that provides a place to sleep", "synonyms": ["bed"], "image_count": 1765, "id": 77, "frequency": "f", "synset": "bed.n.01"}, {"name": "bedpan", "instance_count": 2, "def": "a shallow vessel used by a bedridden patient for defecation and urination", "synonyms": ["bedpan"], "image_count": 2, "id": 78, "frequency": "r", "synset": "bedpan.n.01"}, {"name": "bedspread", "instance_count": 188, "def": "decorative cover for a bed", "synonyms": ["bedspread", "bedcover", "bed_covering", "counterpane", "spread"], "image_count": 125, "id": 79, "frequency": "f", "synset": "bedspread.n.01"}, {"name": "cow", "instance_count": 8085, "def": "cattle/cow", "synonyms": ["cow"], "image_count": 1420, "id": 80, "frequency": "f", "synset": "beef.n.01"}, {"name": "beef_(food)", "instance_count": 1242, "def": "meat from an adult domestic bovine", "synonyms": ["beef_(food)", "boeuf_(food)"], "image_count": 140, "id": 81, "frequency": "f", "synset": "beef.n.02"}, {"name": "beeper", "instance_count": 4, "def": "an device that beeps when the person carrying it is being paged", "synonyms": ["beeper", "pager"], "image_count": 4, "id": 82, "frequency": "r", "synset": "beeper.n.01"}, {"name": "beer_bottle", "instance_count": 1227, "def": "a bottle that holds beer", "synonyms": ["beer_bottle"], "image_count": 322, "id": 83, "frequency": "f", "synset": "beer_bottle.n.01"}, {"name": "beer_can", "instance_count": 203, "def": "a can that holds beer", "synonyms": ["beer_can"], "image_count": 60, "id": 84, "frequency": "c", "synset": "beer_can.n.01"}, {"name": "beetle", "instance_count": 9, "def": "insect with hard wing covers", "synonyms": ["beetle"], "image_count": 2, "id": 85, "frequency": "r", "synset": "beetle.n.01"}, {"name": "bell", "instance_count": 590, "def": "a hollow device made of metal that makes a ringing sound when struck", "synonyms": ["bell"], "image_count": 231, "id": 86, "frequency": "f", "synset": "bell.n.01"}, {"name": "bell_pepper", "instance_count": 4369, "def": "large bell-shaped sweet pepper in green or red or yellow or orange or black varieties", "synonyms": ["bell_pepper", "capsicum"], "image_count": 333, "id": 87, "frequency": "f", "synset": "bell_pepper.n.02"}, {"name": "belt", "instance_count": 3683, "def": "a band to tie or buckle around the body (usually at the waist)", "synonyms": ["belt"], "image_count": 1941, "id": 88, "frequency": "f", "synset": "belt.n.02"}, {"name": "belt_buckle", "instance_count": 589, "def": "the buckle used to fasten a belt", "synonyms": ["belt_buckle"], "image_count": 367, "id": 89, "frequency": "f", "synset": "belt_buckle.n.01"}, {"name": "bench", "instance_count": 4374, "def": "a long seat for more than one person", "synonyms": ["bench"], "image_count": 1922, "id": 90, "frequency": "f", "synset": "bench.n.01"}, {"name": "beret", "instance_count": 57, "def": "a cap with no brim or bill; made of soft cloth", "synonyms": ["beret"], "image_count": 18, "id": 91, "frequency": "c", "synset": "beret.n.01"}, {"name": "bib", "instance_count": 96, "def": "a napkin tied under the chin of a child while eating", "synonyms": ["bib"], "image_count": 81, "id": 92, "frequency": "c", "synset": "bib.n.02"}, {"name": "Bible", "instance_count": 2, "def": "the sacred writings of the Christian religions", "synonyms": ["Bible"], "image_count": 1, "id": 93, "frequency": "r", "synset": "bible.n.01"}, {"name": "bicycle", "instance_count": 4566, "def": "a wheeled vehicle that has two wheels and is moved by foot pedals", "synonyms": ["bicycle", "bike_(bicycle)"], "image_count": 1852, "id": 94, "frequency": "f", "synset": "bicycle.n.01"}, {"name": "visor", "instance_count": 777, "def": "a brim that projects to the front to shade the eyes", "synonyms": ["visor", "vizor"], "image_count": 430, "id": 95, "frequency": "f", "synset": "bill.n.09"}, {"name": "billboard", "instance_count": 1025, "def": "large outdoor signboard", "synonyms": ["billboard"], "image_count": 247, "id": 96, "frequency": "f", "synset": "billboard.n.01"}, {"name": "binder", "instance_count": 311, "def": "holds loose papers or magazines", "synonyms": ["binder", "ring-binder"], "image_count": 94, "id": 97, "frequency": "c", "synset": "binder.n.03"}, {"name": "binoculars", "instance_count": 22, "def": "an optical instrument designed for simultaneous use by both eyes", "synonyms": ["binoculars", "field_glasses", "opera_glasses"], "image_count": 21, "id": 98, "frequency": "c", "synset": "binoculars.n.01"}, {"name": "bird", "instance_count": 11557, "def": "animal characterized by feathers and wings", "synonyms": ["bird"], "image_count": 1821, "id": 99, "frequency": "f", "synset": "bird.n.01"}, {"name": "birdfeeder", "instance_count": 16, "def": "an outdoor device that supplies food for wild birds", "synonyms": ["birdfeeder"], "image_count": 16, "id": 100, "frequency": "c", "synset": "bird_feeder.n.01"}, {"name": "birdbath", "instance_count": 12, "def": "an ornamental basin (usually in a garden) for birds to bathe in", "synonyms": ["birdbath"], "image_count": 12, "id": 101, "frequency": "c", "synset": "birdbath.n.01"}, {"name": "birdcage", "instance_count": 180, "def": "a cage in which a bird can be kept", "synonyms": ["birdcage"], "image_count": 25, "id": 102, "frequency": "c", "synset": "birdcage.n.01"}, {"name": "birdhouse", "instance_count": 60, "def": "a shelter for birds", "synonyms": ["birdhouse"], "image_count": 41, "id": 103, "frequency": "c", "synset": "birdhouse.n.01"}, {"name": "birthday_cake", "instance_count": 311, "def": "decorated cake served at a birthday party", "synonyms": ["birthday_cake"], "image_count": 244, "id": 104, "frequency": "f", "synset": "birthday_cake.n.01"}, {"name": "birthday_card", "instance_count": 23, "def": "a card expressing a birthday greeting", "synonyms": ["birthday_card"], "image_count": 7, "id": 105, "frequency": "r", "synset": "birthday_card.n.01"}, {"name": "pirate_flag", "instance_count": 1, "def": "a flag usually bearing a white skull and crossbones on a black background", "synonyms": ["pirate_flag"], "image_count": 1, "id": 106, "frequency": "r", "synset": "black_flag.n.01"}, {"name": "black_sheep", "instance_count": 214, "def": "sheep with a black coat", "synonyms": ["black_sheep"], "image_count": 40, "id": 107, "frequency": "c", "synset": "black_sheep.n.02"}, {"name": "blackberry", "instance_count": 406, "def": "large sweet black or very dark purple edible aggregate fruit", "synonyms": ["blackberry"], "image_count": 40, "id": 108, "frequency": "c", "synset": "blackberry.n.01"}, {"name": "blackboard", "instance_count": 154, "def": "sheet of slate; for writing with chalk", "synonyms": ["blackboard", "chalkboard"], "image_count": 104, "id": 109, "frequency": "f", "synset": "blackboard.n.01"}, {"name": "blanket", "instance_count": 3075, "def": "bedding that keeps a person warm in bed", "synonyms": ["blanket"], "image_count": 1671, "id": 110, "frequency": "f", "synset": "blanket.n.01"}, {"name": "blazer", "instance_count": 124, "def": "lightweight jacket; often striped in the colors of a club or school", "synonyms": ["blazer", "sport_jacket", "sport_coat", "sports_jacket", "sports_coat"], "image_count": 49, "id": 111, "frequency": "c", "synset": "blazer.n.01"}, {"name": "blender", "instance_count": 316, "def": "an electrically powered mixer that mix or chop or liquefy foods", "synonyms": ["blender", "liquidizer", "liquidiser"], "image_count": 243, "id": 112, "frequency": "f", "synset": "blender.n.01"}, {"name": "blimp", "instance_count": 3, "def": "a small nonrigid airship used for observation or as a barrage balloon", "synonyms": ["blimp"], "image_count": 2, "id": 113, "frequency": "r", "synset": "blimp.n.02"}, {"name": "blinker", "instance_count": 1269, "def": "a light that flashes on and off; used as a signal or to send messages", "synonyms": ["blinker", "flasher"], "image_count": 242, "id": 114, "frequency": "f", "synset": "blinker.n.01"}, {"name": "blouse", "instance_count": 623, "def": "a top worn by women", "synonyms": ["blouse"], "image_count": 271, "id": 115, "frequency": "f", "synset": "blouse.n.01"}, {"name": "blueberry", "instance_count": 2114, "def": "sweet edible dark-blue berries of blueberry plants", "synonyms": ["blueberry"], "image_count": 104, "id": 116, "frequency": "f", "synset": "blueberry.n.02"}, {"name": "gameboard", "instance_count": 17, "def": "a flat portable surface (usually rectangular) designed for board games", "synonyms": ["gameboard"], "image_count": 8, "id": 117, "frequency": "r", "synset": "board.n.09"}, {"name": "boat", "instance_count": 9981, "def": "a vessel for travel on water", "synonyms": ["boat", "ship_(boat)"], "image_count": 1758, "id": 118, "frequency": "f", "synset": "boat.n.01"}, {"name": "bob", "instance_count": 2, "def": "a small float usually made of cork; attached to a fishing line", "synonyms": ["bob", "bobber", "bobfloat"], "image_count": 1, "id": 119, "frequency": "r", "synset": "bob.n.05"}, {"name": "bobbin", "instance_count": 190, "def": "a thing around which thread/tape/film or other flexible materials can be wound", "synonyms": ["bobbin", "spool", "reel"], "image_count": 48, "id": 120, "frequency": "c", "synset": "bobbin.n.01"}, {"name": "bobby_pin", "instance_count": 43, "def": "a flat wire hairpin used to hold bobbed hair in place", "synonyms": ["bobby_pin", "hairgrip"], "image_count": 14, "id": 121, "frequency": "c", "synset": "bobby_pin.n.01"}, {"name": "boiled_egg", "instance_count": 125, "def": "egg cooked briefly in the shell in gently boiling water", "synonyms": ["boiled_egg", "coddled_egg"], "image_count": 40, "id": 122, "frequency": "c", "synset": "boiled_egg.n.01"}, {"name": "bolo_tie", "instance_count": 1, "def": "a cord fastened around the neck with an ornamental clasp and worn as a necktie", "synonyms": ["bolo_tie", "bolo", "bola_tie", "bola"], "image_count": 1, "id": 123, "frequency": "r", "synset": "bolo_tie.n.01"}, {"name": "deadbolt", "instance_count": 46, "def": "the part of a lock that is engaged or withdrawn with a key", "synonyms": ["deadbolt"], "image_count": 37, "id": 124, "frequency": "c", "synset": "bolt.n.03"}, {"name": "bolt", "instance_count": 11261, "def": "a screw that screws into a nut to form a fastener", "synonyms": ["bolt"], "image_count": 1510, "id": 125, "frequency": "f", "synset": "bolt.n.06"}, {"name": "bonnet", "instance_count": 10, "def": "a hat tied under the chin", "synonyms": ["bonnet"], "image_count": 6, "id": 126, "frequency": "r", "synset": "bonnet.n.01"}, {"name": "book", "instance_count": 33353, "def": "a written work or composition that has been published", "synonyms": ["book"], "image_count": 1903, "id": 127, "frequency": "f", "synset": "book.n.01"}, {"name": "bookcase", "instance_count": 113, "def": "a piece of furniture with shelves for storing books", "synonyms": ["bookcase"], "image_count": 70, "id": 128, "frequency": "c", "synset": "bookcase.n.01"}, {"name": "booklet", "instance_count": 439, "def": "a small book usually having a paper cover", "synonyms": ["booklet", "brochure", "leaflet", "pamphlet"], "image_count": 86, "id": 129, "frequency": "c", "synset": "booklet.n.01"}, {"name": "bookmark", "instance_count": 15, "def": "a marker (a piece of paper or ribbon) placed between the pages of a book", "synonyms": ["bookmark", "bookmarker"], "image_count": 7, "id": 130, "frequency": "r", "synset": "bookmark.n.01"}, {"name": "boom_microphone", "instance_count": 10, "def": "a pole carrying an overhead microphone projected over a film or tv set", "synonyms": ["boom_microphone", "microphone_boom"], "image_count": 5, "id": 131, "frequency": "r", "synset": "boom.n.04"}, {"name": "boot", "instance_count": 4194, "def": "footwear that covers the whole foot and lower leg", "synonyms": ["boot"], "image_count": 1406, "id": 132, "frequency": "f", "synset": "boot.n.01"}, {"name": "bottle", "instance_count": 7969, "def": "a glass or plastic vessel used for storing drinks or other liquids", "synonyms": ["bottle"], "image_count": 1901, "id": 133, "frequency": "f", "synset": "bottle.n.01"}, {"name": "bottle_opener", "instance_count": 15, "def": "an opener for removing caps or corks from bottles", "synonyms": ["bottle_opener"], "image_count": 15, "id": 134, "frequency": "c", "synset": "bottle_opener.n.01"}, {"name": "bouquet", "instance_count": 53, "def": "an arrangement of flowers that is usually given as a present", "synonyms": ["bouquet"], "image_count": 28, "id": 135, "frequency": "c", "synset": "bouquet.n.01"}, {"name": "bow_(weapon)", "instance_count": 6, "def": "a weapon for shooting arrows", "synonyms": ["bow_(weapon)"], "image_count": 6, "id": 136, "frequency": "r", "synset": "bow.n.04"}, {"name": "bow_(decorative_ribbons)", "instance_count": 1144, "def": "a decorative interlacing of ribbons", "synonyms": ["bow_(decorative_ribbons)"], "image_count": 494, "id": 137, "frequency": "f", "synset": "bow.n.08"}, {"name": "bow-tie", "instance_count": 359, "def": "a man's tie that ties in a bow", "synonyms": ["bow-tie", "bowtie"], "image_count": 234, "id": 138, "frequency": "f", "synset": "bow_tie.n.01"}, {"name": "bowl", "instance_count": 5308, "def": "a dish that is round and open at the top for serving foods", "synonyms": ["bowl"], "image_count": 1922, "id": 139, "frequency": "f", "synset": "bowl.n.03"}, {"name": "pipe_bowl", "instance_count": 1, "def": "a small round container that is open at the top for holding tobacco", "synonyms": ["pipe_bowl"], "image_count": 1, "id": 140, "frequency": "r", "synset": "bowl.n.08"}, {"name": "bowler_hat", "instance_count": 89, "def": "a felt hat that is round and hard with a narrow brim", "synonyms": ["bowler_hat", "bowler", "derby_hat", "derby", "plug_hat"], "image_count": 35, "id": 141, "frequency": "c", "synset": "bowler_hat.n.01"}, {"name": "bowling_ball", "instance_count": 38, "def": "a large ball with finger holes used in the sport of bowling", "synonyms": ["bowling_ball"], "image_count": 5, "id": 142, "frequency": "r", "synset": "bowling_ball.n.01"}, {"name": "box", "instance_count": 7855, "def": "a (usually rectangular) container; may have a lid", "synonyms": ["box"], "image_count": 1828, "id": 143, "frequency": "f", "synset": "box.n.01"}, {"name": "boxing_glove", "instance_count": 22, "def": "large glove coverings the fists of a fighter worn for the sport of boxing", "synonyms": ["boxing_glove"], "image_count": 8, "id": 144, "frequency": "r", "synset": "boxing_glove.n.01"}, {"name": "suspenders", "instance_count": 88, "def": "elastic straps that hold trousers up (usually used in the plural)", "synonyms": ["suspenders"], "image_count": 63, "id": 145, "frequency": "c", "synset": "brace.n.06"}, {"name": "bracelet", "instance_count": 3219, "def": "jewelry worn around the wrist for decoration", "synonyms": ["bracelet", "bangle"], "image_count": 1668, "id": 146, "frequency": "f", "synset": "bracelet.n.02"}, {"name": "brass_plaque", "instance_count": 4, "def": "a memorial made of brass", "synonyms": ["brass_plaque"], "image_count": 4, "id": 147, "frequency": "r", "synset": "brass.n.07"}, {"name": "brassiere", "instance_count": 118, "def": "an undergarment worn by women to support their breasts", "synonyms": ["brassiere", "bra", "bandeau"], "image_count": 95, "id": 148, "frequency": "c", "synset": "brassiere.n.01"}, {"name": "bread-bin", "instance_count": 17, "def": "a container used to keep bread or cake in", "synonyms": ["bread-bin", "breadbox"], "image_count": 17, "id": 149, "frequency": "c", "synset": "bread-bin.n.01"}, {"name": "bread", "instance_count": 6550, "def": "food made from dough of flour or meal and usually raised with yeast or baking powder and then baked", "synonyms": ["bread"], "image_count": 1567, "id": 150, "frequency": "f", "synset": "bread.n.01"}, {"name": "breechcloth", "instance_count": 3, "def": "a garment that provides covering for the loins", "synonyms": ["breechcloth", "breechclout", "loincloth"], "image_count": 2, "id": 151, "frequency": "r", "synset": "breechcloth.n.01"}, {"name": "bridal_gown", "instance_count": 118, "def": "a gown worn by the bride at a wedding", "synonyms": ["bridal_gown", "wedding_gown", "wedding_dress"], "image_count": 103, "id": 152, "frequency": "f", "synset": "bridal_gown.n.01"}, {"name": "briefcase", "instance_count": 84, "def": "a case with a handle; for carrying papers or files or books", "synonyms": ["briefcase"], "image_count": 50, "id": 153, "frequency": "c", "synset": "briefcase.n.01"}, {"name": "broccoli", "instance_count": 12166, "def": "plant with dense clusters of tight green flower buds", "synonyms": ["broccoli"], "image_count": 1309, "id": 154, "frequency": "f", "synset": "broccoli.n.01"}, {"name": "broach", "instance_count": 9, "def": "a decorative pin worn by women", "synonyms": ["broach"], "image_count": 6, "id": 155, "frequency": "r", "synset": "brooch.n.01"}, {"name": "broom", "instance_count": 144, "def": "bundle of straws or twigs attached to a long handle; used for cleaning", "synonyms": ["broom"], "image_count": 92, "id": 156, "frequency": "c", "synset": "broom.n.01"}, {"name": "brownie", "instance_count": 217, "def": "square or bar of very rich chocolate cake usually with nuts", "synonyms": ["brownie"], "image_count": 19, "id": 157, "frequency": "c", "synset": "brownie.n.03"}, {"name": "brussels_sprouts", "instance_count": 590, "def": "the small edible cabbage-like buds growing along a stalk", "synonyms": ["brussels_sprouts"], "image_count": 37, "id": 158, "frequency": "c", "synset": "brussels_sprouts.n.01"}, {"name": "bubble_gum", "instance_count": 4, "def": "a kind of chewing gum that can be blown into bubbles", "synonyms": ["bubble_gum"], "image_count": 4, "id": 159, "frequency": "r", "synset": "bubble_gum.n.01"}, {"name": "bucket", "instance_count": 1346, "def": "a roughly cylindrical vessel that is open at the top", "synonyms": ["bucket", "pail"], "image_count": 709, "id": 160, "frequency": "f", "synset": "bucket.n.01"}, {"name": "horse_buggy", "instance_count": 19, "def": "a small lightweight carriage; drawn by a single horse", "synonyms": ["horse_buggy"], "image_count": 9, "id": 161, "frequency": "r", "synset": "buggy.n.01"}, {"name": "bull", "instance_count": 230, "def": "a cow with horns", "synonyms": ["horned_cow"], "image_count": 82, "id": 162, "frequency": "c", "synset": "bull.n.11"}, {"name": "bulldog", "instance_count": 21, "def": "a thickset short-haired dog with a large head and strong undershot lower jaw", "synonyms": ["bulldog"], "image_count": 15, "id": 163, "frequency": "c", "synset": "bulldog.n.01"}, {"name": "bulldozer", "instance_count": 4, "def": "large powerful tractor; a large blade in front flattens areas of ground", "synonyms": ["bulldozer", "dozer"], "image_count": 3, "id": 164, "frequency": "r", "synset": "bulldozer.n.01"}, {"name": "bullet_train", "instance_count": 80, "def": "a high-speed passenger train", "synonyms": ["bullet_train"], "image_count": 61, "id": 165, "frequency": "c", "synset": "bullet_train.n.01"}, {"name": "bulletin_board", "instance_count": 76, "def": "a board that hangs on a wall; displays announcements", "synonyms": ["bulletin_board", "notice_board"], "image_count": 51, "id": 166, "frequency": "c", "synset": "bulletin_board.n.02"}, {"name": "bulletproof_vest", "instance_count": 27, "def": "a vest capable of resisting the impact of a bullet", "synonyms": ["bulletproof_vest"], "image_count": 5, "id": 167, "frequency": "r", "synset": "bulletproof_vest.n.01"}, {"name": "bullhorn", "instance_count": 15, "def": "a portable loudspeaker with built-in microphone and amplifier", "synonyms": ["bullhorn", "megaphone"], "image_count": 13, "id": 168, "frequency": "c", "synset": "bullhorn.n.01"}, {"name": "bun", "instance_count": 1780, "def": "small rounded bread either plain or sweet", "synonyms": ["bun", "roll"], "image_count": 642, "id": 169, "frequency": "f", "synset": "bun.n.01"}, {"name": "bunk_bed", "instance_count": 44, "def": "beds built one above the other", "synonyms": ["bunk_bed"], "image_count": 24, "id": 170, "frequency": "c", "synset": "bunk_bed.n.01"}, {"name": "buoy", "instance_count": 1404, "def": "a float attached by rope to the seabed to mark channels in a harbor or underwater hazards", "synonyms": ["buoy"], "image_count": 255, "id": 171, "frequency": "f", "synset": "buoy.n.01"}, {"name": "burrito", "instance_count": 14, "def": "a flour tortilla folded around a filling", "synonyms": ["burrito"], "image_count": 9, "id": 172, "frequency": "r", "synset": "burrito.n.01"}, {"name": "bus_(vehicle)", "instance_count": 3281, "def": "a vehicle carrying many passengers; used for public transport", "synonyms": ["bus_(vehicle)", "autobus", "charabanc", "double-decker", "motorbus", "motorcoach"], "image_count": 1808, "id": 173, "frequency": "f", "synset": "bus.n.01"}, {"name": "business_card", "instance_count": 84, "def": "a card on which are printed the person's name and business affiliation", "synonyms": ["business_card"], "image_count": 31, "id": 174, "frequency": "c", "synset": "business_card.n.01"}, {"name": "butter", "instance_count": 308, "def": "an edible emulsion of fat globules made by churning milk or cream; for cooking and table use", "synonyms": ["butter"], "image_count": 158, "id": 175, "frequency": "f", "synset": "butter.n.01"}, {"name": "butterfly", "instance_count": 296, "def": "insect typically having a slender body with knobbed antennae and broad colorful wings", "synonyms": ["butterfly"], "image_count": 80, "id": 176, "frequency": "c", "synset": "butterfly.n.01"}, {"name": "button", "instance_count": 7884, "def": "a round fastener sewn to shirts and coats etc to fit through buttonholes", "synonyms": ["button"], "image_count": 1884, "id": 177, "frequency": "f", "synset": "button.n.01"}, {"name": "cab_(taxi)", "instance_count": 414, "def": "a car that takes passengers where they want to go in exchange for money", "synonyms": ["cab_(taxi)", "taxi", "taxicab"], "image_count": 158, "id": 178, "frequency": "f", "synset": "cab.n.03"}, {"name": "cabana", "instance_count": 20, "def": "a small tent used as a dressing room beside the sea or a swimming pool", "synonyms": ["cabana"], "image_count": 2, "id": 179, "frequency": "r", "synset": "cabana.n.01"}, {"name": "cabin_car", "instance_count": 14, "def": "a car on a freight train for use of the train crew; usually the last car on the train", "synonyms": ["cabin_car", "caboose"], "image_count": 12, "id": 180, "frequency": "c", "synset": "cabin_car.n.01"}, {"name": "cabinet", "instance_count": 7371, "def": "a piece of furniture resembling a cupboard with doors and shelves and drawers", "synonyms": ["cabinet"], "image_count": 1659, "id": 181, "frequency": "f", "synset": "cabinet.n.01"}, {"name": "locker", "instance_count": 95, "def": "a storage compartment for clothes and valuables; usually it has a lock", "synonyms": ["locker", "storage_locker"], "image_count": 7, "id": 182, "frequency": "r", "synset": "cabinet.n.03"}, {"name": "cake", "instance_count": 2297, "def": "baked goods made from or based on a mixture of flour, sugar, eggs, and fat", "synonyms": ["cake"], "image_count": 834, "id": 183, "frequency": "f", "synset": "cake.n.03"}, {"name": "calculator", "instance_count": 60, "def": "a small machine that is used for mathematical calculations", "synonyms": ["calculator"], "image_count": 57, "id": 184, "frequency": "c", "synset": "calculator.n.02"}, {"name": "calendar", "instance_count": 251, "def": "a list or register of events (appointments/social events/court cases, etc)", "synonyms": ["calendar"], "image_count": 174, "id": 185, "frequency": "f", "synset": "calendar.n.02"}, {"name": "calf", "instance_count": 301, "def": "young of domestic cattle", "synonyms": ["calf"], "image_count": 95, "id": 186, "frequency": "c", "synset": "calf.n.01"}, {"name": "camcorder", "instance_count": 45, "def": "a portable television camera and videocassette recorder", "synonyms": ["camcorder"], "image_count": 27, "id": 187, "frequency": "c", "synset": "camcorder.n.01"}, {"name": "camel", "instance_count": 34, "def": "cud-chewing mammal used as a draft or saddle animal in desert regions", "synonyms": ["camel"], "image_count": 22, "id": 188, "frequency": "c", "synset": "camel.n.01"}, {"name": "camera", "instance_count": 2471, "def": "equipment for taking photographs", "synonyms": ["camera"], "image_count": 1391, "id": 189, "frequency": "f", "synset": "camera.n.01"}, {"name": "camera_lens", "instance_count": 167, "def": "a lens that focuses the image in a camera", "synonyms": ["camera_lens"], "image_count": 90, "id": 190, "frequency": "c", "synset": "camera_lens.n.01"}, {"name": "camper_(vehicle)", "instance_count": 102, "def": "a recreational vehicle equipped for camping out while traveling", "synonyms": ["camper_(vehicle)", "camping_bus", "motor_home"], "image_count": 40, "id": 191, "frequency": "c", "synset": "camper.n.02"}, {"name": "can", "instance_count": 1424, "def": "airtight sealed metal container for food or drink or paint etc.", "synonyms": ["can", "tin_can"], "image_count": 445, "id": 192, "frequency": "f", "synset": "can.n.01"}, {"name": "can_opener", "instance_count": 22, "def": "a device for cutting cans open", "synonyms": ["can_opener", "tin_opener"], "image_count": 21, "id": 193, "frequency": "c", "synset": "can_opener.n.01"}, {"name": "candle", "instance_count": 4288, "def": "stick of wax with a wick in the middle", "synonyms": ["candle", "candlestick"], "image_count": 1132, "id": 194, "frequency": "f", "synset": "candle.n.01"}, {"name": "candle_holder", "instance_count": 530, "def": "a holder with sockets for candles", "synonyms": ["candle_holder"], "image_count": 177, "id": 195, "frequency": "f", "synset": "candlestick.n.01"}, {"name": "candy_bar", "instance_count": 29, "def": "a candy shaped as a bar", "synonyms": ["candy_bar"], "image_count": 4, "id": 196, "frequency": "r", "synset": "candy_bar.n.01"}, {"name": "candy_cane", "instance_count": 107, "def": "a hard candy in the shape of a rod (usually with stripes)", "synonyms": ["candy_cane"], "image_count": 17, "id": 197, "frequency": "c", "synset": "candy_cane.n.01"}, {"name": "walking_cane", "instance_count": 106, "def": "a stick that people can lean on to help them walk", "synonyms": ["walking_cane"], "image_count": 84, "id": 198, "frequency": "c", "synset": "cane.n.01"}, {"name": "canister", "instance_count": 218, "def": "metal container for storing dry foods such as tea or flour", "synonyms": ["canister", "cannister"], "image_count": 55, "id": 199, "frequency": "c", "synset": "canister.n.02"}, {"name": "canoe", "instance_count": 96, "def": "small and light boat; pointed at both ends; propelled with a paddle", "synonyms": ["canoe"], "image_count": 30, "id": 200, "frequency": "c", "synset": "canoe.n.01"}, {"name": "cantaloup", "instance_count": 193, "def": "the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh", "synonyms": ["cantaloup", "cantaloupe"], "image_count": 25, "id": 201, "frequency": "c", "synset": "cantaloup.n.02"}, {"name": "canteen", "instance_count": 2, "def": "a flask for carrying water; used by soldiers or travelers", "synonyms": ["canteen"], "image_count": 2, "id": 202, "frequency": "r", "synset": "canteen.n.01"}, {"name": "cap_(headwear)", "instance_count": 636, "def": "a tight-fitting headwear", "synonyms": ["cap_(headwear)"], "image_count": 125, "id": 203, "frequency": "f", "synset": "cap.n.01"}, {"name": "bottle_cap", "instance_count": 5293, "def": "a top (as for a bottle)", "synonyms": ["bottle_cap", "cap_(container_lid)"], "image_count": 1135, "id": 204, "frequency": "f", "synset": "cap.n.02"}, {"name": "cape", "instance_count": 27, "def": "a sleeveless garment like a cloak but shorter", "synonyms": ["cape"], "image_count": 19, "id": 205, "frequency": "c", "synset": "cape.n.02"}, {"name": "cappuccino", "instance_count": 87, "def": "equal parts of espresso and steamed milk", "synonyms": ["cappuccino", "coffee_cappuccino"], "image_count": 72, "id": 206, "frequency": "c", "synset": "cappuccino.n.01"}, {"name": "car_(automobile)", "instance_count": 10528, "def": "a motor vehicle with four wheels", "synonyms": ["car_(automobile)", "auto_(automobile)", "automobile"], "image_count": 1926, "id": 207, "frequency": "f", "synset": "car.n.01"}, {"name": "railcar_(part_of_a_train)", "instance_count": 928, "def": "a wheeled vehicle adapted to the rails of railroad (mark each individual railcar separately)", "synonyms": ["railcar_(part_of_a_train)", "railway_car_(part_of_a_train)", "railroad_car_(part_of_a_train)"], "image_count": 159, "id": 208, "frequency": "f", "synset": "car.n.02"}, {"name": "elevator_car", "instance_count": 10, "def": "where passengers ride up and down", "synonyms": ["elevator_car"], "image_count": 7, "id": 209, "frequency": "r", "synset": "car.n.04"}, {"name": "car_battery", "instance_count": 1, "def": "a battery in a motor vehicle", "synonyms": ["car_battery", "automobile_battery"], "image_count": 1, "id": 210, "frequency": "r", "synset": "car_battery.n.01"}, {"name": "identity_card", "instance_count": 16, "def": "a card certifying the identity of the bearer", "synonyms": ["identity_card"], "image_count": 13, "id": 211, "frequency": "c", "synset": "card.n.02"}, {"name": "card", "instance_count": 122, "def": "a rectangular piece of paper used to send messages (e.g. greetings or pictures)", "synonyms": ["card"], "image_count": 35, "id": 212, "frequency": "c", "synset": "card.n.03"}, {"name": "cardigan", "instance_count": 22, "def": "knitted jacket that is fastened up the front with buttons or a zipper", "synonyms": ["cardigan"], "image_count": 18, "id": 213, "frequency": "c", "synset": "cardigan.n.01"}, {"name": "cargo_ship", "instance_count": 15, "def": "a ship designed to carry cargo", "synonyms": ["cargo_ship", "cargo_vessel"], "image_count": 8, "id": 214, "frequency": "r", "synset": "cargo_ship.n.01"}, {"name": "carnation", "instance_count": 22, "def": "plant with pink to purple-red spice-scented usually double flowers", "synonyms": ["carnation"], "image_count": 6, "id": 215, "frequency": "r", "synset": "carnation.n.01"}, {"name": "horse_carriage", "instance_count": 49, "def": "a vehicle with wheels drawn by one or more horses", "synonyms": ["horse_carriage"], "image_count": 35, "id": 216, "frequency": "c", "synset": "carriage.n.02"}, {"name": "carrot", "instance_count": 18049, "def": "deep orange edible root of the cultivated carrot plant", "synonyms": ["carrot"], "image_count": 1222, "id": 217, "frequency": "f", "synset": "carrot.n.01"}, {"name": "tote_bag", "instance_count": 231, "def": "a capacious bag or basket", "synonyms": ["tote_bag"], "image_count": 103, "id": 218, "frequency": "f", "synset": "carryall.n.01"}, {"name": "cart", "instance_count": 51, "def": "a heavy open wagon usually having two wheels and drawn by an animal", "synonyms": ["cart"], "image_count": 28, "id": 219, "frequency": "c", "synset": "cart.n.01"}, {"name": "carton", "instance_count": 206, "def": "a container made of cardboard for holding food or drink", "synonyms": ["carton"], "image_count": 63, "id": 220, "frequency": "c", "synset": "carton.n.02"}, {"name": "cash_register", "instance_count": 33, "def": "a cashbox with an adding machine to register transactions", "synonyms": ["cash_register", "register_(for_cash_transactions)"], "image_count": 28, "id": 221, "frequency": "c", "synset": "cash_register.n.01"}, {"name": "casserole", "instance_count": 12, "def": "food cooked and served in a casserole", "synonyms": ["casserole"], "image_count": 5, "id": 222, "frequency": "r", "synset": "casserole.n.01"}, {"name": "cassette", "instance_count": 74, "def": "a container that holds a magnetic tape used for recording or playing sound or video", "synonyms": ["cassette"], "image_count": 7, "id": 223, "frequency": "r", "synset": "cassette.n.01"}, {"name": "cast", "instance_count": 15, "def": "bandage consisting of a firm covering that immobilizes broken bones while they heal", "synonyms": ["cast", "plaster_cast", "plaster_bandage"], "image_count": 14, "id": 224, "frequency": "c", "synset": "cast.n.05"}, {"name": "cat", "instance_count": 2387, "def": "a domestic house cat", "synonyms": ["cat"], "image_count": 1918, "id": 225, "frequency": "f", "synset": "cat.n.01"}, {"name": "cauliflower", "instance_count": 1035, "def": "edible compact head of white undeveloped flowers", "synonyms": ["cauliflower"], "image_count": 133, "id": 226, "frequency": "f", "synset": "cauliflower.n.02"}, {"name": "cayenne_(spice)", "instance_count": 49, "def": "ground pods and seeds of pungent red peppers of the genus Capsicum", "synonyms": ["cayenne_(spice)", "cayenne_pepper_(spice)", "red_pepper_(spice)"], "image_count": 16, "id": 227, "frequency": "c", "synset": "cayenne.n.02"}, {"name": "CD_player", "instance_count": 37, "def": "electronic equipment for playing compact discs (CDs)", "synonyms": ["CD_player"], "image_count": 27, "id": 228, "frequency": "c", "synset": "cd_player.n.01"}, {"name": "celery", "instance_count": 911, "def": "widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked", "synonyms": ["celery"], "image_count": 110, "id": 229, "frequency": "f", "synset": "celery.n.01"}, {"name": "cellular_telephone", "instance_count": 2902, "def": "a hand-held mobile telephone", "synonyms": ["cellular_telephone", "cellular_phone", "cellphone", "mobile_phone", "smart_phone"], "image_count": 1895, "id": 230, "frequency": "f", "synset": "cellular_telephone.n.01"}, {"name": "chain_mail", "instance_count": 13, "def": "(Middle Ages) flexible armor made of interlinked metal rings", "synonyms": ["chain_mail", "ring_mail", "chain_armor", "chain_armour", "ring_armor", "ring_armour"], "image_count": 4, "id": 231, "frequency": "r", "synset": "chain_mail.n.01"}, {"name": "chair", "instance_count": 11549, "def": "a seat for one person, with a support for the back", "synonyms": ["chair"], "image_count": 1927, "id": 232, "frequency": "f", "synset": "chair.n.01"}, {"name": "chaise_longue", "instance_count": 15, "def": "a long chair; for reclining", "synonyms": ["chaise_longue", "chaise", "daybed"], "image_count": 8, "id": 233, "frequency": "r", "synset": "chaise_longue.n.01"}, {"name": "chalice", "instance_count": 1, "def": "a bowl-shaped drinking vessel; especially the Eucharistic cup", "synonyms": ["chalice"], "image_count": 1, "id": 234, "frequency": "r", "synset": "chalice.n.01"}, {"name": "chandelier", "instance_count": 392, "def": "branched lighting fixture; often ornate; hangs from the ceiling", "synonyms": ["chandelier"], "image_count": 263, "id": 235, "frequency": "f", "synset": "chandelier.n.01"}, {"name": "chap", "instance_count": 19, "def": "leather leggings without a seat; worn over trousers by cowboys to protect their legs", "synonyms": ["chap"], "image_count": 10, "id": 236, "frequency": "r", "synset": "chap.n.04"}, {"name": "checkbook", "instance_count": 2, "def": "a book issued to holders of checking accounts", "synonyms": ["checkbook", "chequebook"], "image_count": 2, "id": 237, "frequency": "r", "synset": "checkbook.n.01"}, {"name": "checkerboard", "instance_count": 3, "def": "a board having 64 squares of two alternating colors", "synonyms": ["checkerboard"], "image_count": 3, "id": 238, "frequency": "r", "synset": "checkerboard.n.01"}, {"name": "cherry", "instance_count": 903, "def": "a red fruit with a single hard stone", "synonyms": ["cherry"], "image_count": 87, "id": 239, "frequency": "c", "synset": "cherry.n.03"}, {"name": "chessboard", "instance_count": 13, "def": "a checkerboard used to play chess", "synonyms": ["chessboard"], "image_count": 9, "id": 240, "frequency": "r", "synset": "chessboard.n.01"}, {"name": "chicken_(animal)", "instance_count": 417, "def": "a domestic fowl bred for flesh or eggs", "synonyms": ["chicken_(animal)"], "image_count": 71, "id": 241, "frequency": "c", "synset": "chicken.n.02"}, {"name": "chickpea", "instance_count": 265, "def": "the seed of the chickpea plant; usually dried", "synonyms": ["chickpea", "garbanzo"], "image_count": 13, "id": 242, "frequency": "c", "synset": "chickpea.n.01"}, {"name": "chili_(vegetable)", "instance_count": 354, "def": "very hot and finely tapering pepper of special pungency", "synonyms": ["chili_(vegetable)", "chili_pepper_(vegetable)", "chilli_(vegetable)", "chilly_(vegetable)", "chile_(vegetable)"], "image_count": 18, "id": 243, "frequency": "c", "synset": "chili.n.02"}, {"name": "chime", "instance_count": 2, "def": "an instrument consisting of a set of bells that are struck with a hammer", "synonyms": ["chime", "gong"], "image_count": 2, "id": 244, "frequency": "r", "synset": "chime.n.01"}, {"name": "chinaware", "instance_count": 41, "def": "dishware made of high quality porcelain", "synonyms": ["chinaware"], "image_count": 5, "id": 245, "frequency": "r", "synset": "chinaware.n.01"}, {"name": "crisp_(potato_chip)", "instance_count": 541, "def": "a thin crisp slice of potato fried in deep fat", "synonyms": ["crisp_(potato_chip)", "potato_chip"], "image_count": 45, "id": 246, "frequency": "c", "synset": "chip.n.04"}, {"name": "poker_chip", "instance_count": 21, "def": "a small disk-shaped counter used to represent money when gambling", "synonyms": ["poker_chip"], "image_count": 1, "id": 247, "frequency": "r", "synset": "chip.n.06"}, {"name": "chocolate_bar", "instance_count": 179, "def": "a bar of chocolate candy", "synonyms": ["chocolate_bar"], "image_count": 23, "id": 248, "frequency": "c", "synset": "chocolate_bar.n.01"}, {"name": "chocolate_cake", "instance_count": 80, "def": "cake containing chocolate", "synonyms": ["chocolate_cake"], "image_count": 32, "id": 249, "frequency": "c", "synset": "chocolate_cake.n.01"}, {"name": "chocolate_milk", "instance_count": 7, "def": "milk flavored with chocolate syrup", "synonyms": ["chocolate_milk"], "image_count": 4, "id": 250, "frequency": "r", "synset": "chocolate_milk.n.01"}, {"name": "chocolate_mousse", "instance_count": 1, "def": "dessert mousse made with chocolate", "synonyms": ["chocolate_mousse"], "image_count": 1, "id": 251, "frequency": "r", "synset": "chocolate_mousse.n.01"}, {"name": "choker", "instance_count": 1380, "def": "shirt collar, animal collar, or tight-fitting necklace", "synonyms": ["choker", "collar", "neckband"], "image_count": 858, "id": 252, "frequency": "f", "synset": "choker.n.03"}, {"name": "chopping_board", "instance_count": 840, "def": "a wooden board where meats or vegetables can be cut", "synonyms": ["chopping_board", "cutting_board", "chopping_block"], "image_count": 661, "id": 253, "frequency": "f", "synset": "chopping_board.n.01"}, {"name": "chopstick", "instance_count": 557, "def": "one of a pair of slender sticks used as oriental tableware to eat food with", "synonyms": ["chopstick"], "image_count": 168, "id": 254, "frequency": "f", "synset": "chopstick.n.01"}, {"name": "Christmas_tree", "instance_count": 303, "def": "an ornamented evergreen used as a Christmas decoration", "synonyms": ["Christmas_tree"], "image_count": 210, "id": 255, "frequency": "f", "synset": "christmas_tree.n.05"}, {"name": "slide", "instance_count": 106, "def": "sloping channel through which things can descend", "synonyms": ["slide"], "image_count": 65, "id": 256, "frequency": "c", "synset": "chute.n.02"}, {"name": "cider", "instance_count": 38, "def": "a beverage made from juice pressed from apples", "synonyms": ["cider", "cyder"], "image_count": 4, "id": 257, "frequency": "r", "synset": "cider.n.01"}, {"name": "cigar_box", "instance_count": 3, "def": "a box for holding cigars", "synonyms": ["cigar_box"], "image_count": 2, "id": 258, "frequency": "r", "synset": "cigar_box.n.01"}, {"name": "cigarette", "instance_count": 269, "def": "finely ground tobacco wrapped in paper; for smoking", "synonyms": ["cigarette"], "image_count": 159, "id": 259, "frequency": "f", "synset": "cigarette.n.01"}, {"name": "cigarette_case", "instance_count": 35, "def": "a small flat case for holding cigarettes", "synonyms": ["cigarette_case", "cigarette_pack"], "image_count": 31, "id": 260, "frequency": "c", "synset": "cigarette_case.n.01"}, {"name": "cistern", "instance_count": 901, "def": "a tank that holds the water used to flush a toilet", "synonyms": ["cistern", "water_tank"], "image_count": 811, "id": 261, "frequency": "f", "synset": "cistern.n.02"}, {"name": "clarinet", "instance_count": 1, "def": "a single-reed instrument with a straight tube", "synonyms": ["clarinet"], "image_count": 1, "id": 262, "frequency": "r", "synset": "clarinet.n.01"}, {"name": "clasp", "instance_count": 197, "def": "a fastener (as a buckle or hook) that is used to hold two things together", "synonyms": ["clasp"], "image_count": 42, "id": 263, "frequency": "c", "synset": "clasp.n.01"}, {"name": "cleansing_agent", "instance_count": 63, "def": "a preparation used in cleaning something", "synonyms": ["cleansing_agent", "cleanser", "cleaner"], "image_count": 27, "id": 264, "frequency": "c", "synset": "cleansing_agent.n.01"}, {"name": "cleat_(for_securing_rope)", "instance_count": 8, "def": "a fastener (usually with two projecting horns) around which a rope can be secured", "synonyms": ["cleat_(for_securing_rope)"], "image_count": 2, "id": 265, "frequency": "r", "synset": "cleat.n.02"}, {"name": "clementine", "instance_count": 108, "def": "a variety of mandarin orange", "synonyms": ["clementine"], "image_count": 5, "id": 266, "frequency": "r", "synset": "clementine.n.01"}, {"name": "clip", "instance_count": 301, "def": "any of various small fasteners used to hold loose articles together", "synonyms": ["clip"], "image_count": 95, "id": 267, "frequency": "c", "synset": "clip.n.03"}, {"name": "clipboard", "instance_count": 36, "def": "a small writing board with a clip at the top for holding papers", "synonyms": ["clipboard"], "image_count": 32, "id": 268, "frequency": "c", "synset": "clipboard.n.01"}, {"name": "clippers_(for_plants)", "instance_count": 1, "def": "shears for cutting grass or shrubbery (often used in the plural)", "synonyms": ["clippers_(for_plants)"], "image_count": 1, "id": 269, "frequency": "r", "synset": "clipper.n.03"}, {"name": "cloak", "instance_count": 1, "def": "a loose outer garment", "synonyms": ["cloak"], "image_count": 1, "id": 270, "frequency": "r", "synset": "cloak.n.02"}, {"name": "clock", "instance_count": 2677, "def": "a timepiece that shows the time of day", "synonyms": ["clock", "timepiece", "timekeeper"], "image_count": 1844, "id": 271, "frequency": "f", "synset": "clock.n.01"}, {"name": "clock_tower", "instance_count": 932, "def": "a tower with a large clock visible high up on an outside face", "synonyms": ["clock_tower"], "image_count": 897, "id": 272, "frequency": "f", "synset": "clock_tower.n.01"}, {"name": "clothes_hamper", "instance_count": 47, "def": "a hamper that holds dirty clothes to be washed or wet clothes to be dried", "synonyms": ["clothes_hamper", "laundry_basket", "clothes_basket"], "image_count": 31, "id": 273, "frequency": "c", "synset": "clothes_hamper.n.01"}, {"name": "clothespin", "instance_count": 111, "def": "wood or plastic fastener; for holding clothes on a clothesline", "synonyms": ["clothespin", "clothes_peg"], "image_count": 23, "id": 274, "frequency": "c", "synset": "clothespin.n.01"}, {"name": "clutch_bag", "instance_count": 1, "def": "a woman's strapless purse that is carried in the hand", "synonyms": ["clutch_bag"], "image_count": 1, "id": 275, "frequency": "r", "synset": "clutch_bag.n.01"}, {"name": "coaster", "instance_count": 390, "def": "a covering (plate or mat) that protects the surface of a table", "synonyms": ["coaster"], "image_count": 202, "id": 276, "frequency": "f", "synset": "coaster.n.03"}, {"name": "coat", "instance_count": 4145, "def": "an outer garment that has sleeves and covers the body from shoulder down", "synonyms": ["coat"], "image_count": 746, "id": 277, "frequency": "f", "synset": "coat.n.01"}, {"name": "coat_hanger", "instance_count": 282, "def": "a hanger that is shaped like a person's shoulders", "synonyms": ["coat_hanger", "clothes_hanger", "dress_hanger"], "image_count": 44, "id": 278, "frequency": "c", "synset": "coat_hanger.n.01"}, {"name": "coatrack", "instance_count": 16, "def": "a rack with hooks for temporarily holding coats and hats", "synonyms": ["coatrack", "hatrack"], "image_count": 14, "id": 279, "frequency": "c", "synset": "coatrack.n.01"}, {"name": "cock", "instance_count": 132, "def": "adult male chicken", "synonyms": ["cock", "rooster"], "image_count": 26, "id": 280, "frequency": "c", "synset": "cock.n.04"}, {"name": "cockroach", "instance_count": 1, "def": "any of numerous chiefly nocturnal insects; some are domestic pests", "synonyms": ["cockroach"], "image_count": 1, "id": 281, "frequency": "r", "synset": "cockroach.n.01"}, {"name": "cocoa_(beverage)", "instance_count": 4, "def": "a beverage made from cocoa powder and milk and sugar; usually drunk hot", "synonyms": ["cocoa_(beverage)", "hot_chocolate_(beverage)", "drinking_chocolate"], "image_count": 2, "id": 282, "frequency": "r", "synset": "cocoa.n.01"}, {"name": "coconut", "instance_count": 273, "def": "large hard-shelled brown oval nut with a fibrous husk", "synonyms": ["coconut", "cocoanut"], "image_count": 25, "id": 283, "frequency": "c", "synset": "coconut.n.02"}, {"name": "coffee_maker", "instance_count": 271, "def": "a kitchen appliance for brewing coffee automatically", "synonyms": ["coffee_maker", "coffee_machine"], "image_count": 238, "id": 284, "frequency": "f", "synset": "coffee_maker.n.01"}, {"name": "coffee_table", "instance_count": 709, "def": "low table where magazines can be placed and coffee or cocktails are served", "synonyms": ["coffee_table", "cocktail_table"], "image_count": 592, "id": 285, "frequency": "f", "synset": "coffee_table.n.01"}, {"name": "coffeepot", "instance_count": 32, "def": "tall pot in which coffee is brewed", "synonyms": ["coffeepot"], "image_count": 26, "id": 286, "frequency": "c", "synset": "coffeepot.n.01"}, {"name": "coil", "instance_count": 7, "def": "tubing that is wound in a spiral", "synonyms": ["coil"], "image_count": 5, "id": 287, "frequency": "r", "synset": "coil.n.05"}, {"name": "coin", "instance_count": 305, "def": "a flat metal piece (usually a disc) used as money", "synonyms": ["coin"], "image_count": 42, "id": 288, "frequency": "c", "synset": "coin.n.01"}, {"name": "colander", "instance_count": 16, "def": "bowl-shaped strainer; used to wash or drain foods", "synonyms": ["colander", "cullender"], "image_count": 13, "id": 289, "frequency": "c", "synset": "colander.n.01"}, {"name": "coleslaw", "instance_count": 72, "def": "basically shredded cabbage", "synonyms": ["coleslaw", "slaw"], "image_count": 46, "id": 290, "frequency": "c", "synset": "coleslaw.n.01"}, {"name": "coloring_material", "instance_count": 1, "def": "any material used for its color", "synonyms": ["coloring_material", "colouring_material"], "image_count": 1, "id": 291, "frequency": "r", "synset": "coloring_material.n.01"}, {"name": "combination_lock", "instance_count": 13, "def": "lock that can be opened only by turning dials in a special sequence", "synonyms": ["combination_lock"], "image_count": 8, "id": 292, "frequency": "r", "synset": "combination_lock.n.01"}, {"name": "pacifier", "instance_count": 40, "def": "device used for an infant to suck or bite on", "synonyms": ["pacifier", "teething_ring"], "image_count": 34, "id": 293, "frequency": "c", "synset": "comforter.n.04"}, {"name": "comic_book", "instance_count": 97, "def": "a magazine devoted to comic strips", "synonyms": ["comic_book"], "image_count": 5, "id": 294, "frequency": "r", "synset": "comic_book.n.01"}, {"name": "compass", "instance_count": 1, "def": "navigational instrument for finding directions", "synonyms": ["compass"], "image_count": 1, "id": 295, "frequency": "r", "synset": "compass.n.01"}, {"name": "computer_keyboard", "instance_count": 2745, "def": "a keyboard that is a data input device for computers", "synonyms": ["computer_keyboard", "keyboard_(computer)"], "image_count": 1871, "id": 296, "frequency": "f", "synset": "computer_keyboard.n.01"}, {"name": "condiment", "instance_count": 2985, "def": "a preparation (a sauce or relish or spice) to enhance flavor or enjoyment", "synonyms": ["condiment"], "image_count": 717, "id": 297, "frequency": "f", "synset": "condiment.n.01"}, {"name": "cone", "instance_count": 4081, "def": "a cone-shaped object used to direct traffic", "synonyms": ["cone", "traffic_cone"], "image_count": 1010, "id": 298, "frequency": "f", "synset": "cone.n.01"}, {"name": "control", "instance_count": 1775, "def": "a mechanism that controls the operation of a machine", "synonyms": ["control", "controller"], "image_count": 679, "id": 299, "frequency": "f", "synset": "control.n.09"}, {"name": "convertible_(automobile)", "instance_count": 4, "def": "a car that has top that can be folded or removed", "synonyms": ["convertible_(automobile)"], "image_count": 3, "id": 300, "frequency": "r", "synset": "convertible.n.01"}, {"name": "sofa_bed", "instance_count": 5, "def": "a sofa that can be converted into a bed", "synonyms": ["sofa_bed"], "image_count": 4, "id": 301, "frequency": "r", "synset": "convertible.n.03"}, {"name": "cooker", "instance_count": 1, "def": "a utensil for cooking", "synonyms": ["cooker"], "image_count": 1, "id": 302, "frequency": "r", "synset": "cooker.n.01"}, {"name": "cookie", "instance_count": 1920, "def": "any of various small flat sweet cakes (`biscuit' is the British term)", "synonyms": ["cookie", "cooky", "biscuit_(cookie)"], "image_count": 166, "id": 303, "frequency": "f", "synset": "cookie.n.01"}, {"name": "cooking_utensil", "instance_count": 18, "def": "a kitchen utensil made of material that does not melt easily; used for cooking", "synonyms": ["cooking_utensil"], "image_count": 2, "id": 304, "frequency": "r", "synset": "cooking_utensil.n.01"}, {"name": "cooler_(for_food)", "instance_count": 499, "def": "an insulated box for storing food often with ice", "synonyms": ["cooler_(for_food)", "ice_chest"], "image_count": 266, "id": 305, "frequency": "f", "synset": "cooler.n.01"}, {"name": "cork_(bottle_plug)", "instance_count": 326, "def": "the plug in the mouth of a bottle (especially a wine bottle)", "synonyms": ["cork_(bottle_plug)", "bottle_cork"], "image_count": 101, "id": 306, "frequency": "f", "synset": "cork.n.04"}, {"name": "corkboard", "instance_count": 7, "def": "a sheet consisting of cork granules", "synonyms": ["corkboard"], "image_count": 6, "id": 307, "frequency": "r", "synset": "corkboard.n.01"}, {"name": "corkscrew", "instance_count": 15, "def": "a bottle opener that pulls corks", "synonyms": ["corkscrew", "bottle_screw"], "image_count": 14, "id": 308, "frequency": "c", "synset": "corkscrew.n.01"}, {"name": "edible_corn", "instance_count": 1883, "def": "ears or kernels of corn that can be prepared and served for human food (only mark individual ears or kernels)", "synonyms": ["edible_corn", "corn", "maize"], "image_count": 133, "id": 309, "frequency": "f", "synset": "corn.n.03"}, {"name": "cornbread", "instance_count": 10, "def": "bread made primarily of cornmeal", "synonyms": ["cornbread"], "image_count": 2, "id": 310, "frequency": "r", "synset": "cornbread.n.01"}, {"name": "cornet", "instance_count": 65, "def": "a brass musical instrument with a narrow tube and a flared bell and many valves", "synonyms": ["cornet", "horn", "trumpet"], "image_count": 38, "id": 311, "frequency": "c", "synset": "cornet.n.01"}, {"name": "cornice", "instance_count": 149, "def": "a decorative framework to conceal curtain fixtures at the top of a window casing", "synonyms": ["cornice", "valance", "valance_board", "pelmet"], "image_count": 95, "id": 312, "frequency": "c", "synset": "cornice.n.01"}, {"name": "cornmeal", "instance_count": 1, "def": "coarsely ground corn", "synonyms": ["cornmeal"], "image_count": 1, "id": 313, "frequency": "r", "synset": "cornmeal.n.01"}, {"name": "corset", "instance_count": 12, "def": "a woman's close-fitting foundation garment", "synonyms": ["corset", "girdle"], "image_count": 12, "id": 314, "frequency": "c", "synset": "corset.n.01"}, {"name": "costume", "instance_count": 124, "def": "the attire characteristic of a country or a time or a social class", "synonyms": ["costume"], "image_count": 49, "id": 315, "frequency": "c", "synset": "costume.n.04"}, {"name": "cougar", "instance_count": 6, "def": "large American feline resembling a lion", "synonyms": ["cougar", "puma", "catamount", "mountain_lion", "panther"], "image_count": 5, "id": 316, "frequency": "r", "synset": "cougar.n.01"}, {"name": "coverall", "instance_count": 12, "def": "a loose-fitting protective garment that is worn over other clothing", "synonyms": ["coverall"], "image_count": 5, "id": 317, "frequency": "r", "synset": "coverall.n.01"}, {"name": "cowbell", "instance_count": 29, "def": "a bell hung around the neck of cow so that the cow can be easily located", "synonyms": ["cowbell"], "image_count": 16, "id": 318, "frequency": "c", "synset": "cowbell.n.01"}, {"name": "cowboy_hat", "instance_count": 535, "def": "a hat with a wide brim and a soft crown; worn by American ranch hands", "synonyms": ["cowboy_hat", "ten-gallon_hat"], "image_count": 216, "id": 319, "frequency": "f", "synset": "cowboy_hat.n.01"}, {"name": "crab_(animal)", "instance_count": 50, "def": "decapod having eyes on short stalks and a broad flattened shell and pincers", "synonyms": ["crab_(animal)"], "image_count": 12, "id": 320, "frequency": "c", "synset": "crab.n.01"}, {"name": "crabmeat", "instance_count": 5, "def": "the edible flesh of any of various crabs", "synonyms": ["crabmeat"], "image_count": 1, "id": 321, "frequency": "r", "synset": "crab.n.05"}, {"name": "cracker", "instance_count": 510, "def": "a thin crisp wafer", "synonyms": ["cracker"], "image_count": 54, "id": 322, "frequency": "c", "synset": "cracker.n.01"}, {"name": "crape", "instance_count": 12, "def": "small very thin pancake", "synonyms": ["crape", "crepe", "French_pancake"], "image_count": 5, "id": 323, "frequency": "r", "synset": "crape.n.01"}, {"name": "crate", "instance_count": 1832, "def": "a rugged box (usually made of wood); used for shipping", "synonyms": ["crate"], "image_count": 245, "id": 324, "frequency": "f", "synset": "crate.n.01"}, {"name": "crayon", "instance_count": 59, "def": "writing or drawing implement made of a colored stick of composition wax", "synonyms": ["crayon", "wax_crayon"], "image_count": 12, "id": 325, "frequency": "c", "synset": "crayon.n.01"}, {"name": "cream_pitcher", "instance_count": 10, "def": "a small pitcher for serving cream", "synonyms": ["cream_pitcher"], "image_count": 7, "id": 326, "frequency": "r", "synset": "cream_pitcher.n.01"}, {"name": "crescent_roll", "instance_count": 152, "def": "very rich flaky crescent-shaped roll", "synonyms": ["crescent_roll", "croissant"], "image_count": 35, "id": 327, "frequency": "c", "synset": "crescent_roll.n.01"}, {"name": "crib", "instance_count": 40, "def": "baby bed with high sides made of slats", "synonyms": ["crib", "cot"], "image_count": 36, "id": 328, "frequency": "c", "synset": "crib.n.01"}, {"name": "crock_pot", "instance_count": 128, "def": "an earthen jar (made of baked clay) or a modern electric crockpot", "synonyms": ["crock_pot", "earthenware_jar"], "image_count": 32, "id": 329, "frequency": "c", "synset": "crock.n.03"}, {"name": "crossbar", "instance_count": 6991, "def": "a horizontal bar that goes across something", "synonyms": ["crossbar"], "image_count": 1027, "id": 330, "frequency": "f", "synset": "crossbar.n.01"}, {"name": "crouton", "instance_count": 140, "def": "a small piece of toasted or fried bread; served in soup or salads", "synonyms": ["crouton"], "image_count": 10, "id": 331, "frequency": "r", "synset": "crouton.n.01"}, {"name": "crow", "instance_count": 24, "def": "black birds having a raucous call", "synonyms": ["crow"], "image_count": 12, "id": 332, "frequency": "c", "synset": "crow.n.01"}, {"name": "crowbar", "instance_count": 1, "def": "a heavy iron lever with one end forged into a wedge", "synonyms": ["crowbar", "wrecking_bar", "pry_bar"], "image_count": 1, "id": 333, "frequency": "r", "synset": "crowbar.n.01"}, {"name": "crown", "instance_count": 126, "def": "an ornamental jeweled headdress signifying sovereignty", "synonyms": ["crown"], "image_count": 67, "id": 334, "frequency": "c", "synset": "crown.n.04"}, {"name": "crucifix", "instance_count": 99, "def": "representation of the cross on which Jesus died", "synonyms": ["crucifix"], "image_count": 71, "id": 335, "frequency": "c", "synset": "crucifix.n.01"}, {"name": "cruise_ship", "instance_count": 35, "def": "a passenger ship used commercially for pleasure cruises", "synonyms": ["cruise_ship", "cruise_liner"], "image_count": 30, "id": 336, "frequency": "c", "synset": "cruise_ship.n.01"}, {"name": "police_cruiser", "instance_count": 86, "def": "a car in which policemen cruise the streets", "synonyms": ["police_cruiser", "patrol_car", "police_car", "squad_car"], "image_count": 48, "id": 337, "frequency": "c", "synset": "cruiser.n.01"}, {"name": "crumb", "instance_count": 3021, "def": "small piece of e.g. bread or cake", "synonyms": ["crumb"], "image_count": 249, "id": 338, "frequency": "f", "synset": "crumb.n.03"}, {"name": "crutch", "instance_count": 20, "def": "a wooden or metal staff that fits under the armpit and reaches to the ground", "synonyms": ["crutch"], "image_count": 13, "id": 339, "frequency": "c", "synset": "crutch.n.01"}, {"name": "cub_(animal)", "instance_count": 55, "def": "the young of certain carnivorous mammals such as the bear or wolf or lion", "synonyms": ["cub_(animal)"], "image_count": 29, "id": 340, "frequency": "c", "synset": "cub.n.03"}, {"name": "cube", "instance_count": 189, "def": "a block in the (approximate) shape of a cube", "synonyms": ["cube", "square_block"], "image_count": 14, "id": 341, "frequency": "c", "synset": "cube.n.05"}, {"name": "cucumber", "instance_count": 1533, "def": "cylindrical green fruit with thin green rind and white flesh eaten as a vegetable", "synonyms": ["cucumber", "cuke"], "image_count": 236, "id": 342, "frequency": "f", "synset": "cucumber.n.02"}, {"name": "cufflink", "instance_count": 17, "def": "jewelry consisting of linked buttons used to fasten the cuffs of a shirt", "synonyms": ["cufflink"], "image_count": 15, "id": 343, "frequency": "c", "synset": "cufflink.n.01"}, {"name": "cup", "instance_count": 4637, "def": "a small open container usually used for drinking; usually has a handle", "synonyms": ["cup"], "image_count": 1521, "id": 344, "frequency": "f", "synset": "cup.n.01"}, {"name": "trophy_cup", "instance_count": 80, "def": "a metal award or cup-shaped vessel with handles that is awarded as a trophy to a competition winner", "synonyms": ["trophy_cup"], "image_count": 25, "id": 345, "frequency": "c", "synset": "cup.n.08"}, {"name": "cupboard", "instance_count": 1623, "def": "a small room (or recess) or cabinet used for storage space", "synonyms": ["cupboard", "closet"], "image_count": 249, "id": 346, "frequency": "f", "synset": "cupboard.n.01"}, {"name": "cupcake", "instance_count": 1628, "def": "small cake baked in a muffin tin", "synonyms": ["cupcake"], "image_count": 139, "id": 347, "frequency": "f", "synset": "cupcake.n.01"}, {"name": "hair_curler", "instance_count": 20, "def": "a cylindrical tube around which the hair is wound to curl it", "synonyms": ["hair_curler", "hair_roller", "hair_crimper"], "image_count": 2, "id": 348, "frequency": "r", "synset": "curler.n.01"}, {"name": "curling_iron", "instance_count": 2, "def": "a cylindrical home appliance that heats hair that has been curled around it", "synonyms": ["curling_iron"], "image_count": 2, "id": 349, "frequency": "r", "synset": "curling_iron.n.01"}, {"name": "curtain", "instance_count": 4506, "def": "hanging cloth used as a blind (especially for a window)", "synonyms": ["curtain", "drapery"], "image_count": 1890, "id": 350, "frequency": "f", "synset": "curtain.n.01"}, {"name": "cushion", "instance_count": 7174, "def": "a soft bag filled with air or padding such as feathers or foam rubber", "synonyms": ["cushion"], "image_count": 1240, "id": 351, "frequency": "f", "synset": "cushion.n.03"}, {"name": "cylinder", "instance_count": 3, "def": "a cylindrical container", "synonyms": ["cylinder"], "image_count": 1, "id": 352, "frequency": "r", "synset": "cylinder.n.04"}, {"name": "cymbal", "instance_count": 24, "def": "a percussion instrument consisting of a concave brass disk", "synonyms": ["cymbal"], "image_count": 9, "id": 353, "frequency": "r", "synset": "cymbal.n.01"}, {"name": "dagger", "instance_count": 1, "def": "a short knife with a pointed blade used for piercing or stabbing", "synonyms": ["dagger"], "image_count": 1, "id": 354, "frequency": "r", "synset": "dagger.n.01"}, {"name": "dalmatian", "instance_count": 3, "def": "a large breed having a smooth white coat with black or brown spots", "synonyms": ["dalmatian"], "image_count": 3, "id": 355, "frequency": "r", "synset": "dalmatian.n.02"}, {"name": "dartboard", "instance_count": 11, "def": "a circular board of wood or cork used as the target in the game of darts", "synonyms": ["dartboard"], "image_count": 11, "id": 356, "frequency": "c", "synset": "dartboard.n.01"}, {"name": "date_(fruit)", "instance_count": 103, "def": "sweet edible fruit of the date palm with a single long woody seed", "synonyms": ["date_(fruit)"], "image_count": 4, "id": 357, "frequency": "r", "synset": "date.n.08"}, {"name": "deck_chair", "instance_count": 1787, "def": "a folding chair for use outdoors; a wooden frame supports a length of canvas", "synonyms": ["deck_chair", "beach_chair"], "image_count": 236, "id": 358, "frequency": "f", "synset": "deck_chair.n.01"}, {"name": "deer", "instance_count": 130, "def": "distinguished from Bovidae by the male's having solid deciduous antlers", "synonyms": ["deer", "cervid"], "image_count": 44, "id": 359, "frequency": "c", "synset": "deer.n.01"}, {"name": "dental_floss", "instance_count": 20, "def": "a soft thread for cleaning the spaces between the teeth", "synonyms": ["dental_floss", "floss"], "image_count": 19, "id": 360, "frequency": "c", "synset": "dental_floss.n.01"}, {"name": "desk", "instance_count": 1662, "def": "a piece of furniture with a writing surface and usually drawers or other compartments", "synonyms": ["desk"], "image_count": 1100, "id": 361, "frequency": "f", "synset": "desk.n.01"}, {"name": "detergent", "instance_count": 11, "def": "a surface-active chemical widely used in industry and laundering", "synonyms": ["detergent"], "image_count": 7, "id": 362, "frequency": "r", "synset": "detergent.n.01"}, {"name": "diaper", "instance_count": 89, "def": "garment consisting of a folded cloth drawn up between the legs and fastened at the waist", "synonyms": ["diaper"], "image_count": 69, "id": 363, "frequency": "c", "synset": "diaper.n.01"}, {"name": "diary", "instance_count": 2, "def": "yearly planner book", "synonyms": ["diary", "journal"], "image_count": 2, "id": 364, "frequency": "r", "synset": "diary.n.01"}, {"name": "die", "instance_count": 25, "def": "a small cube with 1 to 6 spots on the six faces; used in gambling", "synonyms": ["die", "dice"], "image_count": 8, "id": 365, "frequency": "r", "synset": "die.n.01"}, {"name": "dinghy", "instance_count": 15, "def": "a small boat of shallow draft with seats and oars with which it is propelled", "synonyms": ["dinghy", "dory", "rowboat"], "image_count": 5, "id": 366, "frequency": "r", "synset": "dinghy.n.01"}, {"name": "dining_table", "instance_count": 312, "def": "a table at which meals are served", "synonyms": ["dining_table"], "image_count": 227, "id": 367, "frequency": "f", "synset": "dining_table.n.01"}, {"name": "tux", "instance_count": 10, "def": "semiformal evening dress for men", "synonyms": ["tux", "tuxedo"], "image_count": 6, "id": 368, "frequency": "r", "synset": "dinner_jacket.n.01"}, {"name": "dish", "instance_count": 532, "def": "a piece of dishware normally used as a container for holding or serving food", "synonyms": ["dish"], "image_count": 106, "id": 369, "frequency": "f", "synset": "dish.n.01"}, {"name": "dish_antenna", "instance_count": 153, "def": "directional antenna consisting of a parabolic reflector", "synonyms": ["dish_antenna"], "image_count": 81, "id": 370, "frequency": "c", "synset": "dish.n.05"}, {"name": "dishrag", "instance_count": 32, "def": "a cloth for washing dishes or cleaning in general", "synonyms": ["dishrag", "dishcloth"], "image_count": 17, "id": 371, "frequency": "c", "synset": "dishrag.n.01"}, {"name": "dishtowel", "instance_count": 223, "def": "a towel for drying dishes", "synonyms": ["dishtowel", "tea_towel"], "image_count": 134, "id": 372, "frequency": "f", "synset": "dishtowel.n.01"}, {"name": "dishwasher", "instance_count": 317, "def": "a machine for washing dishes", "synonyms": ["dishwasher", "dishwashing_machine"], "image_count": 312, "id": 373, "frequency": "f", "synset": "dishwasher.n.01"}, {"name": "dishwasher_detergent", "instance_count": 9, "def": "dishsoap or dish detergent designed for use in dishwashers", "synonyms": ["dishwasher_detergent", "dishwashing_detergent", "dishwashing_liquid", "dishsoap"], "image_count": 8, "id": 374, "frequency": "r", "synset": "dishwasher_detergent.n.01"}, {"name": "dispenser", "instance_count": 610, "def": "a container so designed that the contents can be used in prescribed amounts", "synonyms": ["dispenser"], "image_count": 271, "id": 375, "frequency": "f", "synset": "dispenser.n.01"}, {"name": "diving_board", "instance_count": 2, "def": "a springboard from which swimmers can dive", "synonyms": ["diving_board"], "image_count": 2, "id": 376, "frequency": "r", "synset": "diving_board.n.01"}, {"name": "Dixie_cup", "instance_count": 352, "def": "a disposable cup made of paper; for holding drinks", "synonyms": ["Dixie_cup", "paper_cup"], "image_count": 103, "id": 377, "frequency": "f", "synset": "dixie_cup.n.01"}, {"name": "dog", "instance_count": 2684, "def": "a common domesticated dog", "synonyms": ["dog"], "image_count": 1938, "id": 378, "frequency": "f", "synset": "dog.n.01"}, {"name": "dog_collar", "instance_count": 733, "def": "a collar for a dog", "synonyms": ["dog_collar"], "image_count": 574, "id": 379, "frequency": "f", "synset": "dog_collar.n.01"}, {"name": "doll", "instance_count": 398, "def": "a toy replica of a HUMAN (NOT AN ANIMAL)", "synonyms": ["doll"], "image_count": 120, "id": 380, "frequency": "f", "synset": "doll.n.01"}, {"name": "dollar", "instance_count": 2, "def": "a piece of paper money worth one dollar", "synonyms": ["dollar", "dollar_bill", "one_dollar_bill"], "image_count": 2, "id": 381, "frequency": "r", "synset": "dollar.n.02"}, {"name": "dollhouse", "instance_count": 2, "def": "a house so small that it is likened to a child's plaything", "synonyms": ["dollhouse", "doll's_house"], "image_count": 2, "id": 382, "frequency": "r", "synset": "dollhouse.n.01"}, {"name": "dolphin", "instance_count": 38, "def": "any of various small toothed whales with a beaklike snout; larger than porpoises", "synonyms": ["dolphin"], "image_count": 13, "id": 383, "frequency": "c", "synset": "dolphin.n.02"}, {"name": "domestic_ass", "instance_count": 49, "def": "domestic beast of burden descended from the African wild ass; patient but stubborn", "synonyms": ["domestic_ass", "donkey"], "image_count": 29, "id": 384, "frequency": "c", "synset": "domestic_ass.n.01"}, {"name": "doorknob", "instance_count": 4072, "def": "a knob used to open a door (often called `doorhandle' in Great Britain)", "synonyms": ["doorknob", "doorhandle"], "image_count": 1710, "id": 385, "frequency": "f", "synset": "doorknob.n.01"}, {"name": "doormat", "instance_count": 78, "def": "a mat placed outside an exterior door for wiping the shoes before entering", "synonyms": ["doormat", "welcome_mat"], "image_count": 66, "id": 386, "frequency": "c", "synset": "doormat.n.02"}, {"name": "doughnut", "instance_count": 11911, "def": "a small ring-shaped friedcake", "synonyms": ["doughnut", "donut"], "image_count": 1008, "id": 387, "frequency": "f", "synset": "doughnut.n.02"}, {"name": "dove", "instance_count": 2, "def": "any of numerous small pigeons", "synonyms": ["dove"], "image_count": 1, "id": 388, "frequency": "r", "synset": "dove.n.01"}, {"name": "dragonfly", "instance_count": 8, "def": "slender-bodied non-stinging insect having iridescent wings that are outspread at rest", "synonyms": ["dragonfly"], "image_count": 3, "id": 389, "frequency": "r", "synset": "dragonfly.n.01"}, {"name": "drawer", "instance_count": 7927, "def": "a boxlike container in a piece of furniture; made so as to slide in and out", "synonyms": ["drawer"], "image_count": 1942, "id": 390, "frequency": "f", "synset": "drawer.n.01"}, {"name": "underdrawers", "instance_count": 23, "def": "underpants worn by men", "synonyms": ["underdrawers", "boxers", "boxershorts"], "image_count": 19, "id": 391, "frequency": "c", "synset": "drawers.n.01"}, {"name": "dress", "instance_count": 2842, "def": "a one-piece garment for a woman; has skirt and bodice", "synonyms": ["dress", "frock"], "image_count": 1488, "id": 392, "frequency": "f", "synset": "dress.n.01"}, {"name": "dress_hat", "instance_count": 76, "def": "a man's hat with a tall crown; usually covered with silk or with beaver fur", "synonyms": ["dress_hat", "high_hat", "opera_hat", "silk_hat", "top_hat"], "image_count": 46, "id": 393, "frequency": "c", "synset": "dress_hat.n.01"}, {"name": "dress_suit", "instance_count": 306, "def": "formalwear consisting of full evening dress for men", "synonyms": ["dress_suit"], "image_count": 106, "id": 394, "frequency": "f", "synset": "dress_suit.n.01"}, {"name": "dresser", "instance_count": 152, "def": "a cabinet with shelves", "synonyms": ["dresser"], "image_count": 115, "id": 395, "frequency": "f", "synset": "dresser.n.05"}, {"name": "drill", "instance_count": 24, "def": "a tool with a sharp rotating point for making holes in hard materials", "synonyms": ["drill"], "image_count": 19, "id": 396, "frequency": "c", "synset": "drill.n.01"}, {"name": "drone", "instance_count": 2, "def": "an aircraft without a pilot that is operated by remote control", "synonyms": ["drone"], "image_count": 2, "id": 397, "frequency": "r", "synset": "drone.n.04"}, {"name": "dropper", "instance_count": 1, "def": "pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time", "synonyms": ["dropper", "eye_dropper"], "image_count": 1, "id": 398, "frequency": "r", "synset": "dropper.n.01"}, {"name": "drum_(musical_instrument)", "instance_count": 59, "def": "a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end", "synonyms": ["drum_(musical_instrument)"], "image_count": 28, "id": 399, "frequency": "c", "synset": "drum.n.01"}, {"name": "drumstick", "instance_count": 25, "def": "a stick used for playing a drum", "synonyms": ["drumstick"], "image_count": 9, "id": 400, "frequency": "r", "synset": "drumstick.n.02"}, {"name": "duck", "instance_count": 1090, "def": "small web-footed broad-billed swimming bird", "synonyms": ["duck"], "image_count": 192, "id": 401, "frequency": "f", "synset": "duck.n.01"}, {"name": "duckling", "instance_count": 36, "def": "young duck", "synonyms": ["duckling"], "image_count": 12, "id": 402, "frequency": "c", "synset": "duckling.n.02"}, {"name": "duct_tape", "instance_count": 77, "def": "a wide silvery adhesive tape", "synonyms": ["duct_tape"], "image_count": 21, "id": 403, "frequency": "c", "synset": "duct_tape.n.01"}, {"name": "duffel_bag", "instance_count": 666, "def": "a large cylindrical bag of heavy cloth (does not include suitcases)", "synonyms": ["duffel_bag", "duffle_bag", "duffel", "duffle"], "image_count": 247, "id": 404, "frequency": "f", "synset": "duffel_bag.n.01"}, {"name": "dumbbell", "instance_count": 13, "def": "an exercising weight with two ball-like ends connected by a short handle", "synonyms": ["dumbbell"], "image_count": 6, "id": 405, "frequency": "r", "synset": "dumbbell.n.01"}, {"name": "dumpster", "instance_count": 95, "def": "a container designed to receive and transport and dump waste", "synonyms": ["dumpster"], "image_count": 64, "id": 406, "frequency": "c", "synset": "dumpster.n.01"}, {"name": "dustpan", "instance_count": 7, "def": "a short-handled receptacle into which dust can be swept", "synonyms": ["dustpan"], "image_count": 7, "id": 407, "frequency": "r", "synset": "dustpan.n.02"}, {"name": "eagle", "instance_count": 48, "def": "large birds of prey noted for their broad wings and strong soaring flight", "synonyms": ["eagle"], "image_count": 40, "id": 408, "frequency": "c", "synset": "eagle.n.01"}, {"name": "earphone", "instance_count": 767, "def": "device for listening to audio that is held over or inserted into the ear", "synonyms": ["earphone", "earpiece", "headphone"], "image_count": 542, "id": 409, "frequency": "f", "synset": "earphone.n.01"}, {"name": "earplug", "instance_count": 39, "def": "a soft plug that is inserted into the ear canal to block sound", "synonyms": ["earplug"], "image_count": 2, "id": 410, "frequency": "r", "synset": "earplug.n.01"}, {"name": "earring", "instance_count": 3070, "def": "jewelry to ornament the ear", "synonyms": ["earring"], "image_count": 1898, "id": 411, "frequency": "f", "synset": "earring.n.01"}, {"name": "easel", "instance_count": 43, "def": "an upright tripod for displaying something (usually an artist's canvas)", "synonyms": ["easel"], "image_count": 36, "id": 412, "frequency": "c", "synset": "easel.n.01"}, {"name": "eclair", "instance_count": 39, "def": "oblong cream puff", "synonyms": ["eclair"], "image_count": 4, "id": 413, "frequency": "r", "synset": "eclair.n.01"}, {"name": "eel", "instance_count": 1, "def": "an elongate fish with fatty flesh", "synonyms": ["eel"], "image_count": 1, "id": 414, "frequency": "r", "synset": "eel.n.01"}, {"name": "egg", "instance_count": 813, "def": "oval reproductive body of a fowl (especially a hen) used as food", "synonyms": ["egg", "eggs"], "image_count": 191, "id": 415, "frequency": "f", "synset": "egg.n.02"}, {"name": "egg_roll", "instance_count": 15, "def": "minced vegetables and meat wrapped in a pancake and fried", "synonyms": ["egg_roll", "spring_roll"], "image_count": 6, "id": 416, "frequency": "r", "synset": "egg_roll.n.01"}, {"name": "egg_yolk", "instance_count": 90, "def": "the yellow spherical part of an egg", "synonyms": ["egg_yolk", "yolk_(egg)"], "image_count": 41, "id": 417, "frequency": "c", "synset": "egg_yolk.n.01"}, {"name": "eggbeater", "instance_count": 52, "def": "a mixer for beating eggs or whipping cream", "synonyms": ["eggbeater", "eggwhisk"], "image_count": 39, "id": 418, "frequency": "c", "synset": "eggbeater.n.02"}, {"name": "eggplant", "instance_count": 337, "def": "egg-shaped vegetable having a shiny skin typically dark purple", "synonyms": ["eggplant", "aubergine"], "image_count": 46, "id": 419, "frequency": "c", "synset": "eggplant.n.01"}, {"name": "electric_chair", "instance_count": 1, "def": "a chair-shaped instrument of execution by electrocution", "synonyms": ["electric_chair"], "image_count": 1, "id": 420, "frequency": "r", "synset": "electric_chair.n.01"}, {"name": "refrigerator", "instance_count": 1702, "def": "a refrigerator in which the coolant is pumped around by an electric motor", "synonyms": ["refrigerator"], "image_count": 1451, "id": 421, "frequency": "f", "synset": "electric_refrigerator.n.01"}, {"name": "elephant", "instance_count": 5325, "def": "a common elephant", "synonyms": ["elephant"], "image_count": 1878, "id": 422, "frequency": "f", "synset": "elephant.n.01"}, {"name": "elk", "instance_count": 29, "def": "large northern deer with enormous flattened antlers in the male", "synonyms": ["elk", "moose"], "image_count": 11, "id": 423, "frequency": "c", "synset": "elk.n.01"}, {"name": "envelope", "instance_count": 210, "def": "a flat (usually rectangular) container for a letter, thin package, etc.", "synonyms": ["envelope"], "image_count": 82, "id": 424, "frequency": "c", "synset": "envelope.n.01"}, {"name": "eraser", "instance_count": 41, "def": "an implement used to erase something", "synonyms": ["eraser"], "image_count": 18, "id": 425, "frequency": "c", "synset": "eraser.n.01"}, {"name": "escargot", "instance_count": 5, "def": "edible snail usually served in the shell with a sauce of melted butter and garlic", "synonyms": ["escargot"], "image_count": 1, "id": 426, "frequency": "r", "synset": "escargot.n.01"}, {"name": "eyepatch", "instance_count": 9, "def": "a protective cloth covering for an injured eye", "synonyms": ["eyepatch"], "image_count": 7, "id": 427, "frequency": "r", "synset": "eyepatch.n.01"}, {"name": "falcon", "instance_count": 3, "def": "birds of prey having long pointed powerful wings adapted for swift flight", "synonyms": ["falcon"], "image_count": 3, "id": 428, "frequency": "r", "synset": "falcon.n.01"}, {"name": "fan", "instance_count": 737, "def": "a device for creating a current of air by movement of a surface or surfaces", "synonyms": ["fan"], "image_count": 575, "id": 429, "frequency": "f", "synset": "fan.n.01"}, {"name": "faucet", "instance_count": 3185, "def": "a regulator for controlling the flow of a liquid from a reservoir", "synonyms": ["faucet", "spigot", "tap"], "image_count": 1907, "id": 430, "frequency": "f", "synset": "faucet.n.01"}, {"name": "fedora", "instance_count": 14, "def": "a hat made of felt with a creased crown", "synonyms": ["fedora"], "image_count": 8, "id": 431, "frequency": "r", "synset": "fedora.n.01"}, {"name": "ferret", "instance_count": 5, "def": "domesticated albino variety of the European polecat bred for hunting rats and rabbits", "synonyms": ["ferret"], "image_count": 4, "id": 432, "frequency": "r", "synset": "ferret.n.02"}, {"name": "Ferris_wheel", "instance_count": 32, "def": "a large wheel with suspended seats that remain upright as the wheel rotates", "synonyms": ["Ferris_wheel"], "image_count": 32, "id": 433, "frequency": "c", "synset": "ferris_wheel.n.01"}, {"name": "ferry", "instance_count": 17, "def": "a boat that transports people or vehicles across a body of water and operates on a regular schedule", "synonyms": ["ferry", "ferryboat"], "image_count": 11, "id": 434, "frequency": "c", "synset": "ferry.n.01"}, {"name": "fig_(fruit)", "instance_count": 147, "def": "fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried", "synonyms": ["fig_(fruit)"], "image_count": 4, "id": 435, "frequency": "r", "synset": "fig.n.04"}, {"name": "fighter_jet", "instance_count": 115, "def": "a high-speed military or naval airplane designed to destroy enemy targets", "synonyms": ["fighter_jet", "fighter_aircraft", "attack_aircraft"], "image_count": 54, "id": 436, "frequency": "c", "synset": "fighter.n.02"}, {"name": "figurine", "instance_count": 1056, "def": "a small carved or molded figure", "synonyms": ["figurine"], "image_count": 202, "id": 437, "frequency": "f", "synset": "figurine.n.01"}, {"name": "file_cabinet", "instance_count": 53, "def": "office furniture consisting of a container for keeping papers in order", "synonyms": ["file_cabinet", "filing_cabinet"], "image_count": 32, "id": 438, "frequency": "c", "synset": "file.n.03"}, {"name": "file_(tool)", "instance_count": 3, "def": "a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal", "synonyms": ["file_(tool)"], "image_count": 3, "id": 439, "frequency": "r", "synset": "file.n.04"}, {"name": "fire_alarm", "instance_count": 151, "def": "an alarm that is tripped off by fire or smoke", "synonyms": ["fire_alarm", "smoke_alarm"], "image_count": 130, "id": 440, "frequency": "f", "synset": "fire_alarm.n.02"}, {"name": "fire_engine", "instance_count": 179, "def": "large trucks that carry firefighters and equipment to the site of a fire", "synonyms": ["fire_engine", "fire_truck"], "image_count": 119, "id": 441, "frequency": "f", "synset": "fire_engine.n.01"}, {"name": "fire_extinguisher", "instance_count": 165, "def": "a manually operated device for extinguishing small fires", "synonyms": ["fire_extinguisher", "extinguisher"], "image_count": 141, "id": 442, "frequency": "f", "synset": "fire_extinguisher.n.01"}, {"name": "fire_hose", "instance_count": 67, "def": "a large hose that carries water from a fire hydrant to the site of the fire", "synonyms": ["fire_hose"], "image_count": 29, "id": 443, "frequency": "c", "synset": "fire_hose.n.01"}, {"name": "fireplace", "instance_count": 530, "def": "an open recess in a wall at the base of a chimney where a fire can be built", "synonyms": ["fireplace"], "image_count": 525, "id": 444, "frequency": "f", "synset": "fireplace.n.01"}, {"name": "fireplug", "instance_count": 1458, "def": "an upright hydrant for drawing water to use in fighting a fire", "synonyms": ["fireplug", "fire_hydrant", "hydrant"], "image_count": 1323, "id": 445, "frequency": "f", "synset": "fireplug.n.01"}, {"name": "first-aid_kit", "instance_count": 2, "def": "kit consisting of a set of bandages and medicines for giving first aid", "synonyms": ["first-aid_kit"], "image_count": 2, "id": 446, "frequency": "r", "synset": "first-aid_kit.n.01"}, {"name": "fish", "instance_count": 525, "def": "any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills", "synonyms": ["fish"], "image_count": 113, "id": 447, "frequency": "f", "synset": "fish.n.01"}, {"name": "fish_(food)", "instance_count": 96, "def": "the flesh of fish used as food", "synonyms": ["fish_(food)"], "image_count": 16, "id": 448, "frequency": "c", "synset": "fish.n.02"}, {"name": "fishbowl", "instance_count": 33, "def": "a transparent bowl in which small fish are kept", "synonyms": ["fishbowl", "goldfish_bowl"], "image_count": 7, "id": 449, "frequency": "r", "synset": "fishbowl.n.02"}, {"name": "fishing_rod", "instance_count": 84, "def": "a rod that is used in fishing to extend the fishing line", "synonyms": ["fishing_rod", "fishing_pole"], "image_count": 35, "id": 450, "frequency": "c", "synset": "fishing_rod.n.01"}, {"name": "flag", "instance_count": 7007, "def": "emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)", "synonyms": ["flag"], "image_count": 1908, "id": 451, "frequency": "f", "synset": "flag.n.01"}, {"name": "flagpole", "instance_count": 1082, "def": "a tall staff or pole on which a flag is raised", "synonyms": ["flagpole", "flagstaff"], "image_count": 353, "id": 452, "frequency": "f", "synset": "flagpole.n.02"}, {"name": "flamingo", "instance_count": 309, "def": "large pink web-footed bird with down-bent bill", "synonyms": ["flamingo"], "image_count": 18, "id": 453, "frequency": "c", "synset": "flamingo.n.01"}, {"name": "flannel", "instance_count": 18, "def": "a soft light woolen fabric; used for clothing", "synonyms": ["flannel"], "image_count": 14, "id": 454, "frequency": "c", "synset": "flannel.n.01"}, {"name": "flap", "instance_count": 218, "def": "any broad thin covering attached at one edge, such as a mud flap next to a wheel or a flap on an airplane wing", "synonyms": ["flap"], "image_count": 77, "id": 455, "frequency": "c", "synset": "flap.n.01"}, {"name": "flash", "instance_count": 10, "def": "a lamp for providing momentary light to take a photograph", "synonyms": ["flash", "flashbulb"], "image_count": 8, "id": 456, "frequency": "r", "synset": "flash.n.10"}, {"name": "flashlight", "instance_count": 48, "def": "a small portable battery-powered electric lamp", "synonyms": ["flashlight", "torch"], "image_count": 37, "id": 457, "frequency": "c", "synset": "flashlight.n.01"}, {"name": "fleece", "instance_count": 2, "def": "a soft bulky fabric with deep pile; used chiefly for clothing", "synonyms": ["fleece"], "image_count": 1, "id": 458, "frequency": "r", "synset": "fleece.n.03"}, {"name": "flip-flop_(sandal)", "instance_count": 1103, "def": "a backless sandal held to the foot by a thong between two toes", "synonyms": ["flip-flop_(sandal)"], "image_count": 346, "id": 459, "frequency": "f", "synset": "flip-flop.n.02"}, {"name": "flipper_(footwear)", "instance_count": 49, "def": "a shoe to aid a person in swimming", "synonyms": ["flipper_(footwear)", "fin_(footwear)"], "image_count": 19, "id": 460, "frequency": "c", "synset": "flipper.n.01"}, {"name": "flower_arrangement", "instance_count": 3960, "def": "a decorative arrangement of flowers", "synonyms": ["flower_arrangement", "floral_arrangement"], "image_count": 1779, "id": 461, "frequency": "f", "synset": "flower_arrangement.n.01"}, {"name": "flute_glass", "instance_count": 86, "def": "a tall narrow wineglass", "synonyms": ["flute_glass", "champagne_flute"], "image_count": 23, "id": 462, "frequency": "c", "synset": "flute.n.02"}, {"name": "foal", "instance_count": 30, "def": "a young horse", "synonyms": ["foal"], "image_count": 25, "id": 463, "frequency": "c", "synset": "foal.n.01"}, {"name": "folding_chair", "instance_count": 303, "def": "a chair that can be folded flat for storage", "synonyms": ["folding_chair"], "image_count": 67, "id": 464, "frequency": "c", "synset": "folding_chair.n.01"}, {"name": "food_processor", "instance_count": 22, "def": "a kitchen appliance for shredding, blending, chopping, or slicing food", "synonyms": ["food_processor"], "image_count": 19, "id": 465, "frequency": "c", "synset": "food_processor.n.01"}, {"name": "football_(American)", "instance_count": 35, "def": "the inflated oblong ball used in playing American football", "synonyms": ["football_(American)"], "image_count": 28, "id": 466, "frequency": "c", "synset": "football.n.02"}, {"name": "football_helmet", "instance_count": 7, "def": "a padded helmet with a face mask to protect the head of football players", "synonyms": ["football_helmet"], "image_count": 4, "id": 467, "frequency": "r", "synset": "football_helmet.n.01"}, {"name": "footstool", "instance_count": 41, "def": "a low seat or a stool to rest the feet of a seated person", "synonyms": ["footstool", "footrest"], "image_count": 27, "id": 468, "frequency": "c", "synset": "footstool.n.01"}, {"name": "fork", "instance_count": 3137, "def": "cutlery used for serving and eating food", "synonyms": ["fork"], "image_count": 1861, "id": 469, "frequency": "f", "synset": "fork.n.01"}, {"name": "forklift", "instance_count": 14, "def": "an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them", "synonyms": ["forklift"], "image_count": 11, "id": 470, "frequency": "c", "synset": "forklift.n.01"}, {"name": "freight_car", "instance_count": 121, "def": "a railway car that carries freight", "synonyms": ["freight_car"], "image_count": 13, "id": 471, "frequency": "c", "synset": "freight_car.n.01"}, {"name": "French_toast", "instance_count": 41, "def": "bread slice dipped in egg and milk and fried", "synonyms": ["French_toast"], "image_count": 13, "id": 472, "frequency": "c", "synset": "french_toast.n.01"}, {"name": "freshener", "instance_count": 39, "def": "anything that freshens air by removing or covering odor", "synonyms": ["freshener", "air_freshener"], "image_count": 32, "id": 473, "frequency": "c", "synset": "freshener.n.01"}, {"name": "frisbee", "instance_count": 2332, "def": "a light, plastic disk propelled with a flip of the wrist for recreation or competition", "synonyms": ["frisbee"], "image_count": 1767, "id": 474, "frequency": "f", "synset": "frisbee.n.01"}, {"name": "frog", "instance_count": 84, "def": "a tailless stout-bodied amphibians with long hind limbs for leaping", "synonyms": ["frog", "toad", "toad_frog"], "image_count": 42, "id": 475, "frequency": "c", "synset": "frog.n.01"}, {"name": "fruit_juice", "instance_count": 37, "def": "drink produced by squeezing or crushing fruit", "synonyms": ["fruit_juice"], "image_count": 17, "id": 476, "frequency": "c", "synset": "fruit_juice.n.01"}, {"name": "frying_pan", "instance_count": 310, "def": "a pan used for frying foods", "synonyms": ["frying_pan", "frypan", "skillet"], "image_count": 128, "id": 477, "frequency": "f", "synset": "frying_pan.n.01"}, {"name": "fudge", "instance_count": 4, "def": "soft creamy candy", "synonyms": ["fudge"], "image_count": 1, "id": 478, "frequency": "r", "synset": "fudge.n.01"}, {"name": "funnel", "instance_count": 9, "def": "a cone-shaped utensil used to channel a substance into a container with a small mouth", "synonyms": ["funnel"], "image_count": 9, "id": 479, "frequency": "r", "synset": "funnel.n.02"}, {"name": "futon", "instance_count": 11, "def": "a pad that is used for sleeping on the floor or on a raised frame", "synonyms": ["futon"], "image_count": 10, "id": 480, "frequency": "r", "synset": "futon.n.01"}, {"name": "gag", "instance_count": 4, "def": "restraint put into a person's mouth to prevent speaking or shouting", "synonyms": ["gag", "muzzle"], "image_count": 4, "id": 481, "frequency": "r", "synset": "gag.n.02"}, {"name": "garbage", "instance_count": 18, "def": "a receptacle where waste can be discarded", "synonyms": ["garbage"], "image_count": 9, "id": 482, "frequency": "r", "synset": "garbage.n.03"}, {"name": "garbage_truck", "instance_count": 18, "def": "a truck for collecting domestic refuse", "synonyms": ["garbage_truck"], "image_count": 18, "id": 483, "frequency": "c", "synset": "garbage_truck.n.01"}, {"name": "garden_hose", "instance_count": 50, "def": "a hose used for watering a lawn or garden", "synonyms": ["garden_hose"], "image_count": 41, "id": 484, "frequency": "c", "synset": "garden_hose.n.01"}, {"name": "gargle", "instance_count": 38, "def": "a medicated solution used for gargling and rinsing the mouth", "synonyms": ["gargle", "mouthwash"], "image_count": 28, "id": 485, "frequency": "c", "synset": "gargle.n.01"}, {"name": "gargoyle", "instance_count": 8, "def": "an ornament consisting of a grotesquely carved figure of a person or animal", "synonyms": ["gargoyle"], "image_count": 3, "id": 486, "frequency": "r", "synset": "gargoyle.n.02"}, {"name": "garlic", "instance_count": 487, "def": "aromatic bulb used as seasoning", "synonyms": ["garlic", "ail"], "image_count": 65, "id": 487, "frequency": "c", "synset": "garlic.n.02"}, {"name": "gasmask", "instance_count": 12, "def": "a protective face mask with a filter", "synonyms": ["gasmask", "respirator", "gas_helmet"], "image_count": 9, "id": 488, "frequency": "r", "synset": "gasmask.n.01"}, {"name": "gazelle", "instance_count": 82, "def": "small swift graceful antelope of Africa and Asia having lustrous eyes", "synonyms": ["gazelle"], "image_count": 23, "id": 489, "frequency": "c", "synset": "gazelle.n.01"}, {"name": "gelatin", "instance_count": 248, "def": "an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods", "synonyms": ["gelatin", "jelly"], "image_count": 24, "id": 490, "frequency": "c", "synset": "gelatin.n.02"}, {"name": "gemstone", "instance_count": 2, "def": "a crystalline rock that can be cut and polished for jewelry", "synonyms": ["gemstone"], "image_count": 1, "id": 491, "frequency": "r", "synset": "gem.n.02"}, {"name": "generator", "instance_count": 2, "def": "engine that converts mechanical energy into electrical energy by electromagnetic induction", "synonyms": ["generator"], "image_count": 2, "id": 492, "frequency": "r", "synset": "generator.n.02"}, {"name": "giant_panda", "instance_count": 112, "def": "large black-and-white herbivorous mammal of bamboo forests of China and Tibet", "synonyms": ["giant_panda", "panda", "panda_bear"], "image_count": 59, "id": 493, "frequency": "c", "synset": "giant_panda.n.01"}, {"name": "gift_wrap", "instance_count": 247, "def": "attractive wrapping paper suitable for wrapping gifts", "synonyms": ["gift_wrap"], "image_count": 48, "id": 494, "frequency": "c", "synset": "gift_wrap.n.01"}, {"name": "ginger", "instance_count": 93, "def": "the root of the common ginger plant; used fresh as a seasoning", "synonyms": ["ginger", "gingerroot"], "image_count": 17, "id": 495, "frequency": "c", "synset": "ginger.n.03"}, {"name": "giraffe", "instance_count": 3923, "def": "tall animal having a spotted coat and small horns and very long neck and legs", "synonyms": ["giraffe"], "image_count": 1877, "id": 496, "frequency": "f", "synset": "giraffe.n.01"}, {"name": "cincture", "instance_count": 56, "def": "a band of material around the waist that strengthens a skirt or trousers", "synonyms": ["cincture", "sash", "waistband", "waistcloth"], "image_count": 18, "id": 497, "frequency": "c", "synset": "girdle.n.02"}, {"name": "glass_(drink_container)", "instance_count": 6420, "def": "a container for holding liquids while drinking", "synonyms": ["glass_(drink_container)", "drinking_glass"], "image_count": 1920, "id": 498, "frequency": "f", "synset": "glass.n.02"}, {"name": "globe", "instance_count": 59, "def": "a sphere on which a map (especially of the earth) is represented", "synonyms": ["globe"], "image_count": 50, "id": 499, "frequency": "c", "synset": "globe.n.03"}, {"name": "glove", "instance_count": 5951, "def": "handwear covering the hand", "synonyms": ["glove"], "image_count": 1890, "id": 500, "frequency": "f", "synset": "glove.n.02"}, {"name": "goat", "instance_count": 842, "def": "a common goat", "synonyms": ["goat"], "image_count": 99, "id": 501, "frequency": "c", "synset": "goat.n.01"}, {"name": "goggles", "instance_count": 3202, "def": "tight-fitting spectacles worn to protect the eyes", "synonyms": ["goggles"], "image_count": 1530, "id": 502, "frequency": "f", "synset": "goggles.n.01"}, {"name": "goldfish", "instance_count": 11, "def": "small golden or orange-red freshwater fishes used as pond or aquarium pets", "synonyms": ["goldfish"], "image_count": 3, "id": 503, "frequency": "r", "synset": "goldfish.n.01"}, {"name": "golf_club", "instance_count": 14, "def": "golf equipment used by a golfer to hit a golf ball", "synonyms": ["golf_club", "golf-club"], "image_count": 11, "id": 504, "frequency": "c", "synset": "golf_club.n.02"}, {"name": "golfcart", "instance_count": 25, "def": "a small motor vehicle in which golfers can ride between shots", "synonyms": ["golfcart"], "image_count": 19, "id": 505, "frequency": "c", "synset": "golfcart.n.01"}, {"name": "gondola_(boat)", "instance_count": 8, "def": "long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice", "synonyms": ["gondola_(boat)"], "image_count": 3, "id": 506, "frequency": "r", "synset": "gondola.n.02"}, {"name": "goose", "instance_count": 413, "def": "loud, web-footed long-necked aquatic birds usually larger than ducks", "synonyms": ["goose"], "image_count": 63, "id": 507, "frequency": "c", "synset": "goose.n.01"}, {"name": "gorilla", "instance_count": 10, "def": "largest ape", "synonyms": ["gorilla"], "image_count": 5, "id": 508, "frequency": "r", "synset": "gorilla.n.01"}, {"name": "gourd", "instance_count": 101, "def": "any of numerous inedible fruits with hard rinds", "synonyms": ["gourd"], "image_count": 6, "id": 509, "frequency": "r", "synset": "gourd.n.02"}, {"name": "grape", "instance_count": 6377, "def": "any of various juicy fruit with green or purple skins; grow in clusters", "synonyms": ["grape"], "image_count": 233, "id": 510, "frequency": "f", "synset": "grape.n.01"}, {"name": "grater", "instance_count": 64, "def": "utensil with sharp perforations for shredding foods (as vegetables or cheese)", "synonyms": ["grater"], "image_count": 54, "id": 511, "frequency": "c", "synset": "grater.n.01"}, {"name": "gravestone", "instance_count": 778, "def": "a stone that is used to mark a grave", "synonyms": ["gravestone", "headstone", "tombstone"], "image_count": 36, "id": 512, "frequency": "c", "synset": "gravestone.n.01"}, {"name": "gravy_boat", "instance_count": 10, "def": "a dish (often boat-shaped) for serving gravy or sauce", "synonyms": ["gravy_boat", "gravy_holder"], "image_count": 10, "id": 513, "frequency": "r", "synset": "gravy_boat.n.01"}, {"name": "green_bean", "instance_count": 2571, "def": "a common bean plant cultivated for its slender green edible pods", "synonyms": ["green_bean"], "image_count": 124, "id": 514, "frequency": "f", "synset": "green_bean.n.02"}, {"name": "green_onion", "instance_count": 1618, "def": "a young onion before the bulb has enlarged", "synonyms": ["green_onion", "spring_onion", "scallion"], "image_count": 101, "id": 515, "frequency": "f", "synset": "green_onion.n.01"}, {"name": "griddle", "instance_count": 4, "def": "cooking utensil consisting of a flat heated surface on which food is cooked", "synonyms": ["griddle"], "image_count": 3, "id": 516, "frequency": "r", "synset": "griddle.n.01"}, {"name": "grill", "instance_count": 747, "def": "a framework of metal bars used as a partition or a grate", "synonyms": ["grill", "grille", "grillwork", "radiator_grille"], "image_count": 363, "id": 517, "frequency": "f", "synset": "grill.n.02"}, {"name": "grits", "instance_count": 3, "def": "coarsely ground corn boiled as a breakfast dish", "synonyms": ["grits", "hominy_grits"], "image_count": 3, "id": 518, "frequency": "r", "synset": "grits.n.01"}, {"name": "grizzly", "instance_count": 44, "def": "powerful brownish-yellow bear of the uplands of western North America", "synonyms": ["grizzly", "grizzly_bear"], "image_count": 30, "id": 519, "frequency": "c", "synset": "grizzly.n.01"}, {"name": "grocery_bag", "instance_count": 46, "def": "a sack for holding customer's groceries", "synonyms": ["grocery_bag"], "image_count": 18, "id": 520, "frequency": "c", "synset": "grocery_bag.n.01"}, {"name": "guitar", "instance_count": 315, "def": "a stringed instrument usually having six strings; played by strumming or plucking", "synonyms": ["guitar"], "image_count": 199, "id": 521, "frequency": "f", "synset": "guitar.n.01"}, {"name": "gull", "instance_count": 1398, "def": "mostly white aquatic bird having long pointed wings and short legs", "synonyms": ["gull", "seagull"], "image_count": 97, "id": 522, "frequency": "c", "synset": "gull.n.02"}, {"name": "gun", "instance_count": 68, "def": "a weapon that discharges a bullet at high velocity from a metal tube", "synonyms": ["gun"], "image_count": 32, "id": 523, "frequency": "c", "synset": "gun.n.01"}, {"name": "hairbrush", "instance_count": 165, "def": "a brush used to groom a person's hair", "synonyms": ["hairbrush"], "image_count": 121, "id": 524, "frequency": "f", "synset": "hairbrush.n.01"}, {"name": "hairnet", "instance_count": 53, "def": "a small net that someone wears over their hair to keep it in place", "synonyms": ["hairnet"], "image_count": 16, "id": 525, "frequency": "c", "synset": "hairnet.n.01"}, {"name": "hairpin", "instance_count": 20, "def": "a double pronged pin used to hold women's hair in place", "synonyms": ["hairpin"], "image_count": 12, "id": 526, "frequency": "c", "synset": "hairpin.n.01"}, {"name": "halter_top", "instance_count": 3, "def": "a woman's top that fastens behind the back and neck leaving the back and arms uncovered", "synonyms": ["halter_top"], "image_count": 2, "id": 527, "frequency": "r", "synset": "halter.n.03"}, {"name": "ham", "instance_count": 1765, "def": "meat cut from the thigh of a hog (usually smoked)", "synonyms": ["ham", "jambon", "gammon"], "image_count": 214, "id": 528, "frequency": "f", "synset": "ham.n.01"}, {"name": "hamburger", "instance_count": 126, "def": "a sandwich consisting of a patty of minced beef served on a bun", "synonyms": ["hamburger", "beefburger", "burger"], "image_count": 48, "id": 529, "frequency": "c", "synset": "hamburger.n.01"}, {"name": "hammer", "instance_count": 41, "def": "a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking", "synonyms": ["hammer"], "image_count": 26, "id": 530, "frequency": "c", "synset": "hammer.n.02"}, {"name": "hammock", "instance_count": 15, "def": "a hanging bed of canvas or rope netting (usually suspended between two trees)", "synonyms": ["hammock"], "image_count": 13, "id": 531, "frequency": "c", "synset": "hammock.n.02"}, {"name": "hamper", "instance_count": 5, "def": "a basket usually with a cover", "synonyms": ["hamper"], "image_count": 4, "id": 532, "frequency": "r", "synset": "hamper.n.02"}, {"name": "hamster", "instance_count": 12, "def": "short-tailed burrowing rodent with large cheek pouches", "synonyms": ["hamster"], "image_count": 11, "id": 533, "frequency": "c", "synset": "hamster.n.01"}, {"name": "hair_dryer", "instance_count": 144, "def": "a hand-held electric blower that can blow warm air onto the hair", "synonyms": ["hair_dryer"], "image_count": 123, "id": 534, "frequency": "f", "synset": "hand_blower.n.01"}, {"name": "hand_glass", "instance_count": 7, "def": "a mirror intended to be held in the hand", "synonyms": ["hand_glass", "hand_mirror"], "image_count": 7, "id": 535, "frequency": "r", "synset": "hand_glass.n.01"}, {"name": "hand_towel", "instance_count": 619, "def": "a small towel used to dry the hands or face", "synonyms": ["hand_towel", "face_towel"], "image_count": 200, "id": 536, "frequency": "f", "synset": "hand_towel.n.01"}, {"name": "handcart", "instance_count": 204, "def": "wheeled vehicle that can be pushed by a person", "synonyms": ["handcart", "pushcart", "hand_truck"], "image_count": 91, "id": 537, "frequency": "c", "synset": "handcart.n.01"}, {"name": "handcuff", "instance_count": 10, "def": "shackle that consists of a metal loop that can be locked around the wrist", "synonyms": ["handcuff"], "image_count": 9, "id": 538, "frequency": "r", "synset": "handcuff.n.01"}, {"name": "handkerchief", "instance_count": 86, "def": "a square piece of cloth used for wiping the eyes or nose or as a costume accessory", "synonyms": ["handkerchief"], "image_count": 72, "id": 539, "frequency": "c", "synset": "handkerchief.n.01"}, {"name": "handle", "instance_count": 8314, "def": "the appendage to an object that is designed to be held in order to use or move it", "synonyms": ["handle", "grip", "handgrip"], "image_count": 1886, "id": 540, "frequency": "f", "synset": "handle.n.01"}, {"name": "handsaw", "instance_count": 5, "def": "a saw used with one hand for cutting wood", "synonyms": ["handsaw", "carpenter's_saw"], "image_count": 4, "id": 541, "frequency": "r", "synset": "handsaw.n.01"}, {"name": "hardback_book", "instance_count": 2, "def": "a book with cardboard or cloth or leather covers", "synonyms": ["hardback_book", "hardcover_book"], "image_count": 1, "id": 542, "frequency": "r", "synset": "hardback.n.01"}, {"name": "harmonium", "instance_count": 2, "def": "a free-reed instrument in which air is forced through the reeds by bellows", "synonyms": ["harmonium", "organ_(musical_instrument)", "reed_organ_(musical_instrument)"], "image_count": 1, "id": 543, "frequency": "r", "synset": "harmonium.n.01"}, {"name": "hat", "instance_count": 7213, "def": "headwear that protects the head from bad weather, sun, or worn for fashion", "synonyms": ["hat"], "image_count": 1932, "id": 544, "frequency": "f", "synset": "hat.n.01"}, {"name": "hatbox", "instance_count": 7, "def": "a round piece of luggage for carrying hats", "synonyms": ["hatbox"], "image_count": 4, "id": 545, "frequency": "r", "synset": "hatbox.n.01"}, {"name": "veil", "instance_count": 57, "def": "a garment that covers the head OR face", "synonyms": ["veil"], "image_count": 56, "id": 546, "frequency": "c", "synset": "head_covering.n.01"}, {"name": "headband", "instance_count": 1114, "def": "a band worn around or over the head", "synonyms": ["headband"], "image_count": 854, "id": 547, "frequency": "f", "synset": "headband.n.01"}, {"name": "headboard", "instance_count": 850, "def": "a vertical board or panel forming the head of a bedstead", "synonyms": ["headboard"], "image_count": 755, "id": 548, "frequency": "f", "synset": "headboard.n.01"}, {"name": "headlight", "instance_count": 7326, "def": "a powerful light with reflector; attached to the front of an automobile or locomotive", "synonyms": ["headlight", "headlamp"], "image_count": 1843, "id": 549, "frequency": "f", "synset": "headlight.n.01"}, {"name": "headscarf", "instance_count": 235, "def": "a kerchief worn over the head and tied under the chin", "synonyms": ["headscarf"], "image_count": 96, "id": 550, "frequency": "c", "synset": "headscarf.n.01"}, {"name": "headset", "instance_count": 10, "def": "receiver consisting of a pair of headphones", "synonyms": ["headset"], "image_count": 7, "id": 551, "frequency": "r", "synset": "headset.n.01"}, {"name": "headstall_(for_horses)", "instance_count": 133, "def": "the band that is the part of a bridle that fits around a horse's head", "synonyms": ["headstall_(for_horses)", "headpiece_(for_horses)"], "image_count": 74, "id": 552, "frequency": "c", "synset": "headstall.n.01"}, {"name": "heart", "instance_count": 347, "def": "a muscular organ; its contractions move the blood through the body", "synonyms": ["heart"], "image_count": 66, "id": 553, "frequency": "c", "synset": "heart.n.02"}, {"name": "heater", "instance_count": 64, "def": "device that heats water or supplies warmth to a room", "synonyms": ["heater", "warmer"], "image_count": 57, "id": 554, "frequency": "c", "synset": "heater.n.01"}, {"name": "helicopter", "instance_count": 68, "def": "an aircraft without wings that obtains its lift from the rotation of overhead blades", "synonyms": ["helicopter"], "image_count": 44, "id": 555, "frequency": "c", "synset": "helicopter.n.01"}, {"name": "helmet", "instance_count": 4845, "def": "a protective headgear made of hard material to resist blows", "synonyms": ["helmet"], "image_count": 1905, "id": 556, "frequency": "f", "synset": "helmet.n.02"}, {"name": "heron", "instance_count": 6, "def": "grey or white wading bird with long neck and long legs and (usually) long bill", "synonyms": ["heron"], "image_count": 4, "id": 557, "frequency": "r", "synset": "heron.n.02"}, {"name": "highchair", "instance_count": 98, "def": "a chair for feeding a very young child", "synonyms": ["highchair", "feeding_chair"], "image_count": 90, "id": 558, "frequency": "c", "synset": "highchair.n.01"}, {"name": "hinge", "instance_count": 5283, "def": "a joint that holds two parts together so that one can swing relative to the other", "synonyms": ["hinge"], "image_count": 1635, "id": 559, "frequency": "f", "synset": "hinge.n.01"}, {"name": "hippopotamus", "instance_count": 24, "def": "massive thick-skinned animal living in or around rivers of tropical Africa", "synonyms": ["hippopotamus"], "image_count": 8, "id": 560, "frequency": "r", "synset": "hippopotamus.n.01"}, {"name": "hockey_stick", "instance_count": 15, "def": "sports implement consisting of a stick used by hockey players to move the puck", "synonyms": ["hockey_stick"], "image_count": 5, "id": 561, "frequency": "r", "synset": "hockey_stick.n.01"}, {"name": "hog", "instance_count": 73, "def": "domestic swine", "synonyms": ["hog", "pig"], "image_count": 50, "id": 562, "frequency": "c", "synset": "hog.n.03"}, {"name": "home_plate_(baseball)", "instance_count": 551, "def": "(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score", "synonyms": ["home_plate_(baseball)", "home_base_(baseball)"], "image_count": 545, "id": 563, "frequency": "f", "synset": "home_plate.n.01"}, {"name": "honey", "instance_count": 90, "def": "a sweet yellow liquid produced by bees", "synonyms": ["honey"], "image_count": 20, "id": 564, "frequency": "c", "synset": "honey.n.01"}, {"name": "fume_hood", "instance_count": 208, "def": "metal covering leading to a vent that exhausts smoke or fumes", "synonyms": ["fume_hood", "exhaust_hood"], "image_count": 193, "id": 565, "frequency": "f", "synset": "hood.n.06"}, {"name": "hook", "instance_count": 1157, "def": "a curved or bent implement for suspending or pulling something", "synonyms": ["hook"], "image_count": 285, "id": 566, "frequency": "f", "synset": "hook.n.05"}, {"name": "hookah", "instance_count": 3, "def": "a tobacco pipe with a long flexible tube connected to a container where the smoke is cooled by passing through water", "synonyms": ["hookah", "narghile", "nargileh", "sheesha", "shisha", "water_pipe"], "image_count": 3, "id": 567, "frequency": "r", "synset": "hookah.n.01"}, {"name": "hornet", "instance_count": 1, "def": "large stinging wasp", "synonyms": ["hornet"], "image_count": 1, "id": 568, "frequency": "r", "synset": "hornet.n.01"}, {"name": "horse", "instance_count": 4744, "def": "a common horse", "synonyms": ["horse"], "image_count": 1904, "id": 569, "frequency": "f", "synset": "horse.n.01"}, {"name": "hose", "instance_count": 610, "def": "a flexible pipe for conveying a liquid or gas", "synonyms": ["hose", "hosepipe"], "image_count": 294, "id": 570, "frequency": "f", "synset": "hose.n.03"}, {"name": "hot-air_balloon", "instance_count": 4, "def": "balloon for travel through the air in a basket suspended below a large bag of heated air", "synonyms": ["hot-air_balloon"], "image_count": 3, "id": 571, "frequency": "r", "synset": "hot-air_balloon.n.01"}, {"name": "hotplate", "instance_count": 6, "def": "a portable electric appliance for heating or cooking or keeping food warm", "synonyms": ["hotplate"], "image_count": 5, "id": 572, "frequency": "r", "synset": "hot_plate.n.01"}, {"name": "hot_sauce", "instance_count": 70, "def": "a pungent peppery sauce", "synonyms": ["hot_sauce"], "image_count": 24, "id": 573, "frequency": "c", "synset": "hot_sauce.n.01"}, {"name": "hourglass", "instance_count": 2, "def": "a sandglass timer that runs for sixty minutes", "synonyms": ["hourglass"], "image_count": 2, "id": 574, "frequency": "r", "synset": "hourglass.n.01"}, {"name": "houseboat", "instance_count": 4, "def": "a barge that is designed and equipped for use as a dwelling", "synonyms": ["houseboat"], "image_count": 2, "id": 575, "frequency": "r", "synset": "houseboat.n.01"}, {"name": "hummingbird", "instance_count": 18, "def": "tiny American bird having brilliant iridescent plumage and long slender bills", "synonyms": ["hummingbird"], "image_count": 16, "id": 576, "frequency": "c", "synset": "hummingbird.n.01"}, {"name": "hummus", "instance_count": 9, "def": "a thick spread made from mashed chickpeas", "synonyms": ["hummus", "humus", "hommos", "hoummos", "humous"], "image_count": 8, "id": 577, "frequency": "r", "synset": "hummus.n.01"}, {"name": "polar_bear", "instance_count": 196, "def": "white bear of Arctic regions", "synonyms": ["polar_bear"], "image_count": 154, "id": 578, "frequency": "f", "synset": "ice_bear.n.01"}, {"name": "icecream", "instance_count": 180, "def": "frozen dessert containing cream and sugar and flavoring", "synonyms": ["icecream"], "image_count": 66, "id": 579, "frequency": "c", "synset": "ice_cream.n.01"}, {"name": "popsicle", "instance_count": 1, "def": "ice cream or water ice on a small wooden stick", "synonyms": ["popsicle"], "image_count": 1, "id": 580, "frequency": "r", "synset": "ice_lolly.n.01"}, {"name": "ice_maker", "instance_count": 26, "def": "an appliance included in some electric refrigerators for making ice cubes", "synonyms": ["ice_maker"], "image_count": 24, "id": 581, "frequency": "c", "synset": "ice_maker.n.01"}, {"name": "ice_pack", "instance_count": 4, "def": "a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling", "synonyms": ["ice_pack", "ice_bag"], "image_count": 1, "id": 582, "frequency": "r", "synset": "ice_pack.n.01"}, {"name": "ice_skate", "instance_count": 14, "def": "skate consisting of a boot with a steel blade fitted to the sole", "synonyms": ["ice_skate"], "image_count": 4, "id": 583, "frequency": "r", "synset": "ice_skate.n.01"}, {"name": "igniter", "instance_count": 77, "def": "a substance or device used to start a fire", "synonyms": ["igniter", "ignitor", "lighter"], "image_count": 75, "id": 584, "frequency": "c", "synset": "igniter.n.01"}, {"name": "inhaler", "instance_count": 7, "def": "a dispenser that produces a chemical vapor to be inhaled through mouth or nose", "synonyms": ["inhaler", "inhalator"], "image_count": 6, "id": 585, "frequency": "r", "synset": "inhaler.n.01"}, {"name": "iPod", "instance_count": 172, "def": "a pocket-sized device used to play music files", "synonyms": ["iPod"], "image_count": 126, "id": 586, "frequency": "f", "synset": "ipod.n.01"}, {"name": "iron_(for_clothing)", "instance_count": 38, "def": "home appliance consisting of a flat metal base that is heated and used to smooth cloth", "synonyms": ["iron_(for_clothing)", "smoothing_iron_(for_clothing)"], "image_count": 24, "id": 587, "frequency": "c", "synset": "iron.n.04"}, {"name": "ironing_board", "instance_count": 24, "def": "narrow padded board on collapsible supports; used for ironing clothes", "synonyms": ["ironing_board"], "image_count": 22, "id": 588, "frequency": "c", "synset": "ironing_board.n.01"}, {"name": "jacket", "instance_count": 8013, "def": "a waist-length coat", "synonyms": ["jacket"], "image_count": 1872, "id": 589, "frequency": "f", "synset": "jacket.n.01"}, {"name": "jam", "instance_count": 29, "def": "preserve of crushed fruit", "synonyms": ["jam"], "image_count": 16, "id": 590, "frequency": "c", "synset": "jam.n.01"}, {"name": "jar", "instance_count": 2002, "def": "a vessel (usually cylindrical) with a wide mouth and without handles", "synonyms": ["jar"], "image_count": 423, "id": 591, "frequency": "f", "synset": "jar.n.01"}, {"name": "jean", "instance_count": 5421, "def": "(usually plural) close-fitting trousers of heavy denim for manual work or casual wear", "synonyms": ["jean", "blue_jean", "denim"], "image_count": 1927, "id": 592, "frequency": "f", "synset": "jean.n.01"}, {"name": "jeep", "instance_count": 55, "def": "a car suitable for traveling over rough terrain", "synonyms": ["jeep", "landrover"], "image_count": 38, "id": 593, "frequency": "c", "synset": "jeep.n.01"}, {"name": "jelly_bean", "instance_count": 116, "def": "sugar-glazed jellied candy", "synonyms": ["jelly_bean", "jelly_egg"], "image_count": 3, "id": 594, "frequency": "r", "synset": "jelly_bean.n.01"}, {"name": "jersey", "instance_count": 8117, "def": "a close-fitting pullover shirt", "synonyms": ["jersey", "T-shirt", "tee_shirt"], "image_count": 1945, "id": 595, "frequency": "f", "synset": "jersey.n.03"}, {"name": "jet_plane", "instance_count": 87, "def": "an airplane powered by one or more jet engines", "synonyms": ["jet_plane", "jet-propelled_plane"], "image_count": 35, "id": 596, "frequency": "c", "synset": "jet.n.01"}, {"name": "jewel", "instance_count": 1, "def": "a precious or semiprecious stone incorporated into a piece of jewelry", "synonyms": ["jewel", "gem", "precious_stone"], "image_count": 1, "id": 597, "frequency": "r", "synset": "jewel.n.01"}, {"name": "jewelry", "instance_count": 51, "def": "an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)", "synonyms": ["jewelry", "jewellery"], "image_count": 13, "id": 598, "frequency": "c", "synset": "jewelry.n.01"}, {"name": "joystick", "instance_count": 12, "def": "a control device for computers consisting of a vertical handle that can move freely in two directions", "synonyms": ["joystick"], "image_count": 9, "id": 599, "frequency": "r", "synset": "joystick.n.02"}, {"name": "jumpsuit", "instance_count": 21, "def": "one-piece garment fashioned after a parachutist's uniform", "synonyms": ["jumpsuit"], "image_count": 14, "id": 600, "frequency": "c", "synset": "jump_suit.n.01"}, {"name": "kayak", "instance_count": 124, "def": "a small canoe consisting of a light frame made watertight with animal skins", "synonyms": ["kayak"], "image_count": 37, "id": 601, "frequency": "c", "synset": "kayak.n.01"}, {"name": "keg", "instance_count": 6, "def": "small cask or barrel", "synonyms": ["keg"], "image_count": 3, "id": 602, "frequency": "r", "synset": "keg.n.02"}, {"name": "kennel", "instance_count": 4, "def": "outbuilding that serves as a shelter for a dog", "synonyms": ["kennel", "doghouse"], "image_count": 4, "id": 603, "frequency": "r", "synset": "kennel.n.01"}, {"name": "kettle", "instance_count": 130, "def": "a metal pot for stewing or boiling; usually has a lid", "synonyms": ["kettle", "boiler"], "image_count": 100, "id": 604, "frequency": "c", "synset": "kettle.n.01"}, {"name": "key", "instance_count": 447, "def": "metal instrument used to unlock a lock", "synonyms": ["key"], "image_count": 195, "id": 605, "frequency": "f", "synset": "key.n.01"}, {"name": "keycard", "instance_count": 1, "def": "a plastic card used to gain access typically to a door", "synonyms": ["keycard"], "image_count": 1, "id": 606, "frequency": "r", "synset": "keycard.n.01"}, {"name": "kilt", "instance_count": 19, "def": "a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland", "synonyms": ["kilt"], "image_count": 12, "id": 607, "frequency": "c", "synset": "kilt.n.01"}, {"name": "kimono", "instance_count": 38, "def": "a loose robe; imitated from robes originally worn by Japanese", "synonyms": ["kimono"], "image_count": 24, "id": 608, "frequency": "c", "synset": "kimono.n.01"}, {"name": "kitchen_sink", "instance_count": 519, "def": "a sink in a kitchen", "synonyms": ["kitchen_sink"], "image_count": 489, "id": 609, "frequency": "f", "synset": "kitchen_sink.n.01"}, {"name": "kitchen_table", "instance_count": 11, "def": "a table in the kitchen", "synonyms": ["kitchen_table"], "image_count": 10, "id": 610, "frequency": "r", "synset": "kitchen_table.n.01"}, {"name": "kite", "instance_count": 11174, "def": "plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string", "synonyms": ["kite"], "image_count": 1689, "id": 611, "frequency": "f", "synset": "kite.n.03"}, {"name": "kitten", "instance_count": 60, "def": "young domestic cat", "synonyms": ["kitten", "kitty"], "image_count": 42, "id": 612, "frequency": "c", "synset": "kitten.n.01"}, {"name": "kiwi_fruit", "instance_count": 702, "def": "fuzzy brown egg-shaped fruit with slightly tart green flesh", "synonyms": ["kiwi_fruit"], "image_count": 81, "id": 613, "frequency": "c", "synset": "kiwi.n.03"}, {"name": "knee_pad", "instance_count": 1765, "def": "protective garment consisting of a pad worn by football or baseball or hockey players", "synonyms": ["knee_pad"], "image_count": 894, "id": 614, "frequency": "f", "synset": "knee_pad.n.01"}, {"name": "knife", "instance_count": 3515, "def": "tool with a blade and point used as a cutting instrument", "synonyms": ["knife"], "image_count": 1868, "id": 615, "frequency": "f", "synset": "knife.n.01"}, {"name": "knitting_needle", "instance_count": 16, "def": "needle consisting of a slender rod with pointed ends; usually used in pairs", "synonyms": ["knitting_needle"], "image_count": 7, "id": 616, "frequency": "r", "synset": "knitting_needle.n.01"}, {"name": "knob", "instance_count": 8432, "def": "a round handle often found on a door", "synonyms": ["knob"], "image_count": 1567, "id": 617, "frequency": "f", "synset": "knob.n.02"}, {"name": "knocker_(on_a_door)", "instance_count": 10, "def": "a device (usually metal and ornamental) attached by a hinge to a door", "synonyms": ["knocker_(on_a_door)", "doorknocker"], "image_count": 10, "id": 618, "frequency": "r", "synset": "knocker.n.05"}, {"name": "koala", "instance_count": 15, "def": "sluggish tailless Australian marsupial with grey furry ears and coat", "synonyms": ["koala", "koala_bear"], "image_count": 8, "id": 619, "frequency": "r", "synset": "koala.n.01"}, {"name": "lab_coat", "instance_count": 42, "def": "a light coat worn to protect clothing from substances used while working in a laboratory", "synonyms": ["lab_coat", "laboratory_coat"], "image_count": 7, "id": 620, "frequency": "r", "synset": "lab_coat.n.01"}, {"name": "ladder", "instance_count": 975, "def": "steps consisting of two parallel members connected by rungs", "synonyms": ["ladder"], "image_count": 629, "id": 621, "frequency": "f", "synset": "ladder.n.01"}, {"name": "ladle", "instance_count": 226, "def": "a spoon-shaped vessel with a long handle frequently used to transfer liquids", "synonyms": ["ladle"], "image_count": 89, "id": 622, "frequency": "c", "synset": "ladle.n.01"}, {"name": "ladybug", "instance_count": 68, "def": "small round bright-colored and spotted beetle, typically red and black", "synonyms": ["ladybug", "ladybeetle", "ladybird_beetle"], "image_count": 15, "id": 623, "frequency": "c", "synset": "ladybug.n.01"}, {"name": "lamb_(animal)", "instance_count": 618, "def": "young sheep", "synonyms": ["lamb_(animal)"], "image_count": 134, "id": 624, "frequency": "f", "synset": "lamb.n.01"}, {"name": "lamb-chop", "instance_count": 8, "def": "chop cut from a lamb", "synonyms": ["lamb-chop", "lambchop"], "image_count": 4, "id": 625, "frequency": "r", "synset": "lamb_chop.n.01"}, {"name": "lamp", "instance_count": 4139, "def": "a piece of furniture holding one or more electric light bulbs", "synonyms": ["lamp"], "image_count": 1802, "id": 626, "frequency": "f", "synset": "lamp.n.02"}, {"name": "lamppost", "instance_count": 2234, "def": "a metal post supporting an outdoor lamp (such as a streetlight)", "synonyms": ["lamppost"], "image_count": 595, "id": 627, "frequency": "f", "synset": "lamppost.n.01"}, {"name": "lampshade", "instance_count": 2475, "def": "a protective ornamental shade used to screen a light bulb from direct view", "synonyms": ["lampshade"], "image_count": 1210, "id": 628, "frequency": "f", "synset": "lampshade.n.01"}, {"name": "lantern", "instance_count": 364, "def": "light in a transparent protective case", "synonyms": ["lantern"], "image_count": 48, "id": 629, "frequency": "c", "synset": "lantern.n.01"}, {"name": "lanyard", "instance_count": 1065, "def": "a cord worn around the neck to hold a knife or whistle, etc.", "synonyms": ["lanyard", "laniard"], "image_count": 418, "id": 630, "frequency": "f", "synset": "lanyard.n.02"}, {"name": "laptop_computer", "instance_count": 2852, "def": "a portable computer small enough to use in your lap", "synonyms": ["laptop_computer", "notebook_computer"], "image_count": 1846, "id": 631, "frequency": "f", "synset": "laptop.n.01"}, {"name": "lasagna", "instance_count": 7, "def": "baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables", "synonyms": ["lasagna", "lasagne"], "image_count": 5, "id": 632, "frequency": "r", "synset": "lasagna.n.01"}, {"name": "latch", "instance_count": 702, "def": "a bar that can be lowered or slid into a groove to fasten a door or gate", "synonyms": ["latch"], "image_count": 221, "id": 633, "frequency": "f", "synset": "latch.n.02"}, {"name": "lawn_mower", "instance_count": 12, "def": "garden tool for mowing grass on lawns", "synonyms": ["lawn_mower"], "image_count": 10, "id": 634, "frequency": "r", "synset": "lawn_mower.n.01"}, {"name": "leather", "instance_count": 20, "def": "an animal skin made smooth and flexible by removing the hair and then tanning", "synonyms": ["leather"], "image_count": 7, "id": 635, "frequency": "r", "synset": "leather.n.01"}, {"name": "legging_(clothing)", "instance_count": 154, "def": "a garment covering the leg (usually extending from the knee to the ankle)", "synonyms": ["legging_(clothing)", "leging_(clothing)", "leg_covering"], "image_count": 76, "id": 636, "frequency": "c", "synset": "legging.n.01"}, {"name": "Lego", "instance_count": 331, "def": "a child's plastic construction set for making models from blocks", "synonyms": ["Lego", "Lego_set"], "image_count": 22, "id": 637, "frequency": "c", "synset": "lego.n.01"}, {"name": "legume", "instance_count": 333, "def": "the fruit or seed of bean or pea plants", "synonyms": ["legume"], "image_count": 10, "id": 638, "frequency": "r", "synset": "legume.n.02"}, {"name": "lemon", "instance_count": 2168, "def": "yellow oval fruit with juicy acidic flesh", "synonyms": ["lemon"], "image_count": 341, "id": 639, "frequency": "f", "synset": "lemon.n.01"}, {"name": "lemonade", "instance_count": 2, "def": "sweetened beverage of diluted lemon juice", "synonyms": ["lemonade"], "image_count": 1, "id": 640, "frequency": "r", "synset": "lemonade.n.01"}, {"name": "lettuce", "instance_count": 5500, "def": "leafy plant commonly eaten in salad or on sandwiches", "synonyms": ["lettuce"], "image_count": 705, "id": 641, "frequency": "f", "synset": "lettuce.n.02"}, {"name": "license_plate", "instance_count": 4392, "def": "a plate mounted on the front and back of car and bearing the car's registration number", "synonyms": ["license_plate", "numberplate"], "image_count": 1900, "id": 642, "frequency": "f", "synset": "license_plate.n.01"}, {"name": "life_buoy", "instance_count": 524, "def": "a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)", "synonyms": ["life_buoy", "lifesaver", "life_belt", "life_ring"], "image_count": 188, "id": 643, "frequency": "f", "synset": "life_buoy.n.01"}, {"name": "life_jacket", "instance_count": 689, "def": "life preserver consisting of a sleeveless jacket of buoyant or inflatable design", "synonyms": ["life_jacket", "life_vest"], "image_count": 227, "id": 644, "frequency": "f", "synset": "life_jacket.n.01"}, {"name": "lightbulb", "instance_count": 7075, "def": "lightblub/source of light", "synonyms": ["lightbulb"], "image_count": 861, "id": 645, "frequency": "f", "synset": "light_bulb.n.01"}, {"name": "lightning_rod", "instance_count": 6, "def": "a metallic conductor that is attached to a high point and leads to the ground", "synonyms": ["lightning_rod", "lightning_conductor"], "image_count": 6, "id": 646, "frequency": "r", "synset": "lightning_rod.n.02"}, {"name": "lime", "instance_count": 1134, "def": "the green acidic fruit of any of various lime trees", "synonyms": ["lime"], "image_count": 115, "id": 647, "frequency": "f", "synset": "lime.n.06"}, {"name": "limousine", "instance_count": 6, "def": "long luxurious car; usually driven by a chauffeur", "synonyms": ["limousine"], "image_count": 5, "id": 648, "frequency": "r", "synset": "limousine.n.01"}, {"name": "lion", "instance_count": 69, "def": "large gregarious predatory cat of Africa and India", "synonyms": ["lion"], "image_count": 43, "id": 649, "frequency": "c", "synset": "lion.n.01"}, {"name": "lip_balm", "instance_count": 29, "def": "a balm applied to the lips", "synonyms": ["lip_balm"], "image_count": 14, "id": 650, "frequency": "c", "synset": "lip_balm.n.01"}, {"name": "liquor", "instance_count": 66, "def": "liquor or beer", "synonyms": ["liquor", "spirits", "hard_liquor", "liqueur", "cordial"], "image_count": 6, "id": 651, "frequency": "r", "synset": "liquor.n.01"}, {"name": "lizard", "instance_count": 22, "def": "a reptile with usually two pairs of legs and a tapering tail", "synonyms": ["lizard"], "image_count": 15, "id": 652, "frequency": "c", "synset": "lizard.n.01"}, {"name": "log", "instance_count": 7363, "def": "a segment of the trunk of a tree when stripped of branches", "synonyms": ["log"], "image_count": 1167, "id": 653, "frequency": "f", "synset": "log.n.01"}, {"name": "lollipop", "instance_count": 59, "def": "hard candy on a stick", "synonyms": ["lollipop"], "image_count": 15, "id": 654, "frequency": "c", "synset": "lollipop.n.02"}, {"name": "speaker_(stero_equipment)", "instance_count": 2029, "def": "electronic device that produces sound often as part of a stereo system", "synonyms": ["speaker_(stero_equipment)"], "image_count": 994, "id": 655, "frequency": "f", "synset": "loudspeaker.n.01"}, {"name": "loveseat", "instance_count": 41, "def": "small sofa that seats two people", "synonyms": ["loveseat"], "image_count": 28, "id": 656, "frequency": "c", "synset": "love_seat.n.01"}, {"name": "machine_gun", "instance_count": 5, "def": "a rapidly firing automatic gun", "synonyms": ["machine_gun"], "image_count": 2, "id": 657, "frequency": "r", "synset": "machine_gun.n.01"}, {"name": "magazine", "instance_count": 1379, "def": "a paperback periodic publication", "synonyms": ["magazine"], "image_count": 338, "id": 658, "frequency": "f", "synset": "magazine.n.02"}, {"name": "magnet", "instance_count": 5638, "def": "a device that attracts iron and produces a magnetic field", "synonyms": ["magnet"], "image_count": 334, "id": 659, "frequency": "f", "synset": "magnet.n.01"}, {"name": "mail_slot", "instance_count": 16, "def": "a slot (usually in a door) through which mail can be delivered", "synonyms": ["mail_slot"], "image_count": 15, "id": 660, "frequency": "c", "synset": "mail_slot.n.01"}, {"name": "mailbox_(at_home)", "instance_count": 240, "def": "a private box for delivery of mail", "synonyms": ["mailbox_(at_home)", "letter_box_(at_home)"], "image_count": 102, "id": 661, "frequency": "f", "synset": "mailbox.n.01"}, {"name": "mallard", "instance_count": 2, "def": "wild dabbling duck from which domestic ducks are descended", "synonyms": ["mallard"], "image_count": 1, "id": 662, "frequency": "r", "synset": "mallard.n.01"}, {"name": "mallet", "instance_count": 16, "def": "a sports implement with a long handle and a hammer-like head used to hit a ball", "synonyms": ["mallet"], "image_count": 8, "id": 663, "frequency": "r", "synset": "mallet.n.01"}, {"name": "mammoth", "instance_count": 2, "def": "any of numerous extinct elephants widely distributed in the Pleistocene", "synonyms": ["mammoth"], "image_count": 1, "id": 664, "frequency": "r", "synset": "mammoth.n.01"}, {"name": "manatee", "instance_count": 1, "def": "sirenian mammal of tropical coastal waters of America", "synonyms": ["manatee"], "image_count": 1, "id": 665, "frequency": "r", "synset": "manatee.n.01"}, {"name": "mandarin_orange", "instance_count": 401, "def": "a somewhat flat reddish-orange loose skinned citrus of China", "synonyms": ["mandarin_orange"], "image_count": 28, "id": 666, "frequency": "c", "synset": "mandarin.n.05"}, {"name": "manger", "instance_count": 126, "def": "a container (usually in a barn or stable) from which cattle or horses feed", "synonyms": ["manger", "trough"], "image_count": 91, "id": 667, "frequency": "c", "synset": "manger.n.01"}, {"name": "manhole", "instance_count": 445, "def": "a hole (usually with a flush cover) through which a person can gain access to an underground structure", "synonyms": ["manhole"], "image_count": 260, "id": 668, "frequency": "f", "synset": "manhole.n.01"}, {"name": "map", "instance_count": 186, "def": "a diagrammatic representation of the earth's surface (or part of it)", "synonyms": ["map"], "image_count": 131, "id": 669, "frequency": "f", "synset": "map.n.01"}, {"name": "marker", "instance_count": 501, "def": "a writing implement for making a mark", "synonyms": ["marker"], "image_count": 128, "id": 670, "frequency": "f", "synset": "marker.n.03"}, {"name": "martini", "instance_count": 3, "def": "a cocktail made of gin (or vodka) with dry vermouth", "synonyms": ["martini"], "image_count": 3, "id": 671, "frequency": "r", "synset": "martini.n.01"}, {"name": "mascot", "instance_count": 10, "def": "a person or animal that is adopted by a team or other group as a symbolic figure", "synonyms": ["mascot"], "image_count": 10, "id": 672, "frequency": "r", "synset": "mascot.n.01"}, {"name": "mashed_potato", "instance_count": 58, "def": "potato that has been peeled and boiled and then mashed", "synonyms": ["mashed_potato"], "image_count": 39, "id": 673, "frequency": "c", "synset": "mashed_potato.n.01"}, {"name": "masher", "instance_count": 2, "def": "a kitchen utensil used for mashing (e.g. potatoes)", "synonyms": ["masher"], "image_count": 2, "id": 674, "frequency": "r", "synset": "masher.n.02"}, {"name": "mask", "instance_count": 1595, "def": "a protective covering worn over the face", "synonyms": ["mask", "facemask"], "image_count": 925, "id": 675, "frequency": "f", "synset": "mask.n.04"}, {"name": "mast", "instance_count": 2985, "def": "a vertical spar for supporting sails", "synonyms": ["mast"], "image_count": 354, "id": 676, "frequency": "f", "synset": "mast.n.01"}, {"name": "mat_(gym_equipment)", "instance_count": 114, "def": "sports equipment consisting of a piece of thick padding on the floor for gymnastics", "synonyms": ["mat_(gym_equipment)", "gym_mat"], "image_count": 31, "id": 677, "frequency": "c", "synset": "mat.n.03"}, {"name": "matchbox", "instance_count": 11, "def": "a box for holding matches", "synonyms": ["matchbox"], "image_count": 10, "id": 678, "frequency": "r", "synset": "matchbox.n.01"}, {"name": "mattress", "instance_count": 354, "def": "a thick pad filled with resilient material used as a bed or part of a bed", "synonyms": ["mattress"], "image_count": 215, "id": 679, "frequency": "f", "synset": "mattress.n.01"}, {"name": "measuring_cup", "instance_count": 139, "def": "graduated cup used to measure liquid or granular ingredients", "synonyms": ["measuring_cup"], "image_count": 71, "id": 680, "frequency": "c", "synset": "measuring_cup.n.01"}, {"name": "measuring_stick", "instance_count": 57, "def": "measuring instrument having a sequence of marks at regular intervals", "synonyms": ["measuring_stick", "ruler_(measuring_stick)", "measuring_rod"], "image_count": 43, "id": 681, "frequency": "c", "synset": "measuring_stick.n.01"}, {"name": "meatball", "instance_count": 174, "def": "ground meat formed into a ball and fried or simmered in broth", "synonyms": ["meatball"], "image_count": 28, "id": 682, "frequency": "c", "synset": "meatball.n.01"}, {"name": "medicine", "instance_count": 243, "def": "something that treats or prevents or alleviates the symptoms of disease", "synonyms": ["medicine"], "image_count": 34, "id": 683, "frequency": "c", "synset": "medicine.n.02"}, {"name": "melon", "instance_count": 167, "def": "fruit of the gourd family having a hard rind and sweet juicy flesh", "synonyms": ["melon"], "image_count": 16, "id": 684, "frequency": "c", "synset": "melon.n.01"}, {"name": "microphone", "instance_count": 435, "def": "device for converting sound waves into electrical energy", "synonyms": ["microphone"], "image_count": 273, "id": 685, "frequency": "f", "synset": "microphone.n.01"}, {"name": "microscope", "instance_count": 3, "def": "magnifier of the image of small objects", "synonyms": ["microscope"], "image_count": 2, "id": 686, "frequency": "r", "synset": "microscope.n.01"}, {"name": "microwave_oven", "instance_count": 1105, "def": "kitchen appliance that cooks food by passing an electromagnetic wave through it", "synonyms": ["microwave_oven"], "image_count": 999, "id": 687, "frequency": "f", "synset": "microwave.n.02"}, {"name": "milestone", "instance_count": 5, "def": "stone post at side of a road to show distances", "synonyms": ["milestone", "milepost"], "image_count": 4, "id": 688, "frequency": "r", "synset": "milestone.n.01"}, {"name": "milk", "instance_count": 227, "def": "a white nutritious liquid secreted by mammals and used as food by human beings", "synonyms": ["milk"], "image_count": 107, "id": 689, "frequency": "f", "synset": "milk.n.01"}, {"name": "milk_can", "instance_count": 8, "def": "can for transporting milk", "synonyms": ["milk_can"], "image_count": 2, "id": 690, "frequency": "r", "synset": "milk_can.n.01"}, {"name": "milkshake", "instance_count": 1, "def": "frothy drink of milk and flavoring and sometimes fruit or ice cream", "synonyms": ["milkshake"], "image_count": 1, "id": 691, "frequency": "r", "synset": "milkshake.n.01"}, {"name": "minivan", "instance_count": 1046, "def": "a small box-shaped passenger van", "synonyms": ["minivan"], "image_count": 454, "id": 692, "frequency": "f", "synset": "minivan.n.01"}, {"name": "mint_candy", "instance_count": 27, "def": "a candy that is flavored with a mint oil", "synonyms": ["mint_candy"], "image_count": 9, "id": 693, "frequency": "r", "synset": "mint.n.05"}, {"name": "mirror", "instance_count": 3490, "def": "polished surface that forms images by reflecting light", "synonyms": ["mirror"], "image_count": 1901, "id": 694, "frequency": "f", "synset": "mirror.n.01"}, {"name": "mitten", "instance_count": 156, "def": "glove that encases the thumb separately and the other four fingers together", "synonyms": ["mitten"], "image_count": 61, "id": 695, "frequency": "c", "synset": "mitten.n.01"}, {"name": "mixer_(kitchen_tool)", "instance_count": 108, "def": "a kitchen utensil that is used for mixing foods", "synonyms": ["mixer_(kitchen_tool)", "stand_mixer"], "image_count": 91, "id": 696, "frequency": "c", "synset": "mixer.n.04"}, {"name": "money", "instance_count": 122, "def": "the official currency issued by a government or national bank", "synonyms": ["money"], "image_count": 46, "id": 697, "frequency": "c", "synset": "money.n.03"}, {"name": "monitor_(computer_equipment) computer_monitor", "instance_count": 2955, "def": "a computer monitor", "synonyms": ["monitor_(computer_equipment) computer_monitor"], "image_count": 1402, "id": 698, "frequency": "f", "synset": "monitor.n.04"}, {"name": "monkey", "instance_count": 166, "def": "any of various long-tailed primates", "synonyms": ["monkey"], "image_count": 74, "id": 699, "frequency": "c", "synset": "monkey.n.01"}, {"name": "motor", "instance_count": 985, "def": "machine that converts other forms of energy into mechanical energy and so imparts motion", "synonyms": ["motor"], "image_count": 421, "id": 700, "frequency": "f", "synset": "motor.n.01"}, {"name": "motor_scooter", "instance_count": 720, "def": "a wheeled vehicle with small wheels and a low-powered engine", "synonyms": ["motor_scooter", "scooter"], "image_count": 226, "id": 701, "frequency": "f", "synset": "motor_scooter.n.01"}, {"name": "motor_vehicle", "instance_count": 64, "def": "a self-propelled wheeled vehicle that does not run on rails", "synonyms": ["motor_vehicle", "automotive_vehicle"], "image_count": 10, "id": 702, "frequency": "r", "synset": "motor_vehicle.n.01"}, {"name": "motorcycle", "instance_count": 5247, "def": "a motor vehicle with two wheels and a strong frame", "synonyms": ["motorcycle"], "image_count": 1720, "id": 703, "frequency": "f", "synset": "motorcycle.n.01"}, {"name": "mound_(baseball)", "instance_count": 269, "def": "(baseball) the slight elevation on which the pitcher stands", "synonyms": ["mound_(baseball)", "pitcher's_mound"], "image_count": 261, "id": 704, "frequency": "f", "synset": "mound.n.01"}, {"name": "mouse_(computer_equipment)", "instance_count": 1832, "def": "a computer input device that controls an on-screen pointer (does not include trackpads / touchpads)", "synonyms": ["mouse_(computer_equipment)", "computer_mouse"], "image_count": 1337, "id": 705, "frequency": "f", "synset": "mouse.n.04"}, {"name": "mousepad", "instance_count": 333, "def": "a small portable pad that provides an operating surface for a computer mouse", "synonyms": ["mousepad"], "image_count": 293, "id": 706, "frequency": "f", "synset": "mousepad.n.01"}, {"name": "muffin", "instance_count": 352, "def": "a sweet quick bread baked in a cup-shaped pan", "synonyms": ["muffin"], "image_count": 62, "id": 707, "frequency": "c", "synset": "muffin.n.01"}, {"name": "mug", "instance_count": 1785, "def": "with handle and usually cylindrical", "synonyms": ["mug"], "image_count": 814, "id": 708, "frequency": "f", "synset": "mug.n.04"}, {"name": "mushroom", "instance_count": 6257, "def": "a common mushroom", "synonyms": ["mushroom"], "image_count": 407, "id": 709, "frequency": "f", "synset": "mushroom.n.02"}, {"name": "music_stool", "instance_count": 6, "def": "a stool for piano players; usually adjustable in height", "synonyms": ["music_stool", "piano_stool"], "image_count": 6, "id": 710, "frequency": "r", "synset": "music_stool.n.01"}, {"name": "musical_instrument", "instance_count": 33, "def": "any of various devices or contrivances that can be used to produce musical tones or sounds", "synonyms": ["musical_instrument", "instrument_(musical)"], "image_count": 16, "id": 711, "frequency": "c", "synset": "musical_instrument.n.01"}, {"name": "nailfile", "instance_count": 10, "def": "a small flat file for shaping the nails", "synonyms": ["nailfile"], "image_count": 7, "id": 712, "frequency": "r", "synset": "nailfile.n.01"}, {"name": "napkin", "instance_count": 3979, "def": "a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing", "synonyms": ["napkin", "table_napkin", "serviette"], "image_count": 1791, "id": 713, "frequency": "f", "synset": "napkin.n.01"}, {"name": "neckerchief", "instance_count": 4, "def": "a kerchief worn around the neck", "synonyms": ["neckerchief"], "image_count": 2, "id": 714, "frequency": "r", "synset": "neckerchief.n.01"}, {"name": "necklace", "instance_count": 2709, "def": "jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament", "synonyms": ["necklace"], "image_count": 1915, "id": 715, "frequency": "f", "synset": "necklace.n.01"}, {"name": "necktie", "instance_count": 4069, "def": "neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front", "synonyms": ["necktie", "tie_(necktie)"], "image_count": 1940, "id": 716, "frequency": "f", "synset": "necktie.n.01"}, {"name": "needle", "instance_count": 61, "def": "a sharp pointed implement (usually metal)", "synonyms": ["needle"], "image_count": 13, "id": 717, "frequency": "c", "synset": "needle.n.03"}, {"name": "nest", "instance_count": 20, "def": "a structure in which animals lay eggs or give birth to their young", "synonyms": ["nest"], "image_count": 16, "id": 718, "frequency": "c", "synset": "nest.n.01"}, {"name": "newspaper", "instance_count": 1179, "def": "a daily or weekly publication on folded sheets containing news, articles, and advertisements", "synonyms": ["newspaper", "paper_(newspaper)"], "image_count": 448, "id": 719, "frequency": "f", "synset": "newspaper.n.01"}, {"name": "newsstand", "instance_count": 39, "def": "a stall where newspapers and other periodicals are sold", "synonyms": ["newsstand"], "image_count": 12, "id": 720, "frequency": "c", "synset": "newsstand.n.01"}, {"name": "nightshirt", "instance_count": 35, "def": "garments designed to be worn in bed", "synonyms": ["nightshirt", "nightwear", "sleepwear", "nightclothes"], "image_count": 18, "id": 721, "frequency": "c", "synset": "nightwear.n.01"}, {"name": "nosebag_(for_animals)", "instance_count": 4, "def": "a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head", "synonyms": ["nosebag_(for_animals)", "feedbag"], "image_count": 4, "id": 722, "frequency": "r", "synset": "nosebag.n.01"}, {"name": "noseband_(for_animals)", "instance_count": 120, "def": "a strap that is the part of a bridle that goes over the animal's nose", "synonyms": ["noseband_(for_animals)", "nosepiece_(for_animals)"], "image_count": 71, "id": 723, "frequency": "c", "synset": "noseband.n.01"}, {"name": "notebook", "instance_count": 290, "def": "a book with blank pages for recording notes or memoranda", "synonyms": ["notebook"], "image_count": 189, "id": 724, "frequency": "f", "synset": "notebook.n.01"}, {"name": "notepad", "instance_count": 187, "def": "a pad of paper for keeping notes", "synonyms": ["notepad"], "image_count": 74, "id": 725, "frequency": "c", "synset": "notepad.n.01"}, {"name": "nut", "instance_count": 790, "def": "a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt", "synonyms": ["nut"], "image_count": 103, "id": 726, "frequency": "f", "synset": "nut.n.03"}, {"name": "nutcracker", "instance_count": 7, "def": "a hand tool used to crack nuts open", "synonyms": ["nutcracker"], "image_count": 3, "id": 727, "frequency": "r", "synset": "nutcracker.n.01"}, {"name": "oar", "instance_count": 488, "def": "an implement used to propel or steer a boat", "synonyms": ["oar"], "image_count": 110, "id": 728, "frequency": "f", "synset": "oar.n.01"}, {"name": "octopus_(food)", "instance_count": 5, "def": "tentacles of octopus prepared as food", "synonyms": ["octopus_(food)"], "image_count": 5, "id": 729, "frequency": "r", "synset": "octopus.n.01"}, {"name": "octopus_(animal)", "instance_count": 17, "def": "bottom-living cephalopod having a soft oval body with eight long tentacles", "synonyms": ["octopus_(animal)"], "image_count": 9, "id": 730, "frequency": "r", "synset": "octopus.n.02"}, {"name": "oil_lamp", "instance_count": 28, "def": "a lamp that burns oil (as kerosine) for light", "synonyms": ["oil_lamp", "kerosene_lamp", "kerosine_lamp"], "image_count": 15, "id": 731, "frequency": "c", "synset": "oil_lamp.n.01"}, {"name": "olive_oil", "instance_count": 36, "def": "oil from olives", "synonyms": ["olive_oil"], "image_count": 25, "id": 732, "frequency": "c", "synset": "olive_oil.n.01"}, {"name": "omelet", "instance_count": 10, "def": "beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly", "synonyms": ["omelet", "omelette"], "image_count": 7, "id": 733, "frequency": "r", "synset": "omelet.n.01"}, {"name": "onion", "instance_count": 9779, "def": "the bulb of an onion plant", "synonyms": ["onion"], "image_count": 647, "id": 734, "frequency": "f", "synset": "onion.n.01"}, {"name": "orange_(fruit)", "instance_count": 13034, "def": "orange (FRUIT of an orange tree)", "synonyms": ["orange_(fruit)"], "image_count": 824, "id": 735, "frequency": "f", "synset": "orange.n.01"}, {"name": "orange_juice", "instance_count": 223, "def": "bottled or freshly squeezed juice of oranges", "synonyms": ["orange_juice"], "image_count": 100, "id": 736, "frequency": "c", "synset": "orange_juice.n.01"}, {"name": "ostrich", "instance_count": 71, "def": "fast-running African flightless bird with two-toed feet; largest living bird", "synonyms": ["ostrich"], "image_count": 47, "id": 737, "frequency": "c", "synset": "ostrich.n.02"}, {"name": "ottoman", "instance_count": 157, "def": "a thick standalone cushion used as a seat or footrest, often next to a chair", "synonyms": ["ottoman", "pouf", "pouffe", "hassock"], "image_count": 121, "id": 738, "frequency": "f", "synset": "ottoman.n.03"}, {"name": "oven", "instance_count": 929, "def": "kitchen appliance used for baking or roasting", "synonyms": ["oven"], "image_count": 731, "id": 739, "frequency": "f", "synset": "oven.n.01"}, {"name": "overalls_(clothing)", "instance_count": 76, "def": "work clothing consisting of denim trousers usually with a bib and shoulder straps", "synonyms": ["overalls_(clothing)"], "image_count": 73, "id": 740, "frequency": "c", "synset": "overall.n.01"}, {"name": "owl", "instance_count": 73, "def": "nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes", "synonyms": ["owl"], "image_count": 49, "id": 741, "frequency": "c", "synset": "owl.n.01"}, {"name": "packet", "instance_count": 109, "def": "a small package or bundle", "synonyms": ["packet"], "image_count": 23, "id": 742, "frequency": "c", "synset": "packet.n.03"}, {"name": "inkpad", "instance_count": 12, "def": "absorbent material saturated with ink used to transfer ink evenly to a rubber stamp", "synonyms": ["inkpad", "inking_pad", "stamp_pad"], "image_count": 4, "id": 743, "frequency": "r", "synset": "pad.n.03"}, {"name": "pad", "instance_count": 264, "def": "mostly arm/knee pads labeled", "synonyms": ["pad"], "image_count": 62, "id": 744, "frequency": "c", "synset": "pad.n.04"}, {"name": "paddle", "instance_count": 306, "def": "a short light oar used without an oarlock to propel a canoe or small boat", "synonyms": ["paddle", "boat_paddle"], "image_count": 118, "id": 745, "frequency": "f", "synset": "paddle.n.04"}, {"name": "padlock", "instance_count": 184, "def": "a detachable, portable lock", "synonyms": ["padlock"], "image_count": 99, "id": 746, "frequency": "c", "synset": "padlock.n.01"}, {"name": "paintbrush", "instance_count": 91, "def": "a brush used as an applicator to apply paint", "synonyms": ["paintbrush"], "image_count": 40, "id": 747, "frequency": "c", "synset": "paintbrush.n.01"}, {"name": "painting", "instance_count": 2645, "def": "graphic art consisting of an artistic composition made by applying paints to a surface", "synonyms": ["painting"], "image_count": 1036, "id": 748, "frequency": "f", "synset": "painting.n.01"}, {"name": "pajamas", "instance_count": 163, "def": "loose-fitting nightclothes worn for sleeping or lounging", "synonyms": ["pajamas", "pyjamas"], "image_count": 105, "id": 749, "frequency": "f", "synset": "pajama.n.02"}, {"name": "palette", "instance_count": 68, "def": "board that provides a flat surface on which artists mix paints and the range of colors used", "synonyms": ["palette", "pallet"], "image_count": 21, "id": 750, "frequency": "c", "synset": "palette.n.02"}, {"name": "pan_(for_cooking)", "instance_count": 643, "def": "cooking utensil consisting of a wide metal vessel", "synonyms": ["pan_(for_cooking)", "cooking_pan"], "image_count": 229, "id": 751, "frequency": "f", "synset": "pan.n.01"}, {"name": "pan_(metal_container)", "instance_count": 21, "def": "shallow container made of metal", "synonyms": ["pan_(metal_container)"], "image_count": 7, "id": 752, "frequency": "r", "synset": "pan.n.03"}, {"name": "pancake", "instance_count": 295, "def": "a flat cake of thin batter fried on both sides on a griddle", "synonyms": ["pancake"], "image_count": 72, "id": 753, "frequency": "c", "synset": "pancake.n.01"}, {"name": "pantyhose", "instance_count": 11, "def": "a woman's tights consisting of underpants and stockings", "synonyms": ["pantyhose"], "image_count": 9, "id": 754, "frequency": "r", "synset": "pantyhose.n.01"}, {"name": "papaya", "instance_count": 206, "def": "large oval melon-like tropical fruit with yellowish flesh", "synonyms": ["papaya"], "image_count": 10, "id": 755, "frequency": "r", "synset": "papaya.n.02"}, {"name": "paper_plate", "instance_count": 957, "def": "a disposable plate made of cardboard", "synonyms": ["paper_plate"], "image_count": 328, "id": 756, "frequency": "f", "synset": "paper_plate.n.01"}, {"name": "paper_towel", "instance_count": 600, "def": "a disposable towel made of absorbent paper", "synonyms": ["paper_towel"], "image_count": 468, "id": 757, "frequency": "f", "synset": "paper_towel.n.01"}, {"name": "paperback_book", "instance_count": 3, "def": "a book with paper covers", "synonyms": ["paperback_book", "paper-back_book", "softback_book", "soft-cover_book"], "image_count": 1, "id": 758, "frequency": "r", "synset": "paperback_book.n.01"}, {"name": "paperweight", "instance_count": 4, "def": "a weight used to hold down a stack of papers", "synonyms": ["paperweight"], "image_count": 2, "id": 759, "frequency": "r", "synset": "paperweight.n.01"}, {"name": "parachute", "instance_count": 61, "def": "rescue equipment consisting of a device that fills with air and retards your fall", "synonyms": ["parachute"], "image_count": 24, "id": 760, "frequency": "c", "synset": "parachute.n.01"}, {"name": "parakeet", "instance_count": 46, "def": "any of numerous small slender long-tailed parrots", "synonyms": ["parakeet", "parrakeet", "parroket", "paraquet", "paroquet", "parroquet"], "image_count": 11, "id": 761, "frequency": "c", "synset": "parakeet.n.01"}, {"name": "parasail_(sports)", "instance_count": 385, "def": "parachute that will lift a person up into the air when it is towed by a motorboat or a car", "synonyms": ["parasail_(sports)"], "image_count": 72, "id": 762, "frequency": "c", "synset": "parasail.n.01"}, {"name": "parasol", "instance_count": 45, "def": "a handheld collapsible source of shade", "synonyms": ["parasol", "sunshade"], "image_count": 17, "id": 763, "frequency": "c", "synset": "parasol.n.01"}, {"name": "parchment", "instance_count": 17, "def": "a superior paper resembling sheepskin", "synonyms": ["parchment"], "image_count": 10, "id": 764, "frequency": "r", "synset": "parchment.n.01"}, {"name": "parka", "instance_count": 89, "def": "a kind of heavy jacket (`windcheater' is a British term)", "synonyms": ["parka", "anorak"], "image_count": 17, "id": 765, "frequency": "c", "synset": "parka.n.01"}, {"name": "parking_meter", "instance_count": 1075, "def": "a coin-operated timer located next to a parking space", "synonyms": ["parking_meter"], "image_count": 489, "id": 766, "frequency": "f", "synset": "parking_meter.n.01"}, {"name": "parrot", "instance_count": 76, "def": "usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds", "synonyms": ["parrot"], "image_count": 47, "id": 767, "frequency": "c", "synset": "parrot.n.01"}, {"name": "passenger_car_(part_of_a_train)", "instance_count": 465, "def": "a railcar where passengers ride", "synonyms": ["passenger_car_(part_of_a_train)", "coach_(part_of_a_train)"], "image_count": 93, "id": 768, "frequency": "c", "synset": "passenger_car.n.01"}, {"name": "passenger_ship", "instance_count": 1, "def": "a ship built to carry passengers", "synonyms": ["passenger_ship"], "image_count": 1, "id": 769, "frequency": "r", "synset": "passenger_ship.n.01"}, {"name": "passport", "instance_count": 12, "def": "a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country", "synonyms": ["passport"], "image_count": 12, "id": 770, "frequency": "c", "synset": "passport.n.02"}, {"name": "pastry", "instance_count": 4972, "def": "any of various baked foods made of dough or batter", "synonyms": ["pastry"], "image_count": 228, "id": 771, "frequency": "f", "synset": "pastry.n.02"}, {"name": "patty_(food)", "instance_count": 20, "def": "small flat mass of chopped food", "synonyms": ["patty_(food)"], "image_count": 5, "id": 772, "frequency": "r", "synset": "patty.n.01"}, {"name": "pea_(food)", "instance_count": 1869, "def": "seed of a pea plant used for food", "synonyms": ["pea_(food)"], "image_count": 76, "id": 773, "frequency": "c", "synset": "pea.n.01"}, {"name": "peach", "instance_count": 1041, "def": "downy juicy fruit with sweet yellowish or whitish flesh", "synonyms": ["peach"], "image_count": 71, "id": 774, "frequency": "c", "synset": "peach.n.03"}, {"name": "peanut_butter", "instance_count": 50, "def": "a spread made from ground peanuts", "synonyms": ["peanut_butter"], "image_count": 30, "id": 775, "frequency": "c", "synset": "peanut_butter.n.01"}, {"name": "pear", "instance_count": 1069, "def": "sweet juicy gritty-textured fruit available in many varieties", "synonyms": ["pear"], "image_count": 109, "id": 776, "frequency": "f", "synset": "pear.n.01"}, {"name": "peeler_(tool_for_fruit_and_vegetables)", "instance_count": 18, "def": "a device for peeling vegetables or fruits", "synonyms": ["peeler_(tool_for_fruit_and_vegetables)"], "image_count": 14, "id": 777, "frequency": "c", "synset": "peeler.n.03"}, {"name": "wooden_leg", "instance_count": 1, "def": "a prosthesis that replaces a missing leg", "synonyms": ["wooden_leg", "pegleg"], "image_count": 1, "id": 778, "frequency": "r", "synset": "peg.n.04"}, {"name": "pegboard", "instance_count": 9, "def": "a board perforated with regularly spaced holes into which pegs can be fitted", "synonyms": ["pegboard"], "image_count": 8, "id": 779, "frequency": "r", "synset": "pegboard.n.01"}, {"name": "pelican", "instance_count": 76, "def": "large long-winged warm-water seabird having a large bill with a distensible pouch for fish", "synonyms": ["pelican"], "image_count": 26, "id": 780, "frequency": "c", "synset": "pelican.n.01"}, {"name": "pen", "instance_count": 987, "def": "a writing implement with a point from which ink flows", "synonyms": ["pen"], "image_count": 339, "id": 781, "frequency": "f", "synset": "pen.n.01"}, {"name": "pencil", "instance_count": 543, "def": "a thin cylindrical pointed writing implement made of wood and graphite", "synonyms": ["pencil"], "image_count": 153, "id": 782, "frequency": "f", "synset": "pencil.n.01"}, {"name": "pencil_box", "instance_count": 2, "def": "a box for holding pencils", "synonyms": ["pencil_box", "pencil_case"], "image_count": 2, "id": 783, "frequency": "r", "synset": "pencil_box.n.01"}, {"name": "pencil_sharpener", "instance_count": 4, "def": "a rotary implement for sharpening the point on pencils", "synonyms": ["pencil_sharpener"], "image_count": 3, "id": 784, "frequency": "r", "synset": "pencil_sharpener.n.01"}, {"name": "pendulum", "instance_count": 18, "def": "an apparatus consisting of an object mounted so that it swings freely under the influence of gravity", "synonyms": ["pendulum"], "image_count": 8, "id": 785, "frequency": "r", "synset": "pendulum.n.01"}, {"name": "penguin", "instance_count": 229, "def": "short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers", "synonyms": ["penguin"], "image_count": 47, "id": 786, "frequency": "c", "synset": "penguin.n.01"}, {"name": "pennant", "instance_count": 235, "def": "a flag longer than it is wide (and often tapering)", "synonyms": ["pennant"], "image_count": 8, "id": 787, "frequency": "r", "synset": "pennant.n.02"}, {"name": "penny_(coin)", "instance_count": 15, "def": "a coin worth one-hundredth of the value of the basic unit", "synonyms": ["penny_(coin)"], "image_count": 6, "id": 788, "frequency": "r", "synset": "penny.n.02"}, {"name": "pepper", "instance_count": 697, "def": "pungent seasoning from the berry of the common pepper plant; whole or ground", "synonyms": ["pepper", "peppercorn"], "image_count": 116, "id": 789, "frequency": "f", "synset": "pepper.n.03"}, {"name": "pepper_mill", "instance_count": 91, "def": "a mill for grinding pepper", "synonyms": ["pepper_mill", "pepper_grinder"], "image_count": 69, "id": 790, "frequency": "c", "synset": "pepper_mill.n.01"}, {"name": "perfume", "instance_count": 28, "def": "a toiletry that emits and diffuses a fragrant odor", "synonyms": ["perfume"], "image_count": 13, "id": 791, "frequency": "c", "synset": "perfume.n.02"}, {"name": "persimmon", "instance_count": 22, "def": "orange fruit resembling a plum; edible when fully ripe", "synonyms": ["persimmon"], "image_count": 6, "id": 792, "frequency": "r", "synset": "persimmon.n.02"}, {"name": "person", "instance_count": 13439, "def": "a human being", "synonyms": ["person", "baby", "child", "boy", "girl", "man", "woman", "human"], "image_count": 1928, "id": 793, "frequency": "f", "synset": "person.n.01"}, {"name": "pet", "instance_count": 103, "def": "a domesticated animal kept for companionship or amusement", "synonyms": ["pet"], "image_count": 79, "id": 794, "frequency": "c", "synset": "pet.n.01"}, {"name": "pew_(church_bench)", "instance_count": 194, "def": "long bench with backs; used in church by the congregation", "synonyms": ["pew_(church_bench)", "church_bench"], "image_count": 14, "id": 795, "frequency": "c", "synset": "pew.n.01"}, {"name": "phonebook", "instance_count": 24, "def": "a directory containing an alphabetical list of telephone subscribers and their telephone numbers", "synonyms": ["phonebook", "telephone_book", "telephone_directory"], "image_count": 7, "id": 796, "frequency": "r", "synset": "phonebook.n.01"}, {"name": "phonograph_record", "instance_count": 138, "def": "sound recording consisting of a typically black disk with a continuous groove", "synonyms": ["phonograph_record", "phonograph_recording", "record_(phonograph_recording)"], "image_count": 20, "id": 797, "frequency": "c", "synset": "phonograph_record.n.01"}, {"name": "piano", "instance_count": 126, "def": "a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds", "synonyms": ["piano"], "image_count": 114, "id": 798, "frequency": "f", "synset": "piano.n.01"}, {"name": "pickle", "instance_count": 632, "def": "vegetables (especially cucumbers) preserved in brine or vinegar", "synonyms": ["pickle"], "image_count": 221, "id": 799, "frequency": "f", "synset": "pickle.n.01"}, {"name": "pickup_truck", "instance_count": 838, "def": "a light truck with an open body and low sides and a tailboard", "synonyms": ["pickup_truck"], "image_count": 502, "id": 800, "frequency": "f", "synset": "pickup.n.01"}, {"name": "pie", "instance_count": 228, "def": "dish baked in pastry-lined pan often with a pastry top", "synonyms": ["pie"], "image_count": 62, "id": 801, "frequency": "c", "synset": "pie.n.01"}, {"name": "pigeon", "instance_count": 1850, "def": "wild and domesticated birds having a heavy body and short legs", "synonyms": ["pigeon"], "image_count": 87, "id": 802, "frequency": "c", "synset": "pigeon.n.01"}, {"name": "piggy_bank", "instance_count": 5, "def": "a child's coin bank (often shaped like a pig)", "synonyms": ["piggy_bank", "penny_bank"], "image_count": 4, "id": 803, "frequency": "r", "synset": "piggy_bank.n.01"}, {"name": "pillow", "instance_count": 6115, "def": "a cushion to support the head of a sleeping person", "synonyms": ["pillow"], "image_count": 1912, "id": 804, "frequency": "f", "synset": "pillow.n.01"}, {"name": "pin_(non_jewelry)", "instance_count": 112, "def": "a small slender (often pointed) piece of wood or metal used to support or fasten or attach things", "synonyms": ["pin_(non_jewelry)"], "image_count": 7, "id": 805, "frequency": "r", "synset": "pin.n.09"}, {"name": "pineapple", "instance_count": 1636, "def": "large sweet fleshy tropical fruit with a tuft of stiff leaves", "synonyms": ["pineapple"], "image_count": 186, "id": 806, "frequency": "f", "synset": "pineapple.n.02"}, {"name": "pinecone", "instance_count": 141, "def": "the seed-producing cone of a pine tree", "synonyms": ["pinecone"], "image_count": 18, "id": 807, "frequency": "c", "synset": "pinecone.n.01"}, {"name": "ping-pong_ball", "instance_count": 4, "def": "light hollow ball used in playing table tennis", "synonyms": ["ping-pong_ball"], "image_count": 4, "id": 808, "frequency": "r", "synset": "ping-pong_ball.n.01"}, {"name": "pinwheel", "instance_count": 172, "def": "a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind", "synonyms": ["pinwheel"], "image_count": 3, "id": 809, "frequency": "r", "synset": "pinwheel.n.03"}, {"name": "tobacco_pipe", "instance_count": 7, "def": "a tube with a small bowl at one end; used for smoking tobacco", "synonyms": ["tobacco_pipe"], "image_count": 7, "id": 810, "frequency": "r", "synset": "pipe.n.01"}, {"name": "pipe", "instance_count": 4762, "def": "a long tube made of metal or plastic that is used to carry water or oil or gas etc.", "synonyms": ["pipe", "piping"], "image_count": 1413, "id": 811, "frequency": "f", "synset": "pipe.n.02"}, {"name": "pistol", "instance_count": 9, "def": "a firearm that is held and fired with one hand", "synonyms": ["pistol", "handgun"], "image_count": 7, "id": 812, "frequency": "r", "synset": "pistol.n.01"}, {"name": "pita_(bread)", "instance_count": 28, "def": "usually small round bread that can open into a pocket for filling", "synonyms": ["pita_(bread)", "pocket_bread"], "image_count": 12, "id": 813, "frequency": "c", "synset": "pita.n.01"}, {"name": "pitcher_(vessel_for_liquid)", "instance_count": 488, "def": "an open vessel with a handle and a spout for pouring", "synonyms": ["pitcher_(vessel_for_liquid)", "ewer"], "image_count": 248, "id": 814, "frequency": "f", "synset": "pitcher.n.02"}, {"name": "pitchfork", "instance_count": 4, "def": "a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay", "synonyms": ["pitchfork"], "image_count": 4, "id": 815, "frequency": "r", "synset": "pitchfork.n.01"}, {"name": "pizza", "instance_count": 4103, "def": "Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese", "synonyms": ["pizza"], "image_count": 1881, "id": 816, "frequency": "f", "synset": "pizza.n.01"}, {"name": "place_mat", "instance_count": 1123, "def": "a mat placed on a table for an individual place setting", "synonyms": ["place_mat"], "image_count": 529, "id": 817, "frequency": "f", "synset": "place_mat.n.01"}, {"name": "plate", "instance_count": 5214, "def": "dish on which food is served or from which food is eaten", "synonyms": ["plate"], "image_count": 1932, "id": 818, "frequency": "f", "synset": "plate.n.04"}, {"name": "platter", "instance_count": 148, "def": "a large shallow dish used for serving food", "synonyms": ["platter"], "image_count": 50, "id": 819, "frequency": "c", "synset": "platter.n.01"}, {"name": "playpen", "instance_count": 3, "def": "a portable enclosure in which babies may be left to play", "synonyms": ["playpen"], "image_count": 3, "id": 820, "frequency": "r", "synset": "playpen.n.01"}, {"name": "pliers", "instance_count": 49, "def": "a gripping hand tool with two hinged arms and (usually) serrated jaws", "synonyms": ["pliers", "plyers"], "image_count": 28, "id": 821, "frequency": "c", "synset": "pliers.n.01"}, {"name": "plow_(farm_equipment)", "instance_count": 12, "def": "a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing", "synonyms": ["plow_(farm_equipment)", "plough_(farm_equipment)"], "image_count": 10, "id": 822, "frequency": "r", "synset": "plow.n.01"}, {"name": "plume", "instance_count": 11, "def": "a feather or cluster of feathers worn as an ornament", "synonyms": ["plume"], "image_count": 5, "id": 823, "frequency": "r", "synset": "plume.n.02"}, {"name": "pocket_watch", "instance_count": 20, "def": "a watch that is carried in a small watch pocket", "synonyms": ["pocket_watch"], "image_count": 5, "id": 824, "frequency": "r", "synset": "pocket_watch.n.01"}, {"name": "pocketknife", "instance_count": 21, "def": "a knife with a blade that folds into the handle; suitable for carrying in the pocket", "synonyms": ["pocketknife"], "image_count": 18, "id": 825, "frequency": "c", "synset": "pocketknife.n.01"}, {"name": "poker_(fire_stirring_tool)", "instance_count": 34, "def": "fire iron consisting of a metal rod with a handle; used to stir a fire", "synonyms": ["poker_(fire_stirring_tool)", "stove_poker", "fire_hook"], "image_count": 14, "id": 826, "frequency": "c", "synset": "poker.n.01"}, {"name": "pole", "instance_count": 14276, "def": "a long (usually round) rod of wood or metal or plastic", "synonyms": ["pole", "post"], "image_count": 1890, "id": 827, "frequency": "f", "synset": "pole.n.01"}, {"name": "polo_shirt", "instance_count": 1695, "def": "a shirt with short sleeves designed for comfort and casual wear", "synonyms": ["polo_shirt", "sport_shirt"], "image_count": 660, "id": 828, "frequency": "f", "synset": "polo_shirt.n.01"}, {"name": "poncho", "instance_count": 14, "def": "a blanket-like cloak with a hole in the center for the head", "synonyms": ["poncho"], "image_count": 8, "id": 829, "frequency": "r", "synset": "poncho.n.01"}, {"name": "pony", "instance_count": 57, "def": "any of various breeds of small gentle horses usually less than five feet high at the shoulder", "synonyms": ["pony"], "image_count": 25, "id": 830, "frequency": "c", "synset": "pony.n.05"}, {"name": "pool_table", "instance_count": 10, "def": "game equipment consisting of a heavy table on which pool is played", "synonyms": ["pool_table", "billiard_table", "snooker_table"], "image_count": 10, "id": 831, "frequency": "r", "synset": "pool_table.n.01"}, {"name": "pop_(soda)", "instance_count": 951, "def": "a sweet drink containing carbonated water and flavoring", "synonyms": ["pop_(soda)", "soda_(pop)", "tonic", "soft_drink"], "image_count": 218, "id": 832, "frequency": "f", "synset": "pop.n.02"}, {"name": "postbox_(public)", "instance_count": 57, "def": "public box for deposit of mail", "synonyms": ["postbox_(public)", "mailbox_(public)"], "image_count": 36, "id": 833, "frequency": "c", "synset": "postbox.n.01"}, {"name": "postcard", "instance_count": 276, "def": "a card for sending messages by post without an envelope", "synonyms": ["postcard", "postal_card", "mailing-card"], "image_count": 16, "id": 834, "frequency": "c", "synset": "postcard.n.01"}, {"name": "poster", "instance_count": 3378, "def": "a sign posted in a public place as an advertisement", "synonyms": ["poster", "placard"], "image_count": 808, "id": 835, "frequency": "f", "synset": "poster.n.01"}, {"name": "pot", "instance_count": 1719, "def": "metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid", "synonyms": ["pot"], "image_count": 479, "id": 836, "frequency": "f", "synset": "pot.n.01"}, {"name": "flowerpot", "instance_count": 3902, "def": "a container in which plants are cultivated", "synonyms": ["flowerpot"], "image_count": 1404, "id": 837, "frequency": "f", "synset": "pot.n.04"}, {"name": "potato", "instance_count": 4393, "def": "an edible tuber native to South America", "synonyms": ["potato"], "image_count": 307, "id": 838, "frequency": "f", "synset": "potato.n.01"}, {"name": "potholder", "instance_count": 112, "def": "an insulated pad for holding hot pots", "synonyms": ["potholder"], "image_count": 57, "id": 839, "frequency": "c", "synset": "potholder.n.01"}, {"name": "pottery", "instance_count": 272, "def": "ceramic ware made from clay and baked in a kiln", "synonyms": ["pottery", "clayware"], "image_count": 28, "id": 840, "frequency": "c", "synset": "pottery.n.01"}, {"name": "pouch", "instance_count": 131, "def": "a small or medium size container for holding or carrying things", "synonyms": ["pouch"], "image_count": 80, "id": 841, "frequency": "c", "synset": "pouch.n.01"}, {"name": "power_shovel", "instance_count": 16, "def": "a machine for excavating", "synonyms": ["power_shovel", "excavator", "digger"], "image_count": 11, "id": 842, "frequency": "c", "synset": "power_shovel.n.01"}, {"name": "prawn", "instance_count": 779, "def": "any of various edible decapod crustaceans", "synonyms": ["prawn", "shrimp"], "image_count": 92, "id": 843, "frequency": "c", "synset": "prawn.n.01"}, {"name": "pretzel", "instance_count": 179, "def": "glazed and salted cracker typically in the shape of a loose knot", "synonyms": ["pretzel"], "image_count": 20, "id": 844, "frequency": "c", "synset": "pretzel.n.01"}, {"name": "printer", "instance_count": 217, "def": "a machine that prints", "synonyms": ["printer", "printing_machine"], "image_count": 194, "id": 845, "frequency": "f", "synset": "printer.n.03"}, {"name": "projectile_(weapon)", "instance_count": 64, "def": "a weapon that is forcibly thrown or projected at a targets", "synonyms": ["projectile_(weapon)", "missile"], "image_count": 23, "id": 846, "frequency": "c", "synset": "projectile.n.01"}, {"name": "projector", "instance_count": 54, "def": "an optical instrument that projects an enlarged image onto a screen", "synonyms": ["projector"], "image_count": 52, "id": 847, "frequency": "c", "synset": "projector.n.02"}, {"name": "propeller", "instance_count": 1458, "def": "a mechanical device that rotates to push against air or water", "synonyms": ["propeller", "propellor"], "image_count": 673, "id": 848, "frequency": "f", "synset": "propeller.n.01"}, {"name": "prune", "instance_count": 8, "def": "dried plum", "synonyms": ["prune"], "image_count": 2, "id": 849, "frequency": "r", "synset": "prune.n.01"}, {"name": "pudding", "instance_count": 2, "def": "any of various soft thick unsweetened baked dishes", "synonyms": ["pudding"], "image_count": 2, "id": 850, "frequency": "r", "synset": "pudding.n.01"}, {"name": "puffer_(fish)", "instance_count": 2, "def": "fishes whose elongated spiny body can inflate itself with water or air to form a globe", "synonyms": ["puffer_(fish)", "pufferfish", "blowfish", "globefish"], "image_count": 1, "id": 851, "frequency": "r", "synset": "puffer.n.02"}, {"name": "puffin", "instance_count": 4, "def": "seabirds having short necks and brightly colored compressed bills", "synonyms": ["puffin"], "image_count": 2, "id": 852, "frequency": "r", "synset": "puffin.n.01"}, {"name": "pug-dog", "instance_count": 13, "def": "small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle", "synonyms": ["pug-dog"], "image_count": 8, "id": 853, "frequency": "r", "synset": "pug.n.01"}, {"name": "pumpkin", "instance_count": 1192, "def": "usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn", "synonyms": ["pumpkin"], "image_count": 80, "id": 854, "frequency": "c", "synset": "pumpkin.n.02"}, {"name": "puncher", "instance_count": 6, "def": "a tool for making holes or indentations", "synonyms": ["puncher"], "image_count": 3, "id": 855, "frequency": "r", "synset": "punch.n.03"}, {"name": "puppet", "instance_count": 18, "def": "a small figure of a person operated from above with strings by a puppeteer", "synonyms": ["puppet", "marionette"], "image_count": 3, "id": 856, "frequency": "r", "synset": "puppet.n.01"}, {"name": "puppy", "instance_count": 57, "def": "a young dog", "synonyms": ["puppy"], "image_count": 15, "id": 857, "frequency": "c", "synset": "puppy.n.01"}, {"name": "quesadilla", "instance_count": 6, "def": "a tortilla that is filled with cheese and heated", "synonyms": ["quesadilla"], "image_count": 2, "id": 858, "frequency": "r", "synset": "quesadilla.n.01"}, {"name": "quiche", "instance_count": 33, "def": "a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)", "synonyms": ["quiche"], "image_count": 10, "id": 859, "frequency": "r", "synset": "quiche.n.02"}, {"name": "quilt", "instance_count": 513, "def": "bedding made of two layers of cloth filled with stuffing and stitched together", "synonyms": ["quilt", "comforter"], "image_count": 386, "id": 860, "frequency": "f", "synset": "quilt.n.01"}, {"name": "rabbit", "instance_count": 139, "def": "any of various burrowing animals of the family Leporidae having long ears and short tails", "synonyms": ["rabbit"], "image_count": 65, "id": 861, "frequency": "c", "synset": "rabbit.n.01"}, {"name": "race_car", "instance_count": 6, "def": "a fast car that competes in races", "synonyms": ["race_car", "racing_car"], "image_count": 3, "id": 862, "frequency": "r", "synset": "racer.n.02"}, {"name": "racket", "instance_count": 64, "def": "a sports implement used to strike a ball in various games", "synonyms": ["racket", "racquet"], "image_count": 35, "id": 863, "frequency": "c", "synset": "racket.n.04"}, {"name": "radar", "instance_count": 13, "def": "measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects", "synonyms": ["radar"], "image_count": 5, "id": 864, "frequency": "r", "synset": "radar.n.01"}, {"name": "radiator", "instance_count": 195, "def": "a mechanism consisting of a metal honeycomb through which hot fluids circulate", "synonyms": ["radiator"], "image_count": 180, "id": 865, "frequency": "f", "synset": "radiator.n.03"}, {"name": "radio_receiver", "instance_count": 123, "def": "an electronic receiver that detects and demodulates and amplifies transmitted radio signals", "synonyms": ["radio_receiver", "radio_set", "radio", "tuner_(radio)"], "image_count": 99, "id": 866, "frequency": "c", "synset": "radio_receiver.n.01"}, {"name": "radish", "instance_count": 519, "def": "pungent edible root of any of various cultivated radish plants", "synonyms": ["radish", "daikon"], "image_count": 49, "id": 867, "frequency": "c", "synset": "radish.n.03"}, {"name": "raft", "instance_count": 66, "def": "a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers", "synonyms": ["raft"], "image_count": 28, "id": 868, "frequency": "c", "synset": "raft.n.01"}, {"name": "rag_doll", "instance_count": 3, "def": "a cloth doll that is stuffed and (usually) painted", "synonyms": ["rag_doll"], "image_count": 1, "id": 869, "frequency": "r", "synset": "rag_doll.n.01"}, {"name": "raincoat", "instance_count": 303, "def": "a water-resistant coat", "synonyms": ["raincoat", "waterproof_jacket"], "image_count": 52, "id": 870, "frequency": "c", "synset": "raincoat.n.01"}, {"name": "ram_(animal)", "instance_count": 132, "def": "uncastrated adult male sheep", "synonyms": ["ram_(animal)"], "image_count": 36, "id": 871, "frequency": "c", "synset": "ram.n.05"}, {"name": "raspberry", "instance_count": 778, "def": "red or black edible aggregate berries usually smaller than the related blackberries", "synonyms": ["raspberry"], "image_count": 70, "id": 872, "frequency": "c", "synset": "raspberry.n.02"}, {"name": "rat", "instance_count": 6, "def": "any of various long-tailed rodents similar to but larger than a mouse", "synonyms": ["rat"], "image_count": 6, "id": 873, "frequency": "r", "synset": "rat.n.01"}, {"name": "razorblade", "instance_count": 35, "def": "a blade that has very sharp edge", "synonyms": ["razorblade"], "image_count": 29, "id": 874, "frequency": "c", "synset": "razorblade.n.01"}, {"name": "reamer_(juicer)", "instance_count": 26, "def": "a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit", "synonyms": ["reamer_(juicer)", "juicer", "juice_reamer"], "image_count": 24, "id": 875, "frequency": "c", "synset": "reamer.n.01"}, {"name": "rearview_mirror", "instance_count": 3650, "def": "vehicle mirror (side or rearview)", "synonyms": ["rearview_mirror"], "image_count": 1115, "id": 876, "frequency": "f", "synset": "rearview_mirror.n.01"}, {"name": "receipt", "instance_count": 89, "def": "an acknowledgment (usually tangible) that payment has been made", "synonyms": ["receipt"], "image_count": 61, "id": 877, "frequency": "c", "synset": "receipt.n.02"}, {"name": "recliner", "instance_count": 28, "def": "an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it", "synonyms": ["recliner", "reclining_chair", "lounger_(chair)"], "image_count": 18, "id": 878, "frequency": "c", "synset": "recliner.n.01"}, {"name": "record_player", "instance_count": 22, "def": "machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically", "synonyms": ["record_player", "phonograph_(record_player)", "turntable"], "image_count": 18, "id": 879, "frequency": "c", "synset": "record_player.n.01"}, {"name": "reflector", "instance_count": 3426, "def": "device that reflects light, radiation, etc.", "synonyms": ["reflector"], "image_count": 665, "id": 880, "frequency": "f", "synset": "reflector.n.01"}, {"name": "remote_control", "instance_count": 2467, "def": "a device that can be used to control a machine or apparatus from a distance", "synonyms": ["remote_control"], "image_count": 1096, "id": 881, "frequency": "f", "synset": "remote_control.n.01"}, {"name": "rhinoceros", "instance_count": 50, "def": "massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout", "synonyms": ["rhinoceros"], "image_count": 29, "id": 882, "frequency": "c", "synset": "rhinoceros.n.01"}, {"name": "rib_(food)", "instance_count": 32, "def": "cut of meat including one or more ribs", "synonyms": ["rib_(food)"], "image_count": 8, "id": 883, "frequency": "r", "synset": "rib.n.03"}, {"name": "rifle", "instance_count": 37, "def": "a shoulder firearm with a long barrel", "synonyms": ["rifle"], "image_count": 14, "id": 884, "frequency": "c", "synset": "rifle.n.01"}, {"name": "ring", "instance_count": 2314, "def": "jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger", "synonyms": ["ring"], "image_count": 1622, "id": 885, "frequency": "f", "synset": "ring.n.08"}, {"name": "river_boat", "instance_count": 3, "def": "a boat used on rivers or to ply a river", "synonyms": ["river_boat"], "image_count": 2, "id": 886, "frequency": "r", "synset": "river_boat.n.01"}, {"name": "road_map", "instance_count": 3, "def": "(NOT A ROAD) a MAP showing roads (for automobile travel)", "synonyms": ["road_map"], "image_count": 3, "id": 887, "frequency": "r", "synset": "road_map.n.02"}, {"name": "robe", "instance_count": 77, "def": "any loose flowing garment", "synonyms": ["robe"], "image_count": 32, "id": 888, "frequency": "c", "synset": "robe.n.01"}, {"name": "rocking_chair", "instance_count": 70, "def": "a chair mounted on rockers", "synonyms": ["rocking_chair"], "image_count": 55, "id": 889, "frequency": "c", "synset": "rocking_chair.n.01"}, {"name": "rodent", "instance_count": 2, "def": "relatively small placental mammals having a single pair of constantly growing incisor teeth specialized for gnawing", "synonyms": ["rodent"], "image_count": 1, "id": 890, "frequency": "r", "synset": "rodent.n.01"}, {"name": "roller_skate", "instance_count": 35, "def": "a shoe with pairs of rollers (small hard wheels) fixed to the sole", "synonyms": ["roller_skate"], "image_count": 10, "id": 891, "frequency": "r", "synset": "roller_skate.n.01"}, {"name": "Rollerblade", "instance_count": 31, "def": "an in-line variant of a roller skate", "synonyms": ["Rollerblade"], "image_count": 10, "id": 892, "frequency": "r", "synset": "rollerblade.n.01"}, {"name": "rolling_pin", "instance_count": 52, "def": "utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough", "synonyms": ["rolling_pin"], "image_count": 47, "id": 893, "frequency": "c", "synset": "rolling_pin.n.01"}, {"name": "root_beer", "instance_count": 3, "def": "carbonated drink containing extracts of roots and herbs", "synonyms": ["root_beer"], "image_count": 3, "id": 894, "frequency": "r", "synset": "root_beer.n.01"}, {"name": "router_(computer_equipment)", "instance_count": 41, "def": "a device that forwards data packets between computer networks", "synonyms": ["router_(computer_equipment)"], "image_count": 29, "id": 895, "frequency": "c", "synset": "router.n.02"}, {"name": "rubber_band", "instance_count": 574, "def": "a narrow band of elastic rubber used to hold things (such as papers) together", "synonyms": ["rubber_band", "elastic_band"], "image_count": 342, "id": 896, "frequency": "f", "synset": "rubber_band.n.01"}, {"name": "runner_(carpet)", "instance_count": 32, "def": "a long narrow carpet", "synonyms": ["runner_(carpet)"], "image_count": 25, "id": 897, "frequency": "c", "synset": "runner.n.08"}, {"name": "plastic_bag", "instance_count": 3631, "def": "a bag made of paper or plastic for holding customer's purchases", "synonyms": ["plastic_bag", "paper_bag"], "image_count": 1469, "id": 898, "frequency": "f", "synset": "sack.n.01"}, {"name": "saddle_(on_an_animal)", "instance_count": 955, "def": "a seat for the rider of a horse or camel", "synonyms": ["saddle_(on_an_animal)"], "image_count": 521, "id": 899, "frequency": "f", "synset": "saddle.n.01"}, {"name": "saddle_blanket", "instance_count": 648, "def": "stable gear consisting of a blanket placed under the saddle", "synonyms": ["saddle_blanket", "saddlecloth", "horse_blanket"], "image_count": 347, "id": 900, "frequency": "f", "synset": "saddle_blanket.n.01"}, {"name": "saddlebag", "instance_count": 56, "def": "a large bag (or pair of bags) hung over a saddle", "synonyms": ["saddlebag"], "image_count": 35, "id": 901, "frequency": "c", "synset": "saddlebag.n.01"}, {"name": "safety_pin", "instance_count": 15, "def": "a pin in the form of a clasp; has a guard so the point of the pin will not stick the user", "synonyms": ["safety_pin"], "image_count": 7, "id": 902, "frequency": "r", "synset": "safety_pin.n.01"}, {"name": "sail", "instance_count": 863, "def": "a large piece of fabric by means of which wind is used to propel a sailing vessel", "synonyms": ["sail"], "image_count": 207, "id": 903, "frequency": "f", "synset": "sail.n.01"}, {"name": "salad", "instance_count": 171, "def": "food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens", "synonyms": ["salad"], "image_count": 108, "id": 904, "frequency": "f", "synset": "salad.n.01"}, {"name": "salad_plate", "instance_count": 6, "def": "a plate or bowl for individual servings of salad", "synonyms": ["salad_plate", "salad_bowl"], "image_count": 2, "id": 905, "frequency": "r", "synset": "salad_plate.n.01"}, {"name": "salami", "instance_count": 290, "def": "highly seasoned fatty sausage of pork and beef usually dried", "synonyms": ["salami"], "image_count": 34, "id": 906, "frequency": "c", "synset": "salami.n.01"}, {"name": "salmon_(fish)", "instance_count": 27, "def": "any of various large food and game fishes of northern waters", "synonyms": ["salmon_(fish)"], "image_count": 12, "id": 907, "frequency": "c", "synset": "salmon.n.01"}, {"name": "salmon_(food)", "instance_count": 14, "def": "flesh of any of various marine or freshwater fish of the family Salmonidae", "synonyms": ["salmon_(food)"], "image_count": 10, "id": 908, "frequency": "r", "synset": "salmon.n.03"}, {"name": "salsa", "instance_count": 22, "def": "spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods", "synonyms": ["salsa"], "image_count": 13, "id": 909, "frequency": "c", "synset": "salsa.n.01"}, {"name": "saltshaker", "instance_count": 543, "def": "a shaker with a perforated top for sprinkling salt", "synonyms": ["saltshaker"], "image_count": 361, "id": 910, "frequency": "f", "synset": "saltshaker.n.01"}, {"name": "sandal_(type_of_shoe)", "instance_count": 3145, "def": "a shoe consisting of a sole fastened by straps to the foot", "synonyms": ["sandal_(type_of_shoe)"], "image_count": 1023, "id": 911, "frequency": "f", "synset": "sandal.n.01"}, {"name": "sandwich", "instance_count": 2315, "def": "two (or more) slices of bread with a filling between them", "synonyms": ["sandwich"], "image_count": 782, "id": 912, "frequency": "f", "synset": "sandwich.n.01"}, {"name": "satchel", "instance_count": 3, "def": "luggage consisting of a small case with a flat bottom and (usually) a shoulder strap", "synonyms": ["satchel"], "image_count": 2, "id": 913, "frequency": "r", "synset": "satchel.n.01"}, {"name": "saucepan", "instance_count": 26, "def": "a deep pan with a handle; used for stewing or boiling", "synonyms": ["saucepan"], "image_count": 5, "id": 914, "frequency": "r", "synset": "saucepan.n.01"}, {"name": "saucer", "instance_count": 555, "def": "a small shallow dish for holding a cup at the table", "synonyms": ["saucer"], "image_count": 247, "id": 915, "frequency": "f", "synset": "saucer.n.02"}, {"name": "sausage", "instance_count": 2704, "def": "highly seasoned minced meat stuffed in casings", "synonyms": ["sausage"], "image_count": 221, "id": 916, "frequency": "f", "synset": "sausage.n.01"}, {"name": "sawhorse", "instance_count": 5, "def": "a framework for holding wood that is being sawed", "synonyms": ["sawhorse", "sawbuck"], "image_count": 4, "id": 917, "frequency": "r", "synset": "sawhorse.n.01"}, {"name": "saxophone", "instance_count": 13, "def": "a wind instrument with a `J'-shaped form typically made of brass", "synonyms": ["saxophone"], "image_count": 8, "id": 918, "frequency": "r", "synset": "sax.n.02"}, {"name": "scale_(measuring_instrument)", "instance_count": 178, "def": "a measuring instrument for weighing; shows amount of mass", "synonyms": ["scale_(measuring_instrument)"], "image_count": 158, "id": 919, "frequency": "f", "synset": "scale.n.07"}, {"name": "scarecrow", "instance_count": 4, "def": "an effigy in the shape of a man to frighten birds away from seeds", "synonyms": ["scarecrow", "strawman"], "image_count": 3, "id": 920, "frequency": "r", "synset": "scarecrow.n.01"}, {"name": "scarf", "instance_count": 1310, "def": "a garment worn around the head or neck or shoulders for warmth or decoration", "synonyms": ["scarf"], "image_count": 752, "id": 921, "frequency": "f", "synset": "scarf.n.01"}, {"name": "school_bus", "instance_count": 142, "def": "a bus used to transport children to or from school", "synonyms": ["school_bus"], "image_count": 64, "id": 922, "frequency": "c", "synset": "school_bus.n.01"}, {"name": "scissors", "instance_count": 1376, "def": "a tool having two crossed pivoting blades with looped handles", "synonyms": ["scissors"], "image_count": 707, "id": 923, "frequency": "f", "synset": "scissors.n.01"}, {"name": "scoreboard", "instance_count": 161, "def": "a large board for displaying the score of a contest (and some other information)", "synonyms": ["scoreboard"], "image_count": 143, "id": 924, "frequency": "f", "synset": "scoreboard.n.01"}, {"name": "scraper", "instance_count": 1, "def": "any of various hand tools for scraping", "synonyms": ["scraper"], "image_count": 1, "id": 925, "frequency": "r", "synset": "scraper.n.01"}, {"name": "screwdriver", "instance_count": 88, "def": "a hand tool for driving screws; has a tip that fits into the head of a screw", "synonyms": ["screwdriver"], "image_count": 49, "id": 926, "frequency": "c", "synset": "screwdriver.n.01"}, {"name": "scrubbing_brush", "instance_count": 141, "def": "a brush with short stiff bristles for heavy cleaning", "synonyms": ["scrubbing_brush"], "image_count": 126, "id": 927, "frequency": "f", "synset": "scrub_brush.n.01"}, {"name": "sculpture", "instance_count": 202, "def": "a three-dimensional work of art", "synonyms": ["sculpture"], "image_count": 76, "id": 928, "frequency": "c", "synset": "sculpture.n.01"}, {"name": "seabird", "instance_count": 126, "def": "a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.", "synonyms": ["seabird", "seafowl"], "image_count": 11, "id": 929, "frequency": "c", "synset": "seabird.n.01"}, {"name": "seahorse", "instance_count": 23, "def": "small fish with horse-like heads bent sharply downward and curled tails", "synonyms": ["seahorse"], "image_count": 11, "id": 930, "frequency": "c", "synset": "seahorse.n.02"}, {"name": "seaplane", "instance_count": 4, "def": "an airplane that can land on or take off from water", "synonyms": ["seaplane", "hydroplane"], "image_count": 4, "id": 931, "frequency": "r", "synset": "seaplane.n.01"}, {"name": "seashell", "instance_count": 451, "def": "the shell of a marine organism", "synonyms": ["seashell"], "image_count": 39, "id": 932, "frequency": "c", "synset": "seashell.n.01"}, {"name": "sewing_machine", "instance_count": 11, "def": "a textile machine used as a home appliance for sewing", "synonyms": ["sewing_machine"], "image_count": 11, "id": 933, "frequency": "c", "synset": "sewing_machine.n.01"}, {"name": "shaker", "instance_count": 24, "def": "a container in which something can be shaken", "synonyms": ["shaker"], "image_count": 13, "id": 934, "frequency": "c", "synset": "shaker.n.03"}, {"name": "shampoo", "instance_count": 254, "def": "cleansing agent consisting of soaps or detergents used for washing the hair", "synonyms": ["shampoo"], "image_count": 91, "id": 935, "frequency": "c", "synset": "shampoo.n.01"}, {"name": "shark", "instance_count": 20, "def": "typically large carnivorous fishes with sharpe teeth", "synonyms": ["shark"], "image_count": 14, "id": 936, "frequency": "c", "synset": "shark.n.01"}, {"name": "sharpener", "instance_count": 7, "def": "any implement that is used to make something (an edge or a point) sharper", "synonyms": ["sharpener"], "image_count": 5, "id": 937, "frequency": "r", "synset": "sharpener.n.01"}, {"name": "Sharpie", "instance_count": 5, "def": "a pen with indelible ink that will write on any surface", "synonyms": ["Sharpie"], "image_count": 3, "id": 938, "frequency": "r", "synset": "sharpie.n.03"}, {"name": "shaver_(electric)", "instance_count": 12, "def": "a razor powered by an electric motor", "synonyms": ["shaver_(electric)", "electric_shaver", "electric_razor"], "image_count": 10, "id": 939, "frequency": "r", "synset": "shaver.n.03"}, {"name": "shaving_cream", "instance_count": 33, "def": "toiletry consisting that forms a rich lather for softening the beard before shaving", "synonyms": ["shaving_cream", "shaving_soap"], "image_count": 18, "id": 940, "frequency": "c", "synset": "shaving_cream.n.01"}, {"name": "shawl", "instance_count": 9, "def": "cloak consisting of an oblong piece of cloth used to cover the head and shoulders", "synonyms": ["shawl"], "image_count": 9, "id": 941, "frequency": "r", "synset": "shawl.n.01"}, {"name": "shears", "instance_count": 38, "def": "large scissors with strong blades", "synonyms": ["shears"], "image_count": 6, "id": 942, "frequency": "r", "synset": "shears.n.01"}, {"name": "sheep", "instance_count": 13304, "def": "woolly usually horned ruminant mammal related to the goat", "synonyms": ["sheep"], "image_count": 951, "id": 943, "frequency": "f", "synset": "sheep.n.01"}, {"name": "shepherd_dog", "instance_count": 2, "def": "any of various usually long-haired breeds of dog reared to herd and guard sheep", "synonyms": ["shepherd_dog", "sheepdog"], "image_count": 2, "id": 944, "frequency": "r", "synset": "shepherd_dog.n.01"}, {"name": "sherbert", "instance_count": 2, "def": "a frozen dessert made primarily of fruit juice and sugar", "synonyms": ["sherbert", "sherbet"], "image_count": 1, "id": 945, "frequency": "r", "synset": "sherbert.n.01"}, {"name": "shield", "instance_count": 41, "def": "armor carried on the arm to intercept blows", "synonyms": ["shield"], "image_count": 19, "id": 946, "frequency": "c", "synset": "shield.n.02"}, {"name": "shirt", "instance_count": 10177, "def": "a garment worn on the upper half of the body", "synonyms": ["shirt"], "image_count": 1942, "id": 947, "frequency": "f", "synset": "shirt.n.01"}, {"name": "shoe", "instance_count": 9374, "def": "common footwear covering the foot", "synonyms": ["shoe", "sneaker_(type_of_shoe)", "tennis_shoe"], "image_count": 1916, "id": 948, "frequency": "f", "synset": "shoe.n.01"}, {"name": "shopping_bag", "instance_count": 377, "def": "a bag made of plastic or strong paper (often with handles); used to transport goods after shopping", "synonyms": ["shopping_bag"], "image_count": 139, "id": 949, "frequency": "f", "synset": "shopping_bag.n.01"}, {"name": "shopping_cart", "instance_count": 90, "def": "a handcart that holds groceries or other goods while shopping", "synonyms": ["shopping_cart"], "image_count": 43, "id": 950, "frequency": "c", "synset": "shopping_cart.n.01"}, {"name": "short_pants", "instance_count": 5305, "def": "trousers that end at or above the knee", "synonyms": ["short_pants", "shorts_(clothing)", "trunks_(clothing)"], "image_count": 1969, "id": 951, "frequency": "f", "synset": "short_pants.n.01"}, {"name": "shot_glass", "instance_count": 24, "def": "a small glass adequate to hold a single swallow of whiskey", "synonyms": ["shot_glass"], "image_count": 5, "id": 952, "frequency": "r", "synset": "shot_glass.n.01"}, {"name": "shoulder_bag", "instance_count": 331, "def": "a large handbag that can be carried by a strap looped over the shoulder", "synonyms": ["shoulder_bag"], "image_count": 134, "id": 953, "frequency": "f", "synset": "shoulder_bag.n.01"}, {"name": "shovel", "instance_count": 110, "def": "a hand tool for lifting loose material such as snow, dirt, etc.", "synonyms": ["shovel"], "image_count": 74, "id": 954, "frequency": "c", "synset": "shovel.n.01"}, {"name": "shower_head", "instance_count": 450, "def": "a plumbing fixture that sprays water over you", "synonyms": ["shower_head"], "image_count": 381, "id": 955, "frequency": "f", "synset": "shower.n.01"}, {"name": "shower_cap", "instance_count": 1, "def": "a tight cap worn to keep hair dry while showering", "synonyms": ["shower_cap"], "image_count": 1, "id": 956, "frequency": "r", "synset": "shower_cap.n.01"}, {"name": "shower_curtain", "instance_count": 479, "def": "a curtain that keeps water from splashing out of the shower area", "synonyms": ["shower_curtain"], "image_count": 381, "id": 957, "frequency": "f", "synset": "shower_curtain.n.01"}, {"name": "shredder_(for_paper)", "instance_count": 6, "def": "a device that shreds documents", "synonyms": ["shredder_(for_paper)"], "image_count": 6, "id": 958, "frequency": "r", "synset": "shredder.n.01"}, {"name": "signboard", "instance_count": 8091, "def": "structure displaying a board on which advertisements can be posted", "synonyms": ["signboard"], "image_count": 1826, "id": 959, "frequency": "f", "synset": "signboard.n.01"}, {"name": "silo", "instance_count": 95, "def": "a cylindrical tower used for storing goods", "synonyms": ["silo"], "image_count": 28, "id": 960, "frequency": "c", "synset": "silo.n.01"}, {"name": "sink", "instance_count": 2182, "def": "plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe", "synonyms": ["sink"], "image_count": 1635, "id": 961, "frequency": "f", "synset": "sink.n.01"}, {"name": "skateboard", "instance_count": 3597, "def": "a board with wheels that is ridden in a standing or crouching position and propelled by foot", "synonyms": ["skateboard"], "image_count": 1967, "id": 962, "frequency": "f", "synset": "skateboard.n.01"}, {"name": "skewer", "instance_count": 81, "def": "a long pin for holding meat in position while it is being roasted", "synonyms": ["skewer"], "image_count": 16, "id": 963, "frequency": "c", "synset": "skewer.n.01"}, {"name": "ski", "instance_count": 8496, "def": "sports equipment for skiing on snow", "synonyms": ["ski"], "image_count": 1926, "id": 964, "frequency": "f", "synset": "ski.n.01"}, {"name": "ski_boot", "instance_count": 8124, "def": "a stiff boot that is fastened to a ski with a ski binding", "synonyms": ["ski_boot"], "image_count": 1789, "id": 965, "frequency": "f", "synset": "ski_boot.n.01"}, {"name": "ski_parka", "instance_count": 1727, "def": "a parka to be worn while skiing", "synonyms": ["ski_parka", "ski_jacket"], "image_count": 401, "id": 966, "frequency": "f", "synset": "ski_parka.n.01"}, {"name": "ski_pole", "instance_count": 8263, "def": "a pole with metal points used as an aid in skiing", "synonyms": ["ski_pole"], "image_count": 1968, "id": 967, "frequency": "f", "synset": "ski_pole.n.01"}, {"name": "skirt", "instance_count": 1784, "def": "a garment hanging from the waist; worn mainly by girls and women", "synonyms": ["skirt"], "image_count": 1167, "id": 968, "frequency": "f", "synset": "skirt.n.02"}, {"name": "skullcap", "instance_count": 1, "def": "rounded brimless cap fitting the crown of the head", "synonyms": ["skullcap"], "image_count": 1, "id": 969, "frequency": "r", "synset": "skullcap.n.01"}, {"name": "sled", "instance_count": 102, "def": "a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.", "synonyms": ["sled", "sledge", "sleigh"], "image_count": 56, "id": 970, "frequency": "c", "synset": "sled.n.01"}, {"name": "sleeping_bag", "instance_count": 33, "def": "large padded bag designed to be slept in outdoors", "synonyms": ["sleeping_bag"], "image_count": 17, "id": 971, "frequency": "c", "synset": "sleeping_bag.n.01"}, {"name": "sling_(bandage)", "instance_count": 1, "def": "bandage to support an injured forearm; slung over the shoulder or neck", "synonyms": ["sling_(bandage)", "triangular_bandage"], "image_count": 1, "id": 972, "frequency": "r", "synset": "sling.n.05"}, {"name": "slipper_(footwear)", "instance_count": 121, "def": "low footwear that can be slipped on and off easily; usually worn indoors", "synonyms": ["slipper_(footwear)", "carpet_slipper_(footwear)"], "image_count": 58, "id": 973, "frequency": "c", "synset": "slipper.n.01"}, {"name": "smoothie", "instance_count": 53, "def": "a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk", "synonyms": ["smoothie"], "image_count": 9, "id": 974, "frequency": "r", "synset": "smoothie.n.02"}, {"name": "snake", "instance_count": 16, "def": "limbless scaly elongate reptile; some are venomous", "synonyms": ["snake", "serpent"], "image_count": 8, "id": 975, "frequency": "r", "synset": "snake.n.01"}, {"name": "snowboard", "instance_count": 2119, "def": "a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes", "synonyms": ["snowboard"], "image_count": 1124, "id": 976, "frequency": "f", "synset": "snowboard.n.01"}, {"name": "snowman", "instance_count": 61, "def": "a figure of a person made of packed snow", "synonyms": ["snowman"], "image_count": 31, "id": 977, "frequency": "c", "synset": "snowman.n.01"}, {"name": "snowmobile", "instance_count": 23, "def": "tracked vehicle for travel on snow having skis in front", "synonyms": ["snowmobile"], "image_count": 16, "id": 978, "frequency": "c", "synset": "snowmobile.n.01"}, {"name": "soap", "instance_count": 895, "def": "a cleansing agent made from the salts of vegetable or animal fats", "synonyms": ["soap"], "image_count": 491, "id": 979, "frequency": "f", "synset": "soap.n.01"}, {"name": "soccer_ball", "instance_count": 670, "def": "an inflated ball used in playing soccer (called `football' outside of the United States)", "synonyms": ["soccer_ball"], "image_count": 432, "id": 980, "frequency": "f", "synset": "soccer_ball.n.01"}, {"name": "sock", "instance_count": 6866, "def": "cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee", "synonyms": ["sock"], "image_count": 1945, "id": 981, "frequency": "f", "synset": "sock.n.01"}, {"name": "sofa", "instance_count": 2408, "def": "an upholstered seat for more than one person", "synonyms": ["sofa", "couch", "lounge"], "image_count": 1899, "id": 982, "frequency": "f", "synset": "sofa.n.01"}, {"name": "softball", "instance_count": 5, "def": "ball used in playing softball", "synonyms": ["softball"], "image_count": 5, "id": 983, "frequency": "r", "synset": "softball.n.01"}, {"name": "solar_array", "instance_count": 52, "def": "electrical device consisting of a large array of connected solar cells", "synonyms": ["solar_array", "solar_battery", "solar_panel"], "image_count": 28, "id": 984, "frequency": "c", "synset": "solar_array.n.01"}, {"name": "sombrero", "instance_count": 22, "def": "a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico", "synonyms": ["sombrero"], "image_count": 7, "id": 985, "frequency": "r", "synset": "sombrero.n.02"}, {"name": "soup", "instance_count": 193, "def": "liquid food especially of meat or fish or vegetable stock often containing pieces of solid food", "synonyms": ["soup"], "image_count": 146, "id": 986, "frequency": "f", "synset": "soup.n.01"}, {"name": "soup_bowl", "instance_count": 2, "def": "a bowl for serving soup", "synonyms": ["soup_bowl"], "image_count": 1, "id": 987, "frequency": "r", "synset": "soup_bowl.n.01"}, {"name": "soupspoon", "instance_count": 44, "def": "a spoon with a rounded bowl for eating soup", "synonyms": ["soupspoon"], "image_count": 25, "id": 988, "frequency": "c", "synset": "soupspoon.n.01"}, {"name": "sour_cream", "instance_count": 49, "def": "soured light cream", "synonyms": ["sour_cream", "soured_cream"], "image_count": 22, "id": 989, "frequency": "c", "synset": "sour_cream.n.01"}, {"name": "soya_milk", "instance_count": 2, "def": "a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu", "synonyms": ["soya_milk", "soybean_milk", "soymilk"], "image_count": 1, "id": 990, "frequency": "r", "synset": "soya_milk.n.01"}, {"name": "space_shuttle", "instance_count": 10, "def": "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", "synonyms": ["space_shuttle"], "image_count": 10, "id": 991, "frequency": "r", "synset": "space_shuttle.n.01"}, {"name": "sparkler_(fireworks)", "instance_count": 12, "def": "a firework that burns slowly and throws out a shower of sparks", "synonyms": ["sparkler_(fireworks)"], "image_count": 9, "id": 992, "frequency": "r", "synset": "sparkler.n.02"}, {"name": "spatula", "instance_count": 508, "def": "a hand tool with a thin flexible blade used to mix or spread soft substances", "synonyms": ["spatula"], "image_count": 308, "id": 993, "frequency": "f", "synset": "spatula.n.02"}, {"name": "spear", "instance_count": 9, "def": "a long pointed rod used as a tool or weapon", "synonyms": ["spear", "lance"], "image_count": 4, "id": 994, "frequency": "r", "synset": "spear.n.01"}, {"name": "spectacles", "instance_count": 3040, "def": "optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision", "synonyms": ["spectacles", "specs", "eyeglasses", "glasses"], "image_count": 1969, "id": 995, "frequency": "f", "synset": "spectacles.n.01"}, {"name": "spice_rack", "instance_count": 54, "def": "a rack for displaying containers filled with spices", "synonyms": ["spice_rack"], "image_count": 45, "id": 996, "frequency": "c", "synset": "spice_rack.n.01"}, {"name": "spider", "instance_count": 19, "def": "predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body", "synonyms": ["spider"], "image_count": 12, "id": 997, "frequency": "c", "synset": "spider.n.01"}, {"name": "crawfish", "instance_count": 5, "def": "large edible marine crustacean having a spiny carapace but lacking the large pincers of true lobsters", "synonyms": ["crawfish", "crayfish"], "image_count": 1, "id": 998, "frequency": "r", "synset": "spiny_lobster.n.02"}, {"name": "sponge", "instance_count": 116, "def": "a porous mass usable to absorb water typically used for cleaning", "synonyms": ["sponge"], "image_count": 85, "id": 999, "frequency": "c", "synset": "sponge.n.01"}, {"name": "spoon", "instance_count": 2111, "def": "a piece of cutlery with a shallow bowl-shaped container and a handle", "synonyms": ["spoon"], "image_count": 1127, "id": 1000, "frequency": "f", "synset": "spoon.n.01"}, {"name": "sportswear", "instance_count": 85, "def": "attire worn for sport or for casual wear", "synonyms": ["sportswear", "athletic_wear", "activewear"], "image_count": 11, "id": 1001, "frequency": "c", "synset": "sportswear.n.01"}, {"name": "spotlight", "instance_count": 403, "def": "a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer", "synonyms": ["spotlight"], "image_count": 60, "id": 1002, "frequency": "c", "synset": "spotlight.n.02"}, {"name": "squid_(food)", "instance_count": 6, "def": "(Italian cuisine) squid prepared as food", "synonyms": ["squid_(food)", "calamari", "calamary"], "image_count": 1, "id": 1003, "frequency": "r", "synset": "squid.n.01"}, {"name": "squirrel", "instance_count": 19, "def": "a kind of arboreal rodent having a long bushy tail", "synonyms": ["squirrel"], "image_count": 16, "id": 1004, "frequency": "c", "synset": "squirrel.n.01"}, {"name": "stagecoach", "instance_count": 1, "def": "a large coach-and-four formerly used to carry passengers and mail on regular routes between towns", "synonyms": ["stagecoach"], "image_count": 1, "id": 1005, "frequency": "r", "synset": "stagecoach.n.01"}, {"name": "stapler_(stapling_machine)", "instance_count": 68, "def": "a machine that inserts staples into sheets of paper in order to fasten them together", "synonyms": ["stapler_(stapling_machine)"], "image_count": 65, "id": 1006, "frequency": "c", "synset": "stapler.n.01"}, {"name": "starfish", "instance_count": 28, "def": "echinoderms characterized by five arms extending from a central disk", "synonyms": ["starfish", "sea_star"], "image_count": 13, "id": 1007, "frequency": "c", "synset": "starfish.n.01"}, {"name": "statue_(sculpture)", "instance_count": 1934, "def": "a sculpture representing a human or animal", "synonyms": ["statue_(sculpture)"], "image_count": 655, "id": 1008, "frequency": "f", "synset": "statue.n.01"}, {"name": "steak_(food)", "instance_count": 139, "def": "a slice of meat cut from the fleshy part of an animal or large fish", "synonyms": ["steak_(food)"], "image_count": 51, "id": 1009, "frequency": "c", "synset": "steak.n.01"}, {"name": "steak_knife", "instance_count": 1, "def": "a sharp table knife used in eating steak", "synonyms": ["steak_knife"], "image_count": 1, "id": 1010, "frequency": "r", "synset": "steak_knife.n.01"}, {"name": "steering_wheel", "instance_count": 901, "def": "a handwheel that is used for steering", "synonyms": ["steering_wheel"], "image_count": 673, "id": 1011, "frequency": "f", "synset": "steering_wheel.n.01"}, {"name": "stepladder", "instance_count": 5, "def": "a folding portable ladder hinged at the top", "synonyms": ["stepladder"], "image_count": 5, "id": 1012, "frequency": "r", "synset": "step_ladder.n.01"}, {"name": "step_stool", "instance_count": 43, "def": "a stool that has one or two steps that fold under the seat", "synonyms": ["step_stool"], "image_count": 36, "id": 1013, "frequency": "c", "synset": "step_stool.n.01"}, {"name": "stereo_(sound_system)", "instance_count": 77, "def": "electronic device for playing audio", "synonyms": ["stereo_(sound_system)"], "image_count": 54, "id": 1014, "frequency": "c", "synset": "stereo.n.01"}, {"name": "stew", "instance_count": 7, "def": "food prepared by stewing especially meat or fish with vegetables", "synonyms": ["stew"], "image_count": 5, "id": 1015, "frequency": "r", "synset": "stew.n.02"}, {"name": "stirrer", "instance_count": 18, "def": "an implement used for stirring", "synonyms": ["stirrer"], "image_count": 8, "id": 1016, "frequency": "r", "synset": "stirrer.n.02"}, {"name": "stirrup", "instance_count": 625, "def": "support consisting of metal loops into which rider's feet go", "synonyms": ["stirrup"], "image_count": 305, "id": 1017, "frequency": "f", "synset": "stirrup.n.01"}, {"name": "stool", "instance_count": 583, "def": "a simple seat without a back or arms", "synonyms": ["stool"], "image_count": 297, "id": 1018, "frequency": "f", "synset": "stool.n.01"}, {"name": "stop_sign", "instance_count": 1349, "def": "a traffic sign to notify drivers that they must come to a complete stop", "synonyms": ["stop_sign"], "image_count": 1053, "id": 1019, "frequency": "f", "synset": "stop_sign.n.01"}, {"name": "brake_light", "instance_count": 1334, "def": "a red light on the rear of a motor vehicle that signals when the brakes are applied", "synonyms": ["brake_light"], "image_count": 223, "id": 1020, "frequency": "f", "synset": "stoplight.n.01"}, {"name": "stove", "instance_count": 1133, "def": "a kitchen appliance used for cooking food", "synonyms": ["stove", "kitchen_stove", "range_(kitchen_appliance)", "kitchen_range", "cooking_stove"], "image_count": 1037, "id": 1021, "frequency": "f", "synset": "stove.n.01"}, {"name": "strainer", "instance_count": 99, "def": "a filter to retain larger pieces while smaller pieces and liquids pass through", "synonyms": ["strainer"], "image_count": 63, "id": 1022, "frequency": "c", "synset": "strainer.n.01"}, {"name": "strap", "instance_count": 7435, "def": "an elongated strip of material for binding things together or holding", "synonyms": ["strap"], "image_count": 1881, "id": 1023, "frequency": "f", "synset": "strap.n.01"}, {"name": "straw_(for_drinking)", "instance_count": 1154, "def": "a thin paper or plastic tube used to suck liquids into the mouth", "synonyms": ["straw_(for_drinking)", "drinking_straw"], "image_count": 507, "id": 1024, "frequency": "f", "synset": "straw.n.04"}, {"name": "strawberry", "instance_count": 4386, "def": "sweet fleshy red fruit", "synonyms": ["strawberry"], "image_count": 333, "id": 1025, "frequency": "f", "synset": "strawberry.n.01"}, {"name": "street_sign", "instance_count": 8350, "def": "a sign visible from the street", "synonyms": ["street_sign"], "image_count": 1911, "id": 1026, "frequency": "f", "synset": "street_sign.n.01"}, {"name": "streetlight", "instance_count": 7381, "def": "a lamp supported on a lamppost; for illuminating a street", "synonyms": ["streetlight", "street_lamp"], "image_count": 1765, "id": 1027, "frequency": "f", "synset": "streetlight.n.01"}, {"name": "string_cheese", "instance_count": 1, "def": "cheese formed in long strings twisted together", "synonyms": ["string_cheese"], "image_count": 1, "id": 1028, "frequency": "r", "synset": "string_cheese.n.01"}, {"name": "stylus", "instance_count": 11, "def": "a pointed tool for writing or drawing or engraving, including pens", "synonyms": ["stylus"], "image_count": 5, "id": 1029, "frequency": "r", "synset": "stylus.n.02"}, {"name": "subwoofer", "instance_count": 1, "def": "a loudspeaker that is designed to reproduce very low bass frequencies", "synonyms": ["subwoofer"], "image_count": 1, "id": 1030, "frequency": "r", "synset": "subwoofer.n.01"}, {"name": "sugar_bowl", "instance_count": 10, "def": "a dish in which sugar is served", "synonyms": ["sugar_bowl"], "image_count": 9, "id": 1031, "frequency": "r", "synset": "sugar_bowl.n.01"}, {"name": "sugarcane_(plant)", "instance_count": 31, "def": "juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice", "synonyms": ["sugarcane_(plant)"], "image_count": 2, "id": 1032, "frequency": "r", "synset": "sugarcane.n.01"}, {"name": "suit_(clothing)", "instance_count": 461, "def": "a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color", "synonyms": ["suit_(clothing)"], "image_count": 151, "id": 1033, "frequency": "f", "synset": "suit.n.01"}, {"name": "sunflower", "instance_count": 618, "def": "any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays", "synonyms": ["sunflower"], "image_count": 82, "id": 1034, "frequency": "c", "synset": "sunflower.n.01"}, {"name": "sunglasses", "instance_count": 5603, "def": "spectacles that are darkened or polarized to protect the eyes from the glare of the sun", "synonyms": ["sunglasses"], "image_count": 1931, "id": 1035, "frequency": "f", "synset": "sunglasses.n.01"}, {"name": "sunhat", "instance_count": 170, "def": "a hat with a broad brim that protects the face from direct exposure to the sun", "synonyms": ["sunhat"], "image_count": 41, "id": 1036, "frequency": "c", "synset": "sunhat.n.01"}, {"name": "surfboard", "instance_count": 3835, "def": "a narrow buoyant board for riding surf", "synonyms": ["surfboard"], "image_count": 1895, "id": 1037, "frequency": "f", "synset": "surfboard.n.01"}, {"name": "sushi", "instance_count": 337, "def": "rice (with raw fish) wrapped in seaweed", "synonyms": ["sushi"], "image_count": 24, "id": 1038, "frequency": "c", "synset": "sushi.n.01"}, {"name": "mop", "instance_count": 22, "def": "cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors", "synonyms": ["mop"], "image_count": 22, "id": 1039, "frequency": "c", "synset": "swab.n.02"}, {"name": "sweat_pants", "instance_count": 56, "def": "loose-fitting trousers with elastic cuffs; worn by athletes", "synonyms": ["sweat_pants"], "image_count": 35, "id": 1040, "frequency": "c", "synset": "sweat_pants.n.01"}, {"name": "sweatband", "instance_count": 145, "def": "a band of material tied around the forehead or wrist to absorb sweat", "synonyms": ["sweatband"], "image_count": 69, "id": 1041, "frequency": "c", "synset": "sweatband.n.02"}, {"name": "sweater", "instance_count": 1894, "def": "a crocheted or knitted garment covering the upper part of the body", "synonyms": ["sweater"], "image_count": 962, "id": 1042, "frequency": "f", "synset": "sweater.n.01"}, {"name": "sweatshirt", "instance_count": 1482, "def": "cotton knit pullover with long sleeves worn during athletic activity", "synonyms": ["sweatshirt"], "image_count": 588, "id": 1043, "frequency": "f", "synset": "sweatshirt.n.01"}, {"name": "sweet_potato", "instance_count": 137, "def": "the edible tuberous root of the sweet potato vine", "synonyms": ["sweet_potato"], "image_count": 21, "id": 1044, "frequency": "c", "synset": "sweet_potato.n.02"}, {"name": "swimsuit", "instance_count": 3141, "def": "garment worn for swimming", "synonyms": ["swimsuit", "swimwear", "bathing_suit", "swimming_costume", "bathing_costume", "swimming_trunks", "bathing_trunks"], "image_count": 825, "id": 1045, "frequency": "f", "synset": "swimsuit.n.01"}, {"name": "sword", "instance_count": 72, "def": "a cutting or thrusting weapon that has a long metal blade", "synonyms": ["sword"], "image_count": 52, "id": 1046, "frequency": "c", "synset": "sword.n.01"}, {"name": "syringe", "instance_count": 14, "def": "a medical instrument used to inject or withdraw fluids", "synonyms": ["syringe"], "image_count": 5, "id": 1047, "frequency": "r", "synset": "syringe.n.01"}, {"name": "Tabasco_sauce", "instance_count": 5, "def": "very spicy sauce (trade name Tabasco) made from fully-aged red peppers", "synonyms": ["Tabasco_sauce"], "image_count": 5, "id": 1048, "frequency": "r", "synset": "tabasco.n.02"}, {"name": "table-tennis_table", "instance_count": 5, "def": "a table used for playing table tennis", "synonyms": ["table-tennis_table", "ping-pong_table"], "image_count": 5, "id": 1049, "frequency": "r", "synset": "table-tennis_table.n.01"}, {"name": "table", "instance_count": 2804, "def": "a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs", "synonyms": ["table"], "image_count": 1860, "id": 1050, "frequency": "f", "synset": "table.n.02"}, {"name": "table_lamp", "instance_count": 81, "def": "a lamp that sits on a table", "synonyms": ["table_lamp"], "image_count": 56, "id": 1051, "frequency": "c", "synset": "table_lamp.n.01"}, {"name": "tablecloth", "instance_count": 2496, "def": "a covering spread over a dining table", "synonyms": ["tablecloth"], "image_count": 1582, "id": 1052, "frequency": "f", "synset": "tablecloth.n.01"}, {"name": "tachometer", "instance_count": 10, "def": "measuring instrument for indicating speed of rotation", "synonyms": ["tachometer"], "image_count": 7, "id": 1053, "frequency": "r", "synset": "tachometer.n.01"}, {"name": "taco", "instance_count": 21, "def": "a small tortilla cupped around a filling", "synonyms": ["taco"], "image_count": 2, "id": 1054, "frequency": "r", "synset": "taco.n.02"}, {"name": "tag", "instance_count": 7550, "def": "a label associated with something for the purpose of identification or information", "synonyms": ["tag"], "image_count": 1562, "id": 1055, "frequency": "f", "synset": "tag.n.02"}, {"name": "taillight", "instance_count": 9222, "def": "lamp (usually red) mounted at the rear of a motor vehicle", "synonyms": ["taillight", "rear_light"], "image_count": 1885, "id": 1056, "frequency": "f", "synset": "taillight.n.01"}, {"name": "tambourine", "instance_count": 1, "def": "a shallow drum with a single drumhead and with metallic disks in the sides", "synonyms": ["tambourine"], "image_count": 1, "id": 1057, "frequency": "r", "synset": "tambourine.n.01"}, {"name": "army_tank", "instance_count": 7, "def": "an enclosed armored military vehicle; has a cannon and moves on caterpillar treads", "synonyms": ["army_tank", "armored_combat_vehicle", "armoured_combat_vehicle"], "image_count": 5, "id": 1058, "frequency": "r", "synset": "tank.n.01"}, {"name": "tank_(storage_vessel)", "instance_count": 304, "def": "a large (usually metallic) vessel for holding gases or liquids", "synonyms": ["tank_(storage_vessel)", "storage_tank"], "image_count": 137, "id": 1059, "frequency": "f", "synset": "tank.n.02"}, {"name": "tank_top_(clothing)", "instance_count": 1799, "def": "a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening", "synonyms": ["tank_top_(clothing)"], "image_count": 1094, "id": 1060, "frequency": "f", "synset": "tank_top.n.01"}, {"name": "tape_(sticky_cloth_or_paper)", "instance_count": 560, "def": "a long thin piece of cloth or paper as used for binding or fastening", "synonyms": ["tape_(sticky_cloth_or_paper)"], "image_count": 134, "id": 1061, "frequency": "f", "synset": "tape.n.01"}, {"name": "tape_measure", "instance_count": 35, "def": "measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths", "synonyms": ["tape_measure", "measuring_tape"], "image_count": 29, "id": 1062, "frequency": "c", "synset": "tape.n.04"}, {"name": "tapestry", "instance_count": 29, "def": "a heavy textile with a woven design; used for curtains and upholstery", "synonyms": ["tapestry"], "image_count": 22, "id": 1063, "frequency": "c", "synset": "tapestry.n.02"}, {"name": "tarp", "instance_count": 1315, "def": "waterproofed canvas", "synonyms": ["tarp"], "image_count": 522, "id": 1064, "frequency": "f", "synset": "tarpaulin.n.01"}, {"name": "tartan", "instance_count": 68, "def": "a cloth having a crisscross design", "synonyms": ["tartan", "plaid"], "image_count": 50, "id": 1065, "frequency": "c", "synset": "tartan.n.01"}, {"name": "tassel", "instance_count": 276, "def": "adornment consisting of a bunch of cords fastened at one end", "synonyms": ["tassel"], "image_count": 68, "id": 1066, "frequency": "c", "synset": "tassel.n.01"}, {"name": "tea_bag", "instance_count": 42, "def": "a measured amount of tea in a bag for an individual serving of tea", "synonyms": ["tea_bag"], "image_count": 16, "id": 1067, "frequency": "c", "synset": "tea_bag.n.01"}, {"name": "teacup", "instance_count": 152, "def": "a cup from which tea is drunk", "synonyms": ["teacup"], "image_count": 40, "id": 1068, "frequency": "c", "synset": "teacup.n.02"}, {"name": "teakettle", "instance_count": 40, "def": "kettle for boiling water to make tea", "synonyms": ["teakettle"], "image_count": 35, "id": 1069, "frequency": "c", "synset": "teakettle.n.01"}, {"name": "teapot", "instance_count": 209, "def": "pot for brewing tea; usually has a spout and handle", "synonyms": ["teapot"], "image_count": 135, "id": 1070, "frequency": "f", "synset": "teapot.n.01"}, {"name": "teddy_bear", "instance_count": 4886, "def": "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", "synonyms": ["teddy_bear"], "image_count": 1413, "id": 1071, "frequency": "f", "synset": "teddy.n.01"}, {"name": "telephone", "instance_count": 945, "def": "electronic device for communicating by voice over long distances (includes wired and wireless/cell phones)", "synonyms": ["telephone", "phone", "telephone_set"], "image_count": 772, "id": 1072, "frequency": "f", "synset": "telephone.n.01"}, {"name": "telephone_booth", "instance_count": 62, "def": "booth for using a telephone", "synonyms": ["telephone_booth", "phone_booth", "call_box", "telephone_box", "telephone_kiosk"], "image_count": 50, "id": 1073, "frequency": "c", "synset": "telephone_booth.n.01"}, {"name": "telephone_pole", "instance_count": 3725, "def": "tall pole supporting telephone wires", "synonyms": ["telephone_pole", "telegraph_pole", "telegraph_post"], "image_count": 1015, "id": 1074, "frequency": "f", "synset": "telephone_pole.n.01"}, {"name": "telephoto_lens", "instance_count": 1, "def": "a camera lens that magnifies the image", "synonyms": ["telephoto_lens", "zoom_lens"], "image_count": 1, "id": 1075, "frequency": "r", "synset": "telephoto_lens.n.01"}, {"name": "television_camera", "instance_count": 117, "def": "television equipment for capturing and recording video", "synonyms": ["television_camera", "tv_camera"], "image_count": 65, "id": 1076, "frequency": "c", "synset": "television_camera.n.01"}, {"name": "television_set", "instance_count": 2205, "def": "an electronic device that receives television signals and displays them on a screen", "synonyms": ["television_set", "tv", "tv_set"], "image_count": 1900, "id": 1077, "frequency": "f", "synset": "television_receiver.n.01"}, {"name": "tennis_ball", "instance_count": 2835, "def": "ball about the size of a fist used in playing tennis", "synonyms": ["tennis_ball"], "image_count": 1302, "id": 1078, "frequency": "f", "synset": "tennis_ball.n.01"}, {"name": "tennis_racket", "instance_count": 3035, "def": "a racket used to play tennis", "synonyms": ["tennis_racket"], "image_count": 1977, "id": 1079, "frequency": "f", "synset": "tennis_racket.n.01"}, {"name": "tequila", "instance_count": 2, "def": "Mexican liquor made from fermented juices of an agave plant", "synonyms": ["tequila"], "image_count": 2, "id": 1080, "frequency": "r", "synset": "tequila.n.01"}, {"name": "thermometer", "instance_count": 33, "def": "measuring instrument for measuring temperature", "synonyms": ["thermometer"], "image_count": 29, "id": 1081, "frequency": "c", "synset": "thermometer.n.01"}, {"name": "thermos_bottle", "instance_count": 49, "def": "vacuum flask that preserves temperature of hot or cold drinks", "synonyms": ["thermos_bottle"], "image_count": 36, "id": 1082, "frequency": "c", "synset": "thermos.n.01"}, {"name": "thermostat", "instance_count": 153, "def": "a regulator for automatically regulating temperature by starting or stopping the supply of heat", "synonyms": ["thermostat"], "image_count": 138, "id": 1083, "frequency": "f", "synset": "thermostat.n.01"}, {"name": "thimble", "instance_count": 6, "def": "a small metal cap to protect the finger while sewing; can be used as a small container", "synonyms": ["thimble"], "image_count": 4, "id": 1084, "frequency": "r", "synset": "thimble.n.02"}, {"name": "thread", "instance_count": 320, "def": "a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving", "synonyms": ["thread", "yarn"], "image_count": 67, "id": 1085, "frequency": "c", "synset": "thread.n.01"}, {"name": "thumbtack", "instance_count": 224, "def": "a tack for attaching papers to a bulletin board or drawing board", "synonyms": ["thumbtack", "drawing_pin", "pushpin"], "image_count": 26, "id": 1086, "frequency": "c", "synset": "thumbtack.n.01"}, {"name": "tiara", "instance_count": 31, "def": "a jeweled headdress worn by women on formal occasions", "synonyms": ["tiara"], "image_count": 25, "id": 1087, "frequency": "c", "synset": "tiara.n.01"}, {"name": "tiger", "instance_count": 67, "def": "large feline of forests in most of Asia having a tawny coat with black stripes", "synonyms": ["tiger"], "image_count": 33, "id": 1088, "frequency": "c", "synset": "tiger.n.02"}, {"name": "tights_(clothing)", "instance_count": 45, "def": "skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls", "synonyms": ["tights_(clothing)", "leotards"], "image_count": 37, "id": 1089, "frequency": "c", "synset": "tights.n.01"}, {"name": "timer", "instance_count": 62, "def": "a timepiece that measures a time interval and signals its end", "synonyms": ["timer", "stopwatch"], "image_count": 50, "id": 1090, "frequency": "c", "synset": "timer.n.01"}, {"name": "tinfoil", "instance_count": 421, "def": "foil made of tin or an alloy of tin and lead", "synonyms": ["tinfoil"], "image_count": 270, "id": 1091, "frequency": "f", "synset": "tinfoil.n.01"}, {"name": "tinsel", "instance_count": 70, "def": "a showy decoration that is basically valueless", "synonyms": ["tinsel"], "image_count": 12, "id": 1092, "frequency": "c", "synset": "tinsel.n.01"}, {"name": "tissue_paper", "instance_count": 587, "def": "a soft thin (usually translucent) paper", "synonyms": ["tissue_paper"], "image_count": 316, "id": 1093, "frequency": "f", "synset": "tissue.n.02"}, {"name": "toast_(food)", "instance_count": 125, "def": "slice of bread that has been toasted", "synonyms": ["toast_(food)"], "image_count": 41, "id": 1094, "frequency": "c", "synset": "toast.n.01"}, {"name": "toaster", "instance_count": 240, "def": "a kitchen appliance (usually electric) for toasting bread", "synonyms": ["toaster"], "image_count": 224, "id": 1095, "frequency": "f", "synset": "toaster.n.02"}, {"name": "toaster_oven", "instance_count": 114, "def": "kitchen appliance consisting of a small electric oven for toasting or warming food", "synonyms": ["toaster_oven"], "image_count": 105, "id": 1096, "frequency": "f", "synset": "toaster_oven.n.01"}, {"name": "toilet", "instance_count": 2295, "def": "a plumbing fixture for defecation and urination", "synonyms": ["toilet"], "image_count": 1925, "id": 1097, "frequency": "f", "synset": "toilet.n.02"}, {"name": "toilet_tissue", "instance_count": 1683, "def": "a soft thin absorbent paper for use in toilets", "synonyms": ["toilet_tissue", "toilet_paper", "bathroom_tissue"], "image_count": 1021, "id": 1098, "frequency": "f", "synset": "toilet_tissue.n.01"}, {"name": "tomato", "instance_count": 12338, "def": "mildly acid red or yellow pulpy fruit eaten as a vegetable", "synonyms": ["tomato"], "image_count": 1213, "id": 1099, "frequency": "f", "synset": "tomato.n.01"}, {"name": "tongs", "instance_count": 294, "def": "any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below", "synonyms": ["tongs"], "image_count": 172, "id": 1100, "frequency": "f", "synset": "tongs.n.01"}, {"name": "toolbox", "instance_count": 39, "def": "a box or chest or cabinet for holding hand tools", "synonyms": ["toolbox"], "image_count": 28, "id": 1101, "frequency": "c", "synset": "toolbox.n.01"}, {"name": "toothbrush", "instance_count": 1683, "def": "small brush; has long handle; used to clean teeth", "synonyms": ["toothbrush"], "image_count": 745, "id": 1102, "frequency": "f", "synset": "toothbrush.n.01"}, {"name": "toothpaste", "instance_count": 326, "def": "a dentifrice in the form of a paste", "synonyms": ["toothpaste"], "image_count": 187, "id": 1103, "frequency": "f", "synset": "toothpaste.n.01"}, {"name": "toothpick", "instance_count": 423, "def": "pick consisting of a small strip of wood or plastic; used to pick food from between the teeth", "synonyms": ["toothpick"], "image_count": 147, "id": 1104, "frequency": "f", "synset": "toothpick.n.01"}, {"name": "cover", "instance_count": 306, "def": "covering for a hole (especially a hole in the top of a container)", "synonyms": ["cover"], "image_count": 136, "id": 1105, "frequency": "f", "synset": "top.n.09"}, {"name": "tortilla", "instance_count": 135, "def": "thin unleavened pancake made from cornmeal or wheat flour", "synonyms": ["tortilla"], "image_count": 34, "id": 1106, "frequency": "c", "synset": "tortilla.n.01"}, {"name": "tow_truck", "instance_count": 45, "def": "a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)", "synonyms": ["tow_truck"], "image_count": 41, "id": 1107, "frequency": "c", "synset": "tow_truck.n.01"}, {"name": "towel", "instance_count": 2212, "def": "a rectangular piece of absorbent cloth (or paper) for drying or wiping", "synonyms": ["towel"], "image_count": 636, "id": 1108, "frequency": "f", "synset": "towel.n.01"}, {"name": "towel_rack", "instance_count": 987, "def": "a rack consisting of one or more bars on which towels can be hung", "synonyms": ["towel_rack", "towel_rail", "towel_bar"], "image_count": 570, "id": 1109, "frequency": "f", "synset": "towel_rack.n.01"}, {"name": "toy", "instance_count": 6756, "def": "a device regarded as providing amusement", "synonyms": ["toy"], "image_count": 1149, "id": 1110, "frequency": "f", "synset": "toy.n.03"}, {"name": "tractor_(farm_equipment)", "instance_count": 80, "def": "a wheeled vehicle with large wheels; used in farming and other applications", "synonyms": ["tractor_(farm_equipment)"], "image_count": 61, "id": 1111, "frequency": "c", "synset": "tractor.n.01"}, {"name": "traffic_light", "instance_count": 7298, "def": "a device to control vehicle traffic often consisting of three or more lights", "synonyms": ["traffic_light"], "image_count": 1890, "id": 1112, "frequency": "f", "synset": "traffic_light.n.01"}, {"name": "dirt_bike", "instance_count": 47, "def": "a lightweight motorcycle equipped with rugged tires and suspension for off-road use", "synonyms": ["dirt_bike"], "image_count": 18, "id": 1113, "frequency": "c", "synset": "trail_bike.n.01"}, {"name": "trailer_truck", "instance_count": 297, "def": "a truck consisting of a tractor and trailer together", "synonyms": ["trailer_truck", "tractor_trailer", "trucking_rig", "articulated_lorry", "semi_truck"], "image_count": 143, "id": 1114, "frequency": "f", "synset": "trailer_truck.n.01"}, {"name": "train_(railroad_vehicle)", "instance_count": 2192, "def": "public or private transport provided by a line of railway cars coupled together and drawn by a locomotive", "synonyms": ["train_(railroad_vehicle)", "railroad_train"], "image_count": 1517, "id": 1115, "frequency": "f", "synset": "train.n.01"}, {"name": "trampoline", "instance_count": 7, "def": "gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame", "synonyms": ["trampoline"], "image_count": 7, "id": 1116, "frequency": "r", "synset": "trampoline.n.01"}, {"name": "tray", "instance_count": 2397, "def": "an open receptacle for holding or displaying or serving articles or food", "synonyms": ["tray"], "image_count": 943, "id": 1117, "frequency": "f", "synset": "tray.n.01"}, {"name": "trench_coat", "instance_count": 16, "def": "a military style raincoat; belted with deep pockets", "synonyms": ["trench_coat"], "image_count": 6, "id": 1118, "frequency": "r", "synset": "trench_coat.n.01"}, {"name": "triangle_(musical_instrument)", "instance_count": 1, "def": "a percussion instrument consisting of a metal bar bent in the shape of an open triangle", "synonyms": ["triangle_(musical_instrument)"], "image_count": 1, "id": 1119, "frequency": "r", "synset": "triangle.n.05"}, {"name": "tricycle", "instance_count": 15, "def": "a vehicle with three wheels that is moved by foot pedals", "synonyms": ["tricycle"], "image_count": 11, "id": 1120, "frequency": "c", "synset": "tricycle.n.01"}, {"name": "tripod", "instance_count": 132, "def": "a three-legged rack used for support", "synonyms": ["tripod"], "image_count": 101, "id": 1121, "frequency": "f", "synset": "tripod.n.01"}, {"name": "trousers", "instance_count": 7806, "def": "a garment extending from the waist to the knee or ankle, covering each leg separately", "synonyms": ["trousers", "pants_(clothing)"], "image_count": 1909, "id": 1122, "frequency": "f", "synset": "trouser.n.01"}, {"name": "truck", "instance_count": 1797, "def": "an automotive vehicle suitable for hauling", "synonyms": ["truck"], "image_count": 800, "id": 1123, "frequency": "f", "synset": "truck.n.01"}, {"name": "truffle_(chocolate)", "instance_count": 4, "def": "creamy chocolate candy", "synonyms": ["truffle_(chocolate)", "chocolate_truffle"], "image_count": 1, "id": 1124, "frequency": "r", "synset": "truffle.n.03"}, {"name": "trunk", "instance_count": 334, "def": "luggage consisting of a large strong case used when traveling or for storage", "synonyms": ["trunk"], "image_count": 44, "id": 1125, "frequency": "c", "synset": "trunk.n.02"}, {"name": "vat", "instance_count": 15, "def": "a large vessel for holding or storing liquids", "synonyms": ["vat"], "image_count": 3, "id": 1126, "frequency": "r", "synset": "tub.n.02"}, {"name": "turban", "instance_count": 124, "def": "a traditional headdress consisting of a long scarf wrapped around the head", "synonyms": ["turban"], "image_count": 44, "id": 1127, "frequency": "c", "synset": "turban.n.01"}, {"name": "turkey_(food)", "instance_count": 120, "def": "flesh of large domesticated fowl usually roasted", "synonyms": ["turkey_(food)"], "image_count": 31, "id": 1128, "frequency": "c", "synset": "turkey.n.04"}, {"name": "turnip", "instance_count": 109, "def": "widely cultivated plant having a large fleshy edible white or yellow root", "synonyms": ["turnip"], "image_count": 7, "id": 1129, "frequency": "r", "synset": "turnip.n.01"}, {"name": "turtle", "instance_count": 31, "def": "any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming", "synonyms": ["turtle"], "image_count": 20, "id": 1130, "frequency": "c", "synset": "turtle.n.02"}, {"name": "turtleneck_(clothing)", "instance_count": 13, "def": "a sweater or jersey with a high close-fitting collar", "synonyms": ["turtleneck_(clothing)", "polo-neck"], "image_count": 11, "id": 1131, "frequency": "c", "synset": "turtleneck.n.01"}, {"name": "typewriter", "instance_count": 14, "def": "hand-operated character printer for printing written messages one character at a time", "synonyms": ["typewriter"], "image_count": 13, "id": 1132, "frequency": "c", "synset": "typewriter.n.01"}, {"name": "umbrella", "instance_count": 9161, "def": "a lightweight handheld collapsible canopy", "synonyms": ["umbrella"], "image_count": 1924, "id": 1133, "frequency": "f", "synset": "umbrella.n.01"}, {"name": "underwear", "instance_count": 164, "def": "undergarment worn next to the skin and under the outer garments", "synonyms": ["underwear", "underclothes", "underclothing", "underpants"], "image_count": 113, "id": 1134, "frequency": "f", "synset": "underwear.n.01"}, {"name": "unicycle", "instance_count": 2, "def": "a vehicle with a single wheel that is driven by pedals", "synonyms": ["unicycle"], "image_count": 2, "id": 1135, "frequency": "r", "synset": "unicycle.n.01"}, {"name": "urinal", "instance_count": 381, "def": "a plumbing fixture (usually attached to the wall) used by men to urinate", "synonyms": ["urinal"], "image_count": 139, "id": 1136, "frequency": "f", "synset": "urinal.n.01"}, {"name": "urn", "instance_count": 81, "def": "a large vase that usually has a pedestal or feet", "synonyms": ["urn"], "image_count": 12, "id": 1137, "frequency": "c", "synset": "urn.n.01"}, {"name": "vacuum_cleaner", "instance_count": 38, "def": "an electrical home appliance that cleans by suction", "synonyms": ["vacuum_cleaner"], "image_count": 37, "id": 1138, "frequency": "c", "synset": "vacuum.n.04"}, {"name": "vase", "instance_count": 4971, "def": "an open jar of glass or porcelain used as an ornament or to hold flowers", "synonyms": ["vase"], "image_count": 1866, "id": 1139, "frequency": "f", "synset": "vase.n.01"}, {"name": "vending_machine", "instance_count": 65, "def": "a slot machine for selling goods", "synonyms": ["vending_machine"], "image_count": 47, "id": 1140, "frequency": "c", "synset": "vending_machine.n.01"}, {"name": "vent", "instance_count": 3370, "def": "a hole for the escape of gas or air", "synonyms": ["vent", "blowhole", "air_vent"], "image_count": 1468, "id": 1141, "frequency": "f", "synset": "vent.n.01"}, {"name": "vest", "instance_count": 1313, "def": "a man's sleeveless garment worn underneath a coat", "synonyms": ["vest", "waistcoat"], "image_count": 729, "id": 1142, "frequency": "f", "synset": "vest.n.01"}, {"name": "videotape", "instance_count": 228, "def": "a video recording made on magnetic tape", "synonyms": ["videotape"], "image_count": 24, "id": 1143, "frequency": "c", "synset": "videotape.n.01"}, {"name": "vinegar", "instance_count": 1, "def": "sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative", "synonyms": ["vinegar"], "image_count": 1, "id": 1144, "frequency": "r", "synset": "vinegar.n.01"}, {"name": "violin", "instance_count": 10, "def": "bowed stringed instrument that is the highest member of the violin family", "synonyms": ["violin", "fiddle"], "image_count": 10, "id": 1145, "frequency": "r", "synset": "violin.n.01"}, {"name": "vodka", "instance_count": 3, "def": "unaged colorless liquor originating in Russia", "synonyms": ["vodka"], "image_count": 3, "id": 1146, "frequency": "r", "synset": "vodka.n.01"}, {"name": "volleyball", "instance_count": 33, "def": "an inflated ball used in playing volleyball", "synonyms": ["volleyball"], "image_count": 14, "id": 1147, "frequency": "c", "synset": "volleyball.n.02"}, {"name": "vulture", "instance_count": 16, "def": "any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion", "synonyms": ["vulture"], "image_count": 4, "id": 1148, "frequency": "r", "synset": "vulture.n.01"}, {"name": "waffle", "instance_count": 61, "def": "pancake batter baked in a waffle iron", "synonyms": ["waffle"], "image_count": 29, "id": 1149, "frequency": "c", "synset": "waffle.n.01"}, {"name": "waffle_iron", "instance_count": 4, "def": "a kitchen appliance for baking waffles", "synonyms": ["waffle_iron"], "image_count": 4, "id": 1150, "frequency": "r", "synset": "waffle_iron.n.01"}, {"name": "wagon", "instance_count": 121, "def": "any of various kinds of wheeled vehicles drawn by an animal or a tractor", "synonyms": ["wagon"], "image_count": 70, "id": 1151, "frequency": "c", "synset": "wagon.n.01"}, {"name": "wagon_wheel", "instance_count": 209, "def": "a wheel of a wagon", "synonyms": ["wagon_wheel"], "image_count": 46, "id": 1152, "frequency": "c", "synset": "wagon_wheel.n.01"}, {"name": "walking_stick", "instance_count": 21, "def": "a stick carried in the hand for support in walking", "synonyms": ["walking_stick"], "image_count": 14, "id": 1153, "frequency": "c", "synset": "walking_stick.n.01"}, {"name": "wall_clock", "instance_count": 100, "def": "a clock mounted on a wall", "synonyms": ["wall_clock"], "image_count": 48, "id": 1154, "frequency": "c", "synset": "wall_clock.n.01"}, {"name": "wall_socket", "instance_count": 3069, "def": "receptacle providing a place in a wiring system where current can be taken to run electrical devices", "synonyms": ["wall_socket", "wall_plug", "electric_outlet", "electrical_outlet", "outlet", "electric_receptacle"], "image_count": 1855, "id": 1155, "frequency": "f", "synset": "wall_socket.n.01"}, {"name": "wallet", "instance_count": 123, "def": "a pocket-size case for holding papers and paper money", "synonyms": ["wallet", "billfold"], "image_count": 113, "id": 1156, "frequency": "f", "synset": "wallet.n.01"}, {"name": "walrus", "instance_count": 1, "def": "either of two large northern marine mammals having ivory tusks and tough hide over thick blubber", "synonyms": ["walrus"], "image_count": 1, "id": 1157, "frequency": "r", "synset": "walrus.n.01"}, {"name": "wardrobe", "instance_count": 1, "def": "a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes", "synonyms": ["wardrobe"], "image_count": 1, "id": 1158, "frequency": "r", "synset": "wardrobe.n.01"}, {"name": "washbasin", "instance_count": 15, "def": "a bathroom sink that is permanently installed and connected to a water supply and drainpipe; where you can wash your hands and face", "synonyms": ["washbasin", "basin_(for_washing)", "washbowl", "washstand", "handbasin"], "image_count": 10, "id": 1159, "frequency": "r", "synset": "washbasin.n.01"}, {"name": "automatic_washer", "instance_count": 68, "def": "a home appliance for washing clothes and linens automatically", "synonyms": ["automatic_washer", "washing_machine"], "image_count": 54, "id": 1160, "frequency": "c", "synset": "washer.n.03"}, {"name": "watch", "instance_count": 2703, "def": "a small, portable timepiece", "synonyms": ["watch", "wristwatch"], "image_count": 1923, "id": 1161, "frequency": "f", "synset": "watch.n.01"}, {"name": "water_bottle", "instance_count": 1449, "def": "a bottle for holding water", "synonyms": ["water_bottle"], "image_count": 630, "id": 1162, "frequency": "f", "synset": "water_bottle.n.01"}, {"name": "water_cooler", "instance_count": 39, "def": "a device for cooling and dispensing drinking water", "synonyms": ["water_cooler"], "image_count": 31, "id": 1163, "frequency": "c", "synset": "water_cooler.n.01"}, {"name": "water_faucet", "instance_count": 109, "def": "a faucet for drawing water from a pipe or cask", "synonyms": ["water_faucet", "water_tap", "tap_(water_faucet)"], "image_count": 69, "id": 1164, "frequency": "c", "synset": "water_faucet.n.01"}, {"name": "water_heater", "instance_count": 7, "def": "a heater and storage tank to supply heated water", "synonyms": ["water_heater", "hot-water_heater"], "image_count": 7, "id": 1165, "frequency": "r", "synset": "water_heater.n.01"}, {"name": "water_jug", "instance_count": 23, "def": "a jug that holds water", "synonyms": ["water_jug"], "image_count": 11, "id": 1166, "frequency": "c", "synset": "water_jug.n.01"}, {"name": "water_gun", "instance_count": 1, "def": "plaything consisting of a toy pistol that squirts water", "synonyms": ["water_gun", "squirt_gun"], "image_count": 1, "id": 1167, "frequency": "r", "synset": "water_pistol.n.01"}, {"name": "water_scooter", "instance_count": 54, "def": "a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)", "synonyms": ["water_scooter", "sea_scooter", "jet_ski"], "image_count": 30, "id": 1168, "frequency": "c", "synset": "water_scooter.n.01"}, {"name": "water_ski", "instance_count": 98, "def": "broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)", "synonyms": ["water_ski"], "image_count": 50, "id": 1169, "frequency": "c", "synset": "water_ski.n.01"}, {"name": "water_tower", "instance_count": 60, "def": "a large reservoir for water", "synonyms": ["water_tower"], "image_count": 45, "id": 1170, "frequency": "c", "synset": "water_tower.n.01"}, {"name": "watering_can", "instance_count": 44, "def": "a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants", "synonyms": ["watering_can"], "image_count": 28, "id": 1171, "frequency": "c", "synset": "watering_can.n.01"}, {"name": "watermelon", "instance_count": 814, "def": "large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp", "synonyms": ["watermelon"], "image_count": 114, "id": 1172, "frequency": "f", "synset": "watermelon.n.02"}, {"name": "weathervane", "instance_count": 237, "def": "mechanical device attached to an elevated structure; rotates freely to show the direction of the wind", "synonyms": ["weathervane", "vane_(weathervane)", "wind_vane"], "image_count": 193, "id": 1173, "frequency": "f", "synset": "weathervane.n.01"}, {"name": "webcam", "instance_count": 27, "def": "a digital camera designed to take digital photographs and transmit them over the internet", "synonyms": ["webcam"], "image_count": 21, "id": 1174, "frequency": "c", "synset": "webcam.n.01"}, {"name": "wedding_cake", "instance_count": 140, "def": "a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception", "synonyms": ["wedding_cake", "bridecake"], "image_count": 91, "id": 1175, "frequency": "c", "synset": "wedding_cake.n.01"}, {"name": "wedding_ring", "instance_count": 49, "def": "a ring given to the bride and/or groom at the wedding", "synonyms": ["wedding_ring", "wedding_band"], "image_count": 31, "id": 1176, "frequency": "c", "synset": "wedding_ring.n.01"}, {"name": "wet_suit", "instance_count": 2907, "def": "a close-fitting garment made of a permeable material; worn in cold water to retain body heat", "synonyms": ["wet_suit"], "image_count": 1469, "id": 1177, "frequency": "f", "synset": "wet_suit.n.01"}, {"name": "wheel", "instance_count": 11272, "def": "a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle", "synonyms": ["wheel"], "image_count": 1924, "id": 1178, "frequency": "f", "synset": "wheel.n.01"}, {"name": "wheelchair", "instance_count": 107, "def": "a movable chair mounted on large wheels", "synonyms": ["wheelchair"], "image_count": 87, "id": 1179, "frequency": "c", "synset": "wheelchair.n.01"}, {"name": "whipped_cream", "instance_count": 201, "def": "cream that has been beaten until light and fluffy", "synonyms": ["whipped_cream"], "image_count": 77, "id": 1180, "frequency": "c", "synset": "whipped_cream.n.01"}, {"name": "whistle", "instance_count": 13, "def": "a small wind instrument that produces a whistling sound by blowing into it", "synonyms": ["whistle"], "image_count": 11, "id": 1181, "frequency": "c", "synset": "whistle.n.03"}, {"name": "wig", "instance_count": 69, "def": "hairpiece covering the head and made of real or synthetic hair", "synonyms": ["wig"], "image_count": 47, "id": 1182, "frequency": "c", "synset": "wig.n.01"}, {"name": "wind_chime", "instance_count": 28, "def": "a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle", "synonyms": ["wind_chime"], "image_count": 21, "id": 1183, "frequency": "c", "synset": "wind_chime.n.01"}, {"name": "windmill", "instance_count": 202, "def": "A mill or turbine that is powered by wind", "synonyms": ["windmill"], "image_count": 47, "id": 1184, "frequency": "c", "synset": "windmill.n.01"}, {"name": "window_box_(for_plants)", "instance_count": 253, "def": "a container for growing plants on a windowsill", "synonyms": ["window_box_(for_plants)"], "image_count": 70, "id": 1185, "frequency": "c", "synset": "window_box.n.01"}, {"name": "windshield_wiper", "instance_count": 4793, "def": "a mechanical device that cleans the windshield", "synonyms": ["windshield_wiper", "windscreen_wiper", "wiper_(for_windshield/screen)"], "image_count": 1838, "id": 1186, "frequency": "f", "synset": "windshield_wiper.n.01"}, {"name": "windsock", "instance_count": 26, "def": "a truncated cloth cone mounted on a mast/pole; shows wind direction", "synonyms": ["windsock", "air_sock", "air-sleeve", "wind_sleeve", "wind_cone"], "image_count": 19, "id": 1187, "frequency": "c", "synset": "windsock.n.01"}, {"name": "wine_bottle", "instance_count": 4449, "def": "a bottle for holding wine", "synonyms": ["wine_bottle"], "image_count": 531, "id": 1188, "frequency": "f", "synset": "wine_bottle.n.01"}, {"name": "wine_bucket", "instance_count": 21, "def": "a bucket of ice used to chill a bottle of wine", "synonyms": ["wine_bucket", "wine_cooler"], "image_count": 11, "id": 1189, "frequency": "c", "synset": "wine_bucket.n.01"}, {"name": "wineglass", "instance_count": 4259, "def": "a glass that has a stem and in which wine is served", "synonyms": ["wineglass"], "image_count": 941, "id": 1190, "frequency": "f", "synset": "wineglass.n.01"}, {"name": "blinder_(for_horses)", "instance_count": 271, "def": "blinds that prevent a horse from seeing something on either side", "synonyms": ["blinder_(for_horses)"], "image_count": 113, "id": 1191, "frequency": "f", "synset": "winker.n.02"}, {"name": "wok", "instance_count": 60, "def": "pan with a convex bottom; used for frying in Chinese cooking", "synonyms": ["wok"], "image_count": 26, "id": 1192, "frequency": "c", "synset": "wok.n.01"}, {"name": "wolf", "instance_count": 16, "def": "a wild carnivorous mammal of the dog family, living and hunting in packs", "synonyms": ["wolf"], "image_count": 5, "id": 1193, "frequency": "r", "synset": "wolf.n.01"}, {"name": "wooden_spoon", "instance_count": 123, "def": "a spoon made of wood", "synonyms": ["wooden_spoon"], "image_count": 56, "id": 1194, "frequency": "c", "synset": "wooden_spoon.n.02"}, {"name": "wreath", "instance_count": 119, "def": "an arrangement of flowers, leaves, or stems fastened in a ring", "synonyms": ["wreath"], "image_count": 73, "id": 1195, "frequency": "c", "synset": "wreath.n.01"}, {"name": "wrench", "instance_count": 80, "def": "a hand tool that is used to hold or twist a nut or bolt", "synonyms": ["wrench", "spanner"], "image_count": 32, "id": 1196, "frequency": "c", "synset": "wrench.n.03"}, {"name": "wristband", "instance_count": 268, "def": "band consisting of a part of a sleeve that covers the wrist", "synonyms": ["wristband"], "image_count": 128, "id": 1197, "frequency": "f", "synset": "wristband.n.01"}, {"name": "wristlet", "instance_count": 1330, "def": "a band or bracelet worn around the wrist", "synonyms": ["wristlet", "wrist_band"], "image_count": 623, "id": 1198, "frequency": "f", "synset": "wristlet.n.01"}, {"name": "yacht", "instance_count": 50, "def": "an expensive vessel propelled by sail or power and used for cruising or racing", "synonyms": ["yacht"], "image_count": 12, "id": 1199, "frequency": "c", "synset": "yacht.n.01"}, {"name": "yogurt", "instance_count": 116, "def": "a custard-like food made from curdled milk", "synonyms": ["yogurt", "yoghurt", "yoghourt"], "image_count": 52, "id": 1200, "frequency": "c", "synset": "yogurt.n.01"}, {"name": "yoke_(animal_equipment)", "instance_count": 20, "def": "gear joining two animals at the neck; NOT egg yolk", "synonyms": ["yoke_(animal_equipment)"], "image_count": 11, "id": 1201, "frequency": "c", "synset": "yoke.n.07"}, {"name": "zebra", "instance_count": 5443, "def": "any of several fleet black-and-white striped African equines", "synonyms": ["zebra"], "image_count": 1674, "id": 1202, "frequency": "f", "synset": "zebra.n.01"}, {"name": "zucchini", "instance_count": 798, "def": "small cucumber-shaped vegetable marrow; typically dark green", "synonyms": ["zucchini", "courgette"], "image_count": 81, "id": 1203, "frequency": "c", "synset": "zucchini.n.02"}] \ No newline at end of file diff --git a/dimos/models/Detic/datasets/metadata/o365_clip_a+cnamefix.npy b/dimos/models/Detic/datasets/metadata/o365_clip_a+cnamefix.npy new file mode 100644 index 0000000000..64a2e43c4b Binary files /dev/null and b/dimos/models/Detic/datasets/metadata/o365_clip_a+cnamefix.npy differ diff --git a/dimos/models/Detic/datasets/metadata/oid_clip_a+cname.npy.REMOVED.git-id b/dimos/models/Detic/datasets/metadata/oid_clip_a+cname.npy.REMOVED.git-id new file mode 100644 index 0000000000..2e1266c9d5 --- /dev/null +++ b/dimos/models/Detic/datasets/metadata/oid_clip_a+cname.npy.REMOVED.git-id @@ -0,0 +1 @@ +1a2c953b8d55d0e6bc09e98623a5243973c285ed \ No newline at end of file diff --git a/dimos/models/Detic/demo.py b/dimos/models/Detic/demo.py new file mode 100755 index 0000000000..e982f745a5 --- /dev/null +++ b/dimos/models/Detic/demo.py @@ -0,0 +1,228 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import glob +import multiprocessing as mp +import os +import sys +import tempfile +import time +import warnings + +import cv2 +from detectron2.config import get_cfg +from detectron2.data.detection_utils import read_image +from detectron2.utils.logger import setup_logger +import mss +import numpy as np +import tqdm + +sys.path.insert(0, "third_party/CenterNet2/") +from centernet.config import add_centernet_config +from detic.config import add_detic_config +from detic.predictor import VisualizationDemo + + +# Fake a video capture object OpenCV style - half width, half height of first screen using MSS +class ScreenGrab: + def __init__(self) -> None: + self.sct = mss.mss() + m0 = self.sct.monitors[0] + self.monitor = {"top": 0, "left": 0, "width": m0["width"] / 2, "height": m0["height"] / 2} + + def read(self): + img = np.array(self.sct.grab(self.monitor)) + nf = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) + return (True, nf) + + def isOpened(self) -> bool: + return True + + def release(self) -> bool: + return True + + +# constants +WINDOW_NAME = "Detic" + + +def setup_cfg(args): + cfg = get_cfg() + if args.cpu: + cfg.MODEL.DEVICE = "cpu" + add_centernet_config(cfg) + add_detic_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + # Set score_threshold for builtin models + cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold + cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold + cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" # load later + if not args.pred_all_class: + cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True + cfg.freeze() + return cfg + + +def get_parser(): + parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") + parser.add_argument( + "--config-file", + default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml", + metavar="FILE", + help="path to config file", + ) + parser.add_argument("--webcam", help="Take inputs from webcam.") + parser.add_argument("--cpu", action="store_true", help="Use CPU only.") + parser.add_argument("--video-input", help="Path to video file.") + parser.add_argument( + "--input", + nargs="+", + help="A list of space separated input images; or a single glob pattern such as 'directory/*.jpg'", + ) + parser.add_argument( + "--output", + help="A file or directory to save output visualizations. If not given, will show output in an OpenCV window.", + ) + parser.add_argument( + "--vocabulary", + default="lvis", + choices=["lvis", "openimages", "objects365", "coco", "custom"], + help="", + ) + parser.add_argument( + "--custom_vocabulary", + default="", + help="", + ) + parser.add_argument("--pred_all_class", action="store_true") + parser.add_argument( + "--confidence-threshold", + type=float, + default=0.5, + help="Minimum score for instance predictions to be shown", + ) + parser.add_argument( + "--opts", + help="Modify config options using the command-line 'KEY VALUE' pairs", + default=[], + nargs=argparse.REMAINDER, + ) + return parser + + +def test_opencv_video_format(codec, file_ext) -> bool: + with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: + filename = os.path.join(dir, "test_file" + file_ext) + writer = cv2.VideoWriter( + filename=filename, + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=float(30), + frameSize=(10, 10), + isColor=True, + ) + [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] + writer.release() + if os.path.isfile(filename): + return True + return False + + +if __name__ == "__main__": + mp.set_start_method("spawn", force=True) + args = get_parser().parse_args() + setup_logger(name="fvcore") + logger = setup_logger() + logger.info("Arguments: " + str(args)) + + cfg = setup_cfg(args) + + demo = VisualizationDemo(cfg, args) + + if args.input: + if len(args.input) == 1: + args.input = glob.glob(os.path.expanduser(args.input[0])) + assert args.input, "The input path(s) was not found" + for path in tqdm.tqdm(args.input, disable=not args.output): + img = read_image(path, format="BGR") + start_time = time.time() + predictions, visualized_output = demo.run_on_image(img) + logger.info( + "{}: {} in {:.2f}s".format( + path, + "detected {} instances".format(len(predictions["instances"])) + if "instances" in predictions + else "finished", + time.time() - start_time, + ) + ) + + if args.output: + if os.path.isdir(args.output): + assert os.path.isdir(args.output), args.output + out_filename = os.path.join(args.output, os.path.basename(path)) + else: + assert len(args.input) == 1, "Please specify a directory with args.output" + out_filename = args.output + visualized_output.save(out_filename) + else: + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) + if cv2.waitKey(0) == 27: + break # esc to quit + elif args.webcam: + assert args.input is None, "Cannot have both --input and --webcam!" + assert args.output is None, "output not yet supported with --webcam!" + if args.webcam == "screen": + cam = ScreenGrab() + else: + cam = cv2.VideoCapture(int(args.webcam)) + for vis in tqdm.tqdm(demo.run_on_video(cam)): + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, vis) + if cv2.waitKey(1) == 27: + break # esc to quit + cam.release() + cv2.destroyAllWindows() + elif args.video_input: + video = cv2.VideoCapture(args.video_input) + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames_per_second = video.get(cv2.CAP_PROP_FPS) + num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + basename = os.path.basename(args.video_input) + codec, file_ext = ( + ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") + ) + if codec == ".mp4v": + warnings.warn("x264 codec not available, switching to mp4v", stacklevel=2) + if args.output: + if os.path.isdir(args.output): + output_fname = os.path.join(args.output, basename) + output_fname = os.path.splitext(output_fname)[0] + file_ext + else: + output_fname = args.output + assert not os.path.isfile(output_fname), output_fname + output_file = cv2.VideoWriter( + filename=output_fname, + # some installation of opencv may not support x264 (due to its license), + # you can try other format (e.g. MPEG) + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=float(frames_per_second), + frameSize=(width, height), + isColor=True, + ) + assert os.path.isfile(args.video_input) + for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): + if args.output: + output_file.write(vis_frame) + else: + cv2.namedWindow(basename, cv2.WINDOW_NORMAL) + cv2.imshow(basename, vis_frame) + if cv2.waitKey(1) == 27: + break # esc to quit + video.release() + if args.output: + output_file.release() + else: + cv2.destroyAllWindows() diff --git a/dimos/models/Detic/detic/__init__.py b/dimos/models/Detic/detic/__init__.py new file mode 100644 index 0000000000..2f8aa0a44e --- /dev/null +++ b/dimos/models/Detic/detic/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .data.datasets import cc, coco_zeroshot, imagenet, lvis_v1, objects365, oid +from .modeling.backbone import swintransformer, timm +from .modeling.meta_arch import custom_rcnn +from .modeling.roi_heads import detic_roi_heads, res5_roi_heads + +try: + from .modeling.meta_arch import d2_deformable_detr +except: + pass diff --git a/dimos/models/Detic/detic/config.py b/dimos/models/Detic/detic/config.py new file mode 100644 index 0000000000..c053f0bd06 --- /dev/null +++ b/dimos/models/Detic/detic/config.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import CfgNode as CN + + +def add_detic_config(cfg) -> None: + _C = cfg + + _C.WITH_IMAGE_LABELS = False # Turn on co-training with classification data + + # Open-vocabulary classifier + _C.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS = ( + False # Use fixed classifier for open-vocabulary detection + ) + _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "datasets/metadata/lvis_v1_clip_a+cname.npy" + _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM = 512 + _C.MODEL.ROI_BOX_HEAD.NORM_WEIGHT = True + _C.MODEL.ROI_BOX_HEAD.NORM_TEMP = 50.0 + _C.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS = False + _C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use + + _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False # CenterNet2 + _C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False + _C.MODEL.ROI_BOX_HEAD.PRIOR_PROB = 0.01 + _C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False # Federated Loss + _C.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = "datasets/metadata/lvis_v1_train_cat_info.json" + _C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT = 50 + _C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT = 0.5 + + # Classification data configs + _C.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS = "max_size" # max, softmax, sum + _C.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT = 0.1 + _C.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE = 1.0 + _C.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX = False # Used for image-box loss and caption loss + _C.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS = 128 # num proposals for image-labeled data + _C.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP = False # Used for WSDDN + _C.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT = 1.0 # Caption loss weight + _C.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT = 0.125 # Caption loss hyper-parameter + _C.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP = False # Used for WSDDN + _C.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS = False # Used when USE_SIGMOID_CE is False + + _C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0 + _C.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False # For demo only + + # Caption losses + _C.MODEL.CAP_BATCH_RATIO = 4 # Ratio between detection data and caption data + _C.MODEL.WITH_CAPTION = False + _C.MODEL.SYNC_CAPTION_BATCH = False # synchronize across GPUs to enlarge # "classes" + + # dynamic class sampling when training with 21K classes + _C.MODEL.DYNAMIC_CLASSIFIER = False + _C.MODEL.NUM_SAMPLE_CATS = 50 + + # Different classifiers in testing, used in cross-dataset evaluation + _C.MODEL.RESET_CLS_TESTS = False + _C.MODEL.TEST_CLASSIFIERS = [] + _C.MODEL.TEST_NUM_CLASSES = [] + + # Backbones + _C.MODEL.SWIN = CN() + _C.MODEL.SWIN.SIZE = "T" # 'T', 'S', 'B' + _C.MODEL.SWIN.USE_CHECKPOINT = False + _C.MODEL.SWIN.OUT_FEATURES = (1, 2, 3) # FPN stride 8 - 32 + + _C.MODEL.TIMM = CN() + _C.MODEL.TIMM.BASE_NAME = "resnet50" + _C.MODEL.TIMM.OUT_LEVELS = (3, 4, 5) + _C.MODEL.TIMM.NORM = "FrozenBN" + _C.MODEL.TIMM.FREEZE_AT = 0 + _C.MODEL.TIMM.PRETRAINED = False + _C.MODEL.DATASET_LOSS_WEIGHT = [] + + # Multi-dataset dataloader + _C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio + _C.DATALOADER.USE_RFS = [False, False] + _C.DATALOADER.MULTI_DATASET_GROUPING = False # Always true when multi-dataset is enabled + _C.DATALOADER.DATASET_ANN = ["box", "box"] # Annotation type of each dataset + _C.DATALOADER.USE_DIFF_BS_SIZE = False # Use different batchsize for each dataset + _C.DATALOADER.DATASET_BS = [8, 32] # Used when USE_DIFF_BS_SIZE is on + _C.DATALOADER.DATASET_INPUT_SIZE = [896, 384] # Used when USE_DIFF_BS_SIZE is on + _C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.5, 1.5)] # Used when USE_DIFF_BS_SIZE is on + _C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (320, 400)] # Used when USE_DIFF_BS_SIZE is on + _C.DATALOADER.DATASET_MAX_SIZES = [1333, 667] # Used when USE_DIFF_BS_SIZE is on + _C.DATALOADER.USE_TAR_DATASET = False # for ImageNet-21K, directly reading from unziped files + _C.DATALOADER.TARFILE_PATH = "datasets/imagenet/metadata-22k/tar_files.npy" + _C.DATALOADER.TAR_INDEX_DIR = "datasets/imagenet/metadata-22k/tarindex_npy" + + _C.SOLVER.USE_CUSTOM_SOLVER = False + _C.SOLVER.OPTIMIZER = "SGD" + _C.SOLVER.BACKBONE_MULTIPLIER = 1.0 # Used in DETR + _C.SOLVER.CUSTOM_MULTIPLIER = 1.0 # Used in DETR + _C.SOLVER.CUSTOM_MULTIPLIER_NAME = [] # Used in DETR + + # Deformable DETR + _C.MODEL.DETR = CN() + _C.MODEL.DETR.NUM_CLASSES = 80 + _C.MODEL.DETR.FROZEN_WEIGHTS = "" # For Segmentation + _C.MODEL.DETR.GIOU_WEIGHT = 2.0 + _C.MODEL.DETR.L1_WEIGHT = 5.0 + _C.MODEL.DETR.DEEP_SUPERVISION = True + _C.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1 + _C.MODEL.DETR.CLS_WEIGHT = 2.0 + _C.MODEL.DETR.NUM_FEATURE_LEVELS = 4 + _C.MODEL.DETR.TWO_STAGE = False + _C.MODEL.DETR.WITH_BOX_REFINE = False + _C.MODEL.DETR.FOCAL_ALPHA = 0.25 + _C.MODEL.DETR.NHEADS = 8 + _C.MODEL.DETR.DROPOUT = 0.1 + _C.MODEL.DETR.DIM_FEEDFORWARD = 2048 + _C.MODEL.DETR.ENC_LAYERS = 6 + _C.MODEL.DETR.DEC_LAYERS = 6 + _C.MODEL.DETR.PRE_NORM = False + _C.MODEL.DETR.HIDDEN_DIM = 256 + _C.MODEL.DETR.NUM_OBJECT_QUERIES = 100 + + _C.MODEL.DETR.USE_FED_LOSS = False + _C.MODEL.DETR.WEAK_WEIGHT = 0.1 + + _C.INPUT.CUSTOM_AUG = "" + _C.INPUT.TRAIN_SIZE = 640 + _C.INPUT.TEST_SIZE = 640 + _C.INPUT.SCALE_RANGE = (0.1, 2.0) + # 'default' for fixed short/ long edge, 'square' for max size=INPUT.SIZE + _C.INPUT.TEST_INPUT_TYPE = "default" + + _C.FIND_UNUSED_PARAM = True + _C.EVAL_PRED_AR = False + _C.EVAL_PROPOSAL_AR = False + _C.EVAL_CAT_SPEC_AR = False + _C.IS_DEBUG = False + _C.QUICK_DEBUG = False + _C.FP16 = False + _C.EVAL_AP_FIX = False + _C.GEN_PSEDO_LABELS = False + _C.SAVE_DEBUG_PATH = "output/save_debug/" diff --git a/dimos/models/Detic/detic/custom_solver.py b/dimos/models/Detic/detic/custom_solver.py new file mode 100644 index 0000000000..a552dea0f1 --- /dev/null +++ b/dimos/models/Detic/detic/custom_solver.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import itertools +from typing import Any, Dict, List, Set + +from detectron2.config import CfgNode +from detectron2.solver.build import maybe_add_gradient_clipping +import torch + + +def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + +def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build an optimizer from config. + """ + params: list[dict[str, Any]] = [] + memo: set[torch.nn.parameter.Parameter] = set() + custom_multiplier_name = cfg.SOLVER.CUSTOM_MULTIPLIER_NAME + optimizer_type = cfg.SOLVER.OPTIMIZER + for key, value in model.named_parameters(recurse=True): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + if "backbone" in key: + lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER + if match_name_keywords(key, custom_multiplier_name): + lr = lr * cfg.SOLVER.CUSTOM_MULTIPLIER + print("Costum LR", key, lr) + param = {"params": [value], "lr": lr} + if optimizer_type != "ADAMW": + param["weight_decay"] = weight_decay + params += [param] + + def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class + # detectron2 doesn't have full model gradient clipping now + clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE + enable = ( + cfg.SOLVER.CLIP_GRADIENTS.ENABLED + and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" + and clip_norm_val > 0.0 + ) + + class FullModelGradientClippingOptimizer(optim): + def step(self, closure=None) -> None: + all_params = itertools.chain(*[x["params"] for x in self.param_groups]) + torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) + super().step(closure=closure) + + return FullModelGradientClippingOptimizer if enable else optim + + if optimizer_type == "SGD": + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( + params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV + ) + elif optimizer_type == "ADAMW": + optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( + params, cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY + ) + else: + raise NotImplementedError(f"no optimizer type {optimizer_type}") + if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": + optimizer = maybe_add_gradient_clipping(cfg, optimizer) + return optimizer diff --git a/dimos/models/Detic/detic/data/custom_build_augmentation.py b/dimos/models/Detic/detic/data/custom_build_augmentation.py new file mode 100644 index 0000000000..5a6049ae02 --- /dev/null +++ b/dimos/models/Detic/detic/data/custom_build_augmentation.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + + +from detectron2.data import transforms as T + +from .transforms.custom_augmentation_impl import EfficientDetResizeCrop +from typing import Optional + + +def build_custom_augmentation(cfg, is_train: bool, scale=None, size: Optional[int]=None, min_size: Optional[int]=None, max_size: Optional[int]=None): + """ + Create a list of default :class:`Augmentation` from config. + Now it includes resizing and flipping. + + Returns: + list[Augmentation] + """ + if cfg.INPUT.CUSTOM_AUG == "ResizeShortestEdge": + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size + max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] + elif cfg.INPUT.CUSTOM_AUG == "EfficientDetResizeCrop": + if is_train: + scale = cfg.INPUT.SCALE_RANGE if scale is None else scale + size = cfg.INPUT.TRAIN_SIZE if size is None else size + else: + scale = (1, 1) + size = cfg.INPUT.TEST_SIZE + augmentation = [EfficientDetResizeCrop(size, scale)] + else: + assert 0, cfg.INPUT.CUSTOM_AUG + + if is_train: + augmentation.append(T.RandomFlip()) + return augmentation + + +build_custom_transform_gen = build_custom_augmentation +""" +Alias for backward-compatibility. +""" diff --git a/dimos/models/Detic/detic/data/custom_dataset_dataloader.py b/dimos/models/Detic/detic/data/custom_dataset_dataloader.py new file mode 100644 index 0000000000..ff4bfc9ea4 --- /dev/null +++ b/dimos/models/Detic/detic/data/custom_dataset_dataloader.py @@ -0,0 +1,322 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License) +from collections import defaultdict +import itertools +import math +import operator +from typing import Iterator, Sequence, Optional + +from detectron2.config import configurable +from detectron2.data.build import ( + build_batch_data_loader, + check_metadata_consistency, + filter_images_with_few_keypoints, + filter_images_with_only_crowd_annotations, + get_detection_dataset_dicts, + print_instances_class_histogram, + worker_init_reset_seed, +) +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog +from detectron2.data.common import DatasetFromList, MapDataset +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.data.samplers import RepeatFactorTrainingSampler, TrainingSampler +from detectron2.utils import comm +from detectron2.utils.comm import get_world_size +import torch +import torch.utils.data +from torch.utils.data.sampler import Sampler + + +def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + if "MultiDataset" in sampler_name: + dataset_dicts = get_detection_dataset_dicts_with_source( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + else: + dataset_dicts = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + + if mapper is None: + mapper = DatasetMapper(cfg, True) + + if sampler is not None: + pass + elif sampler_name == "TrainingSampler": + sampler = TrainingSampler(len(dataset)) + elif sampler_name == "MultiDatasetSampler": + sampler = MultiDatasetSampler( + dataset_dicts, + dataset_ratio=cfg.DATALOADER.DATASET_RATIO, + use_rfs=cfg.DATALOADER.USE_RFS, + dataset_ann=cfg.DATALOADER.DATASET_ANN, + repeat_threshold=cfg.DATALOADER.REPEAT_THRESHOLD, + ) + elif sampler_name == "RepeatFactorTrainingSampler": + repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD + ) + sampler = RepeatFactorTrainingSampler(repeat_factors) + else: + raise ValueError(f"Unknown training sampler: {sampler_name}") + + return { + "dataset": dataset_dicts, + "sampler": sampler, + "mapper": mapper, + "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, + "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + "multi_dataset_grouping": cfg.DATALOADER.MULTI_DATASET_GROUPING, + "use_diff_bs_size": cfg.DATALOADER.USE_DIFF_BS_SIZE, + "dataset_bs": cfg.DATALOADER.DATASET_BS, + "num_datasets": len(cfg.DATASETS.TRAIN), + } + + +@configurable(from_config=_custom_train_loader_from_config) +def build_custom_train_loader( + dataset, + *, + mapper, + sampler, + total_batch_size: int=16, + aspect_ratio_grouping: bool=True, + num_workers: int=0, + num_datasets: int=1, + multi_dataset_grouping: bool=False, + use_diff_bs_size: bool=False, + dataset_bs=None, +): + """ + Modified from detectron2.data.build.build_custom_train_loader, but supports + different samplers + """ + if dataset_bs is None: + dataset_bs = [] + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + if sampler is None: + sampler = TrainingSampler(len(dataset)) + assert isinstance(sampler, torch.utils.data.sampler.Sampler) + if multi_dataset_grouping: + return build_multi_dataset_batch_data_loader( + use_diff_bs_size, + dataset_bs, + dataset, + sampler, + total_batch_size, + num_datasets=num_datasets, + num_workers=num_workers, + ) + else: + return build_batch_data_loader( + dataset, + sampler, + total_batch_size, + aspect_ratio_grouping=aspect_ratio_grouping, + num_workers=num_workers, + ) + + +def build_multi_dataset_batch_data_loader( + use_diff_bs_size: int, dataset_bs, dataset, sampler, total_batch_size: int, num_datasets: int, num_workers: int=0 +): + """ """ + world_size = get_world_size() + assert total_batch_size > 0 and total_batch_size % world_size == 0, ( + f"Total batch size ({total_batch_size}) must be divisible by the number of gpus ({world_size})." + ) + + batch_size = total_batch_size // world_size + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + num_workers=num_workers, + batch_sampler=None, + collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements + worker_init_fn=worker_init_reset_seed, + ) # yield individual mapped dict + if use_diff_bs_size: + return DIFFMDAspectRatioGroupedDataset(data_loader, dataset_bs, num_datasets) + else: + return MDAspectRatioGroupedDataset(data_loader, batch_size, num_datasets) + + +def get_detection_dataset_dicts_with_source( + dataset_names: Sequence[str], filter_empty: bool=True, min_keypoints: int=0, proposal_files=None +): + assert len(dataset_names) + dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] + for dataset_name, dicts in zip(dataset_names, dataset_dicts, strict=False): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" + + for source_id, (dataset_name, dicts) in enumerate(zip(dataset_names, dataset_dicts, strict=False)): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" + for d in dicts: + d["dataset_source"] = source_id + + if "annotations" in dicts[0]: + try: + class_names = MetadataCatalog.get(dataset_name).thing_classes + check_metadata_consistency("thing_classes", dataset_name) + print_instances_class_histogram(dicts, class_names) + except AttributeError: # class names are not available for this dataset + pass + + assert proposal_files is None + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = "annotations" in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) + + return dataset_dicts + + +class MultiDatasetSampler(Sampler): + def __init__( + self, + dataset_dicts, + dataset_ratio, + use_rfs, + dataset_ann, + repeat_threshold: float=0.001, + seed: int | None = None, + ) -> None: + """ """ + sizes = [0 for _ in range(len(dataset_ratio))] + for d in dataset_dicts: + sizes[d["dataset_source"]] += 1 + print("dataset sizes", sizes) + self.sizes = sizes + assert len(dataset_ratio) == len(sizes), ( + f"length of dataset ratio {len(dataset_ratio)} should be equal to number if dataset {len(sizes)}" + ) + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + self.dataset_ids = torch.tensor( + [d["dataset_source"] for d in dataset_dicts], dtype=torch.long + ) + + dataset_weight = [ + torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) + for i, (r, s) in enumerate(zip(dataset_ratio, sizes, strict=False)) + ] + dataset_weight = torch.cat(dataset_weight) + + rfs_factors = [] + st = 0 + for i, s in enumerate(sizes): + if use_rfs[i]: + if dataset_ann[i] == "box": + rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency + else: + rfs_func = repeat_factors_from_tag_frequency + rfs_factor = rfs_func(dataset_dicts[st : st + s], repeat_thresh=repeat_threshold) + rfs_factor = rfs_factor * (s / rfs_factor.sum()) + else: + rfs_factor = torch.ones(s) + rfs_factors.append(rfs_factor) + st = st + s + rfs_factors = torch.cat(rfs_factors) + + self.weights = dataset_weight * rfs_factors + self.sample_epoch_size = len(self.weights) + + def __iter__(self) -> Iterator: + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + ids = torch.multinomial( + self.weights, self.sample_epoch_size, generator=g, replacement=True + ) + [(self.dataset_ids[ids] == i).sum().int().item() for i in range(len(self.sizes))] + yield from ids + + +class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): + def __init__(self, dataset, batch_size: int, num_datasets: int) -> None: + """ """ + self.dataset = dataset + self.batch_size = batch_size + self._buckets = [[] for _ in range(2 * num_datasets)] + + def __iter__(self) -> Iterator: + for d in self.dataset: + w, h = d["width"], d["height"] + aspect_ratio_bucket_id = 0 if w > h else 1 + bucket_id = d["dataset_source"] * 2 + aspect_ratio_bucket_id + bucket = self._buckets[bucket_id] + bucket.append(d) + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + +class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): + def __init__(self, dataset, batch_sizes: Sequence[int], num_datasets: int) -> None: + """ """ + self.dataset = dataset + self.batch_sizes = batch_sizes + self._buckets = [[] for _ in range(2 * num_datasets)] + + def __iter__(self) -> Iterator: + for d in self.dataset: + w, h = d["width"], d["height"] + aspect_ratio_bucket_id = 0 if w > h else 1 + bucket_id = d["dataset_source"] * 2 + aspect_ratio_bucket_id + bucket = self._buckets[bucket_id] + bucket.append(d) + if len(bucket) == self.batch_sizes[d["dataset_source"]]: + yield bucket[:] + del bucket[:] + + +def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh): + """ """ + category_freq = defaultdict(int) + for dataset_dict in dataset_dicts: + cat_ids = dataset_dict["pos_category_ids"] + for cat_id in cat_ids: + category_freq[cat_id] += 1 + num_images = len(dataset_dicts) + for k, v in category_freq.items(): + category_freq[k] = v / num_images + + category_rep = { + cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) + for cat_id, cat_freq in category_freq.items() + } + + rep_factors = [] + for dataset_dict in dataset_dicts: + cat_ids = dataset_dict["pos_category_ids"] + rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0) + rep_factors.append(rep_factor) + + return torch.tensor(rep_factors, dtype=torch.float32) diff --git a/dimos/models/Detic/detic/data/custom_dataset_mapper.py b/dimos/models/Detic/detic/data/custom_dataset_mapper.py new file mode 100644 index 0000000000..46c86ffd84 --- /dev/null +++ b/dimos/models/Detic/detic/data/custom_dataset_mapper.py @@ -0,0 +1,284 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import copy +import logging + +from detectron2.config import configurable +from detectron2.data import detection_utils as utils, transforms as T +from detectron2.data.dataset_mapper import DatasetMapper +import numpy as np +import torch + +from .custom_build_augmentation import build_custom_augmentation +from .tar_dataset import DiskTarDataset + +__all__ = ["CustomDatasetMapper"] + + +class CustomDatasetMapper(DatasetMapper): + @configurable + def __init__( + self, + is_train: bool, + with_ann_type: bool=False, + dataset_ann=None, + use_diff_bs_size: bool=False, + dataset_augs=None, + is_debug: bool=False, + use_tar_dataset: bool=False, + tarfile_path: str="", + tar_index_dir: str="", + **kwargs, + ) -> None: + """ + add image labels + """ + if dataset_augs is None: + dataset_augs = [] + if dataset_ann is None: + dataset_ann = [] + self.with_ann_type = with_ann_type + self.dataset_ann = dataset_ann + self.use_diff_bs_size = use_diff_bs_size + if self.use_diff_bs_size and is_train: + self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs] + self.is_debug = is_debug + self.use_tar_dataset = use_tar_dataset + if self.use_tar_dataset: + print("Using tar dataset") + self.tar_dataset = DiskTarDataset(tarfile_path, tar_index_dir) + super().__init__(is_train, **kwargs) + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + ret = super().from_config(cfg, is_train) + ret.update( + { + "with_ann_type": cfg.WITH_IMAGE_LABELS, + "dataset_ann": cfg.DATALOADER.DATASET_ANN, + "use_diff_bs_size": cfg.DATALOADER.USE_DIFF_BS_SIZE, + "is_debug": cfg.IS_DEBUG, + "use_tar_dataset": cfg.DATALOADER.USE_TAR_DATASET, + "tarfile_path": cfg.DATALOADER.TARFILE_PATH, + "tar_index_dir": cfg.DATALOADER.TAR_INDEX_DIR, + } + ) + if ret["use_diff_bs_size"] and is_train: + if cfg.INPUT.CUSTOM_AUG == "EfficientDetResizeCrop": + dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE + dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE + ret["dataset_augs"] = [ + build_custom_augmentation(cfg, True, scale, size) + for scale, size in zip(dataset_scales, dataset_sizes, strict=False) + ] + else: + assert cfg.INPUT.CUSTOM_AUG == "ResizeShortestEdge" + min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES + max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES + ret["dataset_augs"] = [ + build_custom_augmentation(cfg, True, min_size=mi, max_size=ma) + for mi, ma in zip(min_sizes, max_sizes, strict=False) + ] + else: + ret["dataset_augs"] = [] + + return ret + + def __call__(self, dataset_dict): + """ + include image labels + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + # USER: Write your own image loading if it's not from a file + if "file_name" in dataset_dict: + ori_image = utils.read_image(dataset_dict["file_name"], format=self.image_format) + else: + ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]] + ori_image = utils._apply_exif_orientation(ori_image) + ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format) + utils.check_image_size(dataset_dict, ori_image) + + # USER: Remove if you don't do semantic/panoptic segmentation. + if "sem_seg_file_name" in dataset_dict: + sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) + else: + sem_seg_gt = None + + if self.is_debug: + dataset_dict["dataset_source"] = 0 + + ( + "dataset_source" in dataset_dict + and self.with_ann_type + and self.dataset_ann[dataset_dict["dataset_source"]] != "box" + ) + + aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=sem_seg_gt) + if self.use_diff_bs_size and self.is_train: + transforms = self.dataset_augs[dataset_dict["dataset_source"]](aug_input) + else: + transforms = self.augmentations(aug_input) + image, sem_seg_gt = aug_input.image, aug_input.sem_seg + + image_shape = image.shape[:2] # h, w + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) + + # USER: Remove if you don't use pre-computed proposals. + # Most users would not need this feature. + if self.proposal_topk is not None: + utils.transform_proposals( + dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk + ) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop("annotations", None) + dataset_dict.pop("sem_seg_file_name", None) + return dataset_dict + + if "annotations" in dataset_dict: + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.use_instance_mask: + anno.pop("segmentation", None) + if not self.use_keypoint: + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + all_annos = [ + ( + utils.transform_instance_annotations( + obj, + transforms, + image_shape, + keypoint_hflip_indices=self.keypoint_hflip_indices, + ), + obj.get("iscrowd", 0), + ) + for obj in dataset_dict.pop("annotations") + ] + annos = [ann[0] for ann in all_annos if ann[1] == 0] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.instance_mask_format + ) + + del all_annos + if self.recompute_boxes: + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict["instances"] = utils.filter_empty_instances(instances) + if self.with_ann_type: + dataset_dict["pos_category_ids"] = dataset_dict.get("pos_category_ids", []) + dataset_dict["ann_type"] = self.dataset_ann[dataset_dict["dataset_source"]] + if self.is_debug and ( + ("pos_category_ids" not in dataset_dict) or (dataset_dict["pos_category_ids"] == []) + ): + dataset_dict["pos_category_ids"] = [ + x for x in sorted(set(dataset_dict["instances"].gt_classes.tolist())) + ] + return dataset_dict + + +# DETR augmentation +def build_transform_gen(cfg, is_train: bool): + """ """ + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + if sample_style == "range": + assert len(min_size) == 2, f"more than 2 ({len(min_size)}) min_size(s) are provided for ranges" + + logger = logging.getLogger(__name__) + tfm_gens = [] + if is_train: + tfm_gens.append(T.RandomFlip()) + tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) + if is_train: + logger.info("TransformGens used in training: " + str(tfm_gens)) + return tfm_gens + + +class DetrDatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by DETR. + The callable currently does the following: + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + """ + + def __init__(self, cfg, is_train: bool=True) -> None: + if cfg.INPUT.CROP.ENABLED and is_train: + self.crop_gen = [ + T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), + T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), + ] + else: + self.crop_gen = None + + self.mask_on = cfg.MODEL.MASK_ON + self.tfm_gens = build_transform_gen(cfg, is_train) + logging.getLogger(__name__).info( + f"Full TransformGens used in training: {self.tfm_gens!s}, crop: {self.crop_gen!s}" + ) + + self.img_format = cfg.INPUT.FORMAT + self.is_train = is_train + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + utils.check_image_size(dataset_dict, image) + + if self.crop_gen is None: + image, transforms = T.apply_transform_gens(self.tfm_gens, image) + else: + if np.random.rand() > 0.5: + image, transforms = T.apply_transform_gens(self.tfm_gens, image) + else: + image, transforms = T.apply_transform_gens( + self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image + ) + + image_shape = image.shape[:2] # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop("annotations", None) + return dataset_dict + + if "annotations" in dataset_dict: + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.mask_on: + anno.pop("segmentation", None) + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + annos = [ + utils.transform_instance_annotations(obj, transforms, image_shape) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances(annos, image_shape) + dataset_dict["instances"] = utils.filter_empty_instances(instances) + return dataset_dict diff --git a/dimos/models/Detic/detic/data/datasets/cc.py b/dimos/models/Detic/detic/data/datasets/cc.py new file mode 100644 index 0000000000..be9c7f4a8b --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/cc.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data.datasets.lvis import get_lvis_instances_meta + +from .lvis_v1 import custom_register_lvis_instances + +_CUSTOM_SPLITS = { + "cc3m_v1_val": ("cc3m/validation/", "cc3m/val_image_info.json"), + "cc3m_v1_train": ("cc3m/training/", "cc3m/train_image_info.json"), + "cc3m_v1_train_tags": ("cc3m/training/", "cc3m/train_image_info_tags.json"), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS.items(): + custom_register_lvis_instances( + key, + get_lvis_instances_meta("lvis_v1"), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/detic/data/datasets/coco_zeroshot.py b/dimos/models/Detic/detic/data/datasets/coco_zeroshot.py new file mode 100644 index 0000000000..80c360593d --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/coco_zeroshot.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data.datasets.builtin_meta import _get_builtin_metadata +from detectron2.data.datasets.register_coco import register_coco_instances + +from .lvis_v1 import custom_register_lvis_instances + +categories_seen = [ + {"id": 1, "name": "person"}, + {"id": 2, "name": "bicycle"}, + {"id": 3, "name": "car"}, + {"id": 4, "name": "motorcycle"}, + {"id": 7, "name": "train"}, + {"id": 8, "name": "truck"}, + {"id": 9, "name": "boat"}, + {"id": 15, "name": "bench"}, + {"id": 16, "name": "bird"}, + {"id": 19, "name": "horse"}, + {"id": 20, "name": "sheep"}, + {"id": 23, "name": "bear"}, + {"id": 24, "name": "zebra"}, + {"id": 25, "name": "giraffe"}, + {"id": 27, "name": "backpack"}, + {"id": 31, "name": "handbag"}, + {"id": 33, "name": "suitcase"}, + {"id": 34, "name": "frisbee"}, + {"id": 35, "name": "skis"}, + {"id": 38, "name": "kite"}, + {"id": 42, "name": "surfboard"}, + {"id": 44, "name": "bottle"}, + {"id": 48, "name": "fork"}, + {"id": 50, "name": "spoon"}, + {"id": 51, "name": "bowl"}, + {"id": 52, "name": "banana"}, + {"id": 53, "name": "apple"}, + {"id": 54, "name": "sandwich"}, + {"id": 55, "name": "orange"}, + {"id": 56, "name": "broccoli"}, + {"id": 57, "name": "carrot"}, + {"id": 59, "name": "pizza"}, + {"id": 60, "name": "donut"}, + {"id": 62, "name": "chair"}, + {"id": 65, "name": "bed"}, + {"id": 70, "name": "toilet"}, + {"id": 72, "name": "tv"}, + {"id": 73, "name": "laptop"}, + {"id": 74, "name": "mouse"}, + {"id": 75, "name": "remote"}, + {"id": 78, "name": "microwave"}, + {"id": 79, "name": "oven"}, + {"id": 80, "name": "toaster"}, + {"id": 82, "name": "refrigerator"}, + {"id": 84, "name": "book"}, + {"id": 85, "name": "clock"}, + {"id": 86, "name": "vase"}, + {"id": 90, "name": "toothbrush"}, +] + +categories_unseen = [ + {"id": 5, "name": "airplane"}, + {"id": 6, "name": "bus"}, + {"id": 17, "name": "cat"}, + {"id": 18, "name": "dog"}, + {"id": 21, "name": "cow"}, + {"id": 22, "name": "elephant"}, + {"id": 28, "name": "umbrella"}, + {"id": 32, "name": "tie"}, + {"id": 36, "name": "snowboard"}, + {"id": 41, "name": "skateboard"}, + {"id": 47, "name": "cup"}, + {"id": 49, "name": "knife"}, + {"id": 61, "name": "cake"}, + {"id": 63, "name": "couch"}, + {"id": 76, "name": "keyboard"}, + {"id": 81, "name": "sink"}, + {"id": 87, "name": "scissors"}, +] + + +def _get_metadata(cat): + if cat == "all": + return _get_builtin_metadata("coco") + elif cat == "seen": + id_to_name = {x["id"]: x["name"] for x in categories_seen} + else: + assert cat == "unseen" + id_to_name = {x["id"]: x["name"] for x in categories_unseen} + + thing_dataset_id_to_contiguous_id = {x: i for i, x in enumerate(sorted(id_to_name))} + thing_classes = [id_to_name[k] for k in sorted(id_to_name)] + return { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + } + + +_PREDEFINED_SPLITS_COCO = { + "coco_zeroshot_train": ( + "coco/train2017", + "coco/zero-shot/instances_train2017_seen_2.json", + "seen", + ), + "coco_zeroshot_val": ( + "coco/val2017", + "coco/zero-shot/instances_val2017_unseen_2.json", + "unseen", + ), + "coco_not_zeroshot_val": ( + "coco/val2017", + "coco/zero-shot/instances_val2017_seen_2.json", + "seen", + ), + "coco_generalized_zeroshot_val": ( + "coco/val2017", + "coco/zero-shot/instances_val2017_all_2_oriorder.json", + "all", + ), + "coco_zeroshot_train_oriorder": ( + "coco/train2017", + "coco/zero-shot/instances_train2017_seen_2_oriorder.json", + "all", + ), +} + +for key, (image_root, json_file, cat) in _PREDEFINED_SPLITS_COCO.items(): + register_coco_instances( + key, + _get_metadata(cat), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) + +_CUSTOM_SPLITS_COCO = { + "cc3m_coco_train_tags": ("cc3m/training/", "cc3m/coco_train_image_info_tags.json"), + "coco_caption_train_tags": ( + "coco/train2017/", + "coco/annotations/captions_train2017_tags_allcaps.json", + ), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS_COCO.items(): + custom_register_lvis_instances( + key, + _get_builtin_metadata("coco"), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/detic/data/datasets/imagenet.py b/dimos/models/Detic/detic/data/datasets/imagenet.py new file mode 100644 index 0000000000..caa7aa8fe0 --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/imagenet.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.lvis import get_lvis_instances_meta + +from .lvis_v1 import custom_load_lvis_json, get_lvis_22k_meta + + +def custom_register_imagenet_instances(name: str, metadata, json_file, image_root) -> None: + """ """ + DatasetCatalog.register(name, lambda: custom_load_lvis_json(json_file, image_root, name)) + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="imagenet", **metadata + ) + + +_CUSTOM_SPLITS_IMAGENET = { + "imagenet_lvis_v1": ( + "imagenet/ImageNet-LVIS/", + "imagenet/annotations/imagenet_lvis_image_info.json", + ), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET.items(): + custom_register_imagenet_instances( + key, + get_lvis_instances_meta("lvis_v1"), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) + + +_CUSTOM_SPLITS_IMAGENET_22K = { + "imagenet_lvis-22k": ( + "imagenet/ImageNet-LVIS/", + "imagenet/annotations/imagenet-22k_image_info_lvis-22k.json", + ), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET_22K.items(): + custom_register_imagenet_instances( + key, + get_lvis_22k_meta(), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/detic/data/datasets/lvis_22k_categories.py.REMOVED.git-id b/dimos/models/Detic/detic/data/datasets/lvis_22k_categories.py.REMOVED.git-id new file mode 100644 index 0000000000..ac45ed8b95 --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/lvis_22k_categories.py.REMOVED.git-id @@ -0,0 +1 @@ +d1b3cc370afdb22dbff33647a9404c764e54a649 \ No newline at end of file diff --git a/dimos/models/Detic/detic/data/datasets/lvis_v1.py b/dimos/models/Detic/detic/data/datasets/lvis_v1.py new file mode 100644 index 0000000000..659a5fbbc0 --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/lvis_v1.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.lvis import get_lvis_instances_meta +from detectron2.structures import BoxMode +from fvcore.common.file_io import PathManager +from fvcore.common.timer import Timer +from typing import Optional + +logger = logging.getLogger(__name__) + +__all__ = ["custom_load_lvis_json", "custom_register_lvis_instances"] + + +def custom_register_lvis_instances(name: str, metadata, json_file, image_root) -> None: + """ """ + DatasetCatalog.register(name, lambda: custom_load_lvis_json(json_file, image_root, name)) + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata + ) + + +def custom_load_lvis_json(json_file, image_root, dataset_name: Optional[str]=None): + """ + Modifications: + use `file_name` + convert neg_category_ids + add pos_category_ids + """ + from lvis import LVIS + + json_file = PathManager.get_local_path(json_file) + + timer = Timer() + lvis_api = LVIS(json_file) + if timer.seconds() > 1: + logger.info(f"Loading {json_file} takes {timer.seconds():.2f} seconds.") + + catid2contid = { + x["id"]: i + for i, x in enumerate(sorted(lvis_api.dataset["categories"], key=lambda x: x["id"])) + } + if len(lvis_api.dataset["categories"]) == 1203: + for x in lvis_api.dataset["categories"]: + assert catid2contid[x["id"]] == x["id"] - 1 + img_ids = sorted(lvis_api.imgs.keys()) + imgs = lvis_api.load_imgs(img_ids) + anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] + + ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), f"Annotation ids in '{json_file}' are not unique" + + imgs_anns = list(zip(imgs, anns, strict=False)) + logger.info(f"Loaded {len(imgs_anns)} images in the LVIS v1 format from {json_file}") + + dataset_dicts = [] + + for img_dict, anno_dict_list in imgs_anns: + record = {} + if "file_name" in img_dict: + file_name = img_dict["file_name"] + if img_dict["file_name"].startswith("COCO"): + file_name = file_name[-16:] + record["file_name"] = os.path.join(image_root, file_name) + elif "coco_url" in img_dict: + # e.g., http://images.cocodataset.org/train2017/000000391895.jpg + file_name = img_dict["coco_url"][30:] + record["file_name"] = os.path.join(image_root, file_name) + elif "tar_index" in img_dict: + record["tar_index"] = img_dict["tar_index"] + + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", []) + record["neg_category_ids"] = img_dict.get("neg_category_ids", []) + # NOTE: modified by Xingyi: convert to 0-based + record["neg_category_ids"] = [catid2contid[x] for x in record["neg_category_ids"]] + if "pos_category_ids" in img_dict: + record["pos_category_ids"] = [ + catid2contid[x] for x in img_dict.get("pos_category_ids", []) + ] + if "captions" in img_dict: + record["captions"] = img_dict["captions"] + if "caption_features" in img_dict: + record["caption_features"] = img_dict["caption_features"] + image_id = record["image_id"] = img_dict["id"] + + objs = [] + for anno in anno_dict_list: + assert anno["image_id"] == image_id + if anno.get("iscrowd", 0) > 0: + continue + obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS} + obj["category_id"] = catid2contid[anno["category_id"]] + if "segmentation" in anno: + segm = anno["segmentation"] + valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + # assert len(segm) == len( + # valid_segm + # ), "Annotation contains an invalid polygon with < 3 points" + if not len(segm) == len(valid_segm): + print("Annotation contains an invalid polygon with < 3 points") + assert len(segm) > 0 + obj["segmentation"] = segm + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + + return dataset_dicts + + +_CUSTOM_SPLITS_LVIS = { + "lvis_v1_train+coco": ("coco/", "lvis/lvis_v1_train+coco_mask.json"), + "lvis_v1_train_norare": ("coco/", "lvis/lvis_v1_train_norare.json"), +} + + +for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items(): + custom_register_lvis_instances( + key, + get_lvis_instances_meta(key), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) + + +def get_lvis_22k_meta(): + from .lvis_22k_categories import CATEGORIES + + cat_ids = [k["id"] for k in CATEGORIES] + assert min(cat_ids) == 1 and max(cat_ids) == len(cat_ids), ( + "Category ids are not in [1, #categories], as expected" + ) + # Ensure that the category list is sorted by id + lvis_categories = sorted(CATEGORIES, key=lambda x: x["id"]) + thing_classes = [k["name"] for k in lvis_categories] + meta = {"thing_classes": thing_classes} + return meta + + +_CUSTOM_SPLITS_LVIS_22K = { + "lvis_v1_train_22k": ("coco/", "lvis/lvis_v1_train_lvis-22k.json"), +} + +for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS_22K.items(): + custom_register_lvis_instances( + key, + get_lvis_22k_meta(), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/detic/data/datasets/objects365.py b/dimos/models/Detic/detic/data/datasets/objects365.py new file mode 100644 index 0000000000..236e609287 --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/objects365.py @@ -0,0 +1,781 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data.datasets.register_coco import register_coco_instances + +# categories_v2 = [ +# {'id': 1, 'name': 'Person'}, +# {'id': 2, 'name': 'Sneakers'}, +# {'id': 3, 'name': 'Chair'}, +# {'id': 4, 'name': 'Other Shoes'}, +# {'id': 5, 'name': 'Hat'}, +# {'id': 6, 'name': 'Car'}, +# {'id': 7, 'name': 'Lamp'}, +# {'id': 8, 'name': 'Glasses'}, +# {'id': 9, 'name': 'Bottle'}, +# {'id': 10, 'name': 'Desk'}, +# {'id': 11, 'name': 'Cup'}, +# {'id': 12, 'name': 'Street Lights'}, +# {'id': 13, 'name': 'Cabinet/shelf'}, +# {'id': 14, 'name': 'Handbag/Satchel'}, +# {'id': 15, 'name': 'Bracelet'}, +# {'id': 16, 'name': 'Plate'}, +# {'id': 17, 'name': 'Picture/Frame'}, +# {'id': 18, 'name': 'Helmet'}, +# {'id': 19, 'name': 'Book'}, +# {'id': 20, 'name': 'Gloves'}, +# {'id': 21, 'name': 'Storage box'}, +# {'id': 22, 'name': 'Boat'}, +# {'id': 23, 'name': 'Leather Shoes'}, +# {'id': 24, 'name': 'Flower'}, +# {'id': 25, 'name': 'Bench'}, +# {'id': 26, 'name': 'Potted Plant'}, +# {'id': 27, 'name': 'Bowl/Basin'}, +# {'id': 28, 'name': 'Flag'}, +# {'id': 29, 'name': 'Pillow'}, +# {'id': 30, 'name': 'Boots'}, +# {'id': 31, 'name': 'Vase'}, +# {'id': 32, 'name': 'Microphone'}, +# {'id': 33, 'name': 'Necklace'}, +# {'id': 34, 'name': 'Ring'}, +# {'id': 35, 'name': 'SUV'}, +# {'id': 36, 'name': 'Wine Glass'}, +# {'id': 37, 'name': 'Belt'}, +# {'id': 38, 'name': 'Moniter/TV'}, +# {'id': 39, 'name': 'Backpack'}, +# {'id': 40, 'name': 'Umbrella'}, +# {'id': 41, 'name': 'Traffic Light'}, +# {'id': 42, 'name': 'Speaker'}, +# {'id': 43, 'name': 'Watch'}, +# {'id': 44, 'name': 'Tie'}, +# {'id': 45, 'name': 'Trash bin Can'}, +# {'id': 46, 'name': 'Slippers'}, +# {'id': 47, 'name': 'Bicycle'}, +# {'id': 48, 'name': 'Stool'}, +# {'id': 49, 'name': 'Barrel/bucket'}, +# {'id': 50, 'name': 'Van'}, +# {'id': 51, 'name': 'Couch'}, +# {'id': 52, 'name': 'Sandals'}, +# {'id': 53, 'name': 'Bakset'}, +# {'id': 54, 'name': 'Drum'}, +# {'id': 55, 'name': 'Pen/Pencil'}, +# {'id': 56, 'name': 'Bus'}, +# {'id': 57, 'name': 'Wild Bird'}, +# {'id': 58, 'name': 'High Heels'}, +# {'id': 59, 'name': 'Motorcycle'}, +# {'id': 60, 'name': 'Guitar'}, +# {'id': 61, 'name': 'Carpet'}, +# {'id': 62, 'name': 'Cell Phone'}, +# {'id': 63, 'name': 'Bread'}, +# {'id': 64, 'name': 'Camera'}, +# {'id': 65, 'name': 'Canned'}, +# {'id': 66, 'name': 'Truck'}, +# {'id': 67, 'name': 'Traffic cone'}, +# {'id': 68, 'name': 'Cymbal'}, +# {'id': 69, 'name': 'Lifesaver'}, +# {'id': 70, 'name': 'Towel'}, +# {'id': 71, 'name': 'Stuffed Toy'}, +# {'id': 72, 'name': 'Candle'}, +# {'id': 73, 'name': 'Sailboat'}, +# {'id': 74, 'name': 'Laptop'}, +# {'id': 75, 'name': 'Awning'}, +# {'id': 76, 'name': 'Bed'}, +# {'id': 77, 'name': 'Faucet'}, +# {'id': 78, 'name': 'Tent'}, +# {'id': 79, 'name': 'Horse'}, +# {'id': 80, 'name': 'Mirror'}, +# {'id': 81, 'name': 'Power outlet'}, +# {'id': 82, 'name': 'Sink'}, +# {'id': 83, 'name': 'Apple'}, +# {'id': 84, 'name': 'Air Conditioner'}, +# {'id': 85, 'name': 'Knife'}, +# {'id': 86, 'name': 'Hockey Stick'}, +# {'id': 87, 'name': 'Paddle'}, +# {'id': 88, 'name': 'Pickup Truck'}, +# {'id': 89, 'name': 'Fork'}, +# {'id': 90, 'name': 'Traffic Sign'}, +# {'id': 91, 'name': 'Ballon'}, +# {'id': 92, 'name': 'Tripod'}, +# {'id': 93, 'name': 'Dog'}, +# {'id': 94, 'name': 'Spoon'}, +# {'id': 95, 'name': 'Clock'}, +# {'id': 96, 'name': 'Pot'}, +# {'id': 97, 'name': 'Cow'}, +# {'id': 98, 'name': 'Cake'}, +# {'id': 99, 'name': 'Dinning Table'}, +# {'id': 100, 'name': 'Sheep'}, +# {'id': 101, 'name': 'Hanger'}, +# {'id': 102, 'name': 'Blackboard/Whiteboard'}, +# {'id': 103, 'name': 'Napkin'}, +# {'id': 104, 'name': 'Other Fish'}, +# {'id': 105, 'name': 'Orange/Tangerine'}, +# {'id': 106, 'name': 'Toiletry'}, +# {'id': 107, 'name': 'Keyboard'}, +# {'id': 108, 'name': 'Tomato'}, +# {'id': 109, 'name': 'Lantern'}, +# {'id': 110, 'name': 'Machinery Vehicle'}, +# {'id': 111, 'name': 'Fan'}, +# {'id': 112, 'name': 'Green Vegetables'}, +# {'id': 113, 'name': 'Banana'}, +# {'id': 114, 'name': 'Baseball Glove'}, +# {'id': 115, 'name': 'Airplane'}, +# {'id': 116, 'name': 'Mouse'}, +# {'id': 117, 'name': 'Train'}, +# {'id': 118, 'name': 'Pumpkin'}, +# {'id': 119, 'name': 'Soccer'}, +# {'id': 120, 'name': 'Skiboard'}, +# {'id': 121, 'name': 'Luggage'}, +# {'id': 122, 'name': 'Nightstand'}, +# {'id': 123, 'name': 'Tea pot'}, +# {'id': 124, 'name': 'Telephone'}, +# {'id': 125, 'name': 'Trolley'}, +# {'id': 126, 'name': 'Head Phone'}, +# {'id': 127, 'name': 'Sports Car'}, +# {'id': 128, 'name': 'Stop Sign'}, +# {'id': 129, 'name': 'Dessert'}, +# {'id': 130, 'name': 'Scooter'}, +# {'id': 131, 'name': 'Stroller'}, +# {'id': 132, 'name': 'Crane'}, +# {'id': 133, 'name': 'Remote'}, +# {'id': 134, 'name': 'Refrigerator'}, +# {'id': 135, 'name': 'Oven'}, +# {'id': 136, 'name': 'Lemon'}, +# {'id': 137, 'name': 'Duck'}, +# {'id': 138, 'name': 'Baseball Bat'}, +# {'id': 139, 'name': 'Surveillance Camera'}, +# {'id': 140, 'name': 'Cat'}, +# {'id': 141, 'name': 'Jug'}, +# {'id': 142, 'name': 'Broccoli'}, +# {'id': 143, 'name': 'Piano'}, +# {'id': 144, 'name': 'Pizza'}, +# {'id': 145, 'name': 'Elephant'}, +# {'id': 146, 'name': 'Skateboard'}, +# {'id': 147, 'name': 'Surfboard'}, +# {'id': 148, 'name': 'Gun'}, +# {'id': 149, 'name': 'Skating and Skiing shoes'}, +# {'id': 150, 'name': 'Gas stove'}, +# {'id': 151, 'name': 'Donut'}, +# {'id': 152, 'name': 'Bow Tie'}, +# {'id': 153, 'name': 'Carrot'}, +# {'id': 154, 'name': 'Toilet'}, +# {'id': 155, 'name': 'Kite'}, +# {'id': 156, 'name': 'Strawberry'}, +# {'id': 157, 'name': 'Other Balls'}, +# {'id': 158, 'name': 'Shovel'}, +# {'id': 159, 'name': 'Pepper'}, +# {'id': 160, 'name': 'Computer Box'}, +# {'id': 161, 'name': 'Toilet Paper'}, +# {'id': 162, 'name': 'Cleaning Products'}, +# {'id': 163, 'name': 'Chopsticks'}, +# {'id': 164, 'name': 'Microwave'}, +# {'id': 165, 'name': 'Pigeon'}, +# {'id': 166, 'name': 'Baseball'}, +# {'id': 167, 'name': 'Cutting/chopping Board'}, +# {'id': 168, 'name': 'Coffee Table'}, +# {'id': 169, 'name': 'Side Table'}, +# {'id': 170, 'name': 'Scissors'}, +# {'id': 171, 'name': 'Marker'}, +# {'id': 172, 'name': 'Pie'}, +# {'id': 173, 'name': 'Ladder'}, +# {'id': 174, 'name': 'Snowboard'}, +# {'id': 175, 'name': 'Cookies'}, +# {'id': 176, 'name': 'Radiator'}, +# {'id': 177, 'name': 'Fire Hydrant'}, +# {'id': 178, 'name': 'Basketball'}, +# {'id': 179, 'name': 'Zebra'}, +# {'id': 180, 'name': 'Grape'}, +# {'id': 181, 'name': 'Giraffe'}, +# {'id': 182, 'name': 'Potato'}, +# {'id': 183, 'name': 'Sausage'}, +# {'id': 184, 'name': 'Tricycle'}, +# {'id': 185, 'name': 'Violin'}, +# {'id': 186, 'name': 'Egg'}, +# {'id': 187, 'name': 'Fire Extinguisher'}, +# {'id': 188, 'name': 'Candy'}, +# {'id': 189, 'name': 'Fire Truck'}, +# {'id': 190, 'name': 'Billards'}, +# {'id': 191, 'name': 'Converter'}, +# {'id': 192, 'name': 'Bathtub'}, +# {'id': 193, 'name': 'Wheelchair'}, +# {'id': 194, 'name': 'Golf Club'}, +# {'id': 195, 'name': 'Briefcase'}, +# {'id': 196, 'name': 'Cucumber'}, +# {'id': 197, 'name': 'Cigar/Cigarette '}, +# {'id': 198, 'name': 'Paint Brush'}, +# {'id': 199, 'name': 'Pear'}, +# {'id': 200, 'name': 'Heavy Truck'}, +# {'id': 201, 'name': 'Hamburger'}, +# {'id': 202, 'name': 'Extractor'}, +# {'id': 203, 'name': 'Extention Cord'}, +# {'id': 204, 'name': 'Tong'}, +# {'id': 205, 'name': 'Tennis Racket'}, +# {'id': 206, 'name': 'Folder'}, +# {'id': 207, 'name': 'American Football'}, +# {'id': 208, 'name': 'earphone'}, +# {'id': 209, 'name': 'Mask'}, +# {'id': 210, 'name': 'Kettle'}, +# {'id': 211, 'name': 'Tennis'}, +# {'id': 212, 'name': 'Ship'}, +# {'id': 213, 'name': 'Swing'}, +# {'id': 214, 'name': 'Coffee Machine'}, +# {'id': 215, 'name': 'Slide'}, +# {'id': 216, 'name': 'Carriage'}, +# {'id': 217, 'name': 'Onion'}, +# {'id': 218, 'name': 'Green beans'}, +# {'id': 219, 'name': 'Projector'}, +# {'id': 220, 'name': 'Frisbee'}, +# {'id': 221, 'name': 'Washing Machine/Drying Machine'}, +# {'id': 222, 'name': 'Chicken'}, +# {'id': 223, 'name': 'Printer'}, +# {'id': 224, 'name': 'Watermelon'}, +# {'id': 225, 'name': 'Saxophone'}, +# {'id': 226, 'name': 'Tissue'}, +# {'id': 227, 'name': 'Toothbrush'}, +# {'id': 228, 'name': 'Ice cream'}, +# {'id': 229, 'name': 'Hotair ballon'}, +# {'id': 230, 'name': 'Cello'}, +# {'id': 231, 'name': 'French Fries'}, +# {'id': 232, 'name': 'Scale'}, +# {'id': 233, 'name': 'Trophy'}, +# {'id': 234, 'name': 'Cabbage'}, +# {'id': 235, 'name': 'Hot dog'}, +# {'id': 236, 'name': 'Blender'}, +# {'id': 237, 'name': 'Peach'}, +# {'id': 238, 'name': 'Rice'}, +# {'id': 239, 'name': 'Wallet/Purse'}, +# {'id': 240, 'name': 'Volleyball'}, +# {'id': 241, 'name': 'Deer'}, +# {'id': 242, 'name': 'Goose'}, +# {'id': 243, 'name': 'Tape'}, +# {'id': 244, 'name': 'Tablet'}, +# {'id': 245, 'name': 'Cosmetics'}, +# {'id': 246, 'name': 'Trumpet'}, +# {'id': 247, 'name': 'Pineapple'}, +# {'id': 248, 'name': 'Golf Ball'}, +# {'id': 249, 'name': 'Ambulance'}, +# {'id': 250, 'name': 'Parking meter'}, +# {'id': 251, 'name': 'Mango'}, +# {'id': 252, 'name': 'Key'}, +# {'id': 253, 'name': 'Hurdle'}, +# {'id': 254, 'name': 'Fishing Rod'}, +# {'id': 255, 'name': 'Medal'}, +# {'id': 256, 'name': 'Flute'}, +# {'id': 257, 'name': 'Brush'}, +# {'id': 258, 'name': 'Penguin'}, +# {'id': 259, 'name': 'Megaphone'}, +# {'id': 260, 'name': 'Corn'}, +# {'id': 261, 'name': 'Lettuce'}, +# {'id': 262, 'name': 'Garlic'}, +# {'id': 263, 'name': 'Swan'}, +# {'id': 264, 'name': 'Helicopter'}, +# {'id': 265, 'name': 'Green Onion'}, +# {'id': 266, 'name': 'Sandwich'}, +# {'id': 267, 'name': 'Nuts'}, +# {'id': 268, 'name': 'Speed Limit Sign'}, +# {'id': 269, 'name': 'Induction Cooker'}, +# {'id': 270, 'name': 'Broom'}, +# {'id': 271, 'name': 'Trombone'}, +# {'id': 272, 'name': 'Plum'}, +# {'id': 273, 'name': 'Rickshaw'}, +# {'id': 274, 'name': 'Goldfish'}, +# {'id': 275, 'name': 'Kiwi fruit'}, +# {'id': 276, 'name': 'Router/modem'}, +# {'id': 277, 'name': 'Poker Card'}, +# {'id': 278, 'name': 'Toaster'}, +# {'id': 279, 'name': 'Shrimp'}, +# {'id': 280, 'name': 'Sushi'}, +# {'id': 281, 'name': 'Cheese'}, +# {'id': 282, 'name': 'Notepaper'}, +# {'id': 283, 'name': 'Cherry'}, +# {'id': 284, 'name': 'Pliers'}, +# {'id': 285, 'name': 'CD'}, +# {'id': 286, 'name': 'Pasta'}, +# {'id': 287, 'name': 'Hammer'}, +# {'id': 288, 'name': 'Cue'}, +# {'id': 289, 'name': 'Avocado'}, +# {'id': 290, 'name': 'Hamimelon'}, +# {'id': 291, 'name': 'Flask'}, +# {'id': 292, 'name': 'Mushroon'}, +# {'id': 293, 'name': 'Screwdriver'}, +# {'id': 294, 'name': 'Soap'}, +# {'id': 295, 'name': 'Recorder'}, +# {'id': 296, 'name': 'Bear'}, +# {'id': 297, 'name': 'Eggplant'}, +# {'id': 298, 'name': 'Board Eraser'}, +# {'id': 299, 'name': 'Coconut'}, +# {'id': 300, 'name': 'Tape Measur/ Ruler'}, +# {'id': 301, 'name': 'Pig'}, +# {'id': 302, 'name': 'Showerhead'}, +# {'id': 303, 'name': 'Globe'}, +# {'id': 304, 'name': 'Chips'}, +# {'id': 305, 'name': 'Steak'}, +# {'id': 306, 'name': 'Crosswalk Sign'}, +# {'id': 307, 'name': 'Stapler'}, +# {'id': 308, 'name': 'Campel'}, +# {'id': 309, 'name': 'Formula 1 '}, +# {'id': 310, 'name': 'Pomegranate'}, +# {'id': 311, 'name': 'Dishwasher'}, +# {'id': 312, 'name': 'Crab'}, +# {'id': 313, 'name': 'Hoverboard'}, +# {'id': 314, 'name': 'Meat ball'}, +# {'id': 315, 'name': 'Rice Cooker'}, +# {'id': 316, 'name': 'Tuba'}, +# {'id': 317, 'name': 'Calculator'}, +# {'id': 318, 'name': 'Papaya'}, +# {'id': 319, 'name': 'Antelope'}, +# {'id': 320, 'name': 'Parrot'}, +# {'id': 321, 'name': 'Seal'}, +# {'id': 322, 'name': 'Buttefly'}, +# {'id': 323, 'name': 'Dumbbell'}, +# {'id': 324, 'name': 'Donkey'}, +# {'id': 325, 'name': 'Lion'}, +# {'id': 326, 'name': 'Urinal'}, +# {'id': 327, 'name': 'Dolphin'}, +# {'id': 328, 'name': 'Electric Drill'}, +# {'id': 329, 'name': 'Hair Dryer'}, +# {'id': 330, 'name': 'Egg tart'}, +# {'id': 331, 'name': 'Jellyfish'}, +# {'id': 332, 'name': 'Treadmill'}, +# {'id': 333, 'name': 'Lighter'}, +# {'id': 334, 'name': 'Grapefruit'}, +# {'id': 335, 'name': 'Game board'}, +# {'id': 336, 'name': 'Mop'}, +# {'id': 337, 'name': 'Radish'}, +# {'id': 338, 'name': 'Baozi'}, +# {'id': 339, 'name': 'Target'}, +# {'id': 340, 'name': 'French'}, +# {'id': 341, 'name': 'Spring Rolls'}, +# {'id': 342, 'name': 'Monkey'}, +# {'id': 343, 'name': 'Rabbit'}, +# {'id': 344, 'name': 'Pencil Case'}, +# {'id': 345, 'name': 'Yak'}, +# {'id': 346, 'name': 'Red Cabbage'}, +# {'id': 347, 'name': 'Binoculars'}, +# {'id': 348, 'name': 'Asparagus'}, +# {'id': 349, 'name': 'Barbell'}, +# {'id': 350, 'name': 'Scallop'}, +# {'id': 351, 'name': 'Noddles'}, +# {'id': 352, 'name': 'Comb'}, +# {'id': 353, 'name': 'Dumpling'}, +# {'id': 354, 'name': 'Oyster'}, +# {'id': 355, 'name': 'Table Teniis paddle'}, +# {'id': 356, 'name': 'Cosmetics Brush/Eyeliner Pencil'}, +# {'id': 357, 'name': 'Chainsaw'}, +# {'id': 358, 'name': 'Eraser'}, +# {'id': 359, 'name': 'Lobster'}, +# {'id': 360, 'name': 'Durian'}, +# {'id': 361, 'name': 'Okra'}, +# {'id': 362, 'name': 'Lipstick'}, +# {'id': 363, 'name': 'Cosmetics Mirror'}, +# {'id': 364, 'name': 'Curling'}, +# {'id': 365, 'name': 'Table Tennis '}, +# ] + +""" +The official Objects365 category names contains typos. +Below is a manual fix. +""" +categories_v2_fix = [ + {"id": 1, "name": "Person"}, + {"id": 2, "name": "Sneakers"}, + {"id": 3, "name": "Chair"}, + {"id": 4, "name": "Other Shoes"}, + {"id": 5, "name": "Hat"}, + {"id": 6, "name": "Car"}, + {"id": 7, "name": "Lamp"}, + {"id": 8, "name": "Glasses"}, + {"id": 9, "name": "Bottle"}, + {"id": 10, "name": "Desk"}, + {"id": 11, "name": "Cup"}, + {"id": 12, "name": "Street Lights"}, + {"id": 13, "name": "Cabinet/shelf"}, + {"id": 14, "name": "Handbag/Satchel"}, + {"id": 15, "name": "Bracelet"}, + {"id": 16, "name": "Plate"}, + {"id": 17, "name": "Picture/Frame"}, + {"id": 18, "name": "Helmet"}, + {"id": 19, "name": "Book"}, + {"id": 20, "name": "Gloves"}, + {"id": 21, "name": "Storage box"}, + {"id": 22, "name": "Boat"}, + {"id": 23, "name": "Leather Shoes"}, + {"id": 24, "name": "Flower"}, + {"id": 25, "name": "Bench"}, + {"id": 26, "name": "Potted Plant"}, + {"id": 27, "name": "Bowl/Basin"}, + {"id": 28, "name": "Flag"}, + {"id": 29, "name": "Pillow"}, + {"id": 30, "name": "Boots"}, + {"id": 31, "name": "Vase"}, + {"id": 32, "name": "Microphone"}, + {"id": 33, "name": "Necklace"}, + {"id": 34, "name": "Ring"}, + {"id": 35, "name": "SUV"}, + {"id": 36, "name": "Wine Glass"}, + {"id": 37, "name": "Belt"}, + {"id": 38, "name": "Monitor/TV"}, + {"id": 39, "name": "Backpack"}, + {"id": 40, "name": "Umbrella"}, + {"id": 41, "name": "Traffic Light"}, + {"id": 42, "name": "Speaker"}, + {"id": 43, "name": "Watch"}, + {"id": 44, "name": "Tie"}, + {"id": 45, "name": "Trash bin Can"}, + {"id": 46, "name": "Slippers"}, + {"id": 47, "name": "Bicycle"}, + {"id": 48, "name": "Stool"}, + {"id": 49, "name": "Barrel/bucket"}, + {"id": 50, "name": "Van"}, + {"id": 51, "name": "Couch"}, + {"id": 52, "name": "Sandals"}, + {"id": 53, "name": "Basket"}, + {"id": 54, "name": "Drum"}, + {"id": 55, "name": "Pen/Pencil"}, + {"id": 56, "name": "Bus"}, + {"id": 57, "name": "Wild Bird"}, + {"id": 58, "name": "High Heels"}, + {"id": 59, "name": "Motorcycle"}, + {"id": 60, "name": "Guitar"}, + {"id": 61, "name": "Carpet"}, + {"id": 62, "name": "Cell Phone"}, + {"id": 63, "name": "Bread"}, + {"id": 64, "name": "Camera"}, + {"id": 65, "name": "Canned"}, + {"id": 66, "name": "Truck"}, + {"id": 67, "name": "Traffic cone"}, + {"id": 68, "name": "Cymbal"}, + {"id": 69, "name": "Lifesaver"}, + {"id": 70, "name": "Towel"}, + {"id": 71, "name": "Stuffed Toy"}, + {"id": 72, "name": "Candle"}, + {"id": 73, "name": "Sailboat"}, + {"id": 74, "name": "Laptop"}, + {"id": 75, "name": "Awning"}, + {"id": 76, "name": "Bed"}, + {"id": 77, "name": "Faucet"}, + {"id": 78, "name": "Tent"}, + {"id": 79, "name": "Horse"}, + {"id": 80, "name": "Mirror"}, + {"id": 81, "name": "Power outlet"}, + {"id": 82, "name": "Sink"}, + {"id": 83, "name": "Apple"}, + {"id": 84, "name": "Air Conditioner"}, + {"id": 85, "name": "Knife"}, + {"id": 86, "name": "Hockey Stick"}, + {"id": 87, "name": "Paddle"}, + {"id": 88, "name": "Pickup Truck"}, + {"id": 89, "name": "Fork"}, + {"id": 90, "name": "Traffic Sign"}, + {"id": 91, "name": "Ballon"}, + {"id": 92, "name": "Tripod"}, + {"id": 93, "name": "Dog"}, + {"id": 94, "name": "Spoon"}, + {"id": 95, "name": "Clock"}, + {"id": 96, "name": "Pot"}, + {"id": 97, "name": "Cow"}, + {"id": 98, "name": "Cake"}, + {"id": 99, "name": "Dining Table"}, + {"id": 100, "name": "Sheep"}, + {"id": 101, "name": "Hanger"}, + {"id": 102, "name": "Blackboard/Whiteboard"}, + {"id": 103, "name": "Napkin"}, + {"id": 104, "name": "Other Fish"}, + {"id": 105, "name": "Orange/Tangerine"}, + {"id": 106, "name": "Toiletry"}, + {"id": 107, "name": "Keyboard"}, + {"id": 108, "name": "Tomato"}, + {"id": 109, "name": "Lantern"}, + {"id": 110, "name": "Machinery Vehicle"}, + {"id": 111, "name": "Fan"}, + {"id": 112, "name": "Green Vegetables"}, + {"id": 113, "name": "Banana"}, + {"id": 114, "name": "Baseball Glove"}, + {"id": 115, "name": "Airplane"}, + {"id": 116, "name": "Mouse"}, + {"id": 117, "name": "Train"}, + {"id": 118, "name": "Pumpkin"}, + {"id": 119, "name": "Soccer"}, + {"id": 120, "name": "Skiboard"}, + {"id": 121, "name": "Luggage"}, + {"id": 122, "name": "Nightstand"}, + {"id": 123, "name": "Teapot"}, + {"id": 124, "name": "Telephone"}, + {"id": 125, "name": "Trolley"}, + {"id": 126, "name": "Head Phone"}, + {"id": 127, "name": "Sports Car"}, + {"id": 128, "name": "Stop Sign"}, + {"id": 129, "name": "Dessert"}, + {"id": 130, "name": "Scooter"}, + {"id": 131, "name": "Stroller"}, + {"id": 132, "name": "Crane"}, + {"id": 133, "name": "Remote"}, + {"id": 134, "name": "Refrigerator"}, + {"id": 135, "name": "Oven"}, + {"id": 136, "name": "Lemon"}, + {"id": 137, "name": "Duck"}, + {"id": 138, "name": "Baseball Bat"}, + {"id": 139, "name": "Surveillance Camera"}, + {"id": 140, "name": "Cat"}, + {"id": 141, "name": "Jug"}, + {"id": 142, "name": "Broccoli"}, + {"id": 143, "name": "Piano"}, + {"id": 144, "name": "Pizza"}, + {"id": 145, "name": "Elephant"}, + {"id": 146, "name": "Skateboard"}, + {"id": 147, "name": "Surfboard"}, + {"id": 148, "name": "Gun"}, + {"id": 149, "name": "Skating and Skiing shoes"}, + {"id": 150, "name": "Gas stove"}, + {"id": 151, "name": "Donut"}, + {"id": 152, "name": "Bow Tie"}, + {"id": 153, "name": "Carrot"}, + {"id": 154, "name": "Toilet"}, + {"id": 155, "name": "Kite"}, + {"id": 156, "name": "Strawberry"}, + {"id": 157, "name": "Other Balls"}, + {"id": 158, "name": "Shovel"}, + {"id": 159, "name": "Pepper"}, + {"id": 160, "name": "Computer Box"}, + {"id": 161, "name": "Toilet Paper"}, + {"id": 162, "name": "Cleaning Products"}, + {"id": 163, "name": "Chopsticks"}, + {"id": 164, "name": "Microwave"}, + {"id": 165, "name": "Pigeon"}, + {"id": 166, "name": "Baseball"}, + {"id": 167, "name": "Cutting/chopping Board"}, + {"id": 168, "name": "Coffee Table"}, + {"id": 169, "name": "Side Table"}, + {"id": 170, "name": "Scissors"}, + {"id": 171, "name": "Marker"}, + {"id": 172, "name": "Pie"}, + {"id": 173, "name": "Ladder"}, + {"id": 174, "name": "Snowboard"}, + {"id": 175, "name": "Cookies"}, + {"id": 176, "name": "Radiator"}, + {"id": 177, "name": "Fire Hydrant"}, + {"id": 178, "name": "Basketball"}, + {"id": 179, "name": "Zebra"}, + {"id": 180, "name": "Grape"}, + {"id": 181, "name": "Giraffe"}, + {"id": 182, "name": "Potato"}, + {"id": 183, "name": "Sausage"}, + {"id": 184, "name": "Tricycle"}, + {"id": 185, "name": "Violin"}, + {"id": 186, "name": "Egg"}, + {"id": 187, "name": "Fire Extinguisher"}, + {"id": 188, "name": "Candy"}, + {"id": 189, "name": "Fire Truck"}, + {"id": 190, "name": "Billards"}, + {"id": 191, "name": "Converter"}, + {"id": 192, "name": "Bathtub"}, + {"id": 193, "name": "Wheelchair"}, + {"id": 194, "name": "Golf Club"}, + {"id": 195, "name": "Briefcase"}, + {"id": 196, "name": "Cucumber"}, + {"id": 197, "name": "Cigar/Cigarette "}, + {"id": 198, "name": "Paint Brush"}, + {"id": 199, "name": "Pear"}, + {"id": 200, "name": "Heavy Truck"}, + {"id": 201, "name": "Hamburger"}, + {"id": 202, "name": "Extractor"}, + {"id": 203, "name": "Extension Cord"}, + {"id": 204, "name": "Tong"}, + {"id": 205, "name": "Tennis Racket"}, + {"id": 206, "name": "Folder"}, + {"id": 207, "name": "American Football"}, + {"id": 208, "name": "earphone"}, + {"id": 209, "name": "Mask"}, + {"id": 210, "name": "Kettle"}, + {"id": 211, "name": "Tennis"}, + {"id": 212, "name": "Ship"}, + {"id": 213, "name": "Swing"}, + {"id": 214, "name": "Coffee Machine"}, + {"id": 215, "name": "Slide"}, + {"id": 216, "name": "Carriage"}, + {"id": 217, "name": "Onion"}, + {"id": 218, "name": "Green beans"}, + {"id": 219, "name": "Projector"}, + {"id": 220, "name": "Frisbee"}, + {"id": 221, "name": "Washing Machine/Drying Machine"}, + {"id": 222, "name": "Chicken"}, + {"id": 223, "name": "Printer"}, + {"id": 224, "name": "Watermelon"}, + {"id": 225, "name": "Saxophone"}, + {"id": 226, "name": "Tissue"}, + {"id": 227, "name": "Toothbrush"}, + {"id": 228, "name": "Ice cream"}, + {"id": 229, "name": "Hot air balloon"}, + {"id": 230, "name": "Cello"}, + {"id": 231, "name": "French Fries"}, + {"id": 232, "name": "Scale"}, + {"id": 233, "name": "Trophy"}, + {"id": 234, "name": "Cabbage"}, + {"id": 235, "name": "Hot dog"}, + {"id": 236, "name": "Blender"}, + {"id": 237, "name": "Peach"}, + {"id": 238, "name": "Rice"}, + {"id": 239, "name": "Wallet/Purse"}, + {"id": 240, "name": "Volleyball"}, + {"id": 241, "name": "Deer"}, + {"id": 242, "name": "Goose"}, + {"id": 243, "name": "Tape"}, + {"id": 244, "name": "Tablet"}, + {"id": 245, "name": "Cosmetics"}, + {"id": 246, "name": "Trumpet"}, + {"id": 247, "name": "Pineapple"}, + {"id": 248, "name": "Golf Ball"}, + {"id": 249, "name": "Ambulance"}, + {"id": 250, "name": "Parking meter"}, + {"id": 251, "name": "Mango"}, + {"id": 252, "name": "Key"}, + {"id": 253, "name": "Hurdle"}, + {"id": 254, "name": "Fishing Rod"}, + {"id": 255, "name": "Medal"}, + {"id": 256, "name": "Flute"}, + {"id": 257, "name": "Brush"}, + {"id": 258, "name": "Penguin"}, + {"id": 259, "name": "Megaphone"}, + {"id": 260, "name": "Corn"}, + {"id": 261, "name": "Lettuce"}, + {"id": 262, "name": "Garlic"}, + {"id": 263, "name": "Swan"}, + {"id": 264, "name": "Helicopter"}, + {"id": 265, "name": "Green Onion"}, + {"id": 266, "name": "Sandwich"}, + {"id": 267, "name": "Nuts"}, + {"id": 268, "name": "Speed Limit Sign"}, + {"id": 269, "name": "Induction Cooker"}, + {"id": 270, "name": "Broom"}, + {"id": 271, "name": "Trombone"}, + {"id": 272, "name": "Plum"}, + {"id": 273, "name": "Rickshaw"}, + {"id": 274, "name": "Goldfish"}, + {"id": 275, "name": "Kiwi fruit"}, + {"id": 276, "name": "Router/modem"}, + {"id": 277, "name": "Poker Card"}, + {"id": 278, "name": "Toaster"}, + {"id": 279, "name": "Shrimp"}, + {"id": 280, "name": "Sushi"}, + {"id": 281, "name": "Cheese"}, + {"id": 282, "name": "Notepaper"}, + {"id": 283, "name": "Cherry"}, + {"id": 284, "name": "Pliers"}, + {"id": 285, "name": "CD"}, + {"id": 286, "name": "Pasta"}, + {"id": 287, "name": "Hammer"}, + {"id": 288, "name": "Cue"}, + {"id": 289, "name": "Avocado"}, + {"id": 290, "name": "Hami melon"}, + {"id": 291, "name": "Flask"}, + {"id": 292, "name": "Mushroom"}, + {"id": 293, "name": "Screwdriver"}, + {"id": 294, "name": "Soap"}, + {"id": 295, "name": "Recorder"}, + {"id": 296, "name": "Bear"}, + {"id": 297, "name": "Eggplant"}, + {"id": 298, "name": "Board Eraser"}, + {"id": 299, "name": "Coconut"}, + {"id": 300, "name": "Tape Measure/ Ruler"}, + {"id": 301, "name": "Pig"}, + {"id": 302, "name": "Showerhead"}, + {"id": 303, "name": "Globe"}, + {"id": 304, "name": "Chips"}, + {"id": 305, "name": "Steak"}, + {"id": 306, "name": "Crosswalk Sign"}, + {"id": 307, "name": "Stapler"}, + {"id": 308, "name": "Camel"}, + {"id": 309, "name": "Formula 1 "}, + {"id": 310, "name": "Pomegranate"}, + {"id": 311, "name": "Dishwasher"}, + {"id": 312, "name": "Crab"}, + {"id": 313, "name": "Hoverboard"}, + {"id": 314, "name": "Meatball"}, + {"id": 315, "name": "Rice Cooker"}, + {"id": 316, "name": "Tuba"}, + {"id": 317, "name": "Calculator"}, + {"id": 318, "name": "Papaya"}, + {"id": 319, "name": "Antelope"}, + {"id": 320, "name": "Parrot"}, + {"id": 321, "name": "Seal"}, + {"id": 322, "name": "Butterfly"}, + {"id": 323, "name": "Dumbbell"}, + {"id": 324, "name": "Donkey"}, + {"id": 325, "name": "Lion"}, + {"id": 326, "name": "Urinal"}, + {"id": 327, "name": "Dolphin"}, + {"id": 328, "name": "Electric Drill"}, + {"id": 329, "name": "Hair Dryer"}, + {"id": 330, "name": "Egg tart"}, + {"id": 331, "name": "Jellyfish"}, + {"id": 332, "name": "Treadmill"}, + {"id": 333, "name": "Lighter"}, + {"id": 334, "name": "Grapefruit"}, + {"id": 335, "name": "Game board"}, + {"id": 336, "name": "Mop"}, + {"id": 337, "name": "Radish"}, + {"id": 338, "name": "Baozi"}, + {"id": 339, "name": "Target"}, + {"id": 340, "name": "French"}, + {"id": 341, "name": "Spring Rolls"}, + {"id": 342, "name": "Monkey"}, + {"id": 343, "name": "Rabbit"}, + {"id": 344, "name": "Pencil Case"}, + {"id": 345, "name": "Yak"}, + {"id": 346, "name": "Red Cabbage"}, + {"id": 347, "name": "Binoculars"}, + {"id": 348, "name": "Asparagus"}, + {"id": 349, "name": "Barbell"}, + {"id": 350, "name": "Scallop"}, + {"id": 351, "name": "Noddles"}, + {"id": 352, "name": "Comb"}, + {"id": 353, "name": "Dumpling"}, + {"id": 354, "name": "Oyster"}, + {"id": 355, "name": "Table Tennis paddle"}, + {"id": 356, "name": "Cosmetics Brush/Eyeliner Pencil"}, + {"id": 357, "name": "Chainsaw"}, + {"id": 358, "name": "Eraser"}, + {"id": 359, "name": "Lobster"}, + {"id": 360, "name": "Durian"}, + {"id": 361, "name": "Okra"}, + {"id": 362, "name": "Lipstick"}, + {"id": 363, "name": "Cosmetics Mirror"}, + {"id": 364, "name": "Curling"}, + {"id": 365, "name": "Table Tennis "}, +] + + +def _get_builtin_metadata(): + id_to_name = {x["id"]: x["name"] for x in categories_v2_fix} + thing_dataset_id_to_contiguous_id = { + x["id"]: i for i, x in enumerate(sorted(categories_v2_fix, key=lambda x: x["id"])) + } + thing_classes = [id_to_name[k] for k in sorted(id_to_name)] + return { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + } + + +_PREDEFINED_SPLITS_OBJECTS365 = { + "objects365_v2_train": ( + "objects365/train", + "objects365/annotations/zhiyuan_objv2_train_fixname_fixmiss.json", + ), + # 80,000 images, 1,240,587 annotations + "objects365_v2_val": ( + "objects365/val", + "objects365/annotations/zhiyuan_objv2_val_fixname.json", + ), + "objects365_v2_val_rare": ( + "objects365/val", + "objects365/annotations/zhiyuan_objv2_val_fixname_rare.json", + ), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_OBJECTS365.items(): + register_coco_instances( + key, + _get_builtin_metadata(), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/detic/data/datasets/oid.py b/dimos/models/Detic/detic/data/datasets/oid.py new file mode 100644 index 0000000000..0308a8da1d --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/oid.py @@ -0,0 +1,544 @@ +# Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/datasets/oid.py +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from .register_oid import register_oid_instances + +categories = [ + {"id": 1, "name": "Infant bed", "freebase_id": "/m/061hd_"}, + {"id": 2, "name": "Rose", "freebase_id": "/m/06m11"}, + {"id": 3, "name": "Flag", "freebase_id": "/m/03120"}, + {"id": 4, "name": "Flashlight", "freebase_id": "/m/01kb5b"}, + {"id": 5, "name": "Sea turtle", "freebase_id": "/m/0120dh"}, + {"id": 6, "name": "Camera", "freebase_id": "/m/0dv5r"}, + {"id": 7, "name": "Animal", "freebase_id": "/m/0jbk"}, + {"id": 8, "name": "Glove", "freebase_id": "/m/0174n1"}, + {"id": 9, "name": "Crocodile", "freebase_id": "/m/09f_2"}, + {"id": 10, "name": "Cattle", "freebase_id": "/m/01xq0k1"}, + {"id": 11, "name": "House", "freebase_id": "/m/03jm5"}, + {"id": 12, "name": "Guacamole", "freebase_id": "/m/02g30s"}, + {"id": 13, "name": "Penguin", "freebase_id": "/m/05z6w"}, + {"id": 14, "name": "Vehicle registration plate", "freebase_id": "/m/01jfm_"}, + {"id": 15, "name": "Bench", "freebase_id": "/m/076lb9"}, + {"id": 16, "name": "Ladybug", "freebase_id": "/m/0gj37"}, + {"id": 17, "name": "Human nose", "freebase_id": "/m/0k0pj"}, + {"id": 18, "name": "Watermelon", "freebase_id": "/m/0kpqd"}, + {"id": 19, "name": "Flute", "freebase_id": "/m/0l14j_"}, + {"id": 20, "name": "Butterfly", "freebase_id": "/m/0cyf8"}, + {"id": 21, "name": "Washing machine", "freebase_id": "/m/0174k2"}, + {"id": 22, "name": "Raccoon", "freebase_id": "/m/0dq75"}, + {"id": 23, "name": "Segway", "freebase_id": "/m/076bq"}, + {"id": 24, "name": "Taco", "freebase_id": "/m/07crc"}, + {"id": 25, "name": "Jellyfish", "freebase_id": "/m/0d8zb"}, + {"id": 26, "name": "Cake", "freebase_id": "/m/0fszt"}, + {"id": 27, "name": "Pen", "freebase_id": "/m/0k1tl"}, + {"id": 28, "name": "Cannon", "freebase_id": "/m/020kz"}, + {"id": 29, "name": "Bread", "freebase_id": "/m/09728"}, + {"id": 30, "name": "Tree", "freebase_id": "/m/07j7r"}, + {"id": 31, "name": "Shellfish", "freebase_id": "/m/0fbdv"}, + {"id": 32, "name": "Bed", "freebase_id": "/m/03ssj5"}, + {"id": 33, "name": "Hamster", "freebase_id": "/m/03qrc"}, + {"id": 34, "name": "Hat", "freebase_id": "/m/02dl1y"}, + {"id": 35, "name": "Toaster", "freebase_id": "/m/01k6s3"}, + {"id": 36, "name": "Sombrero", "freebase_id": "/m/02jfl0"}, + {"id": 37, "name": "Tiara", "freebase_id": "/m/01krhy"}, + {"id": 38, "name": "Bowl", "freebase_id": "/m/04kkgm"}, + {"id": 39, "name": "Dragonfly", "freebase_id": "/m/0ft9s"}, + {"id": 40, "name": "Moths and butterflies", "freebase_id": "/m/0d_2m"}, + {"id": 41, "name": "Antelope", "freebase_id": "/m/0czz2"}, + {"id": 42, "name": "Vegetable", "freebase_id": "/m/0f4s2w"}, + {"id": 43, "name": "Torch", "freebase_id": "/m/07dd4"}, + {"id": 44, "name": "Building", "freebase_id": "/m/0cgh4"}, + {"id": 45, "name": "Power plugs and sockets", "freebase_id": "/m/03bbps"}, + {"id": 46, "name": "Blender", "freebase_id": "/m/02pjr4"}, + {"id": 47, "name": "Billiard table", "freebase_id": "/m/04p0qw"}, + {"id": 48, "name": "Cutting board", "freebase_id": "/m/02pdsw"}, + {"id": 49, "name": "Bronze sculpture", "freebase_id": "/m/01yx86"}, + {"id": 50, "name": "Turtle", "freebase_id": "/m/09dzg"}, + {"id": 51, "name": "Broccoli", "freebase_id": "/m/0hkxq"}, + {"id": 52, "name": "Tiger", "freebase_id": "/m/07dm6"}, + {"id": 53, "name": "Mirror", "freebase_id": "/m/054_l"}, + {"id": 54, "name": "Bear", "freebase_id": "/m/01dws"}, + {"id": 55, "name": "Zucchini", "freebase_id": "/m/027pcv"}, + {"id": 56, "name": "Dress", "freebase_id": "/m/01d40f"}, + {"id": 57, "name": "Volleyball", "freebase_id": "/m/02rgn06"}, + {"id": 58, "name": "Guitar", "freebase_id": "/m/0342h"}, + {"id": 59, "name": "Reptile", "freebase_id": "/m/06bt6"}, + {"id": 60, "name": "Golf cart", "freebase_id": "/m/0323sq"}, + {"id": 61, "name": "Tart", "freebase_id": "/m/02zvsm"}, + {"id": 62, "name": "Fedora", "freebase_id": "/m/02fq_6"}, + {"id": 63, "name": "Carnivore", "freebase_id": "/m/01lrl"}, + {"id": 64, "name": "Car", "freebase_id": "/m/0k4j"}, + {"id": 65, "name": "Lighthouse", "freebase_id": "/m/04h7h"}, + {"id": 66, "name": "Coffeemaker", "freebase_id": "/m/07xyvk"}, + {"id": 67, "name": "Food processor", "freebase_id": "/m/03y6mg"}, + {"id": 68, "name": "Truck", "freebase_id": "/m/07r04"}, + {"id": 69, "name": "Bookcase", "freebase_id": "/m/03__z0"}, + {"id": 70, "name": "Surfboard", "freebase_id": "/m/019w40"}, + {"id": 71, "name": "Footwear", "freebase_id": "/m/09j5n"}, + {"id": 72, "name": "Bench", "freebase_id": "/m/0cvnqh"}, + {"id": 73, "name": "Necklace", "freebase_id": "/m/01llwg"}, + {"id": 74, "name": "Flower", "freebase_id": "/m/0c9ph5"}, + {"id": 75, "name": "Radish", "freebase_id": "/m/015x5n"}, + {"id": 76, "name": "Marine mammal", "freebase_id": "/m/0gd2v"}, + {"id": 77, "name": "Frying pan", "freebase_id": "/m/04v6l4"}, + {"id": 78, "name": "Tap", "freebase_id": "/m/02jz0l"}, + {"id": 79, "name": "Peach", "freebase_id": "/m/0dj6p"}, + {"id": 80, "name": "Knife", "freebase_id": "/m/04ctx"}, + {"id": 81, "name": "Handbag", "freebase_id": "/m/080hkjn"}, + {"id": 82, "name": "Laptop", "freebase_id": "/m/01c648"}, + {"id": 83, "name": "Tent", "freebase_id": "/m/01j61q"}, + {"id": 84, "name": "Ambulance", "freebase_id": "/m/012n7d"}, + {"id": 85, "name": "Christmas tree", "freebase_id": "/m/025nd"}, + {"id": 86, "name": "Eagle", "freebase_id": "/m/09csl"}, + {"id": 87, "name": "Limousine", "freebase_id": "/m/01lcw4"}, + {"id": 88, "name": "Kitchen & dining room table", "freebase_id": "/m/0h8n5zk"}, + {"id": 89, "name": "Polar bear", "freebase_id": "/m/0633h"}, + {"id": 90, "name": "Tower", "freebase_id": "/m/01fdzj"}, + {"id": 91, "name": "Football", "freebase_id": "/m/01226z"}, + {"id": 92, "name": "Willow", "freebase_id": "/m/0mw_6"}, + {"id": 93, "name": "Human head", "freebase_id": "/m/04hgtk"}, + {"id": 94, "name": "Stop sign", "freebase_id": "/m/02pv19"}, + {"id": 95, "name": "Banana", "freebase_id": "/m/09qck"}, + {"id": 96, "name": "Mixer", "freebase_id": "/m/063rgb"}, + {"id": 97, "name": "Binoculars", "freebase_id": "/m/0lt4_"}, + {"id": 98, "name": "Dessert", "freebase_id": "/m/0270h"}, + {"id": 99, "name": "Bee", "freebase_id": "/m/01h3n"}, + {"id": 100, "name": "Chair", "freebase_id": "/m/01mzpv"}, + {"id": 101, "name": "Wood-burning stove", "freebase_id": "/m/04169hn"}, + {"id": 102, "name": "Flowerpot", "freebase_id": "/m/0fm3zh"}, + {"id": 103, "name": "Beaker", "freebase_id": "/m/0d20w4"}, + {"id": 104, "name": "Oyster", "freebase_id": "/m/0_cp5"}, + {"id": 105, "name": "Woodpecker", "freebase_id": "/m/01dy8n"}, + {"id": 106, "name": "Harp", "freebase_id": "/m/03m5k"}, + {"id": 107, "name": "Bathtub", "freebase_id": "/m/03dnzn"}, + {"id": 108, "name": "Wall clock", "freebase_id": "/m/0h8mzrc"}, + {"id": 109, "name": "Sports uniform", "freebase_id": "/m/0h8mhzd"}, + {"id": 110, "name": "Rhinoceros", "freebase_id": "/m/03d443"}, + {"id": 111, "name": "Beehive", "freebase_id": "/m/01gllr"}, + {"id": 112, "name": "Cupboard", "freebase_id": "/m/0642b4"}, + {"id": 113, "name": "Chicken", "freebase_id": "/m/09b5t"}, + {"id": 114, "name": "Man", "freebase_id": "/m/04yx4"}, + {"id": 115, "name": "Blue jay", "freebase_id": "/m/01f8m5"}, + {"id": 116, "name": "Cucumber", "freebase_id": "/m/015x4r"}, + {"id": 117, "name": "Balloon", "freebase_id": "/m/01j51"}, + {"id": 118, "name": "Kite", "freebase_id": "/m/02zt3"}, + {"id": 119, "name": "Fireplace", "freebase_id": "/m/03tw93"}, + {"id": 120, "name": "Lantern", "freebase_id": "/m/01jfsr"}, + {"id": 121, "name": "Missile", "freebase_id": "/m/04ylt"}, + {"id": 122, "name": "Book", "freebase_id": "/m/0bt_c3"}, + {"id": 123, "name": "Spoon", "freebase_id": "/m/0cmx8"}, + {"id": 124, "name": "Grapefruit", "freebase_id": "/m/0hqkz"}, + {"id": 125, "name": "Squirrel", "freebase_id": "/m/071qp"}, + {"id": 126, "name": "Orange", "freebase_id": "/m/0cyhj_"}, + {"id": 127, "name": "Coat", "freebase_id": "/m/01xygc"}, + {"id": 128, "name": "Punching bag", "freebase_id": "/m/0420v5"}, + {"id": 129, "name": "Zebra", "freebase_id": "/m/0898b"}, + {"id": 130, "name": "Billboard", "freebase_id": "/m/01knjb"}, + {"id": 131, "name": "Bicycle", "freebase_id": "/m/0199g"}, + {"id": 132, "name": "Door handle", "freebase_id": "/m/03c7gz"}, + {"id": 133, "name": "Mechanical fan", "freebase_id": "/m/02x984l"}, + {"id": 134, "name": "Ring binder", "freebase_id": "/m/04zwwv"}, + {"id": 135, "name": "Table", "freebase_id": "/m/04bcr3"}, + {"id": 136, "name": "Parrot", "freebase_id": "/m/0gv1x"}, + {"id": 137, "name": "Sock", "freebase_id": "/m/01nq26"}, + {"id": 138, "name": "Vase", "freebase_id": "/m/02s195"}, + {"id": 139, "name": "Weapon", "freebase_id": "/m/083kb"}, + {"id": 140, "name": "Shotgun", "freebase_id": "/m/06nrc"}, + {"id": 141, "name": "Glasses", "freebase_id": "/m/0jyfg"}, + {"id": 142, "name": "Seahorse", "freebase_id": "/m/0nybt"}, + {"id": 143, "name": "Belt", "freebase_id": "/m/0176mf"}, + {"id": 144, "name": "Watercraft", "freebase_id": "/m/01rzcn"}, + {"id": 145, "name": "Window", "freebase_id": "/m/0d4v4"}, + {"id": 146, "name": "Giraffe", "freebase_id": "/m/03bk1"}, + {"id": 147, "name": "Lion", "freebase_id": "/m/096mb"}, + {"id": 148, "name": "Tire", "freebase_id": "/m/0h9mv"}, + {"id": 149, "name": "Vehicle", "freebase_id": "/m/07yv9"}, + {"id": 150, "name": "Canoe", "freebase_id": "/m/0ph39"}, + {"id": 151, "name": "Tie", "freebase_id": "/m/01rkbr"}, + {"id": 152, "name": "Shelf", "freebase_id": "/m/0gjbg72"}, + {"id": 153, "name": "Picture frame", "freebase_id": "/m/06z37_"}, + {"id": 154, "name": "Printer", "freebase_id": "/m/01m4t"}, + {"id": 155, "name": "Human leg", "freebase_id": "/m/035r7c"}, + {"id": 156, "name": "Boat", "freebase_id": "/m/019jd"}, + {"id": 157, "name": "Slow cooker", "freebase_id": "/m/02tsc9"}, + {"id": 158, "name": "Croissant", "freebase_id": "/m/015wgc"}, + {"id": 159, "name": "Candle", "freebase_id": "/m/0c06p"}, + {"id": 160, "name": "Pancake", "freebase_id": "/m/01dwwc"}, + {"id": 161, "name": "Pillow", "freebase_id": "/m/034c16"}, + {"id": 162, "name": "Coin", "freebase_id": "/m/0242l"}, + {"id": 163, "name": "Stretcher", "freebase_id": "/m/02lbcq"}, + {"id": 164, "name": "Sandal", "freebase_id": "/m/03nfch"}, + {"id": 165, "name": "Woman", "freebase_id": "/m/03bt1vf"}, + {"id": 166, "name": "Stairs", "freebase_id": "/m/01lynh"}, + {"id": 167, "name": "Harpsichord", "freebase_id": "/m/03q5t"}, + {"id": 168, "name": "Stool", "freebase_id": "/m/0fqt361"}, + {"id": 169, "name": "Bus", "freebase_id": "/m/01bjv"}, + {"id": 170, "name": "Suitcase", "freebase_id": "/m/01s55n"}, + {"id": 171, "name": "Human mouth", "freebase_id": "/m/0283dt1"}, + {"id": 172, "name": "Juice", "freebase_id": "/m/01z1kdw"}, + {"id": 173, "name": "Skull", "freebase_id": "/m/016m2d"}, + {"id": 174, "name": "Door", "freebase_id": "/m/02dgv"}, + {"id": 175, "name": "Violin", "freebase_id": "/m/07y_7"}, + {"id": 176, "name": "Chopsticks", "freebase_id": "/m/01_5g"}, + {"id": 177, "name": "Digital clock", "freebase_id": "/m/06_72j"}, + {"id": 178, "name": "Sunflower", "freebase_id": "/m/0ftb8"}, + {"id": 179, "name": "Leopard", "freebase_id": "/m/0c29q"}, + {"id": 180, "name": "Bell pepper", "freebase_id": "/m/0jg57"}, + {"id": 181, "name": "Harbor seal", "freebase_id": "/m/02l8p9"}, + {"id": 182, "name": "Snake", "freebase_id": "/m/078jl"}, + {"id": 183, "name": "Sewing machine", "freebase_id": "/m/0llzx"}, + {"id": 184, "name": "Goose", "freebase_id": "/m/0dbvp"}, + {"id": 185, "name": "Helicopter", "freebase_id": "/m/09ct_"}, + {"id": 186, "name": "Seat belt", "freebase_id": "/m/0dkzw"}, + {"id": 187, "name": "Coffee cup", "freebase_id": "/m/02p5f1q"}, + {"id": 188, "name": "Microwave oven", "freebase_id": "/m/0fx9l"}, + {"id": 189, "name": "Hot dog", "freebase_id": "/m/01b9xk"}, + {"id": 190, "name": "Countertop", "freebase_id": "/m/0b3fp9"}, + {"id": 191, "name": "Serving tray", "freebase_id": "/m/0h8n27j"}, + {"id": 192, "name": "Dog bed", "freebase_id": "/m/0h8n6f9"}, + {"id": 193, "name": "Beer", "freebase_id": "/m/01599"}, + {"id": 194, "name": "Sunglasses", "freebase_id": "/m/017ftj"}, + {"id": 195, "name": "Golf ball", "freebase_id": "/m/044r5d"}, + {"id": 196, "name": "Waffle", "freebase_id": "/m/01dwsz"}, + {"id": 197, "name": "Palm tree", "freebase_id": "/m/0cdl1"}, + {"id": 198, "name": "Trumpet", "freebase_id": "/m/07gql"}, + {"id": 199, "name": "Ruler", "freebase_id": "/m/0hdln"}, + {"id": 200, "name": "Helmet", "freebase_id": "/m/0zvk5"}, + {"id": 201, "name": "Ladder", "freebase_id": "/m/012w5l"}, + {"id": 202, "name": "Office building", "freebase_id": "/m/021sj1"}, + {"id": 203, "name": "Tablet computer", "freebase_id": "/m/0bh9flk"}, + {"id": 204, "name": "Toilet paper", "freebase_id": "/m/09gtd"}, + {"id": 205, "name": "Pomegranate", "freebase_id": "/m/0jwn_"}, + {"id": 206, "name": "Skirt", "freebase_id": "/m/02wv6h6"}, + {"id": 207, "name": "Gas stove", "freebase_id": "/m/02wv84t"}, + {"id": 208, "name": "Cookie", "freebase_id": "/m/021mn"}, + {"id": 209, "name": "Cart", "freebase_id": "/m/018p4k"}, + {"id": 210, "name": "Raven", "freebase_id": "/m/06j2d"}, + {"id": 211, "name": "Egg", "freebase_id": "/m/033cnk"}, + {"id": 212, "name": "Burrito", "freebase_id": "/m/01j3zr"}, + {"id": 213, "name": "Goat", "freebase_id": "/m/03fwl"}, + {"id": 214, "name": "Kitchen knife", "freebase_id": "/m/058qzx"}, + {"id": 215, "name": "Skateboard", "freebase_id": "/m/06_fw"}, + {"id": 216, "name": "Salt and pepper shakers", "freebase_id": "/m/02x8cch"}, + {"id": 217, "name": "Lynx", "freebase_id": "/m/04g2r"}, + {"id": 218, "name": "Boot", "freebase_id": "/m/01b638"}, + {"id": 219, "name": "Platter", "freebase_id": "/m/099ssp"}, + {"id": 220, "name": "Ski", "freebase_id": "/m/071p9"}, + {"id": 221, "name": "Swimwear", "freebase_id": "/m/01gkx_"}, + {"id": 222, "name": "Swimming pool", "freebase_id": "/m/0b_rs"}, + {"id": 223, "name": "Drinking straw", "freebase_id": "/m/03v5tg"}, + {"id": 224, "name": "Wrench", "freebase_id": "/m/01j5ks"}, + {"id": 225, "name": "Drum", "freebase_id": "/m/026t6"}, + {"id": 226, "name": "Ant", "freebase_id": "/m/0_k2"}, + {"id": 227, "name": "Human ear", "freebase_id": "/m/039xj_"}, + {"id": 228, "name": "Headphones", "freebase_id": "/m/01b7fy"}, + {"id": 229, "name": "Fountain", "freebase_id": "/m/0220r2"}, + {"id": 230, "name": "Bird", "freebase_id": "/m/015p6"}, + {"id": 231, "name": "Jeans", "freebase_id": "/m/0fly7"}, + {"id": 232, "name": "Television", "freebase_id": "/m/07c52"}, + {"id": 233, "name": "Crab", "freebase_id": "/m/0n28_"}, + {"id": 234, "name": "Microphone", "freebase_id": "/m/0hg7b"}, + {"id": 235, "name": "Home appliance", "freebase_id": "/m/019dx1"}, + {"id": 236, "name": "Snowplow", "freebase_id": "/m/04vv5k"}, + {"id": 237, "name": "Beetle", "freebase_id": "/m/020jm"}, + {"id": 238, "name": "Artichoke", "freebase_id": "/m/047v4b"}, + {"id": 239, "name": "Jet ski", "freebase_id": "/m/01xs3r"}, + {"id": 240, "name": "Stationary bicycle", "freebase_id": "/m/03kt2w"}, + {"id": 241, "name": "Human hair", "freebase_id": "/m/03q69"}, + {"id": 242, "name": "Brown bear", "freebase_id": "/m/01dxs"}, + {"id": 243, "name": "Starfish", "freebase_id": "/m/01h8tj"}, + {"id": 244, "name": "Fork", "freebase_id": "/m/0dt3t"}, + {"id": 245, "name": "Lobster", "freebase_id": "/m/0cjq5"}, + {"id": 246, "name": "Corded phone", "freebase_id": "/m/0h8lkj8"}, + {"id": 247, "name": "Drink", "freebase_id": "/m/0271t"}, + {"id": 248, "name": "Saucer", "freebase_id": "/m/03q5c7"}, + {"id": 249, "name": "Carrot", "freebase_id": "/m/0fj52s"}, + {"id": 250, "name": "Insect", "freebase_id": "/m/03vt0"}, + {"id": 251, "name": "Clock", "freebase_id": "/m/01x3z"}, + {"id": 252, "name": "Castle", "freebase_id": "/m/0d5gx"}, + {"id": 253, "name": "Tennis racket", "freebase_id": "/m/0h8my_4"}, + {"id": 254, "name": "Ceiling fan", "freebase_id": "/m/03ldnb"}, + {"id": 255, "name": "Asparagus", "freebase_id": "/m/0cjs7"}, + {"id": 256, "name": "Jaguar", "freebase_id": "/m/0449p"}, + {"id": 257, "name": "Musical instrument", "freebase_id": "/m/04szw"}, + {"id": 258, "name": "Train", "freebase_id": "/m/07jdr"}, + {"id": 259, "name": "Cat", "freebase_id": "/m/01yrx"}, + {"id": 260, "name": "Rifle", "freebase_id": "/m/06c54"}, + {"id": 261, "name": "Dumbbell", "freebase_id": "/m/04h8sr"}, + {"id": 262, "name": "Mobile phone", "freebase_id": "/m/050k8"}, + {"id": 263, "name": "Taxi", "freebase_id": "/m/0pg52"}, + {"id": 264, "name": "Shower", "freebase_id": "/m/02f9f_"}, + {"id": 265, "name": "Pitcher", "freebase_id": "/m/054fyh"}, + {"id": 266, "name": "Lemon", "freebase_id": "/m/09k_b"}, + {"id": 267, "name": "Invertebrate", "freebase_id": "/m/03xxp"}, + {"id": 268, "name": "Turkey", "freebase_id": "/m/0jly1"}, + {"id": 269, "name": "High heels", "freebase_id": "/m/06k2mb"}, + {"id": 270, "name": "Bust", "freebase_id": "/m/04yqq2"}, + {"id": 271, "name": "Elephant", "freebase_id": "/m/0bwd_0j"}, + {"id": 272, "name": "Scarf", "freebase_id": "/m/02h19r"}, + {"id": 273, "name": "Barrel", "freebase_id": "/m/02zn6n"}, + {"id": 274, "name": "Trombone", "freebase_id": "/m/07c6l"}, + {"id": 275, "name": "Pumpkin", "freebase_id": "/m/05zsy"}, + {"id": 276, "name": "Box", "freebase_id": "/m/025dyy"}, + {"id": 277, "name": "Tomato", "freebase_id": "/m/07j87"}, + {"id": 278, "name": "Frog", "freebase_id": "/m/09ld4"}, + {"id": 279, "name": "Bidet", "freebase_id": "/m/01vbnl"}, + {"id": 280, "name": "Human face", "freebase_id": "/m/0dzct"}, + {"id": 281, "name": "Houseplant", "freebase_id": "/m/03fp41"}, + {"id": 282, "name": "Van", "freebase_id": "/m/0h2r6"}, + {"id": 283, "name": "Shark", "freebase_id": "/m/0by6g"}, + {"id": 284, "name": "Ice cream", "freebase_id": "/m/0cxn2"}, + {"id": 285, "name": "Swim cap", "freebase_id": "/m/04tn4x"}, + {"id": 286, "name": "Falcon", "freebase_id": "/m/0f6wt"}, + {"id": 287, "name": "Ostrich", "freebase_id": "/m/05n4y"}, + {"id": 288, "name": "Handgun", "freebase_id": "/m/0gxl3"}, + {"id": 289, "name": "Whiteboard", "freebase_id": "/m/02d9qx"}, + {"id": 290, "name": "Lizard", "freebase_id": "/m/04m9y"}, + {"id": 291, "name": "Pasta", "freebase_id": "/m/05z55"}, + {"id": 292, "name": "Snowmobile", "freebase_id": "/m/01x3jk"}, + {"id": 293, "name": "Light bulb", "freebase_id": "/m/0h8l4fh"}, + {"id": 294, "name": "Window blind", "freebase_id": "/m/031b6r"}, + {"id": 295, "name": "Muffin", "freebase_id": "/m/01tcjp"}, + {"id": 296, "name": "Pretzel", "freebase_id": "/m/01f91_"}, + {"id": 297, "name": "Computer monitor", "freebase_id": "/m/02522"}, + {"id": 298, "name": "Horn", "freebase_id": "/m/0319l"}, + {"id": 299, "name": "Furniture", "freebase_id": "/m/0c_jw"}, + {"id": 300, "name": "Sandwich", "freebase_id": "/m/0l515"}, + {"id": 301, "name": "Fox", "freebase_id": "/m/0306r"}, + {"id": 302, "name": "Convenience store", "freebase_id": "/m/0crjs"}, + {"id": 303, "name": "Fish", "freebase_id": "/m/0ch_cf"}, + {"id": 304, "name": "Fruit", "freebase_id": "/m/02xwb"}, + {"id": 305, "name": "Earrings", "freebase_id": "/m/01r546"}, + {"id": 306, "name": "Curtain", "freebase_id": "/m/03rszm"}, + {"id": 307, "name": "Grape", "freebase_id": "/m/0388q"}, + {"id": 308, "name": "Sofa bed", "freebase_id": "/m/03m3pdh"}, + {"id": 309, "name": "Horse", "freebase_id": "/m/03k3r"}, + {"id": 310, "name": "Luggage and bags", "freebase_id": "/m/0hf58v5"}, + {"id": 311, "name": "Desk", "freebase_id": "/m/01y9k5"}, + {"id": 312, "name": "Crutch", "freebase_id": "/m/05441v"}, + {"id": 313, "name": "Bicycle helmet", "freebase_id": "/m/03p3bw"}, + {"id": 314, "name": "Tick", "freebase_id": "/m/0175cv"}, + {"id": 315, "name": "Airplane", "freebase_id": "/m/0cmf2"}, + {"id": 316, "name": "Canary", "freebase_id": "/m/0ccs93"}, + {"id": 317, "name": "Spatula", "freebase_id": "/m/02d1br"}, + {"id": 318, "name": "Watch", "freebase_id": "/m/0gjkl"}, + {"id": 319, "name": "Lily", "freebase_id": "/m/0jqgx"}, + {"id": 320, "name": "Kitchen appliance", "freebase_id": "/m/0h99cwc"}, + {"id": 321, "name": "Filing cabinet", "freebase_id": "/m/047j0r"}, + {"id": 322, "name": "Aircraft", "freebase_id": "/m/0k5j"}, + {"id": 323, "name": "Cake stand", "freebase_id": "/m/0h8n6ft"}, + {"id": 324, "name": "Candy", "freebase_id": "/m/0gm28"}, + {"id": 325, "name": "Sink", "freebase_id": "/m/0130jx"}, + {"id": 326, "name": "Mouse", "freebase_id": "/m/04rmv"}, + {"id": 327, "name": "Wine", "freebase_id": "/m/081qc"}, + {"id": 328, "name": "Wheelchair", "freebase_id": "/m/0qmmr"}, + {"id": 329, "name": "Goldfish", "freebase_id": "/m/03fj2"}, + {"id": 330, "name": "Refrigerator", "freebase_id": "/m/040b_t"}, + {"id": 331, "name": "French fries", "freebase_id": "/m/02y6n"}, + {"id": 332, "name": "Drawer", "freebase_id": "/m/0fqfqc"}, + {"id": 333, "name": "Treadmill", "freebase_id": "/m/030610"}, + {"id": 334, "name": "Picnic basket", "freebase_id": "/m/07kng9"}, + {"id": 335, "name": "Dice", "freebase_id": "/m/029b3"}, + {"id": 336, "name": "Cabbage", "freebase_id": "/m/0fbw6"}, + {"id": 337, "name": "Football helmet", "freebase_id": "/m/07qxg_"}, + {"id": 338, "name": "Pig", "freebase_id": "/m/068zj"}, + {"id": 339, "name": "Person", "freebase_id": "/m/01g317"}, + {"id": 340, "name": "Shorts", "freebase_id": "/m/01bfm9"}, + {"id": 341, "name": "Gondola", "freebase_id": "/m/02068x"}, + {"id": 342, "name": "Honeycomb", "freebase_id": "/m/0fz0h"}, + {"id": 343, "name": "Doughnut", "freebase_id": "/m/0jy4k"}, + {"id": 344, "name": "Chest of drawers", "freebase_id": "/m/05kyg_"}, + {"id": 345, "name": "Land vehicle", "freebase_id": "/m/01prls"}, + {"id": 346, "name": "Bat", "freebase_id": "/m/01h44"}, + {"id": 347, "name": "Monkey", "freebase_id": "/m/08pbxl"}, + {"id": 348, "name": "Dagger", "freebase_id": "/m/02gzp"}, + {"id": 349, "name": "Tableware", "freebase_id": "/m/04brg2"}, + {"id": 350, "name": "Human foot", "freebase_id": "/m/031n1"}, + {"id": 351, "name": "Mug", "freebase_id": "/m/02jvh9"}, + {"id": 352, "name": "Alarm clock", "freebase_id": "/m/046dlr"}, + {"id": 353, "name": "Pressure cooker", "freebase_id": "/m/0h8ntjv"}, + {"id": 354, "name": "Human hand", "freebase_id": "/m/0k65p"}, + {"id": 355, "name": "Tortoise", "freebase_id": "/m/011k07"}, + {"id": 356, "name": "Baseball glove", "freebase_id": "/m/03grzl"}, + {"id": 357, "name": "Sword", "freebase_id": "/m/06y5r"}, + {"id": 358, "name": "Pear", "freebase_id": "/m/061_f"}, + {"id": 359, "name": "Miniskirt", "freebase_id": "/m/01cmb2"}, + {"id": 360, "name": "Traffic sign", "freebase_id": "/m/01mqdt"}, + {"id": 361, "name": "Girl", "freebase_id": "/m/05r655"}, + {"id": 362, "name": "Roller skates", "freebase_id": "/m/02p3w7d"}, + {"id": 363, "name": "Dinosaur", "freebase_id": "/m/029tx"}, + {"id": 364, "name": "Porch", "freebase_id": "/m/04m6gz"}, + {"id": 365, "name": "Human beard", "freebase_id": "/m/015h_t"}, + {"id": 366, "name": "Submarine sandwich", "freebase_id": "/m/06pcq"}, + {"id": 367, "name": "Screwdriver", "freebase_id": "/m/01bms0"}, + {"id": 368, "name": "Strawberry", "freebase_id": "/m/07fbm7"}, + {"id": 369, "name": "Wine glass", "freebase_id": "/m/09tvcd"}, + {"id": 370, "name": "Seafood", "freebase_id": "/m/06nwz"}, + {"id": 371, "name": "Racket", "freebase_id": "/m/0dv9c"}, + {"id": 372, "name": "Wheel", "freebase_id": "/m/083wq"}, + {"id": 373, "name": "Sea lion", "freebase_id": "/m/0gd36"}, + {"id": 374, "name": "Toy", "freebase_id": "/m/0138tl"}, + {"id": 375, "name": "Tea", "freebase_id": "/m/07clx"}, + {"id": 376, "name": "Tennis ball", "freebase_id": "/m/05ctyq"}, + {"id": 377, "name": "Waste container", "freebase_id": "/m/0bjyj5"}, + {"id": 378, "name": "Mule", "freebase_id": "/m/0dbzx"}, + {"id": 379, "name": "Cricket ball", "freebase_id": "/m/02ctlc"}, + {"id": 380, "name": "Pineapple", "freebase_id": "/m/0fp6w"}, + {"id": 381, "name": "Coconut", "freebase_id": "/m/0djtd"}, + {"id": 382, "name": "Doll", "freebase_id": "/m/0167gd"}, + {"id": 383, "name": "Coffee table", "freebase_id": "/m/078n6m"}, + {"id": 384, "name": "Snowman", "freebase_id": "/m/0152hh"}, + {"id": 385, "name": "Lavender", "freebase_id": "/m/04gth"}, + {"id": 386, "name": "Shrimp", "freebase_id": "/m/0ll1f78"}, + {"id": 387, "name": "Maple", "freebase_id": "/m/0cffdh"}, + {"id": 388, "name": "Cowboy hat", "freebase_id": "/m/025rp__"}, + {"id": 389, "name": "Goggles", "freebase_id": "/m/02_n6y"}, + {"id": 390, "name": "Rugby ball", "freebase_id": "/m/0wdt60w"}, + {"id": 391, "name": "Caterpillar", "freebase_id": "/m/0cydv"}, + {"id": 392, "name": "Poster", "freebase_id": "/m/01n5jq"}, + {"id": 393, "name": "Rocket", "freebase_id": "/m/09rvcxw"}, + {"id": 394, "name": "Organ", "freebase_id": "/m/013y1f"}, + {"id": 395, "name": "Saxophone", "freebase_id": "/m/06ncr"}, + {"id": 396, "name": "Traffic light", "freebase_id": "/m/015qff"}, + {"id": 397, "name": "Cocktail", "freebase_id": "/m/024g6"}, + {"id": 398, "name": "Plastic bag", "freebase_id": "/m/05gqfk"}, + {"id": 399, "name": "Squash", "freebase_id": "/m/0dv77"}, + {"id": 400, "name": "Mushroom", "freebase_id": "/m/052sf"}, + {"id": 401, "name": "Hamburger", "freebase_id": "/m/0cdn1"}, + {"id": 402, "name": "Light switch", "freebase_id": "/m/03jbxj"}, + {"id": 403, "name": "Parachute", "freebase_id": "/m/0cyfs"}, + {"id": 404, "name": "Teddy bear", "freebase_id": "/m/0kmg4"}, + {"id": 405, "name": "Winter melon", "freebase_id": "/m/02cvgx"}, + {"id": 406, "name": "Deer", "freebase_id": "/m/09kx5"}, + {"id": 407, "name": "Musical keyboard", "freebase_id": "/m/057cc"}, + {"id": 408, "name": "Plumbing fixture", "freebase_id": "/m/02pkr5"}, + {"id": 409, "name": "Scoreboard", "freebase_id": "/m/057p5t"}, + {"id": 410, "name": "Baseball bat", "freebase_id": "/m/03g8mr"}, + {"id": 411, "name": "Envelope", "freebase_id": "/m/0frqm"}, + {"id": 412, "name": "Adhesive tape", "freebase_id": "/m/03m3vtv"}, + {"id": 413, "name": "Briefcase", "freebase_id": "/m/0584n8"}, + {"id": 414, "name": "Paddle", "freebase_id": "/m/014y4n"}, + {"id": 415, "name": "Bow and arrow", "freebase_id": "/m/01g3x7"}, + {"id": 416, "name": "Telephone", "freebase_id": "/m/07cx4"}, + {"id": 417, "name": "Sheep", "freebase_id": "/m/07bgp"}, + {"id": 418, "name": "Jacket", "freebase_id": "/m/032b3c"}, + {"id": 419, "name": "Boy", "freebase_id": "/m/01bl7v"}, + {"id": 420, "name": "Pizza", "freebase_id": "/m/0663v"}, + {"id": 421, "name": "Otter", "freebase_id": "/m/0cn6p"}, + {"id": 422, "name": "Office supplies", "freebase_id": "/m/02rdsp"}, + {"id": 423, "name": "Couch", "freebase_id": "/m/02crq1"}, + {"id": 424, "name": "Cello", "freebase_id": "/m/01xqw"}, + {"id": 425, "name": "Bull", "freebase_id": "/m/0cnyhnx"}, + {"id": 426, "name": "Camel", "freebase_id": "/m/01x_v"}, + {"id": 427, "name": "Ball", "freebase_id": "/m/018xm"}, + {"id": 428, "name": "Duck", "freebase_id": "/m/09ddx"}, + {"id": 429, "name": "Whale", "freebase_id": "/m/084zz"}, + {"id": 430, "name": "Shirt", "freebase_id": "/m/01n4qj"}, + {"id": 431, "name": "Tank", "freebase_id": "/m/07cmd"}, + {"id": 432, "name": "Motorcycle", "freebase_id": "/m/04_sv"}, + {"id": 433, "name": "Accordion", "freebase_id": "/m/0mkg"}, + {"id": 434, "name": "Owl", "freebase_id": "/m/09d5_"}, + {"id": 435, "name": "Porcupine", "freebase_id": "/m/0c568"}, + {"id": 436, "name": "Sun hat", "freebase_id": "/m/02wbtzl"}, + {"id": 437, "name": "Nail", "freebase_id": "/m/05bm6"}, + {"id": 438, "name": "Scissors", "freebase_id": "/m/01lsmm"}, + {"id": 439, "name": "Swan", "freebase_id": "/m/0dftk"}, + {"id": 440, "name": "Lamp", "freebase_id": "/m/0dtln"}, + {"id": 441, "name": "Crown", "freebase_id": "/m/0nl46"}, + {"id": 442, "name": "Piano", "freebase_id": "/m/05r5c"}, + {"id": 443, "name": "Sculpture", "freebase_id": "/m/06msq"}, + {"id": 444, "name": "Cheetah", "freebase_id": "/m/0cd4d"}, + {"id": 445, "name": "Oboe", "freebase_id": "/m/05kms"}, + {"id": 446, "name": "Tin can", "freebase_id": "/m/02jnhm"}, + {"id": 447, "name": "Mango", "freebase_id": "/m/0fldg"}, + {"id": 448, "name": "Tripod", "freebase_id": "/m/073bxn"}, + {"id": 449, "name": "Oven", "freebase_id": "/m/029bxz"}, + {"id": 450, "name": "Mouse", "freebase_id": "/m/020lf"}, + {"id": 451, "name": "Barge", "freebase_id": "/m/01btn"}, + {"id": 452, "name": "Coffee", "freebase_id": "/m/02vqfm"}, + {"id": 453, "name": "Snowboard", "freebase_id": "/m/06__v"}, + {"id": 454, "name": "Common fig", "freebase_id": "/m/043nyj"}, + {"id": 455, "name": "Salad", "freebase_id": "/m/0grw1"}, + {"id": 456, "name": "Marine invertebrates", "freebase_id": "/m/03hl4l9"}, + {"id": 457, "name": "Umbrella", "freebase_id": "/m/0hnnb"}, + {"id": 458, "name": "Kangaroo", "freebase_id": "/m/04c0y"}, + {"id": 459, "name": "Human arm", "freebase_id": "/m/0dzf4"}, + {"id": 460, "name": "Measuring cup", "freebase_id": "/m/07v9_z"}, + {"id": 461, "name": "Snail", "freebase_id": "/m/0f9_l"}, + {"id": 462, "name": "Loveseat", "freebase_id": "/m/0703r8"}, + {"id": 463, "name": "Suit", "freebase_id": "/m/01xyhv"}, + {"id": 464, "name": "Teapot", "freebase_id": "/m/01fh4r"}, + {"id": 465, "name": "Bottle", "freebase_id": "/m/04dr76w"}, + {"id": 466, "name": "Alpaca", "freebase_id": "/m/0pcr"}, + {"id": 467, "name": "Kettle", "freebase_id": "/m/03s_tn"}, + {"id": 468, "name": "Trousers", "freebase_id": "/m/07mhn"}, + {"id": 469, "name": "Popcorn", "freebase_id": "/m/01hrv5"}, + {"id": 470, "name": "Centipede", "freebase_id": "/m/019h78"}, + {"id": 471, "name": "Spider", "freebase_id": "/m/09kmb"}, + {"id": 472, "name": "Sparrow", "freebase_id": "/m/0h23m"}, + {"id": 473, "name": "Plate", "freebase_id": "/m/050gv4"}, + {"id": 474, "name": "Bagel", "freebase_id": "/m/01fb_0"}, + {"id": 475, "name": "Personal care", "freebase_id": "/m/02w3_ws"}, + {"id": 476, "name": "Apple", "freebase_id": "/m/014j1m"}, + {"id": 477, "name": "Brassiere", "freebase_id": "/m/01gmv2"}, + {"id": 478, "name": "Bathroom cabinet", "freebase_id": "/m/04y4h8h"}, + {"id": 479, "name": "studio couch", "freebase_id": "/m/026qbn5"}, + {"id": 480, "name": "Computer keyboard", "freebase_id": "/m/01m2v"}, + {"id": 481, "name": "Table tennis racket", "freebase_id": "/m/05_5p_0"}, + {"id": 482, "name": "Sushi", "freebase_id": "/m/07030"}, + {"id": 483, "name": "Cabinetry", "freebase_id": "/m/01s105"}, + {"id": 484, "name": "Street light", "freebase_id": "/m/033rq4"}, + {"id": 485, "name": "Towel", "freebase_id": "/m/0162_1"}, + {"id": 486, "name": "Nightstand", "freebase_id": "/m/02z51p"}, + {"id": 487, "name": "Rabbit", "freebase_id": "/m/06mf6"}, + {"id": 488, "name": "Dolphin", "freebase_id": "/m/02hj4"}, + {"id": 489, "name": "Dog", "freebase_id": "/m/0bt9lr"}, + {"id": 490, "name": "Jug", "freebase_id": "/m/08hvt4"}, + {"id": 491, "name": "Wok", "freebase_id": "/m/084rd"}, + {"id": 492, "name": "Fire hydrant", "freebase_id": "/m/01pns0"}, + {"id": 493, "name": "Human eye", "freebase_id": "/m/014sv8"}, + {"id": 494, "name": "Skyscraper", "freebase_id": "/m/079cl"}, + {"id": 495, "name": "Backpack", "freebase_id": "/m/01940j"}, + {"id": 496, "name": "Potato", "freebase_id": "/m/05vtc"}, + {"id": 497, "name": "Paper towel", "freebase_id": "/m/02w3r3"}, + {"id": 498, "name": "Lifejacket", "freebase_id": "/m/054xkw"}, + {"id": 499, "name": "Bicycle wheel", "freebase_id": "/m/01bqk0"}, + {"id": 500, "name": "Toilet", "freebase_id": "/m/09g1w"}, +] + + +def _get_builtin_metadata(cats): + {x["id"]: x["name"] for x in cats} + thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(cats))} + thing_classes = [x["name"] for x in sorted(cats, key=lambda x: x["id"])] + return { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + } + + +_PREDEFINED_SPLITS_OID = { + # cat threshold: 500, 1500: r 170, c 151, f 179 + "oid_train": ("oid/images/", "oid/annotations/oid_challenge_2019_train_bbox.json"), + # "expanded" duplicates annotations to their father classes based on the official + # hierarchy. This is used in the official evaulation protocol. + # https://storage.googleapis.com/openimages/web/evaluation.html + "oid_val_expanded": ( + "oid/images/validation/", + "oid/annotations/oid_challenge_2019_val_expanded.json", + ), + "oid_val_expanded_rare": ( + "oid/images/validation/", + "oid/annotations/oid_challenge_2019_val_expanded_rare.json", + ), +} + + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_OID.items(): + register_oid_instances( + key, + _get_builtin_metadata(categories), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/detic/data/datasets/register_oid.py b/dimos/models/Detic/detic/data/datasets/register_oid.py new file mode 100644 index 0000000000..0739556041 --- /dev/null +++ b/dimos/models/Detic/detic/data/datasets/register_oid.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Xingyi Zhou from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/coco.py +import contextlib +import io +import logging +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from fvcore.common.file_io import PathManager +from fvcore.common.timer import Timer +from typing import Optional + +logger = logging.getLogger(__name__) + +""" +This file contains functions to register a COCO-format dataset to the DatasetCatalog. +""" + +__all__ = ["register_coco_instances", "register_coco_panoptic_separated"] + + +def register_oid_instances(name: str, metadata, json_file, image_root) -> None: + """ """ + # 1. register a function which returns dicts + DatasetCatalog.register(name, lambda: load_coco_json_mem_efficient(json_file, image_root, name)) + + # 2. Optionally, add metadata about this dataset, + # since they might be useful in evaluation, visualization or logging + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="oid", **metadata + ) + + +def load_coco_json_mem_efficient( + json_file, image_root, dataset_name: Optional[str]=None, extra_annotation_keys=None +): + """ + Actually not mem efficient + """ + from pycocotools.coco import COCO + + timer = Timer() + json_file = PathManager.get_local_path(json_file) + with contextlib.redirect_stdout(io.StringIO()): + coco_api = COCO(json_file) + if timer.seconds() > 1: + logger.info(f"Loading {json_file} takes {timer.seconds():.2f} seconds.") + + id_map = None + if dataset_name is not None: + meta = MetadataCatalog.get(dataset_name) + cat_ids = sorted(coco_api.getCatIds()) + cats = coco_api.loadCats(cat_ids) + # The categories in a custom json file may not be sorted. + thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])] + meta.thing_classes = thing_classes + + if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)): + if "coco" not in dataset_name: + logger.warning( + """ + Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you. + """ + ) + id_map = {v: i for i, v in enumerate(cat_ids)} + meta.thing_dataset_id_to_contiguous_id = id_map + + # sort indices for reproducible results + img_ids = sorted(coco_api.imgs.keys()) + imgs = coco_api.loadImgs(img_ids) + logger.info(f"Loaded {len(imgs)} images in COCO format from {json_file}") + + dataset_dicts = [] + + ann_keys = ["iscrowd", "bbox", "category_id"] + (extra_annotation_keys or []) + + for img_dict in imgs: + record = {} + record["file_name"] = os.path.join(image_root, img_dict["file_name"]) + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + image_id = record["image_id"] = img_dict["id"] + anno_dict_list = coco_api.imgToAnns[image_id] + if "neg_category_ids" in img_dict: + record["neg_category_ids"] = [id_map[x] for x in img_dict["neg_category_ids"]] + + objs = [] + for anno in anno_dict_list: + assert anno["image_id"] == image_id + + assert anno.get("ignore", 0) == 0 + + obj = {key: anno[key] for key in ann_keys if key in anno} + + segm = anno.get("segmentation", None) + if segm: # either list[list[float]] or dict(RLE) + if not isinstance(segm, dict): + # filter out invalid polygons (< 3 points) + segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + if len(segm) == 0: + num_instances_without_valid_segmentation += 1 + continue # ignore this instance + obj["segmentation"] = segm + + obj["bbox_mode"] = BoxMode.XYWH_ABS + + if id_map: + obj["category_id"] = id_map[obj["category_id"]] + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + + del coco_api + return dataset_dicts diff --git a/dimos/models/Detic/detic/data/tar_dataset.py b/dimos/models/Detic/detic/data/tar_dataset.py new file mode 100644 index 0000000000..8c87a056d1 --- /dev/null +++ b/dimos/models/Detic/detic/data/tar_dataset.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +import gzip +import io +import os + +import numpy as np +from PIL import Image +from torch.utils.data import Dataset + +try: + from PIL import UnidentifiedImageError + + unidentified_error_available = True +except ImportError: + # UnidentifiedImageError isn't available in older versions of PIL + unidentified_error_available = False + + +class DiskTarDataset(Dataset): + def __init__( + self, + tarfile_path: str="dataset/imagenet/ImageNet-21k/metadata/tar_files.npy", + tar_index_dir: str="dataset/imagenet/ImageNet-21k/metadata/tarindex_npy", + preload: bool=False, + num_synsets: str="all", + ) -> None: + """ + - preload (bool): Recommend to set preload to False when using + - num_synsets (integer or string "all"): set to small number for debugging + will load subset of dataset + """ + tar_files = np.load(tarfile_path) + + chunk_datasets = [] + dataset_lens = [] + if isinstance(num_synsets, int): + assert num_synsets < len(tar_files) + tar_files = tar_files[:num_synsets] + for tar_file in tar_files: + dataset = _TarDataset(tar_file, tar_index_dir, preload=preload) + chunk_datasets.append(dataset) + dataset_lens.append(len(dataset)) + + self.chunk_datasets = chunk_datasets + self.dataset_lens = np.array(dataset_lens).astype(np.int32) + self.dataset_cumsums = np.cumsum(self.dataset_lens) + self.num_samples = sum(self.dataset_lens) + labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64) + sI = 0 + for k in range(len(self.dataset_lens)): + assert (sI + self.dataset_lens[k]) <= len(labels), ( + f"{k} {sI + self.dataset_lens[k]} vs. {len(labels)}" + ) + labels[sI : (sI + self.dataset_lens[k])] = k + sI += self.dataset_lens[k] + self.labels = labels + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, index): + assert index >= 0 and index < len(self) + # find the dataset file we need to go to + d_index = np.searchsorted(self.dataset_cumsums, index) + + # edge case, if index is at edge of chunks, move right + if index in self.dataset_cumsums: + d_index += 1 + + assert d_index == self.labels[index], ( + f"{d_index} vs. {self.labels[index]} mismatch for {index}" + ) + + # change index to local dataset index + if d_index == 0: + local_index = index + else: + local_index = index - self.dataset_cumsums[d_index - 1] + data_bytes = self.chunk_datasets[d_index][local_index] + exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception + try: + image = Image.open(data_bytes).convert("RGB") + except exception_to_catch: + image = Image.fromarray(np.ones((224, 224, 3), dtype=np.uint8) * 128) + d_index = -1 + + # label is the dataset (synset) we indexed into + return image, d_index, index + + def __repr__(self) -> str: + st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})" + return st + + +class _TarDataset: + def __init__(self, filename, npy_index_dir, preload: bool=False) -> None: + # translated from + # fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua + self.filename = filename + self.names = [] + self.offsets = [] + self.npy_index_dir = npy_index_dir + names, offsets = self.load_index() + + self.num_samples = len(names) + if preload: + self.data = np.memmap(filename, mode="r", dtype="uint8") + self.offsets = offsets + else: + self.data = None + + def __len__(self) -> int: + return self.num_samples + + def load_index(self): + basename = os.path.basename(self.filename) + basename = os.path.splitext(basename)[0] + names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy")) + offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy")) + return names, offsets + + def __getitem__(self, idx: int): + if self.data is None: + self.data = np.memmap(self.filename, mode="r", dtype="uint8") + _, self.offsets = self.load_index() + + ofs = self.offsets[idx] * 512 + fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx]) + data = self.data[ofs : ofs + fsize] + + if data[:13].tostring() == "././@LongLink": + data = data[3 * 512 :] + else: + data = data[512:] + + # just to make it more fun a few JPEGs are GZIP compressed... + # catch this case + if tuple(data[:2]) == (0x1F, 0x8B): + s = io.BytesIO(data.tostring()) + g = gzip.GzipFile(None, "r", 0, s) + sdata = g.read() + else: + sdata = data.tostring() + return io.BytesIO(sdata) diff --git a/dimos/models/Detic/detic/data/transforms/custom_augmentation_impl.py b/dimos/models/Detic/detic/data/transforms/custom_augmentation_impl.py new file mode 100644 index 0000000000..7cabc91e0f --- /dev/null +++ b/dimos/models/Detic/detic/data/transforms/custom_augmentation_impl.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py +# Modified by Xingyi Zhou +# The original code is under Apache-2.0 License +from detectron2.data.transforms.augmentation import Augmentation +import numpy as np +from PIL import Image + +from .custom_transform import EfficientDetResizeCropTransform + +__all__ = [ + "EfficientDetResizeCrop", +] + + +class EfficientDetResizeCrop(Augmentation): + """ + Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge. + If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. + """ + + def __init__(self, size: int, scale, interp=Image.BILINEAR) -> None: + """ """ + super().__init__() + self.target_size = (size, size) + self.scale = scale + self.interp = interp + + def get_transform(self, img): + # Select a random scale factor. + scale_factor = np.random.uniform(*self.scale) + scaled_target_height = scale_factor * self.target_size[0] + scaled_target_width = scale_factor * self.target_size[1] + # Recompute the accurate scale_factor using rounded scaled image size. + width, height = img.shape[1], img.shape[0] + img_scale_y = scaled_target_height / height + img_scale_x = scaled_target_width / width + img_scale = min(img_scale_y, img_scale_x) + + # Select non-zero random offset (x, y) if scaled image is larger than target size + scaled_h = int(height * img_scale) + scaled_w = int(width * img_scale) + offset_y = scaled_h - self.target_size[0] + offset_x = scaled_w - self.target_size[1] + offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1)) + offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1)) + return EfficientDetResizeCropTransform( + scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp + ) diff --git a/dimos/models/Detic/detic/data/transforms/custom_transform.py b/dimos/models/Detic/detic/data/transforms/custom_transform.py new file mode 100644 index 0000000000..2017c27a5f --- /dev/null +++ b/dimos/models/Detic/detic/data/transforms/custom_transform.py @@ -0,0 +1,102 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py +# Modified by Xingyi Zhou +# The original code is under Apache-2.0 License +from fvcore.transforms.transform import ( + Transform, +) +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F + +try: + import cv2 +except ImportError: + # OpenCV is an optional dependency at the moment + pass + +__all__ = [ + "EfficientDetResizeCropTransform", +] + + +class EfficientDetResizeCropTransform(Transform): + """ """ + + def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, target_size: int, interp=None) -> None: + """ + Args: + h, w (int): original image size + new_h, new_w (int): new image size + interp: PIL interpolation methods, defaults to bilinear. + """ + # TODO decide on PIL vs opencv + super().__init__() + if interp is None: + interp = Image.BILINEAR + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + assert len(img.shape) <= 4 + + if img.dtype == np.uint8: + pil_image = Image.fromarray(img) + interp_method = interp if interp is not None else self.interp + pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method) + ret = np.asarray(pil_image) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + if len(ret.shape) <= 3: + ret = ret[self.offset_y : lower, self.offset_x : right] + else: + ret = ret[..., self.offset_y : lower, self.offset_x : right, :] + else: + # PIL only supports uint8 + img = torch.from_numpy(img) + shape = list(img.shape) + shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] + img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw + _PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"} + mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp] + img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False) + shape[:2] = (self.scaled_h, self.scaled_w) + ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + if len(ret.shape) <= 3: + ret = ret[self.offset_y : lower, self.offset_x : right] + else: + ret = ret[..., self.offset_y : lower, self.offset_x : right, :] + return ret + + def apply_coords(self, coords): + coords[:, 0] = coords[:, 0] * self.img_scale + coords[:, 1] = coords[:, 1] * self.img_scale + coords[:, 0] -= self.offset_x + coords[:, 1] -= self.offset_y + return coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + def inverse(self): + raise NotImplementedError + + def inverse_apply_coords(self, coords): + coords[:, 0] += self.offset_x + coords[:, 1] += self.offset_y + coords[:, 0] = coords[:, 0] / self.img_scale + coords[:, 1] = coords[:, 1] / self.img_scale + return coords + + def inverse_apply_box(self, box: np.ndarray) -> np.ndarray: + """ """ + idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten() + coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2) + coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2)) + minxy = coords.min(axis=1) + maxxy = coords.max(axis=1) + trans_boxes = np.concatenate((minxy, maxxy), axis=1) + return trans_boxes diff --git a/dimos/models/Detic/detic/evaluation/custom_coco_eval.py b/dimos/models/Detic/detic/evaluation/custom_coco_eval.py new file mode 100644 index 0000000000..759d885f00 --- /dev/null +++ b/dimos/models/Detic/detic/evaluation/custom_coco_eval.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools + +from detectron2.evaluation.coco_evaluation import COCOEvaluator +from detectron2.utils.logger import create_small_table +import numpy as np +from tabulate import tabulate + +from ..data.datasets.coco_zeroshot import categories_seen, categories_unseen +from typing import Optional, Sequence + + +class CustomCOCOEvaluator(COCOEvaluator): + def _derive_coco_results(self, coco_eval, iou_type, class_names: Optional[Sequence[str]]=None): + """ + Additionally plot mAP for 'seen classes' and 'unseen classes' + """ + + metrics = { + "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "keypoints": ["AP", "AP50", "AP75", "APm", "APl"], + }[iou_type] + + if coco_eval is None: + self._logger.warn("No predictions from the model!") + return {metric: float("nan") for metric in metrics} + + # the standard metrics + results = { + metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan") + for idx, metric in enumerate(metrics) + } + self._logger.info( + f"Evaluation results for {iou_type}: \n" + create_small_table(results) + ) + if not np.isfinite(sum(results.values())): + self._logger.info("Some metrics cannot be computed and is shown as NaN.") + + if class_names is None or len(class_names) <= 1: + return results + # Compute per-category AP + # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 + precisions = coco_eval.eval["precision"] + # precision has dims (iou, recall, cls, area range, max dets) + assert len(class_names) == precisions.shape[2] + + seen_names = set([x["name"] for x in categories_seen]) + unseen_names = set([x["name"] for x in categories_unseen]) + results_per_category = [] + results_per_category50 = [] + results_per_category50_seen = [] + results_per_category50_unseen = [] + for idx, name in enumerate(class_names): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + ap = np.mean(precision) if precision.size else float("nan") + results_per_category.append((f"{name}", float(ap * 100))) + precision50 = precisions[0, :, idx, 0, -1] + precision50 = precision50[precision50 > -1] + ap50 = np.mean(precision50) if precision50.size else float("nan") + results_per_category50.append((f"{name}", float(ap50 * 100))) + if name in seen_names: + results_per_category50_seen.append(float(ap50 * 100)) + if name in unseen_names: + results_per_category50_unseen.append(float(ap50 * 100)) + + # tabulate it + N_COLS = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + results_2d, + tablefmt="pipe", + floatfmt=".3f", + headers=["category", "AP"] * (N_COLS // 2), + numalign="left", + ) + self._logger.info(f"Per-category {iou_type} AP: \n" + table) + + N_COLS = min(6, len(results_per_category50) * 2) + results_flatten = list(itertools.chain(*results_per_category50)) + results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + results_2d, + tablefmt="pipe", + floatfmt=".3f", + headers=["category", "AP50"] * (N_COLS // 2), + numalign="left", + ) + self._logger.info(f"Per-category {iou_type} AP50: \n" + table) + self._logger.info( + f"Seen {iou_type} AP50: {sum(results_per_category50_seen) / len(results_per_category50_seen)}" + ) + self._logger.info( + f"Unseen {iou_type} AP50: {sum(results_per_category50_unseen) / len(results_per_category50_unseen)}" + ) + + results.update({"AP-" + name: ap for name, ap in results_per_category}) + results["AP50-seen"] = sum(results_per_category50_seen) / len(results_per_category50_seen) + results["AP50-unseen"] = sum(results_per_category50_unseen) / len( + results_per_category50_unseen + ) + return results diff --git a/dimos/models/Detic/detic/evaluation/oideval.py b/dimos/models/Detic/detic/evaluation/oideval.py new file mode 100644 index 0000000000..aa5a954aef --- /dev/null +++ b/dimos/models/Detic/detic/evaluation/oideval.py @@ -0,0 +1,683 @@ +# Part of the code is from https://github.com/tensorflow/models/blob/master/research/object_detection/metrics/oid_challenge_evaluation.py +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# The original code is under Apache License, Version 2.0 (the "License"); +# Part of the code is from https://github.com/lvis-dataset/lvis-api/blob/master/lvis/eval.py +# Copyright (c) 2019, Agrim Gupta and Ross Girshick +# Modified by Xingyi Zhou +# This script re-implement OpenImages evaluation in detectron2 +# The code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/evaluation/oideval.py +# The original code is under Apache-2.0 License +# Copyright (c) Facebook, Inc. and its affiliates. +from collections import OrderedDict, defaultdict +import copy +import datetime +import itertools +import json +import logging +import os + +from detectron2.data import MetadataCatalog +from detectron2.evaluation import DatasetEvaluator +from detectron2.evaluation.coco_evaluation import instances_to_coco_json +import detectron2.utils.comm as comm +from detectron2.utils.logger import create_small_table +from fvcore.common.file_io import PathManager +from lvis.lvis import LVIS +from lvis.results import LVISResults +import numpy as np +import pycocotools.mask as mask_utils +from tabulate import tabulate +import torch +from typing import Optional, Sequence + + +def compute_average_precision(precision, recall): + """Compute Average Precision according to the definition in VOCdevkit. + Precision is modified to ensure that it does not decrease as recall + decrease. + Args: + precision: A float [N, 1] numpy array of precisions + recall: A float [N, 1] numpy array of recalls + Raises: + ValueError: if the input is not of the correct format + Returns: + average_precison: The area under the precision recall curve. NaN if + precision and recall are None. + """ + if precision is None: + if recall is not None: + raise ValueError("If precision is None, recall must also be None") + return np.NAN + + if not isinstance(precision, np.ndarray) or not isinstance(recall, np.ndarray): + raise ValueError("precision and recall must be numpy array") + if precision.dtype != np.float or recall.dtype != np.float: + raise ValueError("input must be float numpy array.") + if len(precision) != len(recall): + raise ValueError("precision and recall must be of the same size.") + if not precision.size: + return 0.0 + if np.amin(precision) < 0 or np.amax(precision) > 1: + raise ValueError("Precision must be in the range of [0, 1].") + if np.amin(recall) < 0 or np.amax(recall) > 1: + raise ValueError("recall must be in the range of [0, 1].") + if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): + raise ValueError("recall must be a non-decreasing array") + + recall = np.concatenate([[0], recall, [1]]) + precision = np.concatenate([[0], precision, [0]]) + + for i in range(len(precision) - 2, -1, -1): + precision[i] = np.maximum(precision[i], precision[i + 1]) + indices = np.where(recall[1:] != recall[:-1])[0] + 1 + average_precision = np.sum((recall[indices] - recall[indices - 1]) * precision[indices]) + return average_precision + + +class OIDEval: + def __init__( + self, + lvis_gt, + lvis_dt, + iou_type: str="bbox", + expand_pred_label: bool=False, + oid_hierarchy_path: str="./datasets/oid/annotations/challenge-2019-label500-hierarchy.json", + ) -> None: + """Constructor for OIDEval. + Args: + lvis_gt (LVIS class instance, or str containing path of annotation file) + lvis_dt (LVISResult class instance, or str containing path of result file, + or list of dict) + iou_type (str): segm or bbox evaluation + """ + self.logger = logging.getLogger(__name__) + + if iou_type not in ["bbox", "segm"]: + raise ValueError(f"iou_type: {iou_type} is not supported.") + + if isinstance(lvis_gt, LVIS): + self.lvis_gt = lvis_gt + elif isinstance(lvis_gt, str): + self.lvis_gt = LVIS(lvis_gt) + else: + raise TypeError(f"Unsupported type {lvis_gt} of lvis_gt.") + + if isinstance(lvis_dt, LVISResults): + self.lvis_dt = lvis_dt + elif isinstance(lvis_dt, str | list): + # self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt, max_dets=-1) + self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt) + else: + raise TypeError(f"Unsupported type {lvis_dt} of lvis_dt.") + + if expand_pred_label: + oid_hierarchy = json.load(open(oid_hierarchy_path)) + cat_info = self.lvis_gt.dataset["categories"] + freebase2id = {x["freebase_id"]: x["id"] for x in cat_info} + {x["id"]: x["freebase_id"] for x in cat_info} + {x["id"]: x["name"] for x in cat_info} + + fas = defaultdict(set) + + def dfs(hierarchy, cur_id): + all_childs = set() + if "Subcategory" in hierarchy: + for x in hierarchy["Subcategory"]: + childs = dfs(x, freebase2id[x["LabelName"]]) + all_childs.update(childs) + if cur_id != -1: + for c in all_childs: + fas[c].add(cur_id) + all_childs.add(cur_id) + return all_childs + + dfs(oid_hierarchy, -1) + + expanded_pred = [] + id_count = 0 + for d in self.lvis_dt.dataset["annotations"]: + cur_id = d["category_id"] + ids = [cur_id] + [x for x in fas[cur_id]] + for cat_id in ids: + new_box = copy.deepcopy(d) + id_count = id_count + 1 + new_box["id"] = id_count + new_box["category_id"] = cat_id + expanded_pred.append(new_box) + + print( + "Expanding original {} preds to {} preds".format( + len(self.lvis_dt.dataset["annotations"]), len(expanded_pred) + ) + ) + self.lvis_dt.dataset["annotations"] = expanded_pred + self.lvis_dt._create_index() + + # per-image per-category evaluation results + self.eval_imgs = defaultdict(list) + self.eval = {} # accumulated evaluation results + self._gts = defaultdict(list) # gt for evaluation + self._dts = defaultdict(list) # dt for evaluation + self.params = Params(iou_type=iou_type) # parameters + self.results = OrderedDict() + self.ious = {} # ious between all gts and dts + + self.params.img_ids = sorted(self.lvis_gt.get_img_ids()) + self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids()) + + def _to_mask(self, anns, lvis) -> None: + for ann in anns: + rle = lvis.ann_to_rle(ann) + ann["segmentation"] = rle + + def _prepare(self) -> None: + """Prepare self._gts and self._dts for evaluation based on params.""" + + cat_ids = self.params.cat_ids if self.params.cat_ids else None + + gts = self.lvis_gt.load_anns( + self.lvis_gt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids) + ) + dts = self.lvis_dt.load_anns( + self.lvis_dt.get_ann_ids(img_ids=self.params.img_ids, cat_ids=cat_ids) + ) + # convert ground truth to mask if iou_type == 'segm' + if self.params.iou_type == "segm": + self._to_mask(gts, self.lvis_gt) + self._to_mask(dts, self.lvis_dt) + + for gt in gts: + self._gts[gt["image_id"], gt["category_id"]].append(gt) + + # For federated dataset evaluation we will filter out all dt for an + # image which belong to categories not present in gt and not present in + # the negative list for an image. In other words detector is not penalized + # for categories about which we don't have gt information about their + # presence or absence in an image. + img_data = self.lvis_gt.load_imgs(ids=self.params.img_ids) + # per image map of categories not present in image + img_nl = {d["id"]: d["neg_category_ids"] for d in img_data} + # per image list of categories present in image + img_pl = {d["id"]: d["pos_category_ids"] for d in img_data} + # img_pl = defaultdict(set) + for ann in gts: + # img_pl[ann["image_id"]].add(ann["category_id"]) + assert ann["category_id"] in img_pl[ann["image_id"]] + # print('check pos ids OK.') + + for dt in dts: + img_id, cat_id = dt["image_id"], dt["category_id"] + if cat_id not in img_nl[img_id] and cat_id not in img_pl[img_id]: + continue + self._dts[img_id, cat_id].append(dt) + + def evaluate(self) -> None: + """ + Run per image evaluation on given images and store results + (a list of dict) in self.eval_imgs. + """ + self.logger.info("Running per image evaluation.") + self.logger.info(f"Evaluate annotation type *{self.params.iou_type}*") + + self.params.img_ids = list(np.unique(self.params.img_ids)) + + if self.params.use_cats: + cat_ids = self.params.cat_ids + else: + cat_ids = [-1] + + self._prepare() + + self.ious = { + (img_id, cat_id): self.compute_iou(img_id, cat_id) + for img_id in self.params.img_ids + for cat_id in cat_ids + } + + # loop through images, area range, max detection number + print("Evaluating ...") + self.eval_imgs = [ + self.evaluate_img_google(img_id, cat_id, area_rng) + for cat_id in cat_ids + for area_rng in self.params.area_rng + for img_id in self.params.img_ids + ] + + def _get_gt_dt(self, img_id, cat_id): + """Create gt, dt which are list of anns/dets. If use_cats is true + only anns/dets corresponding to tuple (img_id, cat_id) will be + used. Else, all anns/dets in image are used and cat_id is not used. + """ + if self.params.use_cats: + gt = self._gts[img_id, cat_id] + dt = self._dts[img_id, cat_id] + else: + gt = [_ann for _cat_id in self.params.cat_ids for _ann in self._gts[img_id, cat_id]] + dt = [_ann for _cat_id in self.params.cat_ids for _ann in self._dts[img_id, cat_id]] + return gt, dt + + def compute_iou(self, img_id, cat_id): + gt, dt = self._get_gt_dt(img_id, cat_id) + + if len(gt) == 0 and len(dt) == 0: + return [] + + # Sort detections in decreasing order of score. + idx = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in idx] + + # iscrowd = [int(False)] * len(gt) + iscrowd = [int("iscrowd" in g and g["iscrowd"] > 0) for g in gt] + + if self.params.iou_type == "segm": + ann_type = "segmentation" + elif self.params.iou_type == "bbox": + ann_type = "bbox" + else: + raise ValueError("Unknown iou_type for iou computation.") + gt = [g[ann_type] for g in gt] + dt = [d[ann_type] for d in dt] + + # compute iou between each dt and gt region + # will return array of shape len(dt), len(gt) + ious = mask_utils.iou(dt, gt, iscrowd) + return ious + + def evaluate_img_google(self, img_id, cat_id, area_rng): + gt, dt = self._get_gt_dt(img_id, cat_id) + if len(gt) == 0 and len(dt) == 0: + return None + + if len(dt) == 0: + return { + "image_id": img_id, + "category_id": cat_id, + "area_rng": area_rng, + "dt_ids": [], + "dt_matches": np.array([], dtype=np.int32).reshape(1, -1), + "dt_scores": [], + "dt_ignore": np.array([], dtype=np.int32).reshape(1, -1), + "num_gt": len(gt), + } + + no_crowd_inds = [i for i, g in enumerate(gt) if ("iscrowd" not in g) or g["iscrowd"] == 0] + crowd_inds = [i for i, g in enumerate(gt) if "iscrowd" in g and g["iscrowd"] == 1] + dt_idx = np.argsort([-d["score"] for d in dt], kind="mergesort") + + if len(self.ious[img_id, cat_id]) > 0: + ious = self.ious[img_id, cat_id] + iou = ious[:, no_crowd_inds] + iou = iou[dt_idx] + ioa = ious[:, crowd_inds] + ioa = ioa[dt_idx] + else: + iou = np.zeros((len(dt_idx), 0)) + ioa = np.zeros((len(dt_idx), 0)) + scores = np.array([dt[i]["score"] for i in dt_idx]) + + num_detected_boxes = len(dt) + tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool) + is_matched_to_group_of = np.zeros(num_detected_boxes, dtype=bool) + + def compute_match_iou(iou) -> None: + max_overlap_gt_ids = np.argmax(iou, axis=1) + is_gt_detected = np.zeros(iou.shape[1], dtype=bool) + for i in range(num_detected_boxes): + gt_id = max_overlap_gt_ids[i] + is_evaluatable = ( + not tp_fp_labels[i] and iou[i, gt_id] >= 0.5 and not is_matched_to_group_of[i] + ) + if is_evaluatable: + if not is_gt_detected[gt_id]: + tp_fp_labels[i] = True + is_gt_detected[gt_id] = True + + def compute_match_ioa(ioa): + scores_group_of = np.zeros(ioa.shape[1], dtype=float) + tp_fp_labels_group_of = np.ones(ioa.shape[1], dtype=float) + max_overlap_group_of_gt_ids = np.argmax(ioa, axis=1) + for i in range(num_detected_boxes): + gt_id = max_overlap_group_of_gt_ids[i] + is_evaluatable = ( + not tp_fp_labels[i] and ioa[i, gt_id] >= 0.5 and not is_matched_to_group_of[i] + ) + if is_evaluatable: + is_matched_to_group_of[i] = True + scores_group_of[gt_id] = max(scores_group_of[gt_id], scores[i]) + selector = np.where((scores_group_of > 0) & (tp_fp_labels_group_of > 0)) + scores_group_of = scores_group_of[selector] + tp_fp_labels_group_of = tp_fp_labels_group_of[selector] + + return scores_group_of, tp_fp_labels_group_of + + if iou.shape[1] > 0: + compute_match_iou(iou) + + scores_box_group_of = np.ndarray([0], dtype=float) + tp_fp_labels_box_group_of = np.ndarray([0], dtype=float) + + if ioa.shape[1] > 0: + scores_box_group_of, tp_fp_labels_box_group_of = compute_match_ioa(ioa) + + valid_entries = ~is_matched_to_group_of + + scores = np.concatenate((scores[valid_entries], scores_box_group_of)) + tp_fps = np.concatenate( + (tp_fp_labels[valid_entries].astype(float), tp_fp_labels_box_group_of) + ) + + return { + "image_id": img_id, + "category_id": cat_id, + "area_rng": area_rng, + "dt_matches": np.array([1 if x > 0 else 0 for x in tp_fps], dtype=np.int32).reshape( + 1, -1 + ), + "dt_scores": [x for x in scores], + "dt_ignore": np.array([0 for x in scores], dtype=np.int32).reshape(1, -1), + "num_gt": len(gt), + } + + def accumulate(self) -> None: + """Accumulate per image evaluation results and store the result in + self.eval. + """ + self.logger.info("Accumulating evaluation results.") + + if not self.eval_imgs: + self.logger.warn("Please run evaluate first.") + + if self.params.use_cats: + cat_ids = self.params.cat_ids + else: + cat_ids = [-1] + + num_thrs = 1 + num_recalls = 1 + + num_cats = len(cat_ids) + num_area_rngs = 1 + num_imgs = len(self.params.img_ids) + + # -1 for absent categories + precision = -np.ones((num_thrs, num_recalls, num_cats, num_area_rngs)) + recall = -np.ones((num_thrs, num_cats, num_area_rngs)) + + # Initialize dt_pointers + dt_pointers = {} + for cat_idx in range(num_cats): + dt_pointers[cat_idx] = {} + for area_idx in range(num_area_rngs): + dt_pointers[cat_idx][area_idx] = {} + + # Per category evaluation + for cat_idx in range(num_cats): + Nk = cat_idx * num_area_rngs * num_imgs + for area_idx in range(num_area_rngs): + Na = area_idx * num_imgs + E = [self.eval_imgs[Nk + Na + img_idx] for img_idx in range(num_imgs)] + # Remove elements which are None + E = [e for e in E if e is not None] + if len(E) == 0: + continue + + dt_scores = np.concatenate([e["dt_scores"] for e in E], axis=0) + dt_idx = np.argsort(-dt_scores, kind="mergesort") + dt_scores = dt_scores[dt_idx] + dt_m = np.concatenate([e["dt_matches"] for e in E], axis=1)[:, dt_idx] + dt_ig = np.concatenate([e["dt_ignore"] for e in E], axis=1)[:, dt_idx] + + num_gt = sum([e["num_gt"] for e in E]) + if num_gt == 0: + continue + + tps = np.logical_and(dt_m, np.logical_not(dt_ig)) + fps = np.logical_and(np.logical_not(dt_m), np.logical_not(dt_ig)) + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) + fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) + + dt_pointers[cat_idx][area_idx] = { + "tps": tps, + "fps": fps, + } + + for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum, strict=False)): + tp = np.array(tp) + fp = np.array(fp) + num_tp = len(tp) + rc = tp / num_gt + + if num_tp: + recall[iou_thr_idx, cat_idx, area_idx] = rc[-1] + else: + recall[iou_thr_idx, cat_idx, area_idx] = 0 + + # np.spacing(1) ~= eps + pr = tp / (fp + tp + np.spacing(1)) + pr = pr.tolist() + + for i in range(num_tp - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + mAP = compute_average_precision( + np.array(pr, np.float).reshape(-1), np.array(rc, np.float).reshape(-1) + ) + precision[iou_thr_idx, :, cat_idx, area_idx] = mAP + + self.eval = { + "params": self.params, + "counts": [num_thrs, num_recalls, num_cats, num_area_rngs], + "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "precision": precision, + "recall": recall, + "dt_pointers": dt_pointers, + } + + def _summarize(self, summary_type): + s = self.eval["precision"] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + # print(s.reshape(1, 1, -1, 1)) + return mean_s + + def summarize(self): + """Compute and display summary metrics for evaluation results.""" + if not self.eval: + raise RuntimeError("Please run accumulate() first.") + + self.results["AP50"] = self._summarize("ap") + + def run(self) -> None: + """Wrapper function which calculates the results.""" + self.evaluate() + self.accumulate() + self.summarize() + + def print_results(self) -> None: + template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}" + + for key, value in self.results.items(): + max_dets = self.params.max_dets + if "AP" in key: + title = "Average Precision" + _type = "(AP)" + else: + title = "Average Recall" + _type = "(AR)" + + if len(key) > 2 and key[2].isdigit(): + iou_thr = float(key[2:]) / 100 + iou = f"{iou_thr:0.2f}" + else: + iou = f"{self.params.iou_thrs[0]:0.2f}:{self.params.iou_thrs[-1]:0.2f}" + + cat_group_name = "all" + area_rng = "all" + + print(template.format(title, _type, iou, area_rng, max_dets, cat_group_name, value)) + + def get_results(self): + if not self.results: + self.logger.warn("results is empty. Call run().") + return self.results + + +class Params: + def __init__(self, iou_type) -> None: + self.img_ids = [] + self.cat_ids = [] + # np.arange causes trouble. the data point on arange is slightly + # larger than the true value + self.iou_thrs = np.linspace( + 0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True + ) + self.google_style = True + # print('Using google style PR curve') + self.iou_thrs = self.iou_thrs[:1] + self.max_dets = 1000 + + self.area_rng = [ + [0**2, 1e5**2], + ] + self.area_rng_lbl = ["all"] + self.use_cats = 1 + self.iou_type = iou_type + + +class OIDEvaluator(DatasetEvaluator): + def __init__(self, dataset_name: str, cfg, distributed, output_dir=None) -> None: + self._distributed = distributed + self._output_dir = output_dir + + self._cpu_device = torch.device("cpu") + self._logger = logging.getLogger(__name__) + + self._metadata = MetadataCatalog.get(dataset_name) + json_file = PathManager.get_local_path(self._metadata.json_file) + self._oid_api = LVIS(json_file) + # Test set json files do not contain annotations (evaluation must be + # performed using the LVIS evaluation server). + self._do_evaluation = len(self._oid_api.get_ann_ids()) > 0 + self._mask_on = cfg.MODEL.MASK_ON + + def reset(self) -> None: + self._predictions = [] + self._oid_results = [] + + def process(self, inputs, outputs) -> None: + for input, output in zip(inputs, outputs, strict=False): + prediction = {"image_id": input["image_id"]} + instances = output["instances"].to(self._cpu_device) + prediction["instances"] = instances_to_coco_json(instances, input["image_id"]) + self._predictions.append(prediction) + + def evaluate(self): + if self._distributed: + comm.synchronize() + self._predictions = comm.gather(self._predictions, dst=0) + self._predictions = list(itertools.chain(*self._predictions)) + + if not comm.is_main_process(): + return + + if len(self._predictions) == 0: + self._logger.warning("[LVISEvaluator] Did not receive valid predictions.") + return {} + + self._logger.info("Preparing results in the OID format ...") + self._oid_results = list(itertools.chain(*[x["instances"] for x in self._predictions])) + + # unmap the category ids for LVIS (from 0-indexed to 1-indexed) + for result in self._oid_results: + result["category_id"] += 1 + + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "oid_instances_results.json") + self._logger.info(f"Saving results to {file_path}") + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(self._oid_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating predictions ...") + self._results = OrderedDict() + res, mAP = _evaluate_predictions_on_oid( + self._oid_api, + file_path, + eval_seg=self._mask_on, + class_names=self._metadata.get("thing_classes"), + ) + self._results["bbox"] = res + mAP_out_path = os.path.join(self._output_dir, "oid_mAP.npy") + self._logger.info("Saving mAP to" + mAP_out_path) + np.save(mAP_out_path, mAP) + return copy.deepcopy(self._results) + + +def _evaluate_predictions_on_oid(oid_gt, oid_results_path, eval_seg: bool=False, class_names: Optional[Sequence[str]]=None): + logger = logging.getLogger(__name__) + + results = {} + oid_eval = OIDEval(oid_gt, oid_results_path, "bbox", expand_pred_label=False) + oid_eval.run() + oid_eval.print_results() + results["AP50"] = oid_eval.get_results()["AP50"] + + if eval_seg: + oid_eval = OIDEval(oid_gt, oid_results_path, "segm", expand_pred_label=False) + oid_eval.run() + oid_eval.print_results() + results["AP50_segm"] = oid_eval.get_results()["AP50"] + else: + oid_eval = OIDEval(oid_gt, oid_results_path, "bbox", expand_pred_label=True) + oid_eval.run() + oid_eval.print_results() + results["AP50_expand"] = oid_eval.get_results()["AP50"] + + mAP = np.zeros(len(class_names)) - 1 + precisions = oid_eval.eval["precision"] + assert len(class_names) == precisions.shape[2] + results_per_category = [] + id2apiid = sorted(oid_gt.get_cat_ids()) + inst_aware_ap, inst_count = 0, 0 + for idx, name in enumerate(class_names): + precision = precisions[:, :, idx, 0] + precision = precision[precision > -1] + ap = np.mean(precision) if precision.size else float("nan") + inst_num = len(oid_gt.get_ann_ids(cat_ids=[id2apiid[idx]])) + if inst_num > 0: + results_per_category.append( + ( + "{} {}".format( + name.replace(" ", "_"), + inst_num if inst_num < 1000 else f"{inst_num / 1000:.1f}k", + ), + float(ap * 100), + ) + ) + inst_aware_ap += inst_num * ap + inst_count += inst_num + mAP[idx] = ap + # logger.info("{} {} {:.2f}".format(name, inst_num, ap * 100)) + inst_aware_ap = inst_aware_ap * 100 / inst_count + N_COLS = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + results_2d, + tablefmt="pipe", + floatfmt=".3f", + headers=["category", "AP"] * (N_COLS // 2), + numalign="left", + ) + logger.info("Per-category {} AP: \n".format("bbox") + table) + logger.info("Instance-aware {} AP: {:.4f}".format("bbox", inst_aware_ap)) + + logger.info("Evaluation results for bbox: \n" + create_small_table(results)) + return results, mAP diff --git a/dimos/models/Detic/detic/modeling/backbone/swintransformer.py b/dimos/models/Detic/detic/modeling/backbone/swintransformer.py new file mode 100644 index 0000000000..b7da6328e3 --- /dev/null +++ b/dimos/models/Detic/detic/modeling/backbone/swintransformer.py @@ -0,0 +1,825 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Xingyi Zhou from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py + + +from centernet.modeling.backbone.bifpn import BiFPN +from centernet.modeling.backbone.fpn_p5 import LastLevelP6P7_P5 +from detectron2.layers import ShapeSpec +from detectron2.modeling.backbone.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from typing import Optional, Sequence + +# from .checkpoint import load_checkpoint + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop: float=0.0 + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim: int, + window_size: int, + num_heads: int, + qkv_bias: bool=True, + qk_scale=None, + attn_drop: float=0.0, + proj_drop: float=0.0, + ) -> None: + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: int=7, + shift_size: int=0, + mlp_ratio: float=4.0, + qkv_bias: bool=True, + qk_scale=None, + drop: float=0.0, + attn_drop: float=0.0, + drop_path: float=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ) -> None: + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim: int, norm_layer=nn.LayerNorm) -> None: + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: int=7, + mlp_ratio: float=4.0, + qkv_bias: bool=True, + qk_scale=None, + drop: float=0.0, + attn_drop: float=0.0, + drop_path: float=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint: bool=False, + ) -> None: + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill( + attn_mask == 0, 0.0 + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size: int=4, in_chans: int=3, embed_dim: int=96, norm_layer=None) -> None: + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(Backbone): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size: int=224, + patch_size: int=4, + in_chans: int=3, + embed_dim: int=96, + depths: Optional[Sequence[int]]=None, + num_heads: Optional[int]=None, + window_size: int=7, + mlp_ratio: float=4.0, + qkv_bias: bool=True, + qk_scale=None, + drop_rate: float=0.0, + attn_drop_rate: float=0.0, + drop_path_rate: float=0.2, + norm_layer=nn.LayerNorm, + ape: bool=False, + patch_norm: bool=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint: bool=False, + ) -> None: + if num_heads is None: + num_heads = [3, 6, 12, 24] + if depths is None: + depths = [2, 2, 6, 2] + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + self._out_features = [f"swin{i}" for i in self.out_indices] + self._out_feature_channels = { + f"swin{i}": self.embed_dim * 2**i for i in self.out_indices + } + self._out_feature_strides = {f"swin{i}": 2 ** (i + 2) for i in self.out_indices} + self._size_devisibility = 32 + + def _freeze_stages(self) -> None: + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained: Optional[bool]=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + # load_checkpoint(self, pretrained, strict=False) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError("pretrained must be a str or None") + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + # outs = [] + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + # outs.append(out) + outs[f"swin{i}"] = out + + return outs + + def train(self, mode: bool=True) -> None: + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + self._freeze_stages() + + +size2config = { + "T": { + "window_size": 7, + "embed_dim": 96, + "depth": [2, 2, 6, 2], + "num_heads": [3, 6, 12, 24], + "drop_path_rate": 0.2, + "pretrained": "models/swin_tiny_patch4_window7_224.pth", + }, + "S": { + "window_size": 7, + "embed_dim": 96, + "depth": [2, 2, 18, 2], + "num_heads": [3, 6, 12, 24], + "drop_path_rate": 0.2, + "pretrained": "models/swin_small_patch4_window7_224.pth", + }, + "B": { + "window_size": 7, + "embed_dim": 128, + "depth": [2, 2, 18, 2], + "num_heads": [4, 8, 16, 32], + "drop_path_rate": 0.3, + "pretrained": "models/swin_base_patch4_window7_224.pth", + }, + "B-22k": { + "window_size": 7, + "embed_dim": 128, + "depth": [2, 2, 18, 2], + "num_heads": [4, 8, 16, 32], + "drop_path_rate": 0.3, + "pretrained": "models/swin_base_patch4_window7_224_22k.pth", + }, + "B-22k-384": { + "window_size": 12, + "embed_dim": 128, + "depth": [2, 2, 18, 2], + "num_heads": [4, 8, 16, 32], + "drop_path_rate": 0.3, + "pretrained": "models/swin_base_patch4_window12_384_22k.pth", + }, + "L-22k": { + "window_size": 7, + "embed_dim": 192, + "depth": [2, 2, 18, 2], + "num_heads": [6, 12, 24, 48], + "drop_path_rate": 0.3, # TODO (xingyi): this is unclear + "pretrained": "models/swin_large_patch4_window7_224_22k.pth", + }, + "L-22k-384": { + "window_size": 12, + "embed_dim": 192, + "depth": [2, 2, 18, 2], + "num_heads": [6, 12, 24, 48], + "drop_path_rate": 0.3, # TODO (xingyi): this is unclear + "pretrained": "models/swin_large_patch4_window12_384_22k.pth", + }, +} + + +@BACKBONE_REGISTRY.register() +def build_swintransformer_backbone(cfg, input_shape): + """ """ + config = size2config[cfg.MODEL.SWIN.SIZE] + out_indices = cfg.MODEL.SWIN.OUT_FEATURES + model = SwinTransformer( + embed_dim=config["embed_dim"], + window_size=config["window_size"], + depths=config["depth"], + num_heads=config["num_heads"], + drop_path_rate=config["drop_path_rate"], + out_indices=out_indices, + frozen_stages=-1, + use_checkpoint=cfg.MODEL.SWIN.USE_CHECKPOINT, + ) + # print('Initializing', config['pretrained']) + model.init_weights(config["pretrained"]) + return model + + +@BACKBONE_REGISTRY.register() +def build_swintransformer_fpn_backbone(cfg, input_shape: ShapeSpec): + """ """ + bottom_up = build_swintransformer_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_swintransformer_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ """ + bottom_up = build_swintransformer_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + backbone = BiFPN( + cfg=cfg, + bottom_up=bottom_up, + in_features=in_features, + out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS, + norm=cfg.MODEL.BIFPN.NORM, + num_levels=cfg.MODEL.BIFPN.NUM_LEVELS, + num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN, + separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV, + ) + return backbone diff --git a/dimos/models/Detic/detic/modeling/backbone/timm.py b/dimos/models/Detic/detic/modeling/backbone/timm.py new file mode 100644 index 0000000000..a15e03f875 --- /dev/null +++ b/dimos/models/Detic/detic/modeling/backbone/timm.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +import copy + +from detectron2.layers.batch_norm import FrozenBatchNorm2d +from detectron2.modeling.backbone import FPN, Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +import fvcore.nn.weight_init as weight_init +from timm import create_model +from timm.models.convnext import ConvNeXt, checkpoint_filter_fn, default_cfgs +from timm.models.helpers import build_model_with_cfg +from timm.models.registry import register_model +from timm.models.resnet import Bottleneck, ResNet, default_cfgs as default_cfgs_resnet +import torch +from torch import nn +import torch.nn.functional as F + + +@register_model +def convnext_tiny_21k(pretrained: bool=False, **kwargs): + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) + cfg = default_cfgs["convnext_tiny"] + cfg["url"] = "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth" + model = build_model_with_cfg( + ConvNeXt, + "convnext_tiny", + pretrained, + default_cfg=cfg, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **model_args, + ) + return model + + +class CustomResNet(ResNet): + def __init__(self, **kwargs) -> None: + self.out_indices = kwargs.pop("out_indices") + super().__init__(**kwargs) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.maxpool(x) + ret = [x] + x = self.layer1(x) + ret.append(x) + x = self.layer2(x) + ret.append(x) + x = self.layer3(x) + ret.append(x) + x = self.layer4(x) + ret.append(x) + return [ret[i] for i in self.out_indices] + + def load_pretrained(self, cached_file) -> None: + data = torch.load(cached_file, map_location="cpu") + if "state_dict" in data: + self.load_state_dict(data["state_dict"]) + else: + self.load_state_dict(data) + + +model_params = { + "resnet50_in21k": dict(block=Bottleneck, layers=[3, 4, 6, 3]), +} + + +def create_timm_resnet(variant, out_indices, pretrained: bool=False, **kwargs): + params = model_params[variant] + default_cfgs_resnet["resnet50_in21k"] = copy.deepcopy(default_cfgs_resnet["resnet50"]) + default_cfgs_resnet["resnet50_in21k"]["url"] = ( + "https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth" + ) + default_cfgs_resnet["resnet50_in21k"]["num_classes"] = 11221 + + return build_model_with_cfg( + CustomResNet, + variant, + pretrained, + default_cfg=default_cfgs_resnet[variant], + out_indices=out_indices, + pretrained_custom_load=True, + **params, + **kwargs, + ) + + +class LastLevelP6P7_P5(nn.Module): + """ """ + + def __init__(self, in_channels, out_channels) -> None: + super().__init__() + self.num_levels = 2 + self.in_feature = "p5" + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +def freeze_module(x): + """ """ + for p in x.parameters(): + p.requires_grad = False + FrozenBatchNorm2d.convert_frozen_batchnorm(x) + return x + + +class TIMM(Backbone): + def __init__(self, base_name: str, out_levels, freeze_at: int=0, norm: str="FrozenBN", pretrained: bool=False) -> None: + super().__init__() + out_indices = [x - 1 for x in out_levels] + if base_name in model_params: + self.base = create_timm_resnet(base_name, out_indices=out_indices, pretrained=False) + elif "eff" in base_name or "resnet" in base_name or "regnet" in base_name: + self.base = create_model( + base_name, features_only=True, out_indices=out_indices, pretrained=pretrained + ) + elif "convnext" in base_name: + drop_path_rate = 0.2 if ("tiny" in base_name or "small" in base_name) else 0.3 + self.base = create_model( + base_name, + features_only=True, + out_indices=out_indices, + pretrained=pretrained, + drop_path_rate=drop_path_rate, + ) + else: + assert 0, base_name + feature_info = [ + dict(num_chs=f["num_chs"], reduction=f["reduction"]) + for i, f in enumerate(self.base.feature_info) + ] + self._out_features = [f"layer{x}" for x in out_levels] + self._out_feature_channels = { + f"layer{l}": feature_info[l - 1]["num_chs"] for l in out_levels + } + self._out_feature_strides = { + f"layer{l}": feature_info[l - 1]["reduction"] for l in out_levels + } + self._size_divisibility = max(self._out_feature_strides.values()) + if "resnet" in base_name: + self.freeze(freeze_at) + if norm == "FrozenBN": + self = FrozenBatchNorm2d.convert_frozen_batchnorm(self) + + def freeze(self, freeze_at: int=0) -> None: + """ """ + if freeze_at >= 1: + print("Frezing", self.base.conv1) + self.base.conv1 = freeze_module(self.base.conv1) + if freeze_at >= 2: + print("Frezing", self.base.layer1) + self.base.layer1 = freeze_module(self.base.layer1) + + def forward(self, x): + features = self.base(x) + ret = {k: v for k, v in zip(self._out_features, features, strict=False)} + return ret + + @property + def size_divisibility(self): + return self._size_divisibility + + +@BACKBONE_REGISTRY.register() +def build_timm_backbone(cfg, input_shape): + model = TIMM( + cfg.MODEL.TIMM.BASE_NAME, + cfg.MODEL.TIMM.OUT_LEVELS, + freeze_at=cfg.MODEL.TIMM.FREEZE_AT, + norm=cfg.MODEL.TIMM.NORM, + pretrained=cfg.MODEL.TIMM.PRETRAINED, + ) + return model + + +@BACKBONE_REGISTRY.register() +def build_p67_timm_fpn_backbone(cfg, input_shape): + """ """ + bottom_up = build_timm_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p35_timm_fpn_backbone(cfg, input_shape): + """ """ + bottom_up = build_timm_backbone(cfg, input_shape) + + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=None, + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/dimos/models/Detic/detic/modeling/debug.py b/dimos/models/Detic/detic/modeling/debug.py new file mode 100644 index 0000000000..f37849019e --- /dev/null +++ b/dimos/models/Detic/detic/modeling/debug.py @@ -0,0 +1,408 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from typing import Optional, Sequence + +COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * 255).astype(np.uint8).reshape(1300, 1, 1, 3) + + +def _get_color_image(heatmap): + heatmap = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], heatmap.shape[2], 1) + if heatmap.shape[0] == 1: + color_map = ( + (heatmap * np.ones((1, 1, 1, 3), np.uint8) * 255).max(axis=0).astype(np.uint8) + ) # H, W, 3 + else: + color_map = (heatmap * COLORS[: heatmap.shape[0]]).max(axis=0).astype(np.uint8) # H, W, 3 + + return color_map + + +def _blend_image(image, color_map, a: float=0.7): + color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) + ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8) + return ret + + +def _blend_image_heatmaps(image, color_maps, a: float=0.7): + merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32) + for color_map in color_maps: + color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) + merges = np.maximum(merges, color_map) + ret = np.clip(image * (1 - a) + merges * a, 0, 255).astype(np.uint8) + return ret + + +def _decompose_level(x, shapes_per_level, N): + """ + x: LNHiWi x C + """ + x = x.view(x.shape[0], -1) + ret = [] + st = 0 + for l in range(len(shapes_per_level)): + ret.append([]) + h = shapes_per_level[l][0].int().item() + w = shapes_per_level[l][1].int().item() + for i in range(N): + ret[l].append(x[st + h * w * i : st + h * w * (i + 1)].view(h, w, -1).permute(2, 0, 1)) + st += h * w * N + return ret + + +def _imagelist_to_tensor(images): + images = [x for x in images] + image_sizes = [x.shape[-2:] for x in images] + h = max([size[0] for size in image_sizes]) + w = max([size[1] for size in image_sizes]) + S = 32 + h, w = ((h - 1) // S + 1) * S, ((w - 1) // S + 1) * S + images = [F.pad(x, (0, w - x.shape[2], 0, h - x.shape[1], 0, 0)) for x in images] + images = torch.stack(images) + return images + + +def _ind2il(ind, shapes_per_level, N): + r = ind + l = 0 + S = 0 + while r - S >= N * shapes_per_level[l][0] * shapes_per_level[l][1]: + S += N * shapes_per_level[l][0] * shapes_per_level[l][1] + l += 1 + i = (r - S) // (shapes_per_level[l][0] * shapes_per_level[l][1]) + return i, l + + +def debug_train( + images, + gt_instances, + flattened_hms, + reg_targets, + labels: Sequence[str], + pos_inds, + shapes_per_level, + locations, + strides: Sequence[int], +) -> None: + """ + images: N x 3 x H x W + flattened_hms: LNHiWi x C + shapes_per_level: L x 2 [(H_i, W_i)] + locations: LNHiWi x 2 + """ + reg_inds = torch.nonzero(reg_targets.max(dim=1)[0] > 0).squeeze(1) + N = len(images) + images = _imagelist_to_tensor(images) + repeated_locations = [torch.cat([loc] * N, dim=0) for loc in locations] + locations = torch.cat(repeated_locations, dim=0) + gt_hms = _decompose_level(flattened_hms, shapes_per_level, N) + masks = flattened_hms.new_zeros((flattened_hms.shape[0], 1)) + masks[pos_inds] = 1 + masks = _decompose_level(masks, shapes_per_level, N) + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0) + color_maps = [] + for l in range(len(gt_hms)): + color_map = _get_color_image(gt_hms[l][i].detach().cpu().numpy()) + color_maps.append(color_map) + cv2.imshow(f"gthm_{l}", color_map) + blend = _blend_image_heatmaps(image.copy(), color_maps) + if gt_instances is not None: + bboxes = gt_instances[i].gt_boxes.tensor + for j in range(len(bboxes)): + bbox = bboxes[j] + cv2.rectangle( + blend, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (0, 0, 255), + 3, + cv2.LINE_AA, + ) + + for j in range(len(pos_inds)): + image_id, l = _ind2il(pos_inds[j], shapes_per_level, N) + if image_id != i: + continue + loc = locations[pos_inds[j]] + cv2.drawMarker( + blend, (int(loc[0]), int(loc[1])), (0, 255, 255), markerSize=(l + 1) * 16 + ) + + for j in range(len(reg_inds)): + image_id, l = _ind2il(reg_inds[j], shapes_per_level, N) + if image_id != i: + continue + ltrb = reg_targets[reg_inds[j]] + ltrb *= strides[l] + loc = locations[reg_inds[j]] + bbox = [(loc[0] - ltrb[0]), (loc[1] - ltrb[1]), (loc[0] + ltrb[2]), (loc[1] + ltrb[3])] + cv2.rectangle( + blend, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (255, 0, 0), + 1, + cv2.LINE_AA, + ) + cv2.circle(blend, (int(loc[0]), int(loc[1])), 2, (255, 0, 0), -1) + + cv2.imshow("blend", blend) + cv2.waitKey() + + +def debug_test( + images, + logits_pred, + reg_pred, + agn_hm_pred=None, + preds=None, + vis_thresh: float=0.3, + debug_show_name: bool=False, + mult_agn: bool=False, +) -> None: + """ + images: N x 3 x H x W + class_target: LNHiWi x C + cat_agn_heatmap: LNHiWi + shapes_per_level: L x 2 [(H_i, W_i)] + """ + if preds is None: + preds = [] + if agn_hm_pred is None: + agn_hm_pred = [] + len(images) + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0) + image.copy().astype(np.uint8) + pred_image = image.copy().astype(np.uint8) + color_maps = [] + L = len(logits_pred) + for l in range(L): + if logits_pred[0] is not None: + stride = min(image.shape[0], image.shape[1]) / min( + logits_pred[l][i].shape[1], logits_pred[l][i].shape[2] + ) + else: + stride = min(image.shape[0], image.shape[1]) / min( + agn_hm_pred[l][i].shape[1], agn_hm_pred[l][i].shape[2] + ) + stride = stride if stride < 60 else 64 if stride < 100 else 128 + if logits_pred[0] is not None: + if mult_agn: + logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i] + color_map = _get_color_image(logits_pred[l][i].detach().cpu().numpy()) + color_maps.append(color_map) + cv2.imshow(f"predhm_{l}", color_map) + + if debug_show_name: + from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES + + cat2name = [x["name"] for x in LVIS_CATEGORIES] + for j in range(len(preds[i].scores) if preds is not None else 0): + if preds[i].scores[j] > vis_thresh: + bbox = ( + preds[i].proposal_boxes[j] + if preds[i].has("proposal_boxes") + else preds[i].pred_boxes[j] + ) + bbox = bbox.tensor[0].detach().cpu().numpy().astype(np.int32) + cat = int(preds[i].pred_classes[j]) if preds[i].has("pred_classes") else 0 + cl = COLORS[cat, 0, 0] + cv2.rectangle( + pred_image, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (int(cl[0]), int(cl[1]), int(cl[2])), + 2, + cv2.LINE_AA, + ) + if debug_show_name: + txt = "{}{:.1f}".format( + cat2name[cat] if cat > 0 else "", preds[i].scores[j] + ) + font = cv2.FONT_HERSHEY_SIMPLEX + cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] + cv2.rectangle( + pred_image, + (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), + (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), + (int(cl[0]), int(cl[1]), int(cl[2])), + -1, + ) + cv2.putText( + pred_image, + txt, + (int(bbox[0]), int(bbox[1] - 2)), + font, + 0.5, + (0, 0, 0), + thickness=1, + lineType=cv2.LINE_AA, + ) + + if agn_hm_pred[l] is not None: + agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy() + agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(1, 1, 3)).astype(np.uint8) + cv2.imshow(f"agn_hm_{l}", agn_hm_) + blend = _blend_image_heatmaps(image.copy(), color_maps) + cv2.imshow("blend", blend) + cv2.imshow("preds", pred_image) + cv2.waitKey() + + +global cnt +cnt = 0 + + +def debug_second_stage( + images, + instances, + proposals=None, + vis_thresh: float=0.3, + save_debug: bool=False, + debug_show_name: bool=False, + image_labels: Optional[Sequence[str]]=None, + save_debug_path: str="output/save_debug/", + bgr: bool=False, +) -> None: + if image_labels is None: + image_labels = [] + images = _imagelist_to_tensor(images) + if "COCO" in save_debug_path: + from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES + + cat2name = [x["name"] for x in COCO_CATEGORIES] + else: + from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES + + cat2name = ["({}){}".format(x["frequency"], x["name"]) for x in LVIS_CATEGORIES] + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy() + if bgr: + image = image[:, :, ::-1].copy() + if instances[i].has("gt_boxes"): + bboxes = instances[i].gt_boxes.tensor.cpu().numpy() + scores = np.ones(bboxes.shape[0]) + cats = instances[i].gt_classes.cpu().numpy() + else: + bboxes = instances[i].pred_boxes.tensor.cpu().numpy() + scores = instances[i].scores.cpu().numpy() + cats = instances[i].pred_classes.cpu().numpy() + for j in range(len(bboxes)): + if scores[j] > vis_thresh: + bbox = bboxes[j] + cl = COLORS[cats[j], 0, 0] + cl = (int(cl[0]), int(cl[1]), int(cl[2])) + cv2.rectangle( + image, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + cl, + 2, + cv2.LINE_AA, + ) + if debug_show_name: + cat = cats[j] + txt = "{}{:.1f}".format(cat2name[cat] if cat > 0 else "", scores[j]) + font = cv2.FONT_HERSHEY_SIMPLEX + cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] + cv2.rectangle( + image, + (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), + (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), + (int(cl[0]), int(cl[1]), int(cl[2])), + -1, + ) + cv2.putText( + image, + txt, + (int(bbox[0]), int(bbox[1] - 2)), + font, + 0.5, + (0, 0, 0), + thickness=1, + lineType=cv2.LINE_AA, + ) + if proposals is not None: + proposal_image = ( + images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy() + ) + if bgr: + proposal_image = proposal_image.copy() + else: + proposal_image = proposal_image[:, :, ::-1].copy() + bboxes = proposals[i].proposal_boxes.tensor.cpu().numpy() + if proposals[i].has("scores"): + scores = proposals[i].scores.detach().cpu().numpy() + else: + scores = proposals[i].objectness_logits.detach().cpu().numpy() + # selected = -1 + # if proposals[i].has('image_loss'): + # selected = proposals[i].image_loss.argmin() + if proposals[i].has("selected"): + selected = proposals[i].selected + else: + selected = [-1 for _ in range(len(bboxes))] + for j in range(len(bboxes)): + if scores[j] > vis_thresh or selected[j] >= 0: + bbox = bboxes[j] + cl = (209, 159, 83) + th = 2 + if selected[j] >= 0: + cl = (0, 0, 0xA4) + th = 4 + cv2.rectangle( + proposal_image, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + cl, + th, + cv2.LINE_AA, + ) + if selected[j] >= 0 and debug_show_name: + cat = selected[j].item() + txt = f"{cat2name[cat]}" + font = cv2.FONT_HERSHEY_SIMPLEX + cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] + cv2.rectangle( + proposal_image, + (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), + (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), + (int(cl[0]), int(cl[1]), int(cl[2])), + -1, + ) + cv2.putText( + proposal_image, + txt, + (int(bbox[0]), int(bbox[1] - 2)), + font, + 0.5, + (0, 0, 0), + thickness=1, + lineType=cv2.LINE_AA, + ) + + if save_debug: + global cnt + cnt = (cnt + 1) % 5000 + if not os.path.exists(save_debug_path): + os.mkdir(save_debug_path) + save_name = f"{save_debug_path}/{cnt:05d}.jpg" + if i < len(image_labels): + image_label = image_labels[i] + save_name = f"{save_debug_path}/{cnt:05d}" + for x in image_label: + class_name = cat2name[x] + save_name = save_name + f"|{class_name}" + save_name = save_name + ".jpg" + cv2.imwrite(save_name, proposal_image) + else: + cv2.imshow("image", image) + if proposals is not None: + cv2.imshow("proposals", proposal_image) + cv2.waitKey() diff --git a/dimos/models/Detic/detic/modeling/meta_arch/custom_rcnn.py b/dimos/models/Detic/detic/modeling/meta_arch/custom_rcnn.py new file mode 100644 index 0000000000..872084f7cb --- /dev/null +++ b/dimos/models/Detic/detic/modeling/meta_arch/custom_rcnn.py @@ -0,0 +1,227 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Dict, List, Optional, Tuple + +from detectron2.config import configurable +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN +from detectron2.structures import Instances +import detectron2.utils.comm as comm +from detectron2.utils.events import get_event_storage +import torch +from torch.cuda.amp import autocast + +from ..text.text_encoder import build_text_encoder +from ..utils import get_fed_loss_inds, load_class_freq + + +@META_ARCH_REGISTRY.register() +class CustomRCNN(GeneralizedRCNN): + """ + Add image labels + """ + + @configurable + def __init__( + self, + with_image_labels: bool=False, + dataset_loss_weight=None, + fp16: bool=False, + sync_caption_batch: bool=False, + roi_head_name: str="", + cap_batch_ratio: int=4, + with_caption: bool=False, + dynamic_classifier: bool=False, + **kwargs, + ) -> None: + """ """ + if dataset_loss_weight is None: + dataset_loss_weight = [] + self.with_image_labels = with_image_labels + self.dataset_loss_weight = dataset_loss_weight + self.fp16 = fp16 + self.with_caption = with_caption + self.sync_caption_batch = sync_caption_batch + self.roi_head_name = roi_head_name + self.cap_batch_ratio = cap_batch_ratio + self.dynamic_classifier = dynamic_classifier + self.return_proposal = False + if self.dynamic_classifier: + self.freq_weight = kwargs.pop("freq_weight") + self.num_classes = kwargs.pop("num_classes") + self.num_sample_cats = kwargs.pop("num_sample_cats") + super().__init__(**kwargs) + assert self.proposal_generator is not None + if self.with_caption: + assert not self.dynamic_classifier + self.text_encoder = build_text_encoder(pretrain=True) + for v in self.text_encoder.parameters(): + v.requires_grad = False + + @classmethod + def from_config(cls, cfg): + ret = super().from_config(cfg) + ret.update( + { + "with_image_labels": cfg.WITH_IMAGE_LABELS, + "dataset_loss_weight": cfg.MODEL.DATASET_LOSS_WEIGHT, + "fp16": cfg.FP16, + "with_caption": cfg.MODEL.WITH_CAPTION, + "sync_caption_batch": cfg.MODEL.SYNC_CAPTION_BATCH, + "dynamic_classifier": cfg.MODEL.DYNAMIC_CLASSIFIER, + "roi_head_name": cfg.MODEL.ROI_HEADS.NAME, + "cap_batch_ratio": cfg.MODEL.CAP_BATCH_RATIO, + } + ) + if ret["dynamic_classifier"]: + ret["freq_weight"] = load_class_freq( + cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT + ) + ret["num_classes"] = cfg.MODEL.ROI_HEADS.NUM_CLASSES + ret["num_sample_cats"] = cfg.MODEL.NUM_SAMPLE_CATS + return ret + + def inference( + self, + batched_inputs: tuple[dict[str, torch.Tensor]], + detected_instances: list[Instances] | None = None, + do_postprocess: bool = True, + ): + assert not self.training + assert detected_instances is None + + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + proposals, _ = self.proposal_generator(images, features, None) + results, _ = self.roi_heads(images, features, proposals) + if do_postprocess: + assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." + return CustomRCNN._postprocess(results, batched_inputs, images.image_sizes) + else: + return results + + def forward(self, batched_inputs: list[dict[str, torch.Tensor]]): + """ + Add ann_type + Ignore proposal loss when training with image labels + """ + if not self.training: + return self.inference(batched_inputs) + + images = self.preprocess_image(batched_inputs) + + ann_type = "box" + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + if self.with_image_labels: + for inst, x in zip(gt_instances, batched_inputs, strict=False): + inst._ann_type = x["ann_type"] + inst._pos_category_ids = x["pos_category_ids"] + ann_types = [x["ann_type"] for x in batched_inputs] + assert len(set(ann_types)) == 1 + ann_type = ann_types[0] + if ann_type in ["prop", "proptag"]: + for t in gt_instances: + t.gt_classes *= 0 + + if self.fp16: # TODO (zhouxy): improve + with autocast(): + features = self.backbone(images.tensor.half()) + features = {k: v.float() for k, v in features.items()} + else: + features = self.backbone(images.tensor) + + cls_features, cls_inds, caption_features = None, None, None + + if self.with_caption and "caption" in ann_type: + inds = [torch.randint(len(x["captions"]), (1,))[0].item() for x in batched_inputs] + caps = [x["captions"][ind] for ind, x in zip(inds, batched_inputs, strict=False)] + caption_features = self.text_encoder(caps).float() + if self.sync_caption_batch: + caption_features = self._sync_caption_features( + caption_features, ann_type, len(batched_inputs) + ) + + if self.dynamic_classifier and ann_type != "caption": + cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds + ind_with_bg = [*cls_inds[0].tolist(), -1] + cls_features = ( + self.roi_heads.box_predictor[0] + .cls_score.zs_weight[:, ind_with_bg] + .permute(1, 0) + .contiguous() + ) + + classifier_info = cls_features, cls_inds, caption_features + proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + + if self.roi_head_name in ["StandardROIHeads", "CascadeROIHeads"]: + proposals, detector_losses = self.roi_heads(images, features, proposals, gt_instances) + else: + proposals, detector_losses = self.roi_heads( + images, + features, + proposals, + gt_instances, + ann_type=ann_type, + classifier_info=classifier_info, + ) + + if self.vis_period > 0: + storage = get_event_storage() + if storage.iter % self.vis_period == 0: + self.visualize_training(batched_inputs, proposals) + + losses = {} + losses.update(detector_losses) + if self.with_image_labels: + if ann_type in ["box", "prop", "proptag"]: + losses.update(proposal_losses) + else: # ignore proposal loss for non-bbox data + losses.update({k: v * 0 for k, v in proposal_losses.items()}) + else: + losses.update(proposal_losses) + if len(self.dataset_loss_weight) > 0: + dataset_sources = [x["dataset_source"] for x in batched_inputs] + assert len(set(dataset_sources)) == 1 + dataset_source = dataset_sources[0] + for k in losses: + losses[k] *= self.dataset_loss_weight[dataset_source] + + if self.return_proposal: + return proposals, losses + else: + return losses + + def _sync_caption_features(self, caption_features, ann_type, BS): + has_caption_feature = caption_features is not None + BS = (BS * self.cap_batch_ratio) if (ann_type == "box") else BS + rank = torch.full((BS, 1), comm.get_rank(), dtype=torch.float32, device=self.device) + if not has_caption_feature: + caption_features = rank.new_zeros((BS, 512)) + caption_features = torch.cat([caption_features, rank], dim=1) + global_caption_features = comm.all_gather(caption_features) + caption_features = ( + torch.cat([x.to(self.device) for x in global_caption_features], dim=0) + if has_caption_feature + else None + ) # (NB) x (D + 1) + return caption_features + + def _sample_cls_inds(self, gt_instances, ann_type: str="box"): + if ann_type == "box": + gt_classes = torch.cat([x.gt_classes for x in gt_instances]) + C = len(self.freq_weight) + freq_weight = self.freq_weight + else: + gt_classes = torch.cat( + [ + torch.tensor(x._pos_category_ids, dtype=torch.long, device=x.gt_classes.device) + for x in gt_instances + ] + ) + C = self.num_classes + freq_weight = None + assert gt_classes.max() < C, f"{gt_classes.max()} {C}" + inds = get_fed_loss_inds(gt_classes, self.num_sample_cats, C, weight=freq_weight) + cls_id_map = gt_classes.new_full((self.num_classes + 1,), len(inds)) + cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device) + return inds, cls_id_map diff --git a/dimos/models/Detic/detic/modeling/meta_arch/d2_deformable_detr.py b/dimos/models/Detic/detic/modeling/meta_arch/d2_deformable_detr.py new file mode 100644 index 0000000000..9c2ec8e81e --- /dev/null +++ b/dimos/models/Detic/detic/modeling/meta_arch/d2_deformable_detr.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.modeling import META_ARCH_REGISTRY, build_backbone +from detectron2.structures import Boxes, Instances +from models.backbone import Joiner +from models.deformable_detr import DeformableDETR, SetCriterion +from models.deformable_transformer import DeformableTransformer +from models.matcher import HungarianMatcher +from models.position_encoding import PositionEmbeddingSine +from models.segmentation import sigmoid_focal_loss +import torch +from torch import nn +import torch.nn.functional as F +from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh +from util.misc import NestedTensor, accuracy + +from ..utils import get_fed_loss_inds, load_class_freq +from typing import Sequence + +__all__ = ["DeformableDetr"] + + +class CustomSetCriterion(SetCriterion): + def __init__( + self, num_classes: int, matcher, weight_dict, losses, focal_alpha: float=0.25, use_fed_loss: bool=False + ) -> None: + super().__init__(num_classes, matcher, weight_dict, losses, focal_alpha) + self.use_fed_loss = use_fed_loss + if self.use_fed_loss: + self.register_buffer("fed_loss_weight", load_class_freq(freq_weight=0.5)) + + def loss_labels(self, outputs, targets, indices, num_boxes: int, log: bool=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices, strict=False)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros( + [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, + layout=src_logits.layout, + device=src_logits.device, + ) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:, :, :-1] # B x N x C + if self.use_fed_loss: + inds = get_fed_loss_inds( + gt_classes=target_classes_o, + num_sample_cats=50, + weight=self.fed_loss_weight, + C=target_classes_onehot.shape[2], + ) + loss_ce = ( + sigmoid_focal_loss( + src_logits[:, :, inds], + target_classes_onehot[:, :, inds], + num_boxes, + alpha=self.focal_alpha, + gamma=2, + ) + * src_logits.shape[1] + ) + else: + loss_ce = ( + sigmoid_focal_loss( + src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2 + ) + * src_logits.shape[1] + ) + losses = {"loss_ce": loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + +class MaskedBackbone(nn.Module): + """This is a thin wrapper around D2's backbone to provide padding masking""" + + def __init__(self, cfg) -> None: + super().__init__() + self.backbone = build_backbone(cfg) + backbone_shape = self.backbone.output_shape() + self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] + self.strides = [backbone_shape[f].stride for f in backbone_shape.keys()] + self.num_channels = [backbone_shape[x].channels for x in backbone_shape.keys()] + + def forward(self, tensor_list: NestedTensor): + xs = self.backbone(tensor_list.tensors) + out = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +@META_ARCH_REGISTRY.register() +class DeformableDetr(nn.Module): + """ + Implement Deformable Detr + """ + + def __init__(self, cfg) -> None: + super().__init__() + self.with_image_labels = cfg.WITH_IMAGE_LABELS + self.weak_weight = cfg.MODEL.DETR.WEAK_WEIGHT + + self.device = torch.device(cfg.MODEL.DEVICE) + self.test_topk = cfg.TEST.DETECTIONS_PER_IMAGE + self.num_classes = cfg.MODEL.DETR.NUM_CLASSES + self.mask_on = cfg.MODEL.MASK_ON + hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM + num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES + + # Transformer parameters: + nheads = cfg.MODEL.DETR.NHEADS + dropout = cfg.MODEL.DETR.DROPOUT + dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD + enc_layers = cfg.MODEL.DETR.ENC_LAYERS + dec_layers = cfg.MODEL.DETR.DEC_LAYERS + num_feature_levels = cfg.MODEL.DETR.NUM_FEATURE_LEVELS + two_stage = cfg.MODEL.DETR.TWO_STAGE + with_box_refine = cfg.MODEL.DETR.WITH_BOX_REFINE + + # Loss parameters: + giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT + l1_weight = cfg.MODEL.DETR.L1_WEIGHT + deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION + cls_weight = cfg.MODEL.DETR.CLS_WEIGHT + focal_alpha = cfg.MODEL.DETR.FOCAL_ALPHA + + N_steps = hidden_dim // 2 + d2_backbone = MaskedBackbone(cfg) + backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True)) + + transformer = DeformableTransformer( + d_model=hidden_dim, + nhead=nheads, + num_encoder_layers=enc_layers, + num_decoder_layers=dec_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=num_feature_levels, + dec_n_points=4, + enc_n_points=4, + two_stage=two_stage, + two_stage_num_proposals=num_queries, + ) + + self.detr = DeformableDETR( + backbone, + transformer, + num_classes=self.num_classes, + num_queries=num_queries, + num_feature_levels=num_feature_levels, + aux_loss=deep_supervision, + with_box_refine=with_box_refine, + two_stage=two_stage, + ) + + if self.mask_on: + assert 0, "Mask is not supported yet :(" + + matcher = HungarianMatcher( + cost_class=cls_weight, cost_bbox=l1_weight, cost_giou=giou_weight + ) + weight_dict = {"loss_ce": cls_weight, "loss_bbox": l1_weight} + weight_dict["loss_giou"] = giou_weight + if deep_supervision: + aux_weight_dict = {} + for i in range(dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + print("weight_dict", weight_dict) + losses = ["labels", "boxes", "cardinality"] + if self.mask_on: + losses += ["masks"] + self.criterion = CustomSetCriterion( + self.num_classes, + matcher=matcher, + weight_dict=weight_dict, + focal_alpha=focal_alpha, + losses=losses, + use_fed_loss=cfg.MODEL.DETR.USE_FED_LOSS, + ) + pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) + pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) + self.normalizer = lambda x: (x - pixel_mean) / pixel_std + + def forward(self, batched_inputs): + """ + Args: + Returns: + dict[str: Tensor]: + mapping from a named loss to a tensor storing the loss. Used during training only. + """ + images = self.preprocess_image(batched_inputs) + output = self.detr(images) + if self.training: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + targets = self.prepare_targets(gt_instances) + loss_dict = self.criterion(output, targets) + weight_dict = self.criterion.weight_dict + for k in loss_dict.keys(): + if k in weight_dict: + loss_dict[k] *= weight_dict[k] + if self.with_image_labels: + if batched_inputs[0]["ann_type"] in ["image", "captiontag"]: + loss_dict["loss_image"] = self.weak_weight * self._weak_loss( + output, batched_inputs + ) + else: + loss_dict["loss_image"] = images[0].new_zeros([1], dtype=torch.float32)[0] + # import pdb; pdb.set_trace() + return loss_dict + else: + image_sizes = output["pred_boxes"].new_tensor( + [(t["height"], t["width"]) for t in batched_inputs] + ) + results = self.post_process(output, image_sizes) + return results + + def prepare_targets(self, targets): + new_targets = [] + for targets_per_image in targets: + h, w = targets_per_image.image_size + image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) + gt_classes = targets_per_image.gt_classes + gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy + gt_boxes = box_xyxy_to_cxcywh(gt_boxes) + new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) + if self.mask_on and hasattr(targets_per_image, "gt_masks"): + assert 0, "Mask is not supported yet :(" + gt_masks = targets_per_image.gt_masks + gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) + new_targets[-1].update({"masks": gt_masks}) + return new_targets + + def post_process(self, outputs, target_sizes: Sequence[int]): + """ """ + out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk( + prob.view(out_logits.shape[0], -1), self.test_topk, dim=1 + ) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b, size in zip(scores, labels, boxes, target_sizes, strict=False): + r = Instances((size[0], size[1])) + r.pred_boxes = Boxes(b) + r.scores = s + r.pred_classes = l + results.append({"instances": r}) + return results + + def preprocess_image(self, batched_inputs): + """ + Normalize, pad and batch the input images. + """ + images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] + return images + + def _weak_loss(self, outputs, batched_inputs): + loss = 0 + for b, x in enumerate(batched_inputs): + labels = x["pos_category_ids"] + pred_logits = [outputs["pred_logits"][b]] + pred_boxes = [outputs["pred_boxes"][b]] + for xx in outputs["aux_outputs"]: + pred_logits.append(xx["pred_logits"][b]) + pred_boxes.append(xx["pred_boxes"][b]) + pred_logits = torch.stack(pred_logits, dim=0) # L x N x C + pred_boxes = torch.stack(pred_boxes, dim=0) # L x N x 4 + for label in labels: + loss += self._max_size_loss(pred_logits, pred_boxes, label) / len(labels) + loss = loss / len(batched_inputs) + return loss + + def _max_size_loss(self, logits, boxes, label: str): + """ + Inputs: + logits: L x N x C + boxes: L x N x 4 + """ + target = logits.new_zeros((logits.shape[0], logits.shape[2])) + target[:, label] = 1.0 + sizes = boxes[..., 2] * boxes[..., 3] # L x N + ind = sizes.argmax(dim=1) # L + loss = F.binary_cross_entropy_with_logits( + logits[range(len(ind)), ind], target, reduction="sum" + ) + return loss diff --git a/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py b/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py new file mode 100644 index 0000000000..aaa7ca233e --- /dev/null +++ b/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math + +from detectron2.config import configurable +from detectron2.layers import ShapeSpec, cat, nonzero_tuple +from detectron2.modeling.roi_heads.fast_rcnn import ( + FastRCNNOutputLayers, + _log_classification_stats, + fast_rcnn_inference, +) +import detectron2.utils.comm as comm +from detectron2.utils.events import get_event_storage +from fvcore.nn import giou_loss, smooth_l1_loss +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from ..utils import get_fed_loss_inds, load_class_freq +from .zero_shot_classifier import ZeroShotClassifier +from typing import Sequence + +__all__ = ["DeticFastRCNNOutputLayers"] + + +class DeticFastRCNNOutputLayers(FastRCNNOutputLayers): + @configurable + def __init__( + self, + input_shape: ShapeSpec, + *, + mult_proposal_score: bool=False, + cls_score=None, + sync_caption_batch: bool=False, + use_sigmoid_ce: bool=False, + use_fed_loss: bool=False, + ignore_zero_cats: bool=False, + fed_loss_num_cat: int=50, + dynamic_classifier: bool=False, + image_label_loss: str="", + use_zeroshot_cls: bool=False, + image_loss_weight: float=0.1, + with_softmax_prop: bool=False, + caption_weight: float=1.0, + neg_cap_weight: float=1.0, + add_image_box: bool=False, + debug: bool=False, + prior_prob: float=0.01, + cat_freq_path: str="", + fed_loss_freq_weight: float=0.5, + softmax_weak_loss: bool=False, + **kwargs, + ) -> None: + super().__init__( + input_shape=input_shape, + **kwargs, + ) + self.mult_proposal_score = mult_proposal_score + self.sync_caption_batch = sync_caption_batch + self.use_sigmoid_ce = use_sigmoid_ce + self.use_fed_loss = use_fed_loss + self.ignore_zero_cats = ignore_zero_cats + self.fed_loss_num_cat = fed_loss_num_cat + self.dynamic_classifier = dynamic_classifier + self.image_label_loss = image_label_loss + self.use_zeroshot_cls = use_zeroshot_cls + self.image_loss_weight = image_loss_weight + self.with_softmax_prop = with_softmax_prop + self.caption_weight = caption_weight + self.neg_cap_weight = neg_cap_weight + self.add_image_box = add_image_box + self.softmax_weak_loss = softmax_weak_loss + self.debug = debug + + if softmax_weak_loss: + assert image_label_loss in ["max_size"] + + if self.use_sigmoid_ce: + bias_value = -math.log((1 - prior_prob) / prior_prob) + nn.init.constant_(self.cls_score.bias, bias_value) + + if self.use_fed_loss or self.ignore_zero_cats: + freq_weight = load_class_freq(cat_freq_path, fed_loss_freq_weight) + self.register_buffer("freq_weight", freq_weight) + else: + self.freq_weight = None + + if self.use_fed_loss and len(self.freq_weight) < self.num_classes: + # assert self.num_classes == 11493 + print("Extending federated loss weight") + self.freq_weight = torch.cat( + [ + self.freq_weight, + self.freq_weight.new_zeros(self.num_classes - len(self.freq_weight)), + ] + ) + + assert (not self.dynamic_classifier) or (not self.use_fed_loss) + input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) + + if self.use_zeroshot_cls: + del self.cls_score + del self.bbox_pred + assert cls_score is not None + self.cls_score = cls_score + self.bbox_pred = nn.Sequential( + nn.Linear(input_size, input_size), nn.ReLU(inplace=True), nn.Linear(input_size, 4) + ) + weight_init.c2_xavier_fill(self.bbox_pred[0]) + nn.init.normal_(self.bbox_pred[-1].weight, std=0.001) + nn.init.constant_(self.bbox_pred[-1].bias, 0) + + if self.with_softmax_prop: + self.prop_score = nn.Sequential( + nn.Linear(input_size, input_size), + nn.ReLU(inplace=True), + nn.Linear(input_size, self.num_classes + 1), + ) + weight_init.c2_xavier_fill(self.prop_score[0]) + nn.init.normal_(self.prop_score[-1].weight, mean=0, std=0.001) + nn.init.constant_(self.prop_score[-1].bias, 0) + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret.update( + { + "mult_proposal_score": cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE, + "sync_caption_batch": cfg.MODEL.SYNC_CAPTION_BATCH, + "use_sigmoid_ce": cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE, + "use_fed_loss": cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS, + "ignore_zero_cats": cfg.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS, + "fed_loss_num_cat": cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT, + "dynamic_classifier": cfg.MODEL.DYNAMIC_CLASSIFIER, + "image_label_loss": cfg.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS, + "use_zeroshot_cls": cfg.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS, + "image_loss_weight": cfg.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT, + "with_softmax_prop": cfg.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP, + "caption_weight": cfg.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT, + "neg_cap_weight": cfg.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT, + "add_image_box": cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX, + "debug": cfg.DEBUG or cfg.SAVE_DEBUG or cfg.IS_DEBUG, + "prior_prob": cfg.MODEL.ROI_BOX_HEAD.PRIOR_PROB, + "cat_freq_path": cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, + "fed_loss_freq_weight": cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT, + "softmax_weak_loss": cfg.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS, + } + ) + if ret["use_zeroshot_cls"]: + ret["cls_score"] = ZeroShotClassifier(cfg, input_shape) + return ret + + def losses( + self, predictions, proposals, use_advanced_loss: bool=True, classifier_info=(None, None, None) + ): + """ + enable advanced loss + """ + scores, proposal_deltas = predictions + gt_classes = ( + cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) + ) + num_classes = self.num_classes + if self.dynamic_classifier: + _, cls_id_map = classifier_info[1] + gt_classes = cls_id_map[gt_classes] + num_classes = scores.shape[1] - 1 + assert cls_id_map[self.num_classes] == num_classes + _log_classification_stats(scores, gt_classes) + + if len(proposals): + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4 + assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" + gt_boxes = cat( + [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], + dim=0, + ) + else: + proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) + + if self.use_sigmoid_ce: + loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) + else: + loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes) + return { + "loss_cls": loss_cls, + "loss_box_reg": self.box_reg_loss( + proposal_boxes, gt_boxes, proposal_deltas, gt_classes, num_classes=num_classes + ), + } + + def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes): + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] # This is more robust than .sum() * 0. + + B = pred_class_logits.shape[0] + C = pred_class_logits.shape[1] - 1 + + target = pred_class_logits.new_zeros(B, C + 1) + target[range(len(gt_classes)), gt_classes] = 1 # B x (C + 1) + target = target[:, :C] # B x C + + weight = 1 + + if self.use_fed_loss and (self.freq_weight is not None): # fedloss + appeared = get_fed_loss_inds( + gt_classes, num_sample_cats=self.fed_loss_num_cat, C=C, weight=self.freq_weight + ) + appeared_mask = appeared.new_zeros(C + 1) + appeared_mask[appeared] = 1 # C + 1 + appeared_mask = appeared_mask[:C] + fed_w = appeared_mask.view(1, C).expand(B, C) + weight = weight * fed_w.float() + if self.ignore_zero_cats and (self.freq_weight is not None): + w = (self.freq_weight.view(-1) > 1e-4).float() + weight = weight * w.view(1, C).expand(B, C) + # import pdb; pdb.set_trace() + + cls_loss = F.binary_cross_entropy_with_logits( + pred_class_logits[:, :-1], target, reduction="none" + ) # B x C + loss = torch.sum(cls_loss * weight) / B + return loss + + def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): + """ + change _no_instance handling + """ + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] + + if self.ignore_zero_cats and (self.freq_weight is not None): + zero_weight = torch.cat( + [(self.freq_weight.view(-1) > 1e-4).float(), self.freq_weight.new_ones(1)] + ) # C + 1 + loss = F.cross_entropy( + pred_class_logits, gt_classes, weight=zero_weight, reduction="mean" + ) + elif self.use_fed_loss and (self.freq_weight is not None): # fedloss + C = pred_class_logits.shape[1] - 1 + appeared = get_fed_loss_inds( + gt_classes, num_sample_cats=self.fed_loss_num_cat, C=C, weight=self.freq_weight + ) + appeared_mask = appeared.new_zeros(C + 1).float() + appeared_mask[appeared] = 1.0 # C + 1 + appeared_mask[C] = 1.0 + loss = F.cross_entropy( + pred_class_logits, gt_classes, weight=appeared_mask, reduction="mean" + ) + else: + loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean") + return loss + + def box_reg_loss(self, proposal_boxes, gt_boxes, pred_deltas, gt_classes, num_classes: int=-1): + """ + Allow custom background index + """ + num_classes = num_classes if num_classes > 0 else self.num_classes + box_dim = proposal_boxes.shape[1] # 4 or 5 + fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < num_classes))[0] + if pred_deltas.shape[1] == box_dim: # cls-agnostic regression + fg_pred_deltas = pred_deltas[fg_inds] + else: + fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[ + fg_inds, gt_classes[fg_inds] + ] + + if self.box_reg_loss_type == "smooth_l1": + gt_pred_deltas = self.box2box_transform.get_deltas( + proposal_boxes[fg_inds], + gt_boxes[fg_inds], + ) + loss_box_reg = smooth_l1_loss( + fg_pred_deltas, gt_pred_deltas, self.smooth_l1_beta, reduction="sum" + ) + elif self.box_reg_loss_type == "giou": + fg_pred_boxes = self.box2box_transform.apply_deltas( + fg_pred_deltas, proposal_boxes[fg_inds] + ) + loss_box_reg = giou_loss(fg_pred_boxes, gt_boxes[fg_inds], reduction="sum") + else: + raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'") + return loss_box_reg / max(gt_classes.numel(), 1.0) + + def inference(self, predictions, proposals): + """ + enable use proposal boxes + """ + predictions = (predictions[0], predictions[1]) + boxes = self.predict_boxes(predictions, proposals) + scores = self.predict_probs(predictions, proposals) + if self.mult_proposal_score: + proposal_scores = [p.get("objectness_logits") for p in proposals] + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] + image_shapes = [x.image_size for x in proposals] + return fast_rcnn_inference( + boxes, + scores, + image_shapes, + self.test_score_thresh, + self.test_nms_thresh, + self.test_topk_per_image, + ) + + def predict_probs(self, predictions, proposals): + """ + support sigmoid + """ + # scores, _ = predictions + scores = predictions[0] + num_inst_per_image = [len(p) for p in proposals] + if self.use_sigmoid_ce: + probs = scores.sigmoid() + else: + probs = F.softmax(scores, dim=-1) + return probs.split(num_inst_per_image, dim=0) + + def image_label_losses( + self, + predictions, + proposals, + image_labels: Sequence[str], + classifier_info=(None, None, None), + ann_type: str="image", + ): + """ + Inputs: + scores: N x (C + 1) + image_labels B x 1 + """ + num_inst_per_image = [len(p) for p in proposals] + scores = predictions[0] + scores = scores.split(num_inst_per_image, dim=0) # B x n x (C + 1) + if self.with_softmax_prop: + prop_scores = predictions[2].split(num_inst_per_image, dim=0) + else: + prop_scores = [None for _ in num_inst_per_image] + B = len(scores) + img_box_count = 0 + select_size_count = 0 + select_x_count = 0 + select_y_count = 0 + max_score_count = 0 + storage = get_event_storage() + loss = scores[0].new_zeros([1])[0] + caption_loss = scores[0].new_zeros([1])[0] + for idx, (score, labels, prop_score, p) in enumerate( + zip(scores, image_labels, prop_scores, proposals, strict=False) + ): + if score.shape[0] == 0: + loss += score.new_zeros([1])[0] + continue + if "caption" in ann_type: + score, caption_loss_img = self._caption_loss(score, classifier_info, idx, B) + caption_loss += self.caption_weight * caption_loss_img + if ann_type == "caption": + continue + + if self.debug: + p.selected = score.new_zeros((len(p),), dtype=torch.long) - 1 + for i_l, label in enumerate(labels): + if self.dynamic_classifier: + if idx == 0 and i_l == 0 and comm.is_main_process(): + storage.put_scalar("stats_label", label) + label = classifier_info[1][1][label] + assert label < score.shape[1] + if self.image_label_loss in ["wsod", "wsddn"]: + loss_i, ind = self._wsddn_loss(score, prop_score, label) + elif self.image_label_loss == "max_score": + loss_i, ind = self._max_score_loss(score, label) + elif self.image_label_loss == "max_size": + loss_i, ind = self._max_size_loss(score, label, p) + elif self.image_label_loss == "first": + loss_i, ind = self._first_loss(score, label) + elif self.image_label_loss == "image": + loss_i, ind = self._image_loss(score, label) + elif self.image_label_loss == "min_loss": + loss_i, ind = self._min_loss_loss(score, label) + else: + assert 0 + loss += loss_i / len(labels) + if type(ind) == type([]): + img_box_count = sum(ind) / len(ind) + if self.debug: + for ind_i in ind: + p.selected[ind_i] = label + else: + img_box_count = ind + select_size_count = p[ind].proposal_boxes.area() / ( + p.image_size[0] * p.image_size[1] + ) + max_score_count = score[ind, label].sigmoid() + select_x_count = ( + (p.proposal_boxes.tensor[ind, 0] + p.proposal_boxes.tensor[ind, 2]) + / 2 + / p.image_size[1] + ) + select_y_count = ( + (p.proposal_boxes.tensor[ind, 1] + p.proposal_boxes.tensor[ind, 3]) + / 2 + / p.image_size[0] + ) + if self.debug: + p.selected[ind] = label + + loss = loss / B + storage.put_scalar("stats_l_image", loss.item()) + if "caption" in ann_type: + caption_loss = caption_loss / B + loss = loss + caption_loss + storage.put_scalar("stats_l_caption", caption_loss.item()) + if comm.is_main_process(): + storage.put_scalar("pool_stats", img_box_count) + storage.put_scalar("stats_select_size", select_size_count) + storage.put_scalar("stats_select_x", select_x_count) + storage.put_scalar("stats_select_y", select_y_count) + storage.put_scalar("stats_max_label_score", max_score_count) + + return { + "image_loss": loss * self.image_loss_weight, + "loss_cls": score.new_zeros([1])[0], + "loss_box_reg": score.new_zeros([1])[0], + } + + def forward(self, x, classifier_info=(None, None, None)): + """ + enable classifier_info + """ + if x.dim() > 2: + x = torch.flatten(x, start_dim=1) + scores = [] + + if classifier_info[0] is not None: + cls_scores = self.cls_score(x, classifier=classifier_info[0]) + scores.append(cls_scores) + else: + cls_scores = self.cls_score(x) + scores.append(cls_scores) + + if classifier_info[2] is not None: + cap_cls = classifier_info[2] + if self.sync_caption_batch: + caption_scores = self.cls_score(x, classifier=cap_cls[:, :-1]) + else: + caption_scores = self.cls_score(x, classifier=cap_cls) + scores.append(caption_scores) + scores = torch.cat(scores, dim=1) # B x C' or B x N or B x (C'+N) + + proposal_deltas = self.bbox_pred(x) + if self.with_softmax_prop: + prop_score = self.prop_score(x) + return scores, proposal_deltas, prop_score + else: + return scores, proposal_deltas + + def _caption_loss(self, score, classifier_info, idx: int, B): + assert classifier_info[2] is not None + assert self.add_image_box + cls_and_cap_num = score.shape[1] + cap_num = classifier_info[2].shape[0] + score, caption_score = score.split([cls_and_cap_num - cap_num, cap_num], dim=1) + # n x (C + 1), n x B + caption_score = caption_score[-1:] # 1 x B # -1: image level box + caption_target = caption_score.new_zeros( + caption_score.shape + ) # 1 x B or 1 x MB, M: num machines + if self.sync_caption_batch: + # caption_target: 1 x MB + rank = comm.get_rank() + global_idx = B * rank + idx + assert (classifier_info[2][global_idx, -1] - rank) ** 2 < 1e-8, f"{rank} {global_idx} {classifier_info[2][global_idx, -1]} {classifier_info[2].shape} {classifier_info[2][:, -1]}" + caption_target[:, global_idx] = 1.0 + else: + assert caption_score.shape[1] == B + caption_target[:, idx] = 1.0 + caption_loss_img = F.binary_cross_entropy_with_logits( + caption_score, caption_target, reduction="none" + ) + if self.sync_caption_batch: + fg_mask = (caption_target > 0.5).float() + assert (fg_mask.sum().item() - 1.0) ** 2 < 1e-8, f"{fg_mask.shape} {fg_mask}" + pos_loss = (caption_loss_img * fg_mask).sum() + neg_loss = (caption_loss_img * (1.0 - fg_mask)).sum() + caption_loss_img = pos_loss + self.neg_cap_weight * neg_loss + else: + caption_loss_img = caption_loss_img.sum() + return score, caption_loss_img + + def _wsddn_loss(self, score, prop_score, label: str): + assert prop_score is not None + loss = 0 + final_score = score.sigmoid() * F.softmax(prop_score, dim=0) # B x (C + 1) + img_score = torch.clamp(torch.sum(final_score, dim=0), min=1e-10, max=1 - 1e-10) # (C + 1) + target = img_score.new_zeros(img_score.shape) # (C + 1) + target[label] = 1.0 + loss += F.binary_cross_entropy(img_score, target) + ind = final_score[:, label].argmax() + return loss, ind + + def _max_score_loss(self, score, label: str): + loss = 0 + target = score.new_zeros(score.shape[1]) + target[label] = 1.0 + ind = score[:, label].argmax().item() + loss += F.binary_cross_entropy_with_logits(score[ind], target, reduction="sum") + return loss, ind + + def _min_loss_loss(self, score, label: str): + loss = 0 + target = score.new_zeros(score.shape) + target[:, label] = 1.0 + with torch.no_grad(): + x = F.binary_cross_entropy_with_logits(score, target, reduction="none").sum(dim=1) # n + ind = x.argmin().item() + loss += F.binary_cross_entropy_with_logits(score[ind], target[0], reduction="sum") + return loss, ind + + def _first_loss(self, score, label: str): + loss = 0 + target = score.new_zeros(score.shape[1]) + target[label] = 1.0 + ind = 0 + loss += F.binary_cross_entropy_with_logits(score[ind], target, reduction="sum") + return loss, ind + + def _image_loss(self, score, label: str): + assert self.add_image_box + target = score.new_zeros(score.shape[1]) + target[label] = 1.0 + ind = score.shape[0] - 1 + loss = F.binary_cross_entropy_with_logits(score[ind], target, reduction="sum") + return loss, ind + + def _max_size_loss(self, score, label: str, p): + loss = 0 + target = score.new_zeros(score.shape[1]) + target[label] = 1.0 + sizes = p.proposal_boxes.area() + ind = sizes[:-1].argmax().item() if len(sizes) > 1 else 0 + if self.softmax_weak_loss: + loss += F.cross_entropy( + score[ind : ind + 1], + score.new_tensor(label, dtype=torch.long).view(1), + reduction="sum", + ) + else: + loss += F.binary_cross_entropy_with_logits(score[ind], target, reduction="sum") + return loss, ind + + +def put_label_distribution(storage, hist_name: str, hist_counts, num_classes: int) -> None: + """ """ + ht_min, ht_max = 0, num_classes + hist_edges = torch.linspace( + start=ht_min, end=ht_max, steps=num_classes + 1, dtype=torch.float32 + ) + + hist_params = dict( + tag=hist_name, + min=ht_min, + max=ht_max, + num=float(hist_counts.sum()), + sum=float((hist_counts * torch.arange(len(hist_counts))).sum()), + sum_squares=float(((hist_counts * torch.arange(len(hist_counts))) ** 2).sum()), + bucket_limits=hist_edges[1:].tolist(), + bucket_counts=hist_counts.tolist(), + global_step=storage._iter, + ) + storage._histograms.append(hist_params) diff --git a/dimos/models/Detic/detic/modeling/roi_heads/detic_roi_heads.py b/dimos/models/Detic/detic/modeling/roi_heads/detic_roi_heads.py new file mode 100644 index 0000000000..a8f1f4efe2 --- /dev/null +++ b/dimos/models/Detic/detic/modeling/roi_heads/detic_roi_heads.py @@ -0,0 +1,258 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import configurable +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient +from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference +from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY +from detectron2.structures import Boxes, Instances +from detectron2.utils.events import get_event_storage +import torch + +from .detic_fast_rcnn import DeticFastRCNNOutputLayers +from typing import Sequence + + +@ROI_HEADS_REGISTRY.register() +class DeticCascadeROIHeads(CascadeROIHeads): + @configurable + def __init__( + self, + *, + mult_proposal_score: bool = False, + with_image_labels: bool = False, + add_image_box: bool = False, + image_box_size: float = 1.0, + ws_num_props: int = 512, + add_feature_to_prop: bool = False, + mask_weight: float = 1.0, + one_class_per_proposal: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.mult_proposal_score = mult_proposal_score + self.with_image_labels = with_image_labels + self.add_image_box = add_image_box + self.image_box_size = image_box_size + self.ws_num_props = ws_num_props + self.add_feature_to_prop = add_feature_to_prop + self.mask_weight = mask_weight + self.one_class_per_proposal = one_class_per_proposal + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret.update( + { + "mult_proposal_score": cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE, + "with_image_labels": cfg.WITH_IMAGE_LABELS, + "add_image_box": cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX, + "image_box_size": cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE, + "ws_num_props": cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS, + "add_feature_to_prop": cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP, + "mask_weight": cfg.MODEL.ROI_HEADS.MASK_WEIGHT, + "one_class_per_proposal": cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL, + } + ) + return ret + + @classmethod + def _init_box_head(cls, cfg, input_shape): + ret = super()._init_box_head(cfg, input_shape) + del ret["box_predictors"] + cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS + box_predictors = [] + for box_head, bbox_reg_weights in zip(ret["box_heads"], cascade_bbox_reg_weights, strict=False): + box_predictors.append( + DeticFastRCNNOutputLayers( + cfg, + box_head.output_shape, + box2box_transform=Box2BoxTransform(weights=bbox_reg_weights), + ) + ) + ret["box_predictors"] = box_predictors + return ret + + def _forward_box( + self, features, proposals, targets=None, ann_type: str="box", classifier_info=(None, None, None) + ): + """ + Add mult proposal scores at testing + Add ann_type + """ + if (not self.training) and self.mult_proposal_score: + if len(proposals) > 0 and proposals[0].has("scores"): + proposal_scores = [p.get("scores") for p in proposals] + else: + proposal_scores = [p.get("objectness_logits") for p in proposals] + + features = [features[f] for f in self.box_in_features] + head_outputs = [] # (predictor, predictions, proposals) + prev_pred_boxes = None + image_sizes = [x.image_size for x in proposals] + + for k in range(self.num_cascade_stages): + if k > 0: + proposals = self._create_proposals_from_boxes( + prev_pred_boxes, image_sizes, logits=[p.objectness_logits for p in proposals] + ) + if self.training and ann_type in ["box"]: + proposals = self._match_and_label_boxes(proposals, k, targets) + predictions = self._run_stage(features, proposals, k, classifier_info=classifier_info) + prev_pred_boxes = self.box_predictor[k].predict_boxes( + (predictions[0], predictions[1]), proposals + ) + head_outputs.append((self.box_predictor[k], predictions, proposals)) + + if self.training: + losses = {} + storage = get_event_storage() + for stage, (predictor, predictions, proposals) in enumerate(head_outputs): + with storage.name_scope(f"stage{stage}"): + if ann_type != "box": + stage_losses = {} + if ann_type in ["image", "caption", "captiontag"]: + image_labels = [x._pos_category_ids for x in targets] + weak_losses = predictor.image_label_losses( + predictions, + proposals, + image_labels, + classifier_info=classifier_info, + ann_type=ann_type, + ) + stage_losses.update(weak_losses) + else: # supervised + stage_losses = predictor.losses( + (predictions[0], predictions[1]), + proposals, + classifier_info=classifier_info, + ) + if self.with_image_labels: + stage_losses["image_loss"] = predictions[0].new_zeros([1])[0] + losses.update({k + f"_stage{stage}": v for k, v in stage_losses.items()}) + return losses + else: + # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1) + scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] + scores = [ + sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) + for scores_per_image in zip(*scores_per_stage, strict=False) + ] + if self.mult_proposal_score: + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] + if self.one_class_per_proposal: + scores = [s * (s == s[:, :-1].max(dim=1)[0][:, None]).float() for s in scores] + predictor, predictions, proposals = head_outputs[-1] + boxes = predictor.predict_boxes((predictions[0], predictions[1]), proposals) + pred_instances, _ = fast_rcnn_inference( + boxes, + scores, + image_sizes, + predictor.test_score_thresh, + predictor.test_nms_thresh, + predictor.test_topk_per_image, + ) + return pred_instances + + def forward( + self, + images, + features, + proposals, + targets=None, + ann_type: str="box", + classifier_info=(None, None, None), + ): + """ + enable debug and image labels + classifier_info is shared across the batch + """ + if self.training: + if ann_type in ["box", "prop", "proptag"]: + proposals = self.label_and_sample_proposals(proposals, targets) + else: + proposals = self.get_top_proposals(proposals) + + losses = self._forward_box( + features, proposals, targets, ann_type=ann_type, classifier_info=classifier_info + ) + if ann_type == "box" and targets[0].has("gt_masks"): + mask_losses = self._forward_mask(features, proposals) + losses.update({k: v * self.mask_weight for k, v in mask_losses.items()}) + losses.update(self._forward_keypoint(features, proposals)) + else: + losses.update( + self._get_empty_mask_loss( + features, proposals, device=proposals[0].objectness_logits.device + ) + ) + return proposals, losses + else: + pred_instances = self._forward_box(features, proposals, classifier_info=classifier_info) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + return pred_instances, {} + + def get_top_proposals(self, proposals): + for i in range(len(proposals)): + proposals[i].proposal_boxes.clip(proposals[i].image_size) + proposals = [p[: self.ws_num_props] for p in proposals] + for i, p in enumerate(proposals): + p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach() + if self.add_image_box: + proposals[i] = self._add_image_box(p) + return proposals + + def _add_image_box(self, p): + image_box = Instances(p.image_size) + n = 1 + h, w = p.image_size + f = self.image_box_size + image_box.proposal_boxes = Boxes( + p.proposal_boxes.tensor.new_tensor( + [ + w * (1.0 - f) / 2.0, + h * (1.0 - f) / 2.0, + w * (1.0 - (1.0 - f) / 2.0), + h * (1.0 - (1.0 - f) / 2.0), + ] + ).view(n, 4) + ) + image_box.objectness_logits = p.objectness_logits.new_ones(n) + return Instances.cat([p, image_box]) + + def _get_empty_mask_loss(self, features, proposals, device): + if self.mask_on: + return {"loss_mask": torch.zeros((1,), device=device, dtype=torch.float32)[0]} + else: + return {} + + def _create_proposals_from_boxes(self, boxes, image_sizes: Sequence[int], logits): + """ + Add objectness_logits + """ + boxes = [Boxes(b.detach()) for b in boxes] + proposals = [] + for boxes_per_image, image_size, logit in zip(boxes, image_sizes, logits, strict=False): + boxes_per_image.clip(image_size) + if self.training: + inds = boxes_per_image.nonempty() + boxes_per_image = boxes_per_image[inds] + logit = logit[inds] + prop = Instances(image_size) + prop.proposal_boxes = boxes_per_image + prop.objectness_logits = logit + proposals.append(prop) + return proposals + + def _run_stage(self, features, proposals, stage, classifier_info=(None, None, None)): + """ + Support classifier_info and add_feature_to_prop + """ + pool_boxes = [x.proposal_boxes for x in proposals] + box_features = self.box_pooler(features, pool_boxes) + box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages) + box_features = self.box_head[stage](box_features) + if self.add_feature_to_prop: + feats_per_image = box_features.split([len(p) for p in proposals], dim=0) + for feat, p in zip(feats_per_image, proposals, strict=False): + p.feat = feat + return self.box_predictor[stage](box_features, classifier_info=classifier_info) diff --git a/dimos/models/Detic/detic/modeling/roi_heads/res5_roi_heads.py b/dimos/models/Detic/detic/modeling/roi_heads/res5_roi_heads.py new file mode 100644 index 0000000000..642f889b5d --- /dev/null +++ b/dimos/models/Detic/detic/modeling/roi_heads/res5_roi_heads.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import configurable +from detectron2.layers import ShapeSpec +from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads +from detectron2.structures import Boxes, Instances +import torch + +from ..debug import debug_second_stage +from .detic_fast_rcnn import DeticFastRCNNOutputLayers + + +@ROI_HEADS_REGISTRY.register() +class CustomRes5ROIHeads(Res5ROIHeads): + @configurable + def __init__(self, **kwargs) -> None: + cfg = kwargs.pop("cfg") + super().__init__(**kwargs) + stage_channel_factor = 2**3 + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor + + self.with_image_labels = cfg.WITH_IMAGE_LABELS + self.ws_num_props = cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS + self.add_image_box = cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX + self.add_feature_to_prop = cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP + self.image_box_size = cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE + self.box_predictor = DeticFastRCNNOutputLayers( + cfg, ShapeSpec(channels=out_channels, height=1, width=1) + ) + + self.save_debug = cfg.SAVE_DEBUG + self.save_debug_path = cfg.SAVE_DEBUG_PATH + if self.save_debug: + self.debug_show_name = cfg.DEBUG_SHOW_NAME + self.vis_thresh = cfg.VIS_THRESH + self.pixel_mean = ( + torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + ) + self.pixel_std = ( + torch.Tensor(cfg.MODEL.PIXEL_STD).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + ) + self.bgr = cfg.INPUT.FORMAT == "BGR" + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret["cfg"] = cfg + return ret + + def forward( + self, + images, + features, + proposals, + targets=None, + ann_type: str="box", + classifier_info=(None, None, None), + ): + """ + enable debug and image labels + classifier_info is shared across the batch + """ + if not self.save_debug: + del images + + if self.training: + if ann_type in ["box"]: + proposals = self.label_and_sample_proposals(proposals, targets) + else: + proposals = self.get_top_proposals(proposals) + + proposal_boxes = [x.proposal_boxes for x in proposals] + box_features = self._shared_roi_transform( + [features[f] for f in self.in_features], proposal_boxes + ) + predictions = self.box_predictor( + box_features.mean(dim=[2, 3]), classifier_info=classifier_info + ) + + if self.add_feature_to_prop: + feats_per_image = box_features.mean(dim=[2, 3]).split( + [len(p) for p in proposals], dim=0 + ) + for feat, p in zip(feats_per_image, proposals, strict=False): + p.feat = feat + + if self.training: + del features + if ann_type != "box": + image_labels = [x._pos_category_ids for x in targets] + losses = self.box_predictor.image_label_losses( + predictions, + proposals, + image_labels, + classifier_info=classifier_info, + ann_type=ann_type, + ) + else: + losses = self.box_predictor.losses((predictions[0], predictions[1]), proposals) + if self.with_image_labels: + assert "image_loss" not in losses + losses["image_loss"] = predictions[0].new_zeros([1])[0] + if self.save_debug: + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean + if ann_type != "box": + image_labels = [x._pos_category_ids for x in targets] + else: + image_labels = [[] for x in targets] + debug_second_stage( + [denormalizer(x.clone()) for x in images], + targets, + proposals=proposals, + save_debug=self.save_debug, + debug_show_name=self.debug_show_name, + vis_thresh=self.vis_thresh, + image_labels=image_labels, + save_debug_path=self.save_debug_path, + bgr=self.bgr, + ) + return proposals, losses + else: + pred_instances, _ = self.box_predictor.inference(predictions, proposals) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + if self.save_debug: + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean + debug_second_stage( + [denormalizer(x.clone()) for x in images], + pred_instances, + proposals=proposals, + save_debug=self.save_debug, + debug_show_name=self.debug_show_name, + vis_thresh=self.vis_thresh, + save_debug_path=self.save_debug_path, + bgr=self.bgr, + ) + return pred_instances, {} + + def get_top_proposals(self, proposals): + for i in range(len(proposals)): + proposals[i].proposal_boxes.clip(proposals[i].image_size) + proposals = [p[: self.ws_num_props] for p in proposals] + for i, p in enumerate(proposals): + p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach() + if self.add_image_box: + proposals[i] = self._add_image_box(p) + return proposals + + def _add_image_box(self, p, use_score: bool=False): + image_box = Instances(p.image_size) + n = 1 + h, w = p.image_size + if self.image_box_size < 1.0: + f = self.image_box_size + image_box.proposal_boxes = Boxes( + p.proposal_boxes.tensor.new_tensor( + [ + w * (1.0 - f) / 2.0, + h * (1.0 - f) / 2.0, + w * (1.0 - (1.0 - f) / 2.0), + h * (1.0 - (1.0 - f) / 2.0), + ] + ).view(n, 4) + ) + else: + image_box.proposal_boxes = Boxes( + p.proposal_boxes.tensor.new_tensor([0, 0, w, h]).view(n, 4) + ) + if use_score: + image_box.scores = p.objectness_logits.new_ones(n) + image_box.pred_classes = p.objectness_logits.new_zeros(n, dtype=torch.long) + image_box.objectness_logits = p.objectness_logits.new_ones(n) + else: + image_box.objectness_logits = p.objectness_logits.new_ones(n) + return Instances.cat([p, image_box]) diff --git a/dimos/models/Detic/detic/modeling/roi_heads/zero_shot_classifier.py b/dimos/models/Detic/detic/modeling/roi_heads/zero_shot_classifier.py new file mode 100644 index 0000000000..d436e6be34 --- /dev/null +++ b/dimos/models/Detic/detic/modeling/roi_heads/zero_shot_classifier.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import configurable +from detectron2.layers import ShapeSpec +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +class ZeroShotClassifier(nn.Module): + @configurable + def __init__( + self, + input_shape: ShapeSpec, + *, + num_classes: int, + zs_weight_path: str, + zs_weight_dim: int = 512, + use_bias: float = 0.0, + norm_weight: bool = True, + norm_temperature: float = 50.0, + ) -> None: + super().__init__() + if isinstance(input_shape, int): # some backward compatibility + input_shape = ShapeSpec(channels=input_shape) + input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) + self.norm_weight = norm_weight + self.norm_temperature = norm_temperature + + self.use_bias = use_bias < 0 + if self.use_bias: + self.cls_bias = nn.Parameter(torch.ones(1) * use_bias) + + self.linear = nn.Linear(input_size, zs_weight_dim) + + if zs_weight_path == "rand": + zs_weight = torch.randn((zs_weight_dim, num_classes)) + nn.init.normal_(zs_weight, std=0.01) + else: + zs_weight = ( + torch.tensor(np.load(zs_weight_path), dtype=torch.float32) + .permute(1, 0) + .contiguous() + ) # D x C + zs_weight = torch.cat( + [zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))], dim=1 + ) # D x (C + 1) + + if self.norm_weight: + zs_weight = F.normalize(zs_weight, p=2, dim=0) + + if zs_weight_path == "rand": + self.zs_weight = nn.Parameter(zs_weight) + else: + self.register_buffer("zs_weight", zs_weight) + + assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape + + @classmethod + def from_config(cls, cfg, input_shape): + return { + "input_shape": input_shape, + "num_classes": cfg.MODEL.ROI_HEADS.NUM_CLASSES, + "zs_weight_path": cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH, + "zs_weight_dim": cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM, + "use_bias": cfg.MODEL.ROI_BOX_HEAD.USE_BIAS, + "norm_weight": cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT, + "norm_temperature": cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP, + } + + def forward(self, x, classifier=None): + """ + Inputs: + x: B x D' + classifier_info: (C', C' x D) + """ + x = self.linear(x) + if classifier is not None: + zs_weight = classifier.permute(1, 0).contiguous() # D x C' + zs_weight = F.normalize(zs_weight, p=2, dim=0) if self.norm_weight else zs_weight + else: + zs_weight = self.zs_weight + if self.norm_weight: + x = self.norm_temperature * F.normalize(x, p=2, dim=1) + x = torch.mm(x, zs_weight) + if self.use_bias: + x = x + self.cls_bias + return x diff --git a/dimos/models/Detic/detic/modeling/text/text_encoder.py b/dimos/models/Detic/detic/modeling/text/text_encoder.py new file mode 100644 index 0000000000..7c9b15bdf5 --- /dev/null +++ b/dimos/models/Detic/detic/modeling/text/text_encoder.py @@ -0,0 +1,198 @@ +# This code is modified from https://github.com/openai/CLIP/blob/main/clip/clip.py +# Modified by Xingyi Zhou +# The original code is under MIT license +# Copyright (c) Facebook, Inc. and its affiliates. +from collections import OrderedDict +from typing import List, Union + +from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer +import torch +from torch import nn + +__all__ = ["tokenize"] + +count = 0 + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None) -> None: + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None + else None + ) + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None) -> None: + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] + ) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class CLIPTEXT(nn.Module): + def __init__( + self, + embed_dim: int=512, + # text + context_length: int=77, + vocab_size: int=49408, + transformer_width: int=512, + transformer_heads: int=8, + transformer_layers: int=12, + ) -> None: + super().__init__() + + self._tokenizer = _Tokenizer() + self.context_length = context_length + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width) + ) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self) -> None: + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def device(self): + return self.text_projection.device + + @property + def dtype(self): + return self.text_projection.dtype + + def tokenize(self, texts: Union[str, list[str]], context_length: int = 77) -> torch.LongTensor: + """ """ + if isinstance(texts, str): + texts = [texts] + + sot_token = self._tokenizer.encoder["<|startoftext|>"] + eot_token = self._tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token, *self._tokenizer.encode(text), eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + st = torch.randint(len(tokens) - context_length + 1, (1,))[0].item() + tokens = tokens[st : st + context_length] + # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, : len(tokens)] = torch.tensor(tokens) + + return result + + def encode_text(self, text: str): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x + + def forward(self, captions): + """ + captions: list of strings + """ + text = self.tokenize(captions).to(self.device) # B x L x D + features = self.encode_text(text) # B x D + return features + + +def build_text_encoder(pretrain: bool=True): + text_encoder = CLIPTEXT() + if pretrain: + import clip + + pretrained_model, _ = clip.load("ViT-B/32", device="cpu") + state_dict = pretrained_model.state_dict() + to_delete_keys = ["logit_scale", "input_resolution", "context_length", "vocab_size"] + [ + k for k in state_dict.keys() if k.startswith("visual.") + ] + for k in to_delete_keys: + if k in state_dict: + del state_dict[k] + print("Loading pretrained CLIP") + text_encoder.load_state_dict(state_dict) + # import pdb; pdb.set_trace() + return text_encoder diff --git a/dimos/models/Detic/detic/modeling/utils.py b/dimos/models/Detic/detic/modeling/utils.py new file mode 100644 index 0000000000..f24a0699a1 --- /dev/null +++ b/dimos/models/Detic/detic/modeling/utils.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import json + +import numpy as np +import torch +from torch.nn import functional as F + + +def load_class_freq(path: str="datasets/metadata/lvis_v1_train_cat_info.json", freq_weight: float=1.0): + cat_info = json.load(open(path)) + cat_info = torch.tensor([c["image_count"] for c in sorted(cat_info, key=lambda x: x["id"])]) + freq_weight = cat_info.float() ** freq_weight + return freq_weight + + +def get_fed_loss_inds(gt_classes, num_sample_cats: int, C, weight=None): + appeared = torch.unique(gt_classes) # C' + prob = appeared.new_ones(C + 1).float() + prob[-1] = 0 + if len(appeared) < num_sample_cats: + if weight is not None: + prob[:C] = weight.float().clone() + prob[appeared] = 0 + more_appeared = torch.multinomial(prob, num_sample_cats - len(appeared), replacement=False) + appeared = torch.cat([appeared, more_appeared]) + return appeared + + +def reset_cls_test(model, cls_path, num_classes: int) -> None: + model.roi_heads.num_classes = num_classes + if type(cls_path) == str: + print("Resetting zs_weight", cls_path) + zs_weight = ( + torch.tensor(np.load(cls_path), dtype=torch.float32).permute(1, 0).contiguous() + ) # D x C + else: + zs_weight = cls_path + zs_weight = torch.cat( + [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))], dim=1 + ) # D x (C + 1) + if model.roi_heads.box_predictor[0].cls_score.norm_weight: + zs_weight = F.normalize(zs_weight, p=2, dim=0) + zs_weight = zs_weight.to(model.device) + for k in range(len(model.roi_heads.box_predictor)): + del model.roi_heads.box_predictor[k].cls_score.zs_weight + model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight diff --git a/dimos/models/Detic/detic/predictor.py b/dimos/models/Detic/detic/predictor.py new file mode 100644 index 0000000000..a85941e25a --- /dev/null +++ b/dimos/models/Detic/detic/predictor.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import atexit +import bisect +from collections import deque +import multiprocessing as mp + +import cv2 +from detectron2.data import MetadataCatalog +from detectron2.engine.defaults import DefaultPredictor +from detectron2.utils.video_visualizer import VideoVisualizer +from detectron2.utils.visualizer import ColorMode, Visualizer +import torch + +from .modeling.utils import reset_cls_test + + +def get_clip_embeddings(vocabulary, prompt: str="a "): + from detic.modeling.text.text_encoder import build_text_encoder + + text_encoder = build_text_encoder(pretrain=True) + text_encoder.eval() + texts = [prompt + x for x in vocabulary] + emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + return emb + + +BUILDIN_CLASSIFIER = { + "lvis": "datasets/metadata/lvis_v1_clip_a+cname.npy", + "objects365": "datasets/metadata/o365_clip_a+cnamefix.npy", + "openimages": "datasets/metadata/oid_clip_a+cname.npy", + "coco": "datasets/metadata/coco_clip_a+cname.npy", +} + +BUILDIN_METADATA_PATH = { + "lvis": "lvis_v1_val", + "objects365": "objects365_v2_val", + "openimages": "oid_val_expanded", + "coco": "coco_2017_val", +} + + +class VisualizationDemo: + def __init__(self, cfg, args, instance_mode=ColorMode.IMAGE, parallel: bool=False) -> None: + """ + Args: + cfg (CfgNode): + instance_mode (ColorMode): + parallel (bool): whether to run the model in different processes from visualization. + Useful since the visualization logic can be slow. + """ + if args.vocabulary == "custom": + self.metadata = MetadataCatalog.get("__unused") + self.metadata.thing_classes = args.custom_vocabulary.split(",") + classifier = get_clip_embeddings(self.metadata.thing_classes) + else: + self.metadata = MetadataCatalog.get(BUILDIN_METADATA_PATH[args.vocabulary]) + classifier = BUILDIN_CLASSIFIER[args.vocabulary] + + num_classes = len(self.metadata.thing_classes) + self.cpu_device = torch.device("cpu") + self.instance_mode = instance_mode + + self.parallel = parallel + if parallel: + num_gpu = torch.cuda.device_count() + self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) + else: + self.predictor = DefaultPredictor(cfg) + reset_cls_test(self.predictor.model, classifier, num_classes) + + def run_on_image(self, image): + """ + Args: + image (np.ndarray): an image of shape (H, W, C) (in BGR order). + This is the format used by OpenCV. + + Returns: + predictions (dict): the output of the model. + vis_output (VisImage): the visualized image output. + """ + vis_output = None + predictions = self.predictor(image) + # Convert image from OpenCV BGR format to Matplotlib RGB format. + image = image[:, :, ::-1] + visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_output = visualizer.draw_panoptic_seg_predictions( + panoptic_seg.to(self.cpu_device), segments_info + ) + else: + if "sem_seg" in predictions: + vis_output = visualizer.draw_sem_seg( + predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + if "instances" in predictions: + instances = predictions["instances"].to(self.cpu_device) + vis_output = visualizer.draw_instance_predictions(predictions=instances) + + return predictions, vis_output + + def _frame_from_video(self, video): + while video.isOpened(): + success, frame = video.read() + if success: + yield frame + else: + break + + def run_on_video(self, video): + """ + Visualizes predictions on frames of the input video. + + Args: + video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be + either a webcam or a video file. + + Yields: + ndarray: BGR visualizations of each video frame. + """ + video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) + + def process_predictions(frame, predictions): + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_frame = video_visualizer.draw_panoptic_seg_predictions( + frame, panoptic_seg.to(self.cpu_device), segments_info + ) + elif "instances" in predictions: + predictions = predictions["instances"].to(self.cpu_device) + vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) + elif "sem_seg" in predictions: + vis_frame = video_visualizer.draw_sem_seg( + frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + + # Converts Matplotlib RGB format to OpenCV BGR format + vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) + return vis_frame + + frame_gen = self._frame_from_video(video) + if self.parallel: + buffer_size = self.predictor.default_buffer_size + + frame_data = deque() + + for cnt, frame in enumerate(frame_gen): + frame_data.append(frame) + self.predictor.put(frame) + + if cnt >= buffer_size: + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + + while len(frame_data): + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + else: + for frame in frame_gen: + yield process_predictions(frame, self.predictor(frame)) + + +class AsyncPredictor: + """ + A predictor that runs the model asynchronously, possibly on >1 GPUs. + Because rendering the visualization takes considerably amount of time, + this helps improve throughput a little bit when rendering videos. + """ + + class _StopToken: + pass + + class _PredictWorker(mp.Process): + def __init__(self, cfg, task_queue, result_queue) -> None: + self.cfg = cfg + self.task_queue = task_queue + self.result_queue = result_queue + super().__init__() + + def run(self) -> None: + predictor = DefaultPredictor(self.cfg) + + while True: + task = self.task_queue.get() + if isinstance(task, AsyncPredictor._StopToken): + break + idx, data = task + result = predictor(data) + self.result_queue.put((idx, result)) + + def __init__(self, cfg, num_gpus: int = 1) -> None: + """ + Args: + cfg (CfgNode): + num_gpus (int): if 0, will run on CPU + """ + num_workers = max(num_gpus, 1) + self.task_queue = mp.Queue(maxsize=num_workers * 3) + self.result_queue = mp.Queue(maxsize=num_workers * 3) + self.procs = [] + for gpuid in range(max(num_gpus, 1)): + cfg = cfg.clone() + cfg.defrost() + cfg.MODEL.DEVICE = f"cuda:{gpuid}" if num_gpus > 0 else "cpu" + self.procs.append( + AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) + ) + + self.put_idx = 0 + self.get_idx = 0 + self.result_rank = [] + self.result_data = [] + + for p in self.procs: + p.start() + atexit.register(self.shutdown) + + def put(self, image) -> None: + self.put_idx += 1 + self.task_queue.put((self.put_idx, image)) + + def get(self): + self.get_idx += 1 # the index needed for this request + if len(self.result_rank) and self.result_rank[0] == self.get_idx: + res = self.result_data[0] + del self.result_data[0], self.result_rank[0] + return res + + while True: + # make sure the results are returned in the correct order + idx, res = self.result_queue.get() + if idx == self.get_idx: + return res + insert = bisect.bisect(self.result_rank, idx) + self.result_rank.insert(insert, idx) + self.result_data.insert(insert, res) + + def __len__(self) -> int: + return self.put_idx - self.get_idx + + def __call__(self, image): + self.put(image) + return self.get() + + def shutdown(self) -> None: + for _ in self.procs: + self.task_queue.put(AsyncPredictor._StopToken()) + + @property + def default_buffer_size(self): + return len(self.procs) * 5 diff --git a/dimos/models/Detic/docs/INSTALL.md b/dimos/models/Detic/docs/INSTALL.md new file mode 100644 index 0000000000..1d5fbc4ae1 --- /dev/null +++ b/dimos/models/Detic/docs/INSTALL.md @@ -0,0 +1,33 @@ +# Installation + +### Requirements +- Linux or macOS with Python ≥ 3.6 +- PyTorch ≥ 1.8. + Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check + PyTorch version matches that is required by Detectron2. +- Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html). + + +### Example conda environment setup +```bash +conda create --name detic python=3.8 -y +conda activate detic +conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia + +# under your working directory +git clone git@github.com:facebookresearch/detectron2.git +cd detectron2 +pip install -e . + +cd .. +git clone https://github.com/facebookresearch/Detic.git --recurse-submodules +cd Detic +pip install -r requirements.txt +``` + +Our project uses two submodules, [CenterNet2](https://github.com/xingyizhou/CenterNet2.git) and [Deformable-DETR](https://github.com/fundamentalvision/Deformable-DETR.git). If you forget to add `--recurse-submodules`, do `git submodule init` and then `git submodule update`. To train models with Deformable-DETR (optional), we need to compile it + +``` +cd third_party/Deformable-DETR/models/ops +./make.sh +``` \ No newline at end of file diff --git a/dimos/models/Detic/docs/MODEL_ZOO.md b/dimos/models/Detic/docs/MODEL_ZOO.md new file mode 100644 index 0000000000..fe7c795197 --- /dev/null +++ b/dimos/models/Detic/docs/MODEL_ZOO.md @@ -0,0 +1,143 @@ +# Detic model zoo + +## Introduction + +This file documents a collection of models reported in our paper. +The training time was measured on [Big Basin](https://engineering.fb.com/data-center-engineering/introducing-big-basin-our-next-generation-ai-hardware/) +servers with 8 NVIDIA V100 GPUs & NVLink. + +#### How to Read the Tables + +The "Name" column contains a link to the config file. +To train a model, run + +``` +python train_net.py --num-gpus 8 --config-file /path/to/config/name.yaml +``` + +To evaluate a model with a trained/ pretrained model, run + +``` +python train_net.py --num-gpus 8 --config-file /path/to/config/name.yaml --eval-only MODEL.WEIGHTS /path/to/weight.pth +``` + +#### Third-party ImageNet-21K Pretrained Models + +Our paper uses ImageNet-21K pretrained models that are not part of Detectron2 (ResNet-50-21K from [MIIL](https://github.com/Alibaba-MIIL/ImageNet21K) and SwinB-21K from [Swin-Transformer](https://github.com/microsoft/Swin-Transformer)). Before training, +please download the models and place them under `DETIC_ROOT/models/`, and following [this tool](../tools/convert-thirdparty-pretrained-model-to-d2.py) to convert the format. + + +## Open-vocabulary LVIS + +| Name |Training time | mask mAP | mask mAP_novel | Download | +|-----------------------|------------------|-----------|-----------------|----------| +|[Box-Supervised_C2_R50_640_4x](../configs/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.yaml) | 17h | 30.2 | 16.4 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup-C2_Lbase_CLIP_R5021k_640b64_4x.pth) | +|[Detic_C2_IN-L_R50_640_4x](../configs/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml) | 22h | 32.4 | 24.9 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LbaseI_CLIP_R5021k_640b64_4x_ft4x_max-size.pth) | +|[Detic_C2_CCimg_R50_640_4x](../configs/Detic_LbaseCCimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml) | 22h | 31.0 | 19.8 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LbaseCCimg_CLIP_R5021k_640b64_4x_ft4x_max-size.pth) | +|[Detic_C2_CCcapimg_R50_640_4x](../configs/Detic_LbaseCCcapimg_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml) | 22h | 31.0 | 21.3 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LbaseCCcapimg_CLIP_R5021k_640b64_4x_ft4x_max-size.pth) | +|[Box-Supervised_C2_SwinB_896_4x](../configs/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.yaml) | 43h | 38.4 | 21.9 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup-C2_Lbase_CLIP_SwinB_896b32_4x.pth) | +|[Detic_C2_IN-L_SwinB_896_4x](../configs/Detic_LbaseI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml) | 47h | 40.7 | 33.8 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LbaseI_CLIP_SwinB_896b32_4x_ft4x_max-size.pth) | + + +#### Note + +- The open-vocabulary LVIS setup is LVIS without rare class annotations in training. We evaluate rare classes as novel classes in testing. + +- The models with `C2` are trained using our improved LVIS baseline (Appendix D of the paper), including CenterNet2 detector, Federated Loss, large-scale jittering, etc. + +- All models use [CLIP](https://github.com/openai/CLIP) embeddings as classifiers. This makes the box-supervised models have non-zero mAP on novel classes. + +- The models with `IN-L` use the overlap classes between ImageNet-21K and LVIS as image-labeled data. + +- The models with `CC` use Conception Captions. `CCimg` uses image labels extracted from the captions (using a naive text-match) as image-labeled data. `CCcapimg` additionally uses the row captions (Appendix C of the paper). + +- The Detic models are finetuned on the corresponding Box-Supervised models above (indicated by MODEL.WEIGHTS in the config files). Please train or download the Box-Supervised model and place them under `DETIC_ROOT/models/` before training the Detic models. + + +## Standard LVIS + +| Name |Training time | mask mAP | mask mAP_rare | Download | +|-----------------------|------------------|-----------|-----------------|----------| +|[Box-Supervised_C2_R50_640_4x](../configs/BoxSup-C2_L_CLIP_R5021k_640b64_4x.yaml) | 17h | 31.5 | 25.6 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup-C2_L_CLIP_R5021k_640b64_4x.pth) | +|[Detic_C2_R50_640_4x](../configs/Detic_LI_CLIP_R5021k_640b64_4x_ft4x_max-size.yaml) | 22h | 33.2 | 29.7 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LI_CLIP_R5021k_640b64_4x_ft4x_max-size.pth) | +|[Box-Supervised_C2_SwinB_896_4x](../configs/BoxSup-C2_L_CLIP_SwinB_896b32_4x.yaml) | 43h | 40.7 | 35.9 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup-C2_L_CLIP_SwinB_896b32_4x.pth) | +|[Detic_C2_SwinB_896_4x](../configs/Detic_LI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml) | 47h | 41.7 | 41.7 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LI_CLIP_SwinB_896b32_4x_ft4x_max-size.pth) | + + +| Name |Training time | box mAP | box mAP_rare | Download | +|-----------------------|------------------|-----------|-----------------|----------| +|[Box-Supervised_DeformDETR_R50_4x](../configs/BoxSup-DeformDETR_L_R50_4x.yaml) | 31h | 31.7 | 21.4 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup-DeformDETR_L_R50_4x.pth) | +|[Detic_DeformDETR_R50_4x](../configs/Detic_DeformDETR_LI_R50_4x_ft4x.yaml) | 47h | 32.5 | 26.2 | [model](https://dl.fbaipublicfiles.com/detic/Detic_DeformDETR_LI_R50_4x_ft4x.pth) | + + +#### Note + +- All Detic models use the overlap classes between ImageNet-21K and LVIS as image-labeled data; + +- The models with `C2` are trained using our improved LVIS baseline in the paper, including CenterNet2 detector, Federated loss, large-scale jittering, etc. + +- The models with `DeformDETR` are Deformable DETR models. We train the models with Federated Loss. + +## Open-vocabulary COCO + +| Name |Training time | box mAP50 | box mAP50_novel | Download | +|-----------------------|------------------|-----------|-----------------|----------| +|[BoxSup_CLIP_R50_1x](../configs/BoxSup_OVCOCO_CLIP_R50_1x.yaml) | 12h | 39.3 | 1.3 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup_OVCOCO_CLIP_R50_1x.pth) | +|[Detic_CLIP_R50_1x_image](../configs/Detic_OVCOCO_CLIP_R50_1x_max-size.yaml) | 13h | 44.7 | 24.1 | [model](https://dl.fbaipublicfiles.com/detic/Detic_OVCOCO_CLIP_R50_1x_max-size.pth) | +|[Detic_CLIP_R50_1x_caption](../configs/Detic_OVCOCO_CLIP_R50_1x_caption.yaml) | 16h | 43.8 | 21.0 | [model](https://dl.fbaipublicfiles.com/detic/Detic_OVCOCO_CLIP_R50_1x_caption.pth) | +|[Detic_CLIP_R50_1x_caption-image](../configs/Detic_OVCOCO_CLIP_R50_1x_max-size_caption.yaml) | 16h | 45.0 | 27.8 | [model](https://dl.fbaipublicfiles.com/detic/Detic_OVCOCO_CLIP_R50_1x_max-size_caption.pth) | + +#### Note + +- All models are trained with ResNet50-C4 without multi-scale augmentation. All models use CLIP embeddings as the classifier. + +- We extract class names from COCO-captions as image-labels. `Detic_CLIP_R50_1x_image` uses the max-size loss; `Detic_CLIP_R50_1x_caption` directly uses CLIP caption embedding within each mini-batch for classification; `Detic_CLIP_R50_1x_caption-image` uses both losses. + +- We report box mAP50 under the "generalized" open-vocabulary setting. + + +## Cross-dataset evaluation + + +| Name |Training time | Objects365 box mAP | OpenImages box mAP50 | Download | +|-----------------------|------------------|-----------|-----------------|----------| +|[Box-Supervised_C2_SwinB_896_4x](../configs/BoxSup-C2_L_CLIP_SwinB_896b32_4x.yaml) | 43h | 19.1 | 46.2 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup-C2_L_CLIP_SwinB_896b32_4x.pth) | +|[Detic_C2_SwinB_896_4x](../configs/Detic_LI_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml) | 47h | 21.2 |53.0 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LI_CLIP_SwinB_896b32_4x_ft4x_max-size.pth) | +|[Detic_C2_SwinB_896_4x_IN-21K](../configs/Detic_LI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml) | 47h | 21.4 | 55.2 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth) | +|[Box-Supervised_C2_SwinB_896_4x+COCO](../configs/BoxSup-C2_LCOCO_CLIP_SwinB_896b32_4x.yaml) | 43h | 19.7 | 46.4 | [model](https://dl.fbaipublicfiles.com/detic/BoxSup-C2_LCOCO_CLIP_SwinB_896b32_4x.pth) | +|[Detic_C2_SwinB_896_4x_IN-21K+COCO](../configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml) | 47h | 21.6 | 54.6 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth) | + + + +#### Note + +- `Box-Supervised_C2_SwinB_896_4x` and `Detic_C2_SwinB_896_4x` are the same model in the [Standard LVIS](#standard-lvis) section, but evaluated with Objects365/ OpenImages vocabulary (i.e. CLIP embeddings of the corresponding class names as classifier). To run the evaluation on Objects365/ OpenImages, run + + ``` + python train_net.py --num-gpus 8 --config-file configs/Detic_C2_SwinB_896_4x.yaml --eval-only DATASETS.TEST "('oid_val_expanded','objects365_v2_val',)" MODEL.RESET_CLS_TESTS True MODEL.TEST_CLASSIFIERS "('datasets/metadata/oid_clip_a+cname.npy','datasets/metadata/o365_clip_a+cnamefix.npy',)" MODEL.TEST_NUM_CLASSES "(500,365)" MODEL.MASK_ON False + ``` + +- `Detic_C2_SwinB_896_4x_IN-21K` trains on the full ImageNet-22K. We additionally use a dynamic class sampling ("Modified Federated Loss" in Section 4.4) and use a larger data sampling ratio of ImageNet images (1:16 instead of 1:4). + +- `Detic_C2_SwinB_896_4x_IN-21K-COCO` is a model trained on combined LVIS-COCO and ImageNet-21K for better demo purposes. LVIS models do not detect persons well due to its federated annotation protocol. LVIS+COCO models give better visual results. + + +## Real-time models + +| Name | Run time (ms) | LVIS box mAP | Download | +|-----------------------|------------------|-----------|-----------------| +|[Detic_C2_SwinB_896_4x_IN-21K+COCO (800x1333, no threshold)](../configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml) | 115 | 44.4 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth) | +|[Detic_C2_SwinB_896_4x_IN-21K+COCO](../configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml) | 46 | 35.0 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth) | +|[Detic_C2_ConvNeXtT_896_4x_IN-21K+COCO](../configs/Detic_LCOCOI21k_CLIP_CXT21k_640b32_4x_ft4x_max-size.yaml) | 26 | 30.7 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_CXT21k_640b32_4x_ft4x_max-size.pth) | +|[Detic_C2_R5021k_896_4x_IN-21K+COCO](../configs/Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size.yaml) | 23 | 29.0 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size.pth) | +|[Detic_C2_R18_896_4x_IN-21K+COCO](../configs/Detic_LCOCOI21k_CLIP_R18_640b32_4x_ft4x_max-size.yaml) | 18 | 22.1 | [model](https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_R18_640b32_4x_ft4x_max-size.pth) | + +- `Detic_C2_SwinB_896_4x_IN-21K+COCO (800x1333, thresh 0.02)` is the entry on the [Cross-dataset evaluation](#Cross-dataset evaluation) section without the mask head. All other entries use a max-size of 640 and an output score threshold of 0.3 using the following command (e.g., with R50). + + ``` + python train_net.py --config-file configs/Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size.yaml --num-gpus 2 --eval-only DATASETS.TEST "('lvis_v1_val',)" MODEL.RESET_CLS_TESTS True MODEL.TEST_CLASSIFIERS "('datasets/metadata/lvis_v1_clip_a+cname.npy',)" MODEL.TEST_NUM_CLASSES "(1203,)" MODEL.MASK_ON False MODEL.WEIGHTS models/Detic_LCOCOI21k_CLIP_R5021k_640b32_4x_ft4x_max-size.pth INPUT.MIN_SIZE_TEST 640 INPUT.MAX_SIZE_TEST 640 MODEL.ROI_HEADS.SCORE_THRESH_TEST 0.3 + ``` + +- All models are trained using the same training recipe except for different backbones. +- The ConvNeXtT and Res50 models are initialized from their corresponding ImageNet-21K pretrained models. The Res18 model is initialized from its ImageNet-1K pretrained model. +- The runtimes are measured on a local workstation with a Titan RTX GPU. diff --git a/dimos/models/Detic/docs/example_output_custom.jpeg b/dimos/models/Detic/docs/example_output_custom.jpeg new file mode 100644 index 0000000000..ac6aa3fb93 Binary files /dev/null and b/dimos/models/Detic/docs/example_output_custom.jpeg differ diff --git a/dimos/models/Detic/docs/example_output_lvis.jpeg b/dimos/models/Detic/docs/example_output_lvis.jpeg new file mode 100644 index 0000000000..3d22122059 Binary files /dev/null and b/dimos/models/Detic/docs/example_output_lvis.jpeg differ diff --git a/dimos/models/Detic/docs/teaser.jpeg.REMOVED.git-id b/dimos/models/Detic/docs/teaser.jpeg.REMOVED.git-id new file mode 100644 index 0000000000..7024286d06 --- /dev/null +++ b/dimos/models/Detic/docs/teaser.jpeg.REMOVED.git-id @@ -0,0 +1 @@ +2e8fbac2f8fc89249a3a3a957d02c2c0701686d7 \ No newline at end of file diff --git a/dimos/models/Detic/lazy_train_net.py b/dimos/models/Detic/lazy_train_net.py new file mode 100644 index 0000000000..3525a1f63a --- /dev/null +++ b/dimos/models/Detic/lazy_train_net.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Training script using the new "LazyConfig" python config files. +This scripts reads a given python config file and runs the training or evaluation. +It can be used to train any models or dataset as long as they can be +instantiated by the recursive construction defined in the given config file. +Besides lazy construction of models, dataloader, etc., this scripts expects a +few common configuration parameters currently defined in "configs/common/train.py". +To add more complicated training logic, you can easily add other configs +in the config file and implement a new train_net.py to handle them. +""" + +import logging +import sys + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + AMPTrainer, + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.utils import comm + +sys.path.insert(0, "third_party/CenterNet2/") +sys.path.insert(0, "third_party/Deformable-DETR") +logger = logging.getLogger("detectron2") + + +def do_test(cfg, model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + +def do_train(args, cfg) -> None: + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `common_train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info(f"Model:\n{model}") + model.to(cfg.train.device) + + cfg.optimizer.params.model = model + optim = instantiate(cfg.optimizer) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim) + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + optimizer=optim, + trainer=trainer, + ) + train_hooks = [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) + if comm.is_main_process() + else None, + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) + if comm.is_main_process() + else None, + ] + trainer.register_hooks(train_hooks) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args) -> None: + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + if args.eval_only: + model = instantiate(cfg.model) + model.to(cfg.train.device) + model = create_ddp_model(model) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + print(do_test(cfg, model)) + else: + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/dimos/models/Detic/predict.py b/dimos/models/Detic/predict.py new file mode 100644 index 0000000000..bf71d007a1 --- /dev/null +++ b/dimos/models/Detic/predict.py @@ -0,0 +1,102 @@ +from pathlib import Path +import sys +import tempfile +import time + +import cog +import cv2 +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog + +# import some common detectron2 utilities +from detectron2.engine import DefaultPredictor +from detectron2.utils.visualizer import Visualizer + +# Detic libraries +sys.path.insert(0, "third_party/CenterNet2/") +from centernet.config import add_centernet_config +from detic.config import add_detic_config +from detic.modeling.text.text_encoder import build_text_encoder +from detic.modeling.utils import reset_cls_test + + +class Predictor(cog.Predictor): + def setup(self) -> None: + cfg = get_cfg() + add_centernet_config(cfg) + add_detic_config(cfg) + cfg.merge_from_file("configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml") + cfg.MODEL.WEIGHTS = "Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth" + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model + cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" + cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True + self.predictor = DefaultPredictor(cfg) + self.BUILDIN_CLASSIFIER = { + "lvis": "datasets/metadata/lvis_v1_clip_a+cname.npy", + "objects365": "datasets/metadata/o365_clip_a+cnamefix.npy", + "openimages": "datasets/metadata/oid_clip_a+cname.npy", + "coco": "datasets/metadata/coco_clip_a+cname.npy", + } + self.BUILDIN_METADATA_PATH = { + "lvis": "lvis_v1_val", + "objects365": "objects365_v2_val", + "openimages": "oid_val_expanded", + "coco": "coco_2017_val", + } + + @cog.input( + "image", + type=Path, + help="input image", + ) + @cog.input( + "vocabulary", + type=str, + default="lvis", + options=["lvis", "objects365", "openimages", "coco", "custom"], + help="Choose vocabulary", + ) + @cog.input( + "custom_vocabulary", + type=str, + default=None, + help="Type your own vocabularies, separated by coma ','", + ) + def predict(self, image, vocabulary, custom_vocabulary): + image = cv2.imread(str(image)) + if not vocabulary == "custom": + metadata = MetadataCatalog.get(self.BUILDIN_METADATA_PATH[vocabulary]) + classifier = self.BUILDIN_CLASSIFIER[vocabulary] + num_classes = len(metadata.thing_classes) + reset_cls_test(self.predictor.model, classifier, num_classes) + + else: + assert custom_vocabulary is not None and len(custom_vocabulary.split(",")) > 0, ( + "Please provide your own vocabularies when vocabulary is set to 'custom'." + ) + metadata = MetadataCatalog.get(str(time.time())) + metadata.thing_classes = custom_vocabulary.split(",") + classifier = get_clip_embeddings(metadata.thing_classes) + num_classes = len(metadata.thing_classes) + reset_cls_test(self.predictor.model, classifier, num_classes) + # Reset visualization threshold + output_score_threshold = 0.3 + for cascade_stages in range(len(self.predictor.model.roi_heads.box_predictor)): + self.predictor.model.roi_heads.box_predictor[ + cascade_stages + ].test_score_thresh = output_score_threshold + + outputs = self.predictor(image) + v = Visualizer(image[:, :, ::-1], metadata) + out = v.draw_instance_predictions(outputs["instances"].to("cpu")) + out_path = Path(tempfile.mkdtemp()) / "out.png" + cv2.imwrite(str(out_path), out.get_image()[:, :, ::-1]) + return out_path + + +def get_clip_embeddings(vocabulary, prompt: str="a "): + text_encoder = build_text_encoder(pretrain=True) + text_encoder.eval() + texts = [prompt + x for x in vocabulary] + emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + return emb diff --git a/dimos/models/Detic/requirements.txt b/dimos/models/Detic/requirements.txt new file mode 100644 index 0000000000..518274db24 --- /dev/null +++ b/dimos/models/Detic/requirements.txt @@ -0,0 +1,11 @@ +opencv-python +mss +timm +dataclasses +ftfy +regex +fasttext +scikit-learn +lvis +nltk +git+https://github.com/openai/CLIP.git diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/CODE_OF_CONDUCT.md b/dimos/models/Detic/third_party/CenterNet2/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..0f7ad8bfc1 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. +Please read the [full text](https://code.fb.com/codeofconduct/) +so that you can understand what actions will and will not be tolerated. diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/CONTRIBUTING.md b/dimos/models/Detic/third_party/CenterNet2/.github/CONTRIBUTING.md new file mode 100644 index 0000000000..9bab709cae --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/CONTRIBUTING.md @@ -0,0 +1,68 @@ +# Contributing to detectron2 + +## Issues +We use GitHub issues to track public bugs and questions. +Please make sure to follow one of the +[issue templates](https://github.com/facebookresearch/detectron2/issues/new/choose) +when reporting any issues. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Pull Requests +We actively welcome pull requests. + +However, if you're adding any significant features (e.g. > 50 lines), please +make sure to discuss with maintainers about your motivation and proposals in an issue +before sending a PR. This is to save your time so you don't spend time on a PR that we'll not accept. + +We do not always accept new features, and we take the following +factors into consideration: + +1. Whether the same feature can be achieved without modifying detectron2. + Detectron2 is designed so that you can implement many extensions from the outside, e.g. + those in [projects](https://github.com/facebookresearch/detectron2/tree/master/projects). + * If some part of detectron2 is not extensible enough, you can also bring up a more general issue to + improve it. Such feature request may be useful to more users. +2. Whether the feature is potentially useful to a large audience (e.g. an impactful detection paper, a popular dataset, + a significant speedup, a widely useful utility), + or only to a small portion of users (e.g., a less-known paper, an improvement not in the object + detection field, a trick that's not very popular in the community, code to handle a non-standard type of data) + * Adoption of additional models, datasets, new task are by default not added to detectron2 before they + receive significant popularity in the community. + We sometimes accept such features in `projects/`, or as a link in `projects/README.md`. +3. Whether the proposed solution has a good design / interface. This can be discussed in the issue prior to PRs, or + in the form of a draft PR. +4. Whether the proposed solution adds extra mental/practical overhead to users who don't + need such feature. +5. Whether the proposed solution breaks existing APIs. + +To add a feature to an existing function/class `Func`, there are always two approaches: +(1) add new arguments to `Func`; (2) write a new `Func_with_new_feature`. +To meet the above criteria, we often prefer approach (2), because: + +1. It does not involve modifying or potentially breaking existing code. +2. It does not add overhead to users who do not need the new feature. +3. Adding new arguments to a function/class is not scalable w.r.t. all the possible new research ideas in the future. + +When sending a PR, please do: + +1. If a PR contains multiple orthogonal changes, split it to several PRs. +2. If you've added code that should be tested, add tests. +3. For PRs that need experiments (e.g. adding a new model or new methods), + you don't need to update model zoo, but do provide experiment results in the description of the PR. +4. If APIs are changed, update the documentation. +5. We use the [Google style docstrings](https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html) in python. +6. Make sure your code lints with `./dev/linter.sh`. + + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## License +By contributing to detectron2, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/Detectron2-Logo-Horz.svg b/dimos/models/Detic/third_party/CenterNet2/.github/Detectron2-Logo-Horz.svg new file mode 100644 index 0000000000..eb2d643ddd --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/Detectron2-Logo-Horz.svg @@ -0,0 +1 @@ +Detectron2-Logo-Horz \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE.md b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000000..5e8aaa2d37 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,5 @@ + +Please select an issue template from +https://github.com/facebookresearch/detectron2/issues/new/choose . + +Otherwise your issue will be closed. diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/bugs.md b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/bugs.md new file mode 100644 index 0000000000..d0235c708a --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/bugs.md @@ -0,0 +1,38 @@ +--- +name: "🐛 Bugs" +about: Report bugs in detectron2 +title: Please read & provide the following + +--- + +## Instructions To Reproduce the 🐛 Bug: +1. Full runnable code or full changes you made: +``` +If making changes to the project itself, please use output of the following command: +git rev-parse HEAD; git diff + + +``` +2. What exact command you run: +3. __Full logs__ or other relevant observations: +``` + +``` +4. please simplify the steps as much as possible so they do not require additional resources to + run, such as a private dataset. + +## Expected behavior: + +If there are no obvious error in "full logs" provided above, +please tell us the expected behavior. + +## Environment: + +Provide your environment information using the following command: +``` +wget -nc -q https://github.com/facebookresearch/detectron2/raw/main/detectron2/utils/collect_env.py && python collect_env.py +``` + +If your issue looks like an installation issue / environment issue, +please first try to solve it yourself with the instructions in +https://detectron2.readthedocs.io/tutorials/install.html#common-installation-issues diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/config.yml b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..c60c2e1430 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,17 @@ +# require an issue template to be chosen +blank_issues_enabled: false + +contact_links: + - name: How-To / All Other Questions + url: https://github.com/facebookresearch/detectron2/discussions + about: Use "github discussions" for community support on general questions that don't belong to the above issue categories + - name: Detectron2 Documentation + url: https://detectron2.readthedocs.io/index.html + about: Check if your question is answered in tutorials or API docs + +# Unexpected behaviors & bugs are split to two templates. +# When they are one template, users think "it's not a bug" and don't choose the template. +# +# But the file name is still "unexpected-problems-bugs.md" so that old references +# to this issue template still works. +# It's ok since this template should be a superset of "bugs.md" (unexpected behaviors is a superset of bugs) diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/documentation.md b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/documentation.md new file mode 100644 index 0000000000..88214d62e5 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/documentation.md @@ -0,0 +1,14 @@ +--- +name: "\U0001F4DA Documentation Issue" +about: Report a problem about existing documentation, comments, website or tutorials. +labels: documentation + +--- + +## 📚 Documentation Issue + +This issue category is for problems about existing documentation, not for asking how-to questions. + +* Provide a link to an existing documentation/comment/tutorial: + +* How should the above documentation/comment/tutorial improve: diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/feature-request.md b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 0000000000..03a1e93d72 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,31 @@ +--- +name: "\U0001F680Feature Request" +about: Suggest an improvement or new feature +labels: enhancement + +--- + +## 🚀 Feature +A clear and concise description of the feature proposal. + +## Motivation & Examples + +Tell us why the feature is useful. + +Describe what the feature would look like, if it is implemented. +Best demonstrated using **code examples** in addition to words. + +## Note + +We only consider adding new features if they are relevant to many users. + +If you request implementation of research papers -- we only consider papers that have enough significance and prevalance in the object detection field. + +We do not take requests for most projects in the `projects/` directory, because they are research code release that is mainly for other researchers to reproduce results. + +"Make X faster/accurate" is not a valid feature request. "Implement a concrete feature that can make X faster/accurate" can be a valid feature request. + +Instead of adding features inside detectron2, +you can implement many features by [extending detectron2](https://detectron2.readthedocs.io/tutorials/extend.html). +The [projects/](https://github.com/facebookresearch/detectron2/tree/main/projects/) directory contains many of such examples. + diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md new file mode 100644 index 0000000000..5db8f22415 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md @@ -0,0 +1,44 @@ +--- +name: "😩 Unexpected behaviors" +about: Report unexpected behaviors when using detectron2 +title: Please read & provide the following + +--- + +If you do not know the root cause of the problem, please post according to this template: + +## Instructions To Reproduce the Issue: + +Check https://stackoverflow.com/help/minimal-reproducible-example for how to ask good questions. +Simplify the steps to reproduce the issue using suggestions from the above link, and provide them below: + +1. Full runnable code or full changes you made: +``` +If making changes to the project itself, please use output of the following command: +git rev-parse HEAD; git diff + + +``` +2. What exact command you run: +3. __Full logs__ or other relevant observations: +``` + +``` + +## Expected behavior: + +If there are no obvious crash in "full logs" provided above, +please tell us the expected behavior. + +If you expect a model to converge / work better, we do not help with such issues, unless +a model fails to reproduce the results in detectron2 model zoo, or proves existence of bugs. + +## Environment: + +Paste the output of the following command: +``` +wget -nc -nv https://github.com/facebookresearch/detectron2/raw/main/detectron2/utils/collect_env.py && python collect_env.py +``` + +If your issue looks like an installation issue / environment issue, +please first check common issues in https://detectron2.readthedocs.io/tutorials/install.html#common-installation-issues diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/pull_request_template.md b/dimos/models/Detic/third_party/CenterNet2/.github/pull_request_template.md new file mode 100644 index 0000000000..d71729baee --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/pull_request_template.md @@ -0,0 +1,10 @@ +Thanks for your contribution! + +If you're sending a large PR (e.g., >100 lines), +please open an issue first about the feature / bug, and indicate how you want to contribute. + +We do not always accept features. +See https://detectron2.readthedocs.io/notes/contributing.html#pull-requests about how we handle PRs. + +Before submitting a PR, please run `dev/linter.sh` to lint the code. + diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/workflows/check-template.yml b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/check-template.yml new file mode 100644 index 0000000000..3caed9df3c --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/check-template.yml @@ -0,0 +1,86 @@ +name: Check issue template + +on: + issues: + types: [opened] + +jobs: + check-template: + runs-on: ubuntu-latest + # comment this out when testing with https://github.com/nektos/act + if: ${{ github.repository_owner == 'facebookresearch' }} + steps: + - uses: actions/checkout@v2 + - uses: actions/github-script@v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + // Arguments available: + // - github: A pre-authenticated octokit/rest.js client + // - context: An object containing the context of the workflow run + // - core: A reference to the @actions/core package + // - io: A reference to the @actions/io package + const fs = require('fs'); + const editDistance = require(`${process.env.GITHUB_WORKSPACE}/.github/workflows/levenshtein.js`).getEditDistance + issue = await github.issues.get({ + owner: context.issue.owner, + repo: context.issue.repo, + issue_number: context.issue.number, + }); + const hasLabel = issue.data.labels.length > 0; + if (hasLabel || issue.state === "closed") { + // don't require template on them + core.debug("Issue " + issue.data.title + " was skipped."); + return; + } + + sameAsTemplate = function(filename, body) { + let tmpl = fs.readFileSync(`.github/ISSUE_TEMPLATE/${filename}`, 'utf8'); + tmpl = tmpl.toLowerCase().split("---").slice(2).join("").trim(); + tmpl = tmpl.replace(/(\r\n|\n|\r)/gm, ""); + let bodyr = body.replace(/(\r\n|\n|\r)/gm, ""); + let dist = editDistance(tmpl, bodyr); + return dist < 8; + }; + + checkFail = async function(msg) { + core.info("Processing '" + issue.data.title + "' with message: " + msg); + await github.issues.addLabels({ + owner: context.issue.owner, + repo: context.issue.repo, + issue_number: context.issue.number, + labels: ["needs-more-info"], + }); + await github.issues.createComment({ + owner: context.issue.owner, + repo: context.issue.repo, + issue_number: context.issue.number, + body: msg, + }); + }; + + const body = issue.data.body.toLowerCase().trim(); + + if (sameAsTemplate("bugs.md", body) || sameAsTemplate("unexpected-problems-bugs.md", body)) { + await checkFail(` + We found that not enough information is provided about this issue. + Please provide details following the [issue template](https://github.com/facebookresearch/detectron2/issues/new/choose).`) + return; + } + + const hasInstructions = body.indexOf("reproduce") != -1; + const hasEnvironment = (body.indexOf("environment") != -1) || (body.indexOf("colab") != -1) || (body.indexOf("docker") != -1); + if (hasInstructions && hasEnvironment) { + core.debug("Issue " + issue.data.title + " follows template."); + return; + } + + let message = "You've chosen to report an unexpected problem or bug. Unless you already know the root cause of it, please include details about it by filling the [issue template](https://github.com/facebookresearch/detectron2/issues/new/choose).\n"; + message += "The following information is missing: "; + if (!hasInstructions) { + message += "\"Instructions To Reproduce the Issue and __Full__ Logs\"; "; + } + if (!hasEnvironment) { + message += "\"Your Environment\"; "; + } + await checkFail(message); diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/workflows/levenshtein.js b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/levenshtein.js new file mode 100644 index 0000000000..67a5e3613c --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/levenshtein.js @@ -0,0 +1,44 @@ +/* +Copyright (c) 2011 Andrei Mackenzie + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +// Compute the edit distance between the two given strings +exports.getEditDistance = function(a, b){ + if(a.length == 0) return b.length; + if(b.length == 0) return a.length; + + var matrix = []; + + // increment along the first column of each row + var i; + for(i = 0; i <= b.length; i++){ + matrix[i] = [i]; + } + + // increment each column in the first row + var j; + for(j = 0; j <= a.length; j++){ + matrix[0][j] = j; + } + + // Fill in the rest of the matrix + for(i = 1; i <= b.length; i++){ + for(j = 1; j <= a.length; j++){ + if(b.charAt(i-1) == a.charAt(j-1)){ + matrix[i][j] = matrix[i-1][j-1]; + } else { + matrix[i][j] = Math.min(matrix[i-1][j-1] + 1, // substitution + Math.min(matrix[i][j-1] + 1, // insertion + matrix[i-1][j] + 1)); // deletion + } + } + } + + return matrix[b.length][a.length]; +}; diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/workflows/needs-reply.yml b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/needs-reply.yml new file mode 100644 index 0000000000..4affabd349 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/needs-reply.yml @@ -0,0 +1,98 @@ +name: Close/Lock issues after inactivity + +on: + schedule: + - cron: "0 0 * * *" + +jobs: + close-issues-needs-more-info: + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'facebookresearch' }} + steps: + - name: Close old issues that need reply + uses: actions/github-script@v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + # Modified from https://github.com/dwieeb/needs-reply + script: | + // Arguments available: + // - github: A pre-authenticated octokit/rest.js client + // - context: An object containing the context of the workflow run + // - core: A reference to the @actions/core package + // - io: A reference to the @actions/io package + const kLabelToCheck = "needs-more-info"; + const kInvalidLabel = "invalid/unrelated"; + const kDaysBeforeClose = 7; + const kMessage = "Requested information was not provided in 7 days, so we're closing this issue.\n\nPlease open new issue if information becomes available. Otherwise, use [github discussions](https://github.com/facebookresearch/detectron2/discussions) for free-form discussions." + + issues = await github.issues.listForRepo({ + owner: context.repo.owner, + repo: context.repo.repo, + state: 'open', + labels: kLabelToCheck, + sort: 'updated', + direction: 'asc', + per_page: 30, + page: 1, + }); + issues = issues.data; + if (issues.length === 0) { + core.info('No more issues found to process. Exiting.'); + return; + } + for (const issue of issues) { + if (!!issue.pull_request) + continue; + core.info(`Processing issue #${issue.number}`); + + let updatedAt = new Date(issue.updated_at).getTime(); + const numComments = issue.comments; + const comments = await github.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + per_page: 30, + page: Math.floor((numComments - 1) / 30) + 1, // the last page + }); + const lastComments = comments.data + .map(l => new Date(l.created_at).getTime()) + .sort(); + if (lastComments.length > 0) { + updatedAt = lastComments[lastComments.length - 1]; + } + + const now = new Date().getTime(); + const daysSinceUpdated = (now - updatedAt) / 1000 / 60 / 60 / 24; + + if (daysSinceUpdated < kDaysBeforeClose) { + core.info(`Skipping #${issue.number} because it has been updated in the last ${daysSinceUpdated} days`); + continue; + } + core.info(`Closing #${issue.number} because it has not been updated in the last ${daysSinceUpdated} days`); + await github.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + body: kMessage, + }); + const newLabels = numComments <= 2 ? [kInvalidLabel, kLabelToCheck] : issue.labels; + await github.issues.update({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issue.number, + labels: newLabels, + state: 'closed', + }); + } + + lock-issues-after-closed: + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'facebookresearch' }} + steps: + - name: Lock closed issues that have no activity for a while + uses: dessant/lock-threads@v2 + with: + github-token: ${{ github.token }} + issue-lock-inactive-days: '300' + process-only: 'issues' + issue-exclude-labels: 'enhancement,bug,documentation' diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/workflows/remove-needs-reply.yml b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/remove-needs-reply.yml new file mode 100644 index 0000000000..1f000b28ca --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/remove-needs-reply.yml @@ -0,0 +1,25 @@ +name: Remove needs-more-info label + +on: + issue_comment: + types: [created] + issues: + types: [edited] + +jobs: + remove-needs-more-info-label: + runs-on: ubuntu-latest + # 1. issue_comment events could include PR comment, filter them out + # 2. Only trigger action if event was produced by the original author + if: ${{ !github.event.issue.pull_request && github.event.sender.login == github.event.issue.user.login }} + steps: + - name: Remove needs-more-info label + uses: octokit/request-action@v2.x + continue-on-error: true + with: + route: DELETE /repos/:repository/issues/:issue/labels/:label + repository: ${{ github.repository }} + issue: ${{ github.event.issue.number }} + label: needs-more-info + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/dimos/models/Detic/third_party/CenterNet2/.github/workflows/workflow.yml b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/workflow.yml new file mode 100644 index 0000000000..6085b32a50 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.github/workflows/workflow.yml @@ -0,0 +1,81 @@ +name: CI +on: [push, pull_request] + +# Run linter with github actions for quick feedbacks. +# Run macos tests with github actions. Linux (CPU & GPU) tests currently runs on CircleCI +jobs: + linter: + runs-on: ubuntu-latest + # run on PRs, or commits to facebookresearch (not internal) + if: ${{ github.repository_owner == 'facebookresearch' || github.event_name == 'pull_request' }} + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.6 + uses: actions/setup-python@v2 + with: + python-version: 3.6 + - name: Install dependencies + # flake8-bugbear flake8-comprehensions are useful but not available internally + run: | + python -m pip install --upgrade pip + python -m pip install flake8==3.8.1 isort==4.3.21 + python -m pip install black==21.4b2 + flake8 --version + - name: Lint + run: | + echo "Running isort" + isort -c -sp . + echo "Running black" + black -l 100 --check . + echo "Running flake8" + flake8 . + + macos_tests: + runs-on: macos-latest + # run on PRs, or commits to facebookresearch (not internal) + if: ${{ github.repository_owner == 'facebookresearch' || github.event_name == 'pull_request' }} + strategy: + fail-fast: false + matrix: + torch: ["1.8", "1.9", "1.10"] + include: + - torch: "1.8" + torchvision: 0.9 + - torch: "1.9" + torchvision: "0.10" + - torch: "1.10" + torchvision: "0.11.1" + env: + # point datasets to ~/.torch so it's cached by CI + DETECTRON2_DATASETS: ~/.torch/datasets + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python 3.6 + uses: actions/setup-python@v2 + with: + python-version: 3.6 + - name: Cache dependencies + uses: actions/cache@v2 + with: + path: | + ${{ env.pythonLocation }}/lib/python3.6/site-packages + ~/.torch + key: ${{ runner.os }}-torch${{ matrix.torch }}-${{ hashFiles('setup.py') }}-20210420 + + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install ninja opencv-python-headless onnx pytest-xdist + python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html + # install from github to get latest; install iopath first since fvcore depends on it + python -m pip install -U 'git+https://github.com/facebookresearch/iopath' + python -m pip install -U 'git+https://github.com/facebookresearch/fvcore' + + - name: Build and install + run: | + CC=clang CXX=clang++ python -m pip install -e .[all] + python -m detectron2.utils.collect_env + ./datasets/prepare_for_tests.sh + - name: Run unittests + run: python -m pytest -n 4 --durations=15 -v tests/ diff --git a/dimos/models/Detic/third_party/CenterNet2/.gitignore b/dimos/models/Detic/third_party/CenterNet2/.gitignore new file mode 100644 index 0000000000..e045ffa557 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/.gitignore @@ -0,0 +1,58 @@ +third_party/detectron2 +slurm* +# output dir +output +instant_test_output +inference_test_output + + +*.png +*.json +*.diff +# *.jpg +!/projects/DensePose/doc/images/*.jpg + +# compilation and distribution +__pycache__ +_ext +*.pyc +*.pyd +*.so +*.dll +*.egg-info/ +build/ +dist/ +wheels/ + +# pytorch/python/numpy formats +*.pth +*.pkl +*.npy +*.ts +model_ts*.txt + +# ipython/jupyter notebooks +*.ipynb +**/.ipynb_checkpoints/ + +# Editor temporaries +*.swn +*.swo +*.swp +*~ + +# editor settings +.idea +.vscode +_darcs + +# project dirs +/detectron2/model_zoo/configs +/datasets/* +!/datasets/*.* +!/datasets/lvis/ +/datasets/lvis/* +!/datasets/lvis/lvis_v1_train_cat_info.json +/projects/*/datasets +/models +/snippet diff --git a/dimos/models/Detic/third_party/CenterNet2/LICENSE b/dimos/models/Detic/third_party/CenterNet2/LICENSE new file mode 100644 index 0000000000..cd1b070674 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/LICENSE @@ -0,0 +1,202 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + +(a) You must give any other recipients of the Work or +Derivative Works a copy of this License; and + +(b) You must cause any modified files to carry prominent notices +stating that You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works +that You distribute, all copyright, patent, trademark, and +attribution notices from the Source form of the Work, +excluding those notices that do not pertain to any part of +the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its +distribution, then any Derivative Works that You distribute must +include a readable copy of the attribution notices contained +within such NOTICE file, excluding those notices that do not +pertain to any part of the Derivative Works, in at least one +of the following places: within a NOTICE text file distributed +as part of the Derivative Works; within the Source form or +documentation, if provided along with the Derivative Works; or, +within a display generated by the Derivative Works, if and +wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and +do not modify the License. You may add Your own attribution +notices within Derivative Works that You distribute, alongside +or as an addendum to the NOTICE text from the Work, provided +that such additional attribution notices cannot be construed +as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/dimos/models/Detic/third_party/CenterNet2/README.md b/dimos/models/Detic/third_party/CenterNet2/README.md new file mode 100644 index 0000000000..7ccbf8818f --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/README.md @@ -0,0 +1,81 @@ +# Probabilistic two-stage detection +Two-stage object detectors that use class-agnostic one-stage detectors as the proposal network. + + +

+ +> [**Probabilistic two-stage detection**](http://arxiv.org/abs/2103.07461), +> Xingyi Zhou, Vladlen Koltun, Philipp Krähenbühl, +> *arXiv technical report ([arXiv 2103.07461](http://arxiv.org/abs/2103.07461))* + +Contact: [zhouxy@cs.utexas.edu](mailto:zhouxy@cs.utexas.edu). Any questions or discussions are welcomed! + +## Summary + +- Two-stage CenterNet: First stage estimates object probabilities, second stage conditionally classifies objects. + +- Resulting detector is faster and more accurate than both traditional two-stage detectors (fewer proposals required), and one-stage detectors (lighter first stage head). + +- Our best model achieves 56.4 mAP on COCO test-dev. + +- This repo also includes a detectron2-based CenterNet implementation with better accuracy (42.5 mAP at 70FPS) and a new FPN version of CenterNet (40.2 mAP with Res50_1x). + +## Main results + +All models are trained with multi-scale training, and tested with a single scale. The FPS is tested on a Titan RTX GPU. +More models and details can be found in the [MODEL_ZOO](docs/MODEL_ZOO.md). + +#### COCO + +| Model | COCO val mAP | FPS | +|-------------------------------------------|---------------|-------| +| CenterNet-S4_DLA_8x | 42.5 | 71 | +| CenterNet2_R50_1x | 42.9 | 24 | +| CenterNet2_X101-DCN_2x | 49.9 | 8 | +| CenterNet2_R2-101-DCN-BiFPN_4x+4x_1560_ST | 56.1 | 5 | +| CenterNet2_DLA-BiFPN-P5_24x_ST | 49.2 | 38 | + + +#### LVIS + +| Model | val mAP box | +| ------------------------- | ----------- | +| CenterNet2_R50_1x | 26.5 | +| CenterNet2_FedLoss_R50_1x | 28.3 | + + +#### Objects365 + +| Model | val mAP | +|-------------------------------------------|----------| +| CenterNet2_R50_1x | 22.6 | + +## Installation + +Our project is developed on [detectron2](https://github.com/facebookresearch/detectron2). Please follow the official detectron2 [installation](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). + +We use the default detectron2 demo script. To run inference on an image folder using our pre-trained model, run + +~~~ +python demo.py --config-file configs/CenterNet2_R50_1x.yaml --input path/to/image/ --opts MODEL.WEIGHTS models/CenterNet2_R50_1x.pth +~~~ + +## Benchmark evaluation and training + +Please check detectron2 [GETTING_STARTED.md](https://github.com/facebookresearch/detectron2/blob/master/GETTING_STARTED.md) for running evaluation and training. Our config files are under `configs` and the pre-trained models are in the [MODEL_ZOO](docs/MODEL_ZOO.md). + + +## License + +Our code is under [Apache 2.0 license](LICENSE). `centernet/modeling/backbone/bifpn_fcos.py` are from [AdelaiDet](https://github.com/aim-uofa/AdelaiDet), which follows the original [non-commercial license](https://github.com/aim-uofa/AdelaiDet/blob/master/LICENSE). + +## Citation + +If you find this project useful for your research, please use the following BibTeX entry. + + @inproceedings{zhou2021probablistic, + title={Probabilistic two-stage detection}, + author={Zhou, Xingyi and Koltun, Vladlen and Kr{\"a}henb{\"u}hl, Philipp}, + booktitle={arXiv preprint arXiv:2103.07461}, + year={2021} + } diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/__init__.py b/dimos/models/Detic/third_party/CenterNet2/centernet/__init__.py new file mode 100644 index 0000000000..5e2e7afac6 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/__init__.py @@ -0,0 +1,12 @@ +from .data.datasets import nuimages +from .data.datasets.coco import _PREDEFINED_SPLITS_COCO +from .data.datasets.objects365 import categories_v1 +from .modeling.backbone.bifpn import build_resnet_bifpn_backbone +from .modeling.backbone.bifpn_fcos import build_fcos_resnet_bifpn_backbone +from .modeling.backbone.dla import build_dla_backbone +from .modeling.backbone.dlafpn import build_dla_fpn3_backbone +from .modeling.backbone.fpn_p5 import build_p67_resnet_fpn_backbone +from .modeling.backbone.res2net import build_p67_res2net_fpn_backbone +from .modeling.dense_heads.centernet import CenterNet +from .modeling.meta_arch.centernet_detector import CenterNetDetector +from .modeling.roi_heads.custom_roi_heads import CustomCascadeROIHeads, CustomROIHeads diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/config.py b/dimos/models/Detic/third_party/CenterNet2/centernet/config.py new file mode 100644 index 0000000000..255eb36340 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/config.py @@ -0,0 +1,88 @@ +from detectron2.config import CfgNode as CN + + +def add_centernet_config(cfg) -> None: + _C = cfg + + _C.MODEL.CENTERNET = CN() + _C.MODEL.CENTERNET.NUM_CLASSES = 80 + _C.MODEL.CENTERNET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"] + _C.MODEL.CENTERNET.FPN_STRIDES = [8, 16, 32, 64, 128] + _C.MODEL.CENTERNET.PRIOR_PROB = 0.01 + _C.MODEL.CENTERNET.INFERENCE_TH = 0.05 + _C.MODEL.CENTERNET.CENTER_NMS = False + _C.MODEL.CENTERNET.NMS_TH_TRAIN = 0.6 + _C.MODEL.CENTERNET.NMS_TH_TEST = 0.6 + _C.MODEL.CENTERNET.PRE_NMS_TOPK_TRAIN = 1000 + _C.MODEL.CENTERNET.POST_NMS_TOPK_TRAIN = 100 + _C.MODEL.CENTERNET.PRE_NMS_TOPK_TEST = 1000 + _C.MODEL.CENTERNET.POST_NMS_TOPK_TEST = 100 + _C.MODEL.CENTERNET.NORM = "GN" + _C.MODEL.CENTERNET.USE_DEFORMABLE = False + _C.MODEL.CENTERNET.NUM_CLS_CONVS = 4 + _C.MODEL.CENTERNET.NUM_BOX_CONVS = 4 + _C.MODEL.CENTERNET.NUM_SHARE_CONVS = 0 + _C.MODEL.CENTERNET.LOC_LOSS_TYPE = "giou" + _C.MODEL.CENTERNET.SIGMOID_CLAMP = 1e-4 + _C.MODEL.CENTERNET.HM_MIN_OVERLAP = 0.8 + _C.MODEL.CENTERNET.MIN_RADIUS = 4 + _C.MODEL.CENTERNET.SOI = [[0, 80], [64, 160], [128, 320], [256, 640], [512, 10000000]] + _C.MODEL.CENTERNET.POS_WEIGHT = 1.0 + _C.MODEL.CENTERNET.NEG_WEIGHT = 1.0 + _C.MODEL.CENTERNET.REG_WEIGHT = 2.0 + _C.MODEL.CENTERNET.HM_FOCAL_BETA = 4 + _C.MODEL.CENTERNET.HM_FOCAL_ALPHA = 0.25 + _C.MODEL.CENTERNET.LOSS_GAMMA = 2.0 + _C.MODEL.CENTERNET.WITH_AGN_HM = False + _C.MODEL.CENTERNET.ONLY_PROPOSAL = False + _C.MODEL.CENTERNET.AS_PROPOSAL = False + _C.MODEL.CENTERNET.IGNORE_HIGH_FP = -1.0 + _C.MODEL.CENTERNET.MORE_POS = False + _C.MODEL.CENTERNET.MORE_POS_THRESH = 0.2 + _C.MODEL.CENTERNET.MORE_POS_TOPK = 9 + _C.MODEL.CENTERNET.NOT_NORM_REG = True + _C.MODEL.CENTERNET.NOT_NMS = False + _C.MODEL.CENTERNET.NO_REDUCE = False + + _C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False + _C.MODEL.ROI_BOX_HEAD.PRIOR_PROB = 0.01 + _C.MODEL.ROI_BOX_HEAD.USE_EQL_LOSS = False + _C.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = "datasets/lvis/lvis_v1_train_cat_info.json" + _C.MODEL.ROI_BOX_HEAD.EQL_FREQ_CAT = 200 + _C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False + _C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT = 50 + _C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT = 0.5 + _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False + + _C.MODEL.BIFPN = CN() + _C.MODEL.BIFPN.NUM_LEVELS = 5 + _C.MODEL.BIFPN.NUM_BIFPN = 6 + _C.MODEL.BIFPN.NORM = "GN" + _C.MODEL.BIFPN.OUT_CHANNELS = 160 + _C.MODEL.BIFPN.SEPARABLE_CONV = False + + _C.MODEL.DLA = CN() + _C.MODEL.DLA.OUT_FEATURES = ["dla2"] + _C.MODEL.DLA.USE_DLA_UP = True + _C.MODEL.DLA.NUM_LAYERS = 34 + _C.MODEL.DLA.MS_OUTPUT = False + _C.MODEL.DLA.NORM = "BN" + _C.MODEL.DLA.DLAUP_IN_FEATURES = ["dla3", "dla4", "dla5"] + _C.MODEL.DLA.DLAUP_NODE = "conv" + + _C.SOLVER.RESET_ITER = False + _C.SOLVER.TRAIN_ITER = -1 + + _C.INPUT.CUSTOM_AUG = "" + _C.INPUT.TRAIN_SIZE = 640 + _C.INPUT.TEST_SIZE = 640 + _C.INPUT.SCALE_RANGE = (0.1, 2.0) + # 'default' for fixed short/ long edge, 'square' for max size=INPUT.SIZE + _C.INPUT.TEST_INPUT_TYPE = "default" + _C.INPUT.NOT_CLAMP_BOX = False + + _C.DEBUG = False + _C.SAVE_DEBUG = False + _C.SAVE_PTH = False + _C.VIS_THRESH = 0.3 + _C.DEBUG_SHOW_NAME = False diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_build_augmentation.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_build_augmentation.py new file mode 100644 index 0000000000..1bcb7cee66 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_build_augmentation.py @@ -0,0 +1,43 @@ +from detectron2.data import transforms as T + +from .transforms.custom_augmentation_impl import EfficientDetResizeCrop + + +def build_custom_augmentation(cfg, is_train: bool): + """ + Create a list of default :class:`Augmentation` from config. + Now it includes resizing and flipping. + + Returns: + list[Augmentation] + """ + if cfg.INPUT.CUSTOM_AUG == "ResizeShortestEdge": + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] + elif cfg.INPUT.CUSTOM_AUG == "EfficientDetResizeCrop": + if is_train: + scale = cfg.INPUT.SCALE_RANGE + size = cfg.INPUT.TRAIN_SIZE + else: + scale = (1, 1) + size = cfg.INPUT.TEST_SIZE + augmentation = [EfficientDetResizeCrop(size, scale)] + else: + assert 0, cfg.INPUT.CUSTOM_AUG + + if is_train: + augmentation.append(T.RandomFlip()) + return augmentation + + +build_custom_transform_gen = build_custom_augmentation +""" +Alias for backward-compatibility. +""" diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_dataset_dataloader.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_dataset_dataloader.py new file mode 100644 index 0000000000..a7cfdd523d --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_dataset_dataloader.py @@ -0,0 +1,217 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from collections import defaultdict +import itertools +import logging +from typing import Iterator, Sequence, Optional + +from detectron2.data.build import ( + build_batch_data_loader, + check_metadata_consistency, + filter_images_with_few_keypoints, + filter_images_with_only_crowd_annotations, + get_detection_dataset_dicts, + print_instances_class_histogram, +) +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog +from detectron2.data.common import DatasetFromList, MapDataset +from detectron2.data.samplers import RepeatFactorTrainingSampler, TrainingSampler +from detectron2.utils import comm +import torch +import torch.utils.data +from torch.utils.data.sampler import Sampler + +# from .custom_build_augmentation import build_custom_augmentation + + +def build_custom_train_loader(cfg, mapper=None): + """ + Modified from detectron2.data.build.build_custom_train_loader, but supports + different samplers + """ + source_aware = cfg.DATALOADER.SOURCE_AWARE + if source_aware: + dataset_dicts = get_detection_dataset_dicts_with_source( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + sizes = [0 for _ in range(len(cfg.DATASETS.TRAIN))] + for d in dataset_dicts: + sizes[d["dataset_source"]] += 1 + print("dataset sizes", sizes) + else: + dataset_dicts = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + dataset = DatasetFromList(dataset_dicts, copy=False) + + if mapper is None: + assert 0 + # mapper = DatasetMapper(cfg, True) + dataset = MapDataset(dataset, mapper) + + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + logger = logging.getLogger(__name__) + logger.info(f"Using training sampler {sampler_name}") + # TODO avoid if-else? + if sampler_name == "TrainingSampler": + sampler = TrainingSampler(len(dataset)) + elif sampler_name == "MultiDatasetSampler": + assert source_aware + sampler = MultiDatasetSampler(cfg, sizes, dataset_dicts) + elif sampler_name == "RepeatFactorTrainingSampler": + repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD + ) + sampler = RepeatFactorTrainingSampler(repeat_factors) + elif sampler_name == "ClassAwareSampler": + sampler = ClassAwareSampler(dataset_dicts) + else: + raise ValueError(f"Unknown training sampler: {sampler_name}") + + return build_batch_data_loader( + dataset, + sampler, + cfg.SOLVER.IMS_PER_BATCH, + aspect_ratio_grouping=cfg.DATALOADER.ASPECT_RATIO_GROUPING, + num_workers=cfg.DATALOADER.NUM_WORKERS, + ) + + +class ClassAwareSampler(Sampler): + def __init__(self, dataset_dicts, seed: int | None = None) -> None: + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + self._size = len(dataset_dicts) + assert self._size > 0 + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + self.weights = self._get_class_balance_factor(dataset_dicts) + + def __iter__(self) -> Iterator: + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + ids = torch.multinomial(self.weights, self._size, generator=g, replacement=True) + yield from ids + + def _get_class_balance_factor(self, dataset_dicts, l: float=1.0): + # 1. For each category c, compute the fraction of images that contain it: f(c) + ret = [] + category_freq = defaultdict(int) + for dataset_dict in dataset_dicts: # For each image (without repeats) + cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} + for cat_id in cat_ids: + category_freq[cat_id] += 1 + for _i, dataset_dict in enumerate(dataset_dicts): + cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} + ret.append(sum([1.0 / (category_freq[cat_id] ** l) for cat_id in cat_ids])) + return torch.tensor(ret).float() + + +def get_detection_dataset_dicts_with_source( + dataset_names: Sequence[str], filter_empty: bool=True, min_keypoints: int=0, proposal_files=None +): + assert len(dataset_names) + dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] + for dataset_name, dicts in zip(dataset_names, dataset_dicts, strict=False): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" + + for source_id, (dataset_name, dicts) in enumerate(zip(dataset_names, dataset_dicts, strict=False)): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" + for d in dicts: + d["dataset_source"] = source_id + + if "annotations" in dicts[0]: + try: + class_names = MetadataCatalog.get(dataset_name).thing_classes + check_metadata_consistency("thing_classes", dataset_name) + print_instances_class_histogram(dicts, class_names) + except AttributeError: # class names are not available for this dataset + pass + + assert proposal_files is None + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = "annotations" in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) + + return dataset_dicts + + +class MultiDatasetSampler(Sampler): + def __init__(self, cfg, sizes: Sequence[int], dataset_dicts, seed: int | None = None) -> None: + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + self.sizes = sizes + dataset_ratio = cfg.DATALOADER.DATASET_RATIO + self._batch_size = cfg.SOLVER.IMS_PER_BATCH + assert len(dataset_ratio) == len(sizes), ( + f"length of dataset ratio {len(dataset_ratio)} should be equal to number if dataset {len(sizes)}" + ) + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + self._ims_per_gpu = self._batch_size // self._world_size + self.dataset_ids = torch.tensor( + [d["dataset_source"] for d in dataset_dicts], dtype=torch.long + ) + + dataset_weight = [ + torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) + for i, (r, s) in enumerate(zip(dataset_ratio, sizes, strict=False)) + ] + dataset_weight = torch.cat(dataset_weight) + self.weights = dataset_weight + self.sample_epoch_size = len(self.weights) + + def __iter__(self) -> Iterator: + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + ids = torch.multinomial( + self.weights, self.sample_epoch_size, generator=g, replacement=True + ) + nums = [(self.dataset_ids[ids] == i).sum().int().item() for i in range(len(self.sizes))] + print("_rank, len, nums", self._rank, len(ids), nums, flush=True) + # print('_rank, len, nums, self.dataset_ids[ids[:10]], ', + # self._rank, len(ids), nums, self.dataset_ids[ids[:10]], + # flush=True) + yield from ids diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/coco.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/coco.py new file mode 100644 index 0000000000..33ff5a6980 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/coco.py @@ -0,0 +1,53 @@ +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.builtin_meta import _get_builtin_metadata +from detectron2.data.datasets.coco import load_coco_json +from detectron2.data.datasets.register_coco import register_coco_instances + + +def register_distill_coco_instances(name: str, metadata, json_file, image_root) -> None: + """ + add extra_annotation_keys + """ + assert isinstance(name, str), name + assert isinstance(json_file, str | os.PathLike), json_file + assert isinstance(image_root, str | os.PathLike), image_root + # 1. register a function which returns dicts + DatasetCatalog.register( + name, lambda: load_coco_json(json_file, image_root, name, extra_annotation_keys=["score"]) + ) + + # 2. Optionally, add metadata about this dataset, + # since they might be useful in evaluation, visualization or logging + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata + ) + + +_PREDEFINED_SPLITS_COCO = { + "coco_2017_unlabeled": ("coco/unlabeled2017", "coco/annotations/image_info_unlabeled2017.json"), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_COCO.items(): + register_coco_instances( + key, + _get_builtin_metadata("coco"), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) + +_PREDEFINED_SPLITS_DISTILL_COCO = { + "coco_un_yolov4_55_0.5": ( + "coco/unlabeled2017", + "coco/annotations/yolov4_cocounlabeled_55_ann0.5.json", + ), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_DISTILL_COCO.items(): + register_distill_coco_instances( + key, + _get_builtin_metadata("coco"), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/nuimages.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/nuimages.py new file mode 100644 index 0000000000..fdcd40242f --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/nuimages.py @@ -0,0 +1,41 @@ +import os + +from detectron2.data.datasets.register_coco import register_coco_instances + +categories = [ + {"id": 0, "name": "car"}, + {"id": 1, "name": "truck"}, + {"id": 2, "name": "trailer"}, + {"id": 3, "name": "bus"}, + {"id": 4, "name": "construction_vehicle"}, + {"id": 5, "name": "bicycle"}, + {"id": 6, "name": "motorcycle"}, + {"id": 7, "name": "pedestrian"}, + {"id": 8, "name": "traffic_cone"}, + {"id": 9, "name": "barrier"}, +] + + +def _get_builtin_metadata(): + id_to_name = {x["id"]: x["name"] for x in categories} + thing_dataset_id_to_contiguous_id = {i: i for i in range(len(categories))} + thing_classes = [id_to_name[k] for k in sorted(id_to_name)] + return { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + } + + +_PREDEFINED_SPLITS = { + "nuimages_train": ("nuimages", "nuimages/annotations/nuimages_v1.0-train.json"), + "nuimages_val": ("nuimages", "nuimages/annotations/nuimages_v1.0-val.json"), + "nuimages_mini": ("nuimages", "nuimages/annotations/nuimages_v1.0-mini.json"), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS.items(): + register_coco_instances( + key, + _get_builtin_metadata(), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/objects365.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/objects365.py new file mode 100644 index 0000000000..e3e8383a91 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/objects365.py @@ -0,0 +1,398 @@ +import os + +from detectron2.data.datasets.register_coco import register_coco_instances + +categories_v1 = [ + {"id": 164, "name": "cutting/chopping board"}, + {"id": 49, "name": "tie"}, + {"id": 306, "name": "crosswalk sign"}, + {"id": 145, "name": "gun"}, + {"id": 14, "name": "street lights"}, + {"id": 223, "name": "bar soap"}, + {"id": 74, "name": "wild bird"}, + {"id": 219, "name": "ice cream"}, + {"id": 37, "name": "stool"}, + {"id": 25, "name": "storage box"}, + {"id": 153, "name": "giraffe"}, + {"id": 52, "name": "pen/pencil"}, + {"id": 61, "name": "high heels"}, + {"id": 340, "name": "mangosteen"}, + {"id": 22, "name": "bracelet"}, + {"id": 155, "name": "piano"}, + {"id": 162, "name": "vent"}, + {"id": 75, "name": "laptop"}, + {"id": 236, "name": "toaster"}, + {"id": 231, "name": "fire truck"}, + {"id": 42, "name": "basket"}, + {"id": 150, "name": "zebra"}, + {"id": 124, "name": "head phone"}, + {"id": 90, "name": "sheep"}, + {"id": 322, "name": "steak"}, + {"id": 39, "name": "couch"}, + {"id": 209, "name": "toothbrush"}, + {"id": 59, "name": "bicycle"}, + {"id": 336, "name": "red cabbage"}, + {"id": 228, "name": "golf ball"}, + {"id": 120, "name": "tomato"}, + {"id": 132, "name": "computer box"}, + {"id": 8, "name": "cup"}, + {"id": 183, "name": "basketball"}, + {"id": 298, "name": "butterfly"}, + {"id": 250, "name": "garlic"}, + {"id": 12, "name": "desk"}, + {"id": 141, "name": "microwave"}, + {"id": 171, "name": "strawberry"}, + {"id": 200, "name": "kettle"}, + {"id": 63, "name": "van"}, + {"id": 300, "name": "cheese"}, + {"id": 215, "name": "marker"}, + {"id": 100, "name": "blackboard/whiteboard"}, + {"id": 186, "name": "printer"}, + {"id": 333, "name": "bread/bun"}, + {"id": 243, "name": "penguin"}, + {"id": 364, "name": "iron"}, + {"id": 180, "name": "ladder"}, + {"id": 34, "name": "flag"}, + {"id": 78, "name": "cell phone"}, + {"id": 97, "name": "fan"}, + {"id": 224, "name": "scale"}, + {"id": 151, "name": "duck"}, + {"id": 319, "name": "flute"}, + {"id": 156, "name": "stop sign"}, + {"id": 290, "name": "rickshaw"}, + {"id": 128, "name": "sailboat"}, + {"id": 165, "name": "tennis racket"}, + {"id": 241, "name": "cigar"}, + {"id": 101, "name": "balloon"}, + {"id": 308, "name": "hair drier"}, + {"id": 167, "name": "skating and skiing shoes"}, + {"id": 237, "name": "helicopter"}, + {"id": 65, "name": "sink"}, + {"id": 129, "name": "tangerine"}, + {"id": 330, "name": "crab"}, + {"id": 320, "name": "measuring cup"}, + {"id": 260, "name": "fishing rod"}, + {"id": 346, "name": "saw"}, + {"id": 216, "name": "ship"}, + {"id": 46, "name": "coffee table"}, + {"id": 194, "name": "facial mask"}, + {"id": 281, "name": "stapler"}, + {"id": 118, "name": "refrigerator"}, + {"id": 40, "name": "belt"}, + {"id": 349, "name": "starfish"}, + {"id": 87, "name": "hanger"}, + {"id": 116, "name": "baseball glove"}, + {"id": 261, "name": "cherry"}, + {"id": 334, "name": "baozi"}, + {"id": 267, "name": "screwdriver"}, + {"id": 158, "name": "converter"}, + {"id": 335, "name": "lion"}, + {"id": 170, "name": "baseball"}, + {"id": 111, "name": "skis"}, + {"id": 136, "name": "broccoli"}, + {"id": 342, "name": "eraser"}, + {"id": 337, "name": "polar bear"}, + {"id": 139, "name": "shovel"}, + {"id": 193, "name": "extension cord"}, + {"id": 284, "name": "goldfish"}, + {"id": 174, "name": "pepper"}, + {"id": 138, "name": "stroller"}, + {"id": 328, "name": "yak"}, + {"id": 83, "name": "clock"}, + {"id": 235, "name": "tricycle"}, + {"id": 248, "name": "parking meter"}, + {"id": 274, "name": "trophy"}, + {"id": 324, "name": "binoculars"}, + {"id": 51, "name": "traffic light"}, + {"id": 314, "name": "donkey"}, + {"id": 45, "name": "barrel/bucket"}, + {"id": 292, "name": "pomegranate"}, + {"id": 13, "name": "handbag"}, + {"id": 262, "name": "tablet"}, + {"id": 68, "name": "apple"}, + {"id": 226, "name": "cabbage"}, + {"id": 23, "name": "flower"}, + {"id": 58, "name": "faucet"}, + {"id": 206, "name": "tong"}, + {"id": 291, "name": "trombone"}, + {"id": 160, "name": "carrot"}, + {"id": 172, "name": "bow tie"}, + {"id": 122, "name": "tent"}, + {"id": 163, "name": "cookies"}, + {"id": 115, "name": "remote"}, + {"id": 175, "name": "coffee machine"}, + {"id": 238, "name": "green beans"}, + {"id": 233, "name": "cello"}, + {"id": 28, "name": "wine glass"}, + {"id": 295, "name": "mushroom"}, + {"id": 344, "name": "scallop"}, + {"id": 125, "name": "lantern"}, + {"id": 123, "name": "shampoo/shower gel"}, + {"id": 285, "name": "meat balls"}, + {"id": 266, "name": "key"}, + {"id": 296, "name": "calculator"}, + {"id": 168, "name": "scissors"}, + {"id": 103, "name": "cymbal"}, + {"id": 6, "name": "bottle"}, + {"id": 264, "name": "nuts"}, + {"id": 234, "name": "notepaper"}, + {"id": 211, "name": "mango"}, + {"id": 287, "name": "toothpaste"}, + {"id": 196, "name": "chopsticks"}, + {"id": 140, "name": "baseball bat"}, + {"id": 244, "name": "hurdle"}, + {"id": 195, "name": "tennis ball"}, + {"id": 144, "name": "surveillance camera"}, + {"id": 271, "name": "volleyball"}, + {"id": 94, "name": "keyboard"}, + {"id": 339, "name": "seal"}, + {"id": 11, "name": "picture/frame"}, + {"id": 348, "name": "okra"}, + {"id": 191, "name": "sausage"}, + {"id": 166, "name": "candy"}, + {"id": 62, "name": "ring"}, + {"id": 311, "name": "dolphin"}, + {"id": 273, "name": "eggplant"}, + {"id": 84, "name": "drum"}, + {"id": 143, "name": "surfboard"}, + {"id": 288, "name": "antelope"}, + {"id": 204, "name": "clutch"}, + {"id": 207, "name": "slide"}, + {"id": 43, "name": "towel/napkin"}, + {"id": 352, "name": "durian"}, + {"id": 276, "name": "board eraser"}, + {"id": 315, "name": "electric drill"}, + {"id": 312, "name": "sushi"}, + {"id": 198, "name": "pie"}, + {"id": 106, "name": "pickup truck"}, + {"id": 176, "name": "bathtub"}, + {"id": 26, "name": "vase"}, + {"id": 133, "name": "elephant"}, + {"id": 256, "name": "sandwich"}, + {"id": 327, "name": "noodles"}, + {"id": 10, "name": "glasses"}, + {"id": 109, "name": "airplane"}, + {"id": 95, "name": "tripod"}, + {"id": 247, "name": "CD"}, + {"id": 121, "name": "machinery vehicle"}, + {"id": 365, "name": "flashlight"}, + {"id": 53, "name": "microphone"}, + {"id": 270, "name": "pliers"}, + {"id": 362, "name": "chainsaw"}, + {"id": 259, "name": "bear"}, + {"id": 197, "name": "electronic stove and gas stove"}, + {"id": 89, "name": "pot/pan"}, + {"id": 220, "name": "tape"}, + {"id": 338, "name": "lighter"}, + {"id": 177, "name": "snowboard"}, + {"id": 214, "name": "violin"}, + {"id": 217, "name": "chicken"}, + {"id": 2, "name": "sneakers"}, + {"id": 161, "name": "washing machine"}, + {"id": 131, "name": "kite"}, + {"id": 354, "name": "rabbit"}, + {"id": 86, "name": "bus"}, + {"id": 275, "name": "dates"}, + {"id": 282, "name": "camel"}, + {"id": 88, "name": "nightstand"}, + {"id": 179, "name": "grapes"}, + {"id": 229, "name": "pine apple"}, + {"id": 56, "name": "necklace"}, + {"id": 18, "name": "leather shoes"}, + {"id": 358, "name": "hoverboard"}, + {"id": 345, "name": "pencil case"}, + {"id": 359, "name": "pasta"}, + {"id": 157, "name": "radiator"}, + {"id": 201, "name": "hamburger"}, + {"id": 268, "name": "globe"}, + {"id": 332, "name": "barbell"}, + {"id": 329, "name": "mop"}, + {"id": 252, "name": "horn"}, + {"id": 350, "name": "eagle"}, + {"id": 169, "name": "folder"}, + {"id": 137, "name": "toilet"}, + {"id": 5, "name": "lamp"}, + {"id": 27, "name": "bench"}, + {"id": 249, "name": "swan"}, + {"id": 76, "name": "knife"}, + {"id": 341, "name": "comb"}, + {"id": 64, "name": "watch"}, + {"id": 105, "name": "telephone"}, + {"id": 3, "name": "chair"}, + {"id": 33, "name": "boat"}, + {"id": 107, "name": "orange"}, + {"id": 60, "name": "bread"}, + {"id": 147, "name": "cat"}, + {"id": 135, "name": "gas stove"}, + {"id": 307, "name": "papaya"}, + {"id": 227, "name": "router/modem"}, + {"id": 357, "name": "asparagus"}, + {"id": 73, "name": "motorcycle"}, + {"id": 77, "name": "traffic sign"}, + {"id": 67, "name": "fish"}, + {"id": 326, "name": "radish"}, + {"id": 213, "name": "egg"}, + {"id": 203, "name": "cucumber"}, + {"id": 17, "name": "helmet"}, + {"id": 110, "name": "luggage"}, + {"id": 80, "name": "truck"}, + {"id": 199, "name": "frisbee"}, + {"id": 232, "name": "peach"}, + {"id": 1, "name": "person"}, + {"id": 29, "name": "boots"}, + {"id": 310, "name": "chips"}, + {"id": 142, "name": "skateboard"}, + {"id": 44, "name": "slippers"}, + {"id": 4, "name": "hat"}, + {"id": 178, "name": "suitcase"}, + {"id": 24, "name": "tv"}, + {"id": 119, "name": "train"}, + {"id": 82, "name": "power outlet"}, + {"id": 245, "name": "swing"}, + {"id": 15, "name": "book"}, + {"id": 294, "name": "jellyfish"}, + {"id": 192, "name": "fire extinguisher"}, + {"id": 212, "name": "deer"}, + {"id": 181, "name": "pear"}, + {"id": 347, "name": "table tennis paddle"}, + {"id": 113, "name": "trolley"}, + {"id": 91, "name": "guitar"}, + {"id": 202, "name": "golf club"}, + {"id": 221, "name": "wheelchair"}, + {"id": 254, "name": "saxophone"}, + {"id": 117, "name": "paper towel"}, + {"id": 303, "name": "race car"}, + {"id": 240, "name": "carriage"}, + {"id": 246, "name": "radio"}, + {"id": 318, "name": "parrot"}, + {"id": 251, "name": "french fries"}, + {"id": 98, "name": "dog"}, + {"id": 112, "name": "soccer"}, + {"id": 355, "name": "french horn"}, + {"id": 79, "name": "paddle"}, + {"id": 283, "name": "lettuce"}, + {"id": 9, "name": "car"}, + {"id": 258, "name": "kiwi fruit"}, + {"id": 325, "name": "llama"}, + {"id": 187, "name": "billiards"}, + {"id": 210, "name": "facial cleanser"}, + {"id": 81, "name": "cow"}, + {"id": 331, "name": "microscope"}, + {"id": 148, "name": "lemon"}, + {"id": 302, "name": "pomelo"}, + {"id": 85, "name": "fork"}, + {"id": 154, "name": "pumpkin"}, + {"id": 289, "name": "shrimp"}, + {"id": 71, "name": "teddy bear"}, + {"id": 184, "name": "potato"}, + {"id": 102, "name": "air conditioner"}, + {"id": 208, "name": "hot dog"}, + {"id": 222, "name": "plum"}, + {"id": 316, "name": "spring rolls"}, + {"id": 230, "name": "crane"}, + {"id": 149, "name": "liquid soap"}, + {"id": 55, "name": "canned"}, + {"id": 35, "name": "speaker"}, + {"id": 108, "name": "banana"}, + {"id": 297, "name": "treadmill"}, + {"id": 99, "name": "spoon"}, + {"id": 104, "name": "mouse"}, + {"id": 182, "name": "american football"}, + {"id": 299, "name": "egg tart"}, + {"id": 127, "name": "cleaning products"}, + {"id": 313, "name": "urinal"}, + {"id": 286, "name": "medal"}, + {"id": 239, "name": "brush"}, + {"id": 96, "name": "hockey"}, + {"id": 279, "name": "dumbbell"}, + {"id": 32, "name": "umbrella"}, + {"id": 272, "name": "hammer"}, + {"id": 16, "name": "plate"}, + {"id": 21, "name": "potted plant"}, + {"id": 242, "name": "earphone"}, + {"id": 70, "name": "candle"}, + {"id": 185, "name": "paint brush"}, + {"id": 48, "name": "toy"}, + {"id": 130, "name": "pizza"}, + {"id": 255, "name": "trumpet"}, + {"id": 361, "name": "hotair balloon"}, + {"id": 188, "name": "fire hydrant"}, + {"id": 50, "name": "bed"}, + {"id": 253, "name": "avocado"}, + {"id": 293, "name": "coconut"}, + {"id": 257, "name": "cue"}, + {"id": 280, "name": "hamimelon"}, + {"id": 66, "name": "horse"}, + {"id": 173, "name": "pigeon"}, + {"id": 190, "name": "projector"}, + {"id": 69, "name": "camera"}, + {"id": 30, "name": "bowl"}, + {"id": 269, "name": "broom"}, + {"id": 343, "name": "pitaya"}, + {"id": 305, "name": "tuba"}, + {"id": 309, "name": "green onion"}, + {"id": 363, "name": "lobster"}, + {"id": 225, "name": "watermelon"}, + {"id": 47, "name": "suv"}, + {"id": 31, "name": "dining table"}, + {"id": 54, "name": "sandals"}, + {"id": 351, "name": "monkey"}, + {"id": 218, "name": "onion"}, + {"id": 36, "name": "trash bin/can"}, + {"id": 20, "name": "glove"}, + {"id": 277, "name": "rice"}, + {"id": 152, "name": "sports car"}, + {"id": 360, "name": "target"}, + {"id": 205, "name": "blender"}, + {"id": 19, "name": "pillow"}, + {"id": 72, "name": "cake"}, + {"id": 93, "name": "tea pot"}, + {"id": 353, "name": "game board"}, + {"id": 38, "name": "backpack"}, + {"id": 356, "name": "ambulance"}, + {"id": 146, "name": "life saver"}, + {"id": 189, "name": "goose"}, + {"id": 278, "name": "tape measure/ruler"}, + {"id": 92, "name": "traffic cone"}, + {"id": 134, "name": "toiletries"}, + {"id": 114, "name": "oven"}, + {"id": 317, "name": "tortoise/turtle"}, + {"id": 265, "name": "corn"}, + {"id": 126, "name": "donut"}, + {"id": 57, "name": "mirror"}, + {"id": 7, "name": "cabinet/shelf"}, + {"id": 263, "name": "green vegetables"}, + {"id": 159, "name": "tissue "}, + {"id": 321, "name": "shark"}, + {"id": 301, "name": "pig"}, + {"id": 41, "name": "carpet"}, + {"id": 304, "name": "rice cooker"}, + {"id": 323, "name": "poker card"}, +] + + +def _get_builtin_metadata(version): + if version == "v1": + id_to_name = {x["id"]: x["name"] for x in categories_v1} + else: + assert 0, version + thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(365)} + thing_classes = [id_to_name[k] for k in sorted(id_to_name)] + return { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + } + + +_PREDEFINED_SPLITS_OBJECTS365 = { + "objects365_train": ("objects365/train", "objects365/annotations/objects365_train.json"), + "objects365_val": ("objects365/val", "objects365/annotations/objects365_val.json"), +} + +for key, (image_root, json_file) in _PREDEFINED_SPLITS_OBJECTS365.items(): + register_coco_instances( + key, + _get_builtin_metadata("v1"), + os.path.join("datasets", json_file) if "://" not in json_file else json_file, + os.path.join("datasets", image_root), + ) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_augmentation_impl.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_augmentation_impl.py new file mode 100644 index 0000000000..f4ec0ad07f --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_augmentation_impl.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified by Xingyi Zhou +""" +Implement many useful :class:`Augmentation`. +""" + +from detectron2.data.transforms.augmentation import Augmentation +import numpy as np +from PIL import Image + +from .custom_transform import EfficientDetResizeCropTransform + +__all__ = [ + "EfficientDetResizeCrop", +] + + +class EfficientDetResizeCrop(Augmentation): + """ + Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge. + If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. + """ + + def __init__(self, size: int, scale, interp=Image.BILINEAR) -> None: + """ + Args: + """ + super().__init__() + self.target_size = (size, size) + self.scale = scale + self.interp = interp + + def get_transform(self, img): + # Select a random scale factor. + scale_factor = np.random.uniform(*self.scale) + scaled_target_height = scale_factor * self.target_size[0] + scaled_target_width = scale_factor * self.target_size[1] + # Recompute the accurate scale_factor using rounded scaled image size. + width, height = img.shape[1], img.shape[0] + img_scale_y = scaled_target_height / height + img_scale_x = scaled_target_width / width + img_scale = min(img_scale_y, img_scale_x) + + # Select non-zero random offset (x, y) if scaled image is larger than target size + scaled_h = int(height * img_scale) + scaled_w = int(width * img_scale) + offset_y = scaled_h - self.target_size[0] + offset_x = scaled_w - self.target_size[1] + offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1)) + offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1)) + return EfficientDetResizeCropTransform( + scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp + ) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_transform.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_transform.py new file mode 100644 index 0000000000..6635a5999b --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_transform.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified by Xingyi Zhou +# File: transform.py + +from fvcore.transforms.transform import ( + Transform, +) +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F + +try: + import cv2 +except ImportError: + # OpenCV is an optional dependency at the moment + pass + +__all__ = [ + "EfficientDetResizeCropTransform", +] + + +class EfficientDetResizeCropTransform(Transform): + """ """ + + def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, target_size: int, interp=None) -> None: + """ + Args: + h, w (int): original image size + new_h, new_w (int): new image size + interp: PIL interpolation methods, defaults to bilinear. + """ + # TODO decide on PIL vs opencv + super().__init__() + if interp is None: + interp = Image.BILINEAR + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + # assert img.shape[:2] == (self.h, self.w) + assert len(img.shape) <= 4 + + if img.dtype == np.uint8: + pil_image = Image.fromarray(img) + interp_method = interp if interp is not None else self.interp + pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method) + ret = np.asarray(pil_image) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + # img = img.crop((self.offset_x, self.offset_y, right, lower)) + if len(ret.shape) <= 3: + ret = ret[self.offset_y : lower, self.offset_x : right] + else: + ret = ret[..., self.offset_y : lower, self.offset_x : right, :] + else: + # PIL only supports uint8 + img = torch.from_numpy(img) + shape = list(img.shape) + shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] + img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw + _PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"} + mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp] + img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False) + shape[:2] = (self.scaled_h, self.scaled_w) + ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) + right = min(self.scaled_w, self.offset_x + self.target_size[1]) + lower = min(self.scaled_h, self.offset_y + self.target_size[0]) + if len(ret.shape) <= 3: + ret = ret[self.offset_y : lower, self.offset_x : right] + else: + ret = ret[..., self.offset_y : lower, self.offset_x : right, :] + return ret + + def apply_coords(self, coords): + coords[:, 0] = coords[:, 0] * self.img_scale + coords[:, 1] = coords[:, 1] * self.img_scale + coords[:, 0] -= self.offset_x + coords[:, 1] -= self.offset_y + return coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + def inverse(self): + raise NotImplementedError + # return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn.py new file mode 100644 index 0000000000..733b502da4 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn.py @@ -0,0 +1,527 @@ +# Modified from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/efficientdet.py +# The original file is under Apache-2.0 License +from collections import OrderedDict +import math + +from detectron2.layers import Conv2d, ShapeSpec +from detectron2.layers.batch_norm import get_norm +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.resnet import build_resnet_backbone +import torch +from torch import nn + +from .dlafpn import dla34 + + +def get_fpn_config(base_reduction: int=8): + """BiFPN config with sum.""" + p = { + "nodes": [ + {"reduction": base_reduction << 3, "inputs_offsets": [3, 4]}, + {"reduction": base_reduction << 2, "inputs_offsets": [2, 5]}, + {"reduction": base_reduction << 1, "inputs_offsets": [1, 6]}, + {"reduction": base_reduction, "inputs_offsets": [0, 7]}, + {"reduction": base_reduction << 1, "inputs_offsets": [1, 7, 8]}, + {"reduction": base_reduction << 2, "inputs_offsets": [2, 6, 9]}, + {"reduction": base_reduction << 3, "inputs_offsets": [3, 5, 10]}, + {"reduction": base_reduction << 4, "inputs_offsets": [4, 11]}, + ], + "weight_method": "fastattn", + } + return p + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941""" + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +class SequentialAppend(nn.Sequential): + def __init__(self, *args) -> None: + super().__init__(*args) + + def forward(self, x): + for module in self: + x.append(module(x)) + return x + + +class SequentialAppendLast(nn.Sequential): + def __init__(self, *args) -> None: + super().__init__(*args) + + # def forward(self, x: List[torch.Tensor]): + def forward(self, x): + for module in self: + x.append(module(x[-1])) + return x + + +class ConvBnAct2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int, + stride: int=1, + dilation: int=1, + padding: str="", + bias: bool=False, + norm: str="", + act_layer=Swish, + ) -> None: + super().__init__() + # self.conv = create_conv2d( + # in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias) + self.conv = Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + bias=(norm == ""), + ) + self.bn = get_norm(norm, out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +class SeparableConv2d(nn.Module): + """Separable Conv""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size: int=3, + stride: int=1, + dilation: int=1, + padding: str="", + bias: bool=False, + channel_multiplier: float=1.0, + pw_kernel_size: int=1, + act_layer=Swish, + norm: str="", + ) -> None: + super().__init__() + + # self.conv_dw = create_conv2d( + # in_channels, int(in_channels * channel_multiplier), kernel_size, + # stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_dw = Conv2d( + in_channels, + int(in_channels * channel_multiplier), + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + bias=bias, + groups=out_channels, + ) + # print('conv_dw', kernel_size, stride) + # self.conv_pw = create_conv2d( + # int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + self.conv_pw = Conv2d( + int(in_channels * channel_multiplier), + out_channels, + kernel_size=pw_kernel_size, + padding=pw_kernel_size // 2, + bias=(norm == ""), + ) + # print('conv_pw', pw_kernel_size) + + self.bn = get_norm(norm, out_channels) + self.act = None if act_layer is None else act_layer(inplace=True) + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + + +class ResampleFeatureMap(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + reduction_ratio: float=1.0, + pad_type: str="", + pooling_type: str="max", + norm: str="", + apply_bn: bool=False, + conv_after_downsample: bool=False, + redundant_bias: bool=False, + ) -> None: + super().__init__() + pooling_type = pooling_type or "max" + self.in_channels = in_channels + self.out_channels = out_channels + self.reduction_ratio = reduction_ratio + self.conv_after_downsample = conv_after_downsample + + conv = None + if in_channels != out_channels: + conv = ConvBnAct2d( + in_channels, + out_channels, + kernel_size=1, + padding=pad_type, + norm=norm if apply_bn else "", + bias=not apply_bn or redundant_bias, + act_layer=None, + ) + + if reduction_ratio > 1: + stride_size = int(reduction_ratio) + if conv is not None and not self.conv_after_downsample: + self.add_module("conv", conv) + self.add_module( + "downsample", + # create_pool2d( + # pooling_type, kernel_size=stride_size + 1, stride=stride_size, padding=pad_type) + # nn.MaxPool2d(kernel_size=stride_size + 1, stride=stride_size, padding=pad_type) + nn.MaxPool2d(kernel_size=stride_size, stride=stride_size), + ) + if conv is not None and self.conv_after_downsample: + self.add_module("conv", conv) + else: + if conv is not None: + self.add_module("conv", conv) + if reduction_ratio < 1: + scale = int(1 // reduction_ratio) + self.add_module("upsample", nn.UpsamplingNearest2d(scale_factor=scale)) + + +class FpnCombine(nn.Module): + def __init__( + self, + feature_info, + fpn_config, + fpn_channels, + inputs_offsets, + target_reduction, + pad_type: str="", + pooling_type: str="max", + norm: str="", + apply_bn_for_resampling: bool=False, + conv_after_downsample: bool=False, + redundant_bias: bool=False, + weight_method: str="attn", + ) -> None: + super().__init__() + self.inputs_offsets = inputs_offsets + self.weight_method = weight_method + + self.resample = nn.ModuleDict() + for _idx, offset in enumerate(inputs_offsets): + in_channels = fpn_channels + if offset < len(feature_info): + in_channels = feature_info[offset]["num_chs"] + input_reduction = feature_info[offset]["reduction"] + else: + node_idx = offset - len(feature_info) + # print('node_idx, len', node_idx, len(fpn_config['nodes'])) + input_reduction = fpn_config["nodes"][node_idx]["reduction"] + reduction_ratio = target_reduction / input_reduction + self.resample[str(offset)] = ResampleFeatureMap( + in_channels, + fpn_channels, + reduction_ratio=reduction_ratio, + pad_type=pad_type, + pooling_type=pooling_type, + norm=norm, + apply_bn=apply_bn_for_resampling, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias, + ) + + if weight_method == "attn" or weight_method == "fastattn": + # WSM + self.edge_weights = nn.Parameter(torch.ones(len(inputs_offsets)), requires_grad=True) + else: + self.edge_weights = None + + def forward(self, x): + dtype = x[0].dtype + nodes = [] + for offset in self.inputs_offsets: + input_node = x[offset] + input_node = self.resample[str(offset)](input_node) + nodes.append(input_node) + + if self.weight_method == "attn": + normalized_weights = torch.softmax(self.edge_weights.type(dtype), dim=0) + x = torch.stack(nodes, dim=-1) * normalized_weights + elif self.weight_method == "fastattn": + edge_weights = nn.functional.relu(self.edge_weights.type(dtype)) + weights_sum = torch.sum(edge_weights) + x = torch.stack( + [(nodes[i] * edge_weights[i]) / (weights_sum + 0.0001) for i in range(len(nodes))], + dim=-1, + ) + elif self.weight_method == "sum": + x = torch.stack(nodes, dim=-1) + else: + raise ValueError(f"unknown weight_method {self.weight_method}") + x = torch.sum(x, dim=-1) + return x + + +class BiFpnLayer(nn.Module): + def __init__( + self, + feature_info, + fpn_config, + fpn_channels, + num_levels: int=5, + pad_type: str="", + pooling_type: str="max", + norm: str="", + act_layer=Swish, + apply_bn_for_resampling: bool=False, + conv_after_downsample: bool=True, + conv_bn_relu_pattern: bool=False, + separable_conv: bool=True, + redundant_bias: bool=False, + ) -> None: + super().__init__() + self.fpn_config = fpn_config + self.num_levels = num_levels + self.conv_bn_relu_pattern = False + + self.feature_info = [] + self.fnode = SequentialAppend() + for i, fnode_cfg in enumerate(fpn_config["nodes"]): + # logging.debug('fnode {} : {}'.format(i, fnode_cfg)) + # print('fnode {} : {}'.format(i, fnode_cfg)) + fnode_layers = OrderedDict() + + # combine features + reduction = fnode_cfg["reduction"] + fnode_layers["combine"] = FpnCombine( + feature_info, + fpn_config, + fpn_channels, + fnode_cfg["inputs_offsets"], + target_reduction=reduction, + pad_type=pad_type, + pooling_type=pooling_type, + norm=norm, + apply_bn_for_resampling=apply_bn_for_resampling, + conv_after_downsample=conv_after_downsample, + redundant_bias=redundant_bias, + weight_method=fpn_config["weight_method"], + ) + self.feature_info.append(dict(num_chs=fpn_channels, reduction=reduction)) + + # after combine ops + after_combine = OrderedDict() + if not conv_bn_relu_pattern: + after_combine["act"] = act_layer(inplace=True) + conv_bias = redundant_bias + conv_act = None + else: + conv_bias = False + conv_act = act_layer + conv_kwargs = dict( + in_channels=fpn_channels, + out_channels=fpn_channels, + kernel_size=3, + padding=pad_type, + bias=conv_bias, + norm=norm, + act_layer=conv_act, + ) + after_combine["conv"] = ( + SeparableConv2d(**conv_kwargs) if separable_conv else ConvBnAct2d(**conv_kwargs) + ) + fnode_layers["after_combine"] = nn.Sequential(after_combine) + + self.fnode.add_module(str(i), nn.Sequential(fnode_layers)) + + self.feature_info = self.feature_info[-num_levels::] + + def forward(self, x): + x = self.fnode(x) + return x[-self.num_levels : :] + + +class BiFPN(Backbone): + def __init__( + self, + cfg, + bottom_up, + in_features, + out_channels, + norm: str="", + num_levels: int=5, + num_bifpn: int=4, + separable_conv: bool=False, + ) -> None: + super().__init__() + assert isinstance(bottom_up, Backbone) + + # Feature map strides and channels from the bottom up network (e.g. ResNet) + input_shapes = bottom_up.output_shape() + in_strides = [input_shapes[f].stride for f in in_features] + in_channels = [input_shapes[f].channels for f in in_features] + + self.num_levels = num_levels + self.num_bifpn = num_bifpn + self.bottom_up = bottom_up + self.in_features = in_features + self._size_divisibility = 128 + levels = [int(math.log2(s)) for s in in_strides] + self._out_feature_strides = {f"p{int(math.log2(s))}": s for s in in_strides} + if len(in_features) < num_levels: + for l in range(num_levels - len(in_features)): + s = l + levels[-1] + self._out_feature_strides[f"p{s + 1}"] = 2 ** (s + 1) + self._out_features = list(sorted(self._out_feature_strides.keys())) + self._out_feature_channels = {k: out_channels for k in self._out_features} + + # print('self._out_feature_strides', self._out_feature_strides) + # print('self._out_feature_channels', self._out_feature_channels) + + feature_info = [ + {"num_chs": in_channels[level], "reduction": in_strides[level]} + for level in range(len(self.in_features)) + ] + # self.config = config + fpn_config = get_fpn_config() + self.resample = SequentialAppendLast() + for level in range(num_levels): + if level < len(feature_info): + in_chs = in_channels[level] # feature_info[level]['num_chs'] + reduction = in_strides[level] # feature_info[level]['reduction'] + else: + # Adds a coarser level by downsampling the last feature map + reduction_ratio = 2 + self.resample.add_module( + str(level), + ResampleFeatureMap( + in_channels=in_chs, + out_channels=out_channels, + pad_type="same", + pooling_type=None, + norm=norm, + reduction_ratio=reduction_ratio, + apply_bn=True, + conv_after_downsample=False, + redundant_bias=False, + ), + ) + in_chs = out_channels + reduction = int(reduction * reduction_ratio) + feature_info.append(dict(num_chs=in_chs, reduction=reduction)) + + self.cell = nn.Sequential() + for rep in range(self.num_bifpn): + # logging.debug('building cell {}'.format(rep)) + # print('building cell {}'.format(rep)) + fpn_layer = BiFpnLayer( + feature_info=feature_info, + fpn_config=fpn_config, + fpn_channels=out_channels, + num_levels=self.num_levels, + pad_type="same", + pooling_type=None, + norm=norm, + act_layer=Swish, + separable_conv=separable_conv, + apply_bn_for_resampling=True, + conv_after_downsample=False, + conv_bn_relu_pattern=False, + redundant_bias=False, + ) + self.cell.add_module(str(rep), fpn_layer) + feature_info = fpn_layer.feature_info + # import pdb; pdb.set_trace() + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + # print('input shapes', x.shape) + bottom_up_features = self.bottom_up(x) + x = [bottom_up_features[f] for f in self.in_features] + assert len(self.resample) == self.num_levels - len(x) + x = self.resample(x) + [xx.shape for xx in x] + # print('resample shapes', shapes) + x = self.cell(x) + out = {f: xx for f, xx in zip(self._out_features, x, strict=False)} + # import pdb; pdb.set_trace() + return out + + +@BACKBONE_REGISTRY.register() +def build_resnet_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + backbone = BiFPN( + cfg=cfg, + bottom_up=bottom_up, + in_features=in_features, + out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS, + norm=cfg.MODEL.BIFPN.NORM, + num_levels=cfg.MODEL.BIFPN.NUM_LEVELS, + num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN, + separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p37_dla_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = dla34(cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + assert cfg.MODEL.BIFPN.NUM_LEVELS == 5 + + backbone = BiFPN( + cfg=cfg, + bottom_up=bottom_up, + in_features=in_features, + out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS, + norm=cfg.MODEL.BIFPN.NORM, + num_levels=cfg.MODEL.BIFPN.NUM_LEVELS, + num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN, + separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV, + ) + return backbone diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn_fcos.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn_fcos.py new file mode 100644 index 0000000000..27ad4e62fc --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn_fcos.py @@ -0,0 +1,455 @@ +# This file is modified from https://github.com/aim-uofa/AdelaiDet/blob/master/adet/modeling/backbone/bifpn.py +# The original file is under 2-clause BSD License for academic use, and *non-commercial use*. +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import BACKBONE_REGISTRY +from detectron2.modeling.backbone import Backbone, build_resnet_backbone +import torch +from torch import nn +import torch.nn.functional as F + +from .dlafpn import dla34 +from typing import Sequence + +__all__ = [] + + +def swish(x): + return x * x.sigmoid() + + +def split_name(name: str): + for i, c in enumerate(name): + if not c.isalpha(): + return name[:i], int(name[i:]) + raise ValueError() + + +class FeatureMapResampler(nn.Module): + def __init__(self, in_channels, out_channels, stride: int, norm: str="") -> None: + super().__init__() + if in_channels != out_channels: + self.reduction = Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=(norm == ""), + norm=get_norm(norm, out_channels), + activation=None, + ) + else: + self.reduction = None + + assert stride <= 2 + self.stride = stride + + def forward(self, x): + if self.reduction is not None: + x = self.reduction(x) + + if self.stride == 2: + x = F.max_pool2d(x, kernel_size=self.stride + 1, stride=self.stride, padding=1) + elif self.stride == 1: + pass + else: + raise NotImplementedError() + return x + + +class BackboneWithTopLevels(Backbone): + def __init__(self, backbone, out_channels, num_top_levels: int, norm: str="") -> None: + super().__init__() + self.backbone = backbone + backbone_output_shape = backbone.output_shape() + + self._out_feature_channels = { + name: shape.channels for name, shape in backbone_output_shape.items() + } + self._out_feature_strides = { + name: shape.stride for name, shape in backbone_output_shape.items() + } + self._out_features = list(self._out_feature_strides.keys()) + + last_feature_name = max(self._out_feature_strides.keys(), key=lambda x: split_name(x)[1]) + self.last_feature_name = last_feature_name + self.num_top_levels = num_top_levels + + last_channels = self._out_feature_channels[last_feature_name] + last_stride = self._out_feature_strides[last_feature_name] + + prefix, suffix = split_name(last_feature_name) + prev_channels = last_channels + for i in range(num_top_levels): + name = prefix + str(suffix + i + 1) + self.add_module(name, FeatureMapResampler(prev_channels, out_channels, 2, norm)) + prev_channels = out_channels + + self._out_feature_channels[name] = out_channels + self._out_feature_strides[name] = last_stride * 2 ** (i + 1) + self._out_features.append(name) + + def forward(self, x): + outputs = self.backbone(x) + last_features = outputs[self.last_feature_name] + prefix, suffix = split_name(self.last_feature_name) + + x = last_features + for i in range(self.num_top_levels): + name = prefix + str(suffix + i + 1) + x = self.__getattr__(name)(x) + outputs[name] = x + + return outputs + + +class SingleBiFPN(Backbone): + """ + This module implements Feature Pyramid Network. + It creates pyramid features built on top of some input feature maps. + """ + + def __init__(self, in_channels_list, out_channels, norm: str="") -> None: + """ + Args: + bottom_up (Backbone): module representing the bottom up subnetwork. + Must be a subclass of :class:`Backbone`. The multi-scale feature + maps generated by the bottom up network, and listed in `in_features`, + are used to generate FPN levels. + in_features (list[str]): names of the input feature maps coming + from the backbone to which FPN is attached. For example, if the + backbone produces ["res2", "res3", "res4"], any *contiguous* sublist + of these may be used; order must be from high to low resolution. + out_channels (int): number of channels in the output feature maps. + norm (str): the normalization to use. + """ + super().__init__() + + self.out_channels = out_channels + # build 5-levels bifpn + if len(in_channels_list) == 5: + self.nodes = [ + {"feat_level": 3, "inputs_offsets": [3, 4]}, + {"feat_level": 2, "inputs_offsets": [2, 5]}, + {"feat_level": 1, "inputs_offsets": [1, 6]}, + {"feat_level": 0, "inputs_offsets": [0, 7]}, + {"feat_level": 1, "inputs_offsets": [1, 7, 8]}, + {"feat_level": 2, "inputs_offsets": [2, 6, 9]}, + {"feat_level": 3, "inputs_offsets": [3, 5, 10]}, + {"feat_level": 4, "inputs_offsets": [4, 11]}, + ] + elif len(in_channels_list) == 3: + self.nodes = [ + {"feat_level": 1, "inputs_offsets": [1, 2]}, + {"feat_level": 0, "inputs_offsets": [0, 3]}, + {"feat_level": 1, "inputs_offsets": [1, 3, 4]}, + {"feat_level": 2, "inputs_offsets": [2, 5]}, + ] + else: + raise NotImplementedError + + node_info = [_ for _ in in_channels_list] + + num_output_connections = [0 for _ in in_channels_list] + for fnode in self.nodes: + feat_level = fnode["feat_level"] + inputs_offsets = fnode["inputs_offsets"] + inputs_offsets_str = "_".join(map(str, inputs_offsets)) + for input_offset in inputs_offsets: + num_output_connections[input_offset] += 1 + + in_channels = node_info[input_offset] + if in_channels != out_channels: + lateral_conv = Conv2d( + in_channels, out_channels, kernel_size=1, norm=get_norm(norm, out_channels) + ) + self.add_module(f"lateral_{input_offset}_f{feat_level}", lateral_conv) + node_info.append(out_channels) + num_output_connections.append(0) + + # generate attention weights + name = f"weights_f{feat_level}_{inputs_offsets_str}" + self.__setattr__( + name, + nn.Parameter( + torch.ones(len(inputs_offsets), dtype=torch.float32), requires_grad=True + ), + ) + + # generate convolutions after combination + name = f"outputs_f{feat_level}_{inputs_offsets_str}" + self.add_module( + name, + Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm=get_norm(norm, out_channels), + bias=(norm == ""), + ), + ) + + def forward(self, feats): + """ + Args: + input (dict[str->Tensor]): mapping feature map name (e.g., "p5") to + feature map tensor for each feature level in high to low resolution order. + Returns: + dict[str->Tensor]: + mapping from feature map name to FPN feature map tensor + in high to low resolution order. Returned feature names follow the FPN + paper convention: "p", where stage has stride = 2 ** stage e.g., + ["n2", "n3", ..., "n6"]. + """ + feats = [_ for _ in feats] + num_levels = len(feats) + num_output_connections = [0 for _ in feats] + for fnode in self.nodes: + feat_level = fnode["feat_level"] + inputs_offsets = fnode["inputs_offsets"] + inputs_offsets_str = "_".join(map(str, inputs_offsets)) + input_nodes = [] + _, _, target_h, target_w = feats[feat_level].size() + for input_offset in inputs_offsets: + num_output_connections[input_offset] += 1 + input_node = feats[input_offset] + + # reduction + if input_node.size(1) != self.out_channels: + name = f"lateral_{input_offset}_f{feat_level}" + input_node = self.__getattr__(name)(input_node) + + # maybe downsample + _, _, h, w = input_node.size() + if h > target_h and w > target_w: + height_stride_size = int((h - 1) // target_h + 1) + width_stride_size = int((w - 1) // target_w + 1) + assert height_stride_size == width_stride_size == 2 + input_node = F.max_pool2d( + input_node, + kernel_size=(height_stride_size + 1, width_stride_size + 1), + stride=(height_stride_size, width_stride_size), + padding=1, + ) + elif h <= target_h and w <= target_w: + if h < target_h or w < target_w: + input_node = F.interpolate( + input_node, size=(target_h, target_w), mode="nearest" + ) + else: + raise NotImplementedError() + input_nodes.append(input_node) + + # attention + name = f"weights_f{feat_level}_{inputs_offsets_str}" + weights = F.relu(self.__getattr__(name)) + norm_weights = weights / (weights.sum() + 0.0001) + + new_node = torch.stack(input_nodes, dim=-1) + new_node = (norm_weights * new_node).sum(dim=-1) + new_node = swish(new_node) + + name = f"outputs_f{feat_level}_{inputs_offsets_str}" + feats.append(self.__getattr__(name)(new_node)) + + num_output_connections.append(0) + + output_feats = [] + for idx in range(num_levels): + for i, fnode in enumerate(reversed(self.nodes)): + if fnode["feat_level"] == idx: + output_feats.append(feats[-1 - i]) + break + else: + raise ValueError() + return output_feats + + +class BiFPN(Backbone): + """ + This module implements Feature Pyramid Network. + It creates pyramid features built on top of some input feature maps. + """ + + def __init__(self, bottom_up, in_features, out_channels, num_top_levels: int, num_repeats: int, norm: str="") -> None: + """ + Args: + bottom_up (Backbone): module representing the bottom up subnetwork. + Must be a subclass of :class:`Backbone`. The multi-scale feature + maps generated by the bottom up network, and listed in `in_features`, + are used to generate FPN levels. + in_features (list[str]): names of the input feature maps coming + from the backbone to which FPN is attached. For example, if the + backbone produces ["res2", "res3", "res4"], any *contiguous* sublist + of these may be used; order must be from high to low resolution. + out_channels (int): number of channels in the output feature maps. + num_top_levels (int): the number of the top levels (p6 or p7). + num_repeats (int): the number of repeats of BiFPN. + norm (str): the normalization to use. + """ + super().__init__() + assert isinstance(bottom_up, Backbone) + + # add extra feature levels (i.e., 6 and 7) + self.bottom_up = BackboneWithTopLevels(bottom_up, out_channels, num_top_levels, norm) + bottom_up_output_shapes = self.bottom_up.output_shape() + + in_features = sorted(in_features, key=lambda x: split_name(x)[1]) + self._size_divisibility = 128 # bottom_up_output_shapes[in_features[-1]].stride + self.out_channels = out_channels + self.min_level = split_name(in_features[0])[1] + + # add the names for top blocks + prefix, last_suffix = split_name(in_features[-1]) + for i in range(num_top_levels): + in_features.append(prefix + str(last_suffix + i + 1)) + self.in_features = in_features + + # generate output features + self._out_features = [f"p{split_name(name)[1]}" for name in in_features] + self._out_feature_strides = { + out_name: bottom_up_output_shapes[in_name].stride + for out_name, in_name in zip(self._out_features, in_features, strict=False) + } + self._out_feature_channels = {k: out_channels for k in self._out_features} + + # build bifpn + self.repeated_bifpn = nn.ModuleList() + for i in range(num_repeats): + if i == 0: + in_channels_list = [bottom_up_output_shapes[name].channels for name in in_features] + else: + in_channels_list = [self._out_feature_channels[name] for name in self._out_features] + self.repeated_bifpn.append(SingleBiFPN(in_channels_list, out_channels, norm)) + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + """ + Args: + input (dict[str->Tensor]): mapping feature map name (e.g., "p5") to + feature map tensor for each feature level in high to low resolution order. + Returns: + dict[str->Tensor]: + mapping from feature map name to FPN feature map tensor + in high to low resolution order. Returned feature names follow the FPN + paper convention: "p", where stage has stride = 2 ** stage e.g., + ["n2", "n3", ..., "n6"]. + """ + bottom_up_features = self.bottom_up(x) + feats = [bottom_up_features[f] for f in self.in_features] + + for bifpn in self.repeated_bifpn: + feats = bifpn(feats) + + return dict(zip(self._out_features, feats, strict=False)) + + +def _assert_strides_are_log2_contiguous(strides: Sequence[int]) -> None: + """ + Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". + """ + for i, stride in enumerate(strides[1:], 1): + assert stride == 2 * strides[i - 1], f"Strides {stride} {strides[i - 1]} are not log2 contiguous" + + +@BACKBONE_REGISTRY.register() +def build_fcos_resnet_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + top_levels = 2 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p35_fcos_resnet_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + top_levels = 0 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p35_fcos_dla_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = dla34(cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + top_levels = 0 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p37_fcos_dla_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = dla34(cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.BIFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BIFPN.NUM_BIFPN + assert cfg.MODEL.BIFPN.NUM_LEVELS == 5 + top_levels = 2 + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BIFPN.NORM, + ) + return backbone diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dla.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dla.py new file mode 100644 index 0000000000..8b6464153b --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dla.py @@ -0,0 +1,542 @@ +import math +from os.path import join + +from detectron2.layers import ( + Conv2d, + DeformConv, + ModulatedDeformConv, + ShapeSpec, + get_norm, +) +from detectron2.modeling.backbone.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN +from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo + +__all__ = [ + "BasicStem", + "BottleneckBlock", + "DeformBottleneckBlock", +] + +DCNV1 = False + +HASH = { + 34: "ba72cf86", + 60: "24839fc4", +} + + +def get_model_url(data, name: str, hash): + return join("http://dl.yf.io/dla/models", data, f"{name}-{hash}.pth") + + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride: int=1, dilation: int=1, norm: str="BN") -> None: + super().__init__() + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn1 = get_norm(norm, planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation + ) + self.bn2 = get_norm(norm, planes) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride: int=1, dilation: int=1, norm: str="BN") -> None: + super().__init__() + expansion = Bottleneck.expansion + bottle_planes = planes // expansion + self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) + self.bn1 = get_norm(norm, bottle_planes) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn2 = get_norm(norm, bottle_planes) + self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) + self.bn3 = get_norm(norm, planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class Root(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size: int, residual, norm: str="BN") -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2 + ) + self.bn = get_norm(norm, out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class Tree(nn.Module): + def __init__( + self, + levels, + block, + in_channels, + out_channels, + stride: int=1, + level_root: bool=False, + root_dim: int=0, + root_kernel_size: int=1, + dilation: int=1, + root_residual: bool=False, + norm: str="BN", + ) -> None: + super().__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block(in_channels, out_channels, stride, dilation=dilation, norm=norm) + self.tree2 = block(out_channels, out_channels, 1, dilation=dilation, norm=norm) + else: + self.tree1 = Tree( + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + norm=norm, + ) + self.tree2 = Tree( + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + norm=norm, + ) + if levels == 1: + self.root = Root(root_dim, out_channels, root_kernel_size, root_residual, norm=norm) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride) + if in_channels != out_channels: + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + get_norm(norm, out_channels), + ) + + def forward(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + def __init__( + self, num_layers: int, levels, channels, block=BasicBlock, residual_root: bool=False, norm: str="BN" + ) -> None: + """ + Args: + """ + super().__init__() + self.norm = norm + self.channels = channels + self.base_layer = nn.Sequential( + nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False), + get_norm(self.norm, channels[0]), + nn.ReLU(inplace=True), + ) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree( + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root, + norm=norm, + ) + self.level3 = Tree( + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root, + norm=norm, + ) + self.level4 = Tree( + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root, + norm=norm, + ) + self.level5 = Tree( + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root, + norm=norm, + ) + self.load_pretrained_model( + data="imagenet", name=f"dla{num_layers}", hash=HASH[num_layers] + ) + + def load_pretrained_model(self, data, name: str, hash) -> None: + model_url = get_model_url(data, name, hash) + model_weights = model_zoo.load_url(model_url) + num_classes = len(model_weights[list(model_weights.keys())[-1]]) + self.fc = nn.Conv2d( + self.channels[-1], num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) + print("Loading pretrained") + self.load_state_dict(model_weights, strict=False) + + def _make_conv_level(self, inplanes, planes, convs, stride: int=1, dilation: int=1): + modules = [] + for i in range(convs): + modules.extend( + [ + nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias=False, + dilation=dilation, + ), + get_norm(self.norm, planes), + nn.ReLU(inplace=True), + ] + ) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, x): + y = [] + x = self.base_layer(x) + for i in range(6): + x = getattr(self, f"level{i}")(x) + y.append(x) + return y + + +def fill_up_weights(up) -> None: + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2.0 * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + +class _DeformConv(nn.Module): + def __init__(self, chi, cho, norm: str="BN") -> None: + super().__init__() + self.actf = nn.Sequential(get_norm(norm, cho), nn.ReLU(inplace=True)) + if DCNV1: + self.offset = Conv2d(chi, 18, kernel_size=3, stride=1, padding=1, dilation=1) + self.conv = DeformConv( + chi, cho, kernel_size=(3, 3), stride=1, padding=1, dilation=1, deformable_groups=1 + ) + else: + self.offset = Conv2d(chi, 27, kernel_size=3, stride=1, padding=1, dilation=1) + self.conv = ModulatedDeformConv( + chi, cho, kernel_size=3, stride=1, padding=1, dilation=1, deformable_groups=1 + ) + nn.init.constant_(self.offset.weight, 0) + nn.init.constant_(self.offset.bias, 0) + + def forward(self, x): + if DCNV1: + offset = self.offset(x) + x = self.conv(x, offset) + else: + offset_mask = self.offset(x) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + x = self.conv(x, offset, mask) + x = self.actf(x) + return x + + +class IDAUp(nn.Module): + def __init__(self, o, channels, up_f, norm: str="BN") -> None: + super().__init__() + for i in range(1, len(channels)): + c = channels[i] + f = int(up_f[i]) + proj = _DeformConv(c, o, norm=norm) + node = _DeformConv(o, o, norm=norm) + + up = nn.ConvTranspose2d( + o, o, f * 2, stride=f, padding=f // 2, output_padding=0, groups=o, bias=False + ) + fill_up_weights(up) + + setattr(self, "proj_" + str(i), proj) + setattr(self, "up_" + str(i), up) + setattr(self, "node_" + str(i), node) + + def forward(self, layers, startp, endp) -> None: + for i in range(startp + 1, endp): + upsample = getattr(self, "up_" + str(i - startp)) + project = getattr(self, "proj_" + str(i - startp)) + layers[i] = upsample(project(layers[i])) + node = getattr(self, "node_" + str(i - startp)) + layers[i] = node(layers[i] + layers[i - 1]) + + +class DLAUp(nn.Module): + def __init__(self, startp, channels, scales, in_channels=None, norm: str="BN") -> None: + super().__init__() + self.startp = startp + if in_channels is None: + in_channels = channels + self.channels = channels + channels = list(channels) + scales = np.array(scales, dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr( + self, + f"ida_{i}", + IDAUp(channels[j], in_channels[j:], scales[j:] // scales[j], norm=norm), + ) + scales[j + 1 :] = scales[j] + in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]] + + def forward(self, layers): + out = [layers[-1]] # start with 32 + for i in range(len(layers) - self.startp - 1): + ida = getattr(self, f"ida_{i}") + ida(layers, len(layers) - i - 2, len(layers)) + out.insert(0, layers[-1]) + return out + + +DLA_CONFIGS = { + 34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], BasicBlock), + 60: ([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], Bottleneck), +} + + +class DLASeg(Backbone): + def __init__(self, num_layers: int, out_features, use_dla_up: bool=True, ms_output: bool=False, norm: str="BN") -> None: + super().__init__() + # depth = 34 + levels, channels, Block = DLA_CONFIGS[num_layers] + self.base = DLA( + num_layers=num_layers, levels=levels, channels=channels, block=Block, norm=norm + ) + down_ratio = 4 + self.first_level = int(np.log2(down_ratio)) + self.ms_output = ms_output + self.last_level = 5 if not self.ms_output else 6 + channels = self.base.channels + scales = [2**i for i in range(len(channels[self.first_level :]))] + self.use_dla_up = use_dla_up + if self.use_dla_up: + self.dla_up = DLAUp(self.first_level, channels[self.first_level :], scales, norm=norm) + out_channel = channels[self.first_level] + if not self.ms_output: # stride 4 DLA + self.ida_up = IDAUp( + out_channel, + channels[self.first_level : self.last_level], + [2**i for i in range(self.last_level - self.first_level)], + norm=norm, + ) + self._out_features = out_features + self._out_feature_channels = {f"dla{i}": channels[i] for i in range(6)} + self._out_feature_strides = {f"dla{i}": 2**i for i in range(6)} + self._size_divisibility = 32 + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + x = self.base(x) + if self.use_dla_up: + x = self.dla_up(x) + if not self.ms_output: # stride 4 dla + y = [] + for i in range(self.last_level - self.first_level): + y.append(x[i].clone()) + self.ida_up(y, 0, len(y)) + ret = {} + for i in range(self.last_level - self.first_level): + out_feature = f"dla{i}" + if out_feature in self._out_features: + ret[out_feature] = y[i] + else: + ret = {} + st = self.first_level if self.use_dla_up else 0 + for i in range(self.last_level - st): + out_feature = f"dla{i + st}" + if out_feature in self._out_features: + ret[out_feature] = x[i] + + return ret + + +@BACKBONE_REGISTRY.register() +def build_dla_backbone(cfg, input_shape): + """ + Create a ResNet instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + return DLASeg( + out_features=cfg.MODEL.DLA.OUT_FEATURES, + num_layers=cfg.MODEL.DLA.NUM_LAYERS, + use_dla_up=cfg.MODEL.DLA.USE_DLA_UP, + ms_output=cfg.MODEL.DLA.MS_OUTPUT, + norm=cfg.MODEL.DLA.NORM, + ) + + +class LastLevelP6P7(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7 from + C5 feature. + """ + + def __init__(self, in_channels, out_channels) -> None: + super().__init__() + self.num_levels = 2 + self.in_feature = "dla5" + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_retinanet_dla_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_dla_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + in_channels_p6p7 = bottom_up.output_shape()["dla5"].channels + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7(in_channels_p6p7, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dlafpn.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dlafpn.py new file mode 100644 index 0000000000..54f05bf719 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dlafpn.py @@ -0,0 +1,563 @@ +#!/usr/bin/env python + +# this file is from https://github.com/ucbdrive/dla/blob/master/dla.py. + +import math +from os.path import join + +from detectron2.layers import Conv2d, ModulatedDeformConv, ShapeSpec +from detectron2.layers.batch_norm import get_norm +from detectron2.modeling.backbone import FPN, Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from typing import Optional + +WEB_ROOT = "http://dl.yf.io/dla/models" + + +def get_model_url(data, name: str, hash): + return join("http://dl.yf.io/dla/models", data, f"{name}-{hash}.pth") + + +def conv3x3(in_planes, out_planes, stride: int=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, cfg, inplanes, planes, stride: int=1, dilation: int=1) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn1 = get_norm(cfg.MODEL.DLA.NORM, planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation + ) + self.bn2 = get_norm(cfg.MODEL.DLA.NORM, planes) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, cfg, inplanes, planes, stride: int=1, dilation: int=1) -> None: + super().__init__() + expansion = Bottleneck.expansion + bottle_planes = planes // expansion + self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) + self.bn1 = get_norm(cfg.MODEL.DLA.NORM, bottle_planes) + self.conv2 = nn.Conv2d( + bottle_planes, + bottle_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn2 = get_norm(cfg.MODEL.DLA.NORM, bottle_planes) + self.conv3 = nn.Conv2d(bottle_planes, planes, kernel_size=1, bias=False) + self.bn3 = get_norm(cfg.MODEL.DLA.NORM, planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class Root(nn.Module): + def __init__(self, cfg, in_channels, out_channels, kernel_size: int, residual) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + bias=False, + padding=(kernel_size - 1) // 2, + ) + self.bn = get_norm(cfg.MODEL.DLA.NORM, out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class Tree(nn.Module): + def __init__( + self, + cfg, + levels, + block, + in_channels, + out_channels, + stride: int=1, + level_root: bool=False, + root_dim: int=0, + root_kernel_size: int=1, + dilation: int=1, + root_residual: bool=False, + ) -> None: + super().__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block(cfg, in_channels, out_channels, stride, dilation=dilation) + self.tree2 = block(cfg, out_channels, out_channels, 1, dilation=dilation) + else: + self.tree1 = Tree( + cfg, + levels - 1, + block, + in_channels, + out_channels, + stride, + root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + ) + self.tree2 = Tree( + cfg, + levels - 1, + block, + out_channels, + out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, + root_residual=root_residual, + ) + if levels == 1: + self.root = Root(cfg, root_dim, out_channels, root_kernel_size, root_residual) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride) + if in_channels != out_channels: + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + get_norm(cfg.MODEL.DLA.NORM, out_channels), + ) + + def forward(self, x, residual=None, children=None): + if self.training and residual is not None: + x = x + residual.sum() * 0.0 + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(Backbone): + def __init__(self, cfg, levels, channels, block=BasicBlock, residual_root: bool=False) -> None: + super().__init__() + self.cfg = cfg + self.channels = channels + + self._out_features = [f"dla{i}" for i in range(6)] + self._out_feature_channels = {k: channels[i] for i, k in enumerate(self._out_features)} + self._out_feature_strides = {k: 2**i for i, k in enumerate(self._out_features)} + + self.base_layer = nn.Sequential( + nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False), + get_norm(cfg.MODEL.DLA.NORM, channels[0]), + nn.ReLU(inplace=True), + ) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree( + cfg, + levels[2], + block, + channels[1], + channels[2], + 2, + level_root=False, + root_residual=residual_root, + ) + self.level3 = Tree( + cfg, + levels[3], + block, + channels[2], + channels[3], + 2, + level_root=True, + root_residual=residual_root, + ) + self.level4 = Tree( + cfg, + levels[4], + block, + channels[3], + channels[4], + 2, + level_root=True, + root_residual=residual_root, + ) + self.level5 = Tree( + cfg, + levels[5], + block, + channels[4], + channels[5], + 2, + level_root=True, + root_residual=residual_root, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + + self.load_pretrained_model(data="imagenet", name="dla34", hash="ba72cf86") + + def load_pretrained_model(self, data, name: str, hash) -> None: + model_url = get_model_url(data, name, hash) + model_weights = model_zoo.load_url(model_url) + del model_weights["fc.weight"] + del model_weights["fc.bias"] + print("Loading pretrained DLA!") + self.load_state_dict(model_weights, strict=True) + + def _make_conv_level(self, inplanes, planes, convs, stride: int=1, dilation: int=1): + modules = [] + for i in range(convs): + modules.extend( + [ + nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, + bias=False, + dilation=dilation, + ), + get_norm(self.cfg.MODEL.DLA.NORM, planes), + nn.ReLU(inplace=True), + ] + ) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, x): + y = {} + x = self.base_layer(x) + for i in range(6): + name = f"level{i}" + x = getattr(self, name)(x) + y[f"dla{i}"] = x + return y + + +def fill_up_weights(up) -> None: + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2.0 * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + +class Conv(nn.Module): + def __init__(self, chi, cho, norm) -> None: + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(chi, cho, kernel_size=1, stride=1, bias=False), + get_norm(norm, cho), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.conv(x) + + +class DeformConv(nn.Module): + def __init__(self, chi, cho, norm) -> None: + super().__init__() + self.actf = nn.Sequential(get_norm(norm, cho), nn.ReLU(inplace=True)) + self.offset = Conv2d(chi, 27, kernel_size=3, stride=1, padding=1, dilation=1) + self.conv = ModulatedDeformConv( + chi, cho, kernel_size=3, stride=1, padding=1, dilation=1, deformable_groups=1 + ) + nn.init.constant_(self.offset.weight, 0) + nn.init.constant_(self.offset.bias, 0) + + def forward(self, x): + offset_mask = self.offset(x) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + x = self.conv(x, offset, mask) + x = self.actf(x) + return x + + +class IDAUp(nn.Module): + def __init__(self, o, channels, up_f, norm: str="FrozenBN", node_type=Conv) -> None: + super().__init__() + for i in range(1, len(channels)): + c = channels[i] + f = int(up_f[i]) + proj = node_type(c, o, norm) + node = node_type(o, o, norm) + + up = nn.ConvTranspose2d( + o, o, f * 2, stride=f, padding=f // 2, output_padding=0, groups=o, bias=False + ) + fill_up_weights(up) + + setattr(self, "proj_" + str(i), proj) + setattr(self, "up_" + str(i), up) + setattr(self, "node_" + str(i), node) + + def forward(self, layers, startp, endp) -> None: + for i in range(startp + 1, endp): + upsample = getattr(self, "up_" + str(i - startp)) + project = getattr(self, "proj_" + str(i - startp)) + layers[i] = upsample(project(layers[i])) + node = getattr(self, "node_" + str(i - startp)) + layers[i] = node(layers[i] + layers[i - 1]) + + +DLAUP_NODE_MAP = { + "conv": Conv, + "dcn": DeformConv, +} + + +class DLAUP(Backbone): + def __init__(self, bottom_up, in_features, norm, dlaup_node: str="conv") -> None: + super().__init__() + assert isinstance(bottom_up, Backbone) + self.bottom_up = bottom_up + input_shapes = bottom_up.output_shape() + in_strides = [input_shapes[f].stride for f in in_features] + in_channels = [input_shapes[f].channels for f in in_features] + in_levels = [int(math.log2(input_shapes[f].stride)) for f in in_features] + self.in_features = in_features + out_features = [f"dlaup{l}" for l in in_levels] + self._out_features = out_features + self._out_feature_channels = { + f"dlaup{l}": in_channels[i] for i, l in enumerate(in_levels) + } + self._out_feature_strides = {f"dlaup{l}": 2**l for l in in_levels} + + print("self._out_features", self._out_features) + print("self._out_feature_channels", self._out_feature_channels) + print("self._out_feature_strides", self._out_feature_strides) + self._size_divisibility = 32 + + node_type = DLAUP_NODE_MAP[dlaup_node] + + self.startp = int(math.log2(in_strides[0])) + self.channels = in_channels + channels = list(in_channels) + scales = np.array([2**i for i in range(len(out_features))], dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr( + self, + f"ida_{i}", + IDAUp( + channels[j], + in_channels[j:], + scales[j:] // scales[j], + norm=norm, + node_type=node_type, + ), + ) + scales[j + 1 :] = scales[j] + in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]] + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + bottom_up_features = self.bottom_up(x) + layers = [bottom_up_features[f] for f in self.in_features] + out = [layers[-1]] # start with 32 + for i in range(len(layers) - 1): + ida = getattr(self, f"ida_{i}") + ida(layers, len(layers) - i - 2, len(layers)) + out.insert(0, layers[-1]) + ret = {} + for k, v in zip(self._out_features, out, strict=False): + ret[k] = v + # import pdb; pdb.set_trace() + return ret + + +def dla34(cfg, pretrained: Optional[bool]=None): # DLA-34 + model = DLA(cfg, [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock) + return model + + +class LastLevelP6P7(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7 from + C5 feature. + """ + + def __init__(self, in_channels, out_channels) -> None: + super().__init__() + self.num_levels = 2 + self.in_feature = "dla5" + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_dla_fpn3_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + + depth_to_creator = {"dla34": dla34} + bottom_up = depth_to_creator[f"dla{cfg.MODEL.DLA.NUM_LAYERS}"](cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=None, + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + + return backbone + + +@BACKBONE_REGISTRY.register() +def build_dla_fpn5_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + + depth_to_creator = {"dla34": dla34} + bottom_up = depth_to_creator[f"dla{cfg.MODEL.DLA.NUM_LAYERS}"](cfg) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + in_channels_top = bottom_up.output_shape()["dla5"].channels + + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7(in_channels_top, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + + return backbone + + +@BACKBONE_REGISTRY.register() +def build_dlaup_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + + depth_to_creator = {"dla34": dla34} + bottom_up = depth_to_creator[f"dla{cfg.MODEL.DLA.NUM_LAYERS}"](cfg) + + backbone = DLAUP( + bottom_up=bottom_up, + in_features=cfg.MODEL.DLA.DLAUP_IN_FEATURES, + norm=cfg.MODEL.DLA.NORM, + dlaup_node=cfg.MODEL.DLA.DLAUP_NODE, + ) + + return backbone diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/fpn_p5.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/fpn_p5.py new file mode 100644 index 0000000000..4ce285b6c6 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/fpn_p5.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from detectron2.layers import ShapeSpec +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN +from detectron2.modeling.backbone.resnet import build_resnet_backbone +import fvcore.nn.weight_init as weight_init +from torch import nn +import torch.nn.functional as F + + +class LastLevelP6P7_P5(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7 from + C5 feature. + """ + + def __init__(self, in_channels, out_channels) -> None: + super().__init__() + self.num_levels = 2 + self.in_feature = "p5" + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_p67_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_p35_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=None, + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/res2net.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/res2net.py new file mode 100644 index 0000000000..e04400032e --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/res2net.py @@ -0,0 +1,810 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# This file is modified from https://github.com/Res2Net/Res2Net-detectron2/blob/master/detectron2/modeling/backbone/resnet.py +# The original file is under Apache-2.0 License +from detectron2.layers import ( + CNNBlockBase, + Conv2d, + DeformConv, + ModulatedDeformConv, + ShapeSpec, + get_norm, +) +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from .bifpn import BiFPN +from .fpn_p5 import LastLevelP6P7_P5 +from typing import Optional + +__all__ = [ + "BasicBlock", + "BasicStem", + "BottleneckBlock", + "DeformBottleneckBlock", + "ResNet", + "ResNetBlockBase", + "build_res2net_backbone", + "make_stage", +] + + +ResNetBlockBase = CNNBlockBase +""" +Alias for backward compatibiltiy. +""" + + +class BasicBlock(CNNBlockBase): + """ + The basic residual block for ResNet-18 and ResNet-34, with two 3x3 conv layers + and a projection shortcut if needed. + """ + + def __init__(self, in_channels, out_channels, *, stride: int=1, norm: str="BN") -> None: + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the first conv. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + self.conv2 = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + out = self.conv2(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottle2neck residual block used by Res2Net-50, 101 and 152. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride: int=1, + num_groups: int=1, + norm: str="BN", + stride_in_1x1: bool=False, + dilation: int=1, + basewidth: int=26, + scale: int=4, + ) -> None: + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False + ), + Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False, + norm=get_norm(norm, out_channels), + ), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + width = bottleneck_channels // scale + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + if scale == 1: + self.nums = 1 + else: + self.nums = scale - 1 + if self.in_channels != self.out_channels and stride_3x3 != 2: + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride_3x3, padding=1) + + convs = [] + bns = [] + for _i in range(self.nums): + convs.append( + nn.Conv2d( + width, + width, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + ) + ) + bns.append(get_norm(norm, width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + self.scale = scale + self.width = width + self.in_channels = in_channels + self.out_channels = out_channels + self.stride_3x3 = stride_3x3 + for layer in [self.conv1, self.conv3]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + if self.shortcut is not None: + for layer in self.shortcut.modules(): + if isinstance(layer, Conv2d): + weight_init.c2_msra_fill(layer) + + for layer in self.convs: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient γ is initialized + # to be 1, except for each residual block's last BN + # where γ is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0 or self.in_channels != self.out_channels: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = F.relu_(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + if self.scale != 1 and self.stride_3x3 == 1: + out = torch.cat((out, spx[self.nums]), 1) + elif self.scale != 1 and self.stride_3x3 == 2: + out = torch.cat((out, self.pool(spx[self.nums])), 1) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class DeformBottleneckBlock(ResNetBlockBase): + """ + Not implemented for res2net yet. + Similar to :class:`BottleneckBlock`, but with deformable conv in the 3x3 convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride: int=1, + num_groups: int=1, + norm: str="BN", + stride_in_1x1: bool=False, + dilation: int=1, + deform_modulated: bool=False, + deform_num_groups: int=1, + basewidth: int=26, + scale: int=4, + ) -> None: + super().__init__(in_channels, out_channels, stride) + self.deform_modulated = deform_modulated + + if in_channels != out_channels: + # self.shortcut = Conv2d( + # in_channels, + # out_channels, + # kernel_size=1, + # stride=stride, + # bias=False, + # norm=get_norm(norm, out_channels), + # ) + self.shortcut = nn.Sequential( + nn.AvgPool2d( + kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False + ), + Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False, + norm=get_norm(norm, out_channels), + ), + ) + else: + self.shortcut = None + + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + width = bottleneck_channels // scale + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + if scale == 1: + self.nums = 1 + else: + self.nums = scale - 1 + if self.in_channels != self.out_channels and stride_3x3 != 2: + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride_3x3, padding=1) + + if deform_modulated: + deform_conv_op = ModulatedDeformConv + # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size + offset_channels = 27 + else: + deform_conv_op = DeformConv + offset_channels = 18 + + # self.conv2_offset = Conv2d( + # bottleneck_channels, + # offset_channels * deform_num_groups, + # kernel_size=3, + # stride=stride_3x3, + # padding=1 * dilation, + # dilation=dilation, + # ) + # self.conv2 = deform_conv_op( + # bottleneck_channels, + # bottleneck_channels, + # kernel_size=3, + # stride=stride_3x3, + # padding=1 * dilation, + # bias=False, + # groups=num_groups, + # dilation=dilation, + # deformable_groups=deform_num_groups, + # norm=get_norm(norm, bottleneck_channels), + # ) + + conv2_offsets = [] + convs = [] + bns = [] + for _i in range(self.nums): + conv2_offsets.append( + Conv2d( + width, + offset_channels * deform_num_groups, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + ) + ) + convs.append( + deform_conv_op( + width, + width, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + deformable_groups=deform_num_groups, + ) + ) + bns.append(get_norm(norm, width)) + self.conv2_offsets = nn.ModuleList(conv2_offsets) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + self.scale = scale + self.width = width + self.in_channels = in_channels + self.out_channels = out_channels + self.stride_3x3 = stride_3x3 + # for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + # if layer is not None: # shortcut can be None + # weight_init.c2_msra_fill(layer) + + # nn.init.constant_(self.conv2_offset.weight, 0) + # nn.init.constant_(self.conv2_offset.bias, 0) + for layer in [self.conv1, self.conv3]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + if self.shortcut is not None: + for layer in self.shortcut.modules(): + if isinstance(layer, Conv2d): + weight_init.c2_msra_fill(layer) + + for layer in self.convs: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + for layer in self.conv2_offsets: + if layer.weight is not None: + nn.init.constant_(layer.weight, 0) + if layer.bias is not None: + nn.init.constant_(layer.bias, 0) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + # if self.deform_modulated: + # offset_mask = self.conv2_offset(out) + # offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + # offset = torch.cat((offset_x, offset_y), dim=1) + # mask = mask.sigmoid() + # out = self.conv2(out, offset, mask) + # else: + # offset = self.conv2_offset(out) + # out = self.conv2(out, offset) + # out = F.relu_(out) + + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0 or self.in_channels != self.out_channels: + sp = spx[i].contiguous() + else: + sp = sp + spx[i].contiguous() + + # sp = self.convs[i](sp) + if self.deform_modulated: + offset_mask = self.conv2_offsets[i](sp) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + sp = self.convs[i](sp, offset, mask) + else: + offset = self.conv2_offsets[i](sp) + sp = self.convs[i](sp, offset) + sp = F.relu_(self.bns[i](sp)) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + if self.scale != 1 and self.stride_3x3 == 1: + out = torch.cat((out, spx[self.nums]), 1) + elif self.scale != 1 and self.stride_3x3 == 2: + out = torch.cat((out, self.pool(spx[self.nums])), 1) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +def make_stage(block_class, num_blocks: int, first_stride, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks just like those in a ResNet stage. + Args: + block_class (type): a subclass of ResNetBlockBase + num_blocks (int): + first_stride (int): the stride of the first block. The other blocks will have stride=1. + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of every block. + Returns: + list[nn.Module]: a list of block module. + """ + assert "stride" not in kwargs, "Stride of blocks in make_stage cannot be changed." + blocks = [] + for i in range(num_blocks): + blocks.append( + block_class( + in_channels=in_channels, + out_channels=out_channels, + stride=first_stride if i == 0 else 1, + **kwargs, + ) + ) + in_channels = out_channels + return blocks + + +class BasicStem(CNNBlockBase): + """ + The standard ResNet stem (layers before the first residual block). + """ + + def __init__(self, in_channels: int=3, out_channels: int=64, norm: str="BN") -> None: + """ + Args: + norm (str or callable): norm after the first conv layer. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, 4) + self.in_channels = in_channels + self.conv1 = nn.Sequential( + Conv2d( + in_channels, + 32, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + get_norm(norm, 32), + nn.ReLU(inplace=True), + Conv2d( + 32, + 32, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + get_norm(norm, 32), + nn.ReLU(inplace=True), + Conv2d( + 32, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ) + self.bn1 = get_norm(norm, out_channels) + + for layer in self.conv1: + if isinstance(layer, Conv2d): + weight_init.c2_msra_fill(layer) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +class ResNet(Backbone): + def __init__(self, stem, stages, num_classes: Optional[int]=None, out_features=None) -> None: + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stages_and_names = [] + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "res" + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stages_and_names.append((stage, name)) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + + def forward(self, x): + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for stage, name in self.stages_and_names: + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + def freeze(self, freeze_at: int=0): + """ + Freeze the first several stages of the ResNet. Commonly used in + fine-tuning. + Args: + freeze_at (int): number of stem and stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + the first stage, etc. + Returns: + nn.Module: this ResNet itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, (stage, _) in enumerate(self.stages_and_names, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + +@BACKBONE_REGISTRY.register() +def build_res2net_backbone(cfg, input_shape): + """ + Create a Res2Net instance from config. + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + + # fmt: off + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + scale = 4 + bottleneck_channels = num_groups * width_per_group * scale + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + # fmt: on + assert res5_dilation in {1, 2}, f"res5_dilation cannot be {res5_dilation}." + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + + if depth in [18, 34]: + assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" + assert not any(deform_on_per_stage), ( + "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34" + ) + assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" + assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" + + stages = [] + + # Avoid creating variables without gradients + # It consumes extra memory and may cause allreduce to fail + out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] + max_stage_idx = max(out_stage_idx) + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "first_stride": first_stride, + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs["block_class"] = BasicBlock + else: + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + stage_kargs["scale"] = scale + + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + blocks = make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features).freeze(freeze_at) + + +@BACKBONE_REGISTRY.register() +def build_p67_res2net_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_res2net_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7_P5(out_channels, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_res2net_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_res2net_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + backbone = BiFPN( + cfg=cfg, + bottom_up=bottom_up, + in_features=in_features, + out_channels=cfg.MODEL.BIFPN.OUT_CHANNELS, + norm=cfg.MODEL.BIFPN.NORM, + num_levels=cfg.MODEL.BIFPN.NUM_LEVELS, + num_bifpn=cfg.MODEL.BIFPN.NUM_BIFPN, + separable_conv=cfg.MODEL.BIFPN.SEPARABLE_CONV, + ) + return backbone diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py new file mode 100644 index 0000000000..63186b05c5 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py @@ -0,0 +1,341 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from typing import Sequence + +COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * 255).astype(np.uint8).reshape(1300, 1, 1, 3) + + +def _get_color_image(heatmap): + heatmap = heatmap.reshape(heatmap.shape[0], heatmap.shape[1], heatmap.shape[2], 1) + if heatmap.shape[0] == 1: + color_map = ( + (heatmap * np.ones((1, 1, 1, 3), np.uint8) * 255).max(axis=0).astype(np.uint8) + ) # H, W, 3 + else: + color_map = (heatmap * COLORS[: heatmap.shape[0]]).max(axis=0).astype(np.uint8) # H, W, 3 + + return color_map + + +def _blend_image(image, color_map, a: float=0.7): + color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) + ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8) + return ret + + +def _blend_image_heatmaps(image, color_maps, a: float=0.7): + merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32) + for color_map in color_maps: + color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) + merges = np.maximum(merges, color_map) + ret = np.clip(image * (1 - a) + merges * a, 0, 255).astype(np.uint8) + return ret + + +def _decompose_level(x, shapes_per_level, N): + """ + x: LNHiWi x C + """ + x = x.view(x.shape[0], -1) + ret = [] + st = 0 + for l in range(len(shapes_per_level)): + ret.append([]) + h = shapes_per_level[l][0].int().item() + w = shapes_per_level[l][1].int().item() + for i in range(N): + ret[l].append(x[st + h * w * i : st + h * w * (i + 1)].view(h, w, -1).permute(2, 0, 1)) + st += h * w * N + return ret + + +def _imagelist_to_tensor(images): + images = [x for x in images] + image_sizes = [x.shape[-2:] for x in images] + h = max([size[0] for size in image_sizes]) + w = max([size[1] for size in image_sizes]) + S = 32 + h, w = ((h - 1) // S + 1) * S, ((w - 1) // S + 1) * S + images = [F.pad(x, (0, w - x.shape[2], 0, h - x.shape[1], 0, 0)) for x in images] + images = torch.stack(images) + return images + + +def _ind2il(ind, shapes_per_level, N): + r = ind + l = 0 + S = 0 + while r - S >= N * shapes_per_level[l][0] * shapes_per_level[l][1]: + S += N * shapes_per_level[l][0] * shapes_per_level[l][1] + l += 1 + i = (r - S) // (shapes_per_level[l][0] * shapes_per_level[l][1]) + return i, l + + +def debug_train( + images, + gt_instances, + flattened_hms, + reg_targets, + labels: Sequence[str], + pos_inds, + shapes_per_level, + locations, + strides: Sequence[int], +) -> None: + """ + images: N x 3 x H x W + flattened_hms: LNHiWi x C + shapes_per_level: L x 2 [(H_i, W_i)] + locations: LNHiWi x 2 + """ + reg_inds = torch.nonzero(reg_targets.max(dim=1)[0] > 0).squeeze(1) + N = len(images) + images = _imagelist_to_tensor(images) + repeated_locations = [torch.cat([loc] * N, dim=0) for loc in locations] + locations = torch.cat(repeated_locations, dim=0) + gt_hms = _decompose_level(flattened_hms, shapes_per_level, N) + masks = flattened_hms.new_zeros((flattened_hms.shape[0], 1)) + masks[pos_inds] = 1 + masks = _decompose_level(masks, shapes_per_level, N) + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0) + color_maps = [] + for l in range(len(gt_hms)): + color_map = _get_color_image(gt_hms[l][i].detach().cpu().numpy()) + color_maps.append(color_map) + cv2.imshow(f"gthm_{l}", color_map) + blend = _blend_image_heatmaps(image.copy(), color_maps) + if gt_instances is not None: + bboxes = gt_instances[i].gt_boxes.tensor + for j in range(len(bboxes)): + bbox = bboxes[j] + cv2.rectangle( + blend, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (0, 0, 255), + 3, + cv2.LINE_AA, + ) + + for j in range(len(pos_inds)): + image_id, l = _ind2il(pos_inds[j], shapes_per_level, N) + if image_id != i: + continue + loc = locations[pos_inds[j]] + cv2.drawMarker( + blend, (int(loc[0]), int(loc[1])), (0, 255, 255), markerSize=(l + 1) * 16 + ) + + for j in range(len(reg_inds)): + image_id, l = _ind2il(reg_inds[j], shapes_per_level, N) + if image_id != i: + continue + ltrb = reg_targets[reg_inds[j]] + ltrb *= strides[l] + loc = locations[reg_inds[j]] + bbox = [(loc[0] - ltrb[0]), (loc[1] - ltrb[1]), (loc[0] + ltrb[2]), (loc[1] + ltrb[3])] + cv2.rectangle( + blend, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (255, 0, 0), + 1, + cv2.LINE_AA, + ) + cv2.circle(blend, (int(loc[0]), int(loc[1])), 2, (255, 0, 0), -1) + + cv2.imshow("blend", blend) + cv2.waitKey() + + +def debug_test( + images, + logits_pred, + reg_pred, + agn_hm_pred=None, + preds=None, + vis_thresh: float=0.3, + debug_show_name: bool=False, + mult_agn: bool=False, +) -> None: + """ + images: N x 3 x H x W + class_target: LNHiWi x C + cat_agn_heatmap: LNHiWi + shapes_per_level: L x 2 [(H_i, W_i)] + """ + if preds is None: + preds = [] + if agn_hm_pred is None: + agn_hm_pred = [] + len(images) + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0) + image.copy().astype(np.uint8) + pred_image = image.copy().astype(np.uint8) + color_maps = [] + L = len(logits_pred) + for l in range(L): + if logits_pred[0] is not None: + stride = min(image.shape[0], image.shape[1]) / min( + logits_pred[l][i].shape[1], logits_pred[l][i].shape[2] + ) + else: + stride = min(image.shape[0], image.shape[1]) / min( + agn_hm_pred[l][i].shape[1], agn_hm_pred[l][i].shape[2] + ) + stride = stride if stride < 60 else 64 if stride < 100 else 128 + if logits_pred[0] is not None: + if mult_agn: + logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i] + color_map = _get_color_image(logits_pred[l][i].detach().cpu().numpy()) + color_maps.append(color_map) + cv2.imshow(f"predhm_{l}", color_map) + + if debug_show_name: + from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES + + cat2name = [x["name"] for x in LVIS_CATEGORIES] + for j in range(len(preds[i].scores) if preds is not None else 0): + if preds[i].scores[j] > vis_thresh: + bbox = ( + preds[i].proposal_boxes[j] + if preds[i].has("proposal_boxes") + else preds[i].pred_boxes[j] + ) + bbox = bbox.tensor[0].detach().cpu().numpy().astype(np.int32) + cat = int(preds[i].pred_classes[j]) if preds[i].has("pred_classes") else 0 + cl = COLORS[cat, 0, 0] + cv2.rectangle( + pred_image, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (int(cl[0]), int(cl[1]), int(cl[2])), + 2, + cv2.LINE_AA, + ) + if debug_show_name: + txt = "{}{:.1f}".format( + cat2name[cat] if cat > 0 else "", preds[i].scores[j] + ) + font = cv2.FONT_HERSHEY_SIMPLEX + cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] + cv2.rectangle( + pred_image, + (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), + (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), + (int(cl[0]), int(cl[1]), int(cl[2])), + -1, + ) + cv2.putText( + pred_image, + txt, + (int(bbox[0]), int(bbox[1] - 2)), + font, + 0.5, + (0, 0, 0), + thickness=1, + lineType=cv2.LINE_AA, + ) + + if agn_hm_pred[l] is not None: + agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy() + agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(1, 1, 3)).astype(np.uint8) + cv2.imshow(f"agn_hm_{l}", agn_hm_) + blend = _blend_image_heatmaps(image.copy(), color_maps) + cv2.imshow("blend", blend) + cv2.imshow("preds", pred_image) + cv2.waitKey() + + +global cnt +cnt = 0 + + +def debug_second_stage( + images, instances, proposals=None, vis_thresh: float=0.3, save_debug: bool=False, debug_show_name: bool=False +) -> None: + images = _imagelist_to_tensor(images) + if debug_show_name: + from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES + + cat2name = [x["name"] for x in LVIS_CATEGORIES] + for i in range(len(images)): + image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy() + if instances[i].has("gt_boxes"): + bboxes = instances[i].gt_boxes.tensor.cpu().numpy() + scores = np.ones(bboxes.shape[0]) + cats = instances[i].gt_classes.cpu().numpy() + else: + bboxes = instances[i].pred_boxes.tensor.cpu().numpy() + scores = instances[i].scores.cpu().numpy() + cats = instances[i].pred_classes.cpu().numpy() + for j in range(len(bboxes)): + if scores[j] > vis_thresh: + bbox = bboxes[j] + cl = COLORS[cats[j], 0, 0] + cl = (int(cl[0]), int(cl[1]), int(cl[2])) + cv2.rectangle( + image, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + cl, + 2, + cv2.LINE_AA, + ) + if debug_show_name: + cat = cats[j] + txt = "{}{:.1f}".format(cat2name[cat] if cat > 0 else "", scores[j]) + font = cv2.FONT_HERSHEY_SIMPLEX + cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] + cv2.rectangle( + image, + (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), + (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), + (int(cl[0]), int(cl[1]), int(cl[2])), + -1, + ) + cv2.putText( + image, + txt, + (int(bbox[0]), int(bbox[1] - 2)), + font, + 0.5, + (0, 0, 0), + thickness=1, + lineType=cv2.LINE_AA, + ) + if proposals is not None: + proposal_image = ( + images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy() + ) + bboxes = proposals[i].proposal_boxes.tensor.cpu().numpy() + if proposals[i].has("scores"): + scores = proposals[i].scores.cpu().numpy() + else: + scores = proposals[i].objectness_logits.sigmoid().cpu().numpy() + for j in range(len(bboxes)): + if scores[j] > vis_thresh: + bbox = bboxes[j] + cl = (209, 159, 83) + cv2.rectangle( + proposal_image, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + cl, + 2, + cv2.LINE_AA, + ) + + cv2.imshow("image", image) + if proposals is not None: + cv2.imshow("proposals", proposal_image) + if save_debug: + global cnt + cnt += 1 + cv2.imwrite(f"output/save_debug/{cnt}.jpg", proposal_image) + cv2.waitKey() diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py new file mode 100644 index 0000000000..cd68ed3f40 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py @@ -0,0 +1,912 @@ +from detectron2.config import configurable +from detectron2.layers import cat +from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY +from detectron2.structures import Boxes, Instances +from detectron2.utils.comm import get_world_size +import torch +from torch import nn + +from ..debug import debug_test, debug_train +from ..layers.heatmap_focal_loss import binary_heatmap_focal_loss_jit, heatmap_focal_loss_jit +from ..layers.iou_loss import IOULoss +from ..layers.ml_nms import ml_nms +from .centernet_head import CenterNetHead +from .utils import _transpose, reduce_sum +from typing import Sequence + +__all__ = ["CenterNet"] + +INF = 100000000 + + +@PROPOSAL_GENERATOR_REGISTRY.register() +class CenterNet(nn.Module): + @configurable + def __init__( + self, + # input_shape: Dict[str, ShapeSpec], + in_channels: int=256, + *, + num_classes: int=80, + in_features=("p3", "p4", "p5", "p6", "p7"), + strides: Sequence[int]=(8, 16, 32, 64, 128), + score_thresh: float=0.05, + hm_min_overlap: float=0.8, + loc_loss_type: str="giou", + min_radius: int=4, + hm_focal_alpha: float=0.25, + hm_focal_beta: int=4, + loss_gamma: float=2.0, + reg_weight: float=2.0, + not_norm_reg: bool=True, + with_agn_hm: bool=False, + only_proposal: bool=False, + as_proposal: bool=False, + not_nms: bool=False, + pos_weight: float=1.0, + neg_weight: float=1.0, + sigmoid_clamp: float=1e-4, + ignore_high_fp=-1.0, + center_nms: bool=False, + sizes_of_interest=None, + more_pos: bool=False, + more_pos_thresh: float=0.2, + more_pos_topk: int=9, + pre_nms_topk_train: int=1000, + pre_nms_topk_test: int=1000, + post_nms_topk_train: int=100, + post_nms_topk_test: int=100, + nms_thresh_train: float=0.6, + nms_thresh_test: float=0.6, + no_reduce: bool=False, + not_clamp_box: bool=False, + debug: bool=False, + vis_thresh: float=0.5, + pixel_mean=None, + pixel_std=None, + device: str="cuda", + centernet_head=None, + ) -> None: + if pixel_std is None: + pixel_std = [1.0, 1.0, 1.0] + if pixel_mean is None: + pixel_mean = [103.53, 116.28, 123.675] + if sizes_of_interest is None: + sizes_of_interest = [[0, 80], [64, 160], [128, 320], [256, 640], [512, 10000000]] + super().__init__() + self.num_classes = num_classes + self.in_features = in_features + self.strides = strides + self.score_thresh = score_thresh + self.min_radius = min_radius + self.hm_focal_alpha = hm_focal_alpha + self.hm_focal_beta = hm_focal_beta + self.loss_gamma = loss_gamma + self.reg_weight = reg_weight + self.not_norm_reg = not_norm_reg + self.with_agn_hm = with_agn_hm + self.only_proposal = only_proposal + self.as_proposal = as_proposal + self.not_nms = not_nms + self.pos_weight = pos_weight + self.neg_weight = neg_weight + self.sigmoid_clamp = sigmoid_clamp + self.ignore_high_fp = ignore_high_fp + self.center_nms = center_nms + self.sizes_of_interest = sizes_of_interest + self.more_pos = more_pos + self.more_pos_thresh = more_pos_thresh + self.more_pos_topk = more_pos_topk + self.pre_nms_topk_train = pre_nms_topk_train + self.pre_nms_topk_test = pre_nms_topk_test + self.post_nms_topk_train = post_nms_topk_train + self.post_nms_topk_test = post_nms_topk_test + self.nms_thresh_train = nms_thresh_train + self.nms_thresh_test = nms_thresh_test + self.no_reduce = no_reduce + self.not_clamp_box = not_clamp_box + + self.debug = debug + self.vis_thresh = vis_thresh + if self.center_nms: + self.not_nms = True + self.iou_loss = IOULoss(loc_loss_type) + assert (not self.only_proposal) or self.with_agn_hm + # delta for rendering heatmap + self.delta = (1 - hm_min_overlap) / (1 + hm_min_overlap) + if centernet_head is None: + self.centernet_head = CenterNetHead( + in_channels=in_channels, + num_levels=len(in_features), + with_agn_hm=with_agn_hm, + only_proposal=only_proposal, + ) + else: + self.centernet_head = centernet_head + if self.debug: + pixel_mean = torch.Tensor(pixel_mean).to(torch.device(device)).view(3, 1, 1) + pixel_std = torch.Tensor(pixel_std).to(torch.device(device)).view(3, 1, 1) + self.denormalizer = lambda x: x * pixel_std + pixel_mean + + @classmethod + def from_config(cls, cfg, input_shape): + ret = { + # 'input_shape': input_shape, + "in_channels": input_shape[cfg.MODEL.CENTERNET.IN_FEATURES[0]].channels, + "num_classes": cfg.MODEL.CENTERNET.NUM_CLASSES, + "in_features": cfg.MODEL.CENTERNET.IN_FEATURES, + "strides": cfg.MODEL.CENTERNET.FPN_STRIDES, + "score_thresh": cfg.MODEL.CENTERNET.INFERENCE_TH, + "loc_loss_type": cfg.MODEL.CENTERNET.LOC_LOSS_TYPE, + "hm_min_overlap": cfg.MODEL.CENTERNET.HM_MIN_OVERLAP, + "min_radius": cfg.MODEL.CENTERNET.MIN_RADIUS, + "hm_focal_alpha": cfg.MODEL.CENTERNET.HM_FOCAL_ALPHA, + "hm_focal_beta": cfg.MODEL.CENTERNET.HM_FOCAL_BETA, + "loss_gamma": cfg.MODEL.CENTERNET.LOSS_GAMMA, + "reg_weight": cfg.MODEL.CENTERNET.REG_WEIGHT, + "not_norm_reg": cfg.MODEL.CENTERNET.NOT_NORM_REG, + "with_agn_hm": cfg.MODEL.CENTERNET.WITH_AGN_HM, + "only_proposal": cfg.MODEL.CENTERNET.ONLY_PROPOSAL, + "as_proposal": cfg.MODEL.CENTERNET.AS_PROPOSAL, + "not_nms": cfg.MODEL.CENTERNET.NOT_NMS, + "pos_weight": cfg.MODEL.CENTERNET.POS_WEIGHT, + "neg_weight": cfg.MODEL.CENTERNET.NEG_WEIGHT, + "sigmoid_clamp": cfg.MODEL.CENTERNET.SIGMOID_CLAMP, + "ignore_high_fp": cfg.MODEL.CENTERNET.IGNORE_HIGH_FP, + "center_nms": cfg.MODEL.CENTERNET.CENTER_NMS, + "sizes_of_interest": cfg.MODEL.CENTERNET.SOI, + "more_pos": cfg.MODEL.CENTERNET.MORE_POS, + "more_pos_thresh": cfg.MODEL.CENTERNET.MORE_POS_THRESH, + "more_pos_topk": cfg.MODEL.CENTERNET.MORE_POS_TOPK, + "pre_nms_topk_train": cfg.MODEL.CENTERNET.PRE_NMS_TOPK_TRAIN, + "pre_nms_topk_test": cfg.MODEL.CENTERNET.PRE_NMS_TOPK_TEST, + "post_nms_topk_train": cfg.MODEL.CENTERNET.POST_NMS_TOPK_TRAIN, + "post_nms_topk_test": cfg.MODEL.CENTERNET.POST_NMS_TOPK_TEST, + "nms_thresh_train": cfg.MODEL.CENTERNET.NMS_TH_TRAIN, + "nms_thresh_test": cfg.MODEL.CENTERNET.NMS_TH_TEST, + "no_reduce": cfg.MODEL.CENTERNET.NO_REDUCE, + "not_clamp_box": cfg.INPUT.NOT_CLAMP_BOX, + "debug": cfg.DEBUG, + "vis_thresh": cfg.VIS_THRESH, + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + "device": cfg.MODEL.DEVICE, + "centernet_head": CenterNetHead( + cfg, [input_shape[f] for f in cfg.MODEL.CENTERNET.IN_FEATURES] + ), + } + return ret + + def forward(self, images, features_dict, gt_instances): + features = [features_dict[f] for f in self.in_features] + clss_per_level, reg_pred_per_level, agn_hm_pred_per_level = self.centernet_head(features) + grids = self.compute_grids(features) + shapes_per_level = grids[0].new_tensor( + [(x.shape[2], x.shape[3]) for x in reg_pred_per_level] + ) + + if not self.training: + return self.inference( + images, clss_per_level, reg_pred_per_level, agn_hm_pred_per_level, grids + ) + else: + pos_inds, labels, reg_targets, flattened_hms = self._get_ground_truth( + grids, shapes_per_level, gt_instances + ) + # logits_pred: M x F, reg_pred: M x 4, agn_hm_pred: M + logits_pred, reg_pred, agn_hm_pred = self._flatten_outputs( + clss_per_level, reg_pred_per_level, agn_hm_pred_per_level + ) + + if self.more_pos: + # add more pixels as positive if \ + # 1. they are within the center3x3 region of an object + # 2. their regression losses are small (= 0).squeeze(1) + reg_pred = reg_pred[reg_inds] + reg_targets_pos = reg_targets[reg_inds] + reg_weight_map = flattened_hms.max(dim=1)[0] + reg_weight_map = reg_weight_map[reg_inds] + reg_weight_map = reg_weight_map * 0 + 1 if self.not_norm_reg else reg_weight_map + if self.no_reduce: + reg_norm = max(reg_weight_map.sum(), 1) + else: + reg_norm = max(reduce_sum(reg_weight_map.sum()).item() / num_gpus, 1) + + reg_loss = ( + self.reg_weight + * self.iou_loss(reg_pred, reg_targets_pos, reg_weight_map, reduction="sum") + / reg_norm + ) + losses["loss_centernet_loc"] = reg_loss + + if self.with_agn_hm: + cat_agn_heatmap = flattened_hms.max(dim=1)[0] # M + agn_pos_loss, agn_neg_loss = binary_heatmap_focal_loss_jit( + agn_hm_pred.float(), + cat_agn_heatmap.float(), + pos_inds, + alpha=self.hm_focal_alpha, + beta=self.hm_focal_beta, + gamma=self.loss_gamma, + sigmoid_clamp=self.sigmoid_clamp, + ignore_high_fp=self.ignore_high_fp, + ) + agn_pos_loss = self.pos_weight * agn_pos_loss / num_pos_avg + agn_neg_loss = self.neg_weight * agn_neg_loss / num_pos_avg + losses["loss_centernet_agn_pos"] = agn_pos_loss + losses["loss_centernet_agn_neg"] = agn_neg_loss + + if self.debug: + print("losses", losses) + print("total_num_pos", total_num_pos) + return losses + + def compute_grids(self, features): + grids = [] + for level, feature in enumerate(features): + h, w = feature.size()[-2:] + shifts_x = torch.arange( + 0, + w * self.strides[level], + step=self.strides[level], + dtype=torch.float32, + device=feature.device, + ) + shifts_y = torch.arange( + 0, + h * self.strides[level], + step=self.strides[level], + dtype=torch.float32, + device=feature.device, + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + grids_per_level = torch.stack((shift_x, shift_y), dim=1) + self.strides[level] // 2 + grids.append(grids_per_level) + return grids + + def _get_ground_truth(self, grids, shapes_per_level, gt_instances): + """ + Input: + grids: list of tensors [(hl x wl, 2)]_l + shapes_per_level: list of tuples L x 2: + gt_instances: gt instances + Retuen: + pos_inds: N + labels: N + reg_targets: M x 4 + flattened_hms: M x C or M x 1 + N: number of objects in all images + M: number of pixels from all FPN levels + """ + + # get positive pixel index + if not self.more_pos: + pos_inds, labels = self._get_label_inds(gt_instances, shapes_per_level) + else: + pos_inds, labels = None, None + heatmap_channels = self.num_classes + L = len(grids) + num_loc_list = [len(loc) for loc in grids] + strides = torch.cat( + [shapes_per_level.new_ones(num_loc_list[l]) * self.strides[l] for l in range(L)] + ).float() # M + reg_size_ranges = torch.cat( + [ + shapes_per_level.new_tensor(self.sizes_of_interest[l]) + .float() + .view(1, 2) + .expand(num_loc_list[l], 2) + for l in range(L) + ] + ) # M x 2 + grids = torch.cat(grids, dim=0) # M x 2 + M = grids.shape[0] + + reg_targets = [] + flattened_hms = [] + for i in range(len(gt_instances)): # images + boxes = gt_instances[i].gt_boxes.tensor # N x 4 + area = gt_instances[i].gt_boxes.area() # N + gt_classes = gt_instances[i].gt_classes # N in [0, self.num_classes] + + N = boxes.shape[0] + if N == 0: + reg_targets.append(grids.new_zeros((M, 4)) - INF) + flattened_hms.append( + grids.new_zeros((M, 1 if self.only_proposal else heatmap_channels)) + ) + continue + + l = grids[:, 0].view(M, 1) - boxes[:, 0].view(1, N) # M x N + t = grids[:, 1].view(M, 1) - boxes[:, 1].view(1, N) # M x N + r = boxes[:, 2].view(1, N) - grids[:, 0].view(M, 1) # M x N + b = boxes[:, 3].view(1, N) - grids[:, 1].view(M, 1) # M x N + reg_target = torch.stack([l, t, r, b], dim=2) # M x N x 4 + + centers = (boxes[:, [0, 1]] + boxes[:, [2, 3]]) / 2 # N x 2 + centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2 + strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) + centers_discret = ( + (centers_expanded / strides_expanded).int() * strides_expanded + ).float() + strides_expanded / 2 # M x N x 2 + + is_peak = ((grids.view(M, 1, 2).expand(M, N, 2) - centers_discret) ** 2).sum( + dim=2 + ) == 0 # M x N + is_in_boxes = reg_target.min(dim=2)[0] > 0 # M x N + is_center3x3 = self.get_center3x3(grids, centers, strides) & is_in_boxes # M x N + is_cared_in_the_level = self.assign_reg_fpn(reg_target, reg_size_ranges) # M x N + reg_mask = is_center3x3 & is_cared_in_the_level # M x N + + dist2 = ((grids.view(M, 1, 2).expand(M, N, 2) - centers_expanded) ** 2).sum( + dim=2 + ) # M x N + dist2[is_peak] = 0 + radius2 = self.delta**2 * 2 * area # N + radius2 = torch.clamp(radius2, min=self.min_radius**2) + weighted_dist2 = dist2 / radius2.view(1, N).expand(M, N) # M x N + reg_target = self._get_reg_targets( + reg_target, weighted_dist2.clone(), reg_mask, area + ) # M x 4 + + if self.only_proposal: + flattened_hm = self._create_agn_heatmaps_from_dist(weighted_dist2.clone()) # M x 1 + else: + flattened_hm = self._create_heatmaps_from_dist( + weighted_dist2.clone(), gt_classes, channels=heatmap_channels + ) # M x C + + reg_targets.append(reg_target) + flattened_hms.append(flattened_hm) + + # transpose im first training_targets to level first ones + reg_targets = _transpose(reg_targets, num_loc_list) + flattened_hms = _transpose(flattened_hms, num_loc_list) + for l in range(len(reg_targets)): + reg_targets[l] = reg_targets[l] / float(self.strides[l]) + reg_targets = cat([x for x in reg_targets], dim=0) # MB x 4 + flattened_hms = cat([x for x in flattened_hms], dim=0) # MB x C + + return pos_inds, labels, reg_targets, flattened_hms + + def _get_label_inds(self, gt_instances, shapes_per_level): + """ + Inputs: + gt_instances: [n_i], sum n_i = N + shapes_per_level: L x 2 [(h_l, w_l)]_L + Returns: + pos_inds: N' + labels: N' + """ + pos_inds = [] + labels = [] + L = len(self.strides) + B = len(gt_instances) + shapes_per_level = shapes_per_level.long() + loc_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]).long() # L + level_bases = [] + s = 0 + for l in range(L): + level_bases.append(s) + s = s + B * loc_per_level[l] + level_bases = shapes_per_level.new_tensor(level_bases).long() # L + strides_default = shapes_per_level.new_tensor(self.strides).float() # L + for im_i in range(B): + targets_per_im = gt_instances[im_i] + bboxes = targets_per_im.gt_boxes.tensor # n x 4 + n = bboxes.shape[0] + centers = (bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2 # n x 2 + centers = centers.view(n, 1, 2).expand(n, L, 2).contiguous() + if self.not_clamp_box: + h, w = gt_instances[im_i]._image_size + centers[:, :, 0].clamp_(min=0).clamp_(max=w - 1) + centers[:, :, 1].clamp_(min=0).clamp_(max=h - 1) + strides = strides_default.view(1, L, 1).expand(n, L, 2) + centers_inds = (centers / strides).long() # n x L x 2 + Ws = shapes_per_level[:, 1].view(1, L).expand(n, L) + pos_ind = ( + level_bases.view(1, L).expand(n, L) + + im_i * loc_per_level.view(1, L).expand(n, L) + + centers_inds[:, :, 1] * Ws + + centers_inds[:, :, 0] + ) # n x L + is_cared_in_the_level = self.assign_fpn_level(bboxes) + pos_ind = pos_ind[is_cared_in_the_level].view(-1) + label = ( + targets_per_im.gt_classes.view(n, 1).expand(n, L)[is_cared_in_the_level].view(-1) + ) + + pos_inds.append(pos_ind) # n' + labels.append(label) # n' + pos_inds = torch.cat(pos_inds, dim=0).long() + labels = torch.cat(labels, dim=0) + return pos_inds, labels # N, N + + def assign_fpn_level(self, boxes): + """ + Inputs: + boxes: n x 4 + size_ranges: L x 2 + Return: + is_cared_in_the_level: n x L + """ + size_ranges = boxes.new_tensor(self.sizes_of_interest).view( + len(self.sizes_of_interest), 2 + ) # L x 2 + crit = ((boxes[:, 2:] - boxes[:, :2]) ** 2).sum(dim=1) ** 0.5 / 2 # n + n, L = crit.shape[0], size_ranges.shape[0] + crit = crit.view(n, 1).expand(n, L) + size_ranges_expand = size_ranges.view(1, L, 2).expand(n, L, 2) + is_cared_in_the_level = (crit >= size_ranges_expand[:, :, 0]) & ( + crit <= size_ranges_expand[:, :, 1] + ) + return is_cared_in_the_level + + def assign_reg_fpn(self, reg_targets_per_im, size_ranges): + """ + TODO (Xingyi): merge it with assign_fpn_level + Inputs: + reg_targets_per_im: M x N x 4 + size_ranges: M x 2 + """ + crit = ((reg_targets_per_im[:, :, :2] + reg_targets_per_im[:, :, 2:]) ** 2).sum( + dim=2 + ) ** 0.5 / 2 # M x N + is_cared_in_the_level = (crit >= size_ranges[:, [0]]) & (crit <= size_ranges[:, [1]]) + return is_cared_in_the_level + + def _get_reg_targets(self, reg_targets, dist, mask, area): + """ + reg_targets (M x N x 4): long tensor + dist (M x N) + is_*: M x N + """ + dist[mask == 0] = INF * 1.0 + min_dist, min_inds = dist.min(dim=1) # M + reg_targets_per_im = reg_targets[range(len(reg_targets)), min_inds] # M x N x 4 --> M x 4 + reg_targets_per_im[min_dist == INF] = -INF + return reg_targets_per_im + + def _create_heatmaps_from_dist(self, dist, labels: Sequence[str], channels): + """ + dist: M x N + labels: N + return: + heatmaps: M x C + """ + heatmaps = dist.new_zeros((dist.shape[0], channels)) + for c in range(channels): + inds = labels == c # N + if inds.int().sum() == 0: + continue + heatmaps[:, c] = torch.exp(-dist[:, inds].min(dim=1)[0]) + zeros = heatmaps[:, c] < 1e-4 + heatmaps[zeros, c] = 0 + return heatmaps + + def _create_agn_heatmaps_from_dist(self, dist): + """ + TODO (Xingyi): merge it with _create_heatmaps_from_dist + dist: M x N + return: + heatmaps: M x 1 + """ + heatmaps = dist.new_zeros((dist.shape[0], 1)) + heatmaps[:, 0] = torch.exp(-dist.min(dim=1)[0]) + zeros = heatmaps < 1e-4 + heatmaps[zeros] = 0 + return heatmaps + + def _flatten_outputs(self, clss, reg_pred, agn_hm_pred): + # Reshape: (N, F, Hl, Wl) -> (N, Hl, Wl, F) -> (sum_l N*Hl*Wl, F) + clss = ( + cat([x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]) for x in clss], dim=0) + if clss[0] is not None + else None + ) + reg_pred = cat([x.permute(0, 2, 3, 1).reshape(-1, 4) for x in reg_pred], dim=0) + agn_hm_pred = ( + cat([x.permute(0, 2, 3, 1).reshape(-1) for x in agn_hm_pred], dim=0) + if self.with_agn_hm + else None + ) + return clss, reg_pred, agn_hm_pred + + def get_center3x3(self, locations, centers, strides: Sequence[int]): + """ + Inputs: + locations: M x 2 + centers: N x 2 + strides: M + """ + M, N = locations.shape[0], centers.shape[0] + locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2 + centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2 + strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N + centers_discret = ( + (centers_expanded / strides_expanded).int() * strides_expanded + ).float() + strides_expanded / 2 # M x N x 2 + dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs() + dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs() + return (dist_x <= strides_expanded[:, :, 0]) & (dist_y <= strides_expanded[:, :, 0]) + + @torch.no_grad() + def inference(self, images, clss_per_level, reg_pred_per_level, agn_hm_pred_per_level, grids): + logits_pred = [x.sigmoid() if x is not None else None for x in clss_per_level] + agn_hm_pred_per_level = [ + x.sigmoid() if x is not None else None for x in agn_hm_pred_per_level + ] + + if self.only_proposal: + proposals = self.predict_instances( + grids, + agn_hm_pred_per_level, + reg_pred_per_level, + images.image_sizes, + [None for _ in agn_hm_pred_per_level], + ) + else: + proposals = self.predict_instances( + grids, logits_pred, reg_pred_per_level, images.image_sizes, agn_hm_pred_per_level + ) + if self.as_proposal or self.only_proposal: + for p in range(len(proposals)): + proposals[p].proposal_boxes = proposals[p].get("pred_boxes") + proposals[p].objectness_logits = proposals[p].get("scores") + proposals[p].remove("pred_boxes") + + if self.debug: + debug_test( + [self.denormalizer(x) for x in images], + logits_pred, + reg_pred_per_level, + agn_hm_pred_per_level, + preds=proposals, + vis_thresh=self.vis_thresh, + debug_show_name=False, + ) + return proposals, {} + + @torch.no_grad() + def predict_instances( + self, grids, logits_pred, reg_pred, image_sizes: Sequence[int], agn_hm_pred, is_proposal: bool=False + ): + sampled_boxes = [] + for l in range(len(grids)): + sampled_boxes.append( + self.predict_single_level( + grids[l], + logits_pred[l], + reg_pred[l] * self.strides[l], + image_sizes, + agn_hm_pred[l], + l, + is_proposal=is_proposal, + ) + ) + boxlists = list(zip(*sampled_boxes, strict=False)) + boxlists = [Instances.cat(boxlist) for boxlist in boxlists] + boxlists = self.nms_and_topK(boxlists, nms=not self.not_nms) + return boxlists + + @torch.no_grad() + def predict_single_level( + self, grids, heatmap, reg_pred, image_sizes: Sequence[int], agn_hm, level, is_proposal: bool=False + ): + N, C, H, W = heatmap.shape + # put in the same format as grids + if self.center_nms: + heatmap_nms = nn.functional.max_pool2d(heatmap, (3, 3), stride=1, padding=1) + heatmap = heatmap * (heatmap_nms == heatmap).float() + heatmap = heatmap.permute(0, 2, 3, 1) # N x H x W x C + heatmap = heatmap.reshape(N, -1, C) # N x HW x C + box_regression = reg_pred.view(N, 4, H, W).permute(0, 2, 3, 1) # N x H x W x 4 + box_regression = box_regression.reshape(N, -1, 4) + + candidate_inds = heatmap > self.score_thresh # 0.05 + pre_nms_top_n = candidate_inds.view(N, -1).sum(1) # N + pre_nms_topk = self.pre_nms_topk_train if self.training else self.pre_nms_topk_test + pre_nms_top_n = pre_nms_top_n.clamp(max=pre_nms_topk) # N + + if agn_hm is not None: + agn_hm = agn_hm.view(N, 1, H, W).permute(0, 2, 3, 1) + agn_hm = agn_hm.reshape(N, -1) + heatmap = heatmap * agn_hm[:, :, None] + + results = [] + for i in range(N): + per_box_cls = heatmap[i] # HW x C + per_candidate_inds = candidate_inds[i] # n + per_box_cls = per_box_cls[per_candidate_inds] # n + + per_candidate_nonzeros = per_candidate_inds.nonzero() # n + per_box_loc = per_candidate_nonzeros[:, 0] # n + per_class = per_candidate_nonzeros[:, 1] # n + + per_box_regression = box_regression[i] # HW x 4 + per_box_regression = per_box_regression[per_box_loc] # n x 4 + per_grids = grids[per_box_loc] # n x 2 + + per_pre_nms_top_n = pre_nms_top_n[i] # 1 + + if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): + per_box_cls, top_k_indices = per_box_cls.topk(per_pre_nms_top_n, sorted=False) + per_class = per_class[top_k_indices] + per_box_regression = per_box_regression[top_k_indices] + per_grids = per_grids[top_k_indices] + + detections = torch.stack( + [ + per_grids[:, 0] - per_box_regression[:, 0], + per_grids[:, 1] - per_box_regression[:, 1], + per_grids[:, 0] + per_box_regression[:, 2], + per_grids[:, 1] + per_box_regression[:, 3], + ], + dim=1, + ) # n x 4 + + # avoid invalid boxes in RoI heads + detections[:, 2] = torch.max(detections[:, 2], detections[:, 0] + 0.01) + detections[:, 3] = torch.max(detections[:, 3], detections[:, 1] + 0.01) + boxlist = Instances(image_sizes[i]) + boxlist.scores = torch.sqrt(per_box_cls) if self.with_agn_hm else per_box_cls # n + # import pdb; pdb.set_trace() + boxlist.pred_boxes = Boxes(detections) + boxlist.pred_classes = per_class + results.append(boxlist) + return results + + @torch.no_grad() + def nms_and_topK(self, boxlists, nms: bool=True): + num_images = len(boxlists) + results = [] + for i in range(num_images): + nms_thresh = self.nms_thresh_train if self.training else self.nms_thresh_test + result = ml_nms(boxlists[i], nms_thresh) if nms else boxlists[i] + if self.debug: + print("#proposals before nms", len(boxlists[i])) + print("#proposals after nms", len(result)) + num_dets = len(result) + post_nms_topk = self.post_nms_topk_train if self.training else self.post_nms_topk_test + if num_dets > post_nms_topk: + cls_scores = result.scores + image_thresh, _ = torch.kthvalue( + cls_scores.float().cpu(), num_dets - post_nms_topk + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + if self.debug: + print("#proposals after filter", len(result)) + results.append(result) + return results + + @torch.no_grad() + def _add_more_pos(self, reg_pred, gt_instances, shapes_per_level): + labels, level_masks, c33_inds, c33_masks, c33_regs = self._get_c33_inds( + gt_instances, shapes_per_level + ) + N, L, K = labels.shape[0], len(self.strides), 9 + c33_inds[c33_masks == 0] = 0 + reg_pred_c33 = reg_pred[c33_inds].detach() # N x L x K + invalid_reg = c33_masks == 0 + c33_regs_expand = c33_regs.view(N * L * K, 4).clamp(min=0) + if N > 0: + with torch.no_grad(): + c33_reg_loss = ( + self.iou_loss( + reg_pred_c33.view(N * L * K, 4), c33_regs_expand, None, reduction="none" + ) + .view(N, L, K) + .detach() + ) # N x L x K + else: + c33_reg_loss = reg_pred_c33.new_zeros((N, L, K)).detach() + c33_reg_loss[invalid_reg] = INF # N x L x K + c33_reg_loss.view(N * L, K)[level_masks.view(N * L), 4] = 0 # real center + c33_reg_loss = c33_reg_loss.view(N, L * K) + if N == 0: + loss_thresh = c33_reg_loss.new_ones(N).float() + else: + loss_thresh = torch.kthvalue(c33_reg_loss, self.more_pos_topk, dim=1)[0] # N + loss_thresh[loss_thresh > self.more_pos_thresh] = self.more_pos_thresh # N + new_pos = c33_reg_loss.view(N, L, K) < loss_thresh.view(N, 1, 1).expand(N, L, K) + pos_inds = c33_inds[new_pos].view(-1) # P + labels = labels.view(N, 1, 1).expand(N, L, K)[new_pos].view(-1) + return pos_inds, labels + + @torch.no_grad() + def _get_c33_inds(self, gt_instances, shapes_per_level): + """ + TODO (Xingyi): The current implementation is ugly. Refactor. + Get the center (and the 3x3 region near center) locations of each objects + Inputs: + gt_instances: [n_i], sum n_i = N + shapes_per_level: L x 2 [(h_l, w_l)]_L + """ + labels = [] + level_masks = [] + c33_inds = [] + c33_masks = [] + c33_regs = [] + L = len(self.strides) + B = len(gt_instances) + shapes_per_level = shapes_per_level.long() + loc_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]).long() # L + level_bases = [] + s = 0 + for l in range(L): + level_bases.append(s) + s = s + B * loc_per_level[l] + level_bases = shapes_per_level.new_tensor(level_bases).long() # L + strides_default = shapes_per_level.new_tensor(self.strides).float() # L + K = 9 + dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0, 1]).long() + dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1, 1]).long() + for im_i in range(B): + targets_per_im = gt_instances[im_i] + bboxes = targets_per_im.gt_boxes.tensor # n x 4 + n = bboxes.shape[0] + if n == 0: + continue + centers = (bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2 # n x 2 + centers = centers.view(n, 1, 2).expand(n, L, 2) + + strides = strides_default.view(1, L, 1).expand(n, L, 2) # + centers_inds = (centers / strides).long() # n x L x 2 + center_grids = centers_inds * strides + strides // 2 # n x L x 2 + l = center_grids[:, :, 0] - bboxes[:, 0].view(n, 1).expand(n, L) + t = center_grids[:, :, 1] - bboxes[:, 1].view(n, 1).expand(n, L) + r = bboxes[:, 2].view(n, 1).expand(n, L) - center_grids[:, :, 0] + b = bboxes[:, 3].view(n, 1).expand(n, L) - center_grids[:, :, 1] # n x L + reg = torch.stack([l, t, r, b], dim=2) # n x L x 4 + reg = reg / strides_default.view(1, L, 1).expand(n, L, 4).float() + + Ws = shapes_per_level[:, 1].view(1, L).expand(n, L) + Hs = shapes_per_level[:, 0].view(1, L).expand(n, L) + expand_Ws = Ws.view(n, L, 1).expand(n, L, K) + expand_Hs = Hs.view(n, L, 1).expand(n, L, K) + label = targets_per_im.gt_classes.view(n).clone() + mask = reg.min(dim=2)[0] >= 0 # n x L + mask = mask & self.assign_fpn_level(bboxes) + labels.append(label) # n + level_masks.append(mask) # n x L + + Dy = dy.view(1, 1, K).expand(n, L, K) + Dx = dx.view(1, 1, K).expand(n, L, K) + c33_ind = ( + level_bases.view(1, L, 1).expand(n, L, K) + + im_i * loc_per_level.view(1, L, 1).expand(n, L, K) + + (centers_inds[:, :, 1:2].expand(n, L, K) + Dy) * expand_Ws + + (centers_inds[:, :, 0:1].expand(n, L, K) + Dx) + ) # n x L x K + + c33_mask = ( + ((centers_inds[:, :, 1:2].expand(n, L, K) + dy) < expand_Hs) + & ((centers_inds[:, :, 1:2].expand(n, L, K) + dy) >= 0) + & ((centers_inds[:, :, 0:1].expand(n, L, K) + dx) < expand_Ws) + & ((centers_inds[:, :, 0:1].expand(n, L, K) + dx) >= 0) + ) + # TODO (Xingyi): think about better way to implement this + # Currently it hard codes the 3x3 region + c33_reg = reg.view(n, L, 1, 4).expand(n, L, K, 4).clone() + c33_reg[:, :, [0, 3, 6], 0] -= 1 + c33_reg[:, :, [0, 3, 6], 2] += 1 + c33_reg[:, :, [2, 5, 8], 0] += 1 + c33_reg[:, :, [2, 5, 8], 2] -= 1 + c33_reg[:, :, [0, 1, 2], 1] -= 1 + c33_reg[:, :, [0, 1, 2], 3] += 1 + c33_reg[:, :, [6, 7, 8], 1] += 1 + c33_reg[:, :, [6, 7, 8], 3] -= 1 + c33_mask = c33_mask & (c33_reg.min(dim=3)[0] >= 0) # n x L x K + c33_inds.append(c33_ind) + c33_masks.append(c33_mask) + c33_regs.append(c33_reg) + + if len(level_masks) > 0: + labels = torch.cat(labels, dim=0) + level_masks = torch.cat(level_masks, dim=0) + c33_inds = torch.cat(c33_inds, dim=0).long() + c33_regs = torch.cat(c33_regs, dim=0) + c33_masks = torch.cat(c33_masks, dim=0) + else: + labels = shapes_per_level.new_zeros(0).long() + level_masks = shapes_per_level.new_zeros((0, L)).bool() + c33_inds = shapes_per_level.new_zeros((0, L, K)).long() + c33_regs = shapes_per_level.new_zeros((0, L, K, 4)).float() + c33_masks = shapes_per_level.new_zeros((0, L, K)).bool() + return labels, level_masks, c33_inds, c33_masks, c33_regs # N x L, N x L x K diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet_head.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet_head.py new file mode 100644 index 0000000000..e2e1852e27 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet_head.py @@ -0,0 +1,168 @@ +import math + +from detectron2.config import configurable +from detectron2.layers import get_norm +import torch +from torch import nn +from torch.nn import functional as F + +from ..layers.deform_conv import DFConv2d + +__all__ = ["CenterNetHead"] + + +class Scale(nn.Module): + def __init__(self, init_value: float=1.0) -> None: + super().__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale + + +class CenterNetHead(nn.Module): + @configurable + def __init__( + self, + # input_shape: List[ShapeSpec], + in_channels, + num_levels: int, + *, + num_classes: int=80, + with_agn_hm: bool=False, + only_proposal: bool=False, + norm: str="GN", + num_cls_convs: int=4, + num_box_convs: int=4, + num_share_convs: int=0, + use_deformable: bool=False, + prior_prob: float=0.01, + ) -> None: + super().__init__() + self.num_classes = num_classes + self.with_agn_hm = with_agn_hm + self.only_proposal = only_proposal + self.out_kernel = 3 + + head_configs = { + "cls": (num_cls_convs if not self.only_proposal else 0, use_deformable), + "bbox": (num_box_convs, use_deformable), + "share": (num_share_convs, use_deformable), + } + + # in_channels = [s.channels for s in input_shape] + # assert len(set(in_channels)) == 1, \ + # "Each level must have the same channel!" + # in_channels = in_channels[0] + channels = { + "cls": in_channels, + "bbox": in_channels, + "share": in_channels, + } + for head in head_configs: + tower = [] + num_convs, use_deformable = head_configs[head] + channel = channels[head] + for i in range(num_convs): + if use_deformable and i == num_convs - 1: + conv_func = DFConv2d + else: + conv_func = nn.Conv2d + tower.append( + conv_func( + in_channels if i == 0 else channel, + channel, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + ) + if norm == "GN" and channel % 32 != 0: + tower.append(nn.GroupNorm(25, channel)) + elif norm != "": + tower.append(get_norm(norm, channel)) + tower.append(nn.ReLU()) + self.add_module(f"{head}_tower", nn.Sequential(*tower)) + + self.bbox_pred = nn.Conv2d( + in_channels, 4, kernel_size=self.out_kernel, stride=1, padding=self.out_kernel // 2 + ) + + self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(num_levels)]) + + for modules in [ + self.cls_tower, + self.bbox_tower, + self.share_tower, + self.bbox_pred, + ]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + torch.nn.init.constant_(self.bbox_pred.bias, 8.0) + prior_prob = prior_prob + bias_value = -math.log((1 - prior_prob) / prior_prob) + + if self.with_agn_hm: + self.agn_hm = nn.Conv2d( + in_channels, 1, kernel_size=self.out_kernel, stride=1, padding=self.out_kernel // 2 + ) + torch.nn.init.constant_(self.agn_hm.bias, bias_value) + torch.nn.init.normal_(self.agn_hm.weight, std=0.01) + + if not self.only_proposal: + cls_kernel_size = self.out_kernel + self.cls_logits = nn.Conv2d( + in_channels, + self.num_classes, + kernel_size=cls_kernel_size, + stride=1, + padding=cls_kernel_size // 2, + ) + + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + + @classmethod + def from_config(cls, cfg, input_shape): + ret = { + # 'input_shape': input_shape, + "in_channels": next(s.channels for s in input_shape), + "num_levels": len(input_shape), + "num_classes": cfg.MODEL.CENTERNET.NUM_CLASSES, + "with_agn_hm": cfg.MODEL.CENTERNET.WITH_AGN_HM, + "only_proposal": cfg.MODEL.CENTERNET.ONLY_PROPOSAL, + "norm": cfg.MODEL.CENTERNET.NORM, + "num_cls_convs": cfg.MODEL.CENTERNET.NUM_CLS_CONVS, + "num_box_convs": cfg.MODEL.CENTERNET.NUM_BOX_CONVS, + "num_share_convs": cfg.MODEL.CENTERNET.NUM_SHARE_CONVS, + "use_deformable": cfg.MODEL.CENTERNET.USE_DEFORMABLE, + "prior_prob": cfg.MODEL.CENTERNET.PRIOR_PROB, + } + return ret + + def forward(self, x): + clss = [] + bbox_reg = [] + agn_hms = [] + for l, feature in enumerate(x): + feature = self.share_tower(feature) + cls_tower = self.cls_tower(feature) + bbox_tower = self.bbox_tower(feature) + if not self.only_proposal: + clss.append(self.cls_logits(cls_tower)) + else: + clss.append(None) + + if self.with_agn_hm: + agn_hms.append(self.agn_hm(bbox_tower)) + else: + agn_hms.append(None) + reg = self.bbox_pred(bbox_tower) + reg = self.scales[l](reg) + bbox_reg.append(F.relu(reg)) + + return clss, bbox_reg, agn_hms diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/utils.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/utils.py new file mode 100644 index 0000000000..ea962943ca --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/utils.py @@ -0,0 +1,32 @@ +from detectron2.utils.comm import get_world_size +import torch + +# from .data import CenterNetCrop + +__all__ = ["_transpose", "reduce_sum"] + +INF = 1000000000 + + +def _transpose(training_targets, num_loc_list): + """ + This function is used to transpose image first training targets to + level first ones + :return: level first training targets + """ + for im_i in range(len(training_targets)): + training_targets[im_i] = torch.split(training_targets[im_i], num_loc_list, dim=0) + + targets_level_first = [] + for targets_per_level in zip(*training_targets, strict=False): + targets_level_first.append(torch.cat(targets_per_level, dim=0)) + return targets_level_first + + +def reduce_sum(tensor): + world_size = get_world_size() + if world_size < 2: + return tensor + tensor = tensor.clone() + torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) + return tensor diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/deform_conv.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/deform_conv.py new file mode 100644 index 0000000000..643660c6bc --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/deform_conv.py @@ -0,0 +1,114 @@ +from detectron2.layers import Conv2d +import torch +from torch import nn + + +class _NewEmptyTensorOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return _NewEmptyTensorOp.apply(grad, shape), None + + +class DFConv2d(nn.Module): + """Deformable convolutional layer""" + + def __init__( + self, + in_channels, + out_channels, + with_modulated_dcn: bool=True, + kernel_size: int=3, + stride: int=1, + groups: int=1, + dilation: int=1, + deformable_groups: int=1, + bias: bool=False, + padding=None, + ) -> None: + super().__init__() + if isinstance(kernel_size, list | tuple): + assert isinstance(stride, list | tuple) + assert isinstance(dilation, list | tuple) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(dilation) == 2 + padding = ( + dilation[0] * (kernel_size[0] - 1) // 2, + dilation[1] * (kernel_size[1] - 1) // 2, + ) + offset_base_channels = kernel_size[0] * kernel_size[1] + else: + padding = dilation * (kernel_size - 1) // 2 + offset_base_channels = kernel_size * kernel_size + if with_modulated_dcn: + from detectron2.layers.deform_conv import ModulatedDeformConv + + offset_channels = offset_base_channels * 3 # default: 27 + conv_block = ModulatedDeformConv + else: + from detectron2.layers.deform_conv import DeformConv + + offset_channels = offset_base_channels * 2 # default: 18 + conv_block = DeformConv + self.offset = Conv2d( + in_channels, + deformable_groups * offset_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=1, + dilation=dilation, + ) + nn.init.constant_(self.offset.weight, 0) + nn.init.constant_(self.offset.bias, 0) + """ + for l in [self.offset, ]: + nn.init.kaiming_uniform_(l.weight, a=1) + torch.nn.init.constant_(l.bias, 0.) + """ + self.conv = conv_block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + deformable_groups=deformable_groups, + bias=bias, + ) + self.with_modulated_dcn = with_modulated_dcn + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.offset_split = offset_base_channels * deformable_groups * 2 + + def forward(self, x, return_offset: bool=False): + if x.numel() > 0: + if not self.with_modulated_dcn: + offset_mask = self.offset(x) + x = self.conv(x, offset_mask) + else: + offset_mask = self.offset(x) + offset = offset_mask[:, : self.offset_split, :, :] + mask = offset_mask[:, self.offset_split :, :, :].sigmoid() + x = self.conv(x, offset, mask) + if return_offset: + return x, offset_mask + return x + # get output shape + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // d + 1 + for i, p, di, k, d in zip( + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride, strict=False + ) + ] + output_shape = [x.shape[0], self.conv.weight.shape[0], *output_shape] + return _NewEmptyTensorOp.apply(x, output_shape) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py new file mode 100644 index 0000000000..50ccf371c9 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py @@ -0,0 +1,91 @@ +import torch +from typing import Sequence + + +# TODO: merge these two function +def heatmap_focal_loss( + inputs, + targets, + pos_inds, + labels: Sequence[str], + alpha: float = -1, + beta: float = 4, + gamma: float = 2, + reduction: str = "sum", + sigmoid_clamp: float = 1e-4, + ignore_high_fp: float = -1.0, +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: (sum_l N*Hl*Wl, C) + targets: (sum_l N*Hl*Wl, C) + pos_inds: N + labels: N + Returns: + Loss tensor with the reduction option applied. + """ + pred = torch.clamp(inputs.sigmoid_(), min=sigmoid_clamp, max=1 - sigmoid_clamp) + neg_weights = torch.pow(1 - targets, beta) + pos_pred_pix = pred[pos_inds] # N x C + pos_pred = pos_pred_pix.gather(1, labels.unsqueeze(1)) + pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, gamma) + neg_loss = torch.log(1 - pred) * torch.pow(pred, gamma) * neg_weights + + if ignore_high_fp > 0: + not_high_fp = (pred < ignore_high_fp).float() + neg_loss = not_high_fp * neg_loss + + if reduction == "sum": + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + if alpha >= 0: + pos_loss = alpha * pos_loss + neg_loss = (1 - alpha) * neg_loss + + return -pos_loss, -neg_loss + + +heatmap_focal_loss_jit = torch.jit.script(heatmap_focal_loss) +# heatmap_focal_loss_jit = heatmap_focal_loss + + +def binary_heatmap_focal_loss( + inputs, + targets, + pos_inds, + alpha: float = -1, + beta: float = 4, + gamma: float = 2, + sigmoid_clamp: float = 1e-4, + ignore_high_fp: float = -1.0, +): + """ + Args: + inputs: (sum_l N*Hl*Wl,) + targets: (sum_l N*Hl*Wl,) + pos_inds: N + Returns: + Loss tensor with the reduction option applied. + """ + pred = torch.clamp(inputs.sigmoid_(), min=sigmoid_clamp, max=1 - sigmoid_clamp) + neg_weights = torch.pow(1 - targets, beta) + pos_pred = pred[pos_inds] # N + pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, gamma) + neg_loss = torch.log(1 - pred) * torch.pow(pred, gamma) * neg_weights + if ignore_high_fp > 0: + not_high_fp = (pred < ignore_high_fp).float() + neg_loss = not_high_fp * neg_loss + + pos_loss = -pos_loss.sum() + neg_loss = -neg_loss.sum() + + if alpha >= 0: + pos_loss = alpha * pos_loss + neg_loss = (1 - alpha) * neg_loss + + return pos_loss, neg_loss + + +binary_heatmap_focal_loss_jit = torch.jit.script(binary_heatmap_focal_loss) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/iou_loss.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/iou_loss.py new file mode 100644 index 0000000000..55fa2a186d --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/iou_loss.py @@ -0,0 +1,115 @@ +import torch +from torch import nn + + +class IOULoss(nn.Module): + def __init__(self, loc_loss_type: str="iou") -> None: + super().__init__() + self.loc_loss_type = loc_loss_type + + def forward(self, pred, target, weight=None, reduction: str="sum"): + pred_left = pred[:, 0] + pred_top = pred[:, 1] + pred_right = pred[:, 2] + pred_bottom = pred[:, 3] + + target_left = target[:, 0] + target_top = target[:, 1] + target_right = target[:, 2] + target_bottom = target[:, 3] + + target_aera = (target_left + target_right) * (target_top + target_bottom) + pred_aera = (pred_left + pred_right) * (pred_top + pred_bottom) + + w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right) + h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top) + + g_w_intersect = torch.max(pred_left, target_left) + torch.max(pred_right, target_right) + g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top) + ac_uion = g_w_intersect * g_h_intersect + + area_intersect = w_intersect * h_intersect + area_union = target_aera + pred_aera - area_intersect + + ious = (area_intersect + 1.0) / (area_union + 1.0) + gious = ious - (ac_uion - area_union) / ac_uion + if self.loc_loss_type == "iou": + losses = -torch.log(ious) + elif self.loc_loss_type == "linear_iou": + losses = 1 - ious + elif self.loc_loss_type == "giou": + losses = 1 - gious + else: + raise NotImplementedError + + if weight is not None: + losses = losses * weight + else: + losses = losses + + if reduction == "sum": + return losses.sum() + elif reduction == "batch": + return losses.sum(dim=[1]) + elif reduction == "none": + return losses + else: + raise NotImplementedError + + +def giou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Generalized Intersection over Union Loss (Hamid Rezatofighi et. al) + https://arxiv.org/abs/1902.09630 + Gradient-friendly IoU loss with an additional penalty that is non-zero when the + boxes do not overlap and scales with the size of their smallest enclosing box. + This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. + Args: + boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + assert (x2 >= x1).all(), "bad box: x1 larger than x2" + assert (y2 >= y1).all(), "bad box: y1 larger than y2" + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsctk = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + iouk = intsctk / (unionk + eps) + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + + area_c = (xc2 - xc1) * (yc2 - yc1) + miouk = iouk - ((area_c - unionk) / (area_c + eps)) + + loss = 1 - miouk + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/ml_nms.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/ml_nms.py new file mode 100644 index 0000000000..429c986cfe --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/ml_nms.py @@ -0,0 +1,29 @@ +from detectron2.layers import batched_nms + + +def ml_nms(boxlist, nms_thresh, max_proposals=-1, score_field: str="scores", label_field: str="labels"): + """ + Performs non-maximum suppression on a boxlist, with scores specified + in a boxlist field via score_field. + Arguments: + boxlist(BoxList) + nms_thresh (float) + max_proposals (int): if > 0, then only the top max_proposals are kept + after non-maximum suppression + score_field (str) + """ + if nms_thresh <= 0: + return boxlist + if boxlist.has("pred_boxes"): + boxes = boxlist.pred_boxes.tensor + labels = boxlist.pred_classes + else: + boxes = boxlist.proposal_boxes.tensor + labels = boxlist.proposal_boxes.tensor.new_zeros(len(boxlist.proposal_boxes.tensor)) + scores = boxlist.scores + + keep = batched_nms(boxes, scores, labels, nms_thresh) + if max_proposals > 0: + keep = keep[:max_proposals] + boxlist = boxlist[keep] + return boxlist diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/meta_arch/centernet_detector.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/meta_arch/centernet_detector.py new file mode 100644 index 0000000000..02cd3da416 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/meta_arch/centernet_detector.py @@ -0,0 +1,63 @@ +from detectron2.modeling import build_backbone, build_proposal_generator, detector_postprocess +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.structures import ImageList +import torch +from torch import nn + + +@META_ARCH_REGISTRY.register() +class CenterNetDetector(nn.Module): + def __init__(self, cfg) -> None: + super().__init__() + self.mean, self.std = cfg.MODEL.PIXEL_MEAN, cfg.MODEL.PIXEL_STD + self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)) + self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1)) + + self.backbone = build_backbone(cfg) + self.proposal_generator = build_proposal_generator( + cfg, self.backbone.output_shape() + ) # TODO: change to a more precise name + + def forward(self, batched_inputs): + if not self.training: + return self.inference(batched_inputs) + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + + _, proposal_losses = self.proposal_generator(images, features, gt_instances) + return proposal_losses + + @property + def device(self): + return self.pixel_mean.device + + @torch.no_grad() + def inference(self, batched_inputs, do_postprocess: bool=True): + images = self.preprocess_image(batched_inputs) + inp = images.tensor + features = self.backbone(inp) + proposals, _ = self.proposal_generator(images, features, None) + + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + proposals, batched_inputs, images.image_sizes, strict=False + ): + if do_postprocess: + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"instances": r}) + else: + r = results_per_image + processed_results.append(r) + return processed_results + + def preprocess_image(self, batched_inputs): + """ + Normalize, pad and batch the input images. + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.backbone.size_divisibility) + return images diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py new file mode 100644 index 0000000000..b48b5447ac --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py @@ -0,0 +1,151 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Part of the code is from https://github.com/tztztztztz/eql.detectron2/blob/master/projects/EQL/eql/fast_rcnn.py +import math + +from detectron2.layers import ShapeSpec, cat +from detectron2.modeling.roi_heads.fast_rcnn import ( + FastRCNNOutputLayers, + _log_classification_stats, + fast_rcnn_inference, +) +import torch +from torch import nn +from torch.nn import functional as F + +from .fed_loss import get_fed_loss_inds, load_class_freq + +__all__ = ["CustomFastRCNNOutputLayers"] + + +class CustomFastRCNNOutputLayers(FastRCNNOutputLayers): + def __init__(self, cfg, input_shape: ShapeSpec, **kwargs) -> None: + super().__init__(cfg, input_shape, **kwargs) + self.use_sigmoid_ce = cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE + if self.use_sigmoid_ce: + prior_prob = cfg.MODEL.ROI_BOX_HEAD.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + nn.init.constant_(self.cls_score.bias, bias_value) + + self.cfg = cfg + self.use_fed_loss = cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS + if self.use_fed_loss: + self.fed_loss_num_cat = cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT + self.register_buffer( + "freq_weight", + load_class_freq( + cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, + cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT, + ), + ) + + def losses(self, predictions, proposals): + """ + enable advanced loss + """ + scores, proposal_deltas = predictions + gt_classes = ( + cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) + ) + _log_classification_stats(scores, gt_classes) + + if len(proposals): + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4 + assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" + gt_boxes = cat( + [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], + dim=0, + ) + else: + proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) + + if self.use_sigmoid_ce: + loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) + else: + loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes) + return { + "loss_cls": loss_cls, + "loss_box_reg": self.box_reg_loss( + proposal_boxes, gt_boxes, proposal_deltas, gt_classes + ), + } + + def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes): + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] # This is more robust than .sum() * 0. + + B = pred_class_logits.shape[0] + C = pred_class_logits.shape[1] - 1 + + target = pred_class_logits.new_zeros(B, C + 1) + target[range(len(gt_classes)), gt_classes] = 1 # B x (C + 1) + target = target[:, :C] # B x C + + weight = 1 + if self.use_fed_loss and (self.freq_weight is not None): # fedloss + appeared = get_fed_loss_inds( + gt_classes, num_sample_cats=self.fed_loss_num_cat, C=C, weight=self.freq_weight + ) + appeared_mask = appeared.new_zeros(C + 1) + appeared_mask[appeared] = 1 # C + 1 + appeared_mask = appeared_mask[:C] + fed_w = appeared_mask.view(1, C).expand(B, C) + weight = weight * fed_w.float() + + cls_loss = F.binary_cross_entropy_with_logits( + pred_class_logits[:, :-1], target, reduction="none" + ) # B x C + loss = torch.sum(cls_loss * weight) / B + return loss + + def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): + """ + change _no_instance handling + """ + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] + + if self.use_fed_loss and (self.freq_weight is not None): + C = pred_class_logits.shape[1] - 1 + appeared = get_fed_loss_inds( + gt_classes, num_sample_cats=self.fed_loss_num_cat, C=C, weight=self.freq_weight + ) + appeared_mask = appeared.new_zeros(C + 1).float() + appeared_mask[appeared] = 1.0 # C + 1 + appeared_mask[C] = 1.0 + loss = F.cross_entropy( + pred_class_logits, gt_classes, weight=appeared_mask, reduction="mean" + ) + else: + loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean") + return loss + + def inference(self, predictions, proposals): + """ + enable use proposal boxes + """ + boxes = self.predict_boxes(predictions, proposals) + scores = self.predict_probs(predictions, proposals) + if self.cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE: + proposal_scores = [p.get("objectness_logits") for p in proposals] + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] + image_shapes = [x.image_size for x in proposals] + return fast_rcnn_inference( + boxes, + scores, + image_shapes, + self.test_score_thresh, + self.test_nms_thresh, + self.test_topk_per_image, + ) + + def predict_probs(self, predictions, proposals): + """ + support sigmoid + """ + scores, _ = predictions + num_inst_per_image = [len(p) for p in proposals] + if self.use_sigmoid_ce: + probs = scores.sigmoid() + else: + probs = F.softmax(scores, dim=-1) + return probs.split(num_inst_per_image, dim=0) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_roi_heads.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_roi_heads.py new file mode 100644 index 0000000000..d0478de2f3 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_roi_heads.py @@ -0,0 +1,182 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads +from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference +from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads +from detectron2.utils.events import get_event_storage +import torch + +from .custom_fast_rcnn import CustomFastRCNNOutputLayers + + +@ROI_HEADS_REGISTRY.register() +class CustomROIHeads(StandardROIHeads): + @classmethod + def _init_box_head(cls, cfg, input_shape): + ret = super()._init_box_head(cfg, input_shape) + del ret["box_predictor"] + ret["box_predictor"] = CustomFastRCNNOutputLayers(cfg, ret["box_head"].output_shape) + cls.debug = cfg.DEBUG + if cls.debug: + cls.debug_show_name = cfg.DEBUG_SHOW_NAME + cls.save_debug = cfg.SAVE_DEBUG + cls.vis_thresh = cfg.VIS_THRESH + cls.pixel_mean = ( + torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + ) + cls.pixel_std = ( + torch.Tensor(cfg.MODEL.PIXEL_STD).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + ) + return ret + + def forward(self, images, features, proposals, targets=None): + """ + enable debug + """ + if not self.debug: + del images + if self.training: + assert targets + proposals = self.label_and_sample_proposals(proposals, targets) + del targets + + if self.training: + losses = self._forward_box(features, proposals) + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + pred_instances = self._forward_box(features, proposals) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + if self.debug: + from ..debug import debug_second_stage + + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean + debug_second_stage( + [denormalizer(images[0].clone())], + pred_instances, + proposals=proposals, + debug_show_name=self.debug_show_name, + ) + return pred_instances, {} + + +@ROI_HEADS_REGISTRY.register() +class CustomCascadeROIHeads(CascadeROIHeads): + @classmethod + def _init_box_head(cls, cfg, input_shape): + cls.mult_proposal_score = cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE + ret = super()._init_box_head(cfg, input_shape) + del ret["box_predictors"] + cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS + box_predictors = [] + for box_head, bbox_reg_weights in zip(ret["box_heads"], cascade_bbox_reg_weights, strict=False): + box_predictors.append( + CustomFastRCNNOutputLayers( + cfg, + box_head.output_shape, + box2box_transform=Box2BoxTransform(weights=bbox_reg_weights), + ) + ) + ret["box_predictors"] = box_predictors + cls.debug = cfg.DEBUG + if cls.debug: + cls.debug_show_name = cfg.DEBUG_SHOW_NAME + cls.save_debug = cfg.SAVE_DEBUG + cls.vis_thresh = cfg.VIS_THRESH + cls.pixel_mean = ( + torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + ) + cls.pixel_std = ( + torch.Tensor(cfg.MODEL.PIXEL_STD).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) + ) + return ret + + def _forward_box(self, features, proposals, targets=None): + """ + Add mult proposal scores at testing + """ + if (not self.training) and self.mult_proposal_score: + if len(proposals) > 0 and proposals[0].has("scores"): + proposal_scores = [p.get("scores") for p in proposals] + else: + proposal_scores = [p.get("objectness_logits") for p in proposals] + + features = [features[f] for f in self.box_in_features] + head_outputs = [] # (predictor, predictions, proposals) + prev_pred_boxes = None + image_sizes = [x.image_size for x in proposals] + for k in range(self.num_cascade_stages): + if k > 0: + proposals = self._create_proposals_from_boxes(prev_pred_boxes, image_sizes) + if self.training: + proposals = self._match_and_label_boxes(proposals, k, targets) + predictions = self._run_stage(features, proposals, k) + prev_pred_boxes = self.box_predictor[k].predict_boxes(predictions, proposals) + head_outputs.append((self.box_predictor[k], predictions, proposals)) + + if self.training: + losses = {} + storage = get_event_storage() + for stage, (predictor, predictions, proposals) in enumerate(head_outputs): + with storage.name_scope(f"stage{stage}"): + stage_losses = predictor.losses(predictions, proposals) + losses.update({k + f"_stage{stage}": v for k, v in stage_losses.items()}) + return losses + else: + # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1) + scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] + scores = [ + sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) + for scores_per_image in zip(*scores_per_stage, strict=False) + ] + + if self.mult_proposal_score: + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] + + predictor, predictions, proposals = head_outputs[-1] + boxes = predictor.predict_boxes(predictions, proposals) + pred_instances, _ = fast_rcnn_inference( + boxes, + scores, + image_sizes, + predictor.test_score_thresh, + predictor.test_nms_thresh, + predictor.test_topk_per_image, + ) + + return pred_instances + + def forward(self, images, features, proposals, targets=None): + """ + enable debug + """ + if not self.debug: + del images + if self.training: + proposals = self.label_and_sample_proposals(proposals, targets) + + if self.training: + losses = self._forward_box(features, proposals, targets) + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + # import pdb; pdb.set_trace() + pred_instances = self._forward_box(features, proposals) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + if self.debug: + from ..debug import debug_second_stage + + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean + debug_second_stage( + [denormalizer(x.clone()) for x in images], + pred_instances, + proposals=proposals, + save_debug=self.save_debug, + debug_show_name=self.debug_show_name, + vis_thresh=self.vis_thresh, + ) + return pred_instances, {} diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/fed_loss.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/fed_loss.py new file mode 100644 index 0000000000..8a41607ea9 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/fed_loss.py @@ -0,0 +1,25 @@ +import json + +import torch + + +def load_class_freq(path: str="datasets/lvis/lvis_v1_train_cat_info.json", freq_weight: float=0.5): + cat_info = json.load(open(path)) + cat_info = torch.tensor([c["image_count"] for c in sorted(cat_info, key=lambda x: x["id"])]) + freq_weight = cat_info.float() ** freq_weight + return freq_weight + + +def get_fed_loss_inds(gt_classes, num_sample_cats: int=50, C: int=1203, weight=None, fed_cls_inds=-1): + appeared = torch.unique(gt_classes) # C' + prob = appeared.new_ones(C + 1).float() + prob[-1] = 0 + if len(appeared) < num_sample_cats: + if weight is not None: + prob[:C] = weight.float().clone() + prob[appeared] = 0 + if fed_cls_inds > 0: + prob[fed_cls_inds:] = 0 + more_appeared = torch.multinomial(prob, num_sample_cats - len(appeared), replacement=False) + appeared = torch.cat([appeared, more_appeared]) + return appeared diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/Base-CenterNet-FPN.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/Base-CenterNet-FPN.yaml new file mode 100644 index 0000000000..bef3dc10de --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/Base-CenterNet-FPN.yaml @@ -0,0 +1,28 @@ +MODEL: + META_ARCHITECTURE: "CenterNetDetector" + PROPOSAL_GENERATOR: + NAME: "CenterNet" + BACKBONE: + NAME: "build_p67_resnet_fpn_backbone" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + OUT_FEATURES: ["res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res3", "res4", "res5"] +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 + STEPS: (60000, 80000) + MAX_ITER: 90000 + CHECKPOINT_PERIOD: 1000000000 + WARMUP_ITERS: 4000 + WARMUP_FACTOR: 0.00025 + CLIP_GRADIENTS: + ENABLED: True +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +OUTPUT_DIR: "./output/CenterNet2/auto" diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/Base-CenterNet2.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/Base-CenterNet2.yaml new file mode 100644 index 0000000000..6893723101 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/Base-CenterNet2.yaml @@ -0,0 +1,56 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + PROPOSAL_GENERATOR: + NAME: "CenterNet" + BACKBONE: + NAME: "build_p67_resnet_fpn_backbone" + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + OUT_FEATURES: ["res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res3", "res4", "res5"] + ROI_HEADS: + NAME: CustomCascadeROIHeads + IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"] + IOU_THRESHOLDS: [0.6] + NMS_THRESH_TEST: 0.7 + ROI_BOX_CASCADE_HEAD: + IOUS: [0.6, 0.7, 0.8] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + CLS_AGNOSTIC_BBOX_REG: True + MULT_PROPOSAL_SCORE: True + CENTERNET: + REG_WEIGHT: 1. + NOT_NORM_REG: True + ONLY_PROPOSAL: True + WITH_AGN_HM: True + INFERENCE_TH: 0.0001 + PRE_NMS_TOPK_TRAIN: 4000 + POST_NMS_TOPK_TRAIN: 2000 + PRE_NMS_TOPK_TEST: 1000 + POST_NMS_TOPK_TEST: 256 + NMS_TH_TRAIN: 0.9 + NMS_TH_TEST: 0.9 + POS_WEIGHT: 0.5 + NEG_WEIGHT: 0.5 + IGNORE_HIGH_FP: 0.85 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 + CHECKPOINT_PERIOD: 1000000000 + WARMUP_ITERS: 4000 + WARMUP_FACTOR: 0.00025 + CLIP_GRADIENTS: + ENABLED: True +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +OUTPUT_DIR: "./output/CenterNet2/auto" diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/Base_S4_DLA.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/Base_S4_DLA.yaml new file mode 100644 index 0000000000..7e01be7e55 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/Base_S4_DLA.yaml @@ -0,0 +1,40 @@ +MODEL: + META_ARCHITECTURE: "CenterNetDetector" + PROPOSAL_GENERATOR: + NAME: "CenterNet" + PIXEL_STD: [57.375, 57.120, 58.395] + BACKBONE: + NAME: "build_dla_backbone" + DLA: + NORM: "BN" + CENTERNET: + IN_FEATURES: ["dla2"] + FPN_STRIDES: [4] + SOI: [[0, 1000000]] + NUM_CLS_CONVS: 1 + NUM_BOX_CONVS: 1 + REG_WEIGHT: 1. + MORE_POS: True + HM_FOCAL_ALPHA: 0.25 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + LR_SCHEDULER_NAME: "WarmupCosineLR" + MAX_ITER: 90000 + BASE_LR: 0.04 + IMS_PER_BATCH: 64 + WEIGHT_DECAY: 0.0001 + CHECKPOINT_PERIOD: 1000000 + CLIP_GRADIENTS: + ENABLED: True +INPUT: + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 640 + MIN_SIZE_TEST: 608 + MAX_SIZE_TEST: 900 +TEST: + EVAL_PERIOD: 7500 +DATALOADER: + NUM_WORKERS: 8 +OUTPUT_DIR: "output/CenterNet2/auto" diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet-FPN_R50_1x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet-FPN_R50_1x.yaml new file mode 100644 index 0000000000..6ea7d9b703 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet-FPN_R50_1x.yaml @@ -0,0 +1,4 @@ +_BASE_: "Base-CenterNet-FPN.yaml" +MODEL: + CENTERNET: + MORE_POS: True \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet-S4_DLA_8x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet-S4_DLA_8x.yaml new file mode 100644 index 0000000000..b3d88be9f5 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet-S4_DLA_8x.yaml @@ -0,0 +1,5 @@ +_BASE_: "Base_S4_DLA.yaml" +SOLVER: + MAX_ITER: 90000 + BASE_LR: 0.08 + IMS_PER_BATCH: 128 \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2-F_R50_1x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2-F_R50_1x.yaml new file mode 100644 index 0000000000..c40eecc13a --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2-F_R50_1x.yaml @@ -0,0 +1,4 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + ROI_HEADS: + NAME: CustomROIHeads \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P3_24x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P3_24x.yaml new file mode 100644 index 0000000000..d7491447eb --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P3_24x.yaml @@ -0,0 +1,36 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_p35_fcos_dla_bifpn_backbone" + BIFPN: + OUT_CHANNELS: 160 + NUM_LEVELS: 3 + NUM_BIFPN: 4 + DLA: + NUM_LAYERS: 34 + NORM: "SyncBN" + FPN: + IN_FEATURES: ["dla3", "dla4", "dla5"] + ROI_HEADS: + IN_FEATURES: ["p3", "p4", "p5"] + CENTERNET: + POST_NMS_TOPK_TEST: 128 + FPN_STRIDES: [8, 16, 32] + IN_FEATURES: ['p3', 'p4', 'p5'] + SOI: [[0, 64], [48, 192], [128, 1000000]] +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (300000, 340000) + MAX_ITER: 360000 + CHECKPOINT_PERIOD: 100000 + WARMUP_ITERS: 4000 + WARMUP_FACTOR: 0.00025 +INPUT: + MIN_SIZE_TRAIN: (256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608) + MAX_SIZE_TRAIN: 900 + MAX_SIZE_TEST: 736 + MIN_SIZE_TEST: 512 \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P3_4x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P3_4x.yaml new file mode 100644 index 0000000000..d7491447eb --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P3_4x.yaml @@ -0,0 +1,36 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_p35_fcos_dla_bifpn_backbone" + BIFPN: + OUT_CHANNELS: 160 + NUM_LEVELS: 3 + NUM_BIFPN: 4 + DLA: + NUM_LAYERS: 34 + NORM: "SyncBN" + FPN: + IN_FEATURES: ["dla3", "dla4", "dla5"] + ROI_HEADS: + IN_FEATURES: ["p3", "p4", "p5"] + CENTERNET: + POST_NMS_TOPK_TEST: 128 + FPN_STRIDES: [8, 16, 32] + IN_FEATURES: ['p3', 'p4', 'p5'] + SOI: [[0, 64], [48, 192], [128, 1000000]] +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (300000, 340000) + MAX_ITER: 360000 + CHECKPOINT_PERIOD: 100000 + WARMUP_ITERS: 4000 + WARMUP_FACTOR: 0.00025 +INPUT: + MIN_SIZE_TRAIN: (256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608) + MAX_SIZE_TRAIN: 900 + MAX_SIZE_TEST: 736 + MIN_SIZE_TEST: 512 \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P5_640_16x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P5_640_16x.yaml new file mode 100644 index 0000000000..80413a62d6 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P5_640_16x.yaml @@ -0,0 +1,29 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_p37_dla_bifpn_backbone" + BIFPN: + OUT_CHANNELS: 160 + NUM_LEVELS: 5 + NUM_BIFPN: 3 + CENTERNET: + POST_NMS_TOPK_TEST: 128 + WEIGHTS: '' + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + FPN: + IN_FEATURES: ["dla3", "dla4", "dla5"] +SOLVER: + LR_SCHEDULER_NAME: "WarmupCosineLR" + MAX_ITER: 360000 + BASE_LR: 0.08 + IMS_PER_BATCH: 64 + CHECKPOINT_PERIOD: 90000 +TEST: + EVAL_PERIOD: 7500 +INPUT: + FORMAT: RGB + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 640 + MIN_SIZE_TEST: 608 + MAX_SIZE_TEST: 900 diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P5_640_16x_ST.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P5_640_16x_ST.yaml new file mode 100644 index 0000000000..8813b39c1c --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-BiFPN-P5_640_16x_ST.yaml @@ -0,0 +1,30 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_p37_dla_bifpn_backbone" + BIFPN: + OUT_CHANNELS: 160 + NUM_LEVELS: 5 + NUM_BIFPN: 3 + CENTERNET: + POST_NMS_TOPK_TEST: 128 + WEIGHTS: '' + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + FPN: + IN_FEATURES: ["dla3", "dla4", "dla5"] +SOLVER: + LR_SCHEDULER_NAME: "WarmupCosineLR" + MAX_ITER: 360000 + BASE_LR: 0.08 + IMS_PER_BATCH: 64 +TEST: + EVAL_PERIOD: 7500 +INPUT: + FORMAT: RGB + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 640 + MIN_SIZE_TEST: 608 + MAX_SIZE_TEST: 900 +DATASETS: + TRAIN: ("coco_2017_train","coco_un_yolov4_55_0.5",) diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-fcosBiFPN-P5_640_16x_ST.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-fcosBiFPN-P5_640_16x_ST.yaml new file mode 100644 index 0000000000..f94f1358ce --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_DLA-fcosBiFPN-P5_640_16x_ST.yaml @@ -0,0 +1,30 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_p37_fcos_dla_bifpn_backbone" + BIFPN: + OUT_CHANNELS: 160 + NUM_LEVELS: 5 + NUM_BIFPN: 3 + CENTERNET: + POST_NMS_TOPK_TEST: 128 + WEIGHTS: '' + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + FPN: + IN_FEATURES: ["dla3", "dla4", "dla5"] +TEST: + EVAL_PERIOD: 7500 +SOLVER: + LR_SCHEDULER_NAME: "WarmupCosineLR" + MAX_ITER: 360000 + BASE_LR: 0.08 + IMS_PER_BATCH: 64 +INPUT: + FORMAT: RGB + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 640 + MIN_SIZE_TEST: 608 + MAX_SIZE_TEST: 900 +DATASETS: + TRAIN: ("coco_2017_train","coco_un_yolov4_55_0.5",) diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN-BiFPN_1280_4x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN-BiFPN_1280_4x.yaml new file mode 100644 index 0000000000..e07574b351 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN-BiFPN_1280_4x.yaml @@ -0,0 +1,32 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_res2net_bifpn_backbone" + BIFPN: + NUM_BIFPN: 7 + OUT_CHANNELS: 288 + WEIGHTS: "output/r2_101.pkl" + RESNETS: + DEPTH: 101 + WIDTH_PER_GROUP: 26 + DEFORM_ON_PER_STAGE: [False, False, True, True] # on Res4, Res5 + DEFORM_MODULATED: True + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + CENTERNET: + USE_DEFORMABLE: True + ROI_HEADS: + IN_FEATURES: ["p3", "p4"] +INPUT: + FORMAT: RGB +TEST: + EVAL_PERIOD: 7500 +SOLVER: + MAX_ITER: 180000 + CHECKPOINT_PERIOD: 60000 + LR_SCHEDULER_NAME: "WarmupCosineLR" + BASE_LR: 0.04 + IMS_PER_BATCH: 32 +INPUT: + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 1280 diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN-BiFPN_4x+4x_1560_ST.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN-BiFPN_4x+4x_1560_ST.yaml new file mode 100644 index 0000000000..81fcab0972 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN-BiFPN_4x+4x_1560_ST.yaml @@ -0,0 +1,36 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_res2net_bifpn_backbone" + BIFPN: + NUM_BIFPN: 7 + OUT_CHANNELS: 288 + WEIGHTS: "output/r2_101.pkl" + RESNETS: + DEPTH: 101 + WIDTH_PER_GROUP: 26 + DEFORM_ON_PER_STAGE: [False, False, True, True] # on Res4, Res5 + DEFORM_MODULATED: True + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + CENTERNET: + USE_DEFORMABLE: True + ROI_HEADS: + IN_FEATURES: ["p3", "p4"] +TEST: + EVAL_PERIOD: 7500 +SOLVER: + MAX_ITER: 180000 + CHECKPOINT_PERIOD: 7500 + LR_SCHEDULER_NAME: "WarmupCosineLR" + BASE_LR: 0.04 + IMS_PER_BATCH: 32 +DATASETS: + TRAIN: "('coco_2017_train', 'coco_un_yolov4_55_0.5')" +INPUT: + FORMAT: RGB + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 1280 + TEST_SIZE: 1560 + TEST_INPUT_TYPE: 'square' + \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN_896_4x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN_896_4x.yaml new file mode 100644 index 0000000000..fd6c49ee40 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R2-101-DCN_896_4x.yaml @@ -0,0 +1,29 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + BACKBONE: + NAME: "build_p67_res2net_fpn_backbone" + WEIGHTS: "output/r2_101.pkl" + RESNETS: + DEPTH: 101 + WIDTH_PER_GROUP: 26 + DEFORM_ON_PER_STAGE: [False, False, True, True] # on Res4, Res5 + DEFORM_MODULATED: True + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] + CENTERNET: + USE_DEFORMABLE: True + ROI_HEADS: + IN_FEATURES: ["p3", "p4"] +INPUT: + FORMAT: RGB +TEST: + EVAL_PERIOD: 7500 +SOLVER: + MAX_ITER: 180000 + CHECKPOINT_PERIOD: 600000 + LR_SCHEDULER_NAME: "WarmupCosineLR" + BASE_LR: 0.04 + IMS_PER_BATCH: 32 +INPUT: + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 896 \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R50_1x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R50_1x.yaml new file mode 100644 index 0000000000..9dcdf5b8b6 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_R50_1x.yaml @@ -0,0 +1 @@ +_BASE_: "Base-CenterNet2.yaml" diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_X101-DCN_2x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_X101-DCN_2x.yaml new file mode 100644 index 0000000000..009c68085b --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/CenterNet2_X101-DCN_2x.yaml @@ -0,0 +1,22 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + CENTERNET: + USE_DEFORMABLE: True + WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" + PIXEL_STD: [57.375, 57.120, 58.395] + RESNETS: + STRIDE_IN_1X1: False + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + DEPTH: 101 + DEFORM_ON_PER_STAGE: [False, False, True, True] # on Res4, Res5 + DEFORM_MODULATED: True + ROI_HEADS: + IN_FEATURES: ["p3", "p4"] +SOLVER: + STEPS: (120000, 160000) + MAX_ITER: 180000 + CHECKPOINT_PERIOD: 40000 +INPUT: + MIN_SIZE_TRAIN: (480, 960) + MIN_SIZE_TRAIN_SAMPLING: "range" diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/LVIS_CenterNet2_R50_1x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/LVIS_CenterNet2_R50_1x.yaml new file mode 100644 index 0000000000..912e8925dc --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/LVIS_CenterNet2_R50_1x.yaml @@ -0,0 +1,17 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + ROI_HEADS: + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.02 + NMS_THRESH_TEST: 0.5 + CENTERNET: + NUM_CLASSES: 1203 + +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 +TEST: + DETECTIONS_PER_IMAGE: 300 diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/LVIS_CenterNet2_R50_Fed_1x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/LVIS_CenterNet2_R50_Fed_1x.yaml new file mode 100644 index 0000000000..d6b6c823f2 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/LVIS_CenterNet2_R50_Fed_1x.yaml @@ -0,0 +1,19 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + ROI_HEADS: + NUM_CLASSES: 1203 + SCORE_THRESH_TEST: 0.02 + NMS_THRESH_TEST: 0.5 + CENTERNET: + NUM_CLASSES: 1203 + ROI_BOX_HEAD: + USE_SIGMOID_CE: True + USE_FED_LOSS: True +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 +TEST: + DETECTIONS_PER_IMAGE: 300 diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/O365_CenterNet2_R50_1x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/O365_CenterNet2_R50_1x.yaml new file mode 100644 index 0000000000..514e52cddc --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/O365_CenterNet2_R50_1x.yaml @@ -0,0 +1,13 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + ROI_HEADS: + NUM_CLASSES: 365 + CENTERNET: + NUM_CLASSES: 365 +DATASETS: + TRAIN: ("objects365_train",) + TEST: ("objects365_val",) +DATALOADER: + SAMPLER_TRAIN: "ClassAwareSampler" +TEST: + DETECTIONS_PER_IMAGE: 300 \ No newline at end of file diff --git a/dimos/models/Detic/third_party/CenterNet2/configs/nuImages_CenterNet2_DLA_640_8x.yaml b/dimos/models/Detic/third_party/CenterNet2/configs/nuImages_CenterNet2_DLA_640_8x.yaml new file mode 100644 index 0000000000..c400e92ce7 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/configs/nuImages_CenterNet2_DLA_640_8x.yaml @@ -0,0 +1,42 @@ +_BASE_: "Base-CenterNet2.yaml" +MODEL: + MASK_ON: True + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 + ROI_HEADS: + NUM_CLASSES: 10 + IN_FEATURES: ["dla2"] + BACKBONE: + NAME: "build_dla_backbone" + DLA: + NORM: "BN" + CENTERNET: + IN_FEATURES: ["dla2"] + FPN_STRIDES: [4] + SOI: [[0, 1000000]] + NUM_CLS_CONVS: 1 + NUM_BOX_CONVS: 1 + REG_WEIGHT: 1. + MORE_POS: True + HM_FOCAL_ALPHA: 0.25 + POST_NMS_TOPK_TEST: 128 + WEIGHTS: '' + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.12, 57.375] +SOLVER: + MAX_ITER: 180000 + STEPS: (120000, 160000) + BASE_LR: 0.08 + IMS_PER_BATCH: 64 +INPUT: + FORMAT: RGB + CUSTOM_AUG: EfficientDetResizeCrop + TRAIN_SIZE: 640 + MIN_SIZE_TEST: 608 + MAX_SIZE_TEST: 900 + MASK_FORMAT: bitmask +DATASETS: + TRAIN: ("nuimages_train",) + TEST: ("nuimages_val",) diff --git a/dimos/models/Detic/third_party/CenterNet2/datasets/README.md b/dimos/models/Detic/third_party/CenterNet2/datasets/README.md new file mode 100644 index 0000000000..0eb44cc3b2 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/datasets/README.md @@ -0,0 +1,140 @@ +# Use Builtin Datasets + +A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog) +for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc). +This document explains how to setup the builtin datasets so they can be used by the above APIs. +[Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`, +and how to add new datasets to them. + +Detectron2 has builtin support for a few datasets. +The datasets are assumed to exist in a directory specified by the environment variable +`DETECTRON2_DATASETS`. +Under this directory, detectron2 will look for datasets in the structure described below, if needed. +``` +$DETECTRON2_DATASETS/ + coco/ + lvis/ + cityscapes/ + VOC20{07,12}/ +``` + +You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`. +If left unset, the default is `./datasets` relative to your current working directory. + +The [model zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md) +contains configs and models that use these builtin datasets. + +## Expected dataset structure for [COCO instance/keypoint detection](https://cocodataset.org/#download): + +``` +coco/ + annotations/ + instances_{train,val}2017.json + person_keypoints_{train,val}2017.json + {train,val}2017/ + # image files that are mentioned in the corresponding json +``` + +You can use the 2014 version of the dataset as well. + +Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset, +which you can download with `./datasets/prepare_for_tests.sh`. + +## Expected dataset structure for PanopticFPN: + +Extract panoptic annotations from [COCO website](https://cocodataset.org/#download) +into the following structure: +``` +coco/ + annotations/ + panoptic_{train,val}2017.json + panoptic_{train,val}2017/ # png annotations + panoptic_stuff_{train,val}2017/ # generated by the script mentioned below +``` + +Install panopticapi by: +``` +pip install git+https://github.com/cocodataset/panopticapi.git +``` +Then, run `python datasets/prepare_panoptic_fpn.py`, to extract semantic annotations from panoptic annotations. + +## Expected dataset structure for [LVIS instance segmentation](https://www.lvisdataset.org/dataset): +``` +coco/ + {train,val,test}2017/ +lvis/ + lvis_v0.5_{train,val}.json + lvis_v0.5_image_info_test.json + lvis_v1_{train,val}.json + lvis_v1_image_info_test{,_challenge}.json +``` + +Install lvis-api by: +``` +pip install git+https://github.com/lvis-dataset/lvis-api.git +``` + +To evaluate models trained on the COCO dataset using LVIS annotations, +run `python datasets/prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations. + +## Expected dataset structure for [cityscapes](https://www.cityscapes-dataset.com/downloads/): +``` +cityscapes/ + gtFine/ + train/ + aachen/ + color.png, instanceIds.png, labelIds.png, polygons.json, + labelTrainIds.png + ... + val/ + test/ + # below are generated Cityscapes panoptic annotation + cityscapes_panoptic_train.json + cityscapes_panoptic_train/ + cityscapes_panoptic_val.json + cityscapes_panoptic_val/ + cityscapes_panoptic_test.json + cityscapes_panoptic_test/ + leftImg8bit/ + train/ + val/ + test/ +``` +Install cityscapes scripts by: +``` +pip install git+https://github.com/mcordts/cityscapesScripts.git +``` + +Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with: +``` +CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py +``` +These files are not needed for instance segmentation. + +Note: to generate Cityscapes panoptic dataset, run cityscapesescript with: +``` +CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py +``` +These files are not needed for semantic and instance segmentation. + +## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html): +``` +VOC20{07,12}/ + Annotations/ + ImageSets/ + Main/ + trainval.txt + test.txt + # train.txt or val.txt, if you use these splits + JPEGImages/ +``` + +## Expected dataset structure for [ADE20k Scene Parsing](http://sceneparsing.csail.mit.edu/): +``` +ADEChallengeData2016/ + annotations/ + annotations_detectron2/ + images/ + objectInfo150.txt +``` +The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`. diff --git a/dimos/models/Detic/third_party/CenterNet2/demo.py b/dimos/models/Detic/third_party/CenterNet2/demo.py new file mode 100644 index 0000000000..3177d838ac --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/demo.py @@ -0,0 +1,183 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +import glob +import multiprocessing as mp +import os +import time + +from centernet.config import add_centernet_config +import cv2 +from detectron2.config import get_cfg +from detectron2.data.detection_utils import read_image +from detectron2.utils.logger import setup_logger +from predictor import VisualizationDemo +import tqdm + +# constants +WINDOW_NAME = "CenterNet2 detections" + +from detectron2.data import MetadataCatalog +from detectron2.utils.video_visualizer import VideoVisualizer +from detectron2.utils.visualizer import ColorMode + + +def setup_cfg(args): + # load config from file and command-line arguments + cfg = get_cfg() + add_centernet_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + # Set score_threshold for builtin models + cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold + if cfg.MODEL.META_ARCHITECTURE in ["ProposalNetwork", "CenterNetDetector"]: + cfg.MODEL.CENTERNET.INFERENCE_TH = args.confidence_threshold + cfg.MODEL.CENTERNET.NMS_TH = cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST + cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold + cfg.freeze() + return cfg + + +def get_parser(): + parser = argparse.ArgumentParser(description="Detectron2 demo for builtin models") + parser.add_argument( + "--config-file", + default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml", + metavar="FILE", + help="path to config file", + ) + parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") + parser.add_argument("--video-input", help="Path to video file.") + parser.add_argument("--input", nargs="+", help="A list of space separated input images") + parser.add_argument( + "--output", + help="A file or directory to save output visualizations. If not given, will show output in an OpenCV window.", + ) + + parser.add_argument( + "--confidence-threshold", + type=float, + default=0.3, + help="Minimum score for instance predictions to be shown", + ) + parser.add_argument( + "--opts", + help="Modify config options using the command-line 'KEY VALUE' pairs", + default=[], + nargs=argparse.REMAINDER, + ) + return parser + + +if __name__ == "__main__": + mp.set_start_method("spawn", force=True) + args = get_parser().parse_args() + logger = setup_logger() + logger.info("Arguments: " + str(args)) + + cfg = setup_cfg(args) + + demo = VisualizationDemo(cfg) + output_file = None + if args.input: + if len(args.input) == 1: + args.input = glob.glob(os.path.expanduser(args.input[0])) + files = os.listdir(args.input[0]) + args.input = [args.input[0] + x for x in files] + assert args.input, "The input path(s) was not found" + visualizer = VideoVisualizer( + MetadataCatalog.get(cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"), + instance_mode=ColorMode.IMAGE, + ) + for path in tqdm.tqdm(args.input, disable=not args.output): + # use PIL, to be consistent with evaluation + img = read_image(path, format="BGR") + start_time = time.time() + predictions, visualized_output = demo.run_on_image(img, visualizer=visualizer) + if "instances" in predictions: + logger.info( + "{}: detected {} instances in {:.2f}s".format( + path, len(predictions["instances"]), time.time() - start_time + ) + ) + else: + logger.info( + "{}: detected {} instances in {:.2f}s".format( + path, len(predictions["proposals"]), time.time() - start_time + ) + ) + + if args.output: + if os.path.isdir(args.output): + assert os.path.isdir(args.output), args.output + out_filename = os.path.join(args.output, os.path.basename(path)) + visualized_output.save(out_filename) + else: + # assert len(args.input) == 1, "Please specify a directory with args.output" + # out_filename = args.output + if output_file is None: + width = visualized_output.get_image().shape[1] + height = visualized_output.get_image().shape[0] + frames_per_second = 15 + output_file = cv2.VideoWriter( + filename=args.output, + # some installation of opencv may not support x264 (due to its license), + # you can try other format (e.g. MPEG) + fourcc=cv2.VideoWriter_fourcc(*"x264"), + fps=float(frames_per_second), + frameSize=(width, height), + isColor=True, + ) + output_file.write(visualized_output.get_image()[:, :, ::-1]) + else: + # cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) + if cv2.waitKey(1) == 27: + break # esc to quit + elif args.webcam: + assert args.input is None, "Cannot have both --input and --webcam!" + cam = cv2.VideoCapture(0) + for vis in tqdm.tqdm(demo.run_on_video(cam)): + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, vis) + if cv2.waitKey(1) == 27: + break # esc to quit + cv2.destroyAllWindows() + elif args.video_input: + video = cv2.VideoCapture(args.video_input) + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames_per_second = 15 # video.get(cv2.CAP_PROP_FPS) + num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + basename = os.path.basename(args.video_input) + + if args.output: + if os.path.isdir(args.output): + output_fname = os.path.join(args.output, basename) + output_fname = os.path.splitext(output_fname)[0] + ".mkv" + else: + output_fname = args.output + # assert not os.path.isfile(output_fname), output_fname + output_file = cv2.VideoWriter( + filename=output_fname, + # some installation of opencv may not support x264 (due to its license), + # you can try other format (e.g. MPEG) + fourcc=cv2.VideoWriter_fourcc(*"x264"), + fps=float(frames_per_second), + frameSize=(width, height), + isColor=True, + ) + assert os.path.isfile(args.video_input) + for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): + if args.output: + output_file.write(vis_frame) + + cv2.namedWindow(basename, cv2.WINDOW_NORMAL) + cv2.imshow(basename, vis_frame) + if cv2.waitKey(1) == 27: + break # esc to quit + video.release() + if args.output: + output_file.release() + else: + cv2.destroyAllWindows() diff --git a/dimos/models/Detic/third_party/CenterNet2/docs/MODEL_ZOO.md b/dimos/models/Detic/third_party/CenterNet2/docs/MODEL_ZOO.md new file mode 100644 index 0000000000..97063b95c8 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/docs/MODEL_ZOO.md @@ -0,0 +1,73 @@ +# MODEL_ZOO + +### Common settings and notes + +- Multiscale training is used by default in all models. The results are all reported using single-scale testing. +- We report runtime on our local workstation with a TitanXp GPU and a Titan RTX GPU. +- All models are trained on 8-GPU servers by default. The 1280 models are trained on 24G GPUs. Reducing the batchsize with the linear learning rate rule should be fine. +- All models can be downloaded directly from [Google drive](https://drive.google.com/drive/folders/1meZIsz8E3Ia9CRxLOAULDLeYrKMhhjJE). + + +## COCO + +### CenterNet + +| Model | val mAP | FPS (Titan Xp/ Titan RTX) | links | +|-------------------------------------------|---------|---------|-----------| +| CenterNet-S4_DLA_8x | 42.5 | 50 / 71 |[config](../configs/CenterNet-S4_DLA_8x.yaml)/[model](https://drive.google.com/file/d/1AVfs9OoLePk_sqTPvqdRi1cXmO2cD0W_)| +| CenterNet-FPN_R50_1x | 40.2 | 20 / 24 |[config](../configs/CenterNet-FPN_R50_1x.yaml)/[model](https://drive.google.com/file/d/1iYlmjsBt9YIcaI8NzEwiMoaDDMHRmcR9)| + +#### Note + +- `CenterNet-S4_DLA_8x` is a re-implemented version of the original CenterNet (stride 4), with several changes, including + - Using top-left-right-bottom box encoding and GIoU Loss; adding regression loss to the center 3x3 region. + - Adding more positive pixels for the heatmap loss whose regression loss is small and is within the center3x3 region. + - Using more heavy crop augmentation (EfficientDet-style crop ratio 0.1-2), and removing color augmentations. + - Using standard NMS instead of max pooling. + - Using RetinaNet-style optimizer (SGD), learning rate rule (0.01 for each batch size 16), and schedule (8x12 epochs). +- `CenterNet-FPN_R50_1x` is a (new) FPN version of CenterNet. It includes the changes above, and assigns objects to FPN levels based on a fixed size range. The model is trained with standard short edge 640-800 multi-scale training with 12 epochs (1x). + + +### CenterNet2 + +| Model | val mAP | FPS (Titan Xp/ Titan RTX) | links | +|-------------------------------------------|---------|---------|-----------| +| CenterNet2-F_R50_1x | 41.7 | 22 / 27 |[config](../configs/CenterNet2-F_R50_1x.yaml)/[model](X)| +| CenterNet2_R50_1x | 42.9 | 18 / 24 |[config](../configs/CenterNet2_R50_1x.yaml)/[model](https://drive.google.com/file/d/1Qn0E_F1cmXtKPEdyZ_lSt-bnM9NueQpq)| +| CenterNet2_X101-DCN_2x | 49.9 | 6 / 8 |[config](../configs/CenterNet2_X101-DCN_2x.yaml)/[model](https://drive.google.com/file/d/1yuJbIlUgMiXdaDWRWArcsRsSoHti9e1y)| +| CenterNet2_DLA-BiFPN-P3_4x | 43.8 | 40 / 50|[config](../configs/CenterNet2_DLA-BiFPN-P3_4x.yaml)/[model](https://drive.google.com/file/d/1UGrnOE0W8Tgu6ffcCOQEbeUgThtDkbuQ)| +| CenterNet2_DLA-BiFPN-P3_24x | 45.6 | 40 / 50 |[config](../configs/CenterNet2_DLA-BiFPN-P3_24x.yaml)/[model](https://drive.google.com/file/d/17osgvr_Zhp9SS2uMa_YLiKwkKJIDtwPZ)| +| CenterNet2_R2-101-DCN_896_4x | 51.2 | 9 / 13 |[config](../configs/CenterNet2_R2-101-DCN_896_4x.yaml)/[model](https://drive.google.com/file/d/1YiJm7UtMstl63E8I4qQ8owteYC5zRFuQ)| +| CenterNet2_R2-101-DCN-BiFPN_1280_4x | 52.9 | 6 / 8 |[config](../configs/CenterNet2_R2-101-DCN-BiFPN_1280_4x.yaml)/[model](https://drive.google.com/file/d/1BIfEH04Lm3EvW9ov76yEPntUOJxaVoKd)| +| CenterNet2_R2-101-DCN-BiFPN_4x+4x_1560_ST | 56.1 | 3 / 5 |[config](../configs/CenterNet2_R2-101-DCN-BiFPN_4x+4x_1560_ST.yaml)/[model](https://drive.google.com/file/d/1GZyzJLB3FTcs8C7MpZRQWw44liYPyOMD)| +| CenterNet2_DLA-BiFPN-P5_640_24x_ST | 49.2 | 33 / 38 |[config](../configs/CenterNet2_DLA-BiFPN-P5_640_24x_ST.yaml)/[model](https://drive.google.com/file/d/1pGXpnHhvi66my_p5dASTnTjvaaj0FEvE)| + +#### Note + +- `CenterNet2-F_R50_1x` uses Faster RCNN as the second stage. All other CenterNet2 models use Cascade RCNN as the second stage. +- `CenterNet2_DLA-BiFPN-P3_4x` follows the same training setting as [realtime-FCOS](https://github.com/aim-uofa/AdelaiDet/blob/master/configs/FCOS-Detection/README.md). +- `CenterNet2_DLA-BiFPN-P3_24x` is trained by repeating the `4x` schedule (starting from learning rate 0.01) 6 times. +- R2 means [Res2Net](https://github.com/Res2Net/Res2Net-detectron2) backbone. To train Res2Net models, you need to download the ImageNet pre-trained weight [here](https://github.com/Res2Net/Res2Net-detectron2) and place it in `output/r2_101.pkl`. +- The last 4 models in the table are trained with the EfficientDet-style resize-and-crop augmentation, instead of the default random resizing short edge in detectron2. We found this trains faster (per-iteration) and gives better performance under a long schedule. +- `_ST` means using [self-training](https://arxiv.org/abs/2006.06882) using pseudo-labels produced by [Scaled-YOLOv4](https://github.com/WongKinYiu/ScaledYOLOv4) on COCO unlabeled images, with a hard score threshold 0.5. Our processed pseudo-labels can be downloaded [here](https://drive.google.com/file/d/1R9tHlUaIrujmK6T08yJ0T77b2XzekisC). +- `CenterNet2_R2-101-DCN-BiFPN_4x+4x_1560_ST` finetunes from `CenterNet2_R2-101-DCN-BiFPN_1280_4x` for an additional `4x` schedule with the self-training data. It is trained under `1280x1280` but tested under `1560x1560`. + +## LVIS v1 + +| Model | val mAP box | links | +|-------------------------------------------|--------------|-----------| +| LVIS_CenterNet2_R50_1x | 26.5 |[config](../configs/LVIS_CenterNet2_R50_1x.yaml)/[model](https://drive.google.com/file/d/1oOOKEDQIWW19AHhfnTb7HYZ3Z9gkZn_K)| +| LVIS_CenterNet2_R50_Fed_1x | 28.3 |[config](../configs/LVIS_CenterNet2_R50_Fed_1x.yaml)/[model](https://drive.google.com/file/d/1ETurGA7KIC5XMkMBI8MOIMDD_iJyMTif)| + +- The models are trained with repeat-factor sampling. +- `LVIS_CenterNet2_R50_Fed_1x` is CenterNet2 with our federated loss. Check our Appendix D of our [paper](https://arxiv.org/abs/2103.07461) or our [technical report at LVIS challenge](https://www.lvisdataset.org/assets/challenge_reports/2020/CenterNet2.pdf) for references. + +## Objects365 + +| Model | val mAP| links | +|-------------------------------------------|---------|-----------| +| O365_CenterNet2_R50_1x | 22.6 |[config](../configs/O365_CenterNet2_R50_1x.yaml)/[model](https://drive.google.com/file/d/11d1Qx75otBAQQL2raxMTVJb17Qr56M3O)| + +#### Note +- Objects365 dataset can be downloaded [here](https://www.objects365.org/overview.html). +- The model is trained with class-aware sampling. diff --git a/dimos/models/Detic/third_party/CenterNet2/predictor.py b/dimos/models/Detic/third_party/CenterNet2/predictor.py new file mode 100644 index 0000000000..0bdee56264 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/predictor.py @@ -0,0 +1,241 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import atexit +import bisect +from collections import deque +import multiprocessing as mp + +import cv2 +from detectron2.data import MetadataCatalog +from detectron2.engine.defaults import DefaultPredictor +from detectron2.utils.video_visualizer import VideoVisualizer +from detectron2.utils.visualizer import ColorMode, Visualizer +import torch + + +class VisualizationDemo: + def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel: bool=False) -> None: + """ + Args: + cfg (CfgNode): + instance_mode (ColorMode): + parallel (bool): whether to run the model in different processes from visualization. + Useful since the visualization logic can be slow. + """ + self.metadata = MetadataCatalog.get( + cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "__unused" + ) + self.cpu_device = torch.device("cpu") + self.instance_mode = instance_mode + + self.parallel = parallel + if parallel: + num_gpu = torch.cuda.device_count() + self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) + else: + self.predictor = DefaultPredictor(cfg) + + def run_on_image(self, image, visualizer=None): + """ + Args: + image (np.ndarray): an image of shape (H, W, C) (in BGR order). + This is the format used by OpenCV. + + Returns: + predictions (dict): the output of the model. + vis_output (VisImage): the visualized image output. + """ + vis_output = None + predictions = self.predictor(image) + # Convert image from OpenCV BGR format to Matplotlib RGB format. + image = image[:, :, ::-1] + use_video_vis = True + if visualizer is None: + use_video_vis = False + visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_output = visualizer.draw_panoptic_seg_predictions( + panoptic_seg.to(self.cpu_device), segments_info + ) + else: + if "sem_seg" in predictions: + vis_output = visualizer.draw_sem_seg( + predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + if "instances" in predictions: + instances = predictions["instances"].to(self.cpu_device) + if use_video_vis: + vis_output = visualizer.draw_instance_predictions(image, predictions=instances) + else: + vis_output = visualizer.draw_instance_predictions(predictions=instances) + elif "proposals" in predictions: + instances = predictions["proposals"].to(self.cpu_device) + instances.pred_boxes = instances.proposal_boxes + instances.scores = instances.objectness_logits + instances.pred_classes[:] = -1 + if use_video_vis: + vis_output = visualizer.draw_instance_predictions(image, predictions=instances) + else: + vis_output = visualizer.draw_instance_predictions(predictions=instances) + + return predictions, vis_output + + def _frame_from_video(self, video): + while video.isOpened(): + success, frame = video.read() + if success: + yield frame + else: + break + + def run_on_video(self, video): + """ + Visualizes predictions on frames of the input video. + + Args: + video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be + either a webcam or a video file. + + Yields: + ndarray: BGR visualizations of each video frame. + """ + video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) + + def process_predictions(frame, predictions): + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_frame = video_visualizer.draw_panoptic_seg_predictions( + frame, panoptic_seg.to(self.cpu_device), segments_info + ) + elif "instances" in predictions: + predictions = predictions["instances"].to(self.cpu_device) + vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) + elif "sem_seg" in predictions: + vis_frame = video_visualizer.draw_sem_seg( + frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + elif "proposals" in predictions: + predictions = predictions["proposals"].to(self.cpu_device) + predictions.pred_boxes = predictions.proposal_boxes + predictions.scores = predictions.objectness_logits + predictions.pred_classes[:] = -1 + vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) + + # Converts Matplotlib RGB format to OpenCV BGR format + vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) + return vis_frame + + frame_gen = self._frame_from_video(video) + if self.parallel: + buffer_size = self.predictor.default_buffer_size + + frame_data = deque() + + for cnt, frame in enumerate(frame_gen): + frame_data.append(frame) + self.predictor.put(frame) + + if cnt >= buffer_size: + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + + while len(frame_data): + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + else: + for frame in frame_gen: + yield process_predictions(frame, self.predictor(frame)) + + +class AsyncPredictor: + """ + A predictor that runs the model asynchronously, possibly on >1 GPUs. + Because rendering the visualization takes considerably amount of time, + this helps improve throughput when rendering videos. + """ + + class _StopToken: + pass + + class _PredictWorker(mp.Process): + def __init__(self, cfg, task_queue, result_queue) -> None: + self.cfg = cfg + self.task_queue = task_queue + self.result_queue = result_queue + super().__init__() + + def run(self) -> None: + predictor = DefaultPredictor(self.cfg) + + while True: + task = self.task_queue.get() + if isinstance(task, AsyncPredictor._StopToken): + break + idx, data = task + result = predictor(data) + self.result_queue.put((idx, result)) + + def __init__(self, cfg, num_gpus: int = 1) -> None: + """ + Args: + cfg (CfgNode): + num_gpus (int): if 0, will run on CPU + """ + num_workers = max(num_gpus, 1) + self.task_queue = mp.Queue(maxsize=num_workers * 3) + self.result_queue = mp.Queue(maxsize=num_workers * 3) + self.procs = [] + for gpuid in range(max(num_gpus, 1)): + cfg = cfg.clone() + cfg.defrost() + cfg.MODEL.DEVICE = f"cuda:{gpuid}" if num_gpus > 0 else "cpu" + self.procs.append( + AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) + ) + + self.put_idx = 0 + self.get_idx = 0 + self.result_rank = [] + self.result_data = [] + + for p in self.procs: + p.start() + atexit.register(self.shutdown) + + def put(self, image) -> None: + self.put_idx += 1 + self.task_queue.put((self.put_idx, image)) + + def get(self): + self.get_idx += 1 # the index needed for this request + if len(self.result_rank) and self.result_rank[0] == self.get_idx: + res = self.result_data[0] + del self.result_data[0], self.result_rank[0] + return res + + while True: + # make sure the results are returned in the correct order + idx, res = self.result_queue.get() + if idx == self.get_idx: + return res + insert = bisect.bisect(self.result_rank, idx) + self.result_rank.insert(insert, idx) + self.result_data.insert(insert, res) + + def __len__(self) -> int: + return self.put_idx - self.get_idx + + def __call__(self, image): + self.put(image) + return self.get() + + def shutdown(self) -> None: + for _ in self.procs: + self.task_queue.put(AsyncPredictor._StopToken()) + + @property + def default_buffer_size(self): + return len(self.procs) * 5 diff --git a/dimos/models/Detic/third_party/CenterNet2/requirements.txt b/dimos/models/Detic/third_party/CenterNet2/requirements.txt new file mode 100644 index 0000000000..0dd006bbc3 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/requirements.txt @@ -0,0 +1 @@ +opencv-python diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/README.md b/dimos/models/Detic/third_party/CenterNet2/tools/README.md new file mode 100644 index 0000000000..0b40d5319c --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/README.md @@ -0,0 +1,49 @@ + +This directory contains a few example scripts that demonstrate features of detectron2. + + +* `train_net.py` + +An example training script that's made to train builtin models of detectron2. + +For usage, see [GETTING_STARTED.md](../GETTING_STARTED.md). + +* `plain_train_net.py` + +Similar to `train_net.py`, but implements a training loop instead of using `Trainer`. +This script includes fewer features but it may be more friendly to hackers. + +* `benchmark.py` + +Benchmark the training speed, inference speed or data loading speed of a given config. + +Usage: +``` +python benchmark.py --config-file config.yaml --task train/eval/data [optional DDP flags] +``` + +* `analyze_model.py` + +Analyze FLOPs, parameters, activations of a detectron2 model. See its `--help` for usage. + +* `visualize_json_results.py` + +Visualize the json instance detection/segmentation results dumped by `COCOEvalutor` or `LVISEvaluator` + +Usage: +``` +python visualize_json_results.py --input x.json --output dir/ --dataset coco_2017_val +``` +If not using a builtin dataset, you'll need your own script or modify this script. + +* `visualize_data.py` + +Visualize ground truth raw annotations or training data (after preprocessing/augmentations). + +Usage: +``` +python visualize_data.py --config-file config.yaml --source annotation/dataloader --output-dir dir/ [--show] +``` + +NOTE: the script does not stop by itself when using `--source dataloader` because a training +dataloader is usually infinite. diff --git a/dimos/manipulation/classical/grasp_gen.py b/dimos/models/Detic/third_party/CenterNet2/tools/__init__.py similarity index 100% rename from dimos/manipulation/classical/grasp_gen.py rename to dimos/models/Detic/third_party/CenterNet2/tools/__init__.py diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/analyze_model.py b/dimos/models/Detic/third_party/CenterNet2/tools/analyze_model.py new file mode 100755 index 0000000000..7b7b9e3432 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/analyze_model.py @@ -0,0 +1,155 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from collections import Counter +import logging + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate +from detectron2.data import build_detection_test_loader +from detectron2.engine import default_argument_parser +from detectron2.modeling import build_model +from detectron2.utils.analysis import ( + FlopCountAnalysis, + activation_count_operators, + parameter_count_table, +) +from detectron2.utils.logger import setup_logger +from fvcore.nn import flop_count_table # can also try flop_count_str +import numpy as np +import tqdm + +logger = logging.getLogger("detectron2") + + +def setup(args): + if args.config_file.endswith(".yaml"): + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.DATALOADER.NUM_WORKERS = 0 + cfg.merge_from_list(args.opts) + cfg.freeze() + else: + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + setup_logger(name="fvcore") + setup_logger() + return cfg + + +def do_flop(cfg) -> None: + if isinstance(cfg, CfgNode): + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + else: + data_loader = instantiate(cfg.dataloader.test) + model = instantiate(cfg.model) + model.to(cfg.train.device) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + model.eval() + + counts = Counter() + total_flops = [] + for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa + flops = FlopCountAnalysis(model, data) + if idx > 0: + flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False) + counts += flops.by_operator() + total_flops.append(flops.total()) + + logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops)) + logger.info( + "Average GFlops for each type of operators:\n" + + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]) + ) + logger.info( + f"Total GFlops: {np.mean(total_flops) / 1e9:.1f}±{np.std(total_flops) / 1e9:.1f}" + ) + + +def do_activation(cfg) -> None: + if isinstance(cfg, CfgNode): + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + else: + data_loader = instantiate(cfg.dataloader.test) + model = instantiate(cfg.model) + model.to(cfg.train.device) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + model.eval() + + counts = Counter() + total_activations = [] + for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa + count = activation_count_operators(model, data) + counts += count + total_activations.append(sum(count.values())) + logger.info( + "(Million) Activations for Each Type of Operators:\n" + + str([(k, v / idx) for k, v in counts.items()]) + ) + logger.info( + f"Total (Million) Activations: {np.mean(total_activations)}±{np.std(total_activations)}" + ) + + +def do_parameter(cfg) -> None: + if isinstance(cfg, CfgNode): + model = build_model(cfg) + else: + model = instantiate(cfg.model) + logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5)) + + +def do_structure(cfg) -> None: + if isinstance(cfg, CfgNode): + model = build_model(cfg) + else: + model = instantiate(cfg.model) + logger.info("Model Structure:\n" + str(model)) + + +if __name__ == "__main__": + parser = default_argument_parser( + epilog=""" +Examples: + +To show parameters of a model: +$ ./analyze_model.py --tasks parameter \\ + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml + +Flops and activations are data-dependent, therefore inputs and model weights +are needed to count them: + +$ ./analyze_model.py --num-inputs 100 --tasks flop \\ + --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\ + MODEL.WEIGHTS /path/to/model.pkl +""" + ) + parser.add_argument( + "--tasks", + choices=["flop", "activation", "parameter", "structure"], + required=True, + nargs="+", + ) + parser.add_argument( + "-n", + "--num-inputs", + default=100, + type=int, + help="number of inputs used to compute statistics for flops/activations, both are data dependent.", + ) + args = parser.parse_args() + assert not args.eval_only + assert args.num_gpus == 1 + + cfg = setup(args) + + for task in args.tasks: + { + "flop": do_flop, + "activation": do_activation, + "parameter": do_parameter, + "structure": do_structure, + }[task](cfg) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/benchmark.py b/dimos/models/Detic/third_party/CenterNet2/tools/benchmark.py new file mode 100755 index 0000000000..48f398d83d --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/benchmark.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +A script to benchmark builtin models. + +Note: this script has an extra dependency of psutil. +""" + +import itertools +import logging + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, get_cfg, instantiate +from detectron2.data import ( + DatasetFromList, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.data.benchmark import DataLoaderBenchmark +from detectron2.engine import AMPTrainer, SimpleTrainer, default_argument_parser, hooks, launch +from detectron2.modeling import build_model +from detectron2.solver import build_optimizer +from detectron2.utils import comm +from detectron2.utils.collect_env import collect_env_info +from detectron2.utils.events import CommonMetricPrinter +from detectron2.utils.logger import setup_logger +from fvcore.common.timer import Timer +import psutil +import torch +from torch.nn.parallel import DistributedDataParallel +import tqdm + +logger = logging.getLogger("detectron2") + + +def setup(args): + if args.config_file.endswith(".yaml"): + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway. + cfg.merge_from_list(args.opts) + cfg.freeze() + else: + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + setup_logger(distributed_rank=comm.get_rank()) + return cfg + + +def create_data_benchmark(cfg, args): + if args.config_file.endswith(".py"): + dl_cfg = cfg.dataloader.train + dl_cfg._target_ = DataLoaderBenchmark + return instantiate(dl_cfg) + else: + kwargs = build_detection_train_loader.from_config(cfg) + kwargs.pop("aspect_ratio_grouping", None) + kwargs["_target_"] = DataLoaderBenchmark + return instantiate(kwargs) + + +def RAM_msg() -> str: + vram = psutil.virtual_memory() + return f"RAM Usage: {(vram.total - vram.available) / 1024**3:.2f}/{vram.total / 1024**3:.2f} GB" + + +def benchmark_data(args) -> None: + cfg = setup(args) + logger.info("After spawning " + RAM_msg()) + + benchmark = create_data_benchmark(cfg, args) + benchmark.benchmark_distributed(250, 10) + # test for a few more rounds + for k in range(10): + logger.info(f"Iteration {k} " + RAM_msg()) + benchmark.benchmark_distributed(250, 1) + + +def benchmark_data_advanced(args) -> None: + # benchmark dataloader with more details to help analyze performance bottleneck + cfg = setup(args) + benchmark = create_data_benchmark(cfg, args) + + if comm.get_rank() == 0: + benchmark.benchmark_dataset(100) + benchmark.benchmark_mapper(100) + benchmark.benchmark_workers(100, warmup=10) + benchmark.benchmark_IPC(100, warmup=10) + if comm.get_world_size() > 1: + benchmark.benchmark_distributed(100) + logger.info("Rerun ...") + benchmark.benchmark_distributed(100) + + +def benchmark_train(args) -> None: + cfg = setup(args) + model = build_model(cfg) + logger.info(f"Model:\n{model}") + if comm.get_world_size() > 1: + model = DistributedDataParallel( + model, device_ids=[comm.get_local_rank()], broadcast_buffers=False + ) + optimizer = build_optimizer(cfg, model) + checkpointer = DetectionCheckpointer(model, optimizer=optimizer) + checkpointer.load(cfg.MODEL.WEIGHTS) + + cfg.defrost() + cfg.DATALOADER.NUM_WORKERS = 2 + data_loader = build_detection_train_loader(cfg) + dummy_data = list(itertools.islice(data_loader, 100)) + + def f(): + data = DatasetFromList(dummy_data, copy=False, serialize=False) + while True: + yield from data + + max_iter = 400 + trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, f(), optimizer) + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.PeriodicWriter([CommonMetricPrinter(max_iter)]), + hooks.TorchProfiler( + lambda trainer: trainer.iter == max_iter - 1, cfg.OUTPUT_DIR, save_tensorboard=True + ), + ] + ) + trainer.train(1, max_iter) + + +@torch.no_grad() +def benchmark_eval(args) -> None: + cfg = setup(args) + if args.config_file.endswith(".yaml"): + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + + cfg.defrost() + cfg.DATALOADER.NUM_WORKERS = 0 + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + else: + model = instantiate(cfg.model) + model.to(cfg.train.device) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + + cfg.dataloader.num_workers = 0 + data_loader = instantiate(cfg.dataloader.test) + + model.eval() + logger.info(f"Model:\n{model}") + dummy_data = DatasetFromList(list(itertools.islice(data_loader, 100)), copy=False) + + def f(): + while True: + yield from dummy_data + + for k in range(5): # warmup + model(dummy_data[k]) + + max_iter = 300 + timer = Timer() + with tqdm.tqdm(total=max_iter) as pbar: + for idx, d in enumerate(f()): + if idx == max_iter: + break + model(d) + pbar.update() + logger.info(f"{max_iter} iters in {timer.seconds()} seconds.") + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument("--task", choices=["train", "eval", "data", "data_advanced"], required=True) + args = parser.parse_args() + assert not args.eval_only + + logger.info("Environment info:\n" + collect_env_info()) + if "data" in args.task: + print("Initial " + RAM_msg()) + if args.task == "data": + f = benchmark_data + if args.task == "data_advanced": + f = benchmark_data_advanced + elif args.task == "train": + """ + Note: training speed may not be representative. + The training cost of a R-CNN model varies with the content of the data + and the quality of the model. + """ + f = benchmark_train + elif args.task == "eval": + f = benchmark_eval + # only benchmark single-GPU inference. + assert args.num_gpus == 1 and args.num_machines == 1 + launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,)) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/convert-torchvision-to-d2.py b/dimos/models/Detic/third_party/CenterNet2/tools/convert-torchvision-to-d2.py new file mode 100755 index 0000000000..8bf0565d5e --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/convert-torchvision-to-d2.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + +import pickle as pkl +import sys + +import torch + +""" +Usage: + # download one of the ResNet{18,34,50,101,152} models from torchvision: + wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth + # run the conversion + ./convert-torchvision-to-d2.py r50.pth r50.pkl + + # Then, use r50.pkl with the following changes in config: + +MODEL: + WEIGHTS: "/path/to/r50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False +INPUT: + FORMAT: "RGB" + + These models typically produce slightly worse results than the + pre-trained ResNets we use in official configs, which are the + original ResNet models released by MSRA. +""" + +if __name__ == "__main__": + input = sys.argv[1] + + obj = torch.load(input, map_location="cpu") + + newmodel = {} + for k in list(obj.keys()): + old_k = k + if "layer" not in k: + k = "stem." + k + for t in [1, 2, 3, 4]: + k = k.replace(f"layer{t}", f"res{t + 1}") + for t in [1, 2, 3]: + k = k.replace(f"bn{t}", f"conv{t}.norm") + k = k.replace("downsample.0", "shortcut") + k = k.replace("downsample.1", "shortcut.norm") + print(old_k, "->", k) + newmodel[k] = obj.pop(old_k).detach().numpy() + + res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} + + with open(sys.argv[2], "wb") as f: + pkl.dump(res, f) + if obj: + print("Unconverted keys:", obj.keys()) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/deploy/CMakeLists.txt b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/CMakeLists.txt new file mode 100644 index 0000000000..80dae12500 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# See https://pytorch.org/tutorials/advanced/cpp_frontend.html +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) +project(torchscript_mask_rcnn) + +find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED) +find_package(TorchVision REQUIRED) # needed by export-method=tracing/scripting + +add_executable(torchscript_mask_rcnn torchscript_mask_rcnn.cpp) +target_link_libraries( + torchscript_mask_rcnn + -Wl,--no-as-needed TorchVision::TorchVision -Wl,--as-needed + "${TORCH_LIBRARIES}" ${OpenCV_LIBS}) +set_property(TARGET torchscript_mask_rcnn PROPERTY CXX_STANDARD 14) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/deploy/README.md b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/README.md new file mode 100644 index 0000000000..e33cbeb54c --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/README.md @@ -0,0 +1,66 @@ +See [deployment tutorial](https://detectron2.readthedocs.io/tutorials/deployment.html) +for some high-level background about deployment. + +This directory contains the following examples: + +1. An example script `export_model.py` + that exports a detectron2 model for deployment using different methods and formats. + +2. A C++ example that runs inference with Mask R-CNN model in TorchScript format. + +## Build +Deployment depends on libtorch and OpenCV. Some require more dependencies: + +* Running TorchScript-format models produced by `--export-method=caffe2_tracing` requires libtorch + to be built with caffe2 enabled. +* Running TorchScript-format models produced by `--export-method=tracing/scripting` requires libtorchvision (C++ library of torchvision). + +All methods are supported in one C++ file that requires all the above dependencies. +Adjust it and remove code you don't need. +As a reference, we provide a [Dockerfile](../../docker/deploy.Dockerfile) that installs all the above dependencies and builds the C++ example. + +## Use + +We show a few example commands to export and execute a Mask R-CNN model in C++. + +* `export-method=tracing, format=torchscript`: +``` +./export_model.py --config-file ../../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + --output ./output --export-method tracing --format torchscript \ + MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl \ + MODEL.DEVICE cuda + +./build/torchscript_mask_rcnn output/model.ts input.jpg tracing +``` + +* `export-method=scripting, format=torchscript`: +``` +./export_model.py --config-file ../../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + --output ./output --export-method scripting --format torchscript \ + MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl \ + +./build/torchscript_mask_rcnn output/model.ts input.jpg scripting +``` + +* `export-method=caffe2_tracing, format=torchscript`: + +``` +./export_model.py --config-file ../../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ + --output ./output --export-method caffe2_tracing --format torchscript \ + MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl \ + +./build/torchscript_mask_rcnn output/model.ts input.jpg caffe2_tracing +``` + + +## Notes: + +1. Tracing/Caffe2-tracing requires valid weights & sample inputs. + Therefore the above commands require pre-trained models and [COCO dataset](https://detectron2.readthedocs.io/tutorials/builtin_datasets.html). + You can modify the script to obtain sample inputs in other ways instead of from COCO. + +2. `--run-eval` is implemented only for tracing mode + to evaluate the exported model using the dataset in the config. + It's recommended to always verify the accuracy in case the conversion is not successful. + Evaluation can be slow if model is exported to CPU or dataset is too large ("coco_2017_val_100" is a small subset of COCO useful for evaluation). + `caffe2_tracing` accuracy may be slightly different (within 0.1 AP) from original model due to numerical precisions between different runtime. diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/deploy/export_model.py b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/export_model.py new file mode 100755 index 0000000000..6b9d2d60be --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/export_model.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import os +from typing import Dict, List, Tuple + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import build_detection_test_loader, detection_utils +import detectron2.data.transforms as T +from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format +from detectron2.export import TracingAdapter, dump_torchscript_IR, scripting_with_instances +from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model +from detectron2.modeling.postprocessing import detector_postprocess +from detectron2.projects.point_rend import add_pointrend_config +from detectron2.structures import Boxes +from detectron2.utils.env import TORCH_VERSION +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger +import torch +from torch import Tensor, nn + + +def setup_cfg(args): + cfg = get_cfg() + # cuda context is initialized before creating dataloader, so we don't fork anymore + cfg.DATALOADER.NUM_WORKERS = 0 + add_pointrend_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + return cfg + + +def export_caffe2_tracing(cfg, torch_model, inputs): + from detectron2.export import Caffe2Tracer + + tracer = Caffe2Tracer(cfg, torch_model, inputs) + if args.format == "caffe2": + caffe2_model = tracer.export_caffe2() + caffe2_model.save_protobuf(args.output) + # draw the caffe2 graph + caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs) + return caffe2_model + elif args.format == "onnx": + import onnx + + onnx_model = tracer.export_onnx() + onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) + elif args.format == "torchscript": + ts_model = tracer.export_torchscript() + with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: + torch.jit.save(ts_model, f) + dump_torchscript_IR(ts_model, args.output) + + +# experimental. API not yet final +def export_scripting(torch_model): + assert TORCH_VERSION >= (1, 8) + fields = { + "proposal_boxes": Boxes, + "objectness_logits": Tensor, + "pred_boxes": Boxes, + "scores": Tensor, + "pred_classes": Tensor, + "pred_masks": Tensor, + "pred_keypoints": torch.Tensor, + "pred_keypoint_heatmaps": torch.Tensor, + } + assert args.format == "torchscript", "Scripting only supports torchscript format." + + class ScriptableAdapterBase(nn.Module): + # Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944 + # by not retuning instances but dicts. Otherwise the exported model is not deployable + def __init__(self) -> None: + super().__init__() + self.model = torch_model + self.eval() + + if isinstance(torch_model, GeneralizedRCNN): + + class ScriptableAdapter(ScriptableAdapterBase): + def forward(self, inputs: tuple[dict[str, torch.Tensor]]) -> list[dict[str, Tensor]]: + instances = self.model.inference(inputs, do_postprocess=False) + return [i.get_fields() for i in instances] + + else: + + class ScriptableAdapter(ScriptableAdapterBase): + def forward(self, inputs: tuple[dict[str, torch.Tensor]]) -> list[dict[str, Tensor]]: + instances = self.model(inputs) + return [i.get_fields() for i in instances] + + ts_model = scripting_with_instances(ScriptableAdapter(), fields) + with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: + torch.jit.save(ts_model, f) + dump_torchscript_IR(ts_model, args.output) + # TODO inference in Python now missing postprocessing glue code + return None + + +# experimental. API not yet final +def export_tracing(torch_model, inputs): + assert TORCH_VERSION >= (1, 8) + image = inputs[0]["image"] + inputs = [{"image": image}] # remove other unused keys + + if isinstance(torch_model, GeneralizedRCNN): + + def inference(model, inputs): + # use do_postprocess=False so it returns ROI mask + inst = model.inference(inputs, do_postprocess=False)[0] + return [{"instances": inst}] + + else: + inference = None # assume that we just call the model directly + + traceable_model = TracingAdapter(torch_model, inputs, inference) + + if args.format == "torchscript": + ts_model = torch.jit.trace(traceable_model, (image,)) + with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: + torch.jit.save(ts_model, f) + dump_torchscript_IR(ts_model, args.output) + elif args.format == "onnx": + with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f: + torch.onnx.export(traceable_model, (image,), f, opset_version=11) + logger.info("Inputs schema: " + str(traceable_model.inputs_schema)) + logger.info("Outputs schema: " + str(traceable_model.outputs_schema)) + + if args.format != "torchscript": + return None + if not isinstance(torch_model, GeneralizedRCNN | RetinaNet): + return None + + def eval_wrapper(inputs): + """ + The exported model does not contain the final resize step, which is typically + unused in deployment but needed for evaluation. We add it manually here. + """ + input = inputs[0] + instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"] + postprocessed = detector_postprocess(instances, input["height"], input["width"]) + return [{"instances": postprocessed}] + + return eval_wrapper + + +def get_sample_inputs(args): + if args.sample_image is None: + # get a first batch from dataset + data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) + first_batch = next(iter(data_loader)) + return first_batch + else: + # get a sample data + original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT) + # Do same preprocessing as DefaultPredictor + aug = T.ResizeShortestEdge( + [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST + ) + height, width = original_image.shape[:2] + image = aug.get_transform(original_image).apply_image(original_image) + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + + inputs = {"image": image, "height": height, "width": width} + + # Sample ready + sample_inputs = [inputs] + return sample_inputs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export a model for deployment.") + parser.add_argument( + "--format", + choices=["caffe2", "onnx", "torchscript"], + help="output format", + default="torchscript", + ) + parser.add_argument( + "--export-method", + choices=["caffe2_tracing", "tracing", "scripting"], + help="Method to export models", + default="tracing", + ) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") + parser.add_argument("--run-eval", action="store_true") + parser.add_argument("--output", help="output directory for the converted model") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + args = parser.parse_args() + logger = setup_logger() + logger.info("Command line arguments: " + str(args)) + PathManager.mkdirs(args.output) + # Disable respecialization on new shapes. Otherwise --run-eval will be slow + torch._C._jit_set_bailout_depth(1) + + cfg = setup_cfg(args) + + # create a torch model + torch_model = build_model(cfg) + DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) + torch_model.eval() + + # get sample data + sample_inputs = get_sample_inputs(args) + + # convert and save model + if args.export_method == "caffe2_tracing": + exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs) + elif args.export_method == "scripting": + exported_model = export_scripting(torch_model) + elif args.export_method == "tracing": + exported_model = export_tracing(torch_model, sample_inputs) + + # run evaluation with the converted model + if args.run_eval: + assert exported_model is not None, ( + f"Python inference is not yet implemented for export_method={args.export_method}, format={args.format}." + ) + logger.info("Running evaluation ... this takes a long time if you export to CPU.") + dataset = cfg.DATASETS.TEST[0] + data_loader = build_detection_test_loader(cfg, dataset) + # NOTE: hard-coded evaluator. change to the evaluator for your dataset + evaluator = COCOEvaluator(dataset, output_dir=args.output) + metrics = inference_on_dataset(exported_model, data_loader, evaluator) + print_csv_format(metrics) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/deploy/torchscript_mask_rcnn.cpp b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/torchscript_mask_rcnn.cpp new file mode 100644 index 0000000000..b40f13b81f --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/torchscript_mask_rcnn.cpp @@ -0,0 +1,187 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// @lint-ignore-every CLANGTIDY +// This is an example code that demonstrates how to run inference +// with a torchscript format Mask R-CNN model exported by ./export_model.py +// using export method=tracing, caffe2_tracing & scripting. + +#include +#include +#include + +#include +#include +#include +#include + +// only needed for export_method=tracing +#include // @oss-only +// @fb-only: #include + +using namespace std; + +c10::IValue get_caffe2_tracing_inputs(cv::Mat& img, c10::Device device) { + const int height = img.rows; + const int width = img.cols; + // FPN models require divisibility of 32. + // Tracing mode does padding inside the graph, but caffe2_tracing does not. + assert(height % 32 == 0 && width % 32 == 0); + const int channels = 3; + + auto input = + torch::from_blob(img.data, {1, height, width, channels}, torch::kUInt8); + // NHWC to NCHW + input = input.to(device, torch::kFloat).permute({0, 3, 1, 2}).contiguous(); + + std::array im_info_data{height * 1.0f, width * 1.0f, 1.0f}; + auto im_info = + torch::from_blob(im_info_data.data(), {1, 3}).clone().to(device); + return std::make_tuple(input, im_info); +} + +c10::IValue get_tracing_inputs(cv::Mat& img, c10::Device device) { + const int height = img.rows; + const int width = img.cols; + const int channels = 3; + + auto input = + torch::from_blob(img.data, {height, width, channels}, torch::kUInt8); + // HWC to CHW + input = input.to(device, torch::kFloat).permute({2, 0, 1}).contiguous(); + return input; +} + +// create a Tuple[Dict[str, Tensor]] which is the input type of scripted model +c10::IValue get_scripting_inputs(cv::Mat& img, c10::Device device) { + const int height = img.rows; + const int width = img.cols; + const int channels = 3; + + auto img_tensor = + torch::from_blob(img.data, {height, width, channels}, torch::kUInt8); + // HWC to CHW + img_tensor = + img_tensor.to(device, torch::kFloat).permute({2, 0, 1}).contiguous(); + auto dic = c10::Dict(); + dic.insert("image", img_tensor); + return std::make_tuple(dic); +} + +c10::IValue +get_inputs(std::string export_method, cv::Mat& img, c10::Device device) { + // Given an image, create inputs in the format required by the model. + if (export_method == "tracing") + return get_tracing_inputs(img, device); + if (export_method == "caffe2_tracing") + return get_caffe2_tracing_inputs(img, device); + if (export_method == "scripting") + return get_scripting_inputs(img, device); + abort(); +} + +struct MaskRCNNOutputs { + at::Tensor pred_boxes, pred_classes, pred_masks, scores; + int num_instances() const { + return pred_boxes.sizes()[0]; + } +}; + +MaskRCNNOutputs get_outputs(std::string export_method, c10::IValue outputs) { + // Given outputs of the model, extract tensors from it to turn into a + // common MaskRCNNOutputs format. + if (export_method == "tracing") { + auto out_tuple = outputs.toTuple()->elements(); + // They are ordered alphabetically by their field name in Instances + return MaskRCNNOutputs{ + out_tuple[0].toTensor(), + out_tuple[1].toTensor(), + out_tuple[2].toTensor(), + out_tuple[3].toTensor()}; + } + if (export_method == "caffe2_tracing") { + auto out_tuple = outputs.toTuple()->elements(); + // A legacy order used by caffe2 models + return MaskRCNNOutputs{ + out_tuple[0].toTensor(), + out_tuple[2].toTensor(), + out_tuple[3].toTensor(), + out_tuple[1].toTensor()}; + } + if (export_method == "scripting") { + // With the ScriptableAdapter defined in export_model.py, the output is + // List[Dict[str, Any]]. + auto out_dict = outputs.toList().get(0).toGenericDict(); + return MaskRCNNOutputs{ + out_dict.at("pred_boxes").toTensor(), + out_dict.at("pred_classes").toTensor(), + out_dict.at("pred_masks").toTensor(), + out_dict.at("scores").toTensor()}; + } + abort(); +} + +int main(int argc, const char* argv[]) { + if (argc != 4) { + cerr << R"xx( +Usage: + ./torchscript_mask_rcnn model.ts input.jpg EXPORT_METHOD + + EXPORT_METHOD can be "tracing", "caffe2_tracing" or "scripting". +)xx"; + return 1; + } + std::string image_file = argv[2]; + std::string export_method = argv[3]; + assert( + export_method == "caffe2_tracing" || export_method == "tracing" || + export_method == "scripting"); + + torch::jit::getBailoutDepth() = 1; + torch::autograd::AutoGradMode guard(false); + auto module = torch::jit::load(argv[1]); + + assert(module.buffers().size() > 0); + // Assume that the entire model is on the same device. + // We just put input to this device. + auto device = (*begin(module.buffers())).device(); + + cv::Mat input_img = cv::imread(image_file, cv::IMREAD_COLOR); + auto inputs = get_inputs(export_method, input_img, device); + + // Run the network + auto output = module.forward({inputs}); + if (device.is_cuda()) + c10::cuda::getCurrentCUDAStream().synchronize(); + + // run 3 more times to benchmark + int N_benchmark = 3, N_warmup = 1; + auto start_time = chrono::high_resolution_clock::now(); + for (int i = 0; i < N_benchmark + N_warmup; ++i) { + if (i == N_warmup) + start_time = chrono::high_resolution_clock::now(); + output = module.forward({inputs}); + if (device.is_cuda()) + c10::cuda::getCurrentCUDAStream().synchronize(); + } + auto end_time = chrono::high_resolution_clock::now(); + auto ms = chrono::duration_cast(end_time - start_time) + .count(); + cout << "Latency (should vary with different inputs): " + << ms * 1.0 / 1e6 / N_benchmark << " seconds" << endl; + + // Parse Mask R-CNN outputs + auto rcnn_outputs = get_outputs(export_method, output); + cout << "Number of detected objects: " << rcnn_outputs.num_instances() + << endl; + + cout << "pred_boxes: " << rcnn_outputs.pred_boxes.toString() << " " + << rcnn_outputs.pred_boxes.sizes() << endl; + cout << "scores: " << rcnn_outputs.scores.toString() << " " + << rcnn_outputs.scores.sizes() << endl; + cout << "pred_classes: " << rcnn_outputs.pred_classes.toString() << " " + << rcnn_outputs.pred_classes.sizes() << endl; + cout << "pred_masks: " << rcnn_outputs.pred_masks.toString() << " " + << rcnn_outputs.pred_masks.sizes() << endl; + + cout << rcnn_outputs.pred_boxes << endl; + return 0; +} diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/lazyconfig_train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/lazyconfig_train_net.py new file mode 100755 index 0000000000..8f40a40c39 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/lazyconfig_train_net.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Training script using the new "LazyConfig" python config files. + +This scripts reads a given python config file and runs the training or evaluation. +It can be used to train any models or dataset as long as they can be +instantiated by the recursive construction defined in the given config file. + +Besides lazy construction of models, dataloader, etc., this scripts expects a +few common configuration parameters currently defined in "configs/common/train.py". +To add more complicated training logic, you can easily add other configs +in the config file and implement a new train_net.py to handle them. +""" + +import logging + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + AMPTrainer, + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.utils import comm + +logger = logging.getLogger("detectron2") + + +def do_test(cfg, model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + +def do_train(args, cfg) -> None: + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `configs/common/train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info(f"Model:\n{model}") + model.to(cfg.train.device) + + cfg.optimizer.params.model = model + optim = instantiate(cfg.optimizer) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim) + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + trainer=trainer, + ) + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) + if comm.is_main_process() + else None, + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) + if comm.is_main_process() + else None, + ] + ) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args) -> None: + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + if args.eval_only: + model = instantiate(cfg.model) + model.to(cfg.train.device) + model = create_ddp_model(model) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + print(do_test(cfg, model)) + else: + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/lightning_train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/lightning_train_net.py new file mode 100644 index 0000000000..dbb6cb6e43 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/lightning_train_net.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# Lightning Trainer should be considered beta at this point +# We have confirmed that training and validation run correctly and produce correct results +# Depending on how you launch the trainer, there are issues with processes terminating correctly +# This module is still dependent on D2 logging, but could be transferred to use Lightning logging + +from collections import OrderedDict +import logging +import os +import time +from typing import Any, Dict, List +import weakref + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import build_detection_test_loader, build_detection_train_loader +from detectron2.engine import ( + DefaultTrainer, + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, +) +from detectron2.evaluation import print_csv_format +from detectron2.evaluation.testing import flatten_results_dict +from detectron2.modeling import build_model +from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm +from detectron2.utils.events import EventStorage +from detectron2.utils.logger import setup_logger +import pytorch_lightning as pl # type: ignore +from pytorch_lightning import LightningDataModule, LightningModule +from train_net import build_evaluator + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("detectron2") + + +class TrainingModule(LightningModule): + def __init__(self, cfg) -> None: + super().__init__() + if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 + setup_logger() + self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) + self.storage: EventStorage = None + self.model = build_model(self.cfg) + + self.start_iter = 0 + self.max_iter = cfg.SOLVER.MAX_ITER + + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + checkpoint["iteration"] = self.storage.iter + + def on_load_checkpoint(self, checkpointed_state: dict[str, Any]) -> None: + self.start_iter = checkpointed_state["iteration"] + self.storage.iter = self.start_iter + + def setup(self, stage: str) -> None: + if self.cfg.MODEL.WEIGHTS: + self.checkpointer = DetectionCheckpointer( + # Assume you want to save checkpoints together with logs/statistics + self.model, + self.cfg.OUTPUT_DIR, + ) + logger.info(f"Load model weights from checkpoint: {self.cfg.MODEL.WEIGHTS}.") + # Only load weights, use lightning checkpointing if you want to resume + self.checkpointer.load(self.cfg.MODEL.WEIGHTS) + + self.iteration_timer = hooks.IterationTimer() + self.iteration_timer.before_train() + self.data_start = time.perf_counter() + self.writers = None + + def training_step(self, batch, batch_idx): + data_time = time.perf_counter() - self.data_start + # Need to manually enter/exit since trainer may launch processes + # This ideally belongs in setup, but setup seems to run before processes are spawned + if self.storage is None: + self.storage = EventStorage(0) + self.storage.__enter__() + self.iteration_timer.trainer = weakref.proxy(self) + self.iteration_timer.before_step() + self.writers = ( + default_writers(self.cfg.OUTPUT_DIR, self.max_iter) + if comm.is_main_process() + else {} + ) + + loss_dict = self.model(batch) + SimpleTrainer.write_metrics(loss_dict, data_time) + + opt = self.optimizers() + self.storage.put_scalar( + "lr", opt.param_groups[self._best_param_group_id]["lr"], smoothing_hint=False + ) + self.iteration_timer.after_step() + self.storage.step() + # A little odd to put before step here, but it's the best way to get a proper timing + self.iteration_timer.before_step() + + if self.storage.iter % 20 == 0: + for writer in self.writers: + writer.write() + return sum(loss_dict.values()) + + def training_step_end(self, training_step_outpus): + self.data_start = time.perf_counter() + return training_step_outpus + + def training_epoch_end(self, training_step_outputs) -> None: + self.iteration_timer.after_train() + if comm.is_main_process(): + self.checkpointer.save("model_final") + for writer in self.writers: + writer.write() + writer.close() + self.storage.__exit__(None, None, None) + + def _process_dataset_evaluation_results(self) -> OrderedDict: + results = OrderedDict() + for idx, dataset_name in enumerate(self.cfg.DATASETS.TEST): + results[dataset_name] = self._evaluators[idx].evaluate() + if comm.is_main_process(): + print_csv_format(results[dataset_name]) + + if len(results) == 1: + results = next(iter(results.values())) + return results + + def _reset_dataset_evaluators(self) -> None: + self._evaluators = [] + for dataset_name in self.cfg.DATASETS.TEST: + evaluator = build_evaluator(self.cfg, dataset_name) + evaluator.reset() + self._evaluators.append(evaluator) + + def on_validation_epoch_start(self, _outputs) -> None: + self._reset_dataset_evaluators() + + def validation_epoch_end(self, _outputs): + results = self._process_dataset_evaluation_results(_outputs) + + flattened_results = flatten_results_dict(results) + for k, v in flattened_results.items(): + try: + v = float(v) + except Exception as e: + raise ValueError( + f"[EvalHook] eval_function should return a nested dict of float. Got '{k}: {v}' instead." + ) from e + self.storage.put_scalars(**flattened_results, smoothing_hint=False) + + def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None: + if not isinstance(batch, list): + batch = [batch] + outputs = self.model(batch) + self._evaluators[dataloader_idx].process(batch, outputs) + + def configure_optimizers(self): + optimizer = build_optimizer(self.cfg, self.model) + self._best_param_group_id = hooks.LRScheduler.get_best_param_group_id(optimizer) + scheduler = build_lr_scheduler(self.cfg, optimizer) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + +class DataModule(LightningDataModule): + def __init__(self, cfg) -> None: + super().__init__() + self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) + + def train_dataloader(self): + return build_detection_train_loader(self.cfg) + + def val_dataloader(self): + dataloaders = [] + for dataset_name in self.cfg.DATASETS.TEST: + dataloaders.append(build_detection_test_loader(self.cfg, dataset_name)) + return dataloaders + + +def main(args) -> None: + cfg = setup(args) + train(cfg, args) + + +def train(cfg, args) -> None: + trainer_params = { + # training loop is bounded by max steps, use a large max_epochs to make + # sure max_steps is met first + "max_epochs": 10**8, + "max_steps": cfg.SOLVER.MAX_ITER, + "val_check_interval": cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else 10**8, + "num_nodes": args.num_machines, + "gpus": args.num_gpus, + "num_sanity_val_steps": 0, + } + if cfg.SOLVER.AMP.ENABLED: + trainer_params["precision"] = 16 + + last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") + if args.resume: + # resume training from checkpoint + trainer_params["resume_from_checkpoint"] = last_checkpoint + logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") + + trainer = pl.Trainer(**trainer_params) + logger.info(f"start to train with {args.num_machines} nodes and {args.num_gpus} GPUs") + + module = TrainingModule(cfg) + data_module = DataModule(cfg) + if args.eval_only: + logger.info("Running inference") + trainer.validate(module, data_module) + else: + logger.info("Running training") + trainer.fit(module, data_module) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + logger.info("Command Line Args:", args) + main(args) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/plain_train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/plain_train_net.py new file mode 100755 index 0000000000..a06d19aff2 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/plain_train_net.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Detectron2 training script with a plain training loop. + +This script reads a given config file and runs the training or evaluation. +It is an entry point that is able to train standard models in detectron2. + +In order to let one script support training of many models, +this script contains logic that are specific to these built-in models and therefore +may not be suitable for your own project. +For example, your research project perhaps only needs a single "evaluator". + +Therefore, we recommend you to use detectron2 as a library and take +this file as an example of how to use the library. +You may want to write your own script with your datasets and other customizations. + +Compared to "train_net.py", this script supports fewer default features. +It also includes fewer abstraction, therefore is easier to add custom logic. +""" + +from collections import OrderedDict +import logging +import os + +from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer +from detectron2.config import get_cfg +from detectron2.data import ( + MetadataCatalog, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.engine import default_argument_parser, default_setup, default_writers, launch +from detectron2.evaluation import ( + CityscapesInstanceEvaluator, + CityscapesSemSegEvaluator, + COCOEvaluator, + COCOPanopticEvaluator, + DatasetEvaluators, + LVISEvaluator, + PascalVOCDetectionEvaluator, + SemSegEvaluator, + inference_on_dataset, + print_csv_format, +) +from detectron2.modeling import build_model +from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm +from detectron2.utils.events import EventStorage +import torch +from torch.nn.parallel import DistributedDataParallel + +logger = logging.getLogger("detectron2") + + +def get_evaluator(cfg, dataset_name: str, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: + evaluator_list.append( + SemSegEvaluator( + dataset_name, + distributed=True, + output_dir=output_folder, + ) + ) + if evaluator_type in ["coco", "coco_panoptic_seg"]: + evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) + if evaluator_type == "coco_panoptic_seg": + evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) + if evaluator_type == "cityscapes_instance": + assert torch.cuda.device_count() > comm.get_rank(), ( + "CityscapesEvaluator currently do not work with multiple machines." + ) + return CityscapesInstanceEvaluator(dataset_name) + if evaluator_type == "cityscapes_sem_seg": + assert torch.cuda.device_count() > comm.get_rank(), ( + "CityscapesEvaluator currently do not work with multiple machines." + ) + return CityscapesSemSegEvaluator(dataset_name) + if evaluator_type == "pascal_voc": + return PascalVOCDetectionEvaluator(dataset_name) + if evaluator_type == "lvis": + return LVISEvaluator(dataset_name, cfg, True, output_folder) + if len(evaluator_list) == 0: + raise NotImplementedError( + f"no Evaluator for the dataset {dataset_name} with the type {evaluator_type}" + ) + if len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + +def do_test(cfg, model): + results = OrderedDict() + for dataset_name in cfg.DATASETS.TEST: + data_loader = build_detection_test_loader(cfg, dataset_name) + evaluator = get_evaluator( + cfg, dataset_name, os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) + ) + results_i = inference_on_dataset(model, data_loader, evaluator) + results[dataset_name] = results_i + if comm.is_main_process(): + logger.info(f"Evaluation results for {dataset_name} in csv format:") + print_csv_format(results_i) + if len(results) == 1: + results = next(iter(results.values())) + return results + + +def do_train(cfg, model, resume: bool=False) -> None: + model.train() + optimizer = build_optimizer(cfg, model) + scheduler = build_lr_scheduler(cfg, optimizer) + + checkpointer = DetectionCheckpointer( + model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler + ) + start_iter = ( + checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + ) + max_iter = cfg.SOLVER.MAX_ITER + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter + ) + + writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else [] + + # compared to "train_net.py", we do not support accurate timing and + # precise BN here, because they are not trivial to implement in a small training loop + data_loader = build_detection_train_loader(cfg) + logger.info(f"Starting training from iteration {start_iter}") + with EventStorage(start_iter) as storage: + for data, iteration in zip(data_loader, range(start_iter, max_iter), strict=False): + storage.iter = iteration + + loss_dict = model(data) + losses = sum(loss_dict.values()) + assert torch.isfinite(losses).all(), loss_dict + + loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()} + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + if comm.is_main_process(): + storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) + scheduler.step() + + if ( + cfg.TEST.EVAL_PERIOD > 0 + and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0 + and iteration != max_iter - 1 + ): + do_test(cfg, model) + # Compared to "train_net.py", the test results are not dumped to EventStorage + comm.synchronize() + + if iteration - start_iter > 5 and ( + (iteration + 1) % 20 == 0 or iteration == max_iter - 1 + ): + for writer in writers: + writer.write() + periodic_checkpointer.step(iteration) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup( + cfg, args + ) # if you don't like any of the default setup, write your own setup code + return cfg + + +def main(args): + cfg = setup(args) + + model = build_model(cfg) + logger.info(f"Model:\n{model}") + if args.eval_only: + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + return do_test(cfg, model) + + distributed = comm.get_world_size() > 1 + if distributed: + model = DistributedDataParallel( + model, device_ids=[comm.get_local_rank()], broadcast_buffers=False + ) + + do_train(cfg, model, resume=args.resume) + return do_test(cfg, model) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/train_net.py new file mode 100755 index 0000000000..deb2ca6db8 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/train_net.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +A main training script. + +This scripts reads a given config file and runs the training or evaluation. +It is an entry point that is made to train standard models in detectron2. + +In order to let one script support training of many models, +this script contains logic that are specific to these built-in models and therefore +may not be suitable for your own project. +For example, your research project perhaps only needs a single "evaluator". + +Therefore, we recommend you to use detectron2 as an library and take +this file as an example of how to use the library. +You may want to write your own script with your datasets and other customizations. +""" + +from collections import OrderedDict +import logging +import os + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch +from detectron2.evaluation import ( + CityscapesInstanceEvaluator, + CityscapesSemSegEvaluator, + COCOEvaluator, + COCOPanopticEvaluator, + DatasetEvaluators, + LVISEvaluator, + PascalVOCDetectionEvaluator, + SemSegEvaluator, + verify_results, +) +from detectron2.modeling import GeneralizedRCNNWithTTA +import detectron2.utils.comm as comm +import torch + + +def build_evaluator(cfg, dataset_name: str, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: + evaluator_list.append( + SemSegEvaluator( + dataset_name, + distributed=True, + output_dir=output_folder, + ) + ) + if evaluator_type in ["coco", "coco_panoptic_seg"]: + evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) + if evaluator_type == "coco_panoptic_seg": + evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) + if evaluator_type == "cityscapes_instance": + assert torch.cuda.device_count() > comm.get_rank(), ( + "CityscapesEvaluator currently do not work with multiple machines." + ) + return CityscapesInstanceEvaluator(dataset_name) + if evaluator_type == "cityscapes_sem_seg": + assert torch.cuda.device_count() > comm.get_rank(), ( + "CityscapesEvaluator currently do not work with multiple machines." + ) + return CityscapesSemSegEvaluator(dataset_name) + elif evaluator_type == "pascal_voc": + return PascalVOCDetectionEvaluator(dataset_name) + elif evaluator_type == "lvis": + return LVISEvaluator(dataset_name, output_dir=output_folder) + if len(evaluator_list) == 0: + raise NotImplementedError( + f"no Evaluator for the dataset {dataset_name} with the type {evaluator_type}" + ) + elif len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + +class Trainer(DefaultTrainer): + """ + We use the "DefaultTrainer" which contains pre-defined default logic for + standard training workflow. They may not work for you, especially if you + are working on a new research project. In that case you can write your + own training loop. You can use "tools/plain_train_net.py" as an example. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name: str, output_folder=None): + return build_evaluator(cfg, dataset_name, output_folder) + + @classmethod + def test_with_TTA(cls, cfg, model): + logger = logging.getLogger("detectron2.trainer") + # In the end of training, run an evaluation with TTA + # Only support some R-CNN models. + logger.info("Running inference with test-time augmentation ...") + model = GeneralizedRCNNWithTTA(cfg, model) + evaluators = [ + cls.build_evaluator( + cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") + ) + for name in cfg.DATASETS.TEST + ] + res = cls.test(cfg, model, evaluators) + res = OrderedDict({k + "_TTA": v for k, v in res.items()}) + return res + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if cfg.TEST.AUG.ENABLED: + res.update(Trainer.test_with_TTA(cfg, model)) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + """ + If you'd like to do anything fancier than the standard training logic, + consider writing your own training loop (see plain_train_net.py) or + subclassing the trainer. + """ + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + if cfg.TEST.AUG.ENABLED: + trainer.register_hooks( + [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] + ) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/visualize_data.py b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_data.py new file mode 100755 index 0000000000..99abfdff4e --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_data.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +from itertools import chain +import os + +import cv2 +from detectron2.config import get_cfg +from detectron2.data import ( + DatasetCatalog, + MetadataCatalog, + build_detection_train_loader, + detection_utils as utils, +) +from detectron2.data.build import filter_images_with_few_keypoints +from detectron2.utils.logger import setup_logger +from detectron2.utils.visualizer import Visualizer +import tqdm + + +def setup(args): + cfg = get_cfg() + if args.config_file: + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.DATALOADER.NUM_WORKERS = 0 + cfg.freeze() + return cfg + + +def parse_args(in_args=None): + parser = argparse.ArgumentParser(description="Visualize ground-truth data") + parser.add_argument( + "--source", + choices=["annotation", "dataloader"], + required=True, + help="visualize the annotations or the data loader (with pre-processing)", + ) + parser.add_argument("--config-file", metavar="FILE", help="path to config file") + parser.add_argument("--output-dir", default="./", help="path to output directory") + parser.add_argument("--show", action="store_true", help="show output in a window") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + return parser.parse_args(in_args) + + +if __name__ == "__main__": + args = parse_args() + logger = setup_logger() + logger.info("Arguments: " + str(args)) + cfg = setup(args) + + dirname = args.output_dir + os.makedirs(dirname, exist_ok=True) + metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) + + def output(vis, fname) -> None: + if args.show: + print(fname) + cv2.imshow("window", vis.get_image()[:, :, ::-1]) + cv2.waitKey() + else: + filepath = os.path.join(dirname, fname) + print(f"Saving to {filepath} ...") + vis.save(filepath) + + scale = 1.0 + if args.source == "dataloader": + train_data_loader = build_detection_train_loader(cfg) + for batch in train_data_loader: + for per_image in batch: + # Pytorch tensor is in (C, H, W) format + img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy() + img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT) + + visualizer = Visualizer(img, metadata=metadata, scale=scale) + target_fields = per_image["instances"].get_fields() + labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]] + vis = visualizer.overlay_instances( + labels=labels, + boxes=target_fields.get("gt_boxes", None), + masks=target_fields.get("gt_masks", None), + keypoints=target_fields.get("gt_keypoints", None), + ) + output(vis, str(per_image["image_id"]) + ".jpg") + else: + dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in cfg.DATASETS.TRAIN])) + if cfg.MODEL.KEYPOINT_ON: + dicts = filter_images_with_few_keypoints(dicts, 1) + for dic in tqdm.tqdm(dicts): + img = utils.read_image(dic["file_name"], "RGB") + visualizer = Visualizer(img, metadata=metadata, scale=scale) + vis = visualizer.draw_dataset_dict(dic) + output(vis, os.path.basename(dic["file_name"])) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/visualize_json_results.py b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_json_results.py new file mode 100755 index 0000000000..04dea72446 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_json_results.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. + +import argparse +from collections import defaultdict +import json +import os + +import cv2 +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import Boxes, BoxMode, Instances +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import setup_logger +from detectron2.utils.visualizer import Visualizer +import numpy as np +import tqdm + + +def create_instances(predictions, image_size: int): + ret = Instances(image_size) + + score = np.asarray([x["score"] for x in predictions]) + chosen = (score > args.conf_threshold).nonzero()[0] + score = score[chosen] + bbox = np.asarray([predictions[i]["bbox"] for i in chosen]).reshape(-1, 4) + bbox = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) + + labels = np.asarray([dataset_id_map(predictions[i]["category_id"]) for i in chosen]) + + ret.scores = score + ret.pred_boxes = Boxes(bbox) + ret.pred_classes = labels + + try: + ret.pred_masks = [predictions[i]["segmentation"] for i in chosen] + except KeyError: + pass + return ret + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="A script that visualizes the json predictions from COCO or LVIS dataset." + ) + parser.add_argument("--input", required=True, help="JSON file produced by the model") + parser.add_argument("--output", required=True, help="output directory") + parser.add_argument("--dataset", help="name of the dataset", default="coco_2017_val") + parser.add_argument("--conf-threshold", default=0.5, type=float, help="confidence threshold") + args = parser.parse_args() + + logger = setup_logger() + + with PathManager.open(args.input, "r") as f: + predictions = json.load(f) + + pred_by_image = defaultdict(list) + for p in predictions: + pred_by_image[p["image_id"]].append(p) + + dicts = list(DatasetCatalog.get(args.dataset)) + metadata = MetadataCatalog.get(args.dataset) + if hasattr(metadata, "thing_dataset_id_to_contiguous_id"): + + def dataset_id_map(ds_id): + return metadata.thing_dataset_id_to_contiguous_id[ds_id] + + elif "lvis" in args.dataset: + # LVIS results are in the same format as COCO results, but have a different + # mapping from dataset category id to contiguous category id in [0, #categories - 1] + def dataset_id_map(ds_id): + return ds_id - 1 + + else: + raise ValueError(f"Unsupported dataset: {args.dataset}") + + os.makedirs(args.output, exist_ok=True) + + for dic in tqdm.tqdm(dicts): + img = cv2.imread(dic["file_name"], cv2.IMREAD_COLOR)[:, :, ::-1] + basename = os.path.basename(dic["file_name"]) + + predictions = create_instances(pred_by_image[dic["image_id"]], img.shape[:2]) + vis = Visualizer(img, metadata) + vis_pred = vis.draw_instance_predictions(predictions).get_image() + + vis = Visualizer(img, metadata) + vis_gt = vis.draw_dataset_dict(dic).get_image() + + concat = np.concatenate((vis_pred, vis_gt), axis=1) + cv2.imwrite(os.path.join(args.output, basename), concat[:, :, ::-1]) diff --git a/dimos/models/Detic/third_party/CenterNet2/train_net.py b/dimos/models/Detic/third_party/CenterNet2/train_net.py new file mode 100644 index 0000000000..92859d7586 --- /dev/null +++ b/dimos/models/Detic/third_party/CenterNet2/train_net.py @@ -0,0 +1,227 @@ +from collections import OrderedDict +import datetime +import logging +import os +import time + +from centernet.config import add_centernet_config +from centernet.data.custom_build_augmentation import build_custom_augmentation +from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer +from detectron2.config import get_cfg +from detectron2.data import ( + MetadataCatalog, + build_detection_test_loader, +) +from detectron2.data.build import build_detection_train_loader +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.engine import default_argument_parser, default_setup, launch +from detectron2.evaluation import ( + COCOEvaluator, + LVISEvaluator, + inference_on_dataset, + print_csv_format, +) +from detectron2.modeling import build_model +from detectron2.modeling.test_time_augmentation import GeneralizedRCNNWithTTA +from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm +from detectron2.utils.events import ( + CommonMetricPrinter, + EventStorage, + JSONWriter, + TensorboardXWriter, +) +from fvcore.common.timer import Timer +import torch +from torch.nn.parallel import DistributedDataParallel + +logger = logging.getLogger("detectron2") + + +def do_test(cfg, model): + results = OrderedDict() + for dataset_name in cfg.DATASETS.TEST: + mapper = ( + None + if cfg.INPUT.TEST_INPUT_TYPE == "default" + else DatasetMapper(cfg, False, augmentations=build_custom_augmentation(cfg, False)) + ) + data_loader = build_detection_test_loader(cfg, dataset_name, mapper=mapper) + output_folder = os.path.join(cfg.OUTPUT_DIR, f"inference_{dataset_name}") + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + + if evaluator_type == "lvis": + evaluator = LVISEvaluator(dataset_name, cfg, True, output_folder) + elif evaluator_type == "coco": + evaluator = COCOEvaluator(dataset_name, cfg, True, output_folder) + else: + assert 0, evaluator_type + + results[dataset_name] = inference_on_dataset(model, data_loader, evaluator) + if comm.is_main_process(): + logger.info(f"Evaluation results for {dataset_name} in csv format:") + print_csv_format(results[dataset_name]) + if len(results) == 1: + results = next(iter(results.values())) + return results + + +def do_train(cfg, model, resume: bool=False) -> None: + model.train() + optimizer = build_optimizer(cfg, model) + scheduler = build_lr_scheduler(cfg, optimizer) + + checkpointer = DetectionCheckpointer( + model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler + ) + + start_iter = ( + checkpointer.resume_or_load( + cfg.MODEL.WEIGHTS, + resume=resume, + ).get("iteration", -1) + + 1 + ) + if cfg.SOLVER.RESET_ITER: + logger.info("Reset loaded iteration. Start training from iteration 0.") + start_iter = 0 + max_iter = cfg.SOLVER.MAX_ITER if cfg.SOLVER.TRAIN_ITER < 0 else cfg.SOLVER.TRAIN_ITER + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter + ) + + writers = ( + [ + CommonMetricPrinter(max_iter), + JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), + TensorboardXWriter(cfg.OUTPUT_DIR), + ] + if comm.is_main_process() + else [] + ) + + mapper = ( + DatasetMapper(cfg, True) + if cfg.INPUT.CUSTOM_AUG == "" + else DatasetMapper(cfg, True, augmentations=build_custom_augmentation(cfg, True)) + ) + if cfg.DATALOADER.SAMPLER_TRAIN in ["TrainingSampler", "RepeatFactorTrainingSampler"]: + data_loader = build_detection_train_loader(cfg, mapper=mapper) + else: + from centernet.data.custom_dataset_dataloader import build_custom_train_loader + + data_loader = build_custom_train_loader(cfg, mapper=mapper) + + logger.info(f"Starting training from iteration {start_iter}") + with EventStorage(start_iter) as storage: + step_timer = Timer() + data_timer = Timer() + start_time = time.perf_counter() + for data, iteration in zip(data_loader, range(start_iter, max_iter), strict=False): + data_time = data_timer.seconds() + storage.put_scalars(data_time=data_time) + step_timer.reset() + iteration = iteration + 1 + storage.step() + loss_dict = model(data) + + losses = sum(loss for k, loss in loss_dict.items()) + assert torch.isfinite(losses).all(), loss_dict + + loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()} + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + if comm.is_main_process(): + storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) + + step_time = step_timer.seconds() + storage.put_scalars(time=step_time) + data_timer.reset() + scheduler.step() + + if ( + cfg.TEST.EVAL_PERIOD > 0 + and iteration % cfg.TEST.EVAL_PERIOD == 0 + and iteration != max_iter + ): + do_test(cfg, model) + comm.synchronize() + + if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter): + for writer in writers: + writer.write() + periodic_checkpointer.step(iteration) + + total_time = time.perf_counter() - start_time + logger.info( + f"Total training time: {datetime.timedelta(seconds=int(total_time))!s}" + ) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_centernet_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + if "/auto" in cfg.OUTPUT_DIR: + file_name = os.path.basename(args.config_file)[:-5] + cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace("/auto", f"/{file_name}") + logger.info(f"OUTPUT_DIR: {cfg.OUTPUT_DIR}") + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + model = build_model(cfg) + logger.info(f"Model:\n{model}") + if args.eval_only: + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + if cfg.TEST.AUG.ENABLED: + logger.info("Running inference with test-time augmentation ...") + model = GeneralizedRCNNWithTTA(cfg, model, batch_size=1) + + return do_test(cfg, model) + + distributed = comm.get_world_size() > 1 + if distributed: + model = DistributedDataParallel( + model, + device_ids=[comm.get_local_rank()], + broadcast_buffers=False, + find_unused_parameters=True, + ) + + do_train(cfg, model, resume=args.resume) + return do_test(cfg, model) + + +if __name__ == "__main__": + args = default_argument_parser() + args.add_argument("--manual_device", default="") + args = args.parse_args() + if args.manual_device != "": + os.environ["CUDA_VISIBLE_DEVICES"] = args.manual_device + args.dist_url = f"tcp://127.0.0.1:{torch.randint(11111, 60000, (1,))[0].item()}" + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/LICENSE b/dimos/models/Detic/third_party/Deformable-DETR/LICENSE new file mode 100644 index 0000000000..522e5bd3b6 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/LICENSE @@ -0,0 +1,220 @@ +Copyright (c) 2020 SenseTime. All Rights Reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 SenseTime + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +DETR + +Copyright 2020 - present, Facebook, Inc + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/dimos/models/Detic/third_party/Deformable-DETR/README.md b/dimos/models/Detic/third_party/Deformable-DETR/README.md new file mode 100644 index 0000000000..c9db563511 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/README.md @@ -0,0 +1,169 @@ +# Deformable DETR + +By [Xizhou Zhu](https://scholar.google.com/citations?user=02RXI00AAAAJ), [Weijie Su](https://www.weijiesu.com/), [Lewei Lu](https://www.linkedin.com/in/lewei-lu-94015977/), [Bin Li](http://staff.ustc.edu.cn/~binli/), [Xiaogang Wang](http://www.ee.cuhk.edu.hk/~xgwang/), [Jifeng Dai](https://jifengdai.org/). + +This repository is an official implementation of the paper [Deformable DETR: Deformable Transformers for End-to-End Object Detection](https://arxiv.org/abs/2010.04159). + + +## Introduction + +**TL; DR.** Deformable DETR is an efficient and fast-converging end-to-end object detector. It mitigates the high complexity and slow convergence issues of DETR via a novel sampling-based efficient attention mechanism. + +![deformable_detr](./figs/illustration.png) + +![deformable_detr](./figs/convergence.png) + +**Abstract.** DETR has been recently proposed to eliminate the need for many hand-designed components in object detection while demonstrating good performance. However, it suffers from slow convergence and limited feature spatial resolution, due to the limitation of Transformer attention modules in processing image feature maps. To mitigate these issues, we proposed Deformable DETR, whose attention modules only attend to a small set of key sampling points around a reference. Deformable DETR can achieve better performance than DETR (especially on small objects) with 10× less training epochs. Extensive experiments on the COCO benchmark demonstrate the effectiveness of our approach. + +## License + +This project is released under the [Apache 2.0 license](./LICENSE). + +## Changelog + +See [changelog.md](./docs/changelog.md) for detailed logs of major changes. + + +## Citing Deformable DETR +If you find Deformable DETR useful in your research, please consider citing: +```bibtex +@article{zhu2020deformable, + title={Deformable DETR: Deformable Transformers for End-to-End Object Detection}, + author={Zhu, Xizhou and Su, Weijie and Lu, Lewei and Li, Bin and Wang, Xiaogang and Dai, Jifeng}, + journal={arXiv preprint arXiv:2010.04159}, + year={2020} +} +``` + +## Main Results + +| Method | Epochs | AP | APS | APM | APL | params
(M)
| FLOPs
(G)
| Total
Train
Time
(GPU
hours)
| Train
Speed
(GPU
hours
/epoch)
| Infer
Speed
(FPS)
| Batch
Infer
Speed
(FPS)
| URL | +| ----------------------------------- | :----: | :--: | :----: | :---: | :------------------------------: | :--------------------:| :----------------------------------------------------------: | :--: | :---: | :---: | ----- | ----- | +| Faster R-CNN + FPN | 109 | 42.0 | 26.6 | 45.4 | 53.4 | 42 | 180 | 380 | 3.5 | 25.6 | 28.0 | - | +| DETR | 500 | 42.0 | 20.5 | 45.8 | 61.1 | 41 | 86 | 2000 | 4.0 | 27.0 | 38.3 | - | +| DETR-DC5 | 500 | 43.3 | 22.5 | 47.3 | 61.1 | 41 |187|7000|14.0|11.4|12.4| - | +| DETR-DC5 | 50 | 35.3 | 15.2 | 37.5 | 53.6 | 41 |187|700|14.0|11.4|12.4| - | +| DETR-DC5+ | 50 | 36.2 | 16.3 | 39.2 | 53.9 | 41 |187|700|14.0|11.4|12.4| - | +| **Deformable DETR
(single scale)
** | 50 | 39.4 | 20.6 | 43.0 | 55.5 | 34 |78|160|3.2|27.0|42.4| [config](./configs/r50_deformable_detr_single_scale.sh)
[log](https://drive.google.com/file/d/1n3ZnZ-UAqmTUR4AZoM4qQntIDn6qCZx4/view?usp=sharing)
[model](https://drive.google.com/file/d/1WEjQ9_FgfI5sw5OZZ4ix-OKk-IJ_-SDU/view?usp=sharing)
| +| **Deformable DETR
(single scale, DC5)
** | 50 | 41.5 | 24.1 | 45.3 | 56.0 | 34 |128|215|4.3|22.1|29.4| [config](./configs/r50_deformable_detr_single_scale_dc5.sh)
[log](https://drive.google.com/file/d/1-UfTp2q4GIkJjsaMRIkQxa5k5vn8_n-B/view?usp=sharing)
[model](https://drive.google.com/file/d/1m_TgMjzH7D44fbA-c_jiBZ-xf-odxGdk/view?usp=sharing)
| +| **Deformable DETR** | 50 | 44.5 | 27.1 | 47.6 | 59.6 | 40 |173|325|6.5|15.0|19.4|[config](./configs/r50_deformable_detr.sh)
[log](https://drive.google.com/file/d/18YSLshFjc_erOLfFC-hHu4MX4iyz1Dqr/view?usp=sharing)
[model](https://drive.google.com/file/d/1nDWZWHuRwtwGden77NLM9JoWe-YisJnA/view?usp=sharing)
| +| **+ iterative bounding box refinement** | 50 | 46.2 | 28.3 | 49.2 | 61.5 | 41 |173|325|6.5|15.0|19.4|[config](./configs/r50_deformable_detr_plus_iterative_bbox_refinement.sh)
[log](https://drive.google.com/file/d/1DFNloITi1SFBWjYzvVEAI75ndwmGM1Uj/view?usp=sharing)
[model](https://drive.google.com/file/d/1JYKyRYzUH7uo9eVfDaVCiaIGZb5YTCuI/view?usp=sharing)
| +| **++ two-stage Deformable DETR** | 50 | 46.9 | 29.6 | 50.1 | 61.6 | 41 |173|340|6.8|14.5|18.8|[config](./configs/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh)
[log](https://drive.google.com/file/d/1ozi0wbv5-Sc5TbWt1jAuXco72vEfEtbY/view?usp=sharing)
[model](https://drive.google.com/file/d/15I03A7hNTpwuLNdfuEmW9_taZMNVssEp/view?usp=sharing)
| + +*Note:* + +1. All models of Deformable DETR are trained with total batch size of 32. +2. Training and inference speed are measured on NVIDIA Tesla V100 GPU. +3. "Deformable DETR (single scale)" means only using res5 feature map (of stride 32) as input feature maps for Deformable Transformer Encoder. +4. "DC5" means removing the stride in C5 stage of ResNet and add a dilation of 2 instead. +5. "DETR-DC5+" indicates DETR-DC5 with some modifications, including using Focal Loss for bounding box classification and increasing number of object queries to 300. +6. "Batch Infer Speed" refer to inference with batch size = 4 to maximize GPU utilization. +7. The original implementation is based on our internal codebase. There are slight differences in the final accuracy and running time due to the plenty details in platform switch. + + +## Installation + +### Requirements + +* Linux, CUDA>=9.2, GCC>=5.4 + +* Python>=3.7 + + We recommend you to use Anaconda to create a conda environment: + ```bash + conda create -n deformable_detr python=3.7 pip + ``` + Then, activate the environment: + ```bash + conda activate deformable_detr + ``` + +* PyTorch>=1.5.1, torchvision>=0.6.1 (following instructions [here](https://pytorch.org/)) + + For example, if your CUDA version is 9.2, you could install pytorch and torchvision as following: + ```bash + conda install pytorch=1.5.1 torchvision=0.6.1 cudatoolkit=9.2 -c pytorch + ``` + +* Other requirements + ```bash + pip install -r requirements.txt + ``` + +### Compiling CUDA operators +```bash +cd ./models/ops +sh ./make.sh +# unit test (should see all checking is True) +python test.py +``` + +## Usage + +### Dataset preparation + +Please download [COCO 2017 dataset](https://cocodataset.org/) and organize them as following: + +``` +code_root/ +└── data/ + └── coco/ + ├── train2017/ + ├── val2017/ + └── annotations/ + ├── instances_train2017.json + └── instances_val2017.json +``` + +### Training + +#### Training on single node + +For example, the command for training Deformable DETR on 8 GPUs is as following: + +```bash +GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/r50_deformable_detr.sh +``` + +#### Training on multiple nodes + +For example, the command for training Deformable DETR on 2 nodes of each with 8 GPUs is as following: + +On node 1: + +```bash +MASTER_ADDR= NODE_RANK=0 GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 16 ./configs/r50_deformable_detr.sh +``` + +On node 2: + +```bash +MASTER_ADDR= NODE_RANK=1 GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 16 ./configs/r50_deformable_detr.sh +``` + +#### Training on slurm cluster + +If you are using slurm cluster, you can simply run the following command to train on 1 node with 8 GPUs: + +```bash +GPUS_PER_NODE=8 ./tools/run_dist_slurm.sh deformable_detr 8 configs/r50_deformable_detr.sh +``` + +Or 2 nodes of each with 8 GPUs: + +```bash +GPUS_PER_NODE=8 ./tools/run_dist_slurm.sh deformable_detr 16 configs/r50_deformable_detr.sh +``` +#### Some tips to speed-up training +* If your file system is slow to read images, you may consider enabling '--cache_mode' option to load whole dataset into memory at the beginning of training. +* You may increase the batch size to maximize the GPU utilization, according to GPU memory of yours, e.g., set '--batch_size 3' or '--batch_size 4'. + +### Evaluation + +You can get the config file and pretrained model of Deformable DETR (the link is in "Main Results" session), then run following command to evaluate it on COCO 2017 validation set: + +```bash + --resume --eval +``` + +You can also run distributed evaluation by using ```./tools/run_dist_launch.sh``` or ```./tools/run_dist_slurm.sh```. diff --git a/dimos/models/Detic/third_party/Deformable-DETR/benchmark.py b/dimos/models/Detic/third_party/Deformable-DETR/benchmark.py new file mode 100644 index 0000000000..3a4fcbd4e6 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/benchmark.py @@ -0,0 +1,70 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +""" +Benchmark inference speed of Deformable DETR. +""" + +import argparse +import os +import time + +from datasets import build_dataset +from main import get_args_parser as get_main_args_parser +from models import build_model +import torch +from util.misc import nested_tensor_from_tensor_list + + +def get_benckmark_arg_parser(): + parser = argparse.ArgumentParser("Benchmark inference speed of Deformable DETR.") + parser.add_argument("--num_iters", type=int, default=300, help="total iters to benchmark speed") + parser.add_argument( + "--warm_iters", type=int, default=5, help="ignore first several iters that are very slow" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference") + parser.add_argument("--resume", type=str, help="load the pre-trained checkpoint") + return parser + + +@torch.no_grad() +def measure_average_inference_time(model, inputs, num_iters: int=100, warm_iters: int=5): + ts = [] + for iter_ in range(num_iters): + torch.cuda.synchronize() + t_ = time.perf_counter() + model(inputs) + torch.cuda.synchronize() + t = time.perf_counter() - t_ + if iter_ >= warm_iters: + ts.append(t) + print(ts) + return sum(ts) / len(ts) + + +def benchmark(): + args, _ = get_benckmark_arg_parser().parse_known_args() + main_args = get_main_args_parser().parse_args(_) + assert args.warm_iters < args.num_iters and args.num_iters > 0 and args.warm_iters >= 0 + assert args.batch_size > 0 + assert args.resume is None or os.path.exists(args.resume) + dataset = build_dataset("val", main_args) + model, _, _ = build_model(main_args) + model.cuda() + model.eval() + if args.resume is not None: + ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt["model"]) + inputs = nested_tensor_from_tensor_list( + [dataset.__getitem__(0)[0].cuda() for _ in range(args.batch_size)] + ) + t = measure_average_inference_time(model, inputs, args.num_iters, args.warm_iters) + return 1.0 / t * args.batch_size + + +if __name__ == "__main__": + fps = benchmark() + print(f"Inference Speed: {fps:.1f} FPS") diff --git a/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr.sh b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr.sh new file mode 100755 index 0000000000..a42953f266 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/r50_deformable_detr +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + ${PY_ARGS} diff --git a/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_plus_iterative_bbox_refinement.sh b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_plus_iterative_bbox_refinement.sh new file mode 100755 index 0000000000..8ea20006b1 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_plus_iterative_bbox_refinement.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/r50_deformable_detr_plus_iterative_bbox_refinement +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + ${PY_ARGS} diff --git a/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh new file mode 100755 index 0000000000..722c658e45 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/r50_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage +PY_ARGS=${@:1} + +python -u main.py \ + --output_dir ${EXP_DIR} \ + --with_box_refine \ + --two_stage \ + ${PY_ARGS} diff --git a/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_single_scale.sh b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_single_scale.sh new file mode 100755 index 0000000000..a24e54718d --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_single_scale.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/r50_deformable_detr_single_scale +PY_ARGS=${@:1} + +python -u main.py \ + --num_feature_levels 1 \ + --output_dir ${EXP_DIR} \ + ${PY_ARGS} diff --git a/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_single_scale_dc5.sh b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_single_scale_dc5.sh new file mode 100755 index 0000000000..26d35d6a49 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/configs/r50_deformable_detr_single_scale_dc5.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -x + +EXP_DIR=exps/r50_deformable_detr_single_scale_dc5 +PY_ARGS=${@:1} + +python -u main.py \ + --num_feature_levels 1 \ + --dilation \ + --output_dir ${EXP_DIR} \ + ${PY_ARGS} diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/__init__.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/__init__.py new file mode 100644 index 0000000000..870166e145 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/__init__.py @@ -0,0 +1,34 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import torch.utils.data + +from .coco import build as build_coco +from .torchvision_datasets import CocoDetection + + +def get_coco_api_from_dataset(dataset): + for _ in range(10): + # if isinstance(dataset, torchvision.datasets.CocoDetection): + # break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, CocoDetection): + return dataset.coco + + +def build_dataset(image_set, args): + if args.dataset_file == "coco": + return build_coco(image_set, args) + if args.dataset_file == "coco_panoptic": + # to avoid making panopticapi required for coco + from .coco_panoptic import build as build_coco_panoptic + + return build_coco_panoptic(image_set, args) + raise ValueError(f"dataset {args.dataset_file} not supported") diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco.py new file mode 100644 index 0000000000..aa00ce49e3 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco.py @@ -0,0 +1,194 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +COCO dataset which returns image_id for evaluation. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py +""" + +from pathlib import Path + +from pycocotools import mask as coco_mask +import torch +import torch.utils.data +from util.misc import get_local_rank, get_local_size + +import datasets.transforms as T + +from .torchvision_datasets import CocoDetection as TvCocoDetection + + +class CocoDetection(TvCocoDetection): + def __init__( + self, + img_folder, + ann_file, + transforms, + return_masks, + cache_mode: bool=False, + local_rank: int=0, + local_size: int=1, + ) -> None: + super().__init__( + img_folder, + ann_file, + cache_mode=cache_mode, + local_rank=local_rank, + local_size=local_size, + ) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks) + + def __getitem__(self, idx: int): + img, target = super().__getitem__(idx) + image_id = self.ids[idx] + target = {"image_id": image_id, "annotations": target} + img, target = self.prepare(img, target) + if self._transforms is not None: + img, target = self._transforms(img, target) + return img, target + + +def convert_coco_poly_to_mask(segmentations, height, width: int): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask: + def __init__(self, return_masks: bool=False) -> None: + self.return_masks = return_masks + + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + if self.return_masks: + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + return image, target + + +def make_coco_transforms(image_set): + normalize = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + + scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] + + if image_set == "train": + return T.Compose( + [ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=1333), + T.Compose( + [ + T.RandomResize([400, 500, 600]), + T.RandomSizeCrop(384, 600), + T.RandomResize(scales, max_size=1333), + ] + ), + ), + normalize, + ] + ) + + if image_set == "val": + return T.Compose( + [ + T.RandomResize([800], max_size=1333), + normalize, + ] + ) + + raise ValueError(f"unknown {image_set}") + + +def build(image_set, args): + root = Path(args.coco_path) + assert root.exists(), f"provided COCO path {root} does not exist" + mode = "instances" + PATHS = { + "train": (root / "train2017", root / "annotations" / f"{mode}_train2017.json"), + "val": (root / "val2017", root / "annotations" / f"{mode}_val2017.json"), + } + + img_folder, ann_file = PATHS[image_set] + dataset = CocoDetection( + img_folder, + ann_file, + transforms=make_coco_transforms(image_set), + return_masks=args.masks, + cache_mode=args.cache_mode, + local_rank=get_local_rank(), + local_size=get_local_size(), + ) + return dataset diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_eval.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_eval.py new file mode 100644 index 0000000000..1a0e7962bd --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_eval.py @@ -0,0 +1,265 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" + +import contextlib +import copy +import os + +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +import pycocotools.mask as mask_util +import torch +from util.misc import all_gather + + +class CocoEvaluator: + def __init__(self, coco_gt, iou_types) -> None: + assert isinstance(iou_types, list | tuple) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions) -> None: + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self) -> None: + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval( + self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type] + ) + + def accumulate(self) -> None: + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self) -> None: + for iou_type, coco_eval in self.coco_eval.items(): + print(f"IoU metric: {iou_type}") + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError(f"Unknown iou type {iou_type}") + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "keypoints": keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs) -> None: + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +def evaluate(self): + """ + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + """ + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = "segm" if p.useSegm == 1 else "bbox" + print(f"useSegm (deprecated) is not None. Running {p.iouType} evaluation") + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == "segm" or p.iouType == "bbox": + computeIoU = self.computeIoU + elif p.iouType == "keypoints": + computeIoU = self.computeOks + self.ious = {(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_panoptic.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_panoptic.py new file mode 100644 index 0000000000..d1dd9bda59 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_panoptic.py @@ -0,0 +1,119 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import json +from pathlib import Path + +import numpy as np +from panopticapi.utils import rgb2id +from PIL import Image +import torch +from util.box_ops import masks_to_boxes + +from .coco import make_coco_transforms + + +class CocoPanoptic: + def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks: bool=True) -> None: + with open(ann_file) as f: + self.coco = json.load(f) + + # sort 'images' field so that they are aligned with 'annotations' + # i.e., in alphabetical order + self.coco["images"] = sorted(self.coco["images"], key=lambda x: x["id"]) + # sanity check + if "annotations" in self.coco: + for img, ann in zip(self.coco["images"], self.coco["annotations"], strict=False): + assert img["file_name"][:-4] == ann["file_name"][:-4] + + self.img_folder = img_folder + self.ann_folder = ann_folder + self.ann_file = ann_file + self.transforms = transforms + self.return_masks = return_masks + + def __getitem__(self, idx: int): + ann_info = ( + self.coco["annotations"][idx] + if "annotations" in self.coco + else self.coco["images"][idx] + ) + img_path = Path(self.img_folder) / ann_info["file_name"].replace(".png", ".jpg") + ann_path = Path(self.ann_folder) / ann_info["file_name"] + + img = Image.open(img_path).convert("RGB") + w, h = img.size + if "segments_info" in ann_info: + masks = np.asarray(Image.open(ann_path), dtype=np.uint32) + masks = rgb2id(masks) + + ids = np.array([ann["id"] for ann in ann_info["segments_info"]]) + masks = masks == ids[:, None, None] + + masks = torch.as_tensor(masks, dtype=torch.uint8) + labels = torch.tensor( + [ann["category_id"] for ann in ann_info["segments_info"]], dtype=torch.int64 + ) + + target = {} + target["image_id"] = torch.tensor( + [ann_info["image_id"] if "image_id" in ann_info else ann_info["id"]] + ) + if self.return_masks: + target["masks"] = masks + target["labels"] = labels + + target["boxes"] = masks_to_boxes(masks) + + target["size"] = torch.as_tensor([int(h), int(w)]) + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + if "segments_info" in ann_info: + for name in ["iscrowd", "area"]: + target[name] = torch.tensor([ann[name] for ann in ann_info["segments_info"]]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self) -> int: + return len(self.coco["images"]) + + def get_height_and_width(self, idx: int): + img_info = self.coco["images"][idx] + height = img_info["height"] + width = img_info["width"] + return height, width + + +def build(image_set, args): + img_folder_root = Path(args.coco_path) + ann_folder_root = Path(args.coco_panoptic_path) + assert img_folder_root.exists(), f"provided COCO path {img_folder_root} does not exist" + assert ann_folder_root.exists(), f"provided COCO path {ann_folder_root} does not exist" + mode = "panoptic" + PATHS = { + "train": ("train2017", Path("annotations") / f"{mode}_train2017.json"), + "val": ("val2017", Path("annotations") / f"{mode}_val2017.json"), + } + + img_folder, ann_file = PATHS[image_set] + img_folder_path = img_folder_root / img_folder + ann_folder = ann_folder_root / f"{mode}_{img_folder}" + ann_file = ann_folder_root / ann_file + + dataset = CocoPanoptic( + img_folder_path, + ann_folder, + ann_file, + transforms=make_coco_transforms(image_set), + return_masks=args.masks, + ) + + return dataset diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/data_prefetcher.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/data_prefetcher.py new file mode 100644 index 0000000000..4942500801 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/data_prefetcher.py @@ -0,0 +1,74 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import torch + + +def to_cuda(samples, targets, device): + samples = samples.to(device, non_blocking=True) + targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets] + return samples, targets + + +class data_prefetcher: + def __init__(self, loader, device, prefetch: bool=True) -> None: + self.loader = iter(loader) + self.prefetch = prefetch + self.device = device + if prefetch: + self.stream = torch.cuda.Stream() + self.preload() + + def preload(self) -> None: + try: + self.next_samples, self.next_targets = next(self.loader) + except StopIteration: + self.next_samples = None + self.next_targets = None + return + # if record_stream() doesn't work, another option is to make sure device inputs are created + # on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') + # Need to make sure the memory allocated for next_* is not still in use by the main stream + # at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.next_samples, self.next_targets = to_cuda( + self.next_samples, self.next_targets, self.device + ) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.next_input = self.next_input.half() + # else: + + def next(self): + if self.prefetch: + torch.cuda.current_stream().wait_stream(self.stream) + samples = self.next_samples + targets = self.next_targets + if samples is not None: + samples.record_stream(torch.cuda.current_stream()) + if targets is not None: + for t in targets: + for _k, v in t.items(): + v.record_stream(torch.cuda.current_stream()) + self.preload() + else: + try: + samples, targets = next(self.loader) + samples, targets = to_cuda(samples, targets, self.device) + except StopIteration: + samples = None + targets = None + return samples, targets diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/panoptic_eval.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/panoptic_eval.py new file mode 100644 index 0000000000..1a8ed7a82f --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/panoptic_eval.py @@ -0,0 +1,57 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import json +import os + +import util.misc as utils + +try: + from panopticapi.evaluation import pq_compute +except ImportError: + pass + + +class PanopticEvaluator: + def __init__(self, ann_file, ann_folder, output_dir: str="panoptic_eval") -> None: + self.gt_json = ann_file + self.gt_folder = ann_folder + if utils.is_main_process(): + if not os.path.exists(output_dir): + os.mkdir(output_dir) + self.output_dir = output_dir + self.predictions = [] + + def update(self, predictions) -> None: + for p in predictions: + with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: + f.write(p.pop("png_string")) + + self.predictions += predictions + + def synchronize_between_processes(self) -> None: + all_predictions = utils.all_gather(self.predictions) + merged_predictions = [] + for p in all_predictions: + merged_predictions += p + self.predictions = merged_predictions + + def summarize(self): + if utils.is_main_process(): + json_data = {"annotations": self.predictions} + predictions_json = os.path.join(self.output_dir, "predictions.json") + with open(predictions_json, "w") as f: + f.write(json.dumps(json_data)) + return pq_compute( + self.gt_json, + predictions_json, + gt_folder=self.gt_folder, + pred_folder=self.output_dir, + ) + return None diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/samplers.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/samplers.py new file mode 100644 index 0000000000..5c2fff2d46 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/samplers.py @@ -0,0 +1,148 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from codes in torch.utils.data.distributed +# ------------------------------------------------------------------------ + +import math +import os + +import torch +import torch.distributed as dist +from torch.utils.data.sampler import Sampler +from typing import Iterator, Optional + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__( + self, dataset, num_replicas: Optional[int]=None, rank=None, local_rank=None, local_size: Optional[int]=None, shuffle: bool=True + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * 1.0 / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + +class NodeDistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__( + self, dataset, num_replicas: Optional[int]=None, rank=None, local_rank=None, local_size: Optional[int]=None, shuffle: bool=True + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if local_rank is None: + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if local_size is None: + local_size = int(os.environ.get("LOCAL_SIZE", 1)) + self.dataset = dataset + self.shuffle = shuffle + self.num_replicas = num_replicas + self.num_parts = local_size + self.rank = rank + self.local_rank = local_rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * 1.0 / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + indices = [i for i in indices if i % self.num_parts == self.local_rank] + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size_parts - len(indices))] + assert len(indices) == self.total_size_parts + + # subsample + indices = indices[ + self.rank // self.num_parts : self.total_size_parts : self.num_replicas + // self.num_parts + ] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/__init__.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/__init__.py new file mode 100644 index 0000000000..162303c4ce --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/__init__.py @@ -0,0 +1,7 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from .coco import CocoDetection diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/coco.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/coco.py new file mode 100644 index 0000000000..65eb674294 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/coco.py @@ -0,0 +1,96 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from torchvision +# ------------------------------------------------------------------------ + +""" +Copy-Paste from torchvision, but add utility of caching images on memory +""" + +from io import BytesIO +import os +import os.path + +from PIL import Image +from torchvision.datasets.vision import VisionDataset +import tqdm + + +class CocoDetection(VisionDataset): + """`MS Coco Detection `_ Dataset. + Args: + root (string): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + def __init__( + self, + root, + annFile, + transform=None, + target_transform=None, + transforms=None, + cache_mode: bool=False, + local_rank: int=0, + local_size: int=1, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + from pycocotools.coco import COCO + + self.coco = COCO(annFile) + self.ids = list(sorted(self.coco.imgs.keys())) + self.cache_mode = cache_mode + self.local_rank = local_rank + self.local_size = local_size + if cache_mode: + self.cache = {} + self.cache_images() + + def cache_images(self) -> None: + self.cache = {} + for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids, strict=False): + if index % self.local_size != self.local_rank: + continue + path = self.coco.loadImgs(img_id)[0]["file_name"] + with open(os.path.join(self.root, path), "rb") as f: + self.cache[path] = f.read() + + def get_image(self, path): + if self.cache_mode: + if path not in self.cache.keys(): + with open(os.path.join(self.root, path), "rb") as f: + self.cache[path] = f.read() + return Image.open(BytesIO(self.cache[path])).convert("RGB") + return Image.open(os.path.join(self.root, path)).convert("RGB") + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. + """ + coco = self.coco + img_id = self.ids[index] + ann_ids = coco.getAnnIds(imgIds=img_id) + target = coco.loadAnns(ann_ids) + + path = coco.loadImgs(img_id)[0]["file_name"] + + img = self.get_image(path) + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self) -> int: + return len(self.ids) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/transforms.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/transforms.py new file mode 100644 index 0000000000..3c2947ee36 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/transforms.py @@ -0,0 +1,290 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Transforms and data augmentation for both image + bbox. +""" + +import random + +import PIL +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +from util.box_ops import box_xyxy_to_cxcywh +from util.misc import interpolate +from typing import Optional, Sequence + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target["masks"] = target["masks"][:, i : i + h, j : j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target["boxes"].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target["masks"].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( + [w, 0, w, 0] + ) + target["boxes"] = boxes + + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + + return flipped_image, target + + +def resize(image, target, size: int, max_size: Optional[int]=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size: int, size: int, max_size: Optional[int]=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = round(max_size * min_original_size / max_original_size) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size: int, size: int, max_size: Optional[int]=None): + if isinstance(size, list | tuple): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target["masks"] = ( + interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + ) + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image[::-1]) + if "masks" in target: + target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class RandomCrop: + def __init__(self, size: int) -> None: + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class RandomSizeCrop: + def __init__(self, min_size: int, max_size: int) -> None: + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class CenterCrop: + def __init__(self, size: int) -> None: + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = round((image_height - crop_height) / 2.0) + crop_left = round((image_width - crop_width) / 2.0) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class RandomHorizontalFlip: + def __init__(self, p: float=0.5) -> None: + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class RandomResize: + def __init__(self, sizes: Sequence[int], max_size: Optional[int]=None) -> None: + assert isinstance(sizes, list | tuple) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class RandomPad: + def __init__(self, max_pad) -> None: + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class RandomSelect: + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + + def __init__(self, transforms1, transforms2, p: float=0.5) -> None: + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class ToTensor: + def __call__(self, img, target): + return F.to_tensor(img), target + + +class RandomErasing: + def __init__(self, *args, **kwargs) -> None: + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class Normalize: + def __init__(self, mean, std) -> None: + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose: + def __init__(self, transforms) -> None: + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string diff --git a/dimos/models/Detic/third_party/Deformable-DETR/docs/changelog.md b/dimos/models/Detic/third_party/Deformable-DETR/docs/changelog.md new file mode 100644 index 0000000000..1ed5e79a4d --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/docs/changelog.md @@ -0,0 +1,3 @@ +## Changelog + +**[2020.12.07]** Fix a bug of sampling offset normalization (see [this issue](https://github.com/fundamentalvision/Deformable-DETR/issues/6)) in the MSDeformAttn module. The final accuracy on COCO is slightly improved. Code and pre-trained models have been updated. This bug only occurs in this released version but not in the original implementation used in our paper. \ No newline at end of file diff --git a/dimos/models/Detic/third_party/Deformable-DETR/engine.py b/dimos/models/Detic/third_party/Deformable-DETR/engine.py new file mode 100644 index 0000000000..7e6e7c2c20 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/engine.py @@ -0,0 +1,177 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Train and eval functions used in main.py +""" + +import math +import os +import sys +from typing import Iterable + +from datasets.coco_eval import CocoEvaluator +from datasets.data_prefetcher import data_prefetcher +from datasets.panoptic_eval import PanopticEvaluator +import torch +import util.misc as utils + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + max_norm: float = 0, +): + model.train() + criterion.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) + metric_logger.add_meter("class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")) + metric_logger.add_meter("grad_norm", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")) + header = f"Epoch: [{epoch}]" + print_freq = 10 + + prefetcher = data_prefetcher(data_loader, device, prefetch=True) + samples, targets = prefetcher.next() + + # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header): + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_unscaled = {f"{k}_unscaled": v for k, v in loss_dict_reduced.items()} + loss_dict_reduced_scaled = { + k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict + } + losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) + + loss_value = losses_reduced_scaled.item() + + if not math.isfinite(loss_value): + print(f"Loss is {loss_value}, stopping training") + print(loss_dict_reduced) + sys.exit(1) + + optimizer.zero_grad() + losses.backward() + if max_norm > 0: + grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + else: + grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) + optimizer.step() + + metric_logger.update( + loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled + ) + metric_logger.update(class_error=loss_dict_reduced["class_error"]) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + metric_logger.update(grad_norm=grad_total_norm) + + samples, targets = prefetcher.next() + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): + model.eval() + criterion.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter("class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")) + header = "Test:" + + iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys()) + coco_evaluator = CocoEvaluator(base_ds, iou_types) + # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] + + panoptic_evaluator = None + if "panoptic" in postprocessors.keys(): + panoptic_evaluator = PanopticEvaluator( + data_loader.dataset.ann_file, + data_loader.dataset.ann_folder, + output_dir=os.path.join(output_dir, "panoptic_eval"), + ) + + for samples, targets in metric_logger.log_every(data_loader, 10, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + outputs = model(samples) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + loss_dict_reduced_scaled = { + k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict + } + loss_dict_reduced_unscaled = {f"{k}_unscaled": v for k, v in loss_dict_reduced.items()} + metric_logger.update( + loss=sum(loss_dict_reduced_scaled.values()), + **loss_dict_reduced_scaled, + **loss_dict_reduced_unscaled, + ) + metric_logger.update(class_error=loss_dict_reduced["class_error"]) + + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + results = postprocessors["bbox"](outputs, orig_target_sizes) + if "segm" in postprocessors.keys(): + target_sizes = torch.stack([t["size"] for t in targets], dim=0) + results = postprocessors["segm"](results, outputs, orig_target_sizes, target_sizes) + res = {target["image_id"].item(): output for target, output in zip(targets, results, strict=False)} + if coco_evaluator is not None: + coco_evaluator.update(res) + + if panoptic_evaluator is not None: + res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) + for i, target in enumerate(targets): + image_id = target["image_id"].item() + file_name = f"{image_id:012d}.png" + res_pano[i]["image_id"] = image_id + res_pano[i]["file_name"] = file_name + + panoptic_evaluator.update(res_pano) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + if panoptic_evaluator is not None: + panoptic_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + panoptic_res = None + if panoptic_evaluator is not None: + panoptic_res = panoptic_evaluator.summarize() + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + if "bbox" in postprocessors.keys(): + stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist() + if "segm" in postprocessors.keys(): + stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() + if panoptic_res is not None: + stats["PQ_all"] = panoptic_res["All"] + stats["PQ_th"] = panoptic_res["Things"] + stats["PQ_st"] = panoptic_res["Stuff"] + return stats, coco_evaluator diff --git a/dimos/models/Detic/third_party/Deformable-DETR/main.py b/dimos/models/Detic/third_party/Deformable-DETR/main.py new file mode 100644 index 0000000000..187b93a868 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/main.py @@ -0,0 +1,418 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + + +import argparse +import datetime +import json +from pathlib import Path +import random +import time + +import datasets +from datasets import build_dataset, get_coco_api_from_dataset +import datasets.samplers as samplers +from engine import evaluate, train_one_epoch +from models import build_model +import numpy as np +import torch +from torch.utils.data import DataLoader +import util.misc as utils + + +def get_args_parser(): + parser = argparse.ArgumentParser("Deformable DETR Detector", add_help=False) + parser.add_argument("--lr", default=2e-4, type=float) + parser.add_argument("--lr_backbone_names", default=["backbone.0"], type=str, nargs="+") + parser.add_argument("--lr_backbone", default=2e-5, type=float) + parser.add_argument( + "--lr_linear_proj_names", + default=["reference_points", "sampling_offsets"], + type=str, + nargs="+", + ) + parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float) + parser.add_argument("--batch_size", default=2, type=int) + parser.add_argument("--weight_decay", default=1e-4, type=float) + parser.add_argument("--epochs", default=50, type=int) + parser.add_argument("--lr_drop", default=40, type=int) + parser.add_argument("--lr_drop_epochs", default=None, type=int, nargs="+") + parser.add_argument( + "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm" + ) + + parser.add_argument("--sgd", action="store_true") + + # Variants of Deformable DETR + parser.add_argument("--with_box_refine", default=False, action="store_true") + parser.add_argument("--two_stage", default=False, action="store_true") + + # Model parameters + parser.add_argument( + "--frozen_weights", + type=str, + default=None, + help="Path to the pretrained model. If set, only the mask head will be trained", + ) + + # * Backbone + parser.add_argument( + "--backbone", default="resnet50", type=str, help="Name of the convolutional backbone to use" + ) + parser.add_argument( + "--dilation", + action="store_true", + help="If true, we replace stride with dilation in the last convolutional block (DC5)", + ) + parser.add_argument( + "--position_embedding", + default="sine", + type=str, + choices=("sine", "learned"), + help="Type of positional embedding to use on top of the image features", + ) + parser.add_argument( + "--position_embedding_scale", default=2 * np.pi, type=float, help="position / size * scale" + ) + parser.add_argument( + "--num_feature_levels", default=4, type=int, help="number of feature levels" + ) + + # * Transformer + parser.add_argument( + "--enc_layers", default=6, type=int, help="Number of encoding layers in the transformer" + ) + parser.add_argument( + "--dec_layers", default=6, type=int, help="Number of decoding layers in the transformer" + ) + parser.add_argument( + "--dim_feedforward", + default=1024, + type=int, + help="Intermediate size of the feedforward layers in the transformer blocks", + ) + parser.add_argument( + "--hidden_dim", + default=256, + type=int, + help="Size of the embeddings (dimension of the transformer)", + ) + parser.add_argument( + "--dropout", default=0.1, type=float, help="Dropout applied in the transformer" + ) + parser.add_argument( + "--nheads", + default=8, + type=int, + help="Number of attention heads inside the transformer's attentions", + ) + parser.add_argument("--num_queries", default=300, type=int, help="Number of query slots") + parser.add_argument("--dec_n_points", default=4, type=int) + parser.add_argument("--enc_n_points", default=4, type=int) + + # * Segmentation + parser.add_argument( + "--masks", action="store_true", help="Train segmentation head if the flag is provided" + ) + + # Loss + parser.add_argument( + "--no_aux_loss", + dest="aux_loss", + action="store_false", + help="Disables auxiliary decoding losses (loss at each layer)", + ) + + # * Matcher + parser.add_argument( + "--set_cost_class", default=2, type=float, help="Class coefficient in the matching cost" + ) + parser.add_argument( + "--set_cost_bbox", default=5, type=float, help="L1 box coefficient in the matching cost" + ) + parser.add_argument( + "--set_cost_giou", default=2, type=float, help="giou box coefficient in the matching cost" + ) + + # * Loss coefficients + parser.add_argument("--mask_loss_coef", default=1, type=float) + parser.add_argument("--dice_loss_coef", default=1, type=float) + parser.add_argument("--cls_loss_coef", default=2, type=float) + parser.add_argument("--bbox_loss_coef", default=5, type=float) + parser.add_argument("--giou_loss_coef", default=2, type=float) + parser.add_argument("--focal_alpha", default=0.25, type=float) + + # dataset parameters + parser.add_argument("--dataset_file", default="coco") + parser.add_argument("--coco_path", default="./data/coco", type=str) + parser.add_argument("--coco_panoptic_path", type=str) + parser.add_argument("--remove_difficult", action="store_true") + + parser.add_argument("--output_dir", default="", help="path where to save, empty for no saving") + parser.add_argument("--device", default="cuda", help="device to use for training / testing") + parser.add_argument("--seed", default=42, type=int) + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch") + parser.add_argument("--eval", action="store_true") + parser.add_argument("--num_workers", default=2, type=int) + parser.add_argument( + "--cache_mode", default=False, action="store_true", help="whether to cache images on memory" + ) + + return parser + + +def main(args) -> None: + utils.init_distributed_mode(args) + print(f"git:\n {utils.get_sha()}\n") + + if args.frozen_weights is not None: + assert args.masks, "Frozen training is meant for segmentation only" + print(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + model, criterion, postprocessors = build_model(args) + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of params:", n_parameters) + + dataset_train = build_dataset(image_set="train", args=args) + dataset_val = build_dataset(image_set="val", args=args) + + if args.distributed: + if args.cache_mode: + sampler_train = samplers.NodeDistributedSampler(dataset_train) + sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = samplers.DistributedSampler(dataset_train) + sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + batch_sampler_train = torch.utils.data.BatchSampler( + sampler_train, args.batch_size, drop_last=True + ) + + data_loader_train = DataLoader( + dataset_train, + batch_sampler=batch_sampler_train, + collate_fn=utils.collate_fn, + num_workers=args.num_workers, + pin_memory=True, + ) + data_loader_val = DataLoader( + dataset_val, + args.batch_size, + sampler=sampler_val, + drop_last=False, + collate_fn=utils.collate_fn, + num_workers=args.num_workers, + pin_memory=True, + ) + + # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"] + def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + for n, _p in model_without_ddp.named_parameters(): + print(n) + + param_dicts = [ + { + "params": [ + p + for n, p in model_without_ddp.named_parameters() + if not match_name_keywords(n, args.lr_backbone_names) + and not match_name_keywords(n, args.lr_linear_proj_names) + and p.requires_grad + ], + "lr": args.lr, + }, + { + "params": [ + p + for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad + ], + "lr": args.lr_backbone, + }, + { + "params": [ + p + for n, p in model_without_ddp.named_parameters() + if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad + ], + "lr": args.lr * args.lr_linear_proj_mult, + }, + ] + if args.sgd: + optimizer = torch.optim.SGD( + param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay + ) + else: + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + if args.dataset_file == "coco_panoptic": + # We also evaluate AP during panoptic training, on original coco DS + coco_val = datasets.coco.build("val", args) + base_ds = get_coco_api_from_dataset(coco_val) + else: + base_ds = get_coco_api_from_dataset(dataset_val) + + if args.frozen_weights is not None: + checkpoint = torch.load(args.frozen_weights, map_location="cpu") + model_without_ddp.detr.load_state_dict(checkpoint["model"]) + + output_dir = Path(args.output_dir) + if args.resume: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + missing_keys, unexpected_keys = model_without_ddp.load_state_dict( + checkpoint["model"], strict=False + ) + unexpected_keys = [ + k + for k in unexpected_keys + if not (k.endswith("total_params") or k.endswith("total_ops")) + ] + if len(missing_keys) > 0: + print(f"Missing Keys: {missing_keys}") + if len(unexpected_keys) > 0: + print(f"Unexpected Keys: {unexpected_keys}") + if ( + not args.eval + and "optimizer" in checkpoint + and "lr_scheduler" in checkpoint + and "epoch" in checkpoint + ): + import copy + + p_groups = copy.deepcopy(optimizer.param_groups) + optimizer.load_state_dict(checkpoint["optimizer"]) + for pg, pg_old in zip(optimizer.param_groups, p_groups, strict=False): + pg["lr"] = pg_old["lr"] + pg["initial_lr"] = pg_old["initial_lr"] + print(optimizer.param_groups) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance). + args.override_resumed_lr_drop = True + if args.override_resumed_lr_drop: + print( + "Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler." + ) + lr_scheduler.step_size = args.lr_drop + lr_scheduler.base_lrs = list( + map(lambda group: group["initial_lr"], optimizer.param_groups) + ) + lr_scheduler.step(lr_scheduler.last_epoch) + args.start_epoch = checkpoint["epoch"] + 1 + # check the resumed model + if not args.eval: + test_stats, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir + ) + + if args.eval: + test_stats, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir + ) + if args.output_dir: + utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + return + + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + sampler_train.set_epoch(epoch) + train_stats = train_one_epoch( + model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm + ) + lr_scheduler.step() + if args.output_dir: + checkpoint_paths = [output_dir / "checkpoint.pth"] + # extra checkpoint before LR drop and every 5 epochs + if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 5 == 0: + checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") + for checkpoint_path in checkpoint_paths: + utils.save_on_master( + { + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + }, + checkpoint_path, + ) + + test_stats, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir + ) + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + # for evaluation logs + if coco_evaluator is not None: + (output_dir / "eval").mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ["latest.pth"] + if epoch % 50 == 0: + filenames.append(f"{epoch:03}.pth") + for name in filenames: + torch.save( + coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"Training time {total_time_str}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + "Deformable DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/__init__.py b/dimos/models/Detic/third_party/Deformable-DETR/models/__init__.py new file mode 100644 index 0000000000..46b898b988 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/__init__.py @@ -0,0 +1,14 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +from .deformable_detr import build + + +def build_model(args): + return build(args) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/backbone.py b/dimos/models/Detic/third_party/Deformable-DETR/models/backbone.py new file mode 100644 index 0000000000..cd973fa891 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/backbone.py @@ -0,0 +1,142 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" + +from typing import Dict, List + +import torch +from torch import nn +import torch.nn.functional as F +import torchvision +from torchvision.models._utils import IntermediateLayerGetter +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n, eps: float=1e-5) -> None: + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + + def _load_from_state_dict( + self, state_dict, prefix: str, local_metadata, strict: bool, missing_keys, unexpected_keys, error_msgs + ) -> None: + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = self.eps + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool) -> None: + super().__init__() + for name, parameter in backbone.named_parameters(): + if ( + not train_backbone + or ("layer2" not in name + and "layer3" not in name + and "layer4" not in name) + ): + parameter.requires_grad_(False) + if return_interm_layers: + # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + self.strides = [8, 16, 32] + self.num_channels = [512, 1024, 2048] + else: + return_layers = {"layer4": "0"} + self.strides = [32] + self.num_channels = [2048] + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + + def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool) -> None: + norm_layer = FrozenBatchNorm2d + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), + norm_layer=norm_layer, + ) + assert name not in ("resnet18", "resnet34"), "number of channels are hard coded" + super().__init__(backbone, train_backbone, return_interm_layers) + if dilation: + self.strides[-1] = self.strides[-1] // 2 + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding) -> None: + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: list[NestedTensor] = [] + pos = [] + for _name, x in sorted(xs.items()): + out.append(x) + + # position encoding + for x in out: + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks or (args.num_feature_levels > 1) + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + return model diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_detr.py b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_detr.py new file mode 100644 index 0000000000..661c6b3d98 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_detr.py @@ -0,0 +1,552 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Deformable DETR model and criterion classes. +""" + +import copy +import math + +import torch +from torch import nn +import torch.nn.functional as F +from util import box_ops +from util.misc import ( + NestedTensor, + accuracy, + get_world_size, + interpolate, + inverse_sigmoid, + is_dist_avail_and_initialized, + nested_tensor_from_tensor_list, +) + +from .backbone import build_backbone +from .deformable_transformer import build_deforamble_transformer +from .matcher import build_matcher +from .segmentation import ( + DETRsegm, + PostProcessPanoptic, + PostProcessSegm, + dice_loss, + sigmoid_focal_loss, +) +from typing import Sequence + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DeformableDETR(nn.Module): + """This is the Deformable DETR module that performs object detection""" + + def __init__( + self, + backbone, + transformer, + num_classes: int, + num_queries: int, + num_feature_levels: int, + aux_loss: bool=True, + with_box_refine: bool=False, + two_stage: bool=False, + ) -> None: + """Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.num_feature_levels = num_feature_levels + if not two_stage: + self.query_embed = nn.Embedding(num_queries, hidden_dim * 2) + if num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append( + nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, hidden_dim), + ) + ) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + ] + ) + self.backbone = backbone + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + self.two_stage = two_stage + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = ( + (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers + ) + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + + def forward(self, samples: NestedTensor): + """The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + query_embeds = None + if not self.two_stage: + query_embeds = self.query_embed.weight + hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = ( + self.transformer(srcs, masks, pos, query_embeds) + ) + + outputs_classes = [] + outputs_coords = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.aux_loss: + out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) + + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out["enc_outputs"] = {"pred_logits": enc_outputs_class, "pred_boxes": enc_outputs_coord} + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + {"pred_logits": a, "pred_boxes": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1], strict=False) + ] + + +class SetCriterion(nn.Module): + """This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, num_classes: int, matcher, weight_dict, losses, focal_alpha: float=0.25) -> None: + """Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + + def loss_labels(self, outputs, targets, indices, num_boxes: int, log: bool=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices, strict=False)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros( + [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, + layout=src_logits.layout, + device=src_logits.device, + ) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:, :, :-1] + loss_ce = ( + sigmoid_focal_loss( + src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2 + ) + * src_logits.shape[1] + ) + losses = {"loss_ce": loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes: int): + """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs["pred_logits"] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {"cardinality_error": card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes: int): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + assert "pred_boxes" in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices, strict=False)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + + losses = {} + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag( + box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), box_ops.box_cxcywh_to_xyxy(target_boxes) + ) + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes: int): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list( + [t["masks"] for t in targets] + ).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate( + src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes: int, **kwargs): + loss_map = { + "labels": self.loss_labels, + "cardinality": self.loss_cardinality, + "boxes": self.loss_boxes, + "masks": self.loss_masks, + } + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = { + k: v for k, v in outputs.items() if k != "aux_outputs" and k != "enc_outputs" + } + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor( + [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device + ) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + kwargs = {} + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "aux_outputs" in outputs: + for i, aux_outputs in enumerate(outputs["aux_outputs"]): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == "labels": + # Logging is enabled only for the last layer + kwargs["log"] = False + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + if "enc_outputs" in outputs: + enc_outputs = outputs["enc_outputs"] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt["labels"] = torch.zeros_like(bt["labels"]) + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == "labels": + # Logging is enabled only for the last layer + kwargs["log"] = False + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = {k + "_enc": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +class PostProcess(nn.Module): + """This module converts the model's output into the format expected by the coco api""" + + @torch.no_grad() + def forward(self, outputs, target_sizes: Sequence[int]): + """Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes, strict=False)] + + return results + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers: int) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim], strict=False) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build(args): + num_classes = 20 if args.dataset_file != "coco" else 91 + if args.dataset_file == "coco_panoptic": + num_classes = 250 + device = torch.device(args.device) + + backbone = build_backbone(args) + + transformer = build_deforamble_transformer(args) + model = DeformableDETR( + backbone, + transformer, + num_classes=num_classes, + num_queries=args.num_queries, + num_feature_levels=args.num_feature_levels, + aux_loss=args.aux_loss, + with_box_refine=args.with_box_refine, + two_stage=args.two_stage, + ) + if args.masks: + model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) + matcher = build_matcher(args) + weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef} + weight_dict["loss_giou"] = args.giou_loss_coef + if args.masks: + weight_dict["loss_mask"] = args.mask_loss_coef + weight_dict["loss_dice"] = args.dice_loss_coef + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ["labels", "boxes", "cardinality"] + if args.masks: + losses += ["masks"] + # num_classes, matcher, weight_dict, losses, focal_alpha=0.25 + criterion = SetCriterion( + num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha + ) + criterion.to(device) + postprocessors = {"bbox": PostProcess()} + if args.masks: + postprocessors["segm"] = PostProcessSegm() + if args.dataset_file == "coco_panoptic": + is_thing_map = {i: i <= 90 for i in range(201)} + postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) + + return model, criterion, postprocessors diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_transformer.py b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_transformer.py new file mode 100644 index 0000000000..f3cde19e1b --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_transformer.py @@ -0,0 +1,507 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import copy +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import constant_, normal_, xavier_uniform_ +from util.misc import inverse_sigmoid + +from models.ops.modules import MSDeformAttn + + +class DeformableTransformer(nn.Module): + def __init__( + self, + d_model: int=256, + nhead: int=8, + num_encoder_layers: int=6, + num_decoder_layers: int=6, + dim_feedforward: int=1024, + dropout: float=0.1, + activation: str="relu", + return_intermediate_dec: bool=False, + num_feature_levels: int=4, + dec_n_points: int=4, + enc_n_points: int=4, + two_stage: bool=False, + two_stage_num_proposals: int=300, + ) -> None: + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + + encoder_layer = DeformableTransformerEncoderLayer( + d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points + ) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer( + d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points + ) + self.decoder = DeformableTransformerDecoder( + decoder_layer, num_decoder_layers, return_intermediate_dec + ) + + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + else: + self.reference_points = nn.Linear(d_model, 2) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if not self.two_stage: + xavier_uniform_(self.reference_points.weight.data, gain=1.0) + constant_(self.reference_points.bias.data, 0.0) + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + N_, S_, C_ = memory.shape + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += H_ * W_ + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( + -1, keepdim=True + ) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), float("inf") + ) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds, query_embed=None): + assert self.two_stage or query_embed is not None + + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds, strict=False)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=src_flatten.device + ) + level_start_index = torch.cat( + (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) + ) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # encoder + memory = self.encoder( + src_flatten, + spatial_shapes, + level_start_index, + valid_ratios, + lvl_pos_embed_flatten, + mask_flatten, + ) + + # prepare input for decoder + bs, _, c = memory.shape + if self.two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes + ) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = ( + self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals + ) + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) + ) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm( + self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)) + ) + query_embed, tgt = torch.split(pos_trans_out, c, dim=2) + else: + query_embed, tgt = torch.split(query_embed, c, dim=1) + query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_embed).sigmoid() + init_reference_out = reference_points + + # decoder + hs, inter_references = self.decoder( + tgt, + reference_points, + memory, + spatial_shapes, + level_start_index, + valid_ratios, + query_embed, + mask_flatten, + ) + + inter_references_out = inter_references + if self.two_stage: + return ( + hs, + init_reference_out, + inter_references_out, + enc_outputs_class, + enc_outputs_coord_unact, + ) + return hs, init_reference_out, inter_references_out, None, None + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model: int=256, + d_ffn: int=1024, + dropout: float=0.1, + activation: str="relu", + n_levels: int=4, + n_heads: int=8, + n_points: int=4, + ) -> None: + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward( + self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None + ): + # self attention + src2 = self.self_attn( + self.with_pos_embed(src, pos), + reference_points, + src, + spatial_shapes, + level_start_index, + padding_mask, + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers: int) -> None: + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward( + self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None + ): + output = src + reference_points = self.get_reference_points( + spatial_shapes, valid_ratios, device=src.device + ) + for _, layer in enumerate(self.layers): + output = layer( + output, pos, reference_points, spatial_shapes, level_start_index, padding_mask + ) + + return output + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model: int=256, + d_ffn: int=1024, + dropout: float=0.1, + activation: str="relu", + n_levels: int=4, + n_heads: int=8, + n_points: int=4, + ) -> None: + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward( + self, + tgt, + query_pos, + reference_points, + src, + src_spatial_shapes, + level_start_index, + src_padding_mask=None, + ): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[ + 0 + ].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos), + reference_points, + src, + src_spatial_shapes, + level_start_index, + src_padding_mask, + ) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers: int, return_intermediate: bool=False) -> None: + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + + def forward( + self, + tgt, + reference_points, + src, + src_spatial_shapes, + src_level_start_index, + src_valid_ratios, + query_pos=None, + src_padding_mask=None, + ): + output = tgt + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] + * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] + output = layer( + output, + query_pos, + reference_points_input, + src, + src_spatial_shapes, + src_level_start_index, + src_padding_mask, + ) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def build_deforamble_transformer(args): + return DeformableTransformer( + d_model=args.hidden_dim, + nhead=args.nheads, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=args.num_feature_levels, + dec_n_points=args.dec_n_points, + enc_n_points=args.enc_n_points, + two_stage=args.two_stage, + two_stage_num_proposals=args.num_queries, + ) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/matcher.py b/dimos/models/Detic/third_party/Deformable-DETR/models/matcher.py new file mode 100644 index 0000000000..7cbcf4a82e --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/matcher.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" + +from scipy.optimize import linear_sum_assignment +import torch +from torch import nn +from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1) -> None: + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + with torch.no_grad(): + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou( + box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) + ) + + # Final cost matrix + C = ( + self.cost_bbox * cost_bbox + + self.cost_class * cost_class + + self.cost_giou * cost_giou + ) + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] + + +def build_matcher(args): + return HungarianMatcher( + cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou + ) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/__init__.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/__init__.py new file mode 100644 index 0000000000..c528f3c6cf --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/ms_deform_attn_func.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000000..965811ed7f --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,94 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + + +import MultiScaleDeformableAttention as MSDA +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable +import torch.nn.functional as F + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward( + ctx, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ctx.im2col_step, + ) + ctx.save_for_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(N_, M_ * D_, Lq_) + ) + return output.transpose(1, 2).contiguous() diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/make.sh b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/make.sh new file mode 100755 index 0000000000..106b685722 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/make.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +python setup.py build install diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/__init__.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/__init__.py new file mode 100644 index 0000000000..f82cb1ad9d --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/ms_deform_attn.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000..1d70af7cc4 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/ms_deform_attn.py @@ -0,0 +1,147 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + + +import math +import warnings + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import constant_, xavier_uniform_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError(f"invalid input for _is_power_of_2: {n} (type: {type(n)})") + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model: int=256, n_levels: int=4, n_heads: int=8, n_points: int=4) -> None: + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError( + f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}" + ) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.", stacklevel=2 + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + r""" + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(query).view( + N, Len_q, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points + ) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1 + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead." + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/setup.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/setup.py new file mode 100644 index 0000000000..7a5560a83f --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/setup.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import glob +import os + +from setuptools import find_packages, setup +import torch +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError("Cuda is not availabel") + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000000..e1bf854de1 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,41 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.h b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000000..81b7b58a3d --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.cu b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000000..d6d583647c --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.h b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000000..c7ae53f99c --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000000..6bc2acb7ae --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/ms_deform_attn.h b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/ms_deform_attn.h new file mode 100644 index 0000000000..ac0ef2ec25 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/ms_deform_attn.h @@ -0,0 +1,62 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/vision.cpp b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/vision.cpp new file mode 100644 index 0000000000..2201f63a51 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/test.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/test.py new file mode 100644 index 0000000000..720d6473b2 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/test.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch +import torch +from torch.autograd import gradcheck + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H * W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double() -> None: + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ( + ms_deform_attn_core_pytorch( + value.double(), shapes, sampling_locations.double(), attention_weights.double() + ) + .detach() + .cpu() + ) + output_cuda = ( + MSDeformAttnFunction.apply( + value.double(), + shapes, + level_start_index, + sampling_locations.double(), + attention_weights.double(), + im2col_step, + ) + .detach() + .cpu() + ) + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print( + f"* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" + ) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float() -> None: + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ( + ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights) + .detach() + .cpu() + ) + output_cuda = ( + MSDeformAttnFunction.apply( + value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step + ) + .detach() + .cpu() + ) + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print( + f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" + ) + + +def check_gradient_numerical( + channels: int=4, grad_value: bool=True, grad_sampling_loc: bool=True, grad_attn_weight: bool=True +) -> None: + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck( + func, + ( + value.double(), + shapes, + level_start_index, + sampling_locations.double(), + attention_weights.double(), + im2col_step, + ), + ) + + print(f"* {gradok} check_gradient_numerical(D={channels})") + + +if __name__ == "__main__": + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/position_encoding.py b/dimos/models/Detic/third_party/Deformable-DETR/models/position_encoding.py new file mode 100644 index 0000000000..2ce5038e5e --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/position_encoding.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" + +import math + +import torch +from torch import nn +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats: int=64, temperature: int=10000, normalize: bool=False, scale=None) -> None: + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats: int=256) -> None: + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ("v2", "sine"): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/segmentation.py b/dimos/models/Detic/third_party/Deformable-DETR/models/segmentation.py new file mode 100644 index 0000000000..2450a5c447 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/segmentation.py @@ -0,0 +1,398 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +This file provides the definition of the convolutional heads used to predict masks, as well as the losses +""" + +from collections import defaultdict +import io + +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +import util.box_ops as box_ops +from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list +from typing import Optional, Sequence + +try: + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + pass + + +class DETRsegm(nn.Module): + def __init__(self, detr, freeze_detr: bool=False) -> None: + super().__init__() + self.detr = detr + + if freeze_detr: + for p in self.parameters(): + p.requires_grad_(False) + + hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead + self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0) + self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) + + def forward(self, samples: NestedTensor): + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.detr.backbone(samples) + + bs = features[-1].tensors.shape[0] + + src, mask = features[-1].decompose() + src_proj = self.detr.input_proj(src) + hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) + + outputs_class = self.detr.class_embed(hs) + outputs_coord = self.detr.bbox_embed(hs).sigmoid() + out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.detr.aux_loss: + out["aux_outputs"] = [ + {"pred_logits": a, "pred_boxes": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1], strict=False) + ] + + # FIXME h_boxes takes the last one computed, keep this in mind + bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) + + seg_masks = self.mask_head( + src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors] + ) + outputs_seg_masks = seg_masks.view( + bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1] + ) + + out["pred_masks"] = outputs_seg_masks + return out + + +class MaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. + Upsampling is done using a FPN approach + """ + + def __init__(self, dim: int, fpn_dims, context_dim) -> None: + super().__init__() + + inter_dims = [ + dim, + context_dim // 2, + context_dim // 4, + context_dim // 8, + context_dim // 16, + context_dim // 64, + ] + self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, dim) + self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x, bbox_mask, fpns): + def expand(tensor, length: int): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + +class MHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads: int, dropout: int=0, bias: bool=True) -> None: + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + nn.init.zeros_(self.k_linear.bias) + nn.init.zeros_(self.q_linear.bias) + nn.init.xavier_uniform_(self.k_linear.weight) + nn.init.xavier_uniform_(self.q_linear.weight) + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask=None): + q = self.q_linear(q) + k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + kh = k.view( + k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1] + ) + weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) + + if mask is not None: + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) + weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) + weights = self.dropout(weights) + return weights + + +def dice_loss(inputs, targets, num_boxes: int): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes: int, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +class PostProcessSegm(nn.Module): + def __init__(self, threshold: float=0.5) -> None: + super().__init__() + self.threshold = threshold + + @torch.no_grad() + def forward(self, results, outputs, orig_target_sizes: Sequence[int], max_target_sizes: Sequence[int]): + assert len(orig_target_sizes) == len(max_target_sizes) + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs["pred_masks"].squeeze(2) + outputs_masks = F.interpolate( + outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False + ) + outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate( + zip(outputs_masks, max_target_sizes, orig_target_sizes, strict=False) + ): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = F.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + +class PostProcessPanoptic(nn.Module): + """This class converts the output of the model to the final panoptic result, in the format expected by the + coco panoptic API""" + + def __init__(self, is_thing_map: bool, threshold: float=0.85) -> None: + """ + Parameters: + is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether + the class is a thing (True) or a stuff (False) class + threshold: confidence threshold: segments with confidence lower than this will be deleted + """ + super().__init__() + self.threshold = threshold + self.is_thing_map = is_thing_map + + def forward(self, outputs, processed_sizes: Sequence[int], target_sizes: Optional[Sequence[int]]=None): + """This function computes the panoptic prediction from the model's predictions. + Parameters: + outputs: This is a dict coming directly from the model. See the model doc for the content. + processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the + model, ie the size after data augmentation but before batching. + target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size + of each prediction. If left to None, it will default to the processed_sizes + """ + if target_sizes is None: + target_sizes = processed_sizes + assert len(processed_sizes) == len(target_sizes) + out_logits, raw_masks, raw_boxes = ( + outputs["pred_logits"], + outputs["pred_masks"], + outputs["pred_boxes"], + ) + assert len(out_logits) == len(raw_masks) == len(target_sizes) + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes, strict=False + ): + # we filter empty queries and detection below threshold + scores, labels = cur_logits.softmax(-1).max(-1) + keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) + cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores = cur_scores[keep] + cur_classes = cur_classes[keep] + cur_masks = cur_masks[keep] + cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) + cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + assert len(cur_boxes) == len(cur_classes) + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_classes): + if not self.is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup: bool=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) + + np_seg_img = ( + torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())) + .view(final_h, final_w, 3) + .numpy() + ) + m_id = torch.from_numpy(rgb2id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_classes.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_classes)], + dtype=torch.bool, + device=keep.device, + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_classes[i].item() + segments_info.append( + {"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a} + ) + del cur_classes + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds diff --git a/dimos/models/Detic/third_party/Deformable-DETR/requirements.txt b/dimos/models/Detic/third_party/Deformable-DETR/requirements.txt new file mode 100644 index 0000000000..fd846723be --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/requirements.txt @@ -0,0 +1,4 @@ +pycocotools +tqdm +cython +scipy diff --git a/dimos/models/Detic/third_party/Deformable-DETR/tools/launch.py b/dimos/models/Detic/third_party/Deformable-DETR/tools/launch.py new file mode 100644 index 0000000000..1d60ae4994 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/tools/launch.py @@ -0,0 +1,204 @@ +# -------------------------------------------------------------------------------------------------------------------------- +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# -------------------------------------------------------------------------------------------------------------------------- +# Modified from https://github.com/pytorch/pytorch/blob/173f224570017b4b1a3a1a13d0bff280a54d9cd9/torch/distributed/launch.py +# -------------------------------------------------------------------------------------------------------------------------- + +r""" +`torch.distributed.launch` is a module that spawns up multiple distributed +training processes on each of the training nodes. +The utility can be used for single-node distributed training, in which one or +more processes per node will be spawned. The utility can be used for either +CPU training or GPU training. If the utility is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. It can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be benefitial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. +In both cases of single-node distributed training or multi-node distributed +training, this utility will launch the given number of processes per node +(``--nproc_per_node``). If used for GPU training, this number needs to be less +or euqal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. +**How to use this module:** +1. Single-Node multi-process distributed training +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) +2. Multi-Node multi-process distributed training: (e.g. two nodes) +Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" + --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) +Node 2: +:: + >>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" + --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) +3. To look up what optional arguments this module offers: +:: + >>> python -m torch.distributed.launch --help +**Important Notices:** +1. This utilty and multi-process distributed (single-node or +multi-node) GPU training currently only achieves the best performance using +the NCCL distributed backend. Thus NCCL backend is the recommended backend to +use for GPU training. +2. In your training program, you must parse the command-line argument: +``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by this module. +If your training program uses GPUs, you should ensure that your code only +runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: +Parsing the local_rank argument +:: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local_rank", type=int) + >>> args = parser.parse_args() +Set your device to local rank using either +:: + >>> torch.cuda.set_device(arg.local_rank) # before your code runs +or +:: + >>> with torch.cuda.device(arg.local_rank): + >>> # your code to run +3. In your training program, you are supposed to call the following function +at the beginning to start the distributed backend. You need to make sure that +the init_method uses ``env://``, which is the only supported ``init_method`` +by this module. +:: + torch.distributed.init_process_group(backend='YOUR BACKEND', + init_method='env://') +4. In your training program, you can either use regular distributed functions +or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your +training program uses GPUs for training and you would like to use +:func:`torch.nn.parallel.DistributedDataParallel` module, +here is how to configure it. +:: + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[arg.local_rank], + output_device=arg.local_rank) +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, +and ``output_device`` needs to be ``args.local_rank`` in order to use this +utility +5. Another way to pass ``local_rank`` to the subprocesses via environment variable +``LOCAL_RANK``. This behavior is enabled when you launch the script with +``--use_env=True``. You must adjust the subprocess example above to replace +``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher +will not pass ``--local_rank`` when you specify this flag. +.. warning:: + ``local_rank`` is NOT globally unique: it is only unique per process + on a machine. Thus, don't use it to decide if you should, e.g., + write to a networked filesystem. See + https://github.com/pytorch/pytorch/issues/12042 for an example of + how things can go wrong if you don't do this correctly. +""" + +from argparse import REMAINDER, ArgumentParser +import os +import subprocess + + +def parse_args(): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser = ArgumentParser( + description="PyTorch distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes" + ) + + # Optional arguments for the launch helper + parser.add_argument( + "--nnodes", type=int, default=1, help="The number of nodes to use for distributed training" + ) + parser.add_argument( + "--node_rank", + type=int, + default=0, + help="The rank of the node for multi-node distributed training", + ) + parser.add_argument( + "--nproc_per_node", + type=int, + default=1, + help="The number of processes to launch on each node, " + "for GPU training, this is recommended to be set " + "to the number of GPUs in your system so that " + "each process can be bound to a single GPU.", + ) + parser.add_argument( + "--master_addr", + default="127.0.0.1", + type=str, + help="Master node (rank 0)'s address, should be either " + "the IP address or the hostname of node 0, for " + "single node multi-proc training, the " + "--master_addr can simply be 127.0.0.1", + ) + parser.add_argument( + "--master_port", + default=29500, + type=int, + help="Master node (rank 0)'s free port that needs to be used for communciation during distributed training", + ) + + # positional + parser.add_argument( + "training_script", + type=str, + help="The full path to the single GPU training " + "program/script to be launched in parallel, " + "followed by all the arguments for the " + "training script", + ) + + # rest from the training program + parser.add_argument("training_script_args", nargs=REMAINDER) + return parser.parse_args() + + +def main(): + args = parse_args() + + # world size in terms of number of processes + dist_world_size = args.nproc_per_node * args.nnodes + + # set PyTorch distributed related environmental variables + current_env = os.environ.copy() + current_env["MASTER_ADDR"] = args.master_addr + current_env["MASTER_PORT"] = str(args.master_port) + current_env["WORLD_SIZE"] = str(dist_world_size) + + processes = [] + + for local_rank in range(0, args.nproc_per_node): + # each process's rank + dist_rank = args.nproc_per_node * args.node_rank + local_rank + current_env["RANK"] = str(dist_rank) + current_env["LOCAL_RANK"] = str(local_rank) + + cmd = [args.training_script, *args.training_script_args] + + process = subprocess.Popen(cmd, env=current_env) + processes.append(process) + + for process in processes: + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=process.args) + + +if __name__ == "__main__": + main() diff --git a/dimos/models/Detic/third_party/Deformable-DETR/tools/run_dist_launch.sh b/dimos/models/Detic/third_party/Deformable-DETR/tools/run_dist_launch.sh new file mode 100755 index 0000000000..f6f6c4fb6f --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/tools/run_dist_launch.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +set -x + +GPUS=$1 +RUN_COMMAND=${@:2} +if [ $GPUS -lt 8 ]; then + GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} +else + GPUS_PER_NODE=${GPUS_PER_NODE:-8} +fi +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +MASTER_PORT=${MASTER_PORT:-"29500"} +NODE_RANK=${NODE_RANK:-0} + +let "NNODES=GPUS/GPUS_PER_NODE" + +python ./tools/launch.py \ + --nnodes ${NNODES} \ + --node_rank ${NODE_RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT} \ + --nproc_per_node ${GPUS_PER_NODE} \ + ${RUN_COMMAND} \ No newline at end of file diff --git a/dimos/models/Detic/third_party/Deformable-DETR/tools/run_dist_slurm.sh b/dimos/models/Detic/third_party/Deformable-DETR/tools/run_dist_slurm.sh new file mode 100755 index 0000000000..bd73d0bbb7 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/tools/run_dist_slurm.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# -------------------------------------------------------------------------------------------------------------------------- +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# -------------------------------------------------------------------------------------------------------------------------- +# Modified from https://github.com/open-mmlab/mmdetection/blob/3b53fe15d87860c6941f3dda63c0f27422da6266/tools/slurm_train.sh +# -------------------------------------------------------------------------------------------------------------------------- + +set -x + +PARTITION=$1 +JOB_NAME=$2 +GPUS=$3 +RUN_COMMAND=${@:4} +if [ $GPUS -lt 8 ]; then + GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} +else + GPUS_PER_NODE=${GPUS_PER_NODE:-8} +fi +CPUS_PER_TASK=${CPUS_PER_TASK:-4} +SRUN_ARGS=${SRUN_ARGS:-""} + +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + ${RUN_COMMAND} + diff --git a/dimos/models/Detic/third_party/Deformable-DETR/util/__init__.py b/dimos/models/Detic/third_party/Deformable-DETR/util/__init__.py new file mode 100644 index 0000000000..4ebdc90b7f --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/util/__init__.py @@ -0,0 +1,8 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ diff --git a/dimos/models/Detic/third_party/Deformable-DETR/util/box_ops.py b/dimos/models/Detic/third_party/Deformable-DETR/util/box_ops.py new file mode 100644 index 0000000000..5864b68d3b --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/util/box_ops.py @@ -0,0 +1,95 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Utilities for bounding box manipulation and GIoU. +""" + +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = masks * x.unsqueeze(0) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = masks * y.unsqueeze(0) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/util/misc.py b/dimos/models/Detic/third_party/Deformable-DETR/util/misc.py new file mode 100644 index 0000000000..0615de5b5f --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/util/misc.py @@ -0,0 +1,538 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" + +from collections import defaultdict, deque +import datetime +import os +import pickle +import subprocess +import time +from typing import List, Optional + +import torch +from torch import Tensor +import torch.distributed as dist + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision + +if float(torchvision.__version__[:3]) < 0.5: + import math + + from torchvision.ops.misc import _NewEmptyTensorOp + + def _check_size_scale_factor(dim: int, size: int, scale_factor): + # type: (int, Optional[List[int]], Optional[float]) -> None + if size is None and scale_factor is None: + raise ValueError("either size or scale_factor should be defined") + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + if not (scale_factor is not None and len(scale_factor) != dim): + raise ValueError( + f"scale_factor shape must match input shape. Input is {dim}D, scale_factor size is {len(scale_factor)}" + ) + + def _output_size(dim: int, input, size: int, scale_factor): + # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] + assert dim == 2 + _check_size_scale_factor(dim, size, scale_factor) + if size is not None: + return size + # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat + assert scale_factor is not None and isinstance(scale_factor, int | float) + scale_factors = [scale_factor, scale_factor] + # math.floor might return float in py2.7 + return [math.floor(input.size(i + 2) * scale_factors[i]) for i in range(dim)] +elif float(torchvision.__version__[:3]) < 0.7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size: int=20, fmt=None) -> None: + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n: int=1) -> None: + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self) -> None: + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self) -> str: + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list, strict=False): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average: bool=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values, strict=False)} + return reduced_dict + + +class MetricLogger: + def __init__(self, delimiter: str="\t") -> None: + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs) -> None: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, float | int) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self) -> str: + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {meter!s}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self) -> None: + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name: str, meter) -> None: + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)" + ) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommited changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch, strict=False)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: list[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list), *max_size] + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +class NestedTensor: + def __init__(self, tensors, mask: Tensor | None) -> None: + self.tensors = tensors + self.mask = mask + + def to(self, device, non_blocking: bool=False): + # type: (Device) -> NestedTensor + cast_tensor = self.tensors.to(device, non_blocking=non_blocking) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device, non_blocking=non_blocking) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def record_stream(self, *args, **kwargs) -> None: + self.tensors.record_stream(*args, **kwargs) + if self.mask is not None: + self.mask.record_stream(*args, **kwargs) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self) -> str: + return str(self.tensors) + + +def setup_for_distributed(is_master: bool) -> None: + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs) -> None: + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized() -> bool: + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_size(): + if not is_dist_avail_and_initialized(): + return 1 + return int(os.environ["LOCAL_SIZE"]) + + +def get_local_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return int(os.environ["LOCAL_RANK"]) + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs) -> None: + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args) -> None: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + args.dist_url = "env://" + os.environ["LOCAL_SIZE"] = str(torch.cuda.device_count()) + elif "SLURM_PROCID" in os.environ: + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(ntasks) + os.environ["RANK"] = str(proc_id) + os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) + os.environ["LOCAL_SIZE"] = str(num_gpus) + args.dist_url = "env://" + args.world_size = ntasks + args.rank = proc_id + args.gpu = proc_id % num_gpus + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size: Optional[int]=None, scale_factor=None, mode: str="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + if float(torchvision.__version__[:3]) < 0.5: + return _NewEmptyTensorOp.apply(input, output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +def get_total_grad_norm(parameters, norm_type: int=2): + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + device = parameters[0].grad.device + total_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), + norm_type, + ) + return total_norm + + +def inverse_sigmoid(x, eps: float=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py b/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py new file mode 100644 index 0000000000..0af3b9e5e6 --- /dev/null +++ b/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py @@ -0,0 +1,120 @@ +# ------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Plotting utilities to visualize training logs. +""" + +from pathlib import Path, PurePath + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import torch + + +def plot_logs( + logs, fields=("class_error", "loss_bbox_unscaled", "mAP"), ewm_col: int=0, log_name: str="log.txt" +): + """ + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + """ + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError( + f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}" + ) + + # verify valid dir(s) and that every item in list is Path object + for _i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError( + f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}" + ) + if dir.exists(): + continue + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs)), strict=False): + for j, field in enumerate(fields): + if field == "mAP": + coco_eval = ( + pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]) + .ewm(com=ewm_col) + .mean() + ) + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f"train_{field}", f"test_{field}"], + ax=axs[j], + color=[color] * 2, + style=["-", "--"], + ) + for ax, field in zip(axs, fields, strict=False): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme: str="iter"): + if naming_scheme == "exp_id": + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == "iter": + names = [f.stem for f in files] + else: + raise ValueError(f"not supported {naming_scheme}") + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names, strict=False): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data["precision"] + recall = data["params"].recThrs + scores = data["scores"] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data["recall"][0, :, 0, -1].mean() + print( + f"{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, " + + f"score={scores.mean():0.3f}, " + + f"f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}" + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title("Precision / Recall") + axs[0].legend(names) + axs[1].set_title("Scores / Recall") + axs[1].legend(names) + return fig, axs diff --git a/dimos/models/Detic/tools/convert-thirdparty-pretrained-model-to-d2.py b/dimos/models/Detic/tools/convert-thirdparty-pretrained-model-to-d2.py new file mode 100644 index 0000000000..567e71f7c4 --- /dev/null +++ b/dimos/models/Detic/tools/convert-thirdparty-pretrained-model-to-d2.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +import pickle + +import torch + +""" +Usage: + +cd DETIC_ROOT/models/ +wget https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth +python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path resnet50_miil_21k.pth + +wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth +python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path swin_base_patch4_window7_224_22k.pth + +""" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--path", default="") + args = parser.parse_args() + + print("Loading", args.path) + model = torch.load(args.path, map_location="cpu") + # import pdb; pdb.set_trace() + if "model" in model: + model = model["model"] + if "state_dict" in model: + model = model["state_dict"] + ret = {"model": model, "__author__": "third_party", "matching_heuristics": True} + out_path = args.path.replace(".pth", ".pkl") + print("Saving to", out_path) + pickle.dump(ret, open(out_path, "wb")) diff --git a/dimos/models/Detic/tools/create_imagenetlvis_json.py b/dimos/models/Detic/tools/create_imagenetlvis_json.py new file mode 100644 index 0000000000..4f53874421 --- /dev/null +++ b/dimos/models/Detic/tools/create_imagenetlvis_json.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import json +import os + +from detectron2.data.detection_utils import read_image +from nltk.corpus import wordnet + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--imagenet_path", default="datasets/imagenet/ImageNet-LVIS") + parser.add_argument("--lvis_meta_path", default="datasets/lvis/lvis_v1_val.json") + parser.add_argument( + "--out_path", default="datasets/imagenet/annotations/imagenet_lvis_image_info.json" + ) + args = parser.parse_args() + + print("Loading LVIS meta") + data = json.load(open(args.lvis_meta_path)) + print("Done") + synset2cat = {x["synset"]: x for x in data["categories"]} + count = 0 + images = [] + image_counts = {} + folders = sorted(os.listdir(args.imagenet_path)) + for i, folder in enumerate(folders): + class_path = args.imagenet_path + folder + files = sorted(os.listdir(class_path)) + synset = wordnet.synset_from_pos_and_offset("n", int(folder[1:])).name() + cat = synset2cat[synset] + cat_id = cat["id"] + cat_name = cat["name"] + cat_images = [] + for file in files: + count = count + 1 + file_name = f"{folder}/{file}" + # img = cv2.imread('{}/{}'.format(args.imagenet_path, file_name)) + img = read_image(f"{args.imagenet_path}/{file_name}") + h, w = img.shape[:2] + image = { + "id": count, + "file_name": file_name, + "pos_category_ids": [cat_id], + "width": w, + "height": h, + } + cat_images.append(image) + images.extend(cat_images) + image_counts[cat_id] = len(cat_images) + print(i, cat_name, len(cat_images)) + print("# Images", len(images)) + for x in data["categories"]: + x["image_count"] = image_counts[x["id"]] if x["id"] in image_counts else 0 + out = {"categories": data["categories"], "images": images, "annotations": []} + print("Writing to", args.out_path) + json.dump(out, open(args.out_path, "w")) diff --git a/dimos/models/Detic/tools/create_lvis_21k.py b/dimos/models/Detic/tools/create_lvis_21k.py new file mode 100644 index 0000000000..a1f24446ac --- /dev/null +++ b/dimos/models/Detic/tools/create_lvis_21k.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import copy +import json + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--imagenet_path", default="datasets/imagenet/annotations/imagenet-21k_image_info.json" + ) + parser.add_argument("--lvis_path", default="datasets/lvis/lvis_v1_train.json") + parser.add_argument("--save_categories", default="") + parser.add_argument("--not_save_imagenet", action="store_true") + parser.add_argument("--not_save_lvis", action="store_true") + parser.add_argument("--mark", default="lvis-21k") + args = parser.parse_args() + + print("Loading", args.imagenet_path) + in_data = json.load(open(args.imagenet_path)) + print("Loading", args.lvis_path) + lvis_data = json.load(open(args.lvis_path)) + + categories = copy.deepcopy(lvis_data["categories"]) + cat_count = max(x["id"] for x in categories) + synset2id = {x["synset"]: x["id"] for x in categories} + name2id = {x["name"]: x["id"] for x in categories} + in_id_map = {} + for x in in_data["categories"]: + if x["synset"] in synset2id: + in_id_map[x["id"]] = synset2id[x["synset"]] + elif x["name"] in name2id: + in_id_map[x["id"]] = name2id[x["name"]] + x["id"] = name2id[x["name"]] + else: + cat_count = cat_count + 1 + name2id[x["name"]] = cat_count + in_id_map[x["id"]] = cat_count + x["id"] = cat_count + categories.append(x) + + print("lvis cats", len(lvis_data["categories"])) + print("imagenet cats", len(in_data["categories"])) + print("merge cats", len(categories)) + + filtered_images = [] + for x in in_data["images"]: + x["pos_category_ids"] = [in_id_map[xx] for xx in x["pos_category_ids"]] + x["pos_category_ids"] = [xx for xx in sorted(set(x["pos_category_ids"])) if xx >= 0] + if len(x["pos_category_ids"]) > 0: + filtered_images.append(x) + + in_data["categories"] = categories + lvis_data["categories"] = categories + + if not args.not_save_imagenet: + in_out_path = args.imagenet_path[:-5] + f"_{args.mark}.json" + for k, v in in_data.items(): + print("imagenet", k, len(v)) + print("Saving Imagenet to", in_out_path) + json.dump(in_data, open(in_out_path, "w")) + + if not args.not_save_lvis: + lvis_out_path = args.lvis_path[:-5] + f"_{args.mark}.json" + for k, v in lvis_data.items(): + print("lvis", k, len(v)) + print("Saving LVIS to", lvis_out_path) + json.dump(lvis_data, open(lvis_out_path, "w")) + + if args.save_categories != "": + for x in categories: + for k in ["image_count", "instance_count", "synonyms", "def"]: + if k in x: + del x[k] + CATEGORIES = repr(categories) + " # noqa" + open(args.save_categories, "w").write(f"CATEGORIES = {CATEGORIES}") diff --git a/dimos/models/Detic/tools/download_cc.py b/dimos/models/Detic/tools/download_cc.py new file mode 100644 index 0000000000..ef7b4b0f7d --- /dev/null +++ b/dimos/models/Detic/tools/download_cc.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import json +import os + +import numpy as np +from PIL import Image + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ann", default="datasets/cc3m/Train_GCC-training.tsv") + parser.add_argument("--save_image_path", default="datasets/cc3m/training/") + parser.add_argument("--cat_info", default="datasets/lvis/lvis_v1_val.json") + parser.add_argument("--out_path", default="datasets/cc3m/train_image_info.json") + parser.add_argument("--not_download_image", action="store_true") + args = parser.parse_args() + categories = json.load(open(args.cat_info))["categories"] + images = [] + if not os.path.exists(args.save_image_path): + os.makedirs(args.save_image_path) + f = open(args.ann) + for i, line in enumerate(f): + cap, path = line[:-1].split("\t") + print(i, cap, path) + if not args.not_download_image: + os.system(f"wget {path} -O {args.save_image_path}/{i + 1}.jpg") + try: + img = Image.open(open(f"{args.save_image_path}/{i + 1}.jpg", "rb")) + img = np.asarray(img.convert("RGB")) + h, w = img.shape[:2] + except: + continue + image_info = { + "id": i + 1, + "file_name": f"{i + 1}.jpg", + "height": h, + "width": w, + "captions": [cap], + } + images.append(image_info) + data = {"categories": categories, "images": images, "annotations": []} + for k, v in data.items(): + print(k, len(v)) + print("Saving to", args.out_path) + json.dump(data, open(args.out_path, "w")) diff --git a/dimos/models/Detic/tools/dump_clip_features.py b/dimos/models/Detic/tools/dump_clip_features.py new file mode 100644 index 0000000000..31be161f6d --- /dev/null +++ b/dimos/models/Detic/tools/dump_clip_features.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import itertools +import json + +from nltk.corpus import wordnet +import numpy as np +import torch + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ann", default="datasets/lvis/lvis_v1_val.json") + parser.add_argument("--out_path", default="") + parser.add_argument("--prompt", default="a") + parser.add_argument("--model", default="clip") + parser.add_argument("--clip_model", default="ViT-B/32") + parser.add_argument("--fix_space", action="store_true") + parser.add_argument("--use_underscore", action="store_true") + parser.add_argument("--avg_synonyms", action="store_true") + parser.add_argument("--use_wn_name", action="store_true") + args = parser.parse_args() + + print("Loading", args.ann) + data = json.load(open(args.ann)) + cat_names = [x["name"] for x in sorted(data["categories"], key=lambda x: x["id"])] + if "synonyms" in data["categories"][0]: + if args.use_wn_name: + synonyms = [ + [xx.name() for xx in wordnet.synset(x["synset"]).lemmas()] + if x["synset"] != "stop_sign.n.01" + else ["stop_sign"] + for x in sorted(data["categories"], key=lambda x: x["id"]) + ] + else: + synonyms = [x["synonyms"] for x in sorted(data["categories"], key=lambda x: x["id"])] + else: + synonyms = [] + if args.fix_space: + cat_names = [x.replace("_", " ") for x in cat_names] + if args.use_underscore: + cat_names = [x.strip().replace("/ ", "/").replace(" ", "_") for x in cat_names] + print("cat_names", cat_names) + device = "cuda" if torch.cuda.is_available() else "cpu" + + if args.prompt == "a": + sentences = ["a " + x for x in cat_names] + sentences_synonyms = [["a " + xx for xx in x] for x in synonyms] + if args.prompt == "none": + sentences = [x for x in cat_names] + sentences_synonyms = [[xx for xx in x] for x in synonyms] + elif args.prompt == "photo": + sentences = [f"a photo of a {x}" for x in cat_names] + sentences_synonyms = [[f"a photo of a {xx}" for xx in x] for x in synonyms] + elif args.prompt == "scene": + sentences = [f"a photo of a {x} in the scene" for x in cat_names] + sentences_synonyms = [ + [f"a photo of a {xx} in the scene" for xx in x] for x in synonyms + ] + + print("sentences_synonyms", len(sentences_synonyms), sum(len(x) for x in sentences_synonyms)) + if args.model == "clip": + import clip + + print("Loading CLIP") + model, preprocess = clip.load(args.clip_model, device=device) + if args.avg_synonyms: + sentences = list(itertools.chain.from_iterable(sentences_synonyms)) + print("flattened_sentences", len(sentences)) + text = clip.tokenize(sentences).to(device) + with torch.no_grad(): + if len(text) > 10000: + text_features = torch.cat( + [ + model.encode_text(text[: len(text) // 2]), + model.encode_text(text[len(text) // 2 :]), + ], + dim=0, + ) + else: + text_features = model.encode_text(text) + print("text_features.shape", text_features.shape) + if args.avg_synonyms: + synonyms_per_cat = [len(x) for x in sentences_synonyms] + text_features = text_features.split(synonyms_per_cat, dim=0) + text_features = [x.mean(dim=0) for x in text_features] + text_features = torch.stack(text_features, dim=0) + print("after stack", text_features.shape) + text_features = text_features.cpu().numpy() + elif args.model in ["bert", "roberta"]: + from transformers import AutoModel, AutoTokenizer + + if args.model == "bert": + model_name = "bert-large-uncased" + if args.model == "roberta": + model_name = "roberta-large" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name) + model.eval() + if args.avg_synonyms: + sentences = list(itertools.chain.from_iterable(sentences_synonyms)) + print("flattened_sentences", len(sentences)) + inputs = tokenizer(sentences, padding=True, return_tensors="pt") + with torch.no_grad(): + model_outputs = model(**inputs) + outputs = model_outputs.pooler_output + text_features = outputs.detach().cpu() + if args.avg_synonyms: + synonyms_per_cat = [len(x) for x in sentences_synonyms] + text_features = text_features.split(synonyms_per_cat, dim=0) + text_features = [x.mean(dim=0) for x in text_features] + text_features = torch.stack(text_features, dim=0) + print("after stack", text_features.shape) + text_features = text_features.numpy() + print("text_features.shape", text_features.shape) + else: + assert 0, args.model + if args.out_path != "": + print("saveing to", args.out_path) + np.save(open(args.out_path, "wb"), text_features) + import pdb + + pdb.set_trace() diff --git a/dimos/models/Detic/tools/fix_o365_names.py b/dimos/models/Detic/tools/fix_o365_names.py new file mode 100644 index 0000000000..5aee27a14f --- /dev/null +++ b/dimos/models/Detic/tools/fix_o365_names.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import copy +import json + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ann", default="datasets/objects365/annotations/zhiyuan_objv2_val.json") + parser.add_argument("--fix_name_map", default="datasets/metadata/Objects365_names_fix.csv") + args = parser.parse_args() + + new_names = {} + old_names = {} + with open(args.fix_name_map) as f: + for line in f: + tmp = line.strip().split(",") + old_names[int(tmp[0])] = tmp[1] + new_names[int(tmp[0])] = tmp[2] + data = json.load(open(args.ann)) + + cat_info = copy.deepcopy(data["categories"]) + + for x in cat_info: + if old_names[x["id"]].strip() != x["name"].strip(): + print("{} {} {}".format(x, old_names[x["id"]], new_names[x["id"]])) + import pdb + + pdb.set_trace() + if old_names[x["id"]] != new_names[x["id"]]: + print("Renaming", x["id"], x["name"], new_names[x["id"]]) + x["name"] = new_names[x["id"]] + + data["categories"] = cat_info + out_name = args.ann[:-5] + "_fixname.json" + print("Saving to", out_name) + json.dump(data, open(out_name, "w")) diff --git a/dimos/models/Detic/tools/fix_o365_path.py b/dimos/models/Detic/tools/fix_o365_path.py new file mode 100644 index 0000000000..c43358fff0 --- /dev/null +++ b/dimos/models/Detic/tools/fix_o365_path.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import json +import os + +import path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ann", default="datasets/objects365/annotations/zhiyuan_objv2_train_fixname.json" + ) + parser.add_argument("--img_dir", default="datasets/objects365/train/") + args = parser.parse_args() + + print("Loading", args.ann) + data = json.load(open(args.ann)) + images = [] + count = 0 + for x in data["images"]: + path = "{}/{}".format(args.img_dir, x["file_name"]) + if os.path.exists(path): + images.append(x) + else: + print(path) + count = count + 1 + print("Missing", count, "images") + data["images"] = images + out_name = args.ann[:-5] + "_fixmiss.json" + print("Saving to", out_name) + json.dump(data, open(out_name, "w")) diff --git a/dimos/models/Detic/tools/get_cc_tags.py b/dimos/models/Detic/tools/get_cc_tags.py new file mode 100644 index 0000000000..0a5cdab8ec --- /dev/null +++ b/dimos/models/Detic/tools/get_cc_tags.py @@ -0,0 +1,198 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +from collections import defaultdict +import json + +from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES + +# This mapping is extracted from the official LVIS mapping: +# https://github.com/lvis-dataset/lvis-api/blob/master/data/coco_to_synset.json +COCO_SYNSET_CATEGORIES = [ + {"synset": "person.n.01", "coco_cat_id": 1}, + {"synset": "bicycle.n.01", "coco_cat_id": 2}, + {"synset": "car.n.01", "coco_cat_id": 3}, + {"synset": "motorcycle.n.01", "coco_cat_id": 4}, + {"synset": "airplane.n.01", "coco_cat_id": 5}, + {"synset": "bus.n.01", "coco_cat_id": 6}, + {"synset": "train.n.01", "coco_cat_id": 7}, + {"synset": "truck.n.01", "coco_cat_id": 8}, + {"synset": "boat.n.01", "coco_cat_id": 9}, + {"synset": "traffic_light.n.01", "coco_cat_id": 10}, + {"synset": "fireplug.n.01", "coco_cat_id": 11}, + {"synset": "stop_sign.n.01", "coco_cat_id": 13}, + {"synset": "parking_meter.n.01", "coco_cat_id": 14}, + {"synset": "bench.n.01", "coco_cat_id": 15}, + {"synset": "bird.n.01", "coco_cat_id": 16}, + {"synset": "cat.n.01", "coco_cat_id": 17}, + {"synset": "dog.n.01", "coco_cat_id": 18}, + {"synset": "horse.n.01", "coco_cat_id": 19}, + {"synset": "sheep.n.01", "coco_cat_id": 20}, + {"synset": "beef.n.01", "coco_cat_id": 21}, + {"synset": "elephant.n.01", "coco_cat_id": 22}, + {"synset": "bear.n.01", "coco_cat_id": 23}, + {"synset": "zebra.n.01", "coco_cat_id": 24}, + {"synset": "giraffe.n.01", "coco_cat_id": 25}, + {"synset": "backpack.n.01", "coco_cat_id": 27}, + {"synset": "umbrella.n.01", "coco_cat_id": 28}, + {"synset": "bag.n.04", "coco_cat_id": 31}, + {"synset": "necktie.n.01", "coco_cat_id": 32}, + {"synset": "bag.n.06", "coco_cat_id": 33}, + {"synset": "frisbee.n.01", "coco_cat_id": 34}, + {"synset": "ski.n.01", "coco_cat_id": 35}, + {"synset": "snowboard.n.01", "coco_cat_id": 36}, + {"synset": "ball.n.06", "coco_cat_id": 37}, + {"synset": "kite.n.03", "coco_cat_id": 38}, + {"synset": "baseball_bat.n.01", "coco_cat_id": 39}, + {"synset": "baseball_glove.n.01", "coco_cat_id": 40}, + {"synset": "skateboard.n.01", "coco_cat_id": 41}, + {"synset": "surfboard.n.01", "coco_cat_id": 42}, + {"synset": "tennis_racket.n.01", "coco_cat_id": 43}, + {"synset": "bottle.n.01", "coco_cat_id": 44}, + {"synset": "wineglass.n.01", "coco_cat_id": 46}, + {"synset": "cup.n.01", "coco_cat_id": 47}, + {"synset": "fork.n.01", "coco_cat_id": 48}, + {"synset": "knife.n.01", "coco_cat_id": 49}, + {"synset": "spoon.n.01", "coco_cat_id": 50}, + {"synset": "bowl.n.03", "coco_cat_id": 51}, + {"synset": "banana.n.02", "coco_cat_id": 52}, + {"synset": "apple.n.01", "coco_cat_id": 53}, + {"synset": "sandwich.n.01", "coco_cat_id": 54}, + {"synset": "orange.n.01", "coco_cat_id": 55}, + {"synset": "broccoli.n.01", "coco_cat_id": 56}, + {"synset": "carrot.n.01", "coco_cat_id": 57}, + # {"synset": "frank.n.02", "coco_cat_id": 58}, + {"synset": "sausage.n.01", "coco_cat_id": 58}, + {"synset": "pizza.n.01", "coco_cat_id": 59}, + {"synset": "doughnut.n.02", "coco_cat_id": 60}, + {"synset": "cake.n.03", "coco_cat_id": 61}, + {"synset": "chair.n.01", "coco_cat_id": 62}, + {"synset": "sofa.n.01", "coco_cat_id": 63}, + {"synset": "pot.n.04", "coco_cat_id": 64}, + {"synset": "bed.n.01", "coco_cat_id": 65}, + {"synset": "dining_table.n.01", "coco_cat_id": 67}, + {"synset": "toilet.n.02", "coco_cat_id": 70}, + {"synset": "television_receiver.n.01", "coco_cat_id": 72}, + {"synset": "laptop.n.01", "coco_cat_id": 73}, + {"synset": "mouse.n.04", "coco_cat_id": 74}, + {"synset": "remote_control.n.01", "coco_cat_id": 75}, + {"synset": "computer_keyboard.n.01", "coco_cat_id": 76}, + {"synset": "cellular_telephone.n.01", "coco_cat_id": 77}, + {"synset": "microwave.n.02", "coco_cat_id": 78}, + {"synset": "oven.n.01", "coco_cat_id": 79}, + {"synset": "toaster.n.02", "coco_cat_id": 80}, + {"synset": "sink.n.01", "coco_cat_id": 81}, + {"synset": "electric_refrigerator.n.01", "coco_cat_id": 82}, + {"synset": "book.n.01", "coco_cat_id": 84}, + {"synset": "clock.n.01", "coco_cat_id": 85}, + {"synset": "vase.n.01", "coco_cat_id": 86}, + {"synset": "scissors.n.01", "coco_cat_id": 87}, + {"synset": "teddy.n.01", "coco_cat_id": 88}, + {"synset": "hand_blower.n.01", "coco_cat_id": 89}, + {"synset": "toothbrush.n.01", "coco_cat_id": 90}, +] + + +def map_name(x): + x = x.replace("_", " ") + if "(" in x: + x = x[: x.find("(")] + return x.lower().strip() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--cc_ann", default="datasets/cc3m/train_image_info.json") + parser.add_argument("--out_path", default="datasets/cc3m/train_image_info_tags.json") + parser.add_argument("--keep_images", action="store_true") + parser.add_argument("--allcaps", action="store_true") + parser.add_argument("--cat_path", default="") + parser.add_argument("--convert_caption", action="store_true") + # parser.add_argument('--lvis_ann', default='datasets/lvis/lvis_v1_val.json') + args = parser.parse_args() + + # lvis_data = json.load(open(args.lvis_ann, 'r')) + cc_data = json.load(open(args.cc_ann)) + if args.convert_caption: + num_caps = 0 + caps = defaultdict(list) + for x in cc_data["annotations"]: + caps[x["image_id"]].append(x["caption"]) + for x in cc_data["images"]: + x["captions"] = caps[x["id"]] + num_caps += len(x["captions"]) + print("# captions", num_caps) + + if args.cat_path != "": + print("Loading", args.cat_path) + cats = json.load(open(args.cat_path))["categories"] + if "synonyms" not in cats[0]: + cocoid2synset = {x["coco_cat_id"]: x["synset"] for x in COCO_SYNSET_CATEGORIES} + synset2synonyms = {x["synset"]: x["synonyms"] for x in LVIS_CATEGORIES} + for x in cats: + synonyms = synset2synonyms[cocoid2synset[x["id"]]] + x["synonyms"] = synonyms + x["frequency"] = "f" + cc_data["categories"] = cats + + id2cat = {x["id"]: x for x in cc_data["categories"]} + class_count = {x["id"]: 0 for x in cc_data["categories"]} + class_data = { + x["id"]: [" " + map_name(xx) + " " for xx in x["synonyms"]] for x in cc_data["categories"] + } + num_examples = 5 + examples = {x["id"]: [] for x in cc_data["categories"]} + + print("class_data", class_data) + + images = [] + for i, x in enumerate(cc_data["images"]): + if i % 10000 == 0: + print(i, len(cc_data["images"])) + if args.allcaps: + caption = (" ".join(x["captions"])).lower() + else: + caption = x["captions"][0].lower() + x["pos_category_ids"] = [] + for cat_id, cat_names in class_data.items(): + find = False + for c in cat_names: + if c in caption or caption.startswith(c[1:]) or caption.endswith(c[:-1]): + find = True + break + if find: + x["pos_category_ids"].append(cat_id) + class_count[cat_id] += 1 + if len(examples[cat_id]) < num_examples: + examples[cat_id].append(caption) + if len(x["pos_category_ids"]) > 0 or args.keep_images: + images.append(x) + + zero_class = [] + for cat_id, count in class_count.items(): + print(id2cat[cat_id]["name"], count, end=", ") + if count == 0: + zero_class.append(id2cat[cat_id]) + print("==") + print("zero class", zero_class) + + # for freq in ['r', 'c', 'f']: + # print('#cats', freq, len([x for x in cc_data['categories'] \ + # if x['frequency'] == freq] and class_count[x['id']] > 0)) + + for freq in ["r", "c", "f"]: + print( + "#Images", + freq, + sum([v for k, v in class_count.items() if id2cat[k]["frequency"] == freq]), + ) + + try: + out_data = {"images": images, "categories": cc_data["categories"], "annotations": []} + for k, v in out_data.items(): + print(k, len(v)) + if args.keep_images and not args.out_path.endswith("_full.json"): + args.out_path = args.out_path[:-5] + "_full.json" + print("Writing to", args.out_path) + json.dump(out_data, open(args.out_path, "w")) + except: + pass diff --git a/dimos/models/Detic/tools/get_coco_zeroshot_oriorder.py b/dimos/models/Detic/tools/get_coco_zeroshot_oriorder.py new file mode 100644 index 0000000000..688b0a92e5 --- /dev/null +++ b/dimos/models/Detic/tools/get_coco_zeroshot_oriorder.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import json + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_path", default="datasets/coco/annotations/instances_val2017_unseen_2.json" + ) + parser.add_argument("--cat_path", default="datasets/coco/annotations/instances_val2017.json") + args = parser.parse_args() + print("Loading", args.cat_path) + cat = json.load(open(args.cat_path))["categories"] + + print("Loading", args.data_path) + data = json.load(open(args.data_path)) + data["categories"] = cat + out_path = args.data_path[:-5] + "_oriorder.json" + print("Saving to", out_path) + json.dump(data, open(out_path, "w")) diff --git a/dimos/models/Detic/tools/get_imagenet_21k_full_tar_json.py b/dimos/models/Detic/tools/get_imagenet_21k_full_tar_json.py new file mode 100644 index 0000000000..00502db11f --- /dev/null +++ b/dimos/models/Detic/tools/get_imagenet_21k_full_tar_json.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import json +import operator +import sys +import time + +from nltk.corpus import wordnet +import numpy as np +import torch +from tqdm import tqdm + +sys.path.insert(0, "third_party/CenterNet2/") +sys.path.insert(0, "third_party/Deformable-DETR") +from detic.data.tar_dataset import DiskTarDataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--imagenet_dir", default="datasets/imagenet/ImageNet-21k/") + parser.add_argument("--tarfile_path", default="datasets/imagenet/metadata-22k/tar_files.npy") + parser.add_argument("--tar_index_dir", default="datasets/imagenet/metadata-22k/tarindex_npy") + parser.add_argument( + "--out_path", default="datasets/imagenet/annotations/imagenet-22k_image_info.json" + ) + parser.add_argument("--workers", default=16, type=int) + args = parser.parse_args() + + start_time = time.time() + print("Building dataset") + dataset = DiskTarDataset(args.tarfile_path, args.tar_index_dir) + end_time = time.time() + print(f"Took {end_time - start_time} seconds to make the dataset.") + print(f"Have {len(dataset)} samples.") + print("dataset", dataset) + + tar_files = np.load(args.tarfile_path) + categories = [] + for i, tar_file in enumerate(tar_files): + wnid = tar_file[-13:-4] + synset = wordnet.synset_from_pos_and_offset("n", int(wnid[1:])) + synonyms = [x.name() for x in synset.lemmas()] + category = { + "id": i + 1, + "synset": synset.name(), + "name": synonyms[0], + "def": synset.definition(), + "synonyms": synonyms, + } + categories.append(category) + print("categories", len(categories)) + + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=args.workers, + collate_fn=operator.itemgetter(0), + ) + images = [] + for img, label, index in tqdm(data_loader): + if label == -1: + continue + image = { + "id": int(index) + 1, + "pos_category_ids": [int(label) + 1], + "height": int(img.height), + "width": int(img.width), + "tar_index": int(index), + } + images.append(image) + + data = {"categories": categories, "images": images, "annotations": []} + try: + for k, v in data.items(): + print(k, len(v)) + print("Saving to ", args.out_path) + json.dump(data, open(args.out_path, "w")) + except: + pass + import pdb + + pdb.set_trace() diff --git a/dimos/models/Detic/tools/get_lvis_cat_info.py b/dimos/models/Detic/tools/get_lvis_cat_info.py new file mode 100644 index 0000000000..414a615b8a --- /dev/null +++ b/dimos/models/Detic/tools/get_lvis_cat_info.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import json + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ann", default="datasets/lvis/lvis_v1_train.json") + parser.add_argument("--add_freq", action="store_true") + parser.add_argument("--r_thresh", type=int, default=10) + parser.add_argument("--c_thresh", type=int, default=100) + args = parser.parse_args() + + print("Loading", args.ann) + data = json.load(open(args.ann)) + cats = data["categories"] + image_count = {x["id"]: set() for x in cats} + ann_count = {x["id"]: 0 for x in cats} + for x in data["annotations"]: + image_count[x["category_id"]].add(x["image_id"]) + ann_count[x["category_id"]] += 1 + num_freqs = {x: 0 for x in ["r", "f", "c"]} + for x in cats: + x["image_count"] = len(image_count[x["id"]]) + x["instance_count"] = ann_count[x["id"]] + if args.add_freq: + freq = "f" + if x["image_count"] < args.c_thresh: + freq = "c" + if x["image_count"] < args.r_thresh: + freq = "r" + x["frequency"] = freq + num_freqs[freq] += 1 + print(cats) + image_counts = sorted([x["image_count"] for x in cats]) + # print('image count', image_counts) + # import pdb; pdb.set_trace() + if args.add_freq: + for x in ["r", "c", "f"]: + print(x, num_freqs[x]) + out = cats # {'categories': cats} + out_path = args.ann[:-5] + "_cat_info.json" + print("Saving to", out_path) + json.dump(out, open(out_path, "w")) diff --git a/dimos/models/Detic/tools/merge_lvis_coco.py b/dimos/models/Detic/tools/merge_lvis_coco.py new file mode 100644 index 0000000000..1a76a02f0b --- /dev/null +++ b/dimos/models/Detic/tools/merge_lvis_coco.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from collections import defaultdict +import json + +from detectron2.structures import Boxes, pairwise_iou +import torch + +COCO_PATH = "datasets/coco/annotations/instances_train2017.json" +IMG_PATH = "datasets/coco/train2017/" +LVIS_PATH = "datasets/lvis/lvis_v1_train.json" +NO_SEG = False +if NO_SEG: + SAVE_PATH = "datasets/lvis/lvis_v1_train+coco_box.json" +else: + SAVE_PATH = "datasets/lvis/lvis_v1_train+coco_mask.json" +THRESH = 0.7 +DEBUG = False + +# This mapping is extracted from the official LVIS mapping: +# https://github.com/lvis-dataset/lvis-api/blob/master/data/coco_to_synset.json +COCO_SYNSET_CATEGORIES = [ + {"synset": "person.n.01", "coco_cat_id": 1}, + {"synset": "bicycle.n.01", "coco_cat_id": 2}, + {"synset": "car.n.01", "coco_cat_id": 3}, + {"synset": "motorcycle.n.01", "coco_cat_id": 4}, + {"synset": "airplane.n.01", "coco_cat_id": 5}, + {"synset": "bus.n.01", "coco_cat_id": 6}, + {"synset": "train.n.01", "coco_cat_id": 7}, + {"synset": "truck.n.01", "coco_cat_id": 8}, + {"synset": "boat.n.01", "coco_cat_id": 9}, + {"synset": "traffic_light.n.01", "coco_cat_id": 10}, + {"synset": "fireplug.n.01", "coco_cat_id": 11}, + {"synset": "stop_sign.n.01", "coco_cat_id": 13}, + {"synset": "parking_meter.n.01", "coco_cat_id": 14}, + {"synset": "bench.n.01", "coco_cat_id": 15}, + {"synset": "bird.n.01", "coco_cat_id": 16}, + {"synset": "cat.n.01", "coco_cat_id": 17}, + {"synset": "dog.n.01", "coco_cat_id": 18}, + {"synset": "horse.n.01", "coco_cat_id": 19}, + {"synset": "sheep.n.01", "coco_cat_id": 20}, + {"synset": "beef.n.01", "coco_cat_id": 21}, + {"synset": "elephant.n.01", "coco_cat_id": 22}, + {"synset": "bear.n.01", "coco_cat_id": 23}, + {"synset": "zebra.n.01", "coco_cat_id": 24}, + {"synset": "giraffe.n.01", "coco_cat_id": 25}, + {"synset": "backpack.n.01", "coco_cat_id": 27}, + {"synset": "umbrella.n.01", "coco_cat_id": 28}, + {"synset": "bag.n.04", "coco_cat_id": 31}, + {"synset": "necktie.n.01", "coco_cat_id": 32}, + {"synset": "bag.n.06", "coco_cat_id": 33}, + {"synset": "frisbee.n.01", "coco_cat_id": 34}, + {"synset": "ski.n.01", "coco_cat_id": 35}, + {"synset": "snowboard.n.01", "coco_cat_id": 36}, + {"synset": "ball.n.06", "coco_cat_id": 37}, + {"synset": "kite.n.03", "coco_cat_id": 38}, + {"synset": "baseball_bat.n.01", "coco_cat_id": 39}, + {"synset": "baseball_glove.n.01", "coco_cat_id": 40}, + {"synset": "skateboard.n.01", "coco_cat_id": 41}, + {"synset": "surfboard.n.01", "coco_cat_id": 42}, + {"synset": "tennis_racket.n.01", "coco_cat_id": 43}, + {"synset": "bottle.n.01", "coco_cat_id": 44}, + {"synset": "wineglass.n.01", "coco_cat_id": 46}, + {"synset": "cup.n.01", "coco_cat_id": 47}, + {"synset": "fork.n.01", "coco_cat_id": 48}, + {"synset": "knife.n.01", "coco_cat_id": 49}, + {"synset": "spoon.n.01", "coco_cat_id": 50}, + {"synset": "bowl.n.03", "coco_cat_id": 51}, + {"synset": "banana.n.02", "coco_cat_id": 52}, + {"synset": "apple.n.01", "coco_cat_id": 53}, + {"synset": "sandwich.n.01", "coco_cat_id": 54}, + {"synset": "orange.n.01", "coco_cat_id": 55}, + {"synset": "broccoli.n.01", "coco_cat_id": 56}, + {"synset": "carrot.n.01", "coco_cat_id": 57}, + # {"synset": "frank.n.02", "coco_cat_id": 58}, + {"synset": "sausage.n.01", "coco_cat_id": 58}, + {"synset": "pizza.n.01", "coco_cat_id": 59}, + {"synset": "doughnut.n.02", "coco_cat_id": 60}, + {"synset": "cake.n.03", "coco_cat_id": 61}, + {"synset": "chair.n.01", "coco_cat_id": 62}, + {"synset": "sofa.n.01", "coco_cat_id": 63}, + {"synset": "pot.n.04", "coco_cat_id": 64}, + {"synset": "bed.n.01", "coco_cat_id": 65}, + {"synset": "dining_table.n.01", "coco_cat_id": 67}, + {"synset": "toilet.n.02", "coco_cat_id": 70}, + {"synset": "television_receiver.n.01", "coco_cat_id": 72}, + {"synset": "laptop.n.01", "coco_cat_id": 73}, + {"synset": "mouse.n.04", "coco_cat_id": 74}, + {"synset": "remote_control.n.01", "coco_cat_id": 75}, + {"synset": "computer_keyboard.n.01", "coco_cat_id": 76}, + {"synset": "cellular_telephone.n.01", "coco_cat_id": 77}, + {"synset": "microwave.n.02", "coco_cat_id": 78}, + {"synset": "oven.n.01", "coco_cat_id": 79}, + {"synset": "toaster.n.02", "coco_cat_id": 80}, + {"synset": "sink.n.01", "coco_cat_id": 81}, + {"synset": "electric_refrigerator.n.01", "coco_cat_id": 82}, + {"synset": "book.n.01", "coco_cat_id": 84}, + {"synset": "clock.n.01", "coco_cat_id": 85}, + {"synset": "vase.n.01", "coco_cat_id": 86}, + {"synset": "scissors.n.01", "coco_cat_id": 87}, + {"synset": "teddy.n.01", "coco_cat_id": 88}, + {"synset": "hand_blower.n.01", "coco_cat_id": 89}, + {"synset": "toothbrush.n.01", "coco_cat_id": 90}, +] + + +def get_bbox(ann): + bbox = ann["bbox"] + return [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]] + + +if __name__ == "__main__": + file_name_key = "file_name" if "v0.5" in LVIS_PATH else "coco_url" + coco_data = json.load(open(COCO_PATH)) + lvis_data = json.load(open(LVIS_PATH)) + + coco_cats = coco_data["categories"] + lvis_cats = lvis_data["categories"] + + num_find = 0 + num_not_find = 0 + num_twice = 0 + coco2lviscats = {} + synset2lvisid = {x["synset"]: x["id"] for x in lvis_cats} + # cocoid2synset = {x['coco_cat_id']: x['synset'] for x in COCO_SYNSET_CATEGORIES} + coco2lviscats = { + x["coco_cat_id"]: synset2lvisid[x["synset"]] + for x in COCO_SYNSET_CATEGORIES + if x["synset"] in synset2lvisid + } + print(len(coco2lviscats)) + + lvis_file2id = {x[file_name_key][-16:]: x["id"] for x in lvis_data["images"]} + lvis_id2img = {x["id"]: x for x in lvis_data["images"]} + lvis_catid2name = {x["id"]: x["name"] for x in lvis_data["categories"]} + + coco_file2anns = {} + coco_id2img = {x["id"]: x for x in coco_data["images"]} + coco_img2anns = defaultdict(list) + for ann in coco_data["annotations"]: + coco_img = coco_id2img[ann["image_id"]] + file_name = coco_img["file_name"][-16:] + if ann["category_id"] in coco2lviscats and file_name in lvis_file2id: + lvis_image_id = lvis_file2id[file_name] + lvis_image = lvis_id2img[lvis_image_id] + lvis_cat_id = coco2lviscats[ann["category_id"]] + if lvis_cat_id in lvis_image["neg_category_ids"]: + continue + if DEBUG: + import cv2 + + img_path = IMG_PATH + file_name + img = cv2.imread(img_path) + print(lvis_catid2name[lvis_cat_id]) + print("neg", [lvis_catid2name[x] for x in lvis_image["neg_category_ids"]]) + cv2.imshow("img", img) + cv2.waitKey() + ann["category_id"] = lvis_cat_id + ann["image_id"] = lvis_image_id + coco_img2anns[file_name].append(ann) + + lvis_img2anns = defaultdict(list) + for ann in lvis_data["annotations"]: + lvis_img = lvis_id2img[ann["image_id"]] + file_name = lvis_img[file_name_key][-16:] + lvis_img2anns[file_name].append(ann) + + ann_id_count = 0 + anns = [] + for file_name in lvis_img2anns: + coco_anns = coco_img2anns[file_name] + lvis_anns = lvis_img2anns[file_name] + ious = pairwise_iou( + Boxes(torch.tensor([get_bbox(x) for x in coco_anns])), + Boxes(torch.tensor([get_bbox(x) for x in lvis_anns])), + ) + + for ann in lvis_anns: + ann_id_count = ann_id_count + 1 + ann["id"] = ann_id_count + anns.append(ann) + + for i, ann in enumerate(coco_anns): + if len(ious[i]) == 0 or ious[i].max() < THRESH: + ann_id_count = ann_id_count + 1 + ann["id"] = ann_id_count + anns.append(ann) + else: + duplicated = False + for j in range(len(ious[i])): + if ( + ious[i, j] >= THRESH + and coco_anns[i]["category_id"] == lvis_anns[j]["category_id"] + ): + duplicated = True + if not duplicated: + ann_id_count = ann_id_count + 1 + ann["id"] = ann_id_count + anns.append(ann) + if NO_SEG: + for ann in anns: + del ann["segmentation"] + lvis_data["annotations"] = anns + + print("# Images", len(lvis_data["images"])) + print("# Anns", len(lvis_data["annotations"])) + json.dump(lvis_data, open(SAVE_PATH, "w")) diff --git a/dimos/models/Detic/tools/preprocess_imagenet22k.py b/dimos/models/Detic/tools/preprocess_imagenet22k.py new file mode 100644 index 0000000000..edf2d2bbf7 --- /dev/null +++ b/dimos/models/Detic/tools/preprocess_imagenet22k.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import sys + +import numpy as np + +sys.path.insert(0, "third_party/CenterNet2/") +sys.path.insert(0, "third_party/Deformable-DETR") +import gzip +import io +import time + +from detic.data.tar_dataset import _TarDataset + + +class _RawTarDataset: + def __init__(self, filename, indexname: str, preload: bool=False) -> None: + self.filename = filename + self.names = [] + self.offsets = [] + + for l in open(indexname): + ll = l.split() + a, b, c = ll[:3] + offset = int(b[:-1]) + if l.endswith("** Block of NULs **\n"): + self.offsets.append(offset) + break + else: + if c.endswith("JPEG"): + self.names.append(c) + self.offsets.append(offset) + else: + # ignore directories + pass + if preload: + self.data = np.memmap(filename, mode="r", dtype="uint8") + else: + self.data = None + + def __len__(self) -> int: + return len(self.names) + + def __getitem__(self, idx: int): + if self.data is None: + self.data = np.memmap(self.filename, mode="r", dtype="uint8") + ofs = self.offsets[idx] * 512 + fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx]) + data = self.data[ofs : ofs + fsize] + + if data[:13].tostring() == "././@LongLink": + data = data[3 * 512 :] + else: + data = data[512:] + + # just to make it more fun a few JPEGs are GZIP compressed... + # catch this case + if tuple(data[:2]) == (0x1F, 0x8B): + s = io.StringIO(data.tostring()) + g = gzip.GzipFile(None, "r", 0, s) + sdata = g.read() + else: + sdata = data.tostring() + return sdata + + +def preprocess() -> None: + # Follow https://github.com/Alibaba-MIIL/ImageNet21K/blob/main/dataset_preprocessing/processing_script.sh + # Expect 12358684 samples with 11221 classes + # ImageNet folder has 21841 classes (synsets) + + i22kdir = "/datasets01/imagenet-22k/062717/" + i22ktarlogs = "/checkpoint/imisra/datasets/imagenet-22k/tarindex" + class_names_file = "/checkpoint/imisra/datasets/imagenet-22k/words.txt" + + output_dir = "/checkpoint/zhouxy/Datasets/ImageNet/metadata-22k/" + i22knpytarlogs = "/checkpoint/zhouxy/Datasets/ImageNet/metadata-22k/tarindex_npy" + print("Listing dir") + log_files = os.listdir(i22ktarlogs) + log_files = [x for x in log_files if x.endswith(".tarlog")] + log_files.sort() + dataset_lens = [] + min_count = 0 + create_npy_tarlogs = True + print("Creating folders") + if create_npy_tarlogs: + os.makedirs(i22knpytarlogs, exist_ok=True) + for log_file in log_files: + syn = log_file.replace(".tarlog", "") + dataset = _RawTarDataset( + os.path.join(i22kdir, syn + ".tar"), + os.path.join(i22ktarlogs, syn + ".tarlog"), + preload=False, + ) + names = np.array(dataset.names) + offsets = np.array(dataset.offsets, dtype=np.int64) + np.save(os.path.join(i22knpytarlogs, f"{syn}_names.npy"), names) + np.save(os.path.join(i22knpytarlogs, f"{syn}_offsets.npy"), offsets) + + os.makedirs(output_dir, exist_ok=True) + + start_time = time.time() + for log_file in log_files: + syn = log_file.replace(".tarlog", "") + dataset = _TarDataset(os.path.join(i22kdir, syn + ".tar"), i22knpytarlogs) + # dataset = _RawTarDataset(os.path.join(i22kdir, syn + ".tar"), + # os.path.join(i22ktarlogs, syn + ".tarlog"), + # preload=False) + dataset_lens.append(len(dataset)) + end_time = time.time() + print(f"Time {end_time - start_time}") + + dataset_lens = np.array(dataset_lens) + dataset_valid = dataset_lens > min_count + + syn2class = {} + with open(class_names_file) as fh: + for line in fh: + line = line.strip().split("\t") + syn2class[line[0]] = line[1] + + tarlog_files = [] + class_names = [] + tar_files = [] + for k in range(len(dataset_valid)): + if not dataset_valid[k]: + continue + syn = log_files[k].replace(".tarlog", "") + tarlog_files.append(os.path.join(i22ktarlogs, syn + ".tarlog")) + tar_files.append(os.path.join(i22kdir, syn + ".tar")) + class_names.append(syn2class[syn]) + + tarlog_files = np.array(tarlog_files) + tar_files = np.array(tar_files) + class_names = np.array(class_names) + print(f"Have {len(class_names)} classes and {dataset_lens[dataset_valid].sum()} samples") + + np.save(os.path.join(output_dir, "tarlog_files.npy"), tarlog_files) + np.save(os.path.join(output_dir, "tar_files.npy"), tar_files) + np.save(os.path.join(output_dir, "class_names.npy"), class_names) + np.save(os.path.join(output_dir, "tar_files.npy"), tar_files) + + +if __name__ == "__main__": + preprocess() diff --git a/dimos/models/Detic/tools/remove_lvis_rare.py b/dimos/models/Detic/tools/remove_lvis_rare.py new file mode 100644 index 0000000000..423dd6e6e2 --- /dev/null +++ b/dimos/models/Detic/tools/remove_lvis_rare.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import json + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ann", default="datasets/lvis/lvis_v1_train.json") + args = parser.parse_args() + + print("Loading", args.ann) + data = json.load(open(args.ann)) + catid2freq = {x["id"]: x["frequency"] for x in data["categories"]} + print("ori #anns", len(data["annotations"])) + exclude = ["r"] + data["annotations"] = [ + x for x in data["annotations"] if catid2freq[x["category_id"]] not in exclude + ] + print("filtered #anns", len(data["annotations"])) + out_path = args.ann[:-5] + "_norare.json" + print("Saving to", out_path) + json.dump(data, open(out_path, "w")) diff --git a/dimos/models/Detic/tools/unzip_imagenet_lvis.py b/dimos/models/Detic/tools/unzip_imagenet_lvis.py new file mode 100644 index 0000000000..fd969c28bb --- /dev/null +++ b/dimos/models/Detic/tools/unzip_imagenet_lvis.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import os + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src_path", default="datasets/imagenet/ImageNet-21K/") + parser.add_argument("--dst_path", default="datasets/imagenet/ImageNet-LVIS/") + parser.add_argument("--data_path", default="datasets/imagenet_lvis_wnid.txt") + args = parser.parse_args() + + f = open(args.data_path) + for i, line in enumerate(f): + cmd = "mkdir {x} && tar -xf {src}/{l}.tar -C {x}".format( + src=args.src_path, l=line.strip(), x=args.dst_path + "/" + line.strip() + ) + print(i, cmd) + os.system(cmd) diff --git a/dimos/models/Detic/train_net.py b/dimos/models/Detic/train_net.py new file mode 100644 index 0000000000..54ab6136f4 --- /dev/null +++ b/dimos/models/Detic/train_net.py @@ -0,0 +1,266 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from collections import OrderedDict +import datetime +import logging +import os +import sys +import time + +from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer +from detectron2.config import get_cfg +from detectron2.data import ( + MetadataCatalog, + build_detection_test_loader, +) +from detectron2.data.build import build_detection_train_loader +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.engine import default_argument_parser, default_setup, launch +from detectron2.evaluation import ( + COCOEvaluator, + LVISEvaluator, + inference_on_dataset, + print_csv_format, +) +from detectron2.modeling import build_model +from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm +from detectron2.utils.events import ( + CommonMetricPrinter, + EventStorage, + JSONWriter, + TensorboardXWriter, +) +from detectron2.utils.logger import setup_logger +from fvcore.common.timer import Timer +import torch +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel + +sys.path.insert(0, "third_party/CenterNet2/") +from centernet.config import add_centernet_config + +sys.path.insert(0, "third_party/Deformable-DETR") +from detic.config import add_detic_config +from detic.custom_solver import build_custom_optimizer +from detic.data.custom_build_augmentation import build_custom_augmentation +from detic.data.custom_dataset_dataloader import build_custom_train_loader +from detic.data.custom_dataset_mapper import CustomDatasetMapper, DetrDatasetMapper +from detic.evaluation.custom_coco_eval import CustomCOCOEvaluator +from detic.evaluation.oideval import OIDEvaluator +from detic.modeling.utils import reset_cls_test + +logger = logging.getLogger("detectron2") + + +def do_test(cfg, model): + results = OrderedDict() + for d, dataset_name in enumerate(cfg.DATASETS.TEST): + if cfg.MODEL.RESET_CLS_TESTS: + reset_cls_test(model, cfg.MODEL.TEST_CLASSIFIERS[d], cfg.MODEL.TEST_NUM_CLASSES[d]) + mapper = ( + None + if cfg.INPUT.TEST_INPUT_TYPE == "default" + else DatasetMapper(cfg, False, augmentations=build_custom_augmentation(cfg, False)) + ) + data_loader = build_detection_test_loader(cfg, dataset_name, mapper=mapper) + output_folder = os.path.join(cfg.OUTPUT_DIR, f"inference_{dataset_name}") + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + + if evaluator_type == "lvis" or cfg.GEN_PSEDO_LABELS: + evaluator = LVISEvaluator(dataset_name, cfg, True, output_folder) + elif evaluator_type == "coco": + if dataset_name == "coco_generalized_zeroshot_val": + # Additionally plot mAP for 'seen classes' and 'unseen classes' + evaluator = CustomCOCOEvaluator(dataset_name, cfg, True, output_folder) + else: + evaluator = COCOEvaluator(dataset_name, cfg, True, output_folder) + elif evaluator_type == "oid": + evaluator = OIDEvaluator(dataset_name, cfg, True, output_folder) + else: + assert 0, evaluator_type + + results[dataset_name] = inference_on_dataset(model, data_loader, evaluator) + if comm.is_main_process(): + logger.info(f"Evaluation results for {dataset_name} in csv format:") + print_csv_format(results[dataset_name]) + if len(results) == 1: + results = next(iter(results.values())) + return results + + +def do_train(cfg, model, resume: bool=False) -> None: + model.train() + if cfg.SOLVER.USE_CUSTOM_SOLVER: + optimizer = build_custom_optimizer(cfg, model) + else: + assert cfg.SOLVER.OPTIMIZER == "SGD" + assert cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE != "full_model" + assert cfg.SOLVER.BACKBONE_MULTIPLIER == 1.0 + optimizer = build_optimizer(cfg, model) + scheduler = build_lr_scheduler(cfg, optimizer) + + checkpointer = DetectionCheckpointer( + model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler + ) + + start_iter = ( + checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + ) + if not resume: + start_iter = 0 + max_iter = cfg.SOLVER.MAX_ITER if cfg.SOLVER.TRAIN_ITER < 0 else cfg.SOLVER.TRAIN_ITER + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter + ) + + writers = ( + [ + CommonMetricPrinter(max_iter), + JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), + TensorboardXWriter(cfg.OUTPUT_DIR), + ] + if comm.is_main_process() + else [] + ) + + use_custom_mapper = cfg.WITH_IMAGE_LABELS + MapperClass = CustomDatasetMapper if use_custom_mapper else DatasetMapper + mapper = ( + MapperClass(cfg, True) + if cfg.INPUT.CUSTOM_AUG == "" + else DetrDatasetMapper(cfg, True) + if cfg.INPUT.CUSTOM_AUG == "DETR" + else MapperClass(cfg, True, augmentations=build_custom_augmentation(cfg, True)) + ) + if cfg.DATALOADER.SAMPLER_TRAIN in ["TrainingSampler", "RepeatFactorTrainingSampler"]: + data_loader = build_detection_train_loader(cfg, mapper=mapper) + else: + data_loader = build_custom_train_loader(cfg, mapper=mapper) + + if cfg.FP16: + scaler = GradScaler() + + logger.info(f"Starting training from iteration {start_iter}") + with EventStorage(start_iter) as storage: + step_timer = Timer() + data_timer = Timer() + start_time = time.perf_counter() + for data, iteration in zip(data_loader, range(start_iter, max_iter), strict=False): + data_time = data_timer.seconds() + storage.put_scalars(data_time=data_time) + step_timer.reset() + iteration = iteration + 1 + storage.step() + loss_dict = model(data) + + losses = sum(loss for k, loss in loss_dict.items()) + assert torch.isfinite(losses).all(), loss_dict + + loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()} + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + if comm.is_main_process(): + storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced) + + optimizer.zero_grad() + if cfg.FP16: + scaler.scale(losses).backward() + scaler.step(optimizer) + scaler.update() + else: + losses.backward() + optimizer.step() + + storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) + + step_time = step_timer.seconds() + storage.put_scalars(time=step_time) + data_timer.reset() + scheduler.step() + + if ( + cfg.TEST.EVAL_PERIOD > 0 + and iteration % cfg.TEST.EVAL_PERIOD == 0 + and iteration != max_iter + ): + do_test(cfg, model) + comm.synchronize() + + if iteration - start_iter > 5 and (iteration % 20 == 0 or iteration == max_iter): + for writer in writers: + writer.write() + periodic_checkpointer.step(iteration) + + total_time = time.perf_counter() - start_time + logger.info( + f"Total training time: {datetime.timedelta(seconds=int(total_time))!s}" + ) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_centernet_config(cfg) + add_detic_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + if "/auto" in cfg.OUTPUT_DIR: + file_name = os.path.basename(args.config_file)[:-5] + cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace("/auto", f"/{file_name}") + logger.info(f"OUTPUT_DIR: {cfg.OUTPUT_DIR}") + cfg.freeze() + default_setup(cfg, args) + setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="detic") + return cfg + + +def main(args): + cfg = setup(args) + + model = build_model(cfg) + logger.info(f"Model:\n{model}") + if args.eval_only: + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + + return do_test(cfg, model) + + distributed = comm.get_world_size() > 1 + if distributed: + model = DistributedDataParallel( + model, + device_ids=[comm.get_local_rank()], + broadcast_buffers=False, + find_unused_parameters=cfg.FIND_UNUSED_PARAM, + ) + + do_train(cfg, model, resume=args.resume) + return do_test(cfg, model) + + +if __name__ == "__main__": + args = default_argument_parser() + args = args.parse_args() + if args.num_machines == 1: + args.dist_url = f"tcp://127.0.0.1:{torch.randint(11111, 60000, (1,))[0].item()}" + else: + if args.dist_url == "host": + args.dist_url = "tcp://{}:12345".format(os.environ["SLURM_JOB_NODELIST"]) + elif not args.dist_url.startswith("tcp"): + tmp = os.popen( + f"echo $(scontrol show job {args.dist_url} | grep BatchHost)" + ).read() + tmp = tmp[tmp.find("=") + 1 : -1] + args.dist_url = f"tcp://{tmp}:12345" + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/dimos/models/depth/metric3d.py b/dimos/models/depth/metric3d.py index c489e6daa5..0da829aa92 100644 --- a/dimos/models/depth/metric3d.py +++ b/dimos/models/depth/metric3d.py @@ -1,9 +1,19 @@ -import os -import sys -import torch -from PIL import Image +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import cv2 -import numpy as np +import torch # May need to add this back for import to work # external_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'external', 'Metric3D')) @@ -12,25 +22,29 @@ class Metric3D: - def __init__(self): - #self.conf = get_config("zoedepth", "infer") - #self.depth_model = build_model(self.conf) - self.depth_model = torch.hub.load('yvanyin/metric3d', 'metric3d_vit_small', pretrain=True).cuda() + def __init__(self, camera_intrinsics=None, gt_depth_scale: float=256.0) -> None: # type: ignore[no-untyped-def] + # self.conf = get_config("zoedepth", "infer") + # self.depth_model = build_model(self.conf) + self.depth_model = torch.hub.load( # type: ignore[no-untyped-call] + "yvanyin/metric3d", "metric3d_vit_small", pretrain=True + ).cuda() if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs!") - #self.depth_model = torch.nn.DataParallel(self.depth_model) + # self.depth_model = torch.nn.DataParallel(self.depth_model) self.depth_model.eval() - self.intrinsic = [707.0493, 707.0493, 604.0814, 180.5066] - self.gt_depth_scale = 256.0 # And this + self.intrinsic = camera_intrinsics + self.intrinsic_scaled = None + self.gt_depth_scale = gt_depth_scale # And this self.pad_info = None self.rgb_origin = None - ''' + + """ Input: Single image in RGB format Output: Depth map - ''' + """ - def update_intrinsic(self, intrinsic): + def update_intrinsic(self, intrinsic): # type: ignore[no-untyped-def] """ Update the intrinsic parameters dynamically. Ensure that the input intrinsic is valid. @@ -40,49 +54,54 @@ def update_intrinsic(self, intrinsic): self.intrinsic = intrinsic print(f"Intrinsics updated to: {self.intrinsic}") - def infer_depth(self, img, debug=False): + def infer_depth(self, img, debug: bool=False): # type: ignore[no-untyped-def] if debug: print(f"Input image: {img}") try: if isinstance(img, str): print(f"Image type string: {type(img)}") - self.rgb_origin = cv2.imread(img)[:, :, ::-1] + self.rgb_origin = cv2.imread(img)[:, :, ::-1] # type: ignore[assignment] else: - print(f"Image type not string: {type(img)}, cv2 conversion assumed to be handled. If not, this will throw an error") + # print(f"Image type not string: {type(img)}, cv2 conversion assumed to be handled. If not, this will throw an error") self.rgb_origin = img except Exception as e: print(f"Error parsing into infer_depth: {e}") - img = self.rescale_input(img, self.rgb_origin) + img = self.rescale_input(img, self.rgb_origin) # type: ignore[no-untyped-call] with torch.no_grad(): - pred_depth, confidence, output_dict = self.depth_model.inference({'input': img}) - print("Inference completed.") + pred_depth, confidence, output_dict = self.depth_model.inference({"input": img}) # Convert to PIL format - depth_image = self.unpad_transform_depth(pred_depth) - out_16bit_numpy = (depth_image.squeeze().cpu().numpy() * 256).astype(np.uint16) - depth_map_pil = Image.fromarray(out_16bit_numpy) + depth_image = self.unpad_transform_depth(pred_depth) # type: ignore[no-untyped-call] - return depth_map_pil - def save_depth(self, pred_depth): + return depth_image.cpu().numpy() + + def save_depth(self, pred_depth) -> None: # type: ignore[no-untyped-def] # Save the depth map to a file pred_depth_np = pred_depth.cpu().numpy() - output_depth_file = 'output_depth_map.png' + output_depth_file = "output_depth_map.png" cv2.imwrite(output_depth_file, pred_depth_np) print(f"Depth map saved to {output_depth_file}") # Adjusts input size to fit pretrained ViT model - def rescale_input(self, rgb, rgb_origin): + def rescale_input(self, rgb, rgb_origin): # type: ignore[no-untyped-def] #### ajust input size to fit pretrained model # keep ratio resize input_size = (616, 1064) # for vit model # input_size = (544, 1216) # for convnext model h, w = rgb_origin.shape[:2] scale = min(input_size[0] / h, input_size[1] / w) - rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) + rgb = cv2.resize( + rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR + ) # remember to scale intrinsic, hold depth - self.intrinsic = [self.intrinsic[0] * scale, self.intrinsic[1] * scale, self.intrinsic[2] * scale, self.intrinsic[3] * scale] + self.intrinsic_scaled = [ # type: ignore[assignment] + self.intrinsic[0] * scale, + self.intrinsic[1] * scale, + self.intrinsic[2] * scale, + self.intrinsic[3] * scale, + ] # padding to input_size padding = [123.675, 116.28, 103.53] h, w = rgb.shape[:2] @@ -90,9 +109,16 @@ def rescale_input(self, rgb, rgb_origin): pad_w = input_size[1] - w pad_h_half = pad_h // 2 pad_w_half = pad_w // 2 - rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, - cv2.BORDER_CONSTANT, value=padding) - self.pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] + rgb = cv2.copyMakeBorder( + rgb, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=padding, + ) + self.pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] # type: ignore[assignment] #### normalize mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] @@ -101,35 +127,41 @@ def rescale_input(self, rgb, rgb_origin): rgb = torch.div((rgb - mean), std) rgb = rgb[None, :, :, :].cuda() return rgb - def unpad_transform_depth(self, pred_depth): + + def unpad_transform_depth(self, pred_depth): # type: ignore[no-untyped-def] # un pad pred_depth = pred_depth.squeeze() - pred_depth = pred_depth[self.pad_info[0]: pred_depth.shape[0] - self.pad_info[1], - self.pad_info[2]: pred_depth.shape[1] - self.pad_info[3]] + pred_depth = pred_depth[ + self.pad_info[0] : pred_depth.shape[0] - self.pad_info[1], # type: ignore[index] + self.pad_info[2] : pred_depth.shape[1] - self.pad_info[3], # type: ignore[index] + ] # upsample to original size - pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], self.rgb_origin.shape[:2], - mode='bilinear').squeeze() + pred_depth = torch.nn.functional.interpolate( + pred_depth[None, None, :, :], self.rgb_origin.shape[:2], mode="bilinear" # type: ignore[attr-defined] + ).squeeze() ###################### canonical camera space ###################### #### de-canonical transform - canonical_to_real_scale = self.intrinsic[0] / 1000.0 # 1000.0 is the focal length of canonical camera + canonical_to_real_scale = ( + self.intrinsic_scaled[0] / 1000.0 # type: ignore[index] + ) # 1000.0 is the focal length of canonical camera pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric - pred_depth = torch.clamp(pred_depth, 0, 300) + pred_depth = torch.clamp(pred_depth, 0, 1000) return pred_depth - """Set new intrinsic value.""" - def update_intrinsic(self, intrinsic): + + def update_intrinsic(self, intrinsic) -> None: # type: ignore[no-redef, no-untyped-def] self.intrinsic = intrinsic - def eval_predicted_depth(self, depth_file, pred_depth): + def eval_predicted_depth(self, depth_file, pred_depth) -> None: # type: ignore[no-untyped-def] if depth_file is not None: gt_depth = cv2.imread(depth_file, -1) gt_depth = gt_depth / self.gt_depth_scale - gt_depth = torch.from_numpy(gt_depth).float().cuda() + gt_depth = torch.from_numpy(gt_depth).float().cuda() # type: ignore[assignment] assert gt_depth.shape == pred_depth.shape - mask = (gt_depth > 1e-8) + mask = gt_depth > 1e-8 abs_rel_err = (torch.abs(pred_depth[mask] - gt_depth[mask]) / gt_depth[mask]).mean() - print('abs_rel_err:', abs_rel_err.item()) \ No newline at end of file + print("abs_rel_err:", abs_rel_err.item()) diff --git a/dimos/models/embedding/__init__.py b/dimos/models/embedding/__init__.py new file mode 100644 index 0000000000..981e25e5c2 --- /dev/null +++ b/dimos/models/embedding/__init__.py @@ -0,0 +1,30 @@ +from dimos.models.embedding.base import Embedding, EmbeddingModel + +__all__ = [ + "Embedding", + "EmbeddingModel", +] + +# Optional: CLIP support +try: + from dimos.models.embedding.clip import CLIPEmbedding, CLIPModel + + __all__.extend(["CLIPEmbedding", "CLIPModel"]) +except ImportError: + pass + +# Optional: MobileCLIP support +try: + from dimos.models.embedding.mobileclip import MobileCLIPEmbedding, MobileCLIPModel + + __all__.extend(["MobileCLIPEmbedding", "MobileCLIPModel"]) +except ImportError: + pass + +# Optional: TorchReID support +try: + from dimos.models.embedding.treid import TorchReIDEmbedding, TorchReIDModel + + __all__.extend(["TorchReIDEmbedding", "TorchReIDModel"]) +except ImportError: + pass diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py new file mode 100644 index 0000000000..6633b6a5f1 --- /dev/null +++ b/dimos/models/embedding/base.py @@ -0,0 +1,150 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +import time +from typing import TYPE_CHECKING, Generic, TypeVar + +import numpy as np +import torch + +from dimos.types.timestamped import Timestamped + +if TYPE_CHECKING: + from dimos.msgs.sensor_msgs import Image + + +class Embedding(Timestamped): + """Base class for embeddings with vector data. + + Supports both torch.Tensor (for GPU-accelerated comparisons) and np.ndarray. + Embeddings are kept as torch.Tensor on device by default for efficiency. + """ + + vector: torch.Tensor | np.ndarray # type: ignore[type-arg] + + def __init__(self, vector: torch.Tensor | np.ndarray, timestamp: float | None = None) -> None: # type: ignore[type-arg] + self.vector = vector + if timestamp: + self.timestamp = timestamp + else: + self.timestamp = time.time() + + def __matmul__(self, other: Embedding) -> float: + """Compute cosine similarity via @ operator.""" + if isinstance(self.vector, torch.Tensor): + other_tensor = other.to_torch(self.vector.device) + result = self.vector @ other_tensor + return result.item() + return float(self.vector @ other.to_numpy()) + + def to_numpy(self) -> np.ndarray: # type: ignore[type-arg] + """Convert to numpy array (moves to CPU if needed).""" + if isinstance(self.vector, torch.Tensor): + return self.vector.detach().cpu().numpy() + return self.vector + + def to_torch(self, device: str | torch.device | None = None) -> torch.Tensor: + """Convert to torch tensor on specified device.""" + if isinstance(self.vector, np.ndarray): + tensor = torch.from_numpy(self.vector) + return tensor.to(device) if device else tensor + + if device is not None and self.vector.device != torch.device(device): + return self.vector.to(device) + return self.vector + + def to_cpu(self) -> Embedding: + """Move embedding to CPU, returning self for chaining.""" + if isinstance(self.vector, torch.Tensor): + self.vector = self.vector.cpu() + return self + + +E = TypeVar("E", bound="Embedding") + + +class EmbeddingModel(ABC, Generic[E]): + """Abstract base class for embedding models supporting vision and language.""" + + device: str + normalize: bool = True + + @abstractmethod + def embed(self, *images: Image) -> E | list[E]: + """ + Embed one or more images. + Returns single Embedding if one image, list if multiple. + """ + pass + + @abstractmethod + def embed_text(self, *texts: str) -> E | list[E]: + """ + Embed one or more text strings. + Returns single Embedding if one text, list if multiple. + """ + pass + + def compare_one_to_many(self, query: E, candidates: list[E]) -> torch.Tensor: + """ + Efficiently compare one query against many candidates on GPU. + + Args: + query: Query embedding + candidates: List of candidate embeddings + + Returns: + torch.Tensor of similarities (N,) + """ + query_tensor = query.to_torch(self.device) + candidate_tensors = torch.stack([c.to_torch(self.device) for c in candidates]) + return query_tensor @ candidate_tensors.T + + def compare_many_to_many(self, queries: list[E], candidates: list[E]) -> torch.Tensor: + """ + Efficiently compare all queries against all candidates on GPU. + + Args: + queries: List of query embeddings + candidates: List of candidate embeddings + + Returns: + torch.Tensor of similarities (M, N) where M=len(queries), N=len(candidates) + """ + query_tensors = torch.stack([q.to_torch(self.device) for q in queries]) + candidate_tensors = torch.stack([c.to_torch(self.device) for c in candidates]) + return query_tensors @ candidate_tensors.T + + def query(self, query_emb: E, candidates: list[E], top_k: int = 5) -> list[tuple[int, float]]: + """ + Find top-k most similar candidates to query (GPU accelerated). + + Args: + query_emb: Query embedding + candidates: List of candidate embeddings + top_k: Number of top results to return + + Returns: + List of (index, similarity) tuples sorted by similarity (descending) + """ + similarities = self.compare_one_to_many(query_emb, candidates) + top_values, top_indices = similarities.topk(k=min(top_k, len(candidates))) + return [(idx.item(), val.item()) for idx, val in zip(top_indices, top_values, strict=False)] + + def warmup(self) -> None: + """Optional warmup method to pre-load model.""" + pass diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py new file mode 100644 index 0000000000..6fd3f70009 --- /dev/null +++ b/dimos/models/embedding/clip.py @@ -0,0 +1,122 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image as PILImage +import torch +import torch.nn.functional as F +from transformers import CLIPModel as HFCLIPModel, CLIPProcessor # type: ignore[import-untyped] + +from dimos.models.embedding.base import Embedding, EmbeddingModel +from dimos.msgs.sensor_msgs import Image + +_CUDA_INITIALIZED = False + + +class CLIPEmbedding(Embedding): ... + + +class CLIPModel(EmbeddingModel[CLIPEmbedding]): + """CLIP embedding model for vision-language re-identification.""" + + def __init__( + self, + model_name: str = "openai/clip-vit-base-patch32", + device: str | None = None, + normalize: bool = False, + ) -> None: + """ + Initialize CLIP model. + + Args: + model_name: HuggingFace model name (e.g., "openai/clip-vit-base-patch32") + device: Device to run on (cuda/cpu), auto-detects if None + normalize: Whether to L2 normalize embeddings + """ + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.normalize = normalize + + # Load model and processor + self.model = HFCLIPModel.from_pretrained(model_name).eval().to(self.device) + self.processor = CLIPProcessor.from_pretrained(model_name) + + def embed(self, *images: Image) -> CLIPEmbedding | list[CLIPEmbedding]: + """Embed one or more images. + + Returns embeddings as torch.Tensor on device for efficient GPU comparisons. + """ + # Convert to PIL images + pil_images = [PILImage.fromarray(img.to_opencv()) for img in images] + + # Process images + with torch.inference_mode(): + inputs = self.processor(images=pil_images, return_tensors="pt").to(self.device) + image_features = self.model.get_image_features(**inputs) + + if self.normalize: + image_features = F.normalize(image_features, dim=-1) + + # Create embeddings (keep as torch.Tensor on device) + embeddings = [] + for i, feat in enumerate(image_features): + timestamp = images[i].ts + embeddings.append(CLIPEmbedding(vector=feat, timestamp=timestamp)) + + return embeddings[0] if len(images) == 1 else embeddings + + def embed_text(self, *texts: str) -> CLIPEmbedding | list[CLIPEmbedding]: + """Embed one or more text strings. + + Returns embeddings as torch.Tensor on device for efficient GPU comparisons. + """ + with torch.inference_mode(): + inputs = self.processor(text=list(texts), return_tensors="pt", padding=True).to( + self.device + ) + text_features = self.model.get_text_features(**inputs) + + if self.normalize: + text_features = F.normalize(text_features, dim=-1) + + # Create embeddings (keep as torch.Tensor on device) + embeddings = [] + for feat in text_features: + embeddings.append(CLIPEmbedding(vector=feat)) + + return embeddings[0] if len(texts) == 1 else embeddings + + def warmup(self) -> None: + """Warmup the model with a dummy forward pass.""" + # WORKAROUND: HuggingFace CLIP fails with CUBLAS_STATUS_ALLOC_FAILED when it's + # the first model to use CUDA. Initialize CUDA context with a dummy operation. + # This only needs to happen once per process. + global _CUDA_INITIALIZED + if self.device == "cuda" and not _CUDA_INITIALIZED: + try: + # Initialize CUDA with a small matmul operation to setup cuBLAS properly + _ = torch.zeros(1, 1, device="cuda") @ torch.zeros(1, 1, device="cuda") + torch.cuda.synchronize() + _CUDA_INITIALIZED = True + except Exception: + # If initialization fails, continue anyway - the warmup might still work + pass + + dummy_image = torch.randn(1, 3, 224, 224).to(self.device) + dummy_text_inputs = self.processor(text=["warmup"], return_tensors="pt", padding=True).to( + self.device + ) + + with torch.inference_mode(): + # Use pixel_values directly for image warmup + self.model.get_image_features(pixel_values=dummy_image) + self.model.get_text_features(**dummy_text_inputs) diff --git a/dimos/models/embedding/embedding_models_disabled_tests.py b/dimos/models/embedding/embedding_models_disabled_tests.py new file mode 100644 index 0000000000..6c80595571 --- /dev/null +++ b/dimos/models/embedding/embedding_models_disabled_tests.py @@ -0,0 +1,404 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from dimos.msgs.sensor_msgs import Image +from dimos.utils.data import get_data + + +@pytest.fixture(scope="session", params=["clip", "mobileclip", "treid"]) +def embedding_model(request): # type: ignore[no-untyped-def] + """Load embedding model once for all tests. Parametrized for different models.""" + if request.param == "mobileclip": + from dimos.models.embedding.mobileclip import MobileCLIPModel + + model_path = get_data("models_mobileclip") / "mobileclip2_s0.pt" + model = MobileCLIPModel(model_name="MobileCLIP2-S0", model_path=model_path) + elif request.param == "clip": + from dimos.models.embedding.clip import CLIPModel + + model = CLIPModel(model_name="openai/clip-vit-base-patch32") # type: ignore[assignment] + elif request.param == "treid": + from dimos.models.embedding.treid import TorchReIDModel + + model = TorchReIDModel(model_name="osnet_x1_0") # type: ignore[assignment] + else: + raise ValueError(f"Unknown model: {request.param}") + + model.warmup() + return model + + +@pytest.fixture(scope="session") +def test_image(): # type: ignore[no-untyped-def] + """Load test image.""" + return Image.from_file(get_data("cafe.jpg")).to_rgb() # type: ignore[arg-type] + + +@pytest.mark.heavy +def test_single_image_embedding(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test embedding a single image.""" + embedding = embedding_model.embed(test_image) + + # Embedding should be torch.Tensor on device + import torch + + assert isinstance(embedding.vector, torch.Tensor), "Embedding should be torch.Tensor" + assert embedding.vector.device.type in ["cuda", "cpu"], "Should be on valid device" + + # Test conversion to numpy + vector_np = embedding.to_numpy() + print(f"\nEmbedding shape: {vector_np.shape}") + print(f"Embedding dtype: {vector_np.dtype}") + print(f"Embedding norm: {np.linalg.norm(vector_np):.4f}") + + assert vector_np.shape[0] > 0, "Embedding should have features" + assert np.isfinite(vector_np).all(), "Embedding should contain finite values" + + # Check L2 normalization + norm = np.linalg.norm(vector_np) + assert abs(norm - 1.0) < 0.01, f"Embedding should be L2 normalized, got norm={norm}" + + +@pytest.mark.heavy +def test_batch_image_embedding(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test embedding multiple images at once.""" + embeddings = embedding_model.embed(test_image, test_image, test_image) + + assert isinstance(embeddings, list), "Batch embedding should return list" + assert len(embeddings) == 3, "Should return 3 embeddings" + + # Check all embeddings are similar (same image) + sim_01 = embeddings[0] @ embeddings[1] + sim_02 = embeddings[0] @ embeddings[2] + + print(f"\nSimilarity between same images: {sim_01:.6f}, {sim_02:.6f}") + + assert sim_01 > 0.99, f"Same image embeddings should be very similar, got {sim_01}" + assert sim_02 > 0.99, f"Same image embeddings should be very similar, got {sim_02}" + + +@pytest.mark.heavy +def test_single_text_embedding(embedding_model) -> None: # type: ignore[no-untyped-def] + """Test embedding a single text string.""" + import torch + + if not hasattr(embedding_model, "embed_text"): + pytest.skip("Model does not support text embeddings") + + embedding = embedding_model.embed_text("a cafe") + + # Should be torch.Tensor + assert isinstance(embedding.vector, torch.Tensor), "Text embedding should be torch.Tensor" + + vector_np = embedding.to_numpy() + print(f"\nText embedding shape: {vector_np.shape}") + print(f"Text embedding norm: {np.linalg.norm(vector_np):.4f}") + + assert vector_np.shape[0] > 0, "Text embedding should have features" + assert np.isfinite(vector_np).all(), "Text embedding should contain finite values" + + # Check L2 normalization + norm = np.linalg.norm(vector_np) + assert abs(norm - 1.0) < 0.01, f"Text embedding should be L2 normalized, got norm={norm}" + + +@pytest.mark.heavy +def test_batch_text_embedding(embedding_model) -> None: # type: ignore[no-untyped-def] + """Test embedding multiple text strings at once.""" + import torch + + if not hasattr(embedding_model, "embed_text"): + pytest.skip("Model does not support text embeddings") + + embeddings = embedding_model.embed_text("a cafe", "a person", "a dog") + + assert isinstance(embeddings, list), "Batch text embedding should return list" + assert len(embeddings) == 3, "Should return 3 text embeddings" + + # All should be torch.Tensor and normalized + for i, emb in enumerate(embeddings): + assert isinstance(emb.vector, torch.Tensor), f"Embedding {i} should be torch.Tensor" + norm = np.linalg.norm(emb.to_numpy()) + assert abs(norm - 1.0) < 0.01, f"Text embedding {i} should be L2 normalized" + + +@pytest.mark.heavy +def test_text_image_similarity(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test cross-modal text-image similarity using @ operator.""" + if not hasattr(embedding_model, "embed_text"): + pytest.skip("Model does not support text embeddings") + + img_embedding = embedding_model.embed(test_image) + + # Embed text queries + queries = ["a cafe", "a person", "a car", "a dog", "potato", "food"] + text_embeddings = embedding_model.embed_text(*queries) + + # Compute similarities using @ operator + similarities = {} + for query, text_emb in zip(queries, text_embeddings, strict=False): + similarity = img_embedding @ text_emb + similarities[query] = similarity + print(f"\n'{query}': {similarity:.4f}") + + # Cafe image should match "a cafe" better than "a dog" + assert similarities["a cafe"] > similarities["a dog"], "Should recognize cafe scene" + assert similarities["a person"] > similarities["a car"], "Should detect people in cafe" + + +@pytest.mark.heavy +def test_cosine_distance(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test cosine distance computation (1 - similarity).""" + emb1 = embedding_model.embed(test_image) + emb2 = embedding_model.embed(test_image) + + # Similarity using @ operator + similarity = emb1 @ emb2 + + # Distance is 1 - similarity + distance = 1.0 - similarity + + print(f"\nSimilarity (same image): {similarity:.6f}") + print(f"Distance (same image): {distance:.6f}") + + assert similarity > 0.99, f"Same image should have high similarity, got {similarity}" + assert distance < 0.01, f"Same image should have low distance, got {distance}" + + +@pytest.mark.heavy +def test_query_functionality(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test query method for top-k retrieval.""" + if not hasattr(embedding_model, "embed_text"): + pytest.skip("Model does not support text embeddings") + + # Create a query and some candidates + query_text = embedding_model.embed_text("a cafe") + + # Create candidate embeddings + candidate_texts = ["a cafe", "a restaurant", "a person", "a dog", "a car"] + candidates = embedding_model.embed_text(*candidate_texts) + + # Query for top-3 + results = embedding_model.query(query_text, candidates, top_k=3) + + print("\nTop-3 results:") + for idx, sim in results: + print(f" {candidate_texts[idx]}: {sim:.4f}") + + assert len(results) == 3, "Should return top-3 results" + assert results[0][0] == 0, "Top match should be 'a cafe' itself" + assert results[0][1] > results[1][1], "Results should be sorted by similarity" + assert results[1][1] > results[2][1], "Results should be sorted by similarity" + + +@pytest.mark.heavy +def test_embedding_operator(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test that @ operator works on embeddings.""" + emb1 = embedding_model.embed(test_image) + emb2 = embedding_model.embed(test_image) + + # Use @ operator + similarity = emb1 @ emb2 + + assert isinstance(similarity, float), "@ operator should return float" + assert 0.0 <= similarity <= 1.0, "Cosine similarity should be in [0, 1]" + assert similarity > 0.99, "Same image should have similarity near 1.0" + + +@pytest.mark.heavy +def test_warmup(embedding_model) -> None: # type: ignore[no-untyped-def] + """Test that warmup runs without error.""" + # Warmup is already called in fixture, but test it explicitly + embedding_model.warmup() + # Just verify no exceptions raised + assert True + + +@pytest.mark.heavy +def test_compare_one_to_many(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test GPU-accelerated one-to-many comparison.""" + import torch + + # Create query and gallery + query_emb = embedding_model.embed(test_image) + gallery_embs = embedding_model.embed(test_image, test_image, test_image) + + # Compare on GPU + similarities = embedding_model.compare_one_to_many(query_emb, gallery_embs) + + print(f"\nOne-to-many similarities: {similarities}") + + # Should return torch.Tensor + assert isinstance(similarities, torch.Tensor), "Should return torch.Tensor" + assert similarities.shape == (3,), "Should have 3 similarities" + assert similarities.device.type in ["cuda", "cpu"], "Should be on device" + + # All should be ~1.0 (same image) + similarities_np = similarities.cpu().numpy() + assert np.all(similarities_np > 0.99), "Same images should have similarity ~1.0" + + +@pytest.mark.heavy +def test_compare_many_to_many(embedding_model) -> None: # type: ignore[no-untyped-def] + """Test GPU-accelerated many-to-many comparison.""" + import torch + + if not hasattr(embedding_model, "embed_text"): + pytest.skip("Model does not support text embeddings") + + # Create queries and candidates + queries = embedding_model.embed_text("a cafe", "a person") + candidates = embedding_model.embed_text("a cafe", "a restaurant", "a dog") + + # Compare on GPU + similarities = embedding_model.compare_many_to_many(queries, candidates) + + print(f"\nMany-to-many similarities:\n{similarities}") + + # Should return torch.Tensor + assert isinstance(similarities, torch.Tensor), "Should return torch.Tensor" + assert similarities.shape == (2, 3), "Should be (2, 3) similarity matrix" + assert similarities.device.type in ["cuda", "cpu"], "Should be on device" + + # First query should match first candidate best + similarities_np = similarities.cpu().numpy() + assert similarities_np[0, 0] > similarities_np[0, 2], "Cafe should match cafe better than dog" + + +@pytest.mark.heavy +def test_gpu_query_performance(embedding_model, test_image) -> None: # type: ignore[no-untyped-def] + """Test that query method uses GPU acceleration.""" + # Create a larger gallery + gallery_size = 20 + gallery_images = [test_image] * gallery_size + gallery_embs = embedding_model.embed(*gallery_images) + + query_emb = embedding_model.embed(test_image) + + # Query should use GPU-accelerated comparison + results = embedding_model.query(query_emb, gallery_embs, top_k=5) + + print(f"\nTop-5 results from gallery of {gallery_size}") + for idx, sim in results: + print(f" Index {idx}: {sim:.4f}") + + assert len(results) == 5, "Should return top-5 results" + # All should be high similarity (same image, allow some variation for image preprocessing) + for idx, sim in results: + assert sim > 0.90, f"Same images should have high similarity, got {sim}" + + +@pytest.mark.heavy +def test_embedding_performance(embedding_model) -> None: # type: ignore[no-untyped-def] + """Measure embedding performance over multiple real video frames.""" + import time + + from dimos.utils.testing import TimedSensorReplay + + # Load actual video frames + data_dir = "unitree_go2_lidar_corrected" + get_data(data_dir) + + video_replay = TimedSensorReplay(f"{data_dir}/video") # type: ignore[var-annotated] + + # Collect 10 real frames from the video + test_images = [] + for _ts, frame in video_replay.iterate_ts(duration=1.0): + test_images.append(frame.to_rgb()) + if len(test_images) >= 10: + break + + if len(test_images) < 10: + pytest.skip(f"Not enough video frames found (got {len(test_images)})") + + # Measure single image embedding time + times = [] + for img in test_images: + start = time.perf_counter() + _ = embedding_model.embed(img) + end = time.perf_counter() + elapsed_ms = (end - start) * 1000 + times.append(elapsed_ms) + + # Calculate statistics + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + std_time = (sum((t - avg_time) ** 2 for t in times) / len(times)) ** 0.5 + + print("\n" + "=" * 60) + print("Embedding Performance Statistics:") + print("=" * 60) + print(f"Number of images: {len(test_images)}") + print(f"Average time: {avg_time:.2f} ms") + print(f"Min time: {min_time:.2f} ms") + print(f"Max time: {max_time:.2f} ms") + print(f"Std dev: {std_time:.2f} ms") + print(f"Throughput: {1000 / avg_time:.1f} images/sec") + print("=" * 60) + + # Also test batch embedding performance + start = time.perf_counter() + batch_embeddings = embedding_model.embed(*test_images) + end = time.perf_counter() + batch_time = (end - start) * 1000 + batch_per_image = batch_time / len(test_images) + + print("\nBatch Embedding Performance:") + print(f"Total batch time: {batch_time:.2f} ms") + print(f"Time per image (batched): {batch_per_image:.2f} ms") + print(f"Batch throughput: {1000 / batch_per_image:.1f} images/sec") + print(f"Speedup vs single: {avg_time / batch_per_image:.2f}x") + print("=" * 60) + + # Verify embeddings are valid + assert len(batch_embeddings) == len(test_images) + assert all(e.vector is not None for e in batch_embeddings) + + # Sanity check: verify embeddings are meaningful by testing text-image similarity + # Skip for models that don't support text embeddings + if hasattr(embedding_model, "embed_text"): + print("\n" + "=" * 60) + print("Sanity Check: Text-Image Similarity on First Frame") + print("=" * 60) + first_frame_emb = batch_embeddings[0] + + # Test common object/scene queries + test_queries = [ + "indoor scene", + "outdoor scene", + "a person", + "a dog", + "a robot", + "grass and trees", + "furniture", + "a car", + ] + + text_embeddings = embedding_model.embed_text(*test_queries) + similarities = [] + for query, text_emb in zip(test_queries, text_embeddings, strict=False): + sim = first_frame_emb @ text_emb + similarities.append((query, sim)) + + # Sort by similarity + similarities.sort(key=lambda x: x[1], reverse=True) + + print("Top matching concepts:") + for query, sim in similarities[:5]: + print(f" '{query}': {sim:.4f}") + print("=" * 60) diff --git a/dimos/models/embedding/mobileclip.py b/dimos/models/embedding/mobileclip.py new file mode 100644 index 0000000000..0aa157c118 --- /dev/null +++ b/dimos/models/embedding/mobileclip.py @@ -0,0 +1,112 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import open_clip # type: ignore[import-not-found] +from PIL import Image as PILImage +import torch +import torch.nn.functional as F + +from dimos.models.embedding.base import Embedding, EmbeddingModel +from dimos.msgs.sensor_msgs import Image + + +class MobileCLIPEmbedding(Embedding): ... + + +class MobileCLIPModel(EmbeddingModel[MobileCLIPEmbedding]): + """MobileCLIP embedding model for vision-language re-identification.""" + + def __init__( + self, + model_name: str = "MobileCLIP2-S4", + model_path: Path | str | None = None, + device: str | None = None, + normalize: bool = True, + ) -> None: + """ + Initialize MobileCLIP model. + + Args: + model_name: Name of the model architecture + model_path: Path to pretrained weights + device: Device to run on (cuda/cpu), auto-detects if None + normalize: Whether to L2 normalize embeddings + """ + if not OPEN_CLIP_AVAILABLE: # type: ignore[name-defined] + raise ImportError( + "open_clip is required for MobileCLIPModel. " + "Install it with: pip install open-clip-torch" + ) + + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.normalize = normalize + + # Load model + pretrained = str(model_path) if model_path else None + self.model, _, self.preprocess = open_clip.create_model_and_transforms( + model_name, pretrained=pretrained + ) + self.tokenizer = open_clip.get_tokenizer(model_name) + self.model = self.model.eval().to(self.device) + + def embed(self, *images: Image) -> MobileCLIPEmbedding | list[MobileCLIPEmbedding]: + """Embed one or more images. + + Returns embeddings as torch.Tensor on device for efficient GPU comparisons. + """ + # Convert to PIL images + pil_images = [PILImage.fromarray(img.to_opencv()) for img in images] + + # Preprocess and batch + with torch.inference_mode(): + batch = torch.stack([self.preprocess(img) for img in pil_images]).to(self.device) + feats = self.model.encode_image(batch) + if self.normalize: + feats = F.normalize(feats, dim=-1) + + # Create embeddings (keep as torch.Tensor on device) + embeddings = [] + for i, feat in enumerate(feats): + timestamp = images[i].ts + embeddings.append(MobileCLIPEmbedding(vector=feat, timestamp=timestamp)) + + return embeddings[0] if len(images) == 1 else embeddings + + def embed_text(self, *texts: str) -> MobileCLIPEmbedding | list[MobileCLIPEmbedding]: + """Embed one or more text strings. + + Returns embeddings as torch.Tensor on device for efficient GPU comparisons. + """ + with torch.inference_mode(): + text_tokens = self.tokenizer(list(texts)).to(self.device) + feats = self.model.encode_text(text_tokens) + if self.normalize: + feats = F.normalize(feats, dim=-1) + + # Create embeddings (keep as torch.Tensor on device) + embeddings = [] + for feat in feats: + embeddings.append(MobileCLIPEmbedding(vector=feat)) + + return embeddings[0] if len(texts) == 1 else embeddings + + def warmup(self) -> None: + """Warmup the model with a dummy forward pass.""" + dummy_image = torch.randn(1, 3, 224, 224).to(self.device) + dummy_text = self.tokenizer(["warmup"]).to(self.device) + with torch.inference_mode(): + self.model.encode_image(dummy_image) + self.model.encode_text(dummy_text) diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py new file mode 100644 index 0000000000..db5db46a55 --- /dev/null +++ b/dimos/models/embedding/treid.py @@ -0,0 +1,125 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch +import torch.nn.functional as F +from torchreid import utils as torchreid_utils # type: ignore[import-not-found] + +from dimos.models.embedding.base import Embedding, EmbeddingModel +from dimos.msgs.sensor_msgs import Image + +_CUDA_INITIALIZED = False + + +class TorchReIDEmbedding(Embedding): ... + + +class TorchReIDModel(EmbeddingModel[TorchReIDEmbedding]): + """TorchReID embedding model for person re-identification.""" + + def __init__( + self, + model_name: str = "se_resnext101_32x4d", + model_path: Path | str | None = None, + device: str | None = None, + normalize: bool = False, + ) -> None: + """ + Initialize TorchReID model. + + Args: + model_name: Name of the model architecture (e.g., "osnet_x1_0", "osnet_x0_75") + model_path: Path to pretrained weights (.pth.tar file) + device: Device to run on (cuda/cpu), auto-detects if None + normalize: Whether to L2 normalize embeddings + """ + if not TORCHREID_AVAILABLE: # type: ignore[name-defined] + raise ImportError( + "torchreid is required for TorchReIDModel. Install it with: pip install torchreid" + ) + + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.normalize = normalize + + # Load model using torchreid's FeatureExtractor + model_path_str = str(model_path) if model_path else "" + self.extractor = torchreid_utils.FeatureExtractor( + model_name=model_name, + model_path=model_path_str, + device=self.device, + ) + + def embed(self, *images: Image) -> TorchReIDEmbedding | list[TorchReIDEmbedding]: + """Embed one or more images. + + Returns embeddings as torch.Tensor on device for efficient GPU comparisons. + """ + # Convert to numpy arrays - torchreid expects numpy arrays or file paths + np_images = [img.to_opencv() for img in images] + + # Extract features + with torch.inference_mode(): + features = self.extractor(np_images) + + # torchreid may return either numpy array or torch tensor depending on configuration + if isinstance(features, torch.Tensor): + features_tensor = features.to(self.device) + else: + features_tensor = torch.from_numpy(features).to(self.device) + + if self.normalize: + features_tensor = F.normalize(features_tensor, dim=-1) + + # Create embeddings (keep as torch.Tensor on device) + embeddings = [] + for i, feat in enumerate(features_tensor): + timestamp = images[i].ts + embeddings.append(TorchReIDEmbedding(vector=feat, timestamp=timestamp)) + + return embeddings[0] if len(images) == 1 else embeddings + + def embed_text(self, *texts: str) -> TorchReIDEmbedding | list[TorchReIDEmbedding]: + """Text embedding not supported for ReID models. + + TorchReID models are vision-only person re-identification models + and do not support text embeddings. + """ + raise NotImplementedError( + "TorchReID models are vision-only and do not support text embeddings. " + "Use CLIP or MobileCLIP for text-image similarity." + ) + + def warmup(self) -> None: + """Warmup the model with a dummy forward pass.""" + # WORKAROUND: TorchReID can fail with CUBLAS errors when it's the first model to use CUDA. + # Initialize CUDA context with a dummy operation. This only needs to happen once per process. + global _CUDA_INITIALIZED + if self.device == "cuda" and not _CUDA_INITIALIZED: + try: + # Initialize CUDA with a small matmul operation to setup cuBLAS properly + _ = torch.zeros(1, 1, device="cuda") @ torch.zeros(1, 1, device="cuda") + torch.cuda.synchronize() + _CUDA_INITIALIZED = True + except Exception: + # If initialization fails, continue anyway - the warmup might still work + pass + + # Create a dummy 256x128 image (typical person ReID input size) as numpy array + import numpy as np + + dummy_image = np.random.randint(0, 256, (256, 128, 3), dtype=np.uint8) + with torch.inference_mode(): + _ = self.extractor([dummy_image]) diff --git a/dimos/models/labels/llava-34b.py b/dimos/models/labels/llava-34b.py deleted file mode 100644 index 4838745728..0000000000 --- a/dimos/models/labels/llava-34b.py +++ /dev/null @@ -1,53 +0,0 @@ -import json - -# llava v1.6 -from llama_cpp import Llama -from llama_cpp.llama_chat_format import Llava15ChatHandler - -from vqasynth.datasets.utils import image_to_base64_data_uri - -class Llava: - def __init__(self, mmproj="/app/models/mmproj-model-f16.gguf", model_path="/app/models/llava-v1.6-34b.Q4_K_M.gguf", gpu=True): - chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True) - n_gpu_layers = 0 - if gpu: - n_gpu_layers = -1 - self.llm = Llama(model_path=model_path, chat_handler=chat_handler, n_ctx=2048, logits_all=True, n_gpu_layers=n_gpu_layers) - - def run_inference(self, image, prompt, return_json=True): - data_uri = image_to_base64_data_uri(image) - res = self.llm.create_chat_completion( - messages = [ - {"role": "system", "content": "You are an assistant who perfectly describes images."}, - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": data_uri}}, - {"type" : "text", "text": prompt} - ] - } - ] - ) - if return_json: - - return list(set(self.extract_descriptions_from_incomplete_json(res["choices"][0]["message"]["content"]))) - - return res["choices"][0]["message"]["content"] - - def extract_descriptions_from_incomplete_json(self, json_like_str): - last_object_idx = json_like_str.rfind(',"object') - - if last_object_idx != -1: - json_str = json_like_str[:last_object_idx] + '}' - else: - json_str = json_like_str.strip() - if not json_str.endswith('}'): - json_str += '}' - - try: - json_obj = json.loads(json_str) - descriptions = [details['description'].replace(".","") for key, details in json_obj.items() if 'description' in details] - - return descriptions - except json.JSONDecodeError as e: - raise ValueError(f"Error parsing JSON: {e}") diff --git a/dimos/manipulation/classical/grounding.py b/dimos/models/manipulation/__init__.py similarity index 100% rename from dimos/manipulation/classical/grounding.py rename to dimos/models/manipulation/__init__.py diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/README.md b/dimos/models/manipulation/contact_graspnet_pytorch/README.md new file mode 100644 index 0000000000..bf95fa39cd --- /dev/null +++ b/dimos/models/manipulation/contact_graspnet_pytorch/README.md @@ -0,0 +1,52 @@ +# ContactGraspNet PyTorch Module + +This module provides a PyTorch implementation of ContactGraspNet for robotic grasping on dimOS. + +## Setup Instructions + +### 1. Install Required Dependencies + +Install the manipulation extras from the main repository: + +```bash +# From the root directory of the dimos repository +pip install -e ".[manipulation]" +``` + +This will install all the necessary dependencies for using the contact_graspnet_pytorch module, including: +- PyTorch +- Open3D +- Other manipulation-specific dependencies + +### 2. Testing the Module + +To test that the module is properly installed and functioning: + +```bash +# From the root directory of the dimos repository +pytest -s dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py +``` + +The test will verify that: +- The model can be loaded +- Inference runs correctly +- Grasping outputs are generated as expected + +### 3. Using in Your Code + +Reference ```inference.py``` for usage example. + +### Troubleshooting + +If you encounter issues with imports or missing dependencies: + +1. Verify that the manipulation extras are properly installed: + ```python + import contact_graspnet_pytorch + print("Module loaded successfully!") + ``` + +2. If LFS data files are missing, ensure Git LFS is installed and initialized: + ```bash + git lfs pull + ``` \ No newline at end of file diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py new file mode 100644 index 0000000000..0769fc150d --- /dev/null +++ b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py @@ -0,0 +1,120 @@ +import argparse +import glob +import os + +from contact_graspnet_pytorch import config_utils # type: ignore[import-not-found] +from contact_graspnet_pytorch.checkpoints import CheckpointIO # type: ignore[import-not-found] +from contact_graspnet_pytorch.contact_grasp_estimator import ( # type: ignore[import-not-found] + GraspEstimator, +) +from contact_graspnet_pytorch.data import ( # type: ignore[import-not-found] + load_available_input_data, +) +import numpy as np + +from dimos.utils.data import get_data + + +def inference(global_config, # type: ignore[no-untyped-def] + ckpt_dir, + input_paths, + local_regions: bool=True, + filter_grasps: bool=True, + skip_border_objects: bool=False, + z_range = None, + forward_passes: int=1, + K=None,): + """ + Predict 6-DoF grasp distribution for given model and input data + + :param global_config: config.yaml from checkpoint directory + :param checkpoint_dir: checkpoint directory + :param input_paths: .png/.npz/.npy file paths that contain depth/pointcloud and optionally intrinsics/segmentation/rgb + :param K: Camera Matrix with intrinsics to convert depth to point cloud + :param local_regions: Crop 3D local regions around given segments. + :param skip_border_objects: When extracting local_regions, ignore segments at depth map boundary. + :param filter_grasps: Filter and assign grasp contacts according to segmap. + :param segmap_id: only return grasps from specified segmap_id. + :param z_range: crop point cloud at a minimum/maximum z distance from camera to filter out outlier points. Default: [0.2, 1.8] m + :param forward_passes: Number of forward passes to run on each point cloud. Default: 1 + """ + # Build the model + if z_range is None: + z_range = [0.2, 1.8] + grasp_estimator = GraspEstimator(global_config) + + # Load the weights + model_checkpoint_dir = get_data(ckpt_dir) + checkpoint_io = CheckpointIO(checkpoint_dir=model_checkpoint_dir, model=grasp_estimator.model) + try: + checkpoint_io.load('model.pt') + except FileExistsError: + print('No model checkpoint found') + + + os.makedirs('results', exist_ok=True) + + # Process example test scenes + for p in glob.glob(input_paths): + print('Loading ', p) + + pc_segments = {} + segmap, rgb, depth, cam_K, pc_full, pc_colors = load_available_input_data(p, K=K) + + if segmap is None and (local_regions or filter_grasps): + raise ValueError('Need segmentation map to extract local regions or filter grasps') + + if pc_full is None: + print('Converting depth to point cloud(s)...') + pc_full, pc_segments, pc_colors = grasp_estimator.extract_point_clouds(depth, cam_K, segmap=segmap, rgb=rgb, + skip_border_objects=skip_border_objects, + z_range=z_range) + + print(pc_full.shape) + + print('Generating Grasps...') + pred_grasps_cam, scores, contact_pts, _ = grasp_estimator.predict_scene_grasps(pc_full, + pc_segments=pc_segments, + local_regions=local_regions, + filter_grasps=filter_grasps, + forward_passes=forward_passes) + + # Save results + np.savez('results/predictions_{}'.format(os.path.basename(p.replace('png','npz').replace('npy','npz'))), + pc_full=pc_full, pred_grasps_cam=pred_grasps_cam, scores=scores, contact_pts=contact_pts, pc_colors=pc_colors) + + # Visualize results + # show_image(rgb, segmap) + # visualize_grasps(pc_full, pred_grasps_cam, scores, plot_opencv_cam=True, pc_colors=pc_colors) + + if not glob.glob(input_paths): + print('No files found: ', input_paths) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt_dir', default='models_contact_graspnet', help='Log dir') + parser.add_argument('--np_path', default='test_data/7.npy', help='Input data: npz/npy file with keys either "depth" & camera matrix "K" or just point cloud "pc" in meters. Optionally, a 2D "segmap"') + parser.add_argument('--K', default=None, help='Flat Camera Matrix, pass as "[fx, 0, cx, 0, fy, cy, 0, 0 ,1]"') + parser.add_argument('--z_range', default=[0.2,1.8], help='Z value threshold to crop the input point cloud') + parser.add_argument('--local_regions', action='store_true', default=True, help='Crop 3D local regions around given segments.') + parser.add_argument('--filter_grasps', action='store_true', default=True, help='Filter grasp contacts according to segmap.') + parser.add_argument('--skip_border_objects', action='store_true', default=False, help='When extracting local_regions, ignore segments at depth map boundary.') + parser.add_argument('--forward_passes', type=int, default=1, help='Run multiple parallel forward passes to mesh_utils more potential contact points.') + parser.add_argument('--arg_configs', nargs="*", type=str, default=[], help='overwrite config parameters') + FLAGS = parser.parse_args() + + global_config = config_utils.load_config(FLAGS.ckpt_dir, batch_size=FLAGS.forward_passes, arg_configs=FLAGS.arg_configs) + + print(str(global_config)) + print(f'pid: {os.getpid()!s}') + + inference(global_config, + FLAGS.ckpt_dir, + FLAGS.np_path, + local_regions=FLAGS.local_regions, + filter_grasps=FLAGS.filter_grasps, + skip_border_objects=FLAGS.skip_border_objects, + z_range=eval(str(FLAGS.z_range)), + forward_passes=FLAGS.forward_passes, + K=eval(str(FLAGS.K))) diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py new file mode 100644 index 0000000000..7964a24954 --- /dev/null +++ b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py @@ -0,0 +1,71 @@ +import glob +import os + +import numpy as np +import pytest + + +def is_manipulation_installed() -> bool: + """Check if the manipulation extras are installed.""" + try: + import contact_graspnet_pytorch + return True + except ImportError: + return False + +@pytest.mark.skipif(not is_manipulation_installed(), + reason="This test requires 'pip install .[manipulation]' to be run") +def test_contact_graspnet_inference() -> None: + """Test contact graspnet inference with local regions and filter grasps.""" + # Skip test if manipulation dependencies not installed + if not is_manipulation_installed(): + pytest.skip("contact_graspnet_pytorch not installed. Run 'pip install .[manipulation]' first.") + return + + try: + from contact_graspnet_pytorch import config_utils + + from dimos.models.manipulation.contact_graspnet_pytorch.inference import inference + from dimos.utils.data import get_data + except ImportError: + pytest.skip("Required modules could not be imported. Make sure you have run 'pip install .[manipulation]'.") + return + + # Test data path - use the default test data path + test_data_path = os.path.join(get_data("models_contact_graspnet"), "test_data/0.npy") + + # Check if test data exists + test_files = glob.glob(test_data_path) + if not test_files: + pytest.fail(f"No test data found at {test_data_path}") + + # Load config with default values + ckpt_dir = 'models_contact_graspnet' + global_config = config_utils.load_config(ckpt_dir, batch_size=1) + + # Run inference function with the same params as the command line + result_files_before = glob.glob('results/predictions_*.npz') + + inference( + global_config=global_config, + ckpt_dir=ckpt_dir, + input_paths=test_data_path, + local_regions=True, + filter_grasps=True, + skip_border_objects=False, + z_range=[0.2, 1.8], + forward_passes=1, + K=None + ) + + # Verify results were created + result_files_after = glob.glob('results/predictions_*.npz') + assert len(result_files_after) >= len(result_files_before), "No result files were generated" + + # Load at least one result file and verify it contains expected data + if result_files_after: + latest_result = sorted(result_files_after)[-1] + result_data = np.load(latest_result, allow_pickle=True) + expected_keys = ['pc_full', 'pred_grasps_cam', 'scores', 'contact_pts', 'pc_colors'] + for key in expected_keys: + assert key in result_data.files, f"Expected key '{key}' not found in results" diff --git a/dimos/models/pointcloud/pointcloud_utils.py b/dimos/models/pointcloud/pointcloud_utils.py index 74ff131c55..3d79f3a33d 100644 --- a/dimos/models/pointcloud/pointcloud_utils.py +++ b/dimos/models/pointcloud/pointcloud_utils.py @@ -1,39 +1,63 @@ -import pickle -import numpy as np -import open3d as o3d +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random -def save_pointcloud(pcd, file_path): +import numpy as np +import open3d as o3d # type: ignore[import-untyped] + + +def save_pointcloud(pcd, file_path) -> None: # type: ignore[no-untyped-def] """ Save a point cloud to a file using Open3D. """ o3d.io.write_point_cloud(file_path, pcd) -def restore_pointclouds(pointcloud_paths): + +def restore_pointclouds(pointcloud_paths): # type: ignore[no-untyped-def] restored_pointclouds = [] for path in pointcloud_paths: restored_pointclouds.append(o3d.io.read_point_cloud(path)) return restored_pointclouds -def create_point_cloud_from_rgbd(rgb_image, depth_image, intrinsic_parameters): +def create_point_cloud_from_rgbd(rgb_image, depth_image, intrinsic_parameters): # type: ignore[no-untyped-def] rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth( o3d.geometry.Image(rgb_image), o3d.geometry.Image(depth_image), - depth_scale=0.125, #1000.0, - depth_trunc=10.0, #10.0, - convert_rgb_to_intensity=False + depth_scale=0.125, # 1000.0, + depth_trunc=10.0, # 10.0, + convert_rgb_to_intensity=False, ) intrinsic = o3d.camera.PinholeCameraIntrinsic() - intrinsic.set_intrinsics(intrinsic_parameters['width'], intrinsic_parameters['height'], - intrinsic_parameters['fx'], intrinsic_parameters['fy'], - intrinsic_parameters['cx'], intrinsic_parameters['cy']) + intrinsic.set_intrinsics( + intrinsic_parameters["width"], + intrinsic_parameters["height"], + intrinsic_parameters["fx"], + intrinsic_parameters["fy"], + intrinsic_parameters["cx"], + intrinsic_parameters["cy"], + ) pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic) return pcd -def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3): + +def canonicalize_point_cloud(pcd, canonicalize_threshold: float=0.3): # type: ignore[no-untyped-def] # Segment the largest plane, assumed to be the floor - plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000) + plane_model, inliers = pcd.segment_plane( + distance_threshold=0.01, ransac_n=3, num_iterations=1000 + ) canonicalized = False if len(inliers) / len(pcd.points) > canonicalize_threshold: @@ -61,9 +85,9 @@ def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3): pcd.transform(transformation) # Additional 180-degree rotation around the Z-axis - rotation_z_180 = np.array([[np.cos(np.pi), -np.sin(np.pi), 0], - [np.sin(np.pi), np.cos(np.pi), 0], - [0, 0, 1]]) + rotation_z_180 = np.array( + [[np.cos(np.pi), -np.sin(np.pi), 0], [np.sin(np.pi), np.cos(np.pi), 0], [0, 0, 1]] + ) pcd.rotate(rotation_z_180, center=(0, 0, 0)) return pcd, canonicalized, transformation @@ -72,7 +96,7 @@ def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3): # Distance calculations -def human_like_distance(distance_meters): +def human_like_distance(distance_meters) -> str: # type: ignore[no-untyped-def] # Define the choices with units included, focusing on the 0.1 to 10 meters range if distance_meters < 1: # For distances less than 1 meter choices = [ @@ -115,7 +139,7 @@ def human_like_distance(distance_meters): cumulative_distribution = [] cumulative_sum = 0 for value, unit, probability in choices: - cumulative_sum += probability / total_probability # Normalize probabilities + cumulative_sum += probability / total_probability # type: ignore[assignment] # Normalize probabilities cumulative_distribution.append((cumulative_sum, value, unit)) # Randomly choose based on the cumulative distribution @@ -127,20 +151,23 @@ def human_like_distance(distance_meters): # Fallback to the last choice if something goes wrong return f"{choices[-1][0]} {choices[-1][1]}" -def calculate_distances_between_point_clouds(A, B): + +def calculate_distances_between_point_clouds(A, B): # type: ignore[no-untyped-def] dist_pcd1_to_pcd2 = np.asarray(A.compute_point_cloud_distance(B)) dist_pcd2_to_pcd1 = np.asarray(B.compute_point_cloud_distance(A)) combined_distances = np.concatenate((dist_pcd1_to_pcd2, dist_pcd2_to_pcd1)) avg_dist = np.mean(combined_distances) return human_like_distance(avg_dist) -def calculate_centroid(pcd): + +def calculate_centroid(pcd): # type: ignore[no-untyped-def] """Calculate the centroid of a point cloud.""" points = np.asarray(pcd.points) centroid = np.mean(points, axis=0) return centroid -def calculate_relative_positions(centroids): + +def calculate_relative_positions(centroids): # type: ignore[no-untyped-def] """Calculate the relative positions between centroids of point clouds.""" num_centroids = len(centroids) relative_positions_info = [] @@ -150,15 +177,14 @@ def calculate_relative_positions(centroids): relative_vector = centroids[j] - centroids[i] distance = np.linalg.norm(relative_vector) - relative_positions_info.append({ - 'pcd_pair': (i, j), - 'relative_vector': relative_vector, - 'distance': distance - }) + relative_positions_info.append( + {"pcd_pair": (i, j), "relative_vector": relative_vector, "distance": distance} + ) return relative_positions_info -def get_bounding_box_height(pcd): + +def get_bounding_box_height(pcd): # type: ignore[no-untyped-def] """ Compute the height of the bounding box for a given point cloud. @@ -171,7 +197,8 @@ def get_bounding_box_height(pcd): aabb = pcd.get_axis_aligned_bounding_box() return aabb.get_extent()[1] # Assuming the Y-axis is the up-direction -def compare_bounding_box_height(pcd_i, pcd_j): + +def compare_bounding_box_height(pcd_i, pcd_j): # type: ignore[no-untyped-def] """ Compare the bounding box heights of two point clouds. @@ -182,7 +209,7 @@ def compare_bounding_box_height(pcd_i, pcd_j): Returns: bool: True if the bounding box of pcd_i is taller than that of pcd_j, False otherwise. """ - height_i = get_bounding_box_height(pcd_i) - height_j = get_bounding_box_height(pcd_j) + height_i = get_bounding_box_height(pcd_i) # type: ignore[no-untyped-call] + height_j = get_bounding_box_height(pcd_j) # type: ignore[no-untyped-call] return height_i > height_j diff --git a/dimos/models/qwen/video_query.py b/dimos/models/qwen/video_query.py new file mode 100644 index 0000000000..0b14bdfbc8 --- /dev/null +++ b/dimos/models/qwen/video_query.py @@ -0,0 +1,241 @@ +"""Utility functions for one-off video frame queries using Qwen model.""" + +import json +import os + +import numpy as np +from openai import OpenAI +from reactivex import Observable, operators as ops +from reactivex.subject import Subject + +from dimos.agents.agent import OpenAIAgent +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.utils.threadpool import get_scheduler + +BBox = tuple[float, float, float, float] # (x1, y1, x2, y2) + + +def query_single_frame_observable( + video_observable: Observable, # type: ignore[type-arg] + query: str, + api_key: str | None = None, + model_name: str = "qwen2.5-vl-72b-instruct", +) -> Observable: # type: ignore[type-arg] + """Process a single frame from a video observable with Qwen model. + + Args: + video_observable: An observable that emits video frames + query: The query to ask about the frame + api_key: Alibaba API key. If None, will try to get from ALIBABA_API_KEY env var + model_name: The Qwen model to use. Defaults to qwen2.5-vl-72b-instruct + + Returns: + Observable: An observable that emits a single response string + + Example: + ```python + video_obs = video_provider.capture_video_as_observable() + single_frame = video_obs.pipe(ops.take(1)) + response = query_single_frame_observable(single_frame, "What objects do you see?") + response.subscribe(print) + ``` + """ + # Get API key from env if not provided + api_key = api_key or os.getenv("ALIBABA_API_KEY") + if not api_key: + raise ValueError( + "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" + ) + + # Create Qwen client + qwen_client = OpenAI( + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key=api_key, + ) + + # Create response subject + response_subject = Subject() # type: ignore[var-annotated] + + # Create temporary agent for processing + agent = OpenAIAgent( + dev_name="QwenSingleFrameAgent", + openai_client=qwen_client, + model_name=model_name, + tokenizer=HuggingFaceTokenizer(model_name=f"Qwen/{model_name}"), + max_output_tokens_per_request=100, + system_query=query, + pool_scheduler=get_scheduler(), + ) + + # Take only first frame + single_frame = video_observable.pipe(ops.take(1)) + + # Subscribe to frame processing and forward response to our subject + agent.subscribe_to_image_processing(single_frame) + + # Forward agent responses to our response subject + agent.get_response_observable().subscribe( + on_next=lambda x: response_subject.on_next(x), + on_error=lambda e: response_subject.on_error(e), + on_completed=lambda: response_subject.on_completed(), + ) + + # Clean up agent when response subject completes + response_subject.subscribe(on_completed=lambda: agent.dispose_all()) + + return response_subject + + +def query_single_frame( + image: np.ndarray, # type: ignore[type-arg] + query: str = "Return the center coordinates of the fridge handle as a tuple (x,y)", + api_key: str | None = None, + model_name: str = "qwen2.5-vl-72b-instruct", +) -> str: + """Process a single numpy image array with Qwen model. + + Args: + image: A numpy array image to process (H, W, 3) in RGB format + query: The query to ask about the image + api_key: Alibaba API key. If None, will try to get from ALIBABA_API_KEY env var + model_name: The Qwen model to use. Defaults to qwen2.5-vl-72b-instruct + + Returns: + str: The model's response + + Example: + ```python + import cv2 + image = cv2.imread('image.jpg') + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to RGB + response = query_single_frame(image, "Return the center coordinates of the object _____ as a tuple (x,y)") + print(response) + ``` + """ + # Get API key from env if not provided + api_key = api_key or os.getenv("ALIBABA_API_KEY") + if not api_key: + raise ValueError( + "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" + ) + + # Create Qwen client + qwen_client = OpenAI( + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key=api_key, + ) + + # Create temporary agent for processing + agent = OpenAIAgent( + dev_name="QwenSingleFrameAgent", + openai_client=qwen_client, + model_name=model_name, + tokenizer=HuggingFaceTokenizer(model_name=f"Qwen/{model_name}"), + max_output_tokens_per_request=8192, + system_query=query, + pool_scheduler=get_scheduler(), + ) + + # Use the numpy array directly (no conversion needed) + frame = image + + # Create a Subject that will emit the image once + frame_subject = Subject() # type: ignore[var-annotated] + + # Subscribe to frame processing + agent.subscribe_to_image_processing(frame_subject) + + # Create response observable + response_observable = agent.get_response_observable() + + # Emit the image + frame_subject.on_next(frame) + frame_subject.on_completed() + + # Take first response and run synchronously + response = response_observable.pipe(ops.take(1)).run() + + # Clean up + agent.dispose_all() + + return response # type: ignore[no-any-return] + + +def get_bbox_from_qwen( + video_stream: Observable, object_name: str | None = None # type: ignore[type-arg] +) -> tuple[BBox, float] | None: + """Get bounding box coordinates from Qwen for a specific object or any object. + + Args: + video_stream: Observable video stream + object_name: Optional name of object to detect + + Returns: + Tuple of (bbox, size) where bbox is (x1, y1, x2, y2) and size is height in meters, + or None if no detection + """ + prompt = ( + f"Look at this image and find the {object_name if object_name else 'most prominent object'}. Estimate the approximate height of the subject." + "Return ONLY a JSON object with format: {'name': 'object_name', 'bbox': [x1, y1, x2, y2], 'size': height_in_meters} " + "where x1,y1 is the top-left and x2,y2 is the bottom-right corner of the bounding box. If not found, return None." + ) + + response = query_single_frame_observable(video_stream, prompt).pipe(ops.take(1)).run() + + try: + # Extract JSON from response + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + if start_idx >= 0 and end_idx > start_idx: + json_str = response[start_idx:end_idx] + result = json.loads(json_str) + + # Extract and validate bbox + if "bbox" in result and len(result["bbox"]) == 4: + bbox = tuple(result["bbox"]) # Convert list to tuple + return (bbox, result["size"]) + except Exception as e: + print(f"Error parsing Qwen response: {e}") + print(f"Raw response: {response}") + + return None + + +def get_bbox_from_qwen_frame(frame, object_name: str | None = None) -> BBox | None: # type: ignore[no-untyped-def] + """Get bounding box coordinates from Qwen for a specific object or any object using a single frame. + + Args: + frame: A single image frame (numpy array in RGB format) + object_name: Optional name of object to detect + + Returns: + BBox: Bounding box as (x1, y1, x2, y2) or None if no detection + """ + # Ensure frame is numpy array + if not isinstance(frame, np.ndarray): + raise ValueError("Frame must be a numpy array") + + prompt = ( + f"Look at this image and find the {object_name if object_name else 'most prominent object'}. " + "Return ONLY a JSON object with format: {'name': 'object_name', 'bbox': [x1, y1, x2, y2]} " + "where x1,y1 is the top-left and x2,y2 is the bottom-right corner of the bounding box. If not found, return None." + ) + + response = query_single_frame(frame, prompt) + + try: + # Extract JSON from response + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + if start_idx >= 0 and end_idx > start_idx: + json_str = response[start_idx:end_idx] + result = json.loads(json_str) + + # Extract and validate bbox + if "bbox" in result and len(result["bbox"]) == 4: + return tuple(result["bbox"]) # Convert list to tuple + except Exception as e: + print(f"Error parsing Qwen response: {e}") + print(f"Raw response: {response}") + + return None diff --git a/dimos/models/segmentation/clipseg.py b/dimos/models/segmentation/clipseg.py deleted file mode 100644 index ddc0cc55d4..0000000000 --- a/dimos/models/segmentation/clipseg.py +++ /dev/null @@ -1,14 +0,0 @@ -from transformers import AutoProcessor, CLIPSegForImageSegmentation -import torch -import numpy as np - -class CLIPSeg: - def __init__(self, model_name="CIDAS/clipseg-rd64-refined"): - self.clipseg_processor = AutoProcessor.from_pretrained(model_name) - self.clipseg_model = CLIPSegForImageSegmentation.from_pretrained(model_name) - - def run_inference(self, image, text_descriptions): - inputs = self.clipseg_processor(text=text_descriptions, images=[image] * len(text_descriptions), padding=True, return_tensors="pt") - outputs = self.clipseg_model(**inputs) - logits = outputs.logits - return logits.detach().unsqueeze(1) \ No newline at end of file diff --git a/dimos/models/segmentation/sam.py b/dimos/models/segmentation/sam.py deleted file mode 100644 index 0a1934dcb0..0000000000 --- a/dimos/models/segmentation/sam.py +++ /dev/null @@ -1,15 +0,0 @@ -from transformers import SamModel, SamProcessor -import torch -import numpy as np - -class SAM: - def __init__(self, model_name="facebook/sam-vit-huge", device="cuda"): - self.device = device - self.sam_model = SamModel.from_pretrained(model_name).to(self.device) - self.sam_processor = SamProcessor.from_pretrained(model_name) - - def run_inference_from_points(self, image, points): - sam_inputs = self.sam_processor(image, input_points=points, return_tensors="pt").to(self.device) - with torch.no_grad(): - sam_outputs = self.sam_model(**sam_inputs) - return self.sam_processor.image_processor.post_process_masks(sam_outputs.pred_masks.cpu(), sam_inputs["original_sizes"].cpu(), sam_inputs["reshaped_input_sizes"].cpu()) diff --git a/dimos/models/segmentation/segment_utils.py b/dimos/models/segmentation/segment_utils.py index 197ef9e11f..e203c56bf6 100644 --- a/dimos/models/segmentation/segment_utils.py +++ b/dimos/models/segmentation/segment_utils.py @@ -1,7 +1,22 @@ -import torch +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np +import torch -def find_medoid_and_closest_points(points, num_closest=5): + +def find_medoid_and_closest_points(points, num_closest: int=5): # type: ignore[no-untyped-def] """ Find the medoid from a collection of points and the closest points to the medoid. @@ -18,38 +33,41 @@ def find_medoid_and_closest_points(points, num_closest=5): medoid_idx = np.argmin(distance_sums) medoid = points[medoid_idx] sorted_indices = np.argsort(distances[medoid_idx]) - closest_indices = sorted_indices[1:num_closest + 1] + closest_indices = sorted_indices[1 : num_closest + 1] return medoid, points[closest_indices] -def sample_points_from_heatmap(heatmap, original_size, num_points=5, percentile=0.95): + +def sample_points_from_heatmap(heatmap, original_size: int, num_points: int=5, percentile: float=0.95): # type: ignore[no-untyped-def] """ Sample points from the given heatmap, focusing on areas with higher values. """ - width, height = original_size + width, height = original_size # type: ignore[misc] threshold = np.percentile(heatmap.numpy(), percentile) masked_heatmap = torch.where(heatmap > threshold, heatmap, torch.tensor(0.0)) probabilities = torch.softmax(masked_heatmap.flatten(), dim=0) attn = torch.sigmoid(heatmap) w = attn.shape[0] - sampled_indices = torch.multinomial(torch.tensor(probabilities.ravel()), num_points, replacement=True) + sampled_indices = torch.multinomial( + torch.tensor(probabilities.ravel()), num_points, replacement=True + ) sampled_coords = np.array(np.unravel_index(sampled_indices, attn.shape)).T - medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) + _medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) pts = [] for pt in sampled_coords.tolist(): x, y = pt - x = height * x / w - y = width * y / w + x = height * x / w # type: ignore[has-type] + y = width * y / w # type: ignore[has-type] pts.append([y, x]) return pts -def apply_mask_to_image(image, mask): +def apply_mask_to_image(image, mask): # type: ignore[no-untyped-def] """ Apply a binary mask to an image. The mask should be a binary array where the regions to keep are True. """ masked_image = image.copy() for c in range(masked_image.shape[2]): masked_image[:, :, c] = masked_image[:, :, c] * mask - return masked_image \ No newline at end of file + return masked_image diff --git a/dimos/models/vl/README.md b/dimos/models/vl/README.md new file mode 100644 index 0000000000..c252d47957 --- /dev/null +++ b/dimos/models/vl/README.md @@ -0,0 +1,67 @@ +# Vision Language Models + +This provides vision language model implementations for processing images and text queries. + +## QwenVL Model + +The `QwenVlModel` class provides access to Alibaba's Qwen2.5-VL model for vision-language tasks. + +### Example Usage + +```python +from dimos.models.vl.qwen import QwenVlModel +from dimos.msgs.sensor_msgs.Image import Image + +# Initialize the model (requires ALIBABA_API_KEY environment variable) +model = QwenVlModel() + +image = Image.from_file("path/to/your/image.jpg") + +response = model.query(image.data, "What do you see in this image?") +print(response) +``` + +## Moondream Hosted Model + +The `MoondreamHostedVlModel` class provides access to the hosted Moondream API for fast vision-language tasks. + +**Prerequisites:** + +You must export your API key before using the model: +```bash +export MOONDREAM_API_KEY="your_api_key_here" +``` + +### Capabilities + +The model supports four modes of operation: + +1. **Caption**: Generate a description of the image. +2. **Query**: Ask natural language questions about the image. +3. **Detect**: Find bounding boxes for specific objects. +4. **Point**: Locate the center points of specific objects. + +### Example Usage + +```python +from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel +from dimos.msgs.sensor_msgs import Image + +model = MoondreamHostedVlModel() +image = Image.from_file("path/to/image.jpg") + +# 1. Caption +print(f"Caption: {model.caption(image)}") + +# 2. Query +print(f"Answer: {model.query(image, 'Is there a person in the image?')}") + +# 3. Detect (returns ImageDetections2D) +detections = model.query_detections(image, "person") +for det in detections.detections: + print(f"Found person at {det.bbox}") + +# 4. Point (returns list of (x, y) coordinates) +points = model.point(image, "person") +print(f"Person centers: {points}") +``` diff --git a/dimos/models/vl/__init__.py b/dimos/models/vl/__init__.py new file mode 100644 index 0000000000..3ea4a28453 --- /dev/null +++ b/dimos/models/vl/__init__.py @@ -0,0 +1,4 @@ +from dimos.models.vl.base import VlModel +from dimos.models.vl.moondream import MoondreamVlModel +from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel +from dimos.models.vl.qwen import QwenVlModel diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py new file mode 100644 index 0000000000..acb998d274 --- /dev/null +++ b/dimos/models/vl/base.py @@ -0,0 +1,106 @@ +from abc import ABC, abstractmethod +import json +import logging + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D +from dimos.utils.data import get_data +from dimos.utils.decorators import retry +from dimos.utils.llm_utils import extract_json + +logger = logging.getLogger(__name__) + + +def vlm_detection_to_detection2d( + vlm_detection: list, track_id: int, image: Image # type: ignore[type-arg] +) -> Detection2DBBox | None: + """Convert a single VLM detection [label, x1, y1, x2, y2] to Detection2DBBox. + + Args: + vlm_detection: Single detection list containing [label, x1, y1, x2, y2] + track_id: Track ID to assign to this detection + image: Source image for the detection + + Returns: + Detection2DBBox instance or None if invalid + """ + # Validate list structure + if not isinstance(vlm_detection, list): + logger.debug(f"VLM detection is not a list: {type(vlm_detection)}") + return None + + if len(vlm_detection) != 5: + logger.debug( + f"Invalid VLM detection length: {len(vlm_detection)}, expected 5. Got: {vlm_detection}" + ) + return None + + # Extract label + name = str(vlm_detection[0]) + + # Validate and convert coordinates + try: + coords = [float(x) for x in vlm_detection[1:]] + except (ValueError, TypeError) as e: + logger.debug(f"Invalid VLM detection coordinates: {vlm_detection[1:]}. Error: {e}") + return None + + bbox = tuple(coords) + + # Use -1 for class_id since VLM doesn't provide it + # confidence defaults to 1.0 for VLM + return Detection2DBBox( + bbox=bbox, # type: ignore[arg-type] + track_id=track_id, + class_id=-1, + confidence=1.0, + name=name, + ts=image.ts, + image=image, + ) + + +class VlModel(ABC): + @abstractmethod + def query(self, image: Image, query: str, **kwargs) -> str: ... # type: ignore[no-untyped-def] + + def warmup(self) -> None: + try: + image = Image.from_file(get_data("cafe-smol.jpg")).to_rgb() # type: ignore[arg-type] + self._model.detect(image, "person", settings={"max_objects": 1}) # type: ignore[attr-defined] + except Exception: + pass + + # requery once if JSON parsing fails + @retry(max_retries=2, on_exception=json.JSONDecodeError, delay=0.0) # type: ignore[misc] + def query_json(self, image: Image, query: str) -> dict: # type: ignore[type-arg] + response = self.query(image, query) + return extract_json(response) # type: ignore[return-value] + + def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetections2D: # type: ignore[no-untyped-def] + full_query = f"""show me bounding boxes in pixels for this query: `{query}` + + format should be: + `[ + [label, x1, y1, x2, y2] + ... + ]` + + (etc, multiple matches are possible) + + If there's no match return `[]`. Label is whatever you think is appropriate + Only respond with the coordinates, no other text.""" + + image_detections = ImageDetections2D(image) + + try: + detection_tuples = self.query_json(image, full_query) + except Exception: + return image_detections + + for track_id, detection_tuple in enumerate(detection_tuples): + detection2d = vlm_detection_to_detection2d(detection_tuple, track_id, image) + if detection2d is not None and detection2d.is_valid(): + image_detections.detections.append(detection2d) + + return image_detections diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py new file mode 100644 index 0000000000..485377e305 --- /dev/null +++ b/dimos/models/vl/moondream.py @@ -0,0 +1,113 @@ +from functools import cached_property +import warnings + +import numpy as np +from PIL import Image as PILImage +import torch +from transformers import AutoModelForCausalLM # type: ignore[import-untyped] + +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D + + +class MoondreamVlModel(VlModel): + _model_name: str + _device: str + _dtype: torch.dtype + + def __init__( + self, + model_name: str = "vikhyatk/moondream2", + device: str | None = None, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + self._model_name = model_name + self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self._dtype = dtype + + @cached_property + def _model(self) -> AutoModelForCausalLM: + model = AutoModelForCausalLM.from_pretrained( + self._model_name, + trust_remote_code=True, + torch_dtype=self._dtype, + ) + model = model.to(self._device) + model.compile() + + return model + + def query(self, image: Image | np.ndarray, query: str, **kwargs) -> str: # type: ignore[no-untyped-def, type-arg] + if isinstance(image, np.ndarray): + warnings.warn( + "MoondreamVlModel.query should receive standard dimos Image type, not a numpy array", + DeprecationWarning, + stacklevel=2, + ) + image = Image.from_numpy(image) + + # Convert dimos Image to PIL Image + # dimos Image stores data in RGB/BGR format, convert to RGB for PIL + rgb_image = image.to_rgb() + pil_image = PILImage.fromarray(rgb_image.data) + + # Query the model + result = self._model.query(image=pil_image, question=query, reasoning=False) + + # Handle both dict and string responses + if isinstance(result, dict): + return result.get("answer", str(result)) # type: ignore[no-any-return] + + return str(result) + + def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetections2D: # type: ignore[no-untyped-def] + """Detect objects using Moondream's native detect method. + + Args: + image: Input image + query: Object query (e.g., "person", "car") + max_objects: Maximum number of objects to detect + + Returns: + ImageDetections2D containing detected bounding boxes + """ + pil_image = PILImage.fromarray(image.data) + + settings = {"max_objects": kwargs.get("max_objects", 5)} + result = self._model.detect(pil_image, query, settings=settings) + + # Convert to ImageDetections2D + image_detections = ImageDetections2D(image) + + # Get image dimensions for converting normalized coords to pixels + height, width = image.height, image.width + + for track_id, obj in enumerate(result.get("objects", [])): + # Convert normalized coordinates (0-1) to pixel coordinates + x_min_norm = obj["x_min"] + y_min_norm = obj["y_min"] + x_max_norm = obj["x_max"] + y_max_norm = obj["y_max"] + + x1 = x_min_norm * width + y1 = y_min_norm * height + x2 = x_max_norm * width + y2 = y_max_norm * height + + bbox = (x1, y1, x2, y2) + + detection = Detection2DBBox( + bbox=bbox, + track_id=track_id, + class_id=-1, # Moondream doesn't provide class IDs + confidence=1.0, # Moondream doesn't provide confidence scores + name=query, # Use the query as the object name + ts=image.ts, + image=image, + ) + + if detection.is_valid(): + image_detections.detections.append(detection) + + return image_detections diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py new file mode 100644 index 0000000000..528517d4c7 --- /dev/null +++ b/dimos/models/vl/moondream_hosted.py @@ -0,0 +1,133 @@ +from functools import cached_property +import os +import warnings + +import moondream as md # type: ignore[import-untyped] +import numpy as np +from PIL import Image as PILImage + +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D + + +class MoondreamHostedVlModel(VlModel): + _api_key: str | None + + def __init__(self, api_key: str | None = None) -> None: + self._api_key = api_key + + @cached_property + def _client(self) -> md.vl: + api_key = self._api_key or os.getenv("MOONDREAM_API_KEY") + if not api_key: + raise ValueError( + "Moondream API key must be provided or set in MOONDREAM_API_KEY environment variable" + ) + return md.vl(api_key=api_key) + + def _to_pil_image(self, image: Image | np.ndarray) -> PILImage.Image: # type: ignore[type-arg] + if isinstance(image, np.ndarray): + warnings.warn( + "MoondreamHostedVlModel should receive standard dimos Image type, not a numpy array", + DeprecationWarning, + stacklevel=3, + ) + image = Image.from_numpy(image) + + rgb_image = image.to_rgb() + return PILImage.fromarray(rgb_image.data) + + def query(self, image: Image | np.ndarray, query: str, **kwargs) -> str: # type: ignore[no-untyped-def, type-arg] + pil_image = self._to_pil_image(image) + + result = self._client.query(pil_image, query) + return result.get("answer", str(result)) # type: ignore[no-any-return] + + def caption(self, image: Image | np.ndarray, length: str = "normal") -> str: # type: ignore[type-arg] + """Generate a caption for the image. + + Args: + image: Input image + length: Caption length ("normal", "short", "long") + """ + pil_image = self._to_pil_image(image) + result = self._client.caption(pil_image, length=length) + return result.get("caption", str(result)) # type: ignore[no-any-return] + + def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetections2D: # type: ignore[no-untyped-def] + """Detect objects using Moondream's hosted detect method. + + Args: + image: Input image + query: Object query (e.g., "person", "car") + max_objects: Maximum number of objects to detect (not directly supported by hosted API args in docs, + but we handle the output) + + Returns: + ImageDetections2D containing detected bounding boxes + """ + pil_image = self._to_pil_image(image) + + # API docs: detect(image, object) -> {"objects": [...]} + result = self._client.detect(pil_image, query) + objects = result.get("objects", []) + + # Convert to ImageDetections2D + image_detections = ImageDetections2D(image) + height, width = image.height, image.width + + for track_id, obj in enumerate(objects): + # Expected format from docs: Region with x_min, y_min, x_max, y_max + # Assuming normalized coordinates as per local model and standard VLM behavior + x_min_norm = obj.get("x_min", 0.0) + y_min_norm = obj.get("y_min", 0.0) + x_max_norm = obj.get("x_max", 1.0) + y_max_norm = obj.get("y_max", 1.0) + + x1 = x_min_norm * width + y1 = y_min_norm * height + x2 = x_max_norm * width + y2 = y_max_norm * height + + bbox = (x1, y1, x2, y2) + + detection = Detection2DBBox( + bbox=bbox, + track_id=track_id, + class_id=-1, + confidence=1.0, + name=query, + ts=image.ts, + image=image, + ) + + if detection.is_valid(): + image_detections.detections.append(detection) + + return image_detections + + def point(self, image: Image, query: str) -> list[tuple[float, float]]: + """Get coordinates of specific objects in an image. + + Args: + image: Input image + query: Object query + + Returns: + List of (x, y) pixel coordinates + """ + pil_image = self._to_pil_image(image) + result = self._client.point(pil_image, query) + points = result.get("points", []) + + pixel_points = [] + height, width = image.height, image.width + + for p in points: + x_norm = p.get("x", 0.0) + y_norm = p.get("y", 0.0) + pixel_points.append((x_norm * width, y_norm * height)) + + return pixel_points + diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py new file mode 100644 index 0000000000..4a5948b486 --- /dev/null +++ b/dimos/models/vl/qwen.py @@ -0,0 +1,62 @@ +from functools import cached_property +import os + +import numpy as np +from openai import OpenAI + +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image + + +class QwenVlModel(VlModel): + _model_name: str + _api_key: str | None + + def __init__(self, api_key: str | None = None, model_name: str = "qwen2.5-vl-72b-instruct") -> None: + self._model_name = model_name + self._api_key = api_key + + @cached_property + def _client(self) -> OpenAI: + api_key = self._api_key or os.getenv("ALIBABA_API_KEY") + if not api_key: + raise ValueError( + "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" + ) + + return OpenAI( + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key=api_key, + ) + + def query(self, image: Image | np.ndarray, query: str) -> str: # type: ignore[override, type-arg] + if isinstance(image, np.ndarray): + import warnings + + warnings.warn( + "QwenVlModel.query should receive standard dimos Image type, not a numpy array", + DeprecationWarning, + stacklevel=2, + ) + + image = Image.from_numpy(image) + + img_base64 = image.to_base64() + + response = self._client.chat.completions.create( + model=self._model_name, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_base64}"}, + }, + {"type": "text", "text": query}, + ], + } + ], + ) + + return response.choices[0].message.content # type: ignore[return-value] diff --git a/dimos/models/vl/test_base.py b/dimos/models/vl/test_base.py new file mode 100644 index 0000000000..3d8575fab3 --- /dev/null +++ b/dimos/models/vl/test_base.py @@ -0,0 +1,105 @@ +import os +from unittest.mock import MagicMock + +import pytest + +from dimos.models.vl.qwen import QwenVlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import ImageDetections2D +from dimos.utils.data import get_data + +# Captured actual response from Qwen API for cafe.jpg with query "humans" +# Added garbage around JSON to ensure we are robustly extracting it +MOCK_QWEN_RESPONSE = """ + Locating humans for you 😊😊 + + [ + ["humans", 76, 368, 219, 580], + ["humans", 354, 372, 512, 525], + ["humans", 409, 370, 615, 748], + ["humans", 628, 350, 762, 528], + ["humans", 785, 323, 960, 650] + ] + + Here is some trash at the end of the response :) + Let me know if you need anything else 😀😊 + """ + + +def test_query_detections_mocked() -> None: + """Test query_detections with mocked API response (no API key required).""" + # Load test image + image = Image.from_file(get_data("cafe.jpg")) + + # Create model and mock the query method + model = QwenVlModel() + model.query = MagicMock(return_value=MOCK_QWEN_RESPONSE) + + # Query for humans in the image + query = "humans" + detections = model.query_detections(image, query) + + # Verify the return type + assert isinstance(detections, ImageDetections2D) + + # Should have 5 detections based on our mock data + assert len(detections.detections) == 5, ( + f"Expected 5 detections, got {len(detections.detections)}" + ) + + # Verify each detection + img_height, img_width = image.shape[:2] + + for i, detection in enumerate(detections.detections): + # Verify attributes + assert detection.name == "humans" + assert detection.confidence == 1.0 + assert detection.class_id == -1 # VLM detections use -1 for class_id + assert detection.track_id == i + assert len(detection.bbox) == 4 + + assert detection.is_valid() + + # Verify bbox coordinates are valid (out-of-bounds detections are discarded) + x1, y1, x2, y2 = detection.bbox + assert x2 > x1, f"Detection {i}: Invalid x coordinates: x1={x1}, x2={x2}" + assert y2 > y1, f"Detection {i}: Invalid y coordinates: y1={y1}, y2={y2}" + + # Check bounds (out-of-bounds detections would have been discarded) + assert 0 <= x1 <= img_width, f"Detection {i}: x1={x1} out of bounds" + assert 0 <= x2 <= img_width, f"Detection {i}: x2={x2} out of bounds" + assert 0 <= y1 <= img_height, f"Detection {i}: y1={y1} out of bounds" + assert 0 <= y2 <= img_height, f"Detection {i}: y2={y2} out of bounds" + + print(f"✓ Successfully processed {len(detections.detections)} mocked detections") + + +@pytest.mark.tool +@pytest.mark.skipif(not os.getenv("ALIBABA_API_KEY"), reason="ALIBABA_API_KEY not set") +def test_query_detections_real() -> None: + """Test query_detections with real API calls (requires API key).""" + # Load test image + image = Image.from_file(get_data("cafe.jpg")) + + # Initialize the model (will use real API) + model = QwenVlModel() + + # Query for humans in the image + query = "humans" + detections = model.query_detections(image, query) + + assert isinstance(detections, ImageDetections2D) + print(detections) + + # Check that detections were found + if detections.detections: + for detection in detections.detections: + # Verify each detection has expected attributes + assert detection.bbox is not None + assert len(detection.bbox) == 4 + assert detection.name + assert detection.confidence == 1.0 + assert detection.class_id == -1 # VLM detections use -1 for class_id + assert detection.is_valid() + + print(f"Found {len(detections.detections)} detections for query '{query}'") diff --git a/dimos/models/vl/test_models.py b/dimos/models/vl/test_models.py new file mode 100644 index 0000000000..b33e0905e6 --- /dev/null +++ b/dimos/models/vl/test_models.py @@ -0,0 +1,91 @@ +import time +from typing import TYPE_CHECKING + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations +import pytest + +from dimos.core import LCMTransport +from dimos.models.vl.moondream import MoondreamVlModel +from dimos.models.vl.qwen import QwenVlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import ImageDetections2D +from dimos.utils.data import get_data + +if TYPE_CHECKING: + from dimos.models.vl.base import VlModel + + +@pytest.mark.parametrize( + "model_class,model_name", + [ + (MoondreamVlModel, "Moondream"), + (QwenVlModel, "Qwen"), + ], + ids=["moondream", "qwen"], +) +@pytest.mark.gpu +def test_vlm(model_class, model_name: str) -> None: + image = Image.from_file(get_data("cafe.jpg")).to_rgb() + + print(f"Testing {model_name}") + + # Initialize model + print(f"Loading {model_name} model...") + model: VlModel = model_class() + model.warmup() + + queries = [ + "glasses", + "blue shirt", + "bulb", + "cigarette", + "reflection of a car", + "knee", + "flowers on the left table", + "shoes", + "leftmost persons ear", + "rightmost arm", + ] + + all_detections = ImageDetections2D(image) + query_times = [] + + # # First, run YOLO detection + # print("\nRunning YOLO detection...") + # yolo_detector = Yolo2DDetector() + # yolo_detections = yolo_detector.process_image(image) + # print(f" YOLO found {len(yolo_detections.detections)} objects") + # all_detections.detections.extend(yolo_detections.detections) + # annotations_transport.publish(all_detections.to_foxglove_annotations()) + + # Publish to LCM with model-specific channel names + annotations_transport: LCMTransport[ImageAnnotations] = LCMTransport( + "/annotations", ImageAnnotations + ) + + image_transport: LCMTransport[Image] = LCMTransport("/image", Image) + + image_transport.publish(image) + + # Then run VLM queries + for query in queries: + print(f"\nQuerying for: {query}") + start_time = time.time() + detections = model.query_detections(image, query, max_objects=5) + query_time = time.time() - start_time + query_times.append(query_time) + + print(f" Found {len(detections)} detections in {query_time:.3f}s") + all_detections.detections.extend(detections.detections) + annotations_transport.publish(all_detections.to_foxglove_annotations()) + + avg_time = sum(query_times) / len(query_times) if query_times else 0 + print(f"\n{model_name} Results:") + print(f" Average query time: {avg_time:.3f}s") + print(f" Total detections: {len(all_detections)}") + print(all_detections) + + annotations_transport.publish(all_detections.to_foxglove_annotations()) + + annotations_transport.lcm.stop() + image_transport.lcm.stop() diff --git a/dimos/models/vl/test_moondream_hosted.py b/dimos/models/vl/test_moondream_hosted.py new file mode 100644 index 0000000000..1f3d59d1b9 --- /dev/null +++ b/dimos/models/vl/test_moondream_hosted.py @@ -0,0 +1,98 @@ +import os +import time + +import pytest + +from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import ImageDetections2D + +# Skip all tests in this module if API key is missing +pytestmark = pytest.mark.skipif( + not os.getenv("MOONDREAM_API_KEY"), + reason="MOONDREAM_API_KEY not set" +) + +@pytest.fixture +def model(): + return MoondreamHostedVlModel() + +@pytest.fixture +def test_image(): + image_path = os.path.join(os.getcwd(), "assets/test.png") + if not os.path.exists(image_path): + pytest.skip(f"Test image not found at {image_path}") + return Image.from_file(image_path) + +def test_caption(model, test_image) -> None: + """Test generating a caption.""" + print("\n--- Testing Caption ---") + caption = model.caption(test_image) + print(f"Caption: {caption}") + assert isinstance(caption, str) + assert len(caption) > 0 + +def test_query(model, test_image) -> None: + """Test querying the image.""" + print("\n--- Testing Query ---") + question = "Is there an xbox controller in the image?" + answer = model.query(test_image, question) + print(f"Question: {question}") + print(f"Answer: {answer}") + assert isinstance(answer, str) + assert len(answer) > 0 + # The answer should likely be positive given the user's prompt + assert "yes" in answer.lower() or "controller" in answer.lower() + +def test_query_latency(model, test_image) -> None: + """Test that a simple query returns in under 1 second.""" + print("\n--- Testing Query Latency ---") + question = "What is this?" + + # Warmup (optional, but good practice if first call establishes connection) + # model.query(test_image, "warmup") + + start_time = time.perf_counter() + model.query(test_image, question) + end_time = time.perf_counter() + + duration = end_time - start_time + print(f"Query took {duration:.4f} seconds") + + assert duration < 1.0, f"Query took too long: {duration:.4f}s > 1.0s" + +@pytest.mark.parametrize("subject", ["xbox controller", "lip balm"]) +def test_detect(model, test_image, subject: str) -> None: + """Test detecting objects.""" + print(f"\n--- Testing Detect: {subject} ---") + detections = model.query_detections(test_image, subject) + + assert isinstance(detections, ImageDetections2D) + print(f"Found {len(detections.detections)} detections for {subject}") + + # We expect to find at least one of each in the provided test image + assert len(detections.detections) > 0 + + for det in detections.detections: + assert det.is_valid() + assert det.name == subject + # Check if bbox coordinates are within image dimensions + x1, y1, x2, y2 = det.bbox + assert 0 <= x1 < x2 <= test_image.width + assert 0 <= y1 < y2 <= test_image.height + +@pytest.mark.parametrize("subject", ["xbox controller", "lip balm"]) +def test_point(model, test_image, subject: str) -> None: + """Test pointing at objects.""" + print(f"\n--- Testing Point: {subject} ---") + points = model.point(test_image, subject) + + print(f"Found {len(points)} points for {subject}: {points}") + assert isinstance(points, list) + assert len(points) > 0 + + for x, y in points: + assert isinstance(x, (int, float)) + assert isinstance(y, (int, float)) + assert 0 <= x <= test_image.width + assert 0 <= y <= test_image.height diff --git a/dimos/manipulation/classical/pose_estimation.py b/dimos/msgs/__init__.py similarity index 100% rename from dimos/manipulation/classical/pose_estimation.py rename to dimos/msgs/__init__.py diff --git a/dimos/msgs/foxglove_msgs/Color.py b/dimos/msgs/foxglove_msgs/Color.py new file mode 100644 index 0000000000..93f521727e --- /dev/null +++ b/dimos/msgs/foxglove_msgs/Color.py @@ -0,0 +1,65 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import hashlib + +from dimos_lcm.foxglove_msgs import Color as LCMColor # type: ignore[import-untyped] + + +class Color(LCMColor): # type: ignore[misc] + """Color with convenience methods.""" + + @classmethod + def from_string(cls, name: str, alpha: float = 0.2, brightness: float = 1.0) -> Color: + """Generate a consistent color from a string using hash function. + + Args: + name: String to generate color from + alpha: Transparency value (0.0-1.0) + brightness: Brightness multiplier (0.0-2.0). Values > 1.0 lighten towards white. + + Returns: + Color instance with deterministic RGB values + """ + # Hash the string to get consistent values + hash_obj = hashlib.md5(name.encode()) + hash_bytes = hash_obj.digest() + + # Use first 3 bytes for RGB (0-255) + r = hash_bytes[0] / 255.0 + g = hash_bytes[1] / 255.0 + b = hash_bytes[2] / 255.0 + + # Apply brightness adjustment + # If brightness > 1.0, mix with white to lighten + if brightness > 1.0: + mix_factor = brightness - 1.0 # 0.0 to 1.0 + r = r + (1.0 - r) * mix_factor + g = g + (1.0 - g) * mix_factor + b = b + (1.0 - b) * mix_factor + else: + # If brightness < 1.0, darken by scaling + r *= brightness + g *= brightness + b *= brightness + + # Create and return color instance + color = cls() + color.r = min(1.0, r) + color.g = min(1.0, g) + color.b = min(1.0, b) + color.a = alpha + return color diff --git a/dimos/msgs/foxglove_msgs/ImageAnnotations.py b/dimos/msgs/foxglove_msgs/ImageAnnotations.py new file mode 100644 index 0000000000..291fab65fe --- /dev/null +++ b/dimos/msgs/foxglove_msgs/ImageAnnotations.py @@ -0,0 +1,35 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + ImageAnnotations as FoxgloveImageAnnotations, +) + + +class ImageAnnotations(FoxgloveImageAnnotations): # type: ignore[misc] + def __add__(self, other: "ImageAnnotations") -> "ImageAnnotations": + points = self.points + other.points + texts = self.texts + other.texts + + return ImageAnnotations( + texts=texts, + texts_length=len(texts), + points=points, + points_length=len(points), + ) + + def agent_encode(self) -> str: + if len(self.texts) == 0: + return None # type: ignore[return-value] + return list(map(lambda t: t.text, self.texts)) # type: ignore[return-value] diff --git a/dimos/msgs/foxglove_msgs/__init__.py b/dimos/msgs/foxglove_msgs/__init__.py new file mode 100644 index 0000000000..945ebf94c9 --- /dev/null +++ b/dimos/msgs/foxglove_msgs/__init__.py @@ -0,0 +1,3 @@ +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations + +__all__ = ["ImageAnnotations"] diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py new file mode 100644 index 0000000000..2a578e7f80 --- /dev/null +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -0,0 +1,275 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeAlias + +from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped] + Pose as LCMPose, + Transform as LCMTransform, +) + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + Point as ROSPoint, + Pose as ROSPose, + Quaternion as ROSQuaternion, + ) +except ImportError: + ROSPose = None # type: ignore[assignment, misc] + ROSPoint = None # type: ignore[assignment, misc] + ROSQuaternion = None # type: ignore[assignment, misc] + +from plum import dispatch + +from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable + +# Types that can be converted to/from Pose +PoseConvertable: TypeAlias = ( + tuple[VectorConvertable, QuaternionConvertable] + | LCMPose + | Vector3 + | dict[str, VectorConvertable | QuaternionConvertable] +) + + +class Pose(LCMPose): # type: ignore[misc] + position: Vector3 + orientation: Quaternion + msg_name = "geometry_msgs.Pose" + + @dispatch + def __init__(self) -> None: + """Initialize a pose at origin with identity orientation.""" + self.position = Vector3(0.0, 0.0, 0.0) + self.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) + + @dispatch # type: ignore[no-redef] + def __init__(self, x: int | float, y: int | float, z: int | float) -> None: + """Initialize a pose with position and identity orientation.""" + self.position = Vector3(x, y, z) + self.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + x: int | float, + y: int | float, + z: int | float, + qx: int | float, + qy: int | float, + qz: int | float, + qw: int | float, + ) -> None: + """Initialize a pose with position and orientation.""" + self.position = Vector3(x, y, z) + self.orientation = Quaternion(qx, qy, qz, qw) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + position: VectorConvertable | Vector3 | None = None, + orientation: QuaternionConvertable | Quaternion | None = None, + ) -> None: + """Initialize a pose with position and orientation.""" + if orientation is None: + orientation = [0, 0, 0, 1] + if position is None: + position = [0, 0, 0] + self.position = Vector3(position) + self.orientation = Quaternion(orientation) + + @dispatch # type: ignore[no-redef] + def __init__(self, pose_tuple: tuple[VectorConvertable, QuaternionConvertable]) -> None: + """Initialize from a tuple of (position, orientation).""" + self.position = Vector3(pose_tuple[0]) + self.orientation = Quaternion(pose_tuple[1]) + + @dispatch # type: ignore[no-redef] + def __init__(self, pose_dict: dict[str, VectorConvertable | QuaternionConvertable]) -> None: + """Initialize from a dictionary with 'position' and 'orientation' keys.""" + self.position = Vector3(pose_dict["position"]) + self.orientation = Quaternion(pose_dict["orientation"]) + + @dispatch # type: ignore[no-redef] + def __init__(self, pose: Pose) -> None: + """Initialize from another Pose (copy constructor).""" + self.position = Vector3(pose.position) + self.orientation = Quaternion(pose.orientation) + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_pose: LCMPose) -> None: + """Initialize from an LCM Pose.""" + self.position = Vector3(lcm_pose.position.x, lcm_pose.position.y, lcm_pose.position.z) + self.orientation = Quaternion( + lcm_pose.orientation.x, + lcm_pose.orientation.y, + lcm_pose.orientation.z, + lcm_pose.orientation.w, + ) + + @property + def x(self) -> float: + """X coordinate of position.""" + return self.position.x + + @property + def y(self) -> float: + """Y coordinate of position.""" + return self.position.y + + @property + def z(self) -> float: + """Z coordinate of position.""" + return self.position.z + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.orientation.to_euler().roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.orientation.to_euler().pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.orientation.to_euler().yaw + + def __repr__(self) -> str: + return f"Pose(position={self.position!r}, orientation={self.orientation!r})" + + def __str__(self) -> str: + return ( + f"Pose(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}]), " + f"quaternion=[{self.orientation}])" + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two poses are equal.""" + if not isinstance(other, Pose): + return False + return self.position == other.position and self.orientation == other.orientation + + def __matmul__(self, transform: LCMTransform | Transform) -> Pose: + return self + transform + + def __add__(self, other: Pose | PoseConvertable | LCMTransform | Transform) -> Pose: + """Compose two poses or apply a transform (transform composition). + + The operation self + other represents applying transformation 'other' + in the coordinate frame defined by 'self'. This is equivalent to: + - First apply transformation 'self' (from world to self's frame) + - Then apply transformation 'other' (from self's frame to other's frame) + + This matches ROS tf convention where: + T_world_to_other = T_world_to_self * T_self_to_other + + Args: + other: The pose or transform to compose with this one + + Returns: + A new Pose representing the composed transformation + + Example: + robot_pose = Pose(1, 0, 0) # Robot at (1,0,0) facing forward + object_in_robot = Pose(2, 0, 0) # Object 2m in front of robot + object_in_world = robot_pose + object_in_robot # Object at (3,0,0) in world + + # Or with a Transform: + transform = Transform() + transform.translation = Vector3(2, 0, 0) + transform.rotation = Quaternion(0, 0, 0, 1) + new_pose = pose + transform + """ + # Handle Transform objects + if isinstance(other, LCMTransform | Transform): + # Convert Transform to Pose using its translation and rotation + other_position = Vector3(other.translation) + other_orientation = Quaternion(other.rotation) + elif isinstance(other, Pose): + other_position = other.position + other_orientation = other.orientation + else: + # Convert to Pose if it's a convertible type + other_pose = Pose(other) + other_position = other_pose.position + other_orientation = other_pose.orientation + + # Compose orientations: self.orientation * other.orientation + new_orientation = self.orientation * other_orientation + + # Transform other's position by self's orientation, then add to self's position + rotated_position = self.orientation.rotate_vector(other_position) + new_position = self.position + rotated_position + + return Pose(new_position, new_orientation) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPose) -> Pose: + """Create a Pose from a ROS geometry_msgs/Pose message. + + Args: + ros_msg: ROS Pose message + + Returns: + Pose instance + """ + position = Vector3(ros_msg.position.x, ros_msg.position.y, ros_msg.position.z) + orientation = Quaternion( + ros_msg.orientation.x, + ros_msg.orientation.y, + ros_msg.orientation.z, + ros_msg.orientation.w, + ) + return cls(position, orientation) + + def to_ros_msg(self) -> ROSPose: + """Convert to a ROS geometry_msgs/Pose message. + + Returns: + ROS Pose message + """ + ros_msg = ROSPose() # type: ignore[no-untyped-call] + ros_msg.position = ROSPoint( # type: ignore[no-untyped-call] + x=float(self.position.x), y=float(self.position.y), z=float(self.position.z) + ) + ros_msg.orientation = ROSQuaternion( # type: ignore[no-untyped-call] + x=float(self.orientation.x), + y=float(self.orientation.y), + z=float(self.orientation.z), + w=float(self.orientation.w), + ) + return ros_msg + + +@dispatch +def to_pose(value: Pose) -> Pose: + """Pass through Pose objects.""" + return value + + +@dispatch # type: ignore[no-redef] +def to_pose(value: PoseConvertable) -> Pose: + """Convert a pose-compatible value to a Pose object.""" + return Pose(value) + + +PoseLike: TypeAlias = PoseConvertable | Pose diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py new file mode 100644 index 0000000000..3c394b1744 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -0,0 +1,154 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import BinaryIO, TypeAlias + +from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped # type: ignore[import-untyped] + +try: + from geometry_msgs.msg import PoseStamped as ROSPoseStamped # type: ignore[attr-defined] +except ImportError: + ROSPoseStamped = None # type: ignore[assignment, misc] + +from plum import dispatch + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from Pose +PoseConvertable: TypeAlias = ( + tuple[VectorConvertable, QuaternionConvertable] + | LCMPoseStamped + | dict[str, VectorConvertable | QuaternionConvertable] +) + + +def sec_nsec(ts): # type: ignore[no-untyped-def] + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class PoseStamped(Pose, Timestamped): + msg_name = "geometry_msgs.PoseStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: # type: ignore[no-untyped-def] + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + def lcm_encode(self) -> bytes: + lcm_mgs = LCMPoseStamped() + lcm_mgs.pose = self + [lcm_mgs.header.stamp.sec, lcm_mgs.header.stamp.nsec] = sec_nsec(self.ts) # type: ignore[no-untyped-call] + lcm_mgs.header.frame_id = self.frame_id + return lcm_mgs.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> PoseStamped: + lcm_msg = LCMPoseStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + position=[lcm_msg.pose.position.x, lcm_msg.pose.position.y, lcm_msg.pose.position.z], + orientation=[ + lcm_msg.pose.orientation.x, + lcm_msg.pose.orientation.y, + lcm_msg.pose.orientation.z, + lcm_msg.pose.orientation.w, + ], + ) + + def __str__(self) -> str: + return ( + f"PoseStamped(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}])" + ) + + def new_transform_to(self, name: str) -> Transform: + return self.find_transform( + PoseStamped( + frame_id=name, + position=Vector3(0, 0, 0), + orientation=Quaternion(0, 0, 0, 1), # Identity quaternion + ) + ) + + def new_transform_from(self, name: str) -> Transform: + return self.new_transform_to(name).inverse() + + def find_transform(self, other: PoseStamped) -> Transform: + inv_orientation = self.orientation.conjugate() + + pos_diff = other.position - self.position + + local_translation = inv_orientation.rotate_vector(pos_diff) + + relative_rotation = inv_orientation * other.orientation + + return Transform( + child_frame_id=other.frame_id, + frame_id=self.frame_id, + translation=local_translation, + rotation=relative_rotation, + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseStamped) -> PoseStamped: # type: ignore[override] + """Create a PoseStamped from a ROS geometry_msgs/PoseStamped message. + + Args: + ros_msg: ROS PoseStamped message + + Returns: + PoseStamped instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose + pose = Pose.from_ros_msg(ros_msg.pose) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + position=pose.position, + orientation=pose.orientation, + ) + + def to_ros_msg(self) -> ROSPoseStamped: # type: ignore[override] + """Convert to a ROS geometry_msgs/PoseStamped message. + + Returns: + ROS PoseStamped message + """ + ros_msg = ROSPoseStamped() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set pose + ros_msg.pose = Pose.to_ros_msg(self) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/PoseWithCovariance.py b/dimos/msgs/geometry_msgs/PoseWithCovariance.py new file mode 100644 index 0000000000..b49ecffab4 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseWithCovariance.py @@ -0,0 +1,233 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias + +from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped] + PoseWithCovariance as LCMPoseWithCovariance, +) +import numpy as np +from plum import dispatch + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + PoseWithCovariance as ROSPoseWithCovariance, + ) +except ImportError: + ROSPoseWithCovariance = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Pose import Pose, PoseConvertable + +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# Types that can be converted to/from PoseWithCovariance +PoseWithCovarianceConvertable: TypeAlias = ( + tuple[PoseConvertable, list[float] | np.ndarray] # type: ignore[type-arg] + | LCMPoseWithCovariance + | dict[str, PoseConvertable | list[float] | np.ndarray] # type: ignore[type-arg] +) + + +class PoseWithCovariance(LCMPoseWithCovariance): # type: ignore[misc] + pose: Pose + msg_name = "geometry_msgs.PoseWithCovariance" + + @dispatch + def __init__(self) -> None: + """Initialize with default pose and zero covariance.""" + self.pose = Pose() + self.covariance = np.zeros(36) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + pose: Pose | PoseConvertable, + covariance: list[float] | np.ndarray | None = None, # type: ignore[type-arg] + ) -> None: + """Initialize with pose and optional covariance.""" + self.pose = Pose(pose) if not isinstance(pose, Pose) else pose + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch # type: ignore[no-redef] + def __init__(self, pose_with_cov: PoseWithCovariance) -> None: + """Initialize from another PoseWithCovariance (copy constructor).""" + self.pose = Pose(pose_with_cov.pose) + self.covariance = np.array(pose_with_cov.covariance).copy() + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_pose_with_cov: LCMPoseWithCovariance) -> None: + """Initialize from an LCM PoseWithCovariance.""" + self.pose = Pose(lcm_pose_with_cov.pose) + self.covariance = np.array(lcm_pose_with_cov.covariance) + + @dispatch # type: ignore[no-redef] + def __init__(self, pose_dict: dict[str, PoseConvertable | list[float] | np.ndarray]) -> None: # type: ignore[type-arg] + """Initialize from a dictionary with 'pose' and 'covariance' keys.""" + self.pose = Pose(pose_dict["pose"]) + covariance = pose_dict.get("covariance") + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch # type: ignore[no-redef] + def __init__(self, pose_tuple: tuple[PoseConvertable, list[float] | np.ndarray]) -> None: # type: ignore[type-arg] + """Initialize from a tuple of (pose, covariance).""" + self.pose = Pose(pose_tuple[0]) + self.covariance = np.array(pose_tuple[1], dtype=float).reshape(36) + + def __getattribute__(self, name: str): # type: ignore[no-untyped-def] + """Override to ensure covariance is always returned as numpy array.""" + if name == "covariance": + cov = object.__getattribute__(self, "covariance") + if not isinstance(cov, np.ndarray): + return np.array(cov, dtype=float) + return cov + return super().__getattribute__(name) + + def __setattr__(self, name: str, value) -> None: # type: ignore[no-untyped-def] + """Override to ensure covariance is stored as numpy array.""" + if name == "covariance": + if not isinstance(value, np.ndarray): + value = np.array(value, dtype=float).reshape(36) + super().__setattr__(name, value) + + @property + def x(self) -> float: + """X coordinate of position.""" + return self.pose.x + + @property + def y(self) -> float: + """Y coordinate of position.""" + return self.pose.y + + @property + def z(self) -> float: + """Z coordinate of position.""" + return self.pose.z + + @property + def position(self) -> Vector3: + """Position vector.""" + return self.pose.position + + @property + def orientation(self) -> Quaternion: + """Orientation quaternion.""" + return self.pose.orientation + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.pose.roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.pose.pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.pose.yaw + + @property + def covariance_matrix(self) -> np.ndarray: # type: ignore[type-arg] + """Get covariance as 6x6 matrix.""" + return self.covariance.reshape(6, 6) # type: ignore[has-type, no-any-return] + + @covariance_matrix.setter + def covariance_matrix(self, value: np.ndarray) -> None: # type: ignore[type-arg] + """Set covariance from 6x6 matrix.""" + self.covariance = np.array(value).reshape(36) # type: ignore[has-type] + + def __repr__(self) -> str: + return f"PoseWithCovariance(pose={self.pose!r}, covariance=<{self.covariance.shape[0] if isinstance(self.covariance, np.ndarray) else len(self.covariance)} elements>)" # type: ignore[has-type] + + def __str__(self) -> str: + return ( + f"PoseWithCovariance(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two PoseWithCovariance are equal.""" + if not isinstance(other, PoseWithCovariance): + return False + return self.pose == other.pose and np.allclose(self.covariance, other.covariance) # type: ignore[has-type] + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMPoseWithCovariance() + lcm_msg.pose = self.pose + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + lcm_msg.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + lcm_msg.covariance = list(self.covariance) # type: ignore[has-type] + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> PoseWithCovariance: + """Decode from LCM binary format.""" + lcm_msg = LCMPoseWithCovariance.lcm_decode(data) + pose = Pose( + position=[lcm_msg.pose.position.x, lcm_msg.pose.position.y, lcm_msg.pose.position.z], + orientation=[ + lcm_msg.pose.orientation.x, + lcm_msg.pose.orientation.y, + lcm_msg.pose.orientation.z, + lcm_msg.pose.orientation.w, + ], + ) + return cls(pose, lcm_msg.covariance) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseWithCovariance) -> PoseWithCovariance: + """Create a PoseWithCovariance from a ROS geometry_msgs/PoseWithCovariance message. + + Args: + ros_msg: ROS PoseWithCovariance message + + Returns: + PoseWithCovariance instance + """ + + pose = Pose.from_ros_msg(ros_msg.pose) + return cls(pose, list(ros_msg.covariance)) + + def to_ros_msg(self) -> ROSPoseWithCovariance: + """Convert to a ROS geometry_msgs/PoseWithCovariance message. + + Returns: + ROS PoseWithCovariance message + """ + + ros_msg = ROSPoseWithCovariance() # type: ignore[no-untyped-call] + ros_msg.pose = self.pose.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + ros_msg.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + ros_msg.covariance = list(self.covariance) # type: ignore[has-type] + return ros_msg diff --git a/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py new file mode 100644 index 0000000000..33a45a3a41 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py @@ -0,0 +1,165 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped] + PoseWithCovarianceStamped as LCMPoseWithCovarianceStamped, +) +import numpy as np +from plum import dispatch + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped, + ) +except ImportError: + ROSPoseWithCovarianceStamped = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Pose import Pose, PoseConvertable +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from PoseWithCovarianceStamped +PoseWithCovarianceStampedConvertable: TypeAlias = ( + tuple[PoseConvertable, list[float] | np.ndarray] # type: ignore[type-arg] + | LCMPoseWithCovarianceStamped + | dict[str, PoseConvertable | list[float] | np.ndarray | float | str] # type: ignore[type-arg] +) + + +def sec_nsec(ts): # type: ignore[no-untyped-def] + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class PoseWithCovarianceStamped(PoseWithCovariance, Timestamped): + msg_name = "geometry_msgs.PoseWithCovarianceStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + """Initialize with timestamp and frame_id.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + pose: Pose | PoseConvertable | None = None, + covariance: list[float] | np.ndarray | None = None, # type: ignore[type-arg] + ) -> None: + """Initialize with timestamp, frame_id, pose and covariance.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + if pose is None: + super().__init__() + else: + super().__init__(pose, covariance) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMPoseWithCovarianceStamped() + lcm_msg.pose.pose = self.pose + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + lcm_msg.pose.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + lcm_msg.pose.covariance = list(self.covariance) # type: ignore[has-type] + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) # type: ignore[no-untyped-call] + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> PoseWithCovarianceStamped: + lcm_msg = LCMPoseWithCovarianceStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + pose=Pose( + position=[ + lcm_msg.pose.pose.position.x, + lcm_msg.pose.pose.position.y, + lcm_msg.pose.pose.position.z, + ], + orientation=[ + lcm_msg.pose.pose.orientation.x, + lcm_msg.pose.pose.orientation.y, + lcm_msg.pose.pose.orientation.z, + lcm_msg.pose.pose.orientation.w, + ], + ), + covariance=lcm_msg.pose.covariance, + ) + + def __str__(self) -> str: + return ( + f"PoseWithCovarianceStamped(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseWithCovarianceStamped) -> PoseWithCovarianceStamped: # type: ignore[override] + """Create a PoseWithCovarianceStamped from a ROS geometry_msgs/PoseWithCovarianceStamped message. + + Args: + ros_msg: ROS PoseWithCovarianceStamped message + + Returns: + PoseWithCovarianceStamped instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose with covariance + pose_with_cov = PoseWithCovariance.from_ros_msg(ros_msg.pose) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + pose=pose_with_cov.pose, + covariance=pose_with_cov.covariance, # type: ignore[has-type] + ) + + def to_ros_msg(self) -> ROSPoseWithCovarianceStamped: # type: ignore[override] + """Convert to a ROS geometry_msgs/PoseWithCovarianceStamped message. + + Returns: + ROS PoseWithCovarianceStamped message + """ + + ros_msg = ROSPoseWithCovarianceStamped() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set pose with covariance + ros_msg.pose.pose = self.pose.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + ros_msg.pose.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + ros_msg.pose.covariance = list(self.covariance) # type: ignore[has-type] + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/Quaternion.py b/dimos/msgs/geometry_msgs/Quaternion.py new file mode 100644 index 0000000000..cacfadbad0 --- /dev/null +++ b/dimos/msgs/geometry_msgs/Quaternion.py @@ -0,0 +1,246 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from io import BytesIO +import struct +from typing import BinaryIO, TypeAlias + +from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion # type: ignore[import-untyped] +import numpy as np +from plum import dispatch +from scipy.spatial.transform import Rotation as R + +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# Types that can be converted to/from Quaternion +QuaternionConvertable: TypeAlias = Sequence[int | float] | LCMQuaternion | np.ndarray # type: ignore[type-arg] + + +class Quaternion(LCMQuaternion): # type: ignore[misc] + x: float = 0.0 + y: float = 0.0 + z: float = 0.0 + w: float = 1.0 + msg_name = "geometry_msgs.Quaternion" + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO): # type: ignore[no-untyped-def] + if not hasattr(data, "read"): + data = BytesIO(data) + if data.read(8) != cls._get_packed_fingerprint(): + raise ValueError("Decode error") + return cls._lcm_decode_one(data) # type: ignore[no-untyped-call] + + @classmethod + def _lcm_decode_one(cls, buf): # type: ignore[no-untyped-def] + return cls(struct.unpack(">dddd", buf.read(32))) + + @dispatch + def __init__(self) -> None: ... + + @dispatch # type: ignore[no-redef] + def __init__(self, x: int | float, y: int | float, z: int | float, w: int | float) -> None: + self.x = float(x) + self.y = float(y) + self.z = float(z) + self.w = float(w) + + @dispatch # type: ignore[no-redef] + def __init__(self, sequence: Sequence[int | float] | np.ndarray) -> None: # type: ignore[type-arg] + if isinstance(sequence, np.ndarray): + if sequence.size != 4: + raise ValueError("Quaternion requires exactly 4 components [x, y, z, w]") + else: + if len(sequence) != 4: + raise ValueError("Quaternion requires exactly 4 components [x, y, z, w]") + + self.x = sequence[0] + self.y = sequence[1] + self.z = sequence[2] + self.w = sequence[3] + + @dispatch # type: ignore[no-redef] + def __init__(self, quaternion: Quaternion) -> None: + """Initialize from another Quaternion (copy constructor).""" + self.x, self.y, self.z, self.w = quaternion.x, quaternion.y, quaternion.z, quaternion.w + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_quaternion: LCMQuaternion) -> None: + """Initialize from an LCM Quaternion.""" + self.x, self.y, self.z, self.w = ( + lcm_quaternion.x, + lcm_quaternion.y, + lcm_quaternion.z, + lcm_quaternion.w, + ) + + def to_tuple(self) -> tuple[float, float, float, float]: + """Tuple representation of the quaternion (x, y, z, w).""" + return (self.x, self.y, self.z, self.w) + + def to_list(self) -> list[float]: + """List representation of the quaternion (x, y, z, w).""" + return [self.x, self.y, self.z, self.w] + + def to_numpy(self) -> np.ndarray: # type: ignore[type-arg] + """Numpy array representation of the quaternion (x, y, z, w).""" + return np.array([self.x, self.y, self.z, self.w]) + + @property + def euler(self) -> Vector3: + return self.to_euler() + + @property + def radians(self) -> Vector3: + return self.to_euler() + + def to_radians(self) -> Vector3: + """Radians representation of the quaternion (x, y, z, w).""" + return self.to_euler() + + @classmethod + def from_euler(cls, vector: Vector3) -> Quaternion: + """Convert Euler angles (roll, pitch, yaw) in radians to quaternion. + + Args: + vector: Vector3 containing (roll, pitch, yaw) in radians + + Returns: + Quaternion representation + """ + + # Calculate quaternion components + cy = np.cos(vector.yaw * 0.5) + sy = np.sin(vector.yaw * 0.5) + cp = np.cos(vector.pitch * 0.5) + sp = np.sin(vector.pitch * 0.5) + cr = np.cos(vector.roll * 0.5) + sr = np.sin(vector.roll * 0.5) + + w = cr * cp * cy + sr * sp * sy + x = sr * cp * cy - cr * sp * sy + y = cr * sp * cy + sr * cp * sy + z = cr * cp * sy - sr * sp * cy + + return cls(x, y, z, w) + + def to_euler(self) -> Vector3: + """Convert quaternion to Euler angles (roll, pitch, yaw) in radians. + + Returns: + Vector3: Euler angles as (roll, pitch, yaw) in radians + """ + # Use scipy for accurate quaternion to euler conversion + quat = [self.x, self.y, self.z, self.w] + rotation = R.from_quat(quat) + euler_angles = rotation.as_euler("xyz") # roll, pitch, yaw + + return Vector3(euler_angles[0], euler_angles[1], euler_angles[2]) + + def __getitem__(self, idx: int) -> float: + """Allow indexing into quaternion components: 0=x, 1=y, 2=z, 3=w.""" + if idx == 0: + return self.x + elif idx == 1: + return self.y + elif idx == 2: + return self.z + elif idx == 3: + return self.w + else: + raise IndexError(f"Quaternion index {idx} out of range [0-3]") + + def __repr__(self) -> str: + return f"Quaternion({self.x:.6f}, {self.y:.6f}, {self.z:.6f}, {self.w:.6f})" + + def __str__(self) -> str: + return self.__repr__() + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + if not isinstance(other, Quaternion): + return False + return self.x == other.x and self.y == other.y and self.z == other.z and self.w == other.w + + def __mul__(self, other: Quaternion) -> Quaternion: + """Multiply two quaternions (Hamilton product). + + The result represents the composition of rotations: + q1 * q2 represents rotating by q2 first, then by q1. + """ + if not isinstance(other, Quaternion): + raise TypeError(f"Cannot multiply Quaternion with {type(other)}") + + # Hamilton product formula + w = self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z + x = self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y + y = self.w * other.y - self.x * other.z + self.y * other.w + self.z * other.x + z = self.w * other.z + self.x * other.y - self.y * other.x + self.z * other.w + + return Quaternion(x, y, z, w) + + def conjugate(self) -> Quaternion: + """Return the conjugate of the quaternion. + + For unit quaternions, the conjugate represents the inverse rotation. + """ + return Quaternion(-self.x, -self.y, -self.z, self.w) + + def inverse(self) -> Quaternion: + """Return the inverse of the quaternion. + + For unit quaternions, this is equivalent to the conjugate. + For non-unit quaternions, this is conjugate / norm^2. + """ + norm_sq = self.x**2 + self.y**2 + self.z**2 + self.w**2 + if norm_sq == 0: + raise ZeroDivisionError("Cannot invert zero quaternion") + + # For unit quaternions (norm_sq ≈ 1), this simplifies to conjugate + if np.isclose(norm_sq, 1.0): + return self.conjugate() + + # For non-unit quaternions + conj = self.conjugate() + return Quaternion(conj.x / norm_sq, conj.y / norm_sq, conj.z / norm_sq, conj.w / norm_sq) + + def normalize(self) -> Quaternion: + """Return a normalized (unit) quaternion.""" + norm = np.sqrt(self.x**2 + self.y**2 + self.z**2 + self.w**2) + if norm == 0: + raise ZeroDivisionError("Cannot normalize zero quaternion") + return Quaternion(self.x / norm, self.y / norm, self.z / norm, self.w / norm) + + def rotate_vector(self, vector: Vector3) -> Vector3: + """Rotate a 3D vector by this quaternion. + + Args: + vector: The vector to rotate + + Returns: + The rotated vector + """ + # For unit quaternions, conjugate equals inverse, so we use conjugate for efficiency + # The rotation formula is: q * v * q^* where q^* is the conjugate + + # Convert vector to pure quaternion (w=0) + v_quat = Quaternion(vector.x, vector.y, vector.z, 0) + + # Apply rotation: q * v * q^* (conjugate for unit quaternions) + rotated = self * v_quat * self.conjugate() + + # Extract vector components + return Vector3(rotated.x, rotated.y, rotated.z) diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py new file mode 100644 index 0000000000..3cb964dd3b --- /dev/null +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -0,0 +1,361 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import BinaryIO + +from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped] + Transform as LCMTransform, + TransformStamped as LCMTransformStamped, +) + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + Quaternion as ROSQuaternion, + Transform as ROSTransform, + TransformStamped as ROSTransformStamped, + Vector3 as ROSVector3, + ) +except ImportError: + ROSTransformStamped = None # type: ignore[assignment, misc] + ROSTransform = None # type: ignore[assignment, misc] + ROSVector3 = None # type: ignore[assignment, misc] + ROSQuaternion = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.std_msgs import Header +from dimos.types.timestamped import Timestamped + + +class Transform(Timestamped): + translation: Vector3 + rotation: Quaternion + ts: float + frame_id: str + child_frame_id: str + msg_name = "tf2_msgs.TFMessage" + + def __init__( # type: ignore[no-untyped-def] + self, + translation: Vector3 | None = None, + rotation: Quaternion | None = None, + frame_id: str = "world", + child_frame_id: str = "unset", + ts: float = 0.0, + **kwargs, + ) -> None: + self.frame_id = frame_id + self.child_frame_id = child_frame_id + self.ts = ts if ts != 0.0 else time.time() + self.translation = translation if translation is not None else Vector3() + self.rotation = rotation if rotation is not None else Quaternion() + + def now(self) -> Transform: + """Return a copy of this Transform with the current timestamp.""" + return Transform( + translation=self.translation, + rotation=self.rotation, + frame_id=self.frame_id, + child_frame_id=self.child_frame_id, + ts=time.time(), + ) + + def __repr__(self) -> str: + return f"Transform(translation={self.translation!r}, rotation={self.rotation!r})" + + def __str__(self) -> str: + return f"Transform:\n {self.frame_id} -> {self.child_frame_id} Translation: {self.translation}\n Rotation: {self.rotation}" + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two transforms are equal.""" + if not isinstance(other, Transform): + return False + return self.translation == other.translation and self.rotation == other.rotation + + @classmethod + def identity(cls) -> Transform: + """Create an identity transform.""" + return cls() + + def lcm_transform(self) -> LCMTransformStamped: + return LCMTransformStamped( + child_frame_id=self.child_frame_id, + header=Header(self.ts, self.frame_id), + transform=LCMTransform( + translation=self.translation, + rotation=self.rotation, + ), + ) + + def apply(self, other: Transform) -> Transform: + return self.__add__(other) + + def __add__(self, other: Transform) -> Transform: + """Compose two transforms (transform composition). + + The operation self + other represents applying transformation 'other' + in the coordinate frame defined by 'self'. This is equivalent to: + - First apply transformation 'self' (from frame A to frame B) + - Then apply transformation 'other' (from frame B to frame C) + + Args: + other: The transform to compose with this one + + Returns: + A new Transform representing the composed transformation + + Example: + t1 = Transform(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + t2 = Transform(Vector3(2, 0, 0), Quaternion(0, 0, 0, 1)) + t3 = t1 + t2 # Combined transform: translation (3, 0, 0) + """ + if not isinstance(other, Transform): + raise TypeError(f"Cannot add Transform and {type(other).__name__}") + + # Compose orientations: self.rotation * other.rotation + new_rotation = self.rotation * other.rotation + + # Transform other's translation by self's rotation, then add to self's translation + rotated_translation = self.rotation.rotate_vector(other.translation) + new_translation = self.translation + rotated_translation + + return Transform( + translation=new_translation, + rotation=new_rotation, + frame_id=self.frame_id, + child_frame_id=other.child_frame_id, + ts=self.ts, + ) + + def inverse(self) -> Transform: + """Compute the inverse transform. + + The inverse transform reverses the direction of the transformation. + If this transform goes from frame A to frame B, the inverse goes from B to A. + + Returns: + A new Transform representing the inverse transformation + """ + # Inverse rotation + inv_rotation = self.rotation.inverse() + + # Inverse translation: -R^(-1) * t + inv_translation = inv_rotation.rotate_vector(self.translation) + inv_translation = Vector3(-inv_translation.x, -inv_translation.y, -inv_translation.z) + + return Transform( + translation=inv_translation, + rotation=inv_rotation, + frame_id=self.child_frame_id, # Swap frame references + child_frame_id=self.frame_id, + ts=self.ts, + ) + + @classmethod + def from_ros_transform_stamped(cls, ros_msg: ROSTransformStamped) -> Transform: + """Create a Transform from a ROS geometry_msgs/TransformStamped message. + + Args: + ros_msg: ROS TransformStamped message + + Returns: + Transform instance + """ + + # Convert timestamp + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert translation + translation = Vector3( + ros_msg.transform.translation.x, + ros_msg.transform.translation.y, + ros_msg.transform.translation.z, + ) + + # Convert rotation + rotation = Quaternion( + ros_msg.transform.rotation.x, + ros_msg.transform.rotation.y, + ros_msg.transform.rotation.z, + ros_msg.transform.rotation.w, + ) + + return cls( + translation=translation, + rotation=rotation, + frame_id=ros_msg.header.frame_id, + child_frame_id=ros_msg.child_frame_id, + ts=ts, + ) + + def to_ros_transform_stamped(self) -> ROSTransformStamped: + """Convert to a ROS geometry_msgs/TransformStamped message. + + Returns: + ROS TransformStamped message + """ + + ros_msg = ROSTransformStamped() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set child frame + ros_msg.child_frame_id = self.child_frame_id + + # Set transform + ros_msg.transform.translation = ROSVector3( # type: ignore[no-untyped-call] + x=self.translation.x, y=self.translation.y, z=self.translation.z + ) + ros_msg.transform.rotation = ROSQuaternion( # type: ignore[no-untyped-call] + x=self.rotation.x, y=self.rotation.y, z=self.rotation.z, w=self.rotation.w + ) + + return ros_msg + + def __neg__(self) -> Transform: + """Unary minus operator returns the inverse transform.""" + return self.inverse() + + @classmethod + def from_pose(cls, frame_id: str, pose: Pose | PoseStamped) -> Transform: # type: ignore[name-defined] + """Create a Transform from a Pose or PoseStamped. + + Args: + pose: A Pose or PoseStamped object to convert + + Returns: + A Transform with the same translation and rotation as the pose + """ + # Import locally to avoid circular imports + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + # Handle both Pose and PoseStamped + if isinstance(pose, PoseStamped): + return cls( + translation=pose.position, + rotation=pose.orientation, + frame_id=pose.frame_id, + child_frame_id=frame_id, + ts=pose.ts, + ) + elif isinstance(pose, Pose): + return cls( + translation=pose.position, + rotation=pose.orientation, + child_frame_id=frame_id, + ) + else: + raise TypeError(f"Expected Pose or PoseStamped, got {type(pose).__name__}") + + def to_pose(self, **kwargs) -> PoseStamped: # type: ignore[name-defined, no-untyped-def] + """Create a Transform from a Pose or PoseStamped. + + Args: + pose: A Pose or PoseStamped object to convert + + Returns: + A Transform with the same translation and rotation as the pose + """ + # Import locally to avoid circular imports + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + # Handle both Pose and PoseStamped + return PoseStamped( + **{ + "position": self.translation, + "orientation": self.rotation, + "frame_id": self.frame_id, + }, + **kwargs, + ) + + def to_matrix(self) -> np.ndarray: # type: ignore[name-defined] + """Convert Transform to a 4x4 transformation matrix. + + Returns a homogeneous transformation matrix that represents both + the rotation and translation of this transform. + + Returns: + np.ndarray: A 4x4 homogeneous transformation matrix + """ + import numpy as np + + # Extract quaternion components + x, y, z, w = self.rotation.x, self.rotation.y, self.rotation.z, self.rotation.w + + # Build rotation matrix from quaternion using standard formula + # This avoids numerical issues compared to converting to axis-angle first + rotation_matrix = np.array( + [ + [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], + [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], + [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], + ] + ) + + # Build 4x4 homogeneous transformation matrix + matrix = np.eye(4) + matrix[:3, :3] = rotation_matrix + matrix[:3, 3] = [self.translation.x, self.translation.y, self.translation.z] + + return matrix + + def lcm_encode(self) -> bytes: + # we get a circular import otherwise + from dimos.msgs.tf2_msgs.TFMessage import TFMessage + + return TFMessage(self).lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> Transform: + """Decode from LCM TFMessage bytes.""" + from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage # type: ignore[import-untyped] + + lcm_msg = LCMTFMessage.lcm_decode(data) + + if not lcm_msg.transforms: + raise ValueError("No transforms found in LCM message") + + # Get the first transform from the message + lcm_transform_stamped = lcm_msg.transforms[0] + + # Extract timestamp from header + ts = lcm_transform_stamped.header.stamp.sec + ( + lcm_transform_stamped.header.stamp.nsec / 1_000_000_000 + ) + + # Create and return Transform instance + return cls( + translation=Vector3( + lcm_transform_stamped.transform.translation.x, + lcm_transform_stamped.transform.translation.y, + lcm_transform_stamped.transform.translation.z, + ), + rotation=Quaternion( + lcm_transform_stamped.transform.rotation.x, + lcm_transform_stamped.transform.rotation.y, + lcm_transform_stamped.transform.rotation.z, + lcm_transform_stamped.transform.rotation.w, + ), + frame_id=lcm_transform_stamped.header.frame_id, + child_frame_id=lcm_transform_stamped.child_frame_id, + ts=ts, + ) diff --git a/dimos/msgs/geometry_msgs/Twist.py b/dimos/msgs/geometry_msgs/Twist.py new file mode 100644 index 0000000000..0ff7498a11 --- /dev/null +++ b/dimos/msgs/geometry_msgs/Twist.py @@ -0,0 +1,139 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dimos_lcm.geometry_msgs import Twist as LCMTwist # type: ignore[import-untyped] +from plum import dispatch + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + Twist as ROSTwist, + Vector3 as ROSVector3, + ) +except ImportError: + ROSTwist = None # type: ignore[assignment, misc] + ROSVector3 = None # type: ignore[assignment, misc] + +# Import Quaternion at runtime for beartype compatibility +# (beartype needs to resolve forward references at runtime) +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike + + +class Twist(LCMTwist): # type: ignore[misc] + linear: Vector3 + angular: Vector3 + msg_name = "geometry_msgs.Twist" + + @dispatch + def __init__(self) -> None: + """Initialize a zero twist (no linear or angular velocity).""" + self.linear = Vector3() + self.angular = Vector3() + + @dispatch # type: ignore[no-redef] + def __init__(self, linear: VectorLike, angular: VectorLike) -> None: + """Initialize a twist from linear and angular velocities.""" + + self.linear = Vector3(linear) + self.angular = Vector3(angular) + + @dispatch # type: ignore[no-redef] + def __init__(self, linear: VectorLike, angular: Quaternion) -> None: + """Initialize a twist from linear velocity and angular as quaternion (converted to euler).""" + self.linear = Vector3(linear) + self.angular = angular.to_euler() + + @dispatch # type: ignore[no-redef] + def __init__(self, twist: Twist) -> None: + """Initialize from another Twist (copy constructor).""" + self.linear = Vector3(twist.linear) + self.angular = Vector3(twist.angular) + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_twist: LCMTwist) -> None: + """Initialize from an LCM Twist.""" + self.linear = Vector3(lcm_twist.linear) + self.angular = Vector3(lcm_twist.angular) + + @dispatch # type: ignore[no-redef] + def __init__(self, **kwargs) -> None: + """Handle keyword arguments for LCM compatibility.""" + linear = kwargs.get("linear", Vector3()) + angular = kwargs.get("angular", Vector3()) + + self.__init__(linear, angular) + + def __repr__(self) -> str: + return f"Twist(linear={self.linear!r}, angular={self.angular!r})" + + def __str__(self) -> str: + return f"Twist:\n Linear: {self.linear}\n Angular: {self.angular}" + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two twists are equal.""" + if not isinstance(other, Twist): + return False + return self.linear == other.linear and self.angular == other.angular + + @classmethod + def zero(cls) -> Twist: + """Create a zero twist (no motion).""" + return cls() + + def is_zero(self) -> bool: + """Check if this is a zero twist (no linear or angular velocity).""" + return self.linear.is_zero() and self.angular.is_zero() + + def __bool__(self) -> bool: + """Boolean conversion for Twist. + + A Twist is considered False if it's a zero twist (no motion), + and True otherwise. + + Returns: + False if twist is zero, True otherwise + """ + return not self.is_zero() + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwist) -> Twist: + """Create a Twist from a ROS geometry_msgs/Twist message. + + Args: + ros_msg: ROS Twist message + + Returns: + Twist instance + """ + + linear = Vector3(ros_msg.linear.x, ros_msg.linear.y, ros_msg.linear.z) + angular = Vector3(ros_msg.angular.x, ros_msg.angular.y, ros_msg.angular.z) + return cls(linear, angular) + + def to_ros_msg(self) -> ROSTwist: + """Convert to a ROS geometry_msgs/Twist message. + + Returns: + ROS Twist message + """ + + ros_msg = ROSTwist() # type: ignore[no-untyped-call] + ros_msg.linear = ROSVector3(x=self.linear.x, y=self.linear.y, z=self.linear.z) # type: ignore[no-untyped-call] + ros_msg.angular = ROSVector3(x=self.angular.x, y=self.angular.y, z=self.angular.z) # type: ignore[no-untyped-call] + return ros_msg + + +__all__ = ["Quaternion", "Twist"] diff --git a/dimos/msgs/geometry_msgs/TwistStamped.py b/dimos/msgs/geometry_msgs/TwistStamped.py new file mode 100644 index 0000000000..3f28db8e98 --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistStamped.py @@ -0,0 +1,118 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import BinaryIO, TypeAlias + +from dimos_lcm.geometry_msgs import TwistStamped as LCMTwistStamped # type: ignore[import-untyped] +from plum import dispatch + +try: + from geometry_msgs.msg import TwistStamped as ROSTwistStamped # type: ignore[attr-defined] +except ImportError: + ROSTwistStamped = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from TwistStamped +TwistConvertable: TypeAlias = ( + tuple[VectorConvertable, VectorConvertable] | LCMTwistStamped | dict[str, VectorConvertable] +) + + +def sec_nsec(ts): # type: ignore[no-untyped-def] + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class TwistStamped(Twist, Timestamped): + msg_name = "geometry_msgs.TwistStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: # type: ignore[no-untyped-def] + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMTwistStamped() + lcm_msg.twist = self + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) # type: ignore[no-untyped-call] + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> TwistStamped: + lcm_msg = LCMTwistStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + linear=[lcm_msg.twist.linear.x, lcm_msg.twist.linear.y, lcm_msg.twist.linear.z], + angular=[lcm_msg.twist.angular.x, lcm_msg.twist.angular.y, lcm_msg.twist.angular.z], + ) + + def __str__(self) -> str: + return ( + f"TwistStamped(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}])" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistStamped) -> TwistStamped: # type: ignore[override] + """Create a TwistStamped from a ROS geometry_msgs/TwistStamped message. + + Args: + ros_msg: ROS TwistStamped message + + Returns: + TwistStamped instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert twist + twist = Twist.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + linear=twist.linear, + angular=twist.angular, + ) + + def to_ros_msg(self) -> ROSTwistStamped: # type: ignore[override] + """Convert to a ROS geometry_msgs/TwistStamped message. + + Returns: + ROS TwistStamped message + """ + + ros_msg = ROSTwistStamped() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set twist + ros_msg.twist = Twist.to_ros_msg(self) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistWithCovariance.py b/dimos/msgs/geometry_msgs/TwistWithCovariance.py new file mode 100644 index 0000000000..1b67b8c7c4 --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistWithCovariance.py @@ -0,0 +1,229 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeAlias + +from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped] + TwistWithCovariance as LCMTwistWithCovariance, +) +import numpy as np +from plum import dispatch + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + TwistWithCovariance as ROSTwistWithCovariance, + ) +except ImportError: + ROSTwistWithCovariance = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable + +# Types that can be converted to/from TwistWithCovariance +TwistWithCovarianceConvertable: TypeAlias = ( + tuple[Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray] # type: ignore[type-arg] + | LCMTwistWithCovariance + | dict[str, Twist | tuple[VectorConvertable, VectorConvertable] | list[float] | np.ndarray] # type: ignore[type-arg] +) + + +class TwistWithCovariance(LCMTwistWithCovariance): # type: ignore[misc] + twist: Twist + msg_name = "geometry_msgs.TwistWithCovariance" + + @dispatch + def __init__(self) -> None: + """Initialize with default twist and zero covariance.""" + self.twist = Twist() + self.covariance = np.zeros(36) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + twist: Twist | tuple[VectorConvertable, VectorConvertable], + covariance: list[float] | np.ndarray | None = None, # type: ignore[type-arg] + ) -> None: + """Initialize with twist and optional covariance.""" + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch # type: ignore[no-redef] + def __init__(self, twist_with_cov: TwistWithCovariance) -> None: + """Initialize from another TwistWithCovariance (copy constructor).""" + self.twist = Twist(twist_with_cov.twist) + self.covariance = np.array(twist_with_cov.covariance).copy() + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_twist_with_cov: LCMTwistWithCovariance) -> None: + """Initialize from an LCM TwistWithCovariance.""" + self.twist = Twist(lcm_twist_with_cov.twist) + self.covariance = np.array(lcm_twist_with_cov.covariance) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + twist_dict: dict[ # type: ignore[type-arg] + str, Twist | tuple[VectorConvertable, VectorConvertable] | list[float] | np.ndarray + ], + ) -> None: + """Initialize from a dictionary with 'twist' and 'covariance' keys.""" + twist = twist_dict["twist"] + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + + covariance = twist_dict.get("covariance") + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + twist_tuple: tuple[ # type: ignore[type-arg] + Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray + ], + ) -> None: + """Initialize from a tuple of (twist, covariance).""" + twist = twist_tuple[0] + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + self.covariance = np.array(twist_tuple[1], dtype=float).reshape(36) + + def __getattribute__(self, name: str): # type: ignore[no-untyped-def] + """Override to ensure covariance is always returned as numpy array.""" + if name == "covariance": + cov = object.__getattribute__(self, "covariance") + if not isinstance(cov, np.ndarray): + return np.array(cov, dtype=float) + return cov + return super().__getattribute__(name) + + def __setattr__(self, name: str, value) -> None: # type: ignore[no-untyped-def] + """Override to ensure covariance is stored as numpy array.""" + if name == "covariance": + if not isinstance(value, np.ndarray): + value = np.array(value, dtype=float).reshape(36) + super().__setattr__(name, value) + + @property + def linear(self) -> Vector3: + """Linear velocity vector.""" + return self.twist.linear + + @property + def angular(self) -> Vector3: + """Angular velocity vector.""" + return self.twist.angular + + @property + def covariance_matrix(self) -> np.ndarray: # type: ignore[type-arg] + """Get covariance as 6x6 matrix.""" + return self.covariance.reshape(6, 6) # type: ignore[has-type, no-any-return] + + @covariance_matrix.setter + def covariance_matrix(self, value: np.ndarray) -> None: # type: ignore[type-arg] + """Set covariance from 6x6 matrix.""" + self.covariance = np.array(value).reshape(36) # type: ignore[has-type] + + def __repr__(self) -> str: + return f"TwistWithCovariance(twist={self.twist!r}, covariance=<{self.covariance.shape[0] if isinstance(self.covariance, np.ndarray) else len(self.covariance)} elements>)" # type: ignore[has-type] + + def __str__(self) -> str: + return ( + f"TwistWithCovariance(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two TwistWithCovariance are equal.""" + if not isinstance(other, TwistWithCovariance): + return False + return self.twist == other.twist and np.allclose(self.covariance, other.covariance) # type: ignore[has-type] + + def is_zero(self) -> bool: + """Check if this is a zero twist (no linear or angular velocity).""" + return self.twist.is_zero() + + def __bool__(self) -> bool: + """Boolean conversion - False if zero twist, True otherwise.""" + return not self.is_zero() + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMTwistWithCovariance() + lcm_msg.twist = self.twist + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + lcm_msg.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + lcm_msg.covariance = list(self.covariance) # type: ignore[has-type] + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> TwistWithCovariance: + """Decode from LCM binary format.""" + lcm_msg = LCMTwistWithCovariance.lcm_decode(data) + twist = Twist( + linear=[lcm_msg.twist.linear.x, lcm_msg.twist.linear.y, lcm_msg.twist.linear.z], + angular=[lcm_msg.twist.angular.x, lcm_msg.twist.angular.y, lcm_msg.twist.angular.z], + ) + return cls(twist, lcm_msg.covariance) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistWithCovariance) -> TwistWithCovariance: + """Create a TwistWithCovariance from a ROS geometry_msgs/TwistWithCovariance message. + + Args: + ros_msg: ROS TwistWithCovariance message + + Returns: + TwistWithCovariance instance + """ + + twist = Twist.from_ros_msg(ros_msg.twist) + return cls(twist, list(ros_msg.covariance)) + + def to_ros_msg(self) -> ROSTwistWithCovariance: + """Convert to a ROS geometry_msgs/TwistWithCovariance message. + + Returns: + ROS TwistWithCovariance message + """ + + ros_msg = ROSTwistWithCovariance() # type: ignore[no-untyped-call] + ros_msg.twist = self.twist.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + ros_msg.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + ros_msg.covariance = list(self.covariance) # type: ignore[has-type] + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py new file mode 100644 index 0000000000..a00349798e --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py @@ -0,0 +1,173 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped] + TwistWithCovarianceStamped as LCMTwistWithCovarianceStamped, +) +import numpy as np +from plum import dispatch + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped, + ) +except ImportError: + ROSTwistWithCovarianceStamped = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from TwistWithCovarianceStamped +TwistWithCovarianceStampedConvertable: TypeAlias = ( + tuple[Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray] # type: ignore[type-arg] + | LCMTwistWithCovarianceStamped + | dict[ + str, + Twist + | tuple[VectorConvertable, VectorConvertable] + | list[float] + | np.ndarray # type: ignore[type-arg] + | float + | str, + ] +) + + +def sec_nsec(ts): # type: ignore[no-untyped-def] + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class TwistWithCovarianceStamped(TwistWithCovariance, Timestamped): + msg_name = "geometry_msgs.TwistWithCovarianceStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + """Initialize with timestamp and frame_id.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + twist: Twist | tuple[VectorConvertable, VectorConvertable] | None = None, + covariance: list[float] | np.ndarray | None = None, # type: ignore[type-arg] + ) -> None: + """Initialize with timestamp, frame_id, twist and covariance.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + if twist is None: + super().__init__() + else: + super().__init__(twist, covariance) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMTwistWithCovarianceStamped() + lcm_msg.twist.twist = self.twist + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + lcm_msg.twist.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + lcm_msg.twist.covariance = list(self.covariance) # type: ignore[has-type] + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) # type: ignore[no-untyped-call] + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> TwistWithCovarianceStamped: + lcm_msg = LCMTwistWithCovarianceStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + twist=Twist( + linear=[ + lcm_msg.twist.twist.linear.x, + lcm_msg.twist.twist.linear.y, + lcm_msg.twist.twist.linear.z, + ], + angular=[ + lcm_msg.twist.twist.angular.x, + lcm_msg.twist.twist.angular.y, + lcm_msg.twist.twist.angular.z, + ], + ), + covariance=lcm_msg.twist.covariance, + ) + + def __str__(self) -> str: + return ( + f"TwistWithCovarianceStamped(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistWithCovarianceStamped) -> TwistWithCovarianceStamped: # type: ignore[override] + """Create a TwistWithCovarianceStamped from a ROS geometry_msgs/TwistWithCovarianceStamped message. + + Args: + ros_msg: ROS TwistWithCovarianceStamped message + + Returns: + TwistWithCovarianceStamped instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert twist with covariance + twist_with_cov = TwistWithCovariance.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + twist=twist_with_cov.twist, + covariance=twist_with_cov.covariance, # type: ignore[has-type] + ) + + def to_ros_msg(self) -> ROSTwistWithCovarianceStamped: # type: ignore[override] + """Convert to a ROS geometry_msgs/TwistWithCovarianceStamped message. + + Returns: + ROS TwistWithCovarianceStamped message + """ + + ros_msg = ROSTwistWithCovarianceStamped() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set twist with covariance + ros_msg.twist.twist = self.twist.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): # type: ignore[has-type] + ros_msg.twist.covariance = self.covariance.tolist() # type: ignore[has-type] + else: + ros_msg.twist.covariance = list(self.covariance) # type: ignore[has-type] + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/Vector3.py b/dimos/msgs/geometry_msgs/Vector3.py new file mode 100644 index 0000000000..129c0c9a38 --- /dev/null +++ b/dimos/msgs/geometry_msgs/Vector3.py @@ -0,0 +1,456 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TypeAlias + +from dimos_lcm.geometry_msgs import Vector3 as LCMVector3 # type: ignore[import-untyped] +import numpy as np +from plum import dispatch + +# Types that can be converted to/from Vector +VectorConvertable: TypeAlias = Sequence[int | float] | LCMVector3 | np.ndarray # type: ignore[type-arg] + + +def _ensure_3d(data: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """Ensure the data array is exactly 3D by padding with zeros or raising an exception if too long.""" + if len(data) == 3: + return data + elif len(data) < 3: + padded = np.zeros(3, dtype=float) + padded[: len(data)] = data + return padded + else: + raise ValueError( + f"Vector3 cannot be initialized with more than 3 components. Got {len(data)} components." + ) + + +class Vector3(LCMVector3): # type: ignore[misc] + x: float = 0.0 + y: float = 0.0 + z: float = 0.0 + msg_name = "geometry_msgs.Vector3" + + @dispatch + def __init__(self) -> None: + """Initialize a zero 3D vector.""" + self.x = 0.0 + self.y = 0.0 + self.z = 0.0 + + @dispatch # type: ignore[no-redef] + def __init__(self, x: int | float) -> None: + """Initialize a 3D vector from a single numeric value (x, 0, 0).""" + self.x = float(x) + self.y = 0.0 + self.z = 0.0 + + @dispatch # type: ignore[no-redef] + def __init__(self, x: int | float, y: int | float) -> None: + """Initialize a 3D vector from x, y components (z=0).""" + self.x = float(x) + self.y = float(y) + self.z = 0.0 + + @dispatch # type: ignore[no-redef] + def __init__(self, x: int | float, y: int | float, z: int | float) -> None: + """Initialize a 3D vector from x, y, z components.""" + self.x = float(x) + self.y = float(y) + self.z = float(z) + + @dispatch # type: ignore[no-redef] + def __init__(self, sequence: Sequence[int | float]) -> None: + """Initialize from a sequence (list, tuple) of numbers, ensuring 3D.""" + data = _ensure_3d(np.array(sequence, dtype=float)) + self.x = float(data[0]) + self.y = float(data[1]) + self.z = float(data[2]) + + @dispatch # type: ignore[no-redef] + def __init__(self, array: np.ndarray) -> None: # type: ignore[type-arg] + """Initialize from a numpy array, ensuring 3D.""" + data = _ensure_3d(np.array(array, dtype=float)) + self.x = float(data[0]) + self.y = float(data[1]) + self.z = float(data[2]) + + @dispatch # type: ignore[no-redef] + def __init__(self, vector: Vector3) -> None: + """Initialize from another Vector3 (copy constructor).""" + self.x = vector.x + self.y = vector.y + self.z = vector.z + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_vector: LCMVector3) -> None: + """Initialize from an LCM Vector3.""" + self.x = float(lcm_vector.x) + self.y = float(lcm_vector.y) + self.z = float(lcm_vector.z) + + @property + def as_tuple(self) -> tuple[float, float, float]: + return (self.x, self.y, self.z) + + @property + def yaw(self) -> float: + return self.z + + @property + def pitch(self) -> float: + return self.y + + @property + def roll(self) -> float: + return self.x + + @property + def data(self) -> np.ndarray: # type: ignore[type-arg] + """Get the underlying numpy array.""" + return np.array([self.x, self.y, self.z], dtype=float) + + def __getitem__(self, idx: int): # type: ignore[no-untyped-def] + if idx == 0: + return self.x + elif idx == 1: + return self.y + elif idx == 2: + return self.z + else: + raise IndexError(f"Vector3 index {idx} out of range [0-2]") + + def __repr__(self) -> str: + return f"Vector({self.data})" + + def __str__(self) -> str: + def getArrow(): # type: ignore[no-untyped-def] + repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] + + if self.x == 0 and self.y == 0: + return "·" + + # Calculate angle in radians and convert to directional index + angle = np.arctan2(self.y, self.x) + # Map angle to 0-7 index (8 directions) with proper orientation + dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) + # Get directional arrow symbol + return repr[dir_index] + + return f"{getArrow()} Vector {self.__repr__()}" # type: ignore[no-untyped-call] + + def agent_encode(self) -> dict: # type: ignore[type-arg] + """Encode the vector for agent communication.""" + return {"x": self.x, "y": self.y, "z": self.z} + + def serialize(self) -> dict: # type: ignore[type-arg] + """Serialize the vector to a tuple.""" + return {"type": "vector", "c": (self.x, self.y, self.z)} + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two vectors are equal using numpy's allclose for floating point comparison.""" + if not isinstance(other, Vector3): + return False + return np.allclose([self.x, self.y, self.z], [other.x, other.y, other.z]) + + def __add__(self, other: VectorConvertable | Vector3) -> Vector3: + other_vector: Vector3 = to_vector(other) + return self.__class__( + self.x + other_vector.x, self.y + other_vector.y, self.z + other_vector.z + ) + + def __sub__(self, other: VectorConvertable | Vector3) -> Vector3: + other_vector = to_vector(other) + return self.__class__( + self.x - other_vector.x, self.y - other_vector.y, self.z - other_vector.z + ) + + def __mul__(self, scalar: float) -> Vector3: + return self.__class__(self.x * scalar, self.y * scalar, self.z * scalar) + + def __rmul__(self, scalar: float) -> Vector3: + return self.__mul__(scalar) + + def __truediv__(self, scalar: float) -> Vector3: + return self.__class__(self.x / scalar, self.y / scalar, self.z / scalar) + + def __neg__(self) -> Vector3: + return self.__class__(-self.x, -self.y, -self.z) + + def dot(self, other: VectorConvertable | Vector3) -> float: + """Compute dot product.""" + other_vector = to_vector(other) + return self.x * other_vector.x + self.y * other_vector.y + self.z * other_vector.z # type: ignore[no-any-return] + + def cross(self, other: VectorConvertable | Vector3) -> Vector3: + """Compute cross product (3D vectors only).""" + other_vector = to_vector(other) + return self.__class__( + self.y * other_vector.z - self.z * other_vector.y, + self.z * other_vector.x - self.x * other_vector.z, + self.x * other_vector.y - self.y * other_vector.x, + ) + + def magnitude(self) -> float: + """Alias for length().""" + return self.length() + + def length(self) -> float: + """Compute the Euclidean length (magnitude) of the vector.""" + return float(np.sqrt(self.x * self.x + self.y * self.y + self.z * self.z)) + + def length_squared(self) -> float: + """Compute the squared length of the vector (faster than length()).""" + return float(self.x * self.x + self.y * self.y + self.z * self.z) + + def normalize(self) -> Vector3: + """Return a normalized unit vector in the same direction.""" + length = self.length() + if length < 1e-10: # Avoid division by near-zero + return self.__class__(0.0, 0.0, 0.0) + return self.__class__(self.x / length, self.y / length, self.z / length) + + def to_2d(self) -> Vector3: + """Convert a vector to a 2D vector by taking only the x and y components (z=0).""" + return self.__class__(self.x, self.y, 0.0) + + def distance(self, other: VectorConvertable | Vector3) -> float: + """Compute Euclidean distance to another vector.""" + other_vector = to_vector(other) + dx = self.x - other_vector.x + dy = self.y - other_vector.y + dz = self.z - other_vector.z + return float(np.sqrt(dx * dx + dy * dy + dz * dz)) + + def distance_squared(self, other: VectorConvertable | Vector3) -> float: + """Compute squared Euclidean distance to another vector (faster than distance()).""" + other_vector = to_vector(other) + dx = self.x - other_vector.x + dy = self.y - other_vector.y + dz = self.z - other_vector.z + return float(dx * dx + dy * dy + dz * dz) + + def angle(self, other: VectorConvertable | Vector3) -> float: + """Compute the angle (in radians) between this vector and another.""" + other_vector = to_vector(other) + this_length = self.length() + other_length = other_vector.length() + + if this_length < 1e-10 or other_length < 1e-10: + return 0.0 + + cos_angle = np.clip( + self.dot(other_vector) / (this_length * other_length), + -1.0, + 1.0, + ) + return float(np.arccos(cos_angle)) + + def project(self, onto: VectorConvertable | Vector3) -> Vector3: + """Project this vector onto another vector.""" + onto_vector = to_vector(onto) + onto_length_sq = ( + onto_vector.x * onto_vector.x + + onto_vector.y * onto_vector.y + + onto_vector.z * onto_vector.z + ) + if onto_length_sq < 1e-10: + return self.__class__(0.0, 0.0, 0.0) + + scalar_projection = self.dot(onto_vector) / onto_length_sq + return self.__class__( + scalar_projection * onto_vector.x, + scalar_projection * onto_vector.y, + scalar_projection * onto_vector.z, + ) + + @classmethod + def zeros(cls) -> Vector3: + """Create a zero 3D vector.""" + return cls() + + @classmethod + def ones(cls) -> Vector3: + """Create a 3D vector of ones.""" + return cls(1.0, 1.0, 1.0) + + @classmethod + def unit_x(cls) -> Vector3: + """Create a unit vector in the x direction.""" + return cls(1.0, 0.0, 0.0) + + @classmethod + def unit_y(cls) -> Vector3: + """Create a unit vector in the y direction.""" + return cls(0.0, 1.0, 0.0) + + @classmethod + def unit_z(cls) -> Vector3: + """Create a unit vector in the z direction.""" + return cls(0.0, 0.0, 1.0) + + def to_list(self) -> list[float]: + """Convert the vector to a list.""" + return [self.x, self.y, self.z] + + def to_tuple(self) -> tuple[float, float, float]: + """Convert the vector to a tuple.""" + return (self.x, self.y, self.z) + + def to_numpy(self) -> np.ndarray: # type: ignore[type-arg] + """Convert the vector to a numpy array.""" + return np.array([self.x, self.y, self.z], dtype=float) + + def is_zero(self) -> bool: + """Check if this is a zero vector (all components are zero). + + Returns: + True if all components are zero, False otherwise + """ + return np.allclose([self.x, self.y, self.z], 0.0) + + @property + def quaternion(self): # type: ignore[no-untyped-def] + return self.to_quaternion() # type: ignore[no-untyped-call] + + def to_quaternion(self): # type: ignore[no-untyped-def] + """Convert Vector3 representing Euler angles (roll, pitch, yaw) to a Quaternion. + + Assumes this Vector3 contains Euler angles in radians: + - x component: roll (rotation around x-axis) + - y component: pitch (rotation around y-axis) + - z component: yaw (rotation around z-axis) + + Returns: + Quaternion: The equivalent quaternion representation + """ + # Import here to avoid circular imports + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + + # Extract Euler angles + roll = self.x + pitch = self.y + yaw = self.z + + # Convert Euler angles to quaternion using ZYX convention + # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles + + # Compute half angles + cy = np.cos(yaw * 0.5) + sy = np.sin(yaw * 0.5) + cp = np.cos(pitch * 0.5) + sp = np.sin(pitch * 0.5) + cr = np.cos(roll * 0.5) + sr = np.sin(roll * 0.5) + + # Compute quaternion components + w = cr * cp * cy + sr * sp * sy + x = sr * cp * cy - cr * sp * sy + y = cr * sp * cy + sr * cp * sy + z = cr * cp * sy - sr * sp * cy + + return Quaternion(x, y, z, w) + + def __bool__(self) -> bool: + """Boolean conversion for Vector. + + A Vector is considered False if it's a zero vector (all components are zero), + and True otherwise. + + Returns: + False if vector is zero, True otherwise + """ + return not self.is_zero() + + +@dispatch +def to_numpy(value: Vector3) -> np.ndarray: # type: ignore[type-arg] + """Convert a Vector3 to a numpy array.""" + return value.to_numpy() + + +@dispatch # type: ignore[no-redef] +def to_numpy(value: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """Pass through numpy arrays.""" + return value + + +@dispatch # type: ignore[no-redef] +def to_numpy(value: Sequence[int | float]) -> np.ndarray: # type: ignore[type-arg] + """Convert a sequence to a numpy array.""" + return np.array(value, dtype=float) + + +@dispatch +def to_vector(value: Vector3) -> Vector3: + """Pass through Vector3 objects.""" + return value + + +@dispatch # type: ignore[no-redef] +def to_vector(value: VectorConvertable | Vector3) -> Vector3: + """Convert a vector-compatible value to a Vector3 object.""" + return Vector3(value) + + +@dispatch +def to_tuple(value: Vector3) -> tuple[float, float, float]: + """Convert a Vector3 to a tuple.""" + return value.to_tuple() + + +@dispatch # type: ignore[no-redef] +def to_tuple(value: np.ndarray) -> tuple[float, ...]: # type: ignore[type-arg] + """Convert a numpy array to a tuple.""" + return tuple(value.tolist()) + + +@dispatch # type: ignore[no-redef] +def to_tuple(value: Sequence[int | float]) -> tuple[float, ...]: + """Convert a sequence to a tuple.""" + if isinstance(value, tuple): + return value + else: + return tuple(value) + + +@dispatch +def to_list(value: Vector3) -> list[float]: + """Convert a Vector3 to a list.""" + return value.to_list() + + +@dispatch # type: ignore[no-redef] +def to_list(value: np.ndarray) -> list[float]: # type: ignore[type-arg] + """Convert a numpy array to a list.""" + return value.tolist() + + +@dispatch # type: ignore[no-redef] +def to_list(value: Sequence[int | float]) -> list[float]: + """Convert a sequence to a list.""" + if isinstance(value, list): + return value + else: + return list(value) + + +VectorLike: TypeAlias = VectorConvertable | Vector3 + + +def make_vector3(x: float, y: float, z: float) -> Vector3: + return Vector3(x, y, z) diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py new file mode 100644 index 0000000000..683aa2e37c --- /dev/null +++ b/dimos/msgs/geometry_msgs/__init__.py @@ -0,0 +1,28 @@ +from dimos.msgs.geometry_msgs.Pose import Pose, PoseLike, to_pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike + +__all__ = [ + "Pose", + "PoseLike", + "PoseStamped", + "PoseWithCovariance", + "PoseWithCovarianceStamped", + "Quaternion", + "Transform", + "Twist", + "TwistStamped", + "TwistWithCovariance", + "TwistWithCovarianceStamped", + "Vector3", + "VectorLike", + "to_pose", +] diff --git a/dimos/msgs/geometry_msgs/test_Pose.py b/dimos/msgs/geometry_msgs/test_Pose.py new file mode 100644 index 0000000000..50bfaf1388 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_Pose.py @@ -0,0 +1,808 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle + +from dimos_lcm.geometry_msgs import Pose as LCMPose +import numpy as np +import pytest + +try: + from geometry_msgs.msg import Point as ROSPoint, Pose as ROSPose, Quaternion as ROSQuaternion +except ImportError: + ROSPose = None + ROSPoint = None + ROSQuaternion = None + +from dimos.msgs.geometry_msgs.Pose import Pose, to_pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_pose_default_init() -> None: + """Test that default initialization creates a pose at origin with identity orientation.""" + pose = Pose() + + # Position should be at origin + assert pose.position.x == 0.0 + assert pose.position.y == 0.0 + assert pose.position.z == 0.0 + + # Orientation should be identity quaternion + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + # Test convenience properties + assert pose.x == 0.0 + assert pose.y == 0.0 + assert pose.z == 0.0 + + +def test_pose_pose_init() -> None: + """Test initialization with position coordinates only (identity orientation).""" + pose_data = Pose(1.0, 2.0, 3.0) + + pose = to_pose(pose_data) + + # Position should be as specified + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should be identity quaternion + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + # Test convenience properties + assert pose.x == 1.0 + assert pose.y == 2.0 + assert pose.z == 3.0 + + +def test_pose_position_init() -> None: + """Test initialization with position coordinates only (identity orientation).""" + pose = Pose(1.0, 2.0, 3.0) + + # Position should be as specified + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should be identity quaternion + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + # Test convenience properties + assert pose.x == 1.0 + assert pose.y == 2.0 + assert pose.z == 3.0 + + +def test_pose_full_init() -> None: + """Test initialization with position and orientation coordinates.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + + # Position should be as specified + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should be as specified + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + # Test convenience properties + assert pose.x == 1.0 + assert pose.y == 2.0 + assert pose.z == 3.0 + + +def test_pose_vector_position_init() -> None: + """Test initialization with Vector3 position (identity orientation).""" + position = Vector3(4.0, 5.0, 6.0) + pose = Pose(position) + + # Position should match the vector + assert pose.position.x == 4.0 + assert pose.position.y == 5.0 + assert pose.position.z == 6.0 + + # Orientation should be identity + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + +def test_pose_vector_quaternion_init() -> None: + """Test initialization with Vector3 position and Quaternion orientation.""" + position = Vector3(1.0, 2.0, 3.0) + orientation = Quaternion(0.1, 0.2, 0.3, 0.9) + pose = Pose(position, orientation) + + # Position should match the vector + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match the quaternion + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_list_init() -> None: + """Test initialization with lists for position and orientation.""" + position_list = [1.0, 2.0, 3.0] + orientation_list = [0.1, 0.2, 0.3, 0.9] + pose = Pose(position_list, orientation_list) + + # Position should match the list + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match the list + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_tuple_init() -> None: + """Test initialization from a tuple of (position, orientation).""" + position = [1.0, 2.0, 3.0] + orientation = [0.1, 0.2, 0.3, 0.9] + pose_tuple = (position, orientation) + pose = Pose(pose_tuple) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_dict_init() -> None: + """Test initialization from a dictionary with 'position' and 'orientation' keys.""" + pose_dict = {"position": [1.0, 2.0, 3.0], "orientation": [0.1, 0.2, 0.3, 0.9]} + pose = Pose(pose_dict) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_copy_init() -> None: + """Test initialization from another Pose (copy constructor).""" + original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + copy = Pose(original) + + # Position should match + assert copy.position.x == 1.0 + assert copy.position.y == 2.0 + assert copy.position.z == 3.0 + + # Orientation should match + assert copy.orientation.x == 0.1 + assert copy.orientation.y == 0.2 + assert copy.orientation.z == 0.3 + assert copy.orientation.w == 0.9 + + # Should be a copy, not the same object + assert copy is not original + assert copy == original + + +def test_pose_lcm_init() -> None: + """Test initialization from an LCM Pose.""" + # Create LCM pose + lcm_pose = LCMPose() + lcm_pose.position.x = 1.0 + lcm_pose.position.y = 2.0 + lcm_pose.position.z = 3.0 + lcm_pose.orientation.x = 0.1 + lcm_pose.orientation.y = 0.2 + lcm_pose.orientation.z = 0.3 + lcm_pose.orientation.w = 0.9 + + pose = Pose(lcm_pose) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_properties() -> None: + """Test pose property access.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + + # Test position properties + assert pose.x == 1.0 + assert pose.y == 2.0 + assert pose.z == 3.0 + + # Test orientation properties (through quaternion's to_euler method) + euler = pose.orientation.to_euler() + assert pose.roll == euler.x + assert pose.pitch == euler.y + assert pose.yaw == euler.z + + +def test_pose_euler_properties_identity() -> None: + """Test pose Euler angle properties with identity orientation.""" + pose = Pose(1.0, 2.0, 3.0) # Identity orientation + + # Identity quaternion should give zero Euler angles + assert np.isclose(pose.roll, 0.0, atol=1e-10) + assert np.isclose(pose.pitch, 0.0, atol=1e-10) + assert np.isclose(pose.yaw, 0.0, atol=1e-10) + + # Euler property should also be zeros + assert np.isclose(pose.orientation.euler.x, 0.0, atol=1e-10) + assert np.isclose(pose.orientation.euler.y, 0.0, atol=1e-10) + assert np.isclose(pose.orientation.euler.z, 0.0, atol=1e-10) + + +def test_pose_repr() -> None: + """Test pose string representation.""" + pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) + + repr_str = repr(pose) + + # Should contain position and orientation info + assert "Pose" in repr_str + assert "position" in repr_str + assert "orientation" in repr_str + + # Should contain the actual values (approximately) + assert "1.234" in repr_str or "1.23" in repr_str + assert "2.567" in repr_str or "2.57" in repr_str + + +def test_pose_str() -> None: + """Test pose string formatting.""" + pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) + + str_repr = str(pose) + + # Should contain position coordinates + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + + # Should contain Euler angles + assert "euler" in str_repr + + # Should be formatted with specified precision + assert str_repr.count("Pose") == 1 + + +def test_pose_equality() -> None: + """Test pose equality comparison.""" + pose1 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose2 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose3 = Pose(1.1, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) # Different position + pose4 = Pose(1.0, 2.0, 3.0, 0.11, 0.2, 0.3, 0.9) # Different orientation + + # Equal poses + assert pose1 == pose2 + assert pose2 == pose1 + + # Different poses + assert pose1 != pose3 + assert pose1 != pose4 + assert pose3 != pose4 + + # Different types + assert pose1 != "not a pose" + assert pose1 != [1.0, 2.0, 3.0] + assert pose1 is not None + + +def test_pose_with_numpy_arrays() -> None: + """Test pose initialization with numpy arrays.""" + position_array = np.array([1.0, 2.0, 3.0]) + orientation_array = np.array([0.1, 0.2, 0.3, 0.9]) + + pose = Pose(position_array, orientation_array) + + # Position should match + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + + # Orientation should match + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_with_mixed_types() -> None: + """Test pose initialization with mixed input types.""" + # Position as tuple, orientation as list + pose1 = Pose((1.0, 2.0, 3.0), [0.1, 0.2, 0.3, 0.9]) + + # Position as numpy array, orientation as Vector3/Quaternion + position = np.array([1.0, 2.0, 3.0]) + orientation = Quaternion(0.1, 0.2, 0.3, 0.9) + pose2 = Pose(position, orientation) + + # Both should result in the same pose + assert pose1.position.x == pose2.position.x + assert pose1.position.y == pose2.position.y + assert pose1.position.z == pose2.position.z + assert pose1.orientation.x == pose2.orientation.x + assert pose1.orientation.y == pose2.orientation.y + assert pose1.orientation.z == pose2.orientation.z + assert pose1.orientation.w == pose2.orientation.w + + +def test_to_pose_passthrough() -> None: + """Test to_pose function with Pose input (passthrough).""" + original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + result = to_pose(original) + + # Should be the same object (passthrough) + assert result is original + + +def test_to_pose_conversion() -> None: + """Test to_pose function with convertible inputs.""" + # Note: The to_pose conversion function has type checking issues in the current implementation + # Test direct construction instead to verify the intended functionality + + # Test the intended functionality by creating poses directly + pose_tuple = ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3, 0.9]) + result1 = Pose(pose_tuple) + + assert isinstance(result1, Pose) + assert result1.position.x == 1.0 + assert result1.position.y == 2.0 + assert result1.position.z == 3.0 + assert result1.orientation.x == 0.1 + assert result1.orientation.y == 0.2 + assert result1.orientation.z == 0.3 + assert result1.orientation.w == 0.9 + + # Test with dictionary + pose_dict = {"position": [1.0, 2.0, 3.0], "orientation": [0.1, 0.2, 0.3, 0.9]} + result2 = Pose(pose_dict) + + assert isinstance(result2, Pose) + assert result2.position.x == 1.0 + assert result2.position.y == 2.0 + assert result2.position.z == 3.0 + assert result2.orientation.x == 0.1 + assert result2.orientation.y == 0.2 + assert result2.orientation.z == 0.3 + assert result2.orientation.w == 0.9 + + +def test_pose_euler_roundtrip() -> None: + """Test conversion from Euler angles to quaternion and back.""" + # Start with known Euler angles (small angles to avoid gimbal lock) + roll = 0.1 + pitch = 0.2 + yaw = 0.3 + + # Create quaternion from Euler angles + euler_vector = Vector3(roll, pitch, yaw) + quaternion = euler_vector.to_quaternion() + + # Create pose with this quaternion + pose = Pose(Vector3(0, 0, 0), quaternion) + + # Convert back to Euler angles + result_euler = pose.orientation.euler + + # Should get back the original Euler angles (within tolerance) + assert np.isclose(result_euler.x, roll, atol=1e-6) + assert np.isclose(result_euler.y, pitch, atol=1e-6) + assert np.isclose(result_euler.z, yaw, atol=1e-6) + + +def test_pose_zero_position() -> None: + """Test pose with zero position vector.""" + # Use manual construction since Vector3.zeros has signature issues + pose = Pose(0.0, 0.0, 0.0) # Position at origin with identity orientation + + assert pose.x == 0.0 + assert pose.y == 0.0 + assert pose.z == 0.0 + assert np.isclose(pose.roll, 0.0, atol=1e-10) + assert np.isclose(pose.pitch, 0.0, atol=1e-10) + assert np.isclose(pose.yaw, 0.0, atol=1e-10) + + +def test_pose_unit_vectors() -> None: + """Test pose with unit vector positions.""" + # Test unit x vector position + pose_x = Pose(Vector3.unit_x()) + assert pose_x.x == 1.0 + assert pose_x.y == 0.0 + assert pose_x.z == 0.0 + + # Test unit y vector position + pose_y = Pose(Vector3.unit_y()) + assert pose_y.x == 0.0 + assert pose_y.y == 1.0 + assert pose_y.z == 0.0 + + # Test unit z vector position + pose_z = Pose(Vector3.unit_z()) + assert pose_z.x == 0.0 + assert pose_z.y == 0.0 + assert pose_z.z == 1.0 + + +def test_pose_negative_coordinates() -> None: + """Test pose with negative coordinates.""" + pose = Pose(-1.0, -2.0, -3.0, -0.1, -0.2, -0.3, 0.9) + + # Position should be negative + assert pose.x == -1.0 + assert pose.y == -2.0 + assert pose.z == -3.0 + + # Orientation should be as specified + assert pose.orientation.x == -0.1 + assert pose.orientation.y == -0.2 + assert pose.orientation.z == -0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_large_coordinates() -> None: + """Test pose with large coordinate values.""" + large_value = 1000.0 + pose = Pose(large_value, large_value, large_value) + + assert pose.x == large_value + assert pose.y == large_value + assert pose.z == large_value + + # Orientation should still be identity + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + +@pytest.mark.parametrize( + "x,y,z", + [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (0.5, -0.5, 1.5), (100.0, -100.0, 0.0)], +) +def test_pose_parametrized_positions(x, y, z) -> None: + """Parametrized test for various position values.""" + pose = Pose(x, y, z) + + assert pose.x == x + assert pose.y == y + assert pose.z == z + + # Should have identity orientation + assert pose.orientation.x == 0.0 + assert pose.orientation.y == 0.0 + assert pose.orientation.z == 0.0 + assert pose.orientation.w == 1.0 + + +@pytest.mark.parametrize( + "qx,qy,qz,qw", + [ + (0.0, 0.0, 0.0, 1.0), # Identity + (1.0, 0.0, 0.0, 0.0), # 180° around x + (0.0, 1.0, 0.0, 0.0), # 180° around y + (0.0, 0.0, 1.0, 0.0), # 180° around z + (0.5, 0.5, 0.5, 0.5), # Equal components + ], +) +def test_pose_parametrized_orientations(qx, qy, qz, qw) -> None: + """Parametrized test for various orientation values.""" + pose = Pose(0.0, 0.0, 0.0, qx, qy, qz, qw) + + # Position should be at origin + assert pose.x == 0.0 + assert pose.y == 0.0 + assert pose.z == 0.0 + + # Orientation should match + assert pose.orientation.x == qx + assert pose.orientation.y == qy + assert pose.orientation.z == qz + assert pose.orientation.w == qw + + +def test_lcm_encode_decode() -> None: + """Test encoding and decoding of Pose to/from binary LCM format.""" + + def encodepass() -> None: + pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + binary_msg = pose_source.lcm_encode() + pose_dest = Pose.lcm_decode(binary_msg) + assert isinstance(pose_dest, Pose) + assert pose_dest is not pose_source + assert pose_dest == pose_source + # Verify we get our custom types back + assert isinstance(pose_dest.position, Vector3) + assert isinstance(pose_dest.orientation, Quaternion) + + import timeit + + print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") + + +def test_pickle_encode_decode() -> None: + """Test encoding and decoding of Pose to/from binary LCM format.""" + + def encodepass() -> None: + pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + binary_msg = pickle.dumps(pose_source) + pose_dest = pickle.loads(binary_msg) + assert isinstance(pose_dest, Pose) + assert pose_dest is not pose_source + assert pose_dest == pose_source + + import timeit + + print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") + + +def test_pose_addition_translation_only() -> None: + """Test pose addition with translation only (identity rotations).""" + # Two poses with only translations + pose1 = Pose(1.0, 2.0, 3.0) # First translation + pose2 = Pose(4.0, 5.0, 6.0) # Second translation + + # Adding should combine translations + result = pose1 + pose2 + + assert result.position.x == 5.0 # 1 + 4 + assert result.position.y == 7.0 # 2 + 5 + assert result.position.z == 9.0 # 3 + 6 + + # Orientation should remain identity + assert result.orientation.x == 0.0 + assert result.orientation.y == 0.0 + assert result.orientation.z == 0.0 + assert result.orientation.w == 1.0 + + +def test_pose_addition_with_rotation() -> None: + """Test pose addition with rotation applied to translation.""" + # First pose: at origin, rotated 90 degrees around Z (yaw) + # 90 degree rotation quaternion around Z: (0, 0, sin(pi/4), cos(pi/4)) + angle = np.pi / 2 # 90 degrees + pose1 = Pose(0.0, 0.0, 0.0, 0.0, 0.0, np.sin(angle / 2), np.cos(angle / 2)) + + # Second pose: 1 unit forward (along X in its frame) + pose2 = Pose(1.0, 0.0, 0.0) + + # After rotation, the forward direction should be along Y + result = pose1 + pose2 + + # Position should be rotated + assert np.isclose(result.position.x, 0.0, atol=1e-10) + assert np.isclose(result.position.y, 1.0, atol=1e-10) + assert np.isclose(result.position.z, 0.0, atol=1e-10) + + # Orientation should be same as pose1 (pose2 has identity rotation) + assert np.isclose(result.orientation.x, 0.0, atol=1e-10) + assert np.isclose(result.orientation.y, 0.0, atol=1e-10) + assert np.isclose(result.orientation.z, np.sin(angle / 2), atol=1e-10) + assert np.isclose(result.orientation.w, np.cos(angle / 2), atol=1e-10) + + +def test_pose_addition_rotation_composition() -> None: + """Test that rotations are properly composed.""" + # First pose: 45 degrees around Z + angle1 = np.pi / 4 # 45 degrees + pose1 = Pose(0.0, 0.0, 0.0, 0.0, 0.0, np.sin(angle1 / 2), np.cos(angle1 / 2)) + + # Second pose: another 45 degrees around Z + angle2 = np.pi / 4 # 45 degrees + pose2 = Pose(0.0, 0.0, 0.0, 0.0, 0.0, np.sin(angle2 / 2), np.cos(angle2 / 2)) + + # Result should be 90 degrees around Z + result = pose1 + pose2 + + # Check final angle is 90 degrees + expected_angle = angle1 + angle2 # 90 degrees + expected_qz = np.sin(expected_angle / 2) + expected_qw = np.cos(expected_angle / 2) + + assert np.isclose(result.orientation.z, expected_qz, atol=1e-10) + assert np.isclose(result.orientation.w, expected_qw, atol=1e-10) + + +def test_pose_addition_full_transform() -> None: + """Test full pose composition with translation and rotation.""" + # Robot pose: at (2, 1, 0), facing 90 degrees left (positive yaw) + robot_yaw = np.pi / 2 # 90 degrees + robot_pose = Pose(2.0, 1.0, 0.0, 0.0, 0.0, np.sin(robot_yaw / 2), np.cos(robot_yaw / 2)) + + # Object in robot frame: 3 units forward, 1 unit right + object_in_robot = Pose(3.0, -1.0, 0.0) + + # Compose to get object in world frame + object_in_world = robot_pose + object_in_robot + + # Robot is facing left (90 degrees), so: + # - Robot's forward (X) is world's negative Y + # - Robot's right (negative Y) is world's X + # So object should be at: robot_pos + rotated_offset + # rotated_offset: (3, -1) rotated 90° CCW = (1, 3) + assert np.isclose(object_in_world.position.x, 3.0, atol=1e-10) # 2 + 1 + assert np.isclose(object_in_world.position.y, 4.0, atol=1e-10) # 1 + 3 + assert np.isclose(object_in_world.position.z, 0.0, atol=1e-10) + + # Orientation should match robot's orientation (object has no rotation) + assert np.isclose(object_in_world.yaw, robot_yaw, atol=1e-10) + + +def test_pose_addition_chain() -> None: + """Test chaining multiple pose additions.""" + # Create a chain of transformations + pose1 = Pose(1.0, 0.0, 0.0) # Move 1 unit in X + pose2 = Pose(0.0, 1.0, 0.0) # Move 1 unit in Y (relative to pose1) + pose3 = Pose(0.0, 0.0, 1.0) # Move 1 unit in Z (relative to pose1+pose2) + + # Chain them together + result = pose1 + pose2 + pose3 + + # Should accumulate all translations + assert result.position.x == 1.0 + assert result.position.y == 1.0 + assert result.position.z == 1.0 + + +def test_pose_addition_with_convertible() -> None: + """Test pose addition with convertible types.""" + pose1 = Pose(1.0, 2.0, 3.0) + + # Add with tuple + pose_tuple = ([4.0, 5.0, 6.0], [0.0, 0.0, 0.0, 1.0]) + result1 = pose1 + pose_tuple + assert result1.position.x == 5.0 + assert result1.position.y == 7.0 + assert result1.position.z == 9.0 + + # Add with dict + pose_dict = {"position": [1.0, 0.0, 0.0], "orientation": [0.0, 0.0, 0.0, 1.0]} + result2 = pose1 + pose_dict + assert result2.position.x == 2.0 + assert result2.position.y == 2.0 + assert result2.position.z == 3.0 + + +def test_pose_identity_addition() -> None: + """Test that adding identity pose leaves pose unchanged.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + identity = Pose() # Identity pose at origin + + result = pose + identity + + # Should be unchanged + assert result.position.x == pose.position.x + assert result.position.y == pose.position.y + assert result.position.z == pose.position.z + assert result.orientation.x == pose.orientation.x + assert result.orientation.y == pose.orientation.y + assert result.orientation.z == pose.orientation.z + assert result.orientation.w == pose.orientation.w + + +def test_pose_addition_3d_rotation() -> None: + """Test pose addition with 3D rotations.""" + # First pose: rotated around X axis (roll) + roll = np.pi / 4 # 45 degrees + pose1 = Pose(1.0, 0.0, 0.0, np.sin(roll / 2), 0.0, 0.0, np.cos(roll / 2)) + + # Second pose: movement along Y and Z in local frame + pose2 = Pose(0.0, 1.0, 1.0) + + # Compose transformations + result = pose1 + pose2 + + # The Y and Z movement should be rotated around X + # After 45° rotation around X: + # - Local Y -> world Y * cos(45°) - Z * sin(45°) + # - Local Z -> world Y * sin(45°) + Z * cos(45°) + cos45 = np.cos(roll) + sin45 = np.sin(roll) + + assert np.isclose(result.position.x, 1.0, atol=1e-10) # X unchanged + assert np.isclose(result.position.y, cos45 - sin45, atol=1e-10) + assert np.isclose(result.position.z, sin45 + cos45, atol=1e-10) + + +@pytest.mark.ros +def test_pose_from_ros_msg() -> None: + """Test creating a Pose from a ROS Pose message.""" + ros_msg = ROSPose() + ros_msg.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + + pose = Pose.from_ros_msg(ros_msg) + + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_to_ros_msg() -> None: + """Test converting a Pose to a ROS Pose message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + + ros_msg = pose.to_ros_msg() + + assert isinstance(ros_msg, ROSPose) + assert ros_msg.position.x == 1.0 + assert ros_msg.position.y == 2.0 + assert ros_msg.position.z == 3.0 + assert ros_msg.orientation.x == 0.1 + assert ros_msg.orientation.y == 0.2 + assert ros_msg.orientation.z == 0.3 + assert ros_msg.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_ros_roundtrip() -> None: + """Test round-trip conversion between Pose and ROS Pose.""" + original = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + + ros_msg = original.to_ros_msg() + restored = Pose.from_ros_msg(ros_msg) + + assert restored.position.x == original.position.x + assert restored.position.y == original.position.y + assert restored.position.z == original.position.z + assert restored.orientation.x == original.orientation.x + assert restored.orientation.y == original.orientation.y + assert restored.orientation.z == original.orientation.z + assert restored.orientation.w == original.orientation.w diff --git a/dimos/msgs/geometry_msgs/test_PoseStamped.py b/dimos/msgs/geometry_msgs/test_PoseStamped.py new file mode 100644 index 0000000000..603723b610 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_PoseStamped.py @@ -0,0 +1,139 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import time + +import pytest + +try: + from geometry_msgs.msg import PoseStamped as ROSPoseStamped +except ImportError: + ROSPoseStamped = None + +from dimos.msgs.geometry_msgs import PoseStamped + + +def test_lcm_encode_decode() -> None: + """Test encoding and decoding of Pose to/from binary LCM format.""" + + pose_source = PoseStamped( + ts=time.time(), + position=(1.0, 2.0, 3.0), + orientation=(0.1, 0.2, 0.3, 0.9), + ) + binary_msg = pose_source.lcm_encode() + pose_dest = PoseStamped.lcm_decode(binary_msg) + + assert isinstance(pose_dest, PoseStamped) + assert pose_dest is not pose_source + + print(pose_source.position) + print(pose_source.orientation) + + print(pose_dest.position) + print(pose_dest.orientation) + assert pose_dest == pose_source + + +def test_pickle_encode_decode() -> None: + """Test encoding and decoding of PoseStamped to/from binary LCM format.""" + + pose_source = PoseStamped( + ts=time.time(), + position=(1.0, 2.0, 3.0), + orientation=(0.1, 0.2, 0.3, 0.9), + ) + binary_msg = pickle.dumps(pose_source) + pose_dest = pickle.loads(binary_msg) + assert isinstance(pose_dest, PoseStamped) + assert pose_dest is not pose_source + assert pose_dest == pose_source + + +@pytest.mark.ros +def test_pose_stamped_from_ros_msg() -> None: + """Test creating a PoseStamped from a ROS PoseStamped message.""" + ros_msg = ROSPoseStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.pose.position.x = 1.0 + ros_msg.pose.position.y = 2.0 + ros_msg.pose.position.z = 3.0 + ros_msg.pose.orientation.x = 0.1 + ros_msg.pose.orientation.y = 0.2 + ros_msg.pose.orientation.z = 0.3 + ros_msg.pose.orientation.w = 0.9 + + pose_stamped = PoseStamped.from_ros_msg(ros_msg) + + assert pose_stamped.frame_id == "world" + assert pose_stamped.ts == 123.456 + assert pose_stamped.position.x == 1.0 + assert pose_stamped.position.y == 2.0 + assert pose_stamped.position.z == 3.0 + assert pose_stamped.orientation.x == 0.1 + assert pose_stamped.orientation.y == 0.2 + assert pose_stamped.orientation.z == 0.3 + assert pose_stamped.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_stamped_to_ros_msg() -> None: + """Test converting a PoseStamped to a ROS PoseStamped message.""" + pose_stamped = PoseStamped( + ts=123.456, + frame_id="base_link", + position=(1.0, 2.0, 3.0), + orientation=(0.1, 0.2, 0.3, 0.9), + ) + + ros_msg = pose_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert ros_msg.pose.position.x == 1.0 + assert ros_msg.pose.position.y == 2.0 + assert ros_msg.pose.position.z == 3.0 + assert ros_msg.pose.orientation.x == 0.1 + assert ros_msg.pose.orientation.y == 0.2 + assert ros_msg.pose.orientation.z == 0.3 + assert ros_msg.pose.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_stamped_ros_roundtrip() -> None: + """Test round-trip conversion between PoseStamped and ROS PoseStamped.""" + original = PoseStamped( + ts=123.789, + frame_id="odom", + position=(1.5, 2.5, 3.5), + orientation=(0.15, 0.25, 0.35, 0.85), + ) + + ros_msg = original.to_ros_msg() + restored = PoseStamped.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert restored.position.x == original.position.x + assert restored.position.y == original.position.y + assert restored.position.z == original.position.z + assert restored.orientation.x == original.orientation.x + assert restored.orientation.y == original.orientation.y + assert restored.orientation.z == original.orientation.z + assert restored.orientation.w == original.orientation.w diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py new file mode 100644 index 0000000000..d62ca6e806 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py @@ -0,0 +1,388 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance +import numpy as np +import pytest + +try: + from geometry_msgs.msg import ( + Point as ROSPoint, + Pose as ROSPose, + PoseWithCovariance as ROSPoseWithCovariance, + Quaternion as ROSQuaternion, + ) +except ImportError: + ROSPoseWithCovariance = None + ROSPose = None + ROSPoint = None + ROSQuaternion = None + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance + + +def test_pose_with_covariance_default_init() -> None: + """Test that default initialization creates a pose at origin with zero covariance.""" + pose_cov = PoseWithCovariance() + + # Pose should be at origin with identity orientation + assert pose_cov.pose.position.x == 0.0 + assert pose_cov.pose.position.y == 0.0 + assert pose_cov.pose.position.z == 0.0 + assert pose_cov.pose.orientation.x == 0.0 + assert pose_cov.pose.orientation.y == 0.0 + assert pose_cov.pose.orientation.z == 0.0 + assert pose_cov.pose.orientation.w == 1.0 + + # Covariance should be all zeros + assert np.all(pose_cov.covariance == 0.0) + assert pose_cov.covariance.shape == (36,) + + +def test_pose_with_covariance_pose_init() -> None: + """Test initialization with a Pose object.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = PoseWithCovariance(pose) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + + # Covariance should be zeros by default + assert np.all(pose_cov.covariance == 0.0) + + +def test_pose_with_covariance_pose_and_covariance_init() -> None: + """Test initialization with pose and covariance.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + + # Covariance should match + assert np.array_equal(pose_cov.covariance, covariance) + + +def test_pose_with_covariance_list_covariance() -> None: + """Test initialization with covariance as a list.""" + pose = Pose(1.0, 2.0, 3.0) + covariance_list = list(range(36)) + pose_cov = PoseWithCovariance(pose, covariance_list) + + # Covariance should be converted to numpy array + assert isinstance(pose_cov.covariance, np.ndarray) + assert np.array_equal(pose_cov.covariance, np.array(covariance_list)) + + +def test_pose_with_covariance_copy_init() -> None: + """Test copy constructor.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + original = PoseWithCovariance(pose, covariance) + copy = PoseWithCovariance(original) + + # Should be equal but not the same object + assert copy == original + assert copy is not original + assert copy.pose is not original.pose + assert copy.covariance is not original.covariance + + # Modify original to ensure they're independent + original.covariance[0] = 999.0 + assert copy.covariance[0] != 999.0 + + +def test_pose_with_covariance_lcm_init() -> None: + """Test initialization from LCM message.""" + lcm_msg = LCMPoseWithCovariance() + lcm_msg.pose.position.x = 1.0 + lcm_msg.pose.position.y = 2.0 + lcm_msg.pose.position.z = 3.0 + lcm_msg.pose.orientation.x = 0.1 + lcm_msg.pose.orientation.y = 0.2 + lcm_msg.pose.orientation.z = 0.3 + lcm_msg.pose.orientation.w = 0.9 + lcm_msg.covariance = list(range(36)) + + pose_cov = PoseWithCovariance(lcm_msg) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + + # Covariance should match + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +def test_pose_with_covariance_dict_init() -> None: + """Test initialization from dictionary.""" + pose_dict = {"pose": Pose(1.0, 2.0, 3.0), "covariance": list(range(36))} + pose_cov = PoseWithCovariance(pose_dict) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +def test_pose_with_covariance_dict_init_no_covariance() -> None: + """Test initialization from dictionary without covariance.""" + pose_dict = {"pose": Pose(1.0, 2.0, 3.0)} + pose_cov = PoseWithCovariance(pose_dict) + + assert pose_cov.pose.position.x == 1.0 + assert np.all(pose_cov.covariance == 0.0) + + +def test_pose_with_covariance_tuple_init() -> None: + """Test initialization from tuple.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.arange(36, dtype=float) + pose_tuple = (pose, covariance) + pose_cov = PoseWithCovariance(pose_tuple) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert np.array_equal(pose_cov.covariance, covariance) + + +def test_pose_with_covariance_properties() -> None: + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = PoseWithCovariance(pose) + + # Position properties + assert pose_cov.x == 1.0 + assert pose_cov.y == 2.0 + assert pose_cov.z == 3.0 + assert pose_cov.position.x == 1.0 + assert pose_cov.position.y == 2.0 + assert pose_cov.position.z == 3.0 + + # Orientation properties + assert pose_cov.orientation.x == 0.1 + assert pose_cov.orientation.y == 0.2 + assert pose_cov.orientation.z == 0.3 + assert pose_cov.orientation.w == 0.9 + + # Euler angle properties + assert pose_cov.roll == pose.roll + assert pose_cov.pitch == pose.pitch + assert pose_cov.yaw == pose.yaw + + +def test_pose_with_covariance_matrix_property() -> None: + """Test covariance matrix property.""" + pose = Pose() + covariance_array = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance_array) + + # Get as matrix + cov_matrix = pose_cov.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert cov_matrix[0, 0] == 0.0 + assert cov_matrix[5, 5] == 35.0 + + # Set from matrix + new_matrix = np.eye(6) * 2.0 + pose_cov.covariance_matrix = new_matrix + assert np.array_equal(pose_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + +def test_pose_with_covariance_repr() -> None: + """Test string representation.""" + pose = Pose(1.234, 2.567, 3.891) + pose_cov = PoseWithCovariance(pose) + + repr_str = repr(pose_cov) + assert "PoseWithCovariance" in repr_str + assert "pose=" in repr_str + assert "covariance=" in repr_str + assert "36 elements" in repr_str + + +def test_pose_with_covariance_str() -> None: + """Test string formatting.""" + pose = Pose(1.234, 2.567, 3.891) + covariance = np.eye(6).flatten() + pose_cov = PoseWithCovariance(pose, covariance) + + str_repr = str(pose_cov) + assert "PoseWithCovariance" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "6.000" in str_repr # Trace of identity matrix is 6 + + +def test_pose_with_covariance_equality() -> None: + """Test equality comparison.""" + pose1 = Pose(1.0, 2.0, 3.0) + cov1 = np.arange(36, dtype=float) + pose_cov1 = PoseWithCovariance(pose1, cov1) + + pose2 = Pose(1.0, 2.0, 3.0) + cov2 = np.arange(36, dtype=float) + pose_cov2 = PoseWithCovariance(pose2, cov2) + + # Equal + assert pose_cov1 == pose_cov2 + + # Different pose + pose3 = Pose(1.1, 2.0, 3.0) + pose_cov3 = PoseWithCovariance(pose3, cov1) + assert pose_cov1 != pose_cov3 + + # Different covariance + cov3 = np.arange(36, dtype=float) + 1 + pose_cov4 = PoseWithCovariance(pose1, cov3) + assert pose_cov1 != pose_cov4 + + # Different type + assert pose_cov1 != "not a pose" + assert pose_cov1 is not None + + +def test_pose_with_covariance_lcm_encode_decode() -> None: + """Test LCM encoding and decoding.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + source = PoseWithCovariance(pose, covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = PoseWithCovariance.lcm_decode(binary_msg) + + # Should be equal + assert decoded == source + assert isinstance(decoded, PoseWithCovariance) + assert isinstance(decoded.pose, Pose) + assert isinstance(decoded.covariance, np.ndarray) + + +@pytest.mark.ros +def test_pose_with_covariance_from_ros_msg() -> None: + """Test creating from ROS message.""" + ros_msg = ROSPoseWithCovariance() + ros_msg.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.covariance = [float(i) for i in range(36)] + + pose_cov = PoseWithCovariance.from_ros_msg(ros_msg) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_to_ros_msg() -> None: + """Test converting to ROS message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance) + + ros_msg = pose_cov.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseWithCovariance) + assert ros_msg.pose.position.x == 1.0 + assert ros_msg.pose.position.y == 2.0 + assert ros_msg.pose.position.z == 3.0 + assert ros_msg.pose.orientation.x == 0.1 + assert ros_msg.pose.orientation.y == 0.2 + assert ros_msg.pose.orientation.z == 0.3 + assert ros_msg.pose.orientation.w == 0.9 + assert list(ros_msg.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_ros_roundtrip() -> None: + """Test round-trip conversion with ROS messages.""" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + covariance = np.random.rand(36) + original = PoseWithCovariance(pose, covariance) + + ros_msg = original.to_ros_msg() + restored = PoseWithCovariance.from_ros_msg(ros_msg) + + assert restored == original + + +def test_pose_with_covariance_zero_covariance() -> None: + """Test with zero covariance matrix.""" + pose = Pose(1.0, 2.0, 3.0) + pose_cov = PoseWithCovariance(pose) + + assert np.all(pose_cov.covariance == 0.0) + assert np.trace(pose_cov.covariance_matrix) == 0.0 + + +def test_pose_with_covariance_diagonal_covariance() -> None: + """Test with diagonal covariance matrix.""" + pose = Pose() + covariance = np.zeros(36) + # Set diagonal elements + for i in range(6): + covariance[i * 6 + i] = i + 1 + + pose_cov = PoseWithCovariance(pose, covariance) + + cov_matrix = pose_cov.covariance_matrix + assert np.trace(cov_matrix) == sum(range(1, 7)) # 1+2+3+4+5+6 = 21 + + # Check diagonal elements + for i in range(6): + assert cov_matrix[i, i] == i + 1 + + # Check off-diagonal elements are zero + for i in range(6): + for j in range(6): + if i != j: + assert cov_matrix[i, j] == 0.0 + + +@pytest.mark.parametrize( + "x,y,z", + [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (100.0, -100.0, 0.0)], +) +def test_pose_with_covariance_parametrized_positions(x, y, z) -> None: + """Parametrized test for various position values.""" + pose = Pose(x, y, z) + pose_cov = PoseWithCovariance(pose) + + assert pose_cov.x == x + assert pose_cov.y == y + assert pose_cov.z == z diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py new file mode 100644 index 0000000000..1d04bd8e87 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py @@ -0,0 +1,368 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest + +try: + from builtin_interfaces.msg import Time as ROSTime + from geometry_msgs.msg import ( + Point as ROSPoint, + Pose as ROSPose, + PoseWithCovariance as ROSPoseWithCovariance, + PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped, + Quaternion as ROSQuaternion, + ) + from std_msgs.msg import Header as ROSHeader +except ImportError: + ROSHeader = None + ROSPoseWithCovarianceStamped = None + ROSPose = None + ROSQuaternion = None + ROSPoint = None + ROSTime = None + ROSPoseWithCovariance = None + + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped + + +def test_pose_with_covariance_stamped_default_init() -> None: + """Test default initialization.""" + if ROSPoseWithCovariance is None: + pytest.skip("ROS not available") + if ROSTime is None: + pytest.skip("ROS not available") + if ROSPoint is None: + pytest.skip("ROS not available") + if ROSQuaternion is None: + pytest.skip("ROS not available") + if ROSPose is None: + pytest.skip("ROS not available") + if ROSPoseWithCovarianceStamped is None: + pytest.skip("ROS not available") + if ROSHeader is None: + pytest.skip("ROS not available") + pose_cov_stamped = PoseWithCovarianceStamped() + + # Should have current timestamp + assert pose_cov_stamped.ts > 0 + assert pose_cov_stamped.frame_id == "" + + # Pose should be at origin with identity orientation + assert pose_cov_stamped.pose.position.x == 0.0 + assert pose_cov_stamped.pose.position.y == 0.0 + assert pose_cov_stamped.pose.position.z == 0.0 + assert pose_cov_stamped.pose.orientation.w == 1.0 + + # Covariance should be all zeros + assert np.all(pose_cov_stamped.covariance == 0.0) + + +def test_pose_with_covariance_stamped_with_timestamp() -> None: + """Test initialization with specific timestamp.""" + ts = 1234567890.123456 + frame_id = "base_link" + pose_cov_stamped = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id) + + assert pose_cov_stamped.ts == ts + assert pose_cov_stamped.frame_id == frame_id + + +def test_pose_with_covariance_stamped_with_pose() -> None: + """Test initialization with pose.""" + ts = 1234567890.123456 + frame_id = "map" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + pose_cov_stamped = PoseWithCovarianceStamped( + ts=ts, frame_id=frame_id, pose=pose, covariance=covariance + ) + + assert pose_cov_stamped.ts == ts + assert pose_cov_stamped.frame_id == frame_id + assert pose_cov_stamped.pose.position.x == 1.0 + assert pose_cov_stamped.pose.position.y == 2.0 + assert pose_cov_stamped.pose.position.z == 3.0 + assert np.array_equal(pose_cov_stamped.covariance, covariance) + + +def test_pose_with_covariance_stamped_properties() -> None: + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.eye(6).flatten() + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="odom", pose=pose, covariance=covariance + ) + + # Position properties + assert pose_cov_stamped.x == 1.0 + assert pose_cov_stamped.y == 2.0 + assert pose_cov_stamped.z == 3.0 + + # Orientation properties + assert pose_cov_stamped.orientation.x == 0.1 + assert pose_cov_stamped.orientation.y == 0.2 + assert pose_cov_stamped.orientation.z == 0.3 + assert pose_cov_stamped.orientation.w == 0.9 + + # Euler angles + assert pose_cov_stamped.roll == pose.roll + assert pose_cov_stamped.pitch == pose.pitch + assert pose_cov_stamped.yaw == pose.yaw + + # Covariance matrix + cov_matrix = pose_cov_stamped.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert np.trace(cov_matrix) == 6.0 + + +def test_pose_with_covariance_stamped_str() -> None: + """Test string representation.""" + pose = Pose(1.234, 2.567, 3.891) + covariance = np.eye(6).flatten() * 2.0 + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="world", pose=pose, covariance=covariance + ) + + str_repr = str(pose_cov_stamped) + assert "PoseWithCovarianceStamped" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "12.000" in str_repr # Trace of 2*identity is 12 + + +def test_pose_with_covariance_stamped_lcm_encode_decode() -> None: + """Test LCM encoding and decoding.""" + ts = 1234567890.123456 + frame_id = "camera_link" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + source = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id, pose=pose, covariance=covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = PoseWithCovarianceStamped.lcm_decode(binary_msg) + + # Check timestamp (may lose some precision) + assert abs(decoded.ts - ts) < 1e-6 + assert decoded.frame_id == frame_id + + # Check pose + assert decoded.pose.position.x == 1.0 + assert decoded.pose.position.y == 2.0 + assert decoded.pose.position.z == 3.0 + assert decoded.pose.orientation.x == 0.1 + assert decoded.pose.orientation.y == 0.2 + assert decoded.pose.orientation.z == 0.3 + assert decoded.pose.orientation.w == 0.9 + + # Check covariance + assert np.array_equal(decoded.covariance, covariance) + + +@pytest.mark.ros +def test_pose_with_covariance_stamped_from_ros_msg() -> None: + """Test creating from ROS message.""" + ros_msg = ROSPoseWithCovarianceStamped() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "laser" + + # Set pose with covariance + ros_msg.pose = ROSPoseWithCovariance() + ros_msg.pose.pose = ROSPose() + ros_msg.pose.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.pose.covariance = [float(i) for i in range(36)] + + pose_cov_stamped = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + + assert pose_cov_stamped.ts == 1234567890.123456 + assert pose_cov_stamped.frame_id == "laser" + assert pose_cov_stamped.pose.position.x == 1.0 + assert pose_cov_stamped.pose.position.y == 2.0 + assert pose_cov_stamped.pose.position.z == 3.0 + assert pose_cov_stamped.pose.orientation.x == 0.1 + assert pose_cov_stamped.pose.orientation.y == 0.2 + assert pose_cov_stamped.pose.orientation.z == 0.3 + assert pose_cov_stamped.pose.orientation.w == 0.9 + assert np.array_equal(pose_cov_stamped.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_stamped_to_ros_msg() -> None: + """Test converting to ROS message.""" + ts = 1234567890.567890 + frame_id = "imu" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + pose_cov_stamped = PoseWithCovarianceStamped( + ts=ts, frame_id=frame_id, pose=pose, covariance=covariance + ) + + ros_msg = pose_cov_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseWithCovarianceStamped) + assert ros_msg.header.frame_id == frame_id + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + + assert ros_msg.pose.pose.position.x == 1.0 + assert ros_msg.pose.pose.position.y == 2.0 + assert ros_msg.pose.pose.position.z == 3.0 + assert ros_msg.pose.pose.orientation.x == 0.1 + assert ros_msg.pose.pose.orientation.y == 0.2 + assert ros_msg.pose.pose.orientation.z == 0.3 + assert ros_msg.pose.pose.orientation.w == 0.9 + assert list(ros_msg.pose.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_stamped_ros_roundtrip() -> None: + """Test round-trip conversion with ROS messages.""" + ts = 2147483647.987654 # Max int32 value for ROS Time.sec + frame_id = "robot_base" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + covariance = np.random.rand(36) + + original = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id, pose=pose, covariance=covariance) + + ros_msg = original.to_ros_msg() + restored = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + + # Check timestamp (loses some precision in conversion) + assert abs(restored.ts - ts) < 1e-6 + assert restored.frame_id == frame_id + + # Check pose + assert restored.pose.position.x == original.pose.position.x + assert restored.pose.position.y == original.pose.position.y + assert restored.pose.position.z == original.pose.position.z + assert restored.pose.orientation.x == original.pose.orientation.x + assert restored.pose.orientation.y == original.pose.orientation.y + assert restored.pose.orientation.z == original.pose.orientation.z + assert restored.pose.orientation.w == original.pose.orientation.w + + # Check covariance + assert np.allclose(restored.covariance, original.covariance) + + +def test_pose_with_covariance_stamped_zero_timestamp() -> None: + """Test that zero timestamp gets replaced with current time.""" + pose_cov_stamped = PoseWithCovarianceStamped(ts=0.0) + + # Should have been replaced with current time + assert pose_cov_stamped.ts > 0 + assert pose_cov_stamped.ts <= time.time() + + +def test_pose_with_covariance_stamped_inheritance() -> None: + """Test that it properly inherits from PoseWithCovariance and Timestamped.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.eye(6).flatten() + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="test", pose=pose, covariance=covariance + ) + + # Should be instance of parent classes + assert isinstance(pose_cov_stamped, PoseWithCovariance) + + # Should have Timestamped attributes + assert hasattr(pose_cov_stamped, "ts") + assert hasattr(pose_cov_stamped, "frame_id") + + # Should have PoseWithCovariance attributes + assert hasattr(pose_cov_stamped, "pose") + assert hasattr(pose_cov_stamped, "covariance") + + +def test_pose_with_covariance_stamped_sec_nsec() -> None: + """Test the sec_nsec helper function.""" + from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import sec_nsec + + # Test integer seconds + s, ns = sec_nsec(1234567890.0) + assert s == 1234567890 + assert ns == 0 + + # Test fractional seconds + s, ns = sec_nsec(1234567890.123456789) + assert s == 1234567890 + assert abs(ns - 123456789) < 100 # Allow small rounding error + + # Test small fractional seconds + s, ns = sec_nsec(0.000000001) + assert s == 0 + assert ns == 1 + + # Test large timestamp + s, ns = sec_nsec(9999999999.999999999) + # Due to floating point precision, this might round to 10000000000 + assert s in [9999999999, 10000000000] + if s == 9999999999: + assert abs(ns - 999999999) < 10 + else: + assert ns == 0 + + +@pytest.mark.ros +@pytest.mark.parametrize( + "frame_id", + ["", "map", "odom", "base_link", "camera_optical_frame", "sensor/lidar/front"], +) +def test_pose_with_covariance_stamped_frame_ids(frame_id) -> None: + """Test various frame ID values.""" + pose_cov_stamped = PoseWithCovarianceStamped(frame_id=frame_id) + assert pose_cov_stamped.frame_id == frame_id + + # Test roundtrip through ROS + ros_msg = pose_cov_stamped.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + + restored = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + + +def test_pose_with_covariance_stamped_different_covariances() -> None: + """Test with different covariance patterns.""" + pose = Pose(1.0, 2.0, 3.0) + + # Zero covariance + zero_cov = np.zeros(36) + pose_cov1 = PoseWithCovarianceStamped(pose=pose, covariance=zero_cov) + assert np.all(pose_cov1.covariance == 0.0) + + # Identity covariance + identity_cov = np.eye(6).flatten() + pose_cov2 = PoseWithCovarianceStamped(pose=pose, covariance=identity_cov) + assert np.trace(pose_cov2.covariance_matrix) == 6.0 + + # Full covariance + full_cov = np.random.rand(36) + pose_cov3 = PoseWithCovarianceStamped(pose=pose, covariance=full_cov) + assert np.array_equal(pose_cov3.covariance, full_cov) diff --git a/dimos/msgs/geometry_msgs/test_Quaternion.py b/dimos/msgs/geometry_msgs/test_Quaternion.py new file mode 100644 index 0000000000..21c1e8caeb --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_Quaternion.py @@ -0,0 +1,387 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs.Quaternion import Quaternion + + +def test_quaternion_default_init() -> None: + """Test that default initialization creates an identity quaternion (w=1, x=y=z=0).""" + q = Quaternion() + assert q.x == 0.0 + assert q.y == 0.0 + assert q.z == 0.0 + assert q.w == 1.0 + assert q.to_tuple() == (0.0, 0.0, 0.0, 1.0) + + +def test_quaternion_component_init() -> None: + """Test initialization with four float components (x, y, z, w).""" + q = Quaternion(0.5, 0.5, 0.5, 0.5) + assert q.x == 0.5 + assert q.y == 0.5 + assert q.z == 0.5 + assert q.w == 0.5 + + # Test with different values + q2 = Quaternion(1.0, 2.0, 3.0, 4.0) + assert q2.x == 1.0 + assert q2.y == 2.0 + assert q2.z == 3.0 + assert q2.w == 4.0 + + # Test with negative values + q3 = Quaternion(-1.0, -2.0, -3.0, -4.0) + assert q3.x == -1.0 + assert q3.y == -2.0 + assert q3.z == -3.0 + assert q3.w == -4.0 + + # Test with integers (should convert to float) + q4 = Quaternion(1, 2, 3, 4) + assert q4.x == 1.0 + assert q4.y == 2.0 + assert q4.z == 3.0 + assert q4.w == 4.0 + assert isinstance(q4.x, float) + + +def test_quaternion_sequence_init() -> None: + """Test initialization from sequence (list, tuple) of 4 numbers.""" + # From list + q1 = Quaternion([0.1, 0.2, 0.3, 0.4]) + assert q1.x == 0.1 + assert q1.y == 0.2 + assert q1.z == 0.3 + assert q1.w == 0.4 + + # From tuple + q2 = Quaternion((0.5, 0.6, 0.7, 0.8)) + assert q2.x == 0.5 + assert q2.y == 0.6 + assert q2.z == 0.7 + assert q2.w == 0.8 + + # Test with integers in sequence + q3 = Quaternion([1, 2, 3, 4]) + assert q3.x == 1.0 + assert q3.y == 2.0 + assert q3.z == 3.0 + assert q3.w == 4.0 + + # Test error with wrong length + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion([1, 2, 3]) # Only 3 components + + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion([1, 2, 3, 4, 5]) # Too many components + + +def test_quaternion_numpy_init() -> None: + """Test initialization from numpy array.""" + # From numpy array + arr = np.array([0.1, 0.2, 0.3, 0.4]) + q1 = Quaternion(arr) + assert q1.x == 0.1 + assert q1.y == 0.2 + assert q1.z == 0.3 + assert q1.w == 0.4 + + # Test with different dtypes + arr_int = np.array([1, 2, 3, 4], dtype=int) + q2 = Quaternion(arr_int) + assert q2.x == 1.0 + assert q2.y == 2.0 + assert q2.z == 3.0 + assert q2.w == 4.0 + + # Test error with wrong size + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion(np.array([1, 2, 3])) # Only 3 elements + + with pytest.raises(ValueError, match="Quaternion requires exactly 4 components"): + Quaternion(np.array([1, 2, 3, 4, 5])) # Too many elements + + +def test_quaternion_copy_init() -> None: + """Test initialization from another Quaternion (copy constructor).""" + original = Quaternion(0.1, 0.2, 0.3, 0.4) + copy = Quaternion(original) + + assert copy.x == 0.1 + assert copy.y == 0.2 + assert copy.z == 0.3 + assert copy.w == 0.4 + + # Verify it's a copy, not the same object + assert copy is not original + assert copy == original + + +def test_quaternion_lcm_init() -> None: + """Test initialization from LCM Quaternion.""" + lcm_quat = LCMQuaternion() + lcm_quat.x = 0.1 + lcm_quat.y = 0.2 + lcm_quat.z = 0.3 + lcm_quat.w = 0.4 + + q = Quaternion(lcm_quat) + assert q.x == 0.1 + assert q.y == 0.2 + assert q.z == 0.3 + assert q.w == 0.4 + + +def test_quaternion_properties() -> None: + """Test quaternion component properties.""" + q = Quaternion(1.0, 2.0, 3.0, 4.0) + + # Test property access + assert q.x == 1.0 + assert q.y == 2.0 + assert q.z == 3.0 + assert q.w == 4.0 + + # Test as_tuple property + assert q.to_tuple() == (1.0, 2.0, 3.0, 4.0) + + +def test_quaternion_indexing() -> None: + """Test quaternion indexing support.""" + q = Quaternion(1.0, 2.0, 3.0, 4.0) + + # Test indexing + assert q[0] == 1.0 + assert q[1] == 2.0 + assert q[2] == 3.0 + assert q[3] == 4.0 + + +def test_quaternion_euler() -> None: + """Test quaternion to Euler angles conversion.""" + + # Test identity quaternion (should give zero angles) + q_identity = Quaternion() + angles = q_identity.to_euler() + assert np.isclose(angles.x, 0.0, atol=1e-10) # roll + assert np.isclose(angles.y, 0.0, atol=1e-10) # pitch + assert np.isclose(angles.z, 0.0, atol=1e-10) # yaw + + # Test 90 degree rotation around Z-axis (yaw) + q_z90 = Quaternion(0, 0, np.sin(np.pi / 4), np.cos(np.pi / 4)) + angles_z90 = q_z90.to_euler() + assert np.isclose(angles_z90.roll, 0.0, atol=1e-10) # roll should be 0 + assert np.isclose(angles_z90.pitch, 0.0, atol=1e-10) # pitch should be 0 + assert np.isclose(angles_z90.yaw, np.pi / 2, atol=1e-10) # yaw should be π/2 (90 degrees) + + # Test 90 degree rotation around X-axis (roll) + q_x90 = Quaternion(np.sin(np.pi / 4), 0, 0, np.cos(np.pi / 4)) + angles_x90 = q_x90.to_euler() + assert np.isclose(angles_x90.x, np.pi / 2, atol=1e-10) # roll should be π/2 + assert np.isclose(angles_x90.y, 0.0, atol=1e-10) # pitch should be 0 + assert np.isclose(angles_x90.z, 0.0, atol=1e-10) # yaw should be 0 + + +def test_lcm_encode_decode() -> None: + """Test encoding and decoding of Quaternion to/from binary LCM format.""" + q_source = Quaternion(1.0, 2.0, 3.0, 4.0) + + binary_msg = q_source.lcm_encode() + + q_dest = Quaternion.lcm_decode(binary_msg) + + assert isinstance(q_dest, Quaternion) + assert q_dest is not q_source + assert q_dest == q_source + + +def test_quaternion_multiplication() -> None: + """Test quaternion multiplication (Hamilton product).""" + # Test identity multiplication + q1 = Quaternion(0.5, 0.5, 0.5, 0.5) + identity = Quaternion(0, 0, 0, 1) + + result = q1 * identity + assert np.allclose([result.x, result.y, result.z, result.w], [q1.x, q1.y, q1.z, q1.w]) + + # Test multiplication order matters (non-commutative) + q2 = Quaternion(0.1, 0.2, 0.3, 0.4) + q3 = Quaternion(0.4, 0.3, 0.2, 0.1) + + result1 = q2 * q3 + result2 = q3 * q2 + + # Results should be different + assert not np.allclose( + [result1.x, result1.y, result1.z, result1.w], [result2.x, result2.y, result2.z, result2.w] + ) + + # Test specific multiplication case + # 90 degree rotations around Z axis + angle = np.pi / 2 + q_90z = Quaternion(0, 0, np.sin(angle / 2), np.cos(angle / 2)) + + # Two 90 degree rotations should give 180 degrees + result = q_90z * q_90z + expected_angle = np.pi + assert np.isclose(result.x, 0, atol=1e-10) + assert np.isclose(result.y, 0, atol=1e-10) + assert np.isclose(result.z, np.sin(expected_angle / 2), atol=1e-10) + assert np.isclose(result.w, np.cos(expected_angle / 2), atol=1e-10) + + +def test_quaternion_conjugate() -> None: + """Test quaternion conjugate.""" + q = Quaternion(0.1, 0.2, 0.3, 0.4) + conj = q.conjugate() + + # Conjugate should negate x, y, z but keep w + assert conj.x == -q.x + assert conj.y == -q.y + assert conj.z == -q.z + assert conj.w == q.w + + # Test that q * q^* gives a real quaternion (x=y=z=0) + result = q * conj + assert np.isclose(result.x, 0, atol=1e-10) + assert np.isclose(result.y, 0, atol=1e-10) + assert np.isclose(result.z, 0, atol=1e-10) + # w should be the squared norm + expected_w = q.x**2 + q.y**2 + q.z**2 + q.w**2 + assert np.isclose(result.w, expected_w, atol=1e-10) + + +def test_quaternion_inverse() -> None: + """Test quaternion inverse.""" + # Test with unit quaternion + q_unit = Quaternion(0, 0, 0, 1).normalize() # Already normalized but being explicit + inv = q_unit.inverse() + + # For unit quaternion, inverse equals conjugate + conj = q_unit.conjugate() + assert np.allclose([inv.x, inv.y, inv.z, inv.w], [conj.x, conj.y, conj.z, conj.w]) + + # Test that q * q^-1 = identity + q = Quaternion(0.5, 0.5, 0.5, 0.5) + inv = q.inverse() + result = q * inv + + assert np.isclose(result.x, 0, atol=1e-10) + assert np.isclose(result.y, 0, atol=1e-10) + assert np.isclose(result.z, 0, atol=1e-10) + assert np.isclose(result.w, 1, atol=1e-10) + + # Test inverse of non-unit quaternion + q_non_unit = Quaternion(2, 0, 0, 0) # Non-unit quaternion + inv = q_non_unit.inverse() + result = q_non_unit * inv + + assert np.isclose(result.x, 0, atol=1e-10) + assert np.isclose(result.y, 0, atol=1e-10) + assert np.isclose(result.z, 0, atol=1e-10) + assert np.isclose(result.w, 1, atol=1e-10) + + +def test_quaternion_normalize() -> None: + """Test quaternion normalization.""" + # Test non-unit quaternion + q = Quaternion(1, 2, 3, 4) + q_norm = q.normalize() + + # Check that magnitude is 1 + magnitude = np.sqrt(q_norm.x**2 + q_norm.y**2 + q_norm.z**2 + q_norm.w**2) + assert np.isclose(magnitude, 1.0, atol=1e-10) + + # Check that direction is preserved + scale = np.sqrt(q.x**2 + q.y**2 + q.z**2 + q.w**2) + assert np.isclose(q_norm.x, q.x / scale, atol=1e-10) + assert np.isclose(q_norm.y, q.y / scale, atol=1e-10) + assert np.isclose(q_norm.z, q.z / scale, atol=1e-10) + assert np.isclose(q_norm.w, q.w / scale, atol=1e-10) + + +def test_quaternion_rotate_vector() -> None: + """Test rotating vectors with quaternions.""" + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + # Test rotation of unit vectors + # 90 degree rotation around Z axis + angle = np.pi / 2 + q_rot = Quaternion(0, 0, np.sin(angle / 2), np.cos(angle / 2)) + + # Rotate X unit vector + v_x = Vector3(1, 0, 0) + v_rotated = q_rot.rotate_vector(v_x) + + # Should now point along Y axis + assert np.isclose(v_rotated.x, 0, atol=1e-10) + assert np.isclose(v_rotated.y, 1, atol=1e-10) + assert np.isclose(v_rotated.z, 0, atol=1e-10) + + # Rotate Y unit vector + v_y = Vector3(0, 1, 0) + v_rotated = q_rot.rotate_vector(v_y) + + # Should now point along negative X axis + assert np.isclose(v_rotated.x, -1, atol=1e-10) + assert np.isclose(v_rotated.y, 0, atol=1e-10) + assert np.isclose(v_rotated.z, 0, atol=1e-10) + + # Test that Z vector is unchanged (rotation axis) + v_z = Vector3(0, 0, 1) + v_rotated = q_rot.rotate_vector(v_z) + + assert np.isclose(v_rotated.x, 0, atol=1e-10) + assert np.isclose(v_rotated.y, 0, atol=1e-10) + assert np.isclose(v_rotated.z, 1, atol=1e-10) + + # Test identity rotation + q_identity = Quaternion(0, 0, 0, 1) + v = Vector3(1, 2, 3) + v_rotated = q_identity.rotate_vector(v) + + assert np.isclose(v_rotated.x, v.x, atol=1e-10) + assert np.isclose(v_rotated.y, v.y, atol=1e-10) + assert np.isclose(v_rotated.z, v.z, atol=1e-10) + + +def test_quaternion_inverse_zero() -> None: + """Test that inverting zero quaternion raises error.""" + q_zero = Quaternion(0, 0, 0, 0) + + with pytest.raises(ZeroDivisionError, match="Cannot invert zero quaternion"): + q_zero.inverse() + + +def test_quaternion_normalize_zero() -> None: + """Test that normalizing zero quaternion raises error.""" + q_zero = Quaternion(0, 0, 0, 0) + + with pytest.raises(ZeroDivisionError, match="Cannot normalize zero quaternion"): + q_zero.normalize() + + +def test_quaternion_multiplication_type_error() -> None: + """Test that multiplying quaternion with non-quaternion raises error.""" + q = Quaternion(1, 0, 0, 0) + + with pytest.raises(TypeError, match="Cannot multiply Quaternion with"): + q * 5.0 + + with pytest.raises(TypeError, match="Cannot multiply Quaternion with"): + q * [1, 2, 3, 4] diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py new file mode 100644 index 0000000000..be3baee6cb --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -0,0 +1,510 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time + +import numpy as np +import pytest + +try: + from geometry_msgs.msg import TransformStamped as ROSTransformStamped +except ImportError: + ROSTransformStamped = None + + +from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 + + +def test_transform_initialization() -> None: + # Test default initialization (identity transform) + tf = Transform() + assert tf.translation.x == 0.0 + assert tf.translation.y == 0.0 + assert tf.translation.z == 0.0 + assert tf.rotation.x == 0.0 + assert tf.rotation.y == 0.0 + assert tf.rotation.z == 0.0 + assert tf.rotation.w == 1.0 + + # Test initialization with Vector3 and Quaternion + trans = Vector3(1.0, 2.0, 3.0) + rot = Quaternion(0.0, 0.0, 0.707107, 0.707107) # 90 degrees around Z + tf2 = Transform(translation=trans, rotation=rot) + assert tf2.translation == trans + assert tf2.rotation == rot + + # Test initialization with only translation + tf5 = Transform(translation=Vector3(7.0, 8.0, 9.0)) + assert tf5.translation.x == 7.0 + assert tf5.translation.y == 8.0 + assert tf5.translation.z == 9.0 + assert tf5.rotation.w == 1.0 # Identity rotation + + # Test initialization with only rotation + tf6 = Transform(rotation=Quaternion(0.0, 0.0, 0.0, 1.0)) + assert tf6.translation.is_zero() # Zero translation + assert tf6.rotation.w == 1.0 + + # Test keyword argument initialization + tf7 = Transform(translation=Vector3(1, 2, 3), rotation=Quaternion()) + assert tf7.translation == Vector3(1, 2, 3) + assert tf7.rotation == Quaternion() + + # Test keyword with only translation + tf8 = Transform(translation=Vector3(4, 5, 6)) + assert tf8.translation == Vector3(4, 5, 6) + assert tf8.rotation.w == 1.0 + + # Test keyword with only rotation + tf9 = Transform(rotation=Quaternion(0, 0, 1, 0)) + assert tf9.translation.is_zero() + assert tf9.rotation == Quaternion(0, 0, 1, 0) + + +def test_transform_identity() -> None: + # Test identity class method + tf = Transform.identity() + assert tf.translation.is_zero() + assert tf.rotation.x == 0.0 + assert tf.rotation.y == 0.0 + assert tf.rotation.z == 0.0 + assert tf.rotation.w == 1.0 + + # Identity should equal default constructor + assert tf == Transform() + + +def test_transform_equality() -> None: + tf1 = Transform(translation=Vector3(1, 2, 3), rotation=Quaternion(0, 0, 0, 1)) + tf2 = Transform(translation=Vector3(1, 2, 3), rotation=Quaternion(0, 0, 0, 1)) + tf3 = Transform(translation=Vector3(1, 2, 4), rotation=Quaternion(0, 0, 0, 1)) # Different z + tf4 = Transform( + translation=Vector3(1, 2, 3), rotation=Quaternion(0, 0, 1, 0) + ) # Different rotation + + assert tf1 == tf2 + assert tf1 != tf3 + assert tf1 != tf4 + assert tf1 != "not a transform" + + +def test_transform_string_representations() -> None: + tf = Transform( + translation=Vector3(1.5, -2.0, 3.14), rotation=Quaternion(0, 0, 0.707107, 0.707107) + ) + + # Test repr + repr_str = repr(tf) + assert "Transform" in repr_str + assert "translation=" in repr_str + assert "rotation=" in repr_str + assert "1.5" in repr_str + + # Test str + str_str = str(tf) + assert "Transform:" in str_str + assert "Translation:" in str_str + assert "Rotation:" in str_str + + +def test_pose_add_transform() -> None: + initial_pose = Pose(1.0, 0.0, 0.0) + + # 90 degree rotation around Z axis + angle = np.pi / 2 + transform = Transform( + translation=Vector3(2.0, 1.0, 0.0), + rotation=Quaternion(0.0, 0.0, np.sin(angle / 2), np.cos(angle / 2)), + ) + + transformed_pose = initial_pose @ transform + + # - Translation (2, 1, 0) is added directly to position (1, 0, 0) + # - Result position: (3, 1, 0) + assert np.isclose(transformed_pose.position.x, 3.0, atol=1e-10) + assert np.isclose(transformed_pose.position.y, 1.0, atol=1e-10) + assert np.isclose(transformed_pose.position.z, 0.0, atol=1e-10) + + # Rotation should be 90 degrees around Z + assert np.isclose(transformed_pose.orientation.x, 0.0, atol=1e-10) + assert np.isclose(transformed_pose.orientation.y, 0.0, atol=1e-10) + assert np.isclose(transformed_pose.orientation.z, np.sin(angle / 2), atol=1e-10) + assert np.isclose(transformed_pose.orientation.w, np.cos(angle / 2), atol=1e-10) + + initial_pose_stamped = PoseStamped( + position=initial_pose.position, orientation=initial_pose.orientation + ) + transformed_pose_stamped = PoseStamped( + position=transformed_pose.position, orientation=transformed_pose.orientation + ) + + found_tf = initial_pose_stamped.find_transform(transformed_pose_stamped) + + assert found_tf.translation == transform.translation + assert found_tf.rotation == transform.rotation + assert found_tf.translation.x == transform.translation.x + assert found_tf.translation.y == transform.translation.y + assert found_tf.translation.z == transform.translation.z + + assert found_tf.rotation.x == transform.rotation.x + assert found_tf.rotation.y == transform.rotation.y + assert found_tf.rotation.z == transform.rotation.z + assert found_tf.rotation.w == transform.rotation.w + + print(found_tf.rotation, found_tf.translation) + + +def test_pose_add_transform_with_rotation() -> None: + # Create a pose at (0, 0, 0) rotated 90 degrees around Z + angle = np.pi / 2 + initial_pose = Pose(0.0, 0.0, 0.0, 0.0, 0.0, np.sin(angle / 2), np.cos(angle / 2)) + + # Add 45 degree rotation to transform1 + rotation_angle = np.pi / 4 # 45 degrees + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + rotation=Quaternion( + 0.0, 0.0, np.sin(rotation_angle / 2), np.cos(rotation_angle / 2) + ), # 45� around Z + ) + + transform2 = Transform( + translation=Vector3(0.0, 1.0, 1.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # No rotation + ) + + transformed_pose1 = initial_pose @ transform1 + transformed_pose2 = initial_pose @ transform1 @ transform2 + + # Test transformed_pose1: initial_pose + transform1 + # Since the pose is rotated 90� (facing +Y), moving forward (local X) + # means moving in the +Y direction in world frame + assert np.isclose(transformed_pose1.position.x, 0.0, atol=1e-10) + assert np.isclose(transformed_pose1.position.y, 1.0, atol=1e-10) + assert np.isclose(transformed_pose1.position.z, 0.0, atol=1e-10) + + # Orientation should be 90� + 45� = 135� around Z + total_angle1 = angle + rotation_angle # 135 degrees + assert np.isclose(transformed_pose1.orientation.x, 0.0, atol=1e-10) + assert np.isclose(transformed_pose1.orientation.y, 0.0, atol=1e-10) + assert np.isclose(transformed_pose1.orientation.z, np.sin(total_angle1 / 2), atol=1e-10) + assert np.isclose(transformed_pose1.orientation.w, np.cos(total_angle1 / 2), atol=1e-10) + + # Test transformed_pose2: initial_pose + transform1 + transform2 + # Starting from (0, 0, 0) facing 90�: + # + # - Apply transform1: move 1 forward (along +Y) � (0, 1, 0), now facing 135� + # + # - Apply transform2: move 1 in local Y and 1 up + # At 135�, local Y points at 225� (135� + 90�) + # + # x += cos(225�) = -2/2, y += sin(225�) = -2/2 + sqrt2_2 = np.sqrt(2) / 2 + expected_x = 0.0 - sqrt2_2 # 0 - 2/2 H -0.707 + expected_y = 1.0 - sqrt2_2 # 1 - 2/2 H 0.293 + expected_z = 1.0 # 0 + 1 + + assert np.isclose(transformed_pose2.position.x, expected_x, atol=1e-10) + assert np.isclose(transformed_pose2.position.y, expected_y, atol=1e-10) + assert np.isclose(transformed_pose2.position.z, expected_z, atol=1e-10) + + # Orientation should be 135� (only transform1 has rotation) + total_angle2 = total_angle1 # 135 degrees (transform2 has no rotation) + assert np.isclose(transformed_pose2.orientation.x, 0.0, atol=1e-10) + assert np.isclose(transformed_pose2.orientation.y, 0.0, atol=1e-10) + assert np.isclose(transformed_pose2.orientation.z, np.sin(total_angle2 / 2), atol=1e-10) + assert np.isclose(transformed_pose2.orientation.w, np.cos(total_angle2 / 2), atol=1e-10) + + +def test_lcm_encode_decode() -> None: + angle = np.pi / 2 + transform = Transform( + translation=Vector3(2.0, 1.0, 0.0), + rotation=Quaternion(0.0, 0.0, np.sin(angle / 2), np.cos(angle / 2)), + ) + + data = transform.lcm_encode() + + decoded_transform = Transform.lcm_decode(data) + + assert decoded_transform == transform + + +def test_transform_addition() -> None: + # Test 1: Simple translation addition (no rotation) + t1 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity rotation + ) + t2 = Transform( + translation=Vector3(2, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity rotation + ) + t3 = t1 + t2 + assert t3.translation == Vector3(3, 0, 0) + assert t3.rotation == Quaternion(0, 0, 0, 1) + + # Test 2: 90-degree rotation composition + # First transform: move 1 unit in X + t1 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity + ) + # Second transform: move 1 unit in X with 90-degree rotation around Z + angle = np.pi / 2 + t2 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, np.sin(angle / 2), np.cos(angle / 2)), + ) + t3 = t1 + t2 + assert t3.translation == Vector3(2, 0, 0) + # Rotation should be 90 degrees around Z + assert np.isclose(t3.rotation.x, 0.0, atol=1e-10) + assert np.isclose(t3.rotation.y, 0.0, atol=1e-10) + assert np.isclose(t3.rotation.z, np.sin(angle / 2), atol=1e-10) + assert np.isclose(t3.rotation.w, np.cos(angle / 2), atol=1e-10) + + # Test 3: Rotation affects translation + # First transform: 90-degree rotation around Z + t1 = Transform( + translation=Vector3(0, 0, 0), + rotation=Quaternion(0, 0, np.sin(angle / 2), np.cos(angle / 2)), # 90° around Z + ) + # Second transform: move 1 unit in X + t2 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity + ) + t3 = t1 + t2 + # X direction rotated 90° becomes Y direction + assert np.isclose(t3.translation.x, 0.0, atol=1e-10) + assert np.isclose(t3.translation.y, 1.0, atol=1e-10) + assert np.isclose(t3.translation.z, 0.0, atol=1e-10) + # Rotation remains 90° around Z + assert np.isclose(t3.rotation.z, np.sin(angle / 2), atol=1e-10) + assert np.isclose(t3.rotation.w, np.cos(angle / 2), atol=1e-10) + + # Test 4: Frame tracking + t1 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), + frame_id="world", + child_frame_id="robot", + ) + t2 = Transform( + translation=Vector3(2, 0, 0), + rotation=Quaternion(0, 0, 0, 1), + frame_id="robot", + child_frame_id="sensor", + ) + t3 = t1 + t2 + assert t3.frame_id == "world" + assert t3.child_frame_id == "sensor" + + # Test 5: Type error + with pytest.raises(TypeError): + t1 + "not a transform" + + +def test_transform_from_pose() -> None: + """Test converting Pose to Transform""" + # Create a Pose with position and orientation + pose = Pose( + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.707, 0.707), # 90 degrees around Z + ) + + # Convert to Transform + transform = Transform.from_pose("base_link", pose) + + # Check that translation and rotation match + assert transform.translation == pose.position + assert transform.rotation == pose.orientation + assert transform.frame_id == "world" # default frame_id + assert transform.child_frame_id == "base_link" # passed as first argument + + +# validating results from example @ +# https://foxglove.dev/blog/understanding-ros-transforms +def test_transform_from_ros() -> None: + """Test converting PoseStamped to Transform""" + test_time = time.time() + pose_stamped = PoseStamped( + ts=test_time, + frame_id="base_link", + position=Vector3(1, -1, 0), + orientation=Quaternion.from_euler(Vector3(0, 0, math.pi / 6)), + ) + transform_base_link_to_arm = Transform.from_pose("arm_base_link", pose_stamped) + + transform_arm_to_end = Transform.from_pose( + "end", + PoseStamped( + ts=test_time, + frame_id="arm_base_link", + position=Vector3(1, 1, 0), + orientation=Quaternion.from_euler(Vector3(0, 0, math.pi / 6)), + ), + ) + + print(transform_base_link_to_arm) + print(transform_arm_to_end) + + end_effector_global_pose = transform_base_link_to_arm + transform_arm_to_end + + assert end_effector_global_pose.translation.x == pytest.approx(1.366, abs=1e-3) + assert end_effector_global_pose.translation.y == pytest.approx(0.366, abs=1e-3) + + +def test_transform_from_pose_stamped() -> None: + """Test converting PoseStamped to Transform""" + # Create a PoseStamped with position, orientation, timestamp and frame + test_time = time.time() + pose_stamped = PoseStamped( + ts=test_time, + frame_id="map", + position=Vector3(4.0, 5.0, 6.0), + orientation=Quaternion(0.0, 0.707, 0.0, 0.707), # 90 degrees around Y + ) + + # Convert to Transform + transform = Transform.from_pose("robot_base", pose_stamped) + + # Check that all fields match + assert transform.translation == pose_stamped.position + assert transform.rotation == pose_stamped.orientation + assert transform.frame_id == pose_stamped.frame_id + assert transform.ts == pose_stamped.ts + assert transform.child_frame_id == "robot_base" # passed as first argument + + +def test_transform_from_pose_variants() -> None: + """Test from_pose with different Pose initialization methods""" + # Test with Pose created from x,y,z + pose1 = Pose(1.0, 2.0, 3.0) + transform1 = Transform.from_pose("base_link", pose1) + assert transform1.translation.x == 1.0 + assert transform1.translation.y == 2.0 + assert transform1.translation.z == 3.0 + assert transform1.rotation.w == 1.0 # Identity quaternion + + # Test with Pose created from tuple + pose2 = Pose(([7.0, 8.0, 9.0], [0.0, 0.0, 0.0, 1.0])) + transform2 = Transform.from_pose("base_link", pose2) + assert transform2.translation.x == 7.0 + assert transform2.translation.y == 8.0 + assert transform2.translation.z == 9.0 + + # Test with Pose created from dict + pose3 = Pose({"position": [10.0, 11.0, 12.0], "orientation": [0.0, 0.0, 0.0, 1.0]}) + transform3 = Transform.from_pose("base_link", pose3) + assert transform3.translation.x == 10.0 + assert transform3.translation.y == 11.0 + assert transform3.translation.z == 12.0 + + +def test_transform_from_pose_invalid_type() -> None: + """Test that from_pose raises TypeError for invalid types""" + with pytest.raises(TypeError): + Transform.from_pose("not a pose") + + with pytest.raises(TypeError): + Transform.from_pose(42) + + with pytest.raises(TypeError): + Transform.from_pose(None) + + +@pytest.mark.ros +def test_transform_from_ros_transform_stamped() -> None: + """Test creating a Transform from a ROS TransformStamped message.""" + ros_msg = ROSTransformStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.child_frame_id = "robot" + ros_msg.transform.translation.x = 1.0 + ros_msg.transform.translation.y = 2.0 + ros_msg.transform.translation.z = 3.0 + ros_msg.transform.rotation.x = 0.1 + ros_msg.transform.rotation.y = 0.2 + ros_msg.transform.rotation.z = 0.3 + ros_msg.transform.rotation.w = 0.9 + + transform = Transform.from_ros_transform_stamped(ros_msg) + + assert transform.frame_id == "world" + assert transform.child_frame_id == "robot" + assert transform.ts == 123.456 + assert transform.translation.x == 1.0 + assert transform.translation.y == 2.0 + assert transform.translation.z == 3.0 + assert transform.rotation.x == 0.1 + assert transform.rotation.y == 0.2 + assert transform.rotation.z == 0.3 + assert transform.rotation.w == 0.9 + + +@pytest.mark.ros +def test_transform_to_ros_transform_stamped() -> None: + """Test converting a Transform to a ROS TransformStamped message.""" + transform = Transform( + translation=Vector3(4.0, 5.0, 6.0), + rotation=Quaternion(0.15, 0.25, 0.35, 0.85), + frame_id="base_link", + child_frame_id="sensor", + ts=124.789, + ) + + ros_msg = transform.to_ros_transform_stamped() + + assert isinstance(ros_msg, ROSTransformStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.child_frame_id == "sensor" + assert ros_msg.header.stamp.sec == 124 + assert ros_msg.header.stamp.nanosec == 789000000 + assert ros_msg.transform.translation.x == 4.0 + assert ros_msg.transform.translation.y == 5.0 + assert ros_msg.transform.translation.z == 6.0 + assert ros_msg.transform.rotation.x == 0.15 + assert ros_msg.transform.rotation.y == 0.25 + assert ros_msg.transform.rotation.z == 0.35 + assert ros_msg.transform.rotation.w == 0.85 + + +@pytest.mark.ros +def test_transform_ros_roundtrip() -> None: + """Test round-trip conversion between Transform and ROS TransformStamped.""" + original = Transform( + translation=Vector3(7.5, 8.5, 9.5), + rotation=Quaternion(0.0, 0.0, 0.383, 0.924), # ~45 degrees around Z + frame_id="odom", + child_frame_id="base_footprint", + ts=99.123, + ) + + ros_msg = original.to_ros_transform_stamped() + restored = Transform.from_ros_transform_stamped(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.child_frame_id == original.child_frame_id + assert restored.ts == original.ts + assert restored.translation.x == original.translation.x + assert restored.translation.y == original.translation.y + assert restored.translation.z == original.translation.z + assert restored.rotation.x == original.rotation.x + assert restored.rotation.y == original.rotation.y + assert restored.rotation.z == original.rotation.z + assert restored.rotation.w == original.rotation.w diff --git a/dimos/msgs/geometry_msgs/test_Twist.py b/dimos/msgs/geometry_msgs/test_Twist.py new file mode 100644 index 0000000000..f83ffa3fdd --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_Twist.py @@ -0,0 +1,301 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +try: + from geometry_msgs.msg import Twist as ROSTwist, Vector3 as ROSVector3 +except ImportError: + ROSTwist = None + ROSVector3 = None + +from dimos_lcm.geometry_msgs import Twist as LCMTwist + +from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 + + +def test_twist_initialization() -> None: + # Test default initialization (zero twist) + tw = Twist() + assert tw.linear.x == 0.0 + assert tw.linear.y == 0.0 + assert tw.linear.z == 0.0 + assert tw.angular.x == 0.0 + assert tw.angular.y == 0.0 + assert tw.angular.z == 0.0 + + # Test initialization with Vector3 linear and angular + lin = Vector3(1.0, 2.0, 3.0) + ang = Vector3(0.1, 0.2, 0.3) + tw2 = Twist(lin, ang) + assert tw2.linear == lin + assert tw2.angular == ang + + # Test copy constructor + tw3 = Twist(tw2) + assert tw3.linear == tw2.linear + assert tw3.angular == tw2.angular + assert tw3 == tw2 + # Ensure it's a deep copy + tw3.linear.x = 10.0 + assert tw2.linear.x == 1.0 + + # Test initialization from LCM Twist + lcm_tw = LCMTwist() + lcm_tw.linear = Vector3(4.0, 5.0, 6.0) + lcm_tw.angular = Vector3(0.4, 0.5, 0.6) + tw4 = Twist(lcm_tw) + assert tw4.linear.x == 4.0 + assert tw4.linear.y == 5.0 + assert tw4.linear.z == 6.0 + assert tw4.angular.x == 0.4 + assert tw4.angular.y == 0.5 + assert tw4.angular.z == 0.6 + + # Test initialization with linear and angular as quaternion + quat = Quaternion(0, 0, 0.707107, 0.707107) # 90 degrees around Z + tw5 = Twist(Vector3(1.0, 2.0, 3.0), quat) + assert tw5.linear == Vector3(1.0, 2.0, 3.0) + # Quaternion should be converted to euler angles + euler = quat.to_euler() + assert np.allclose(tw5.angular.x, euler.x) + assert np.allclose(tw5.angular.y, euler.y) + assert np.allclose(tw5.angular.z, euler.z) + + # Test keyword argument initialization + tw7 = Twist(linear=Vector3(1, 2, 3), angular=Vector3(0.1, 0.2, 0.3)) + assert tw7.linear == Vector3(1, 2, 3) + assert tw7.angular == Vector3(0.1, 0.2, 0.3) + + # Test keyword with only linear + tw8 = Twist(linear=Vector3(4, 5, 6)) + assert tw8.linear == Vector3(4, 5, 6) + assert tw8.angular.is_zero() + + # Test keyword with only angular + tw9 = Twist(angular=Vector3(0.4, 0.5, 0.6)) + assert tw9.linear.is_zero() + assert tw9.angular == Vector3(0.4, 0.5, 0.6) + + # Test keyword with angular as quaternion + tw10 = Twist(angular=Quaternion(0, 0, 0.707107, 0.707107)) + assert tw10.linear.is_zero() + euler = Quaternion(0, 0, 0.707107, 0.707107).to_euler() + assert np.allclose(tw10.angular.x, euler.x) + assert np.allclose(tw10.angular.y, euler.y) + assert np.allclose(tw10.angular.z, euler.z) + + # Test keyword with linear and angular as quaternion + tw11 = Twist(linear=Vector3(1, 0, 0), angular=Quaternion(0, 0, 0, 1)) + assert tw11.linear == Vector3(1, 0, 0) + assert tw11.angular.is_zero() # Identity quaternion -> zero euler angles + + +def test_twist_zero() -> None: + # Test zero class method + tw = Twist.zero() + assert tw.linear.is_zero() + assert tw.angular.is_zero() + assert tw.is_zero() + + # Zero should equal default constructor + assert tw == Twist() + + +def test_twist_equality() -> None: + tw1 = Twist(Vector3(1, 2, 3), Vector3(0.1, 0.2, 0.3)) + tw2 = Twist(Vector3(1, 2, 3), Vector3(0.1, 0.2, 0.3)) + tw3 = Twist(Vector3(1, 2, 4), Vector3(0.1, 0.2, 0.3)) # Different linear z + tw4 = Twist(Vector3(1, 2, 3), Vector3(0.1, 0.2, 0.4)) # Different angular z + + assert tw1 == tw2 + assert tw1 != tw3 + assert tw1 != tw4 + assert tw1 != "not a twist" + + +def test_twist_string_representations() -> None: + tw = Twist(Vector3(1.5, -2.0, 3.14), Vector3(0.1, -0.2, 0.3)) + + # Test repr + repr_str = repr(tw) + assert "Twist" in repr_str + assert "linear=" in repr_str + assert "angular=" in repr_str + assert "1.5" in repr_str + assert "0.1" in repr_str + + # Test str + str_str = str(tw) + assert "Twist:" in str_str + assert "Linear:" in str_str + assert "Angular:" in str_str + + +def test_twist_is_zero() -> None: + # Test zero twist + tw1 = Twist() + assert tw1.is_zero() + + # Test non-zero linear + tw2 = Twist(linear=Vector3(0.1, 0, 0)) + assert not tw2.is_zero() + + # Test non-zero angular + tw3 = Twist(angular=Vector3(0, 0, 0.1)) + assert not tw3.is_zero() + + # Test both non-zero + tw4 = Twist(Vector3(1, 2, 3), Vector3(0.1, 0.2, 0.3)) + assert not tw4.is_zero() + + +def test_twist_bool() -> None: + # Test zero twist is False + tw1 = Twist() + assert not tw1 + + # Test non-zero twist is True + tw2 = Twist(linear=Vector3(1, 0, 0)) + assert tw2 + + tw3 = Twist(angular=Vector3(0, 0, 0.1)) + assert tw3 + + tw4 = Twist(Vector3(1, 2, 3), Vector3(0.1, 0.2, 0.3)) + assert tw4 + + +def test_twist_lcm_encoding() -> None: + # Test encoding and decoding + tw = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.1, 0.2, 0.3)) + + # Encode + encoded = tw.lcm_encode() + assert isinstance(encoded, bytes) + + # Decode + decoded = Twist.lcm_decode(encoded) + assert decoded.linear == tw.linear + assert decoded.angular == tw.angular + + assert isinstance(decoded.linear, Vector3) + assert decoded == tw + + +def test_twist_with_lists() -> None: + # Test initialization with lists instead of Vector3 + tw1 = Twist(linear=[1, 2, 3], angular=[0.1, 0.2, 0.3]) + assert tw1.linear == Vector3(1, 2, 3) + assert tw1.angular == Vector3(0.1, 0.2, 0.3) + + # Test with numpy arrays + tw2 = Twist(linear=np.array([4, 5, 6]), angular=np.array([0.4, 0.5, 0.6])) + assert tw2.linear == Vector3(4, 5, 6) + assert tw2.angular == Vector3(0.4, 0.5, 0.6) + + +@pytest.mark.ros +def test_twist_from_ros_msg() -> None: + """Test Twist.from_ros_msg conversion.""" + # Create ROS message + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=10.0, y=20.0, z=30.0) + ros_msg.angular = ROSVector3(x=1.0, y=2.0, z=3.0) + + # Convert to LCM + lcm_msg = Twist.from_ros_msg(ros_msg) + + assert isinstance(lcm_msg, Twist) + assert lcm_msg.linear.x == 10.0 + assert lcm_msg.linear.y == 20.0 + assert lcm_msg.linear.z == 30.0 + assert lcm_msg.angular.x == 1.0 + assert lcm_msg.angular.y == 2.0 + assert lcm_msg.angular.z == 3.0 + + +@pytest.mark.ros +def test_twist_to_ros_msg() -> None: + """Test Twist.to_ros_msg conversion.""" + # Create LCM message + lcm_msg = Twist(linear=Vector3(40.0, 50.0, 60.0), angular=Vector3(4.0, 5.0, 6.0)) + + # Convert to ROS + ros_msg = lcm_msg.to_ros_msg() + + assert isinstance(ros_msg, ROSTwist) + assert ros_msg.linear.x == 40.0 + assert ros_msg.linear.y == 50.0 + assert ros_msg.linear.z == 60.0 + assert ros_msg.angular.x == 4.0 + assert ros_msg.angular.y == 5.0 + assert ros_msg.angular.z == 6.0 + + +@pytest.mark.ros +def test_ros_zero_twist_conversion() -> None: + """Test conversion of zero twist messages between ROS and LCM.""" + # Test ROS to LCM with zero twist + ros_zero = ROSTwist() + lcm_zero = Twist.from_ros_msg(ros_zero) + assert lcm_zero.is_zero() + + # Test LCM to ROS with zero twist + lcm_zero2 = Twist.zero() + ros_zero2 = lcm_zero2.to_ros_msg() + assert ros_zero2.linear.x == 0.0 + assert ros_zero2.linear.y == 0.0 + assert ros_zero2.linear.z == 0.0 + assert ros_zero2.angular.x == 0.0 + assert ros_zero2.angular.y == 0.0 + assert ros_zero2.angular.z == 0.0 + + +@pytest.mark.ros +def test_ros_negative_values_conversion() -> None: + """Test ROS conversion with negative values.""" + # Create ROS message with negative values + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=-1.5, y=-2.5, z=-3.5) + ros_msg.angular = ROSVector3(x=-0.1, y=-0.2, z=-0.3) + + # Convert to LCM and back + lcm_msg = Twist.from_ros_msg(ros_msg) + ros_msg2 = lcm_msg.to_ros_msg() + + assert ros_msg2.linear.x == -1.5 + assert ros_msg2.linear.y == -2.5 + assert ros_msg2.linear.z == -3.5 + assert ros_msg2.angular.x == -0.1 + assert ros_msg2.angular.y == -0.2 + assert ros_msg2.angular.z == -0.3 + + +@pytest.mark.ros +def test_ros_roundtrip_conversion() -> None: + """Test round-trip conversion maintains data integrity.""" + # LCM -> ROS -> LCM + original_lcm = Twist(linear=Vector3(1.234, 5.678, 9.012), angular=Vector3(0.111, 0.222, 0.333)) + ros_intermediate = original_lcm.to_ros_msg() + final_lcm = Twist.from_ros_msg(ros_intermediate) + + assert final_lcm == original_lcm + assert final_lcm.linear.x == 1.234 + assert final_lcm.linear.y == 5.678 + assert final_lcm.linear.z == 9.012 + assert final_lcm.angular.x == 0.111 + assert final_lcm.angular.y == 0.222 + assert final_lcm.angular.z == 0.333 diff --git a/dimos/msgs/geometry_msgs/test_TwistStamped.py b/dimos/msgs/geometry_msgs/test_TwistStamped.py new file mode 100644 index 0000000000..7ba2f59e7d --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistStamped.py @@ -0,0 +1,158 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import time + +import pytest + +try: + from geometry_msgs.msg import TwistStamped as ROSTwistStamped +except ImportError: + ROSTwistStamped = None + +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped + + +def test_lcm_encode_decode() -> None: + """Test encoding and decoding of TwistStamped to/from binary LCM format.""" + twist_source = TwistStamped( + ts=time.time(), + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + binary_msg = twist_source.lcm_encode() + twist_dest = TwistStamped.lcm_decode(binary_msg) + + assert isinstance(twist_dest, TwistStamped) + assert twist_dest is not twist_source + + print(twist_source.linear) + print(twist_source.angular) + + print(twist_dest.linear) + print(twist_dest.angular) + assert twist_dest == twist_source + + +def test_pickle_encode_decode() -> None: + """Test encoding and decoding of TwistStamped to/from binary pickle format.""" + + twist_source = TwistStamped( + ts=time.time(), + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + binary_msg = pickle.dumps(twist_source) + twist_dest = pickle.loads(binary_msg) + assert isinstance(twist_dest, TwistStamped) + assert twist_dest is not twist_source + assert twist_dest == twist_source + + +@pytest.mark.ros +def test_twist_stamped_from_ros_msg() -> None: + """Test creating a TwistStamped from a ROS TwistStamped message.""" + ros_msg = ROSTwistStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.twist.linear.x = 1.0 + ros_msg.twist.linear.y = 2.0 + ros_msg.twist.linear.z = 3.0 + ros_msg.twist.angular.x = 0.1 + ros_msg.twist.angular.y = 0.2 + ros_msg.twist.angular.z = 0.3 + + twist_stamped = TwistStamped.from_ros_msg(ros_msg) + + assert twist_stamped.frame_id == "world" + assert twist_stamped.ts == 123.456 + assert twist_stamped.linear.x == 1.0 + assert twist_stamped.linear.y == 2.0 + assert twist_stamped.linear.z == 3.0 + assert twist_stamped.angular.x == 0.1 + assert twist_stamped.angular.y == 0.2 + assert twist_stamped.angular.z == 0.3 + + +@pytest.mark.ros +def test_twist_stamped_to_ros_msg() -> None: + """Test converting a TwistStamped to a ROS TwistStamped message.""" + twist_stamped = TwistStamped( + ts=123.456, + frame_id="base_link", + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + + ros_msg = twist_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert ros_msg.twist.linear.x == 1.0 + assert ros_msg.twist.linear.y == 2.0 + assert ros_msg.twist.linear.z == 3.0 + assert ros_msg.twist.angular.x == 0.1 + assert ros_msg.twist.angular.y == 0.2 + assert ros_msg.twist.angular.z == 0.3 + + +@pytest.mark.ros +def test_twist_stamped_ros_roundtrip() -> None: + """Test round-trip conversion between TwistStamped and ROS TwistStamped.""" + original = TwistStamped( + ts=123.789, + frame_id="odom", + linear=(1.5, 2.5, 3.5), + angular=(0.15, 0.25, 0.35), + ) + + ros_msg = original.to_ros_msg() + restored = TwistStamped.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert restored.linear.x == original.linear.x + assert restored.linear.y == original.linear.y + assert restored.linear.z == original.linear.z + assert restored.angular.x == original.angular.x + assert restored.angular.y == original.angular.y + assert restored.angular.z == original.angular.z + + +if __name__ == "__main__": + print("Running test_lcm_encode_decode...") + test_lcm_encode_decode() + print("✓ test_lcm_encode_decode passed") + + print("Running test_pickle_encode_decode...") + test_pickle_encode_decode() + print("✓ test_pickle_encode_decode passed") + + print("Running test_twist_stamped_from_ros_msg...") + test_twist_stamped_from_ros_msg() + print("✓ test_twist_stamped_from_ros_msg passed") + + print("Running test_twist_stamped_to_ros_msg...") + test_twist_stamped_to_ros_msg() + print("✓ test_twist_stamped_to_ros_msg passed") + + print("Running test_twist_stamped_ros_roundtrip...") + test_twist_stamped_ros_roundtrip() + print("✓ test_twist_stamped_ros_roundtrip passed") + + print("\nAll tests passed!") diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py new file mode 100644 index 0000000000..746b0c3646 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py @@ -0,0 +1,423 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +try: + from geometry_msgs.msg import ( + Twist as ROSTwist, + TwistWithCovariance as ROSTwistWithCovariance, + Vector3 as ROSVector3, + ) +except ImportError: + ROSTwist = None + ROSTwistWithCovariance = None + ROSVector3 = None + +from dimos_lcm.geometry_msgs import TwistWithCovariance as LCMTwistWithCovariance + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_twist_with_covariance_default_init() -> None: + """Test that default initialization creates a zero twist with zero covariance.""" + if ROSVector3 is None: + pytest.skip("ROS not available") + if ROSTwistWithCovariance is None: + pytest.skip("ROS not available") + twist_cov = TwistWithCovariance() + + # Twist should be zero + assert twist_cov.twist.linear.x == 0.0 + assert twist_cov.twist.linear.y == 0.0 + assert twist_cov.twist.linear.z == 0.0 + assert twist_cov.twist.angular.x == 0.0 + assert twist_cov.twist.angular.y == 0.0 + assert twist_cov.twist.angular.z == 0.0 + + # Covariance should be all zeros + assert np.all(twist_cov.covariance == 0.0) + assert twist_cov.covariance.shape == (36,) + + +def test_twist_with_covariance_twist_init() -> None: + """Test initialization with a Twist object.""" + linear = Vector3(1.0, 2.0, 3.0) + angular = Vector3(0.1, 0.2, 0.3) + twist = Twist(linear, angular) + twist_cov = TwistWithCovariance(twist) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should be zeros by default + assert np.all(twist_cov.covariance == 0.0) + + +def test_twist_with_covariance_twist_and_covariance_init() -> None: + """Test initialization with twist and covariance.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_tuple_init() -> None: + """Test initialization with tuple of (linear, angular) velocities.""" + linear = [1.0, 2.0, 3.0] + angular = [0.1, 0.2, 0.3] + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance((linear, angular), covariance) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_list_covariance() -> None: + """Test initialization with covariance as a list.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance_list = list(range(36)) + twist_cov = TwistWithCovariance(twist, covariance_list) + + # Covariance should be converted to numpy array + assert isinstance(twist_cov.covariance, np.ndarray) + assert np.array_equal(twist_cov.covariance, np.array(covariance_list)) + + +def test_twist_with_covariance_copy_init() -> None: + """Test copy constructor.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + original = TwistWithCovariance(twist, covariance) + copy = TwistWithCovariance(original) + + # Should be equal but not the same object + assert copy == original + assert copy is not original + assert copy.twist is not original.twist + assert copy.covariance is not original.covariance + + # Modify original to ensure they're independent + original.covariance[0] = 999.0 + assert copy.covariance[0] != 999.0 + + +def test_twist_with_covariance_lcm_init() -> None: + """Test initialization from LCM message.""" + lcm_msg = LCMTwistWithCovariance() + lcm_msg.twist.linear.x = 1.0 + lcm_msg.twist.linear.y = 2.0 + lcm_msg.twist.linear.z = 3.0 + lcm_msg.twist.angular.x = 0.1 + lcm_msg.twist.angular.y = 0.2 + lcm_msg.twist.angular.z = 0.3 + lcm_msg.covariance = list(range(36)) + + twist_cov = TwistWithCovariance(lcm_msg) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +def test_twist_with_covariance_dict_init() -> None: + """Test initialization from dictionary.""" + twist_dict = { + "twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)), + "covariance": list(range(36)), + } + twist_cov = TwistWithCovariance(twist_dict) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +def test_twist_with_covariance_dict_init_no_covariance() -> None: + """Test initialization from dictionary without covariance.""" + twist_dict = {"twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3))} + twist_cov = TwistWithCovariance(twist_dict) + + assert twist_cov.twist.linear.x == 1.0 + assert np.all(twist_cov.covariance == 0.0) + + +def test_twist_with_covariance_tuple_of_tuple_init() -> None: + """Test initialization from tuple of (twist_tuple, covariance).""" + twist_tuple = ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3]) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance((twist_tuple, covariance)) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_properties() -> None: + """Test convenience properties.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + # Linear and angular properties + assert twist_cov.linear.x == 1.0 + assert twist_cov.linear.y == 2.0 + assert twist_cov.linear.z == 3.0 + assert twist_cov.angular.x == 0.1 + assert twist_cov.angular.y == 0.2 + assert twist_cov.angular.z == 0.3 + + +def test_twist_with_covariance_matrix_property() -> None: + """Test covariance matrix property.""" + twist = Twist() + covariance_array = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance_array) + + # Get as matrix + cov_matrix = twist_cov.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert cov_matrix[0, 0] == 0.0 + assert cov_matrix[5, 5] == 35.0 + + # Set from matrix + new_matrix = np.eye(6) * 2.0 + twist_cov.covariance_matrix = new_matrix + assert np.array_equal(twist_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + +def test_twist_with_covariance_repr() -> None: + """Test string representation.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + repr_str = repr(twist_cov) + assert "TwistWithCovariance" in repr_str + assert "twist=" in repr_str + assert "covariance=" in repr_str + assert "36 elements" in repr_str + + +def test_twist_with_covariance_str() -> None: + """Test string formatting.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov = TwistWithCovariance(twist, covariance) + + str_repr = str(twist_cov) + assert "TwistWithCovariance" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "6.000" in str_repr # Trace of identity matrix is 6 + + +def test_twist_with_covariance_equality() -> None: + """Test equality comparison.""" + twist1 = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + cov1 = np.arange(36, dtype=float) + twist_cov1 = TwistWithCovariance(twist1, cov1) + + twist2 = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + cov2 = np.arange(36, dtype=float) + twist_cov2 = TwistWithCovariance(twist2, cov2) + + # Equal + assert twist_cov1 == twist_cov2 + + # Different twist + twist3 = Twist(Vector3(1.1, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov3 = TwistWithCovariance(twist3, cov1) + assert twist_cov1 != twist_cov3 + + # Different covariance + cov3 = np.arange(36, dtype=float) + 1 + twist_cov4 = TwistWithCovariance(twist1, cov3) + assert twist_cov1 != twist_cov4 + + # Different type + assert twist_cov1 != "not a twist" + assert twist_cov1 is not None + + +def test_twist_with_covariance_is_zero() -> None: + """Test is_zero method.""" + # Zero twist + twist_cov1 = TwistWithCovariance() + assert twist_cov1.is_zero() + assert not twist_cov1 # Boolean conversion + + # Non-zero twist + twist = Twist(Vector3(1.0, 0.0, 0.0), Vector3(0.0, 0.0, 0.0)) + twist_cov2 = TwistWithCovariance(twist) + assert not twist_cov2.is_zero() + assert twist_cov2 # Boolean conversion + + +def test_twist_with_covariance_lcm_encode_decode() -> None: + """Test LCM encoding and decoding.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + source = TwistWithCovariance(twist, covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = TwistWithCovariance.lcm_decode(binary_msg) + + # Should be equal + assert decoded == source + assert isinstance(decoded, TwistWithCovariance) + assert isinstance(decoded.twist, Twist) + assert isinstance(decoded.covariance, np.ndarray) + + +@pytest.mark.ros +def test_twist_with_covariance_from_ros_msg() -> None: + """Test creating from ROS message.""" + ros_msg = ROSTwistWithCovariance() + ros_msg.twist.linear = ROSVector3(x=1.0, y=2.0, z=3.0) + ros_msg.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.covariance = [float(i) for i in range(36)] + + twist_cov = TwistWithCovariance.from_ros_msg(ros_msg) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_to_ros_msg() -> None: + """Test converting to ROS message.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance) + + ros_msg = twist_cov.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistWithCovariance) + assert ros_msg.twist.linear.x == 1.0 + assert ros_msg.twist.linear.y == 2.0 + assert ros_msg.twist.linear.z == 3.0 + assert ros_msg.twist.angular.x == 0.1 + assert ros_msg.twist.angular.y == 0.2 + assert ros_msg.twist.angular.z == 0.3 + assert list(ros_msg.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_ros_roundtrip() -> None: + """Test round-trip conversion with ROS messages.""" + twist = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.15, 0.25, 0.35)) + covariance = np.random.rand(36) + original = TwistWithCovariance(twist, covariance) + + ros_msg = original.to_ros_msg() + restored = TwistWithCovariance.from_ros_msg(ros_msg) + + assert restored == original + + +def test_twist_with_covariance_zero_covariance() -> None: + """Test with zero covariance matrix.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + assert np.all(twist_cov.covariance == 0.0) + assert np.trace(twist_cov.covariance_matrix) == 0.0 + + +def test_twist_with_covariance_diagonal_covariance() -> None: + """Test with diagonal covariance matrix.""" + twist = Twist() + covariance = np.zeros(36) + # Set diagonal elements + for i in range(6): + covariance[i * 6 + i] = i + 1 + + twist_cov = TwistWithCovariance(twist, covariance) + + cov_matrix = twist_cov.covariance_matrix + assert np.trace(cov_matrix) == sum(range(1, 7)) # 1+2+3+4+5+6 = 21 + + # Check diagonal elements + for i in range(6): + assert cov_matrix[i, i] == i + 1 + + # Check off-diagonal elements are zero + for i in range(6): + for j in range(6): + if i != j: + assert cov_matrix[i, j] == 0.0 + + +@pytest.mark.parametrize( + "linear,angular", + [ + ([0.0, 0.0, 0.0], [0.0, 0.0, 0.0]), + ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3]), + ([-1.0, -2.0, -3.0], [-0.1, -0.2, -0.3]), + ([100.0, -100.0, 0.0], [3.14, -3.14, 0.0]), + ], +) +def test_twist_with_covariance_parametrized_velocities(linear, angular) -> None: + """Parametrized test for various velocity values.""" + twist = Twist(linear, angular) + twist_cov = TwistWithCovariance(twist) + + assert twist_cov.linear.x == linear[0] + assert twist_cov.linear.y == linear[1] + assert twist_cov.linear.z == linear[2] + assert twist_cov.angular.x == angular[0] + assert twist_cov.angular.y == angular[1] + assert twist_cov.angular.z == angular[2] diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py new file mode 100644 index 0000000000..f0d7e5b4ab --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py @@ -0,0 +1,392 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest + +try: + from builtin_interfaces.msg import Time as ROSTime + from geometry_msgs.msg import ( + Twist as ROSTwist, + TwistWithCovariance as ROSTwistWithCovariance, + TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped, + Vector3 as ROSVector3, + ) + from std_msgs.msg import Header as ROSHeader +except ImportError: + ROSTwistWithCovarianceStamped = None + ROSTwist = None + ROSHeader = None + ROSTime = None + ROSTwistWithCovariance = None + ROSVector3 = None + + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_twist_with_covariance_stamped_default_init() -> None: + """Test default initialization.""" + if ROSVector3 is None: + pytest.skip("ROS not available") + if ROSTwistWithCovariance is None: + pytest.skip("ROS not available") + if ROSTime is None: + pytest.skip("ROS not available") + if ROSHeader is None: + pytest.skip("ROS not available") + if ROSTwist is None: + pytest.skip("ROS not available") + if ROSTwistWithCovarianceStamped is None: + pytest.skip("ROS not available") + twist_cov_stamped = TwistWithCovarianceStamped() + + # Should have current timestamp + assert twist_cov_stamped.ts > 0 + assert twist_cov_stamped.frame_id == "" + + # Twist should be zero + assert twist_cov_stamped.twist.linear.x == 0.0 + assert twist_cov_stamped.twist.linear.y == 0.0 + assert twist_cov_stamped.twist.linear.z == 0.0 + assert twist_cov_stamped.twist.angular.x == 0.0 + assert twist_cov_stamped.twist.angular.y == 0.0 + assert twist_cov_stamped.twist.angular.z == 0.0 + + # Covariance should be all zeros + assert np.all(twist_cov_stamped.covariance == 0.0) + + +def test_twist_with_covariance_stamped_with_timestamp() -> None: + """Test initialization with specific timestamp.""" + ts = 1234567890.123456 + frame_id = "base_link" + twist_cov_stamped = TwistWithCovarianceStamped(ts=ts, frame_id=frame_id) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + + +def test_twist_with_covariance_stamped_with_twist() -> None: + """Test initialization with twist.""" + ts = 1234567890.123456 + frame_id = "odom" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.linear.y == 2.0 + assert twist_cov_stamped.twist.linear.z == 3.0 + assert np.array_equal(twist_cov_stamped.covariance, covariance) + + +def test_twist_with_covariance_stamped_with_tuple() -> None: + """Test initialization with tuple of velocities.""" + ts = 1234567890.123456 + frame_id = "robot_base" + linear = [1.0, 2.0, 3.0] + angular = [0.1, 0.2, 0.3] + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=(linear, angular), covariance=covariance + ) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.angular.x == 0.1 + assert np.array_equal(twist_cov_stamped.covariance, covariance) + + +def test_twist_with_covariance_stamped_properties() -> None: + """Test convenience properties.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="cmd_vel", twist=twist, covariance=covariance + ) + + # Linear and angular properties + assert twist_cov_stamped.linear.x == 1.0 + assert twist_cov_stamped.linear.y == 2.0 + assert twist_cov_stamped.linear.z == 3.0 + assert twist_cov_stamped.angular.x == 0.1 + assert twist_cov_stamped.angular.y == 0.2 + assert twist_cov_stamped.angular.z == 0.3 + + # Covariance matrix + cov_matrix = twist_cov_stamped.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert np.trace(cov_matrix) == 6.0 + + +def test_twist_with_covariance_stamped_str() -> None: + """Test string representation.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.111, 0.222, 0.333)) + covariance = np.eye(6).flatten() * 2.0 + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="world", twist=twist, covariance=covariance + ) + + str_repr = str(twist_cov_stamped) + assert "TwistWithCovarianceStamped" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "12.000" in str_repr # Trace of 2*identity is 12 + + +def test_twist_with_covariance_stamped_lcm_encode_decode() -> None: + """Test LCM encoding and decoding.""" + ts = 1234567890.123456 + frame_id = "camera_link" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + source = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = TwistWithCovarianceStamped.lcm_decode(binary_msg) + + # Check timestamp (may lose some precision) + assert abs(decoded.ts - ts) < 1e-6 + assert decoded.frame_id == frame_id + + # Check twist + assert decoded.twist.linear.x == 1.0 + assert decoded.twist.linear.y == 2.0 + assert decoded.twist.linear.z == 3.0 + assert decoded.twist.angular.x == 0.1 + assert decoded.twist.angular.y == 0.2 + assert decoded.twist.angular.z == 0.3 + + # Check covariance + assert np.array_equal(decoded.covariance, covariance) + + +@pytest.mark.ros +def test_twist_with_covariance_stamped_from_ros_msg() -> None: + """Test creating from ROS message.""" + ros_msg = ROSTwistWithCovarianceStamped() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "laser" + + # Set twist with covariance + ros_msg.twist = ROSTwistWithCovariance() + ros_msg.twist.twist = ROSTwist() + ros_msg.twist.twist.linear = ROSVector3(x=1.0, y=2.0, z=3.0) + ros_msg.twist.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.twist.covariance = [float(i) for i in range(36)] + + twist_cov_stamped = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + + assert twist_cov_stamped.ts == 1234567890.123456 + assert twist_cov_stamped.frame_id == "laser" + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.linear.y == 2.0 + assert twist_cov_stamped.twist.linear.z == 3.0 + assert twist_cov_stamped.twist.angular.x == 0.1 + assert twist_cov_stamped.twist.angular.y == 0.2 + assert twist_cov_stamped.twist.angular.z == 0.3 + assert np.array_equal(twist_cov_stamped.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_stamped_to_ros_msg() -> None: + """Test converting to ROS message.""" + ts = 1234567890.567890 + frame_id = "imu" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + ros_msg = twist_cov_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistWithCovarianceStamped) + assert ros_msg.header.frame_id == frame_id + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + + assert ros_msg.twist.twist.linear.x == 1.0 + assert ros_msg.twist.twist.linear.y == 2.0 + assert ros_msg.twist.twist.linear.z == 3.0 + assert ros_msg.twist.twist.angular.x == 0.1 + assert ros_msg.twist.twist.angular.y == 0.2 + assert ros_msg.twist.twist.angular.z == 0.3 + assert list(ros_msg.twist.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_stamped_ros_roundtrip() -> None: + """Test round-trip conversion with ROS messages.""" + ts = 2147483647.987654 # Max int32 value for ROS Time.sec + frame_id = "robot_base" + twist = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.15, 0.25, 0.35)) + covariance = np.random.rand(36) + + original = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + ros_msg = original.to_ros_msg() + restored = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + + # Check timestamp (loses some precision in conversion) + assert abs(restored.ts - ts) < 1e-6 + assert restored.frame_id == frame_id + + # Check twist + assert restored.twist.linear.x == original.twist.linear.x + assert restored.twist.linear.y == original.twist.linear.y + assert restored.twist.linear.z == original.twist.linear.z + assert restored.twist.angular.x == original.twist.angular.x + assert restored.twist.angular.y == original.twist.angular.y + assert restored.twist.angular.z == original.twist.angular.z + + # Check covariance + assert np.allclose(restored.covariance, original.covariance) + + +def test_twist_with_covariance_stamped_zero_timestamp() -> None: + """Test that zero timestamp gets replaced with current time.""" + twist_cov_stamped = TwistWithCovarianceStamped(ts=0.0) + + # Should have been replaced with current time + assert twist_cov_stamped.ts > 0 + assert twist_cov_stamped.ts <= time.time() + + +def test_twist_with_covariance_stamped_inheritance() -> None: + """Test that it properly inherits from TwistWithCovariance and Timestamped.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="test", twist=twist, covariance=covariance + ) + + # Should be instance of parent classes + assert isinstance(twist_cov_stamped, TwistWithCovariance) + + # Should have Timestamped attributes + assert hasattr(twist_cov_stamped, "ts") + assert hasattr(twist_cov_stamped, "frame_id") + + # Should have TwistWithCovariance attributes + assert hasattr(twist_cov_stamped, "twist") + assert hasattr(twist_cov_stamped, "covariance") + + +def test_twist_with_covariance_stamped_is_zero() -> None: + """Test is_zero method inheritance.""" + # Zero twist + twist_cov_stamped1 = TwistWithCovarianceStamped() + assert twist_cov_stamped1.is_zero() + assert not twist_cov_stamped1 # Boolean conversion + + # Non-zero twist + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.0)) + twist_cov_stamped2 = TwistWithCovarianceStamped(twist=twist) + assert not twist_cov_stamped2.is_zero() + assert twist_cov_stamped2 # Boolean conversion + + +def test_twist_with_covariance_stamped_sec_nsec() -> None: + """Test the sec_nsec helper function.""" + from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import sec_nsec + + # Test integer seconds + s, ns = sec_nsec(1234567890.0) + assert s == 1234567890 + assert ns == 0 + + # Test fractional seconds + s, ns = sec_nsec(1234567890.123456789) + assert s == 1234567890 + assert abs(ns - 123456789) < 100 # Allow small rounding error + + # Test small fractional seconds + s, ns = sec_nsec(0.000000001) + assert s == 0 + assert ns == 1 + + # Test large timestamp + s, ns = sec_nsec(9999999999.999999999) + # Due to floating point precision, this might round to 10000000000 + assert s in [9999999999, 10000000000] + if s == 9999999999: + assert abs(ns - 999999999) < 10 + else: + assert ns == 0 + + +@pytest.mark.ros +@pytest.mark.parametrize( + "frame_id", + ["", "map", "odom", "base_link", "cmd_vel", "sensor/velocity/front"], +) +def test_twist_with_covariance_stamped_frame_ids(frame_id) -> None: + """Test various frame ID values.""" + twist_cov_stamped = TwistWithCovarianceStamped(frame_id=frame_id) + assert twist_cov_stamped.frame_id == frame_id + + # Test roundtrip through ROS + ros_msg = twist_cov_stamped.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + + restored = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + + +def test_twist_with_covariance_stamped_different_covariances() -> None: + """Test with different covariance patterns.""" + twist = Twist(Vector3(1.0, 0.0, 0.0), Vector3(0.0, 0.0, 0.5)) + + # Zero covariance + zero_cov = np.zeros(36) + twist_cov1 = TwistWithCovarianceStamped(twist=twist, covariance=zero_cov) + assert np.all(twist_cov1.covariance == 0.0) + + # Identity covariance + identity_cov = np.eye(6).flatten() + twist_cov2 = TwistWithCovarianceStamped(twist=twist, covariance=identity_cov) + assert np.trace(twist_cov2.covariance_matrix) == 6.0 + + # Full covariance + full_cov = np.random.rand(36) + twist_cov3 = TwistWithCovarianceStamped(twist=twist, covariance=full_cov) + assert np.array_equal(twist_cov3.covariance, full_cov) diff --git a/dimos/msgs/geometry_msgs/test_Vector3.py b/dimos/msgs/geometry_msgs/test_Vector3.py new file mode 100644 index 0000000000..099e35eb19 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_Vector3.py @@ -0,0 +1,462 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_vector_default_init() -> None: + """Test that default initialization of Vector() has x,y,z components all zero.""" + v = Vector3() + assert v.x == 0.0 + assert v.y == 0.0 + assert v.z == 0.0 + assert len(v.data) == 3 + assert v.to_list() == [0.0, 0.0, 0.0] + assert v.is_zero() # Zero vector should be considered zero + + +def test_vector_specific_init() -> None: + """Test initialization with specific values and different input types.""" + + v1 = Vector3(1.0, 2.0) # 2D vector (now becomes 3D with z=0) + assert v1.x == 1.0 + assert v1.y == 2.0 + assert v1.z == 0.0 + + v2 = Vector3(3.0, 4.0, 5.0) # 3D vector + assert v2.x == 3.0 + assert v2.y == 4.0 + assert v2.z == 5.0 + + v3 = Vector3([6.0, 7.0, 8.0]) + assert v3.x == 6.0 + assert v3.y == 7.0 + assert v3.z == 8.0 + + v4 = Vector3((9.0, 10.0, 11.0)) + assert v4.x == 9.0 + assert v4.y == 10.0 + assert v4.z == 11.0 + + v5 = Vector3(np.array([12.0, 13.0, 14.0])) + assert v5.x == 12.0 + assert v5.y == 13.0 + assert v5.z == 14.0 + + original = Vector3([15.0, 16.0, 17.0]) + v6 = Vector3(original) + assert v6.x == 15.0 + assert v6.y == 16.0 + assert v6.z == 17.0 + + assert v6 is not original + assert v6 == original + + +def test_vector_addition() -> None: + """Test vector addition.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + v_add = v1 + v2 + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + +def test_vector_subtraction() -> None: + """Test vector subtraction.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + v_sub = v2 - v1 + assert v_sub.x == 3.0 + assert v_sub.y == 3.0 + assert v_sub.z == 3.0 + + +def test_vector_scalar_multiplication() -> None: + """Test vector multiplication by a scalar.""" + v1 = Vector3(1.0, 2.0, 3.0) + + v_mul = v1 * 2.0 + assert v_mul.x == 2.0 + assert v_mul.y == 4.0 + assert v_mul.z == 6.0 + + # Test right multiplication + v_rmul = 2.0 * v1 + assert v_rmul.x == 2.0 + assert v_rmul.y == 4.0 + assert v_rmul.z == 6.0 + + +def test_vector_scalar_division() -> None: + """Test vector division by a scalar.""" + v2 = Vector3(4.0, 5.0, 6.0) + + v_div = v2 / 2.0 + assert v_div.x == 2.0 + assert v_div.y == 2.5 + assert v_div.z == 3.0 + + +def test_vector_dot_product() -> None: + """Test vector dot product.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + dot = v1.dot(v2) + assert dot == 32.0 + + +def test_vector_length() -> None: + """Test vector length calculation.""" + # 2D vector with length 5 (now 3D with z=0) + v1 = Vector3(3.0, 4.0) + assert v1.length() == 5.0 + + # 3D vector + v2 = Vector3(2.0, 3.0, 6.0) + assert v2.length() == pytest.approx(7.0, 0.001) + + # Test length_squared + assert v1.length_squared() == 25.0 + assert v2.length_squared() == 49.0 + + +def test_vector_normalize() -> None: + """Test vector normalization.""" + v = Vector3(2.0, 3.0, 6.0) + assert not v.is_zero() + + v_norm = v.normalize() + length = v.length() + expected_x = 2.0 / length + expected_y = 3.0 / length + expected_z = 6.0 / length + + assert np.isclose(v_norm.x, expected_x) + assert np.isclose(v_norm.y, expected_y) + assert np.isclose(v_norm.z, expected_z) + assert np.isclose(v_norm.length(), 1.0) + assert not v_norm.is_zero() + + # Test normalizing a zero vector + v_zero = Vector3(0.0, 0.0, 0.0) + assert v_zero.is_zero() + v_zero_norm = v_zero.normalize() + assert v_zero_norm.x == 0.0 + assert v_zero_norm.y == 0.0 + assert v_zero_norm.z == 0.0 + assert v_zero_norm.is_zero() + + +def test_vector_to_2d() -> None: + """Test conversion to 2D vector.""" + v = Vector3(2.0, 3.0, 6.0) + + v_2d = v.to_2d() + assert v_2d.x == 2.0 + assert v_2d.y == 3.0 + assert v_2d.z == 0.0 # z should be 0 for 2D conversion + + # Already 2D vector (z=0) + v2 = Vector3(4.0, 5.0) + v2_2d = v2.to_2d() + assert v2_2d.x == 4.0 + assert v2_2d.y == 5.0 + assert v2_2d.z == 0.0 + + +def test_vector_distance() -> None: + """Test distance calculations between vectors.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 6.0, 8.0) + + # Distance + dist = v1.distance(v2) + expected_dist = np.sqrt(9.0 + 16.0 + 25.0) # sqrt((4-1)² + (6-2)² + (8-3)²) + assert dist == pytest.approx(expected_dist) + + # Distance squared + dist_sq = v1.distance_squared(v2) + assert dist_sq == 50.0 # 9 + 16 + 25 + + +def test_vector_cross_product() -> None: + """Test vector cross product.""" + v1 = Vector3(1.0, 0.0, 0.0) # Unit x vector + v2 = Vector3(0.0, 1.0, 0.0) # Unit y vector + + # v1 × v2 should be unit z vector + cross = v1.cross(v2) + assert cross.x == 0.0 + assert cross.y == 0.0 + assert cross.z == 1.0 + + # Test with more complex vectors + a = Vector3(2.0, 3.0, 4.0) + b = Vector3(5.0, 6.0, 7.0) + c = a.cross(b) + + # Cross product manually calculated: + # (3*7-4*6, 4*5-2*7, 2*6-3*5) + assert c.x == -3.0 + assert c.y == 6.0 + assert c.z == -3.0 + + # Test with vectors that have z=0 (still works as they're 3D) + v_2d1 = Vector3(1.0, 2.0) # (1, 2, 0) + v_2d2 = Vector3(3.0, 4.0) # (3, 4, 0) + cross_2d = v_2d1.cross(v_2d2) + # (2*0-0*4, 0*3-1*0, 1*4-2*3) = (0, 0, -2) + assert cross_2d.x == 0.0 + assert cross_2d.y == 0.0 + assert cross_2d.z == -2.0 + + +def test_vector_zeros() -> None: + """Test Vector3.zeros class method.""" + # 3D zero vector + v_zeros = Vector3.zeros() + assert v_zeros.x == 0.0 + assert v_zeros.y == 0.0 + assert v_zeros.z == 0.0 + assert v_zeros.is_zero() + + +def test_vector_ones() -> None: + """Test Vector3.ones class method.""" + # 3D ones vector + v_ones = Vector3.ones() + assert v_ones.x == 1.0 + assert v_ones.y == 1.0 + assert v_ones.z == 1.0 + + +def test_vector_conversion_methods() -> None: + """Test vector conversion methods (to_list, to_tuple, to_numpy).""" + v = Vector3(1.0, 2.0, 3.0) + + # to_list + assert v.to_list() == [1.0, 2.0, 3.0] + + # to_tuple + assert v.to_tuple() == (1.0, 2.0, 3.0) + + # to_numpy + np_array = v.to_numpy() + assert isinstance(np_array, np.ndarray) + assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) + + +def test_vector_equality() -> None: + """Test vector equality.""" + v1 = Vector3(1, 2, 3) + v2 = Vector3(1, 2, 3) + v3 = Vector3(4, 5, 6) + + assert v1 == v2 + assert v1 != v3 + assert v1 != Vector3(1, 2) # Now (1, 2, 0) vs (1, 2, 3) + assert v1 != Vector3(1.1, 2, 3) # Different values + assert v1 != [1, 2, 3] + + +def test_vector_is_zero() -> None: + """Test is_zero method for vectors.""" + # Default zero vector + v0 = Vector3() + assert v0.is_zero() + + # Explicit zero vector + v1 = Vector3(0.0, 0.0, 0.0) + assert v1.is_zero() + + # Zero vector with different initialization (now always 3D) + v2 = Vector3(0.0, 0.0) # Becomes (0, 0, 0) + assert v2.is_zero() + + # Non-zero vectors + v3 = Vector3(1.0, 0.0, 0.0) + assert not v3.is_zero() + + v4 = Vector3(0.0, 2.0, 0.0) + assert not v4.is_zero() + + v5 = Vector3(0.0, 0.0, 3.0) + assert not v5.is_zero() + + # Almost zero (within tolerance) + v6 = Vector3(1e-10, 1e-10, 1e-10) + assert v6.is_zero() + + # Almost zero (outside tolerance) + v7 = Vector3(1e-6, 1e-6, 1e-6) + assert not v7.is_zero() + + +def test_vector_bool_conversion(): + """Test boolean conversion of vectors.""" + # Zero vectors should be False + v0 = Vector3() + assert not bool(v0) + + v1 = Vector3(0.0, 0.0, 0.0) + assert not bool(v1) + + # Almost zero vectors should be False + v2 = Vector3(1e-10, 1e-10, 1e-10) + assert not bool(v2) + + # Non-zero vectors should be True + v3 = Vector3(1.0, 0.0, 0.0) + assert bool(v3) + + v4 = Vector3(0.0, 2.0, 0.0) + assert bool(v4) + + v5 = Vector3(0.0, 0.0, 3.0) + assert bool(v5) + + # Direct use in if statements + if v0: + raise AssertionError("Zero vector should be False in boolean context") + else: + pass # Expected path + + if v3: + pass # Expected path + else: + raise AssertionError("Non-zero vector should be True in boolean context") + + +def test_vector_add() -> None: + """Test vector addition operator.""" + v1 = Vector3(1.0, 2.0, 3.0) + v2 = Vector3(4.0, 5.0, 6.0) + + # Using __add__ method + v_add = v1.__add__(v2) + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + # Using + operator + v_add_op = v1 + v2 + assert v_add_op.x == 5.0 + assert v_add_op.y == 7.0 + assert v_add_op.z == 9.0 + + # Adding zero vector should return original vector + v_zero = Vector3.zeros() + assert (v1 + v_zero) == v1 + + +def test_vector_add_dim_mismatch() -> None: + """Test vector addition with different input dimensions (now all vectors are 3D).""" + v1 = Vector3(1.0, 2.0) # Becomes (1, 2, 0) + v2 = Vector3(4.0, 5.0, 6.0) # (4, 5, 6) + + # Using + operator - should work fine now since both are 3D + v_add_op = v1 + v2 + assert v_add_op.x == 5.0 # 1 + 4 + assert v_add_op.y == 7.0 # 2 + 5 + assert v_add_op.z == 6.0 # 0 + 6 + + +def test_yaw_pitch_roll_accessors() -> None: + """Test yaw, pitch, and roll accessor properties.""" + # Test with a 3D vector + v = Vector3(1.0, 2.0, 3.0) + + # According to standard convention: + # roll = rotation around x-axis = x component + # pitch = rotation around y-axis = y component + # yaw = rotation around z-axis = z component + assert v.roll == 1.0 # Should return x component + assert v.pitch == 2.0 # Should return y component + assert v.yaw == 3.0 # Should return z component + + # Test with a 2D vector (z should be 0.0) + v_2d = Vector3(4.0, 5.0) + assert v_2d.roll == 4.0 # Should return x component + assert v_2d.pitch == 5.0 # Should return y component + assert v_2d.yaw == 0.0 # Should return z component (defaults to 0 for 2D) + + # Test with empty vector (all should be 0.0) + v_empty = Vector3() + assert v_empty.roll == 0.0 + assert v_empty.pitch == 0.0 + assert v_empty.yaw == 0.0 + + # Test with negative values + v_neg = Vector3(-1.5, -2.5, -3.5) + assert v_neg.roll == -1.5 + assert v_neg.pitch == -2.5 + assert v_neg.yaw == -3.5 + + +def test_vector_to_quaternion() -> None: + """Test vector to quaternion conversion.""" + # Test with zero Euler angles (should produce identity quaternion) + v_zero = Vector3(0.0, 0.0, 0.0) + q_identity = v_zero.to_quaternion() + + # Identity quaternion should have w=1, x=y=z=0 + assert np.isclose(q_identity.x, 0.0, atol=1e-10) + assert np.isclose(q_identity.y, 0.0, atol=1e-10) + assert np.isclose(q_identity.z, 0.0, atol=1e-10) + assert np.isclose(q_identity.w, 1.0, atol=1e-10) + + # Test with small angles (to avoid gimbal lock issues) + v_small = Vector3(0.1, 0.2, 0.3) # Small roll, pitch, yaw + q_small = v_small.to_quaternion() + + # Quaternion should be normalized (magnitude = 1) + magnitude = np.sqrt(q_small.x**2 + q_small.y**2 + q_small.z**2 + q_small.w**2) + assert np.isclose(magnitude, 1.0, atol=1e-10) + + # Test conversion back to Euler (should be close to original) + v_back = q_small.to_euler() + assert np.isclose(v_back.x, 0.1, atol=1e-6) + assert np.isclose(v_back.y, 0.2, atol=1e-6) + assert np.isclose(v_back.z, 0.3, atol=1e-6) + + # Test with π/2 rotation around x-axis + v_x_90 = Vector3(np.pi / 2, 0.0, 0.0) + q_x_90 = v_x_90.to_quaternion() + + # Should be approximately (sin(π/4), 0, 0, cos(π/4)) = (√2/2, 0, 0, √2/2) + expected = np.sqrt(2) / 2 + assert np.isclose(q_x_90.x, expected, atol=1e-10) + assert np.isclose(q_x_90.y, 0.0, atol=1e-10) + assert np.isclose(q_x_90.z, 0.0, atol=1e-10) + assert np.isclose(q_x_90.w, expected, atol=1e-10) + + +def test_lcm_encode_decode() -> None: + v_source = Vector3(1.0, 2.0, 3.0) + + binary_msg = v_source.lcm_encode() + + v_dest = Vector3.lcm_decode(binary_msg) + + assert isinstance(v_dest, Vector3) + assert v_dest is not v_source + assert v_dest == v_source diff --git a/dimos/msgs/geometry_msgs/test_publish.py b/dimos/msgs/geometry_msgs/test_publish.py new file mode 100644 index 0000000000..b3d2324af0 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_publish.py @@ -0,0 +1,54 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import lcm +import pytest + +from dimos.msgs.geometry_msgs import Vector3 + + +@pytest.mark.tool +def test_runpublish() -> None: + for i in range(10): + msg = Vector3(-5 + i, -5 + i, i) + lc = lcm.LCM() + lc.publish("thing1_vector3#geometry_msgs.Vector3", msg.encode()) + time.sleep(0.1) + print(f"Published: {msg}") + + +@pytest.mark.tool +def test_receive() -> None: + lc = lcm.LCM() + + def receive(bla, msg) -> None: + # print("receive", bla, msg) + print(Vector3.decode(msg)) + + lc.subscribe("thing1_vector3#geometry_msgs.Vector3", receive) + + def _loop() -> None: + while True: + """LCM message handling loop""" + try: + lc.handle() + # loop 10000 times + for _ in range(10000000): + 3 + 3 # noqa: B018 + except Exception as e: + print(f"Error in LCM handling: {e}") + + _loop() diff --git a/dimos/msgs/nav_msgs/OccupancyGrid.py b/dimos/msgs/nav_msgs/OccupancyGrid.py new file mode 100644 index 0000000000..c0437ac36d --- /dev/null +++ b/dimos/msgs/nav_msgs/OccupancyGrid.py @@ -0,0 +1,611 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import IntEnum +import time +from typing import TYPE_CHECKING, BinaryIO + +from dimos_lcm.nav_msgs import ( # type: ignore[import-untyped] + MapMetaData, + OccupancyGrid as LCMOccupancyGrid, +) +from dimos_lcm.std_msgs import Time as LCMTime # type: ignore[import-untyped] +import numpy as np +from scipy import ndimage + +from dimos.msgs.geometry_msgs import Pose, Vector3, VectorLike +from dimos.types.timestamped import Timestamped + +if TYPE_CHECKING: + from dimos.msgs.sensor_msgs import PointCloud2 + + +class CostValues(IntEnum): + """Standard cost values for occupancy grid cells. + + These values follow the ROS nav_msgs/OccupancyGrid convention: + - 0: Free space + - 1-99: Occupied space with varying cost levels + - 100: Lethal obstacle (definitely occupied) + - -1: Unknown space + """ + + UNKNOWN = -1 # Unknown space + FREE = 0 # Free space + OCCUPIED = 100 # Occupied/lethal space + + +class OccupancyGrid(Timestamped): + """ + Convenience wrapper for nav_msgs/OccupancyGrid with numpy array support. + """ + + msg_name = "nav_msgs.OccupancyGrid" + + # Attributes + ts: float + frame_id: str + info: MapMetaData + grid: np.ndarray # type: ignore[type-arg] + + def __init__( + self, + grid: np.ndarray | None = None, # type: ignore[type-arg] + width: int | None = None, + height: int | None = None, + resolution: float = 0.05, + origin: Pose | None = None, + frame_id: str = "world", + ts: float = 0.0, + ) -> None: + """Initialize OccupancyGrid. + + Args: + grid: 2D numpy array of int8 values (height x width) + width: Width in cells (used if grid is None) + height: Height in cells (used if grid is None) + resolution: Grid resolution in meters/cell + origin: Origin pose of the grid + frame_id: Reference frame + ts: Timestamp (defaults to current time if 0) + """ + + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + + if grid is not None: + # Initialize from numpy array + if grid.ndim != 2: + raise ValueError("Grid must be a 2D array") + height, width = grid.shape + self.info = MapMetaData( + map_load_time=self._to_lcm_time(), # type: ignore[no-untyped-call] + resolution=resolution, + width=width, + height=height, + origin=origin or Pose(), + ) + self.grid = grid.astype(np.int8) + elif width is not None and height is not None: + # Initialize with dimensions + self.info = MapMetaData( + map_load_time=self._to_lcm_time(), # type: ignore[no-untyped-call] + resolution=resolution, + width=width, + height=height, + origin=origin or Pose(), + ) + self.grid = np.full((height, width), -1, dtype=np.int8) + else: + # Initialize empty + self.info = MapMetaData(map_load_time=self._to_lcm_time()) # type: ignore[no-untyped-call] + self.grid = np.array([], dtype=np.int8) + + def _to_lcm_time(self): # type: ignore[no-untyped-def] + """Convert timestamp to LCM Time.""" + + s = int(self.ts) + return LCMTime(sec=s, nsec=int((self.ts - s) * 1_000_000_000)) + + @property + def width(self) -> int: + """Width of the grid in cells.""" + return self.info.width # type: ignore[no-any-return] + + @property + def height(self) -> int: + """Height of the grid in cells.""" + return self.info.height # type: ignore[no-any-return] + + @property + def resolution(self) -> float: + """Grid resolution in meters/cell.""" + return self.info.resolution # type: ignore[no-any-return] + + @property + def origin(self) -> Pose: + """Origin pose of the grid.""" + return self.info.origin # type: ignore[no-any-return] + + @property + def total_cells(self) -> int: + """Total number of cells in the grid.""" + return self.width * self.height + + @property + def occupied_cells(self) -> int: + """Number of occupied cells (value >= 1).""" + return int(np.sum(self.grid >= 1)) + + @property + def free_cells(self) -> int: + """Number of free cells (value == 0).""" + return int(np.sum(self.grid == 0)) + + @property + def unknown_cells(self) -> int: + """Number of unknown cells (value == -1).""" + return int(np.sum(self.grid == -1)) + + @property + def occupied_percent(self) -> float: + """Percentage of cells that are occupied.""" + return (self.occupied_cells / self.total_cells * 100) if self.total_cells > 0 else 0.0 + + @property + def free_percent(self) -> float: + """Percentage of cells that are free.""" + return (self.free_cells / self.total_cells * 100) if self.total_cells > 0 else 0.0 + + @property + def unknown_percent(self) -> float: + """Percentage of cells that are unknown.""" + return (self.unknown_cells / self.total_cells * 100) if self.total_cells > 0 else 0.0 + + def inflate(self, radius: float) -> OccupancyGrid: + """Inflate obstacles by a given radius (binary inflation). + Args: + radius: Inflation radius in meters + Returns: + New OccupancyGrid with inflated obstacles + """ + # Convert radius to grid cells + cell_radius = int(np.ceil(radius / self.resolution)) + + # Get grid as numpy array + grid_array = self.grid + + # Create circular kernel for binary inflation + 2 * cell_radius + 1 + y, x = np.ogrid[-cell_radius : cell_radius + 1, -cell_radius : cell_radius + 1] + kernel = (x**2 + y**2 <= cell_radius**2).astype(np.uint8) + + # Find occupied cells + occupied_mask = grid_array >= CostValues.OCCUPIED + + # Binary inflation + inflated = ndimage.binary_dilation(occupied_mask, structure=kernel) + result_grid = grid_array.copy() + result_grid[inflated] = CostValues.OCCUPIED + + # Create new OccupancyGrid with inflated data using numpy constructor + return OccupancyGrid( + grid=result_grid, + resolution=self.resolution, + origin=self.origin, + frame_id=self.frame_id, + ts=self.ts, + ) + + def world_to_grid(self, point: VectorLike) -> Vector3: + """Convert world coordinates to grid coordinates. + + Args: + point: A vector-like object containing X,Y coordinates + + Returns: + Vector3 with grid coordinates + """ + positionVector = Vector3(point) + # Get origin position + ox = self.origin.position.x + oy = self.origin.position.y + + # Convert to grid coordinates (simplified, assuming no rotation) + grid_x = (positionVector.x - ox) / self.resolution + grid_y = (positionVector.y - oy) / self.resolution + + return Vector3(grid_x, grid_y, 0.0) + + def grid_to_world(self, grid_point: VectorLike) -> Vector3: + """Convert grid coordinates to world coordinates. + + Args: + grid_point: Vector-like object containing grid coordinates + + Returns: + World position as Vector3 + """ + gridVector = Vector3(grid_point) + # Get origin position + ox = self.origin.position.x + oy = self.origin.position.y + + # Convert to world (simplified, no rotation) + x = ox + gridVector.x * self.resolution + y = oy + gridVector.y * self.resolution + + return Vector3(x, y, 0.0) + + def __str__(self) -> str: + """Create a concise string representation.""" + origin_pos = self.origin.position + + parts = [ + f"▦ OccupancyGrid[{self.frame_id}]", + f"{self.width}x{self.height}", + f"({self.width * self.resolution:.1f}x{self.height * self.resolution:.1f}m @", + f"{1 / self.resolution:.0f}cm res)", + f"Origin: ({origin_pos.x:.2f}, {origin_pos.y:.2f})", + f"▣ {self.occupied_percent:.1f}%", + f"□ {self.free_percent:.1f}%", + f"◌ {self.unknown_percent:.1f}%", + ] + + return " ".join(parts) + + def __repr__(self) -> str: + """Create a detailed representation.""" + return ( + f"OccupancyGrid(width={self.width}, height={self.height}, " + f"resolution={self.resolution}, frame_id='{self.frame_id}', " + f"occupied={self.occupied_cells}, free={self.free_cells}, " + f"unknown={self.unknown_cells})" + ) + + def lcm_encode(self) -> bytes: + """Encode OccupancyGrid to LCM bytes.""" + # Create LCM message + lcm_msg = LCMOccupancyGrid() + + # Build header on demand + s = int(self.ts) + lcm_msg.header.stamp.sec = s + lcm_msg.header.stamp.nsec = int((self.ts - s) * 1_000_000_000) + lcm_msg.header.frame_id = self.frame_id + + # Copy map metadata + lcm_msg.info = self.info + + # Convert numpy array to flat data list + if self.grid.size > 0: + flat_data = self.grid.flatten() + lcm_msg.data_length = len(flat_data) + lcm_msg.data = flat_data.tolist() + else: + lcm_msg.data_length = 0 + lcm_msg.data = [] + + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> OccupancyGrid: + """Decode LCM bytes to OccupancyGrid.""" + lcm_msg = LCMOccupancyGrid.lcm_decode(data) + + # Extract timestamp and frame_id from header + ts = lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000) + frame_id = lcm_msg.header.frame_id + + # Extract grid data + if lcm_msg.data and lcm_msg.info.width > 0 and lcm_msg.info.height > 0: + grid = np.array(lcm_msg.data, dtype=np.int8).reshape( + (lcm_msg.info.height, lcm_msg.info.width) + ) + else: + grid = np.array([], dtype=np.int8) + + # Create new instance + instance = cls( + grid=grid, + resolution=lcm_msg.info.resolution, + origin=lcm_msg.info.origin, + frame_id=frame_id, + ts=ts, + ) + instance.info = lcm_msg.info + return instance + + @classmethod + def from_pointcloud( + cls, + cloud: PointCloud2, + resolution: float = 0.05, + min_height: float = 0.1, + max_height: float = 2.0, + frame_id: str | None = None, + mark_free_radius: float = 0.4, + ) -> OccupancyGrid: + """Create an OccupancyGrid from a PointCloud2 message. + + Args: + cloud: PointCloud2 message containing 3D points + resolution: Grid resolution in meters/cell (default: 0.05) + min_height: Minimum height threshold for including points (default: 0.1) + max_height: Maximum height threshold for including points (default: 2.0) + frame_id: Reference frame for the grid (default: uses cloud's frame_id) + mark_free_radius: Radius in meters around obstacles to mark as free space (default: 0.0) + If 0, only immediate neighbors are marked free. + Set to preserve unknown areas for exploration. + + Returns: + OccupancyGrid with occupied cells where points were projected + """ + + # Get points as numpy array + points = cloud.as_numpy() + + if len(points) == 0: + # Return empty grid + return cls( + width=1, height=1, resolution=resolution, frame_id=frame_id or cloud.frame_id + ) + + # Filter points by height for obstacles + obstacle_mask = (points[:, 2] >= min_height) & (points[:, 2] <= max_height) + obstacle_points = points[obstacle_mask] + + # Get points below min_height for marking as free space + ground_mask = points[:, 2] < min_height + ground_points = points[ground_mask] + + # Find bounds of the point cloud in X-Y plane (use all points) + if len(points) > 0: + min_x = np.min(points[:, 0]) + max_x = np.max(points[:, 0]) + min_y = np.min(points[:, 1]) + max_y = np.max(points[:, 1]) + else: + # Return empty grid if no points at all + return cls( + width=1, height=1, resolution=resolution, frame_id=frame_id or cloud.frame_id + ) + + # Add some padding around the bounds + padding = 1.0 # 1 meter padding + min_x -= padding + max_x += padding + min_y -= padding + max_y += padding + + # Calculate grid dimensions + width = int(np.ceil((max_x - min_x) / resolution)) + height = int(np.ceil((max_y - min_y) / resolution)) + + # Create origin pose (bottom-left corner of the grid) + origin = Pose() + origin.position.x = min_x + origin.position.y = min_y + origin.position.z = 0.0 + origin.orientation.w = 1.0 # No rotation + + # Initialize grid (all unknown) + grid = np.full((height, width), -1, dtype=np.int8) + + # First, mark ground points as free space + if len(ground_points) > 0: + ground_x = ((ground_points[:, 0] - min_x) / resolution).astype(np.int32) + ground_y = ((ground_points[:, 1] - min_y) / resolution).astype(np.int32) + + # Clip indices to grid bounds + ground_x = np.clip(ground_x, 0, width - 1) + ground_y = np.clip(ground_y, 0, height - 1) + + # Mark ground cells as free + grid[ground_y, ground_x] = 0 # Free space + + # Then mark obstacle points (will override ground if at same location) + if len(obstacle_points) > 0: + obs_x = ((obstacle_points[:, 0] - min_x) / resolution).astype(np.int32) + obs_y = ((obstacle_points[:, 1] - min_y) / resolution).astype(np.int32) + + # Clip indices to grid bounds + obs_x = np.clip(obs_x, 0, width - 1) + obs_y = np.clip(obs_y, 0, height - 1) + + # Mark cells as occupied + grid[obs_y, obs_x] = 100 # Lethal obstacle + + # Apply mark_free_radius to expand free space areas + if mark_free_radius > 0: + # Expand existing free space areas by the specified radius + # This will NOT expand from obstacles, only from free space + + free_mask = grid == 0 # Current free space + free_radius_cells = int(np.ceil(mark_free_radius / resolution)) + + # Create circular kernel + y, x = np.ogrid[ + -free_radius_cells : free_radius_cells + 1, + -free_radius_cells : free_radius_cells + 1, + ] + kernel = x**2 + y**2 <= free_radius_cells**2 + + # Dilate free space areas + expanded_free = ndimage.binary_dilation(free_mask, structure=kernel, iterations=1) + + # Mark expanded areas as free, but don't override obstacles + grid[expanded_free & (grid != 100)] = 0 + + # Create and return OccupancyGrid + # Get timestamp from cloud if available + ts = cloud.ts if hasattr(cloud, "ts") and cloud.ts is not None else 0.0 + + occupancy_grid = cls( + grid=grid, + resolution=resolution, + origin=origin, + frame_id=frame_id or cloud.frame_id, + ts=ts, + ) + + return occupancy_grid + + def gradient(self, obstacle_threshold: int = 50, max_distance: float = 2.0) -> OccupancyGrid: + """Create a gradient OccupancyGrid for path planning. + + Creates a gradient where free space has value 0 and values increase near obstacles. + This can be used as a cost map for path planning algorithms like A*. + + Args: + obstacle_threshold: Cell values >= this are considered obstacles (default: 50) + max_distance: Maximum distance to compute gradient in meters (default: 2.0) + + Returns: + New OccupancyGrid with gradient values: + - -1: Unknown cells (preserved as-is) + - 0: Free space far from obstacles + - 1-99: Increasing cost as you approach obstacles + - 100: At obstacles + + Note: Unknown cells remain as unknown (-1) and do not receive gradient values. + """ + + # Remember which cells are unknown + unknown_mask = self.grid == CostValues.UNKNOWN + + # Create binary obstacle map + # Consider cells >= threshold as obstacles (1), everything else as free (0) + # Unknown cells are not considered obstacles for distance calculation + obstacle_map = (self.grid >= obstacle_threshold).astype(np.float32) + + # Compute distance transform (distance to nearest obstacle in cells) + # Unknown cells are treated as if they don't exist for distance calculation + distance_cells = ndimage.distance_transform_edt(1 - obstacle_map) + + # Convert to meters and clip to max distance + distance_meters = np.clip(distance_cells * self.resolution, 0, max_distance) # type: ignore[operator] + + # Invert and scale to 0-100 range + # Far from obstacles (max_distance) -> 0 + # At obstacles (0 distance) -> 100 + gradient_values = (1 - distance_meters / max_distance) * 100 + + # Ensure obstacles are exactly 100 + gradient_values[obstacle_map > 0] = CostValues.OCCUPIED + + # Convert to int8 for OccupancyGrid + gradient_data = gradient_values.astype(np.int8) + + # Preserve unknown cells as unknown (don't apply gradient to them) + gradient_data[unknown_mask] = CostValues.UNKNOWN + + # Create new OccupancyGrid with gradient + gradient_grid = OccupancyGrid( + grid=gradient_data, + resolution=self.resolution, + origin=self.origin, + frame_id=self.frame_id, + ts=self.ts, + ) + + return gradient_grid + + def filter_above(self, threshold: int) -> OccupancyGrid: + """Create a new OccupancyGrid with only values above threshold. + + Args: + threshold: Keep cells with values > threshold + + Returns: + New OccupancyGrid where: + - Cells > threshold: kept as-is + - Cells <= threshold: set to -1 (unknown) + - Unknown cells (-1): preserved + """ + new_grid = self.grid.copy() + + # Create mask for cells to filter (not unknown and <= threshold) + filter_mask = (new_grid != -1) & (new_grid <= threshold) + + # Set filtered cells to unknown + new_grid[filter_mask] = -1 + + # Create new OccupancyGrid + filtered = OccupancyGrid( + new_grid, + resolution=self.resolution, + origin=self.origin, + frame_id=self.frame_id, + ts=self.ts, + ) + + return filtered + + def filter_below(self, threshold: int) -> OccupancyGrid: + """Create a new OccupancyGrid with only values below threshold. + + Args: + threshold: Keep cells with values < threshold + + Returns: + New OccupancyGrid where: + - Cells < threshold: kept as-is + - Cells >= threshold: set to -1 (unknown) + - Unknown cells (-1): preserved + """ + new_grid = self.grid.copy() + + # Create mask for cells to filter (not unknown and >= threshold) + filter_mask = (new_grid != -1) & (new_grid >= threshold) + + # Set filtered cells to unknown + new_grid[filter_mask] = -1 + + # Create new OccupancyGrid + filtered = OccupancyGrid( + new_grid, + resolution=self.resolution, + origin=self.origin, + frame_id=self.frame_id, + ts=self.ts, + ) + + return filtered + + def max(self) -> OccupancyGrid: + """Create a new OccupancyGrid with all non-unknown cells set to maximum value. + + Returns: + New OccupancyGrid where: + - All non-unknown cells: set to CostValues.OCCUPIED (100) + - Unknown cells: preserved as CostValues.UNKNOWN (-1) + """ + new_grid = self.grid.copy() + + # Set all non-unknown cells to max + new_grid[new_grid != CostValues.UNKNOWN] = CostValues.OCCUPIED + + # Create new OccupancyGrid + maxed = OccupancyGrid( + new_grid, + resolution=self.resolution, + origin=self.origin, + frame_id=self.frame_id, + ts=self.ts, + ) + + return maxed diff --git a/dimos/msgs/nav_msgs/Odometry.py b/dimos/msgs/nav_msgs/Odometry.py new file mode 100644 index 0000000000..d7297f1725 --- /dev/null +++ b/dimos/msgs/nav_msgs/Odometry.py @@ -0,0 +1,381 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, TypeAlias + +from dimos_lcm.nav_msgs import Odometry as LCMOdometry # type: ignore[import-untyped] +import numpy as np +from plum import dispatch + +try: + from nav_msgs.msg import Odometry as ROSOdometry # type: ignore[attr-defined] +except ImportError: + ROSOdometry = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.types.timestamped import Timestamped + +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# Types that can be converted to/from Odometry +OdometryConvertable: TypeAlias = ( + LCMOdometry | dict[str, float | str | PoseWithCovariance | TwistWithCovariance | Pose | Twist] +) + + +def sec_nsec(ts): # type: ignore[no-untyped-def] + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class Odometry(LCMOdometry, Timestamped): # type: ignore[misc] + pose: PoseWithCovariance + twist: TwistWithCovariance + msg_name = "nav_msgs.Odometry" + ts: float + frame_id: str + child_frame_id: str + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + child_frame_id: str = "", + pose: PoseWithCovariance | Pose | None = None, + twist: TwistWithCovariance | Twist | None = None, + ) -> None: + """Initialize with timestamp, frame IDs, pose and twist. + + Args: + ts: Timestamp in seconds (defaults to current time if 0) + frame_id: Reference frame ID (e.g., "odom", "map") + child_frame_id: Child frame ID (e.g., "base_link", "base_footprint") + pose: Pose with covariance (or just Pose, covariance will be zero) + twist: Twist with covariance (or just Twist, covariance will be zero) + """ + self.ts = ts if ts != 0 else time.time() + self.frame_id = frame_id + self.child_frame_id = child_frame_id + + # Handle pose + if pose is None: + self.pose = PoseWithCovariance() + elif isinstance(pose, PoseWithCovariance): + self.pose = pose + elif isinstance(pose, Pose): + self.pose = PoseWithCovariance(pose) + else: + self.pose = PoseWithCovariance(Pose(pose)) + + # Handle twist + if twist is None: + self.twist = TwistWithCovariance() + elif isinstance(twist, TwistWithCovariance): + self.twist = twist + elif isinstance(twist, Twist): + self.twist = TwistWithCovariance(twist) + else: + self.twist = TwistWithCovariance(Twist(twist)) + + @dispatch # type: ignore[no-redef] + def __init__(self, odometry: Odometry) -> None: + """Initialize from another Odometry (copy constructor).""" + self.ts = odometry.ts + self.frame_id = odometry.frame_id + self.child_frame_id = odometry.child_frame_id + self.pose = PoseWithCovariance(odometry.pose) + self.twist = TwistWithCovariance(odometry.twist) + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_odometry: LCMOdometry) -> None: + """Initialize from an LCM Odometry.""" + self.ts = lcm_odometry.header.stamp.sec + (lcm_odometry.header.stamp.nsec / 1_000_000_000) + self.frame_id = lcm_odometry.header.frame_id + self.child_frame_id = lcm_odometry.child_frame_id + self.pose = PoseWithCovariance(lcm_odometry.pose) + self.twist = TwistWithCovariance(lcm_odometry.twist) + + @dispatch # type: ignore[no-redef] + def __init__( + self, + odometry_dict: dict[ + str, float | str | PoseWithCovariance | TwistWithCovariance | Pose | Twist + ], + ) -> None: + """Initialize from a dictionary.""" + self.ts = odometry_dict.get("ts", odometry_dict.get("timestamp", time.time())) + self.frame_id = odometry_dict.get("frame_id", "") + self.child_frame_id = odometry_dict.get("child_frame_id", "") + + # Handle pose + pose = odometry_dict.get("pose") + if pose is None: + self.pose = PoseWithCovariance() + elif isinstance(pose, PoseWithCovariance): + self.pose = pose + elif isinstance(pose, Pose): + self.pose = PoseWithCovariance(pose) + else: + self.pose = PoseWithCovariance(Pose(pose)) + + # Handle twist + twist = odometry_dict.get("twist") + if twist is None: + self.twist = TwistWithCovariance() + elif isinstance(twist, TwistWithCovariance): + self.twist = twist + elif isinstance(twist, Twist): + self.twist = TwistWithCovariance(twist) + else: + self.twist = TwistWithCovariance(Twist(twist)) + + @property + def position(self) -> Vector3: + """Get position from pose.""" + return self.pose.position + + @property + def orientation(self): # type: ignore[no-untyped-def] + """Get orientation from pose.""" + return self.pose.orientation + + @property + def linear_velocity(self) -> Vector3: + """Get linear velocity from twist.""" + return self.twist.linear + + @property + def angular_velocity(self) -> Vector3: + """Get angular velocity from twist.""" + return self.twist.angular + + @property + def x(self) -> float: + """X position.""" + return self.pose.x + + @property + def y(self) -> float: + """Y position.""" + return self.pose.y + + @property + def z(self) -> float: + """Z position.""" + return self.pose.z + + @property + def vx(self) -> float: + """Linear velocity in X.""" + return self.twist.linear.x + + @property + def vy(self) -> float: + """Linear velocity in Y.""" + return self.twist.linear.y + + @property + def vz(self) -> float: + """Linear velocity in Z.""" + return self.twist.linear.z + + @property + def wx(self) -> float: + """Angular velocity around X (roll rate).""" + return self.twist.angular.x + + @property + def wy(self) -> float: + """Angular velocity around Y (pitch rate).""" + return self.twist.angular.y + + @property + def wz(self) -> float: + """Angular velocity around Z (yaw rate).""" + return self.twist.angular.z + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.pose.roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.pose.pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.pose.yaw + + def __repr__(self) -> str: + return ( + f"Odometry(ts={self.ts:.6f}, frame_id='{self.frame_id}', " + f"child_frame_id='{self.child_frame_id}', pose={self.pose!r}, twist={self.twist!r})" + ) + + def __str__(self) -> str: + return ( + f"Odometry:\n" + f" Timestamp: {self.ts:.6f}\n" + f" Frame: {self.frame_id} -> {self.child_frame_id}\n" + f" Position: [{self.x:.3f}, {self.y:.3f}, {self.z:.3f}]\n" + f" Orientation: [roll={self.roll:.3f}, pitch={self.pitch:.3f}, yaw={self.yaw:.3f}]\n" + f" Linear Velocity: [{self.vx:.3f}, {self.vy:.3f}, {self.vz:.3f}]\n" + f" Angular Velocity: [{self.wx:.3f}, {self.wy:.3f}, {self.wz:.3f}]" + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two Odometry messages are equal.""" + if not isinstance(other, Odometry): + return False + return ( + abs(self.ts - other.ts) < 1e-6 + and self.frame_id == other.frame_id + and self.child_frame_id == other.child_frame_id + and self.pose == other.pose + and self.twist == other.twist + ) + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMOdometry() + + # Set header + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) # type: ignore[no-untyped-call] + lcm_msg.header.frame_id = self.frame_id + lcm_msg.child_frame_id = self.child_frame_id + + # Set pose with covariance + lcm_msg.pose.pose = self.pose.pose + if isinstance(self.pose.covariance, np.ndarray): # type: ignore[has-type] + lcm_msg.pose.covariance = self.pose.covariance.tolist() # type: ignore[has-type] + else: + lcm_msg.pose.covariance = list(self.pose.covariance) # type: ignore[has-type] + + # Set twist with covariance + lcm_msg.twist.twist = self.twist.twist + if isinstance(self.twist.covariance, np.ndarray): # type: ignore[has-type] + lcm_msg.twist.covariance = self.twist.covariance.tolist() # type: ignore[has-type] + else: + lcm_msg.twist.covariance = list(self.twist.covariance) # type: ignore[has-type] + + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> Odometry: + """Decode from LCM binary format.""" + lcm_msg = LCMOdometry.lcm_decode(data) + + # Extract timestamp + ts = lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000) + + # Create pose with covariance + pose = Pose( + position=[ + lcm_msg.pose.pose.position.x, + lcm_msg.pose.pose.position.y, + lcm_msg.pose.pose.position.z, + ], + orientation=[ + lcm_msg.pose.pose.orientation.x, + lcm_msg.pose.pose.orientation.y, + lcm_msg.pose.pose.orientation.z, + lcm_msg.pose.pose.orientation.w, + ], + ) + pose_with_cov = PoseWithCovariance(pose, lcm_msg.pose.covariance) + + # Create twist with covariance + twist = Twist( + linear=[ + lcm_msg.twist.twist.linear.x, + lcm_msg.twist.twist.linear.y, + lcm_msg.twist.twist.linear.z, + ], + angular=[ + lcm_msg.twist.twist.angular.x, + lcm_msg.twist.twist.angular.y, + lcm_msg.twist.twist.angular.z, + ], + ) + twist_with_cov = TwistWithCovariance(twist, lcm_msg.twist.covariance) + + return cls( + ts=ts, + frame_id=lcm_msg.header.frame_id, + child_frame_id=lcm_msg.child_frame_id, + pose=pose_with_cov, + twist=twist_with_cov, + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSOdometry) -> Odometry: + """Create an Odometry from a ROS nav_msgs/Odometry message. + + Args: + ros_msg: ROS Odometry message + + Returns: + Odometry instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose and twist with covariance + pose_with_cov = PoseWithCovariance.from_ros_msg(ros_msg.pose) + twist_with_cov = TwistWithCovariance.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + child_frame_id=ros_msg.child_frame_id, + pose=pose_with_cov, + twist=twist_with_cov, + ) + + def to_ros_msg(self) -> ROSOdometry: + """Convert to a ROS nav_msgs/Odometry message. + + Returns: + ROS Odometry message + """ + + ros_msg = ROSOdometry() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set child frame ID + ros_msg.child_frame_id = self.child_frame_id + + # Set pose with covariance + ros_msg.pose = self.pose.to_ros_msg() + + # Set twist with covariance + ros_msg.twist = self.twist.to_ros_msg() + + return ros_msg diff --git a/dimos/msgs/nav_msgs/Path.py b/dimos/msgs/nav_msgs/Path.py new file mode 100644 index 0000000000..61e5434b1e --- /dev/null +++ b/dimos/msgs/nav_msgs/Path.py @@ -0,0 +1,233 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, BinaryIO + +from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped] + Point as LCMPoint, + Pose as LCMPose, + PoseStamped as LCMPoseStamped, + Quaternion as LCMQuaternion, +) +from dimos_lcm.nav_msgs import Path as LCMPath # type: ignore[import-untyped] +from dimos_lcm.std_msgs import Header as LCMHeader, Time as LCMTime # type: ignore[import-untyped] + +try: + from nav_msgs.msg import Path as ROSPath # type: ignore[attr-defined] +except ImportError: + ROSPath = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.types.timestamped import Timestamped + +if TYPE_CHECKING: + from collections.abc import Iterator + + +def sec_nsec(ts): # type: ignore[no-untyped-def] + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class Path(Timestamped): + msg_name = "nav_msgs.Path" + ts: float + frame_id: str + poses: list[PoseStamped] + + def __init__( # type: ignore[no-untyped-def] + self, + ts: float = 0.0, + frame_id: str = "world", + poses: list[PoseStamped] | None = None, + **kwargs, + ) -> None: + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + self.poses = poses if poses is not None else [] + + def __len__(self) -> int: + """Return the number of poses in the path.""" + return len(self.poses) + + def __bool__(self) -> bool: + """Return True if path has poses.""" + return len(self.poses) > 0 + + def head(self) -> PoseStamped | None: + """Return the first pose in the path, or None if empty.""" + return self.poses[0] if self.poses else None + + def last(self) -> PoseStamped | None: + """Return the last pose in the path, or None if empty.""" + return self.poses[-1] if self.poses else None + + def tail(self) -> Path: + """Return a new Path with all poses except the first.""" + return Path(ts=self.ts, frame_id=self.frame_id, poses=self.poses[1:] if self.poses else []) + + def push(self, pose: PoseStamped) -> Path: + """Return a new Path with the pose appended (immutable).""" + return Path(ts=self.ts, frame_id=self.frame_id, poses=[*self.poses, pose]) + + def push_mut(self, pose: PoseStamped) -> None: + """Append a pose to this path (mutable).""" + self.poses.append(pose) + + def lcm_encode(self) -> bytes: + """Encode Path to LCM bytes.""" + lcm_msg = LCMPath() + + # Set poses + lcm_msg.poses_length = len(self.poses) + lcm_poses = [] # Build list separately to avoid LCM library reuse issues + for pose in self.poses: + lcm_pose = LCMPoseStamped() + # Create new pose objects to avoid LCM library reuse bug + lcm_pose.pose = LCMPose() + lcm_pose.pose.position = LCMPoint() + lcm_pose.pose.orientation = LCMQuaternion() + + # Set the pose geometry data + lcm_pose.pose.position.x = pose.x + lcm_pose.pose.position.y = pose.y + lcm_pose.pose.position.z = pose.z + lcm_pose.pose.orientation.x = pose.orientation.x + lcm_pose.pose.orientation.y = pose.orientation.y + lcm_pose.pose.orientation.z = pose.orientation.z + lcm_pose.pose.orientation.w = pose.orientation.w + + # Create new header to avoid reuse + lcm_pose.header = LCMHeader() + lcm_pose.header.stamp = LCMTime() + + # Set the header with pose timestamp but path's frame_id + [lcm_pose.header.stamp.sec, lcm_pose.header.stamp.nsec] = sec_nsec(pose.ts) # type: ignore[no-untyped-call] + lcm_pose.header.frame_id = self.frame_id # All poses use path's frame_id + lcm_poses.append(lcm_pose) + lcm_msg.poses = lcm_poses + + # Set header with path's own timestamp + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) # type: ignore[no-untyped-call] + lcm_msg.header.frame_id = self.frame_id + + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> Path: + """Decode LCM bytes to Path.""" + lcm_msg = LCMPath.lcm_decode(data) + + # Decode header + header_ts = lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000) + frame_id = lcm_msg.header.frame_id + + # Decode poses - all use the path's frame_id + poses = [] + for lcm_pose in lcm_msg.poses: + pose = PoseStamped( + ts=lcm_pose.header.stamp.sec + (lcm_pose.header.stamp.nsec / 1_000_000_000), + frame_id=frame_id, # Use path's frame_id for all poses + position=[ + lcm_pose.pose.position.x, + lcm_pose.pose.position.y, + lcm_pose.pose.position.z, + ], + orientation=[ + lcm_pose.pose.orientation.x, + lcm_pose.pose.orientation.y, + lcm_pose.pose.orientation.z, + lcm_pose.pose.orientation.w, + ], + ) + poses.append(pose) + + # Use header timestamp for the path + return cls(ts=header_ts, frame_id=frame_id, poses=poses) + + def __str__(self) -> str: + """String representation of Path.""" + return f"Path(frame_id='{self.frame_id}', poses={len(self.poses)})" + + def __getitem__(self, index: int | slice) -> PoseStamped | list[PoseStamped]: + """Allow indexing and slicing of poses.""" + return self.poses[index] + + def __iter__(self) -> Iterator: # type: ignore[type-arg] + """Allow iteration over poses.""" + return iter(self.poses) + + def slice(self, start: int, end: int | None = None) -> Path: + """Return a new Path with a slice of poses.""" + return Path(ts=self.ts, frame_id=self.frame_id, poses=self.poses[start:end]) + + def extend(self, other: Path) -> Path: + """Return a new Path with poses from both paths (immutable).""" + return Path(ts=self.ts, frame_id=self.frame_id, poses=self.poses + other.poses) + + def extend_mut(self, other: Path) -> None: + """Extend this path with poses from another path (mutable).""" + self.poses.extend(other.poses) + + def reverse(self) -> Path: + """Return a new Path with poses in reverse order.""" + return Path(ts=self.ts, frame_id=self.frame_id, poses=list(reversed(self.poses))) + + def clear(self) -> None: + """Clear all poses from this path (mutable).""" + self.poses.clear() + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPath) -> Path: + """Create a Path from a ROS nav_msgs/Path message. + + Args: + ros_msg: ROS Path message + + Returns: + Path instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert poses + poses = [] + for ros_pose_stamped in ros_msg.poses: + poses.append(PoseStamped.from_ros_msg(ros_pose_stamped)) + + return cls(ts=ts, frame_id=ros_msg.header.frame_id, poses=poses) + + def to_ros_msg(self) -> ROSPath: + """Convert to a ROS nav_msgs/Path message. + + Returns: + ROS Path message + """ + + ros_msg = ROSPath() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Convert poses + for pose in self.poses: + ros_msg.poses.append(pose.to_ros_msg()) + + return ros_msg diff --git a/dimos/msgs/nav_msgs/__init__.py b/dimos/msgs/nav_msgs/__init__.py new file mode 100644 index 0000000000..9d099068ad --- /dev/null +++ b/dimos/msgs/nav_msgs/__init__.py @@ -0,0 +1,9 @@ +from dimos.msgs.nav_msgs.OccupancyGrid import ( # type: ignore[attr-defined] + CostValues, + MapMetaData, + OccupancyGrid, +) +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path + +__all__ = ["CostValues", "MapMetaData", "OccupancyGrid", "Odometry", "Path"] diff --git a/dimos/msgs/nav_msgs/test_OccupancyGrid.py b/dimos/msgs/nav_msgs/test_OccupancyGrid.py new file mode 100644 index 0000000000..3ecd758f47 --- /dev/null +++ b/dimos/msgs/nav_msgs/test_OccupancyGrid.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the OccupancyGrid convenience class.""" + +import pickle + +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.utils.testing import get_data + + +def test_empty_grid() -> None: + """Test creating an empty grid.""" + grid = OccupancyGrid() + assert grid.width == 0 + assert grid.height == 0 + assert grid.grid.shape == (0,) + assert grid.total_cells == 0 + assert grid.frame_id == "world" + + +def test_grid_with_dimensions() -> None: + """Test creating a grid with specified dimensions.""" + grid = OccupancyGrid(width=10, height=10, resolution=0.1, frame_id="map") + assert grid.width == 10 + assert grid.height == 10 + assert grid.resolution == 0.1 + assert grid.frame_id == "map" + assert grid.grid.shape == (10, 10) + assert np.all(grid.grid == -1) # All unknown + assert grid.unknown_cells == 100 + assert grid.unknown_percent == 100.0 + + +def test_grid_from_numpy_array() -> None: + """Test creating a grid from a numpy array.""" + data = np.zeros((20, 30), dtype=np.int8) + data[5:10, 10:20] = 100 # Add some obstacles + data[15:18, 5:8] = -1 # Add unknown area + + origin = Pose(1.0, 2.0, 0.0) + grid = OccupancyGrid(grid=data, resolution=0.05, origin=origin, frame_id="odom") + + assert grid.width == 30 + assert grid.height == 20 + assert grid.resolution == 0.05 + assert grid.frame_id == "odom" + assert grid.origin.position.x == 1.0 + assert grid.origin.position.y == 2.0 + assert grid.grid.shape == (20, 30) + + # Check cell counts + assert grid.occupied_cells == 50 # 5x10 obstacle area + assert grid.free_cells == 541 # Total - occupied - unknown + assert grid.unknown_cells == 9 # 3x3 unknown area + + # Check percentages (approximately) + assert abs(grid.occupied_percent - 8.33) < 0.1 + assert abs(grid.free_percent - 90.17) < 0.1 + assert abs(grid.unknown_percent - 1.5) < 0.1 + + +def test_world_grid_coordinate_conversion() -> None: + """Test converting between world and grid coordinates.""" + data = np.zeros((20, 30), dtype=np.int8) + origin = Pose(1.0, 2.0, 0.0) + grid = OccupancyGrid(grid=data, resolution=0.05, origin=origin, frame_id="odom") + + # Test world to grid + grid_pos = grid.world_to_grid((2.5, 3.0)) + assert int(grid_pos.x) == 30 + assert int(grid_pos.y) == 20 + + # Test grid to world + world_pos = grid.grid_to_world((10, 5)) + assert world_pos.x == 1.5 + assert world_pos.y == 2.25 + + +def test_lcm_encode_decode() -> None: + """Test LCM encoding and decoding.""" + data = np.zeros((20, 30), dtype=np.int8) + data[5:10, 10:20] = 100 # Add some obstacles + data[15:18, 5:8] = -1 # Add unknown area + origin = Pose(1.0, 2.0, 0.0) + grid = OccupancyGrid(grid=data, resolution=0.05, origin=origin, frame_id="odom") + + # Set a specific value for testing + # Convert world coordinates to grid indices + grid_pos = grid.world_to_grid((1.5, 2.25)) + grid.grid[int(grid_pos.y), int(grid_pos.x)] = 50 + + # Encode + lcm_data = grid.lcm_encode() + assert isinstance(lcm_data, bytes) + assert len(lcm_data) > 0 + + # Decode + decoded = OccupancyGrid.lcm_decode(lcm_data) + + # Check that data matches exactly (grid arrays should be identical) + assert np.array_equal(grid.grid, decoded.grid) + assert grid.width == decoded.width + assert grid.height == decoded.height + assert abs(grid.resolution - decoded.resolution) < 1e-6 # Use approximate equality for floats + assert abs(grid.origin.position.x - decoded.origin.position.x) < 1e-6 + assert abs(grid.origin.position.y - decoded.origin.position.y) < 1e-6 + assert grid.frame_id == decoded.frame_id + + # Check that the actual grid data was preserved (don't rely on float conversions) + assert decoded.grid[5, 10] == 50 # Value we set should be preserved in grid + + +def test_string_representation() -> None: + """Test string representations.""" + grid = OccupancyGrid(width=10, height=10, resolution=0.1, frame_id="map") + + # Test __str__ + str_repr = str(grid) + assert "OccupancyGrid[map]" in str_repr + assert "10x10" in str_repr + assert "1.0x1.0m" in str_repr + assert "10cm res" in str_repr + + # Test __repr__ + repr_str = repr(grid) + assert "OccupancyGrid(" in repr_str + assert "width=10" in repr_str + assert "height=10" in repr_str + assert "resolution=0.1" in repr_str + + +def test_grid_property_sync() -> None: + """Test that the grid property works correctly.""" + grid = OccupancyGrid(width=5, height=5, resolution=0.1, frame_id="map") + + # Modify via numpy array + grid.grid[2, 3] = 100 + assert grid.grid[2, 3] == 100 + + # Check that we can access grid values + grid.grid[0, 0] = 50 + assert grid.grid[0, 0] == 50 + + +def test_invalid_grid_dimensions() -> None: + """Test handling of invalid grid dimensions.""" + # Test with non-2D array + with pytest.raises(ValueError, match="Grid must be a 2D array"): + OccupancyGrid(grid=np.zeros(10), resolution=0.1) + + +def test_from_pointcloud() -> None: + """Test creating OccupancyGrid from PointCloud2.""" + file_path = get_data("lcm_msgs") / "sensor_msgs/PointCloud2.pickle" + with open(file_path, "rb") as f: + lcm_msg = pickle.loads(f.read()) + + pointcloud = PointCloud2.lcm_decode(lcm_msg) + + # Convert pointcloud to occupancy grid + occupancygrid = OccupancyGrid.from_pointcloud( + pointcloud, resolution=0.05, min_height=0.1, max_height=2.0 + ) + # Apply inflation separately if needed + occupancygrid = occupancygrid.inflate(0.1) + + # Check that grid was created with reasonable properties + assert occupancygrid.width > 0 + assert occupancygrid.height > 0 + assert occupancygrid.resolution == 0.05 + assert occupancygrid.frame_id == pointcloud.frame_id + assert occupancygrid.occupied_cells > 0 # Should have some occupied cells + + +def test_gradient() -> None: + """Test converting occupancy grid to gradient field.""" + # Create a small test grid with an obstacle in the middle + data = np.zeros((10, 10), dtype=np.int8) + data[4:6, 4:6] = 100 # 2x2 obstacle in center + + grid = OccupancyGrid(grid=data, resolution=0.1) # 0.1m per cell + + # Convert to gradient + gradient_grid = grid.gradient(obstacle_threshold=50, max_distance=1.0) + + # Check that we get an OccupancyGrid back + assert isinstance(gradient_grid, OccupancyGrid) + assert gradient_grid.grid.shape == (10, 10) + assert gradient_grid.resolution == grid.resolution + assert gradient_grid.frame_id == grid.frame_id + + # Obstacle cells should have value 100 + assert gradient_grid.grid[4, 4] == 100 + assert gradient_grid.grid[5, 5] == 100 + + # Adjacent cells should have high values (near obstacles) + assert gradient_grid.grid[3, 4] > 85 # Very close to obstacle + assert gradient_grid.grid[4, 3] > 85 # Very close to obstacle + + # Cells at moderate distance should have moderate values + assert 30 < gradient_grid.grid[0, 0] < 60 # Corner is ~0.57m away + + # Check that gradient decreases with distance + assert gradient_grid.grid[3, 4] > gradient_grid.grid[2, 4] # Closer is higher + assert gradient_grid.grid[2, 4] > gradient_grid.grid[0, 4] # Further is lower + + # Test with unknown cells + data_with_unknown = data.copy() + data_with_unknown[0:2, 0:2] = -1 # Add unknown area (close to obstacle) + data_with_unknown[8:10, 8:10] = -1 # Add unknown area (far from obstacle) + + grid_with_unknown = OccupancyGrid(data_with_unknown, resolution=0.1) + gradient_with_unknown = grid_with_unknown.gradient(max_distance=1.0) # 1m max distance + + # Unknown cells should remain unknown (new behavior - unknowns are preserved) + assert gradient_with_unknown.grid[0, 0] == -1 # Should remain unknown + assert gradient_with_unknown.grid[1, 1] == -1 # Should remain unknown + assert gradient_with_unknown.grid[8, 8] == -1 # Should remain unknown + assert gradient_with_unknown.grid[9, 9] == -1 # Should remain unknown + + # Unknown cells count should be preserved + assert gradient_with_unknown.unknown_cells == 8 # All unknowns preserved + + +def test_filter_above() -> None: + """Test filtering cells above threshold.""" + # Create test grid with various values + data = np.array( + [[-1, 0, 20, 50], [10, 30, 60, 80], [40, 70, 90, 100], [-1, 15, 25, -1]], dtype=np.int8 + ) + + grid = OccupancyGrid(grid=data, resolution=0.1) + + # Filter to keep only values > 50 + filtered = grid.filter_above(50) + + # Check that values > 50 are preserved + assert filtered.grid[1, 2] == 60 + assert filtered.grid[1, 3] == 80 + assert filtered.grid[2, 1] == 70 + assert filtered.grid[2, 2] == 90 + assert filtered.grid[2, 3] == 100 + + # Check that values <= 50 are set to -1 (unknown) + assert filtered.grid[0, 1] == -1 # was 0 + assert filtered.grid[0, 2] == -1 # was 20 + assert filtered.grid[0, 3] == -1 # was 50 + assert filtered.grid[1, 0] == -1 # was 10 + assert filtered.grid[1, 1] == -1 # was 30 + assert filtered.grid[2, 0] == -1 # was 40 + + # Check that unknown cells are preserved + assert filtered.grid[0, 0] == -1 + assert filtered.grid[3, 0] == -1 + assert filtered.grid[3, 3] == -1 + + # Check dimensions and metadata preserved + assert filtered.width == grid.width + assert filtered.height == grid.height + assert filtered.resolution == grid.resolution + assert filtered.frame_id == grid.frame_id + + +def test_filter_below() -> None: + """Test filtering cells below threshold.""" + # Create test grid with various values + data = np.array( + [[-1, 0, 20, 50], [10, 30, 60, 80], [40, 70, 90, 100], [-1, 15, 25, -1]], dtype=np.int8 + ) + + grid = OccupancyGrid(grid=data, resolution=0.1) + + # Filter to keep only values < 50 + filtered = grid.filter_below(50) + + # Check that values < 50 are preserved + assert filtered.grid[0, 1] == 0 + assert filtered.grid[0, 2] == 20 + assert filtered.grid[1, 0] == 10 + assert filtered.grid[1, 1] == 30 + assert filtered.grid[2, 0] == 40 + assert filtered.grid[3, 1] == 15 + assert filtered.grid[3, 2] == 25 + + # Check that values >= 50 are set to -1 (unknown) + assert filtered.grid[0, 3] == -1 # was 50 + assert filtered.grid[1, 2] == -1 # was 60 + assert filtered.grid[1, 3] == -1 # was 80 + assert filtered.grid[2, 1] == -1 # was 70 + assert filtered.grid[2, 2] == -1 # was 90 + assert filtered.grid[2, 3] == -1 # was 100 + + # Check that unknown cells are preserved + assert filtered.grid[0, 0] == -1 + assert filtered.grid[3, 0] == -1 + assert filtered.grid[3, 3] == -1 + + # Check dimensions and metadata preserved + assert filtered.width == grid.width + assert filtered.height == grid.height + assert filtered.resolution == grid.resolution + assert filtered.frame_id == grid.frame_id + + +def test_max() -> None: + """Test setting all non-unknown cells to maximum.""" + # Create test grid with various values + data = np.array( + [[-1, 0, 20, 50], [10, 30, 60, 80], [40, 70, 90, 100], [-1, 15, 25, -1]], dtype=np.int8 + ) + + grid = OccupancyGrid(grid=data, resolution=0.1) + + # Apply max + maxed = grid.max() + + # Check that all non-unknown cells are set to 100 + assert maxed.grid[0, 1] == 100 # was 0 + assert maxed.grid[0, 2] == 100 # was 20 + assert maxed.grid[0, 3] == 100 # was 50 + assert maxed.grid[1, 0] == 100 # was 10 + assert maxed.grid[1, 1] == 100 # was 30 + assert maxed.grid[1, 2] == 100 # was 60 + assert maxed.grid[1, 3] == 100 # was 80 + assert maxed.grid[2, 0] == 100 # was 40 + assert maxed.grid[2, 1] == 100 # was 70 + assert maxed.grid[2, 2] == 100 # was 90 + assert maxed.grid[2, 3] == 100 # was 100 (already max) + assert maxed.grid[3, 1] == 100 # was 15 + assert maxed.grid[3, 2] == 100 # was 25 + + # Check that unknown cells are preserved + assert maxed.grid[0, 0] == -1 + assert maxed.grid[3, 0] == -1 + assert maxed.grid[3, 3] == -1 + + # Check dimensions and metadata preserved + assert maxed.width == grid.width + assert maxed.height == grid.height + assert maxed.resolution == grid.resolution + assert maxed.frame_id == grid.frame_id + + # Verify statistics + assert maxed.unknown_cells == 3 # Same as original + assert maxed.occupied_cells == 13 # All non-unknown cells + assert maxed.free_cells == 0 # No free cells + + +@pytest.mark.lcm +def test_lcm_broadcast() -> None: + """Test broadcasting OccupancyGrid and gradient over LCM.""" + file_path = get_data("lcm_msgs") / "sensor_msgs/PointCloud2.pickle" + with open(file_path, "rb") as f: + lcm_msg = pickle.loads(f.read()) + + pointcloud = PointCloud2.lcm_decode(lcm_msg) + + # Create occupancy grid from pointcloud + occupancygrid = OccupancyGrid.from_pointcloud( + pointcloud, resolution=0.05, min_height=0.1, max_height=2.0 + ) + # Apply inflation separately if needed + occupancygrid = occupancygrid.inflate(0.1) + + # Create gradient field with larger max_distance for better visualization + gradient_grid = occupancygrid.gradient(obstacle_threshold=70, max_distance=2.0) + + # Debug: Print actual values to see the difference + print("\n=== DEBUG: Comparing grids ===") + print(f"Original grid unique values: {np.unique(occupancygrid.grid)}") + print(f"Gradient grid unique values: {np.unique(gradient_grid.grid)}") + + # Find an area with occupied cells to show the difference + occupied_indices = np.argwhere(occupancygrid.grid == 100) + if len(occupied_indices) > 0: + # Pick a point near an occupied cell + idx = len(occupied_indices) // 2 # Middle occupied cell + sample_y, sample_x = occupied_indices[idx] + sample_size = 15 + + # Ensure we don't go out of bounds + y_start = max(0, sample_y - sample_size // 2) + y_end = min(occupancygrid.height, y_start + sample_size) + x_start = max(0, sample_x - sample_size // 2) + x_end = min(occupancygrid.width, x_start + sample_size) + + print(f"\nSample area around occupied cell ({sample_x}, {sample_y}):") + print("Original occupancy grid:") + print(occupancygrid.grid[y_start:y_end, x_start:x_end]) + print("\nGradient grid (same area):") + print(gradient_grid.grid[y_start:y_end, x_start:x_end]) + else: + print("\nNo occupied cells found for sampling") + + # Check statistics + print("\nOriginal grid stats:") + print(f" Occupied (100): {np.sum(occupancygrid.grid == 100)} cells") + print(f" Inflated (99): {np.sum(occupancygrid.grid == 99)} cells") + print(f" Free (0): {np.sum(occupancygrid.grid == 0)} cells") + print(f" Unknown (-1): {np.sum(occupancygrid.grid == -1)} cells") + + print("\nGradient grid stats:") + print(f" Max gradient (100): {np.sum(gradient_grid.grid == 100)} cells") + print( + f" High gradient (80-99): {np.sum((gradient_grid.grid >= 80) & (gradient_grid.grid < 100))} cells" + ) + print( + f" Medium gradient (40-79): {np.sum((gradient_grid.grid >= 40) & (gradient_grid.grid < 80))} cells" + ) + print( + f" Low gradient (1-39): {np.sum((gradient_grid.grid >= 1) & (gradient_grid.grid < 40))} cells" + ) + print(f" Zero gradient (0): {np.sum(gradient_grid.grid == 0)} cells") + print(f" Unknown (-1): {np.sum(gradient_grid.grid == -1)} cells") + + # # Save debug images + # import matplotlib.pyplot as plt + + # fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + + # # Original + # ax = axes[0] + # im1 = ax.imshow(occupancygrid.grid, origin="lower", cmap="gray_r", vmin=-1, vmax=100) + # ax.set_title(f"Original Occupancy Grid\n{occupancygrid}") + # plt.colorbar(im1, ax=ax) + + # # Gradient + # ax = axes[1] + # im2 = ax.imshow(gradient_grid.grid, origin="lower", cmap="hot", vmin=-1, vmax=100) + # ax.set_title(f"Gradient Grid\n{gradient_grid}") + # plt.colorbar(im2, ax=ax) + + # plt.tight_layout() + # plt.savefig("lcm_debug_grids.png", dpi=150) + # print("\nSaved debug visualization to lcm_debug_grids.png") + # plt.close() + + # Broadcast all the data + lcm = LCM() + lcm.start() + lcm.publish(Topic("/global_map", PointCloud2), pointcloud) + lcm.publish(Topic("/global_costmap", OccupancyGrid), occupancygrid) + lcm.publish(Topic("/global_gradient", OccupancyGrid), gradient_grid) + + print("\nPublished to LCM:") + print(f" /global_map: PointCloud2 with {len(pointcloud)} points") + print(f" /global_costmap: {occupancygrid}") + print(f" /global_gradient: {gradient_grid}") + print("\nGradient info:") + print(" Values: 0 (free far from obstacles) -> 100 (at obstacles)") + print(f" Unknown cells: {gradient_grid.unknown_cells} (preserved as -1)") + print(" Max distance for gradient: 5.0 meters") diff --git a/dimos/msgs/nav_msgs/test_Odometry.py b/dimos/msgs/nav_msgs/test_Odometry.py new file mode 100644 index 0000000000..ecdc83c6b4 --- /dev/null +++ b/dimos/msgs/nav_msgs/test_Odometry.py @@ -0,0 +1,504 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest + +try: + from builtin_interfaces.msg import Time as ROSTime + from geometry_msgs.msg import ( + Point as ROSPoint, + Pose as ROSPose, + PoseWithCovariance as ROSPoseWithCovariance, + Quaternion as ROSQuaternion, + Twist as ROSTwist, + TwistWithCovariance as ROSTwistWithCovariance, + Vector3 as ROSVector3, + ) + from nav_msgs.msg import Odometry as ROSOdometry + from std_msgs.msg import Header as ROSHeader +except ImportError: + ROSTwist = None + ROSHeader = None + ROSPose = None + ROSPoseWithCovariance = None + ROSQuaternion = None + ROSOdometry = None + ROSPoint = None + ROSTime = None + ROSTwistWithCovariance = None + ROSVector3 = None + + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry + + +def test_odometry_default_init() -> None: + """Test default initialization.""" + if ROSVector3 is None: + pytest.skip("ROS not available") + if ROSTwistWithCovariance is None: + pytest.skip("ROS not available") + if ROSTime is None: + pytest.skip("ROS not available") + if ROSPoint is None: + pytest.skip("ROS not available") + if ROSOdometry is None: + pytest.skip("ROS not available") + if ROSQuaternion is None: + pytest.skip("ROS not available") + if ROSPoseWithCovariance is None: + pytest.skip("ROS not available") + if ROSPose is None: + pytest.skip("ROS not available") + if ROSHeader is None: + pytest.skip("ROS not available") + if ROSTwist is None: + pytest.skip("ROS not available") + odom = Odometry() + + # Should have current timestamp + assert odom.ts > 0 + assert odom.frame_id == "" + assert odom.child_frame_id == "" + + # Pose should be at origin with identity orientation + assert odom.pose.position.x == 0.0 + assert odom.pose.position.y == 0.0 + assert odom.pose.position.z == 0.0 + assert odom.pose.orientation.w == 1.0 + + # Twist should be zero + assert odom.twist.linear.x == 0.0 + assert odom.twist.linear.y == 0.0 + assert odom.twist.linear.z == 0.0 + assert odom.twist.angular.x == 0.0 + assert odom.twist.angular.y == 0.0 + assert odom.twist.angular.z == 0.0 + + # Covariances should be zero + assert np.all(odom.pose.covariance == 0.0) + assert np.all(odom.twist.covariance == 0.0) + + +def test_odometry_with_frames() -> None: + """Test initialization with frame IDs.""" + ts = 1234567890.123456 + frame_id = "odom" + child_frame_id = "base_link" + + odom = Odometry(ts=ts, frame_id=frame_id, child_frame_id=child_frame_id) + + assert odom.ts == ts + assert odom.frame_id == frame_id + assert odom.child_frame_id == child_frame_id + + +def test_odometry_with_pose_and_twist() -> None: + """Test initialization with pose and twist.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + + odom = Odometry(ts=1000.0, frame_id="odom", child_frame_id="base_link", pose=pose, twist=twist) + + assert odom.pose.pose.position.x == 1.0 + assert odom.pose.pose.position.y == 2.0 + assert odom.pose.pose.position.z == 3.0 + assert odom.twist.twist.linear.x == 0.5 + assert odom.twist.twist.angular.z == 0.1 + + +def test_odometry_with_covariances() -> None: + """Test initialization with pose and twist with covariances.""" + pose = Pose(1.0, 2.0, 3.0) + pose_cov = np.arange(36, dtype=float) + pose_with_cov = PoseWithCovariance(pose, pose_cov) + + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + twist_cov = np.arange(36, 72, dtype=float) + twist_with_cov = TwistWithCovariance(twist, twist_cov) + + odom = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=pose_with_cov, + twist=twist_with_cov, + ) + + assert odom.pose.position.x == 1.0 + assert np.array_equal(odom.pose.covariance, pose_cov) + assert odom.twist.linear.x == 0.5 + assert np.array_equal(odom.twist.covariance, twist_cov) + + +def test_odometry_copy_constructor() -> None: + """Test copy constructor.""" + original = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + copy = Odometry(original) + + assert copy == original + assert copy is not original + assert copy.pose is not original.pose + assert copy.twist is not original.twist + + +def test_odometry_dict_init() -> None: + """Test initialization from dictionary.""" + odom_dict = { + "ts": 1000.0, + "frame_id": "odom", + "child_frame_id": "base_link", + "pose": Pose(1.0, 2.0, 3.0), + "twist": Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + } + + odom = Odometry(odom_dict) + + assert odom.ts == 1000.0 + assert odom.frame_id == "odom" + assert odom.child_frame_id == "base_link" + assert odom.pose.position.x == 1.0 + assert odom.twist.linear.x == 0.5 + + +def test_odometry_properties() -> None: + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + + odom = Odometry(ts=1000.0, frame_id="odom", child_frame_id="base_link", pose=pose, twist=twist) + + # Position properties + assert odom.x == 1.0 + assert odom.y == 2.0 + assert odom.z == 3.0 + assert odom.position.x == 1.0 + assert odom.position.y == 2.0 + assert odom.position.z == 3.0 + + # Orientation properties + assert odom.orientation.x == 0.1 + assert odom.orientation.y == 0.2 + assert odom.orientation.z == 0.3 + assert odom.orientation.w == 0.9 + + # Velocity properties + assert odom.vx == 0.5 + assert odom.vy == 0.6 + assert odom.vz == 0.7 + assert odom.linear_velocity.x == 0.5 + assert odom.linear_velocity.y == 0.6 + assert odom.linear_velocity.z == 0.7 + + # Angular velocity properties + assert odom.wx == 0.1 + assert odom.wy == 0.2 + assert odom.wz == 0.3 + assert odom.angular_velocity.x == 0.1 + assert odom.angular_velocity.y == 0.2 + assert odom.angular_velocity.z == 0.3 + + # Euler angles + assert odom.roll == pose.roll + assert odom.pitch == pose.pitch + assert odom.yaw == pose.yaw + + +def test_odometry_str_repr() -> None: + """Test string representations.""" + odom = Odometry( + ts=1234567890.123456, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.234, 2.567, 3.891), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + repr_str = repr(odom) + assert "Odometry" in repr_str + assert "1234567890.123456" in repr_str + assert "odom" in repr_str + assert "base_link" in repr_str + + str_repr = str(odom) + assert "Odometry" in str_repr + assert "odom -> base_link" in str_repr + assert "1.234" in str_repr + assert "0.500" in str_repr + + +def test_odometry_equality() -> None: + """Test equality comparison.""" + odom1 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + odom2 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + odom3 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.1, 2.0, 3.0), # Different position + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + assert odom1 == odom2 + assert odom1 != odom3 + assert odom1 != "not an odometry" + + +def test_odometry_lcm_encode_decode() -> None: + """Test LCM encoding and decoding.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = np.arange(36, dtype=float) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + twist_cov = np.arange(36, 72, dtype=float) + + source = Odometry( + ts=1234567890.123456, + frame_id="odom", + child_frame_id="base_link", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = Odometry.lcm_decode(binary_msg) + + # Check values (allowing for timestamp precision loss) + assert abs(decoded.ts - source.ts) < 1e-6 + assert decoded.frame_id == source.frame_id + assert decoded.child_frame_id == source.child_frame_id + assert decoded.pose == source.pose + assert decoded.twist == source.twist + + +@pytest.mark.ros +def test_odometry_from_ros_msg() -> None: + """Test creating from ROS message.""" + ros_msg = ROSOdometry() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "odom" + ros_msg.child_frame_id = "base_link" + + # Set pose with covariance + ros_msg.pose = ROSPoseWithCovariance() + ros_msg.pose.pose = ROSPose() + ros_msg.pose.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.pose.covariance = [float(i) for i in range(36)] + + # Set twist with covariance + ros_msg.twist = ROSTwistWithCovariance() + ros_msg.twist.twist = ROSTwist() + ros_msg.twist.twist.linear = ROSVector3(x=0.5, y=0.6, z=0.7) + ros_msg.twist.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.twist.covariance = [float(i) for i in range(36, 72)] + + odom = Odometry.from_ros_msg(ros_msg) + + assert odom.ts == 1234567890.123456 + assert odom.frame_id == "odom" + assert odom.child_frame_id == "base_link" + assert odom.pose.position.x == 1.0 + assert odom.twist.linear.x == 0.5 + assert np.array_equal(odom.pose.covariance, np.arange(36)) + assert np.array_equal(odom.twist.covariance, np.arange(36, 72)) + + +@pytest.mark.ros +def test_odometry_to_ros_msg() -> None: + """Test converting to ROS message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = np.arange(36, dtype=float) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + twist_cov = np.arange(36, 72, dtype=float) + + odom = Odometry( + ts=1234567890.567890, + frame_id="odom", + child_frame_id="base_link", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + ros_msg = odom.to_ros_msg() + + assert isinstance(ros_msg, ROSOdometry) + assert ros_msg.header.frame_id == "odom" + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + assert ros_msg.child_frame_id == "base_link" + + # Check pose + assert ros_msg.pose.pose.position.x == 1.0 + assert ros_msg.pose.pose.position.y == 2.0 + assert ros_msg.pose.pose.position.z == 3.0 + assert ros_msg.pose.pose.orientation.x == 0.1 + assert ros_msg.pose.pose.orientation.y == 0.2 + assert ros_msg.pose.pose.orientation.z == 0.3 + assert ros_msg.pose.pose.orientation.w == 0.9 + assert list(ros_msg.pose.covariance) == list(range(36)) + + # Check twist + assert ros_msg.twist.twist.linear.x == 0.5 + assert ros_msg.twist.twist.linear.y == 0.6 + assert ros_msg.twist.twist.linear.z == 0.7 + assert ros_msg.twist.twist.angular.x == 0.1 + assert ros_msg.twist.twist.angular.y == 0.2 + assert ros_msg.twist.twist.angular.z == 0.3 + assert list(ros_msg.twist.covariance) == list(range(36, 72)) + + +@pytest.mark.ros +def test_odometry_ros_roundtrip() -> None: + """Test round-trip conversion with ROS messages.""" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + pose_cov = np.random.rand(36) + twist = Twist(Vector3(0.55, 0.65, 0.75), Vector3(0.15, 0.25, 0.35)) + twist_cov = np.random.rand(36) + + original = Odometry( + ts=2147483647.987654, # Max int32 value for ROS Time.sec + frame_id="world", + child_frame_id="robot", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + ros_msg = original.to_ros_msg() + restored = Odometry.from_ros_msg(ros_msg) + + # Check values (allowing for timestamp precision loss) + assert abs(restored.ts - original.ts) < 1e-6 + assert restored.frame_id == original.frame_id + assert restored.child_frame_id == original.child_frame_id + assert restored.pose == original.pose + assert restored.twist == original.twist + + +def test_odometry_zero_timestamp() -> None: + """Test that zero timestamp gets replaced with current time.""" + odom = Odometry(ts=0.0) + + # Should have been replaced with current time + assert odom.ts > 0 + assert odom.ts <= time.time() + + +def test_odometry_with_just_pose() -> None: + """Test initialization with just a Pose (no covariance).""" + pose = Pose(1.0, 2.0, 3.0) + + odom = Odometry(pose=pose) + + assert odom.pose.position.x == 1.0 + assert odom.pose.position.y == 2.0 + assert odom.pose.position.z == 3.0 + assert np.all(odom.pose.covariance == 0.0) # Should have zero covariance + assert np.all(odom.twist.covariance == 0.0) # Twist should also be zero + + +def test_odometry_with_just_twist() -> None: + """Test initialization with just a Twist (no covariance).""" + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + + odom = Odometry(twist=twist) + + assert odom.twist.linear.x == 0.5 + assert odom.twist.angular.z == 0.1 + assert np.all(odom.twist.covariance == 0.0) # Should have zero covariance + assert np.all(odom.pose.covariance == 0.0) # Pose should also be zero + + +@pytest.mark.ros +@pytest.mark.parametrize( + "frame_id,child_frame_id", + [ + ("odom", "base_link"), + ("map", "odom"), + ("world", "robot"), + ("base_link", "camera_link"), + ("", ""), # Empty frames + ], +) +def test_odometry_frame_combinations(frame_id, child_frame_id) -> None: + """Test various frame ID combinations.""" + odom = Odometry(frame_id=frame_id, child_frame_id=child_frame_id) + + assert odom.frame_id == frame_id + assert odom.child_frame_id == child_frame_id + + # Test roundtrip through ROS + ros_msg = odom.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + assert ros_msg.child_frame_id == child_frame_id + + restored = Odometry.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + assert restored.child_frame_id == child_frame_id + + +def test_odometry_typical_robot_scenario() -> None: + """Test a typical robot odometry scenario.""" + # Robot moving forward at 0.5 m/s with slight rotation + odom = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_footprint", + pose=Pose(10.0, 5.0, 0.0, 0.0, 0.0, np.sin(0.1), np.cos(0.1)), # 0.2 rad yaw + twist=Twist( + Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.05) + ), # Moving forward, turning slightly + ) + + # Check we can access all the typical properties + assert odom.x == 10.0 + assert odom.y == 5.0 + assert odom.z == 0.0 + assert abs(odom.yaw - 0.2) < 0.01 # Approximately 0.2 radians + assert odom.vx == 0.5 # Forward velocity + assert odom.wz == 0.05 # Yaw rate diff --git a/dimos/msgs/nav_msgs/test_Path.py b/dimos/msgs/nav_msgs/test_Path.py new file mode 100644 index 0000000000..d933123b2b --- /dev/null +++ b/dimos/msgs/nav_msgs/test_Path.py @@ -0,0 +1,391 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +try: + from geometry_msgs.msg import PoseStamped as ROSPoseStamped + from nav_msgs.msg import Path as ROSPath +except ImportError: + ROSPoseStamped = None + ROSPath = None + +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.nav_msgs.Path import Path + + +def create_test_pose(x: float, y: float, z: float, frame_id: str = "map") -> PoseStamped: + """Helper to create a test PoseStamped.""" + return PoseStamped( + frame_id=frame_id, + position=[x, y, z], + orientation=Quaternion(0, 0, 0, 1), # Identity quaternion + ) + + +def test_init_empty() -> None: + """Test creating an empty path.""" + path = Path(frame_id="map") + assert path.frame_id == "map" + assert len(path) == 0 + assert not path # Should be falsy when empty + assert path.poses == [] + + +def test_init_with_poses() -> None: + """Test creating a path with initial poses.""" + poses = [create_test_pose(i, i, 0) for i in range(3)] + path = Path(frame_id="map", poses=poses) + assert len(path) == 3 + assert bool(path) # Should be truthy when has poses + assert path.poses == poses + + +def test_head() -> None: + """Test getting the first pose.""" + poses = [create_test_pose(i, i, 0) for i in range(3)] + path = Path(poses=poses) + assert path.head() == poses[0] + + # Test empty path + empty_path = Path() + assert empty_path.head() is None + + +def test_last() -> None: + """Test getting the last pose.""" + poses = [create_test_pose(i, i, 0) for i in range(3)] + path = Path(poses=poses) + assert path.last() == poses[-1] + + # Test empty path + empty_path = Path() + assert empty_path.last() is None + + +def test_tail() -> None: + """Test getting all poses except the first.""" + poses = [create_test_pose(i, i, 0) for i in range(3)] + path = Path(poses=poses) + tail = path.tail() + assert len(tail) == 2 + assert tail.poses == poses[1:] + assert tail.frame_id == path.frame_id + + # Test single element path + single_path = Path(poses=[poses[0]]) + assert len(single_path.tail()) == 0 + + # Test empty path + empty_path = Path() + assert len(empty_path.tail()) == 0 + + +def test_push_immutable() -> None: + """Test immutable push operation.""" + path = Path(frame_id="map") + pose1 = create_test_pose(1, 1, 0) + pose2 = create_test_pose(2, 2, 0) + + # Push should return new path + path2 = path.push(pose1) + assert len(path) == 0 # Original unchanged + assert len(path2) == 1 + assert path2.poses[0] == pose1 + + # Chain pushes + path3 = path2.push(pose2) + assert len(path2) == 1 # Previous unchanged + assert len(path3) == 2 + assert path3.poses == [pose1, pose2] + + +def test_push_mutable() -> None: + """Test mutable push operation.""" + path = Path(frame_id="map") + pose1 = create_test_pose(1, 1, 0) + pose2 = create_test_pose(2, 2, 0) + + # Push should modify in place + path.push_mut(pose1) + assert len(path) == 1 + assert path.poses[0] == pose1 + + path.push_mut(pose2) + assert len(path) == 2 + assert path.poses == [pose1, pose2] + + +def test_indexing() -> None: + """Test indexing and slicing.""" + poses = [create_test_pose(i, i, 0) for i in range(5)] + path = Path(poses=poses) + + # Single index + assert path[0] == poses[0] + assert path[-1] == poses[-1] + + # Slicing + assert path[1:3] == poses[1:3] + assert path[:2] == poses[:2] + assert path[3:] == poses[3:] + + +def test_iteration() -> None: + """Test iterating over poses.""" + poses = [create_test_pose(i, i, 0) for i in range(3)] + path = Path(poses=poses) + + collected = [] + for pose in path: + collected.append(pose) + assert collected == poses + + +def test_slice_method() -> None: + """Test slice method.""" + poses = [create_test_pose(i, i, 0) for i in range(5)] + path = Path(frame_id="map", poses=poses) + + sliced = path.slice(1, 4) + assert len(sliced) == 3 + assert sliced.poses == poses[1:4] + assert sliced.frame_id == "map" + + # Test open-ended slice + sliced2 = path.slice(2) + assert sliced2.poses == poses[2:] + + +def test_extend_immutable() -> None: + """Test immutable extend operation.""" + poses1 = [create_test_pose(i, i, 0) for i in range(2)] + poses2 = [create_test_pose(i + 2, i + 2, 0) for i in range(2)] + + path1 = Path(frame_id="map", poses=poses1) + path2 = Path(frame_id="odom", poses=poses2) + + extended = path1.extend(path2) + assert len(path1) == 2 # Original unchanged + assert len(extended) == 4 + assert extended.poses == poses1 + poses2 + assert extended.frame_id == "map" # Keeps first path's frame + + +def test_extend_mutable() -> None: + """Test mutable extend operation.""" + poses1 = [create_test_pose(i, i, 0) for i in range(2)] + poses2 = [create_test_pose(i + 2, i + 2, 0) for i in range(2)] + + path1 = Path(frame_id="map", poses=poses1.copy()) # Use copy to avoid modifying original + path2 = Path(frame_id="odom", poses=poses2) + + path1.extend_mut(path2) + assert len(path1) == 4 + # Check poses are the same as concatenation + for _i, (p1, p2) in enumerate(zip(path1.poses, poses1 + poses2, strict=False)): + assert p1.x == p2.x + assert p1.y == p2.y + assert p1.z == p2.z + + +def test_reverse() -> None: + """Test reverse operation.""" + poses = [create_test_pose(i, i, 0) for i in range(3)] + path = Path(poses=poses) + + reversed_path = path.reverse() + assert len(path) == 3 # Original unchanged + assert reversed_path.poses == list(reversed(poses)) + + +def test_clear() -> None: + """Test clear operation.""" + poses = [create_test_pose(i, i, 0) for i in range(3)] + path = Path(poses=poses) + + path.clear() + assert len(path) == 0 + assert path.poses == [] + + +def test_lcm_encode_decode() -> None: + """Test encoding and decoding of Path to/from binary LCM format.""" + # Create path with poses + # Use timestamps that can be represented exactly in float64 + path_ts = 1234567890.5 + poses = [ + PoseStamped( + ts=1234567890.0 + i * 0.1, # Use simpler timestamps + frame_id=f"frame_{i}", + position=[i * 1.5, i * 2.5, i * 3.5], + orientation=(0.1 * i, 0.2 * i, 0.3 * i, 0.9), + ) + for i in range(3) + ] + + path_source = Path(ts=path_ts, frame_id="world", poses=poses) + + # Encode to binary + binary_msg = path_source.lcm_encode() + + # Decode from binary + path_dest = Path.lcm_decode(binary_msg) + + assert isinstance(path_dest, Path) + assert path_dest is not path_source + + # Check header + assert path_dest.frame_id == path_source.frame_id + # Path timestamp should be preserved + assert abs(path_dest.ts - path_source.ts) < 1e-6 # Microsecond precision + + # Check poses + assert len(path_dest.poses) == len(path_source.poses) + + for orig, decoded in zip(path_source.poses, path_dest.poses, strict=False): + # Check pose timestamps + assert abs(decoded.ts - orig.ts) < 1e-6 + # All poses should have the path's frame_id + assert decoded.frame_id == path_dest.frame_id + + # Check position + assert decoded.x == orig.x + assert decoded.y == orig.y + assert decoded.z == orig.z + + # Check orientation + assert decoded.orientation.x == orig.orientation.x + assert decoded.orientation.y == orig.orientation.y + assert decoded.orientation.z == orig.orientation.z + assert decoded.orientation.w == orig.orientation.w + + +def test_lcm_encode_decode_empty() -> None: + """Test encoding and decoding of empty Path.""" + path_source = Path(frame_id="base_link") + + binary_msg = path_source.lcm_encode() + path_dest = Path.lcm_decode(binary_msg) + + assert isinstance(path_dest, Path) + assert path_dest.frame_id == path_source.frame_id + assert len(path_dest.poses) == 0 + + +def test_str_representation() -> None: + """Test string representation.""" + path = Path(frame_id="map") + assert str(path) == "Path(frame_id='map', poses=0)" + + path.push_mut(create_test_pose(1, 1, 0)) + path.push_mut(create_test_pose(2, 2, 0)) + assert str(path) == "Path(frame_id='map', poses=2)" + + +@pytest.mark.ros +def test_path_from_ros_msg() -> None: + """Test creating a Path from a ROS Path message.""" + ros_msg = ROSPath() + ros_msg.header.frame_id = "map" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + + # Add some poses + for i in range(3): + ros_pose = ROSPoseStamped() + ros_pose.header.frame_id = "map" + ros_pose.header.stamp.sec = 123 + i + ros_pose.header.stamp.nanosec = 0 + ros_pose.pose.position.x = float(i) + ros_pose.pose.position.y = float(i * 2) + ros_pose.pose.position.z = float(i * 3) + ros_pose.pose.orientation.x = 0.0 + ros_pose.pose.orientation.y = 0.0 + ros_pose.pose.orientation.z = 0.0 + ros_pose.pose.orientation.w = 1.0 + ros_msg.poses.append(ros_pose) + + path = Path.from_ros_msg(ros_msg) + + assert path.frame_id == "map" + assert path.ts == 123.456 + assert len(path.poses) == 3 + + for i, pose in enumerate(path.poses): + assert pose.position.x == float(i) + assert pose.position.y == float(i * 2) + assert pose.position.z == float(i * 3) + assert pose.orientation.w == 1.0 + + +@pytest.mark.ros +def test_path_to_ros_msg() -> None: + """Test converting a Path to a ROS Path message.""" + poses = [ + PoseStamped( + ts=124.0 + i, frame_id="odom", position=[i, i * 2, i * 3], orientation=[0, 0, 0, 1] + ) + for i in range(3) + ] + + path = Path(ts=123.456, frame_id="odom", poses=poses) + + ros_msg = path.to_ros_msg() + + assert isinstance(ros_msg, ROSPath) + assert ros_msg.header.frame_id == "odom" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert len(ros_msg.poses) == 3 + + for i, ros_pose in enumerate(ros_msg.poses): + assert ros_pose.pose.position.x == float(i) + assert ros_pose.pose.position.y == float(i * 2) + assert ros_pose.pose.position.z == float(i * 3) + assert ros_pose.pose.orientation.w == 1.0 + + +@pytest.mark.ros +def test_path_ros_roundtrip() -> None: + """Test round-trip conversion between Path and ROS Path.""" + poses = [ + PoseStamped( + ts=100.0 + i * 0.1, + frame_id="world", + position=[i * 1.5, i * 2.5, i * 3.5], + orientation=[0.1, 0.2, 0.3, 0.9], + ) + for i in range(3) + ] + + original = Path(ts=99.789, frame_id="world", poses=poses) + + ros_msg = original.to_ros_msg() + restored = Path.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert len(restored.poses) == len(original.poses) + + for orig_pose, rest_pose in zip(original.poses, restored.poses, strict=False): + assert rest_pose.position.x == orig_pose.position.x + assert rest_pose.position.y == orig_pose.position.y + assert rest_pose.position.z == orig_pose.position.z + assert rest_pose.orientation.x == orig_pose.orientation.x + assert rest_pose.orientation.y == orig_pose.orientation.y + assert rest_pose.orientation.z == orig_pose.orientation.z + assert rest_pose.orientation.w == orig_pose.orientation.w diff --git a/dimos/msgs/sensor_msgs/CameraInfo.py b/dimos/msgs/sensor_msgs/CameraInfo.py new file mode 100644 index 0000000000..1b3885867a --- /dev/null +++ b/dimos/msgs/sensor_msgs/CameraInfo.py @@ -0,0 +1,473 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time + +# Import LCM types +from dimos_lcm.sensor_msgs import CameraInfo as LCMCameraInfo # type: ignore[import-untyped] +from dimos_lcm.std_msgs.Header import Header # type: ignore[import-untyped] +import numpy as np + +# Import ROS types +try: + from sensor_msgs.msg import ( # type: ignore[attr-defined] + CameraInfo as ROSCameraInfo, + RegionOfInterest as ROSRegionOfInterest, + ) + from std_msgs.msg import Header as ROSHeader # type: ignore[attr-defined] + + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + +from dimos.types.timestamped import Timestamped + + +class CameraInfo(Timestamped): + """Camera calibration information message.""" + + msg_name = "sensor_msgs.CameraInfo" + + def __init__( + self, + height: int = 0, + width: int = 0, + distortion_model: str = "", + D: list[float] | None = None, + K: list[float] | None = None, + R: list[float] | None = None, + P: list[float] | None = None, + binning_x: int = 0, + binning_y: int = 0, + frame_id: str = "", + ts: float | None = None, + ) -> None: + """Initialize CameraInfo. + + Args: + height: Image height + width: Image width + distortion_model: Name of distortion model (e.g., "plumb_bob") + D: Distortion coefficients + K: 3x3 intrinsic camera matrix + R: 3x3 rectification matrix + P: 3x4 projection matrix + binning_x: Horizontal binning + binning_y: Vertical binning + frame_id: Frame ID + ts: Timestamp + """ + self.ts = ts if ts is not None else time.time() + self.frame_id = frame_id + self.height = height + self.width = width + self.distortion_model = distortion_model + + # Initialize distortion coefficients + self.D = D if D is not None else [] + + # Initialize 3x3 intrinsic camera matrix (row-major) + self.K = K if K is not None else [0.0] * 9 + + # Initialize 3x3 rectification matrix (row-major) + self.R = R if R is not None else [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0] + + # Initialize 3x4 projection matrix (row-major) + self.P = P if P is not None else [0.0] * 12 + + self.binning_x = binning_x + self.binning_y = binning_y + + # Region of interest (not used in basic implementation) + self.roi_x_offset = 0 + self.roi_y_offset = 0 + self.roi_height = 0 + self.roi_width = 0 + self.roi_do_rectify = False + + @classmethod + def from_yaml(cls, yaml_file: str) -> CameraInfo: + """Create CameraInfo from YAML file. + + Args: + yaml_file: Path to YAML file containing camera calibration data + + Returns: + CameraInfo instance with loaded calibration data + """ + import yaml + + with open(yaml_file) as f: + data = yaml.safe_load(f) + + # Extract basic parameters + width = data.get("image_width", 0) + height = data.get("image_height", 0) + distortion_model = data.get("distortion_model", "") + + # Extract matrices + camera_matrix = data.get("camera_matrix", {}) + K = camera_matrix.get("data", [0.0] * 9) + + distortion_coeffs = data.get("distortion_coefficients", {}) + D = distortion_coeffs.get("data", []) + + rect_matrix = data.get("rectification_matrix", {}) + R = rect_matrix.get("data", [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]) + + proj_matrix = data.get("projection_matrix", {}) + P = proj_matrix.get("data", [0.0] * 12) + + # Create CameraInfo instance + return cls( + height=height, + width=width, + distortion_model=distortion_model, + D=D, + K=K, + R=R, + P=P, + frame_id="camera_optical", + ) + + def get_K_matrix(self) -> np.ndarray: # type: ignore[type-arg] + """Get intrinsic matrix as numpy array.""" + return np.array(self.K, dtype=np.float64).reshape(3, 3) + + def get_P_matrix(self) -> np.ndarray: # type: ignore[type-arg] + """Get projection matrix as numpy array.""" + return np.array(self.P, dtype=np.float64).reshape(3, 4) + + def get_R_matrix(self) -> np.ndarray: # type: ignore[type-arg] + """Get rectification matrix as numpy array.""" + return np.array(self.R, dtype=np.float64).reshape(3, 3) + + def get_D_coeffs(self) -> np.ndarray: # type: ignore[type-arg] + """Get distortion coefficients as numpy array.""" + return np.array(self.D, dtype=np.float64) + + def set_K_matrix(self, K: np.ndarray): # type: ignore[no-untyped-def, type-arg] + """Set intrinsic matrix from numpy array.""" + if K.shape != (3, 3): + raise ValueError(f"K matrix must be 3x3, got {K.shape}") + self.K = K.flatten().tolist() + + def set_P_matrix(self, P: np.ndarray): # type: ignore[no-untyped-def, type-arg] + """Set projection matrix from numpy array.""" + if P.shape != (3, 4): + raise ValueError(f"P matrix must be 3x4, got {P.shape}") + self.P = P.flatten().tolist() + + def set_R_matrix(self, R: np.ndarray): # type: ignore[no-untyped-def, type-arg] + """Set rectification matrix from numpy array.""" + if R.shape != (3, 3): + raise ValueError(f"R matrix must be 3x3, got {R.shape}") + self.R = R.flatten().tolist() + + def set_D_coeffs(self, D: np.ndarray) -> None: # type: ignore[type-arg] + """Set distortion coefficients from numpy array.""" + self.D = D.flatten().tolist() + + def lcm_encode(self) -> bytes: + """Convert to LCM CameraInfo message.""" + msg = LCMCameraInfo() + + # Header + msg.header = Header() + msg.header.seq = 0 + msg.header.frame_id = self.frame_id + msg.header.stamp.sec = int(self.ts) + msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) + + # Image dimensions + msg.height = self.height + msg.width = self.width + + # Distortion model + msg.distortion_model = self.distortion_model + + # Distortion coefficients + msg.D_length = len(self.D) + msg.D = self.D + + # Camera matrices (all stored as row-major) + msg.K = self.K + msg.R = self.R + msg.P = self.P + + # Binning + msg.binning_x = self.binning_x + msg.binning_y = self.binning_y + + # ROI + msg.roi.x_offset = self.roi_x_offset + msg.roi.y_offset = self.roi_y_offset + msg.roi.height = self.roi_height + msg.roi.width = self.roi_width + msg.roi.do_rectify = self.roi_do_rectify + + return msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> CameraInfo: + """Decode from LCM CameraInfo bytes.""" + msg = LCMCameraInfo.lcm_decode(data) + + # Extract timestamp + ts = msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 if hasattr(msg, "header") else None + + camera_info = cls( + height=msg.height, + width=msg.width, + distortion_model=msg.distortion_model, + D=list(msg.D) if msg.D_length > 0 else [], + K=list(msg.K), + R=list(msg.R), + P=list(msg.P), + binning_x=msg.binning_x, + binning_y=msg.binning_y, + frame_id=msg.header.frame_id if hasattr(msg, "header") else "", + ts=ts, + ) + + # Set ROI if present + if hasattr(msg, "roi"): + camera_info.roi_x_offset = msg.roi.x_offset + camera_info.roi_y_offset = msg.roi.y_offset + camera_info.roi_height = msg.roi.height + camera_info.roi_width = msg.roi.width + camera_info.roi_do_rectify = msg.roi.do_rectify + + return camera_info + + @classmethod + def from_ros_msg(cls, ros_msg: ROSCameraInfo) -> CameraInfo: + """Create CameraInfo from ROS sensor_msgs/CameraInfo message. + + Args: + ros_msg: ROS CameraInfo message + + Returns: + CameraInfo instance + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert from ROS message.") + + # Extract timestamp + ts = ros_msg.header.stamp.sec + ros_msg.header.stamp.nanosec / 1e9 + + camera_info = cls( + height=ros_msg.height, + width=ros_msg.width, + distortion_model=ros_msg.distortion_model, + D=list(ros_msg.d), + K=list(ros_msg.k), + R=list(ros_msg.r), + P=list(ros_msg.p), + binning_x=ros_msg.binning_x, + binning_y=ros_msg.binning_y, + frame_id=ros_msg.header.frame_id, + ts=ts, + ) + + # Set ROI + camera_info.roi_x_offset = ros_msg.roi.x_offset + camera_info.roi_y_offset = ros_msg.roi.y_offset + camera_info.roi_height = ros_msg.roi.height + camera_info.roi_width = ros_msg.roi.width + camera_info.roi_do_rectify = ros_msg.roi.do_rectify + + return camera_info + + def to_ros_msg(self) -> ROSCameraInfo: + """Convert to ROS sensor_msgs/CameraInfo message. + + Returns: + ROS CameraInfo message + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert to ROS message.") + + ros_msg = ROSCameraInfo() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header = ROSHeader() # type: ignore[no-untyped-call] + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1e9) + + # Image dimensions + ros_msg.height = self.height + ros_msg.width = self.width + + # Distortion model and coefficients + ros_msg.distortion_model = self.distortion_model + ros_msg.d = self.D + + # Camera matrices (all row-major) + ros_msg.k = self.K + ros_msg.r = self.R + ros_msg.p = self.P + + # Binning + ros_msg.binning_x = self.binning_x + ros_msg.binning_y = self.binning_y + + # ROI + ros_msg.roi = ROSRegionOfInterest() # type: ignore[no-untyped-call] + ros_msg.roi.x_offset = self.roi_x_offset + ros_msg.roi.y_offset = self.roi_y_offset + ros_msg.roi.height = self.roi_height + ros_msg.roi.width = self.roi_width + ros_msg.roi.do_rectify = self.roi_do_rectify + + return ros_msg + + def __repr__(self) -> str: + """String representation.""" + return ( + f"CameraInfo(height={self.height}, width={self.width}, " + f"distortion_model='{self.distortion_model}', " + f"frame_id='{self.frame_id}', ts={self.ts})" + ) + + def __str__(self) -> str: + """Human-readable string.""" + return ( + f"CameraInfo:\n" + f" Resolution: {self.width}x{self.height}\n" + f" Distortion model: {self.distortion_model}\n" + f" Frame ID: {self.frame_id}\n" + f" Binning: {self.binning_x}x{self.binning_y}" + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two CameraInfo messages are equal.""" + if not isinstance(other, CameraInfo): + return False + + return ( + self.height == other.height + and self.width == other.width + and self.distortion_model == other.distortion_model + and self.D == other.D + and self.K == other.K + and self.R == other.R + and self.P == other.P + and self.binning_x == other.binning_x + and self.binning_y == other.binning_y + and self.frame_id == other.frame_id + ) + + +class CalibrationProvider: + """Provides lazy-loaded access to camera calibration YAML files in a directory.""" + + def __init__(self, calibration_dir) -> None: # type: ignore[no-untyped-def] + """Initialize with a directory containing calibration YAML files. + + Args: + calibration_dir: Path to directory containing .yaml calibration files + """ + from pathlib import Path + + self._calibration_dir = Path(calibration_dir) + self._cache = {} # type: ignore[var-annotated] + + def _to_snake_case(self, name: str) -> str: + """Convert PascalCase to snake_case.""" + import re + + # Insert underscore before capital letters (except first char) + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + # Insert underscore before capital letter followed by lowercase + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + def _find_yaml_file(self, name: str): # type: ignore[no-untyped-def] + """Find YAML file matching the given name (tries both snake_case and exact match). + + Args: + name: Attribute name to look for + + Returns: + Path to YAML file if found, None otherwise + """ + # Try exact match first + yaml_file = self._calibration_dir / f"{name}.yaml" + if yaml_file.exists(): + return yaml_file + + # Try snake_case conversion for PascalCase names + snake_name = self._to_snake_case(name) + if snake_name != name: + yaml_file = self._calibration_dir / f"{snake_name}.yaml" + if yaml_file.exists(): + return yaml_file + + return None + + def __getattr__(self, name: str) -> CameraInfo: + """Load calibration YAML file on first access. + + Supports both snake_case and PascalCase attribute names. + For example, both 'single_webcam' and 'SingleWebcam' will load 'single_webcam.yaml'. + + Args: + name: Attribute name (can be PascalCase or snake_case) + + Returns: + CameraInfo object loaded from the YAML file + + Raises: + AttributeError: If no matching YAML file exists + """ + # Check cache first + if name in self._cache: + return self._cache[name] # type: ignore[no-any-return] + + # Also check if the snake_case version is cached (for PascalCase access) + snake_name = self._to_snake_case(name) + if snake_name != name and snake_name in self._cache: + return self._cache[snake_name] # type: ignore[no-any-return] + + # Find matching YAML file + yaml_file = self._find_yaml_file(name) + if not yaml_file: + raise AttributeError(f"No calibration file found for: {name}") + + # Load and cache the CameraInfo + camera_info = CameraInfo.from_yaml(str(yaml_file)) + + # Cache both the requested name and the snake_case version + self._cache[name] = camera_info + if snake_name != name: + self._cache[snake_name] = camera_info + + return camera_info + + def __dir__(self): # type: ignore[no-untyped-def] + """List available calibrations in both snake_case and PascalCase.""" + calibrations = [] + if self._calibration_dir.exists() and self._calibration_dir.is_dir(): + yaml_files = self._calibration_dir.glob("*.yaml") + for f in yaml_files: + stem = f.stem + calibrations.append(stem) + # Add PascalCase version + pascal = "".join(word.capitalize() for word in stem.split("_")) + if pascal != stem: + calibrations.append(pascal) + return calibrations diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py new file mode 100644 index 0000000000..a2f7c79f0a --- /dev/null +++ b/dimos/msgs/sensor_msgs/Image.py @@ -0,0 +1,734 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import base64 +import time +from typing import TYPE_CHECKING, Literal, TypedDict + +import cv2 +from dimos_lcm.sensor_msgs.Image import Image as LCMImage # type: ignore[import-untyped] +from dimos_lcm.std_msgs.Header import Header # type: ignore[import-untyped] +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from turbojpeg import TurboJPEG # type: ignore[import-untyped] + +from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ( + HAS_CUDA, + HAS_NVIMGCODEC, + NVIMGCODEC_LAST_USED, + ImageFormat, +) +from dimos.msgs.sensor_msgs.image_impls.CudaImage import CudaImage +from dimos.msgs.sensor_msgs.image_impls.NumpyImage import NumpyImage +from dimos.types.timestamped import Timestamped, TimestampedBufferCollection, to_human_readable +from dimos.utils.reactive import quality_barrier + +if TYPE_CHECKING: + from reactivex.observable import Observable + + from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ( + AbstractImage, + ) + +try: + import cupy as cp # type: ignore +except Exception: + cp = None + +try: + from sensor_msgs.msg import Image as ROSImage # type: ignore[attr-defined] +except ImportError: + ROSImage = None # type: ignore[assignment, misc] + + +class AgentImageMessage(TypedDict): + """Type definition for agent-compatible image representation.""" + + type: Literal["image"] + source_type: Literal["base64"] + mime_type: Literal["image/jpeg", "image/png"] + data: str # Base64 encoded image data + + +class Image(Timestamped): + msg_name = "sensor_msgs.Image" + + def __init__( # type: ignore[no-untyped-def] + self, + impl: AbstractImage | None = None, + *, + data=None, + format: ImageFormat | None = None, + frame_id: str | None = None, + ts: float | None = None, + ) -> None: + """Construct an Image facade. + + Usage: + - Image(impl=) + - Image(data=, format=ImageFormat.RGB, frame_id=str, ts=float) + + Notes: + - When constructed from `data`, uses CudaImage if `data` is a CuPy array and CUDA is available; otherwise NumpyImage. + - `format` defaults to ImageFormat.RGB; `frame_id` defaults to ""; `ts` defaults to `time.time()`. + """ + # Disallow mixing impl with raw kwargs + if impl is not None and any(x is not None for x in (data, format, frame_id, ts)): + raise TypeError( + "Provide either 'impl' or ('data', 'format', 'frame_id', 'ts'), not both" + ) + + if impl is not None: + self._impl = impl + return + + # Raw constructor path + if data is None: + raise TypeError("'data' is required when constructing Image without 'impl'") + fmt = format if format is not None else ImageFormat.BGR + fid = frame_id if frame_id is not None else "" + tstamp = ts if ts is not None else time.time() + + # Detect CuPy array without a hard dependency + is_cu = False + try: + import cupy as _cp + + is_cu = isinstance(data, _cp.ndarray) + except Exception: + is_cu = False + + if is_cu and HAS_CUDA: + self._impl = CudaImage(data, fmt, fid, tstamp) + else: + self._impl = NumpyImage(np.asarray(data), fmt, fid, tstamp) + + def __str__(self) -> str: + dev = "cuda" if self.is_cuda else "cpu" + return ( + f"Image(shape={self.shape}, format={self.format.value}, dtype={self.dtype}, " + f"dev={dev}, ts={to_human_readable(self.ts)})" + ) + + @classmethod + def from_impl(cls, impl: AbstractImage) -> Image: + return cls(impl) + + @classmethod + def from_numpy( # type: ignore[no-untyped-def] + cls, + np_image: np.ndarray, # type: ignore[type-arg] + format: ImageFormat = ImageFormat.BGR, + to_cuda: bool = False, + **kwargs, + ) -> Image: + if kwargs.pop("to_gpu", False): + to_cuda = True + if to_cuda and HAS_CUDA: + return cls( + CudaImage( + np_image if hasattr(np_image, "shape") else np.asarray(np_image), + format, + kwargs.get("frame_id", ""), + kwargs.get("ts", time.time()), + ) + ) + return cls( + NumpyImage( + np.asarray(np_image), + format, + kwargs.get("frame_id", ""), + kwargs.get("ts", time.time()), + ) + ) + + @classmethod + def from_file( # type: ignore[no-untyped-def] + cls, filepath: str, format: ImageFormat = ImageFormat.RGB, to_cuda: bool = False, **kwargs + ) -> Image: + if kwargs.pop("to_gpu", False): + to_cuda = True + arr = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) + if arr is None: + raise ValueError(f"Could not load image from {filepath}") + if arr.ndim == 2: + detected = ImageFormat.GRAY16 if arr.dtype == np.uint16 else ImageFormat.GRAY + elif arr.shape[2] == 3: + detected = ImageFormat.BGR # OpenCV default + elif arr.shape[2] == 4: + detected = ImageFormat.BGRA # OpenCV default + else: + detected = format + return cls(CudaImage(arr, detected) if to_cuda and HAS_CUDA else NumpyImage(arr, detected)) + + @classmethod + def from_opencv( # type: ignore[no-untyped-def] + cls, + cv_image: np.ndarray, # type: ignore[type-arg] + format: ImageFormat = ImageFormat.BGR, + **kwargs, + ) -> Image: + """Construct from an OpenCV image (NumPy array).""" + return cls( + NumpyImage(cv_image, format, kwargs.get("frame_id", ""), kwargs.get("ts", time.time())) + ) + + @classmethod + def from_depth( # type: ignore[no-untyped-def] + cls, depth_data, frame_id: str = "", ts: float | None = None, to_cuda: bool = False + ) -> Image: + arr = np.asarray(depth_data) + if arr.dtype != np.float32: + arr = arr.astype(np.float32) + impl = ( + CudaImage(arr, ImageFormat.DEPTH, frame_id, time.time() if ts is None else ts) + if to_cuda and HAS_CUDA + else NumpyImage(arr, ImageFormat.DEPTH, frame_id, time.time() if ts is None else ts) + ) + return cls(impl) + + # Delegation + @property + def is_cuda(self) -> bool: + return self._impl.is_cuda + + @property + def data(self): # type: ignore[no-untyped-def] + return self._impl.data + + @data.setter + def data(self, value) -> None: # type: ignore[no-untyped-def] + # Preserve backend semantics: ensure array type matches implementation + if isinstance(self._impl, NumpyImage): + self._impl.data = np.asarray(value) + elif isinstance(self._impl, CudaImage): + if cp is None: + raise RuntimeError("CuPy not available to set CUDA image data") + self._impl.data = cp.asarray(value) + else: + self._impl.data = value + + @property + def format(self) -> ImageFormat: + return self._impl.format + + @format.setter + def format(self, value) -> None: # type: ignore[no-untyped-def] + if isinstance(value, ImageFormat): + self._impl.format = value + elif isinstance(value, str): + try: + self._impl.format = ImageFormat[value] + except KeyError as e: + raise ValueError(f"Invalid ImageFormat: {value}") from e + else: + raise TypeError("format must be ImageFormat or str name") + + @property + def frame_id(self) -> str: + return self._impl.frame_id + + @frame_id.setter + def frame_id(self, value: str) -> None: + self._impl.frame_id = str(value) + + @property + def ts(self) -> float: + return self._impl.ts + + @ts.setter + def ts(self, value: float) -> None: + self._impl.ts = float(value) + + @property + def height(self) -> int: + return self._impl.height + + @property + def width(self) -> int: + return self._impl.width + + @property + def channels(self) -> int: + return self._impl.channels + + @property + def shape(self): # type: ignore[no-untyped-def] + return self._impl.shape + + @property + def dtype(self): # type: ignore[no-untyped-def] + return self._impl.dtype + + def copy(self) -> Image: + return Image(self._impl.copy()) + + def to_cpu(self) -> Image: + if isinstance(self._impl, NumpyImage): + return self.copy() + + data = self._impl.data.get() # CuPy array to NumPy + + return Image( + NumpyImage( + data, + self._impl.format, + self._impl.frame_id, + self._impl.ts, + ) + ) + + def to_cupy(self) -> Image: + if isinstance(self._impl, CudaImage): + return self.copy() + return Image( + CudaImage( + np.asarray(self._impl.data), self._impl.format, self._impl.frame_id, self._impl.ts + ) + ) + + def to_opencv(self) -> np.ndarray: # type: ignore[type-arg] + return self._impl.to_opencv() + + def to_rgb(self) -> Image: + return Image(self._impl.to_rgb()) + + def to_bgr(self) -> Image: + return Image(self._impl.to_bgr()) + + def to_grayscale(self) -> Image: + return Image(self._impl.to_grayscale()) + + def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> Image: + return Image(self._impl.resize(width, height, interpolation)) + + def crop(self, x: int, y: int, width: int, height: int) -> Image: + return Image(self._impl.crop(x, y, width, height)) # type: ignore[attr-defined] + + @property + def sharpness(self) -> float: + """Return sharpness score.""" + return self._impl.sharpness() + + def save(self, filepath: str) -> bool: + return self._impl.save(filepath) + + def to_base64( + self, + quality: int = 80, + *, + max_width: int | None = None, + max_height: int | None = None, + ) -> str: + """Encode the image as a base64 JPEG string. + + Args: + quality: JPEG quality (0-100). + max_width: Optional maximum width to constrain the encoded image. + max_height: Optional maximum height to constrain the encoded image. + + Returns: + Base64-encoded JPEG representation of the image. + """ + bgr_image = self.to_bgr().to_opencv() + height, width = bgr_image.shape[:2] + + scale = 1.0 + if max_width is not None and width > max_width: + scale = min(scale, max_width / width) + if max_height is not None and height > max_height: + scale = min(scale, max_height / height) + + if scale < 1.0: + new_width = max(1, round(width * scale)) + new_height = max(1, round(height * scale)) + bgr_image = cv2.resize(bgr_image, (new_width, new_height), interpolation=cv2.INTER_AREA) + + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(np.clip(quality, 0, 100))] + success, buffer = cv2.imencode(".jpg", bgr_image, encode_param) + if not success: + raise ValueError("Failed to encode image as JPEG") + + return base64.b64encode(buffer.tobytes()).decode("utf-8") + + def agent_encode(self) -> AgentImageMessage: + return [ # type: ignore[return-value] + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{self.to_base64()}"}, + } + ] + + # LCM encode/decode + def lcm_encode(self, frame_id: str | None = None) -> bytes: + """Convert to LCM Image message.""" + msg = LCMImage() + + # Header + msg.header = Header() + msg.header.seq = 0 + msg.header.frame_id = frame_id or self.frame_id + + # Set timestamp + if self.ts is not None: + msg.header.stamp.sec = int(self.ts) + msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) + else: + now = time.time() + msg.header.stamp.sec = int(now) + msg.header.stamp.nsec = int((now - int(now)) * 1e9) + + # Image properties + msg.height = self.height + msg.width = self.width + msg.encoding = _get_lcm_encoding(self.format, self.dtype) + msg.is_bigendian = False + + # Calculate step (bytes per row) + channels = 1 if self.data.ndim == 2 else self.data.shape[2] + msg.step = self.width * self.dtype.itemsize * channels + + # Image data - use raw data to preserve format + image_bytes = self.data.tobytes() + msg.data_length = len(image_bytes) + msg.data = image_bytes + + return msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes, **kwargs) -> Image: # type: ignore[no-untyped-def] + msg = LCMImage.lcm_decode(data) + fmt, dtype, channels = _parse_lcm_encoding(msg.encoding) + arr = np.frombuffer(msg.data, dtype=dtype) + if channels == 1: + arr = arr.reshape((msg.height, msg.width)) + else: + arr = arr.reshape((msg.height, msg.width, channels)) + return cls( + NumpyImage( + arr, + fmt, + msg.header.frame_id if hasattr(msg, "header") else "", + ( + msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 + if hasattr(msg, "header") + and hasattr(msg.header, "stamp") + and msg.header.stamp.sec > 0 + else time.time() + ), + ) + ) + + def lcm_jpeg_encode(self, quality: int = 75, frame_id: str | None = None) -> bytes: + """Convert to LCM Image message with JPEG-compressed data. + + Args: + quality: JPEG compression quality (0-100, default 75) + frame_id: Optional frame ID override + + Returns: + LCM-encoded bytes with JPEG-compressed image data + """ + jpeg = TurboJPEG() + msg = LCMImage() + + # Header + msg.header = Header() + msg.header.seq = 0 + msg.header.frame_id = frame_id or self.frame_id + + # Set timestamp + if self.ts is not None: + msg.header.stamp.sec = int(self.ts) + msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) + else: + now = time.time() + msg.header.stamp.sec = int(now) + msg.header.stamp.nsec = int((now - int(now)) * 1e9) + + # Get image in BGR format for JPEG encoding + bgr_image = self.to_bgr().to_opencv() + + # Encode as JPEG + jpeg_data = jpeg.encode(bgr_image, quality=quality) + + # Store JPEG data and metadata + msg.height = self.height + msg.width = self.width + msg.encoding = "jpeg" + msg.is_bigendian = False + msg.step = 0 # Not applicable for compressed format + + msg.data_length = len(jpeg_data) + msg.data = jpeg_data + + return msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_jpeg_decode(cls, data: bytes, **kwargs) -> Image: # type: ignore[no-untyped-def] + """Decode an LCM Image message with JPEG-compressed data. + + Args: + data: LCM-encoded bytes containing JPEG-compressed image + + Returns: + Image instance + """ + jpeg = TurboJPEG() + msg = LCMImage.lcm_decode(data) + + if msg.encoding != "jpeg": + raise ValueError(f"Expected JPEG encoding, got {msg.encoding}") + + # Decode JPEG data + bgr_array = jpeg.decode(msg.data) + + return cls( + NumpyImage( + bgr_array, + ImageFormat.BGR, + msg.header.frame_id if hasattr(msg, "header") else "", + ( + msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 + if hasattr(msg, "header") + and hasattr(msg.header, "stamp") + and msg.header.stamp.sec > 0 + else time.time() + ), + ) + ) + + # PnP wrappers + def solve_pnp(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self._impl.solve_pnp(*args, **kwargs) # type: ignore + + def solve_pnp_ransac(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self._impl.solve_pnp_ransac(*args, **kwargs) # type: ignore + + def solve_pnp_batch(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self._impl.solve_pnp_batch(*args, **kwargs) # type: ignore + + def create_csrt_tracker(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self._impl.create_csrt_tracker(*args, **kwargs) # type: ignore + + def csrt_update(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self._impl.csrt_update(*args, **kwargs) # type: ignore + + @classmethod + def from_ros_msg(cls, ros_msg: ROSImage) -> Image: + """Create an Image from a ROS sensor_msgs/Image message. + + Args: + ros_msg: ROS Image message + + Returns: + Image instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Parse encoding to determine format and data type + format_info = cls._parse_encoding(ros_msg.encoding) + + # Convert data from ROS message (array.array) to numpy array + data_array = np.frombuffer(ros_msg.data, dtype=format_info["dtype"]) + + # Reshape to image dimensions + if format_info["channels"] == 1: + data_array = data_array.reshape((ros_msg.height, ros_msg.width)) + else: + data_array = data_array.reshape( + (ros_msg.height, ros_msg.width, format_info["channels"]) + ) + + # Crop to center 1/3 of the image (simulate 120-degree FOV from 360-degree) + original_width = data_array.shape[1] + crop_width = original_width // 3 + start_x = (original_width - crop_width) // 2 + end_x = start_x + crop_width + + # Crop the image horizontally to center 1/3 + if len(data_array.shape) == 2: + # Grayscale image + data_array = data_array[:, start_x:end_x] + else: + # Color image + data_array = data_array[:, start_x:end_x, :] + + # Fix color channel order: if ROS sends RGB but we expect BGR, swap channels + # ROS typically uses rgb8 encoding, but OpenCV/our system expects BGR + if format_info["format"] == ImageFormat.RGB: + # Convert RGB to BGR by swapping channels + if len(data_array.shape) == 3 and data_array.shape[2] == 3: + data_array = data_array[:, :, [2, 1, 0]] # RGB -> BGR + format_info["format"] = ImageFormat.BGR + elif format_info["format"] == ImageFormat.RGBA: + # Convert RGBA to BGRA by swapping channels + if len(data_array.shape) == 3 and data_array.shape[2] == 4: + data_array = data_array[:, :, [2, 1, 0, 3]] # RGBA -> BGRA + format_info["format"] = ImageFormat.BGRA + + return cls( + data=data_array, + format=format_info["format"], + frame_id=ros_msg.header.frame_id, + ts=ts, + ) + + @staticmethod + def _parse_encoding(encoding: str) -> dict: # type: ignore[type-arg] + """Translate ROS encoding strings into format metadata.""" + encoding_map = { + "mono8": {"format": ImageFormat.GRAY, "dtype": np.uint8, "channels": 1}, + "mono16": {"format": ImageFormat.GRAY16, "dtype": np.uint16, "channels": 1}, + "rgb8": {"format": ImageFormat.RGB, "dtype": np.uint8, "channels": 3}, + "rgba8": {"format": ImageFormat.RGBA, "dtype": np.uint8, "channels": 4}, + "bgr8": {"format": ImageFormat.BGR, "dtype": np.uint8, "channels": 3}, + "bgra8": {"format": ImageFormat.BGRA, "dtype": np.uint8, "channels": 4}, + "32FC1": {"format": ImageFormat.DEPTH, "dtype": np.float32, "channels": 1}, + "32FC3": {"format": ImageFormat.RGB, "dtype": np.float32, "channels": 3}, + "64FC1": {"format": ImageFormat.DEPTH, "dtype": np.float64, "channels": 1}, + "16UC1": {"format": ImageFormat.DEPTH16, "dtype": np.uint16, "channels": 1}, + "16SC1": {"format": ImageFormat.DEPTH16, "dtype": np.int16, "channels": 1}, + } + + key = encoding.strip() + for candidate in (key, key.lower(), key.upper()): + if candidate in encoding_map: + return dict(encoding_map[candidate]) + + raise ValueError(f"Unsupported encoding: {encoding}") + + def __repr__(self) -> str: + dev = "cuda" if self.is_cuda else "cpu" + return f"Image(shape={self.shape}, format={self.format.value}, dtype={self.dtype}, dev={dev}, frame_id='{self.frame_id}', ts={self.ts})" + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + if not isinstance(other, Image): + return False + return ( + np.array_equal(self.data, other.data) + and self.format == other.format + and self.frame_id == other.frame_id + and abs(self.ts - other.ts) < 1e-6 + ) + + def __len__(self) -> int: + return int(self.height * self.width) + + def __getstate__(self): # type: ignore[no-untyped-def] + return {"data": self.data, "format": self.format, "frame_id": self.frame_id, "ts": self.ts} + + def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] + self.__init__( # type: ignore[misc] + data=state.get("data"), + format=state.get("format"), + frame_id=state.get("frame_id"), + ts=state.get("ts"), + ) + + +# Re-exports for tests +HAS_CUDA = HAS_CUDA +ImageFormat = ImageFormat +NVIMGCODEC_LAST_USED = NVIMGCODEC_LAST_USED +HAS_NVIMGCODEC = HAS_NVIMGCODEC +__all__ = [ + "HAS_CUDA", + "HAS_NVIMGCODEC", + "NVIMGCODEC_LAST_USED", + "ImageFormat", + "sharpness_barrier", + "sharpness_window", +] + + +def sharpness_window(target_frequency: float, source: Observable[Image]) -> Observable[Image]: + """Emit the sharpest Image seen within each sliding time window.""" + if target_frequency <= 0: + raise ValueError("target_frequency must be positive") + + window = TimestampedBufferCollection(1.0 / target_frequency) # type: ignore[var-annotated] + source.subscribe(window.add) + + thread_scheduler = ThreadPoolScheduler(max_workers=1) # type: ignore[name-defined] + + def find_best(*_args): # type: ignore[no-untyped-def] + if not window._items: + return None + return max(window._items, key=lambda img: img.sharpness) + + return rx.interval(1.0 / target_frequency).pipe( + ops.observe_on(thread_scheduler), + ops.map(find_best), + ops.filter(lambda img: img is not None), + ) + + +def sharpness_barrier(target_frequency: float): # type: ignore[no-untyped-def] + """Select the sharpest Image within each time window.""" + if target_frequency <= 0: + raise ValueError("target_frequency must be positive") + return quality_barrier(lambda image: image.sharpness, target_frequency) # type: ignore[attr-defined] + + +def _get_lcm_encoding(fmt: ImageFormat, dtype: np.dtype) -> str: # type: ignore[type-arg] + if fmt == ImageFormat.GRAY: + if dtype == np.uint8: + return "mono8" + if dtype == np.uint16: + return "mono16" + if fmt == ImageFormat.GRAY16: + return "mono16" + if fmt == ImageFormat.RGB: + return "rgb8" + if fmt == ImageFormat.RGBA: + return "rgba8" + if fmt == ImageFormat.BGR: + return "bgr8" + if fmt == ImageFormat.BGRA: + return "bgra8" + if fmt == ImageFormat.DEPTH: + if dtype == np.float32: + return "32FC1" + if dtype == np.float64: + return "64FC1" + if fmt == ImageFormat.DEPTH16: + if dtype == np.uint16: + return "16UC1" + if dtype == np.int16: + return "16SC1" + raise ValueError(f"Unsupported LCM encoding for fmt={fmt}, dtype={dtype}") + + +def _parse_lcm_encoding(enc: str): # type: ignore[no-untyped-def] + m = { + "mono8": (ImageFormat.GRAY, np.uint8, 1), + "mono16": (ImageFormat.GRAY16, np.uint16, 1), + "rgb8": (ImageFormat.RGB, np.uint8, 3), + "rgba8": (ImageFormat.RGBA, np.uint8, 4), + "bgr8": (ImageFormat.BGR, np.uint8, 3), + "bgra8": (ImageFormat.BGRA, np.uint8, 4), + "32FC1": (ImageFormat.DEPTH, np.float32, 1), + "32FC3": (ImageFormat.RGB, np.float32, 3), + "64FC1": (ImageFormat.DEPTH, np.float64, 1), + "16UC1": (ImageFormat.DEPTH16, np.uint16, 1), + "16SC1": (ImageFormat.DEPTH16, np.int16, 1), + } + if enc not in m: + raise ValueError(f"Unsupported encoding: {enc}") + return m[enc] diff --git a/dimos/msgs/sensor_msgs/Joy.py b/dimos/msgs/sensor_msgs/Joy.py new file mode 100644 index 0000000000..b9f823584c --- /dev/null +++ b/dimos/msgs/sensor_msgs/Joy.py @@ -0,0 +1,181 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +from dimos_lcm.sensor_msgs import Joy as LCMJoy # type: ignore[import-untyped] + +try: + from sensor_msgs.msg import Joy as ROSJoy # type: ignore[attr-defined] +except ImportError: + ROSJoy = None # type: ignore[assignment, misc] + +from plum import dispatch + +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from Joy +JoyConvertable: TypeAlias = ( + tuple[list[float], list[int]] | dict[str, list[float] | list[int]] | LCMJoy +) + + +def sec_nsec(ts): # type: ignore[no-untyped-def] + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class Joy(Timestamped): + msg_name = "sensor_msgs.Joy" + ts: float + frame_id: str + axes: list[float] + buttons: list[int] + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + axes: list[float] | None = None, + buttons: list[int] | None = None, + ) -> None: + """Initialize a Joy message. + + Args: + ts: Timestamp in seconds + frame_id: Frame ID for the message + axes: List of axis values (typically -1.0 to 1.0) + buttons: List of button states (0 or 1) + """ + self.ts = ts if ts != 0 else time.time() + self.frame_id = frame_id + self.axes = axes if axes is not None else [] + self.buttons = buttons if buttons is not None else [] + + @dispatch # type: ignore[no-redef] + def __init__(self, joy_tuple: tuple[list[float], list[int]]) -> None: + """Initialize from a tuple of (axes, buttons).""" + self.ts = time.time() + self.frame_id = "" + self.axes = list(joy_tuple[0]) + self.buttons = list(joy_tuple[1]) + + @dispatch # type: ignore[no-redef] + def __init__(self, joy_dict: dict[str, list[float] | list[int]]) -> None: + """Initialize from a dictionary with 'axes' and 'buttons' keys.""" + self.ts = joy_dict.get("ts", time.time()) + self.frame_id = joy_dict.get("frame_id", "") + self.axes = list(joy_dict.get("axes", [])) + self.buttons = list(joy_dict.get("buttons", [])) + + @dispatch # type: ignore[no-redef] + def __init__(self, joy: Joy) -> None: + """Initialize from another Joy (copy constructor).""" + self.ts = joy.ts + self.frame_id = joy.frame_id + self.axes = list(joy.axes) + self.buttons = list(joy.buttons) + + @dispatch # type: ignore[no-redef] + def __init__(self, lcm_joy: LCMJoy) -> None: + """Initialize from an LCM Joy message.""" + self.ts = lcm_joy.header.stamp.sec + (lcm_joy.header.stamp.nsec / 1_000_000_000) + self.frame_id = lcm_joy.header.frame_id + self.axes = list(lcm_joy.axes) + self.buttons = list(lcm_joy.buttons) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMJoy() + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) # type: ignore[no-untyped-call] + lcm_msg.header.frame_id = self.frame_id + lcm_msg.axes_length = len(self.axes) + lcm_msg.axes = self.axes + lcm_msg.buttons_length = len(self.buttons) + lcm_msg.buttons = self.buttons + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> Joy: + lcm_msg = LCMJoy.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + axes=list(lcm_msg.axes) if lcm_msg.axes else [], + buttons=list(lcm_msg.buttons) if lcm_msg.buttons else [], + ) + + def __str__(self) -> str: + return ( + f"Joy(axes={len(self.axes)} values, buttons={len(self.buttons)} values, " + f"frame_id='{self.frame_id}')" + ) + + def __repr__(self) -> str: + return ( + f"Joy(ts={self.ts}, frame_id='{self.frame_id}', " + f"axes={self.axes}, buttons={self.buttons})" + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two Joy messages are equal.""" + if not isinstance(other, Joy): + return False + return ( + self.axes == other.axes + and self.buttons == other.buttons + and self.frame_id == other.frame_id + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSJoy) -> Joy: + """Create a Joy from a ROS sensor_msgs/Joy message. + + Args: + ros_msg: ROS Joy message + + Returns: + Joy instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + axes=list(ros_msg.axes), + buttons=list(ros_msg.buttons), + ) + + def to_ros_msg(self) -> ROSJoy: + """Convert to a ROS sensor_msgs/Joy message. + + Returns: + ROS Joy message + """ + ros_msg = ROSJoy() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set axes and buttons + ros_msg.axes = self.axes + ros_msg.buttons = self.buttons + + return ros_msg diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py new file mode 100644 index 0000000000..b64ee3021a --- /dev/null +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -0,0 +1,557 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import functools +import struct + +# Import LCM types +from dimos_lcm.sensor_msgs.PointCloud2 import ( # type: ignore[import-untyped] + PointCloud2 as LCMPointCloud2, +) +from dimos_lcm.sensor_msgs.PointField import PointField # type: ignore[import-untyped] +from dimos_lcm.std_msgs.Header import Header # type: ignore[import-untyped] +import numpy as np +import open3d as o3d # type: ignore[import-untyped] + +from dimos.msgs.geometry_msgs import Vector3 + +# Import ROS types +try: + from sensor_msgs.msg import ( # type: ignore[attr-defined] + PointCloud2 as ROSPointCloud2, + PointField as ROSPointField, + ) + from std_msgs.msg import Header as ROSHeader # type: ignore[attr-defined] + + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + +from dimos.types.timestamped import Timestamped + + +# TODO: encode/decode need to be updated to work with full spectrum of pointcloud2 fields +class PointCloud2(Timestamped): + msg_name = "sensor_msgs.PointCloud2" + + def __init__( + self, + pointcloud: o3d.geometry.PointCloud = None, + frame_id: str = "world", + ts: float | None = None, + ) -> None: + self.ts = ts # type: ignore[assignment] + self.pointcloud = pointcloud if pointcloud is not None else o3d.geometry.PointCloud() + self.frame_id = frame_id + + @classmethod + def from_numpy( + cls, + points: np.ndarray, # type: ignore[type-arg] + frame_id: str = "world", + timestamp: float | None = None, + ) -> PointCloud2: + """Create PointCloud2 from numpy array of shape (N, 3). + + Args: + points: Nx3 numpy array of 3D points + frame_id: Frame ID for the point cloud + timestamp: Timestamp for the point cloud (defaults to current time) + + Returns: + PointCloud2 instance + """ + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + return cls(pointcloud=pcd, ts=timestamp, frame_id=frame_id) + + def __str__(self) -> str: + return f"PointCloud2(frame_id='{self.frame_id}', num_points={len(self.pointcloud.points)})" + + @functools.cached_property + def center(self) -> Vector3: + """Calculate the center of the pointcloud in world frame.""" + center = np.asarray(self.pointcloud.points).mean(axis=0) + return Vector3(*center) + + def points(self): # type: ignore[no-untyped-def] + return self.pointcloud.points + + def __add__(self, other: PointCloud2) -> PointCloud2: + """Combine two PointCloud2 instances into one. + + The resulting point cloud contains points from both inputs. + The frame_id and timestamp are taken from the first point cloud. + + Args: + other: Another PointCloud2 instance to combine with + + Returns: + New PointCloud2 instance containing combined points + """ + if not isinstance(other, PointCloud2): + raise ValueError("Can only add PointCloud2 to another PointCloud2") + + return PointCloud2( + pointcloud=self.pointcloud + other.pointcloud, + frame_id=self.frame_id, + ts=max(self.ts, other.ts), + ) + + # TODO what's the usual storage here? is it already numpy? + def as_numpy(self) -> np.ndarray: # type: ignore[type-arg] + """Get points as numpy array.""" + return np.asarray(self.pointcloud.points) + + @functools.cache + def get_axis_aligned_bounding_box(self) -> o3d.geometry.AxisAlignedBoundingBox: + """Get axis-aligned bounding box of the point cloud.""" + return self.pointcloud.get_axis_aligned_bounding_box() + + @functools.cache + def get_oriented_bounding_box(self) -> o3d.geometry.OrientedBoundingBox: + """Get oriented bounding box of the point cloud.""" + return self.pointcloud.get_oriented_bounding_box() + + @functools.cache + def get_bounding_box_dimensions(self) -> tuple[float, float, float]: + """Get dimensions (width, height, depth) of axis-aligned bounding box.""" + bbox = self.get_axis_aligned_bounding_box() + extent = bbox.get_extent() + return tuple(extent) + + def bounding_box_intersects(self, other: PointCloud2) -> bool: + # Get axis-aligned bounding boxes + bbox1 = self.get_axis_aligned_bounding_box() + bbox2 = other.get_axis_aligned_bounding_box() + + # Get min and max bounds + min1 = bbox1.get_min_bound() + max1 = bbox1.get_max_bound() + min2 = bbox2.get_min_bound() + max2 = bbox2.get_max_bound() + + # Check overlap in all three dimensions + # Boxes intersect if they overlap in ALL dimensions + return ( # type: ignore[no-any-return] + min1[0] <= max2[0] + and max1[0] >= min2[0] + and min1[1] <= max2[1] + and max1[1] >= min2[1] + and min1[2] <= max2[2] + and max1[2] >= min2[2] + ) + + def lcm_encode(self, frame_id: str | None = None) -> bytes: + """Convert to LCM PointCloud2 message.""" + msg = LCMPointCloud2() + + # Header + msg.header = Header() + msg.header.seq = 0 # Initialize sequence number + msg.header.frame_id = frame_id or self.frame_id + + msg.header.stamp.sec = int(self.ts) + msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) + + points = self.as_numpy() + if len(points) == 0: + # Empty point cloud + msg.height = 0 + msg.width = 0 + msg.point_step = 16 # 4 floats * 4 bytes (x, y, z, intensity) + msg.row_step = 0 + msg.data_length = 0 + msg.data = b"" + msg.is_dense = True + msg.is_bigendian = False + msg.fields_length = 4 # x, y, z, intensity + msg.fields = self._create_xyz_field() + return msg.lcm_encode() # type: ignore[no-any-return] + + # Point cloud dimensions + msg.height = 1 # Unorganized point cloud + msg.width = len(points) + + # Define fields (X, Y, Z, intensity as float32) + msg.fields_length = 4 # x, y, z, intensity + msg.fields = self._create_xyz_field() + + # Point step and row step + msg.point_step = 16 # 4 floats * 4 bytes each (x, y, z, intensity) + msg.row_step = msg.point_step * msg.width + + # Convert points to bytes with intensity padding (little endian float32) + # Add intensity column (zeros) to make it 4 columns: x, y, z, intensity + points_with_intensity = np.column_stack( + [ + points, # x, y, z columns + np.zeros(len(points), dtype=np.float32), # intensity column (padding) + ] + ) + data_bytes = points_with_intensity.astype(np.float32).tobytes() + msg.data_length = len(data_bytes) + msg.data = data_bytes + + # Properties + msg.is_dense = True # No invalid points + msg.is_bigendian = False # Little endian + + return msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes) -> PointCloud2: + msg = LCMPointCloud2.lcm_decode(data) + + if msg.width == 0 or msg.height == 0: + # Empty point cloud + pc = o3d.geometry.PointCloud() + return cls( + pointcloud=pc, + frame_id=msg.header.frame_id if hasattr(msg, "header") else "", + ts=msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 + if hasattr(msg, "header") and msg.header.stamp.sec > 0 + else None, + ) + + # Parse field information to find X, Y, Z offsets + x_offset = y_offset = z_offset = None + for msgfield in msg.fields: + if msgfield.name == "x": + x_offset = msgfield.offset + elif msgfield.name == "y": + y_offset = msgfield.offset + elif msgfield.name == "z": + z_offset = msgfield.offset + + if any(offset is None for offset in [x_offset, y_offset, z_offset]): + raise ValueError("PointCloud2 message missing X, Y, or Z msgfields") + + # Extract points from binary data + num_points = msg.width * msg.height + points = np.zeros((num_points, 3), dtype=np.float32) + + data = msg.data + point_step = msg.point_step + + for i in range(num_points): + base_offset = i * point_step + + # Extract X, Y, Z (assuming float32, little endian) + x_bytes = data[base_offset + x_offset : base_offset + x_offset + 4] + y_bytes = data[base_offset + y_offset : base_offset + y_offset + 4] + z_bytes = data[base_offset + z_offset : base_offset + z_offset + 4] + + points[i, 0] = struct.unpack(" 0 + else None, + ) + + def _create_xyz_field(self) -> list: # type: ignore[type-arg] + """Create standard X, Y, Z field definitions for LCM PointCloud2.""" + fields = [] + + # X field + x_field = PointField() + x_field.name = "x" + x_field.offset = 0 + x_field.datatype = 7 # FLOAT32 + x_field.count = 1 + fields.append(x_field) + + # Y field + y_field = PointField() + y_field.name = "y" + y_field.offset = 4 + y_field.datatype = 7 # FLOAT32 + y_field.count = 1 + fields.append(y_field) + + # Z field + z_field = PointField() + z_field.name = "z" + z_field.offset = 8 + z_field.datatype = 7 # FLOAT32 + z_field.count = 1 + fields.append(z_field) + + # I field + i_field = PointField() + i_field.name = "intensity" + i_field.offset = 12 + i_field.datatype = 7 # FLOAT32 + i_field.count = 1 + fields.append(i_field) + + return fields + + def __len__(self) -> int: + """Return number of points.""" + return len(self.pointcloud.points) + + def filter_by_height( + self, + min_height: float | None = None, + max_height: float | None = None, + ) -> PointCloud2: + """Filter points based on their height (z-coordinate). + + This method creates a new PointCloud2 containing only points within the specified + height range. All metadata (frame_id, timestamp) is preserved. + + Args: + min_height: Optional minimum height threshold. Points with z < min_height are filtered out. + If None, no lower limit is applied. + max_height: Optional maximum height threshold. Points with z > max_height are filtered out. + If None, no upper limit is applied. + + Returns: + New PointCloud2 instance containing only the filtered points. + + Raises: + ValueError: If both min_height and max_height are None (no filtering would occur). + + Example: + # Remove ground points below 0.1m height + filtered_pc = pointcloud.filter_by_height(min_height=0.1) + + # Keep only points between ground level and 2m height + filtered_pc = pointcloud.filter_by_height(min_height=0.0, max_height=2.0) + + # Remove points above 1.5m (e.g., ceiling) + filtered_pc = pointcloud.filter_by_height(max_height=1.5) + """ + # Validate that at least one threshold is provided + if min_height is None and max_height is None: + raise ValueError("At least one of min_height or max_height must be specified") + + # Get points as numpy array + points = self.as_numpy() + + if len(points) == 0: + # Empty pointcloud - return a copy + return PointCloud2( + pointcloud=o3d.geometry.PointCloud(), + frame_id=self.frame_id, + ts=self.ts, + ) + + # Extract z-coordinates (height values) - column index 2 + heights = points[:, 2] + + # Create boolean mask for filtering based on height thresholds + # Start with all True values + mask = np.ones(len(points), dtype=bool) + + # Apply minimum height filter if specified + if min_height is not None: + mask &= heights >= min_height + + # Apply maximum height filter if specified + if max_height is not None: + mask &= heights <= max_height + + # Apply mask to filter points + filtered_points = points[mask] + + # Create new PointCloud2 with filtered points + return PointCloud2.from_numpy( + points=filtered_points, + frame_id=self.frame_id, + timestamp=self.ts, + ) + + def __repr__(self) -> str: + """String representation.""" + return f"PointCloud(points={len(self)}, frame_id='{self.frame_id}', ts={self.ts})" + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPointCloud2) -> PointCloud2: + """Convert from ROS sensor_msgs/PointCloud2 message. + + Args: + ros_msg: ROS PointCloud2 message + + Returns: + PointCloud2 instance + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert from ROS message.") + + # Handle empty point cloud + if ros_msg.width == 0 or ros_msg.height == 0: + pc = o3d.geometry.PointCloud() + return cls( + pointcloud=pc, + frame_id=ros_msg.header.frame_id, + ts=ros_msg.header.stamp.sec + ros_msg.header.stamp.nanosec / 1e9, + ) + + # Parse field information to find X, Y, Z offsets + x_offset = y_offset = z_offset = None + for field in ros_msg.fields: + if field.name == "x": + x_offset = field.offset + elif field.name == "y": + y_offset = field.offset + elif field.name == "z": + z_offset = field.offset + + if any(offset is None for offset in [x_offset, y_offset, z_offset]): + raise ValueError("PointCloud2 message missing X, Y, or Z fields") + + # Extract points from binary data using numpy for bulk conversion + num_points = ros_msg.width * ros_msg.height + data = ros_msg.data + point_step = ros_msg.point_step + + # Determine byte order + byte_order = ">" if ros_msg.is_bigendian else "<" + + # Check if we can use fast numpy path (common case: sequential float32 x,y,z) + if ( + x_offset == 0 + and y_offset == 4 + and z_offset == 8 + and point_step >= 12 + and not ros_msg.is_bigendian + ): + # Fast path: direct numpy reshape for tightly packed float32 x,y,z + # This is the most common case for point clouds + if point_step == 12: + # Perfectly packed x,y,z with no padding + points = np.frombuffer(data, dtype=np.float32).reshape(-1, 3) + else: + # Has additional fields after x,y,z, need to extract with stride + dt = np.dtype( + [("x", " 0: # type: ignore[operator] + dt_fields.append(("_pad_x", f"V{x_offset}")) + dt_fields.append(("x", f"{byte_order}f4")) + + # Add padding between x and y if needed + gap_xy = y_offset - x_offset - 4 # type: ignore[operator] + if gap_xy > 0: + dt_fields.append(("_pad_xy", f"V{gap_xy}")) + dt_fields.append(("y", f"{byte_order}f4")) + + # Add padding between y and z if needed + gap_yz = z_offset - y_offset - 4 # type: ignore[operator] + if gap_yz > 0: + dt_fields.append(("_pad_yz", f"V{gap_yz}")) + dt_fields.append(("z", f"{byte_order}f4")) + + # Add padding at the end to match point_step + remaining = point_step - z_offset - 4 + if remaining > 0: + dt_fields.append(("_pad_end", f"V{remaining}")) + + dt = np.dtype(dt_fields) + structured = np.frombuffer(data, dtype=dt, count=num_points) + points = np.column_stack((structured["x"], structured["y"], structured["z"])) + + # Filter out NaN and Inf values if not dense + if not ros_msg.is_dense: + mask = np.isfinite(points).all(axis=1) + points = points[mask] + + # Create Open3D point cloud + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(points) + + # Extract timestamp + ts = ros_msg.header.stamp.sec + ros_msg.header.stamp.nanosec / 1e9 + + return cls( + pointcloud=pc, + frame_id=ros_msg.header.frame_id, + ts=ts, + ) + + def to_ros_msg(self) -> ROSPointCloud2: + """Convert to ROS sensor_msgs/PointCloud2 message. + + Returns: + ROS PointCloud2 message + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert to ROS message.") + + ros_msg = ROSPointCloud2() # type: ignore[no-untyped-call] + + # Set header + ros_msg.header = ROSHeader() # type: ignore[no-untyped-call] + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1e9) + + points = self.as_numpy() + + if len(points) == 0: + # Empty point cloud + ros_msg.height = 0 + ros_msg.width = 0 + ros_msg.fields = [] + ros_msg.is_bigendian = False + ros_msg.point_step = 0 + ros_msg.row_step = 0 + ros_msg.data = b"" + ros_msg.is_dense = True + return ros_msg + + # Set dimensions + ros_msg.height = 1 # Unorganized point cloud + ros_msg.width = len(points) + + # Define fields (X, Y, Z as float32) + ros_msg.fields = [ + ROSPointField(name="x", offset=0, datatype=ROSPointField.FLOAT32, count=1), # type: ignore[no-untyped-call] + ROSPointField(name="y", offset=4, datatype=ROSPointField.FLOAT32, count=1), # type: ignore[no-untyped-call] + ROSPointField(name="z", offset=8, datatype=ROSPointField.FLOAT32, count=1), # type: ignore[no-untyped-call] + ] + + # Set point step and row step + ros_msg.point_step = 12 # 3 floats * 4 bytes each + ros_msg.row_step = ros_msg.point_step * ros_msg.width + + # Convert points to bytes (little endian float32) + ros_msg.data = points.astype(np.float32).tobytes() + + # Set properties + ros_msg.is_bigendian = False # Little endian + ros_msg.is_dense = True # No invalid points + + return ros_msg diff --git a/dimos/msgs/sensor_msgs/__init__.py b/dimos/msgs/sensor_msgs/__init__.py new file mode 100644 index 0000000000..130df72964 --- /dev/null +++ b/dimos/msgs/sensor_msgs/__init__.py @@ -0,0 +1,6 @@ +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.Joy import Joy +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +__all__ = ["CameraInfo", "Image", "ImageFormat", "Joy", "PointCloud2"] diff --git a/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py b/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py new file mode 100644 index 0000000000..09e5ee4a81 --- /dev/null +++ b/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py @@ -0,0 +1,212 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +import base64 +from enum import Enum +import os +from typing import Any + +import cv2 +import numpy as np + +try: + import cupy as cp # type: ignore + + HAS_CUDA = True +except Exception: # pragma: no cover - optional dependency + cp = None + HAS_CUDA = False + +# Optional nvImageCodec (preferred GPU codec) +USE_NVIMGCODEC = os.environ.get("USE_NVIMGCODEC", "0") == "1" +NVIMGCODEC_LAST_USED = False +try: # pragma: no cover - optional dependency + if HAS_CUDA and USE_NVIMGCODEC: + from nvidia import nvimgcodec # type: ignore + + try: + _enc_probe = nvimgcodec.Encoder() + HAS_NVIMGCODEC = True + except Exception: + nvimgcodec = None + HAS_NVIMGCODEC = False + else: + nvimgcodec = None + HAS_NVIMGCODEC = False +except Exception: # pragma: no cover - optional dependency + nvimgcodec = None + HAS_NVIMGCODEC = False + + +class ImageFormat(Enum): + BGR = "BGR" + RGB = "RGB" + RGBA = "RGBA" + BGRA = "BGRA" + GRAY = "GRAY" + GRAY16 = "GRAY16" + DEPTH = "DEPTH" + DEPTH16 = "DEPTH16" + + +def _is_cu(x) -> bool: # type: ignore[no-untyped-def] + return HAS_CUDA and cp is not None and isinstance(x, cp.ndarray) + + +def _ascontig(x): # type: ignore[no-untyped-def] + if _is_cu(x): + return x if x.flags["C_CONTIGUOUS"] else cp.ascontiguousarray(x) + return x if x.flags["C_CONTIGUOUS"] else np.ascontiguousarray(x) + + +def _to_cpu(x): # type: ignore[no-untyped-def] + return cp.asnumpy(x) if _is_cu(x) else x + + +def _to_cu(x): # type: ignore[no-untyped-def] + if HAS_CUDA and cp is not None and isinstance(x, np.ndarray): + return cp.asarray(x) + return x + + +def _encode_nvimgcodec_cuda(bgr_cu, quality: int = 80) -> bytes: # type: ignore[no-untyped-def] # pragma: no cover - optional + if not HAS_NVIMGCODEC or nvimgcodec is None: + raise RuntimeError("nvimgcodec not available") + if bgr_cu.ndim != 3 or bgr_cu.shape[2] != 3: + raise RuntimeError("nvimgcodec expects HxWx3 image") + if bgr_cu.dtype != cp.uint8: + raise RuntimeError("nvimgcodec requires uint8 input") + if not bgr_cu.flags["C_CONTIGUOUS"]: + bgr_cu = cp.ascontiguousarray(bgr_cu) + encoder = nvimgcodec.Encoder() + try: + img = nvimgcodec.Image(bgr_cu, nvimgcodec.PixelFormat.BGR) + except Exception: + img = nvimgcodec.Image(cp.asnumpy(bgr_cu), nvimgcodec.PixelFormat.BGR) + if hasattr(nvimgcodec, "EncodeParams"): + params = nvimgcodec.EncodeParams(quality=quality) + bitstreams = encoder.encode([img], [params]) + else: + bitstreams = encoder.encode([img]) + bs0 = bitstreams[0] + if hasattr(bs0, "buf"): + return bytes(bs0.buf) + return bytes(bs0) + + +class AbstractImage(ABC): + data: Any + format: ImageFormat + frame_id: str + ts: float + + @property + @abstractmethod + def is_cuda(self) -> bool: # pragma: no cover - abstract + ... + + @property + def height(self) -> int: + return int(self.data.shape[0]) + + @property + def width(self) -> int: + return int(self.data.shape[1]) + + @property + def channels(self) -> int: + if getattr(self.data, "ndim", 0) == 2: + return 1 + if getattr(self.data, "ndim", 0) == 3: + return int(self.data.shape[2]) + raise ValueError("Invalid image dimensions") + + @property + def shape(self): # type: ignore[no-untyped-def] + return tuple(self.data.shape) + + @property + def dtype(self): # type: ignore[no-untyped-def] + return self.data.dtype + + @abstractmethod + def to_opencv(self) -> np.ndarray: # type: ignore[type-arg] # pragma: no cover - abstract + ... + + @abstractmethod + def to_rgb(self) -> AbstractImage: # pragma: no cover - abstract + ... + + @abstractmethod + def to_bgr(self) -> AbstractImage: # pragma: no cover - abstract + ... + + @abstractmethod + def to_grayscale(self) -> AbstractImage: # pragma: no cover - abstract + ... + + @abstractmethod + def resize( + self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR + ) -> AbstractImage: # pragma: no cover - abstract + ... + + @abstractmethod + def sharpness(self) -> float: # pragma: no cover - abstract + ... + + def copy(self) -> AbstractImage: + return self.__class__( + data=self.data.copy(), format=self.format, frame_id=self.frame_id, ts=self.ts + ) # type: ignore + + def save(self, filepath: str) -> bool: + global NVIMGCODEC_LAST_USED + if self.is_cuda and HAS_NVIMGCODEC and nvimgcodec is not None: + try: + bgr = self.to_bgr() + if _is_cu(bgr.data): + jpeg = _encode_nvimgcodec_cuda(bgr.data) + NVIMGCODEC_LAST_USED = True + with open(filepath, "wb") as f: + f.write(jpeg) + return True + except Exception: + NVIMGCODEC_LAST_USED = False + arr = self.to_opencv() + return cv2.imwrite(filepath, arr) + + def to_base64(self, quality: int = 80) -> str: + global NVIMGCODEC_LAST_USED + if self.is_cuda and HAS_NVIMGCODEC and nvimgcodec is not None: + try: + bgr = self.to_bgr() + if _is_cu(bgr.data): + jpeg = _encode_nvimgcodec_cuda(bgr.data, quality=quality) + NVIMGCODEC_LAST_USED = True + return base64.b64encode(jpeg).decode("utf-8") + except Exception: + NVIMGCODEC_LAST_USED = False + bgr = self.to_bgr() + success, buffer = cv2.imencode( + ".jpg", + _to_cpu(bgr.data), # type: ignore[no-untyped-call] + [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)], + ) + if not success: + raise ValueError("Failed to encode image as JPEG") + return base64.b64encode(buffer.tobytes()).decode("utf-8") diff --git a/dimos/msgs/sensor_msgs/image_impls/CudaImage.py b/dimos/msgs/sensor_msgs/image_impls/CudaImage.py new file mode 100644 index 0000000000..ec3ab346da --- /dev/null +++ b/dimos/msgs/sensor_msgs/image_impls/CudaImage.py @@ -0,0 +1,939 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +import time + +import cv2 +import numpy as np + +from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ( + HAS_CUDA, + AbstractImage, + ImageFormat, + _ascontig, + _is_cu, + _to_cpu, +) + +try: + import cupy as cp # type: ignore + from cupyx.scipy import ( # type: ignore[import-not-found] + ndimage as cndimage, + signal as csignal, + ) +except Exception: # pragma: no cover + cp = None + cndimage = None + csignal = None + + +_CUDA_SRC = r""" +extern "C" { + +__device__ __forceinline__ void rodrigues_R(const float r[3], float R[9]){ + float theta = sqrtf(r[0]*r[0] + r[1]*r[1] + r[2]*r[2]); + if(theta < 1e-8f){ + R[0]=1.f; R[1]=0.f; R[2]=0.f; + R[3]=0.f; R[4]=1.f; R[5]=0.f; + R[6]=0.f; R[7]=0.f; R[8]=1.f; + return; + } + float kx=r[0]/theta, ky=r[1]/theta, kz=r[2]/theta; + float c=cosf(theta), s=sinf(theta), v=1.f-c; + R[0]=kx*kx*v + c; R[1]=kx*ky*v - kz*s; R[2]=kx*kz*v + ky*s; + R[3]=ky*kx*v + kz*s; R[4]=ky*ky*v + c; R[5]=ky*kz*v - kx*s; + R[6]=kz*kx*v - ky*s; R[7]=kz*ky*v + kx*s; R[8]=kz*kz*v + c; +} + +__device__ __forceinline__ void mat3x3_vec3(const float R[9], const float x[3], float y[3]){ + y[0] = R[0]*x[0] + R[1]*x[1] + R[2]*x[2]; + y[1] = R[3]*x[0] + R[4]*x[1] + R[5]*x[2]; + y[2] = R[6]*x[0] + R[7]*x[1] + R[8]*x[2]; +} + +__device__ __forceinline__ void cross_mat(const float v[3], float S[9]){ + S[0]=0.f; S[1]=-v[2]; S[2]= v[1]; + S[3]= v[2]; S[4]=0.f; S[5]=-v[0]; + S[6]=-v[1]; S[7]= v[0]; S[8]=0.f; +} + +// Solve a 6x6 system (JTJ * x = JTr) with Gauss-Jordan; JTJ is SPD after damping. +__device__ void solve6_gauss_jordan(float A[36], float b[6], float x[6]){ + float M[6][7]; + #pragma unroll + for(int r=0;r<6;++r){ + #pragma unroll + for(int c=0;c<6;++c) M[r][c] = A[r*6 + c]; + M[r][6] = b[r]; + } + for(int piv=0;piv<6;++piv){ + float invd = 1.f / M[piv][piv]; + for(int c=piv;c<7;++c) M[piv][c] *= invd; + for(int r=0;r<6;++r){ + if(r==piv) continue; + float f = M[r][piv]; + if(fabsf(f) < 1e-20f) continue; + for(int c=piv;c<7;++c) M[r][c] -= f * M[piv][c]; + } + } + #pragma unroll + for(int r=0;r<6;++r) x[r] = M[r][6]; +} + +// One block solves one pose; dynamic shared memory holds per-thread accumulators. +__global__ void pnp_gn_batch( + const float* __restrict__ obj, // (B,N,3) + const float* __restrict__ img, // (B,N,2) + const int N, + const float* __restrict__ intr, // (B,4) -> fx, fy, cx, cy + const int max_iters, + const float damping, + float* __restrict__ rvec_out, // (B,3) + float* __restrict__ tvec_out // (B,3) +){ + if(N <= 0) return; + int b = blockIdx.x; + const float* obj_b = obj + b * N * 3; + const float* img_b = img + b * N * 2; + float fx = intr[4*b + 0]; + float fy = intr[4*b + 1]; + float cx = intr[4*b + 2]; + float cy = intr[4*b + 3]; + + __shared__ float s_R[9]; + __shared__ float s_rvec[3]; + __shared__ float s_tvec[3]; + __shared__ float s_JTJ[36]; + __shared__ float s_JTr[6]; + __shared__ int s_done; + + extern __shared__ float scratch[]; + float* sh_JTJ = scratch; + float* sh_JTr = scratch + 36 * blockDim.x; + + if(threadIdx.x==0){ + s_rvec[0]=0.f; s_rvec[1]=0.f; s_rvec[2]=0.f; + s_tvec[0]=0.f; s_tvec[1]=0.f; s_tvec[2]=2.f; + } + __syncthreads(); + + for(int it=0; itmatrix) for NumPy/CuPy arrays.""" + + if cp is not None and ( + isinstance(x, cp.ndarray) or getattr(x, "__cuda_array_interface__", None) is not None + ): + xp = cp + else: + xp = np + arr = xp.asarray(x, dtype=xp.float64) + + if not inverse and arr.ndim >= 2 and arr.shape[-2:] == (3, 3): + inverse = True + + if not inverse: + vec = arr + if vec.ndim >= 2 and vec.shape[-1] == 1: + vec = vec[..., 0] + if vec.shape[-1] != 3: + raise ValueError("Rodrigues expects vectors of shape (..., 3)") + orig_shape = vec.shape[:-1] + vec = vec.reshape(-1, 3) + n = vec.shape[0] + theta = xp.linalg.norm(vec, axis=1) + small = theta < 1e-12 + + def _skew(v): # type: ignore[no-untyped-def] + vx, vy, vz = v[:, 0], v[:, 1], v[:, 2] + O = xp.zeros_like(vx) + return xp.stack( + [ + xp.stack([O, -vz, vy], axis=-1), + xp.stack([vz, O, -vx], axis=-1), + xp.stack([-vy, vx, O], axis=-1), + ], + axis=-2, + ) + + K = _skew(vec) # type: ignore[no-untyped-call] + theta2 = theta * theta + theta4 = theta2 * theta2 + theta_safe = xp.where(small, 1.0, theta) + theta2_safe = xp.where(small, 1.0, theta2) + A = xp.where(small, 1.0 - theta2 / 6.0 + theta4 / 120.0, xp.sin(theta) / theta_safe)[ + :, None, None + ] + B = xp.where( + small, + 0.5 - theta2 / 24.0 + theta4 / 720.0, + (1.0 - xp.cos(theta)) / theta2_safe, + )[:, None, None] + I = xp.eye(3, dtype=arr.dtype) + I = I[None, :, :] if n == 1 else xp.broadcast_to(I, (n, 3, 3)) + KK = xp.matmul(K, K) + out = I + A * K + B * KK + return out.reshape((*orig_shape, 3, 3)) if orig_shape else out[0] + + mat = arr + if mat.shape[-2:] != (3, 3): + raise ValueError("Rodrigues expects rotation matrices of shape (..., 3, 3)") + orig_shape = mat.shape[:-2] + mat = mat.reshape(-1, 3, 3) + trace = xp.trace(mat, axis1=1, axis2=2) + trace = xp.clip((trace - 1.0) / 2.0, -1.0, 1.0) + theta = xp.arccos(trace) + v = xp.stack( + [ + mat[:, 2, 1] - mat[:, 1, 2], + mat[:, 0, 2] - mat[:, 2, 0], + mat[:, 1, 0] - mat[:, 0, 1], + ], + axis=1, + ) + norm_v = xp.linalg.norm(v, axis=1) + small = theta < 1e-7 + eps = 1e-8 + norm_safe = xp.where(norm_v < eps, 1.0, norm_v) + r_general = theta[:, None] * v / norm_safe[:, None] + r_small = 0.5 * v + r = xp.where(small[:, None], r_small, r_general) + pi_mask = xp.abs(theta - xp.pi) < 1e-4 + if np.any(pi_mask) if xp is np else bool(cp.asnumpy(pi_mask).any()): + diag = xp.diagonal(mat, axis1=1, axis2=2) + axis_candidates = xp.clip((diag + 1.0) / 2.0, 0.0, None) + axis = xp.sqrt(axis_candidates) + signs = xp.sign(v) + axis = xp.where(signs == 0, axis, xp.copysign(axis, signs)) + axis_norm = xp.linalg.norm(axis, axis=1) + axis_norm = xp.where(axis_norm < eps, 1.0, axis_norm) + axis = axis / axis_norm[:, None] + r_pi = theta[:, None] * axis + r = xp.where(pi_mask[:, None], r_pi, r) + out = r.reshape((*orig_shape, 3)) if orig_shape else r[0] + return out + + +def _undistort_points_cuda( + img_px: cp.ndarray, K: cp.ndarray, dist: cp.ndarray, iterations: int = 8 +) -> cp.ndarray: + """Iteratively undistort pixel coordinates on device (Brown–Conrady). + + Returns pixel coordinates after undistortion (fx*xu+cx, fy*yu+cy). + """ + N = img_px.shape[0] + ones = cp.ones((N, 1), dtype=cp.float64) + uv1 = cp.concatenate([img_px.astype(cp.float64), ones], axis=1) + Kinv = cp.linalg.inv(K) + xdyd1 = uv1 @ Kinv.T + xd = xdyd1[:, 0] + yd = xdyd1[:, 1] + xu = xd.copy() + yu = yd.copy() + k1 = dist[0] + k2 = dist[1] if dist.size > 1 else 0.0 + p1 = dist[2] if dist.size > 2 else 0.0 + p2 = dist[3] if dist.size > 3 else 0.0 + k3 = dist[4] if dist.size > 4 else 0.0 + for _ in range(iterations): + r2 = xu * xu + yu * yu + r4 = r2 * r2 + r6 = r4 * r2 + radial = 1.0 + k1 * r2 + k2 * r4 + k3 * r6 + delta_x = 2.0 * p1 * xu * yu + p2 * (r2 + 2.0 * xu * xu) + delta_y = p1 * (r2 + 2.0 * yu * yu) + 2.0 * p2 * xu * yu + xu = (xd - delta_x) / radial + yu = (yd - delta_y) / radial + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + return cp.stack([fx * xu + cx, fy * yu + cy], axis=1) + + +@dataclass +class CudaImage(AbstractImage): + data: any # type: ignore[valid-type] # cupy.ndarray + format: ImageFormat = field(default=ImageFormat.BGR) + frame_id: str = field(default="") + ts: float = field(default_factory=time.time) + + def __post_init__(self): # type: ignore[no-untyped-def] + if not HAS_CUDA or cp is None: + raise RuntimeError("CuPy/CUDA not available") + if not _is_cu(self.data): + # Accept NumPy arrays and move to device automatically + try: + self.data = cp.asarray(self.data) + except Exception as e: + raise ValueError("CudaImage requires a CuPy array") from e + if self.data.ndim < 2: # type: ignore[attr-defined] + raise ValueError("Image data must be at least 2D") + self.data = _ascontig(self.data) # type: ignore[no-untyped-call] + + @property + def is_cuda(self) -> bool: + return True + + def to_opencv(self) -> np.ndarray: # type: ignore[type-arg] + if self.format in (ImageFormat.BGR, ImageFormat.RGB, ImageFormat.RGBA, ImageFormat.BGRA): + return _to_cpu(self.to_bgr().data) # type: ignore[no-any-return, no-untyped-call] + return _to_cpu(self.data) # type: ignore[no-any-return, no-untyped-call] + + def to_rgb(self) -> CudaImage: + if self.format == ImageFormat.RGB: + return self.copy() # type: ignore + if self.format == ImageFormat.BGR: + return CudaImage(_bgr_to_rgb_cuda(self.data), ImageFormat.RGB, self.frame_id, self.ts) # type: ignore[no-untyped-call] + if self.format == ImageFormat.RGBA: + return self.copy() # type: ignore + if self.format == ImageFormat.BGRA: + return CudaImage( + _bgra_to_rgba_cuda(self.data), # type: ignore[no-untyped-call] + ImageFormat.RGBA, + self.frame_id, + self.ts, + ) + if self.format == ImageFormat.GRAY: + return CudaImage(_gray_to_rgb_cuda(self.data), ImageFormat.RGB, self.frame_id, self.ts) # type: ignore[no-untyped-call] + if self.format in (ImageFormat.GRAY16, ImageFormat.DEPTH16): + gray8 = (self.data.astype(cp.float32) / 256.0).clip(0, 255).astype(cp.uint8) # type: ignore + return CudaImage(_gray_to_rgb_cuda(gray8), ImageFormat.RGB, self.frame_id, self.ts) # type: ignore[no-untyped-call] + return self.copy() # type: ignore + + def to_bgr(self) -> CudaImage: + if self.format == ImageFormat.BGR: + return self.copy() # type: ignore + if self.format == ImageFormat.RGB: + return CudaImage(_rgb_to_bgr_cuda(self.data), ImageFormat.BGR, self.frame_id, self.ts) # type: ignore[no-untyped-call] + if self.format == ImageFormat.RGBA: + return CudaImage( + _rgba_to_bgra_cuda(self.data)[..., :3], # type: ignore[no-untyped-call] + ImageFormat.BGR, + self.frame_id, + self.ts, + ) + if self.format == ImageFormat.BGRA: + return CudaImage(self.data[..., :3], ImageFormat.BGR, self.frame_id, self.ts) # type: ignore[index] + if self.format in (ImageFormat.GRAY, ImageFormat.DEPTH): + return CudaImage( + _rgb_to_bgr_cuda(_gray_to_rgb_cuda(self.data)), # type: ignore[no-untyped-call] + ImageFormat.BGR, + self.frame_id, + self.ts, + ) + if self.format in (ImageFormat.GRAY16, ImageFormat.DEPTH16): + gray8 = (self.data.astype(cp.float32) / 256.0).clip(0, 255).astype(cp.uint8) # type: ignore + return CudaImage( + _rgb_to_bgr_cuda(_gray_to_rgb_cuda(gray8)), # type: ignore[no-untyped-call] + ImageFormat.BGR, + self.frame_id, + self.ts, + ) + return self.copy() # type: ignore + + def to_grayscale(self) -> CudaImage: + if self.format in (ImageFormat.GRAY, ImageFormat.GRAY16, ImageFormat.DEPTH): + return self.copy() # type: ignore + if self.format == ImageFormat.BGR: + return CudaImage( + _rgb_to_gray_cuda(_bgr_to_rgb_cuda(self.data)), # type: ignore[no-untyped-call] + ImageFormat.GRAY, + self.frame_id, + self.ts, + ) + if self.format == ImageFormat.RGB: + return CudaImage(_rgb_to_gray_cuda(self.data), ImageFormat.GRAY, self.frame_id, self.ts) # type: ignore[no-untyped-call] + if self.format in (ImageFormat.RGBA, ImageFormat.BGRA): + rgb = ( + self.data[..., :3] # type: ignore[index] + if self.format == ImageFormat.RGBA + else _bgra_to_rgba_cuda(self.data)[..., :3] # type: ignore[no-untyped-call] + ) + return CudaImage(_rgb_to_gray_cuda(rgb), ImageFormat.GRAY, self.frame_id, self.ts) # type: ignore[no-untyped-call] + raise ValueError(f"Unsupported format: {self.format}") + + def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> CudaImage: + return CudaImage( + _resize_bilinear_hwc_cuda(self.data, height, width), self.format, self.frame_id, self.ts + ) + + def crop(self, x: int, y: int, width: int, height: int) -> CudaImage: + """Crop the image to the specified region. + + Args: + x: Starting x coordinate (left edge) + y: Starting y coordinate (top edge) + width: Width of the cropped region + height: Height of the cropped region + + Returns: + A new CudaImage containing the cropped region + """ + # Get current image dimensions + img_height, img_width = self.data.shape[:2] # type: ignore[attr-defined] + + # Clamp the crop region to image bounds + x = max(0, min(x, img_width)) + y = max(0, min(y, img_height)) + x_end = min(x + width, img_width) + y_end = min(y + height, img_height) + + # Perform the crop using array slicing + if self.data.ndim == 2: # type: ignore[attr-defined] + # Grayscale image + cropped_data = self.data[y:y_end, x:x_end] # type: ignore[index] + else: + # Color image (HxWxC) + cropped_data = self.data[y:y_end, x:x_end, :] # type: ignore[index] + + # Return a new CudaImage with the cropped data + return CudaImage(cropped_data, self.format, self.frame_id, self.ts) + + def sharpness(self) -> float: + if cp is None: + return 0.0 + try: + from cupyx.scipy import ndimage as cndimage + + gray = self.to_grayscale().data.astype(cp.float32) # type: ignore[attr-defined] + deriv5 = cp.asarray([1, 2, 0, -2, -1], dtype=cp.float32) + smooth5 = cp.asarray([1, 4, 6, 4, 1], dtype=cp.float32) + gx = cndimage.convolve1d(gray, deriv5, axis=1, mode="reflect") + gx = cndimage.convolve1d(gx, smooth5, axis=0, mode="reflect") + gy = cndimage.convolve1d(gray, deriv5, axis=0, mode="reflect") + gy = cndimage.convolve1d(gy, smooth5, axis=1, mode="reflect") + magnitude = cp.hypot(gx, gy) + mean_mag = float(cp.asnumpy(magnitude.mean())) + except Exception: + return 0.0 + if mean_mag <= 0: + return 0.0 + return float(np.clip((np.log10(mean_mag + 1) - 1.7) / 2.0, 0.0, 1.0)) + + # CUDA tracker (template NCC with small scale pyramid) + @dataclass + class BBox: + x: int + y: int + w: int + h: int + + def create_csrt_tracker(self, bbox: BBox): # type: ignore[no-untyped-def] + if csignal is None: + raise RuntimeError("cupyx.scipy.signal not available for CUDA tracker") + x, y, w, h = map(int, bbox) # type: ignore[call-overload] + gray = self.to_grayscale().data.astype(cp.float32) # type: ignore[attr-defined] + tmpl = gray[y : y + h, x : x + w] + if tmpl.size == 0: + raise ValueError("Invalid bbox for CUDA tracker") + return _CudaTemplateTracker(tmpl, x0=x, y0=y) + + def csrt_update(self, tracker) -> tuple[bool, tuple[int, int, int, int]]: # type: ignore[no-untyped-def] + if not isinstance(tracker, _CudaTemplateTracker): + raise TypeError("Expected CUDA tracker instance") + gray = self.to_grayscale().data.astype(cp.float32) # type: ignore[attr-defined] + x, y, w, h = tracker.update(gray) + return True, (int(x), int(y), int(w), int(h)) + + # PnP – Gauss–Newton (no distortion in batch), iterative per-instance + def solve_pnp( + self, + object_points: np.ndarray, # type: ignore[type-arg] + image_points: np.ndarray, # type: ignore[type-arg] + camera_matrix: np.ndarray, # type: ignore[type-arg] + dist_coeffs: np.ndarray | None = None, # type: ignore[type-arg] + flags: int = cv2.SOLVEPNP_ITERATIVE, + ) -> tuple[bool, np.ndarray, np.ndarray]: # type: ignore[type-arg] + if not HAS_CUDA or cp is None or (dist_coeffs is not None and np.any(dist_coeffs)): + obj = np.asarray(object_points, dtype=np.float32).reshape(-1, 3) + img = np.asarray(image_points, dtype=np.float32).reshape(-1, 2) + K = np.asarray(camera_matrix, dtype=np.float64) + dist = None if dist_coeffs is None else np.asarray(dist_coeffs, dtype=np.float64) + ok, rvec, tvec = cv2.solvePnP(obj, img, K, dist, flags=flags) # type: ignore[arg-type] + return bool(ok), rvec.astype(np.float64), tvec.astype(np.float64) + + rvec, tvec = _solve_pnp_cuda_kernel(object_points, image_points, camera_matrix) + ok = np.isfinite(rvec).all() and np.isfinite(tvec).all() + return ok, rvec, tvec + + def solve_pnp_batch( + self, + object_points_batch: np.ndarray, # type: ignore[type-arg] + image_points_batch: np.ndarray, # type: ignore[type-arg] + camera_matrix: np.ndarray, # type: ignore[type-arg] + dist_coeffs: np.ndarray | None = None, # type: ignore[type-arg] + iterations: int = 15, + damping: float = 1e-6, + ) -> tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] + """Batched PnP (each block = one instance).""" + if not HAS_CUDA or cp is None or (dist_coeffs is not None and np.any(dist_coeffs)): + obj = np.asarray(object_points_batch, dtype=np.float32) + img = np.asarray(image_points_batch, dtype=np.float32) + if obj.ndim != 3 or img.ndim != 3 or obj.shape[:2] != img.shape[:2]: + raise ValueError( + "Batched object/image arrays must be shaped (B,N,...) with matching sizes" + ) + K = np.asarray(camera_matrix, dtype=np.float64) + dist = None if dist_coeffs is None else np.asarray(dist_coeffs, dtype=np.float64) + B = obj.shape[0] + r_list = np.empty((B, 3, 1), dtype=np.float64) + t_list = np.empty((B, 3, 1), dtype=np.float64) + for b in range(B): + K_b = K if K.ndim == 2 else K[b] + dist_b = None + if dist is not None: + if dist.ndim == 1: + dist_b = dist + elif dist.ndim == 2: + dist_b = dist[b] + else: + raise ValueError("dist_coeffs must be 1D or batched 2D") + ok, rvec, tvec = cv2.solvePnP( + obj[b], + img[b], + K_b, + dist_b, # type: ignore[arg-type] + flags=cv2.SOLVEPNP_ITERATIVE, + ) + if not ok: + raise RuntimeError(f"cv2.solvePnP failed for batch index {b}") + r_list[b] = rvec.astype(np.float64) + t_list[b] = tvec.astype(np.float64) + return r_list, t_list + + return _solve_pnp_cuda_kernel( # type: ignore[no-any-return] + object_points_batch, + image_points_batch, + camera_matrix, + iterations=iterations, + damping=damping, + ) + + def solve_pnp_ransac( + self, + object_points: np.ndarray, # type: ignore[type-arg] + image_points: np.ndarray, # type: ignore[type-arg] + camera_matrix: np.ndarray, # type: ignore[type-arg] + dist_coeffs: np.ndarray | None = None, # type: ignore[type-arg] + iterations_count: int = 100, + reprojection_error: float = 3.0, + confidence: float = 0.99, + min_sample: int = 6, + ) -> tuple[bool, np.ndarray, np.ndarray, np.ndarray]: # type: ignore[type-arg] + """RANSAC with CUDA PnP solver.""" + if not HAS_CUDA or cp is None or (dist_coeffs is not None and np.any(dist_coeffs)): + obj = np.asarray(object_points, dtype=np.float32) + img = np.asarray(image_points, dtype=np.float32) + K = np.asarray(camera_matrix, dtype=np.float64) + dist = None if dist_coeffs is None else np.asarray(dist_coeffs, dtype=np.float64) + ok, rvec, tvec, mask = cv2.solvePnPRansac( + obj, + img, + K, + dist, # type: ignore[arg-type] + iterationsCount=int(iterations_count), + reprojectionError=float(reprojection_error), + confidence=float(confidence), + flags=cv2.SOLVEPNP_ITERATIVE, + ) + mask_flat = np.zeros((obj.shape[0],), dtype=np.uint8) + if mask is not None and len(mask) > 0: + mask_flat[mask.flatten()] = 1 + return bool(ok), rvec.astype(np.float64), tvec.astype(np.float64), mask_flat + + obj = cp.asarray(object_points, dtype=cp.float32) + img = cp.asarray(image_points, dtype=cp.float32) + camera_matrix_np = np.asarray(_to_cpu(camera_matrix), dtype=np.float32) # type: ignore[no-untyped-call] + fx = float(camera_matrix_np[0, 0]) + fy = float(camera_matrix_np[1, 1]) + cx = float(camera_matrix_np[0, 2]) + cy = float(camera_matrix_np[1, 2]) + N = obj.shape[0] + rng = cp.random.RandomState(1234) + best_inliers = -1 + _best_r, _best_t, best_mask = None, None, None + + for _ in range(iterations_count): + idx = rng.choice(N, size=min_sample, replace=False) + rvec, tvec = _solve_pnp_cuda_kernel(obj[idx], img[idx], camera_matrix_np) + R = _rodrigues(cp.asarray(rvec.flatten())) + Xc = obj @ R.T + cp.asarray(tvec.flatten()) + invZ = 1.0 / cp.clip(Xc[:, 2], 1e-6, None) + u_hat = fx * Xc[:, 0] * invZ + cx + v_hat = fy * Xc[:, 1] * invZ + cy + err = cp.sqrt((img[:, 0] - u_hat) ** 2 + (img[:, 1] - v_hat) ** 2) + mask = (err < reprojection_error).astype(cp.uint8) + inliers = int(mask.sum()) + if inliers > best_inliers: + best_inliers, _best_r, _best_t, best_mask = inliers, rvec, tvec, mask + if inliers >= int(confidence * N): + break + + if best_inliers <= 0: + return False, np.zeros((3, 1)), np.zeros((3, 1)), np.zeros((N,), dtype=np.uint8) + in_idx = cp.nonzero(best_mask)[0] + rvec, tvec = _solve_pnp_cuda_kernel(obj[in_idx], img[in_idx], camera_matrix_np) + return True, rvec, tvec, cp.asnumpy(best_mask) + + +class _CudaTemplateTracker: + def __init__( + self, + tmpl: cp.ndarray, + scale_step: float = 1.05, + lr: float = 0.1, + search_radius: int = 16, + x0: int = 0, + y0: int = 0, + ) -> None: + self.tmpl = tmpl.astype(cp.float32) + self.h, self.w = int(tmpl.shape[0]), int(tmpl.shape[1]) + self.scale_step = float(scale_step) + self.lr = float(lr) + self.search_radius = int(search_radius) + # Cosine window + wy = cp.hanning(self.h).astype(cp.float32) + wx = cp.hanning(self.w).astype(cp.float32) + self.window = wy[:, None] * wx[None, :] + self.tmpl = self.tmpl * self.window + self.y = int(y0) + self.x = int(x0) + + def update(self, gray: cp.ndarray): # type: ignore[no-untyped-def] + H, W = int(gray.shape[0]), int(gray.shape[1]) + r = self.search_radius + x0 = max(0, self.x - r) + y0 = max(0, self.y - r) + x1 = min(W, self.x + self.w + r) + y1 = min(H, self.y + self.h + r) + search = gray[y0:y1, x0:x1] + if search.shape[0] < self.h or search.shape[1] < self.w: + search = gray + x0 = y0 = 0 + best = (self.x, self.y, self.w, self.h) + best_score = -1e9 + for s in (1.0 / self.scale_step, 1.0, self.scale_step): + th = max(1, round(self.h * s)) + tw = max(1, round(self.w * s)) + tmpl_s = _resize_bilinear_hwc_cuda(self.tmpl, th, tw) + if tmpl_s.ndim == 3: + tmpl_s = tmpl_s[..., 0] + tmpl_s = tmpl_s.astype(cp.float32) + tmpl_zm = tmpl_s - tmpl_s.mean() + tmpl_energy = cp.sqrt(cp.sum(tmpl_zm * tmpl_zm)) + 1e-6 + # NCC via correlate2d and local std + ones = cp.ones((th, tw), dtype=cp.float32) + num = csignal.correlate2d(search, tmpl_zm, mode="valid") + sumS = csignal.correlate2d(search, ones, mode="valid") + sumS2 = csignal.correlate2d(search * search, ones, mode="valid") + n = float(th * tw) + meanS = sumS / n + varS = cp.clip(sumS2 - n * meanS * meanS, 0.0, None) + stdS = cp.sqrt(varS) + 1e-6 + res = num / (stdS * tmpl_energy) + ij = cp.unravel_index(cp.argmax(res), res.shape) + dy, dx = int(ij[0].get()), int(ij[1].get()) + score = float(res[ij].get()) + if score > best_score: + best_score = score + best = (x0 + dx, y0 + dy, tw, th) + x, y, w, h = best + patch = gray[y : y + h, x : x + w] + if patch.shape[0] != self.h or patch.shape[1] != self.w: + patch = _resize_bilinear_hwc_cuda(patch, self.h, self.w) + if patch.ndim == 3: + patch = patch[..., 0] + patch = patch.astype(cp.float32) * self.window + self.tmpl = (1.0 - self.lr) * self.tmpl + self.lr * patch + self.x, self.y, self.w, self.h = x, y, w, h + return x, y, w, h diff --git a/dimos/msgs/sensor_msgs/image_impls/NumpyImage.py b/dimos/msgs/sensor_msgs/image_impls/NumpyImage.py new file mode 100644 index 0000000000..8aa5478a20 --- /dev/null +++ b/dimos/msgs/sensor_msgs/image_impls/NumpyImage.py @@ -0,0 +1,243 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +import time + +import cv2 +import numpy as np + +from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ( + AbstractImage, + ImageFormat, +) + + +@dataclass +class NumpyImage(AbstractImage): + data: np.ndarray # type: ignore[type-arg] + format: ImageFormat = field(default=ImageFormat.BGR) + frame_id: str = field(default="") + ts: float = field(default_factory=time.time) + + def __post_init__(self): # type: ignore[no-untyped-def] + if not isinstance(self.data, np.ndarray) or self.data.ndim < 2: + raise ValueError("NumpyImage requires a 2D/3D NumPy array") + + @property + def is_cuda(self) -> bool: + return False + + def to_opencv(self) -> np.ndarray: # type: ignore[type-arg] + arr = self.data + if self.format == ImageFormat.BGR: + return arr + if self.format == ImageFormat.RGB: + return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) + if self.format == ImageFormat.RGBA: + return cv2.cvtColor(arr, cv2.COLOR_RGBA2BGR) + if self.format == ImageFormat.BGRA: + return cv2.cvtColor(arr, cv2.COLOR_BGRA2BGR) + if self.format in ( + ImageFormat.GRAY, + ImageFormat.GRAY16, + ImageFormat.DEPTH, + ImageFormat.DEPTH16, + ): + return arr + raise ValueError(f"Unsupported format: {self.format}") + + def to_rgb(self) -> NumpyImage: + if self.format == ImageFormat.RGB: + return self.copy() # type: ignore + arr = self.data + if self.format == ImageFormat.BGR: + return NumpyImage( + cv2.cvtColor(arr, cv2.COLOR_BGR2RGB), ImageFormat.RGB, self.frame_id, self.ts + ) + if self.format == ImageFormat.RGBA: + return self.copy() # type: ignore[return-value] # RGBA contains RGB + alpha + if self.format == ImageFormat.BGRA: + rgba = cv2.cvtColor(arr, cv2.COLOR_BGRA2RGBA) + return NumpyImage(rgba, ImageFormat.RGBA, self.frame_id, self.ts) + if self.format in (ImageFormat.GRAY, ImageFormat.GRAY16, ImageFormat.DEPTH16): + gray8 = (arr / 256).astype(np.uint8) if self.format != ImageFormat.GRAY else arr + rgb = cv2.cvtColor(gray8, cv2.COLOR_GRAY2RGB) + return NumpyImage(rgb, ImageFormat.RGB, self.frame_id, self.ts) + return self.copy() # type: ignore + + def to_bgr(self) -> NumpyImage: + if self.format == ImageFormat.BGR: + return self.copy() # type: ignore + arr = self.data + if self.format == ImageFormat.RGB: + return NumpyImage( + cv2.cvtColor(arr, cv2.COLOR_RGB2BGR), ImageFormat.BGR, self.frame_id, self.ts + ) + if self.format == ImageFormat.RGBA: + return NumpyImage( + cv2.cvtColor(arr, cv2.COLOR_RGBA2BGR), ImageFormat.BGR, self.frame_id, self.ts + ) + if self.format == ImageFormat.BGRA: + return NumpyImage( + cv2.cvtColor(arr, cv2.COLOR_BGRA2BGR), ImageFormat.BGR, self.frame_id, self.ts + ) + if self.format in (ImageFormat.GRAY, ImageFormat.GRAY16, ImageFormat.DEPTH16): + gray8 = (arr / 256).astype(np.uint8) if self.format != ImageFormat.GRAY else arr + return NumpyImage( + cv2.cvtColor(gray8, cv2.COLOR_GRAY2BGR), ImageFormat.BGR, self.frame_id, self.ts + ) + return self.copy() # type: ignore + + def to_grayscale(self) -> NumpyImage: + if self.format in (ImageFormat.GRAY, ImageFormat.GRAY16, ImageFormat.DEPTH): + return self.copy() # type: ignore + if self.format == ImageFormat.BGR: + return NumpyImage( + cv2.cvtColor(self.data, cv2.COLOR_BGR2GRAY), + ImageFormat.GRAY, + self.frame_id, + self.ts, + ) + if self.format == ImageFormat.RGB: + return NumpyImage( + cv2.cvtColor(self.data, cv2.COLOR_RGB2GRAY), + ImageFormat.GRAY, + self.frame_id, + self.ts, + ) + if self.format in (ImageFormat.RGBA, ImageFormat.BGRA): + code = cv2.COLOR_RGBA2GRAY if self.format == ImageFormat.RGBA else cv2.COLOR_BGRA2GRAY + return NumpyImage( + cv2.cvtColor(self.data, code), ImageFormat.GRAY, self.frame_id, self.ts + ) + raise ValueError(f"Unsupported format: {self.format}") + + def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> NumpyImage: + return NumpyImage( + cv2.resize(self.data, (width, height), interpolation=interpolation), + self.format, + self.frame_id, + self.ts, + ) + + def crop(self, x: int, y: int, width: int, height: int) -> NumpyImage: + """Crop the image to the specified region. + + Args: + x: Starting x coordinate (left edge) + y: Starting y coordinate (top edge) + width: Width of the cropped region + height: Height of the cropped region + + Returns: + A new NumpyImage containing the cropped region + """ + # Get current image dimensions + img_height, img_width = self.data.shape[:2] + + # Clamp the crop region to image bounds + x = max(0, min(x, img_width)) + y = max(0, min(y, img_height)) + x_end = min(x + width, img_width) + y_end = min(y + height, img_height) + + # Perform the crop using array slicing + if self.data.ndim == 2: + # Grayscale image + cropped_data = self.data[y:y_end, x:x_end] + else: + # Color image (HxWxC) + cropped_data = self.data[y:y_end, x:x_end, :] + + # Return a new NumpyImage with the cropped data + return NumpyImage(cropped_data, self.format, self.frame_id, self.ts) + + def sharpness(self) -> float: + gray = self.to_grayscale() + sx = cv2.Sobel(gray.data, cv2.CV_32F, 1, 0, ksize=5) + sy = cv2.Sobel(gray.data, cv2.CV_32F, 0, 1, ksize=5) + magnitude = cv2.magnitude(sx, sy) + mean_mag = float(magnitude.mean()) + if mean_mag <= 0: + return 0.0 + return float(np.clip((np.log10(mean_mag + 1) - 1.7) / 2.0, 0.0, 1.0)) + + # PnP wrappers + def solve_pnp( + self, + object_points: np.ndarray, # type: ignore[type-arg] + image_points: np.ndarray, # type: ignore[type-arg] + camera_matrix: np.ndarray, # type: ignore[type-arg] + dist_coeffs: np.ndarray | None = None, # type: ignore[type-arg] + flags: int = cv2.SOLVEPNP_ITERATIVE, + ) -> tuple[bool, np.ndarray, np.ndarray]: # type: ignore[type-arg] + obj = np.asarray(object_points, dtype=np.float32).reshape(-1, 3) + img = np.asarray(image_points, dtype=np.float32).reshape(-1, 2) + K = np.asarray(camera_matrix, dtype=np.float64) + dist = None if dist_coeffs is None else np.asarray(dist_coeffs, dtype=np.float64) + ok, rvec, tvec = cv2.solvePnP(obj, img, K, dist, flags=flags) # type: ignore[arg-type] + return bool(ok), rvec.astype(np.float64), tvec.astype(np.float64) + + def create_csrt_tracker(self, bbox: tuple[int, int, int, int]): # type: ignore[no-untyped-def] + tracker = None + if hasattr(cv2, "legacy") and hasattr(cv2.legacy, "TrackerCSRT_create"): + tracker = cv2.legacy.TrackerCSRT_create() + elif hasattr(cv2, "TrackerCSRT_create"): + tracker = cv2.TrackerCSRT_create() + else: + raise RuntimeError("OpenCV CSRT tracker not available") + ok = tracker.init(self.to_bgr().to_opencv(), tuple(map(int, bbox))) + if not ok: + raise RuntimeError("Failed to initialize CSRT tracker") + return tracker + + def csrt_update(self, tracker) -> tuple[bool, tuple[int, int, int, int]]: # type: ignore[no-untyped-def] + ok, box = tracker.update(self.to_bgr().to_opencv()) + if not ok: + return False, (0, 0, 0, 0) + x, y, w, h = map(int, box) + return True, (x, y, w, h) + + def solve_pnp_ransac( + self, + object_points: np.ndarray, # type: ignore[type-arg] + image_points: np.ndarray, # type: ignore[type-arg] + camera_matrix: np.ndarray, # type: ignore[type-arg] + dist_coeffs: np.ndarray | None = None, # type: ignore[type-arg] + iterations_count: int = 100, + reprojection_error: float = 3.0, + confidence: float = 0.99, + min_sample: int = 6, + ) -> tuple[bool, np.ndarray, np.ndarray, np.ndarray]: # type: ignore[type-arg] + obj = np.asarray(object_points, dtype=np.float32).reshape(-1, 3) + img = np.asarray(image_points, dtype=np.float32).reshape(-1, 2) + K = np.asarray(camera_matrix, dtype=np.float64) + dist = None if dist_coeffs is None else np.asarray(dist_coeffs, dtype=np.float64) + ok, rvec, tvec, inliers = cv2.solvePnPRansac( + obj, + img, + K, + dist, # type: ignore[arg-type] + iterationsCount=int(iterations_count), + reprojectionError=float(reprojection_error), + confidence=float(confidence), + flags=cv2.SOLVEPNP_ITERATIVE, + ) + mask = np.zeros((obj.shape[0],), dtype=np.uint8) + if inliers is not None and len(inliers) > 0: + mask[inliers.flatten()] = 1 + return bool(ok), rvec.astype(np.float64), tvec.astype(np.float64), mask diff --git a/dimos/msgs/sensor_msgs/image_impls/test_image_backend_utils.py b/dimos/msgs/sensor_msgs/image_impls/test_image_backend_utils.py new file mode 100644 index 0000000000..d53d2c4524 --- /dev/null +++ b/dimos/msgs/sensor_msgs/image_impls/test_image_backend_utils.py @@ -0,0 +1,287 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from dimos.msgs.sensor_msgs import Image, ImageFormat + +try: + HAS_CUDA = True + print("Running image backend utils tests with CUDA/CuPy support (GPU mode)") +except: + HAS_CUDA = False + print("Running image backend utils tests in CPU-only mode") + +from dimos.perception.common.utils import ( + colorize_depth, + draw_bounding_box, + draw_object_detection_visualization, + draw_segmentation_mask, + project_2d_points_to_3d, + project_3d_points_to_2d, + rectify_image, +) + + +def _has_cupy() -> bool: + try: + import cupy as cp # type: ignore + + try: + ndev = cp.cuda.runtime.getDeviceCount() # type: ignore[attr-defined] + if ndev <= 0: + return False + x = cp.array([1, 2, 3]) + _ = int(x.sum().get()) + return True + except Exception: + return False + except Exception: + return False + + +@pytest.mark.parametrize( + "shape,fmt", [((64, 64, 3), ImageFormat.BGR), ((64, 64), ImageFormat.GRAY)] +) +def test_rectify_image_cpu(shape, fmt) -> None: + arr = (np.random.rand(*shape) * (255 if fmt != ImageFormat.GRAY else 65535)).astype( + np.uint8 if fmt != ImageFormat.GRAY else np.uint16 + ) + img = Image(data=arr, format=fmt, frame_id="cam", ts=123.456) + K = np.array( + [[100.0, 0, arr.shape[1] / 2], [0, 100.0, arr.shape[0] / 2], [0, 0, 1]], dtype=np.float64 + ) + D = np.zeros(5, dtype=np.float64) + out = rectify_image(img, K, D) + assert out.shape[:2] == arr.shape[:2] + assert out.format == fmt + assert out.frame_id == "cam" + assert abs(out.ts - 123.456) < 1e-9 + # With zero distortion, pixels should match + np.testing.assert_array_equal(out.data, arr) + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +@pytest.mark.parametrize( + "shape,fmt", [((32, 32, 3), ImageFormat.BGR), ((32, 32), ImageFormat.GRAY)] +) +def test_rectify_image_gpu_parity(shape, fmt) -> None: + import cupy as cp # type: ignore + + arr_np = (np.random.rand(*shape) * (255 if fmt != ImageFormat.GRAY else 65535)).astype( + np.uint8 if fmt != ImageFormat.GRAY else np.uint16 + ) + arr_cu = cp.asarray(arr_np) + img = Image(data=arr_cu, format=fmt, frame_id="cam", ts=1.23) + K = np.array( + [[80.0, 0, arr_np.shape[1] / 2], [0, 80.0, arr_np.shape[0] / 2], [0, 0, 1.0]], + dtype=np.float64, + ) + D = np.zeros(5, dtype=np.float64) + out = rectify_image(img, K, D) + # Zero distortion parity and backend preservation + assert out.format == fmt + assert out.frame_id == "cam" + assert abs(out.ts - 1.23) < 1e-9 + assert out.data.__class__.__module__.startswith("cupy") + np.testing.assert_array_equal(cp.asnumpy(out.data), arr_np) + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +def test_rectify_image_gpu_nonzero_dist_close() -> None: + import cupy as cp # type: ignore + + H, W = 64, 96 + # Structured pattern to make interpolation deterministic enough + x = np.linspace(0, 255, W, dtype=np.float32) + y = np.linspace(0, 255, H, dtype=np.float32) + xv, yv = np.meshgrid(x, y) + arr_np = np.stack( + [ + xv.astype(np.uint8), + yv.astype(np.uint8), + ((xv + yv) / 2).astype(np.uint8), + ], + axis=2, + ) + img_cpu = Image(data=arr_np, format=ImageFormat.BGR, frame_id="cam", ts=0.5) + img_gpu = Image(data=cp.asarray(arr_np), format=ImageFormat.BGR, frame_id="cam", ts=0.5) + + fx, fy = 120.0, 125.0 + cx, cy = W / 2.0, H / 2.0 + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1.0]], dtype=np.float64) + D = np.array([0.05, -0.02, 0.001, -0.001, 0.0], dtype=np.float64) + + out_cpu = rectify_image(img_cpu, K, D) + out_gpu = rectify_image(img_gpu, K, D) + # Compare within a small tolerance + # Small numeric differences may remain due to model and casting; keep tight tolerance + np.testing.assert_allclose( + cp.asnumpy(out_gpu.data).astype(np.int16), out_cpu.data.astype(np.int16), atol=4 + ) + + +def test_project_roundtrip_cpu() -> None: + pts3d = np.array([[0.1, 0.2, 1.0], [0.0, 0.0, 2.0], [0.5, -0.3, 3.0]], dtype=np.float32) + fx, fy, cx, cy = 200.0, 220.0, 64.0, 48.0 + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1.0]], dtype=np.float64) + uv = project_3d_points_to_2d(pts3d, K) + assert uv.shape == (3, 2) + Z = pts3d[:, 2] + pts3d_back = project_2d_points_to_3d(uv.astype(np.float32), Z.astype(np.float32), K) + # Allow small rounding differences due to int rounding in 2D + assert pts3d_back.shape == (3, 3) + assert np.all(pts3d_back[:, 2] > 0) + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +def test_project_parity_gpu_cpu() -> None: + import cupy as cp # type: ignore + + pts3d_np = np.array([[0.1, 0.2, 1.0], [0.0, 0.0, 2.0], [0.5, -0.3, 3.0]], dtype=np.float32) + fx, fy, cx, cy = 200.0, 220.0, 64.0, 48.0 + K_np = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1.0]], dtype=np.float64) + uv_cpu = project_3d_points_to_2d(pts3d_np, K_np) + uv_gpu = project_3d_points_to_2d(cp.asarray(pts3d_np), cp.asarray(K_np)) + np.testing.assert_array_equal(cp.asnumpy(uv_gpu), uv_cpu) + + Z_np = pts3d_np[:, 2] + pts3d_cpu = project_2d_points_to_3d(uv_cpu.astype(np.float32), Z_np.astype(np.float32), K_np) + pts3d_gpu = project_2d_points_to_3d( + cp.asarray(uv_cpu.astype(np.float32)), cp.asarray(Z_np.astype(np.float32)), cp.asarray(K_np) + ) + assert pts3d_cpu.shape == cp.asnumpy(pts3d_gpu).shape + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +def test_project_parity_gpu_cpu_random() -> None: + import cupy as cp # type: ignore + + rng = np.random.RandomState(0) + N = 1000 + Z = rng.uniform(0.1, 5.0, size=(N, 1)).astype(np.float32) + XY = rng.uniform(-1.0, 1.0, size=(N, 2)).astype(np.float32) + pts3d_np = np.concatenate([XY, Z], axis=1) + + fx, fy = 300.0, 320.0 + cx, cy = 128.0, 96.0 + K_np = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1.0]], dtype=np.float64) + + uv_cpu = project_3d_points_to_2d(pts3d_np, K_np) + uv_gpu = project_3d_points_to_2d(cp.asarray(pts3d_np), cp.asarray(K_np)) + np.testing.assert_array_equal(cp.asnumpy(uv_gpu), uv_cpu) + + # Roundtrip + Z_flat = pts3d_np[:, 2] + pts3d_cpu = project_2d_points_to_3d(uv_cpu.astype(np.float32), Z_flat.astype(np.float32), K_np) + pts3d_gpu = project_2d_points_to_3d( + cp.asarray(uv_cpu.astype(np.float32)), + cp.asarray(Z_flat.astype(np.float32)), + cp.asarray(K_np), + ) + assert pts3d_cpu.shape == cp.asnumpy(pts3d_gpu).shape + + +def test_colorize_depth_cpu() -> None: + depth = np.zeros((32, 48), dtype=np.float32) + depth[8:16, 12:24] = 1.5 + out = colorize_depth(depth, max_depth=3.0, overlay_stats=False) + assert isinstance(out, np.ndarray) + assert out.shape == (32, 48, 3) + assert out.dtype == np.uint8 + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +def test_colorize_depth_gpu_parity() -> None: + import cupy as cp # type: ignore + + depth_np = np.zeros((16, 20), dtype=np.float32) + depth_np[4:8, 5:15] = 2.0 + out_cpu = colorize_depth(depth_np, max_depth=4.0, overlay_stats=False) + out_gpu = colorize_depth(cp.asarray(depth_np), max_depth=4.0, overlay_stats=False) + np.testing.assert_array_equal(cp.asnumpy(out_gpu), out_cpu) + + +def test_draw_bounding_box_cpu() -> None: + img = np.zeros((20, 30, 3), dtype=np.uint8) + out = draw_bounding_box(img, [2, 3, 10, 12], color=(255, 0, 0), thickness=1) + assert isinstance(out, np.ndarray) + assert out.shape == img.shape + assert out.dtype == img.dtype + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +def test_draw_bounding_box_gpu_parity() -> None: + import cupy as cp # type: ignore + + img_np = np.zeros((20, 30, 3), dtype=np.uint8) + out_cpu = draw_bounding_box(img_np.copy(), [2, 3, 10, 12], color=(0, 255, 0), thickness=2) + img_cu = cp.asarray(img_np) + out_gpu = draw_bounding_box(img_cu, [2, 3, 10, 12], color=(0, 255, 0), thickness=2) + np.testing.assert_array_equal(cp.asnumpy(out_gpu), out_cpu) + + +def test_draw_segmentation_mask_cpu() -> None: + img = np.zeros((20, 30, 3), dtype=np.uint8) + mask = np.zeros((20, 30), dtype=np.uint8) + mask[5:10, 8:15] = 1 + out = draw_segmentation_mask(img, mask, color=(0, 200, 200), alpha=0.5) + assert out.shape == img.shape + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +def test_draw_segmentation_mask_gpu_parity() -> None: + import cupy as cp # type: ignore + + img_np = np.zeros((20, 30, 3), dtype=np.uint8) + mask_np = np.zeros((20, 30), dtype=np.uint8) + mask_np[2:12, 3:20] = 1 + out_cpu = draw_segmentation_mask(img_np.copy(), mask_np, color=(100, 50, 200), alpha=0.4) + out_gpu = draw_segmentation_mask( + cp.asarray(img_np), cp.asarray(mask_np), color=(100, 50, 200), alpha=0.4 + ) + np.testing.assert_array_equal(cp.asnumpy(out_gpu), out_cpu) + + +def test_draw_object_detection_visualization_cpu() -> None: + img = np.zeros((30, 40, 3), dtype=np.uint8) + objects = [ + { + "object_id": 1, + "bbox": [5, 6, 20, 25], + "label": "box", + "confidence": 0.9, + } + ] + out = draw_object_detection_visualization(img, objects) + assert out.shape == img.shape + + +@pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") +def test_draw_object_detection_visualization_gpu_parity() -> None: + import cupy as cp # type: ignore + + img_np = np.zeros((30, 40, 3), dtype=np.uint8) + objects = [ + { + "object_id": 1, + "bbox": [5, 6, 20, 25], + "label": "box", + "confidence": 0.9, + } + ] + out_cpu = draw_object_detection_visualization(img_np.copy(), objects) + out_gpu = draw_object_detection_visualization(cp.asarray(img_np), objects) + np.testing.assert_array_equal(cp.asnumpy(out_gpu), out_cpu) diff --git a/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py b/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py new file mode 100644 index 0000000000..b1de0ac777 --- /dev/null +++ b/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py @@ -0,0 +1,797 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import cv2 +import numpy as np +import pytest + +from dimos.msgs.sensor_msgs.Image import HAS_CUDA, Image, ImageFormat +from dimos.utils.data import get_data + +IMAGE_PATH = get_data("chair-image.png") + +if HAS_CUDA: + print("Running image backend tests with CUDA/CuPy support (GPU mode)") +else: + print("Running image backend tests in CPU-only mode") + + +def _load_chair_image() -> np.ndarray: + img = cv2.imread(IMAGE_PATH, cv2.IMREAD_UNCHANGED) + if img is None: + raise FileNotFoundError(f"unable to load test image at {IMAGE_PATH}") + return img + + +_CHAIR_BGRA = _load_chair_image() + + +def _prepare_image(fmt: ImageFormat, shape=None) -> np.ndarray: + base = _CHAIR_BGRA + if fmt == ImageFormat.BGR: + arr = cv2.cvtColor(base, cv2.COLOR_BGRA2BGR) + elif fmt == ImageFormat.RGB: + arr = cv2.cvtColor(base, cv2.COLOR_BGRA2RGB) + elif fmt == ImageFormat.BGRA: + arr = base.copy() + elif fmt == ImageFormat.GRAY: + arr = cv2.cvtColor(base, cv2.COLOR_BGRA2GRAY) + else: + raise ValueError(f"unsupported image format {fmt}") + + if shape is None: + return arr.copy() + + if len(shape) == 2: + height, width = shape + orig_h, orig_w = arr.shape[:2] + interp = cv2.INTER_AREA if height <= orig_h and width <= orig_w else cv2.INTER_LINEAR + resized = cv2.resize(arr, (width, height), interpolation=interp) + return resized.copy() + + if len(shape) == 3: + height, width, channels = shape + orig_h, orig_w = arr.shape[:2] + interp = cv2.INTER_AREA if height <= orig_h and width <= orig_w else cv2.INTER_LINEAR + resized = cv2.resize(arr, (width, height), interpolation=interp) + if resized.ndim == 2: + resized = np.repeat(resized[:, :, None], channels, axis=2) + elif resized.shape[2] != channels: + if channels == 4 and resized.shape[2] == 3: + alpha = np.full((height, width, 1), 255, dtype=resized.dtype) + resized = np.concatenate([resized, alpha], axis=2) + elif channels == 3 and resized.shape[2] == 4: + resized = resized[:, :, :3] + else: + raise ValueError(f"cannot adjust image to {channels} channels") + return resized.copy() + + raise ValueError("shape must be a tuple of length 2 or 3") + + +@pytest.fixture +def alloc_timer(request): + """Helper fixture for adaptive testing with optional GPU support.""" + + def _alloc( + arr: np.ndarray, fmt: ImageFormat, *, to_cuda: bool | None = None, label: str | None = None + ): + tag = label or request.node.name + + # Always create CPU image + start = time.perf_counter() + cpu = Image.from_numpy(arr, format=fmt, to_cuda=False) + cpu_time = time.perf_counter() - start + + # Optionally create GPU image if CUDA is available + gpu = None + gpu_time = None + if to_cuda is None: + to_cuda = HAS_CUDA + + if to_cuda and HAS_CUDA: + arr_gpu = np.array(arr, copy=True) + start = time.perf_counter() + gpu = Image.from_numpy(arr_gpu, format=fmt, to_cuda=True) + gpu_time = time.perf_counter() - start + + if gpu_time is not None: + print(f"[alloc {tag}] cpu={cpu_time:.6f}s gpu={gpu_time:.6f}s") + else: + print(f"[alloc {tag}] cpu={cpu_time:.6f}s") + return cpu, gpu, cpu_time, gpu_time + + return _alloc + + +@pytest.mark.parametrize( + "shape,fmt", + [ + ((64, 64, 3), ImageFormat.BGR), + ((64, 64, 4), ImageFormat.BGRA), + ((64, 64, 3), ImageFormat.RGB), + ((64, 64), ImageFormat.GRAY), + ], +) +def test_color_conversions(shape, fmt, alloc_timer) -> None: + """Test color conversions with NumpyImage always, add CudaImage parity when available.""" + arr = _prepare_image(fmt, shape) + cpu, gpu, _, _ = alloc_timer(arr, fmt) + + # Always test CPU backend + cpu_round = cpu.to_rgb().to_bgr().to_opencv() + assert cpu_round.shape[0] == shape[0] + assert cpu_round.shape[1] == shape[1] + assert cpu_round.shape[2] == 3 # to_opencv always returns BGR (3 channels) + assert cpu_round.dtype == np.uint8 + + # Optionally test GPU parity when CUDA is available + if gpu is not None: + gpu_round = gpu.to_rgb().to_bgr().to_opencv() + assert gpu_round.shape == cpu_round.shape + assert gpu_round.dtype == cpu_round.dtype + # Exact match for uint8 color ops + assert np.array_equal(cpu_round, gpu_round) + + +def test_grayscale(alloc_timer) -> None: + """Test grayscale conversion with NumpyImage always, add CudaImage parity when available.""" + arr = _prepare_image(ImageFormat.BGR, (48, 32, 3)) + cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR) + + # Always test CPU backend + cpu_gray = cpu.to_grayscale().to_opencv() + assert cpu_gray.shape == (48, 32) # Grayscale has no channel dimension in OpenCV + assert cpu_gray.dtype == np.uint8 + + # Optionally test GPU parity when CUDA is available + if gpu is not None: + gpu_gray = gpu.to_grayscale().to_opencv() + assert gpu_gray.shape == cpu_gray.shape + assert gpu_gray.dtype == cpu_gray.dtype + # Allow tiny rounding differences (<=1 LSB) — visually indistinguishable + diff = np.abs(cpu_gray.astype(np.int16) - gpu_gray.astype(np.int16)) + assert diff.max() <= 1 + + +@pytest.mark.parametrize("fmt", [ImageFormat.BGR, ImageFormat.RGB, ImageFormat.BGRA]) +def test_resize(fmt, alloc_timer) -> None: + """Test resize with NumpyImage always, add CudaImage parity when available.""" + shape = (60, 80, 3) if fmt in (ImageFormat.BGR, ImageFormat.RGB) else (60, 80, 4) + arr = _prepare_image(fmt, shape) + cpu, gpu, _, _ = alloc_timer(arr, fmt) + + new_w, new_h = 37, 53 + + # Always test CPU backend + cpu_res = cpu.resize(new_w, new_h).to_opencv() + assert ( + cpu_res.shape == (53, 37, 3) if fmt != ImageFormat.BGRA else (53, 37, 3) + ) # to_opencv drops alpha + assert cpu_res.dtype == np.uint8 + + # Optionally test GPU parity when CUDA is available + if gpu is not None: + gpu_res = gpu.resize(new_w, new_h).to_opencv() + assert gpu_res.shape == cpu_res.shape + assert gpu_res.dtype == cpu_res.dtype + # Allow small tolerance due to float interpolation differences + assert np.max(np.abs(cpu_res.astype(np.int16) - gpu_res.astype(np.int16))) <= 1 + + +def test_perf_alloc(alloc_timer) -> None: + """Test allocation performance with NumpyImage always, add CudaImage when available.""" + arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) + alloc_timer(arr, ImageFormat.BGR, label="test_perf_alloc-setup") + + runs = 5 + + # Always test CPU allocation + t0 = time.perf_counter() + for _ in range(runs): + _ = Image.from_numpy(arr, format=ImageFormat.BGR, to_cuda=False) + cpu_t = (time.perf_counter() - t0) / runs + assert cpu_t > 0 + + # Optionally test GPU allocation when CUDA is available + if HAS_CUDA: + t0 = time.perf_counter() + for _ in range(runs): + _ = Image.from_numpy(arr, format=ImageFormat.BGR, to_cuda=True) + gpu_t = (time.perf_counter() - t0) / runs + print(f"alloc (avg per call) cpu={cpu_t:.6f}s gpu={gpu_t:.6f}s") + assert gpu_t > 0 + else: + print(f"alloc (avg per call) cpu={cpu_t:.6f}s") + + +def test_sharpness(alloc_timer) -> None: + """Test sharpness computation with NumpyImage always, add CudaImage parity when available.""" + arr = _prepare_image(ImageFormat.BGR, (64, 64, 3)) + cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR) + + # Always test CPU backend + s_cpu = cpu.sharpness + assert s_cpu >= 0 # Sharpness should be non-negative + assert s_cpu < 1000 # Reasonable upper bound + + # Optionally test GPU parity when CUDA is available + if gpu is not None: + s_gpu = gpu.sharpness + # Values should be very close; minor border/rounding differences allowed + assert abs(s_cpu - s_gpu) < 5e-2 + + +def test_to_opencv(alloc_timer) -> None: + """Test to_opencv conversion with NumpyImage always, add CudaImage parity when available.""" + # BGRA should drop alpha and produce BGR + arr = _prepare_image(ImageFormat.BGRA, (32, 32, 4)) + cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGRA) + + # Always test CPU backend + cpu_bgr = cpu.to_opencv() + assert cpu_bgr.shape == (32, 32, 3) + assert cpu_bgr.dtype == np.uint8 + + # Optionally test GPU parity when CUDA is available + if gpu is not None: + gpu_bgr = gpu.to_opencv() + assert gpu_bgr.shape == cpu_bgr.shape + assert gpu_bgr.dtype == cpu_bgr.dtype + assert np.array_equal(cpu_bgr, gpu_bgr) + + +def test_solve_pnp(alloc_timer) -> None: + """Test solve_pnp with NumpyImage always, add CudaImage parity when available.""" + # Synthetic camera and 3D points + K = np.array([[400.0, 0.0, 32.0], [0.0, 400.0, 24.0], [0.0, 0.0, 1.0]], dtype=np.float64) + dist = None + obj = np.array( + [ + [-0.5, -0.5, 0.0], + [0.5, -0.5, 0.0], + [0.5, 0.5, 0.0], + [-0.5, 0.5, 0.0], + [0.0, 0.0, 0.5], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + + rvec_true = np.zeros((3, 1), dtype=np.float64) + tvec_true = np.array([[0.0], [0.0], [2.0]], dtype=np.float64) + img_pts, _ = cv2.projectPoints(obj, rvec_true, tvec_true, K, dist) + img_pts = img_pts.reshape(-1, 2).astype(np.float32) + + # Build images using deterministic fixture content + base_bgr = _prepare_image(ImageFormat.BGR, (48, 64, 3)) + cpu, gpu, _, _ = alloc_timer(base_bgr, ImageFormat.BGR) + + # Always test CPU backend + ok_cpu, r_cpu, t_cpu = cpu.solve_pnp(obj, img_pts, K, dist) + assert ok_cpu + + # Validate reprojection error for CPU solver + proj_cpu, _ = cv2.projectPoints(obj, r_cpu, t_cpu, K, dist) + proj_cpu = proj_cpu.reshape(-1, 2) + err_cpu = np.linalg.norm(proj_cpu - img_pts, axis=1) + assert err_cpu.mean() < 1e-3 + assert err_cpu.max() < 1e-2 + + # Optionally test GPU parity when CUDA is available + if gpu is not None: + ok_gpu, r_gpu, t_gpu = gpu.solve_pnp(obj, img_pts, K, dist) + assert ok_gpu + + # Validate reprojection error for GPU solver + proj_gpu, _ = cv2.projectPoints(obj, r_gpu, t_gpu, K, dist) + proj_gpu = proj_gpu.reshape(-1, 2) + err_gpu = np.linalg.norm(proj_gpu - img_pts, axis=1) + assert err_gpu.mean() < 1e-3 + assert err_gpu.max() < 1e-2 + + +def test_perf_grayscale(alloc_timer) -> None: + """Test grayscale performance with NumpyImage always, add CudaImage when available.""" + arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) + cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR, label="test_perf_grayscale-setup") + + runs = 10 + + # Always test CPU performance + t0 = time.perf_counter() + for _ in range(runs): + _ = cpu.to_grayscale() + cpu_t = (time.perf_counter() - t0) / runs + assert cpu_t > 0 + + # Optionally test GPU performance when CUDA is available + if gpu is not None: + t0 = time.perf_counter() + for _ in range(runs): + _ = gpu.to_grayscale() + gpu_t = (time.perf_counter() - t0) / runs + print(f"grayscale (avg per call) cpu={cpu_t:.6f}s gpu={gpu_t:.6f}s") + assert gpu_t > 0 + else: + print(f"grayscale (avg per call) cpu={cpu_t:.6f}s") + + +def test_perf_resize(alloc_timer) -> None: + """Test resize performance with NumpyImage always, add CudaImage when available.""" + arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) + cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR, label="test_perf_resize-setup") + + runs = 5 + + # Always test CPU performance + t0 = time.perf_counter() + for _ in range(runs): + _ = cpu.resize(320, 240) + cpu_t = (time.perf_counter() - t0) / runs + assert cpu_t > 0 + + # Optionally test GPU performance when CUDA is available + if gpu is not None: + t0 = time.perf_counter() + for _ in range(runs): + _ = gpu.resize(320, 240) + gpu_t = (time.perf_counter() - t0) / runs + print(f"resize (avg per call) cpu={cpu_t:.6f}s gpu={gpu_t:.6f}s") + assert gpu_t > 0 + else: + print(f"resize (avg per call) cpu={cpu_t:.6f}s") + + +def test_perf_sharpness(alloc_timer) -> None: + """Test sharpness performance with NumpyImage always, add CudaImage when available.""" + arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) + cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR, label="test_perf_sharpness-setup") + + runs = 3 + + # Always test CPU performance + t0 = time.perf_counter() + for _ in range(runs): + _ = cpu.sharpness + cpu_t = (time.perf_counter() - t0) / runs + assert cpu_t > 0 + + # Optionally test GPU performance when CUDA is available + if gpu is not None: + t0 = time.perf_counter() + for _ in range(runs): + _ = gpu.sharpness + gpu_t = (time.perf_counter() - t0) / runs + print(f"sharpness (avg per call) cpu={cpu_t:.6f}s gpu={gpu_t:.6f}s") + assert gpu_t > 0 + else: + print(f"sharpness (avg per call) cpu={cpu_t:.6f}s") + + +def test_perf_solvepnp(alloc_timer) -> None: + """Test solve_pnp performance with NumpyImage always, add CudaImage when available.""" + K = np.array([[600.0, 0.0, 320.0], [0.0, 600.0, 240.0], [0.0, 0.0, 1.0]], dtype=np.float64) + dist = None + rng = np.random.default_rng(123) + obj = rng.standard_normal((200, 3)).astype(np.float32) + rvec_true = np.array([[0.1], [-0.2], [0.05]]) + tvec_true = np.array([[0.0], [0.0], [3.0]]) + img_pts, _ = cv2.projectPoints(obj, rvec_true, tvec_true, K, dist) + img_pts = img_pts.reshape(-1, 2).astype(np.float32) + base_bgr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) + cpu, gpu, _, _ = alloc_timer(base_bgr, ImageFormat.BGR, label="test_perf_solvepnp-setup") + + runs = 5 + + # Always test CPU performance + t0 = time.perf_counter() + for _ in range(runs): + _ = cpu.solve_pnp(obj, img_pts, K, dist) + cpu_t = (time.perf_counter() - t0) / runs + assert cpu_t > 0 + + # Optionally test GPU performance when CUDA is available + if gpu is not None: + t0 = time.perf_counter() + for _ in range(runs): + _ = gpu.solve_pnp(obj, img_pts, K, dist) + gpu_t = (time.perf_counter() - t0) / runs + print(f"solvePnP (avg per call) cpu={cpu_t:.6f}s gpu={gpu_t:.6f}s") + assert gpu_t > 0 + else: + print(f"solvePnP (avg per call) cpu={cpu_t:.6f}s") + + +# this test is failing with +# raise RuntimeError("OpenCV CSRT tracker not available") +@pytest.mark.skip +def test_perf_tracker(alloc_timer) -> None: + """Test tracker performance with NumpyImage always, add CudaImage when available.""" + # Don't check - just let it fail if CSRT isn't available + + H, W = 240, 320 + img_base = _prepare_image(ImageFormat.BGR, (H, W, 3)) + img1 = img_base.copy() + img2 = img_base.copy() + bbox0 = (80, 60, 40, 30) + x0, y0, w0, h0 = bbox0 + cv2.rectangle(img1, (x0, y0), (x0 + w0, y0 + h0), (255, 255, 255), thickness=-1) + dx, dy = 8, 5 + cv2.rectangle( + img2, + (x0 + dx, y0 + dy), + (x0 + dx + w0, y0 + dy + h0), + (255, 255, 255), + thickness=-1, + ) + cpu1, gpu1, _, _ = alloc_timer(img1, ImageFormat.BGR, label="test_perf_tracker-frame1") + cpu2, gpu2, _, _ = alloc_timer(img2, ImageFormat.BGR, label="test_perf_tracker-frame2") + + # Always test CPU tracker + trk_cpu = cpu1.create_csrt_tracker(bbox0) + + runs = 10 + t0 = time.perf_counter() + for _ in range(runs): + _ = cpu2.csrt_update(trk_cpu) + cpu_t = (time.perf_counter() - t0) / runs + assert cpu_t > 0 + + # Optionally test GPU performance when CUDA is available + if gpu1 is not None and gpu2 is not None: + trk_gpu = gpu1.create_csrt_tracker(bbox0) + t0 = time.perf_counter() + for _ in range(runs): + _ = gpu2.csrt_update(trk_gpu) + gpu_t = (time.perf_counter() - t0) / runs + print(f"tracker (avg per call) cpu={cpu_t:.6f}s gpu={gpu_t:.6f}s") + assert gpu_t > 0 + else: + print(f"tracker (avg per call) cpu={cpu_t:.6f}s") + + +# this test is failing with +# raise RuntimeError("OpenCV CSRT tracker not available") +@pytest.mark.skip +def test_csrt_tracker(alloc_timer) -> None: + """Test CSRT tracker with NumpyImage always, add CudaImage parity when available.""" + # Don't check - just let it fail if CSRT isn't available + + H, W = 100, 100 + # Create two frames with a moving rectangle + img_base = _prepare_image(ImageFormat.BGR, (H, W, 3)) + img1 = img_base.copy() + img2 = img_base.copy() + bbox0 = (30, 30, 20, 15) + x0, y0, w0, h0 = bbox0 + # draw rect in img1 + cv2.rectangle(img1, (x0, y0), (x0 + w0, y0 + h0), (255, 255, 255), thickness=-1) + # shift by (dx,dy) + dx, dy = 5, 3 + cv2.rectangle( + img2, + (x0 + dx, y0 + dy), + (x0 + dx + w0, y0 + dy + h0), + (255, 255, 255), + thickness=-1, + ) + + cpu1, gpu1, _, _ = alloc_timer(img1, ImageFormat.BGR, label="test_csrt_tracker-frame1") + cpu2, gpu2, _, _ = alloc_timer(img2, ImageFormat.BGR, label="test_csrt_tracker-frame2") + + # Always test CPU tracker + trk_cpu = cpu1.create_csrt_tracker(bbox0) + ok_cpu, bbox_cpu = cpu2.csrt_update(trk_cpu) + assert ok_cpu + + # Compare to ground-truth expected bbox + expected = (x0 + dx, y0 + dy, w0, h0) + err_cpu = sum(abs(a - b) for a, b in zip(bbox_cpu, expected, strict=False)) + assert err_cpu <= 8 + + # Optionally test GPU parity when CUDA is available + if gpu1 is not None and gpu2 is not None: + trk_gpu = gpu1.create_csrt_tracker(bbox0) + ok_gpu, bbox_gpu = gpu2.csrt_update(trk_gpu) + assert ok_gpu + + err_gpu = sum(abs(a - b) for a, b in zip(bbox_gpu, expected, strict=False)) + assert err_gpu <= 10 # allow some slack for scale/window effects + + +def test_solve_pnp_ransac(alloc_timer) -> None: + """Test solve_pnp_ransac with NumpyImage always, add CudaImage when available.""" + # Camera with distortion + K = np.array([[500.0, 0.0, 320.0], [0.0, 500.0, 240.0], [0.0, 0.0, 1.0]], dtype=np.float64) + dist = np.array([0.1, -0.05, 0.001, 0.001, 0.0], dtype=np.float64) + rng = np.random.default_rng(202) + obj = rng.uniform(-1.0, 1.0, size=(200, 3)).astype(np.float32) + obj[:, 2] = np.abs(obj[:, 2]) + 2.0 # keep in front of camera + rvec_true = np.array([[0.1], [-0.15], [0.05]], dtype=np.float64) + tvec_true = np.array([[0.2], [-0.1], [3.0]], dtype=np.float64) + img_pts, _ = cv2.projectPoints(obj, rvec_true, tvec_true, K, dist) + img_pts = img_pts.reshape(-1, 2) + # Add outliers + n_out = 20 + idx = rng.choice(len(img_pts), size=n_out, replace=False) + img_pts[idx] += rng.uniform(-50, 50, size=(n_out, 2)) + img_pts = img_pts.astype(np.float32) + + base_bgr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) + cpu, gpu, _, _ = alloc_timer(base_bgr, ImageFormat.BGR, label="test_solve_pnp_ransac-setup") + + # Always test CPU backend + ok_cpu, r_cpu, t_cpu, mask_cpu = cpu.solve_pnp_ransac( + obj, img_pts, K, dist, iterations_count=150, reprojection_error=3.0 + ) + assert ok_cpu + inlier_ratio = mask_cpu.mean() + assert inlier_ratio > 0.7 + + # Reprojection error on inliers + in_idx = np.nonzero(mask_cpu)[0] + proj_cpu, _ = cv2.projectPoints(obj[in_idx], r_cpu, t_cpu, K, dist) + proj_cpu = proj_cpu.reshape(-1, 2) + err = np.linalg.norm(proj_cpu - img_pts[in_idx], axis=1) + assert err.mean() < 1.5 + assert err.max() < 4.0 + + # Optionally test GPU parity when CUDA is available + if gpu is not None: + ok_gpu, r_gpu, t_gpu, mask_gpu = gpu.solve_pnp_ransac( + obj, img_pts, K, dist, iterations_count=150, reprojection_error=3.0 + ) + assert ok_gpu + inlier_ratio_gpu = mask_gpu.mean() + assert inlier_ratio_gpu > 0.7 + + # Reprojection error on inliers for GPU + in_idx_gpu = np.nonzero(mask_gpu)[0] + proj_gpu, _ = cv2.projectPoints(obj[in_idx_gpu], r_gpu, t_gpu, K, dist) + proj_gpu = proj_gpu.reshape(-1, 2) + err_gpu = np.linalg.norm(proj_gpu - img_pts[in_idx_gpu], axis=1) + assert err_gpu.mean() < 1.5 + assert err_gpu.max() < 4.0 + + +def test_solve_pnp_batch(alloc_timer) -> None: + """Test solve_pnp batch processing with NumpyImage always, add CudaImage when available.""" + # Note: Batch processing is primarily a GPU feature, but we can still test CPU loop + # Generate batched problems + B, N = 8, 50 + rng = np.random.default_rng(99) + obj = rng.uniform(-1.0, 1.0, size=(B, N, 3)).astype(np.float32) + obj[:, :, 2] = np.abs(obj[:, :, 2]) + 2.0 + K = np.array([[600.0, 0.0, 320.0], [0.0, 600.0, 240.0], [0.0, 0.0, 1.0]], dtype=np.float64) + r_true = np.zeros((B, 3, 1), dtype=np.float64) + t_true = np.tile(np.array([[0.0], [0.0], [3.0]], dtype=np.float64), (B, 1, 1)) + img = [] + for b in range(B): + ip, _ = cv2.projectPoints(obj[b], r_true[b], t_true[b], K, None) + img.append(ip.reshape(-1, 2)) + img = np.stack(img, axis=0).astype(np.float32) + + base_bgr = _prepare_image(ImageFormat.BGR, (10, 10, 3)) + cpu, gpu, _, _ = alloc_timer(base_bgr, ImageFormat.BGR, label="test_solve_pnp_batch-setup") + + # Always test CPU loop + t0 = time.perf_counter() + r_list = [] + t_list = [] + for b in range(B): + ok, r, t = cpu.solve_pnp(obj[b], img[b], K, None) + assert ok + r_list.append(r) + t_list.append(t) + cpu_total = time.perf_counter() - t0 + cpu_t = cpu_total / B + + # Check reprojection for CPU results + for b in range(min(B, 2)): + proj, _ = cv2.projectPoints(obj[b], r_list[b], t_list[b], K, None) + err = np.linalg.norm(proj.reshape(-1, 2) - img[b], axis=1) + assert err.mean() < 1e-2 + assert err.max() < 1e-1 + + # Optionally test GPU batch when CUDA is available + if gpu is not None and hasattr(gpu._impl, "solve_pnp_batch"): + t0 = time.perf_counter() + r_b, t_b = gpu.solve_pnp_batch(obj, img, K) + gpu_total = time.perf_counter() - t0 + gpu_t = gpu_total / B + print(f"solvePnP-batch (avg per pose) cpu={cpu_t:.6f}s gpu={gpu_t:.6f}s (B={B}, N={N})") + + # Check reprojection for GPU batches + for b in range(min(B, 4)): + proj, _ = cv2.projectPoints(obj[b], r_b[b], t_b[b], K, None) + err = np.linalg.norm(proj.reshape(-1, 2) - img[b], axis=1) + assert err.mean() < 1e-2 + assert err.max() < 1e-1 + else: + print(f"solvePnP-batch (avg per pose) cpu={cpu_t:.6f}s (GPU batch not available)") + + +def test_nvimgcodec_flag_and_fallback(monkeypatch) -> None: + # Test that to_base64() works with and without nvimgcodec by patching runtime flags + import dimos.msgs.sensor_msgs.image_impls.AbstractImage as AbstractImageMod + + arr = _prepare_image(ImageFormat.BGR, (32, 32, 3)) + + # Save original values + original_has_nvimgcodec = AbstractImageMod.HAS_NVIMGCODEC + original_nvimgcodec = AbstractImageMod.nvimgcodec + + try: + # Test 1: Simulate nvimgcodec not available + monkeypatch.setattr(AbstractImageMod, "HAS_NVIMGCODEC", False) + monkeypatch.setattr(AbstractImageMod, "nvimgcodec", None) + + # Should work via cv2 fallback for CPU + img_cpu = Image.from_numpy(arr, format=ImageFormat.BGR, to_cuda=False) + b64_cpu = img_cpu.to_base64() + assert isinstance(b64_cpu, str) and len(b64_cpu) > 0 + + # If CUDA available, test GPU fallback to CPU encoding + if HAS_CUDA: + img_gpu = Image.from_numpy(arr, format=ImageFormat.BGR, to_cuda=True) + b64_gpu = img_gpu.to_base64() + assert isinstance(b64_gpu, str) and len(b64_gpu) > 0 + # Should have fallen back to CPU encoding + assert not AbstractImageMod.NVIMGCODEC_LAST_USED + + # Test 2: Restore nvimgcodec if it was originally available + if original_has_nvimgcodec: + monkeypatch.setattr(AbstractImageMod, "HAS_NVIMGCODEC", True) + monkeypatch.setattr(AbstractImageMod, "nvimgcodec", original_nvimgcodec) + + # Test it still works with nvimgcodec "available" + img2 = Image.from_numpy(arr, format=ImageFormat.BGR, to_cuda=HAS_CUDA) + b64_2 = img2.to_base64() + assert isinstance(b64_2, str) and len(b64_2) > 0 + + finally: + pass + + +@pytest.mark.skipif(not HAS_CUDA, reason="CuPy/CUDA not available") +def test_nvimgcodec_gpu_path(monkeypatch) -> None: + """Test nvimgcodec GPU encoding path when CUDA is available. + + This test specifically verifies that when nvimgcodec is available, + GPU images can be encoded directly without falling back to CPU. + """ + import dimos.msgs.sensor_msgs.image_impls.AbstractImage as AbstractImageMod + + # Check if nvimgcodec was originally available + if not AbstractImageMod.HAS_NVIMGCODEC: + pytest.skip("nvimgcodec library not available") + + # Save original nvimgcodec module reference + + # Create a CUDA image and encode using the actual nvimgcodec if available + arr = _prepare_image(ImageFormat.BGR, (32, 32, 3)) + + # Test with nvimgcodec enabled (should be the default if available) + img = Image.from_numpy(arr, format=ImageFormat.BGR, to_cuda=True) + b64 = img.to_base64() + assert isinstance(b64, str) and len(b64) > 0 + + # Check if GPU encoding was actually used + # Some builds may import nvimgcodec but not support CuPy device buffers + if not getattr(AbstractImageMod, "NVIMGCODEC_LAST_USED", False): + pytest.skip("nvimgcodec present but encode fell back to CPU in this environment") + + # Now test that we can disable nvimgcodec and still encode via fallback + monkeypatch.setattr(AbstractImageMod, "HAS_NVIMGCODEC", False) + monkeypatch.setattr(AbstractImageMod, "nvimgcodec", None) + + # Create another GPU image - should fall back to CPU encoding + img2 = Image.from_numpy(arr, format=ImageFormat.BGR, to_cuda=True) + b64_2 = img2.to_base64() + assert isinstance(b64_2, str) and len(b64_2) > 0 + # Should have fallen back to CPU encoding + assert not AbstractImageMod.NVIMGCODEC_LAST_USED + + +@pytest.mark.skipif(not HAS_CUDA, reason="CuPy/CUDA not available") +def test_to_cpu_format_preservation() -> None: + """Test that to_cpu() preserves image format correctly. + + This tests the fix for the bug where to_cpu() was using to_opencv() + which always returns BGR, but keeping the original format label. + """ + # Test RGB format preservation + rgb_array = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + gpu_img_rgb = Image.from_numpy(rgb_array, format=ImageFormat.RGB, to_cuda=True) + cpu_img_rgb = gpu_img_rgb.to_cpu() + + # Verify format is preserved + assert cpu_img_rgb.format == ImageFormat.RGB, ( + f"Format mismatch: expected RGB, got {cpu_img_rgb.format}" + ) + # Verify data is actually in RGB format (not BGR) + np.testing.assert_array_equal(cpu_img_rgb.data, rgb_array) + + # Test RGBA format preservation + rgba_array = np.random.randint(0, 255, (100, 100, 4), dtype=np.uint8) + gpu_img_rgba = Image.from_numpy(rgba_array, format=ImageFormat.RGBA, to_cuda=True) + cpu_img_rgba = gpu_img_rgba.to_cpu() + + assert cpu_img_rgba.format == ImageFormat.RGBA, ( + f"Format mismatch: expected RGBA, got {cpu_img_rgba.format}" + ) + np.testing.assert_array_equal(cpu_img_rgba.data, rgba_array) + + # Test BGR format (should be unchanged since to_opencv returns BGR) + bgr_array = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + gpu_img_bgr = Image.from_numpy(bgr_array, format=ImageFormat.BGR, to_cuda=True) + cpu_img_bgr = gpu_img_bgr.to_cpu() + + assert cpu_img_bgr.format == ImageFormat.BGR, ( + f"Format mismatch: expected BGR, got {cpu_img_bgr.format}" + ) + np.testing.assert_array_equal(cpu_img_bgr.data, bgr_array) + + # Test BGRA format + bgra_array = np.random.randint(0, 255, (100, 100, 4), dtype=np.uint8) + gpu_img_bgra = Image.from_numpy(bgra_array, format=ImageFormat.BGRA, to_cuda=True) + cpu_img_bgra = gpu_img_bgra.to_cpu() + + assert cpu_img_bgra.format == ImageFormat.BGRA, ( + f"Format mismatch: expected BGRA, got {cpu_img_bgra.format}" + ) + np.testing.assert_array_equal(cpu_img_bgra.data, bgra_array) + + # Test GRAY format + gray_array = np.random.randint(0, 255, (100, 100), dtype=np.uint8) + gpu_img_gray = Image.from_numpy(gray_array, format=ImageFormat.GRAY, to_cuda=True) + cpu_img_gray = gpu_img_gray.to_cpu() + + assert cpu_img_gray.format == ImageFormat.GRAY, ( + f"Format mismatch: expected GRAY, got {cpu_img_gray.format}" + ) + np.testing.assert_array_equal(cpu_img_gray.data, gray_array) + + # Test DEPTH format (float32) + depth_array = np.random.uniform(0.5, 10.0, (100, 100)).astype(np.float32) + gpu_img_depth = Image.from_numpy(depth_array, format=ImageFormat.DEPTH, to_cuda=True) + cpu_img_depth = gpu_img_depth.to_cpu() + + assert cpu_img_depth.format == ImageFormat.DEPTH, ( + f"Format mismatch: expected DEPTH, got {cpu_img_depth.format}" + ) + np.testing.assert_array_equal(cpu_img_depth.data, depth_array) + + # Test DEPTH16 format (uint16) + depth16_array = np.random.randint(100, 65000, (100, 100), dtype=np.uint16) + gpu_img_depth16 = Image.from_numpy(depth16_array, format=ImageFormat.DEPTH16, to_cuda=True) + cpu_img_depth16 = gpu_img_depth16.to_cpu() + + assert cpu_img_depth16.format == ImageFormat.DEPTH16, ( + f"Format mismatch: expected DEPTH16, got {cpu_img_depth16.format}" + ) + np.testing.assert_array_equal(cpu_img_depth16.data, depth16_array) + + # Test GRAY16 format (uint16) + gray16_array = np.random.randint(0, 65535, (100, 100), dtype=np.uint16) + gpu_img_gray16 = Image.from_numpy(gray16_array, format=ImageFormat.GRAY16, to_cuda=True) + cpu_img_gray16 = gpu_img_gray16.to_cpu() + + assert cpu_img_gray16.format == ImageFormat.GRAY16, ( + f"Format mismatch: expected GRAY16, got {cpu_img_gray16.format}" + ) + np.testing.assert_array_equal(cpu_img_gray16.data, gray16_array) diff --git a/dimos/msgs/sensor_msgs/test_CameraInfo.py b/dimos/msgs/sensor_msgs/test_CameraInfo.py new file mode 100644 index 0000000000..d8fea70945 --- /dev/null +++ b/dimos/msgs/sensor_msgs/test_CameraInfo.py @@ -0,0 +1,457 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +try: + from sensor_msgs.msg import CameraInfo as ROSCameraInfo, RegionOfInterest as ROSRegionOfInterest + from std_msgs.msg import Header as ROSHeader +except ImportError: + ROSCameraInfo = None + ROSRegionOfInterest = None + ROSHeader = None + +from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider, CameraInfo +from dimos.utils.path_utils import get_project_root + + +def test_lcm_encode_decode() -> None: + """Test LCM encode/decode preserves CameraInfo data.""" + print("Testing CameraInfo LCM encode/decode...") + + # Create test camera info with sample calibration data + original = CameraInfo( + height=480, + width=640, + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.001, -0.002, 0.0], # 5 distortion coefficients + K=[ + 500.0, + 0.0, + 320.0, # fx, 0, cx + 0.0, + 500.0, + 240.0, # 0, fy, cy + 0.0, + 0.0, + 1.0, + ], # 0, 0, 1 + R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + P=[ + 500.0, + 0.0, + 320.0, + 0.0, # fx, 0, cx, Tx + 0.0, + 500.0, + 240.0, + 0.0, # 0, fy, cy, Ty + 0.0, + 0.0, + 1.0, + 0.0, + ], # 0, 0, 1, 0 + binning_x=2, + binning_y=2, + frame_id="camera_optical_frame", + ts=1234567890.123456, + ) + + # Set ROI + original.roi_x_offset = 100 + original.roi_y_offset = 50 + original.roi_height = 200 + original.roi_width = 300 + original.roi_do_rectify = True + + # Encode and decode + binary_msg = original.lcm_encode() + decoded = CameraInfo.lcm_decode(binary_msg) + + # Check basic properties + assert original.height == decoded.height, ( + f"Height mismatch: {original.height} vs {decoded.height}" + ) + assert original.width == decoded.width, f"Width mismatch: {original.width} vs {decoded.width}" + print(f"✓ Image dimensions preserved: {decoded.width}x{decoded.height}") + + assert original.distortion_model == decoded.distortion_model, ( + f"Distortion model mismatch: '{original.distortion_model}' vs '{decoded.distortion_model}'" + ) + print(f"✓ Distortion model preserved: '{decoded.distortion_model}'") + + # Check distortion coefficients + assert len(original.D) == len(decoded.D), ( + f"D length mismatch: {len(original.D)} vs {len(decoded.D)}" + ) + np.testing.assert_allclose( + original.D, decoded.D, rtol=1e-9, atol=1e-9, err_msg="Distortion coefficients don't match" + ) + print(f"✓ Distortion coefficients preserved: {len(decoded.D)} coefficients") + + # Check camera matrices + np.testing.assert_allclose( + original.K, decoded.K, rtol=1e-9, atol=1e-9, err_msg="K matrix doesn't match" + ) + print("✓ Intrinsic matrix K preserved") + + np.testing.assert_allclose( + original.R, decoded.R, rtol=1e-9, atol=1e-9, err_msg="R matrix doesn't match" + ) + print("✓ Rectification matrix R preserved") + + np.testing.assert_allclose( + original.P, decoded.P, rtol=1e-9, atol=1e-9, err_msg="P matrix doesn't match" + ) + print("✓ Projection matrix P preserved") + + # Check binning + assert original.binning_x == decoded.binning_x, ( + f"Binning X mismatch: {original.binning_x} vs {decoded.binning_x}" + ) + assert original.binning_y == decoded.binning_y, ( + f"Binning Y mismatch: {original.binning_y} vs {decoded.binning_y}" + ) + print(f"✓ Binning preserved: {decoded.binning_x}x{decoded.binning_y}") + + # Check ROI + assert original.roi_x_offset == decoded.roi_x_offset, "ROI x_offset mismatch" + assert original.roi_y_offset == decoded.roi_y_offset, "ROI y_offset mismatch" + assert original.roi_height == decoded.roi_height, "ROI height mismatch" + assert original.roi_width == decoded.roi_width, "ROI width mismatch" + assert original.roi_do_rectify == decoded.roi_do_rectify, "ROI do_rectify mismatch" + print("✓ ROI preserved") + + # Check metadata + assert original.frame_id == decoded.frame_id, ( + f"Frame ID mismatch: '{original.frame_id}' vs '{decoded.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{decoded.frame_id}'") + + assert abs(original.ts - decoded.ts) < 1e-6, ( + f"Timestamp mismatch: {original.ts} vs {decoded.ts}" + ) + print(f"✓ Timestamp preserved: {decoded.ts}") + + print("✓ LCM encode/decode test passed - all properties preserved!") + + +def test_numpy_matrix_operations() -> None: + """Test numpy matrix getter/setter operations.""" + print("\nTesting numpy matrix operations...") + + camera_info = CameraInfo() + + # Test K matrix + K = np.array([[525.0, 0.0, 319.5], [0.0, 525.0, 239.5], [0.0, 0.0, 1.0]]) + camera_info.set_K_matrix(K) + K_retrieved = camera_info.get_K_matrix() + np.testing.assert_allclose(K, K_retrieved, rtol=1e-9, atol=1e-9) + print("✓ K matrix setter/getter works") + + # Test P matrix + P = np.array([[525.0, 0.0, 319.5, 0.0], [0.0, 525.0, 239.5, 0.0], [0.0, 0.0, 1.0, 0.0]]) + camera_info.set_P_matrix(P) + P_retrieved = camera_info.get_P_matrix() + np.testing.assert_allclose(P, P_retrieved, rtol=1e-9, atol=1e-9) + print("✓ P matrix setter/getter works") + + # Test R matrix + R = np.eye(3) + camera_info.set_R_matrix(R) + R_retrieved = camera_info.get_R_matrix() + np.testing.assert_allclose(R, R_retrieved, rtol=1e-9, atol=1e-9) + print("✓ R matrix setter/getter works") + + # Test D coefficients + D = np.array([-0.2, 0.1, 0.001, -0.002, 0.05]) + camera_info.set_D_coeffs(D) + D_retrieved = camera_info.get_D_coeffs() + np.testing.assert_allclose(D, D_retrieved, rtol=1e-9, atol=1e-9) + print("✓ D coefficients setter/getter works") + + print("✓ All numpy matrix operations passed!") + + +@pytest.mark.ros +def test_ros_conversion() -> None: + """Test ROS message conversion preserves CameraInfo data.""" + print("\nTesting ROS CameraInfo conversion...") + + # Create test camera info + original = CameraInfo( + height=720, + width=1280, + distortion_model="rational_polynomial", + D=[0.1, -0.2, 0.001, 0.002, -0.05, 0.01, -0.02, 0.003], # 8 coefficients + K=[600.0, 0.0, 640.0, 0.0, 600.0, 360.0, 0.0, 0.0, 1.0], + R=[0.999, -0.01, 0.02, 0.01, 0.999, -0.01, -0.02, 0.01, 0.999], + P=[ + 600.0, + 0.0, + 640.0, + -60.0, # Stereo baseline of 0.1m + 0.0, + 600.0, + 360.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + ], + binning_x=1, + binning_y=1, + frame_id="left_camera_optical", + ts=1234567890.987654, + ) + + # Set ROI + original.roi_x_offset = 200 + original.roi_y_offset = 100 + original.roi_height = 400 + original.roi_width = 800 + original.roi_do_rectify = False + + # Test 1: Convert to ROS and back + ros_msg = original.to_ros_msg() + converted = CameraInfo.from_ros_msg(ros_msg) + + # Check all properties + assert original.height == converted.height, ( + f"Height mismatch: {original.height} vs {converted.height}" + ) + assert original.width == converted.width, ( + f"Width mismatch: {original.width} vs {converted.width}" + ) + print(f"✓ Dimensions preserved: {converted.width}x{converted.height}") + + assert original.distortion_model == converted.distortion_model, ( + f"Distortion model mismatch: '{original.distortion_model}' vs '{converted.distortion_model}'" + ) + print(f"✓ Distortion model preserved: '{converted.distortion_model}'") + + np.testing.assert_allclose( + original.D, + converted.D, + rtol=1e-9, + atol=1e-9, + err_msg="D coefficients don't match after ROS conversion", + ) + print(f"✓ Distortion coefficients preserved: {len(converted.D)} coefficients") + + np.testing.assert_allclose( + original.K, + converted.K, + rtol=1e-9, + atol=1e-9, + err_msg="K matrix doesn't match after ROS conversion", + ) + print("✓ K matrix preserved") + + np.testing.assert_allclose( + original.R, + converted.R, + rtol=1e-9, + atol=1e-9, + err_msg="R matrix doesn't match after ROS conversion", + ) + print("✓ R matrix preserved") + + np.testing.assert_allclose( + original.P, + converted.P, + rtol=1e-9, + atol=1e-9, + err_msg="P matrix doesn't match after ROS conversion", + ) + print("✓ P matrix preserved") + + assert original.binning_x == converted.binning_x, "Binning X mismatch" + assert original.binning_y == converted.binning_y, "Binning Y mismatch" + print(f"✓ Binning preserved: {converted.binning_x}x{converted.binning_y}") + + assert original.roi_x_offset == converted.roi_x_offset, "ROI x_offset mismatch" + assert original.roi_y_offset == converted.roi_y_offset, "ROI y_offset mismatch" + assert original.roi_height == converted.roi_height, "ROI height mismatch" + assert original.roi_width == converted.roi_width, "ROI width mismatch" + assert original.roi_do_rectify == converted.roi_do_rectify, "ROI do_rectify mismatch" + print("✓ ROI preserved") + + assert original.frame_id == converted.frame_id, ( + f"Frame ID mismatch: '{original.frame_id}' vs '{converted.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{converted.frame_id}'") + + assert abs(original.ts - converted.ts) < 1e-6, ( + f"Timestamp mismatch: {original.ts} vs {converted.ts}" + ) + print(f"✓ Timestamp preserved: {converted.ts}") + + # Test 2: Create ROS message directly and convert to DIMOS + ros_msg2 = ROSCameraInfo() + ros_msg2.header = ROSHeader() + ros_msg2.header.frame_id = "test_camera" + ros_msg2.header.stamp.sec = 1234567890 + ros_msg2.header.stamp.nanosec = 500000000 + + ros_msg2.height = 1080 + ros_msg2.width = 1920 + ros_msg2.distortion_model = "plumb_bob" + ros_msg2.d = [-0.3, 0.15, 0.0, 0.0, 0.0] + ros_msg2.k = [1000.0, 0.0, 960.0, 0.0, 1000.0, 540.0, 0.0, 0.0, 1.0] + ros_msg2.r = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0] + ros_msg2.p = [1000.0, 0.0, 960.0, 0.0, 0.0, 1000.0, 540.0, 0.0, 0.0, 0.0, 1.0, 0.0] + ros_msg2.binning_x = 4 + ros_msg2.binning_y = 4 + + ros_msg2.roi = ROSRegionOfInterest() + ros_msg2.roi.x_offset = 10 + ros_msg2.roi.y_offset = 20 + ros_msg2.roi.height = 100 + ros_msg2.roi.width = 200 + ros_msg2.roi.do_rectify = True + + # Convert to DIMOS + dimos_info = CameraInfo.from_ros_msg(ros_msg2) + + assert dimos_info.height == 1080, ( + f"Height not preserved: expected 1080, got {dimos_info.height}" + ) + assert dimos_info.width == 1920, f"Width not preserved: expected 1920, got {dimos_info.width}" + assert dimos_info.frame_id == "test_camera", ( + f"Frame ID not preserved: expected 'test_camera', got '{dimos_info.frame_id}'" + ) + assert dimos_info.distortion_model == "plumb_bob", "Distortion model not preserved" + assert len(dimos_info.D) == 5, ( + f"Wrong number of distortion coefficients: expected 5, got {len(dimos_info.D)}" + ) + print("✓ ROS to DIMOS conversion works correctly") + + # Test 3: Empty/minimal CameraInfo + minimal = CameraInfo(frame_id="minimal_camera", ts=1234567890.0) + minimal_ros = minimal.to_ros_msg() + minimal_converted = CameraInfo.from_ros_msg(minimal_ros) + + assert minimal.frame_id == minimal_converted.frame_id, ( + "Minimal CameraInfo frame_id not preserved" + ) + assert len(minimal_converted.D) == 0, "Minimal CameraInfo should have empty D" + print("✓ Minimal CameraInfo handling works") + + print("\n✓ All ROS conversion tests passed!") + + +def test_equality() -> None: + """Test CameraInfo equality comparison.""" + print("\nTesting CameraInfo equality...") + + info1 = CameraInfo( + height=480, + width=640, + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.0, 0.0, 0.0], + frame_id="camera1", + ) + + info2 = CameraInfo( + height=480, + width=640, + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.0, 0.0, 0.0], + frame_id="camera1", + ) + + info3 = CameraInfo( + height=720, + width=1280, # Different resolution + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.0, 0.0, 0.0], + frame_id="camera1", + ) + + assert info1 == info2, "Identical CameraInfo objects should be equal" + assert info1 != info3, "Different CameraInfo objects should not be equal" + assert info1 != "not_camera_info", "CameraInfo should not equal non-CameraInfo object" + + print("✓ Equality comparison works correctly") + + +def test_camera_info_from_yaml() -> None: + """Test loading CameraInfo from YAML file.""" + + # Get path to the single webcam YAML file + yaml_path = get_project_root() / "dimos" / "hardware" / "camera" / "zed" / "single_webcam.yaml" + + # Load CameraInfo from YAML + camera_info = CameraInfo.from_yaml(str(yaml_path)) + + # Verify loaded values + assert camera_info.width == 640 + assert camera_info.height == 376 + assert camera_info.distortion_model == "plumb_bob" + assert camera_info.frame_id == "camera_optical" + + # Check camera matrix K + K = camera_info.get_K_matrix() + assert K.shape == (3, 3) + assert np.isclose(K[0, 0], 379.45267) # fx + assert np.isclose(K[1, 1], 380.67871) # fy + assert np.isclose(K[0, 2], 302.43516) # cx + assert np.isclose(K[1, 2], 228.00954) # cy + + # Check distortion coefficients + D = camera_info.get_D_coeffs() + assert len(D) == 5 + assert np.isclose(D[0], -0.309435) + + # Check projection matrix P + P = camera_info.get_P_matrix() + assert P.shape == (3, 4) + assert np.isclose(P[0, 0], 291.12888) + + print("✓ CameraInfo loaded successfully from YAML file") + + +def test_calibration_provider() -> None: + """Test CalibrationProvider lazy loading of YAML files.""" + # Get the directory containing calibration files (not the file itself) + calibration_dir = get_project_root() / "dimos" / "hardware" / "camera" / "zed" + + # Create CalibrationProvider instance + Calibrations = CalibrationProvider(calibration_dir) + + # Test lazy loading of single_webcam.yaml using snake_case + camera_info = Calibrations.single_webcam + assert isinstance(camera_info, CameraInfo) + assert camera_info.width == 640 + assert camera_info.height == 376 + + # Test PascalCase access to same calibration + camera_info2 = Calibrations.SingleWebcam + assert isinstance(camera_info2, CameraInfo) + assert camera_info2.width == 640 + assert camera_info2.height == 376 + + # Test caching - both access methods should return same object + assert camera_info is camera_info2 # Same object reference + + # Test __dir__ lists available calibrations in both cases + available = dir(Calibrations) + assert "single_webcam" in available + assert "SingleWebcam" in available + + print("✓ CalibrationProvider test passed with both naming conventions!") diff --git a/dimos/msgs/sensor_msgs/test_Joy.py b/dimos/msgs/sensor_msgs/test_Joy.py new file mode 100644 index 0000000000..77b47f4983 --- /dev/null +++ b/dimos/msgs/sensor_msgs/test_Joy.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +try: + from sensor_msgs.msg import Joy as ROSJoy + from std_msgs.msg import Header as ROSHeader + + ROS_AVAILABLE = True +except ImportError: + ROSJoy = None + ROSHeader = None + ROS_AVAILABLE = False + +from dimos.msgs.sensor_msgs.Joy import Joy + + +def test_lcm_encode_decode() -> None: + """Test LCM encode/decode preserves Joy data.""" + print("Testing Joy LCM encode/decode...") + + # Create test joy message with sample gamepad data + original = Joy( + ts=1234567890.123456789, + frame_id="gamepad", + axes=[0.5, -0.25, 1.0, -1.0, 0.0, 0.75], # 6 axes (e.g., left/right sticks + triggers) + buttons=[1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], # 12 buttons + ) + + # Encode to LCM bytes + encoded = original.lcm_encode() + assert isinstance(encoded, bytes) + assert len(encoded) > 0 + + # Decode back + decoded = Joy.lcm_decode(encoded) + + # Verify all fields match + assert abs(decoded.ts - original.ts) < 1e-9 + assert decoded.frame_id == original.frame_id + assert decoded.axes == original.axes + assert decoded.buttons == original.buttons + + print("✓ Joy LCM encode/decode test passed") + + +def test_initialization_methods() -> None: + """Test various initialization methods for Joy.""" + print("Testing Joy initialization methods...") + + # Test default initialization + joy1 = Joy() + assert joy1.axes == [] + assert joy1.buttons == [] + assert joy1.frame_id == "" + assert joy1.ts > 0 # Should have current time + + # Test full initialization + joy2 = Joy(ts=1234567890.0, frame_id="xbox_controller", axes=[0.1, 0.2, 0.3], buttons=[1, 0, 1]) + assert joy2.ts == 1234567890.0 + assert joy2.frame_id == "xbox_controller" + assert joy2.axes == [0.1, 0.2, 0.3] + assert joy2.buttons == [1, 0, 1] + + # Test tuple initialization + joy3 = Joy(([0.5, -0.5], [1, 1, 0])) + assert joy3.axes == [0.5, -0.5] + assert joy3.buttons == [1, 1, 0] + + # Test dict initialization + joy4 = Joy({"axes": [0.7, 0.8], "buttons": [0, 1], "frame_id": "ps4_controller"}) + assert joy4.axes == [0.7, 0.8] + assert joy4.buttons == [0, 1] + assert joy4.frame_id == "ps4_controller" + + # Test copy constructor + joy5 = Joy(joy2) + assert joy5.ts == joy2.ts + assert joy5.frame_id == joy2.frame_id + assert joy5.axes == joy2.axes + assert joy5.buttons == joy2.buttons + assert joy5 is not joy2 # Different objects + + print("✓ Joy initialization methods test passed") + + +def test_equality() -> None: + """Test Joy equality comparison.""" + print("Testing Joy equality...") + + joy1 = Joy(ts=1000.0, frame_id="controller1", axes=[0.5, -0.5], buttons=[1, 0, 1]) + + joy2 = Joy(ts=1000.0, frame_id="controller1", axes=[0.5, -0.5], buttons=[1, 0, 1]) + + joy3 = Joy( + ts=1000.0, + frame_id="controller2", # Different frame_id + axes=[0.5, -0.5], + buttons=[1, 0, 1], + ) + + joy4 = Joy( + ts=1000.0, + frame_id="controller1", + axes=[0.6, -0.5], # Different axes + buttons=[1, 0, 1], + ) + + # Same content should be equal + assert joy1 == joy2 + + # Different frame_id should not be equal + assert joy1 != joy3 + + # Different axes should not be equal + assert joy1 != joy4 + + # Different type should not be equal + assert joy1 != "not a joy" + assert joy1 != 42 + + print("✓ Joy equality test passed") + + +def test_string_representation() -> None: + """Test Joy string representations.""" + print("Testing Joy string representations...") + + joy = Joy( + ts=1234567890.123, + frame_id="test_controller", + axes=[0.1, -0.2, 0.3, 0.4], + buttons=[1, 0, 1, 0, 0, 1], + ) + + # Test __str__ + str_repr = str(joy) + assert "Joy" in str_repr + assert "axes=4 values" in str_repr + assert "buttons=6 values" in str_repr + assert "test_controller" in str_repr + + # Test __repr__ + repr_str = repr(joy) + assert "Joy" in repr_str + assert "1234567890.123" in repr_str + assert "test_controller" in repr_str + assert "[0.1, -0.2, 0.3, 0.4]" in repr_str + assert "[1, 0, 1, 0, 0, 1]" in repr_str + + print("✓ Joy string representation test passed") + + +@pytest.mark.ros +def test_ros_conversion() -> None: + """Test conversion to/from ROS Joy messages.""" + print("Testing Joy ROS conversion...") + + # Create a ROS Joy message + ros_msg = ROSJoy() + ros_msg.header = ROSHeader() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456789 + ros_msg.header.frame_id = "ros_gamepad" + ros_msg.axes = [0.25, -0.75, 0.0, 1.0, -1.0] + ros_msg.buttons = [1, 1, 0, 0, 1, 0, 1, 0] + + # Convert from ROS + joy = Joy.from_ros_msg(ros_msg) + assert abs(joy.ts - 1234567890.123456789) < 1e-9 + assert joy.frame_id == "ros_gamepad" + assert joy.axes == [0.25, -0.75, 0.0, 1.0, -1.0] + assert joy.buttons == [1, 1, 0, 0, 1, 0, 1, 0] + + # Convert back to ROS + ros_msg2 = joy.to_ros_msg() + assert ros_msg2.header.frame_id == "ros_gamepad" + assert ros_msg2.header.stamp.sec == 1234567890 + assert abs(ros_msg2.header.stamp.nanosec - 123456789) < 100 # Allow small rounding + assert list(ros_msg2.axes) == [0.25, -0.75, 0.0, 1.0, -1.0] + assert list(ros_msg2.buttons) == [1, 1, 0, 0, 1, 0, 1, 0] + + print("✓ Joy ROS conversion test passed") + + +def test_edge_cases() -> None: + """Test Joy with edge cases.""" + print("Testing Joy edge cases...") + + # Empty axes and buttons + joy1 = Joy(axes=[], buttons=[]) + assert joy1.axes == [] + assert joy1.buttons == [] + encoded = joy1.lcm_encode() + decoded = Joy.lcm_decode(encoded) + assert decoded.axes == [] + assert decoded.buttons == [] + + # Large number of axes and buttons + many_axes = [float(i) / 100.0 for i in range(20)] + many_buttons = [i % 2 for i in range(32)] + joy2 = Joy(axes=many_axes, buttons=many_buttons) + assert len(joy2.axes) == 20 + assert len(joy2.buttons) == 32 + encoded = joy2.lcm_encode() + decoded = Joy.lcm_decode(encoded) + # Check axes with floating point tolerance + assert len(decoded.axes) == len(many_axes) + for i, (a, b) in enumerate(zip(decoded.axes, many_axes, strict=False)): + assert abs(a - b) < 1e-6, f"Axis {i}: {a} != {b}" + assert decoded.buttons == many_buttons + + # Extreme axis values + extreme_axes = [-1.0, 1.0, 0.0, -0.999999, 0.999999] + joy3 = Joy(axes=extreme_axes) + assert joy3.axes == extreme_axes + + print("✓ Joy edge cases test passed") diff --git a/dimos/msgs/sensor_msgs/test_PointCloud2.py b/dimos/msgs/sensor_msgs/test_PointCloud2.py new file mode 100644 index 0000000000..37090cb57f --- /dev/null +++ b/dimos/msgs/sensor_msgs/test_PointCloud2.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pytest + +try: + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2, PointField as ROSPointField + from std_msgs.msg import Header as ROSHeader +except ImportError: + ROSPointCloud2 = None + ROSPointField = None + ROSHeader = None + +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.testing import SensorReplay + +# Try to import ROS types for testing +try: + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + + +def test_lcm_encode_decode() -> None: + """Test LCM encode/decode preserves pointcloud data.""" + replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + lidar_msg: LidarMessage = replay.load_one("lidar_data_021") + + binary_msg = lidar_msg.lcm_encode() + decoded = PointCloud2.lcm_decode(binary_msg) + + # 1. Check number of points + original_points = lidar_msg.as_numpy() + decoded_points = decoded.as_numpy() + + print(f"Original points: {len(original_points)}") + print(f"Decoded points: {len(decoded_points)}") + assert len(original_points) == len(decoded_points), ( + f"Point count mismatch: {len(original_points)} vs {len(decoded_points)}" + ) + + # 2. Check point coordinates are preserved (within floating point tolerance) + if len(original_points) > 0: + np.testing.assert_allclose( + original_points, + decoded_points, + rtol=1e-6, + atol=1e-6, + err_msg="Point coordinates don't match between original and decoded", + ) + print(f"✓ All {len(original_points)} point coordinates match within tolerance") + + # 3. Check frame_id is preserved + assert lidar_msg.frame_id == decoded.frame_id, ( + f"Frame ID mismatch: '{lidar_msg.frame_id}' vs '{decoded.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{decoded.frame_id}'") + + # 4. Check timestamp is preserved (within reasonable tolerance for float precision) + if lidar_msg.ts is not None and decoded.ts is not None: + assert abs(lidar_msg.ts - decoded.ts) < 1e-6, ( + f"Timestamp mismatch: {lidar_msg.ts} vs {decoded.ts}" + ) + print(f"✓ Timestamp preserved: {decoded.ts}") + + # 5. Check pointcloud properties + assert len(lidar_msg.pointcloud.points) == len(decoded.pointcloud.points), ( + "Open3D pointcloud size mismatch" + ) + + # 6. Additional detailed checks + print("✓ Original pointcloud summary:") + print(f" - Points: {len(original_points)}") + print(f" - Bounds: {original_points.min(axis=0)} to {original_points.max(axis=0)}") + print(f" - Mean: {original_points.mean(axis=0)}") + + print("✓ Decoded pointcloud summary:") + print(f" - Points: {len(decoded_points)}") + print(f" - Bounds: {decoded_points.min(axis=0)} to {decoded_points.max(axis=0)}") + print(f" - Mean: {decoded_points.mean(axis=0)}") + + print("✓ LCM encode/decode test passed - all properties preserved!") + + +@pytest.mark.ros +def test_ros_conversion() -> None: + """Test ROS message conversion preserves pointcloud data.""" + if not ROS_AVAILABLE: + print("ROS packages not available - skipping ROS conversion test") + return + + print("\nTesting ROS PointCloud2 conversion...") + + # Create a simple test point cloud + import open3d as o3d + + points = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [-1.0, -2.0, -3.0], + [0.5, 0.5, 0.5], + ], + dtype=np.float32, + ) + + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(points) + + # Create DIMOS PointCloud2 + original = PointCloud2( + pointcloud=pc, + frame_id="test_frame", + ts=1234567890.123456, + ) + + # Test 1: Convert to ROS and back + ros_msg = original.to_ros_msg() + converted = PointCloud2.from_ros_msg(ros_msg) + + # Check points are preserved + original_points = original.as_numpy() + converted_points = converted.as_numpy() + + assert len(original_points) == len(converted_points), ( + f"Point count mismatch: {len(original_points)} vs {len(converted_points)}" + ) + + np.testing.assert_allclose( + original_points, + converted_points, + rtol=1e-6, + atol=1e-6, + err_msg="Points don't match after ROS conversion", + ) + print(f"✓ Points preserved: {len(converted_points)} points match") + + # Check metadata + assert original.frame_id == converted.frame_id, ( + f"Frame ID mismatch: '{original.frame_id}' vs '{converted.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{converted.frame_id}'") + + assert abs(original.ts - converted.ts) < 1e-6, ( + f"Timestamp mismatch: {original.ts} vs {converted.ts}" + ) + print(f"✓ Timestamp preserved: {converted.ts}") + + # Test 2: Create ROS message directly and convert to DIMOS + ros_msg2 = ROSPointCloud2() + ros_msg2.header = ROSHeader() + ros_msg2.header.frame_id = "ros_test_frame" + ros_msg2.header.stamp.sec = 1234567890 + ros_msg2.header.stamp.nanosec = 123456000 + + # Set up point cloud data + ros_msg2.height = 1 + ros_msg2.width = 3 + ros_msg2.fields = [ + ROSPointField(name="x", offset=0, datatype=ROSPointField.FLOAT32, count=1), + ROSPointField(name="y", offset=4, datatype=ROSPointField.FLOAT32, count=1), + ROSPointField(name="z", offset=8, datatype=ROSPointField.FLOAT32, count=1), + ] + ros_msg2.is_bigendian = False + ros_msg2.point_step = 12 + ros_msg2.row_step = 36 + + # Pack test points + test_points = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ], + dtype=np.float32, + ) + ros_msg2.data = test_points.tobytes() + ros_msg2.is_dense = True + + # Convert to DIMOS + dimos_pc = PointCloud2.from_ros_msg(ros_msg2) + + assert dimos_pc.frame_id == "ros_test_frame", ( + f"Frame ID not preserved: expected 'ros_test_frame', got '{dimos_pc.frame_id}'" + ) + + decoded_points = dimos_pc.as_numpy() + assert len(decoded_points) == 3, ( + f"Wrong number of points: expected 3, got {len(decoded_points)}" + ) + + np.testing.assert_allclose( + test_points, + decoded_points, + rtol=1e-6, + atol=1e-6, + err_msg="Points from ROS message don't match", + ) + print("✓ ROS to DIMOS conversion works correctly") + + # Test 3: Empty point cloud + empty_pc = PointCloud2( + pointcloud=o3d.geometry.PointCloud(), + frame_id="empty_frame", + ts=1234567890.0, + ) + + empty_ros = empty_pc.to_ros_msg() + assert empty_ros.width == 0, "Empty cloud should have width 0" + assert empty_ros.height == 0, "Empty cloud should have height 0" + assert len(empty_ros.data) == 0, "Empty cloud should have no data" + + empty_converted = PointCloud2.from_ros_msg(empty_ros) + assert len(empty_converted) == 0, "Empty cloud conversion failed" + print("✓ Empty point cloud handling works") + + print("\n✓ All ROS conversion tests passed!") + + +def test_bounding_box_intersects() -> None: + """Test bounding_box_intersects method with various scenarios.""" + # Test 1: Overlapping boxes + pc1 = PointCloud2.from_numpy(np.array([[0, 0, 0], [2, 2, 2]])) + pc2 = PointCloud2.from_numpy(np.array([[1, 1, 1], [3, 3, 3]])) + assert pc1.bounding_box_intersects(pc2) + assert pc2.bounding_box_intersects(pc1) # Should be symmetric + + # Test 2: Non-overlapping boxes + pc3 = PointCloud2.from_numpy(np.array([[0, 0, 0], [1, 1, 1]])) + pc4 = PointCloud2.from_numpy(np.array([[2, 2, 2], [3, 3, 3]])) + assert not pc3.bounding_box_intersects(pc4) + assert not pc4.bounding_box_intersects(pc3) + + # Test 3: Touching boxes (edge case - should be True) + pc5 = PointCloud2.from_numpy(np.array([[0, 0, 0], [1, 1, 1]])) + pc6 = PointCloud2.from_numpy(np.array([[1, 1, 1], [2, 2, 2]])) + assert pc5.bounding_box_intersects(pc6) + assert pc6.bounding_box_intersects(pc5) + + # Test 4: One box completely inside another + pc7 = PointCloud2.from_numpy(np.array([[0, 0, 0], [3, 3, 3]])) + pc8 = PointCloud2.from_numpy(np.array([[1, 1, 1], [2, 2, 2]])) + assert pc7.bounding_box_intersects(pc8) + assert pc8.bounding_box_intersects(pc7) + + # Test 5: Boxes overlapping only in 2 dimensions (not all 3) + pc9 = PointCloud2.from_numpy(np.array([[0, 0, 0], [2, 2, 1]])) + pc10 = PointCloud2.from_numpy(np.array([[1, 1, 2], [3, 3, 3]])) + assert not pc9.bounding_box_intersects(pc10) + assert not pc10.bounding_box_intersects(pc9) + + # Test 6: Real-world detection scenario with floating point coordinates + detection1_points = np.array( + [[-3.5, -0.3, 0.1], [-3.3, -0.2, 0.1], [-3.5, -0.3, 0.3], [-3.3, -0.2, 0.3]] + ) + pc_det1 = PointCloud2.from_numpy(detection1_points) + + detection2_points = np.array( + [[-3.4, -0.25, 0.15], [-3.2, -0.15, 0.15], [-3.4, -0.25, 0.35], [-3.2, -0.15, 0.35]] + ) + pc_det2 = PointCloud2.from_numpy(detection2_points) + + assert pc_det1.bounding_box_intersects(pc_det2) + + # Test 7: Single point clouds + pc_single1 = PointCloud2.from_numpy(np.array([[1.0, 1.0, 1.0]])) + pc_single2 = PointCloud2.from_numpy(np.array([[1.0, 1.0, 1.0]])) + pc_single3 = PointCloud2.from_numpy(np.array([[2.0, 2.0, 2.0]])) + + # Same point should intersect + assert pc_single1.bounding_box_intersects(pc_single2) + # Different points should not intersect + assert not pc_single1.bounding_box_intersects(pc_single3) + + # Test 8: Empty point clouds + pc_empty1 = PointCloud2.from_numpy(np.array([]).reshape(0, 3)) + pc_empty2 = PointCloud2.from_numpy(np.array([]).reshape(0, 3)) + PointCloud2.from_numpy(np.array([[1.0, 1.0, 1.0]])) + + # Empty clouds should handle gracefully (Open3D returns inf bounds) + # This might raise an exception or return False - we should handle gracefully + try: + result = pc_empty1.bounding_box_intersects(pc_empty2) + # If no exception, verify behavior is consistent + assert isinstance(result, bool) + except: + # If it raises an exception, that's also acceptable for empty clouds + pass + + print("✓ All bounding box intersection tests passed!") + + +if __name__ == "__main__": + test_lcm_encode_decode() + test_ros_conversion() + test_bounding_box_intersects() diff --git a/dimos/msgs/sensor_msgs/test_image.py b/dimos/msgs/sensor_msgs/test_image.py new file mode 100644 index 0000000000..24375139b3 --- /dev/null +++ b/dimos/msgs/sensor_msgs/test_image.py @@ -0,0 +1,148 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from reactivex import operators as ops + +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat, sharpness_barrier +from dimos.utils.data import get_data +from dimos.utils.testing import TimedSensorReplay + + +@pytest.fixture +def img(): + image_file_path = get_data("cafe.jpg") + return Image.from_file(str(image_file_path)) + + +def test_file_load(img: Image) -> None: + assert isinstance(img.data, np.ndarray) + assert img.width == 1024 + assert img.height == 771 + assert img.channels == 3 + assert img.shape == (771, 1024, 3) + assert img.data.dtype == np.uint8 + assert img.format == ImageFormat.BGR + assert img.frame_id == "" + assert isinstance(img.ts, float) + assert img.ts > 0 + assert img.data.flags["C_CONTIGUOUS"] + + +def test_lcm_encode_decode(img: Image) -> None: + binary_msg = img.lcm_encode() + decoded_img = Image.lcm_decode(binary_msg) + + assert isinstance(decoded_img, Image) + assert decoded_img is not img + assert decoded_img == img + + +def test_rgb_bgr_conversion(img: Image) -> None: + rgb = img.to_rgb() + assert not rgb == img + assert rgb.to_bgr() == img + + +def test_opencv_conversion(img: Image) -> None: + ocv = img.to_opencv() + decoded_img = Image.from_opencv(ocv) + + # artificially patch timestamp + decoded_img.ts = img.ts + assert decoded_img == img + + +@pytest.mark.tool +def test_sharpness_stream() -> None: + get_data("unitree_office_walk") # Preload data for testing + video_store = TimedSensorReplay( + "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() + ) + + cnt = 0 + for image in video_store.iterate(): + cnt = cnt + 1 + print(image.sharpness) + if cnt > 30: + return + + +def test_sharpness_barrier() -> None: + import time + from unittest.mock import MagicMock + + # Create mock images with known sharpness values + # This avoids loading real data from disk + mock_images = [] + sharpness_values = [0.3711, 0.3241, 0.3067, 0.2583, 0.3665] # Just 5 images for 1 window + + for i, sharp in enumerate(sharpness_values): + img = MagicMock() + img.sharpness = sharp + img.ts = 1758912038.208 + i * 0.01 # Simulate timestamps + mock_images.append(img) + + # Track what goes into windows and what comes out + start_wall_time = None + window_contents = [] # List of (wall_time, image) + emitted_images = [] + + def track_input(img): + """Track all images going into sharpness_barrier with wall-clock time""" + nonlocal start_wall_time + wall_time = time.time() + if start_wall_time is None: + start_wall_time = wall_time + relative_time = wall_time - start_wall_time + window_contents.append((relative_time, img)) + return img + + def track_output(img) -> None: + """Track what sharpness_barrier emits""" + emitted_images.append(img) + + # Use 20Hz frequency (0.05s windows) for faster test + # Emit images at 100Hz to get ~5 per window + from reactivex import from_iterable, interval + + source = from_iterable(mock_images).pipe( + ops.zip(interval(0.01)), # 100Hz emission rate + ops.map(lambda x: x[0]), # Extract just the image + ) + + source.pipe( + ops.do_action(track_input), # Track inputs + sharpness_barrier(20), # 20Hz = 0.05s windows + ops.do_action(track_output), # Track outputs + ).run() + + # Only need 0.08s for 1 full window at 20Hz plus buffer + time.sleep(0.08) + + # Verify we got correct emissions (items span across 2 windows due to timing) + # Items 1-4 arrive in first window (0-50ms), item 5 arrives in second window (50-100ms) + assert len(emitted_images) == 2, ( + f"Expected exactly 2 emissions (one per window), got {len(emitted_images)}" + ) + + # Group inputs by wall-clock windows and verify we got the sharpest + + # Verify each window emitted the sharpest image from that window + # First window (0-50ms): items 1-4 + assert emitted_images[0].sharpness == 0.3711 # Highest among first 4 items + + # Second window (50-100ms): only item 5 + assert emitted_images[1].sharpness == 0.3665 # Only item in second window diff --git a/dimos/msgs/std_msgs/Bool.py b/dimos/msgs/std_msgs/Bool.py new file mode 100644 index 0000000000..b260b8a340 --- /dev/null +++ b/dimos/msgs/std_msgs/Bool.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bool message type.""" + +from dimos_lcm.std_msgs import Bool as LCMBool # type: ignore[import-untyped] + +try: + from std_msgs.msg import Bool as ROSBool # type: ignore[attr-defined] +except ImportError: + ROSBool = None # type: ignore[assignment, misc] + + +class Bool(LCMBool): # type: ignore[misc] + """ROS-compatible Bool message.""" + + msg_name = "std_msgs.Bool" + + def __init__(self, data: bool = False) -> None: + """Initialize Bool with data value.""" + self.data = data + + @classmethod + def from_ros_msg(cls, ros_msg: ROSBool) -> "Bool": + """Create a Bool from a ROS std_msgs/Bool message. + + Args: + ros_msg: ROS Bool message + + Returns: + Bool instance + """ + return cls(data=ros_msg.data) + + def to_ros_msg(self) -> ROSBool: + """Convert to a ROS std_msgs/Bool message. + + Returns: + ROS Bool message + """ + if ROSBool is None: + raise ImportError("ROS std_msgs not available") + ros_msg = ROSBool() # type: ignore[no-untyped-call] + ros_msg.data = bool(self.data) + return ros_msg diff --git a/dimos/msgs/std_msgs/Header.py b/dimos/msgs/std_msgs/Header.py new file mode 100644 index 0000000000..b80f767514 --- /dev/null +++ b/dimos/msgs/std_msgs/Header.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +import time + +from dimos_lcm.std_msgs import Header as LCMHeader, Time as LCMTime # type: ignore[import-untyped] +from plum import dispatch + +# Import the actual LCM header type that's returned from decoding +try: + from lcm_msgs.std_msgs.Header import ( # type: ignore[import-not-found] + Header as DecodedLCMHeader, + ) +except ImportError: + DecodedLCMHeader = None + + +class Header(LCMHeader): # type: ignore[misc] + msg_name = "std_msgs.Header" + ts: float + + @dispatch + def __init__(self) -> None: + """Initialize a Header with current time and empty frame_id.""" + self.ts = time.time() + sec = int(self.ts) + nsec = int((self.ts - sec) * 1_000_000_000) + super().__init__(seq=0, stamp=LCMTime(sec=sec, nsec=nsec), frame_id="") + + @dispatch # type: ignore[no-redef] + def __init__(self, frame_id: str) -> None: + """Initialize a Header with current time and specified frame_id.""" + self.ts = time.time() + sec = int(self.ts) + nsec = int((self.ts - sec) * 1_000_000_000) + super().__init__(seq=1, stamp=LCMTime(sec=sec, nsec=nsec), frame_id=frame_id) + + @dispatch # type: ignore[no-redef] + def __init__(self, timestamp: float, frame_id: str = "", seq: int = 1) -> None: + """Initialize a Header with Unix timestamp, frame_id, and optional seq.""" + sec = int(timestamp) + nsec = int((timestamp - sec) * 1_000_000_000) + super().__init__(seq=seq, stamp=LCMTime(sec=sec, nsec=nsec), frame_id=frame_id) + + @dispatch # type: ignore[no-redef] + def __init__(self, timestamp: datetime, frame_id: str = "") -> None: + """Initialize a Header with datetime object and frame_id.""" + self.ts = timestamp.timestamp() + sec = int(self.ts) + nsec = int((self.ts - sec) * 1_000_000_000) + super().__init__(seq=1, stamp=LCMTime(sec=sec, nsec=nsec), frame_id=frame_id) + + @dispatch # type: ignore[no-redef] + def __init__(self, seq: int, stamp: LCMTime, frame_id: str) -> None: + """Initialize with explicit seq, stamp, and frame_id (LCM compatibility).""" + super().__init__(seq=seq, stamp=stamp, frame_id=frame_id) + + @dispatch # type: ignore[no-redef] + def __init__(self, header: LCMHeader) -> None: + """Initialize from another Header (copy constructor).""" + super().__init__(seq=header.seq, stamp=header.stamp, frame_id=header.frame_id) + + @dispatch # type: ignore[no-redef] + def __init__(self, header: object) -> None: + """Initialize from a decoded LCM header object.""" + # Handle the case where we get an lcm_msgs.std_msgs.Header.Header object + if hasattr(header, "seq") and hasattr(header, "stamp") and hasattr(header, "frame_id"): + super().__init__(seq=header.seq, stamp=header.stamp, frame_id=header.frame_id) + else: + raise ValueError(f"Cannot create Header from {type(header)}") + + @classmethod + def now(cls, frame_id: str = "", seq: int = 1) -> Header: + """Create a Header with current timestamp.""" + ts = time.time() + return cls(ts, frame_id, seq) + + @property + def timestamp(self) -> float: + """Get timestamp as Unix time (float).""" + return self.stamp.sec + (self.stamp.nsec / 1_000_000_000) # type: ignore[no-any-return] + + @property + def datetime(self) -> datetime: + """Get timestamp as datetime object.""" + return datetime.fromtimestamp(self.timestamp) + + def __str__(self) -> str: + return f"Header(seq={self.seq}, time={self.timestamp:.6f}, frame_id='{self.frame_id}')" + + def __repr__(self) -> str: + return f"Header(seq={self.seq}, stamp=Time(sec={self.stamp.sec}, nsec={self.stamp.nsec}), frame_id='{self.frame_id}')" diff --git a/dimos/msgs/std_msgs/Int32.py b/dimos/msgs/std_msgs/Int32.py new file mode 100644 index 0000000000..4ebea095e3 --- /dev/null +++ b/dimos/msgs/std_msgs/Int32.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Int32 message type.""" + +from typing import ClassVar + +from dimos_lcm.std_msgs import Int32 as LCMInt32 # type: ignore[import-untyped] + + +class Int32(LCMInt32): # type: ignore[misc] + """ROS-compatible Int32 message.""" + + msg_name: ClassVar[str] = "std_msgs.Int32" + + def __init__(self, data: int = 0) -> None: + """Initialize Int32 with data value.""" + self.data = data diff --git a/dimos/msgs/std_msgs/Int8.py b/dimos/msgs/std_msgs/Int8.py new file mode 100644 index 0000000000..5fb87ba769 --- /dev/null +++ b/dimos/msgs/std_msgs/Int8.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Int32 message type.""" + +from typing import ClassVar + +from dimos_lcm.std_msgs import Int8 as LCMInt8 # type: ignore[import-untyped] + +try: + from std_msgs.msg import Int8 as ROSInt8 # type: ignore[attr-defined] +except ImportError: + ROSInt8 = None # type: ignore[assignment, misc] + + +class Int8(LCMInt8): # type: ignore[misc] + """ROS-compatible Int32 message.""" + + msg_name: ClassVar[str] = "std_msgs.Int8" + + def __init__(self, data: int = 0) -> None: + """Initialize Int8 with data value.""" + self.data = data + + @classmethod + def from_ros_msg(cls, ros_msg: ROSInt8) -> "Int8": + """Create a Bool from a ROS std_msgs/Bool message. + + Args: + ros_msg: ROS Int8 message + + Returns: + Int8 instance + """ + return cls(data=ros_msg.data) + + def to_ros_msg(self) -> ROSInt8: + """Convert to a ROS std_msgs/Bool message. + + Returns: + ROS Int8 message + """ + if ROSInt8 is None: + raise ImportError("ROS std_msgs not available") + ros_msg = ROSInt8() # type: ignore[no-untyped-call] + ros_msg.data = self.data + return ros_msg diff --git a/dimos/msgs/std_msgs/__init__.py b/dimos/msgs/std_msgs/__init__.py new file mode 100644 index 0000000000..e517ea1864 --- /dev/null +++ b/dimos/msgs/std_msgs/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .Bool import Bool +from .Header import Header +from .Int8 import Int8 +from .Int32 import Int32 + +__all__ = ["Bool", "Header", "Int8", "Int32"] diff --git a/dimos/msgs/std_msgs/test_header.py b/dimos/msgs/std_msgs/test_header.py new file mode 100644 index 0000000000..93f20da283 --- /dev/null +++ b/dimos/msgs/std_msgs/test_header.py @@ -0,0 +1,98 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import time + +from dimos.msgs.std_msgs import Header + + +def test_header_initialization_methods() -> None: + """Test various ways to initialize a Header.""" + + # Method 1: With timestamp and frame_id + header1 = Header(123.456, "world") + assert header1.seq == 1 + assert header1.stamp.sec == 123 + assert header1.stamp.nsec == 456000000 + assert header1.frame_id == "world" + + # Method 2: With just frame_id (uses current time) + header2 = Header("base_link") + assert header2.seq == 1 + assert header2.frame_id == "base_link" + # Timestamp should be close to current time + assert abs(header2.timestamp - time.time()) < 0.1 + + # Method 3: Empty header (current time, empty frame_id) + header3 = Header() + assert header3.seq == 0 + assert header3.frame_id == "" + + # Method 4: With datetime object + dt = datetime(2025, 1, 18, 12, 30, 45, 500000) # 500ms + header4 = Header(dt, "sensor") + assert header4.seq == 1 + assert header4.frame_id == "sensor" + expected_timestamp = dt.timestamp() + assert abs(header4.timestamp - expected_timestamp) < 1e-6 + + # Method 5: With custom seq number + header5 = Header(999.123, "custom", seq=42) + assert header5.seq == 42 + assert header5.stamp.sec == 999 + assert header5.stamp.nsec == 123000000 + assert header5.frame_id == "custom" + + # Method 6: Using now() class method + header6 = Header.now("camera") + assert header6.seq == 1 + assert header6.frame_id == "camera" + assert abs(header6.timestamp - time.time()) < 0.1 + + # Method 7: now() with custom seq + header7 = Header.now("lidar", seq=99) + assert header7.seq == 99 + assert header7.frame_id == "lidar" + + +def test_header_properties() -> None: + """Test Header property accessors.""" + header = Header(1234567890.123456789, "test") + + # Test timestamp property + assert abs(header.timestamp - 1234567890.123456789) < 1e-6 + + # Test datetime property + dt = header.datetime + assert isinstance(dt, datetime) + assert abs(dt.timestamp() - 1234567890.123456789) < 1e-6 + + +def test_header_string_representation() -> None: + """Test Header string representations.""" + header = Header(100.5, "map", seq=10) + + # Test __str__ + str_repr = str(header) + assert "seq=10" in str_repr + assert "time=100.5" in str_repr + assert "frame_id='map'" in str_repr + + # Test __repr__ + repr_str = repr(header) + assert "Header(" in repr_str + assert "seq=10" in repr_str + assert "Time(sec=100, nsec=500000000)" in repr_str + assert "frame_id='map'" in repr_str diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py new file mode 100644 index 0000000000..91446bb28e --- /dev/null +++ b/dimos/msgs/tf2_msgs/TFMessage.py @@ -0,0 +1,161 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License.# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, BinaryIO + +from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage # type: ignore[import-untyped] + +try: + from geometry_msgs.msg import ( # type: ignore[attr-defined] + TransformStamped as ROSTransformStamped, + ) + from tf2_msgs.msg import TFMessage as ROSTFMessage # type: ignore[attr-defined] +except ImportError: + ROSTFMessage = None # type: ignore[assignment, misc] + ROSTransformStamped = None # type: ignore[assignment, misc] + +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +if TYPE_CHECKING: + from collections.abc import Iterator + + +class TFMessage: + """TFMessage that accepts Transform objects and encodes to LCM format.""" + + transforms: list[Transform] + msg_name = "tf2_msgs.TFMessage" + + def __init__(self, *transforms: Transform) -> None: + self.transforms = list(transforms) + + def add_transform(self, transform: Transform, child_frame_id: str = "base_link") -> None: + """Add a transform to the message.""" + self.transforms.append(transform) + self.transforms_length = len(self.transforms) + + def lcm_encode(self) -> bytes: + """Encode as LCM TFMessage. + + Args: + child_frame_ids: Optional list of child frame IDs for each transform. + If not provided, defaults to "base_link" for all. + """ + + res = list(map(lambda t: t.lcm_transform(), self.transforms)) + + lcm_msg = LCMTFMessage( + transforms_length=len(self.transforms), + transforms=res, + ) + + return lcm_msg.lcm_encode() # type: ignore[no-any-return] + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> TFMessage: + """Decode from LCM TFMessage bytes.""" + lcm_msg = LCMTFMessage.lcm_decode(data) + + # Convert LCM TransformStamped objects to Transform objects + transforms = [] + for lcm_transform_stamped in lcm_msg.transforms: + # Extract timestamp + ts = lcm_transform_stamped.header.stamp.sec + ( + lcm_transform_stamped.header.stamp.nsec / 1_000_000_000 + ) + + # Create Transform with our custom types + lcm_trans = lcm_transform_stamped.transform.translation + lcm_rot = lcm_transform_stamped.transform.rotation + + transform = Transform( + translation=Vector3(lcm_trans.x, lcm_trans.y, lcm_trans.z), + rotation=Quaternion(lcm_rot.x, lcm_rot.y, lcm_rot.z, lcm_rot.w), + frame_id=lcm_transform_stamped.header.frame_id, + child_frame_id=lcm_transform_stamped.child_frame_id, + ts=ts, + ) + transforms.append(transform) + + return cls(*transforms) + + def __len__(self) -> int: + """Return number of transforms.""" + return len(self.transforms) + + def __getitem__(self, index: int) -> Transform: + """Get transform by index.""" + return self.transforms[index] + + def __iter__(self) -> Iterator: # type: ignore[type-arg] + """Iterate over transforms.""" + return iter(self.transforms) + + def __repr__(self) -> str: + return f"TFMessage({len(self.transforms)} transforms)" + + def __str__(self) -> str: + lines = [f"TFMessage with {len(self.transforms)} transforms:"] + for i, transform in enumerate(self.transforms): + lines.append(f" [{i}] {transform.frame_id} @ {transform.ts:.3f}") + return "\n".join(lines) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTFMessage) -> TFMessage: + """Create a TFMessage from a ROS tf2_msgs/TFMessage message. + + Args: + ros_msg: ROS TFMessage message + + Returns: + TFMessage instance + """ + transforms = [] + for ros_transform_stamped in ros_msg.transforms: + # Convert from ROS TransformStamped to our Transform + transform = Transform.from_ros_transform_stamped(ros_transform_stamped) + transforms.append(transform) + + return cls(*transforms) + + def to_ros_msg(self) -> ROSTFMessage: + """Convert to a ROS tf2_msgs/TFMessage message. + + Returns: + ROS TFMessage message + """ + ros_msg = ROSTFMessage() # type: ignore[no-untyped-call] + + # Convert each Transform to ROS TransformStamped + for transform in self.transforms: + ros_msg.transforms.append(transform.to_ros_transform_stamped()) + + return ros_msg diff --git a/dimos/msgs/tf2_msgs/__init__.py b/dimos/msgs/tf2_msgs/__init__.py new file mode 100644 index 0000000000..683e4ec61b --- /dev/null +++ b/dimos/msgs/tf2_msgs/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.msgs.tf2_msgs.TFMessage import TFMessage + +__all__ = ["TFMessage"] diff --git a/dimos/msgs/tf2_msgs/test_TFMessage.py b/dimos/msgs/tf2_msgs/test_TFMessage.py new file mode 100644 index 0000000000..783692fb35 --- /dev/null +++ b/dimos/msgs/tf2_msgs/test_TFMessage.py @@ -0,0 +1,269 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + from geometry_msgs.msg import TransformStamped as ROSTransformStamped + from tf2_msgs.msg import TFMessage as ROSTFMessage +except ImportError: + ROSTransformStamped = None + ROSTFMessage = None + +from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage + +from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.tf2_msgs import TFMessage + + +def test_tfmessage_initialization() -> None: + """Test TFMessage initialization with Transform objects.""" + # Create some transforms + tf1 = Transform( + translation=Vector3(1, 2, 3), rotation=Quaternion(0, 0, 0, 1), frame_id="world", ts=100.0 + ) + tf2 = Transform( + translation=Vector3(4, 5, 6), + rotation=Quaternion(0, 0, 0.707, 0.707), + frame_id="map", + ts=101.0, + ) + + # Create TFMessage with transforms + msg = TFMessage(tf1, tf2) + + assert len(msg) == 2 + assert msg[0] == tf1 + assert msg[1] == tf2 + + # Test iteration + transforms = list(msg) + assert transforms == [tf1, tf2] + + +def test_tfmessage_empty() -> None: + """Test empty TFMessage.""" + msg = TFMessage() + assert len(msg) == 0 + assert list(msg) == [] + + +def test_tfmessage_add_transform() -> None: + """Test adding transforms to TFMessage.""" + msg = TFMessage() + + tf = Transform(translation=Vector3(1, 2, 3), frame_id="base", ts=200.0) + + msg.add_transform(tf) + assert len(msg) == 1 + assert msg[0] == tf + + +def test_tfmessage_lcm_encode_decode() -> None: + """Test encoding TFMessage to LCM bytes.""" + # Create transforms + tf1 = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + child_frame_id="robot", + frame_id="world", + ts=123.456, + ) + tf2 = Transform( + translation=Vector3(4.0, 5.0, 6.0), + rotation=Quaternion(0.0, 0.0, 0.707, 0.707), + frame_id="robot", + child_frame_id="target", + ts=124.567, + ) + + # Create TFMessage + msg = TFMessage(tf1, tf2) + + # Encode with custom child_frame_ids + encoded = msg.lcm_encode() + + # Decode using LCM to verify + lcm_msg = LCMTFMessage.lcm_decode(encoded) + + assert lcm_msg.transforms_length == 2 + + # Check first transform + ts1 = lcm_msg.transforms[0] + assert ts1.header.frame_id == "world" + assert ts1.child_frame_id == "robot" + assert ts1.header.stamp.sec == 123 + assert ts1.header.stamp.nsec == 456000000 + assert ts1.transform.translation.x == 1.0 + assert ts1.transform.translation.y == 2.0 + assert ts1.transform.translation.z == 3.0 + + # Check second transform + ts2 = lcm_msg.transforms[1] + assert ts2.header.frame_id == "robot" + assert ts2.child_frame_id == "target" + assert ts2.transform.rotation.z == 0.707 + assert ts2.transform.rotation.w == 0.707 + + +@pytest.mark.ros +def test_tfmessage_from_ros_msg() -> None: + """Test creating a TFMessage from a ROS TFMessage message.""" + + ros_msg = ROSTFMessage() + + # Add first transform + tf1 = ROSTransformStamped() + tf1.header.frame_id = "world" + tf1.header.stamp.sec = 123 + tf1.header.stamp.nanosec = 456000000 + tf1.child_frame_id = "robot" + tf1.transform.translation.x = 1.0 + tf1.transform.translation.y = 2.0 + tf1.transform.translation.z = 3.0 + tf1.transform.rotation.x = 0.0 + tf1.transform.rotation.y = 0.0 + tf1.transform.rotation.z = 0.0 + tf1.transform.rotation.w = 1.0 + ros_msg.transforms.append(tf1) + + # Add second transform + tf2 = ROSTransformStamped() + tf2.header.frame_id = "robot" + tf2.header.stamp.sec = 124 + tf2.header.stamp.nanosec = 567000000 + tf2.child_frame_id = "sensor" + tf2.transform.translation.x = 4.0 + tf2.transform.translation.y = 5.0 + tf2.transform.translation.z = 6.0 + tf2.transform.rotation.x = 0.0 + tf2.transform.rotation.y = 0.0 + tf2.transform.rotation.z = 0.707 + tf2.transform.rotation.w = 0.707 + ros_msg.transforms.append(tf2) + + # Convert to TFMessage + tfmsg = TFMessage.from_ros_msg(ros_msg) + + assert len(tfmsg) == 2 + + # Check first transform + assert tfmsg[0].frame_id == "world" + assert tfmsg[0].child_frame_id == "robot" + assert tfmsg[0].ts == 123.456 + assert tfmsg[0].translation.x == 1.0 + assert tfmsg[0].translation.y == 2.0 + assert tfmsg[0].translation.z == 3.0 + assert tfmsg[0].rotation.w == 1.0 + + # Check second transform + assert tfmsg[1].frame_id == "robot" + assert tfmsg[1].child_frame_id == "sensor" + assert tfmsg[1].ts == 124.567 + assert tfmsg[1].translation.x == 4.0 + assert tfmsg[1].translation.y == 5.0 + assert tfmsg[1].translation.z == 6.0 + assert tfmsg[1].rotation.z == 0.707 + assert tfmsg[1].rotation.w == 0.707 + + +@pytest.mark.ros +def test_tfmessage_to_ros_msg() -> None: + """Test converting a TFMessage to a ROS TFMessage message.""" + # Create transforms + tf1 = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="base_link", + ts=123.456, + ) + tf2 = Transform( + translation=Vector3(7.0, 8.0, 9.0), + rotation=Quaternion(0.1, 0.2, 0.3, 0.9), + frame_id="base_link", + child_frame_id="lidar", + ts=125.789, + ) + + tfmsg = TFMessage(tf1, tf2) + + # Convert to ROS message + ros_msg = tfmsg.to_ros_msg() + + assert isinstance(ros_msg, ROSTFMessage) + assert len(ros_msg.transforms) == 2 + + # Check first transform + assert ros_msg.transforms[0].header.frame_id == "map" + assert ros_msg.transforms[0].child_frame_id == "base_link" + assert ros_msg.transforms[0].header.stamp.sec == 123 + assert ros_msg.transforms[0].header.stamp.nanosec == 456000000 + assert ros_msg.transforms[0].transform.translation.x == 1.0 + assert ros_msg.transforms[0].transform.translation.y == 2.0 + assert ros_msg.transforms[0].transform.translation.z == 3.0 + assert ros_msg.transforms[0].transform.rotation.w == 1.0 + + # Check second transform + assert ros_msg.transforms[1].header.frame_id == "base_link" + assert ros_msg.transforms[1].child_frame_id == "lidar" + assert ros_msg.transforms[1].header.stamp.sec == 125 + assert ros_msg.transforms[1].header.stamp.nanosec == 789000000 + assert ros_msg.transforms[1].transform.translation.x == 7.0 + assert ros_msg.transforms[1].transform.translation.y == 8.0 + assert ros_msg.transforms[1].transform.translation.z == 9.0 + assert ros_msg.transforms[1].transform.rotation.x == 0.1 + assert ros_msg.transforms[1].transform.rotation.y == 0.2 + assert ros_msg.transforms[1].transform.rotation.z == 0.3 + assert ros_msg.transforms[1].transform.rotation.w == 0.9 + + +@pytest.mark.ros +def test_tfmessage_ros_roundtrip() -> None: + """Test round-trip conversion between TFMessage and ROS TFMessage.""" + # Create transforms with various properties + tf1 = Transform( + translation=Vector3(1.5, 2.5, 3.5), + rotation=Quaternion(0.15, 0.25, 0.35, 0.85), + frame_id="odom", + child_frame_id="base_footprint", + ts=100.123, + ) + tf2 = Transform( + translation=Vector3(0.1, 0.2, 0.3), + rotation=Quaternion(0.0, 0.0, 0.383, 0.924), + frame_id="base_footprint", + child_frame_id="camera", + ts=100.456, + ) + + original = TFMessage(tf1, tf2) + + # Convert to ROS and back + ros_msg = original.to_ros_msg() + restored = TFMessage.from_ros_msg(ros_msg) + + assert len(restored) == len(original) + + for orig_tf, rest_tf in zip(original, restored, strict=False): + assert rest_tf.frame_id == orig_tf.frame_id + assert rest_tf.child_frame_id == orig_tf.child_frame_id + assert rest_tf.ts == orig_tf.ts + assert rest_tf.translation.x == orig_tf.translation.x + assert rest_tf.translation.y == orig_tf.translation.y + assert rest_tf.translation.z == orig_tf.translation.z + assert rest_tf.rotation.x == orig_tf.rotation.x + assert rest_tf.rotation.y == orig_tf.rotation.y + assert rest_tf.rotation.z == orig_tf.rotation.z + assert rest_tf.rotation.w == orig_tf.rotation.w diff --git a/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py new file mode 100644 index 0000000000..0846f91ee6 --- /dev/null +++ b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py @@ -0,0 +1,68 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.tf2_msgs import TFMessage +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic + + +# Publishes a series of transforms representing a robot kinematic chain +# to actual LCM messages, foxglove running in parallel should render this +@pytest.mark.skip +def test_publish_transforms() -> None: + from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage + + lcm = LCM(autoconf=True) + lcm.start() + + topic = Topic(topic="/tf", lcm_type=LCMTFMessage) + + # Create a robot kinematic chain using our new types + current_time = time.time() + + # 1. World to base_link transform (robot at position) + world_to_base = Transform( + translation=Vector3(4.0, 3.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.382683, 0.923880), # 45 degrees around Z + frame_id="world", + child_frame_id="base_link", + ts=current_time, + ) + + # 2. Base to arm transform (arm lifted up) + base_to_arm = Transform( + translation=Vector3(0.2, 0.0, 1.5), + rotation=Quaternion(0.0, 0.258819, 0.0, 0.965926), # 30 degrees around Y + frame_id="base_link", + child_frame_id="arm_link", + ts=current_time, + ) + + lcm.publish(topic, TFMessage(world_to_base, base_to_arm)) + + time.sleep(0.05) + # 3. Arm to gripper transform (gripper extended) + arm_to_gripper = Transform( + translation=Vector3(0.5, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # No rotation + frame_id="arm_link", + child_frame_id="gripper_link", + ts=current_time, + ) + + lcm.publish(topic, TFMessage(world_to_base, arm_to_gripper)) diff --git a/dimos/msgs/vision_msgs/BoundingBox2DArray.py b/dimos/msgs/vision_msgs/BoundingBox2DArray.py new file mode 100644 index 0000000000..148e40612a --- /dev/null +++ b/dimos/msgs/vision_msgs/BoundingBox2DArray.py @@ -0,0 +1,21 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.vision_msgs.BoundingBox2DArray import ( # type: ignore[import-untyped] + BoundingBox2DArray as LCMBoundingBox2DArray, +) + + +class BoundingBox2DArray(LCMBoundingBox2DArray): # type: ignore[misc] + msg_name = "vision_msgs.BoundingBox2DArray" diff --git a/dimos/msgs/vision_msgs/BoundingBox3DArray.py b/dimos/msgs/vision_msgs/BoundingBox3DArray.py new file mode 100644 index 0000000000..ff79421b53 --- /dev/null +++ b/dimos/msgs/vision_msgs/BoundingBox3DArray.py @@ -0,0 +1,21 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.vision_msgs.BoundingBox3DArray import ( # type: ignore[import-untyped] + BoundingBox3DArray as LCMBoundingBox3DArray, +) + + +class BoundingBox3DArray(LCMBoundingBox3DArray): # type: ignore[misc] + msg_name = "vision_msgs.BoundingBox3DArray" diff --git a/dimos/msgs/vision_msgs/Detection2DArray.py b/dimos/msgs/vision_msgs/Detection2DArray.py new file mode 100644 index 0000000000..bb922e8edb --- /dev/null +++ b/dimos/msgs/vision_msgs/Detection2DArray.py @@ -0,0 +1,29 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dimos_lcm.vision_msgs.Detection2DArray import ( # type: ignore[import-untyped] + Detection2DArray as LCMDetection2DArray, +) + +from dimos.types.timestamped import to_timestamp + + +class Detection2DArray(LCMDetection2DArray): # type: ignore[misc] + msg_name = "vision_msgs.Detection2DArray" + + # for _get_field_type() to work when decoding in _decode_one() + __annotations__ = LCMDetection2DArray.__annotations__ + + @property + def ts(self) -> float: + return to_timestamp(self.header.stamp) diff --git a/dimos/msgs/vision_msgs/Detection3DArray.py b/dimos/msgs/vision_msgs/Detection3DArray.py new file mode 100644 index 0000000000..33dbb34a17 --- /dev/null +++ b/dimos/msgs/vision_msgs/Detection3DArray.py @@ -0,0 +1,21 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.vision_msgs.Detection3DArray import ( # type: ignore[import-untyped] + Detection3DArray as LCMDetection3DArray, +) + + +class Detection3DArray(LCMDetection3DArray): # type: ignore[misc] + msg_name = "vision_msgs.Detection3DArray" diff --git a/dimos/msgs/vision_msgs/__init__.py b/dimos/msgs/vision_msgs/__init__.py new file mode 100644 index 0000000000..af170cbfab --- /dev/null +++ b/dimos/msgs/vision_msgs/__init__.py @@ -0,0 +1,6 @@ +from .BoundingBox2DArray import BoundingBox2DArray +from .BoundingBox3DArray import BoundingBox3DArray +from .Detection2DArray import Detection2DArray +from .Detection3DArray import Detection3DArray + +__all__ = ["BoundingBox2DArray", "BoundingBox3DArray", "Detection2DArray", "Detection3DArray"] diff --git a/dimos/navigation/base.py b/dimos/navigation/base.py new file mode 100644 index 0000000000..e866dadc3c --- /dev/null +++ b/dimos/navigation/base.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State-machine-based navigation interface for autonomous mobile robots. + +This module defines the contract that all navigation backends must implement. The +interface centers on three design principles: + +1. **State machine control**: Navigation has explicit states (IDLE, FOLLOWING_PATH, + RECOVERY) that callers poll to determine progress. + +2. **Non-blocking goal submission**: `set_goal()` returns immediately after accepting + the goal, allowing agent skills to interleave navigation with perception or + communication during long traversals. + +3. **Success indicator**: When returning to IDLE, `is_goal_reached()` reports whether + the goal was actually reached (vs. cancelled/failed), so you can check the outcome + without polling state during the transition or maintaining external tracking. + +Typical usage pattern: + +```python +nav.set_goal(target_pose) +while nav.get_state() == NavigationState.FOLLOWING_PATH: + # Optionally: check sensors, update beliefs, handle interruptions + time.sleep(0.25) +return nav.is_goal_reached() # True = success, False = cancelled/failed +``` +""" + +from abc import ABC, abstractmethod +from enum import Enum + +from dimos.msgs.geometry_msgs import PoseStamped + + +class NavigationState(Enum): + """State machine states for navigation control. + + Used by skills and agents to monitor navigation progress and distinguish + between idle, active navigation, and recovery behaviors. + + Attributes: + IDLE: No active navigation goal. The navigator is ready to accept + a new goal via `set_goal()`. + FOLLOWING_PATH: Actively navigating toward a goal. Path planning + and motion control are engaged. + RECOVERY: Reserved for stuck detection and recovery behaviors. + Currently only partially implemented - most implementations + transition directly from FOLLOWING_PATH to IDLE on failure. + """ + + IDLE = "idle" + FOLLOWING_PATH = "following_path" + RECOVERY = "recovery" + + +class NavigationInterface(ABC): + """Abstract interface for state-machine-based robot navigation. + + Defines a uniform API for autonomous navigation that works across different + backends (ROS Nav2, custom planners, behavior trees). The interface uses + non-blocking goal submission with polling-based monitoring, allowing callers + to interleave navigation with other tasks. + + See also: + `NavigationState`: Enum defining the state machine states. + """ + + @abstractmethod + def set_goal(self, goal: PoseStamped) -> bool: + """Submit a navigation goal (non-blocking). + + Initiates navigation toward the target pose and returns immediately. If a + previous goal is active, it is implicitly cancelled. The navigator transitions + to the `FOLLOWING_PATH` state upon acceptance. + + Args: + goal: Target pose to navigate to. + + Returns: + True if goal was accepted, False otherwise. Acceptance does not + guarantee reachability. + + Note: + Use `get_state()` and `is_goal_reached()` to poll navigation progress. + The goal's frame_id determines the coordinate frame (e.g., "map", "odom"). + """ + pass + + @abstractmethod + def get_state(self) -> NavigationState: + """ + Get the current state of the navigator. + + Returns: + Current navigation state + """ + pass + + @abstractmethod + def is_goal_reached(self) -> bool: + """ + Check if the current goal has been reached. + + Returns: + True if goal was reached, False otherwise + """ + pass + + @abstractmethod + def cancel_goal(self) -> bool: + """ + Cancel the current navigation goal. + + Returns: + True if goal was cancelled, False if no goal was active + """ + pass + + +__all__ = ["NavigationInterface", "NavigationState"] diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py new file mode 100644 index 0000000000..662d535a26 --- /dev/null +++ b/dimos/navigation/bbox_navigation.py @@ -0,0 +1,76 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.DEBUG) + + +class BBoxNavigationModule(Module): + """Minimal module that converts 2D bbox center to navigation goals.""" + + detection2d: In[Detection2DArray] = None # type: ignore[assignment] + camera_info: In[CameraInfo] = None # type: ignore[assignment] + goal_request: Out[PoseStamped] = None # type: ignore[assignment] + + def __init__(self, goal_distance: float = 1.0) -> None: + super().__init__() + self.goal_distance = goal_distance + self.camera_intrinsics = None + + @rpc + def start(self) -> None: + unsub = self.camera_info.subscribe( + lambda msg: setattr(self, "camera_intrinsics", [msg.K[0], msg.K[4], msg.K[2], msg.K[5]]) + ) + self._disposables.add(Disposable(unsub)) + + unsub = self.detection2d.subscribe(self._on_detection) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + super().stop() + + def _on_detection(self, det: Detection2DArray) -> None: + if det.detections_length == 0 or not self.camera_intrinsics: + return + fx, fy, cx, cy = self.camera_intrinsics + center_x, center_y = ( + det.detections[0].bbox.center.position.x, + det.detections[0].bbox.center.position.y, + ) + x, y, z = ( + (center_x - cx) / fx * self.goal_distance, + (center_y - cy) / fy * self.goal_distance, + self.goal_distance, + ) + goal = PoseStamped( + position=Vector3(z, -x, -y), + orientation=Quaternion(0, 0, 0, 1), + frame_id=det.header.frame_id, + ) + logger.debug( + f"BBox center: ({center_x:.1f}, {center_y:.1f}) → " + f"Goal pose: ({z:.2f}, {-x:.2f}, {-y:.2f}) in frame '{det.header.frame_id}'" + ) + self.goal_request.publish(goal) diff --git a/dimos/navigation/bt_navigator/__init__.py b/dimos/navigation/bt_navigator/__init__.py new file mode 100644 index 0000000000..cfd252ff6a --- /dev/null +++ b/dimos/navigation/bt_navigator/__init__.py @@ -0,0 +1 @@ +from .navigator import BehaviorTreeNavigator diff --git a/dimos/navigation/bt_navigator/goal_validator.py b/dimos/navigation/bt_navigator/goal_validator.py new file mode 100644 index 0000000000..8e387be2ad --- /dev/null +++ b/dimos/navigation/bt_navigator/goal_validator.py @@ -0,0 +1,443 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque + +import numpy as np + +from dimos.msgs.geometry_msgs import Vector3, VectorLike +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid + + +def find_safe_goal( + costmap: OccupancyGrid, + goal: VectorLike, + algorithm: str = "bfs", + cost_threshold: int = 50, + min_clearance: float = 0.3, + max_search_distance: float = 5.0, + connectivity_check_radius: int = 3, +) -> Vector3 | None: + """ + Find a safe goal position when the original goal is in collision or too close to obstacles. + + Args: + costmap: The occupancy grid/costmap + goal: Original goal position in world coordinates + algorithm: Algorithm to use ("bfs", "spiral", "voronoi", "gradient_descent") + cost_threshold: Maximum acceptable cost for a safe position (default: 50) + min_clearance: Minimum clearance from obstacles in meters (default: 0.3m) + max_search_distance: Maximum distance to search from original goal in meters (default: 5.0m) + connectivity_check_radius: Radius in cells to check for connectivity (default: 3) + + Returns: + Safe goal position in world coordinates, or None if no safe position found + """ + + if algorithm == "bfs": + return _find_safe_goal_bfs( + costmap, + goal, + cost_threshold, + min_clearance, + max_search_distance, + connectivity_check_radius, + ) + elif algorithm == "spiral": + return _find_safe_goal_spiral( + costmap, + goal, + cost_threshold, + min_clearance, + max_search_distance, + connectivity_check_radius, + ) + elif algorithm == "voronoi": + return _find_safe_goal_voronoi( + costmap, goal, cost_threshold, min_clearance, max_search_distance + ) + elif algorithm == "gradient_descent": + return _find_safe_goal_gradient( + costmap, + goal, + cost_threshold, + min_clearance, + max_search_distance, + connectivity_check_radius, + ) + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + + +def _find_safe_goal_bfs( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, + connectivity_check_radius: int, +) -> Vector3 | None: + """ + BFS-based search for nearest safe goal position. + This guarantees finding the closest valid position. + + Pros: + - Guarantees finding the closest safe position + - Can check connectivity to avoid isolated spots + - Efficient for small to medium search areas + + Cons: + - Can be slower for large search areas + - Memory usage scales with search area + """ + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + gx, gy = int(goal_grid.x), int(goal_grid.y) + + # Convert distances to grid cells + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + max_search_cells = int(np.ceil(max_search_distance / costmap.resolution)) + + # BFS queue and visited set + queue = deque([(gx, gy, 0)]) + visited = set([(gx, gy)]) + + # 8-connected neighbors + neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0), (1, 1), (1, -1), (-1, 1), (-1, -1)] + + while queue: + x, y, dist = queue.popleft() + + # Check if we've exceeded max search distance + if dist > max_search_cells: + break + + # Check if position is valid + if _is_position_safe( + costmap, x, y, cost_threshold, clearance_cells, connectivity_check_radius + ): + # Convert back to world coordinates + return costmap.grid_to_world((x, y)) + + # Add neighbors to queue + for dx, dy in neighbors: + nx, ny = x + dx, y + dy + + # Check bounds + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + if (nx, ny) not in visited: + visited.add((nx, ny)) + queue.append((nx, ny, dist + 1)) + + return None + + +def _find_safe_goal_spiral( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, + connectivity_check_radius: int, +) -> Vector3 | None: + """ + Spiral search pattern from goal outward. + + Pros: + - Simple and predictable pattern + - Memory efficient + - Good for uniformly distributed obstacles + + Cons: + - May not find the absolute closest safe position + - Can miss nearby safe spots due to spiral pattern + """ + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + cx, cy = int(goal_grid.x), int(goal_grid.y) + + # Convert distances to grid cells + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + max_radius = int(np.ceil(max_search_distance / costmap.resolution)) + + # Spiral outward + for radius in range(0, max_radius + 1): + if radius == 0: + # Check center point + if _is_position_safe( + costmap, cx, cy, cost_threshold, clearance_cells, connectivity_check_radius + ): + return costmap.grid_to_world((cx, cy)) + else: + # Check points on the square perimeter at this radius + points = [] + + # Top and bottom edges + for x in range(cx - radius, cx + radius + 1): + points.append((x, cy - radius)) # Top + points.append((x, cy + radius)) # Bottom + + # Left and right edges (excluding corners to avoid duplicates) + for y in range(cy - radius + 1, cy + radius): + points.append((cx - radius, y)) # Left + points.append((cx + radius, y)) # Right + + # Check each point + for x, y in points: + if 0 <= x < costmap.width and 0 <= y < costmap.height: + if _is_position_safe( + costmap, x, y, cost_threshold, clearance_cells, connectivity_check_radius + ): + return costmap.grid_to_world((x, y)) + + return None + + +def _find_safe_goal_voronoi( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, +) -> Vector3 | None: + """ + Find safe position using Voronoi diagram (ridge points equidistant from obstacles). + + Pros: + - Finds positions maximally far from obstacles + - Good for narrow passages + - Natural safety margin + + Cons: + - More computationally expensive + - May find positions unnecessarily far from obstacles + - Requires scipy for efficient implementation + """ + + from scipy import ndimage + from skimage.morphology import skeletonize # type: ignore[import-not-found] + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + gx, gy = int(goal_grid.x), int(goal_grid.y) + + # Create binary obstacle map + free_map = (costmap.grid < cost_threshold) & (costmap.grid != CostValues.UNKNOWN) + + # Compute distance transform + distance_field = ndimage.distance_transform_edt(free_map) + + # Find skeleton/medial axis (approximation of Voronoi diagram) + skeleton = skeletonize(free_map) + + # Filter skeleton points by minimum clearance + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + valid_skeleton = skeleton & (distance_field >= clearance_cells) # type: ignore[operator] + + if not np.any(valid_skeleton): + # Fall back to BFS if no valid skeleton points + return _find_safe_goal_bfs( + costmap, goal, cost_threshold, min_clearance, max_search_distance, 3 + ) + + # Find nearest valid skeleton point to goal + skeleton_points = np.argwhere(valid_skeleton) + if len(skeleton_points) == 0: + return None + + # Calculate distances from goal to all skeleton points + distances = np.sqrt((skeleton_points[:, 1] - gx) ** 2 + (skeleton_points[:, 0] - gy) ** 2) + + # Filter by max search distance + max_search_cells = max_search_distance / costmap.resolution + valid_indices = distances <= max_search_cells + + if not np.any(valid_indices): + return None + + # Find closest valid point + valid_distances = distances[valid_indices] + valid_points = skeleton_points[valid_indices] + closest_idx = np.argmin(valid_distances) + best_y, best_x = valid_points[closest_idx] + + return costmap.grid_to_world((best_x, best_y)) + + +def _find_safe_goal_gradient( + costmap: OccupancyGrid, + goal: VectorLike, + cost_threshold: int, + min_clearance: float, + max_search_distance: float, + connectivity_check_radius: int, +) -> Vector3 | None: + """ + Use gradient descent on the costmap to find a safe position. + + Pros: + - Naturally flows away from obstacles + - Works well with gradient costmaps + - Can handle complex cost distributions + + Cons: + - Can get stuck in local minima + - Requires a gradient costmap + - May not find globally optimal position + """ + + # Convert goal to grid coordinates + goal_grid = costmap.world_to_grid(goal) + x, y = goal_grid.x, goal_grid.y + + # Convert distances to grid cells + clearance_cells = int(np.ceil(min_clearance / costmap.resolution)) + max_search_cells = int(np.ceil(max_search_distance / costmap.resolution)) + + # Create gradient if needed (assuming costmap might already be a gradient) + if np.all((costmap.grid == 0) | (costmap.grid == 100) | (costmap.grid == -1)): + # Binary map, create gradient + gradient_map = costmap.gradient( + obstacle_threshold=cost_threshold, max_distance=min_clearance * 2 + ) + grid = gradient_map.grid + else: + grid = costmap.grid + + # Gradient descent with momentum + momentum = 0.9 + learning_rate = 1.0 + vx, vy = 0.0, 0.0 + + best_x, best_y = None, None + best_cost = float("inf") + + for iteration in range(100): # Max iterations + ix, iy = int(x), int(y) + + # Check if current position is valid + if 0 <= ix < costmap.width and 0 <= iy < costmap.height: + current_cost = grid[iy, ix] + + # Check distance from original goal + dist = np.sqrt((x - goal_grid.x) ** 2 + (y - goal_grid.y) ** 2) + if dist > max_search_cells: + break + + # Check if position is safe + if _is_position_safe( + costmap, ix, iy, cost_threshold, clearance_cells, connectivity_check_radius + ): + if current_cost < best_cost: + best_x, best_y = ix, iy + best_cost = current_cost + + # If cost is very low, we found a good spot + if current_cost < 10: + break + + # Compute gradient using finite differences + gx, gy = 0.0, 0.0 + + if 0 < ix < costmap.width - 1: + gx = (grid[iy, min(ix + 1, costmap.width - 1)] - grid[iy, max(ix - 1, 0)]) / 2.0 + + if 0 < iy < costmap.height - 1: + gy = (grid[min(iy + 1, costmap.height - 1), ix] - grid[max(iy - 1, 0), ix]) / 2.0 + + # Update with momentum + vx = momentum * vx - learning_rate * gx + vy = momentum * vy - learning_rate * gy + + # Update position + x += vx + y += vy + + # Add small random noise to escape local minima + if iteration % 20 == 0: + x += np.random.randn() * 0.5 + y += np.random.randn() * 0.5 + + if best_x is not None and best_y is not None: + return costmap.grid_to_world((best_x, best_y)) + + return None + + +def _is_position_safe( + costmap: OccupancyGrid, + x: int, + y: int, + cost_threshold: int, + clearance_cells: int, + connectivity_check_radius: int, +) -> bool: + """ + Check if a position is safe based on multiple criteria. + + Args: + costmap: The occupancy grid + x, y: Grid coordinates to check + cost_threshold: Maximum acceptable cost + clearance_cells: Minimum clearance in cells + connectivity_check_radius: Radius to check for connectivity + + Returns: + True if position is safe, False otherwise + """ + + # Check bounds first + if not (0 <= x < costmap.width and 0 <= y < costmap.height): + return False + + # Check if position itself is free + if costmap.grid[y, x] >= cost_threshold or costmap.grid[y, x] == CostValues.UNKNOWN: + return False + + # Check clearance around position + for dy in range(-clearance_cells, clearance_cells + 1): + for dx in range(-clearance_cells, clearance_cells + 1): + nx, ny = x + dx, y + dy + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + # Check if within circular clearance + if dx * dx + dy * dy <= clearance_cells * clearance_cells: + if costmap.grid[ny, nx] >= cost_threshold: + return False + + # Check connectivity (not surrounded by obstacles) + # Count free neighbors in a larger radius + free_count = 0 + total_count = 0 + + for dy in range(-connectivity_check_radius, connectivity_check_radius + 1): + for dx in range(-connectivity_check_radius, connectivity_check_radius + 1): + if dx == 0 and dy == 0: + continue + + nx, ny = x + dx, y + dy + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + total_count += 1 + if ( + costmap.grid[ny, nx] < cost_threshold + and costmap.grid[ny, nx] != CostValues.UNKNOWN + ): + free_count += 1 + + # Require at least 50% of neighbors to be free (not surrounded) + if total_count > 0 and free_count < total_count * 0.5: + return False + + return True diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py new file mode 100644 index 0000000000..427675a386 --- /dev/null +++ b/dimos/navigation/bt_navigator/navigator.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Navigator module for coordinating global and local planning. +""" + +from collections.abc import Callable +import threading +import time + +from dimos_lcm.std_msgs import Bool, String # type: ignore[import-untyped] +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.core.rpc_client import RpcCall +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.navigation.base import NavigationInterface, NavigationState +from dimos.navigation.bt_navigator.goal_validator import find_safe_goal +from dimos.navigation.bt_navigator.recovery_server import RecoveryServer +from dimos.protocol.tf import TF +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import apply_transform + +logger = setup_logger() + + +class BehaviorTreeNavigator(Module, NavigationInterface): + """ + Navigator module for coordinating navigation tasks. + + Manages the state machine for navigation, coordinates between global + and local planners, and monitors goal completion. + + Inputs: + - odom: Current robot odometry + + Outputs: + - goal: Goal pose for global planner + """ + + # LCM inputs + odom: In[PoseStamped] = None # type: ignore[assignment] + goal_request: In[PoseStamped] = None # type: ignore[assignment] # Input for receiving goal requests + global_costmap: In[OccupancyGrid] = None # type: ignore[assignment] + + # LCM outputs + target: Out[PoseStamped] = None # type: ignore[assignment] + goal_reached: Out[Bool] = None # type: ignore[assignment] + navigation_state: Out[String] = None # type: ignore[assignment] + + def __init__( # type: ignore[no-untyped-def] + self, + publishing_frequency: float = 1.0, + reset_local_planner: Callable[[], None] | None = None, + check_goal_reached: Callable[[], bool] | None = None, + **kwargs, + ) -> None: + """Initialize the Navigator. + + Args: + publishing_frequency: Frequency to publish goals to global planner (Hz) + goal_tolerance: Distance threshold to consider goal reached (meters) + """ + super().__init__(**kwargs) + + # Parameters + self.publishing_frequency = publishing_frequency + self.publishing_period = 1.0 / publishing_frequency + + # State machine + self.state = NavigationState.IDLE + self.state_lock = threading.Lock() + + # Current goal + self.current_goal: PoseStamped | None = None + self.original_goal: PoseStamped | None = None + self.goal_lock = threading.Lock() + + # Goal reached state + self._goal_reached = False + + # Latest data + self.latest_odom: PoseStamped | None = None + self.latest_costmap: OccupancyGrid | None = None + + # Control thread + self.control_thread: threading.Thread | None = None + self.stop_event = threading.Event() + + # TF listener + self.tf = TF() + + # Local planner + self.reset_local_planner = reset_local_planner + self.check_goal_reached = check_goal_reached + + # Recovery server for stuck detection + self.recovery_server = RecoveryServer(stuck_duration=5.0) + + logger.info("Navigator initialized with stuck detection") + + @rpc + def set_HolonomicLocalPlanner_reset(self, callable: RpcCall) -> None: + self.reset_local_planner = callable + self.reset_local_planner.set_rpc(self.rpc) # type: ignore[arg-type] + + @rpc + def set_HolonomicLocalPlanner_is_goal_reached(self, callable: RpcCall) -> None: + self.check_goal_reached = callable + self.check_goal_reached.set_rpc(self.rpc) # type: ignore[arg-type] + + @rpc + def start(self) -> None: + super().start() + + # Subscribe to inputs + unsub = self.odom.subscribe(self._on_odom) + self._disposables.add(Disposable(unsub)) + + unsub = self.goal_request.subscribe(self._on_goal_request) + self._disposables.add(Disposable(unsub)) + + unsub = self.global_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) + + # Start control thread + self.stop_event.clear() + self.control_thread = threading.Thread(target=self._control_loop, daemon=True) + self.control_thread.start() + + logger.info("Navigator started") + + @rpc + def stop(self) -> None: + """Clean up resources including stopping the control thread.""" + + self.stop_navigation() + + self.stop_event.set() + if self.control_thread and self.control_thread.is_alive(): + self.control_thread.join(timeout=2.0) + + super().stop() + + @rpc + def cancel_goal(self) -> bool: + """ + Cancel the current navigation goal. + + Returns: + True if goal was cancelled, False if no goal was active + """ + self.stop_navigation() + return True + + @rpc + def set_goal(self, goal: PoseStamped) -> bool: + """ + Set a new navigation goal. + + Args: + goal: Target pose to navigate to + + Returns: + non-blocking: True if goal was accepted, False otherwise + blocking: True if goal was reached, False otherwise + """ + transformed_goal = self._transform_goal_to_odom_frame(goal) + if not transformed_goal: + logger.error("Failed to transform goal to odometry frame") + return False + + with self.goal_lock: + self.current_goal = transformed_goal + self.original_goal = transformed_goal + + self._goal_reached = False + + with self.state_lock: + self.state = NavigationState.FOLLOWING_PATH + + return True + + @rpc + def get_state(self) -> NavigationState: + """Get the current state of the navigator.""" + return self.state + + def _on_odom(self, msg: PoseStamped) -> None: + """Handle incoming odometry messages.""" + self.latest_odom = msg + + if self.state == NavigationState.FOLLOWING_PATH: + self.recovery_server.update_odom(msg) + + def _on_goal_request(self, msg: PoseStamped) -> None: + """Handle incoming goal requests.""" + self.set_goal(msg) + + def _on_costmap(self, msg: OccupancyGrid) -> None: + """Handle incoming costmap messages.""" + self.latest_costmap = msg + + def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> PoseStamped | None: + """Transform goal pose to the odometry frame.""" + if not goal.frame_id: + return goal + + if not self.latest_odom: + logger.error("No odometry data available to transform goal") + return None + + odom_frame = self.latest_odom.frame_id + if goal.frame_id == odom_frame: + return goal + + try: + transform = None + max_retries = 3 + + for attempt in range(max_retries): + transform = self.tf.get( + parent_frame=odom_frame, + child_frame=goal.frame_id, + ) + + if transform: + break + + if attempt < max_retries - 1: + logger.warning( + f"Transform attempt {attempt + 1}/{max_retries} failed, retrying..." + ) + time.sleep(1.0) + else: + logger.error( + f"Could not find transform from '{goal.frame_id}' to '{odom_frame}' after {max_retries} attempts" + ) + return None + + pose = apply_transform(goal, transform) # type: ignore[arg-type] + transformed_goal = PoseStamped( + position=pose.position, + orientation=pose.orientation, + frame_id=odom_frame, + ts=goal.ts, + ) + return transformed_goal + + except Exception as e: + logger.error(f"Failed to transform goal: {e}") + return None + + def _control_loop(self) -> None: + """Main control loop running in separate thread.""" + while not self.stop_event.is_set(): + with self.state_lock: + current_state = self.state + self.navigation_state.publish(String(data=current_state.value)) + + if current_state == NavigationState.FOLLOWING_PATH: + with self.goal_lock: + goal = self.current_goal + original_goal = self.original_goal + + if goal is not None and self.latest_costmap is not None: + # Check if robot is stuck + if self.recovery_server.check_stuck(): + logger.warning("Robot is stuck! Cancelling goal and resetting.") + self.cancel_goal() + continue + + costmap = self.latest_costmap.inflate(0.1).gradient(max_distance=1.0) + + # Find safe goal position + safe_goal_pos = find_safe_goal( + costmap, + original_goal.position, # type: ignore[union-attr] + algorithm="bfs", + cost_threshold=60, + min_clearance=0.25, + max_search_distance=5.0, + ) + + # Create new goal with safe position + if safe_goal_pos: + safe_goal = PoseStamped( + position=safe_goal_pos, + orientation=goal.orientation, + frame_id=goal.frame_id, + ts=goal.ts, + ) + self.target.publish(safe_goal) + self.current_goal = safe_goal + else: + logger.warning("Could not find safe goal position, cancelling goal") + self.cancel_goal() + + # Check if goal is reached + if self.check_goal_reached(): # type: ignore[misc] + reached_msg = Bool() + reached_msg.data = True + self.goal_reached.publish(reached_msg) + self.stop_navigation() + self._goal_reached = True + logger.info("Goal reached, resetting local planner") + + elif current_state == NavigationState.RECOVERY: + with self.state_lock: + self.state = NavigationState.IDLE + + time.sleep(self.publishing_period) + + @rpc + def is_goal_reached(self) -> bool: + """Check if the current goal has been reached. + + Returns: + True if goal was reached, False otherwise + """ + return self._goal_reached + + def stop_navigation(self) -> None: + """Stop navigation and return to IDLE state.""" + with self.goal_lock: + self.current_goal = None + + self._goal_reached = False + + with self.state_lock: + self.state = NavigationState.IDLE + + self.reset_local_planner() # type: ignore[misc] + self.recovery_server.reset() # Reset recovery server when stopping + + logger.info("Navigator stopped") + + +behavior_tree_navigator = BehaviorTreeNavigator.blueprint + +__all__ = ["BehaviorTreeNavigator", "behavior_tree_navigator"] diff --git a/dimos/navigation/bt_navigator/recovery_server.py b/dimos/navigation/bt_navigator/recovery_server.py new file mode 100644 index 0000000000..a8c10fccc4 --- /dev/null +++ b/dimos/navigation/bt_navigator/recovery_server.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Recovery server for handling stuck detection and recovery behaviors. +""" + +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import get_distance + +logger = setup_logger() + + +class RecoveryServer: + """ + Recovery server for detecting stuck situations and executing recovery behaviors. + + Currently implements stuck detection based on time without significant movement. + Will be extended with actual recovery behaviors in the future. + """ + + def __init__( + self, + position_threshold: float = 0.2, + stuck_duration: float = 3.0, + ) -> None: + """Initialize the recovery server. + + Args: + position_threshold: Minimum distance to travel to reset stuck timer (meters) + stuck_duration: Time duration without significant movement to consider stuck (seconds) + """ + self.position_threshold = position_threshold + self.stuck_duration = stuck_duration + + # Store last position that exceeded threshold + self.last_moved_pose = None + self.last_moved_time = None + self.current_odom = None + + logger.info( + f"RecoveryServer initialized with position_threshold={position_threshold}, " + f"stuck_duration={stuck_duration}" + ) + + def update_odom(self, odom: PoseStamped) -> None: + """Update the odometry data for stuck detection. + + Args: + odom: Current robot odometry with timestamp + """ + if odom is None: + return + + # Store current odom for checking stuck + self.current_odom = odom # type: ignore[assignment] + + # Initialize on first update + if self.last_moved_pose is None: + self.last_moved_pose = odom # type: ignore[assignment] + self.last_moved_time = odom.ts # type: ignore[assignment] + return + + # Calculate distance from the reference position (last significant movement) + distance = get_distance(odom, self.last_moved_pose) + + # If robot has moved significantly from the reference, update reference + if distance > self.position_threshold: + self.last_moved_pose = odom + self.last_moved_time = odom.ts + + def check_stuck(self) -> bool: + """Check if the robot is stuck based on time without movement. + + Returns: + True if robot appears to be stuck, False otherwise + """ + if self.last_moved_time is None: + return False + + # Need current odom to check + if self.current_odom is None: + return False + + # Calculate time since last significant movement + current_time = self.current_odom.ts + time_since_movement = current_time - self.last_moved_time + + # Check if stuck based on duration without movement + is_stuck = time_since_movement > self.stuck_duration + + if is_stuck: + logger.warning( + f"Robot appears stuck! No movement for {time_since_movement:.1f} seconds" + ) + + return is_stuck + + def reset(self) -> None: + """Reset the recovery server state.""" + self.last_moved_pose = None + self.last_moved_time = None + self.current_odom = None + logger.debug("RecoveryServer reset") diff --git a/dimos/navigation/demo_ros_navigation.py b/dimos/navigation/demo_ros_navigation.py new file mode 100644 index 0000000000..733f66c1b7 --- /dev/null +++ b/dimos/navigation/demo_ros_navigation.py @@ -0,0 +1,72 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import rclpy + +from dimos import core +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Twist, Vector3 +from dimos.msgs.nav_msgs import Path +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.navigation.rosnav import ROSNav +from dimos.protocol import pubsub +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def main() -> None: + pubsub.lcm.autoconf() # type: ignore[attr-defined] + dimos = core.start(2) + + ros_nav = dimos.deploy(ROSNav) # type: ignore[attr-defined] + + ros_nav.goal_req.transport = core.LCMTransport("/goal", PoseStamped) + ros_nav.pointcloud.transport = core.LCMTransport("/pointcloud_map", PointCloud2) + ros_nav.global_pointcloud.transport = core.LCMTransport("/global_pointcloud", PointCloud2) + ros_nav.goal_active.transport = core.LCMTransport("/goal_active", PoseStamped) + ros_nav.path_active.transport = core.LCMTransport("/path_active", Path) + ros_nav.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) + + ros_nav.start() + + logger.info("\nTesting navigation in 2 seconds...") + time.sleep(2) + + test_pose = PoseStamped( + ts=time.time(), + frame_id="map", + position=Vector3(2.0, 2.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + + logger.info("Sending navigation goal to: (2.0, 2.0, 0.0)") + success = ros_nav.navigate_to(test_pose, timeout=30.0) + logger.info(f"Navigated successfully: {success}") + + try: + logger.info("\nNavBot running. Press Ctrl+C to stop.") + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("\nShutting down...") + ros_nav.stop() + + if rclpy.ok(): # type: ignore[attr-defined] + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/dimos/navigation/frontier_exploration/__init__.py b/dimos/navigation/frontier_exploration/__init__.py new file mode 100644 index 0000000000..24ce957ccf --- /dev/null +++ b/dimos/navigation/frontier_exploration/__init__.py @@ -0,0 +1,3 @@ +from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer, wavefront_frontier_explorer + +__all__ = ["WavefrontFrontierExplorer", "wavefront_frontier_explorer"] diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py new file mode 100644 index 0000000000..7d8c0adf4c --- /dev/null +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -0,0 +1,456 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +from PIL import ImageDraw +import pytest + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.navigation.frontier_exploration.utils import costmap_to_pil_image +from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( + WavefrontFrontierExplorer, +) + + +@pytest.fixture +def explorer(): + """Create a WavefrontFrontierExplorer instance for testing.""" + explorer = WavefrontFrontierExplorer( + min_frontier_perimeter=0.3, # Smaller for faster tests + safe_distance=0.5, # Smaller for faster distance calculations + info_gain_threshold=0.02, + ) + yield explorer + # Cleanup after test + try: + explorer.stop() + except: + pass + + +@pytest.fixture +def quick_costmap(): + """Create a very small costmap for quick tests.""" + width, height = 20, 20 + grid = np.full((height, width), CostValues.UNKNOWN, dtype=np.int8) + + # Simple free space in center + grid[8:12, 8:12] = CostValues.FREE + + # Small extensions + grid[9:11, 6:8] = CostValues.FREE # Left + grid[9:11, 12:14] = CostValues.FREE # Right + + # One obstacle + grid[9:10, 9:10] = CostValues.OCCUPIED + + from dimos.msgs.geometry_msgs import Pose + + origin = Pose() + origin.position.x = -1.0 + origin.position.y = -1.0 + origin.position.z = 0.0 + origin.orientation.w = 1.0 + + occupancy_grid = OccupancyGrid( + grid=grid, resolution=0.1, origin=origin, frame_id="map", ts=time.time() + ) + + class MockLidar: + def __init__(self) -> None: + self.origin = Vector3(0.0, 0.0, 0.0) + + return occupancy_grid, MockLidar() + + +def create_test_costmap(width: int = 40, height: int = 40, resolution: float = 0.1): + """Create a simple test costmap with free, occupied, and unknown regions. + + Default size reduced from 100x100 to 40x40 for faster tests. + """ + grid = np.full((height, width), CostValues.UNKNOWN, dtype=np.int8) + + # Create a smaller free space region with simple shape + # Central room + grid[15:25, 15:25] = CostValues.FREE + + # Small corridors extending from central room + grid[18:22, 10:15] = CostValues.FREE # Left corridor + grid[18:22, 25:30] = CostValues.FREE # Right corridor + grid[10:15, 18:22] = CostValues.FREE # Top corridor + grid[25:30, 18:22] = CostValues.FREE # Bottom corridor + + # Add fewer obstacles for faster processing + grid[19:21, 19:21] = CostValues.OCCUPIED # Central obstacle + grid[13:14, 18:22] = CostValues.OCCUPIED # Top corridor obstacle + + # Create origin at bottom-left, adjusted for map size + from dimos.msgs.geometry_msgs import Pose + + origin = Pose() + # Center the map around (0, 0) in world coordinates + origin.position.x = -(width * resolution) / 2.0 + origin.position.y = -(height * resolution) / 2.0 + origin.position.z = 0.0 + origin.orientation.w = 1.0 + + occupancy_grid = OccupancyGrid( + grid=grid, resolution=resolution, origin=origin, frame_id="map", ts=time.time() + ) + + # Create a mock lidar message with origin + class MockLidar: + def __init__(self) -> None: + self.origin = Vector3(0.0, 0.0, 0.0) + + return occupancy_grid, MockLidar() + + +def test_frontier_detection_with_office_lidar(explorer, quick_costmap) -> None: + """Test frontier detection using a test costmap.""" + # Get test costmap + costmap, first_lidar = quick_costmap + + # Verify we have a valid costmap + assert costmap is not None, "Costmap should not be None" + assert costmap.width > 0 and costmap.height > 0, "Costmap should have valid dimensions" + + print(f"Costmap dimensions: {costmap.width}x{costmap.height}") + print(f"Costmap resolution: {costmap.resolution}") + print(f"Unknown percent: {costmap.unknown_percent:.1f}%") + print(f"Free percent: {costmap.free_percent:.1f}%") + print(f"Occupied percent: {costmap.occupied_percent:.1f}%") + + # Set robot pose near the center of free space in the costmap + # We'll use the lidar origin as a reasonable robot position + robot_pose = first_lidar.origin + print(f"Robot pose: {robot_pose}") + + # Detect frontiers + frontiers = explorer.detect_frontiers(robot_pose, costmap) + + # Verify frontier detection results + assert isinstance(frontiers, list), "Frontiers should be returned as a list" + print(f"Detected {len(frontiers)} frontiers") + + # Test that we get some frontiers (office environment should have unexplored areas) + if len(frontiers) > 0: + print("Frontier detection successful - found unexplored areas") + + # Verify frontiers are Vector objects with valid coordinates + for i, frontier in enumerate(frontiers[:5]): # Check first 5 + assert isinstance(frontier, Vector3), f"Frontier {i} should be a Vector3" + assert hasattr(frontier, "x") and hasattr(frontier, "y"), ( + f"Frontier {i} should have x,y coordinates" + ) + print(f" Frontier {i}: ({frontier.x:.2f}, {frontier.y:.2f})") + else: + print("No frontiers detected - map may be fully explored or parameters too restrictive") + + explorer.stop() # TODO: this should be a in try-finally + + +def test_exploration_goal_selection(explorer) -> None: + """Test the complete exploration goal selection pipeline.""" + # Get test costmap - use regular size for more realistic test + costmap, first_lidar = create_test_costmap() + + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Get exploration goal + goal = explorer.get_exploration_goal(robot_pose, costmap) + + if goal is not None: + assert isinstance(goal, Vector3), "Goal should be a Vector3" + print(f"Selected exploration goal: ({goal.x:.2f}, {goal.y:.2f})") + + # Test that goal gets marked as explored + assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" + assert explorer.explored_goals[0] == goal, "Explored goal should match selected goal" + + # Test that goal is within costmap bounds + grid_pos = costmap.world_to_grid(goal) + assert 0 <= grid_pos.x < costmap.width, "Goal x should be within costmap bounds" + assert 0 <= grid_pos.y < costmap.height, "Goal y should be within costmap bounds" + + # Test that goal is at a reasonable distance from robot + distance = np.sqrt((goal.x - robot_pose.x) ** 2 + (goal.y - robot_pose.y) ** 2) + assert 0.1 < distance < 20.0, f"Goal distance {distance:.2f}m should be reasonable" + + else: + print("No exploration goal selected - map may be fully explored") + + explorer.stop() # TODO: this should be a in try-finally + + +def test_exploration_session_reset(explorer) -> None: + """Test exploration session reset functionality.""" + # Get test costmap + costmap, first_lidar = create_test_costmap() + + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Select a goal to populate exploration state + goal = explorer.get_exploration_goal(robot_pose, costmap) + + # Verify state is populated (skip if no goals available) + if goal: + initial_explored_count = len(explorer.explored_goals) + assert initial_explored_count > 0, "Should have at least one explored goal" + + # Reset exploration session + explorer.reset_exploration_session() + + # Verify state is cleared + assert len(explorer.explored_goals) == 0, "Explored goals should be cleared after reset" + assert explorer.exploration_direction.x == 0.0 and explorer.exploration_direction.y == 0.0, ( + "Exploration direction should be reset" + ) + assert explorer.last_costmap is None, "Last costmap should be cleared" + assert explorer.no_gain_counter == 0, "No-gain counter should be reset" + + print("Exploration session reset successfully") + explorer.stop() # TODO: this should be a in try-finally + + +def test_frontier_ranking(explorer) -> None: + """Test frontier ranking and scoring logic.""" + # Get test costmap + costmap, first_lidar = create_test_costmap() + + robot_pose = first_lidar.origin + + # Get first set of frontiers + frontiers1 = explorer.detect_frontiers(robot_pose, costmap) + goal1 = explorer.get_exploration_goal(robot_pose, costmap) + + if goal1: + # Verify the selected goal is the first in the ranked list + assert frontiers1[0].x == goal1.x and frontiers1[0].y == goal1.y, ( + "Selected goal should be the highest ranked frontier" + ) + + # Test that goals are being marked as explored + assert len(explorer.explored_goals) == 1, "Goal should be marked as explored" + assert ( + explorer.explored_goals[0].x == goal1.x and explorer.explored_goals[0].y == goal1.y + ), "Explored goal should match selected goal" + + # Get another goal + goal2 = explorer.get_exploration_goal(robot_pose, costmap) + if goal2: + assert len(explorer.explored_goals) == 2, ( + "Second goal should also be marked as explored" + ) + + # Test distance to obstacles + obstacle_dist = explorer._compute_distance_to_obstacles(goal1, costmap) + # Note: Goals might be closer than safe_distance if that's the best available frontier + # The safe_distance is used for scoring, not as a hard constraint + print( + f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.safe_distance}m)" + ) + + print(f"Frontier ranking test passed - selected goal at ({goal1.x:.2f}, {goal1.y:.2f})") + print(f"Total frontiers detected: {len(frontiers1)}") + else: + print("No frontiers found for ranking test") + + explorer.stop() # TODO: this should be a in try-finally + + +def test_exploration_with_no_gain_detection() -> None: + """Test information gain detection and exploration termination.""" + # Get initial costmap + costmap1, first_lidar = create_test_costmap() + + # Initialize explorer with low no-gain threshold for testing + explorer = WavefrontFrontierExplorer(info_gain_threshold=0.01, num_no_gain_attempts=2) + + try: + robot_pose = first_lidar.origin + + # Select multiple goals to populate history + for i in range(6): + goal = explorer.get_exploration_goal(robot_pose, costmap1) + if goal: + print(f"Goal {i + 1}: ({goal.x:.2f}, {goal.y:.2f})") + + # Now use same costmap repeatedly to trigger no-gain detection + initial_counter = explorer.no_gain_counter + + # This should increment no-gain counter + goal = explorer.get_exploration_goal(robot_pose, costmap1) + assert explorer.no_gain_counter > initial_counter, "No-gain counter should increment" + + # Continue until exploration stops + for _ in range(3): + goal = explorer.get_exploration_goal(robot_pose, costmap1) + if goal is None: + break + + # Should have stopped due to no information gain + assert goal is None, "Exploration should stop after no-gain threshold" + assert explorer.no_gain_counter == 0, "Counter should reset after stopping" + finally: + explorer.stop() + + +@pytest.mark.vis +def test_frontier_detection_visualization() -> None: + """Test frontier detection with visualization (marked with @pytest.mark.vis).""" + # Get test costmap + costmap, first_lidar = create_test_costmap() + + # Initialize frontier explorer with default parameters + explorer = WavefrontFrontierExplorer() + + try: + # Use lidar origin as robot position + robot_pose = first_lidar.origin + + # Detect all frontiers for visualization + all_frontiers = explorer.detect_frontiers(robot_pose, costmap) + + # Get selected goal + selected_goal = explorer.get_exploration_goal(robot_pose, costmap) + + print(f"Visualizing {len(all_frontiers)} frontier candidates") + if selected_goal: + print(f"Selected goal: ({selected_goal.x:.2f}, {selected_goal.y:.2f})") + + # Create visualization + image_scale_factor = 4 + base_image = costmap_to_pil_image(costmap, image_scale_factor) + + # Helper function to convert world coordinates to image coordinates + def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: + grid_pos = costmap.world_to_grid(world_pos) + img_x = int(grid_pos.x * image_scale_factor) + img_y = int((costmap.height - grid_pos.y) * image_scale_factor) # Flip Y + return img_x, img_y + + # Draw visualization + draw = ImageDraw.Draw(base_image) + + # Draw frontier candidates as gray dots + for frontier in all_frontiers[:20]: # Limit to top 20 + x, y = world_to_image_coords(frontier) + radius = 6 + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(128, 128, 128), # Gray + outline=(64, 64, 64), + width=1, + ) + + # Draw robot position as blue dot + robot_x, robot_y = world_to_image_coords(robot_pose) + robot_radius = 10 + draw.ellipse( + [ + robot_x - robot_radius, + robot_y - robot_radius, + robot_x + robot_radius, + robot_y + robot_radius, + ], + fill=(0, 0, 255), # Blue + outline=(0, 0, 128), + width=3, + ) + + # Draw selected goal as red dot + if selected_goal: + goal_x, goal_y = world_to_image_coords(selected_goal) + goal_radius = 12 + draw.ellipse( + [ + goal_x - goal_radius, + goal_y - goal_radius, + goal_x + goal_radius, + goal_y + goal_radius, + ], + fill=(255, 0, 0), # Red + outline=(128, 0, 0), + width=3, + ) + + # Display the image + base_image.show(title="Frontier Detection - Office Lidar") + print("Visualization displayed. Close the image window to continue.") + finally: + explorer.stop() + + +def test_performance_timing() -> None: + """Test performance by timing frontier detection operations.""" + import time + + # Test with different costmap sizes + sizes = [(20, 20), (40, 40), (60, 60)] + results = [] + + for width, height in sizes: + # Create costmap of specified size + costmap, lidar = create_test_costmap(width, height) + + # Create explorer with optimized parameters + explorer = WavefrontFrontierExplorer( + min_frontier_perimeter=0.3, + safe_distance=0.5, + info_gain_threshold=0.02, + ) + + try: + robot_pose = lidar.origin + + # Time frontier detection + start = time.time() + frontiers = explorer.detect_frontiers(robot_pose, costmap) + detect_time = time.time() - start + + # Time goal selection + start = time.time() + explorer.get_exploration_goal(robot_pose, costmap) + goal_time = time.time() - start + + results.append( + { + "size": f"{width}x{height}", + "cells": width * height, + "detect_time": detect_time, + "goal_time": goal_time, + "frontiers": len(frontiers), + } + ) + + print(f"\nSize {width}x{height}:") + print(f" Cells: {width * height}") + print(f" Frontier detection: {detect_time:.4f}s") + print(f" Goal selection: {goal_time:.4f}s") + print(f" Frontiers found: {len(frontiers)}") + finally: + explorer.stop() + + # Check that larger maps take more time (expected behavior) + for result in results: + assert result["detect_time"] < 2.0, f"Detection too slow: {result['detect_time']}s" + assert result["goal_time"] < 1.5, f"Goal selection too slow: {result['goal_time']}s" + + print("\nPerformance test passed - all operations completed within time limits") diff --git a/dimos/navigation/frontier_exploration/utils.py b/dimos/navigation/frontier_exploration/utils.py new file mode 100644 index 0000000000..c7dd01be6c --- /dev/null +++ b/dimos/navigation/frontier_exploration/utils.py @@ -0,0 +1,138 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for frontier exploration visualization and testing. +""" + +import numpy as np +from PIL import Image, ImageDraw + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid + + +def costmap_to_pil_image(costmap: OccupancyGrid, scale_factor: int = 2) -> Image.Image: + """ + Convert costmap to PIL Image with ROS-style coloring and optional scaling. + + Args: + costmap: Costmap to convert + scale_factor: Factor to scale up the image for better visibility + + Returns: + PIL Image with ROS-style colors + """ + # Create image array (height, width, 3 for RGB) + img_array = np.zeros((costmap.height, costmap.width, 3), dtype=np.uint8) + + # Apply ROS-style coloring based on costmap values + for i in range(costmap.height): + for j in range(costmap.width): + value = costmap.grid[i, j] + if value == CostValues.FREE: # Free space = light grey + img_array[i, j] = [205, 205, 205] + elif value == CostValues.UNKNOWN: # Unknown = dark gray + img_array[i, j] = [128, 128, 128] + elif value >= CostValues.OCCUPIED: # Occupied/obstacles = black + img_array[i, j] = [0, 0, 0] + else: # Any other values (low cost) = light grey + img_array[i, j] = [205, 205, 205] + + # Flip vertically to match ROS convention (origin at bottom-left) + img_array = np.flipud(img_array) + + # Create PIL image + img = Image.fromarray(img_array, "RGB") + + # Scale up if requested + if scale_factor > 1: + new_size = (img.width * scale_factor, img.height * scale_factor) + img = img.resize(new_size, Image.NEAREST) # Use NEAREST to keep sharp pixels + + return img + + +def draw_frontiers_on_image( + image: Image.Image, + costmap: OccupancyGrid, + frontiers: list[Vector3], + scale_factor: int = 2, + unfiltered_frontiers: list[Vector3] | None = None, +) -> Image.Image: + """ + Draw frontier points on the costmap image. + + Args: + image: PIL Image to draw on + costmap: Original costmap for coordinate conversion + frontiers: List of frontier centroids (top 5) + scale_factor: Scaling factor used for the image + unfiltered_frontiers: All unfiltered frontier results (light green) + + Returns: + PIL Image with frontiers drawn + """ + img_copy = image.copy() + draw = ImageDraw.Draw(img_copy) + + def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: + """Convert world coordinates to image pixel coordinates.""" + grid_pos = costmap.world_to_grid(world_pos) + # Flip Y coordinate and apply scaling + img_x = int(grid_pos.x * scale_factor) + img_y = int((costmap.height - grid_pos.y) * scale_factor) # Flip Y + return img_x, img_y + + # Draw all unfiltered frontiers as light green circles + if unfiltered_frontiers: + for frontier in unfiltered_frontiers: + x, y = world_to_image_coords(frontier) + radius = 3 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(144, 238, 144), + outline=(144, 238, 144), + ) # Light green + + # Draw top 5 frontiers as green circles + for i, frontier in enumerate(frontiers[1:]): # Skip the best one for now + x, y = world_to_image_coords(frontier) + radius = 4 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(0, 255, 0), + outline=(0, 128, 0), + width=2, + ) # Green + + # Add number label + draw.text((x + radius + 2, y - radius), str(i + 2), fill=(0, 255, 0)) + + # Draw best frontier as red circle + if frontiers: + best_frontier = frontiers[0] + x, y = world_to_image_coords(best_frontier) + radius = 6 * scale_factor + draw.ellipse( + [x - radius, y - radius, x + radius, y + radius], + fill=(255, 0, 0), + outline=(128, 0, 0), + width=3, + ) # Red + + # Add "BEST" label + draw.text((x + radius + 2, y - radius), "BEST", fill=(255, 0, 0)) + + return img_copy diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py new file mode 100644 index 0000000000..9a9d9ec5a9 --- /dev/null +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -0,0 +1,819 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple wavefront frontier exploration algorithm implementation using dimos types. + +This module provides frontier detection and exploration goal selection +for autonomous navigation using the dimos Costmap and Vector types. +""" + +from collections import deque +from dataclasses import dataclass +from enum import IntFlag +import threading + +from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] +import numpy as np +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import get_distance + +logger = setup_logger() + + +class PointClassification(IntFlag): + """Point classification flags for frontier detection algorithm.""" + + NoInformation = 0 + MapOpen = 1 + MapClosed = 2 + FrontierOpen = 4 + FrontierClosed = 8 + + +@dataclass +class GridPoint: + """Represents a point in the grid map with classification.""" + + x: int + y: int + classification: int = PointClassification.NoInformation + + +class FrontierCache: + """Cache for grid points to avoid duplicate point creation.""" + + def __init__(self) -> None: + self.points = {} # type: ignore[var-annotated] + + def get_point(self, x: int, y: int) -> GridPoint: + """Get or create a grid point at the given coordinates.""" + key = (x, y) + if key not in self.points: + self.points[key] = GridPoint(x, y) + return self.points[key] # type: ignore[no-any-return] + + def clear(self) -> None: + """Clear the point cache.""" + self.points.clear() + + +class WavefrontFrontierExplorer(Module): + """ + Wavefront frontier exploration algorithm implementation. + + This class encapsulates the frontier detection and exploration goal selection + functionality using the wavefront algorithm with BFS exploration. + + Inputs: + - costmap: Current costmap for frontier detection + - odometry: Current robot pose + + Outputs: + - goal_request: Exploration goals sent to the navigator + """ + + # LCM inputs + global_costmap: In[OccupancyGrid] = None # type: ignore[assignment] + odom: In[PoseStamped] = None # type: ignore[assignment] + goal_reached: In[Bool] = None # type: ignore[assignment] + explore_cmd: In[Bool] = None # type: ignore[assignment] + stop_explore_cmd: In[Bool] = None # type: ignore[assignment] + + # LCM outputs + goal_request: Out[PoseStamped] = None # type: ignore[assignment] + + def __init__( # type: ignore[no-untyped-def] + self, + min_frontier_perimeter: float = 0.5, + occupancy_threshold: int = 99, + safe_distance: float = 3.0, + lookahead_distance: float = 5.0, + max_explored_distance: float = 10.0, + info_gain_threshold: float = 0.03, + num_no_gain_attempts: int = 2, + goal_timeout: float = 15.0, + **kwargs, + ) -> None: + """ + Initialize the frontier explorer. + + Args: + min_frontier_perimeter: Minimum perimeter in meters to consider a valid frontier + occupancy_threshold: Cost threshold above which a cell is considered occupied (0-255) + safe_distance: Safe distance from obstacles for scoring (meters) + info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) + num_no_gain_attempts: Maximum number of consecutive attempts with no information gain + """ + super().__init__(**kwargs) + self.min_frontier_perimeter = min_frontier_perimeter + self.occupancy_threshold = occupancy_threshold + self.safe_distance = safe_distance + self.max_explored_distance = max_explored_distance + self.lookahead_distance = lookahead_distance + self.info_gain_threshold = info_gain_threshold + self.num_no_gain_attempts = num_no_gain_attempts + self._cache = FrontierCache() + self.explored_goals = [] # type: ignore[var-annotated] # list of explored goals + self.exploration_direction = Vector3(0.0, 0.0, 0.0) # current exploration direction + self.last_costmap = None # store last costmap for information comparison + self.no_gain_counter = 0 # track consecutive no-gain attempts + self.goal_timeout = goal_timeout + + # Latest data + self.latest_costmap: OccupancyGrid | None = None + self.latest_odometry: PoseStamped | None = None + + # Goal reached event + self.goal_reached_event = threading.Event() + + # Exploration state + self.exploration_active = False + self.exploration_thread: threading.Thread | None = None + self.stop_event = threading.Event() + + logger.info("WavefrontFrontierExplorer module initialized") + + @rpc + def start(self) -> None: + super().start() + + unsub = self.global_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) + + unsub = self.odom.subscribe(self._on_odometry) + self._disposables.add(Disposable(unsub)) + + if self.goal_reached.transport is not None: + unsub = self.goal_reached.subscribe(self._on_goal_reached) + self._disposables.add(Disposable(unsub)) + + if self.explore_cmd.transport is not None: + unsub = self.explore_cmd.subscribe(self._on_explore_cmd) + self._disposables.add(Disposable(unsub)) + + if self.stop_explore_cmd.transport is not None: + unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + self.stop_exploration() + super().stop() + + def _on_costmap(self, msg: OccupancyGrid) -> None: + """Handle incoming costmap messages.""" + self.latest_costmap = msg + + def _on_odometry(self, msg: PoseStamped) -> None: + """Handle incoming odometry messages.""" + self.latest_odometry = msg + + def _on_goal_reached(self, msg: Bool) -> None: + """Handle goal reached messages.""" + if msg.data: + self.goal_reached_event.set() + + def _on_explore_cmd(self, msg: Bool) -> None: + """Handle exploration command messages.""" + if msg.data: + logger.info("Received exploration start command via LCM") + self.explore() + + def _on_stop_explore_cmd(self, msg: Bool) -> None: + """Handle stop exploration command messages.""" + if msg.data: + logger.info("Received exploration stop command via LCM") + self.stop_exploration() + + def _count_costmap_information(self, costmap: OccupancyGrid) -> int: + """ + Count the amount of information in a costmap (free space + obstacles). + + Args: + costmap: Costmap to analyze + + Returns: + Number of cells that are free space or obstacles (not unknown) + """ + free_count = np.sum(costmap.grid == CostValues.FREE) + obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) + return int(free_count + obstacle_count) + + def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> list[GridPoint]: + """Get valid neighboring points for a given grid point.""" + neighbors = [] + + # 8-connected neighbors + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx == 0 and dy == 0: + continue + + nx, ny = point.x + dx, point.y + dy + + # Check bounds + if 0 <= nx < costmap.width and 0 <= ny < costmap.height: + neighbors.append(self._cache.get_point(nx, ny)) + + return neighbors + + def _is_frontier_point(self, point: GridPoint, costmap: OccupancyGrid) -> bool: + """ + Check if a point is a frontier point. + A frontier point is an unknown cell adjacent to at least one free cell + and not adjacent to any occupied cells. + """ + # Point must be unknown + cost = costmap.grid[point.y, point.x] + if cost != CostValues.UNKNOWN: + return False + + has_free = False + + for neighbor in self._get_neighbors(point, costmap): + neighbor_cost = costmap.grid[neighbor.y, neighbor.x] + + # If adjacent to occupied space, not a frontier + if neighbor_cost > self.occupancy_threshold: + return False + + # Check if adjacent to free space + if neighbor_cost == CostValues.FREE: + has_free = True + + return has_free + + def _find_free_space( + self, start_x: int, start_y: int, costmap: OccupancyGrid + ) -> tuple[int, int]: + """ + Find the nearest free space point using BFS from the starting position. + """ + queue = deque([self._cache.get_point(start_x, start_y)]) + visited = set() + + while queue: + point = queue.popleft() + + if (point.x, point.y) in visited: + continue + visited.add((point.x, point.y)) + + # Check if this point is free space + if costmap.grid[point.y, point.x] == CostValues.FREE: + return (point.x, point.y) + + # Add neighbors to search + for neighbor in self._get_neighbors(point, costmap): + if (neighbor.x, neighbor.y) not in visited: + queue.append(neighbor) + + # If no free space found, return original position + return (start_x, start_y) + + def _compute_centroid(self, frontier_points: list[Vector3]) -> Vector3: + """Compute the centroid of a list of frontier points.""" + if not frontier_points: + return Vector3(0.0, 0.0, 0.0) + + # Vectorized approach using numpy + points_array = np.array([[point.x, point.y] for point in frontier_points]) + centroid = np.mean(points_array, axis=0) + + return Vector3(centroid[0], centroid[1], 0.0) + + def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> list[Vector3]: + """ + Main frontier detection algorithm using wavefront exploration. + + Args: + robot_pose: Current robot position in world coordinates + costmap: Costmap for frontier detection + + Returns: + List of frontier centroids in world coordinates + """ + self._cache.clear() + + # Convert robot pose to grid coordinates + grid_pos = costmap.world_to_grid(robot_pose) + grid_x, grid_y = int(grid_pos.x), int(grid_pos.y) + + # Find nearest free space to start exploration + free_x, free_y = self._find_free_space(grid_x, grid_y, costmap) + start_point = self._cache.get_point(free_x, free_y) + start_point.classification = PointClassification.MapOpen + + # Main exploration queue - explore ALL reachable free space + map_queue = deque([start_point]) + frontiers = [] + frontier_sizes = [] + + points_checked = 0 + frontier_candidates = 0 + + while map_queue: + current_point = map_queue.popleft() + points_checked += 1 + + # Skip if already processed + if current_point.classification & PointClassification.MapClosed: + continue + + # Mark as processed + current_point.classification |= PointClassification.MapClosed + + # Check if this point starts a new frontier + if self._is_frontier_point(current_point, costmap): + frontier_candidates += 1 + current_point.classification |= PointClassification.FrontierOpen + frontier_queue = deque([current_point]) + new_frontier = [] + + # Explore this frontier region using BFS + while frontier_queue: + frontier_point = frontier_queue.popleft() + + # Skip if already processed + if frontier_point.classification & PointClassification.FrontierClosed: + continue + + # If this is still a frontier point, add to current frontier + if self._is_frontier_point(frontier_point, costmap): + new_frontier.append(frontier_point) + + # Add neighbors to frontier queue + for neighbor in self._get_neighbors(frontier_point, costmap): + if not ( + neighbor.classification + & ( + PointClassification.FrontierOpen + | PointClassification.FrontierClosed + ) + ): + neighbor.classification |= PointClassification.FrontierOpen + frontier_queue.append(neighbor) + + frontier_point.classification |= PointClassification.FrontierClosed + + # Check if we found a large enough frontier + # Convert minimum perimeter to minimum number of cells based on resolution + min_cells = int(self.min_frontier_perimeter / costmap.resolution) + if len(new_frontier) >= min_cells: + world_points = [] + for point in new_frontier: + world_pos = costmap.grid_to_world( + Vector3(float(point.x), float(point.y), 0.0) + ) + world_points.append(world_pos) + + # Compute centroid in world coordinates (already correctly scaled) + centroid = self._compute_centroid(world_points) + frontiers.append(centroid) # Store centroid + frontier_sizes.append(len(new_frontier)) # Store frontier size + + # Add ALL neighbors to main exploration queue to explore entire free space + for neighbor in self._get_neighbors(current_point, costmap): + if not ( + neighbor.classification + & (PointClassification.MapOpen | PointClassification.MapClosed) + ): + # Check if neighbor is free space or unknown (explorable) + neighbor_cost = costmap.grid[neighbor.y, neighbor.x] + + # Add free space and unknown space to exploration queue + if neighbor_cost == CostValues.FREE or neighbor_cost == CostValues.UNKNOWN: + neighbor.classification |= PointClassification.MapOpen + map_queue.append(neighbor) + + # Extract just the centroids for ranking + frontier_centroids = frontiers + + if not frontier_centroids: + return [] + + # Rank frontiers using original costmap for proper filtering + ranked_frontiers = self._rank_frontiers( + frontier_centroids, frontier_sizes, robot_pose, costmap + ) + + return ranked_frontiers + + def _update_exploration_direction( + self, robot_pose: Vector3, goal_pose: Vector3 | None = None + ) -> None: + """Update the current exploration direction based on robot movement or selected goal.""" + if goal_pose is not None: + # Calculate direction from robot to goal + direction = Vector3(goal_pose.x - robot_pose.x, goal_pose.y - robot_pose.y, 0.0) + magnitude = np.sqrt(direction.x**2 + direction.y**2) + if magnitude > 0.1: # Avoid division by zero for very close goals + self.exploration_direction = Vector3( + direction.x / magnitude, direction.y / magnitude, 0.0 + ) + + def _compute_direction_momentum_score(self, frontier: Vector3, robot_pose: Vector3) -> float: + """Compute direction momentum score for a frontier.""" + if self.exploration_direction.x == 0 and self.exploration_direction.y == 0: + return 0.0 # No momentum if no previous direction + + # Calculate direction from robot to frontier + frontier_direction = Vector3(frontier.x - robot_pose.x, frontier.y - robot_pose.y, 0.0) + magnitude = np.sqrt(frontier_direction.x**2 + frontier_direction.y**2) + + if magnitude < 0.1: + return 0.0 # Too close to calculate meaningful direction + + # Normalize frontier direction + frontier_direction = Vector3( + frontier_direction.x / magnitude, frontier_direction.y / magnitude, 0.0 + ) + + # Calculate dot product for directional alignment + dot_product = ( + self.exploration_direction.x * frontier_direction.x + + self.exploration_direction.y * frontier_direction.y + ) + + # Return momentum score (higher for same direction, lower for opposite) + return max(0.0, dot_product) # Only positive momentum, no penalty for different directions + + def _compute_distance_to_explored_goals(self, frontier: Vector3) -> float: + """Compute distance from frontier to the nearest explored goal.""" + if not self.explored_goals: + return 5.0 # Default consistent value when no explored goals + # Calculate distance to nearest explored goal + min_distance = float("inf") + for goal in self.explored_goals: + distance = np.sqrt((frontier.x - goal.x) ** 2 + (frontier.y - goal.y) ** 2) + min_distance = min(min_distance, distance) + + return min_distance + + def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGrid) -> float: + """ + Compute the minimum distance from a frontier point to the nearest obstacle. + + Args: + frontier: Frontier point in world coordinates + costmap: Costmap to check for obstacles + + Returns: + Minimum distance to nearest obstacle in meters + """ + # Convert frontier to grid coordinates + grid_pos = costmap.world_to_grid(frontier) + grid_x, grid_y = int(grid_pos.x), int(grid_pos.y) + + # Check if frontier is within costmap bounds + if grid_x < 0 or grid_x >= costmap.width or grid_y < 0 or grid_y >= costmap.height: + return 0.0 # Consider out-of-bounds as obstacle + + min_distance = float("inf") + search_radius = ( + int(self.safe_distance / costmap.resolution) + 5 + ) # Search a bit beyond minimum + + # Search in a square around the frontier point + for dy in range(-search_radius, search_radius + 1): + for dx in range(-search_radius, search_radius + 1): + check_x = grid_x + dx + check_y = grid_y + dy + + # Skip if out of bounds + if ( + check_x < 0 + or check_x >= costmap.width + or check_y < 0 + or check_y >= costmap.height + ): + continue + + # Check if this cell is an obstacle + if costmap.grid[check_y, check_x] >= self.occupancy_threshold: + # Calculate distance in meters + distance = np.sqrt(dx**2 + dy**2) * costmap.resolution + min_distance = min(min_distance, distance) + + # If no obstacles found within search radius, return the safe distance + # This indicates the frontier is safely away from obstacles + return min_distance if min_distance != float("inf") else self.safe_distance + + def _compute_comprehensive_frontier_score( + self, frontier: Vector3, frontier_size: int, robot_pose: Vector3, costmap: OccupancyGrid + ) -> float: + """Compute comprehensive score considering multiple criteria.""" + + # 1. Distance from robot (preference for moderate distances) + robot_distance = get_distance(frontier, robot_pose) + + # Distance score: prefer moderate distances (not too close, not too far) + # Normalized to 0-1 range + distance_score = 1.0 / (1.0 + abs(robot_distance - self.lookahead_distance)) + + # 2. Information gain (frontier size) + # Normalize by a reasonable max frontier size + max_expected_frontier_size = self.min_frontier_perimeter / costmap.resolution * 10 + info_gain_score = min(frontier_size / max_expected_frontier_size, 1.0) + + # 3. Distance to explored goals (bonus for being far from explored areas) + # Normalize by a reasonable max distance (e.g., 10 meters) + explored_goals_distance = self._compute_distance_to_explored_goals(frontier) + explored_goals_score = min(explored_goals_distance / self.max_explored_distance, 1.0) + + # 4. Distance to obstacles (score based on safety) + # 0 = too close to obstacles, 1 = at or beyond safe distance + obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) + if obstacles_distance >= self.safe_distance: + obstacles_score = 1.0 # Fully safe + else: + obstacles_score = obstacles_distance / self.safe_distance # Linear penalty + + # 5. Direction momentum (already in 0-1 range from dot product) + momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) + + logger.info( + f"Distance score: {distance_score:.2f}, Info gain: {info_gain_score:.2f}, Explored goals: {explored_goals_score:.2f}, Obstacles: {obstacles_score:.2f}, Momentum: {momentum_score:.2f}" + ) + + # Combine scores with consistent scaling + total_score = ( + 0.3 * info_gain_score # 30% information gain + + 0.3 * explored_goals_score # 30% distance from explored goals + + 0.2 * distance_score # 20% distance optimization + + 0.15 * obstacles_score # 15% distance from obstacles + + 0.05 * momentum_score # 5% direction momentum + ) + + return total_score + + def _rank_frontiers( + self, + frontier_centroids: list[Vector3], + frontier_sizes: list[int], + robot_pose: Vector3, + costmap: OccupancyGrid, + ) -> list[Vector3]: + """ + Find the single best frontier using comprehensive scoring and filtering. + + Args: + frontier_centroids: List of frontier centroids + frontier_sizes: List of frontier sizes + robot_pose: Current robot position + costmap: Costmap for additional analysis + + Returns: + List containing single best frontier, or empty list if none suitable + """ + if not frontier_centroids: + return [] + + valid_frontiers = [] + + for i, frontier in enumerate(frontier_centroids): + # Compute comprehensive score + frontier_size = frontier_sizes[i] if i < len(frontier_sizes) else 1 + score = self._compute_comprehensive_frontier_score( + frontier, frontier_size, robot_pose, costmap + ) + + valid_frontiers.append((frontier, score)) + + logger.info(f"Valid frontiers: {len(valid_frontiers)}") + + if not valid_frontiers: + return [] + + # Sort by score and return all valid frontiers (highest scores first) + valid_frontiers.sort(key=lambda x: x[1], reverse=True) + + # Extract just the frontiers (remove scores) and return as list + return [frontier for frontier, _ in valid_frontiers] + + def get_exploration_goal(self, robot_pose: Vector3, costmap: OccupancyGrid) -> Vector3 | None: + """ + Get the single best exploration goal using comprehensive frontier scoring. + + Args: + robot_pose: Current robot position in world coordinates + costmap: Costmap for additional analysis + + Returns: + Single best frontier goal in world coordinates, or None if no suitable frontiers found + """ + # Check if we should compare costmaps for information gain + if len(self.explored_goals) > 5 and self.last_costmap is not None: + current_info = self._count_costmap_information(costmap) + last_info = self._count_costmap_information(self.last_costmap) + + # Check if information increase meets minimum percentage threshold + if last_info > 0: # Avoid division by zero + info_increase_percent = (current_info - last_info) / last_info + if info_increase_percent < self.info_gain_threshold: + logger.info( + f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" + ) + logger.info( + f"Current information: {current_info}, Last information: {last_info}" + ) + self.no_gain_counter += 1 + if self.no_gain_counter >= self.num_no_gain_attempts: + logger.info( + f"No information gain for {self.no_gain_counter} consecutive attempts" + ) + self.no_gain_counter = 0 # Reset counter when stopping due to no gain + self.stop_exploration() + return None + else: + self.no_gain_counter = 0 + + # Always detect new frontiers to get most up-to-date information + # The new algorithm filters out explored areas and returns only the best frontier + frontiers = self.detect_frontiers(robot_pose, costmap) + + if not frontiers: + # Store current costmap before returning + self.last_costmap = costmap # type: ignore[assignment] + self.reset_exploration_session() + return None + + # Update exploration direction based on best goal selection + if frontiers: + self._update_exploration_direction(robot_pose, frontiers[0]) + + # Store the selected goal as explored + selected_goal = frontiers[0] + self.mark_explored_goal(selected_goal) + + # Store current costmap for next comparison + self.last_costmap = costmap # type: ignore[assignment] + + return selected_goal + + # Store current costmap before returning + self.last_costmap = costmap # type: ignore[assignment] + return None + + def mark_explored_goal(self, goal: Vector3) -> None: + """Mark a goal as explored.""" + self.explored_goals.append(goal) + + def reset_exploration_session(self) -> None: + """ + Reset all exploration state variables for a new exploration session. + + Call this method when starting a new exploration or when the robot + needs to forget its previous exploration history. + """ + self.explored_goals.clear() # Clear all previously explored goals + self.exploration_direction = Vector3(0.0, 0.0, 0.0) # Reset exploration direction + self.last_costmap = None # Clear last costmap comparison + self.no_gain_counter = 0 # Reset no-gain attempt counter + self._cache.clear() # Clear frontier point cache + + logger.info("Exploration session reset - all state variables cleared") + + @rpc + def explore(self) -> bool: + """ + Start autonomous frontier exploration. + + Returns: + bool: True if exploration started, False if already exploring + """ + if self.exploration_active: + logger.warning("Exploration already active") + return False + + self.exploration_active = True + self.stop_event.clear() + + # Start exploration thread + self.exploration_thread = threading.Thread(target=self._exploration_loop, daemon=True) + self.exploration_thread.start() + + logger.info("Started autonomous frontier exploration") + return True + + @rpc + def stop_exploration(self) -> bool: + """ + Stop autonomous frontier exploration. + + Returns: + bool: True if exploration was stopped, False if not exploring + """ + if not self.exploration_active: + return False + + self.exploration_active = False + self.no_gain_counter = 0 # Reset counter when exploration stops + self.stop_event.set() + + # Only join if we're NOT being called from the exploration thread itself + if ( + self.exploration_thread + and self.exploration_thread.is_alive() + and threading.current_thread() != self.exploration_thread + ): + self.exploration_thread.join(timeout=2.0) + + logger.info("Stopped autonomous frontier exploration") + return True + + @rpc + def is_exploration_active(self) -> bool: + return self.exploration_active + + def _exploration_loop(self) -> None: + """Main exploration loop running in separate thread.""" + # Track number of goals published + goals_published = 0 + consecutive_failures = 0 + max_consecutive_failures = 10 # Allow more attempts before giving up + + while self.exploration_active and not self.stop_event.is_set(): + # Check if we have required data + if self.latest_costmap is None or self.latest_odometry is None: + threading.Event().wait(0.5) + continue + + # Get robot pose from odometry + robot_pose = Vector3( + self.latest_odometry.position.x, self.latest_odometry.position.y, 0.0 + ) + + # Get exploration goal + costmap = self.latest_costmap.inflate(0.25) + goal = self.get_exploration_goal(robot_pose, costmap) + + if goal: + # Publish goal to navigator + goal_msg = PoseStamped() + goal_msg.position.x = goal.x + goal_msg.position.y = goal.y + goal_msg.position.z = 0.0 + goal_msg.orientation.w = 1.0 # No rotation + goal_msg.frame_id = "world" + goal_msg.ts = self.latest_costmap.ts + + self.goal_request.publish(goal_msg) + logger.info(f"Published frontier goal: ({goal.x:.2f}, {goal.y:.2f})") + + goals_published += 1 + consecutive_failures = 0 # Reset failure counter on success + + # Clear the goal reached event for next iteration + self.goal_reached_event.clear() + + # Wait for goal to be reached or timeout + logger.info("Waiting for goal to be reached...") + goal_reached = self.goal_reached_event.wait(timeout=self.goal_timeout) + + if goal_reached: + logger.info("Goal reached, finding next frontier") + else: + logger.warning("Goal timeout after 30 seconds, finding next frontier anyway") + else: + consecutive_failures += 1 + + # Only give up if we've published at least 2 goals AND had many consecutive failures + if goals_published >= 2 and consecutive_failures >= max_consecutive_failures: + logger.info( + f"Exploration complete after {goals_published} goals and {consecutive_failures} consecutive failures finding new frontiers" + ) + self.exploration_active = False + break + elif goals_published < 2: + logger.info( + f"No frontier found, but only {goals_published} goals published so far. Retrying in 2 seconds..." + ) + threading.Event().wait(2.0) + else: + logger.info( + f"No frontier found (attempt {consecutive_failures}/{max_consecutive_failures}). Retrying in 2 seconds..." + ) + threading.Event().wait(2.0) + + +wavefront_frontier_explorer = WavefrontFrontierExplorer.blueprint + +__all__ = ["WavefrontFrontierExplorer", "wavefront_frontier_explorer"] diff --git a/dimos/navigation/global_planner/__init__.py b/dimos/navigation/global_planner/__init__.py new file mode 100644 index 0000000000..275619659b --- /dev/null +++ b/dimos/navigation/global_planner/__init__.py @@ -0,0 +1,4 @@ +from dimos.navigation.global_planner.algo import astar +from dimos.navigation.global_planner.planner import AstarPlanner, astar_planner + +__all__ = ["AstarPlanner", "astar", "astar_planner"] diff --git a/dimos/navigation/global_planner/algo.py b/dimos/navigation/global_planner/algo.py new file mode 100644 index 0000000000..6b65d05dbd --- /dev/null +++ b/dimos/navigation/global_planner/algo.py @@ -0,0 +1,215 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import heapq + +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, VectorLike +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid, Path +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def astar( + costmap: OccupancyGrid, + goal: VectorLike, + start: VectorLike = (0.0, 0.0), + cost_threshold: int = 90, + unknown_penalty: float = 0.8, +) -> Path | None: + """ + A* path planning algorithm from start to goal position. + + Args: + costmap: Costmap object containing the environment + goal: Goal position as any vector-like object + start: Start position as any vector-like object (default: origin [0,0]) + cost_threshold: Cost threshold above which a cell is considered an obstacle + + Returns: + Path object containing waypoints, or None if no path found + """ + + # Convert world coordinates to grid coordinates directly using vector-like inputs + start_vector = costmap.world_to_grid(start) + goal_vector = costmap.world_to_grid(goal) + logger.debug(f"ASTAR {costmap} {start_vector} -> {goal_vector}") + + # Store positions as tuples for dictionary keys + start_tuple = (int(start_vector.x), int(start_vector.y)) + goal_tuple = (int(goal_vector.x), int(goal_vector.y)) + + # Check if goal is out of bounds + if not (0 <= goal_tuple[0] < costmap.width and 0 <= goal_tuple[1] < costmap.height): + return None + + # Define possible movements (8-connected grid with diagonal movements) + directions = [ + (0, 1), + (1, 0), + (0, -1), + (-1, 0), + (1, 1), + (1, -1), + (-1, 1), + (-1, -1), + ] + + # Cost for each movement (straight vs diagonal) + sc = 1.0 # Straight cost + dc = 1.42 # Diagonal cost (approximately sqrt(2)) + movement_costs = [sc, sc, sc, sc, dc, dc, dc, dc] + + # A* algorithm implementation + open_set = [] # type: ignore[var-annotated] # Priority queue for nodes to explore + closed_set = set() # Set of explored nodes + + # Dictionary to store cost from start and parents for each node + g_score = {start_tuple: 0} + parents = {} # type: ignore[var-annotated] + + # Heuristic function (Octile distance for 8-connected grid) + def heuristic(x1, y1, x2, y2): # type: ignore[no-untyped-def] + dx = abs(x2 - x1) + dy = abs(y2 - y1) + # Octile distance: optimal for 8-connected grids with diagonal movement + return (dx + dy) + (dc - 2 * sc) * min(dx, dy) + + # Start with the starting node + f_score = g_score[start_tuple] + heuristic( # type: ignore[no-untyped-call] + start_tuple[0], start_tuple[1], goal_tuple[0], goal_tuple[1] + ) + heapq.heappush(open_set, (f_score, start_tuple)) + + # Track nodes already in open set to avoid duplicates + open_set_hash = {start_tuple} + + while open_set: + # Get the node with the lowest f_score + _current_f, current = heapq.heappop(open_set) + current_x, current_y = current + + # Remove from open set hash + if current in open_set_hash: + open_set_hash.remove(current) + + # Skip if already processed (can happen with duplicate entries) + if current in closed_set: + continue + + # Check if we've reached the goal + if current == goal_tuple: + # Reconstruct the path + waypoints = [] + while current in parents: + world_point = costmap.grid_to_world(current) + # Create PoseStamped with identity quaternion (no orientation) + pose = PoseStamped( + frame_id="world", + position=[world_point.x, world_point.y, 0.0], + orientation=Quaternion(0, 0, 0, 1), # Identity quaternion + ) + waypoints.append(pose) + current = parents[current] + + # Add the start position + start_world_point = costmap.grid_to_world(start_tuple) + start_pose = PoseStamped( + frame_id="world", + position=[start_world_point.x, start_world_point.y, 0.0], + orientation=Quaternion(0, 0, 0, 1), + ) + waypoints.append(start_pose) + + # Reverse the path (start to goal) + waypoints.reverse() + + # Add the goal position if it's not already included + goal_point = costmap.grid_to_world(goal_tuple) + + if ( + not waypoints + or (waypoints[-1].x - goal_point.x) ** 2 + (waypoints[-1].y - goal_point.y) ** 2 + > 1e-10 + ): + goal_pose = PoseStamped( + frame_id="world", + position=[goal_point.x, goal_point.y, 0.0], + orientation=Quaternion(0, 0, 0, 1), + ) + waypoints.append(goal_pose) + + return Path(frame_id="world", poses=waypoints) + + # Add current node to closed set + closed_set.add(current) + + # Explore neighbors + for i, (dx, dy) in enumerate(directions): + neighbor_x, neighbor_y = current_x + dx, current_y + dy + neighbor = (neighbor_x, neighbor_y) + + # Check if the neighbor is valid + if not (0 <= neighbor_x < costmap.width and 0 <= neighbor_y < costmap.height): + continue + + # Check if the neighbor is already explored + if neighbor in closed_set: + continue + + # Get the neighbor's cost value + neighbor_val = costmap.grid[neighbor_y, neighbor_x] + + # Skip if it's a hard obstacle + if neighbor_val >= cost_threshold: + continue + + # Calculate movement cost with penalties + # Unknown cells get half the penalty of obstacles + if neighbor_val == CostValues.UNKNOWN: # Unknown cell (-1) + # Unknown cells have a moderate traversal cost (half of obstacle threshold) + cell_cost = cost_threshold * unknown_penalty + elif neighbor_val == CostValues.FREE: # Free space (0) + # Free cells have minimal cost + cell_cost = 0.0 + else: + # Other cells use their actual cost value (1-99) + cell_cost = neighbor_val + + # Calculate cost penalty based on cell cost (higher cost = higher penalty) + # This encourages the planner to prefer lower-cost paths + cost_penalty = cell_cost / CostValues.OCCUPIED # Normalized penalty (divide by 100) + + tentative_g_score = g_score[current] + movement_costs[i] * (1.0 + cost_penalty) + + # Get the current g_score for the neighbor or set to infinity if not yet explored + neighbor_g_score = g_score.get(neighbor, float("inf")) + + # If this path to the neighbor is better than any previous one + if tentative_g_score < neighbor_g_score: + # Update the neighbor's scores and parent + parents[neighbor] = current + g_score[neighbor] = tentative_g_score # type: ignore[assignment] + f_score = tentative_g_score + heuristic( # type: ignore[no-untyped-call] + neighbor_x, neighbor_y, goal_tuple[0], goal_tuple[1] + ) + + # Add the neighbor to the open set with its f_score + # Only add if not already in open set to reduce duplicates + if neighbor not in open_set_hash: + heapq.heappush(open_set, (f_score, neighbor)) + open_set_hash.add(neighbor) + + # If we get here, no path was found + return None diff --git a/dimos/navigation/global_planner/planner.py b/dimos/navigation/global_planner/planner.py new file mode 100644 index 0000000000..fe9fd7d475 --- /dev/null +++ b/dimos/navigation/global_planner/planner.py @@ -0,0 +1,238 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A* global path planning for reactive navigation. + +This module provides `AstarPlanner`, a DimOS Module that computes +obstacle-avoiding paths from the robot's current position to a target goal. +""" + +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import Pose, PoseStamped +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.navigation.global_planner.algo import astar +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import euler_to_quaternion + +logger = setup_logger() + +import math + +from dimos.msgs.geometry_msgs import Quaternion, Vector3 + + +def add_orientations_to_path(path: Path, goal_orientation: Quaternion = None) -> Path: # type: ignore[assignment] + """Add orientations to path poses based on direction of movement. + + Args: + path: Path with poses to add orientations to + goal_orientation: Desired orientation for the final pose + + Returns: + Path with orientations added to all poses + """ + if not path.poses or len(path.poses) < 2: + return path + + # Calculate orientations for all poses except the last one + for i in range(len(path.poses) - 1): + current_pose = path.poses[i] + next_pose = path.poses[i + 1] + + # Calculate direction to next point + dx = next_pose.position.x - current_pose.position.x + dy = next_pose.position.y - current_pose.position.y + + # Calculate yaw angle + yaw = math.atan2(dy, dx) + + # Convert to quaternion (roll=0, pitch=0, yaw) + orientation = euler_to_quaternion(Vector3(0, 0, yaw)) + current_pose.orientation = orientation + + # Set last pose orientation + identity_quat = Quaternion(0, 0, 0, 1) + if goal_orientation is not None and goal_orientation != identity_quat: + # Use the provided goal orientation if it's not the identity + path.poses[-1].orientation = goal_orientation + elif len(path.poses) > 1: + # Use the previous pose's orientation + path.poses[-1].orientation = path.poses[-2].orientation + else: + # Single pose with identity goal orientation + path.poses[-1].orientation = identity_quat + + return path + + +def resample_path(path: Path, spacing: float) -> Path: + """Resample a path to have approximately uniform spacing between poses. + + Args: + path: The original Path + spacing: Desired distance between consecutive poses + + Returns: + A new Path with resampled poses + """ + if len(path) < 2 or spacing <= 0: + return path + + resampled = [] + resampled.append(path.poses[0]) + + accumulated_distance = 0.0 + + for i in range(1, len(path.poses)): + current = path.poses[i] + prev = path.poses[i - 1] + + # Calculate segment distance + dx = current.x - prev.x + dy = current.y - prev.y + segment_length = (dx**2 + dy**2) ** 0.5 + + if segment_length < 1e-10: + continue + + # Direction vector + dir_x = dx / segment_length + dir_y = dy / segment_length + + # Add points along this segment + while accumulated_distance + segment_length >= spacing: + # Distance along segment for next point + dist_along = spacing - accumulated_distance + if dist_along < 0: + break + + # Create new pose + new_x = prev.x + dir_x * dist_along + new_y = prev.y + dir_y * dist_along + new_pose = PoseStamped( + frame_id=path.frame_id, + position=[new_x, new_y, 0.0], + orientation=prev.orientation, # Keep same orientation + ) + resampled.append(new_pose) + + # Update for next iteration + accumulated_distance = 0 + segment_length -= dist_along + prev = new_pose + + accumulated_distance += segment_length + + # Add last pose if not already there + if len(path.poses) > 1: + last = path.poses[-1] + if not resampled or (resampled[-1].x != last.x or resampled[-1].y != last.y): + resampled.append(last) + + return Path(frame_id=path.frame_id, poses=resampled) + + +class AstarPlanner(Module): + """Compute collision-free paths using A* on the global costmap. + + Publishes a path when target, costmap, and odometry are all available + and A* finds a valid route. No output published otherwise. + + To control the robot's facing direction at the goal, pass a non-identity orientation in the target. + Otherwise, the robot faces the direction of travel. + """ + + # LCM inputs + target: In[PoseStamped] = None # type: ignore[assignment] + global_costmap: In[OccupancyGrid] = None # type: ignore[assignment] + odom: In[PoseStamped] = None # type: ignore[assignment] + + # LCM outputs + path: Out[Path] = None # type: ignore[assignment] + + def __init__(self) -> None: + super().__init__() + + # Latest data + self.latest_costmap: OccupancyGrid | None = None + self.latest_odom: PoseStamped | None = None + + @rpc + def start(self) -> None: + super().start() + + unsub = self.target.subscribe(self._on_target) + self._disposables.add(Disposable(unsub)) + + unsub = self.global_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) + + unsub = self.odom.subscribe(self._on_odom) + self._disposables.add(Disposable(unsub)) + + logger.info("A* planner started") + + @rpc + def stop(self) -> None: + super().stop() + + def _on_costmap(self, msg: OccupancyGrid) -> None: + """Handle incoming costmap messages.""" + self.latest_costmap = msg + + def _on_odom(self, msg: PoseStamped) -> None: + """Handle incoming odometry messages.""" + self.latest_odom = msg + + def _on_target(self, msg: PoseStamped) -> None: + """Handle incoming target messages and trigger planning.""" + if self.latest_costmap is None or self.latest_odom is None: + logger.warning("Cannot plan: missing costmap or odometry data") + return + + path = self.plan(msg) + if path: + # Add orientations to the path, using the goal's orientation for the final pose + path = add_orientations_to_path(path, msg.orientation) + self.path.publish(path) + + def plan(self, goal: Pose) -> Path | None: + """Plan a path from current position to goal.""" + if self.latest_costmap is None or self.latest_odom is None: + logger.warning("Cannot plan: missing costmap or odometry data") + return None + + logger.debug(f"Planning path to goal {goal}") + + # Get current position from odometry + robot_pos = self.latest_odom.position + costmap = self.latest_costmap.inflate(0.2).gradient(max_distance=1.5) + + # Run A* planning + path = astar(costmap, goal.position, robot_pos) + + if path: + path = resample_path(path, 0.1) + logger.debug(f"Path found with {len(path.poses)} waypoints") + return path + + logger.warning("No path found to the goal.") + return None + + +astar_planner = AstarPlanner.blueprint + +__all__ = ["AstarPlanner", "astar_planner"] diff --git a/dimos/navigation/local_planner/__init__.py b/dimos/navigation/local_planner/__init__.py new file mode 100644 index 0000000000..9e0f62931a --- /dev/null +++ b/dimos/navigation/local_planner/__init__.py @@ -0,0 +1,2 @@ +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.navigation.local_planner.local_planner import BaseLocalPlanner diff --git a/dimos/navigation/local_planner/holonomic_local_planner.py b/dimos/navigation/local_planner/holonomic_local_planner.py new file mode 100644 index 0000000000..212c625dda --- /dev/null +++ b/dimos/navigation/local_planner/holonomic_local_planner.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Gradient-Augmented Look-Ahead Pursuit (GLAP) holonomic local planner. +""" + +import numpy as np + +from dimos.core import rpc +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.navigation.local_planner.local_planner import BaseLocalPlanner +from dimos.utils.transform_utils import get_distance, normalize_angle, quaternion_to_euler + + +class HolonomicLocalPlanner(BaseLocalPlanner): + """ + Gradient-Augmented Look-Ahead Pursuit (GLAP) holonomic local planner. + + This planner combines path following with obstacle avoidance using + costmap gradients to produce smooth holonomic velocity commands. + + Args: + lookahead_dist: Look-ahead distance in meters (default: 1.0) + k_rep: Repulsion gain for obstacle avoidance (default: 1.0) + alpha: Low-pass filter coefficient [0-1] (default: 0.5) + v_max: Maximum velocity per component in m/s (default: 0.8) + goal_tolerance: Distance threshold to consider goal reached (default: 0.5) + control_frequency: Control loop frequency in Hz (default: 10.0) + """ + + def __init__( # type: ignore[no-untyped-def] + self, + lookahead_dist: float = 1.0, + k_rep: float = 0.5, + k_angular: float = 0.75, + alpha: float = 0.5, + v_max: float = 0.8, + goal_tolerance: float = 0.5, + orientation_tolerance: float = 0.2, + control_frequency: float = 10.0, + **kwargs, + ) -> None: + """Initialize the GLAP planner with specified parameters.""" + super().__init__( + goal_tolerance=goal_tolerance, + orientation_tolerance=orientation_tolerance, + control_frequency=control_frequency, + **kwargs, + ) + + # Algorithm parameters + self.lookahead_dist = lookahead_dist + self.k_rep = k_rep + self.alpha = alpha + self.v_max = v_max + self.k_angular = k_angular + + # Previous velocity for filtering (vx, vy, vtheta) + self.v_prev = np.array([0.0, 0.0, 0.0]) + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + def compute_velocity(self) -> Twist | None: + """ + Compute velocity commands using GLAP algorithm. + + Returns: + Twist with linear and angular velocities in robot frame + """ + if self.latest_odom is None or self.latest_path is None or self.latest_costmap is None: + return None + + pose = np.array([self.latest_odom.position.x, self.latest_odom.position.y]) + + euler = quaternion_to_euler(self.latest_odom.orientation) + robot_yaw = euler.z + + path_points = [] + for pose_stamped in self.latest_path.poses: + path_points.append([pose_stamped.position.x, pose_stamped.position.y]) + + if len(path_points) == 0: + return None + + path = np.array(path_points) + + costmap = self.latest_costmap.grid + + v_follow_odom = self._compute_path_following(pose, path) + + v_rep_odom = self._compute_obstacle_repulsion(pose, costmap) + + v_odom = v_follow_odom + v_rep_odom + + # Transform velocity from odom frame to robot frame + cos_yaw = np.cos(robot_yaw) + sin_yaw = np.sin(robot_yaw) + + v_robot_x = cos_yaw * v_odom[0] + sin_yaw * v_odom[1] + v_robot_y = -sin_yaw * v_odom[0] + cos_yaw * v_odom[1] + + # Compute angular velocity + closest_idx, _ = self._find_closest_point_on_path(pose, path) + + # Check if we're near the final goal + goal_pose = self.latest_path.poses[-1] + distance_to_goal = get_distance(self.latest_odom, goal_pose) + + if distance_to_goal < self.goal_tolerance: + # Near goal - rotate to match final goal orientation + goal_euler = quaternion_to_euler(goal_pose.orientation) + desired_yaw = goal_euler.z + else: + # Not near goal - align with path direction + lookahead_point = self._find_lookahead_point(path, closest_idx) + dx = lookahead_point[0] - pose[0] + dy = lookahead_point[1] - pose[1] + desired_yaw = np.arctan2(dy, dx) + + yaw_error = normalize_angle(desired_yaw - robot_yaw) + k_angular = self.k_angular + v_theta = k_angular * yaw_error + + # Slow down linear velocity when turning + # Scale linear velocity based on angular velocity magnitude + angular_speed = abs(v_theta) + max_angular_speed = self.v_max + + # Calculate speed reduction factor (1.0 when not turning, 0.2 when at max turn rate) + turn_slowdown = 1.0 - 0.8 * min(angular_speed / max_angular_speed, 1.0) + + # Apply speed reduction to linear velocities + v_robot_x = np.clip(v_robot_x * turn_slowdown, -self.v_max, self.v_max) + v_robot_y = np.clip(v_robot_y * turn_slowdown, -self.v_max, self.v_max) + v_theta = np.clip(v_theta, -self.v_max, self.v_max) + + v_raw = np.array([v_robot_x, v_robot_y, v_theta]) + v_filtered = self.alpha * v_raw + (1 - self.alpha) * self.v_prev + self.v_prev = v_filtered + + return Twist( + linear=Vector3(v_filtered[0], v_filtered[1], 0.0), + angular=Vector3(0.0, 0.0, v_filtered[2]), + ) + + def _compute_path_following(self, pose: np.ndarray, path: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """ + Compute path following velocity using pure pursuit. + + Args: + pose: Current robot position [x, y] + path: Path waypoints as Nx2 array + + Returns: + Path following velocity vector [vx, vy] + """ + closest_idx, _ = self._find_closest_point_on_path(pose, path) + + carrot = self._find_lookahead_point(path, closest_idx) + + direction = carrot - pose + distance = np.linalg.norm(direction) + + if distance < 1e-6: + return np.zeros(2) + + v_follow = self.v_max * direction / distance + + return v_follow # type: ignore[no-any-return] + + def _compute_obstacle_repulsion(self, pose: np.ndarray, costmap: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """ + Compute obstacle repulsion velocity from costmap gradient. + + Args: + pose: Current robot position [x, y] + costmap: 2D costmap array + + Returns: + Repulsion velocity vector [vx, vy] + """ + grid_point = self.latest_costmap.world_to_grid(pose) # type: ignore[union-attr] + grid_x = int(grid_point.x) + grid_y = int(grid_point.y) + + height, width = costmap.shape + if not (1 <= grid_x < width - 1 and 1 <= grid_y < height - 1): + return np.zeros(2) + + # Compute gradient using central differences + # Note: costmap is in row-major order (y, x) + gx = (costmap[grid_y, grid_x + 1] - costmap[grid_y, grid_x - 1]) / ( + 2.0 * self.latest_costmap.resolution # type: ignore[union-attr] + ) + gy = (costmap[grid_y + 1, grid_x] - costmap[grid_y - 1, grid_x]) / ( + 2.0 * self.latest_costmap.resolution # type: ignore[union-attr] + ) + + # Gradient points towards higher cost, so negate for repulsion + v_rep = -self.k_rep * np.array([gx, gy]) + + return v_rep + + def _find_closest_point_on_path( + self, + pose: np.ndarray, # type: ignore[type-arg] + path: np.ndarray, # type: ignore[type-arg] + ) -> tuple[int, np.ndarray]: # type: ignore[type-arg] + """ + Find the closest point on the path to current pose. + + Args: + pose: Current position [x, y] + path: Path waypoints as Nx2 array + + Returns: + Tuple of (closest_index, closest_point) + """ + distances = np.linalg.norm(path - pose, axis=1) + closest_idx = np.argmin(distances) + return closest_idx, path[closest_idx] # type: ignore[return-value] + + def _find_lookahead_point(self, path: np.ndarray, start_idx: int) -> np.ndarray: # type: ignore[type-arg] + """ + Find look-ahead point on path at specified distance. + + Args: + path: Path waypoints as Nx2 array + start_idx: Starting index for search + + Returns: + Look-ahead point [x, y] + """ + accumulated_dist = 0.0 + + for i in range(start_idx, len(path) - 1): + segment_dist = np.linalg.norm(path[i + 1] - path[i]) + + if accumulated_dist + segment_dist >= self.lookahead_dist: + remaining_dist = self.lookahead_dist - accumulated_dist + t = remaining_dist / segment_dist + carrot = path[i] + t * (path[i + 1] - path[i]) + return carrot # type: ignore[no-any-return] + + accumulated_dist += segment_dist # type: ignore[assignment] + + return path[-1] # type: ignore[no-any-return] + + def _clip(self, v: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """Instance method to clip velocity with access to v_max.""" + return np.clip(v, -self.v_max, self.v_max) + + +holonomic_local_planner = HolonomicLocalPlanner.blueprint + +__all__ = ["HolonomicLocalPlanner", "holonomic_local_planner"] diff --git a/dimos/navigation/local_planner/local_planner.py b/dimos/navigation/local_planner/local_planner.py new file mode 100644 index 0000000000..0d798596b1 --- /dev/null +++ b/dimos/navigation/local_planner/local_planner.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base Local Planner Module for robot navigation. +Subscribes to local costmap, odometry, and path, publishes movement commands. +""" + +from abc import abstractmethod +import threading +import time + +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Twist +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import get_distance, normalize_angle, quaternion_to_euler + +logger = setup_logger() + + +class BaseLocalPlanner(Module): + """ + local planner module for obstacle avoidance and path following. + + Subscribes to: + - /local_costmap: Local occupancy grid for obstacle detection + - /odom: Robot odometry for current pose + - /path: Path to follow (continuously updated at ~1Hz) + + Publishes: + - /cmd_vel: Velocity commands for robot movement + """ + + # LCM inputs + local_costmap: In[OccupancyGrid] = None # type: ignore[assignment] + odom: In[PoseStamped] = None # type: ignore[assignment] + path: In[Path] = None # type: ignore[assignment] + + # LCM outputs + cmd_vel: Out[Twist] = None # type: ignore[assignment] + + def __init__( # type: ignore[no-untyped-def] + self, + goal_tolerance: float = 0.5, + orientation_tolerance: float = 0.2, + control_frequency: float = 10.0, + **kwargs, + ) -> None: + """Initialize the local planner module. + + Args: + goal_tolerance: Distance threshold to consider goal reached (meters) + orientation_tolerance: Orientation threshold to consider goal reached (radians) + control_frequency: Frequency for control loop (Hz) + """ + super().__init__(**kwargs) + + # Parameters + self.goal_tolerance = goal_tolerance + self.orientation_tolerance = orientation_tolerance + self.control_frequency = control_frequency + self.control_period = 1.0 / control_frequency + + # Latest data + self.latest_costmap: OccupancyGrid | None = None + self.latest_odom: PoseStamped | None = None + self.latest_path: Path | None = None + + # Control thread + self.planning_thread: threading.Thread | None = None + self.stop_planning = threading.Event() + + logger.info("Local planner module initialized") + + @rpc + def start(self) -> None: + super().start() + + unsub = self.local_costmap.subscribe(self._on_costmap) + self._disposables.add(Disposable(unsub)) + + unsub = self.odom.subscribe(self._on_odom) + self._disposables.add(Disposable(unsub)) + + unsub = self.path.subscribe(self._on_path) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + self.cancel_planning() + super().stop() + + def _on_costmap(self, msg: OccupancyGrid) -> None: + self.latest_costmap = msg + + def _on_odom(self, msg: PoseStamped) -> None: + self.latest_odom = msg + + def _on_path(self, msg: Path) -> None: + self.latest_path = msg + + if msg and len(msg.poses) > 0: + if self.planning_thread is None or not self.planning_thread.is_alive(): + self._start_planning_thread() + + def _start_planning_thread(self) -> None: + """Start the planning thread.""" + self.stop_planning.clear() + self.planning_thread = threading.Thread(target=self._follow_path_loop, daemon=True) + self.planning_thread.start() + logger.debug("Started follow path thread") + + def _follow_path_loop(self) -> None: + """Main planning loop that runs in a separate thread.""" + while not self.stop_planning.is_set(): + if self.is_goal_reached(): + self.stop_planning.set() + stop_cmd = Twist() + self.cmd_vel.publish(stop_cmd) + break + + # Compute and publish velocity + self._plan() + + time.sleep(self.control_period) + + def _plan(self) -> None: + """Compute and publish velocity command.""" + cmd_vel = self.compute_velocity() + + if cmd_vel is not None: + self.cmd_vel.publish(cmd_vel) + + @abstractmethod + def compute_velocity(self) -> Twist | None: + """ + Compute velocity commands based on current costmap, odometry, and path. + Must be implemented by derived classes. + + Returns: + Twist message with linear and angular velocity commands, or None if no command + """ + pass + + @rpc + def is_goal_reached(self) -> bool: + """ + Check if the robot has reached the goal position and orientation. + + Returns: + True if goal is reached within tolerance, False otherwise + """ + if self.latest_odom is None or self.latest_path is None: + return False + + if len(self.latest_path.poses) == 0: + return True + + goal_pose = self.latest_path.poses[-1] + distance = get_distance(self.latest_odom, goal_pose) + + # Check distance tolerance + if distance >= self.goal_tolerance: + return False + + # Check orientation tolerance + current_euler = quaternion_to_euler(self.latest_odom.orientation) + goal_euler = quaternion_to_euler(goal_pose.orientation) + + # Calculate yaw difference and normalize to [-pi, pi] + yaw_error = normalize_angle(goal_euler.z - current_euler.z) + + return abs(yaw_error) < self.orientation_tolerance + + @rpc + def reset(self) -> None: + """Reset the local planner state, clearing the current path.""" + # Clear the latest path + self.latest_path = None + self.latest_odom = None + self.latest_costmap = None + self.cancel_planning() + logger.info("Local planner reset") + + @rpc + def cancel_planning(self) -> None: + """Stop the local planner and any running threads.""" + if self.planning_thread and self.planning_thread.is_alive(): + self.stop_planning.set() + self.planning_thread.join(timeout=1.0) + self.planning_thread = None + stop_cmd = Twist() + self.cmd_vel.publish(stop_cmd) diff --git a/dimos/navigation/local_planner/test_base_local_planner.py b/dimos/navigation/local_planner/test_base_local_planner.py new file mode 100644 index 0000000000..59ce7cbda4 --- /dev/null +++ b/dimos/navigation/local_planner/test_base_local_planner.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the GLAP (Gradient-Augmented Look-Ahead Pursuit) holonomic local planner. +""" + +import numpy as np +import pytest + +from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner + + +class TestHolonomicLocalPlanner: + """Test suite for HolonomicLocalPlanner.""" + + @pytest.fixture + def planner(self): + """Create a planner instance for testing.""" + planner = HolonomicLocalPlanner( + lookahead_dist=1.5, + k_rep=1.0, + alpha=1.0, # No filtering for deterministic tests + v_max=1.0, + goal_tolerance=0.5, + control_frequency=10.0, + ) + yield planner + # TODO: This should call `planner.stop()` but that causes errors. + # Calling just this for now to fix thread leaks. + planner._close_module() + + @pytest.fixture + def empty_costmap(self): + """Create an empty costmap (all free space).""" + costmap = OccupancyGrid( + grid=np.zeros((100, 100), dtype=np.int8), resolution=0.1, origin=Pose() + ) + costmap.origin.position.x = -5.0 + costmap.origin.position.y = -5.0 + return costmap + + def test_straight_path_no_obstacles(self, planner, empty_costmap) -> None: + """Test that planner follows straight path with no obstacles.""" + # Set current position at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create straight path along +X + path = Path() + for i in range(10): + ps = PoseStamped() + ps.position.x = float(i) + ps.position.y = 0.0 + ps.orientation.w = 1.0 # Identity quaternion + path.poses.append(ps) + planner.latest_path = path + + # Set empty costmap + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should move along +X + assert vel is not None + assert vel.linear.x > 0.9 # Close to v_max + assert abs(vel.linear.y) < 0.1 # Near zero + assert abs(vel.angular.z) < 0.1 # Small angular velocity when aligned with path + + def test_obstacle_gradient_repulsion(self, planner) -> None: + """Test that obstacle gradients create repulsive forces.""" + # Set position at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Simple path forward + path = Path() + ps = PoseStamped() + ps.position.x = 5.0 + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + # Create costmap with gradient pointing south (higher cost north) + costmap_grid = np.zeros((100, 100), dtype=np.int8) + for i in range(100): + costmap_grid[i, :] = max(0, 50 - i) # Gradient from north to south + + planner.latest_costmap = OccupancyGrid(grid=costmap_grid, resolution=0.1, origin=Pose()) + planner.latest_costmap.origin.position.x = -5.0 + planner.latest_costmap.origin.position.y = -5.0 + + # Compute velocity + vel = planner.compute_velocity() + + # Should have positive Y component (pushed north by gradient) + assert vel is not None + assert vel.linear.y > 0.1 # Repulsion pushes north + + def test_lowpass_filter(self) -> None: + """Test that low-pass filter smooths velocity commands.""" + # Create planner with alpha=0.5 for filtering + planner = HolonomicLocalPlanner( + lookahead_dist=1.0, + k_rep=0.0, # No repulsion + alpha=0.5, # 50% filtering + v_max=1.0, + ) + + # Setup similar to straight path test + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + path = Path() + ps = PoseStamped() + ps.position.x = 5.0 + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = OccupancyGrid( + grid=np.zeros((100, 100), dtype=np.int8), resolution=0.1, origin=Pose() + ) + planner.latest_costmap.origin.position.x = -5.0 + planner.latest_costmap.origin.position.y = -5.0 + + # First call - previous velocity is zero + vel1 = planner.compute_velocity() + assert vel1 is not None + + # Store first velocity + first_vx = vel1.linear.x + + # Second call - should be filtered + vel2 = planner.compute_velocity() + assert vel2 is not None + + # With alpha=0.5 and same conditions: + # v2 = 0.5 * v_raw + 0.5 * v1 + # The filtering effect should be visible + # v2 should be between v1 and the raw velocity + assert vel2.linear.x != first_vx # Should be different due to filtering + assert 0 < vel2.linear.x <= planner.v_max # Should still be positive and within limits + planner._close_module() + + def test_no_path(self, planner, empty_costmap) -> None: + """Test that planner returns None when no path is available.""" + planner.latest_odom = PoseStamped() + planner.latest_costmap = empty_costmap + planner.latest_path = Path() # Empty path + + vel = planner.compute_velocity() + assert vel is None + + def test_no_odometry(self, planner, empty_costmap) -> None: + """Test that planner returns None when no odometry is available.""" + planner.latest_odom = None + planner.latest_costmap = empty_costmap + + path = Path() + ps = PoseStamped() + ps.position.x = 1.0 + ps.position.y = 0.0 + path.poses.append(ps) + planner.latest_path = path + + vel = planner.compute_velocity() + assert vel is None + + def test_no_costmap(self, planner) -> None: + """Test that planner returns None when no costmap is available.""" + planner.latest_odom = PoseStamped() + planner.latest_costmap = None + + path = Path() + ps = PoseStamped() + ps.position.x = 1.0 + ps.position.y = 0.0 + path.poses.append(ps) + planner.latest_path = path + + vel = planner.compute_velocity() + assert vel is None + + def test_goal_reached(self, planner, empty_costmap) -> None: + """Test velocity when robot is at goal.""" + # Set robot at goal position + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 5.0 + planner.latest_odom.position.y = 0.0 + + # Path with single point at robot position + path = Path() + ps = PoseStamped() + ps.position.x = 5.0 + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should have near-zero velocity + assert vel is not None + assert abs(vel.linear.x) < 0.1 + assert abs(vel.linear.y) < 0.1 + + def test_velocity_saturation(self, planner, empty_costmap) -> None: + """Test that velocities are capped at v_max.""" + # Set robot far from goal to maximize commanded velocity + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create path far away + path = Path() + ps = PoseStamped() + ps.position.x = 100.0 # Very far + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Velocity should be saturated at v_max + assert vel is not None + assert abs(vel.linear.x) <= planner.v_max + 0.01 # Small tolerance + assert abs(vel.linear.y) <= planner.v_max + 0.01 + assert abs(vel.angular.z) <= planner.v_max + 0.01 + + def test_lookahead_interpolation(self, planner, empty_costmap) -> None: + """Test that lookahead point is correctly interpolated on path.""" + # Set robot at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create path with waypoints closer than lookahead distance + path = Path() + for i in range(5): + ps = PoseStamped() + ps.position.x = i * 0.5 # 0.5m spacing + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should move forward along path + assert vel is not None + assert vel.linear.x > 0.5 # Moving forward + assert abs(vel.linear.y) < 0.1 # Staying on path + + def test_curved_path_following(self, planner, empty_costmap) -> None: + """Test following a curved path.""" + # Set robot at origin + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + + # Create curved path (quarter circle) + path = Path() + for i in range(10): + angle = (np.pi / 2) * (i / 9.0) # 0 to 90 degrees + ps = PoseStamped() + ps.position.x = 2.0 * np.cos(angle) + ps.position.y = 2.0 * np.sin(angle) + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Should have both X and Y components for curved motion + assert vel is not None + # Test general behavior: should be moving (not exact values) + assert vel.linear.x > 0 # Moving forward (any positive value) + assert vel.linear.y > 0 # Turning left (any positive value) + # Ensure we have meaningful movement, not just noise + total_linear = np.sqrt(vel.linear.x**2 + vel.linear.y**2) + assert total_linear > 0.1 # Some reasonable movement + + def test_robot_frame_transformation(self, empty_costmap) -> None: + """Test that velocities are correctly transformed to robot frame.""" + # Create planner with no filtering for deterministic test + planner = HolonomicLocalPlanner( + lookahead_dist=1.0, + k_rep=0.0, # No repulsion + alpha=1.0, # No filtering + v_max=1.0, + ) + + # Set robot at origin but rotated 90 degrees (facing +Y in odom frame) + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + # Quaternion for 90 degree rotation around Z + planner.latest_odom.orientation = Quaternion(0.0, 0.0, 0.7071068, 0.7071068) + + # Create path along +X axis in odom frame + path = Path() + for i in range(5): + ps = PoseStamped() + ps.position.x = float(i) + ps.position.y = 0.0 + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Robot is facing +Y, path is along +X + # So in robot frame: forward is +Y direction, path is to the right + assert vel is not None + # Test relative magnitudes and signs rather than exact values + # Path is to the right, so Y velocity should be negative + assert vel.linear.y < 0 # Should move right (negative Y in robot frame) + # Should turn to align with path + assert vel.angular.z < 0 # Should turn right (negative angular velocity) + # X velocity should be relatively small compared to Y + assert abs(vel.linear.x) < abs(vel.linear.y) # Lateral movement dominates + planner._close_module() + + def test_angular_velocity_computation(self, empty_costmap) -> None: + """Test that angular velocity is computed to align with path.""" + planner = HolonomicLocalPlanner( + lookahead_dist=2.0, + k_rep=0.0, # No repulsion + alpha=1.0, # No filtering + v_max=1.0, + ) + + # Robot at origin facing +X + planner.latest_odom = PoseStamped() + planner.latest_odom.position.x = 0.0 + planner.latest_odom.position.y = 0.0 + planner.latest_odom.orientation.w = 1.0 # Identity quaternion + + # Create path at 45 degrees + path = Path() + for i in range(5): + ps = PoseStamped() + ps.position.x = float(i) + ps.position.y = float(i) # Diagonal path + ps.orientation.w = 1.0 + path.poses.append(ps) + planner.latest_path = path + + planner.latest_costmap = empty_costmap + + # Compute velocity + vel = planner.compute_velocity() + + # Path is at 45 degrees, robot facing 0 degrees + # Should have positive angular velocity to turn left + assert vel is not None + # Test general behavior without exact thresholds + assert vel.linear.x > 0 # Moving forward (any positive value) + assert vel.linear.y > 0 # Moving left (holonomic, any positive value) + assert vel.angular.z > 0 # Turning left (positive angular velocity) + # Verify the robot is actually moving with reasonable speed + total_linear = np.sqrt(vel.linear.x**2 + vel.linear.y**2) + assert total_linear > 0.1 # Some meaningful movement + # Since path is diagonal, X and Y should be similar magnitude + assert ( + abs(vel.linear.x - vel.linear.y) < max(vel.linear.x, vel.linear.y) * 0.5 + ) # Within 50% of each other + planner._close_module() diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py new file mode 100644 index 0000000000..ffb5f9694b --- /dev/null +++ b/dimos/navigation/rosnav.py @@ -0,0 +1,510 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +NavBot class for navigation-related functionality. +Encapsulates ROS bridge and topic remapping for Unitree robots. +""" + +from collections.abc import Generator +from dataclasses import dataclass, field +import logging +import threading +import time + +from geometry_msgs.msg import ( # type: ignore[attr-defined] + PointStamped as ROSPointStamped, + PoseStamped as ROSPoseStamped, + TwistStamped as ROSTwistStamped, +) +from nav_msgs.msg import Path as ROSPath # type: ignore[attr-defined] +import rclpy +from rclpy.node import Node +from reactivex import operators as ops +from reactivex.subject import Subject +from sensor_msgs.msg import ( # type: ignore[attr-defined] + Joy as ROSJoy, + PointCloud2 as ROSPointCloud2, +) +from std_msgs.msg import Bool as ROSBool, Int8 as ROSInt8 # type: ignore[attr-defined] +from tf2_msgs.msg import TFMessage as ROSTFMessage # type: ignore[attr-defined] + +from dimos import spec +from dimos.agents2 import Reducer, Stream, skill # type: ignore[attr-defined] +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.core.module import ModuleConfig +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + Vector3, +) +from dimos.msgs.nav_msgs import Path +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.std_msgs import Bool +from dimos.msgs.tf2_msgs.TFMessage import TFMessage +from dimos.navigation.base import NavigationInterface, NavigationState +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import euler_to_quaternion + +logger = setup_logger(level=logging.INFO) + + +@dataclass +class Config(ModuleConfig): + local_pointcloud_freq: float = 2.0 + global_pointcloud_freq: float = 1.0 + sensor_to_base_link_transform: Transform = field( + default_factory=lambda: Transform(frame_id="sensor", child_frame_id="base_link") + ) + + +class ROSNav( + Module, NavigationInterface, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlanner +): + """Adapter translating between ROS2 navigation stack and Dimos' navigation interface. + + Connects Dimos to ROS2 (Robot Operating System 2) with dual interfaces: + (i) `NavigationInterface` for other `Module`s to call navigation methods via RPC (see the `@rpc`-decorated methods) and + (ii) `@skill`-decorated methods for LLM agents. + + Provides goal-based navigation with timeout/cancellation, point cloud streaming, + transform broadcasting (map/world/base_link/sensor frames), and thread-safe + state management. + + Configuration: + `local_pointcloud_freq` (float): Rate at which *local* pointcloud is forwarded to downstream modules (default: 2.0 Hz). + `global_pointcloud_freq` (float): Rate at which *global* pointcloud is forwarded to downstream modules (default: 1.0 Hz). + `sensor_to_base_link_transform` (Transform): Static transform from sensor frame to + base_link. Required for coordinate frame lookups between map and base_link + (e.g., `tf.get("map", "base_link")`). + """ + + config: Config + default_config = Config + + goal_req: In[PoseStamped] = None # type: ignore + + pointcloud: Out[PointCloud2] = None # type: ignore + global_pointcloud: Out[PointCloud2] = None # type: ignore + + goal_active: Out[PoseStamped] = None # type: ignore + path_active: Out[Path] = None # type: ignore + cmd_vel: Out[Twist] = None # type: ignore + + # Using RxPY Subjects for reactive data flow instead of storing state + _local_pointcloud_subject: Subject # type: ignore[type-arg] + _global_pointcloud_subject: Subject # type: ignore[type-arg] + + _current_position_running: bool = False + _spin_thread: threading.Thread | None = None + _goal_reach: bool | None = None + + # Navigation state tracking for NavigationInterface + _navigation_state: NavigationState = NavigationState.IDLE + _state_lock: threading.Lock + _navigation_thread: threading.Thread | None = None + _current_goal: PoseStamped | None = None + _goal_reached: bool = False + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + + # Initialize RxPY Subjects for streaming data + self._local_pointcloud_subject = Subject() + self._global_pointcloud_subject = Subject() + + # Initialize state tracking + self._state_lock = threading.Lock() + self._navigation_state = NavigationState.IDLE + self._goal_reached = False + + if not rclpy.ok(): # type: ignore[attr-defined] + rclpy.init() + + self._node = Node("navigation_module") + + # ROS2 Publishers + self.goal_pose_pub = self._node.create_publisher(ROSPoseStamped, "/goal_pose", 10) + self.cancel_goal_pub = self._node.create_publisher(ROSBool, "/cancel_goal", 10) + self.soft_stop_pub = self._node.create_publisher(ROSInt8, "/stop", 10) + self.joy_pub = self._node.create_publisher(ROSJoy, "/joy", 10) + + # ROS2 Subscribers + self.goal_reached_sub = self._node.create_subscription( + ROSBool, "/goal_reached", self._on_ros_goal_reached, 10 + ) + self.cmd_vel_sub = self._node.create_subscription( + ROSTwistStamped, "/cmd_vel", self._on_ros_cmd_vel, 10 + ) + self.goal_waypoint_sub = self._node.create_subscription( + ROSPointStamped, "/way_point", self._on_ros_goal_waypoint, 10 + ) + self.registered_scan_sub = self._node.create_subscription( + ROSPointCloud2, "/registered_scan", self._on_ros_registered_scan, 10 + ) + + self.global_pointcloud_sub = self._node.create_subscription( + ROSPointCloud2, "/terrain_map_ext", self._on_ros_global_pointcloud, 10 + ) + + self.path_sub = self._node.create_subscription(ROSPath, "/path", self._on_ros_path, 10) + self.tf_sub = self._node.create_subscription(ROSTFMessage, "/tf", self._on_ros_tf, 10) + + logger.info("NavigationModule initialized with ROS2 node") + + @rpc + def start(self) -> None: + self._running = True + + self._disposables.add( + self._local_pointcloud_subject.pipe( + ops.sample(1.0 / self.config.local_pointcloud_freq), # Sample at desired frequency + ops.map(lambda msg: PointCloud2.from_ros_msg(msg)), # type: ignore[arg-type] + ).subscribe( + on_next=self.pointcloud.publish, + on_error=lambda e: logger.error(f"Lidar stream error: {e}"), + ) + ) + + self._disposables.add( + self._global_pointcloud_subject.pipe( + ops.sample(1.0 / self.config.global_pointcloud_freq), # Sample at desired frequency + ops.map(lambda msg: PointCloud2.from_ros_msg(msg)), # type: ignore[arg-type] + ).subscribe( + on_next=self.global_pointcloud.publish, + on_error=lambda e: logger.error(f"Map stream error: {e}"), + ) + ) + + # Create and start the spin thread for ROS2 node spinning + self._spin_thread = threading.Thread( + target=self._spin_node, daemon=True, name="ROS2SpinThread" + ) + self._spin_thread.start() + + self.goal_req.subscribe(self._on_goal_pose) + logger.info("NavigationModule started with ROS2 spinning and RxPY streams") + + def _spin_node(self) -> None: + while self._running and rclpy.ok(): # type: ignore[attr-defined] + try: + rclpy.spin_once(self._node, timeout_sec=0.1) + except Exception as e: + if self._running: + logger.error(f"ROS2 spin error: {e}") + + def _on_ros_goal_reached(self, msg: ROSBool) -> None: + self._goal_reach = msg.data + if msg.data: + with self._state_lock: + self._goal_reached = True + self._navigation_state = NavigationState.IDLE + + def _on_ros_goal_waypoint(self, msg: ROSPointStamped) -> None: + dimos_pose = PoseStamped( + ts=time.time(), + frame_id=msg.header.frame_id, + position=Vector3(msg.point.x, msg.point.y, msg.point.z), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + self.goal_active.publish(dimos_pose) + + def _on_ros_cmd_vel(self, msg: ROSTwistStamped) -> None: + self.cmd_vel.publish(Twist.from_ros_msg(msg.twist)) + + def _on_ros_registered_scan(self, msg: ROSPointCloud2) -> None: + self._local_pointcloud_subject.on_next(msg) + + def _on_ros_global_pointcloud(self, msg: ROSPointCloud2) -> None: + self._global_pointcloud_subject.on_next(msg) + + def _on_ros_path(self, msg: ROSPath) -> None: + dimos_path = Path.from_ros_msg(msg) + dimos_path.frame_id = "base_link" + self.path_active.publish(dimos_path) + + def _on_ros_tf(self, msg: ROSTFMessage) -> None: + ros_tf = TFMessage.from_ros_msg(msg) + + map_to_world_tf = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=euler_to_quaternion(Vector3(0.0, 0.0, 0.0)), + frame_id="map", + child_frame_id="world", + ts=time.time(), + ) + + self.tf.publish( + self.config.sensor_to_base_link_transform.now(), + map_to_world_tf, + *ros_tf.transforms, + ) + + def _on_goal_pose(self, msg: PoseStamped) -> None: + self.navigate_to(msg) + + def _on_cancel_goal(self, msg: Bool) -> None: + if msg.data: + self.stop() + + def _set_autonomy_mode(self) -> None: + joy_msg = ROSJoy() # type: ignore[no-untyped-call] + joy_msg.axes = [ + 0.0, # axis 0 + 0.0, # axis 1 + -1.0, # axis 2 + 0.0, # axis 3 + 1.0, # axis 4 + 1.0, # axis 5 + 0.0, # axis 6 + 0.0, # axis 7 + ] + joy_msg.buttons = [ + 0, # button 0 + 0, # button 1 + 0, # button 2 + 0, # button 3 + 0, # button 4 + 0, # button 5 + 0, # button 6 + 1, # button 7 - controls autonomy mode + 0, # button 8 + 0, # button 9 + 0, # button 10 + ] + self.joy_pub.publish(joy_msg) + logger.info("Setting autonomy mode via Joy message") + + @skill(stream=Stream.passive, reducer=Reducer.latest) # type: ignore[arg-type] + def current_position(self): # type: ignore[no-untyped-def] + """passively stream the current position of the robot every second""" + if self._current_position_running: + return "already running" + while True: + self._current_position_running = True + time.sleep(1.0) + tf = self.tf.get("map", "base_link") + if not tf: + continue + yield f"current position {tf.translation.x}, {tf.translation.y}" + + @skill(stream=Stream.call_agent, reducer=Reducer.string) # type: ignore[arg-type] + def goto(self, x: float, y: float): # type: ignore[no-untyped-def] + """ + move the robot in relative coordinates + x is forward, y is left + + goto(1, 0) will move the robot forward by 1 meter + """ + pose_to = PoseStamped( + position=Vector3(x, y, 0), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + frame_id="base_link", + ts=time.time(), + ) + + yield "moving, please wait..." + self.navigate_to(pose_to) + yield "arrived" + + @skill(stream=Stream.call_agent, reducer=Reducer.string) # type: ignore[arg-type] + def goto_global(self, x: float, y: float) -> Generator[str, None, None]: + """ + go to coordinates x,y in the map frame + 0,0 is your starting position + """ + target = PoseStamped( + ts=time.time(), + frame_id="map", + position=Vector3(x, y, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + ) + + pos = self.tf.get("base_link", "map").translation + + yield f"moving from {pos.x:.2f}, {pos.y:.2f} to {x:.2f}, {y:.2f}, please wait..." + + self.navigate_to(target) + + yield "arrived to {x:.2f}, {y:.2f}" + + @rpc + def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: + """ + Navigate to a target pose by publishing to ROS topics. + + Args: + pose: Target pose to navigate to + timeout: Maximum time to wait for goal (seconds) + + Returns: + True if navigation was successful + """ + logger.info( + f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f} @ {pose.frame_id})" + ) + + self._goal_reach = None + self._set_autonomy_mode() + + # Enable soft stop (0 = enable) + soft_stop_msg = ROSInt8() # type: ignore[no-untyped-call] + soft_stop_msg.data = 0 + self.soft_stop_pub.publish(soft_stop_msg) + + ros_pose = pose.to_ros_msg() + self.goal_pose_pub.publish(ros_pose) + + # Wait for goal to be reached + start_time = time.time() + while time.time() - start_time < timeout: + if self._goal_reach is not None: + soft_stop_msg.data = 2 + self.soft_stop_pub.publish(soft_stop_msg) + return self._goal_reach + time.sleep(0.1) + + self.stop_navigation() + logger.warning(f"Navigation timed out after {timeout} seconds") + return False + + @rpc + def stop_navigation(self) -> bool: + """ + Stop current navigation by publishing to ROS topics. + + Returns: + True if stop command was sent successfully + """ + logger.info("Stopping navigation") + + cancel_msg = ROSBool() # type: ignore[no-untyped-call] + cancel_msg.data = True + self.cancel_goal_pub.publish(cancel_msg) + + soft_stop_msg = ROSInt8() # type: ignore[no-untyped-call] + soft_stop_msg.data = 2 + self.soft_stop_pub.publish(soft_stop_msg) + + with self._state_lock: + self._navigation_state = NavigationState.IDLE + self._current_goal = None + self._goal_reached = False + + return True + + @rpc + def set_goal(self, goal: PoseStamped) -> bool: + """Set a new navigation goal (non-blocking).""" + with self._state_lock: + self._current_goal = goal + self._goal_reached = False + self._navigation_state = NavigationState.FOLLOWING_PATH + + # Start navigation in a separate thread to make it non-blocking + if self._navigation_thread and self._navigation_thread.is_alive(): + logger.warning("Previous navigation still running, cancelling") + self.stop_navigation() + self._navigation_thread.join(timeout=1.0) + + self._navigation_thread = threading.Thread( + target=self._navigate_to_goal_async, + args=(goal,), + daemon=True, + name="ROSNavNavigationThread", + ) + self._navigation_thread.start() + + return True + + def _navigate_to_goal_async(self, goal: PoseStamped) -> None: + """Internal method to handle navigation in a separate thread.""" + try: + result = self.navigate_to(goal, timeout=60.0) + with self._state_lock: + self._goal_reached = result + self._navigation_state = NavigationState.IDLE + except Exception as e: + logger.error(f"Navigation failed: {e}") + with self._state_lock: + self._goal_reached = False + self._navigation_state = NavigationState.IDLE + + @rpc + def get_state(self) -> NavigationState: + """Get the current state of the navigator.""" + with self._state_lock: + return self._navigation_state + + @rpc + def is_goal_reached(self) -> bool: + """Check if the current goal has been reached.""" + with self._state_lock: + return self._goal_reached + + @rpc + def cancel_goal(self) -> bool: + """Cancel the current navigation goal.""" + + with self._state_lock: + had_goal = self._current_goal is not None + + if had_goal: + self.stop_navigation() + + return had_goal + + @rpc + def stop(self) -> None: + """Stop the navigation module and clean up resources.""" + self.stop_navigation() + try: + self._running = False + + self._local_pointcloud_subject.on_completed() + self._global_pointcloud_subject.on_completed() + + if self._spin_thread and self._spin_thread.is_alive(): + self._spin_thread.join(timeout=1.0) + + if hasattr(self, "_node") and self._node: + self._node.destroy_node() # type: ignore[no-untyped-call] + + except Exception as e: + logger.error(f"Error during shutdown: {e}") + finally: + super().stop() + + +ros_nav = ROSNav.blueprint + + +def deploy(dimos: DimosCluster): # type: ignore[no-untyped-def] + nav = dimos.deploy(ROSNav) # type: ignore[attr-defined] + + nav.pointcloud.transport = pSHMTransport("/lidar") + nav.global_pointcloud.transport = pSHMTransport("/map") + nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) + nav.goal_active.transport = LCMTransport("/goal_active", PoseStamped) + nav.path_active.transport = LCMTransport("/path_active", Path) + nav.cmd_vel.transport = LCMTransport("/cmd_vel", Twist) + + nav.start() + return nav + + +__all__ = ["ROSNav", "deploy", "ros_nav"] diff --git a/dimos/navigation/visual/query.py b/dimos/navigation/visual/query.py new file mode 100644 index 0000000000..2e0951951e --- /dev/null +++ b/dimos/navigation/visual/query.py @@ -0,0 +1,44 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.models.qwen.video_query import BBox +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image +from dimos.utils.generic import extract_json_from_llm_response + + +def get_object_bbox_from_image( + vl_model: VlModel, image: Image, object_description: str +) -> BBox | None: + prompt = ( + f"Look at this image and find the '{object_description}'. " + "Return ONLY a JSON object with format: {'name': 'object_name', 'bbox': [x1, y1, x2, y2]} " + "where x1,y1 is the top-left and x2,y2 is the bottom-right corner of the bounding box. If not found, return None." + ) + + response = vl_model.query(image, prompt) + + result = extract_json_from_llm_response(response) + if not result: + return None + + try: + ret = tuple(map(float, result["bbox"])) + if len(ret) == 4: + return ret + except Exception: + pass + + return None diff --git a/dimos/manipulation/imitation/act.py b/dimos/perception/__init__.py similarity index 100% rename from dimos/manipulation/imitation/act.py rename to dimos/perception/__init__.py diff --git a/dimos/perception/common/__init__.py b/dimos/perception/common/__init__.py new file mode 100644 index 0000000000..67481bc449 --- /dev/null +++ b/dimos/perception/common/__init__.py @@ -0,0 +1,3 @@ +from .detection2d_tracker import get_tracked_results, target2dTracker +from .ibvs import * +from .utils import * diff --git a/dimos/perception/common/detection2d_tracker.py b/dimos/perception/common/detection2d_tracker.py new file mode 100644 index 0000000000..9ff36be8a1 --- /dev/null +++ b/dimos/perception/common/detection2d_tracker.py @@ -0,0 +1,396 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from collections.abc import Sequence + +import numpy as np + + +def compute_iou(bbox1, bbox2): # type: ignore[no-untyped-def] + """ + Compute Intersection over Union (IoU) of two bounding boxes. + Each bbox is [x1, y1, x2, y2]. + """ + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + + inter_area = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + + union_area = area1 + area2 - inter_area + if union_area == 0: + return 0 + return inter_area / union_area + + +def get_tracked_results(tracked_targets): # type: ignore[no-untyped-def] + """ + Extract tracked results from a list of target2d objects. + + Args: + tracked_targets (list[target2d]): List of target2d objects (published targets) + returned by the tracker's update() function. + + Returns: + tuple: (tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names) + where each is a list of the corresponding attribute from each target. + """ + tracked_masks = [] + tracked_bboxes = [] + tracked_track_ids = [] + tracked_probs = [] + tracked_names = [] + + for target in tracked_targets: + # Extract the latest values stored in each target. + tracked_masks.append(target.latest_mask) + tracked_bboxes.append(target.latest_bbox) + # Here we use the most recent detection's track ID. + tracked_track_ids.append(target.target_id) + # Use the latest probability from the history. + tracked_probs.append(target.score) + # Use the stored name (if any). If not available, you can use a default value. + tracked_names.append(target.name) + + return tracked_masks, tracked_bboxes, tracked_track_ids, tracked_probs, tracked_names + + +class target2d: + """ + Represents a tracked 2D target. + Stores the latest bounding box and mask along with a short history of track IDs, + detection probabilities, and computed texture values. + """ + + def __init__( # type: ignore[no-untyped-def] + self, + initial_mask, + initial_bbox, + track_id, + prob: float, + name: str, + texture_value, + target_id, + history_size: int = 10, + ) -> None: + """ + Args: + initial_mask (torch.Tensor): Latest segmentation mask. + initial_bbox (list): Bounding box in [x1, y1, x2, y2] format. + track_id (int): Detection’s track ID (may be -1 if not provided). + prob (float): Detection probability. + name (str): Object class name. + texture_value (float): Computed average texture value for this detection. + target_id (int): Unique identifier assigned by the tracker. + history_size (int): Maximum number of frames to keep in the history. + """ + self.target_id = target_id + self.latest_mask = initial_mask + self.latest_bbox = initial_bbox + self.name = name + self.score = 1.0 + + self.track_id = track_id + self.probs_history = deque(maxlen=history_size) # type: ignore[var-annotated] + self.texture_history = deque(maxlen=history_size) # type: ignore[var-annotated] + + self.frame_count = deque(maxlen=history_size) # type: ignore[var-annotated] # Total frames this target has been seen. + self.missed_frames = 0 # Consecutive frames when no detection was assigned. + self.history_size = history_size + + def update(self, mask, bbox, track_id, prob: float, name: str, texture_value) -> None: # type: ignore[no-untyped-def] + """ + Update the target with a new detection. + """ + self.latest_mask = mask + self.latest_bbox = bbox + self.name = name + + self.track_id = track_id + self.probs_history.append(prob) + self.texture_history.append(texture_value) + + self.frame_count.append(1) + self.missed_frames = 0 + + def mark_missed(self) -> None: + """ + Increment the count of consecutive frames where this target was not updated. + """ + self.missed_frames += 1 + self.frame_count.append(0) + + def compute_score( # type: ignore[no-untyped-def] + self, + frame_shape, + min_area_ratio, + max_area_ratio, + texture_range=(0.0, 1.0), + border_safe_distance: int = 50, + weights=None, + ): + """ + Compute a combined score for the target based on several factors. + + Factors: + - **Detection probability:** Average over recent frames. + - **Temporal stability:** How consistently the target has appeared. + - **Texture quality:** Normalized using the provided min and max values. + - **Border proximity:** Computed from the minimum distance from the bbox to the frame edges. + - **Size:** How the object's area (relative to the frame) compares to acceptable bounds. + + Args: + frame_shape (tuple): (height, width) of the frame. + min_area_ratio (float): Minimum acceptable ratio (bbox area / frame area). + max_area_ratio (float): Maximum acceptable ratio. + texture_range (tuple): (min_texture, max_texture) expected values. + border_safe_distance (float): Distance (in pixels) considered safe from the border. + weights (dict): Weights for each component. Expected keys: + 'prob', 'temporal', 'texture', 'border', and 'size'. + + Returns: + float: The combined (normalized) score in the range [0, 1]. + """ + # Default weights if none provided. + if weights is None: + weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0} + + h, w = frame_shape + x1, y1, x2, y2 = self.latest_bbox + bbox_area = (x2 - x1) * (y2 - y1) + frame_area = w * h + area_ratio = bbox_area / frame_area + + # Detection probability factor. + avg_prob = np.mean(self.probs_history) + # Temporal stability factor: normalized by history size. + temporal_stability = np.mean(self.frame_count) + # Texture factor: normalize average texture using the provided range. + avg_texture = np.mean(self.texture_history) if self.texture_history else 0.0 + min_texture, max_texture = texture_range + if max_texture == min_texture: + normalized_texture = avg_texture + else: + normalized_texture = (avg_texture - min_texture) / (max_texture - min_texture) + normalized_texture = max(0.0, min(normalized_texture, 1.0)) + + # Border factor: compute the minimum distance from the bbox to any frame edge. + left_dist = x1 + top_dist = y1 + right_dist = w - x2 + min_border_dist = min(left_dist, top_dist, right_dist) + # Normalize the border distance: full score (1.0) if at least border_safe_distance away. + border_factor = min(1.0, min_border_dist / border_safe_distance) + + # Size factor: penalize objects that are too small or too big. + if area_ratio < min_area_ratio: + size_factor = area_ratio / min_area_ratio + elif area_ratio > max_area_ratio: + # Here we compute a linear penalty if the area exceeds max_area_ratio. + if 1 - max_area_ratio > 0: + size_factor = max(0, (1 - area_ratio) / (1 - max_area_ratio)) + else: + size_factor = 0.0 + else: + size_factor = 1.0 + + # Combine factors using a weighted sum (each factor is assumed in [0, 1]). + w_prob = weights.get("prob", 1.0) + w_temporal = weights.get("temporal", 1.0) + w_texture = weights.get("texture", 1.0) + w_border = weights.get("border", 1.0) + w_size = weights.get("size", 1.0) + total_weight = w_prob + w_temporal + w_texture + w_border + w_size + + # print(f"track_id: {self.target_id}, avg_prob: {avg_prob:.2f}, temporal_stability: {temporal_stability:.2f}, normalized_texture: {normalized_texture:.2f}, border_factor: {border_factor:.2f}, size_factor: {size_factor:.2f}") + + final_score = ( + w_prob * avg_prob + + w_temporal * temporal_stability + + w_texture * normalized_texture + + w_border * border_factor + + w_size * size_factor + ) / total_weight + + self.score = final_score + + return final_score + + +class target2dTracker: + """ + Tracker that maintains a history of targets across frames. + New segmentation detections (frame, masks, bboxes, track_ids, probabilities, + and computed texture values) are matched to existing targets or used to create new ones. + + The tracker uses a scoring system that incorporates: + - **Detection probability** + - **Temporal stability** + - **Texture quality** (normalized within a specified range) + - **Proximity to image borders** (a continuous penalty based on the distance) + - **Object size** relative to the frame + + Targets are published if their score exceeds the start threshold and are removed if their score + falls below the stop threshold or if they are missed for too many consecutive frames. + """ + + def __init__( # type: ignore[no-untyped-def] + self, + history_size: int = 10, + score_threshold_start: float = 0.5, + score_threshold_stop: float = 0.3, + min_frame_count: int = 10, + max_missed_frames: int = 3, + min_area_ratio: float = 0.001, + max_area_ratio: float = 0.1, + texture_range=(0.0, 1.0), + border_safe_distance: int = 50, + weights=None, + ) -> None: + """ + Args: + history_size (int): Maximum history length (number of frames) per target. + score_threshold_start (float): Minimum score for a target to be published. + score_threshold_stop (float): If a target’s score falls below this, it is removed. + min_frame_count (int): Minimum number of frames a target must be seen to be published. + max_missed_frames (int): Maximum consecutive frames a target can be missing before deletion. + min_area_ratio (float): Minimum acceptable bbox area relative to the frame. + max_area_ratio (float): Maximum acceptable bbox area relative to the frame. + texture_range (tuple): (min_texture, max_texture) expected values. + border_safe_distance (float): Distance (in pixels) considered safe from the border. + weights (dict): Weights for the scoring components (keys: 'prob', 'temporal', + 'texture', 'border', 'size'). + """ + self.history_size = history_size + self.score_threshold_start = score_threshold_start + self.score_threshold_stop = score_threshold_stop + self.min_frame_count = min_frame_count + self.max_missed_frames = max_missed_frames + self.min_area_ratio = min_area_ratio + self.max_area_ratio = max_area_ratio + self.texture_range = texture_range + self.border_safe_distance = border_safe_distance + # Default weights if none are provided. + if weights is None: + weights = {"prob": 1.0, "temporal": 1.0, "texture": 1.0, "border": 1.0, "size": 1.0} + self.weights = weights + + self.targets = {} # type: ignore[var-annotated] # Dictionary mapping target_id -> target2d instance. + self.next_target_id = 0 + + def update( # type: ignore[no-untyped-def] + self, + frame, + masks, + bboxes, + track_ids, + probs: Sequence[float], + names: Sequence[str], + texture_values, + ): + """ + Update the tracker with new detections from the current frame. + + Args: + frame (np.ndarray): Current BGR frame. + masks (list[torch.Tensor]): List of segmentation masks. + bboxes (list): List of bounding boxes [x1, y1, x2, y2]. + track_ids (list): List of detection track IDs. + probs (list): List of detection probabilities. + names (list): List of class names. + texture_values (list): List of computed texture values. + + Returns: + published_targets (list[target2d]): Targets that are active and have scores above + the start threshold. + """ + updated_target_ids = set() + frame_shape = frame.shape[:2] # (height, width) + + # For each detection, try to match with an existing target. + for mask, bbox, det_tid, prob, name, texture in zip( + masks, bboxes, track_ids, probs, names, texture_values, strict=False + ): + matched_target = None + + # First, try matching by detection track ID if valid. + if det_tid != -1: + for target in self.targets.values(): + if target.track_id == det_tid: + matched_target = target + break + + # Otherwise, try matching using IoU. + if matched_target is None: + best_iou = 0 + for target in self.targets.values(): + iou = compute_iou(bbox, target.latest_bbox) # type: ignore[no-untyped-call] + if iou > 0.5 and iou > best_iou: + best_iou = iou + matched_target = target + + # Update existing target or create a new one. + if matched_target is not None: + matched_target.update(mask, bbox, det_tid, prob, name, texture) + updated_target_ids.add(matched_target.target_id) + else: + new_target = target2d( + mask, bbox, det_tid, prob, name, texture, self.next_target_id, self.history_size + ) + self.targets[self.next_target_id] = new_target + updated_target_ids.add(self.next_target_id) + self.next_target_id += 1 + + # Mark targets that were not updated. + for target_id, target in list(self.targets.items()): + if target_id not in updated_target_ids: + target.mark_missed() + if target.missed_frames > self.max_missed_frames: + del self.targets[target_id] + continue # Skip further checks for this target. + # Remove targets whose score falls below the stop threshold. + score = target.compute_score( + frame_shape, + self.min_area_ratio, + self.max_area_ratio, + texture_range=self.texture_range, + border_safe_distance=self.border_safe_distance, + weights=self.weights, + ) + if score < self.score_threshold_stop: + del self.targets[target_id] + + # Publish targets with scores above the start threshold. + published_targets = [] + for target in self.targets.values(): + score = target.compute_score( + frame_shape, + self.min_area_ratio, + self.max_area_ratio, + texture_range=self.texture_range, + border_safe_distance=self.border_safe_distance, + weights=self.weights, + ) + if ( + score >= self.score_threshold_start + and sum(target.frame_count) >= self.min_frame_count + and target.missed_frames <= 5 + ): + published_targets.append(target) + + return published_targets diff --git a/dimos/perception/common/export_tensorrt.py b/dimos/perception/common/export_tensorrt.py new file mode 100644 index 0000000000..83a48be0a9 --- /dev/null +++ b/dimos/perception/common/export_tensorrt.py @@ -0,0 +1,58 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from ultralytics import YOLO, FastSAM + + +def parse_args(): # type: ignore[no-untyped-def] + parser = argparse.ArgumentParser(description="Export YOLO/FastSAM models to different formats") + parser.add_argument("--model_path", type=str, required=True, help="Path to the model weights") + parser.add_argument( + "--model_type", + type=str, + choices=["yolo", "fastsam"], + required=True, + help="Type of model to export", + ) + parser.add_argument( + "--precision", + type=str, + choices=["fp32", "fp16", "int8"], + default="fp32", + help="Precision for export", + ) + parser.add_argument( + "--format", type=str, choices=["onnx", "engine"], default="onnx", help="Export format" + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() # type: ignore[no-untyped-call] + half = args.precision == "fp16" + int8 = args.precision == "int8" + # Load the appropriate model + if args.model_type == "yolo": + model = YOLO(args.model_path) + else: + model = FastSAM(args.model_path) + + # Export the model + model.export(format=args.format, half=half, int8=int8) + + +if __name__ == "__main__": + main() diff --git a/dimos/perception/common/ibvs.py b/dimos/perception/common/ibvs.py new file mode 100644 index 0000000000..e24819f432 --- /dev/null +++ b/dimos/perception/common/ibvs.py @@ -0,0 +1,280 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +class PersonDistanceEstimator: + def __init__(self, K, camera_pitch, camera_height) -> None: # type: ignore[no-untyped-def] + """ + Initialize the distance estimator using ground plane constraint. + + Args: + K: 3x3 Camera intrinsic matrix in OpenCV format + (Assumed to be already for an undistorted image) + camera_pitch: Upward pitch of the camera (in radians), in the robot frame + Positive means looking up, negative means looking down + camera_height: Height of the camera above the ground (in meters) + """ + self.K = K + self.camera_height = camera_height + + # Precompute the inverse intrinsic matrix + self.K_inv = np.linalg.inv(K) + + # Transform from camera to robot frame (z-forward to x-forward) + self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) + + # Pitch rotation matrix (positive is upward) + theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y + self.R_pitch = np.array( + [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]] + ) + + # Combined transform from camera to robot frame + self.A = self.R_pitch @ self.T + + # Store focal length and principal point for angle calculation + self.fx = K[0, 0] + self.cx = K[0, 2] + + def estimate_distance_angle(self, bbox: tuple, robot_pitch: float | None = None): # type: ignore[no-untyped-def, type-arg] + """ + Estimate distance and angle to person using ground plane constraint. + + Args: + bbox: tuple (x_min, y_min, x_max, y_max) + where y_max represents the feet position + robot_pitch: Current pitch of the robot body (in radians) + If provided, this will be combined with the camera's fixed pitch + + Returns: + depth: distance to person along camera's z-axis (meters) + angle: horizontal angle in camera frame (radians, positive right) + """ + x_min, _, x_max, y_max = bbox + + # Get center point of feet + u_c = (x_min + x_max) / 2.0 + v_feet = y_max + + # Create homogeneous feet point and get ray direction + p_feet = np.array([u_c, v_feet, 1.0]) + d_feet_cam = self.K_inv @ p_feet + + # If robot_pitch is provided, recalculate the transformation matrix + if robot_pitch is not None: + # Combined pitch (fixed camera pitch + current robot pitch) + total_pitch = -camera_pitch - robot_pitch # Both negated for correct rotation direction + R_total_pitch = np.array( + [ + [np.cos(total_pitch), 0, np.sin(total_pitch)], + [0, 1, 0], + [-np.sin(total_pitch), 0, np.cos(total_pitch)], + ] + ) + # Use the updated transformation matrix + A = R_total_pitch @ self.T + else: + # Use the precomputed transformation matrix + A = self.A + + # Convert ray to robot frame using appropriate transformation + d_feet_robot = A @ d_feet_cam + + # Ground plane intersection (z=0) + # camera_height + t * d_feet_robot[2] = 0 + if abs(d_feet_robot[2]) < 1e-6: + raise ValueError("Feet ray is parallel to ground plane") + + # Solve for scaling factor t + t = -self.camera_height / d_feet_robot[2] + + # Get 3D feet position in robot frame + p_feet_robot = t * d_feet_robot + + # Convert back to camera frame + p_feet_cam = self.A.T @ p_feet_robot + + # Extract depth (z-coordinate in camera frame) + depth = p_feet_cam[2] + + # Calculate horizontal angle from image center + angle = np.arctan((u_c - self.cx) / self.fx) + + return depth, angle + + +class ObjectDistanceEstimator: + """ + Estimate distance to an object using the ground plane constraint. + This class assumes the camera is mounted on a robot and uses the + camera's intrinsic parameters to estimate the distance to a detected object. + """ + + def __init__(self, K, camera_pitch, camera_height) -> None: # type: ignore[no-untyped-def] + """ + Initialize the distance estimator using ground plane constraint. + + Args: + K: 3x3 Camera intrinsic matrix in OpenCV format + (Assumed to be already for an undistorted image) + camera_pitch: Upward pitch of the camera (in radians) + Positive means looking up, negative means looking down + camera_height: Height of the camera above the ground (in meters) + """ + self.K = K + self.camera_height = camera_height + + # Precompute the inverse intrinsic matrix + self.K_inv = np.linalg.inv(K) + + # Transform from camera to robot frame (z-forward to x-forward) + self.T = np.array([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) + + # Pitch rotation matrix (positive is upward) + theta = -camera_pitch # Negative since positive pitch is negative rotation about robot Y + self.R_pitch = np.array( + [[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]] + ) + + # Combined transform from camera to robot frame + self.A = self.R_pitch @ self.T + + # Store focal length and principal point for angle calculation + self.fx = K[0, 0] + self.fy = K[1, 1] + self.cx = K[0, 2] + self.estimated_object_size = None + + def estimate_object_size(self, bbox: tuple, distance: float): # type: ignore[no-untyped-def, type-arg] + """ + Estimate the physical size of an object based on its bbox and known distance. + + Args: + bbox: tuple (x_min, y_min, x_max, y_max) bounding box in the image + distance: Known distance to the object (in meters) + robot_pitch: Current pitch of the robot body (in radians), if any + + Returns: + estimated_size: Estimated physical height of the object (in meters) + """ + _x_min, y_min, _x_max, y_max = bbox + + # Calculate object height in pixels + object_height_px = y_max - y_min + + # Calculate the physical height using the known distance and focal length + estimated_size = object_height_px * distance / self.fy + self.estimated_object_size = estimated_size + + return estimated_size + + def set_estimated_object_size(self, size: float) -> None: + """ + Set the estimated object size for future distance calculations. + + Args: + size: Estimated physical size of the object (in meters) + """ + self.estimated_object_size = size # type: ignore[assignment] + + def estimate_distance_angle(self, bbox: tuple): # type: ignore[no-untyped-def, type-arg] + """ + Estimate distance and angle to object using size-based estimation. + + Args: + bbox: tuple (x_min, y_min, x_max, y_max) + where y_max represents the bottom of the object + robot_pitch: Current pitch of the robot body (in radians) + If provided, this will be combined with the camera's fixed pitch + initial_distance: Initial distance estimate for the object (in meters) + Used to calibrate object size if not previously known + + Returns: + depth: distance to object along camera's z-axis (meters) + angle: horizontal angle in camera frame (radians, positive right) + or None, None if estimation not possible + """ + # If we don't have estimated object size and no initial distance is provided, + # we can't estimate the distance + if self.estimated_object_size is None: + return None, None + + x_min, y_min, x_max, y_max = bbox + + # Calculate center of the object for angle calculation + u_c = (x_min + x_max) / 2.0 + + # If we have an initial distance estimate and no object size yet, + # calculate and store the object size using the initial distance + object_height_px = y_max - y_min + depth = self.estimated_object_size * self.fy / object_height_px + + # Calculate horizontal angle from image center + angle = np.arctan((u_c - self.cx) / self.fx) + + return depth, angle + + +# Example usage: +if __name__ == "__main__": + # Example camera calibration + K = np.array([[600, 0, 320], [0, 600, 240], [0, 0, 1]], dtype=np.float32) + + # Camera mounted 1.2m high, pitched down 10 degrees + camera_pitch = np.deg2rad(0) # negative for downward pitch + camera_height = 1.0 # meters + + estimator = PersonDistanceEstimator(K, camera_pitch, camera_height) + object_estimator = ObjectDistanceEstimator(K, camera_pitch, camera_height) + + # Example detection + bbox = (300, 100, 380, 400) # x1, y1, x2, y2 + + depth, angle = estimator.estimate_distance_angle(bbox) + # Estimate object size based on the known distance + object_size = object_estimator.estimate_object_size(bbox, depth) + depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox) + + print(f"Estimated person depth: {depth:.2f} m") + print(f"Estimated person angle: {np.rad2deg(angle):.1f}°") + print(f"Estimated object depth: {depth_obj:.2f} m") + print(f"Estimated object angle: {np.rad2deg(angle_obj):.1f}°") + + # Shrink the bbox by 30 pixels while keeping the same center + x_min, y_min, x_max, y_max = bbox + width = x_max - x_min + height = y_max - y_min + center_x = (x_min + x_max) // 2 + center_y = (y_min + y_max) // 2 + + new_width = max(width - 20, 2) # Ensure width is at least 2 pixels + new_height = max(height - 20, 2) # Ensure height is at least 2 pixels + + x_min = center_x - new_width // 2 + x_max = center_x + new_width // 2 + y_min = center_y - new_height // 2 + y_max = center_y + new_height // 2 + + bbox = (x_min, y_min, x_max, y_max) + + # Re-estimate distance and angle with the new bbox + depth, angle = estimator.estimate_distance_angle(bbox) + depth_obj, angle_obj = object_estimator.estimate_distance_angle(bbox) + + print(f"New estimated person depth: {depth:.2f} m") + print(f"New estimated person angle: {np.rad2deg(angle):.1f}°") + print(f"New estimated object depth: {depth_obj:.2f} m") + print(f"New estimated object angle: {np.rad2deg(angle_obj):.1f}°") diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py new file mode 100644 index 0000000000..df29197257 --- /dev/null +++ b/dimos/perception/common/utils.py @@ -0,0 +1,958 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from dimos_lcm.vision_msgs import ( # type: ignore[import-untyped] + BoundingBox2D, + Detection2D, + Detection3D, +) +import numpy as np +import torch +import yaml + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header +from dimos.types.manipulation import ObjectData +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Optional CuPy support +try: # pragma: no cover - optional dependency + import cupy as cp # type: ignore + + _HAS_CUDA = True +except Exception: # pragma: no cover - optional dependency + cp = None + _HAS_CUDA = False + + +def _is_cu_array(x) -> bool: # type: ignore[no-untyped-def] + return _HAS_CUDA and cp is not None and isinstance(x, cp.ndarray) + + +def _to_numpy(x): # type: ignore[no-untyped-def] + return cp.asnumpy(x) if _is_cu_array(x) else x + + +def _to_cupy(x): # type: ignore[no-untyped-def] + if _HAS_CUDA and cp is not None and isinstance(x, np.ndarray): + try: + return cp.asarray(x) + except Exception: + return x + return x + + +def load_camera_info(yaml_path: str, frame_id: str = "camera_link") -> CameraInfo: + """ + Load ROS-style camera_info YAML file and convert to CameraInfo LCM message. + + Args: + yaml_path: Path to camera_info YAML file (ROS format) + frame_id: Frame ID for the camera (default: "camera_link") + + Returns: + CameraInfo: LCM CameraInfo message with all calibration data + """ + with open(yaml_path) as f: + camera_info_data = yaml.safe_load(f) + + # Extract image dimensions + width = camera_info_data.get("image_width", 1280) + height = camera_info_data.get("image_height", 720) + + # Extract camera matrix (K) - already in row-major format + K = camera_info_data["camera_matrix"]["data"] + + # Extract distortion coefficients + D = camera_info_data["distortion_coefficients"]["data"] + + # Extract rectification matrix (R) if available, else use identity + R = camera_info_data.get("rectification_matrix", {}).get("data", [1, 0, 0, 0, 1, 0, 0, 0, 1]) + + # Extract projection matrix (P) if available + P = camera_info_data.get("projection_matrix", {}).get("data", None) + + # If P not provided, construct from K + if P is None: + fx = K[0] + fy = K[4] + cx = K[2] + cy = K[5] + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + # Create header + header = Header(frame_id) + + # Create and return CameraInfo message + return CameraInfo( + D_length=len(D), + header=header, + height=height, + width=width, + distortion_model=camera_info_data.get("distortion_model", "plumb_bob"), + D=D, + K=K, + R=R, + P=P, + binning_x=0, + binning_y=0, + ) + + +def load_camera_info_opencv(yaml_path: str) -> tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg] + """ + Load ROS-style camera_info YAML file and convert to OpenCV camera matrix and distortion coefficients. + + Args: + yaml_path: Path to camera_info YAML file (ROS format) + + Returns: + K: 3x3 camera intrinsic matrix + dist: 1xN distortion coefficients array (for plumb_bob model) + """ + with open(yaml_path) as f: + camera_info = yaml.safe_load(f) + + # Extract camera matrix (K) + camera_matrix_data = camera_info["camera_matrix"]["data"] + K = np.array(camera_matrix_data).reshape(3, 3) + + # Extract distortion coefficients + dist_coeffs_data = camera_info["distortion_coefficients"]["data"] + dist = np.array(dist_coeffs_data) + + # Ensure dist is 1D array for OpenCV compatibility + if dist.ndim == 2: + dist = dist.flatten() + + return K, dist + + +def rectify_image_cpu(image: Image, camera_matrix: np.ndarray, dist_coeffs: np.ndarray) -> Image: # type: ignore[type-arg] + """CPU rectification using OpenCV. Preserves backend by caller. + + Returns an Image with numpy or cupy data depending on caller choice. + """ + src = _to_numpy(image.data) # type: ignore[no-untyped-call] + rect = cv2.undistort(src, camera_matrix, dist_coeffs) + # Caller decides whether to convert back to GPU. + return Image(data=rect, format=image.format, frame_id=image.frame_id, ts=image.ts) + + +def rectify_image_cuda(image: Image, camera_matrix: np.ndarray, dist_coeffs: np.ndarray) -> Image: # type: ignore[type-arg] + """GPU rectification using CuPy bilinear sampling. + + Generates an undistorted output grid and samples from the distorted source. + Falls back to CPU if CUDA not available. + """ + if not _HAS_CUDA or cp is None or not image.is_cuda: + return rectify_image_cpu(image, camera_matrix, dist_coeffs) + + xp = cp + + # Source (distorted) image on device + src = image.data + if src.ndim not in (2, 3): + raise ValueError("Unsupported image rank for rectification") + H, W = int(src.shape[0]), int(src.shape[1]) + + # Extract intrinsics and distortion as float64 + K = xp.asarray(camera_matrix, dtype=xp.float64) + dist = xp.asarray(dist_coeffs, dtype=xp.float64).reshape(-1) + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + k1 = dist[0] if dist.size > 0 else 0.0 + k2 = dist[1] if dist.size > 1 else 0.0 + p1 = dist[2] if dist.size > 2 else 0.0 + p2 = dist[3] if dist.size > 3 else 0.0 + k3 = dist[4] if dist.size > 4 else 0.0 + + # Build undistorted target grid (pixel coords) + u = xp.arange(W, dtype=xp.float64) + v = xp.arange(H, dtype=xp.float64) + uu, vv = xp.meshgrid(u, v, indexing="xy") + + # Convert to normalized undistorted coords + xu = (uu - cx) / fx + yu = (vv - cy) / fy + + # Apply forward distortion model to get distorted normalized coords + r2 = xu * xu + yu * yu + r4 = r2 * r2 + r6 = r4 * r2 + radial = 1.0 + k1 * r2 + k2 * r4 + k3 * r6 + delta_x = 2.0 * p1 * xu * yu + p2 * (r2 + 2.0 * xu * xu) + delta_y = p1 * (r2 + 2.0 * yu * yu) + 2.0 * p2 * xu * yu + xd = xu * radial + delta_x + yd = yu * radial + delta_y + + # Back to pixel coordinates in the source (distorted) image + us = fx * xd + cx + vs = fy * yd + cy + + # Bilinear sample from src at (vs, us) + def _bilinear_sample_cuda(img, x_src, y_src): # type: ignore[no-untyped-def] + h, w = int(img.shape[0]), int(img.shape[1]) + # Base integer corners (not clamped) + x0i = xp.floor(x_src).astype(xp.int32) + y0i = xp.floor(y_src).astype(xp.int32) + x1i = x0i + 1 + y1i = y0i + 1 + + # Masks for in-bounds neighbors (BORDER_CONSTANT behavior) + m00 = (x0i >= 0) & (x0i < w) & (y0i >= 0) & (y0i < h) + m10 = (x1i >= 0) & (x1i < w) & (y0i >= 0) & (y0i < h) + m01 = (x0i >= 0) & (x0i < w) & (y1i >= 0) & (y1i < h) + m11 = (x1i >= 0) & (x1i < w) & (y1i >= 0) & (y1i < h) + + # Clamp indices for safe gather, but multiply contributions by masks + x0 = xp.clip(x0i, 0, w - 1) + y0 = xp.clip(y0i, 0, h - 1) + x1 = xp.clip(x1i, 0, w - 1) + y1 = xp.clip(y1i, 0, h - 1) + + # Weights + wx = (x_src - x0i).astype(xp.float64) + wy = (y_src - y0i).astype(xp.float64) + w00 = (1.0 - wx) * (1.0 - wy) + w10 = wx * (1.0 - wy) + w01 = (1.0 - wx) * wy + w11 = wx * wy + + # Cast masks for arithmetic + m00f = m00.astype(xp.float64) + m10f = m10.astype(xp.float64) + m01f = m01.astype(xp.float64) + m11f = m11.astype(xp.float64) + + if img.ndim == 2: + Ia = img[y0, x0].astype(xp.float64) + Ib = img[y0, x1].astype(xp.float64) + Ic = img[y1, x0].astype(xp.float64) + Id = img[y1, x1].astype(xp.float64) + out = w00 * m00f * Ia + w10 * m10f * Ib + w01 * m01f * Ic + w11 * m11f * Id + else: + Ia = img[y0, x0].astype(xp.float64) + Ib = img[y0, x1].astype(xp.float64) + Ic = img[y1, x0].astype(xp.float64) + Id = img[y1, x1].astype(xp.float64) + # Expand weights and masks for channel broadcasting + w00e = (w00 * m00f)[..., None] + w10e = (w10 * m10f)[..., None] + w01e = (w01 * m01f)[..., None] + w11e = (w11 * m11f)[..., None] + out = w00e * Ia + w10e * Ib + w01e * Ic + w11e * Id + + # Cast back to original dtype with clipping for integers + if img.dtype == xp.uint8: + out = xp.clip(xp.rint(out), 0, 255).astype(xp.uint8) + elif img.dtype == xp.uint16: + out = xp.clip(xp.rint(out), 0, 65535).astype(xp.uint16) + elif img.dtype == xp.int16: + out = xp.clip(xp.rint(out), -32768, 32767).astype(xp.int16) + else: + out = out.astype(img.dtype, copy=False) + return out + + rect = _bilinear_sample_cuda(src, us, vs) # type: ignore[no-untyped-call] + return Image(data=rect, format=image.format, frame_id=image.frame_id, ts=image.ts) + + +def rectify_image(image: Image, camera_matrix: np.ndarray, dist_coeffs: np.ndarray) -> Image: # type: ignore[type-arg] + """ + Rectify (undistort) an image using camera calibration parameters. + + Args: + image: Input Image object to rectify + camera_matrix: 3x3 camera intrinsic matrix (K) + dist_coeffs: Distortion coefficients array + + Returns: + Image: Rectified Image object with same format and metadata + """ + if image.is_cuda and _HAS_CUDA: + return rectify_image_cuda(image, camera_matrix, dist_coeffs) + return rectify_image_cpu(image, camera_matrix, dist_coeffs) + + +def project_3d_points_to_2d_cuda( + points_3d: "cp.ndarray", camera_intrinsics: Union[list[float], "cp.ndarray"] +) -> "cp.ndarray": + xp = cp + pts = points_3d.astype(xp.float64, copy=False) + mask = pts[:, 2] > 0 + if not bool(xp.any(mask)): + return xp.zeros((0, 2), dtype=xp.int32) + valid = pts[mask] + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = [xp.asarray(v, dtype=xp.float64) for v in camera_intrinsics] + else: + K = camera_intrinsics.astype(xp.float64, copy=False) # type: ignore[union-attr] + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + u = (valid[:, 0] * fx / valid[:, 2]) + cx + v = (valid[:, 1] * fy / valid[:, 2]) + cy + return xp.stack([u, v], axis=1).astype(xp.int32) + + +def project_3d_points_to_2d_cpu( + points_3d: np.ndarray, # type: ignore[type-arg] + camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] +) -> np.ndarray: # type: ignore[type-arg] + pts = np.asarray(points_3d, dtype=np.float64) + valid_mask = pts[:, 2] > 0 + if not np.any(valid_mask): + return np.zeros((0, 2), dtype=np.int32) + valid_points = pts[valid_mask] + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = [float(v) for v in camera_intrinsics] + else: + K = np.array(camera_intrinsics, dtype=np.float64) + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + u = (valid_points[:, 0] * fx / valid_points[:, 2]) + cx + v = (valid_points[:, 1] * fy / valid_points[:, 2]) + cy + return np.column_stack([u, v]).astype(np.int32) + + +def project_3d_points_to_2d( + points_3d: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + camera_intrinsics: Union[list[float], np.ndarray, "cp.ndarray"], # type: ignore[type-arg] +) -> Union[np.ndarray, "cp.ndarray"]: # type: ignore[type-arg] + """ + Project 3D points to 2D image coordinates using camera intrinsics. + + Args: + points_3d: Nx3 array of 3D points (X, Y, Z) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx2 array of 2D image coordinates (u, v) + """ + if len(points_3d) == 0: + return ( + cp.zeros((0, 2), dtype=cp.int32) + if _is_cu_array(points_3d) + else np.zeros((0, 2), dtype=np.int32) + ) + + # Filter out points with zero or negative depth + if _is_cu_array(points_3d) or _is_cu_array(camera_intrinsics): + xp = cp + pts = points_3d if _is_cu_array(points_3d) else xp.asarray(points_3d) + K = camera_intrinsics if _is_cu_array(camera_intrinsics) else camera_intrinsics + return project_3d_points_to_2d_cuda(pts, K) + return project_3d_points_to_2d_cpu(np.asarray(points_3d), np.asarray(camera_intrinsics)) + + +def project_2d_points_to_3d_cuda( + points_2d: "cp.ndarray", + depth_values: "cp.ndarray", + camera_intrinsics: Union[list[float], "cp.ndarray"], +) -> "cp.ndarray": + xp = cp + pts = points_2d.astype(xp.float64, copy=False) + depths = depth_values.astype(xp.float64, copy=False) + valid = depths > 0 + if not bool(xp.any(valid)): + return xp.zeros((0, 3), dtype=xp.float32) + uv = pts[valid] + Z = depths[valid] + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = [xp.asarray(v, dtype=xp.float64) for v in camera_intrinsics] + else: + K = camera_intrinsics.astype(xp.float64, copy=False) # type: ignore[union-attr] + fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] + X = (uv[:, 0] - cx) * Z / fx + Y = (uv[:, 1] - cy) * Z / fy + return xp.stack([X, Y, Z], axis=1).astype(xp.float32) + + +def project_2d_points_to_3d_cpu( + points_2d: np.ndarray, # type: ignore[type-arg] + depth_values: np.ndarray, # type: ignore[type-arg] + camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] +) -> np.ndarray: # type: ignore[type-arg] + pts = np.asarray(points_2d, dtype=np.float64) + depths = np.asarray(depth_values, dtype=np.float64) + valid_mask = depths > 0 + if not np.any(valid_mask): + return np.zeros((0, 3), dtype=np.float32) + valid_points_2d = pts[valid_mask] + valid_depths = depths[valid_mask] + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = [float(v) for v in camera_intrinsics] + else: + camera_matrix = np.array(camera_intrinsics, dtype=np.float64) + fx = camera_matrix[0, 0] + fy = camera_matrix[1, 1] + cx = camera_matrix[0, 2] + cy = camera_matrix[1, 2] + X = (valid_points_2d[:, 0] - cx) * valid_depths / fx + Y = (valid_points_2d[:, 1] - cy) * valid_depths / fy + Z = valid_depths + return np.column_stack([X, Y, Z]).astype(np.float32) + + +def project_2d_points_to_3d( + points_2d: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + depth_values: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + camera_intrinsics: Union[list[float], np.ndarray, "cp.ndarray"], # type: ignore[type-arg] +) -> Union[np.ndarray, "cp.ndarray"]: # type: ignore[type-arg] + """ + Project 2D image points to 3D coordinates using depth values and camera intrinsics. + + Args: + points_2d: Nx2 array of 2D image coordinates (u, v) + depth_values: N-length array of depth values (Z coordinates) for each point + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + + Returns: + Nx3 array of 3D points (X, Y, Z) + """ + if len(points_2d) == 0: + return ( + cp.zeros((0, 3), dtype=cp.float32) + if _is_cu_array(points_2d) + else np.zeros((0, 3), dtype=np.float32) + ) + + # Ensure depth_values is a numpy array + if _is_cu_array(points_2d) or _is_cu_array(depth_values) or _is_cu_array(camera_intrinsics): + xp = cp + pts = points_2d if _is_cu_array(points_2d) else xp.asarray(points_2d) + depths = depth_values if _is_cu_array(depth_values) else xp.asarray(depth_values) + K = camera_intrinsics if _is_cu_array(camera_intrinsics) else camera_intrinsics + return project_2d_points_to_3d_cuda(pts, depths, K) + return project_2d_points_to_3d_cpu( + np.asarray(points_2d), np.asarray(depth_values), np.asarray(camera_intrinsics) + ) + + +def colorize_depth( + depth_img: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + max_depth: float = 5.0, + overlay_stats: bool = True, +) -> Union[np.ndarray, "cp.ndarray"] | None: # type: ignore[type-arg] + """ + Normalize and colorize depth image using COLORMAP_JET with optional statistics overlay. + + Args: + depth_img: Depth image (H, W) in meters + max_depth: Maximum depth value for normalization + overlay_stats: Whether to overlay depth statistics on the image + + Returns: + Colorized depth image (H, W, 3) in RGB format, or None if input is None + """ + if depth_img is None: + return None + + was_cu = _is_cu_array(depth_img) + xp = cp if was_cu else np + depth = depth_img if was_cu else np.asarray(depth_img) + + valid_mask = xp.isfinite(depth) & (depth > 0) + depth_norm = xp.zeros_like(depth, dtype=xp.float32) + if bool(valid_mask.any() if not was_cu else xp.any(valid_mask)): + depth_norm = xp.where(valid_mask, xp.clip(depth / max_depth, 0, 1), depth_norm) + + # Use CPU for colormap/text; convert back to GPU if needed + depth_norm_np = _to_numpy(depth_norm) # type: ignore[no-untyped-call] + depth_colored = cv2.applyColorMap((depth_norm_np * 255).astype(np.uint8), cv2.COLORMAP_JET) + depth_rgb_np = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB) + depth_rgb_np = (depth_rgb_np * 0.6).astype(np.uint8) + + if overlay_stats and (np.any(_to_numpy(valid_mask))): # type: ignore[no-untyped-call] + valid_depths = _to_numpy(depth)[_to_numpy(valid_mask)] # type: ignore[no-untyped-call] + min_depth = float(np.min(valid_depths)) + max_depth_actual = float(np.max(valid_depths)) + h, w = depth_rgb_np.shape[:2] + center_y, center_x = h // 2, w // 2 + center_region = _to_numpy( + depth + )[ # type: ignore[no-untyped-call] + max(0, center_y - 2) : min(h, center_y + 3), max(0, center_x - 2) : min(w, center_x + 3) + ] + center_mask = np.isfinite(center_region) & (center_region > 0) + if center_mask.any(): + center_depth = float(np.median(center_region[center_mask])) + else: + depth_np = _to_numpy(depth) # type: ignore[no-untyped-call] + vm_np = _to_numpy(valid_mask) # type: ignore[no-untyped-call] + center_depth = float(depth_np[center_y, center_x]) if vm_np[center_y, center_x] else 0.0 + + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 1 + line_type = cv2.LINE_AA + text_color = (255, 255, 255) + bg_color = (0, 0, 0) + padding = 5 + + min_text = f"Min: {min_depth:.2f}m" + (text_w, text_h), _ = cv2.getTextSize(min_text, font, font_scale, thickness) + cv2.rectangle( + depth_rgb_np, + (padding, padding), + (padding + text_w + 4, padding + text_h + 6), + bg_color, + -1, + ) + cv2.putText( + depth_rgb_np, + min_text, + (padding + 2, padding + text_h + 2), + font, + font_scale, + text_color, + thickness, + line_type, + ) + + max_text = f"Max: {max_depth_actual:.2f}m" + (text_w, text_h), _ = cv2.getTextSize(max_text, font, font_scale, thickness) + cv2.rectangle( + depth_rgb_np, + (w - padding - text_w - 4, padding), + (w - padding, padding + text_h + 6), + bg_color, + -1, + ) + cv2.putText( + depth_rgb_np, + max_text, + (w - padding - text_w - 2, padding + text_h + 2), + font, + font_scale, + text_color, + thickness, + line_type, + ) + + if center_depth > 0: + center_text = f"{center_depth:.2f}m" + (text_w, text_h), _ = cv2.getTextSize(center_text, font, font_scale, thickness) + center_text_x = center_x - text_w // 2 + center_text_y = center_y + text_h // 2 + cross_size = 10 + cross_color = (255, 255, 255) + cv2.line( + depth_rgb_np, + (center_x - cross_size, center_y), + (center_x + cross_size, center_y), + cross_color, + 1, + ) + cv2.line( + depth_rgb_np, + (center_x, center_y - cross_size), + (center_x, center_y + cross_size), + cross_color, + 1, + ) + cv2.rectangle( + depth_rgb_np, + (center_text_x - 2, center_text_y - text_h - 2), + (center_text_x + text_w + 2, center_text_y + 2), + bg_color, + -1, + ) + cv2.putText( + depth_rgb_np, + center_text, + (center_text_x, center_text_y), + font, + font_scale, + text_color, + thickness, + line_type, + ) + + return _to_cupy(depth_rgb_np) if was_cu else depth_rgb_np # type: ignore[no-untyped-call] + + +def draw_bounding_box( + image: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + bbox: list[float], + color: tuple[int, int, int] = (0, 255, 0), + thickness: int = 2, + label: str | None = None, + confidence: float | None = None, + object_id: int | None = None, + font_scale: float = 0.6, +) -> Union[np.ndarray, "cp.ndarray"]: # type: ignore[type-arg] + """ + Draw a bounding box with optional label on an image. + + Args: + image: Image to draw on (H, W, 3) + bbox: Bounding box [x1, y1, x2, y2] + color: RGB color tuple for the box + thickness: Line thickness for the box + label: Optional class label + confidence: Optional confidence score + object_id: Optional object ID + font_scale: Font scale for text + + Returns: + Image with bounding box drawn + """ + was_cu = _is_cu_array(image) + img_np = _to_numpy(image) # type: ignore[no-untyped-call] + x1, y1, x2, y2 = map(int, bbox) + cv2.rectangle(img_np, (x1, y1), (x2, y2), color, thickness) + + # Create label text + text_parts = [] + if label is not None: + text_parts.append(str(label)) + if object_id is not None: + text_parts.append(f"ID: {object_id}") + if confidence is not None: + text_parts.append(f"({confidence:.2f})") + + if text_parts: + text = ", ".join(text_parts) + + # Draw text background + text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1)[0] + cv2.rectangle( + img_np, + (x1, y1 - text_size[1] - 5), + (x1 + text_size[0], y1), + (0, 0, 0), + -1, + ) + + # Draw text + cv2.putText( + img_np, + text, + (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (255, 255, 255), + 1, + ) + + return _to_cupy(img_np) if was_cu else img_np # type: ignore[no-untyped-call] + + +def draw_segmentation_mask( + image: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + mask: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + color: tuple[int, int, int] = (0, 200, 200), + alpha: float = 0.5, + draw_contours: bool = True, + contour_thickness: int = 2, +) -> Union[np.ndarray, "cp.ndarray"]: # type: ignore[type-arg] + """ + Draw segmentation mask overlay on an image. + + Args: + image: Image to draw on (H, W, 3) + mask: Segmentation mask (H, W) - boolean or uint8 + color: RGB color for the mask + alpha: Transparency factor (0.0 = transparent, 1.0 = opaque) + draw_contours: Whether to draw mask contours + contour_thickness: Thickness of contour lines + + Returns: + Image with mask overlay drawn + """ + if mask is None: + return image + + was_cu = _is_cu_array(image) + img_np = _to_numpy(image) # type: ignore[no-untyped-call] + mask_np = _to_numpy(mask) # type: ignore[no-untyped-call] + + try: + mask_np = mask_np.astype(np.uint8) + colored_mask = np.zeros_like(img_np) + colored_mask[mask_np > 0] = color + mask_area = mask_np > 0 + img_np[mask_area] = cv2.addWeighted( + img_np[mask_area], 1 - alpha, colored_mask[mask_area], alpha, 0 + ) + if draw_contours: + contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + cv2.drawContours(img_np, contours, -1, color, contour_thickness) + except Exception as e: + logger.warning(f"Error drawing segmentation mask: {e}") + + return _to_cupy(img_np) if was_cu else img_np # type: ignore[no-untyped-call] + + +def draw_object_detection_visualization( + image: Union[np.ndarray, "cp.ndarray"], # type: ignore[type-arg] + objects: list[ObjectData], + draw_masks: bool = False, + bbox_color: tuple[int, int, int] = (0, 255, 0), + mask_color: tuple[int, int, int] = (0, 200, 200), + font_scale: float = 0.6, +) -> Union[np.ndarray, "cp.ndarray"]: # type: ignore[type-arg] + """ + Create object detection visualization with bounding boxes and optional masks. + + Args: + image: Base image to draw on (H, W, 3) + objects: List of ObjectData with detection information + draw_masks: Whether to draw segmentation masks + bbox_color: Default color for bounding boxes + mask_color: Default color for segmentation masks + font_scale: Font scale for text labels + + Returns: + Image with detection visualization + """ + was_cu = _is_cu_array(image) + viz_image = _to_numpy(image).copy() # type: ignore[no-untyped-call] + + for obj in objects: + try: + # Draw segmentation mask first (if enabled and available) + if draw_masks and "segmentation_mask" in obj and obj["segmentation_mask"] is not None: + viz_image = draw_segmentation_mask( + viz_image, obj["segmentation_mask"], color=mask_color, alpha=0.5 + ) + + # Draw bounding box + if "bbox" in obj and obj["bbox"] is not None: + # Use object's color if available, otherwise default + color = bbox_color + if "color" in obj and obj["color"] is not None: + obj_color = obj["color"] + if isinstance(obj_color, np.ndarray): + color = tuple(int(c) for c in obj_color) # type: ignore[assignment] + elif isinstance(obj_color, list | tuple): + color = tuple(int(c) for c in obj_color[:3]) + + viz_image = draw_bounding_box( + viz_image, + obj["bbox"], + color=color, + label=obj.get("label"), + confidence=obj.get("confidence"), + object_id=obj.get("object_id"), + font_scale=font_scale, + ) + + except Exception as e: + logger.warning(f"Error drawing object visualization: {e}") + + return _to_cupy(viz_image) if was_cu else viz_image # type: ignore[no-untyped-call] + + +def detection_results_to_object_data( + bboxes: list[list[float]], + track_ids: list[int], + class_ids: list[int], + confidences: list[float], + names: list[str], + masks: list[np.ndarray] | None = None, # type: ignore[type-arg] + source: str = "detection", +) -> list[ObjectData]: + """ + Convert detection/segmentation results to ObjectData format. + + Args: + bboxes: List of bounding boxes [x1, y1, x2, y2] + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + masks: Optional list of segmentation masks + source: Source type ("detection" or "segmentation") + + Returns: + List of ObjectData dictionaries + """ + objects = [] + + for i in range(len(bboxes)): + # Calculate basic properties from bbox + bbox = bboxes[i] + width = bbox[2] - bbox[0] + height = bbox[3] - bbox[1] + bbox[0] + width / 2 + bbox[1] + height / 2 + + # Create ObjectData + object_data: ObjectData = { + "object_id": track_ids[i] if i < len(track_ids) else i, + "bbox": bbox, + "depth": -1.0, # Will be populated by depth estimation or point cloud processing + "confidence": confidences[i] if i < len(confidences) else 1.0, + "class_id": class_ids[i] if i < len(class_ids) else 0, + "label": names[i] if i < len(names) else f"{source}_object", + "movement_tolerance": 1.0, # Default to freely movable + "segmentation_mask": masks[i].cpu().numpy() # type: ignore[attr-defined, typeddict-item] + if masks and i < len(masks) and isinstance(masks[i], torch.Tensor) + else masks[i] + if masks and i < len(masks) + else None, + # Initialize 3D properties (will be populated by point cloud processing) + "position": Vector(0, 0, 0), # type: ignore[arg-type] + "rotation": Vector(0, 0, 0), # type: ignore[arg-type] + "size": { + "width": 0.0, + "height": 0.0, + "depth": 0.0, + }, + } + objects.append(object_data) + + return objects + + +def combine_object_data( + list1: list[ObjectData], list2: list[ObjectData], overlap_threshold: float = 0.8 +) -> list[ObjectData]: + """ + Combine two ObjectData lists, removing duplicates based on segmentation mask overlap. + """ + combined = list1.copy() + used_ids = set(obj.get("object_id", 0) for obj in list1) + next_id = max(used_ids) + 1 if used_ids else 1 + + for obj2 in list2: + obj_copy = obj2.copy() + + # Handle duplicate object_id + if obj_copy.get("object_id", 0) in used_ids: + obj_copy["object_id"] = next_id + next_id += 1 + used_ids.add(obj_copy["object_id"]) + + # Check mask overlap + mask2 = obj2.get("segmentation_mask") + m2 = _to_numpy(mask2) if mask2 is not None else None # type: ignore[no-untyped-call] + if m2 is None or np.sum(m2 > 0) == 0: + combined.append(obj_copy) + continue + + mask2_area = np.sum(m2 > 0) + is_duplicate = False + + for obj1 in list1: + mask1 = obj1.get("segmentation_mask") + if mask1 is None: + continue + + m1 = _to_numpy(mask1) # type: ignore[no-untyped-call] + intersection = np.sum((m1 > 0) & (m2 > 0)) + if intersection / mask2_area >= overlap_threshold: + is_duplicate = True + break + + if not is_duplicate: + combined.append(obj_copy) + + return combined + + +def point_in_bbox(point: tuple[int, int], bbox: list[float]) -> bool: + """ + Check if a point is inside a bounding box. + + Args: + point: (x, y) coordinates + bbox: Bounding box [x1, y1, x2, y2] + + Returns: + True if point is inside bbox + """ + x, y = point + x1, y1, x2, y2 = bbox + return x1 <= x <= x2 and y1 <= y <= y2 + + +def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> tuple[float, float, float, float]: + """ + Convert BoundingBox2D from center format to corner format. + + Args: + bbox_2d: BoundingBox2D with center and size + + Returns: + Tuple of (x1, y1, x2, y2) corner coordinates + """ + center_x = bbox_2d.center.position.x + center_y = bbox_2d.center.position.y + half_width = bbox_2d.size_x / 2.0 + half_height = bbox_2d.size_y / 2.0 + + x1 = center_x - half_width + y1 = center_y - half_height + x2 = center_x + half_width + y2 = center_y + half_height + + return x1, y1, x2, y2 + + +def find_clicked_detection( + click_pos: tuple[int, int], detections_2d: list[Detection2D], detections_3d: list[Detection3D] +) -> Detection3D | None: + """ + Find which detection was clicked based on 2D bounding boxes. + + Args: + click_pos: (x, y) click position + detections_2d: List of Detection2D objects + detections_3d: List of Detection3D objects (must be 1:1 correspondence) + + Returns: + Corresponding Detection3D object if found, None otherwise + """ + click_x, click_y = click_pos + + for i, det_2d in enumerate(detections_2d): + if det_2d.bbox and i < len(detections_3d): + x1, y1, x2, y2 = bbox2d_to_corners(det_2d.bbox) + + if x1 <= click_x <= x2 and y1 <= click_y <= y2: + return detections_3d[i] + + return None + + +def extract_pose_from_detection3d(detection3d: Detection3D): # type: ignore[no-untyped-def] + """Extract PoseStamped from Detection3D message. + + Args: + detection3d: Detection3D message + + Returns: + Pose or None if no valid detection + """ + if not detection3d or not detection3d.bbox or not detection3d.bbox.center: + return None + + # Extract position + pos = detection3d.bbox.center.position + position = Vector3(pos.x, pos.y, pos.z) + + # Extract orientation + orient = detection3d.bbox.center.orientation + orientation = Quaternion(orient.x, orient.y, orient.z, orient.w) + + pose = Pose(position=position, orientation=orientation) + return pose diff --git a/dimos/perception/detection/__init__.py b/dimos/perception/detection/__init__.py new file mode 100644 index 0000000000..72663a69b0 --- /dev/null +++ b/dimos/perception/detection/__init__.py @@ -0,0 +1,7 @@ +from dimos.perception.detection.detectors import * +from dimos.perception.detection.module2D import ( + Detection2DModule, +) +from dimos.perception.detection.module3D import ( + Detection3DModule, +) diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py new file mode 100644 index 0000000000..1c9c8ca05c --- /dev/null +++ b/dimos/perception/detection/conftest.py @@ -0,0 +1,304 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Generator +import functools +from typing import TypedDict + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate +from dimos_lcm.visualization_msgs.MarkerArray import MarkerArray +import pytest + +from dimos.core import LCMTransport +from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.module2D import Detection2DModule +from dimos.perception.detection.module3D import Detection3DModule +from dimos.perception.detection.moduleDB import ObjectDBModule +from dimos.perception.detection.type import ( + Detection2D, + Detection3DPC, + ImageDetections2D, + ImageDetections3DPC, +) +from dimos.protocol.tf import TF +from dimos.robot.unitree.connection import go2 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.data import get_data +from dimos.utils.testing import TimedSensorReplay + + +class Moment(TypedDict, total=False): + odom_frame: Odometry + lidar_frame: LidarMessage + image_frame: Image + camera_info: CameraInfo + transforms: list[Transform] + tf: TF + annotations: ImageAnnotations | None + detections: ImageDetections3DPC | None + markers: MarkerArray | None + scene_update: SceneUpdate | None + + +class Moment2D(Moment): + detections2d: ImageDetections2D + + +class Moment3D(Moment): + detections3dpc: ImageDetections3DPC + + +@pytest.fixture(scope="session") +def tf(): + t = TF() + yield t + t.stop() + + +@pytest.fixture(scope="session") +def get_moment(tf): + @functools.lru_cache(maxsize=1) + def moment_provider(**kwargs) -> Moment: + print("MOMENT PROVIDER ARGS:", kwargs) + seek = kwargs.get("seek", 10.0) + + data_dir = "unitree_go2_lidar_corrected" + get_data(data_dir) + + lidar_frame_result = TimedSensorReplay(f"{data_dir}/lidar").find_closest_seek(seek) + if lidar_frame_result is None: + raise ValueError("No lidar frame found") + lidar_frame: LidarMessage = lidar_frame_result + + image_frame = TimedSensorReplay( + f"{data_dir}/video", + ).find_closest(lidar_frame.ts) + + if image_frame is None: + raise ValueError("No image frame found") + + image_frame.frame_id = "camera_optical" + + odom_frame = TimedSensorReplay(f"{data_dir}/odom", autocast=Odometry.from_msg).find_closest( + lidar_frame.ts + ) + + if odom_frame is None: + raise ValueError("No odom frame found") + + transforms = go2.GO2Connection._odom_to_tf(odom_frame) + + tf.receive_transform(*transforms) + + return { + "odom_frame": odom_frame, + "lidar_frame": lidar_frame, + "image_frame": image_frame, + "camera_info": go2._camera_info_static(), + "transforms": transforms, + "tf": tf, + } + + return moment_provider + + +@pytest.fixture(scope="session") +def publish_moment(): + def publisher(moment: Moment | Moment2D | Moment3D) -> None: + detections2d_val = moment.get("detections2d") + if detections2d_val: + # 2d annotations + annotations: LCMTransport[ImageAnnotations] = LCMTransport( + "/annotations", ImageAnnotations + ) + assert isinstance(detections2d_val, ImageDetections2D) + annotations.publish(detections2d_val.to_foxglove_annotations()) + + detections: LCMTransport[Detection2DArray] = LCMTransport( + "/detections", Detection2DArray + ) + detections.publish(detections2d_val.to_ros_detection2d_array()) + + annotations.lcm.stop() + detections.lcm.stop() + + detections3dpc_val = moment.get("detections3dpc") + if detections3dpc_val: + scene_update: LCMTransport[SceneUpdate] = LCMTransport("/scene_update", SceneUpdate) + # 3d scene update + assert isinstance(detections3dpc_val, ImageDetections3DPC) + scene_update.publish(detections3dpc_val.to_foxglove_scene_update()) + scene_update.lcm.stop() + + lidar_frame = moment.get("lidar_frame") + if lidar_frame: + lidar: LCMTransport[PointCloud2] = LCMTransport("/lidar", PointCloud2) + lidar.publish(lidar_frame) + lidar.lcm.stop() + + image_frame = moment.get("image_frame") + if image_frame: + image: LCMTransport[Image] = LCMTransport("/image", Image) + image.publish(image_frame) + image.lcm.stop() + + camera_info_val = moment.get("camera_info") + if camera_info_val: + camera_info: LCMTransport[CameraInfo] = LCMTransport("/camera_info", CameraInfo) + camera_info.publish(camera_info_val) + camera_info.lcm.stop() + + tf = moment.get("tf") + transforms = moment.get("transforms") + if tf is not None and transforms is not None: + tf.publish(*transforms) + + # moduleDB.scene_update.transport = LCMTransport("/scene_update", SceneUpdate) + # moduleDB.target.transport = LCMTransport("/target", PoseStamped) + + return publisher + + +@pytest.fixture(scope="session") +def imageDetections2d(get_moment_2d) -> ImageDetections2D: + moment = get_moment_2d() + assert len(moment["detections2d"]) > 0, "No detections found in the moment" + return moment["detections2d"] + + +@pytest.fixture(scope="session") +def detection2d(get_moment_2d) -> Detection2D: + moment = get_moment_2d() + assert len(moment["detections2d"]) > 0, "No detections found in the moment" + return moment["detections2d"][0] + + +@pytest.fixture(scope="session") +def detections3dpc(get_moment_3dpc) -> Detection3DPC: + moment = get_moment_3dpc(seek=10.0) + assert len(moment["detections3dpc"]) > 0, "No detections found in the moment" + return moment["detections3dpc"] + + +@pytest.fixture(scope="session") +def detection3dpc(detections3dpc) -> Detection3DPC: + return detections3dpc[0] + + +@pytest.fixture(scope="session") +def get_moment_2d(get_moment) -> Generator[Callable[[], Moment2D], None, None]: + from dimos.perception.detection.detectors import Yolo2DDetector + + module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + + @functools.lru_cache(maxsize=1) + def moment_provider(**kwargs) -> Moment2D: + moment = get_moment(**kwargs) + detections = module.process_image_frame(moment.get("image_frame")) + + return { + **moment, + "detections2d": detections, + } + + yield moment_provider + + module._close_module() + + +@pytest.fixture(scope="session") +def get_moment_3dpc(get_moment_2d) -> Generator[Callable[[], Moment3D], None, None]: + module: Detection3DModule | None = None + + @functools.lru_cache(maxsize=1) + def moment_provider(**kwargs) -> Moment3D: + nonlocal module + moment = get_moment_2d(**kwargs) + + if not module: + module = Detection3DModule(camera_info=moment["camera_info"]) + + lidar_frame = moment.get("lidar_frame") + if lidar_frame is None: + raise ValueError("No lidar frame found") + + camera_transform = moment["tf"].get("camera_optical", lidar_frame.frame_id) + if camera_transform is None: + raise ValueError("No camera_optical transform in tf") + + detections3dpc = module.process_frame( + moment["detections2d"], moment["lidar_frame"], camera_transform + ) + + return { + **moment, + "detections3dpc": detections3dpc, + } + + yield moment_provider + if module is not None: + module._close_module() + + +@pytest.fixture(scope="session") +def object_db_module(get_moment): + """Create and populate an ObjectDBModule with detections from multiple frames.""" + from dimos.perception.detection.detectors import Yolo2DDetector + + module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + module3d = Detection3DModule(camera_info=go2._camera_info_static()) + moduleDB = ObjectDBModule(camera_info=go2._camera_info_static()) + + # Process 5 frames to build up object history + for i in range(5): + seek_value = 10.0 + (i * 2) + moment = get_moment(seek=seek_value) + + # Process 2D detections + imageDetections2d = module2d.process_image_frame(moment["image_frame"]) + + # Get camera transform + camera_transform = moment["tf"].get("camera_optical", moment.get("lidar_frame").frame_id) + + # Process 3D detections + imageDetections3d = module3d.process_frame( + imageDetections2d, moment["lidar_frame"], camera_transform + ) + + # Add to database + moduleDB.add_detections(imageDetections3d) + + yield moduleDB + + module2d._close_module() + module3d._close_module() + moduleDB._close_module() + + +@pytest.fixture(scope="session") +def first_object(object_db_module): + """Get the first object from the database.""" + objects = list(object_db_module.objects.values()) + assert len(objects) > 0, "No objects found in database" + return objects[0] + + +@pytest.fixture(scope="session") +def all_objects(object_db_module): + """Get all objects from the database.""" + return list(object_db_module.objects.values()) diff --git a/dimos/perception/detection/detectors/__init__.py b/dimos/perception/detection/detectors/__init__.py new file mode 100644 index 0000000000..d6383d084e --- /dev/null +++ b/dimos/perception/detection/detectors/__init__.py @@ -0,0 +1,3 @@ +# from dimos.perception.detection.detectors.detic import Detic2DDetector +from dimos.perception.detection.detectors.types import Detector +from dimos.perception.detection.detectors.yolo import Yolo2DDetector diff --git a/dimos/perception/detection/detectors/config/custom_tracker.yaml b/dimos/perception/detection/detectors/config/custom_tracker.yaml new file mode 100644 index 0000000000..7a6748ebf6 --- /dev/null +++ b/dimos/perception/detection/detectors/config/custom_tracker.yaml @@ -0,0 +1,21 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Default Ultralytics settings for BoT-SORT tracker when using mode="track" +# For documentation and examples see https://docs.ultralytics.com/modes/track/ +# For BoT-SORT source code see https://github.com/NirAharon/BoT-SORT + +tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.4 # threshold for the first association +track_low_thresh: 0.2 # threshold for the second association +new_track_thresh: 0.5 # threshold for init new track if the detection does not match any tracks +track_buffer: 100 # buffer to calculate the time when to remove tracks +match_thresh: 0.4 # threshold for matching tracks +fuse_score: False # Whether to fuse confidence scores with the iou distances before matching +# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) + +# BoT-SORT settings +gmc_method: sparseOptFlow # method of global motion compensation +# ReID model related thresh (not supported yet) +proximity_thresh: 0.6 +appearance_thresh: 0.35 +with_reid: False diff --git a/dimos/perception/detection/detectors/conftest.py b/dimos/perception/detection/detectors/conftest.py new file mode 100644 index 0000000000..9cb600aeff --- /dev/null +++ b/dimos/perception/detection/detectors/conftest.py @@ -0,0 +1,38 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.perception.detection.detectors.yolo import Yolo2DDetector +from dimos.utils.data import get_data + + +@pytest.fixture(scope="session") +def test_image(): + """Load the test image used for detector tests.""" + return Image.from_file(get_data("cafe.jpg")) + + +@pytest.fixture(scope="session") +def person_detector(): + """Create a YoloPersonDetector instance.""" + return YoloPersonDetector() + + +@pytest.fixture(scope="session") +def bbox_detector(): + """Create a Yolo2DDetector instance for general object detection.""" + return Yolo2DDetector() diff --git a/dimos/perception/detection/detectors/detic.py b/dimos/perception/detection/detectors/detic.py new file mode 100644 index 0000000000..288a3e056d --- /dev/null +++ b/dimos/perception/detection/detectors/detic.py @@ -0,0 +1,426 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +import os +import sys + +import numpy as np + +# Add Detic to Python path +from dimos.constants import DIMOS_PROJECT_ROOT +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.detectors.types import Detector +from dimos.perception.detection2d.utils import plot_results + +detic_path = DIMOS_PROJECT_ROOT / "dimos/models/Detic" +if str(detic_path) not in sys.path: + sys.path.append(str(detic_path)) + sys.path.append(str(detic_path / "third_party/CenterNet2")) + +# PIL patch for compatibility +import PIL.Image + +if not hasattr(PIL.Image, "LINEAR") and hasattr(PIL.Image, "BILINEAR"): + PIL.Image.LINEAR = PIL.Image.BILINEAR # type: ignore[attr-defined] + +# Detectron2 imports +from detectron2.config import get_cfg # type: ignore[import-not-found] +from detectron2.data import MetadataCatalog # type: ignore[import-not-found] + + +# Simple tracking implementation +class SimpleTracker: + """Simple IOU-based tracker implementation without external dependencies""" + + def __init__(self, iou_threshold: float = 0.3, max_age: int = 5) -> None: + self.iou_threshold = iou_threshold + self.max_age = max_age + self.next_id = 1 + self.tracks = {} # type: ignore[var-annotated] # id -> {bbox, class_id, age, mask, etc} + + def _calculate_iou(self, bbox1, bbox2): # type: ignore[no-untyped-def] + """Calculate IoU between two bboxes in format [x1,y1,x2,y2]""" + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + + if x2 < x1 or y2 < y1: + return 0.0 + + intersection = (x2 - x1) * (y2 - y1) + area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0 + + def update(self, detections, masks): # type: ignore[no-untyped-def] + """Update tracker with new detections + + Args: + detections: List of [x1,y1,x2,y2,score,class_id] + masks: List of segmentation masks corresponding to detections + + Returns: + List of [track_id, bbox, score, class_id, mask] + """ + if len(detections) == 0: + # Age existing tracks + for track_id in list(self.tracks.keys()): + self.tracks[track_id]["age"] += 1 + # Remove old tracks + if self.tracks[track_id]["age"] > self.max_age: + del self.tracks[track_id] + return [] + + # Convert to numpy for easier handling + if not isinstance(detections, np.ndarray): + detections = np.array(detections) + + result = [] + matched_indices = set() + + # Update existing tracks + for track_id, track in list(self.tracks.items()): + track["age"] += 1 + + if track["age"] > self.max_age: + del self.tracks[track_id] + continue + + # Find best matching detection for this track + best_iou = self.iou_threshold + best_idx = -1 + + for i, det in enumerate(detections): + if i in matched_indices: + continue + + # Check class match + if det[5] != track["class_id"]: + continue + + iou = self._calculate_iou(track["bbox"], det[:4]) # type: ignore[no-untyped-call] + if iou > best_iou: + best_iou = iou + best_idx = i + + # If we found a match, update the track + if best_idx >= 0: + self.tracks[track_id]["bbox"] = detections[best_idx][:4] + self.tracks[track_id]["score"] = detections[best_idx][4] + self.tracks[track_id]["age"] = 0 + self.tracks[track_id]["mask"] = masks[best_idx] + matched_indices.add(best_idx) + + # Add to results with mask + result.append( + [ + track_id, + detections[best_idx][:4], + detections[best_idx][4], + int(detections[best_idx][5]), + self.tracks[track_id]["mask"], + ] + ) + + # Create new tracks for unmatched detections + for i, det in enumerate(detections): + if i in matched_indices: + continue + + # Create new track + new_id = self.next_id + self.next_id += 1 + + self.tracks[new_id] = { + "bbox": det[:4], + "score": det[4], + "class_id": int(det[5]), + "age": 0, + "mask": masks[i], + } + + # Add to results with mask directly from the track + result.append([new_id, det[:4], det[4], int(det[5]), masks[i]]) + + return result + + +class Detic2DDetector(Detector): + def __init__( # type: ignore[no-untyped-def] + self, model_path=None, device: str = "cuda", vocabulary=None, threshold: float = 0.5 + ) -> None: + """ + Initialize the Detic detector with open vocabulary support. + + Args: + model_path (str): Path to a custom Detic model weights (optional) + device (str): Device to run inference on ('cuda' or 'cpu') + vocabulary (list): Custom vocabulary (list of class names) or 'lvis', 'objects365', 'openimages', 'coco' + threshold (float): Detection confidence threshold + """ + self.device = device + self.threshold = threshold + + # Set up Detic paths - already added to sys.path at module level + + # Import Detic modules + from centernet.config import add_centernet_config # type: ignore[import-not-found] + from detic.config import add_detic_config # type: ignore[import-not-found] + from detic.modeling.text.text_encoder import ( # type: ignore[import-not-found] + build_text_encoder, + ) + from detic.modeling.utils import reset_cls_test # type: ignore[import-not-found] + + # Keep reference to these functions for later use + self.reset_cls_test = reset_cls_test + self.build_text_encoder = build_text_encoder + + # Setup model configuration + self.cfg = get_cfg() + add_centernet_config(self.cfg) + add_detic_config(self.cfg) + + # Use default Detic config + self.cfg.merge_from_file( + os.path.join( + detic_path, "configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml" + ) + ) + + # Set default weights if not provided + if model_path is None: + self.cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth" + else: + self.cfg.MODEL.WEIGHTS = model_path + + # Set device + if device == "cpu": + self.cfg.MODEL.DEVICE = "cpu" + + # Set detection threshold + self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold + self.cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" + self.cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True + + # Built-in datasets for Detic - use absolute paths with detic_path + self.builtin_datasets = { + "lvis": { + "metadata": "lvis_v1_val", + "classifier": os.path.join( + detic_path, "datasets/metadata/lvis_v1_clip_a+cname.npy" + ), + }, + "objects365": { + "metadata": "objects365_v2_val", + "classifier": os.path.join( + detic_path, "datasets/metadata/o365_clip_a+cnamefix.npy" + ), + }, + "openimages": { + "metadata": "oid_val_expanded", + "classifier": os.path.join(detic_path, "datasets/metadata/oid_clip_a+cname.npy"), + }, + "coco": { + "metadata": "coco_2017_val", + "classifier": os.path.join(detic_path, "datasets/metadata/coco_clip_a+cname.npy"), + }, + } + + # Override config paths to use absolute paths + self.cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = os.path.join( + detic_path, "datasets/metadata/lvis_v1_train_cat_info.json" + ) + + # Initialize model + self.predictor = None + + # Setup with initial vocabulary + vocabulary = vocabulary or "lvis" + self.setup_vocabulary(vocabulary) # type: ignore[no-untyped-call] + + # Initialize our simple tracker + self.tracker = SimpleTracker(iou_threshold=0.5, max_age=5) + + def setup_vocabulary(self, vocabulary): # type: ignore[no-untyped-def] + """ + Setup the model's vocabulary. + + Args: + vocabulary: Either a string ('lvis', 'objects365', 'openimages', 'coco') + or a list of class names for custom vocabulary. + """ + if self.predictor is None: + # Initialize the model + from detectron2.engine import DefaultPredictor # type: ignore[import-not-found] + + self.predictor = DefaultPredictor(self.cfg) + + if isinstance(vocabulary, str) and vocabulary in self.builtin_datasets: + # Use built-in dataset + dataset = vocabulary + metadata = MetadataCatalog.get(self.builtin_datasets[dataset]["metadata"]) + classifier = self.builtin_datasets[dataset]["classifier"] + num_classes = len(metadata.thing_classes) + self.class_names = metadata.thing_classes + else: + # Use custom vocabulary + if isinstance(vocabulary, str): + # If it's a string but not a built-in dataset, treat as a file + try: + with open(vocabulary) as f: + class_names = [line.strip() for line in f if line.strip()] + except: + # Default to LVIS if there's an issue + print(f"Error loading vocabulary from {vocabulary}, using LVIS") + return self.setup_vocabulary("lvis") # type: ignore[no-untyped-call] + else: + # Assume it's a list of class names + class_names = vocabulary + + # Create classifier from text embeddings + metadata = MetadataCatalog.get("__unused") + metadata.thing_classes = class_names + self.class_names = class_names + + # Generate CLIP embeddings for custom vocabulary + classifier = self._get_clip_embeddings(class_names) + num_classes = len(class_names) + + # Reset model with new vocabulary + self.reset_cls_test(self.predictor.model, classifier, num_classes) # type: ignore[attr-defined] + return self.class_names + + def _get_clip_embeddings(self, vocabulary, prompt: str = "a "): # type: ignore[no-untyped-def] + """ + Generate CLIP embeddings for a vocabulary list. + + Args: + vocabulary (list): List of class names + prompt (str): Prompt prefix to use for CLIP + + Returns: + torch.Tensor: Tensor of embeddings + """ + text_encoder = self.build_text_encoder(pretrain=True) + text_encoder.eval() + texts = [prompt + x for x in vocabulary] + emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + return emb + + def process_image(self, image: Image): # type: ignore[no-untyped-def] + """ + Process an image and return detection results. + + Args: + image: Input image in BGR format (OpenCV) + + Returns: + tuple: (bboxes, track_ids, class_ids, confidences, names, masks) + - bboxes: list of [x1, y1, x2, y2] coordinates + - track_ids: list of tracking IDs (or -1 if no tracking) + - class_ids: list of class indices + - confidences: list of detection confidences + - names: list of class names + - masks: list of segmentation masks (numpy arrays) + """ + # Run inference with Detic + outputs = self.predictor(image.to_opencv()) # type: ignore[misc] + instances = outputs["instances"].to("cpu") + + # Extract bounding boxes, classes, scores, and masks + if len(instances) == 0: + return [], [], [], [], [] # , [] + + boxes = instances.pred_boxes.tensor.numpy() + class_ids = instances.pred_classes.numpy() + scores = instances.scores.numpy() + masks = instances.pred_masks.numpy() + + # Convert boxes to [x1, y1, x2, y2] format + bboxes = [] + for box in boxes: + x1, y1, x2, y2 = box.tolist() + bboxes.append([x1, y1, x2, y2]) + + # Get class names + [self.class_names[class_id] for class_id in class_ids] + + # Apply tracking + detections = [] + filtered_masks = [] + for i, bbox in enumerate(bboxes): + if scores[i] >= self.threshold: + # Format for tracker: [x1, y1, x2, y2, score, class_id] + detections.append([*bbox, scores[i], class_ids[i]]) + filtered_masks.append(masks[i]) + + if not detections: + return [], [], [], [], [] # , [] + + # Update tracker with detections and correctly aligned masks + track_results = self.tracker.update(detections, filtered_masks) # type: ignore[no-untyped-call] + + # Process tracking results + track_ids = [] + tracked_bboxes = [] + tracked_class_ids = [] + tracked_scores = [] + tracked_names = [] + tracked_masks = [] + + for track_id, bbox, score, class_id, mask in track_results: + track_ids.append(int(track_id)) + tracked_bboxes.append(bbox.tolist() if isinstance(bbox, np.ndarray) else bbox) + tracked_class_ids.append(int(class_id)) + tracked_scores.append(score) + tracked_names.append(self.class_names[int(class_id)]) + tracked_masks.append(mask) + + return ( + tracked_bboxes, + track_ids, + tracked_class_ids, + tracked_scores, + tracked_names, + # tracked_masks, + ) + + def visualize_results( # type: ignore[no-untyped-def] + self, image, bboxes, track_ids, class_ids, confidences, names: Sequence[str] + ): + """ + Generate visualization of detection results. + + Args: + image: Original input image + bboxes: List of bounding boxes + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + + Returns: + Image with visualized detections + """ + + return plot_results(image, bboxes, track_ids, class_ids, confidences, names) + + def cleanup(self) -> None: + """Clean up resources.""" + # Nothing specific to clean up for Detic + pass diff --git a/dimos/perception/detection/detectors/person/test_person_detectors.py b/dimos/perception/detection/detectors/person/test_person_detectors.py new file mode 100644 index 0000000000..2ed7cdc7dc --- /dev/null +++ b/dimos/perception/detection/detectors/person/test_person_detectors.py @@ -0,0 +1,160 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.perception.detection.type import Detection2DPerson, ImageDetections2D + + +@pytest.fixture(scope="session") +def people(person_detector, test_image): + return person_detector.process_image(test_image) + + +@pytest.fixture(scope="session") +def person(people): + return people[0] + + +def test_person_detection(people) -> None: + """Test that we can detect people with pose keypoints.""" + assert len(people) > 0 + + # Check first person + person = people[0] + assert isinstance(person, Detection2DPerson) + assert person.confidence > 0 + assert len(person.bbox) == 4 # bbox is now a tuple + assert person.keypoints.shape == (17, 2) + assert person.keypoint_scores.shape == (17,) + + +def test_person_properties(people) -> None: + """Test Detection2DPerson object properties and methods.""" + person = people[0] + + # Test bounding box properties + assert person.width > 0 + assert person.height > 0 + assert len(person.center) == 2 + + # Test keypoint access + nose_xy, nose_conf = person.get_keypoint("nose") + assert nose_xy.shape == (2,) + assert 0 <= nose_conf <= 1 + + # Test visible keypoints + visible = person.get_visible_keypoints(threshold=0.5) + assert len(visible) > 0 + assert all(isinstance(name, str) for name, _, _ in visible) + assert all(xy.shape == (2,) for _, xy, _ in visible) + assert all(0 <= conf <= 1 for _, _, conf in visible) + + +def test_person_normalized_coords(people) -> None: + """Test normalized coordinates if available.""" + person = people[0] + + if person.keypoints_normalized is not None: + assert person.keypoints_normalized.shape == (17, 2) + # Check all values are in 0-1 range + assert (person.keypoints_normalized >= 0).all() + assert (person.keypoints_normalized <= 1).all() + + if person.bbox_normalized is not None: + assert person.bbox_normalized.shape == (4,) + assert (person.bbox_normalized >= 0).all() + assert (person.bbox_normalized <= 1).all() + + +def test_multiple_people(people) -> None: + """Test that multiple people can be detected.""" + print(f"\nDetected {len(people)} people in test image") + + for i, person in enumerate(people[:3]): # Show first 3 + print(f"\nPerson {i}:") + print(f" Confidence: {person.confidence:.3f}") + print(f" Size: {person.width:.1f} x {person.height:.1f}") + + visible = person.get_visible_keypoints(threshold=0.8) + print(f" High-confidence keypoints (>0.8): {len(visible)}") + for name, xy, conf in visible[:5]: + print(f" {name}: ({xy[0]:.1f}, {xy[1]:.1f}) conf={conf:.3f}") + + +def test_image_detections2d_structure(people) -> None: + """Test that process_image returns ImageDetections2D.""" + assert isinstance(people, ImageDetections2D) + assert len(people.detections) > 0 + assert all(isinstance(d, Detection2DPerson) for d in people.detections) + + +def test_invalid_keypoint(test_image) -> None: + """Test error handling for invalid keypoint names.""" + # Create a dummy Detection2DPerson + import numpy as np + + person = Detection2DPerson( + # Detection2DBBox fields + bbox=(0.0, 0.0, 100.0, 100.0), + track_id=0, + class_id=0, + confidence=0.9, + name="person", + ts=test_image.ts, + image=test_image, + # Detection2DPerson fields + keypoints=np.zeros((17, 2)), + keypoint_scores=np.zeros(17), + ) + + with pytest.raises(ValueError): + person.get_keypoint("invalid_keypoint") + + +def test_person_annotations(person) -> None: + # Test text annotations + text_anns = person.to_text_annotation() + print(f"\nText annotations: {len(text_anns)}") + for i, ann in enumerate(text_anns): + print(f" {i}: {ann.text}") + assert len(text_anns) == 3 # confidence, name/track_id, keypoints count + assert any("keypoints:" in ann.text for ann in text_anns) + + # Test points annotations + points_anns = person.to_points_annotation() + print(f"\nPoints annotations: {len(points_anns)}") + + # Count different types (use actual LCM constants) + from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation + + bbox_count = sum(1 for ann in points_anns if ann.type == PointsAnnotation.LINE_LOOP) # 2 + keypoint_count = sum(1 for ann in points_anns if ann.type == PointsAnnotation.POINTS) # 1 + skeleton_count = sum(1 for ann in points_anns if ann.type == PointsAnnotation.LINE_LIST) # 4 + + print(f" - Bounding boxes: {bbox_count}") + print(f" - Keypoint circles: {keypoint_count}") + print(f" - Skeleton lines: {skeleton_count}") + + assert bbox_count >= 1 # At least the person bbox + assert keypoint_count >= 1 # At least some visible keypoints + assert skeleton_count >= 1 # At least some skeleton connections + + # Test full image annotations + img_anns = person.to_image_annotations() + assert img_anns.texts_length == len(text_anns) + assert img_anns.points_length == len(points_anns) + + print("\n✓ Person annotations working correctly!") + print(f" - {len(person.get_visible_keypoints(0.5))}/17 visible keypoints") diff --git a/dimos/perception/detection/detectors/person/yolo.py b/dimos/perception/detection/detectors/person/yolo.py new file mode 100644 index 0000000000..bc9bf25734 --- /dev/null +++ b/dimos/perception/detection/detectors/person/yolo.py @@ -0,0 +1,80 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ultralytics import YOLO + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.detectors.types import Detector +from dimos.perception.detection.type import ImageDetections2D +from dimos.utils.data import get_data +from dimos.utils.gpu_utils import is_cuda_available +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class YoloPersonDetector(Detector): + def __init__( + self, + model_path: str = "models_yolo", + model_name: str = "yolo11n-pose.pt", + device: str | None = None, + ) -> None: + self.model = YOLO(get_data(model_path) / model_name, task="track") + + self.tracker = get_data(model_path) / "botsort.yaml" + + if device: + self.device = device + return + + if is_cuda_available(): # type: ignore[no-untyped-call] + self.device = "cuda" + logger.info("Using CUDA for YOLO person detector") + else: + self.device = "cpu" + logger.info("Using CPU for YOLO person detector") + + def process_image(self, image: Image) -> ImageDetections2D: + """Process image and return detection results. + + Args: + image: Input image + + Returns: + ImageDetections2D containing Detection2DPerson objects with pose keypoints + """ + results = self.model.track( + source=image.to_opencv(), + verbose=False, + conf=0.5, + tracker=self.tracker, + persist=True, + device=self.device, + ) + return ImageDetections2D.from_ultralytics_result(image, results) + + def stop(self) -> None: + """ + Clean up resources used by the detector, including tracker threads. + """ + if hasattr(self.model, "predictor") and self.model.predictor is not None: + predictor = self.model.predictor + if hasattr(predictor, "trackers") and predictor.trackers: + for tracker in predictor.trackers: + if hasattr(tracker, "tracker") and hasattr(tracker.tracker, "gmc"): + gmc = tracker.tracker.gmc + if hasattr(gmc, "executor") and gmc.executor is not None: + gmc.executor.shutdown(wait=True) + self.model.predictor = None diff --git a/dimos/perception/detection/detectors/test_bbox_detectors.py b/dimos/perception/detection/detectors/test_bbox_detectors.py new file mode 100644 index 0000000000..bd9c1358b5 --- /dev/null +++ b/dimos/perception/detection/detectors/test_bbox_detectors.py @@ -0,0 +1,158 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.perception.detection.type import Detection2D, ImageDetections2D + + +@pytest.fixture(params=["bbox_detector", "person_detector"], scope="session") +def detector(request): + """Parametrized fixture that provides both bbox and person detectors.""" + return request.getfixturevalue(request.param) + + +@pytest.fixture(scope="session") +def detections(detector, test_image): + """Get ImageDetections2D from any detector.""" + return detector.process_image(test_image) + + +def test_detection_basic(detections) -> None: + """Test that we can detect objects with all detectors.""" + assert len(detections.detections) > 0 + + # Check first detection + detection = detections.detections[0] + assert isinstance(detection, Detection2D) + assert detection.confidence > 0 + assert len(detection.bbox) == 4 # bbox is a tuple (x1, y1, x2, y2) + assert detection.class_id >= 0 + assert detection.name is not None + + +def test_detection_bbox_properties(detections) -> None: + """Test Detection2D bbox properties work for all detectors.""" + detection = detections.detections[0] + + # Test bounding box is valid + x1, y1, x2, y2 = detection.bbox + assert x2 > x1, "x2 should be greater than x1" + assert y2 > y1, "y2 should be greater than y1" + assert all(coord >= 0 for coord in detection.bbox), "Coordinates should be non-negative" + + # Test bbox volume + volume = detection.bbox_2d_volume() + assert volume > 0 + expected_volume = (x2 - x1) * (y2 - y1) + assert abs(volume - expected_volume) < 0.01 + + # Test center calculation + center_x, center_y, width, height = detection.get_bbox_center() + assert center_x == (x1 + x2) / 2.0 + assert center_y == (y1 + y2) / 2.0 + assert width == x2 - x1 + assert height == y2 - y1 + + +def test_detection_cropped_image(detections, test_image) -> None: + """Test cropping image to detection bbox.""" + detection = detections.detections[0] + + # Test cropped image + cropped = detection.cropped_image(padding=20) + assert cropped is not None + + # Cropped image should be smaller than original (usually) + if test_image.shape: + assert cropped.shape[0] <= test_image.shape[0] + assert cropped.shape[1] <= test_image.shape[1] + + +def test_detection_annotations(detections) -> None: + """Test annotation generation for detections.""" + detection = detections.detections[0] + + # Test text annotations - all detections should have at least 2 + text_annotations = detection.to_text_annotation() + assert len(text_annotations) >= 2 # confidence and name/track_id (person has keypoints too) + + # Test points annotations - at least bbox + points_annotations = detection.to_points_annotation() + assert len(points_annotations) >= 1 # At least the bbox polygon + + # Test image annotations + annotations = detection.to_image_annotations() + assert annotations.texts_length >= 2 + assert annotations.points_length >= 1 + + +def test_detection_ros_conversion(detections) -> None: + """Test conversion to ROS Detection2D message.""" + detection = detections.detections[0] + + ros_det = detection.to_ros_detection2d() + + # Check bbox conversion + center_x, center_y, width, height = detection.get_bbox_center() + assert abs(ros_det.bbox.center.position.x - center_x) < 0.01 + assert abs(ros_det.bbox.center.position.y - center_y) < 0.01 + assert abs(ros_det.bbox.size_x - width) < 0.01 + assert abs(ros_det.bbox.size_y - height) < 0.01 + + # Check confidence and class_id + assert len(ros_det.results) > 0 + assert ros_det.results[0].hypothesis.score == detection.confidence + assert ros_det.results[0].hypothesis.class_id == detection.class_id + + +def test_detection_is_valid(detections) -> None: + """Test bbox validation.""" + detection = detections.detections[0] + + # Detection from real detector should be valid + assert detection.is_valid() + + +def test_image_detections2d_structure(detections) -> None: + """Test that process_image returns ImageDetections2D.""" + assert isinstance(detections, ImageDetections2D) + assert len(detections.detections) > 0 + assert all(isinstance(d, Detection2D) for d in detections.detections) + + +def test_multiple_detections(detections) -> None: + """Test that multiple objects can be detected.""" + print(f"\nDetected {len(detections.detections)} objects in test image") + + for i, detection in enumerate(detections.detections[:5]): # Show first 5 + print(f"\nDetection {i}:") + print(f" Class: {detection.name} (id: {detection.class_id})") + print(f" Confidence: {detection.confidence:.3f}") + print( + f" Bbox: ({detection.bbox[0]:.1f}, {detection.bbox[1]:.1f}, {detection.bbox[2]:.1f}, {detection.bbox[3]:.1f})" + ) + print(f" Track ID: {detection.track_id}") + + +def test_detection_string_representation(detections) -> None: + """Test string representation of detections.""" + detection = detections.detections[0] + str_repr = str(detection) + + # Should contain class name (either Detection2DBBox or Detection2DPerson) + assert "Detection2D" in str_repr + + # Should show object name + assert detection.name in str_repr or f"class_{detection.class_id}" in str_repr diff --git a/dimos/perception/detection/detectors/types.py b/dimos/perception/detection/detectors/types.py new file mode 100644 index 0000000000..e85c5ae18e --- /dev/null +++ b/dimos/perception/detection/detectors/types.py @@ -0,0 +1,23 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import ImageDetections2D + + +class Detector(ABC): + @abstractmethod + def process_image(self, image: Image) -> ImageDetections2D: ... diff --git a/dimos/perception/detection/detectors/yolo.py b/dimos/perception/detection/detectors/yolo.py new file mode 100644 index 0000000000..a5f9cc8282 --- /dev/null +++ b/dimos/perception/detection/detectors/yolo.py @@ -0,0 +1,83 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ultralytics import YOLO + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.detectors.types import Detector +from dimos.perception.detection.type import ImageDetections2D +from dimos.utils.data import get_data +from dimos.utils.gpu_utils import is_cuda_available +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Yolo2DDetector(Detector): + def __init__( + self, + model_path: str = "models_yolo", + model_name: str = "yolo11n.pt", + device: str | None = None, + ) -> None: + self.model = YOLO( + get_data(model_path) / model_name, + task="detect", + ) + + if device: + self.device = device + return + + if is_cuda_available(): # type: ignore[no-untyped-call] + self.device = "cuda" + logger.debug("Using CUDA for YOLO 2d detector") + else: + self.device = "cpu" + logger.debug("Using CPU for YOLO 2d detector") + + def process_image(self, image: Image) -> ImageDetections2D: + """ + Process an image and return detection results. + + Args: + image: Input image + + Returns: + ImageDetections2D containing all detected objects + """ + results = self.model.track( + source=image.to_opencv(), + device=self.device, + conf=0.5, + iou=0.6, + persist=True, + verbose=False, + ) + + return ImageDetections2D.from_ultralytics_result(image, results) + + def stop(self) -> None: + """ + Clean up resources used by the detector, including tracker threads. + """ + if hasattr(self.model, "predictor") and self.model.predictor is not None: + predictor = self.model.predictor + if hasattr(predictor, "trackers") and predictor.trackers: + for tracker in predictor.trackers: + if hasattr(tracker, "tracker") and hasattr(tracker.tracker, "gmc"): + gmc = tracker.tracker.gmc + if hasattr(gmc, "executor") and gmc.executor is not None: + gmc.executor.shutdown(wait=True) + self.model.predictor = None diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py new file mode 100644 index 0000000000..470c2bd42d --- /dev/null +++ b/dimos/perception/detection/module2D.py @@ -0,0 +1,179 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + ImageAnnotations, +) +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos import spec +from dimos.core import DimosCluster, In, Module, Out, rpc +from dimos.core.module import ModuleConfig +from dimos.msgs.geometry_msgs import Transform, Vector3 +from dimos.msgs.sensor_msgs import CameraInfo, Image +from dimos.msgs.sensor_msgs.Image import sharpness_barrier +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.detectors import Detector # type: ignore[attr-defined] +from dimos.perception.detection.detectors.yolo import Yolo2DDetector +from dimos.perception.detection.type import Filter2D, ImageDetections2D +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.reactive import backpressure + + +@dataclass +class Config(ModuleConfig): + max_freq: float = 10 + detector: Callable[[Any], Detector] | None = Yolo2DDetector + publish_detection_images: bool = True + camera_info: CameraInfo = None # type: ignore + filter: list[Filter2D] | Filter2D | None = None + + def __post_init__(self) -> None: + if self.filter is None: + self.filter = [] + elif not isinstance(self.filter, list): + self.filter = [self.filter] + + +class Detection2DModule(Module): + default_config = Config + config: Config + detector: Detector + + image: In[Image] = None # type: ignore + + detections: Out[Detection2DArray] = None # type: ignore + annotations: Out[ImageAnnotations] = None # type: ignore + + detected_image_0: Out[Image] = None # type: ignore + detected_image_1: Out[Image] = None # type: ignore + detected_image_2: Out[Image] = None # type: ignore + + cnt: int = 0 + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.detector = self.config.detector() # type: ignore[call-arg, misc] + self.vlm_detections_subject = Subject() # type: ignore[var-annotated] + self.previous_detection_count = 0 + + def process_image_frame(self, image: Image) -> ImageDetections2D: + imageDetections = self.detector.process_image(image) + if not self.config.filter: + return imageDetections + return imageDetections.filter(*self.config.filter) # type: ignore[misc, return-value] + + @simple_mcache + def sharp_image_stream(self) -> Observable[Image]: + return backpressure( + self.image.pure_observable().pipe( + sharpness_barrier(self.config.max_freq), + ) + ) + + @simple_mcache + def detection_stream_2d(self) -> Observable[ImageDetections2D]: + return backpressure(self.sharp_image_stream().pipe(ops.map(self.process_image_frame))) + + def track(self, detections: ImageDetections2D) -> None: + sensor_frame = self.tf.get("sensor", "camera_optical", detections.image.ts, 5.0) + + if not sensor_frame: + return + + if not detections.detections: + return + + sensor_frame.child_frame_id = "sensor_frame" + transforms = [sensor_frame] + + current_count = len(detections.detections) + max_count = max(current_count, self.previous_detection_count) + + # Publish transforms for all detection slots up to max_count + for index in range(max_count): + if index < current_count: + # Active detection - compute real position + detection = detections.detections[index] + position_3d = self.pixel_to_3d( # type: ignore[attr-defined] + detection.center_bbox, # type: ignore[attr-defined] + self.config.camera_info, + assumed_depth=1.0, + ) + else: + # No detection at this index - publish zero transform + position_3d = Vector3(0.0, 0.0, 0.0) + + transforms.append( + Transform( + frame_id=sensor_frame.child_frame_id, + child_frame_id=f"det_{index}", + ts=detections.image.ts, + translation=position_3d, + ) + ) + + self.previous_detection_count = current_count + self.tf.publish(*transforms) + + @rpc + def start(self) -> None: + # self.detection_stream_2d().subscribe(self.track) + + self.detection_stream_2d().subscribe( + lambda det: self.detections.publish(det.to_ros_detection2d_array()) + ) + + self.detection_stream_2d().subscribe( + lambda det: self.annotations.publish(det.to_foxglove_annotations()) + ) + + def publish_cropped_images(detections: ImageDetections2D) -> None: + for index, detection in enumerate(detections[:3]): + image_topic = getattr(self, "detected_image_" + str(index)) + image_topic.publish(detection.cropped_image()) + + if self.config.publish_detection_images: + self.detection_stream_2d().subscribe(publish_cropped_images) + + @rpc + def stop(self) -> None: + return super().stop() # type: ignore[no-any-return] + + +def deploy( # type: ignore[no-untyped-def] + dimos: DimosCluster, + camera: spec.Camera, + prefix: str = "/detector2d", + **kwargs, +) -> Detection2DModule: + from dimos.core import LCMTransport + + detector = Detection2DModule(**kwargs) + detector.image.connect(camera.color_image) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.start() + return detector diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py new file mode 100644 index 0000000000..2f7e3358bf --- /dev/null +++ b/dimos/perception/detection/module3D.py @@ -0,0 +1,231 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + ImageAnnotations, +) +from lcm_msgs.foxglove_msgs import SceneUpdate # type: ignore[import-not-found] +from reactivex import operators as ops +from reactivex.observable import Observable + +from dimos import spec +from dimos.agents2 import skill # type: ignore[attr-defined] +from dimos.core import DimosCluster, In, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.module2D import Detection2DModule +from dimos.perception.detection.type import ( + ImageDetections2D, + ImageDetections3DPC, +) +from dimos.perception.detection.type.detection3d import Detection3DPC +from dimos.types.timestamped import align_timestamped +from dimos.utils.reactive import backpressure + + +class Detection3DModule(Detection2DModule): + image: In[Image] = None # type: ignore + pointcloud: In[PointCloud2] = None # type: ignore + + detections: Out[Detection2DArray] = None # type: ignore + annotations: Out[ImageAnnotations] = None # type: ignore + scene_update: Out[SceneUpdate] = None # type: ignore + + # just for visualization, + # emits latest pointclouds of detected objects in a frame + detected_pointcloud_0: Out[PointCloud2] = None # type: ignore + detected_pointcloud_1: Out[PointCloud2] = None # type: ignore + detected_pointcloud_2: Out[PointCloud2] = None # type: ignore + + # just for visualization, emits latest top 3 detections in a frame + detected_image_0: Out[Image] = None # type: ignore + detected_image_1: Out[Image] = None # type: ignore + detected_image_2: Out[Image] = None # type: ignore + + detection_3d_stream: Observable[ImageDetections3DPC] | None = None + + def process_frame( + self, + detections: ImageDetections2D, + pointcloud: PointCloud2, + transform: Transform, + ) -> ImageDetections3DPC: + if not transform: + return ImageDetections3DPC(detections.image, []) + + detection3d_list: list[Detection3DPC] = [] + for detection in detections: + detection3d = Detection3DPC.from_2d( + detection, + world_pointcloud=pointcloud, + camera_info=self.config.camera_info, + world_to_optical_transform=transform, + ) + if detection3d is not None: + detection3d_list.append(detection3d) + + return ImageDetections3DPC(detections.image, detection3d_list) + + def pixel_to_3d( + self, + pixel: tuple[int, int], + assumed_depth: float = 1.0, + ) -> Vector3: + """Unproject 2D pixel coordinates to 3D position in camera optical frame. + + Args: + camera_info: Camera calibration information + assumed_depth: Assumed depth in meters (default 1.0m from camera) + + Returns: + Vector3 position in camera optical frame coordinates + """ + # Extract camera intrinsics + fx, fy = self.config.camera_info.K[0], self.config.camera_info.K[4] + cx, cy = self.config.camera_info.K[2], self.config.camera_info.K[5] + + # Unproject pixel to normalized camera coordinates + x_norm = (pixel[0] - cx) / fx + y_norm = (pixel[1] - cy) / fy + + # Create 3D point at assumed depth in camera optical frame + # Camera optical frame: X right, Y down, Z forward + return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) + + @skill() + def ask_vlm(self, question: str) -> str: + """asks a visual model about the view of the robot, for example + is the bannana in the trunk? + """ + from dimos.models.vl.qwen import QwenVlModel + + model = QwenVlModel() + image = self.image.get_next() + return model.query(image, question) + + # @skill # type: ignore[arg-type] + @rpc + def nav_vlm(self, question: str) -> str: + """ + query visual model about the view in front of the camera + you can ask to mark objects like: + + "red cup on the table left of the pencil" + "laptop on the desk" + "a person wearing a red shirt" + """ + from dimos.models.vl.qwen import QwenVlModel + + model = QwenVlModel() + image = self.image.get_next() + result = model.query_detections(image, question) + + print("VLM result:", result, "for", image, "and question", question) + + if isinstance(result, str) or not result or not len(result): + return None # type: ignore[return-value] + + detections: ImageDetections2D = result + + print(detections) + if not len(detections): + print("No 2d detections") + return None # type: ignore[return-value] + + pc = self.pointcloud.get_next() + transform = self.tf.get("camera_optical", pc.frame_id, detections.image.ts, 5.0) + + detections3d = self.process_frame(detections, pc, transform) + + if len(detections3d): + return detections3d[0].pose # type: ignore[no-any-return] + print("No 3d detections, projecting 2d") + + center = detections[0].get_bbox_center() + return PoseStamped( + ts=detections.image.ts, + frame_id="world", + position=self.pixel_to_3d(center, assumed_depth=1.5), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + + @rpc + def start(self) -> None: + super().start() + + def detection2d_to_3d(args): # type: ignore[no-untyped-def] + detections, pc = args + transform = self.tf.get("camera_optical", pc.frame_id, detections.image.ts, 5.0) + return self.process_frame(detections, pc, transform) + + self.detection_stream_3d = align_timestamped( + backpressure(self.detection_stream_2d()), + self.pointcloud.observable(), # type: ignore[no-untyped-call] + match_tolerance=0.25, + buffer_size=20.0, + ).pipe(ops.map(detection2d_to_3d)) + + self.detection_stream_3d.subscribe(self._publish_detections) + + @rpc + def stop(self) -> None: + super().stop() + + def _publish_detections(self, detections: ImageDetections3DPC) -> None: + if not detections: + return + + for index, detection in enumerate(detections[:3]): + pointcloud_topic = getattr(self, "detected_pointcloud_" + str(index)) + pointcloud_topic.publish(detection.pointcloud) + + self.scene_update.publish(detections.to_foxglove_scene_update()) + + +def deploy( # type: ignore[no-untyped-def] + dimos: DimosCluster, + lidar: spec.Pointcloud, + camera: spec.Camera, + prefix: str = "/detector3d", + **kwargs, +) -> Detection3DModule: + from dimos.core import LCMTransport + + detector = dimos.deploy(Detection3DModule, camera_info=camera.hardware_camera_info, **kwargs) # type: ignore[attr-defined] + + detector.image.connect(camera.color_image) + detector.pointcloud.connect(lidar.pointcloud) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + detector.scene_update.transport = LCMTransport(f"{prefix}/scene_update", SceneUpdate) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.detected_pointcloud_0.transport = LCMTransport(f"{prefix}/pointcloud/0", PointCloud2) + detector.detected_pointcloud_1.transport = LCMTransport(f"{prefix}/pointcloud/1", PointCloud2) + detector.detected_pointcloud_2.transport = LCMTransport(f"{prefix}/pointcloud/2", PointCloud2) + + detector.start() + + return detector # type: ignore[no-any-return] + + +detection3d_module = Detection3DModule.blueprint + +__all__ = ["Detection3DModule", "deploy", "detection3d_module"] diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py new file mode 100644 index 0000000000..3b57f70418 --- /dev/null +++ b/dimos/perception/detection/moduleDB.py @@ -0,0 +1,343 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from copy import copy +import threading +import time +from typing import Any + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + ImageAnnotations, +) +from lcm_msgs.foxglove_msgs import SceneUpdate # type: ignore[import-not-found] +from reactivex.observable import Observable + +from dimos import spec +from dimos.core import DimosCluster, In, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.module3D import Detection3DModule +from dimos.perception.detection.type import ImageDetections3DPC, TableStr +from dimos.perception.detection.type.detection3d import Detection3DPC + + +# Represents an object in space, as collection of 3d detections over time +class Object3D(Detection3DPC): + best_detection: Detection3DPC | None = None + center: Vector3 | None = None # type: ignore + track_id: str | None = None # type: ignore + detections: int = 0 + + def to_repr_dict(self) -> dict[str, Any]: + if self.center is None: + center_str = "None" + else: + center_str = ( + "[" + ", ".join(list(map(lambda n: f"{n:1f}", self.center.to_list()))) + "]" + ) + return { + "object_id": self.track_id, + "detections": self.detections, + "center": center_str, + } + + def __init__( # type: ignore[no-untyped-def] + self, track_id: str, detection: Detection3DPC | None = None, *args, **kwargs + ) -> None: + if detection is None: + return + self.ts = detection.ts + self.track_id = track_id + self.class_id = detection.class_id + self.name = detection.name + self.confidence = detection.confidence + self.pointcloud = detection.pointcloud + self.bbox = detection.bbox + self.transform = detection.transform + self.center = detection.center + self.frame_id = detection.frame_id + self.detections = self.detections + 1 + self.best_detection = detection + + def __add__(self, detection: Detection3DPC) -> "Object3D": + if self.track_id is None: + raise ValueError("Cannot add detection to object with None track_id") + new_object = Object3D(self.track_id) + new_object.bbox = detection.bbox + new_object.confidence = max(self.confidence, detection.confidence) + new_object.ts = max(self.ts, detection.ts) + new_object.track_id = self.track_id + new_object.class_id = self.class_id + new_object.name = self.name + new_object.transform = self.transform + new_object.pointcloud = self.pointcloud + detection.pointcloud + new_object.frame_id = self.frame_id + new_object.center = (self.center + detection.center) / 2 + new_object.detections = self.detections + 1 + + if detection.bbox_2d_volume() > self.bbox_2d_volume(): + new_object.best_detection = detection + else: + new_object.best_detection = self.best_detection + + return new_object + + def get_image(self) -> Image | None: + return self.best_detection.image if self.best_detection else None + + def scene_entity_label(self) -> str: + return f"{self.name} ({self.detections})" + + def agent_encode(self): # type: ignore[no-untyped-def] + return { + "id": self.track_id, + "name": self.name, + "detections": self.detections, + "last_seen": f"{round(time.time() - self.ts)}s ago", + # "position": self.to_pose().position.agent_encode(), + } + + def to_pose(self) -> PoseStamped: + if self.best_detection is None or self.center is None: + raise ValueError("Cannot compute pose without best_detection and center") + + optical_inverse = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ).inverse() + + print("transform is", self.best_detection.transform) + + global_transform = optical_inverse + self.best_detection.transform + + print("inverse optical is", global_transform) + + print("obj center is", self.center) + global_pose = global_transform.to_pose() + print("Global pose:", global_pose) + global_pose.frame_id = self.best_detection.frame_id + print("remap to", self.best_detection.frame_id) + return PoseStamped( + position=self.center, orientation=Quaternion(), frame_id=self.best_detection.frame_id + ) + + +class ObjectDBModule(Detection3DModule, TableStr): + cnt: int = 0 + objects: dict[str, Object3D] + object_stream: Observable[Object3D] | None = None + + goto: Callable[[PoseStamped], Any] | None = None + + image: In[Image] = None # type: ignore + pointcloud: In[PointCloud2] = None # type: ignore + + detections: Out[Detection2DArray] = None # type: ignore + annotations: Out[ImageAnnotations] = None # type: ignore + + detected_pointcloud_0: Out[PointCloud2] = None # type: ignore + detected_pointcloud_1: Out[PointCloud2] = None # type: ignore + detected_pointcloud_2: Out[PointCloud2] = None # type: ignore + + detected_image_0: Out[Image] = None # type: ignore + detected_image_1: Out[Image] = None # type: ignore + detected_image_2: Out[Image] = None # type: ignore + + scene_update: Out[SceneUpdate] = None # type: ignore + + target: Out[PoseStamped] = None # type: ignore + + remembered_locations: dict[str, PoseStamped] + + @rpc + def start(self) -> None: + Detection3DModule.start(self) + + def update_objects(imageDetections: ImageDetections3DPC) -> None: + for detection in imageDetections.detections: + self.add_detection(detection) + + def scene_thread() -> None: + while True: + scene_update = self.to_foxglove_scene_update() + self.scene_update.publish(scene_update) + time.sleep(1.0) + + threading.Thread(target=scene_thread, daemon=True).start() + + self.detection_stream_3d.subscribe(update_objects) + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.goto = None + self.objects = {} + self.remembered_locations = {} + + def closest_object(self, detection: Detection3DPC) -> Object3D | None: + # Filter objects to only those with matching names + matching_objects = [obj for obj in self.objects.values() if obj.name == detection.name] + + if not matching_objects: + return None + + # Sort by distance + distances = sorted(matching_objects, key=lambda obj: detection.center.distance(obj.center)) + + return distances[0] + + def add_detections(self, detections: list[Detection3DPC]) -> list[Object3D]: + return [ + detection for detection in map(self.add_detection, detections) if detection is not None + ] + + def add_detection(self, detection: Detection3DPC): # type: ignore[no-untyped-def] + """Add detection to existing object or create new one.""" + closest = self.closest_object(detection) + if closest and closest.bounding_box_intersects(detection): + return self.add_to_object(closest, detection) + else: + return self.create_new_object(detection) + + def add_to_object(self, closest: Object3D, detection: Detection3DPC): # type: ignore[no-untyped-def] + new_object = closest + detection + if closest.track_id is not None: + self.objects[closest.track_id] = new_object + return new_object + + def create_new_object(self, detection: Detection3DPC): # type: ignore[no-untyped-def] + new_object = Object3D(f"obj_{self.cnt}", detection) + if new_object.track_id is not None: + self.objects[new_object.track_id] = new_object + self.cnt += 1 + return new_object + + def agent_encode(self) -> str: + ret = [] + for obj in copy(self.objects).values(): + # we need at least 3 detectieons to consider it a valid object + # for this to be serious we need a ratio of detections within the window of observations + if len(obj.detections) < 4: # type: ignore[arg-type] + continue + ret.append(str(obj.agent_encode())) # type: ignore[no-untyped-call] + if not ret: + return "No objects detected yet." + return "\n".join(ret) + + # @rpc + # def vlm_query(self, description: str) -> Object3D | None: # type: ignore[override] + # imageDetections2D = super().ask_vlm(description) + # print("VLM query found", imageDetections2D, "detections") + # time.sleep(3) + + # if not imageDetections2D.detections: + # return None + + # ret = [] + # for obj in self.objects.values(): + # if obj.ts != imageDetections2D.ts: + # print( + # "Skipping", + # obj.track_id, + # "ts", + # obj.ts, + # "!=", + # imageDetections2D.ts, + # ) + # continue + # if obj.class_id != -100: + # continue + # if obj.name != imageDetections2D.detections[0].name: + # print("Skipping", obj.name, "!=", imageDetections2D.detections[0].name) + # continue + # ret.append(obj) + # ret.sort(key=lambda x: x.ts) + + # return ret[0] if ret else None + + def lookup(self, label: str) -> list[Detection3DPC]: + """Look up a detection by label.""" + return [] + + @rpc + def stop(self): # type: ignore[no-untyped-def] + return super().stop() + + def goto_object(self, object_id: str) -> Object3D | None: + """Go to object by id.""" + return self.objects.get(object_id, None) + + def to_foxglove_scene_update(self) -> "SceneUpdate": + """Convert all detections to a Foxglove SceneUpdate message. + + Returns: + SceneUpdate containing SceneEntity objects for all detections + """ + + # Create SceneUpdate message with all detections + scene_update = SceneUpdate() + scene_update.deletions_length = 0 + scene_update.deletions = [] + scene_update.entities = [] + + for obj in self.objects: + try: + scene_update.entities.append( + obj.to_foxglove_scene_entity(entity_id=f"{obj.name}_{obj.track_id}") # type: ignore[attr-defined] + ) + except Exception: + pass + + scene_update.entities_length = len(scene_update.entities) + return scene_update + + def __len__(self) -> int: + return len(self.objects.values()) + + +def deploy( # type: ignore[no-untyped-def] + dimos: DimosCluster, + lidar: spec.Pointcloud, + camera: spec.Camera, + prefix: str = "/detectorDB", + **kwargs, +) -> Detection3DModule: + from dimos.core import LCMTransport + + detector = dimos.deploy(ObjectDBModule, camera_info=camera.camera_info_stream, **kwargs) # type: ignore[attr-defined] + + detector.image.connect(camera.color_image) + detector.pointcloud.connect(lidar.pointcloud) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + detector.scene_update.transport = LCMTransport(f"{prefix}/scene_update", SceneUpdate) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.detected_pointcloud_0.transport = LCMTransport(f"{prefix}/pointcloud/0", PointCloud2) + detector.detected_pointcloud_1.transport = LCMTransport(f"{prefix}/pointcloud/1", PointCloud2) + detector.detected_pointcloud_2.transport = LCMTransport(f"{prefix}/pointcloud/2", PointCloud2) + + detector.start() + return detector # type: ignore[no-any-return] + + +detectionDB_module = ObjectDBModule.blueprint + +__all__ = ["ObjectDBModule", "deploy", "detectionDB_module"] diff --git a/dimos/perception/detection/person_tracker.py b/dimos/perception/detection/person_tracker.py new file mode 100644 index 0000000000..6883609f3c --- /dev/null +++ b/dimos/perception/detection/person_tracker.py @@ -0,0 +1,128 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +from reactivex import operators as ops +from reactivex.observable import Observable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 +from dimos.msgs.sensor_msgs import CameraInfo, Image +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.type import ImageDetections2D +from dimos.types.timestamped import align_timestamped +from dimos.utils.reactive import backpressure + + +class PersonTracker(Module): + detections: In[Detection2DArray] = None # type: ignore + image: In[Image] = None # type: ignore + target: Out[PoseStamped] = None # type: ignore + + camera_info: CameraInfo + + def __init__(self, cameraInfo: CameraInfo, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.camera_info = cameraInfo + + def center_to_3d( + self, + pixel: tuple[int, int], + camera_info: CameraInfo, + assumed_depth: float = 1.0, + ) -> Vector3: + """Unproject 2D pixel coordinates to 3D position in camera_link frame. + + Args: + camera_info: Camera calibration information + assumed_depth: Assumed depth in meters (default 1.0m from camera) + + Returns: + Vector3 position in camera_link frame coordinates (Z up, X forward) + """ + # Extract camera intrinsics + fx, fy = camera_info.K[0], camera_info.K[4] + cx, cy = camera_info.K[2], camera_info.K[5] + + # Unproject pixel to normalized camera coordinates + x_norm = (pixel[0] - cx) / fx + y_norm = (pixel[1] - cy) / fy + + # Create 3D point at assumed depth in camera optical frame + # Camera optical frame: X right, Y down, Z forward + x_optical = x_norm * assumed_depth + y_optical = y_norm * assumed_depth + z_optical = assumed_depth + + # Transform from camera optical frame to camera_link frame + # Optical: X right, Y down, Z forward + # Link: X forward, Y left, Z up + # Transformation: x_link = z_optical, y_link = -x_optical, z_link = -y_optical + return Vector3(z_optical, -x_optical, -y_optical) + + def detections_stream(self) -> Observable[ImageDetections2D]: + return backpressure( + align_timestamped( + self.image.pure_observable(), + self.detections.pure_observable().pipe( + ops.filter(lambda d: d.detections_length > 0) # type: ignore[attr-defined] + ), + match_tolerance=0.0, + buffer_size=2.0, + ).pipe( + ops.map( + lambda pair: ImageDetections2D.from_ros_detection2d_array( + *pair # type: ignore[misc] + ) + ) + ) + ) + + @rpc + def start(self) -> None: + self.detections_stream().subscribe(self.track) + + @rpc + def stop(self) -> None: + super().stop() + + def track(self, detections2D: ImageDetections2D) -> None: + if len(detections2D) == 0: + return + + target = max(detections2D.detections, key=lambda det: det.bbox_2d_volume()) # type: ignore[attr-defined] + vector = self.center_to_3d(target.center_bbox, self.camera_info, 2.0) # type: ignore[attr-defined] + + pose_in_camera = PoseStamped( + ts=detections2D.ts, + position=vector, + frame_id="camera_link", + ) + + tf_world_to_camera = self.tf.get("world", "camera_link", detections2D.ts, 5.0) + if not tf_world_to_camera: + return + + tf_camera_to_target = Transform.from_pose("target", pose_in_camera) + tf_world_to_target = tf_world_to_camera + tf_camera_to_target + pose_in_world = tf_world_to_target.to_pose(ts=detections2D.ts) + + self.target.publish(pose_in_world) + + +person_tracker_module = PersonTracker.blueprint + +__all__ = ["PersonTracker", "person_tracker_module"] diff --git a/dimos/perception/detection/reid/__init__.py b/dimos/perception/detection/reid/__init__.py new file mode 100644 index 0000000000..31d50a894b --- /dev/null +++ b/dimos/perception/detection/reid/__init__.py @@ -0,0 +1,13 @@ +from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem +from dimos.perception.detection.reid.module import Config, ReidModule +from dimos.perception.detection.reid.type import IDSystem, PassthroughIDSystem + +__all__ = [ + "Config", + "EmbeddingIDSystem", + # ID Systems + "IDSystem", + "PassthroughIDSystem", + # Module + "ReidModule", +] diff --git a/dimos/perception/detection/reid/embedding_id_system.py b/dimos/perception/detection/reid/embedding_id_system.py new file mode 100644 index 0000000000..904739c978 --- /dev/null +++ b/dimos/perception/detection/reid/embedding_id_system.py @@ -0,0 +1,266 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Literal + +import numpy as np + +from dimos.models.embedding.base import Embedding, EmbeddingModel +from dimos.perception.detection.reid.type import IDSystem +from dimos.perception.detection.type import Detection2DBBox + + +class EmbeddingIDSystem(IDSystem): + """Associates short-term track_ids to long-term unique detection IDs via embedding similarity. + + Maintains: + - All embeddings per track_id (as numpy arrays) for robust group comparison + - Negative constraints from co-occurrence (tracks in same frame = different objects) + - Mapping from track_id to unique long-term ID + """ + + def __init__( + self, + model: Callable[[], EmbeddingModel[Embedding]], + padding: int = 0, + similarity_threshold: float = 0.63, + comparison_mode: Literal["max", "mean", "top_k_mean"] = "top_k_mean", + top_k: int = 30, + max_embeddings_per_track: int = 500, + min_embeddings_for_matching: int = 10, + ) -> None: + """Initialize track associator. + + Args: + model: Callable (class or function) that returns an embedding model for feature extraction + padding: Padding to add around detection bbox when cropping (default: 0) + similarity_threshold: Minimum similarity for associating tracks (0-1) + comparison_mode: How to aggregate similarities between embedding groups + - "max": Use maximum similarity between any pair + - "mean": Use mean of all pairwise similarities + - "top_k_mean": Use mean of top-k similarities + top_k: Number of top similarities to average (if using top_k_mean) + max_embeddings_per_track: Maximum number of embeddings to keep per track + min_embeddings_for_matching: Minimum embeddings before attempting to match tracks + """ + # Call model factory (class or function) to get model instance + self.model = model() + + # Call warmup if available + if hasattr(self.model, "warmup"): + self.model.warmup() + + self.padding = padding + self.similarity_threshold = similarity_threshold + self.comparison_mode = comparison_mode + self.top_k = top_k + self.max_embeddings_per_track = max_embeddings_per_track + self.min_embeddings_for_matching = min_embeddings_for_matching + + # Track embeddings (list of all embeddings as numpy arrays) + self.track_embeddings: dict[int, list[np.ndarray]] = {} # type: ignore[type-arg] + + # Negative constraints (track_ids that co-occurred = different objects) + self.negative_pairs: dict[int, set[int]] = {} + + # Track ID to long-term unique ID mapping + self.track_to_long_term: dict[int, int] = {} + self.long_term_counter: int = 0 + + # Similarity history for optional adaptive thresholding + self.similarity_history: list[float] = [] + + def register_detection(self, detection: Detection2DBBox) -> int: + """ + Register detection and return long-term ID. + + Args: + detection: Detection to register + + Returns: + Long-term unique ID for this detection + """ + # Extract embedding from detection's cropped image + cropped_image = detection.cropped_image(padding=self.padding) + embedding = self.model.embed(cropped_image) + assert not isinstance(embedding, list), "Expected single embedding for single image" + # Move embedding to CPU immediately to free GPU memory + embedding = embedding.to_cpu() + + # Update and associate track + self.update_embedding(detection.track_id, embedding) + return self.associate(detection.track_id) + + def update_embedding(self, track_id: int, new_embedding: Embedding) -> None: + """Add new embedding to track's embedding collection. + + Args: + track_id: Short-term track ID from detector + new_embedding: New embedding to add to collection + """ + # Convert to numpy array (already on CPU from feature extractor) + new_vec = new_embedding.to_numpy() + + # Ensure normalized for cosine similarity + norm = np.linalg.norm(new_vec) + if norm > 0: + new_vec = new_vec / norm + + if track_id not in self.track_embeddings: + self.track_embeddings[track_id] = [] + + embeddings = self.track_embeddings[track_id] + embeddings.append(new_vec) + + # Keep only most recent embeddings if limit exceeded + if len(embeddings) > self.max_embeddings_per_track: + embeddings.pop(0) # Remove oldest + + def _compute_group_similarity( + self, + query_embeddings: list[np.ndarray], # type: ignore[type-arg] + candidate_embeddings: list[np.ndarray], # type: ignore[type-arg] + ) -> float: + """Compute similarity between two groups of embeddings. + + Args: + query_embeddings: List of embeddings for query track + candidate_embeddings: List of embeddings for candidate track + + Returns: + Aggregated similarity score + """ + # Compute all pairwise similarities efficiently + query_matrix = np.stack(query_embeddings) # [M, D] + candidate_matrix = np.stack(candidate_embeddings) # [N, D] + + # Cosine similarity via matrix multiplication (already normalized) + similarities = query_matrix @ candidate_matrix.T # [M, N] + + if self.comparison_mode == "max": + # Maximum similarity across all pairs + return float(np.max(similarities)) + + elif self.comparison_mode == "mean": + # Mean of all pairwise similarities + return float(np.mean(similarities)) + + elif self.comparison_mode == "top_k_mean": + # Mean of top-k similarities + flat_sims = similarities.flatten() + k = min(self.top_k, len(flat_sims)) + top_k_sims = np.partition(flat_sims, -k)[-k:] + return float(np.mean(top_k_sims)) + + else: + raise ValueError(f"Unknown comparison mode: {self.comparison_mode}") + + def add_negative_constraints(self, track_ids: list[int]) -> None: + """Record that these track_ids co-occurred in same frame (different objects). + + Args: + track_ids: List of track_ids present in current frame + """ + # All pairs of track_ids in same frame can't be same object + for i, tid1 in enumerate(track_ids): + for tid2 in track_ids[i + 1 :]: + self.negative_pairs.setdefault(tid1, set()).add(tid2) + self.negative_pairs.setdefault(tid2, set()).add(tid1) + + def associate(self, track_id: int) -> int: + """Associate track_id to long-term unique detection ID. + + Args: + track_id: Short-term track ID to associate + + Returns: + Long-term unique detection ID + """ + # Already has assignment + if track_id in self.track_to_long_term: + return self.track_to_long_term[track_id] + + # Need embeddings to compare + if track_id not in self.track_embeddings or not self.track_embeddings[track_id]: + # Create new ID if no embeddings yet + new_id = self.long_term_counter + self.long_term_counter += 1 + self.track_to_long_term[track_id] = new_id + return new_id + + # Get query embeddings + query_embeddings = self.track_embeddings[track_id] + + # Don't attempt matching until we have enough embeddings for the query track + if len(query_embeddings) < self.min_embeddings_for_matching: + # Not ready yet - return -1 + return -1 + + # Build candidate list (only tracks with assigned long_term_ids) + best_similarity = -1.0 + best_track_id = None + + for other_tid, other_embeddings in self.track_embeddings.items(): + # Skip self + if other_tid == track_id: + continue + + # Skip if negative constraint (co-occurred) + if other_tid in self.negative_pairs.get(track_id, set()): + continue + + # Skip if no long_term_id yet + if other_tid not in self.track_to_long_term: + continue + + # Skip if not enough embeddings + if len(other_embeddings) < self.min_embeddings_for_matching: + continue + + # Compute group similarity + similarity = self._compute_group_similarity(query_embeddings, other_embeddings) + + if similarity > best_similarity: + best_similarity = similarity + best_track_id = other_tid + + # Check if best match exceeds threshold + if best_track_id is not None and best_similarity >= self.similarity_threshold: + matched_long_term_id = self.track_to_long_term[best_track_id] + print( + f"Track {track_id}: matched with track {best_track_id} " + f"(long_term_id={matched_long_term_id}, similarity={best_similarity:.4f}, " + f"mode={self.comparison_mode}, embeddings: {len(query_embeddings)} vs {len(self.track_embeddings[best_track_id])}), threshold: {self.similarity_threshold}" + ) + + # Track similarity history + self.similarity_history.append(best_similarity) + + # Associate with existing long_term_id + self.track_to_long_term[track_id] = matched_long_term_id + return matched_long_term_id + + # Create new unique detection ID + new_id = self.long_term_counter + self.long_term_counter += 1 + self.track_to_long_term[track_id] = new_id + + if best_track_id is not None: + print( + f"Track {track_id}: creating new ID {new_id} " + f"(best similarity={best_similarity:.4f} with id={self.track_to_long_term[best_track_id]} below threshold={self.similarity_threshold})" + ) + + return new_id diff --git a/dimos/perception/detection/reid/module.py b/dimos/perception/detection/reid/module.py new file mode 100644 index 0000000000..2f07d51834 --- /dev/null +++ b/dimos/perception/detection/reid/module.py @@ -0,0 +1,112 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + ImageAnnotations, + TextAnnotation, +) +from dimos_lcm.foxglove_msgs.Point2 import Point2 # type: ignore[import-untyped] +from reactivex import operators as ops +from reactivex.observable import Observable + +from dimos.core import In, Module, ModuleConfig, Out, rpc +from dimos.msgs.foxglove_msgs.Color import Color +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem +from dimos.perception.detection.reid.type import IDSystem +from dimos.perception.detection.type import ImageDetections2D +from dimos.types.timestamped import align_timestamped, to_ros_stamp +from dimos.utils.reactive import backpressure + + +class Config(ModuleConfig): + idsystem: IDSystem + + +class ReidModule(Module): + default_config = Config + + detections: In[Detection2DArray] = None # type: ignore + image: In[Image] = None # type: ignore + annotations: Out[ImageAnnotations] = None # type: ignore + + def __init__(self, idsystem: IDSystem | None = None, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + if idsystem is None: + try: + from dimos.models.embedding import TorchReIDModel + + idsystem = EmbeddingIDSystem(model=TorchReIDModel, padding=0) # type: ignore[arg-type] + except Exception as e: + raise RuntimeError( + "TorchReIDModel not available. Please install with: pip install dimos[torchreid]" + ) from e + + self.idsystem = idsystem + + def detections_stream(self) -> Observable[ImageDetections2D]: + return backpressure( + align_timestamped( + self.image.pure_observable(), + self.detections.pure_observable().pipe( + ops.filter(lambda d: d.detections_length > 0) # type: ignore[attr-defined] + ), + match_tolerance=0.0, + buffer_size=2.0, + ).pipe(ops.map(lambda pair: ImageDetections2D.from_ros_detection2d_array(*pair))) # type: ignore[misc] + ) + + @rpc + def start(self) -> None: + self.detections_stream().subscribe(self.ingress) + + @rpc + def stop(self) -> None: + super().stop() + + def ingress(self, imageDetections: ImageDetections2D) -> None: + text_annotations = [] + + for detection in imageDetections: + # Register detection and get long-term ID + long_term_id = self.idsystem.register_detection(detection) + + # Skip annotation if not ready yet (long_term_id == -1) + if long_term_id == -1: + continue + + # Create text annotation for long_term_id above the detection + x1, y1, _, _ = detection.bbox + font_size = imageDetections.image.width / 60 + + text_annotations.append( + TextAnnotation( + timestamp=to_ros_stamp(detection.ts), + position=Point2(x=x1, y=y1 - font_size * 1.5), + text=f"PERSON: {long_term_id}", + font_size=font_size, + text_color=Color(r=0.0, g=1.0, b=1.0, a=1.0), # Cyan + background_color=Color(r=0.0, g=0.0, b=0.0, a=0.8), + ) + ) + + # Publish annotations (even if empty to clear previous annotations) + annotations = ImageAnnotations( + texts=text_annotations, + texts_length=len(text_annotations), + points=[], + points_length=0, + ) + self.annotations.publish(annotations) diff --git a/dimos/perception/detection/reid/test_embedding_id_system.py b/dimos/perception/detection/reid/test_embedding_id_system.py new file mode 100644 index 0000000000..8e06af5a89 --- /dev/null +++ b/dimos/perception/detection/reid/test_embedding_id_system.py @@ -0,0 +1,270 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem +from dimos.utils.data import get_data + + +@pytest.fixture(scope="session") +def mobileclip_model(): + """Load MobileCLIP model once for all tests.""" + from dimos.models.embedding.mobileclip import MobileCLIPModel + + model_path = get_data("models_mobileclip") / "mobileclip2_s0.pt" + model = MobileCLIPModel(model_name="MobileCLIP2-S0", model_path=model_path) + model.warmup() + return model + + +@pytest.fixture +def track_associator(mobileclip_model): + """Create fresh EmbeddingIDSystem for each test.""" + return EmbeddingIDSystem(model=lambda: mobileclip_model, similarity_threshold=0.75) + + +@pytest.fixture(scope="session") +def test_image(): + """Load test image.""" + return Image.from_file(get_data("cafe.jpg")).to_rgb() + + +@pytest.mark.gpu +def test_update_embedding_single(track_associator, mobileclip_model, test_image) -> None: + """Test updating embedding for a single track.""" + embedding = mobileclip_model.embed(test_image) + + # First update + track_associator.update_embedding(track_id=1, new_embedding=embedding) + + assert 1 in track_associator.track_embeddings + assert track_associator.embedding_counts[1] == 1 + + # Verify embedding is on device and normalized + emb_vec = track_associator.track_embeddings[1] + assert isinstance(emb_vec, torch.Tensor) + assert emb_vec.device.type in ["cuda", "cpu"] + norm = torch.norm(emb_vec).item() + assert abs(norm - 1.0) < 0.01, "Embedding should be normalized" + + +@pytest.mark.gpu +def test_update_embedding_running_average(track_associator, mobileclip_model, test_image) -> None: + """Test running average of embeddings.""" + embedding1 = mobileclip_model.embed(test_image) + embedding2 = mobileclip_model.embed(test_image) + + # Add first embedding + track_associator.update_embedding(track_id=1, new_embedding=embedding1) + first_vec = track_associator.track_embeddings[1].clone() + + # Add second embedding (same image, should be very similar) + track_associator.update_embedding(track_id=1, new_embedding=embedding2) + avg_vec = track_associator.track_embeddings[1] + + assert track_associator.embedding_counts[1] == 2 + + # Average should still be normalized + norm = torch.norm(avg_vec).item() + assert abs(norm - 1.0) < 0.01, "Average embedding should be normalized" + + # Average should be similar to both originals (same image) + similarity1 = (first_vec @ avg_vec).item() + assert similarity1 > 0.99, "Average should be very similar to original" + + +@pytest.mark.gpu +def test_negative_constraints(track_associator) -> None: + """Test negative constraint recording.""" + # Simulate frame with 3 tracks + track_ids = [1, 2, 3] + track_associator.add_negative_constraints(track_ids) + + # Check that all pairs are recorded + assert 2 in track_associator.negative_pairs[1] + assert 3 in track_associator.negative_pairs[1] + assert 1 in track_associator.negative_pairs[2] + assert 3 in track_associator.negative_pairs[2] + assert 1 in track_associator.negative_pairs[3] + assert 2 in track_associator.negative_pairs[3] + + +@pytest.mark.gpu +def test_associate_new_track(track_associator, mobileclip_model, test_image) -> None: + """Test associating a new track creates new long_term_id.""" + embedding = mobileclip_model.embed(test_image) + track_associator.update_embedding(track_id=1, new_embedding=embedding) + + # First association should create new long_term_id + long_term_id = track_associator.associate(track_id=1) + + assert long_term_id == 0, "First track should get long_term_id=0" + assert track_associator.track_to_long_term[1] == 0 + assert track_associator.long_term_counter == 1 + + +@pytest.mark.gpu +def test_associate_similar_tracks(track_associator, mobileclip_model, test_image) -> None: + """Test associating similar tracks to same long_term_id.""" + # Create embeddings from same image (should be very similar) + embedding1 = mobileclip_model.embed(test_image) + embedding2 = mobileclip_model.embed(test_image) + + # Add first track + track_associator.update_embedding(track_id=1, new_embedding=embedding1) + long_term_id_1 = track_associator.associate(track_id=1) + + # Add second track with similar embedding + track_associator.update_embedding(track_id=2, new_embedding=embedding2) + long_term_id_2 = track_associator.associate(track_id=2) + + # Should get same long_term_id (similarity > 0.75) + assert long_term_id_1 == long_term_id_2, "Similar tracks should get same long_term_id" + assert track_associator.long_term_counter == 1, "Only one long_term_id should be created" + + +@pytest.mark.gpu +def test_associate_with_negative_constraint(track_associator, mobileclip_model, test_image) -> None: + """Test that negative constraints prevent association.""" + # Create similar embeddings + embedding1 = mobileclip_model.embed(test_image) + embedding2 = mobileclip_model.embed(test_image) + + # Add first track + track_associator.update_embedding(track_id=1, new_embedding=embedding1) + long_term_id_1 = track_associator.associate(track_id=1) + + # Add negative constraint (tracks co-occurred) + track_associator.add_negative_constraints([1, 2]) + + # Add second track with similar embedding + track_associator.update_embedding(track_id=2, new_embedding=embedding2) + long_term_id_2 = track_associator.associate(track_id=2) + + # Should get different long_term_ids despite high similarity + assert long_term_id_1 != long_term_id_2, ( + "Co-occurring tracks should get different long_term_ids" + ) + assert track_associator.long_term_counter == 2, "Two long_term_ids should be created" + + +@pytest.mark.gpu +def test_associate_different_objects(track_associator, mobileclip_model, test_image) -> None: + """Test that dissimilar embeddings get different long_term_ids.""" + # Create embeddings for image and text (very different) + image_emb = mobileclip_model.embed(test_image) + text_emb = mobileclip_model.embed_text("a dog") + + # Add first track (image) + track_associator.update_embedding(track_id=1, new_embedding=image_emb) + long_term_id_1 = track_associator.associate(track_id=1) + + # Add second track (text - very different embedding) + track_associator.update_embedding(track_id=2, new_embedding=text_emb) + long_term_id_2 = track_associator.associate(track_id=2) + + # Should get different long_term_ids (similarity < 0.75) + assert long_term_id_1 != long_term_id_2, "Different objects should get different long_term_ids" + assert track_associator.long_term_counter == 2 + + +@pytest.mark.gpu +def test_associate_returns_cached(track_associator, mobileclip_model, test_image) -> None: + """Test that repeated calls return same long_term_id.""" + embedding = mobileclip_model.embed(test_image) + track_associator.update_embedding(track_id=1, new_embedding=embedding) + + # First call + long_term_id_1 = track_associator.associate(track_id=1) + + # Second call should return cached result + long_term_id_2 = track_associator.associate(track_id=1) + + assert long_term_id_1 == long_term_id_2 + assert track_associator.long_term_counter == 1, "Should not create new ID" + + +@pytest.mark.gpu +def test_associate_not_ready(track_associator) -> None: + """Test that associate returns -1 for track without embedding.""" + long_term_id = track_associator.associate(track_id=999) + assert long_term_id == -1, "Should return -1 for track without embedding" + + +@pytest.mark.gpu +def test_gpu_performance(track_associator, mobileclip_model, test_image) -> None: + """Test that embeddings stay on GPU for performance.""" + embedding = mobileclip_model.embed(test_image) + track_associator.update_embedding(track_id=1, new_embedding=embedding) + + # Embedding should stay on device + emb_vec = track_associator.track_embeddings[1] + assert isinstance(emb_vec, torch.Tensor) + # Device comparison (handle "cuda" vs "cuda:0") + expected_device = mobileclip_model.device + assert emb_vec.device.type == torch.device(expected_device).type + + # Running average should happen on GPU + embedding2 = mobileclip_model.embed(test_image) + track_associator.update_embedding(track_id=1, new_embedding=embedding2) + + avg_vec = track_associator.track_embeddings[1] + assert avg_vec.device.type == torch.device(expected_device).type + + +@pytest.mark.gpu +def test_similarity_threshold_configurable(mobileclip_model) -> None: + """Test that similarity threshold is configurable.""" + associator_strict = EmbeddingIDSystem(model=lambda: mobileclip_model, similarity_threshold=0.95) + associator_loose = EmbeddingIDSystem(model=lambda: mobileclip_model, similarity_threshold=0.50) + + assert associator_strict.similarity_threshold == 0.95 + assert associator_loose.similarity_threshold == 0.50 + + +@pytest.mark.gpu +def test_multi_track_scenario(track_associator, mobileclip_model, test_image) -> None: + """Test realistic scenario with multiple tracks across frames.""" + # Frame 1: Track 1 appears + emb1 = mobileclip_model.embed(test_image) + track_associator.update_embedding(1, emb1) + track_associator.add_negative_constraints([1]) + lt1 = track_associator.associate(1) + + # Frame 2: Track 1 and Track 2 appear (different objects) + text_emb = mobileclip_model.embed_text("a dog") + track_associator.update_embedding(1, emb1) # Update average + track_associator.update_embedding(2, text_emb) + track_associator.add_negative_constraints([1, 2]) # Co-occur = different + lt2 = track_associator.associate(2) + + # Track 2 should get different ID despite any similarity + assert lt1 != lt2 + + # Frame 3: Track 1 disappears, Track 3 appears (same as Track 1) + emb3 = mobileclip_model.embed(test_image) + track_associator.update_embedding(3, emb3) + track_associator.add_negative_constraints([2, 3]) + lt3 = track_associator.associate(3) + + # Track 3 should match Track 1 (not co-occurring, similar embedding) + assert lt3 == lt1 + + print("\nMulti-track scenario results:") + print(f" Track 1 -> long_term_id {lt1}") + print(f" Track 2 -> long_term_id {lt2} (different object, co-occurred)") + print(f" Track 3 -> long_term_id {lt3} (re-identified as Track 1)") diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py new file mode 100644 index 0000000000..ea53c17702 --- /dev/null +++ b/dimos/perception/detection/reid/test_module.py @@ -0,0 +1,44 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.core import LCMTransport +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem +from dimos.perception.detection.reid.module import ReidModule + + +@pytest.mark.tool +def test_reid_ingress(imageDetections2d) -> None: + try: + from dimos.models.embedding import TorchReIDModel + except Exception: + pytest.skip("TorchReIDModel not available") + + # Create TorchReID-based IDSystem for testing + reid_model = TorchReIDModel(model_name="osnet_x1_0") + reid_model.warmup() + idsystem = EmbeddingIDSystem( + model=lambda: reid_model, + padding=20, + similarity_threshold=0.75, + ) + + reid_module = ReidModule(idsystem=idsystem, warmup=False) + print("Processing detections through ReidModule...") + reid_module.annotations._transport = LCMTransport("/annotations", ImageAnnotations) + reid_module.ingress(imageDetections2d) + reid_module._close_module() + print("✓ ReidModule ingress test completed successfully") diff --git a/dimos/perception/detection/reid/type.py b/dimos/perception/detection/reid/type.py new file mode 100644 index 0000000000..28ea719f81 --- /dev/null +++ b/dimos/perception/detection/reid/type.py @@ -0,0 +1,50 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D + + +class IDSystem(ABC): + """Abstract base class for ID assignment systems.""" + + def register_detections(self, detections: ImageDetections2D) -> None: + """Register multiple detections.""" + for detection in detections.detections: + if isinstance(detection, Detection2DBBox): + self.register_detection(detection) + + @abstractmethod + def register_detection(self, detection: Detection2DBBox) -> int: + """ + Register a single detection, returning assigned (long term) ID. + + Args: + detection: Detection to register + + Returns: + Long-term unique ID for this detection + """ + ... + + +class PassthroughIDSystem(IDSystem): + """Simple ID system that returns track_id with no object permanence.""" + + def register_detection(self, detection: Detection2DBBox) -> int: + """Return detection's track_id as long-term ID (no permanence).""" + return detection.track_id diff --git a/dimos/perception/detection/test_moduleDB.py b/dimos/perception/detection/test_moduleDB.py new file mode 100644 index 0000000000..e9815f1f3e --- /dev/null +++ b/dimos/perception/detection/test_moduleDB.py @@ -0,0 +1,59 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from lcm_msgs.foxglove_msgs import SceneUpdate +import pytest + +from dimos.core import LCMTransport +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.moduleDB import ObjectDBModule +from dimos.robot.unitree.connection import go2 + + +@pytest.mark.module +def test_moduleDB(dimos_cluster) -> None: + connection = go2.deploy(dimos_cluster, "fake") + + moduleDB = dimos_cluster.deploy( + ObjectDBModule, + camera_info=go2._camera_info_static(), + goto=lambda obj_id: print(f"Going to {obj_id}"), + ) + moduleDB.image.connect(connection.video) + moduleDB.pointcloud.connect(connection.lidar) + + moduleDB.annotations.transport = LCMTransport("/annotations", ImageAnnotations) + moduleDB.detections.transport = LCMTransport("/detections", Detection2DArray) + + moduleDB.detected_pointcloud_0.transport = LCMTransport("/detected/pointcloud/0", PointCloud2) + moduleDB.detected_pointcloud_1.transport = LCMTransport("/detected/pointcloud/1", PointCloud2) + moduleDB.detected_pointcloud_2.transport = LCMTransport("/detected/pointcloud/2", PointCloud2) + + moduleDB.detected_image_0.transport = LCMTransport("/detected/image/0", Image) + moduleDB.detected_image_1.transport = LCMTransport("/detected/image/1", Image) + moduleDB.detected_image_2.transport = LCMTransport("/detected/image/2", Image) + + moduleDB.scene_update.transport = LCMTransport("/scene_update", SceneUpdate) + moduleDB.target.transport = LCMTransport("/target", PoseStamped) + + connection.start() + moduleDB.start() + + time.sleep(4) + print("VLM RES", moduleDB.navigate_to_object_in_view("white floor")) + time.sleep(30) diff --git a/dimos/perception/detection/type/__init__.py b/dimos/perception/detection/type/__init__.py new file mode 100644 index 0000000000..f34598c3a1 --- /dev/null +++ b/dimos/perception/detection/type/__init__.py @@ -0,0 +1,43 @@ +from dimos.perception.detection.type.detection2d import ( # type: ignore[attr-defined] + Detection2D, + Detection2DBBox, + Detection2DPerson, + Filter2D, + ImageDetections2D, +) +from dimos.perception.detection.type.detection3d import ( + Detection3D, + Detection3DBBox, + Detection3DPC, + ImageDetections3DPC, + PointCloudFilter, + height_filter, + radius_outlier, + raycast, + statistical, +) +from dimos.perception.detection.type.imageDetections import ImageDetections +from dimos.perception.detection.type.utils import TableStr + +__all__ = [ + # 2D Detection types + "Detection2D", + "Detection2DBBox", + "Detection2DPerson", + # 3D Detection types + "Detection3D", + "Detection3DBBox", + "Detection3DPC", + "Filter2D", + # Base types + "ImageDetections", + "ImageDetections2D", + "ImageDetections3DPC", + # Point cloud filters + "PointCloudFilter", + "TableStr", + "height_filter", + "radius_outlier", + "raycast", + "statistical", +] diff --git a/dimos/perception/detection/type/detection2d/__init__.py b/dimos/perception/detection/type/detection2d/__init__.py new file mode 100644 index 0000000000..a0e22546b0 --- /dev/null +++ b/dimos/perception/detection/type/detection2d/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.perception.detection.type.detection2d.base import Detection2D, Filter2D +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.person import Detection2DPerson + +__all__ = [ + "Detection2D", + "Detection2DBBox", + "Detection2DPerson", + "ImageDetections2D", +] diff --git a/dimos/perception/detection/type/detection2d/base.py b/dimos/perception/detection/type/detection2d/base.py new file mode 100644 index 0000000000..c93ec06dcd --- /dev/null +++ b/dimos/perception/detection/type/detection2d/base.py @@ -0,0 +1,58 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from collections.abc import Callable + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + PointsAnnotation, + TextAnnotation, +) +from dimos_lcm.vision_msgs import Detection2D as ROSDetection2D # type: ignore[import-untyped] + +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.sensor_msgs import Image +from dimos.types.timestamped import Timestamped + + +class Detection2D(Timestamped): + """Abstract base class for 2D detections.""" + + @abstractmethod + def cropped_image(self, padding: int = 20) -> Image: + """Return a cropped version of the image focused on the detection area.""" + ... + + @abstractmethod + def to_image_annotations(self) -> ImageAnnotations: + """Convert detection to Foxglove ImageAnnotations for visualization.""" + ... + + @abstractmethod + def to_text_annotation(self) -> list[TextAnnotation]: + """Return text annotations for visualization.""" + ... + + @abstractmethod + def to_points_annotation(self) -> list[PointsAnnotation]: + """Return points/shape annotations for visualization.""" + ... + + @abstractmethod + def to_ros_detection2d(self) -> ROSDetection2D: + """Convert detection to ROS Detection2D message.""" + ... + + +Filter2D = Callable[[Detection2D], bool] diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py new file mode 100644 index 0000000000..ff0832195f --- /dev/null +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -0,0 +1,406 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +import hashlib +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from ultralytics.engine.results import Results + + from dimos.msgs.sensor_msgs import Image + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + PointsAnnotation, + TextAnnotation, +) +from dimos_lcm.foxglove_msgs.Point2 import Point2 # type: ignore[import-untyped] +from dimos_lcm.vision_msgs import ( # type: ignore[import-untyped] + BoundingBox2D, + Detection2D as ROSDetection2D, + ObjectHypothesis, + ObjectHypothesisWithPose, + Point2D, + Pose2D, +) +from rich.console import Console +from rich.text import Text + +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.foxglove_msgs.Color import Color +from dimos.msgs.std_msgs import Header +from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.types.timestamped import to_ros_stamp, to_timestamp +from dimos.utils.decorators.decorators import simple_mcache + +Bbox = tuple[float, float, float, float] +CenteredBbox = tuple[float, float, float, float] + + +def _hash_to_color(name: str) -> str: + """Generate a consistent color for a given name using hash.""" + # List of rich colors to choose from + colors = [ + "cyan", + "magenta", + "yellow", + "blue", + "green", + "red", + "bright_cyan", + "bright_magenta", + "bright_yellow", + "bright_blue", + "bright_green", + "bright_red", + "purple", + "white", + "pink", + ] + + # Hash the name and pick a color + hash_value = hashlib.md5(name.encode()).digest()[0] + return colors[hash_value % len(colors)] + + +@dataclass +class Detection2DBBox(Detection2D): + bbox: Bbox + track_id: int + class_id: int + confidence: float + name: str + ts: float + image: Image + + def to_repr_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the detection for display purposes.""" + x1, y1, x2, y2 = self.bbox + return { + "name": self.name, + "class": str(self.class_id), + "track": str(self.track_id), + "conf": f"{self.confidence:.2f}", + "bbox": f"[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]", + } + + def center_to_3d( + self, + pixel: tuple[int, int], + camera_info: CameraInfo, # type: ignore[name-defined] + assumed_depth: float = 1.0, + ) -> PoseStamped: # type: ignore[name-defined] + """Unproject 2D pixel coordinates to 3D position in camera optical frame. + + Args: + camera_info: Camera calibration information + assumed_depth: Assumed depth in meters (default 1.0m from camera) + + Returns: + Vector3 position in camera optical frame coordinates + """ + # Extract camera intrinsics + fx, fy = camera_info.K[0], camera_info.K[4] + cx, cy = camera_info.K[2], camera_info.K[5] + + # Unproject pixel to normalized camera coordinates + x_norm = (pixel[0] - cx) / fx + y_norm = (pixel[1] - cy) / fy + + # Create 3D point at assumed depth in camera optical frame + # Camera optical frame: X right, Y down, Z forward + return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) # type: ignore[name-defined] + + # return focused image, only on the bbox + def cropped_image(self, padding: int = 20) -> Image: + """Return a cropped version of the image focused on the bounding box. + + Args: + padding: Pixels to add around the bounding box (default: 20) + + Returns: + Cropped Image containing only the detection area plus padding + """ + x1, y1, x2, y2 = map(int, self.bbox) + return self.image.crop( + x1 - padding, y1 - padding, x2 - x1 + 2 * padding, y2 - y1 + 2 * padding + ) + + def __str__(self) -> str: + console = Console(force_terminal=True, legacy_windows=False) + d = self.to_repr_dict() + + # Build the string representation + parts = [ + Text(f"{self.__class__.__name__}("), + ] + + # Add any extra fields (e.g., points for Detection3D) + extra_keys = [k for k in d.keys() if k not in ["class"]] + for key in extra_keys: + if d[key] == "None": + parts.append(Text(f"{key}={d[key]}", style="dim")) + else: + parts.append(Text(f"{key}={d[key]}", style=_hash_to_color(key))) + + parts.append(Text(")")) + + # Render to string + with console.capture() as capture: + console.print(*parts, end="") + return capture.get().strip() + + @property + def center_bbox(self) -> tuple[float, float]: + """Get center point of bounding box.""" + x1, y1, x2, y2 = self.bbox + return ((x1 + x2) / 2, (y1 + y2) / 2) + + def bbox_2d_volume(self) -> float: + x1, y1, x2, y2 = self.bbox + width = max(0.0, x2 - x1) + height = max(0.0, y2 - y1) + return width * height + + @simple_mcache + def is_valid(self) -> bool: + """Check if detection bbox is valid. + + Validates that: + - Bounding box has positive dimensions + - Bounding box is within image bounds (if image has shape) + + Returns: + True if bbox is valid, False otherwise + """ + x1, y1, x2, y2 = self.bbox + + # Check positive dimensions + if x2 <= x1 or y2 <= y1: + return False + + # Check if within image bounds (if image has shape) + if self.image.shape: + h, w = self.image.shape[:2] + if not (0 <= x1 <= w and 0 <= y1 <= h and 0 <= x2 <= w and 0 <= y2 <= h): + return False + + return True + + @classmethod + def from_ultralytics_result(cls, result: Results, idx: int, image: Image) -> Detection2DBBox: + """Create Detection2DBBox from ultralytics Results object. + + Args: + result: Ultralytics Results object containing detection data + idx: Index of the detection in the results + image: Source image + + Returns: + Detection2DBBox instance + """ + if result.boxes is None: + raise ValueError("Result has no boxes") + + # Extract bounding box coordinates + bbox_array = result.boxes.xyxy[idx].cpu().numpy() + bbox: Bbox = ( + float(bbox_array[0]), + float(bbox_array[1]), + float(bbox_array[2]), + float(bbox_array[3]), + ) + + # Extract confidence + confidence = float(result.boxes.conf[idx].cpu()) + + # Extract class ID and name + class_id = int(result.boxes.cls[idx].cpu()) + name = ( + result.names.get(class_id, f"class_{class_id}") + if hasattr(result, "names") + else f"class_{class_id}" + ) + + # Extract track ID if available + track_id = -1 + if hasattr(result.boxes, "id") and result.boxes.id is not None: + track_id = int(result.boxes.id[idx].cpu()) + + return cls( + bbox=bbox, + track_id=track_id, + class_id=class_id, + confidence=confidence, + name=name, + ts=image.ts, + image=image, + ) + + def get_bbox_center(self) -> CenteredBbox: + x1, y1, x2, y2 = self.bbox + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = float(x2 - x1) + height = float(y2 - y1) + return (center_x, center_y, width, height) + + def to_ros_bbox(self) -> BoundingBox2D: + center_x, center_y, width, height = self.get_bbox_center() + return BoundingBox2D( + center=Pose2D( + position=Point2D(x=center_x, y=center_y), + theta=0.0, + ), + size_x=width, + size_y=height, + ) + + def lcm_encode(self): # type: ignore[no-untyped-def] + return self.to_image_annotations().lcm_encode() + + def to_text_annotation(self) -> list[TextAnnotation]: + x1, y1, _x2, y2 = self.bbox + + font_size = self.image.width / 80 + + # Build label text - exclude class_id if it's -1 (VLM detection) + if self.class_id == -1: + label_text = f"{self.name}_{self.track_id}" + else: + label_text = f"{self.name}_{self.class_id}_{self.track_id}" + + annotations = [ + TextAnnotation( + timestamp=to_ros_stamp(self.ts), + position=Point2(x=x1, y=y1), + text=label_text, + font_size=font_size, + text_color=Color(r=1.0, g=1.0, b=1.0, a=1), + background_color=Color(r=0, g=0, b=0, a=1), + ), + ] + + # Only show confidence if it's not 1.0 + if self.confidence != 1.0: + annotations.append( + TextAnnotation( + timestamp=to_ros_stamp(self.ts), + position=Point2(x=x1, y=y2 + font_size), + text=f"confidence: {self.confidence:.3f}", + font_size=font_size, + text_color=Color(r=1.0, g=1.0, b=1.0, a=1), + background_color=Color(r=0, g=0, b=0, a=1), + ) + ) + + return annotations + + def to_points_annotation(self) -> list[PointsAnnotation]: + x1, y1, x2, y2 = self.bbox + + thickness = 1 + + # Use consistent color based on object name, brighter for outline + outline_color = Color.from_string(self.name, alpha=1.0, brightness=1.25) + + return [ + PointsAnnotation( + timestamp=to_ros_stamp(self.ts), + outline_color=outline_color, + fill_color=Color.from_string(self.name, alpha=0.2), + thickness=thickness, + points_length=4, + points=[ + Point2(x1, y1), + Point2(x1, y2), + Point2(x2, y2), + Point2(x2, y1), + ], + type=PointsAnnotation.LINE_LOOP, + ) + ] + + # this is almost never called directly since this is a single detection + # and ImageAnnotations message normally contains multiple detections annotations + # so ImageDetections2D and ImageDetections3D normally implements this for whole image + def to_image_annotations(self) -> ImageAnnotations: + points = self.to_points_annotation() + texts = self.to_text_annotation() + + return ImageAnnotations( + texts=texts, + texts_length=len(texts), + points=points, + points_length=len(points), + ) + + @classmethod + def from_ros_detection2d(cls, ros_det: ROSDetection2D, **kwargs) -> Detection2D: # type: ignore[no-untyped-def] + """Convert from ROS Detection2D message to Detection2D object.""" + # Extract bbox from ROS format + center_x = ros_det.bbox.center.position.x + center_y = ros_det.bbox.center.position.y + width = ros_det.bbox.size_x + height = ros_det.bbox.size_y + + # Convert centered bbox to corner format + x1 = center_x - width / 2.0 + y1 = center_y - height / 2.0 + x2 = center_x + width / 2.0 + y2 = center_y + height / 2.0 + bbox = (x1, y1, x2, y2) + + # Extract hypothesis info + class_id = 0 + confidence = 0.0 + if ros_det.results: + hypothesis = ros_det.results[0].hypothesis + class_id = hypothesis.class_id + confidence = hypothesis.score + + # Extract track_id + track_id = int(ros_det.id) if ros_det.id.isdigit() else 0 + + # Extract timestamp + ts = to_timestamp(ros_det.header.stamp) + + name = kwargs.pop("name", f"class_{class_id}") + + return cls( + bbox=bbox, + track_id=track_id, + class_id=class_id, + confidence=confidence, + name=name, + ts=ts, + **kwargs, + ) + + def to_ros_detection2d(self) -> ROSDetection2D: + return ROSDetection2D( + header=Header(self.ts, "camera_link"), + bbox=self.to_ros_bbox(), + results=[ + ObjectHypothesisWithPose( + ObjectHypothesis( + class_id=self.class_id, + score=self.confidence, + ) + ) + ], + id=str(self.track_id), + ) diff --git a/dimos/perception/detection/type/detection2d/imageDetections2D.py b/dimos/perception/detection/type/detection2d/imageDetections2D.py new file mode 100644 index 0000000000..64d8dc8f4d --- /dev/null +++ b/dimos/perception/detection/type/detection2d/imageDetections2D.py @@ -0,0 +1,81 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.imageDetections import ImageDetections + +if TYPE_CHECKING: + from dimos_lcm.vision_msgs import Detection2DArray # type: ignore[import-untyped] + from ultralytics.engine.results import Results + + from dimos.msgs.sensor_msgs import Image + + +class ImageDetections2D(ImageDetections[Detection2D]): + @classmethod + def from_ros_detection2d_array( # type: ignore[no-untyped-def] + cls, image: Image, ros_detections: Detection2DArray, **kwargs + ) -> ImageDetections2D: + """Convert from ROS Detection2DArray message to ImageDetections2D object.""" + detections: list[Detection2D] = [] + for ros_det in ros_detections.detections: + detection = Detection2DBBox.from_ros_detection2d(ros_det, image=image, **kwargs) + if detection.is_valid(): # type: ignore[attr-defined] + detections.append(detection) + + return cls(image=image, detections=detections) + + @classmethod + def from_ultralytics_result( # type: ignore[no-untyped-def] + cls, image: Image, results: list[Results], **kwargs + ) -> ImageDetections2D: + """Create ImageDetections2D from ultralytics Results. + + Dispatches to appropriate Detection2D subclass based on result type: + - If keypoints present: creates Detection2DPerson + - Otherwise: creates Detection2DBBox + + Args: + image: Source image + results: List of ultralytics Results objects + **kwargs: Additional arguments passed to detection constructors + + Returns: + ImageDetections2D containing appropriate detection types + """ + from dimos.perception.detection.type.detection2d.person import Detection2DPerson + + detections: list[Detection2D] = [] + for result in results: + if result.boxes is None: + continue + + num_detections = len(result.boxes.xyxy) + for i in range(num_detections): + detection: Detection2D + if result.keypoints is not None: + # Pose detection with keypoints + detection = Detection2DPerson.from_ultralytics_result(result, i, image) + else: + # Regular bbox detection + detection = Detection2DBBox.from_ultralytics_result(result, i, image) + if detection.is_valid(): + detections.append(detection) + + return cls(image=image, detections=detections) diff --git a/dimos/perception/detection/type/detection2d/person.py b/dimos/perception/detection/type/detection2d/person.py new file mode 100644 index 0000000000..396ce9d30f --- /dev/null +++ b/dimos/perception/detection/type/detection2d/person.py @@ -0,0 +1,345 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +# Import for type checking only to avoid circular imports +from typing import TYPE_CHECKING + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + PointsAnnotation, + TextAnnotation, +) +from dimos_lcm.foxglove_msgs.Point2 import Point2 # type: ignore[import-untyped] +import numpy as np + +from dimos.msgs.foxglove_msgs.Color import Color +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type.detection2d.bbox import Bbox, Detection2DBBox +from dimos.types.timestamped import to_ros_stamp +from dimos.utils.decorators.decorators import simple_mcache + +if TYPE_CHECKING: + from ultralytics.engine.results import Results + + +@dataclass +class Detection2DPerson(Detection2DBBox): + """Represents a detected person with pose keypoints.""" + + # Pose keypoints - additional fields beyond Detection2DBBox + keypoints: np.ndarray # type: ignore[type-arg] # [17, 2] - x,y coordinates + keypoint_scores: np.ndarray # type: ignore[type-arg] # [17] - confidence scores + + # Optional normalized coordinates + bbox_normalized: np.ndarray | None = None # type: ignore[type-arg] # [x1, y1, x2, y2] in 0-1 range + keypoints_normalized: np.ndarray | None = None # type: ignore[type-arg] # [17, 2] in 0-1 range + + # Image dimensions for context + image_width: int | None = None + image_height: int | None = None + + # Keypoint names (class attribute) + KEYPOINT_NAMES = [ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", + ] + + @classmethod + def from_ultralytics_result( + cls, result: "Results", idx: int, image: Image + ) -> "Detection2DPerson": + """Create Detection2DPerson from ultralytics Results object with pose keypoints. + + Args: + result: Ultralytics Results object containing detection and keypoint data + idx: Index of the detection in the results + image: Source image + + Returns: + Detection2DPerson instance + + Raises: + ValueError: If the result doesn't contain keypoints or is not a person detection + """ + # Validate that this is a pose detection result + if not hasattr(result, "keypoints") or result.keypoints is None: + raise ValueError( + "Cannot create Detection2DPerson from result without keypoints. " + "This appears to be a regular detection result, not a pose detection. " + "Use Detection2DBBox.from_ultralytics_result() instead." + ) + + if not hasattr(result, "boxes") or result.boxes is None: + raise ValueError("Cannot create Detection2DPerson from result without bounding boxes") + + # Check if this is actually a person detection (class 0 in COCO) + class_id = int(result.boxes.cls[idx].cpu()) + if class_id != 0: # Person is class 0 in COCO + class_name = ( + result.names.get(class_id, f"class_{class_id}") + if hasattr(result, "names") + else f"class_{class_id}" + ) + raise ValueError( + f"Cannot create Detection2DPerson from non-person detection. " + f"Got class {class_id} ({class_name}), expected class 0 (person)." + ) + + # Extract bounding box as tuple for Detection2DBBox + bbox_array = result.boxes.xyxy[idx].cpu().numpy() + + bbox: Bbox = ( + float(bbox_array[0]), + float(bbox_array[1]), + float(bbox_array[2]), + float(bbox_array[3]), + ) + + bbox_norm = ( + result.boxes.xyxyn[idx].cpu().numpy() if hasattr(result.boxes, "xyxyn") else None + ) + + confidence = float(result.boxes.conf[idx].cpu()) + class_id = int(result.boxes.cls[idx].cpu()) + + # Extract keypoints + if result.keypoints.xy is None or result.keypoints.conf is None: + raise ValueError("Keypoints xy or conf data is missing from the result") + + keypoints = result.keypoints.xy[idx].cpu().numpy() + keypoint_scores = result.keypoints.conf[idx].cpu().numpy() + keypoints_norm = ( + result.keypoints.xyn[idx].cpu().numpy() + if hasattr(result.keypoints, "xyn") and result.keypoints.xyn is not None + else None + ) + + # Get image dimensions + height, width = result.orig_shape + + # Extract track ID if available + track_id = idx # Use index as default + if hasattr(result.boxes, "id") and result.boxes.id is not None: + track_id = int(result.boxes.id[idx].cpu()) + + # Get class name + name = result.names.get(class_id, "person") if hasattr(result, "names") else "person" + + return cls( + # Detection2DBBox fields + bbox=bbox, + track_id=track_id, + class_id=class_id, + confidence=confidence, + name=name, + ts=image.ts, + image=image, + # Person specific fields + keypoints=keypoints, + keypoint_scores=keypoint_scores, + bbox_normalized=bbox_norm, + keypoints_normalized=keypoints_norm, + image_width=width, + image_height=height, + ) + + @classmethod + def from_yolo(cls, result: "Results", idx: int, image: Image) -> "Detection2DPerson": + """Alias for from_ultralytics_result for backward compatibility.""" + return cls.from_ultralytics_result(result, idx, image) + + @classmethod + def from_ros_detection2d(cls, *args, **kwargs) -> "Detection2DPerson": # type: ignore[no-untyped-def] + """Conversion from ROS Detection2D is not supported for Detection2DPerson. + + The ROS Detection2D message format does not include keypoint data, + which is required for Detection2DPerson. Use Detection2DBBox for + round-trip ROS conversions, or store keypoints separately. + + Raises: + NotImplementedError: Always raised as this conversion is impossible + """ + raise NotImplementedError( + "Cannot convert from ROS Detection2D to Detection2DPerson. " + "The ROS Detection2D message format does not contain keypoint data " + "(keypoints and keypoint_scores) which are required fields for Detection2DPerson. " + "Consider using Detection2DBBox for ROS conversions, or implement a custom " + "message format that includes pose keypoints." + ) + + def get_keypoint(self, name: str) -> tuple[np.ndarray, float]: # type: ignore[type-arg] + """Get specific keypoint by name. + Returns: + Tuple of (xy_coordinates, confidence_score) + """ + if name not in self.KEYPOINT_NAMES: + raise ValueError(f"Invalid keypoint name: {name}. Must be one of {self.KEYPOINT_NAMES}") + + idx = self.KEYPOINT_NAMES.index(name) + return self.keypoints[idx], self.keypoint_scores[idx] + + def get_visible_keypoints(self, threshold: float = 0.5) -> list[tuple[str, np.ndarray, float]]: # type: ignore[type-arg] + """Get all keypoints above confidence threshold. + Returns: + List of tuples: (keypoint_name, xy_coordinates, confidence) + """ + visible = [] + for i, (name, score) in enumerate( + zip(self.KEYPOINT_NAMES, self.keypoint_scores, strict=False) + ): + if score > threshold: + visible.append((name, self.keypoints[i], score)) + return visible + + @simple_mcache + def is_valid(self) -> bool: + valid_keypoints = sum(1 for score in self.keypoint_scores if score > 0.8) + return valid_keypoints >= 5 + + @property + def width(self) -> float: + """Get width of bounding box.""" + x1, _, x2, _ = self.bbox + return x2 - x1 + + @property + def height(self) -> float: + """Get height of bounding box.""" + _, y1, _, y2 = self.bbox + return y2 - y1 + + @property + def center(self) -> tuple[float, float]: + """Get center point of bounding box.""" + x1, y1, x2, y2 = self.bbox + return ((x1 + x2) / 2, (y1 + y2) / 2) + + def to_points_annotation(self) -> list[PointsAnnotation]: + """Override to include keypoint visualizations along with bounding box.""" + annotations = [] + + # First add the bounding box from parent class + annotations.extend(super().to_points_annotation()) + + # Add keypoints as circles + visible_keypoints = self.get_visible_keypoints(threshold=0.3) + + # Create points for visible keypoints + if visible_keypoints: + keypoint_points = [] + for _name, xy, _conf in visible_keypoints: + keypoint_points.append(Point2(float(xy[0]), float(xy[1]))) + + # Add keypoints as circles + annotations.append( + PointsAnnotation( + timestamp=to_ros_stamp(self.ts), + outline_color=Color(r=0.0, g=1.0, b=0.0, a=1.0), # Green outline + fill_color=Color(r=0.0, g=1.0, b=0.0, a=0.5), # Semi-transparent green + thickness=2.0, + points_length=len(keypoint_points), + points=keypoint_points, + type=PointsAnnotation.POINTS, # Draw as individual points/circles + ) + ) + + # Add skeleton connections (COCO skeleton) + skeleton_connections = [ + # Face + (0, 1), + (0, 2), + (1, 3), + (2, 4), # nose to eyes, eyes to ears + # Arms + (5, 6), # shoulders + (5, 7), + (7, 9), # left arm + (6, 8), + (8, 10), # right arm + # Torso + (5, 11), + (6, 12), + (11, 12), # shoulders to hips, hip to hip + # Legs + (11, 13), + (13, 15), # left leg + (12, 14), + (14, 16), # right leg + ] + + # Draw skeleton lines between connected keypoints + for start_idx, end_idx in skeleton_connections: + if ( + start_idx < len(self.keypoint_scores) + and end_idx < len(self.keypoint_scores) + and self.keypoint_scores[start_idx] > 0.3 + and self.keypoint_scores[end_idx] > 0.3 + ): + start_point = Point2( + float(self.keypoints[start_idx][0]), float(self.keypoints[start_idx][1]) + ) + end_point = Point2( + float(self.keypoints[end_idx][0]), float(self.keypoints[end_idx][1]) + ) + + annotations.append( + PointsAnnotation( + timestamp=to_ros_stamp(self.ts), + outline_color=Color(r=0.0, g=0.8, b=1.0, a=0.8), # Cyan + thickness=1.5, + points_length=2, + points=[start_point, end_point], + type=PointsAnnotation.LINE_LIST, + ) + ) + + return annotations + + def to_text_annotation(self) -> list[TextAnnotation]: + """Override to include pose information in text annotations.""" + # Get base annotations from parent + annotations = super().to_text_annotation() + + # Add pose-specific info + visible_count = len(self.get_visible_keypoints(threshold=0.5)) + x1, _y1, _x2, y2 = self.bbox + + annotations.append( + TextAnnotation( + timestamp=to_ros_stamp(self.ts), + position=Point2(x=x1, y=y2 + 40), # Below confidence text + text=f"keypoints: {visible_count}/17", + font_size=18, + text_color=Color(r=0.0, g=1.0, b=0.0, a=1), + background_color=Color(r=0, g=0, b=0, a=0.7), + ) + ) + + return annotations diff --git a/dimos/perception/detection/type/detection2d/test_bbox.py b/dimos/perception/detection/type/detection2d/test_bbox.py new file mode 100644 index 0000000000..5a76b41601 --- /dev/null +++ b/dimos/perception/detection/type/detection2d/test_bbox.py @@ -0,0 +1,87 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + + +def test_detection2d(detection2d) -> None: + # def test_detection_basic_properties(detection2d): + """Test basic detection properties.""" + assert detection2d.track_id >= 0 + assert detection2d.class_id >= 0 + assert 0.0 <= detection2d.confidence <= 1.0 + assert detection2d.name is not None + assert detection2d.ts > 0 + + # def test_bounding_box_format(detection2d): + """Test bounding box format and validity.""" + bbox = detection2d.bbox + assert len(bbox) == 4, "Bounding box should have 4 values" + + x1, y1, x2, y2 = bbox + assert x2 > x1, "x2 should be greater than x1" + assert y2 > y1, "y2 should be greater than y1" + assert x1 >= 0, "x1 should be non-negative" + assert y1 >= 0, "y1 should be non-negative" + + # def test_bbox_2d_volume(detection2d): + """Test bounding box volume calculation.""" + volume = detection2d.bbox_2d_volume() + assert volume > 0, "Bounding box volume should be positive" + + # Calculate expected volume + x1, y1, x2, y2 = detection2d.bbox + expected_volume = (x2 - x1) * (y2 - y1) + assert volume == pytest.approx(expected_volume, abs=0.001) + + # def test_bbox_center_calculation(detection2d): + """Test bounding box center calculation.""" + center_bbox = detection2d.get_bbox_center() + assert len(center_bbox) == 4, "Center bbox should have 4 values" + + center_x, center_y, width, height = center_bbox + x1, y1, x2, y2 = detection2d.bbox + + # Verify center calculations + assert center_x == pytest.approx((x1 + x2) / 2.0, abs=0.001) + assert center_y == pytest.approx((y1 + y2) / 2.0, abs=0.001) + assert width == pytest.approx(x2 - x1, abs=0.001) + assert height == pytest.approx(y2 - y1, abs=0.001) + + # def test_cropped_image(detection2d): + """Test cropped image generation.""" + padding = 20 + cropped = detection2d.cropped_image(padding=padding) + + assert cropped is not None, "Cropped image should not be None" + + # The actual cropped image is (260, 192, 3) + assert cropped.width == 192 + assert cropped.height == 260 + assert cropped.shape == (260, 192, 3) + + # def test_to_ros_bbox(detection2d): + """Test ROS bounding box conversion.""" + ros_bbox = detection2d.to_ros_bbox() + + assert ros_bbox is not None + assert hasattr(ros_bbox, "center") + assert hasattr(ros_bbox, "size_x") + assert hasattr(ros_bbox, "size_y") + + # Verify values match + center_x, center_y, width, height = detection2d.get_bbox_center() + assert ros_bbox.center.position.x == pytest.approx(center_x, abs=0.001) + assert ros_bbox.center.position.y == pytest.approx(center_y, abs=0.001) + assert ros_bbox.size_x == pytest.approx(width, abs=0.001) + assert ros_bbox.size_y == pytest.approx(height, abs=0.001) diff --git a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py new file mode 100644 index 0000000000..83487d2c25 --- /dev/null +++ b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py @@ -0,0 +1,52 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from dimos.perception.detection.type import ImageDetections2D + + +def test_from_ros_detection2d_array(get_moment_2d) -> None: + moment = get_moment_2d() + + detections2d = moment["detections2d"] + + test_image = detections2d.image + + # Convert to ROS detection array + ros_array = detections2d.to_ros_detection2d_array() + + # Convert back to ImageDetections2D + recovered = ImageDetections2D.from_ros_detection2d_array(test_image, ros_array) + + # Verify we got the same number of detections + assert len(recovered.detections) == len(detections2d.detections) + + # Verify the detection matches + original_det = detections2d.detections[0] + recovered_det = recovered.detections[0] + + # Check bbox is approximately the same (allow 1 pixel tolerance due to float conversion) + for orig_val, rec_val in zip(original_det.bbox, recovered_det.bbox, strict=False): + assert orig_val == pytest.approx(rec_val, abs=1.0) + + # Check other properties + assert recovered_det.track_id == original_det.track_id + assert recovered_det.class_id == original_det.class_id + assert recovered_det.confidence == pytest.approx(original_det.confidence, abs=0.01) + + print("\nSuccessfully round-tripped detection through ROS format:") + print(f" Original bbox: {original_det.bbox}") + print(f" Recovered bbox: {recovered_det.bbox}") + print(f" Track ID: {recovered_det.track_id}") + print(f" Confidence: {recovered_det.confidence:.3f}") diff --git a/dimos/perception/detection/type/detection2d/test_person.py b/dimos/perception/detection/type/detection2d/test_person.py new file mode 100644 index 0000000000..06c5883ae2 --- /dev/null +++ b/dimos/perception/detection/type/detection2d/test_person.py @@ -0,0 +1,71 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + + +def test_person_ros_confidence() -> None: + """Test that Detection2DPerson preserves confidence when converting to ROS format.""" + + from dimos.msgs.sensor_msgs import Image + from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector + from dimos.perception.detection.type.detection2d.person import Detection2DPerson + from dimos.utils.data import get_data + + # Load test image + image_path = get_data("cafe.jpg") + image = Image.from_file(image_path) + + # Run pose detection + detector = YoloPersonDetector(device="cpu") + detections = detector.process_image(image) + + # Find a Detection2DPerson (should have at least one person in cafe.jpg) + person_detections = [d for d in detections.detections if isinstance(d, Detection2DPerson)] + assert len(person_detections) > 0, "No person detections found in cafe.jpg" + + # Test each person detection + for person_det in person_detections: + original_confidence = person_det.confidence + assert 0.0 <= original_confidence <= 1.0, "Confidence should be between 0 and 1" + + # Convert to ROS format + ros_det = person_det.to_ros_detection2d() + + # Extract confidence from ROS message + assert len(ros_det.results) > 0, "ROS detection should have results" + ros_confidence = ros_det.results[0].hypothesis.score + + # Verify confidence is preserved (allow small floating point tolerance) + assert original_confidence == pytest.approx(ros_confidence, abs=0.001), ( + f"Confidence mismatch: {original_confidence} != {ros_confidence}" + ) + + print("\nSuccessfully preserved confidence in ROS conversion for Detection2DPerson:") + print(f" Original confidence: {original_confidence:.3f}") + print(f" ROS confidence: {ros_confidence:.3f}") + print(f" Track ID: {person_det.track_id}") + print(f" Visible keypoints: {len(person_det.get_visible_keypoints(threshold=0.3))}/17") + + +def test_person_from_ros_raises() -> None: + """Test that Detection2DPerson.from_ros_detection2d() raises NotImplementedError.""" + from dimos.perception.detection.type.detection2d.person import Detection2DPerson + + with pytest.raises(NotImplementedError) as exc_info: + Detection2DPerson.from_ros_detection2d() + + # Verify the error message is informative + error_msg = str(exc_info.value) + assert "keypoint data" in error_msg.lower() + assert "Detection2DBBox" in error_msg diff --git a/dimos/perception/detection/type/detection3d/__init__.py b/dimos/perception/detection/type/detection3d/__init__.py new file mode 100644 index 0000000000..0e765b175f --- /dev/null +++ b/dimos/perception/detection/type/detection3d/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.perception.detection.type.detection3d.base import Detection3D +from dimos.perception.detection.type.detection3d.bbox import Detection3DBBox +from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.perception.detection.type.detection3d.pointcloud_filters import ( + PointCloudFilter, + height_filter, + radius_outlier, + raycast, + statistical, +) + +__all__ = [ + "Detection3D", + "Detection3DBBox", + "Detection3DPC", + "ImageDetections3DPC", + "PointCloudFilter", + "height_filter", + "radius_outlier", + "raycast", + "statistical", +] diff --git a/dimos/perception/detection/type/detection3d/base.py b/dimos/perception/detection/type/detection3d/base.py new file mode 100644 index 0000000000..1dffe0d551 --- /dev/null +++ b/dimos/perception/detection/type/detection3d/base.py @@ -0,0 +1,46 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from dimos.perception.detection.type.detection2d import Detection2DBBox + +if TYPE_CHECKING: + from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + + from dimos.msgs.geometry_msgs import Transform + + +@dataclass +class Detection3D(Detection2DBBox): + """Abstract base class for 3D detections.""" + + transform: Transform + frame_id: str + + @classmethod + @abstractmethod + def from_2d( + cls, + det: Detection2DBBox, + distance: float, + camera_info: CameraInfo, + world_to_optical_transform: Transform, + ) -> Detection3D | None: + """Create a 3D detection from a 2D detection.""" + ... diff --git a/dimos/perception/detection/type/detection3d/bbox.py b/dimos/perception/detection/type/detection3d/bbox.py new file mode 100644 index 0000000000..ac6f82a25e --- /dev/null +++ b/dimos/perception/detection/type/detection3d/bbox.py @@ -0,0 +1,64 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +import functools +from typing import Any + +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 +from dimos.perception.detection.type.detection2d import Detection2DBBox + + +@dataclass +class Detection3DBBox(Detection2DBBox): + """3D bounding box detection with center, size, and orientation. + + Represents a 3D detection as an oriented bounding box in world space. + """ + + transform: Transform # Camera to world transform + frame_id: str # Frame ID (e.g., "world", "map") + center: Vector3 # Center point in world frame + size: Vector3 # Width, height, depth + orientation: tuple[float, float, float, float] # Quaternion (x, y, z, w) + + @functools.cached_property + def pose(self) -> PoseStamped: + """Convert detection to a PoseStamped using bounding box center. + + Returns pose in world frame with the detection's orientation. + """ + return PoseStamped( + ts=self.ts, + frame_id=self.frame_id, + position=self.center, + orientation=self.orientation, + ) + + def to_repr_dict(self) -> dict[str, Any]: + # Calculate distance from camera + camera_pos = self.transform.translation + distance = (self.center - camera_pos).magnitude() + + parent_dict = super().to_repr_dict() + # Remove bbox key if present + parent_dict.pop("bbox", None) + + return { + **parent_dict, + "dist": f"{distance:.2f}m", + "size": f"[{self.size.x:.2f},{self.size.y:.2f},{self.size.z:.2f}]", + } diff --git a/dimos/perception/detection/type/detection3d/imageDetections3DPC.py b/dimos/perception/detection/type/detection3d/imageDetections3DPC.py new file mode 100644 index 0000000000..0fbb1a7c59 --- /dev/null +++ b/dimos/perception/detection/type/detection3d/imageDetections3DPC.py @@ -0,0 +1,45 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from lcm_msgs.foxglove_msgs import SceneUpdate # type: ignore[import-not-found] + +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.perception.detection.type.imageDetections import ImageDetections + + +class ImageDetections3DPC(ImageDetections[Detection3DPC]): + """Specialized class for 3D detections in an image.""" + + def to_foxglove_scene_update(self) -> SceneUpdate: + """Convert all detections to a Foxglove SceneUpdate message. + + Returns: + SceneUpdate containing SceneEntity objects for all detections + """ + + # Create SceneUpdate message with all detections + scene_update = SceneUpdate() + scene_update.deletions_length = 0 + scene_update.deletions = [] + scene_update.entities = [] + + # Process each detection + for i, detection in enumerate(self.detections): + entity = detection.to_foxglove_scene_entity(entity_id=f"detection_{detection.name}_{i}") + scene_update.entities.append(entity) + + scene_update.entities_length = len(scene_update.entities) + return scene_update diff --git a/dimos/perception/detection/type/detection3d/pointcloud.py b/dimos/perception/detection/type/detection3d/pointcloud.py new file mode 100644 index 0000000000..bff9a7f581 --- /dev/null +++ b/dimos/perception/detection/type/detection3d/pointcloud.py @@ -0,0 +1,336 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +import functools +from typing import TYPE_CHECKING, Any + +from lcm_msgs.builtin_interfaces import Duration # type: ignore[import-not-found] +from lcm_msgs.foxglove_msgs import ( # type: ignore[import-not-found] + CubePrimitive, + SceneEntity, + TextPrimitive, +) +from lcm_msgs.geometry_msgs import ( # type: ignore[import-not-found] + Point, + Pose, + Quaternion, + Vector3 as LCMVector3, +) +import numpy as np + +from dimos.msgs.foxglove_msgs.Color import Color +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.perception.detection.type.detection3d.base import Detection3D +from dimos.perception.detection.type.detection3d.pointcloud_filters import ( + PointCloudFilter, + radius_outlier, + raycast, + statistical, +) +from dimos.types.timestamped import to_ros_stamp + +if TYPE_CHECKING: + from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + + from dimos.perception.detection.type.detection2d import Detection2DBBox + + +@dataclass +class Detection3DPC(Detection3D): + pointcloud: PointCloud2 + + @functools.cached_property + def center(self) -> Vector3: + return Vector3(*self.pointcloud.center) + + @functools.cached_property + def pose(self) -> PoseStamped: + """Convert detection to a PoseStamped using pointcloud center. + + Returns pose in world frame with identity rotation. + The pointcloud is already in world frame. + """ + return PoseStamped( + ts=self.ts, + frame_id=self.frame_id, + position=self.center, + orientation=(0.0, 0.0, 0.0, 1.0), # Identity quaternion + ) + + def get_bounding_box(self): # type: ignore[no-untyped-def] + """Get axis-aligned bounding box of the detection's pointcloud.""" + return self.pointcloud.get_axis_aligned_bounding_box() + + def get_oriented_bounding_box(self): # type: ignore[no-untyped-def] + """Get oriented bounding box of the detection's pointcloud.""" + return self.pointcloud.get_oriented_bounding_box() + + def get_bounding_box_dimensions(self) -> tuple[float, float, float]: + """Get dimensions (width, height, depth) of the detection's bounding box.""" + return self.pointcloud.get_bounding_box_dimensions() + + def bounding_box_intersects(self, other: Detection3DPC) -> bool: + """Check if this detection's bounding box intersects with another's.""" + return self.pointcloud.bounding_box_intersects(other.pointcloud) + + def to_repr_dict(self) -> dict[str, Any]: + # Calculate distance from camera + # The pointcloud is in world frame, and transform gives camera position in world + center_world = self.center + # Camera position in world frame is the translation part of the transform + camera_pos = self.transform.translation + # Use Vector3 subtraction and magnitude + distance = (center_world - camera_pos).magnitude() + + parent_dict = super().to_repr_dict() + # Remove bbox key if present + parent_dict.pop("bbox", None) + + return { + **parent_dict, + "dist": f"{distance:.2f}m", + "points": str(len(self.pointcloud)), + } + + def to_foxglove_scene_entity(self, entity_id: str | None = None) -> SceneEntity: + """Convert detection to a Foxglove SceneEntity with cube primitive and text label. + + Args: + entity_id: Optional custom entity ID. If None, generates one from name and hash. + + Returns: + SceneEntity with cube bounding box and text label + """ + + # Create a cube primitive for the bounding box + cube = CubePrimitive() + + # Get the axis-aligned bounding box + aabb = self.get_bounding_box() # type: ignore[no-untyped-call] + + # Set pose from axis-aligned bounding box + cube.pose = Pose() + cube.pose.position = Point() + # Get center of the axis-aligned bounding box + aabb_center = aabb.get_center() + cube.pose.position.x = aabb_center[0] + cube.pose.position.y = aabb_center[1] + cube.pose.position.z = aabb_center[2] + + # For axis-aligned box, use identity quaternion (no rotation) + cube.pose.orientation = Quaternion() + cube.pose.orientation.x = 0 + cube.pose.orientation.y = 0 + cube.pose.orientation.z = 0 + cube.pose.orientation.w = 1 + + # Set size from axis-aligned bounding box + cube.size = LCMVector3() + aabb_extent = aabb.get_extent() + cube.size.x = aabb_extent[0] # width + cube.size.y = aabb_extent[1] # height + cube.size.z = aabb_extent[2] # depth + + # Set color based on name hash + cube.color = Color.from_string(self.name, alpha=0.2) + + # Create text label + text = TextPrimitive() + text.pose = Pose() + text.pose.position = Point() + text.pose.position.x = aabb_center[0] + text.pose.position.y = aabb_center[1] + text.pose.position.z = aabb_center[2] + aabb_extent[2] / 2 + 0.1 # Above the box + text.pose.orientation = Quaternion() + text.pose.orientation.x = 0 + text.pose.orientation.y = 0 + text.pose.orientation.z = 0 + text.pose.orientation.w = 1 + text.billboard = True + text.font_size = 20.0 + text.scale_invariant = True + text.color = Color() + text.color.r = 1.0 + text.color.g = 1.0 + text.color.b = 1.0 + text.color.a = 1.0 + text.text = self.scene_entity_label() + + # Create scene entity + entity = SceneEntity() + entity.timestamp = to_ros_stamp(self.ts) + entity.frame_id = self.frame_id + entity.id = str(self.track_id) + entity.lifetime = Duration() + entity.lifetime.sec = 0 # Persistent + entity.lifetime.nanosec = 0 + entity.frame_locked = False + + # Initialize all primitive arrays + entity.metadata_length = 0 + entity.metadata = [] + entity.arrows_length = 0 + entity.arrows = [] + entity.cubes_length = 1 + entity.cubes = [cube] + entity.spheres_length = 0 + entity.spheres = [] + entity.cylinders_length = 0 + entity.cylinders = [] + entity.lines_length = 0 + entity.lines = [] + entity.triangles_length = 0 + entity.triangles = [] + entity.texts_length = 1 + entity.texts = [text] + entity.models_length = 0 + entity.models = [] + + return entity + + def scene_entity_label(self) -> str: + return f"{self.track_id}/{self.name} ({self.confidence:.0%})" + + @classmethod + def from_2d( # type: ignore[override] + cls, + det: Detection2DBBox, + world_pointcloud: PointCloud2, + camera_info: CameraInfo, + world_to_optical_transform: Transform, + # filters are to be adjusted based on the sensor noise characteristics if feeding + # sensor data directly + filters: list[PointCloudFilter] | None = None, + ) -> Detection3DPC | None: + """Create a Detection3D from a 2D detection by projecting world pointcloud. + + This method handles: + 1. Projecting world pointcloud to camera frame + 2. Filtering points within the 2D detection bounding box + 3. Cleaning up the pointcloud (height filter, outlier removal) + 4. Hidden point removal from camera perspective + + Args: + det: The 2D detection + world_pointcloud: Full pointcloud in world frame + camera_info: Camera calibration info + world_to_camerlka_transform: Transform from world to camera frame + filters: List of functions to apply to the pointcloud for filtering + Returns: + Detection3D with filtered pointcloud, or None if no valid points + """ + # Set default filters if none provided + if filters is None: + filters = [ + # height_filter(0.1), + raycast(), + radius_outlier(), + statistical(), + ] + + # Extract camera parameters + fx, fy = camera_info.K[0], camera_info.K[4] + cx, cy = camera_info.K[2], camera_info.K[5] + image_width = camera_info.width + image_height = camera_info.height + + camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + + # Convert pointcloud to numpy array + world_points = world_pointcloud.as_numpy() + + # Project points to camera frame + points_homogeneous = np.hstack([world_points, np.ones((world_points.shape[0], 1))]) + extrinsics_matrix = world_to_optical_transform.to_matrix() + points_camera = (extrinsics_matrix @ points_homogeneous.T).T + + # Filter out points behind the camera + valid_mask = points_camera[:, 2] > 0 + points_camera = points_camera[valid_mask] + world_points = world_points[valid_mask] + + if len(world_points) == 0: + return None + + # Project to 2D + points_2d_homogeneous = (camera_matrix @ points_camera[:, :3].T).T + points_2d = points_2d_homogeneous[:, :2] / points_2d_homogeneous[:, 2:3] + + # Filter points within image bounds + in_image_mask = ( + (points_2d[:, 0] >= 0) + & (points_2d[:, 0] < image_width) + & (points_2d[:, 1] >= 0) + & (points_2d[:, 1] < image_height) + ) + points_2d = points_2d[in_image_mask] + world_points = world_points[in_image_mask] + + if len(world_points) == 0: + return None + + # Extract bbox from Detection2D + x_min, y_min, x_max, y_max = det.bbox + + # Find points within this detection box (with small margin) + margin = 5 # pixels + in_box_mask = ( + (points_2d[:, 0] >= x_min - margin) + & (points_2d[:, 0] <= x_max + margin) + & (points_2d[:, 1] >= y_min - margin) + & (points_2d[:, 1] <= y_max + margin) + ) + + detection_points = world_points[in_box_mask] + + if detection_points.shape[0] == 0: + # print(f"No points found in detection bbox after projection. {det.name}") + return None + + # Create initial pointcloud for this detection + initial_pc = PointCloud2.from_numpy( + detection_points, + frame_id=world_pointcloud.frame_id, + timestamp=world_pointcloud.ts, + ) + + # Apply filters - each filter gets all arguments + detection_pc = initial_pc + for filter_func in filters: + result = filter_func(det, detection_pc, camera_info, world_to_optical_transform) + if result is None: + return None + detection_pc = result + + # Final check for empty pointcloud + if len(detection_pc.pointcloud.points) == 0: + return None + + # Create Detection3D with filtered pointcloud + return cls( + image=det.image, + bbox=det.bbox, + track_id=det.track_id, + class_id=det.class_id, + confidence=det.confidence, + name=det.name, + ts=det.ts, + pointcloud=detection_pc, + transform=world_to_optical_transform, + frame_id=world_pointcloud.frame_id, + ) diff --git a/dimos/perception/detection/type/detection3d/pointcloud_filters.py b/dimos/perception/detection/type/detection3d/pointcloud_filters.py new file mode 100644 index 0000000000..0e5732ec4b --- /dev/null +++ b/dimos/perception/detection/type/detection3d/pointcloud_filters.py @@ -0,0 +1,82 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + +from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.perception.detection.type.detection2d import Detection2DBBox + +# Filters take Detection2DBBox, PointCloud2, CameraInfo, Transform and return filtered PointCloud2 or None +PointCloudFilter = Callable[ + [Detection2DBBox, PointCloud2, CameraInfo, Transform], PointCloud2 | None +] + + +def height_filter(height: float = 0.1) -> PointCloudFilter: + return lambda det, pc, ci, tf: pc.filter_by_height(height) + + +def statistical(nb_neighbors: int = 40, std_ratio: float = 0.5) -> PointCloudFilter: + def filter_func( + det: Detection2DBBox, pc: PointCloud2, ci: CameraInfo, tf: Transform + ) -> PointCloud2 | None: + try: + statistical, _removed = pc.pointcloud.remove_statistical_outlier( + nb_neighbors=nb_neighbors, std_ratio=std_ratio + ) + return PointCloud2(statistical, pc.frame_id, pc.ts) + except Exception: + # print("statistical filter failed:", e) + return None + + return filter_func + + +def raycast() -> PointCloudFilter: + def filter_func( + det: Detection2DBBox, pc: PointCloud2, ci: CameraInfo, tf: Transform + ) -> PointCloud2 | None: + try: + camera_pos = tf.inverse().translation + camera_pos_np = camera_pos.to_numpy() + _, visible_indices = pc.pointcloud.hidden_point_removal(camera_pos_np, radius=100.0) + visible_pcd = pc.pointcloud.select_by_index(visible_indices) + return PointCloud2(visible_pcd, pc.frame_id, pc.ts) + except Exception: + # print("raycast filter failed:", e) + return None + + return filter_func + + +def radius_outlier(min_neighbors: int = 20, radius: float = 0.3) -> PointCloudFilter: + """ + Remove isolated points: keep only points that have at least `min_neighbors` + neighbors within `radius` meters (same units as your point cloud). + """ + + def filter_func( + det: Detection2DBBox, pc: PointCloud2, ci: CameraInfo, tf: Transform + ) -> PointCloud2 | None: + filtered_pcd, _removed = pc.pointcloud.remove_radius_outlier( + nb_points=min_neighbors, radius=radius + ) + return PointCloud2(filtered_pcd, pc.frame_id, pc.ts) + + return filter_func diff --git a/dimos/perception/detection/type/detection3d/test_imageDetections3DPC.py b/dimos/perception/detection/type/detection3d/test_imageDetections3DPC.py new file mode 100644 index 0000000000..cca8b862d4 --- /dev/null +++ b/dimos/perception/detection/type/detection3d/test_imageDetections3DPC.py @@ -0,0 +1,37 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.mark.skip +def test_to_foxglove_scene_update(detections3dpc) -> None: + # Convert to scene update + scene_update = detections3dpc.to_foxglove_scene_update() + + # Verify scene update structure + assert scene_update is not None + assert scene_update.deletions_length == 0 + assert len(scene_update.deletions) == 0 + assert scene_update.entities_length == len(detections3dpc.detections) + assert len(scene_update.entities) == len(detections3dpc.detections) + + # Verify each entity corresponds to a detection + for _i, (entity, detection) in enumerate( + zip(scene_update.entities, detections3dpc.detections, strict=False) + ): + assert entity.id == str(detection.track_id) + assert entity.frame_id == detection.frame_id + assert entity.cubes_length == 1 + assert entity.texts_length == 1 diff --git a/dimos/perception/detection/type/detection3d/test_pointcloud.py b/dimos/perception/detection/type/detection3d/test_pointcloud.py new file mode 100644 index 0000000000..efc1c659aa --- /dev/null +++ b/dimos/perception/detection/type/detection3d/test_pointcloud.py @@ -0,0 +1,137 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + + +def test_detection3dpc(detection3dpc) -> None: + # def test_oriented_bounding_box(detection3dpc): + """Test oriented bounding box calculation and values.""" + obb = detection3dpc.get_oriented_bounding_box() + assert obb is not None, "Oriented bounding box should not be None" + + # Verify OBB center values + assert obb.center[0] == pytest.approx(-3.36002, abs=0.1) + assert obb.center[1] == pytest.approx(-0.196446, abs=0.1) + assert obb.center[2] == pytest.approx(0.220184, abs=0.1) + + # Verify OBB extent values + assert obb.extent[0] == pytest.approx(0.531275, abs=0.12) + assert obb.extent[1] == pytest.approx(0.461054, abs=0.1) + assert obb.extent[2] == pytest.approx(0.155, abs=0.1) + + # def test_bounding_box_dimensions(detection3dpc): + """Test bounding box dimension calculation.""" + dims = detection3dpc.get_bounding_box_dimensions() + assert len(dims) == 3, "Bounding box dimensions should have 3 values" + assert dims[0] == pytest.approx(0.350, abs=0.1) + assert dims[1] == pytest.approx(0.250, abs=0.1) + assert dims[2] == pytest.approx(0.550, abs=0.1) + + # def test_axis_aligned_bounding_box(detection3dpc): + """Test axis-aligned bounding box calculation.""" + aabb = detection3dpc.get_bounding_box() + assert aabb is not None, "Axis-aligned bounding box should not be None" + + # Verify AABB min values + assert aabb.min_bound[0] == pytest.approx(-3.575, abs=0.1) + assert aabb.min_bound[1] == pytest.approx(-0.375, abs=0.1) + assert aabb.min_bound[2] == pytest.approx(-0.075, abs=0.1) + + # Verify AABB max values + assert aabb.max_bound[0] == pytest.approx(-3.075, abs=0.1) + assert aabb.max_bound[1] == pytest.approx(-0.125, abs=0.1) + assert aabb.max_bound[2] == pytest.approx(0.475, abs=0.1) + + # def test_point_cloud_properties(detection3dpc): + """Test point cloud data and boundaries.""" + pc_points = detection3dpc.pointcloud.points() + assert len(pc_points) > 60 + assert detection3dpc.pointcloud.frame_id == "world", ( + f"Expected frame_id 'world', got '{detection3dpc.pointcloud.frame_id}'" + ) + + # Extract xyz coordinates from points + points = np.array([[pt[0], pt[1], pt[2]] for pt in pc_points]) + + min_pt = np.min(points, axis=0) + max_pt = np.max(points, axis=0) + center = np.mean(points, axis=0) + + # Verify point cloud boundaries + assert min_pt[0] == pytest.approx(-3.575, abs=0.1) + assert min_pt[1] == pytest.approx(-0.375, abs=0.1) + assert min_pt[2] == pytest.approx(-0.075, abs=0.1) + + assert max_pt[0] == pytest.approx(-3.075, abs=0.1) + assert max_pt[1] == pytest.approx(-0.125, abs=0.1) + assert max_pt[2] == pytest.approx(0.475, abs=0.1) + + assert center[0] == pytest.approx(-3.326, abs=0.1) + assert center[1] == pytest.approx(-0.202, abs=0.1) + assert center[2] == pytest.approx(0.160, abs=0.1) + + # def test_foxglove_scene_entity_generation(detection3dpc): + """Test Foxglove scene entity creation and structure.""" + entity = detection3dpc.to_foxglove_scene_entity("test_entity_123") + + # Verify entity metadata + assert entity.id == "1", f"Expected entity ID '1', got '{entity.id}'" + assert entity.frame_id == "world", f"Expected frame_id 'world', got '{entity.frame_id}'" + assert entity.cubes_length == 1, f"Expected 1 cube, got {entity.cubes_length}" + assert entity.texts_length == 1, f"Expected 1 text, got {entity.texts_length}" + + # def test_foxglove_cube_properties(detection3dpc): + """Test Foxglove cube primitive properties.""" + entity = detection3dpc.to_foxglove_scene_entity("test_entity_123") + cube = entity.cubes[0] + + # Verify position + assert cube.pose.position.x == pytest.approx(-3.325, abs=0.1) + assert cube.pose.position.y == pytest.approx(-0.250, abs=0.1) + assert cube.pose.position.z == pytest.approx(0.200, abs=0.1) + + # Verify size + assert cube.size.x == pytest.approx(0.350, abs=0.1) + assert cube.size.y == pytest.approx(0.250, abs=0.1) + assert cube.size.z == pytest.approx(0.550, abs=0.1) + + # Verify color (green with alpha) + assert cube.color.r == pytest.approx(0.08235294117647059, abs=0.1) + assert cube.color.g == pytest.approx(0.7176470588235294, abs=0.1) + assert cube.color.b == pytest.approx(0.28627450980392155, abs=0.1) + assert cube.color.a == pytest.approx(0.2, abs=0.1) + + # def test_foxglove_text_label(detection3dpc): + """Test Foxglove text label properties.""" + entity = detection3dpc.to_foxglove_scene_entity("test_entity_123") + text = entity.texts[0] + + assert text.text in ["1/suitcase (81%)", "1/suitcase (82%)"], ( + f"Expected text '1/suitcase (81%)' or '1/suitcase (82%)', got '{text.text}'" + ) + assert text.pose.position.x == pytest.approx(-3.325, abs=0.1) + assert text.pose.position.y == pytest.approx(-0.250, abs=0.1) + assert text.pose.position.z == pytest.approx(0.575, abs=0.1) + assert text.font_size == 20.0, f"Expected font size 20.0, got {text.font_size}" + + # def test_detection_pose(detection3dpc): + """Test detection pose and frame information.""" + assert detection3dpc.pose.x == pytest.approx(-3.327, abs=0.1) + assert detection3dpc.pose.y == pytest.approx(-0.202, abs=0.1) + assert detection3dpc.pose.z == pytest.approx(0.160, abs=0.1) + assert detection3dpc.pose.frame_id == "world", ( + f"Expected frame_id 'world', got '{detection3dpc.pose.frame_id}'" + ) diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py new file mode 100644 index 0000000000..729ec87972 --- /dev/null +++ b/dimos/perception/detection/type/imageDetections.py @@ -0,0 +1,97 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from dimos_lcm.vision_msgs import Detection2DArray # type: ignore[import-untyped] + +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.std_msgs import Header +from dimos.perception.detection.type.utils import TableStr + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from dimos.msgs.sensor_msgs import Image + from dimos.perception.detection.type.detection2d.base import Detection2D + + T = TypeVar("T", bound=Detection2D) +else: + from dimos.perception.detection.type.detection2d.base import Detection2D + + T = TypeVar("T", bound=Detection2D) + + +class ImageDetections(Generic[T], TableStr): + image: Image + detections: list[T] + + @property + def ts(self) -> float: + return self.image.ts + + def __init__(self, image: Image, detections: list[T] | None = None) -> None: + self.image = image + self.detections = detections or [] + for det in self.detections: + if not det.ts: + det.ts = image.ts + + def __len__(self) -> int: + return len(self.detections) + + def __iter__(self) -> Iterator: # type: ignore[type-arg] + return iter(self.detections) + + def __getitem__(self, index): # type: ignore[no-untyped-def] + return self.detections[index] + + def filter(self, *predicates: Callable[[T], bool]) -> ImageDetections[T]: + """Filter detections using one or more predicate functions. + + Multiple predicates are applied in cascade (all must return True). + + Args: + *predicates: Functions that take a detection and return True to keep it + + Returns: + A new ImageDetections instance with filtered detections + """ + filtered_detections = self.detections + for predicate in predicates: + filtered_detections = [det for det in filtered_detections if predicate(det)] + return ImageDetections(self.image, filtered_detections) + + def to_ros_detection2d_array(self) -> Detection2DArray: + return Detection2DArray( + detections_length=len(self.detections), + header=Header(self.image.ts, "camera_optical"), + detections=[det.to_ros_detection2d() for det in self.detections], + ) + + def to_foxglove_annotations(self) -> ImageAnnotations: + def flatten(xss): # type: ignore[no-untyped-def] + return [x for xs in xss for x in xs] + + texts = flatten(det.to_text_annotation() for det in self.detections) # type: ignore[no-untyped-call] + points = flatten(det.to_points_annotation() for det in self.detections) # type: ignore[no-untyped-call] + + return ImageAnnotations( + texts=texts, + texts_length=len(texts), + points=points, + points_length=len(points), + ) diff --git a/dimos/perception/detection/type/test_detection3d.py b/dimos/perception/detection/type/test_detection3d.py new file mode 100644 index 0000000000..b467df7ffe --- /dev/null +++ b/dimos/perception/detection/type/test_detection3d.py @@ -0,0 +1,36 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + + +def test_guess_projection(get_moment_2d, publish_moment) -> None: + moment = get_moment_2d() + for key, value in moment.items(): + print(key, "====================================") + print(value) + + moment.get("camera_info") + detection2d = moment.get("detections2d")[0] + tf = moment.get("tf") + tf.get("camera_optical", "world", detection2d.ts, 5.0) + + # for stash + # detection3d = Detection3D.from_2d(detection2d, 1.5, camera_info, transform) + # print(detection3d) + + # foxglove bridge needs 2 messages per topic to pass to foxglove + publish_moment(moment) + time.sleep(0.1) + publish_moment(moment) diff --git a/dimos/perception/detection/type/test_object3d.py b/dimos/perception/detection/type/test_object3d.py new file mode 100644 index 0000000000..7057fbb9cb --- /dev/null +++ b/dimos/perception/detection/type/test_object3d.py @@ -0,0 +1,177 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.perception.detection.moduleDB import Object3D +from dimos.perception.detection.type.detection3d import ImageDetections3DPC + + +def test_first_object(first_object) -> None: + # def test_object3d_properties(first_object): + """Test basic properties of an Object3D.""" + assert first_object.track_id is not None + assert isinstance(first_object.track_id, str) + assert first_object.name is not None + assert first_object.class_id >= 0 + assert 0.0 <= first_object.confidence <= 1.0 + assert first_object.ts > 0 + assert first_object.frame_id is not None + assert first_object.best_detection is not None + + # def test_object3d_center(first_object): + """Test Object3D center calculation.""" + assert first_object.center is not None + assert hasattr(first_object.center, "x") + assert hasattr(first_object.center, "y") + assert hasattr(first_object.center, "z") + + # Center should be within reasonable bounds + assert -10 < first_object.center.x < 10 + assert -10 < first_object.center.y < 10 + assert -10 < first_object.center.z < 10 + + +def test_object3d_repr_dict(first_object) -> None: + """Test to_repr_dict method.""" + repr_dict = first_object.to_repr_dict() + + assert "object_id" in repr_dict + assert "detections" in repr_dict + assert "center" in repr_dict + + assert repr_dict["object_id"] == first_object.track_id + assert repr_dict["detections"] == first_object.detections + + # Center should be formatted as string with coordinates + assert isinstance(repr_dict["center"], str) + assert repr_dict["center"].startswith("[") + assert repr_dict["center"].endswith("]") + + # def test_object3d_scene_entity_label(first_object): + """Test scene entity label generation.""" + label = first_object.scene_entity_label() + + assert isinstance(label, str) + assert first_object.name in label + assert f"({first_object.detections})" in label + + # def test_object3d_agent_encode(first_object): + """Test agent encoding.""" + encoded = first_object.agent_encode() + + assert isinstance(encoded, dict) + assert "id" in encoded + assert "name" in encoded + assert "detections" in encoded + assert "last_seen" in encoded + + assert encoded["id"] == first_object.track_id + assert encoded["name"] == first_object.name + assert encoded["detections"] == first_object.detections + assert encoded["last_seen"].endswith("s ago") + + # def test_object3d_image_property(first_object): + """Test get_image method returns best_detection's image.""" + assert first_object.get_image() is not None + assert first_object.get_image() is first_object.best_detection.image + + +def test_all_objeects(all_objects) -> None: + # def test_object3d_multiple_detections(all_objects): + """Test objects that have been built from multiple detections.""" + # Find objects with multiple detections + multi_detection_objects = [obj for obj in all_objects if obj.detections > 1] + + if multi_detection_objects: + obj = multi_detection_objects[0] + + # Since detections is now a counter, we can only test that we have multiple detections + # and that best_detection exists + assert obj.detections > 1 + assert obj.best_detection is not None + assert obj.confidence is not None + assert obj.ts > 0 + + # Test that best_detection has reasonable properties + assert obj.best_detection.bbox_2d_volume() > 0 + + # def test_object_db_module_objects_structure(all_objects): + """Test the structure of objects in the database.""" + for obj in all_objects: + assert isinstance(obj, Object3D) + assert hasattr(obj, "track_id") + assert hasattr(obj, "detections") + assert hasattr(obj, "best_detection") + assert hasattr(obj, "center") + assert obj.detections >= 1 + + +def test_objectdb_module(object_db_module) -> None: + # def test_object_db_module_populated(object_db_module): + """Test that ObjectDBModule is properly populated.""" + assert len(object_db_module.objects) > 0, "Database should contain objects" + assert object_db_module.cnt > 0, "Object counter should be greater than 0" + + # def test_object3d_addition(object_db_module): + """Test Object3D addition operator.""" + # Get existing objects from the database + objects = list(object_db_module.objects.values()) + if len(objects) < 2: + pytest.skip("Not enough objects in database") + + # Get detections from two different objects + det1 = objects[0].best_detection + det2 = objects[1].best_detection + + # Create a new object with the first detection + obj = Object3D("test_track_combined", det1) + + # Add the second detection from a different object + combined = obj + det2 + + assert combined.track_id == "test_track_combined" + assert combined.detections == 2 + + # Since detections is now a counter, we can't check if specific detections are in the list + # We can only verify the count and that best_detection is properly set + + # Best detection should be determined by the Object3D logic + assert combined.best_detection is not None + + # Center should be valid (no specific value check since we're using real detections) + assert hasattr(combined, "center") + assert combined.center is not None + + # def test_image_detections3d_scene_update(object_db_module): + """Test ImageDetections3DPC to Foxglove scene update conversion.""" + # Get some detections + objects = list(object_db_module.objects.values()) + if not objects: + pytest.skip("No objects in database") + + detections = [obj.best_detection for obj in objects[:3]] # Take up to 3 + + image_detections = ImageDetections3DPC(image=detections[0].image, detections=detections) + + scene_update = image_detections.to_foxglove_scene_update() + + assert scene_update is not None + assert scene_update.entities_length == len(detections) + + for i, entity in enumerate(scene_update.entities): + assert entity.id == str(detections[i].track_id) + assert entity.frame_id == detections[i].frame_id + assert entity.cubes_length == 1 + assert entity.texts_length == 1 diff --git a/dimos/perception/detection/type/utils.py b/dimos/perception/detection/type/utils.py new file mode 100644 index 0000000000..eb924cbd1a --- /dev/null +++ b/dimos/perception/detection/type/utils.py @@ -0,0 +1,101 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib + +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from dimos.types.timestamped import to_timestamp + + +def _hash_to_color(name: str) -> str: + """Generate a consistent color for a given name using hash.""" + # List of rich colors to choose from + colors = [ + "cyan", + "magenta", + "yellow", + "blue", + "green", + "red", + "bright_cyan", + "bright_magenta", + "bright_yellow", + "bright_blue", + "bright_green", + "bright_red", + "purple", + "white", + "pink", + ] + + # Hash the name and pick a color + hash_value = hashlib.md5(name.encode()).digest()[0] + return colors[hash_value % len(colors)] + + +class TableStr: + """Mixin class that provides table-based string representation for detection collections.""" + + def __str__(self) -> str: + console = Console(force_terminal=True, legacy_windows=False) + + # Create a table for detections + table = Table( + title=f"{self.__class__.__name__} [{len(self.detections)} detections @ {to_timestamp(self.image.ts):.3f}]", # type: ignore[attr-defined] + show_header=True, + show_edge=True, + ) + + # Dynamically build columns based on the first detection's dict keys + if not self.detections: # type: ignore[attr-defined] + return ( + f" {self.__class__.__name__} [0 detections @ {to_timestamp(self.image.ts):.3f}]" # type: ignore[attr-defined] + ) + + # Cache all repr_dicts to avoid double computation + detection_dicts = [det.to_repr_dict() for det in self] # type: ignore[attr-defined] + + first_dict = detection_dicts[0] + table.add_column("#", style="dim") + for col in first_dict.keys(): + color = _hash_to_color(col) + table.add_column(col.title(), style=color) + + # Add each detection to the table + for i, d in enumerate(detection_dicts): + row = [str(i)] + + for key in first_dict.keys(): + if key == "conf": + # Color-code confidence + conf_color = ( + "green" + if float(d[key]) > 0.8 + else "yellow" + if float(d[key]) > 0.5 + else "red" + ) + row.append(Text(f"{d[key]}", style=conf_color)) # type: ignore[arg-type] + elif key == "points" and d.get(key) == "None": + row.append(Text(d.get(key, ""), style="dim")) # type: ignore[arg-type] + else: + row.append(str(d.get(key, ""))) + table.add_row(*row) + + with console.capture() as capture: + console.print(table) + return capture.get().strip() diff --git a/dimos/perception/detection2d/utils.py b/dimos/perception/detection2d/utils.py new file mode 100644 index 0000000000..a505eef7c8 --- /dev/null +++ b/dimos/perception/detection2d/utils.py @@ -0,0 +1,309 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence + +import cv2 +import numpy as np + + +def filter_detections( # type: ignore[no-untyped-def] + bboxes, + track_ids, + class_ids, + confidences, + names: Sequence[str], + class_filter=None, + name_filter=None, + track_id_filter=None, +): + """ + Filter detection results based on class IDs, names, and/or tracking IDs. + + Args: + bboxes: List of bounding boxes [x1, y1, x2, y2] + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + class_filter: List/set of class IDs to keep, or None to keep all + name_filter: List/set of class names to keep, or None to keep all + track_id_filter: List/set of track IDs to keep, or None to keep all + + Returns: + tuple: (filtered_bboxes, filtered_track_ids, filtered_class_ids, + filtered_confidences, filtered_names) + """ + # Convert filters to sets for efficient lookup + if class_filter is not None: + class_filter = set(class_filter) + if name_filter is not None: + name_filter = set(name_filter) + if track_id_filter is not None: + track_id_filter = set(track_id_filter) + + # Initialize lists for filtered results + filtered_bboxes = [] + filtered_track_ids = [] + filtered_class_ids = [] + filtered_confidences = [] + filtered_names = [] + + # Filter detections + for bbox, track_id, class_id, conf, name in zip( + bboxes, track_ids, class_ids, confidences, names, strict=False + ): + # Check if detection passes all specified filters + keep = True + + if class_filter is not None: + keep = keep and (class_id in class_filter) + + if name_filter is not None: + keep = keep and (name in name_filter) + + if track_id_filter is not None: + keep = keep and (track_id in track_id_filter) + + # If detection passes all filters, add it to results + if keep: + filtered_bboxes.append(bbox) + filtered_track_ids.append(track_id) + filtered_class_ids.append(class_id) + filtered_confidences.append(conf) + filtered_names.append(name) + + return ( + filtered_bboxes, + filtered_track_ids, + filtered_class_ids, + filtered_confidences, + filtered_names, + ) + + +def extract_detection_results(result, class_filter=None, name_filter=None, track_id_filter=None): # type: ignore[no-untyped-def] + """ + Extract and optionally filter detection information from a YOLO result object. + + Args: + result: Ultralytics result object + class_filter: List/set of class IDs to keep, or None to keep all + name_filter: List/set of class names to keep, or None to keep all + track_id_filter: List/set of track IDs to keep, or None to keep all + + Returns: + tuple: (bboxes, track_ids, class_ids, confidences, names) + - bboxes: list of [x1, y1, x2, y2] coordinates + - track_ids: list of tracking IDs + - class_ids: list of class indices + - confidences: list of detection confidences + - names: list of class names + """ + bboxes = [] # type: ignore[var-annotated] + track_ids = [] # type: ignore[var-annotated] + class_ids = [] # type: ignore[var-annotated] + confidences = [] # type: ignore[var-annotated] + names = [] # type: ignore[var-annotated] + + if result.boxes is None: + return bboxes, track_ids, class_ids, confidences, names + + for box in result.boxes: + # Extract bounding box coordinates + x1, y1, x2, y2 = box.xyxy[0].tolist() + + # Extract tracking ID if available + track_id = -1 + if hasattr(box, "id") and box.id is not None: + track_id = int(box.id[0].item()) + + # Extract class information + cls_idx = int(box.cls[0]) + name = result.names[cls_idx] + + # Extract confidence + conf = float(box.conf[0]) + + # Check filters before adding to results + keep = True + if class_filter is not None: + keep = keep and (cls_idx in class_filter) + if name_filter is not None: + keep = keep and (name in name_filter) + if track_id_filter is not None: + keep = keep and (track_id in track_id_filter) + + if keep: + bboxes.append([x1, y1, x2, y2]) + track_ids.append(track_id) + class_ids.append(cls_idx) + confidences.append(conf) + names.append(name) + + return bboxes, track_ids, class_ids, confidences, names + + +def plot_results( # type: ignore[no-untyped-def] + image, bboxes, track_ids, class_ids, confidences, names: Sequence[str], alpha: float = 0.5 +): + """ + Draw bounding boxes and labels on the image. + + Args: + image: Original input image + bboxes: List of bounding boxes [x1, y1, x2, y2] + track_ids: List of tracking IDs + class_ids: List of class indices + confidences: List of detection confidences + names: List of class names + alpha: Transparency of the overlay + + Returns: + Image with visualized detections + """ + vis_img = image.copy() + + for bbox, track_id, conf, name in zip(bboxes, track_ids, confidences, names, strict=False): + # Generate consistent color based on track_id or class name + if track_id != -1: + np.random.seed(track_id) + else: + np.random.seed(hash(name) % 100000) + color = np.random.randint(0, 255, (3,), dtype=np.uint8) + np.random.seed(None) + + # Draw bounding box + x1, y1, x2, y2 = map(int, bbox) + cv2.rectangle(vis_img, (x1, y1), (x2, y2), color.tolist(), 2) + + # Prepare label text + if track_id != -1: + label = f"ID:{track_id} {name} {conf:.2f}" + else: + label = f"{name} {conf:.2f}" + + # Calculate text size for background rectangle + (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Draw background rectangle for text + cv2.rectangle(vis_img, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1) + + # Draw text with white color for better visibility + cv2.putText( + vis_img, label, (x1 + 2, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1 + ) + + return vis_img + + +def calculate_depth_from_bbox(depth_map, bbox): # type: ignore[no-untyped-def] + """ + Calculate the average depth of an object within a bounding box. + Uses the 25th to 75th percentile range to filter outliers. + + Args: + depth_map: The depth map + bbox: Bounding box in format [x1, y1, x2, y2] + + Returns: + float: Average depth in meters, or None if depth estimation fails + """ + try: + # Extract region of interest from the depth map + x1, y1, x2, y2 = map(int, bbox) + roi_depth = depth_map[y1:y2, x1:x2] + + if roi_depth.size == 0: + return None + + # Calculate 25th and 75th percentile to filter outliers + p25 = np.percentile(roi_depth, 25) + p75 = np.percentile(roi_depth, 75) + + # Filter depth values within this range + filtered_depth = roi_depth[(roi_depth >= p25) & (roi_depth <= p75)] + + # Calculate average depth (convert to meters) + if filtered_depth.size > 0: + return np.mean(filtered_depth) / 1000.0 # Convert mm to meters + + return None + except Exception as e: + print(f"Error calculating depth from bbox: {e}") + return None + + +def calculate_distance_angle_from_bbox(bbox, depth: int, camera_intrinsics): # type: ignore[no-untyped-def] + """ + Calculate distance and angle to object center based on bbox and depth. + + Args: + bbox: Bounding box [x1, y1, x2, y2] + depth: Depth value in meters + camera_intrinsics: List [fx, fy, cx, cy] with camera parameters + + Returns: + tuple: (distance, angle) in meters and radians + """ + if camera_intrinsics is None: + raise ValueError("Camera intrinsics required for distance calculation") + + # Extract camera parameters + fx, _fy, cx, _cy = camera_intrinsics + + # Calculate center of bounding box in pixels + x1, y1, x2, y2 = bbox + center_x = (x1 + x2) / 2 + (y1 + y2) / 2 + + # Calculate normalized image coordinates + x_norm = (center_x - cx) / fx + + # Calculate angle (positive to the right) + angle = np.arctan(x_norm) + + # Calculate distance using depth and angle + distance = depth / np.cos(angle) if np.cos(angle) != 0 else depth + + return distance, angle + + +def calculate_object_size_from_bbox(bbox, depth: int, camera_intrinsics): # type: ignore[no-untyped-def] + """ + Estimate physical width and height of object in meters. + + Args: + bbox: Bounding box [x1, y1, x2, y2] + depth: Depth value in meters + camera_intrinsics: List [fx, fy, cx, cy] with camera parameters + + Returns: + tuple: (width, height) in meters + """ + if camera_intrinsics is None: + return 0.0, 0.0 + + fx, fy, _, _ = camera_intrinsics + + # Calculate bbox dimensions in pixels + x1, y1, x2, y2 = bbox + width_px = x2 - x1 + height_px = y2 - y1 + + # Convert to meters using similar triangles and depth + width_m = (width_px * depth) / fx + height_m = (height_px * depth) / fy + + return width_m, height_m diff --git a/dimos/perception/grasp_generation/__init__.py b/dimos/perception/grasp_generation/__init__.py new file mode 100644 index 0000000000..16281fe0b6 --- /dev/null +++ b/dimos/perception/grasp_generation/__init__.py @@ -0,0 +1 @@ +from .utils import * diff --git a/dimos/perception/grasp_generation/grasp_generation.py b/dimos/perception/grasp_generation/grasp_generation.py new file mode 100644 index 0000000000..4f2e4b68a1 --- /dev/null +++ b/dimos/perception/grasp_generation/grasp_generation.py @@ -0,0 +1,233 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Dimensional-hosted grasp generation for manipulation pipeline. +""" + +import asyncio + +import numpy as np +import open3d as o3d # type: ignore[import-untyped] + +from dimos.perception.grasp_generation.utils import parse_grasp_results +from dimos.types.manipulation import ObjectData +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class HostedGraspGenerator: + """ + Dimensional-hosted grasp generator using WebSocket communication. + """ + + def __init__(self, server_url: str) -> None: + """ + Initialize Dimensional-hosted grasp generator. + + Args: + server_url: WebSocket URL for Dimensional-hosted grasp generator server + """ + self.server_url = server_url + logger.info(f"Initialized grasp generator with server: {server_url}") + + def generate_grasps_from_objects( + self, objects: list[ObjectData], full_pcd: o3d.geometry.PointCloud + ) -> list[dict]: # type: ignore[type-arg] + """ + Generate grasps from ObjectData objects using grasp generator. + + Args: + objects: List of ObjectData with point clouds + full_pcd: Open3D point cloud of full scene + + Returns: + Parsed grasp results as list of dictionaries + """ + try: + # Combine all point clouds + all_points = [] + all_colors = [] + valid_objects = 0 + + for obj in objects: + if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: + continue + + points = obj["point_cloud_numpy"] + if not isinstance(points, np.ndarray) or points.size == 0: + continue + + if len(points.shape) != 2 or points.shape[1] != 3: + continue + + colors = None + if "colors_numpy" in obj and obj["colors_numpy"] is not None: # type: ignore[typeddict-item] + colors = obj["colors_numpy"] # type: ignore[typeddict-item] + if isinstance(colors, np.ndarray) and colors.size > 0: + if ( + colors.shape[0] != points.shape[0] + or len(colors.shape) != 2 + or colors.shape[1] != 3 + ): + colors = None + + all_points.append(points) + if colors is not None: + all_colors.append(colors) + valid_objects += 1 + + if not all_points: + return [] + + # Combine point clouds + combined_points = np.vstack(all_points) + combined_colors = None + if len(all_colors) == valid_objects and len(all_colors) > 0: + combined_colors = np.vstack(all_colors) + + # Send grasp request + grasps = self._send_grasp_request_sync(combined_points, combined_colors) + + if not grasps: + return [] + + # Parse and return results in list of dictionaries format + return parse_grasp_results(grasps) + + except Exception as e: + logger.error(f"Grasp generation failed: {e}") + return [] + + def _send_grasp_request_sync( + self, + points: np.ndarray, # type: ignore[type-arg] + colors: np.ndarray | None, # type: ignore[type-arg] + ) -> list[dict] | None: # type: ignore[type-arg] + """Send synchronous grasp request to grasp server.""" + + try: + # Prepare colors + colors = np.ones((points.shape[0], 3), dtype=np.float32) * 0.5 + + # Ensure correct data types + points = points.astype(np.float32) + colors = colors.astype(np.float32) + + # Validate ranges + if np.any(np.isnan(points)) or np.any(np.isinf(points)): + logger.error("Points contain NaN or Inf values") + return None + if np.any(np.isnan(colors)) or np.any(np.isinf(colors)): + logger.error("Colors contain NaN or Inf values") + return None + + colors = np.clip(colors, 0.0, 1.0) + + # Run async request in sync context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete(self._async_grasp_request(points, colors)) + return result + finally: + loop.close() + + except Exception as e: + logger.error(f"Error in synchronous grasp request: {e}") + return None + + async def _async_grasp_request( + self, + points: np.ndarray, # type: ignore[type-arg] + colors: np.ndarray, # type: ignore[type-arg] + ) -> list[dict] | None: # type: ignore[type-arg] + """Async grasp request helper.""" + import json + + import websockets + + try: + async with websockets.connect(self.server_url) as websocket: + request = { + "points": points.tolist(), + "colors": colors.tolist(), + "lims": [-1.0, 1.0, -1.0, 1.0, 0.0, 2.0], + } + + await websocket.send(json.dumps(request)) + response = await websocket.recv() + grasps = json.loads(response) + + if isinstance(grasps, dict) and "error" in grasps: + logger.error(f"Server returned error: {grasps['error']}") + return None + elif isinstance(grasps, int | float) and grasps == 0: + return None + elif not isinstance(grasps, list): + logger.error(f"Server returned unexpected response type: {type(grasps)}") + return None + elif len(grasps) == 0: + return None + + return self._convert_grasp_format(grasps) + + except Exception as e: + logger.error(f"Async grasp request failed: {e}") + return None + + def _convert_grasp_format(self, grasps: list[dict]) -> list[dict]: # type: ignore[type-arg] + """Convert Dimensional Grasp format to visualization format.""" + converted = [] + + for i, grasp in enumerate(grasps): + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + euler_angles = self._rotation_matrix_to_euler(rotation_matrix) + + converted_grasp = { + "id": f"grasp_{i}", + "score": grasp.get("score", 0.0), + "width": grasp.get("width", 0.0), + "height": grasp.get("height", 0.0), + "depth": grasp.get("depth", 0.0), + "translation": grasp.get("translation", [0, 0, 0]), + "rotation_matrix": rotation_matrix.tolist(), + "euler_angles": euler_angles, + } + converted.append(converted_grasp) + + converted.sort(key=lambda x: x["score"], reverse=True) + return converted + + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> dict[str, float]: # type: ignore[type-arg] + """Convert rotation matrix to Euler angles (in radians).""" + sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) + + singular = sy < 1e-6 + + if not singular: + x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + y = np.arctan2(-rotation_matrix[2, 0], sy) + z = 0 + + return {"roll": x, "pitch": y, "yaw": z} + + def cleanup(self) -> None: + """Clean up resources.""" + logger.info("Grasp generator cleaned up") diff --git a/dimos/perception/grasp_generation/utils.py b/dimos/perception/grasp_generation/utils.py new file mode 100644 index 0000000000..492a3d1df4 --- /dev/null +++ b/dimos/perception/grasp_generation/utils.py @@ -0,0 +1,529 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for grasp generation and visualization.""" + +import cv2 +import numpy as np +import open3d as o3d # type: ignore[import-untyped] + +from dimos.perception.common.utils import project_3d_points_to_2d + + +def create_gripper_geometry( + grasp_data: dict, # type: ignore[type-arg] + finger_length: float = 0.08, + finger_thickness: float = 0.004, +) -> list[o3d.geometry.TriangleMesh]: + """ + Create a simple fork-like gripper geometry from grasp data. + + Args: + grasp_data: Dictionary containing grasp parameters + - translation: 3D position list + - rotation_matrix: 3x3 rotation matrix defining gripper coordinate system + * X-axis: gripper width direction (opening/closing) + * Y-axis: finger length direction + * Z-axis: approach direction (toward object) + - width: Gripper opening width + finger_length: Length of gripper fingers (longer) + finger_thickness: Thickness of gripper fingers + base_height: Height of gripper base (longer) + color: RGB color for the gripper (solid blue) + + Returns: + List of Open3D TriangleMesh geometries for the gripper + """ + + translation = np.array(grasp_data["translation"]) + rotation_matrix = np.array(grasp_data["rotation_matrix"]) + + width = grasp_data.get("width", 0.04) + + # Create transformation matrix + transform = np.eye(4) + transform[:3, :3] = rotation_matrix + transform[:3, 3] = translation + + geometries = [] + + # Gripper dimensions + finger_width = 0.006 # Thickness of each finger + handle_length = 0.05 # Length of handle extending backward + + # Build gripper in local coordinate system: + # X-axis = width direction (left/right finger separation) + # Y-axis = finger length direction (fingers extend along +Y) + # Z-axis = approach direction (toward object, handle extends along -Z) + # IMPORTANT: Fingertips should be at origin (translation point) + + # Create left finger extending along +Y, positioned at +X + left_finger = o3d.geometry.TriangleMesh.create_box( + width=finger_width, # Thin finger + height=finger_length, # Extends along Y (finger length direction) + depth=finger_thickness, # Thin in Z direction + ) + left_finger.translate( + [ + width / 2 - finger_width / 2, # Position at +X (half width from center) + -finger_length, # Shift so fingertips are at origin + -finger_thickness / 2, # Center in Z + ] + ) + + # Create right finger extending along +Y, positioned at -X + right_finger = o3d.geometry.TriangleMesh.create_box( + width=finger_width, # Thin finger + height=finger_length, # Extends along Y (finger length direction) + depth=finger_thickness, # Thin in Z direction + ) + right_finger.translate( + [ + -width / 2 - finger_width / 2, # Position at -X (half width from center) + -finger_length, # Shift so fingertips are at origin + -finger_thickness / 2, # Center in Z + ] + ) + + # Create base connecting fingers - flat like a stickman body + base = o3d.geometry.TriangleMesh.create_box( + width=width + finger_width, # Full width plus finger thickness + height=finger_thickness, # Flat like fingers (stickman style) + depth=finger_thickness, # Thin like fingers + ) + base.translate( + [ + -width / 2 - finger_width / 2, # Start from left finger position + -finger_length - finger_thickness, # Behind fingers, adjusted for fingertips at origin + -finger_thickness / 2, # Center in Z + ] + ) + + # Create handle extending backward - flat stick like stickman arm + handle = o3d.geometry.TriangleMesh.create_box( + width=finger_width, # Same width as fingers + height=handle_length, # Extends backward along Y direction (same plane) + depth=finger_thickness, # Thin like fingers (same plane) + ) + handle.translate( + [ + -finger_width / 2, # Center in X + -finger_length + - finger_thickness + - handle_length, # Extend backward from base, adjusted for fingertips at origin + -finger_thickness / 2, # Same Z plane as other components + ] + ) + + # Use solid red color for all parts (user changed to red) + solid_color = [1.0, 0.0, 0.0] # Red color + + left_finger.paint_uniform_color(solid_color) + right_finger.paint_uniform_color(solid_color) + base.paint_uniform_color(solid_color) + handle.paint_uniform_color(solid_color) + + # Apply transformation to all parts + left_finger.transform(transform) + right_finger.transform(transform) + base.transform(transform) + handle.transform(transform) + + geometries.extend([left_finger, right_finger, base, handle]) + + return geometries + + +def create_all_gripper_geometries( + grasp_list: list[dict], # type: ignore[type-arg] + max_grasps: int = -1, +) -> list[o3d.geometry.TriangleMesh]: + """ + Create gripper geometries for multiple grasps. + + Args: + grasp_list: List of grasp dictionaries + max_grasps: Maximum number of grasps to visualize (-1 for all) + + Returns: + List of all gripper geometries + """ + all_geometries = [] + + grasps_to_show = grasp_list if max_grasps < 0 else grasp_list[:max_grasps] + + for grasp in grasps_to_show: + gripper_parts = create_gripper_geometry(grasp) + all_geometries.extend(gripper_parts) + + return all_geometries + + +def draw_grasps_on_image( + image: np.ndarray, # type: ignore[type-arg] + grasp_data: dict | dict[int | str, list[dict]] | list[dict], # type: ignore[type-arg] + camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] # [fx, fy, cx, cy] or 3x3 matrix + max_grasps: int = -1, # -1 means show all grasps + finger_length: float = 0.08, # Match 3D gripper + finger_thickness: float = 0.004, # Match 3D gripper +) -> np.ndarray: # type: ignore[type-arg] + """ + Draw fork-like gripper visualizations on the image matching 3D gripper design. + + Args: + image: Base image to draw on + grasp_data: Can be: + - A single grasp dict + - A list of grasp dicts + - A dictionary mapping object IDs or "scene" to list of grasps + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + max_grasps: Maximum number of grasps to visualize (-1 for all) + finger_length: Length of gripper fingers (matches 3D design) + finger_thickness: Thickness of gripper fingers (matches 3D design) + + Returns: + Image with grasps drawn + """ + result = image.copy() + + # Convert camera intrinsics to 3x3 matrix if needed + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + else: + camera_matrix = np.array(camera_intrinsics) + + # Convert input to standard format + if isinstance(grasp_data, dict) and not any( + key in grasp_data for key in ["scene", 0, 1, 2, 3, 4, 5] + ): + # Single grasp + grasps_to_draw = [(grasp_data, 0)] + elif isinstance(grasp_data, list): + # List of grasps + grasps_to_draw = [(grasp, i) for i, grasp in enumerate(grasp_data)] + else: + # Dictionary of grasps by object ID + grasps_to_draw = [] + for _obj_id, grasps in grasp_data.items(): + for i, grasp in enumerate(grasps): + grasps_to_draw.append((grasp, i)) + + # Limit number of grasps if specified + if max_grasps > 0: + grasps_to_draw = grasps_to_draw[:max_grasps] + + # Define grasp colors (solid red to match 3D design) + def get_grasp_color(index: int) -> tuple: # type: ignore[type-arg] + # Use solid red color for all grasps to match 3D design + return (0, 0, 255) # Red in BGR format for OpenCV + + # Draw each grasp + for grasp, index in grasps_to_draw: + try: + color = get_grasp_color(index) + thickness = max(1, 4 - index // 3) + + # Extract grasp parameters (using translation and rotation_matrix) + if "translation" not in grasp or "rotation_matrix" not in grasp: + continue + + translation = np.array(grasp["translation"]) + rotation_matrix = np.array(grasp["rotation_matrix"]) + width = grasp.get("width", 0.04) + + # Match 3D gripper dimensions + finger_width = 0.006 # Thickness of each finger (matches 3D) + handle_length = 0.05 # Length of handle extending backward (matches 3D) + + # Create gripper geometry in local coordinate system matching 3D design: + # X-axis = width direction (left/right finger separation) + # Y-axis = finger length direction (fingers extend along +Y) + # Z-axis = approach direction (toward object, handle extends along -Z) + # IMPORTANT: Fingertips should be at origin (translation point) + + # Left finger extending along +Y, positioned at +X + left_finger_points = np.array( + [ + [ + width / 2 - finger_width / 2, # type: ignore[operator] + -finger_length, + -finger_thickness / 2, + ], # Back left + [ + width / 2 + finger_width / 2, # type: ignore[operator] + -finger_length, + -finger_thickness / 2, + ], # Back right + [ + width / 2 + finger_width / 2, # type: ignore[operator] + 0, + -finger_thickness / 2, + ], # Front right (at origin) + [ + width / 2 - finger_width / 2, # type: ignore[operator] + 0, + -finger_thickness / 2, + ], # Front left (at origin) + ] + ) + + # Right finger extending along +Y, positioned at -X + right_finger_points = np.array( + [ + [ + -width / 2 - finger_width / 2, # type: ignore[operator] + -finger_length, + -finger_thickness / 2, + ], # Back left + [ + -width / 2 + finger_width / 2, # type: ignore[operator] + -finger_length, + -finger_thickness / 2, + ], # Back right + [ + -width / 2 + finger_width / 2, # type: ignore[operator] + 0, + -finger_thickness / 2, + ], # Front right (at origin) + [ + -width / 2 - finger_width / 2, # type: ignore[operator] + 0, + -finger_thickness / 2, + ], # Front left (at origin) + ] + ) + + # Base connecting fingers - flat rectangle behind fingers + base_points = np.array( + [ + [ + -width / 2 - finger_width / 2, # type: ignore[operator] + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Back left + [ + width / 2 + finger_width / 2, # type: ignore[operator] + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Back right + [ + width / 2 + finger_width / 2, # type: ignore[operator] + -finger_length, + -finger_thickness / 2, + ], # Front right + [ + -width / 2 - finger_width / 2, # type: ignore[operator] + -finger_length, + -finger_thickness / 2, + ], # Front left + ] + ) + + # Handle extending backward - thin rectangle + handle_points = np.array( + [ + [ + -finger_width / 2, + -finger_length - finger_thickness - handle_length, + -finger_thickness / 2, + ], # Back left + [ + finger_width / 2, + -finger_length - finger_thickness - handle_length, + -finger_thickness / 2, + ], # Back right + [ + finger_width / 2, + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Front right + [ + -finger_width / 2, + -finger_length - finger_thickness, + -finger_thickness / 2, + ], # Front left + ] + ) + + # Transform all points to world frame + def transform_points(points): # type: ignore[no-untyped-def] + # Apply rotation and translation + world_points = (rotation_matrix @ points.T).T + translation + return world_points + + left_finger_world = transform_points(left_finger_points) # type: ignore[no-untyped-call] + right_finger_world = transform_points(right_finger_points) # type: ignore[no-untyped-call] + base_world = transform_points(base_points) # type: ignore[no-untyped-call] + handle_world = transform_points(handle_points) # type: ignore[no-untyped-call] + + # Project to 2D + left_finger_2d = project_3d_points_to_2d(left_finger_world, camera_matrix) + right_finger_2d = project_3d_points_to_2d(right_finger_world, camera_matrix) + base_2d = project_3d_points_to_2d(base_world, camera_matrix) + handle_2d = project_3d_points_to_2d(handle_world, camera_matrix) + + # Draw left finger + pts = left_finger_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw right finger + pts = right_finger_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw base + pts = base_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw handle + pts = handle_2d.astype(np.int32) + cv2.polylines(result, [pts], True, color, thickness) + + # Draw grasp center (fingertips at origin) + center_2d = project_3d_points_to_2d(translation.reshape(1, -1), camera_matrix)[0] + cv2.circle(result, tuple(center_2d.astype(int)), 3, color, -1) + + except Exception: + # Skip this grasp if there's an error + continue + + return result + + +def get_standard_coordinate_transform(): # type: ignore[no-untyped-def] + """ + Get a standard coordinate transformation matrix for consistent visualization. + + This transformation ensures that: + - X (red) axis points right + - Y (green) axis points up + - Z (blue) axis points toward viewer + + Returns: + 4x4 transformation matrix + """ + # Standard transformation matrix to ensure consistent coordinate frame orientation + transform = np.array( + [ + [1, 0, 0, 0], # X points right + [0, -1, 0, 0], # Y points up (flip from OpenCV to standard) + [0, 0, -1, 0], # Z points toward viewer (flip depth) + [0, 0, 0, 1], + ] + ) + return transform + + +def visualize_grasps_3d( + point_cloud: o3d.geometry.PointCloud, + grasp_list: list[dict], # type: ignore[type-arg] + max_grasps: int = -1, +) -> None: + """ + Visualize grasps in 3D with point cloud. + + Args: + point_cloud: Open3D point cloud + grasp_list: List of grasp dictionaries + max_grasps: Maximum number of grasps to visualize + """ + # Apply standard coordinate transformation + transform = get_standard_coordinate_transform() # type: ignore[no-untyped-call] + + # Transform point cloud + pc_copy = o3d.geometry.PointCloud(point_cloud) + pc_copy.transform(transform) + geometries = [pc_copy] + + # Transform gripper geometries + gripper_geometries = create_all_gripper_geometries(grasp_list, max_grasps) + for geom in gripper_geometries: + geom.transform(transform) + geometries.extend(gripper_geometries) + + # Add transformed coordinate frame + origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) + origin_frame.transform(transform) + geometries.append(origin_frame) + + o3d.visualization.draw_geometries(geometries, window_name="3D Grasp Visualization") + + +def parse_grasp_results(grasps: list[dict]) -> list[dict]: # type: ignore[type-arg] + """ + Parse grasp results into visualization format. + + Args: + grasps: List of grasp dictionaries + + Returns: + List of dictionaries containing: + - id: Unique grasp identifier + - score: Confidence score (float) + - width: Gripper opening width (float) + - translation: 3D position [x, y, z] + - rotation_matrix: 3x3 rotation matrix as nested list + """ + if not grasps: + return [] + + parsed_grasps = [] + + for i, grasp in enumerate(grasps): + # Extract data from each grasp + translation = grasp.get("translation", [0, 0, 0]) + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3))) + score = float(grasp.get("score", 0.0)) + width = float(grasp.get("width", 0.08)) + + parsed_grasp = { + "id": f"grasp_{i}", + "score": score, + "width": width, + "translation": translation, + "rotation_matrix": rotation_matrix.tolist(), + } + parsed_grasps.append(parsed_grasp) + + return parsed_grasps + + +def create_grasp_overlay( + rgb_image: np.ndarray, # type: ignore[type-arg] + grasps: list[dict], # type: ignore[type-arg] + camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] +) -> np.ndarray: # type: ignore[type-arg] + """ + Create grasp visualization overlay on RGB image. + + Args: + rgb_image: RGB input image + grasps: List of grasp dictionaries in viz format + camera_intrinsics: Camera parameters + + Returns: + RGB image with grasp overlay + """ + try: + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + result_bgr = draw_grasps_on_image( + bgr_image, + grasps, + camera_intrinsics, + max_grasps=-1, + ) + return cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) + except Exception: + return rgb_image.copy() diff --git a/dimos/perception/object_detection_stream.py b/dimos/perception/object_detection_stream.py new file mode 100644 index 0000000000..be9a668eab --- /dev/null +++ b/dimos/perception/object_detection_stream.py @@ -0,0 +1,320 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from reactivex import Observable, operators as ops + +from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector # type: ignore[import-untyped] + +try: + from dimos.perception.detection2d.detic_2d_det import ( # type: ignore[import-untyped] + Detic2DDetector, + ) + + DETIC_AVAILABLE = True +except (ModuleNotFoundError, ImportError): + DETIC_AVAILABLE = False + Detic2DDetector = None +from collections.abc import Callable +from typing import TYPE_CHECKING + +from dimos.models.depth.metric3d import Metric3D +from dimos.perception.common.utils import draw_object_detection_visualization +from dimos.perception.detection2d.utils import ( # type: ignore[attr-defined] + calculate_depth_from_bbox, + calculate_object_size_from_bbox, + calculate_position_rotation_from_bbox, +) +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import transform_robot_to_map # type: ignore[attr-defined] + +if TYPE_CHECKING: + from dimos.types.manipulation import ObjectData + +# Initialize logger for the ObjectDetectionStream +logger = setup_logger() + + +class ObjectDetectionStream: + """ + A stream processor that: + 1. Detects objects using a Detector (Detic or Yolo) + 2. Estimates depth using Metric3D + 3. Calculates 3D position and dimensions using camera intrinsics + 4. Transforms coordinates to map frame + 5. Draws bounding boxes and segmentation masks on the frame + + Provides a stream of structured object data with position and rotation information. + """ + + def __init__( # type: ignore[no-untyped-def] + self, + camera_intrinsics=None, # [fx, fy, cx, cy] + device: str = "cuda", + gt_depth_scale: float = 1000.0, + min_confidence: float = 0.7, + class_filter=None, # Optional list of class names to filter (e.g., ["person", "car"]) + get_pose: Callable | None = None, # type: ignore[type-arg] # Optional function to transform coordinates to map frame + detector: Detic2DDetector | Yolo2DDetector | None = None, + video_stream: Observable = None, # type: ignore[assignment, type-arg] + disable_depth: bool = False, # Flag to disable monocular Metric3D depth estimation + draw_masks: bool = False, # Flag to enable drawing segmentation masks + ) -> None: + """ + Initialize the ObjectDetectionStream. + + Args: + camera_intrinsics: List [fx, fy, cx, cy] with camera parameters + device: Device to run inference on ("cuda" or "cpu") + gt_depth_scale: Ground truth depth scale for Metric3D + min_confidence: Minimum confidence for detections + class_filter: Optional list of class names to filter + get_pose: Optional function to transform pose to map coordinates + detector: Optional detector instance (Detic or Yolo) + video_stream: Observable of video frames to process (if provided, returns a stream immediately) + disable_depth: Flag to disable monocular Metric3D depth estimation + draw_masks: Flag to enable drawing segmentation masks + """ + self.min_confidence = min_confidence + self.class_filter = class_filter + self.get_pose = get_pose + self.disable_depth = disable_depth + self.draw_masks = draw_masks + # Initialize object detector + if detector is not None: + self.detector = detector + else: + if DETIC_AVAILABLE: + try: + self.detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) + logger.info("Using Detic2DDetector") + except Exception as e: + logger.warning( + f"Failed to initialize Detic2DDetector: {e}. Falling back to Yolo2DDetector." + ) + self.detector = Yolo2DDetector() + else: + logger.info("Detic not available. Using Yolo2DDetector.") + self.detector = Yolo2DDetector() + # Set up camera intrinsics + self.camera_intrinsics = camera_intrinsics + + # Initialize depth estimation model + self.depth_model = None + if not disable_depth: + try: + self.depth_model = Metric3D(gt_depth_scale) + + if camera_intrinsics is not None: + self.depth_model.update_intrinsic(camera_intrinsics) # type: ignore[no-untyped-call] + + # Create 3x3 camera matrix for calculations + fx, fy, cx, cy = camera_intrinsics + self.camera_matrix = np.array( + [[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32 + ) + else: + raise ValueError("camera_intrinsics must be provided") + + logger.info("Depth estimation enabled with Metric3D") + except Exception as e: + logger.warning(f"Failed to initialize Metric3D depth model: {e}") + logger.warning("Falling back to disable_depth=True mode") + self.disable_depth = True + self.depth_model = None + else: + logger.info("Depth estimation disabled") + + # If video_stream is provided, create and store the stream immediately + self.stream = None + if video_stream is not None: + self.stream = self.create_stream(video_stream) + + def create_stream(self, video_stream: Observable) -> Observable: # type: ignore[type-arg] + """ + Create an Observable stream of object data from a video stream. + + Args: + video_stream: Observable that emits video frames + + Returns: + Observable that emits dictionaries containing object data + with position and rotation information + """ + + def process_frame(frame): # type: ignore[no-untyped-def] + # TODO: More modular detector output interface + bboxes, track_ids, class_ids, confidences, names, *mask_data = ( # type: ignore[misc] + *self.detector.process_image(frame), + [], + ) + + masks = ( + mask_data[0] # type: ignore[has-type] + if mask_data and len(mask_data[0]) == len(bboxes) # type: ignore[has-type] + else [None] * len(bboxes) # type: ignore[has-type] + ) + + # Create visualization + viz_frame = frame.copy() + + # Process detections + objects = [] + if not self.disable_depth: + depth_map = self.depth_model.infer_depth(frame) # type: ignore[union-attr] + depth_map = np.array(depth_map) + else: + depth_map = None + + for i, bbox in enumerate(bboxes): # type: ignore[has-type] + # Skip if confidence is too low + if i < len(confidences) and confidences[i] < self.min_confidence: # type: ignore[has-type] + continue + + # Skip if class filter is active and class not in filter + class_name = names[i] if i < len(names) else None # type: ignore[has-type] + if self.class_filter and class_name not in self.class_filter: + continue + + if not self.disable_depth and depth_map is not None: + # Get depth for this object + depth = calculate_depth_from_bbox(depth_map, bbox) # type: ignore[no-untyped-call] + if depth is None: + # Skip objects with invalid depth + continue + # Calculate object position and rotation + position, rotation = calculate_position_rotation_from_bbox( + bbox, depth, self.camera_intrinsics + ) + # Get object dimensions + width, height = calculate_object_size_from_bbox( + bbox, depth, self.camera_intrinsics + ) + + # Transform to map frame if a transform function is provided + try: + if self.get_pose: + # position and rotation are already Vector objects, no need to convert + robot_pose = self.get_pose() + position, rotation = transform_robot_to_map( + robot_pose["position"], robot_pose["rotation"], position, rotation + ) + except Exception as e: + logger.error(f"Error transforming to map frame: {e}") + position, rotation = position, rotation + + else: + depth = -1 + position = Vector(0, 0, 0) # type: ignore[arg-type] + rotation = Vector(0, 0, 0) # type: ignore[arg-type] + width = -1 + height = -1 + + # Create a properly typed ObjectData instance + object_data: ObjectData = { + "object_id": track_ids[i] if i < len(track_ids) else -1, # type: ignore[has-type] + "bbox": bbox, + "depth": depth, + "confidence": confidences[i] if i < len(confidences) else None, # type: ignore[has-type, typeddict-item] + "class_id": class_ids[i] if i < len(class_ids) else None, # type: ignore[has-type, typeddict-item] + "label": class_name, # type: ignore[typeddict-item] + "position": position, + "rotation": rotation, + "size": {"width": width, "height": height}, + "segmentation_mask": masks[i], + } + + objects.append(object_data) + + # Create visualization using common function + viz_frame = draw_object_detection_visualization( + viz_frame, objects, draw_masks=self.draw_masks, font_scale=1.5 + ) + + return {"frame": frame, "viz_frame": viz_frame, "objects": objects} + + self.stream = video_stream.pipe(ops.map(process_frame)) + + return self.stream + + def get_stream(self): # type: ignore[no-untyped-def] + """ + Returns the current detection stream if available. + Creates a new one with the provided video_stream if not already created. + + Returns: + Observable: The reactive stream of detection results + """ + if self.stream is None: + raise ValueError( + "Stream not initialized. Either provide a video_stream during initialization or call create_stream first." + ) + return self.stream + + def get_formatted_stream(self): # type: ignore[no-untyped-def] + """ + Returns a formatted stream of object detection data for better readability. + This is especially useful for LLMs like Claude that need structured text input. + + Returns: + Observable: A stream of formatted string representations of object data + """ + if self.stream is None: + raise ValueError( + "Stream not initialized. Either provide a video_stream during initialization or call create_stream first." + ) + + def format_detection_data(result): # type: ignore[no-untyped-def] + # Extract objects from result + objects = result.get("objects", []) + + if not objects: + return "No objects detected." + + formatted_data = "[DETECTED OBJECTS]\n" + try: + for i, obj in enumerate(objects): + pos = obj["position"] + rot = obj["rotation"] + size = obj["size"] + bbox = obj["bbox"] + + # Format each object with a multiline f-string for better readability + bbox_str = f"[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]" + formatted_data += ( + f"Object {i + 1}: {obj['label']}\n" + f" ID: {obj['object_id']}\n" + f" Confidence: {obj['confidence']:.2f}\n" + f" Position: x={pos.x:.2f}m, y={pos.y:.2f}m, z={pos.z:.2f}m\n" + f" Rotation: yaw={rot.z:.2f} rad\n" + f" Size: width={size['width']:.2f}m, height={size['height']:.2f}m\n" + f" Depth: {obj['depth']:.2f}m\n" + f" Bounding box: {bbox_str}\n" + "----------------------------------\n" + ) + except Exception as e: + logger.warning(f"Error formatting object {i}: {e}") + formatted_data += f"Object {i + 1}: [Error formatting data]" + formatted_data += "\n----------------------------------\n" + + return formatted_data + + # Return a new stream with the formatter applied + return self.stream.pipe(ops.map(format_detection_data)) + + def cleanup(self) -> None: + """Clean up resources.""" + pass diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py new file mode 100644 index 0000000000..3c277fec2f --- /dev/null +++ b/dimos/perception/object_tracker.py @@ -0,0 +1,629 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + +# Import LCM messages +from dimos_lcm.vision_msgs import ( # type: ignore[import-untyped] + Detection2D, + Detection3D, + ObjectHypothesisWithPose, +) +import numpy as np +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.manipulation.visual_servoing.utils import visualize_detections_3d +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.protocol.tf import TF +from dimos.types.timestamped import align_timestamped +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import ( + euler_to_quaternion, + optical_to_robot_frame, + yaw_towards_point, +) + +logger = setup_logger() + + +class ObjectTracking(Module): + """Module for object tracking with LCM input/output.""" + + # LCM inputs + color_image: In[Image] = None # type: ignore[assignment] + depth: In[Image] = None # type: ignore[assignment] + camera_info: In[CameraInfo] = None # type: ignore[assignment] + + # LCM outputs + detection2darray: Out[Detection2DArray] = None # type: ignore[assignment] + detection3darray: Out[Detection3DArray] = None # type: ignore[assignment] + tracked_overlay: Out[Image] = None # type: ignore[assignment] # Visualization output + + def __init__( + self, + reid_threshold: int = 10, + reid_fail_tolerance: int = 5, + frame_id: str = "camera_link", + ) -> None: + """ + Initialize an object tracking module using OpenCV's CSRT tracker with ORB re-ID. + + Args: + camera_intrinsics: Optional [fx, fy, cx, cy] camera parameters. + If None, will use camera_info input. + reid_threshold: Minimum good feature matches needed to confirm re-ID. + reid_fail_tolerance: Number of consecutive frames Re-ID can fail before + tracking is stopped. + frame_id: TF frame ID for the camera (default: "camera_link") + """ + # Call parent Module init + super().__init__() + + self.camera_intrinsics = None + self.reid_threshold = reid_threshold + self.reid_fail_tolerance = reid_fail_tolerance + self.frame_id = frame_id + + self.tracker = None + self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization + self.tracking_initialized = False + self.orb = cv2.ORB_create() # type: ignore[attr-defined] + self.bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False) + self.original_des = None # Store original ORB descriptors + self.original_kps = None # Store original ORB keypoints + self.reid_fail_count = 0 # Counter for consecutive re-id failures + self.last_good_matches = [] # type: ignore[var-annotated] # Store good matches for visualization + self.last_roi_kps = None # Store last ROI keypoints for visualization + self.last_roi_bbox = None # Store last ROI bbox for visualization + self.reid_confirmed = False # Store current reid confirmation state + self.tracking_frame_count = 0 # Count frames since tracking started + self.reid_warmup_frames = 3 # Number of frames before REID starts + + self._frame_lock = threading.Lock() + self._latest_rgb_frame: np.ndarray | None = None # type: ignore[type-arg] + self._latest_depth_frame: np.ndarray | None = None # type: ignore[type-arg] + self._latest_camera_info: CameraInfo | None = None + + # Tracking thread control + self.tracking_thread: threading.Thread | None = None + self.stop_tracking = threading.Event() + self.tracking_rate = 30.0 # Hz + self.tracking_period = 1.0 / self.tracking_rate + + # Initialize TF publisher + self.tf = TF() + + # Store latest detections for RPC access + self._latest_detection2d: Detection2DArray | None = None + self._latest_detection3d: Detection3DArray | None = None + self._detection_event = threading.Event() + + @rpc + def start(self) -> None: + super().start() + + # Subscribe to aligned rgb and depth streams + def on_aligned_frames(frames_tuple) -> None: # type: ignore[no-untyped-def] + rgb_msg, depth_msg = frames_tuple + with self._frame_lock: + self._latest_rgb_frame = rgb_msg.data + + depth_data = depth_msg.data + # Convert from millimeters to meters if depth is DEPTH16 format + if depth_msg.format == ImageFormat.DEPTH16: + depth_data = depth_data.astype(np.float32) / 1000.0 + self._latest_depth_frame = depth_data + + # Create aligned observable for RGB and depth + aligned_frames = align_timestamped( + self.color_image.observable(), # type: ignore[no-untyped-call] + self.depth.observable(), # type: ignore[no-untyped-call] + buffer_size=2.0, # 2 second buffer + match_tolerance=0.5, # 500ms tolerance + ) + unsub = aligned_frames.subscribe(on_aligned_frames) + self._disposables.add(unsub) + + # Subscribe to camera info stream separately (doesn't need alignment) + def on_camera_info(camera_info_msg: CameraInfo) -> None: + self._latest_camera_info = camera_info_msg + # Extract intrinsics from camera info K matrix + # K is a 3x3 matrix in row-major order: [fx, 0, cx, 0, fy, cy, 0, 0, 1] + self.camera_intrinsics = [ # type: ignore[assignment] + camera_info_msg.K[0], + camera_info_msg.K[4], + camera_info_msg.K[2], + camera_info_msg.K[5], + ] + + unsub = self.camera_info.subscribe(on_camera_info) # type: ignore[assignment] + self._disposables.add(Disposable(unsub)) # type: ignore[arg-type] + + @rpc + def stop(self) -> None: + self.stop_track() + + self.stop_tracking.set() + + if self.tracking_thread and self.tracking_thread.is_alive(): + self.tracking_thread.join(timeout=2.0) + + super().stop() + + @rpc + def track( + self, + bbox: list[float], + ) -> dict: # type: ignore[type-arg] + """ + Initialize tracking with a bounding box and process current frame. + + Args: + bbox: Bounding box in format [x1, y1, x2, y2] + + Returns: + Dict containing tracking results with 2D and 3D detections + """ + if self._latest_rgb_frame is None: + logger.warning("No RGB frame available for tracking") + + # Initialize tracking + x1, y1, x2, y2 = map(int, bbox) + w, h = x2 - x1, y2 - y1 + if w <= 0 or h <= 0: + logger.warning(f"Invalid initial bbox provided: {bbox}. Tracking not started.") + + # Set tracking parameters + self.tracking_bbox = (x1, y1, w, h) # type: ignore[assignment] # Store in (x, y, w, h) format + self.tracker = cv2.legacy.TrackerCSRT_create() # type: ignore[attr-defined] + self.tracking_initialized = False + self.original_des = None + self.reid_fail_count = 0 + logger.info(f"Tracking target set with bbox: {self.tracking_bbox}") + + # Extract initial features + roi = self._latest_rgb_frame[y1:y2, x1:x2] # type: ignore[index] + if roi.size > 0: + self.original_kps, self.original_des = self.orb.detectAndCompute(roi, None) + if self.original_des is None: + logger.warning("No ORB features found in initial ROI. REID will be disabled.") + else: + logger.info(f"Initial ORB features extracted: {len(self.original_des)}") + + # Initialize the tracker + init_success = self.tracker.init(self._latest_rgb_frame, self.tracking_bbox) # type: ignore[attr-defined] + if init_success: + self.tracking_initialized = True + self.tracking_frame_count = 0 # Reset frame counter + logger.info("Tracker initialized successfully.") + else: + logger.error("Tracker initialization failed.") + self.stop_track() + else: + logger.error("Empty ROI during tracker initialization.") + self.stop_track() + + # Start tracking thread + self._start_tracking_thread() + + # Return initial tracking result + return {"status": "tracking_started", "bbox": self.tracking_bbox} + + def reid(self, frame, current_bbox) -> bool: # type: ignore[no-untyped-def] + """Check if features in current_bbox match stored original features.""" + # During warm-up period, always return True + if self.tracking_frame_count < self.reid_warmup_frames: + return True + + if self.original_des is None: + return False + x1, y1, x2, y2 = map(int, current_bbox) + roi = frame[y1:y2, x1:x2] + if roi.size == 0: + return False # Empty ROI cannot match + + kps_current, des_current = self.orb.detectAndCompute(roi, None) + if des_current is None or len(des_current) < 2: + return False # Need at least 2 descriptors for knnMatch + + # Store ROI keypoints and bbox for visualization + self.last_roi_kps = kps_current + self.last_roi_bbox = [x1, y1, x2, y2] + + # Handle case where original_des has only 1 descriptor (cannot use knnMatch with k=2) + if len(self.original_des) < 2: + matches = self.bf.match(self.original_des, des_current) + self.last_good_matches = matches # Store all matches for visualization + good_matches = len(matches) + else: + matches = self.bf.knnMatch(self.original_des, des_current, k=2) + # Apply Lowe's ratio test robustly + good_matches_list = [] + good_matches = 0 + for match_pair in matches: + if len(match_pair) == 2: + m, n = match_pair + if m.distance < 0.75 * n.distance: + good_matches_list.append(m) + good_matches += 1 + self.last_good_matches = good_matches_list # Store good matches for visualization + + return good_matches >= self.reid_threshold + + def _start_tracking_thread(self) -> None: + """Start the tracking thread.""" + self.stop_tracking.clear() + self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True) + self.tracking_thread.start() + logger.info("Started tracking thread") + + def _tracking_loop(self) -> None: + """Main tracking loop that runs in a separate thread.""" + while not self.stop_tracking.is_set() and self.tracking_initialized: + # Process tracking for current frame + self._process_tracking() + + # Sleep to maintain tracking rate + time.sleep(self.tracking_period) + + logger.info("Tracking loop ended") + + def _reset_tracking_state(self) -> None: + """Reset tracking state without stopping the thread.""" + self.tracker = None + self.tracking_bbox = None + self.tracking_initialized = False + self.original_des = None + self.original_kps = None + self.reid_fail_count = 0 # Reset counter + self.last_good_matches = [] + self.last_roi_kps = None + self.last_roi_bbox = None + self.reid_confirmed = False # Reset reid confirmation state + self.tracking_frame_count = 0 # Reset frame counter + + # Publish empty detections to clear any visualizations + empty_2d = Detection2DArray(detections_length=0, header=Header(), detections=[]) + empty_3d = Detection3DArray(detections_length=0, header=Header(), detections=[]) + self._latest_detection2d = empty_2d + self._latest_detection3d = empty_3d + self._detection_event.clear() + self.detection2darray.publish(empty_2d) + self.detection3darray.publish(empty_3d) + + @rpc + def stop_track(self) -> bool: + """ + Stop tracking the current object. + This resets the tracker and all tracking state. + + Returns: + bool: True if tracking was successfully stopped + """ + # Reset tracking state first + self._reset_tracking_state() + + # Stop tracking thread if running (only if called from outside the thread) + if self.tracking_thread and self.tracking_thread.is_alive(): + # Check if we're being called from within the tracking thread + if threading.current_thread() != self.tracking_thread: + self.stop_tracking.set() + self.tracking_thread.join(timeout=1.0) + self.tracking_thread = None + else: + # If called from within thread, just set the stop flag + self.stop_tracking.set() + + logger.info("Tracking stopped") + return True + + @rpc + def is_tracking(self) -> bool: + """ + Check if the tracker is currently tracking an object successfully. + + Returns: + bool: True if tracking is active and REID is confirmed, False otherwise + """ + return self.tracking_initialized and self.reid_confirmed + + def _process_tracking(self) -> None: + """Process current frame for tracking and publish detections.""" + if self.tracker is None or not self.tracking_initialized: + return + + # Get local copies of frames under lock + with self._frame_lock: + if self._latest_rgb_frame is None or self._latest_depth_frame is None: + return + frame = self._latest_rgb_frame.copy() + depth_frame = self._latest_depth_frame.copy() + tracker_succeeded = False + reid_confirmed_this_frame = False + final_success = False + current_bbox_x1y1x2y2 = None + + # Perform tracker update + tracker_succeeded, bbox_cv = self.tracker.update(frame) + if tracker_succeeded: + x, y, w, h = map(int, bbox_cv) + current_bbox_x1y1x2y2 = [x, y, x + w, y + h] + # Perform re-ID check + reid_confirmed_this_frame = self.reid(frame, current_bbox_x1y1x2y2) + self.reid_confirmed = reid_confirmed_this_frame # Store for is_tracking() RPC + + if reid_confirmed_this_frame: + self.reid_fail_count = 0 + else: + self.reid_fail_count += 1 + else: + self.reid_confirmed = False # No tracking if tracker failed + + # Determine final success + if tracker_succeeded: + if self.reid_fail_count >= self.reid_fail_tolerance: + logger.warning( + f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost." + ) + final_success = False + self._reset_tracking_state() + else: + final_success = True + else: + final_success = False + if self.tracking_initialized: + logger.info("Tracker update failed. Stopping track.") + self._reset_tracking_state() + + self.tracking_frame_count += 1 + + if not reid_confirmed_this_frame and self.tracking_frame_count >= self.reid_warmup_frames: + return + + # Create detections if tracking succeeded + header = Header(self.frame_id) + detection2darray = Detection2DArray(detections_length=0, header=header, detections=[]) + detection3darray = Detection3DArray(detections_length=0, header=header, detections=[]) + + if final_success and current_bbox_x1y1x2y2 is not None: + x1, y1, x2, y2 = current_bbox_x1y1x2y2 + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = float(x2 - x1) + height = float(y2 - y1) + + # Create Detection2D + detection_2d = Detection2D() + detection_2d.id = "0" + detection_2d.results_length = 1 + detection_2d.header = header + + # Create hypothesis + hypothesis = ObjectHypothesisWithPose() + hypothesis.hypothesis.class_id = "tracked_object" + hypothesis.hypothesis.score = 1.0 + detection_2d.results = [hypothesis] + + # Create bounding box + detection_2d.bbox.center.position.x = center_x + detection_2d.bbox.center.position.y = center_y + detection_2d.bbox.center.theta = 0.0 + detection_2d.bbox.size_x = width + detection_2d.bbox.size_y = height + + detection2darray = Detection2DArray() + detection2darray.detections_length = 1 + detection2darray.header = header + detection2darray.detections = [detection_2d] + + # Create Detection3D if depth is available + if depth_frame is not None: + # Calculate 3D position using depth and camera intrinsics + depth_value = self._get_depth_from_bbox(current_bbox_x1y1x2y2, depth_frame) + if ( + depth_value is not None + and depth_value > 0 + and self.camera_intrinsics is not None + ): + fx, fy, cx, cy = self.camera_intrinsics + + # Convert pixel coordinates to 3D in optical frame + z_optical = depth_value + x_optical = (center_x - cx) * z_optical / fx + y_optical = (center_y - cy) * z_optical / fy + + # Create pose in optical frame + optical_pose = Pose() + optical_pose.position = Vector3(x_optical, y_optical, z_optical) + optical_pose.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) # Identity for now + + # Convert to robot frame + robot_pose = optical_to_robot_frame(optical_pose) + + # Calculate orientation: object facing towards camera (origin) + yaw = yaw_towards_point(robot_pose.position) + euler = Vector3(0.0, 0.0, yaw) # Only yaw, no roll/pitch + robot_pose.orientation = euler_to_quaternion(euler) + + # Estimate object size in meters + size_x = width * z_optical / fx + size_y = height * z_optical / fy + size_z = 0.1 # Default depth size + + # Create Detection3D + detection_3d = Detection3D() + detection_3d.id = "0" + detection_3d.results_length = 1 + detection_3d.header = header + + # Reuse hypothesis from 2D + detection_3d.results = [hypothesis] + + # Create 3D bounding box with robot frame pose + detection_3d.bbox.center = Pose() + detection_3d.bbox.center.position = robot_pose.position + detection_3d.bbox.center.orientation = robot_pose.orientation + detection_3d.bbox.size = Vector3(size_x, size_y, size_z) + + detection3darray = Detection3DArray() + detection3darray.detections_length = 1 + detection3darray.header = header + detection3darray.detections = [detection_3d] + + # Publish transform for tracked object + # The optical pose is in camera optical frame, so publish it relative to the camera frame + tracked_object_tf = Transform( + translation=robot_pose.position, + rotation=robot_pose.orientation, + frame_id=self.frame_id, # Use configured camera frame + child_frame_id="tracked_object", + ts=header.ts, + ) + self.tf.publish(tracked_object_tf) + + # Store latest detections for RPC access + self._latest_detection2d = detection2darray + self._latest_detection3d = detection3darray + + # Signal that new detections are available + if detection2darray.detections_length > 0 or detection3darray.detections_length > 0: + self._detection_event.set() + + # Publish detections + self.detection2darray.publish(detection2darray) + self.detection3darray.publish(detection3darray) + + # Create and publish visualization if tracking is active + if self.tracking_initialized: + # Convert single detection to list for visualization + detections_3d = ( + detection3darray.detections if detection3darray.detections_length > 0 else [] + ) + detections_2d = ( + detection2darray.detections if detection2darray.detections_length > 0 else [] + ) + + if detections_3d and detections_2d: + # Extract 2D bbox for visualization + det_2d = detections_2d[0] + bbox_2d = [] + if det_2d.bbox: + x1 = det_2d.bbox.center.position.x - det_2d.bbox.size_x / 2 + y1 = det_2d.bbox.center.position.y - det_2d.bbox.size_y / 2 + x2 = det_2d.bbox.center.position.x + det_2d.bbox.size_x / 2 + y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2 + bbox_2d = [[x1, y1, x2, y2]] + + # Create visualization + viz_image = visualize_detections_3d( + frame, detections_3d, show_coordinates=True, bboxes_2d=bbox_2d + ) + + # Overlay REID feature matches if available + if self.last_good_matches and self.last_roi_kps and self.last_roi_bbox: + viz_image = self._draw_reid_matches(viz_image) + + # Convert to Image message and publish + viz_msg = Image.from_numpy(viz_image) + self.tracked_overlay.publish(viz_msg) + + def _draw_reid_matches(self, image: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """Draw REID feature matches on the image.""" + viz_image = image.copy() + + x1, y1, _x2, _y2 = self.last_roi_bbox # type: ignore[misc] + + # Draw keypoints from current ROI in green + for kp in self.last_roi_kps: # type: ignore[attr-defined] + pt = (int(kp.pt[0] + x1), int(kp.pt[1] + y1)) # type: ignore[has-type] + cv2.circle(viz_image, pt, 3, (0, 255, 0), -1) + + for match in self.last_good_matches: + current_kp = self.last_roi_kps[match.trainIdx] # type: ignore[index] + pt_current = (int(current_kp.pt[0] + x1), int(current_kp.pt[1] + y1)) # type: ignore[has-type] + + # Draw a larger circle for matched points in yellow + cv2.circle(viz_image, pt_current, 5, (0, 255, 255), 2) # Yellow for matched points + + # Draw match strength indicator (smaller circle with intensity based on distance) + # Lower distance = better match = brighter color + intensity = int(255 * (1.0 - min(match.distance / 100.0, 1.0))) + cv2.circle(viz_image, pt_current, 2, (intensity, intensity, 255), -1) + + text = f"REID Matches: {len(self.last_good_matches)}/{len(self.last_roi_kps) if self.last_roi_kps else 0}" + cv2.putText(viz_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) + + if self.tracking_frame_count < self.reid_warmup_frames: + status_text = ( + f"REID: WARMING UP ({self.tracking_frame_count}/{self.reid_warmup_frames})" + ) + status_color = (255, 255, 0) # Yellow + elif len(self.last_good_matches) >= self.reid_threshold: + status_text = "REID: CONFIRMED" + status_color = (0, 255, 0) # Green + else: + status_text = f"REID: WEAK ({self.reid_fail_count}/{self.reid_fail_tolerance})" + status_color = (0, 165, 255) # Orange + + cv2.putText( + viz_image, status_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, status_color, 2 + ) + + return viz_image + + def _get_depth_from_bbox(self, bbox: list[int], depth_frame: np.ndarray) -> float | None: # type: ignore[type-arg] + """Calculate depth from bbox using the 25th percentile of closest points. + + Args: + bbox: Bounding box coordinates [x1, y1, x2, y2] + depth_frame: Depth frame to extract depth values from + + Returns: + Depth value or None if not available + """ + if depth_frame is None: + return None + + x1, y1, x2, y2 = bbox + + # Ensure bbox is within frame bounds + y1 = max(0, y1) + y2 = min(depth_frame.shape[0], y2) + x1 = max(0, x1) + x2 = min(depth_frame.shape[1], x2) + + # Extract depth values from the entire bbox + roi_depth = depth_frame[y1:y2, x1:x2] + + # Get valid (finite and positive) depth values + valid_depths = roi_depth[np.isfinite(roi_depth) & (roi_depth > 0)] + + if len(valid_depths) > 0: + depth_25th_percentile = float(np.percentile(valid_depths, 25)) + return depth_25th_percentile + + return None + + +object_tracking = ObjectTracking.blueprint + +__all__ = ["ObjectTracking", "object_tracking"] diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py new file mode 100644 index 0000000000..feea0d3e42 --- /dev/null +++ b/dimos/perception/object_tracker_2d.py @@ -0,0 +1,299 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time + +import cv2 + +# Import LCM messages +from dimos_lcm.vision_msgs import ( # type: ignore[import-untyped] + BoundingBox2D, + Detection2D, + ObjectHypothesis, + ObjectHypothesisWithPose, + Point2D, + Pose2D, +) +import numpy as np +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.INFO) + + +class ObjectTracker2D(Module): + """Pure 2D object tracking module using OpenCV's CSRT tracker.""" + + color_image: In[Image] = None # type: ignore[assignment] + + detection2darray: Out[Detection2DArray] = None # type: ignore[assignment] + tracked_overlay: Out[Image] = None # type: ignore[assignment] # Visualization output + + def __init__( + self, + frame_id: str = "camera_link", + ) -> None: + """ + Initialize 2D object tracking module using OpenCV's CSRT tracker. + + Args: + frame_id: TF frame ID for the camera (default: "camera_link") + """ + super().__init__() + + self.frame_id = frame_id + + # Tracker state + self.tracker = None + self.tracking_bbox = None # Stores (x, y, w, h) + self.tracking_initialized = False + + # Stuck detection + self._last_bbox = None + self._stuck_count = 0 + self._max_stuck_frames = 10 # Higher threshold for stationary objects + + # Frame management + self._frame_lock = threading.Lock() + self._latest_rgb_frame: np.ndarray | None = None # type: ignore[type-arg] + self._frame_arrival_time: float | None = None + + # Tracking thread control + self.tracking_thread: threading.Thread | None = None + self.stop_tracking_event = threading.Event() + self.tracking_rate = 5.0 # Hz + self.tracking_period = 1.0 / self.tracking_rate + + # Store latest detection for RPC access + self._latest_detection2d: Detection2DArray | None = None + + @rpc + def start(self) -> None: + super().start() + + def on_frame(frame_msg: Image) -> None: + arrival_time = time.perf_counter() + with self._frame_lock: + self._latest_rgb_frame = frame_msg.data + self._frame_arrival_time = arrival_time + + unsub = self.color_image.subscribe(on_frame) + self._disposables.add(Disposable(unsub)) + logger.info("ObjectTracker2D module started") + + @rpc + def stop(self) -> None: + self.stop_track() + if self.tracking_thread and self.tracking_thread.is_alive(): + self.stop_tracking_event.set() + self.tracking_thread.join(timeout=2.0) + + super().stop() + + @rpc + def track(self, bbox: list[float]) -> dict: # type: ignore[type-arg] + """ + Initialize tracking with a bounding box. + + Args: + bbox: Bounding box in format [x1, y1, x2, y2] + + Returns: + Dict containing tracking status + """ + if self._latest_rgb_frame is None: + logger.warning("No RGB frame available for tracking") + return {"status": "no_frame"} + + # Initialize tracking + x1, y1, x2, y2 = map(int, bbox) + w, h = x2 - x1, y2 - y1 + if w <= 0 or h <= 0: + logger.warning(f"Invalid initial bbox provided: {bbox}. Tracking not started.") + return {"status": "invalid_bbox"} + + self.tracking_bbox = (x1, y1, w, h) # type: ignore[assignment] + self.tracker = cv2.legacy.TrackerCSRT_create() # type: ignore[attr-defined] + self.tracking_initialized = False + logger.info(f"Tracking target set with bbox: {self.tracking_bbox}") + + # Convert RGB to BGR for CSRT (OpenCV expects BGR) + frame_bgr = cv2.cvtColor(self._latest_rgb_frame, cv2.COLOR_RGB2BGR) + init_success = self.tracker.init(frame_bgr, self.tracking_bbox) # type: ignore[attr-defined] + if init_success: + self.tracking_initialized = True + logger.info("Tracker initialized successfully.") + else: + logger.error("Tracker initialization failed.") + self.stop_track() + return {"status": "init_failed"} + + # Start tracking thread + self._start_tracking_thread() + + return {"status": "tracking_started", "bbox": self.tracking_bbox} + + def _start_tracking_thread(self) -> None: + """Start the tracking thread.""" + self.stop_tracking_event.clear() + self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True) + self.tracking_thread.start() + logger.info("Started tracking thread") + + def _tracking_loop(self) -> None: + """Main tracking loop that runs in a separate thread.""" + while not self.stop_tracking_event.is_set() and self.tracking_initialized: + self._process_tracking() + time.sleep(self.tracking_period) + logger.info("Tracking loop ended") + + def _reset_tracking_state(self) -> None: + """Reset tracking state without stopping the thread.""" + self.tracker = None + self.tracking_bbox = None + self.tracking_initialized = False + self._last_bbox = None + self._stuck_count = 0 + + # Publish empty detection + empty_2d = Detection2DArray( + detections_length=0, header=Header(time.time(), self.frame_id), detections=[] + ) + self._latest_detection2d = empty_2d + self.detection2darray.publish(empty_2d) + + @rpc + def stop_track(self) -> bool: + """ + Stop tracking the current object. + + Returns: + bool: True if tracking was successfully stopped + """ + self._reset_tracking_state() + + # Stop tracking thread if running + if self.tracking_thread and self.tracking_thread.is_alive(): + if threading.current_thread() != self.tracking_thread: + self.stop_tracking_event.set() + self.tracking_thread.join(timeout=1.0) + self.tracking_thread = None + else: + self.stop_tracking_event.set() + + logger.info("Tracking stopped") + return True + + @rpc + def is_tracking(self) -> bool: + """ + Check if the tracker is currently tracking an object. + + Returns: + bool: True if tracking is active + """ + return self.tracking_initialized + + def _process_tracking(self) -> None: + """Process current frame for tracking and publish 2D detections.""" + if self.tracker is None or not self.tracking_initialized: + return + + # Get frame copy + with self._frame_lock: + if self._latest_rgb_frame is None: + return + frame = self._latest_rgb_frame.copy() + + # Convert RGB to BGR for CSRT (OpenCV expects BGR) + frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + tracker_succeeded, bbox_cv = self.tracker.update(frame_bgr) + + if not tracker_succeeded: + logger.info("Tracker update failed. Stopping track.") + self._reset_tracking_state() + return + + # Extract bbox + x, y, w, h = map(int, bbox_cv) + current_bbox_x1y1x2y2 = [x, y, x + w, y + h] + x1, y1, x2, y2 = current_bbox_x1y1x2y2 + + # Check if tracker is stuck + if self._last_bbox is not None: + if (x1, y1, x2, y2) == self._last_bbox: + self._stuck_count += 1 + if self._stuck_count >= self._max_stuck_frames: + logger.warning(f"Tracker stuck for {self._stuck_count} frames. Stopping track.") + self._reset_tracking_state() + return + else: + self._stuck_count = 0 + + self._last_bbox = (x1, y1, x2, y2) + + center_x = (x1 + x2) / 2.0 + center_y = (y1 + y2) / 2.0 + width = float(x2 - x1) + height = float(y2 - y1) + + # Create 2D detection header + header = Header(time.time(), self.frame_id) + + # Create Detection2D with all fields in constructors + detection_2d = Detection2D( + id="0", + results_length=1, + header=header, + bbox=BoundingBox2D( + center=Pose2D(position=Point2D(x=center_x, y=center_y), theta=0.0), + size_x=width, + size_y=height, + ), + results=[ + ObjectHypothesisWithPose( + hypothesis=ObjectHypothesis(class_id="tracked_object", score=1.0) + ) + ], + ) + + detection2darray = Detection2DArray( + detections_length=1, header=header, detections=[detection_2d] + ) + + # Store and publish + self._latest_detection2d = detection2darray + self.detection2darray.publish(detection2darray) + + # Create visualization + viz_image = self._draw_visualization(frame, current_bbox_x1y1x2y2) + viz_copy = viz_image.copy() # Force copy needed to prevent frame reuse + viz_msg = Image.from_numpy(viz_copy, format=ImageFormat.RGB) + self.tracked_overlay.publish(viz_msg) + + def _draw_visualization(self, image: np.ndarray, bbox: list[int]) -> np.ndarray: # type: ignore[type-arg] + """Draw tracking visualization.""" + viz_image = image.copy() + x1, y1, x2, y2 = bbox + cv2.rectangle(viz_image, (x1, y1), (x2, y2), (0, 255, 0), 2) + cv2.putText(viz_image, "TRACKING", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + return viz_image diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py new file mode 100644 index 0000000000..5d03efab33 --- /dev/null +++ b/dimos/perception/object_tracker_3d.py @@ -0,0 +1,304 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Import LCM messages +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from dimos_lcm.vision_msgs import ( # type: ignore[import-untyped] + Detection3D, + ObjectHypothesisWithPose, +) +import numpy as np + +from dimos.core import In, Out, rpc +from dimos.manipulation.visual_servoing.utils import visualize_detections_3d +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.perception.object_tracker_2d import ObjectTracker2D +from dimos.protocol.tf import TF +from dimos.types.timestamped import align_timestamped +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import ( + euler_to_quaternion, + optical_to_robot_frame, + yaw_towards_point, +) + +logger = setup_logger() + + +class ObjectTracker3D(ObjectTracker2D): + """3D object tracking module extending ObjectTracker2D with depth capabilities.""" + + # Additional inputs (2D tracker already has color_image) + depth: In[Image] = None # type: ignore[assignment] + camera_info: In[CameraInfo] = None # type: ignore[assignment] + + # Additional outputs (2D tracker already has detection2darray and tracked_overlay) + detection3darray: Out[Detection3DArray] = None # type: ignore[assignment] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + """ + Initialize 3D object tracking module. + + Args: + **kwargs: Arguments passed to parent ObjectTracker2D + """ + super().__init__(**kwargs) + + # Additional state for 3D tracking + self.camera_intrinsics = None + self._latest_depth_frame: np.ndarray | None = None # type: ignore[type-arg] + self._latest_camera_info: CameraInfo | None = None + + # TF publisher for tracked object + self.tf = TF() + + # Store latest 3D detection + self._latest_detection3d: Detection3DArray | None = None + + @rpc + def start(self) -> None: + super().start() + + # Subscribe to aligned RGB and depth streams + def on_aligned_frames(frames_tuple) -> None: # type: ignore[no-untyped-def] + rgb_msg, depth_msg = frames_tuple + with self._frame_lock: + self._latest_rgb_frame = rgb_msg.data + + depth_data = depth_msg.data + # Convert from millimeters to meters if depth is DEPTH16 format + if depth_msg.format == ImageFormat.DEPTH16: + depth_data = depth_data.astype(np.float32) / 1000.0 + self._latest_depth_frame = depth_data + + # Create aligned observable for RGB and depth + aligned_frames = align_timestamped( + self.color_image.observable(), # type: ignore[no-untyped-call] + self.depth.observable(), # type: ignore[no-untyped-call] + buffer_size=2.0, # 2 second buffer + match_tolerance=0.5, # 500ms tolerance + ) + unsub = aligned_frames.subscribe(on_aligned_frames) + self._disposables.add(unsub) + + # Subscribe to camera info + def on_camera_info(camera_info_msg: CameraInfo) -> None: + self._latest_camera_info = camera_info_msg + # Extract intrinsics: K is [fx, 0, cx, 0, fy, cy, 0, 0, 1] + self.camera_intrinsics = [ # type: ignore[assignment] + camera_info_msg.K[0], + camera_info_msg.K[4], + camera_info_msg.K[2], + camera_info_msg.K[5], + ] + + self.camera_info.subscribe(on_camera_info) + + logger.info("ObjectTracker3D module started with aligned frame subscription") + + @rpc + def stop(self) -> None: + super().stop() + + def _process_tracking(self) -> None: + """Override to add 3D detection creation after 2D tracking.""" + # Call parent 2D tracking + super()._process_tracking() + + # Enhance with 3D if we have depth and a valid 2D detection + if ( + self._latest_detection2d + and self._latest_detection2d.detections_length > 0 + and self._latest_depth_frame is not None + and self.camera_intrinsics is not None + ): + detection_3d = self._create_detection3d_from_2d(self._latest_detection2d) + if detection_3d: + self._latest_detection3d = detection_3d + self.detection3darray.publish(detection_3d) + + # Update visualization with 3D info + with self._frame_lock: + if self._latest_rgb_frame is not None: + frame = self._latest_rgb_frame.copy() + + # Extract 2D bbox for visualization + det_2d = self._latest_detection2d.detections[0] + x1 = det_2d.bbox.center.position.x - det_2d.bbox.size_x / 2 + y1 = det_2d.bbox.center.position.y - det_2d.bbox.size_y / 2 + x2 = det_2d.bbox.center.position.x + det_2d.bbox.size_x / 2 + y2 = det_2d.bbox.center.position.y + det_2d.bbox.size_y / 2 + bbox_2d = [[x1, y1, x2, y2]] + + # Create 3D visualization + viz_image = visualize_detections_3d( + frame, detection_3d.detections, show_coordinates=True, bboxes_2d=bbox_2d + ) + + # Overlay Re-ID matches + if self.last_good_matches and self.last_roi_kps and self.last_roi_bbox: + viz_image = self._draw_reid_overlay(viz_image) + + viz_msg = Image.from_numpy(viz_image) + self.tracked_overlay.publish(viz_msg) + + def _create_detection3d_from_2d(self, detection2d: Detection2DArray) -> Detection3DArray | None: + """Create 3D detection from 2D detection using depth.""" + if detection2d.detections_length == 0: + return None + + det_2d = detection2d.detections[0] + + # Get bbox center + center_x = det_2d.bbox.center.position.x + center_y = det_2d.bbox.center.position.y + width = det_2d.bbox.size_x + height = det_2d.bbox.size_y + + # Convert to bbox coordinates + x1 = int(center_x - width / 2) + y1 = int(center_y - height / 2) + x2 = int(center_x + width / 2) + y2 = int(center_y + height / 2) + + # Get depth value + depth_value = self._get_depth_from_bbox([x1, y1, x2, y2], self._latest_depth_frame) # type: ignore[arg-type] + + if depth_value is None or depth_value <= 0: + return None + + fx, fy, cx, cy = self.camera_intrinsics # type: ignore[misc] + + # Convert pixel coordinates to 3D in optical frame + z_optical = depth_value + x_optical = (center_x - cx) * z_optical / fx # type: ignore[has-type] + y_optical = (center_y - cy) * z_optical / fy # type: ignore[has-type] + + # Create pose in optical frame + optical_pose = Pose() + optical_pose.position = Vector3(x_optical, y_optical, z_optical) + optical_pose.orientation = Quaternion(0.0, 0.0, 0.0, 1.0) + + # Convert to robot frame + robot_pose = optical_to_robot_frame(optical_pose) + + # Calculate orientation: object facing towards camera + yaw = yaw_towards_point(robot_pose.position) + euler = Vector3(0.0, 0.0, yaw) + robot_pose.orientation = euler_to_quaternion(euler) + + # Estimate object size in meters + size_x = width * z_optical / fx # type: ignore[has-type] + size_y = height * z_optical / fy # type: ignore[has-type] + size_z = 0.1 # Default depth size + + # Create Detection3D + header = Header(self.frame_id) + detection_3d = Detection3D() + detection_3d.id = "0" + detection_3d.results_length = 1 + detection_3d.header = header + + # Create hypothesis + hypothesis = ObjectHypothesisWithPose() + hypothesis.hypothesis.class_id = "tracked_object" + hypothesis.hypothesis.score = 1.0 + detection_3d.results = [hypothesis] + + # Create 3D bounding box + detection_3d.bbox.center = Pose() + detection_3d.bbox.center.position = robot_pose.position + detection_3d.bbox.center.orientation = robot_pose.orientation + detection_3d.bbox.size = Vector3(size_x, size_y, size_z) + + detection3darray = Detection3DArray() + detection3darray.detections_length = 1 + detection3darray.header = header + detection3darray.detections = [detection_3d] + + # Publish TF for tracked object + tracked_object_tf = Transform( + translation=robot_pose.position, + rotation=robot_pose.orientation, + frame_id=self.frame_id, + child_frame_id="tracked_object", + ts=header.ts, + ) + self.tf.publish(tracked_object_tf) + + return detection3darray + + def _get_depth_from_bbox(self, bbox: list[int], depth_frame: np.ndarray) -> float | None: # type: ignore[type-arg] + """ + Calculate depth from bbox using the 25th percentile of closest points. + + Args: + bbox: Bounding box coordinates [x1, y1, x2, y2] + depth_frame: Depth frame to extract depth values from + + Returns: + Depth value or None if not available + """ + if depth_frame is None: + return None + + x1, y1, x2, y2 = bbox + + # Ensure bbox is within frame bounds + y1 = max(0, y1) + y2 = min(depth_frame.shape[0], y2) + x1 = max(0, x1) + x2 = min(depth_frame.shape[1], x2) + + # Extract depth values from the bbox + roi_depth = depth_frame[y1:y2, x1:x2] + + # Get valid (finite and positive) depth values + valid_depths = roi_depth[np.isfinite(roi_depth) & (roi_depth > 0)] + + if len(valid_depths) > 0: + return float(np.percentile(valid_depths, 25)) + + return None + + def _draw_reid_overlay(self, image: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """Draw Re-ID feature matches on visualization.""" + import cv2 + + viz_image = image.copy() + x1, y1, _x2, _y2 = self.last_roi_bbox # type: ignore[attr-defined] + + # Draw keypoints + for kp in self.last_roi_kps: # type: ignore[attr-defined] + pt = (int(kp.pt[0] + x1), int(kp.pt[1] + y1)) + cv2.circle(viz_image, pt, 3, (0, 255, 0), -1) + + # Draw matches + for match in self.last_good_matches: # type: ignore[attr-defined] + current_kp = self.last_roi_kps[match.trainIdx] # type: ignore[attr-defined] + pt_current = (int(current_kp.pt[0] + x1), int(current_kp.pt[1] + y1)) + cv2.circle(viz_image, pt_current, 5, (0, 255, 255), 2) + + intensity = int(255 * (1.0 - min(match.distance / 100.0, 1.0))) + cv2.circle(viz_image, pt_current, 2, (intensity, intensity, 255), -1) + + # Draw match count + text = f"REID: {len(self.last_good_matches)}/{len(self.last_roi_kps)}" # type: ignore[attr-defined] + cv2.putText(viz_image, text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) + + return viz_image diff --git a/dimos/perception/person_tracker.py b/dimos/perception/person_tracker.py new file mode 100644 index 0000000000..0b8160c1da --- /dev/null +++ b/dimos/perception/person_tracker.py @@ -0,0 +1,260 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cv2 +import numpy as np +from reactivex import Observable, interval, operators as ops +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.sensor_msgs import Image +from dimos.perception.common.ibvs import PersonDistanceEstimator +from dimos.perception.detection2d.utils import filter_detections +from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector # type: ignore[import-untyped] +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class PersonTrackingStream(Module): + """Module for person tracking with LCM input/output.""" + + # LCM inputs + video: In[Image] = None # type: ignore[assignment] + + # LCM outputs + tracking_data: Out[dict] = None # type: ignore[assignment, type-arg] + + def __init__( # type: ignore[no-untyped-def] + self, + camera_intrinsics=None, + camera_pitch: float = 0.0, + camera_height: float = 1.0, + ) -> None: + """ + Initialize a person tracking stream using Yolo2DDetector and PersonDistanceEstimator. + + Args: + camera_intrinsics: List in format [fx, fy, cx, cy] where: + - fx: Focal length in x direction (pixels) + - fy: Focal length in y direction (pixels) + - cx: Principal point x-coordinate (pixels) + - cy: Principal point y-coordinate (pixels) + camera_pitch: Camera pitch angle in radians (positive is up) + camera_height: Height of the camera from the ground in meters + """ + # Call parent Module init + super().__init__() + + self.camera_intrinsics = camera_intrinsics + self.camera_pitch = camera_pitch + self.camera_height = camera_height + + self.detector = Yolo2DDetector() + + # Initialize distance estimator + if camera_intrinsics is None: + raise ValueError("Camera intrinsics are required for distance estimation") + + # Validate camera intrinsics format [fx, fy, cx, cy] + if ( + not isinstance(camera_intrinsics, list | tuple | np.ndarray) + or len(camera_intrinsics) != 4 + ): + raise ValueError("Camera intrinsics must be provided as [fx, fy, cx, cy]") + + # Convert [fx, fy, cx, cy] to 3x3 camera matrix + fx, fy, cx, cy = camera_intrinsics + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + self.distance_estimator = PersonDistanceEstimator( + K=K, camera_pitch=camera_pitch, camera_height=camera_height + ) + + # For tracking latest frame data + self._latest_frame: np.ndarray | None = None # type: ignore[type-arg] + self._process_interval = 0.1 # Process at 10Hz + + # Tracking state - starts disabled + self._tracking_enabled = False + + @rpc + def start(self) -> None: + """Start the person tracking module and subscribe to LCM streams.""" + + super().start() + + # Subscribe to video stream + def set_video(image_msg: Image) -> None: + if hasattr(image_msg, "data"): + self._latest_frame = image_msg.data + else: + logger.warning("Received image message without data attribute") + + unsub = self.video.subscribe(set_video) + self._disposables.add(Disposable(unsub)) + + # Start periodic processing + unsub = interval(self._process_interval).subscribe(lambda _: self._process_frame()) # type: ignore[assignment] + self._disposables.add(unsub) # type: ignore[arg-type] + + logger.info("PersonTracking module started and subscribed to LCM streams") + + @rpc + def stop(self) -> None: + super().stop() + + def _process_frame(self) -> None: + """Process the latest frame if available.""" + if self._latest_frame is None: + return + + # Only process and publish if tracking is enabled + if not self._tracking_enabled: + return + + # Process frame through tracking pipeline + result = self._process_tracking(self._latest_frame) # type: ignore[no-untyped-call] + + # Publish result to LCM + if result: + self.tracking_data.publish(result) + + def _process_tracking(self, frame): # type: ignore[no-untyped-def] + """Process a single frame for person tracking.""" + # Detect people in the frame + bboxes, track_ids, class_ids, confidences, names = self.detector.process_image(frame) + + # Filter to keep only person detections using filter_detections + ( + filtered_bboxes, + filtered_track_ids, + filtered_class_ids, + filtered_confidences, + filtered_names, + ) = filter_detections( + bboxes, + track_ids, + class_ids, + confidences, + names, + class_filter=[0], # 0 is the class_id for person + name_filter=["person"], + ) + + # Create visualization + viz_frame = self.detector.visualize_results( + frame, + filtered_bboxes, + filtered_track_ids, + filtered_class_ids, + filtered_confidences, + filtered_names, + ) + + # Calculate distance and angle for each person + targets = [] + for i, bbox in enumerate(filtered_bboxes): + target_data = { + "target_id": filtered_track_ids[i] if i < len(filtered_track_ids) else -1, + "bbox": bbox, + "confidence": filtered_confidences[i] if i < len(filtered_confidences) else None, + } + + distance, angle = self.distance_estimator.estimate_distance_angle(bbox) + target_data["distance"] = distance + target_data["angle"] = angle + + # Add text to visualization + _x1, y1, x2, _y2 = map(int, bbox) + dist_text = f"{distance:.2f}m, {np.rad2deg(angle):.1f} deg" + + # Add black background for better visibility + text_size = cv2.getTextSize(dist_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] + # Position at top-right corner + cv2.rectangle( + viz_frame, (x2 - text_size[0], y1 - text_size[1] - 5), (x2, y1), (0, 0, 0), -1 + ) + + # Draw text in white at top-right + cv2.putText( + viz_frame, + dist_text, + (x2 - text_size[0], y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 2, + ) + + targets.append(target_data) + + # Create the result dictionary + return {"frame": frame, "viz_frame": viz_frame, "targets": targets} + + @rpc + def enable_tracking(self) -> bool: + """Enable person tracking. + + Returns: + bool: True if tracking was enabled successfully + """ + self._tracking_enabled = True + logger.info("Person tracking enabled") + return True + + @rpc + def disable_tracking(self) -> bool: + """Disable person tracking. + + Returns: + bool: True if tracking was disabled successfully + """ + self._tracking_enabled = False + logger.info("Person tracking disabled") + return True + + @rpc + def is_tracking_enabled(self) -> bool: + """Check if tracking is currently enabled. + + Returns: + bool: True if tracking is enabled + """ + return self._tracking_enabled + + @rpc + def get_tracking_data(self) -> dict: # type: ignore[type-arg] + """Get the latest tracking data. + + Returns: + Dictionary containing tracking results + """ + if self._latest_frame is not None: + return self._process_tracking(self._latest_frame) # type: ignore[no-any-return, no-untyped-call] + return {"frame": None, "viz_frame": None, "targets": []} + + def create_stream(self, video_stream: Observable) -> Observable: # type: ignore[type-arg] + """ + Create an Observable stream of person tracking results from a video stream. + + Args: + video_stream: Observable that emits video frames + + Returns: + Observable that emits dictionaries containing tracking results and visualizations + """ + + return video_stream.pipe(ops.map(self._process_tracking)) diff --git a/dimos/perception/pointcloud/__init__.py b/dimos/perception/pointcloud/__init__.py new file mode 100644 index 0000000000..a380e2aadf --- /dev/null +++ b/dimos/perception/pointcloud/__init__.py @@ -0,0 +1,3 @@ +from .cuboid_fit import * +from .pointcloud_filtering import * +from .utils import * diff --git a/dimos/perception/pointcloud/cuboid_fit.py b/dimos/perception/pointcloud/cuboid_fit.py new file mode 100644 index 0000000000..dfec2d9297 --- /dev/null +++ b/dimos/perception/pointcloud/cuboid_fit.py @@ -0,0 +1,420 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cv2 +import numpy as np +import open3d as o3d # type: ignore[import-untyped] + + +def fit_cuboid( + points: np.ndarray | o3d.geometry.PointCloud, # type: ignore[type-arg] + method: str = "minimal", +) -> dict | None: # type: ignore[type-arg] + """ + Fit a cuboid to a point cloud using Open3D's built-in methods. + + Args: + points: Nx3 array of points or Open3D PointCloud + method: Fitting method: + - 'minimal': Minimal oriented bounding box (best fit) + - 'oriented': PCA-based oriented bounding box + - 'axis_aligned': Axis-aligned bounding box + + Returns: + Dictionary containing: + - center: 3D center point + - dimensions: 3D dimensions (extent) + - rotation: 3x3 rotation matrix + - error: Fitting error + - bounding_box: Open3D OrientedBoundingBox object + Returns None if insufficient points or fitting fails. + + Raises: + ValueError: If method is invalid or inputs are malformed + """ + # Validate method + valid_methods = ["minimal", "oriented", "axis_aligned"] + if method not in valid_methods: + raise ValueError(f"method must be one of {valid_methods}, got '{method}'") + + # Convert to point cloud if needed + if isinstance(points, np.ndarray): + points = np.asarray(points) + if len(points.shape) != 2 or points.shape[1] != 3: + raise ValueError(f"points array must be Nx3, got shape {points.shape}") + if len(points) < 4: + return None + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + elif isinstance(points, o3d.geometry.PointCloud): + pcd = points + points = np.asarray(pcd.points) + if len(points) < 4: + return None + else: + raise ValueError(f"points must be numpy array or Open3D PointCloud, got {type(points)}") + + try: + # Get bounding box based on method + if method == "minimal": + obb = pcd.get_minimal_oriented_bounding_box(robust=True) + elif method == "oriented": + obb = pcd.get_oriented_bounding_box(robust=True) + elif method == "axis_aligned": + # Convert axis-aligned to oriented format for consistency + aabb = pcd.get_axis_aligned_bounding_box() + obb = o3d.geometry.OrientedBoundingBox() + obb.center = aabb.get_center() + obb.extent = aabb.get_extent() + obb.R = np.eye(3) # Identity rotation for axis-aligned + + # Extract parameters + center = np.asarray(obb.center) + dimensions = np.asarray(obb.extent) + rotation = np.asarray(obb.R) + + # Calculate fitting error + error = _compute_fitting_error(points, center, dimensions, rotation) + + return { + "center": center, + "dimensions": dimensions, + "rotation": rotation, + "error": error, + "bounding_box": obb, + "method": method, + } + + except Exception as e: + # Log error but don't crash - return None for graceful handling + print(f"Warning: Cuboid fitting failed with method '{method}': {e}") + return None + + +def fit_cuboid_simple(points: np.ndarray | o3d.geometry.PointCloud) -> dict | None: # type: ignore[type-arg] + """ + Simple wrapper for minimal oriented bounding box fitting. + + Args: + points: Nx3 array of points or Open3D PointCloud + + Returns: + Dictionary with center, dimensions, rotation, and bounding_box, + or None if insufficient points + """ + return fit_cuboid(points, method="minimal") + + +def _compute_fitting_error( + points: np.ndarray, # type: ignore[type-arg] + center: np.ndarray, # type: ignore[type-arg] + dimensions: np.ndarray, # type: ignore[type-arg] + rotation: np.ndarray, # type: ignore[type-arg] +) -> float: + """ + Compute fitting error as mean squared distance from points to cuboid surface. + + Args: + points: Nx3 array of points + center: 3D center point + dimensions: 3D dimensions + rotation: 3x3 rotation matrix + + Returns: + Mean squared error + """ + if len(points) == 0: + return 0.0 + + # Transform points to local coordinates + local_points = (points - center) @ rotation + half_dims = dimensions / 2 + + # Calculate distance to cuboid surface + dx = np.abs(local_points[:, 0]) - half_dims[0] + dy = np.abs(local_points[:, 1]) - half_dims[1] + dz = np.abs(local_points[:, 2]) - half_dims[2] + + # Points outside: distance to nearest face + # Points inside: negative distance to nearest face + outside_dist = np.sqrt(np.maximum(dx, 0) ** 2 + np.maximum(dy, 0) ** 2 + np.maximum(dz, 0) ** 2) + inside_dist = np.minimum(np.minimum(dx, dy), dz) + distances = np.where((dx > 0) | (dy > 0) | (dz > 0), outside_dist, -inside_dist) + + return float(np.mean(distances**2)) + + +def get_cuboid_corners( + center: np.ndarray, # type: ignore[type-arg] + dimensions: np.ndarray, # type: ignore[type-arg] + rotation: np.ndarray, # type: ignore[type-arg] +) -> np.ndarray: # type: ignore[type-arg] + """ + Get the 8 corners of a cuboid. + + Args: + center: 3D center point + dimensions: 3D dimensions + rotation: 3x3 rotation matrix + + Returns: + 8x3 array of corner coordinates + """ + half_dims = dimensions / 2 + corners_local = ( + np.array( + [ + [-1, -1, -1], # 0: left bottom back + [-1, -1, 1], # 1: left bottom front + [-1, 1, -1], # 2: left top back + [-1, 1, 1], # 3: left top front + [1, -1, -1], # 4: right bottom back + [1, -1, 1], # 5: right bottom front + [1, 1, -1], # 6: right top back + [1, 1, 1], # 7: right top front + ] + ) + * half_dims + ) + + # Apply rotation and translation + return corners_local @ rotation.T + center # type: ignore[no-any-return] + + +def visualize_cuboid_on_image( + image: np.ndarray, # type: ignore[type-arg] + cuboid_params: dict, # type: ignore[type-arg] + camera_matrix: np.ndarray, # type: ignore[type-arg] + extrinsic_rotation: np.ndarray | None = None, # type: ignore[type-arg] + extrinsic_translation: np.ndarray | None = None, # type: ignore[type-arg] + color: tuple[int, int, int] = (0, 255, 0), + thickness: int = 2, + show_dimensions: bool = True, +) -> np.ndarray: # type: ignore[type-arg] + """ + Draw a fitted cuboid on an image using camera projection. + + Args: + image: Input image to draw on + cuboid_params: Dictionary containing cuboid parameters + camera_matrix: Camera intrinsic matrix (3x3) + extrinsic_rotation: Optional external rotation (3x3) + extrinsic_translation: Optional external translation (3x1) + color: Line color as (B, G, R) tuple + thickness: Line thickness + show_dimensions: Whether to display dimension text + + Returns: + Image with cuboid visualization + + Raises: + ValueError: If required parameters are missing or invalid + """ + # Validate inputs + required_keys = ["center", "dimensions", "rotation"] + if not all(key in cuboid_params for key in required_keys): + raise ValueError(f"cuboid_params must contain keys: {required_keys}") + + if camera_matrix.shape != (3, 3): + raise ValueError(f"camera_matrix must be 3x3, got {camera_matrix.shape}") + + # Get corners in world coordinates + corners = get_cuboid_corners( + cuboid_params["center"], cuboid_params["dimensions"], cuboid_params["rotation"] + ) + + # Transform corners if extrinsic parameters are provided + if extrinsic_rotation is not None and extrinsic_translation is not None: + if extrinsic_rotation.shape != (3, 3): + raise ValueError(f"extrinsic_rotation must be 3x3, got {extrinsic_rotation.shape}") + if extrinsic_translation.shape not in [(3,), (3, 1)]: + raise ValueError( + f"extrinsic_translation must be (3,) or (3,1), got {extrinsic_translation.shape}" + ) + + extrinsic_translation = extrinsic_translation.flatten() + corners = (extrinsic_rotation @ corners.T).T + extrinsic_translation + + try: + # Project 3D corners to image coordinates + corners_img, _ = cv2.projectPoints( # type: ignore[call-overload] + corners.astype(np.float32), + np.zeros(3), + np.zeros(3), # No additional rotation/translation + camera_matrix.astype(np.float32), + None, # No distortion + ) + corners_img = corners_img.reshape(-1, 2).astype(int) + + # Check if corners are within image bounds + h, w = image.shape[:2] + valid_corners = ( + (corners_img[:, 0] >= 0) + & (corners_img[:, 0] < w) + & (corners_img[:, 1] >= 0) + & (corners_img[:, 1] < h) + ) + + if not np.any(valid_corners): + print("Warning: All cuboid corners are outside image bounds") + return image.copy() + + except Exception as e: + print(f"Warning: Failed to project cuboid corners: {e}") + return image.copy() + + # Define edges for wireframe visualization + edges = [ + # Bottom face + (0, 1), + (1, 5), + (5, 4), + (4, 0), + # Top face + (2, 3), + (3, 7), + (7, 6), + (6, 2), + # Vertical edges + (0, 2), + (1, 3), + (5, 7), + (4, 6), + ] + + # Draw edges + vis_img = image.copy() + for i, j in edges: + # Only draw edge if both corners are valid + if valid_corners[i] and valid_corners[j]: + cv2.line(vis_img, tuple(corners_img[i]), tuple(corners_img[j]), color, thickness) + + # Add dimension text if requested + if show_dimensions and np.any(valid_corners): + dims = cuboid_params["dimensions"] + dim_text = f"Dims: {dims[0]:.3f} x {dims[1]:.3f} x {dims[2]:.3f}" + + # Find a good position for text (top-left of image) + text_pos = (10, 30) + font_scale = 0.7 + + # Add background rectangle for better readability + text_size = cv2.getTextSize(dim_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 2)[0] + cv2.rectangle( + vis_img, + (text_pos[0] - 5, text_pos[1] - text_size[1] - 5), + (text_pos[0] + text_size[0] + 5, text_pos[1] + 5), + (0, 0, 0), + -1, + ) + + cv2.putText(vis_img, dim_text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, 2) + + return vis_img + + +def compute_cuboid_volume(cuboid_params: dict) -> float: # type: ignore[type-arg] + """ + Compute the volume of a cuboid. + + Args: + cuboid_params: Dictionary containing cuboid parameters + + Returns: + Volume in cubic units + """ + if "dimensions" not in cuboid_params: + raise ValueError("cuboid_params must contain 'dimensions' key") + + dims = cuboid_params["dimensions"] + return float(np.prod(dims)) + + +def compute_cuboid_surface_area(cuboid_params: dict) -> float: # type: ignore[type-arg] + """ + Compute the surface area of a cuboid. + + Args: + cuboid_params: Dictionary containing cuboid parameters + + Returns: + Surface area in square units + """ + if "dimensions" not in cuboid_params: + raise ValueError("cuboid_params must contain 'dimensions' key") + + dims = cuboid_params["dimensions"] + return 2.0 * (dims[0] * dims[1] + dims[1] * dims[2] + dims[2] * dims[0]) # type: ignore[no-any-return] + + +def check_cuboid_quality(cuboid_params: dict, points: np.ndarray) -> dict: # type: ignore[type-arg] + """ + Assess the quality of a cuboid fit. + + Args: + cuboid_params: Dictionary containing cuboid parameters + points: Original points used for fitting + + Returns: + Dictionary with quality metrics + """ + if len(points) == 0: + return {"error": "No points provided"} + + # Basic metrics + volume = compute_cuboid_volume(cuboid_params) + surface_area = compute_cuboid_surface_area(cuboid_params) + error = cuboid_params.get("error", 0.0) + + # Aspect ratio analysis + dims = cuboid_params["dimensions"] + aspect_ratios = [ + dims[0] / dims[1] if dims[1] > 0 else float("inf"), + dims[1] / dims[2] if dims[2] > 0 else float("inf"), + dims[2] / dims[0] if dims[0] > 0 else float("inf"), + ] + max_aspect_ratio = max(aspect_ratios) + + # Volume ratio (cuboid volume vs convex hull volume) + try: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + hull, _ = pcd.compute_convex_hull() + hull_volume = hull.get_volume() + volume_ratio = volume / hull_volume if hull_volume > 0 else float("inf") + except: + volume_ratio = None + + return { + "fitting_error": error, + "volume": volume, + "surface_area": surface_area, + "max_aspect_ratio": max_aspect_ratio, + "volume_ratio": volume_ratio, + "num_points": len(points), + "method": cuboid_params.get("method", "unknown"), + } + + +# Backward compatibility +def visualize_fit(image, cuboid_params, camera_matrix, R=None, t=None): # type: ignore[no-untyped-def] + """ + Legacy function for backward compatibility. + Use visualize_cuboid_on_image instead. + """ + return visualize_cuboid_on_image( + image, cuboid_params, camera_matrix, R, t, show_dimensions=True + ) diff --git a/dimos/perception/pointcloud/pointcloud_filtering.py b/dimos/perception/pointcloud/pointcloud_filtering.py new file mode 100644 index 0000000000..d6aa2b835f --- /dev/null +++ b/dimos/perception/pointcloud/pointcloud_filtering.py @@ -0,0 +1,370 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cv2 +import numpy as np +import open3d as o3d # type: ignore[import-untyped] +import torch + +from dimos.perception.pointcloud.cuboid_fit import fit_cuboid +from dimos.perception.pointcloud.utils import ( + create_point_cloud_and_extract_masks, + load_camera_matrix_from_yaml, +) +from dimos.types.manipulation import ObjectData +from dimos.types.vector import Vector + + +class PointcloudFiltering: + """ + A production-ready point cloud filtering pipeline for segmented objects. + + This class takes segmentation results and produces clean, filtered point clouds + for each object with consistent coloring and optional outlier removal. + """ + + def __init__( + self, + color_intrinsics: str | list[float] | np.ndarray | None = None, # type: ignore[type-arg] + depth_intrinsics: str | list[float] | np.ndarray | None = None, # type: ignore[type-arg] + color_weight: float = 0.3, + enable_statistical_filtering: bool = True, + statistical_neighbors: int = 20, + statistical_std_ratio: float = 1.5, + enable_radius_filtering: bool = True, + radius_filtering_radius: float = 0.015, + radius_filtering_min_neighbors: int = 25, + enable_subsampling: bool = True, + voxel_size: float = 0.005, + max_num_objects: int = 10, + min_points_for_cuboid: int = 10, + cuboid_method: str = "oriented", + max_bbox_size_percent: float = 30.0, + ) -> None: + """ + Initialize the point cloud filtering pipeline. + + Args: + color_intrinsics: Camera intrinsics for color image + depth_intrinsics: Camera intrinsics for depth image + color_weight: Weight for blending generated color with original (0.0-1.0) + enable_statistical_filtering: Enable/disable statistical outlier filtering + statistical_neighbors: Number of neighbors for statistical filtering + statistical_std_ratio: Standard deviation ratio for statistical filtering + enable_radius_filtering: Enable/disable radius outlier filtering + radius_filtering_radius: Search radius for radius filtering (meters) + radius_filtering_min_neighbors: Min neighbors within radius + enable_subsampling: Enable/disable point cloud subsampling + voxel_size: Voxel size for downsampling (meters, when subsampling enabled) + max_num_objects: Maximum number of objects to process (top N by confidence) + min_points_for_cuboid: Minimum points required for cuboid fitting + cuboid_method: Method for cuboid fitting ('minimal', 'oriented', 'axis_aligned') + max_bbox_size_percent: Maximum percentage of image size for object bboxes (0-100) + + Raises: + ValueError: If invalid parameters are provided + """ + # Validate parameters + if not 0.0 <= color_weight <= 1.0: + raise ValueError(f"color_weight must be between 0.0 and 1.0, got {color_weight}") + if not 0.0 <= max_bbox_size_percent <= 100.0: + raise ValueError( + f"max_bbox_size_percent must be between 0.0 and 100.0, got {max_bbox_size_percent}" + ) + + # Store settings + self.color_weight = color_weight + self.enable_statistical_filtering = enable_statistical_filtering + self.statistical_neighbors = statistical_neighbors + self.statistical_std_ratio = statistical_std_ratio + self.enable_radius_filtering = enable_radius_filtering + self.radius_filtering_radius = radius_filtering_radius + self.radius_filtering_min_neighbors = radius_filtering_min_neighbors + self.enable_subsampling = enable_subsampling + self.voxel_size = voxel_size + self.max_num_objects = max_num_objects + self.min_points_for_cuboid = min_points_for_cuboid + self.cuboid_method = cuboid_method + self.max_bbox_size_percent = max_bbox_size_percent + + # Load camera matrices + self.color_camera_matrix = load_camera_matrix_from_yaml(color_intrinsics) + self.depth_camera_matrix = load_camera_matrix_from_yaml(depth_intrinsics) + + # Store the full point cloud + self.full_pcd = None + + def generate_color_from_id(self, object_id: int) -> np.ndarray: # type: ignore[type-arg] + """Generate a consistent color for a given object ID.""" + np.random.seed(object_id) + color = np.random.randint(0, 255, 3, dtype=np.uint8) + np.random.seed(None) + return color + + def _validate_inputs( # type: ignore[no-untyped-def] + self, + color_img: np.ndarray, # type: ignore[type-arg] + depth_img: np.ndarray, # type: ignore[type-arg] + objects: list[ObjectData], + ): + """Validate input parameters.""" + if color_img.shape[:2] != depth_img.shape: + raise ValueError("Color and depth image dimensions don't match") + + def _prepare_masks(self, masks: list[np.ndarray], target_shape: tuple) -> list[np.ndarray]: # type: ignore[type-arg] + """Prepare and validate masks to match target shape.""" + processed_masks = [] + for mask in masks: + # Convert mask to numpy if it's a tensor + if hasattr(mask, "cpu"): + mask = mask.cpu().numpy() + + mask = mask.astype(bool) + + # Handle shape mismatches + if mask.shape != target_shape: + if len(mask.shape) > 2: + mask = mask[:, :, 0] + + if mask.shape != target_shape: + mask = cv2.resize( + mask.astype(np.uint8), + (target_shape[1], target_shape[0]), + interpolation=cv2.INTER_NEAREST, + ).astype(bool) + + processed_masks.append(mask) + + return processed_masks + + def _apply_color_mask( + self, + pcd: o3d.geometry.PointCloud, + rgb_color: np.ndarray, # type: ignore[type-arg] + ) -> o3d.geometry.PointCloud: + """Apply weighted color mask to point cloud.""" + if len(np.asarray(pcd.colors)) > 0: + original_colors = np.asarray(pcd.colors) + generated_color = rgb_color.astype(np.float32) / 255.0 + colored_mask = ( + 1.0 - self.color_weight + ) * original_colors + self.color_weight * generated_color + colored_mask = np.clip(colored_mask, 0.0, 1.0) + pcd.colors = o3d.utility.Vector3dVector(colored_mask) + return pcd + + def _apply_filtering(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud: + """Apply optional filtering to point cloud based on enabled flags.""" + current_pcd = pcd + + # Apply statistical filtering if enabled + if self.enable_statistical_filtering: + current_pcd, _ = current_pcd.remove_statistical_outlier( + nb_neighbors=self.statistical_neighbors, std_ratio=self.statistical_std_ratio + ) + + # Apply radius filtering if enabled + if self.enable_radius_filtering: + current_pcd, _ = current_pcd.remove_radius_outlier( + nb_points=self.radius_filtering_min_neighbors, radius=self.radius_filtering_radius + ) + + return current_pcd + + def _apply_subsampling(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.PointCloud: + """Apply subsampling to limit point cloud size using Open3D's voxel downsampling.""" + if self.enable_subsampling: + return pcd.voxel_down_sample(self.voxel_size) + return pcd + + def _extract_masks_from_objects(self, objects: list[ObjectData]) -> list[np.ndarray]: # type: ignore[type-arg] + """Extract segmentation masks from ObjectData objects.""" + return [obj["segmentation_mask"] for obj in objects] + + def get_full_point_cloud(self) -> o3d.geometry.PointCloud: + """Get the full point cloud.""" + return self._apply_subsampling(self.full_pcd) + + def process_images( + self, + color_img: np.ndarray, # type: ignore[type-arg] + depth_img: np.ndarray, # type: ignore[type-arg] + objects: list[ObjectData], + ) -> list[ObjectData]: + """ + Process color and depth images with object detection results to create filtered point clouds. + + Args: + color_img: RGB image as numpy array (H, W, 3) + depth_img: Depth image as numpy array (H, W) in meters + objects: List of ObjectData from object detection stream + + Returns: + List of updated ObjectData with pointcloud and 3D information. Each ObjectData + dictionary is enhanced with the following new fields: + + **3D Spatial Information** (added when sufficient points for cuboid fitting): + - "position": Vector(x, y, z) - 3D center position in world coordinates (meters) + - "rotation": Vector(roll, pitch, yaw) - 3D orientation as Euler angles (radians) + - "size": {"width": float, "height": float, "depth": float} - 3D bounding box dimensions (meters) + + **Point Cloud Data**: + - "point_cloud": o3d.geometry.PointCloud - Filtered Open3D point cloud with colors + - "color": np.ndarray - Consistent RGB color [R,G,B] (0-255) generated from object_id + + **Grasp Generation Arrays** (Dimensional grasp format): + - "point_cloud_numpy": np.ndarray - Nx3 XYZ coordinates as float32 (meters) + - "colors_numpy": np.ndarray - Nx3 RGB colors as float32 (0.0-1.0 range) + + Raises: + ValueError: If inputs are invalid + RuntimeError: If processing fails + """ + # Validate inputs + self._validate_inputs(color_img, depth_img, objects) + + if not objects: + return [] + + # Filter to top N objects by confidence + if len(objects) > self.max_num_objects: + # Sort objects by confidence (highest first), handle None confidences + sorted_objects = sorted( + objects, + key=lambda obj: obj.get("confidence", 0.0) + if obj.get("confidence") is not None + else 0.0, + reverse=True, + ) + objects = sorted_objects[: self.max_num_objects] + + # Filter out objects with bboxes too large + image_area = color_img.shape[0] * color_img.shape[1] + max_bbox_area = image_area * (self.max_bbox_size_percent / 100.0) + + filtered_objects = [] + for obj in objects: + if "bbox" in obj and obj["bbox"] is not None: + bbox = obj["bbox"] + # Calculate bbox area (assuming bbox format [x1, y1, x2, y2]) + bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + if bbox_area <= max_bbox_area: + filtered_objects.append(obj) + else: + filtered_objects.append(obj) + + objects = filtered_objects + + # Extract masks from ObjectData + masks = self._extract_masks_from_objects(objects) + + # Prepare masks + processed_masks = self._prepare_masks(masks, depth_img.shape) + + # Create point clouds efficiently + self.full_pcd, masked_pcds = create_point_cloud_and_extract_masks( + color_img, + depth_img, + processed_masks, + self.depth_camera_matrix, # type: ignore[arg-type] + depth_scale=1.0, + ) + + # Process each object and update ObjectData + updated_objects = [] + + for i, (obj, _mask, pcd) in enumerate( + zip(objects, processed_masks, masked_pcds, strict=False) + ): + # Skip empty point clouds + if len(np.asarray(pcd.points)) == 0: + continue + + # Create a copy of the object data to avoid modifying the original + updated_obj = obj.copy() + + # Generate consistent color + object_id = obj.get("object_id", i) + rgb_color = self.generate_color_from_id(object_id) + + # Apply color mask + pcd = self._apply_color_mask(pcd, rgb_color) + + # Apply subsampling to control point cloud size + pcd = self._apply_subsampling(pcd) + + # Apply filtering (optional based on flags) + pcd_filtered = self._apply_filtering(pcd) + + # Fit cuboid and extract 3D information + points = np.asarray(pcd_filtered.points) + if len(points) >= self.min_points_for_cuboid: + cuboid_params = fit_cuboid(points, method=self.cuboid_method) + if cuboid_params is not None: + # Update position, rotation, and size from cuboid + center = cuboid_params["center"] + dimensions = cuboid_params["dimensions"] + rotation_matrix = cuboid_params["rotation"] + + # Convert rotation matrix to euler angles (roll, pitch, yaw) + sy = np.sqrt( + rotation_matrix[0, 0] * rotation_matrix[0, 0] + + rotation_matrix[1, 0] * rotation_matrix[1, 0] + ) + singular = sy < 1e-6 + + if not singular: + roll = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2]) + pitch = np.arctan2(-rotation_matrix[2, 0], sy) + yaw = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0]) + else: + roll = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1]) + pitch = np.arctan2(-rotation_matrix[2, 0], sy) + yaw = 0 + + # Update position, rotation, and size from cuboid + updated_obj["position"] = Vector(center[0], center[1], center[2]) + updated_obj["rotation"] = Vector(roll, pitch, yaw) + updated_obj["size"] = { + "width": float(dimensions[0]), + "height": float(dimensions[1]), + "depth": float(dimensions[2]), + } + + # Add point cloud data to ObjectData + updated_obj["point_cloud"] = pcd_filtered + updated_obj["color"] = rgb_color + + # Extract numpy arrays for grasp generation + points_array = np.asarray(pcd_filtered.points).astype(np.float32) # Nx3 XYZ coordinates + if pcd_filtered.has_colors(): + colors_array = np.asarray(pcd_filtered.colors).astype( + np.float32 + ) # Nx3 RGB (0-1 range) + else: + # If no colors, create array of zeros + colors_array = np.zeros((len(points_array), 3), dtype=np.float32) + + updated_obj["point_cloud_numpy"] = points_array + updated_obj["colors_numpy"] = colors_array # type: ignore[typeddict-unknown-key] + + updated_objects.append(updated_obj) + + return updated_objects + + def cleanup(self) -> None: + """Clean up resources.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/dimos/perception/pointcloud/test_pointcloud_filtering.py b/dimos/perception/pointcloud/test_pointcloud_filtering.py new file mode 100644 index 0000000000..4ac7e5cb2d --- /dev/null +++ b/dimos/perception/pointcloud/test_pointcloud_filtering.py @@ -0,0 +1,263 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING + +import cv2 +import numpy as np +import open3d as o3d +import pytest + +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering +from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml + +if TYPE_CHECKING: + from dimos.types.manipulation import ObjectData + + +class TestPointcloudFiltering: + def test_pointcloud_filtering_initialization(self) -> None: + """Test PointcloudFiltering initializes correctly with default parameters.""" + try: + filtering = PointcloudFiltering() + assert filtering is not None + assert filtering.color_weight == 0.3 + assert filtering.enable_statistical_filtering + assert filtering.enable_radius_filtering + assert filtering.enable_subsampling + except Exception as e: + pytest.skip(f"Skipping test due to initialization error: {e}") + + def test_pointcloud_filtering_with_custom_params(self) -> None: + """Test PointcloudFiltering with custom parameters.""" + try: + filtering = PointcloudFiltering( + color_weight=0.5, + enable_statistical_filtering=False, + enable_radius_filtering=False, + voxel_size=0.01, + max_num_objects=5, + ) + assert filtering.color_weight == 0.5 + assert not filtering.enable_statistical_filtering + assert not filtering.enable_radius_filtering + assert filtering.voxel_size == 0.01 + assert filtering.max_num_objects == 5 + except Exception as e: + pytest.skip(f"Skipping test due to initialization error: {e}") + + def test_pointcloud_filtering_process_images(self) -> None: + """Test PointcloudFiltering can process RGB-D images and return filtered point clouds.""" + try: + # Import data inside method to avoid pytest fixture confusion + from dimos.utils.data import get_data + + # Load test RGB-D data + data_dir = get_data("rgbd_frames") + + # Load first frame + color_path = os.path.join(data_dir, "color", "00000.png") + depth_path = os.path.join(data_dir, "depth", "00000.png") + intrinsics_path = os.path.join(data_dir, "color_camera_info.yaml") + + assert os.path.exists(color_path), f"Color image not found: {color_path}" + assert os.path.exists(depth_path), f"Depth image not found: {depth_path}" + assert os.path.exists(intrinsics_path), f"Intrinsics file not found: {intrinsics_path}" + + # Load images + color_img = cv2.imread(color_path) + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) + + depth_img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH) + if depth_img.dtype == np.uint16: + depth_img = depth_img.astype(np.float32) / 1000.0 + + # Load camera intrinsics + camera_matrix = load_camera_matrix_from_yaml(intrinsics_path) + if camera_matrix is None: + pytest.skip("Failed to load camera intrinsics") + + # Create mock objects with segmentation masks + height, width = color_img.shape[:2] + + # Create simple rectangular masks for testing + mock_objects = [] + + # Object 1: Top-left quadrant + mask1 = np.zeros((height, width), dtype=bool) + mask1[height // 4 : height // 2, width // 4 : width // 2] = True + + obj1: ObjectData = { + "object_id": 1, + "confidence": 0.9, + "bbox": [width // 4, height // 4, width // 2, height // 2], + "segmentation_mask": mask1, + "name": "test_object_1", + } + mock_objects.append(obj1) + + # Object 2: Bottom-right quadrant + mask2 = np.zeros((height, width), dtype=bool) + mask2[height // 2 : 3 * height // 4, width // 2 : 3 * width // 4] = True + + obj2: ObjectData = { + "object_id": 2, + "confidence": 0.8, + "bbox": [width // 2, height // 2, 3 * width // 4, 3 * height // 4], + "segmentation_mask": mask2, + "name": "test_object_2", + } + mock_objects.append(obj2) + + # Initialize filtering with intrinsics + filtering = PointcloudFiltering( + color_intrinsics=camera_matrix, + depth_intrinsics=camera_matrix, + enable_statistical_filtering=False, # Disable for faster testing + enable_radius_filtering=False, # Disable for faster testing + voxel_size=0.01, # Larger voxel for faster processing + ) + + # Process images + results = filtering.process_images(color_img, depth_img, mock_objects) + + print( + f"Processing results - Input objects: {len(mock_objects)}, Output objects: {len(results)}" + ) + + # Verify results + assert isinstance(results, list), "Results should be a list" + assert len(results) <= len(mock_objects), "Should not return more objects than input" + + # Check each result object + for i, result in enumerate(results): + print(f"Object {i}: {result.get('name', 'unknown')}") + + # Verify required fields exist + assert "point_cloud" in result, "Result should contain point_cloud" + assert "color" in result, "Result should contain color" + assert "point_cloud_numpy" in result, "Result should contain point_cloud_numpy" + + # Verify point cloud is valid Open3D object + pcd = result["point_cloud"] + assert isinstance(pcd, o3d.geometry.PointCloud), ( + "point_cloud should be Open3D PointCloud" + ) + + # Verify numpy arrays + points_array = result["point_cloud_numpy"] + assert isinstance(points_array, np.ndarray), ( + "point_cloud_numpy should be numpy array" + ) + assert points_array.shape[1] == 3, "Point array should have 3 columns (x,y,z)" + assert points_array.dtype == np.float32, "Point array should be float32" + + # Verify color + color = result["color"] + assert isinstance(color, np.ndarray), "Color should be numpy array" + assert color.shape == (3,), "Color should be RGB triplet" + assert color.dtype == np.uint8, "Color should be uint8" + + # Check if 3D information was added (when enough points for cuboid fitting) + points = np.asarray(pcd.points) + if len(points) >= filtering.min_points_for_cuboid: + if "position" in result: + assert "rotation" in result, "Should have rotation if position exists" + assert "size" in result, "Should have size if position exists" + + # Verify position format + from dimos.types.vector import Vector + + position = result["position"] + assert isinstance(position, Vector), "Position should be Vector" + + # Verify size format + size = result["size"] + assert isinstance(size, dict), "Size should be dict" + assert "width" in size and "height" in size and "depth" in size + + print(f" - Points: {len(points)}") + print(f" - Color: {color}") + if "position" in result: + print(f" - Position: {result['position']}") + print(f" - Size: {result['size']}") + + # Test full point cloud access + full_pcd = filtering.get_full_point_cloud() + if full_pcd is not None: + assert isinstance(full_pcd, o3d.geometry.PointCloud), ( + "Full point cloud should be Open3D PointCloud" + ) + full_points = np.asarray(full_pcd.points) + print(f"Full point cloud points: {len(full_points)}") + + print("All pointcloud filtering tests passed!") + + except Exception as e: + pytest.skip(f"Skipping test due to error: {e}") + + def test_pointcloud_filtering_empty_objects(self) -> None: + """Test PointcloudFiltering with empty object list.""" + try: + from dimos.utils.data import get_data + + # Load test data + data_dir = get_data("rgbd_frames") + color_path = os.path.join(data_dir, "color", "00000.png") + depth_path = os.path.join(data_dir, "depth", "00000.png") + + if not (os.path.exists(color_path) and os.path.exists(depth_path)): + pytest.skip("Test images not found") + + color_img = cv2.imread(color_path) + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) + depth_img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH) + if depth_img.dtype == np.uint16: + depth_img = depth_img.astype(np.float32) / 1000.0 + + filtering = PointcloudFiltering() + + # Test with empty object list + results = filtering.process_images(color_img, depth_img, []) + + assert isinstance(results, list), "Results should be a list" + assert len(results) == 0, "Should return empty list for empty input" + + except Exception as e: + pytest.skip(f"Skipping test due to error: {e}") + + def test_color_generation_consistency(self) -> None: + """Test that color generation is consistent for the same object ID.""" + try: + filtering = PointcloudFiltering() + + # Test color generation consistency + color1 = filtering.generate_color_from_id(42) + color2 = filtering.generate_color_from_id(42) + color3 = filtering.generate_color_from_id(43) + + assert np.array_equal(color1, color2), "Same ID should generate same color" + assert not np.array_equal(color1, color3), ( + "Different IDs should generate different colors" + ) + assert color1.shape == (3,), "Color should be RGB triplet" + assert color1.dtype == np.uint8, "Color should be uint8" + + except Exception as e: + pytest.skip(f"Skipping test due to error: {e}") + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/dimos/perception/pointcloud/utils.py b/dimos/perception/pointcloud/utils.py new file mode 100644 index 0000000000..3036e6df66 --- /dev/null +++ b/dimos/perception/pointcloud/utils.py @@ -0,0 +1,1113 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Point cloud utilities for RGBD data processing. + +This module provides efficient utilities for creating and manipulating point clouds +from RGBD images using Open3D. +""" + +import os +from typing import Any + +import cv2 +import numpy as np +import open3d as o3d # type: ignore[import-untyped] +from scipy.spatial import cKDTree +import yaml + +from dimos.perception.common.utils import project_3d_points_to_2d + + +def load_camera_matrix_from_yaml( + camera_info: str | list[float] | np.ndarray | dict | None, # type: ignore[type-arg] +) -> np.ndarray | None: # type: ignore[type-arg] + """ + Load camera intrinsic matrix from various input formats. + + Args: + camera_info: Can be: + - Path to YAML file containing camera parameters + - List of [fx, fy, cx, cy] + - 3x3 numpy array (returned as-is) + - Dict with camera parameters + - None (returns None) + + Returns: + 3x3 camera intrinsic matrix or None if input is None + + Raises: + ValueError: If camera_info format is invalid or file cannot be read + FileNotFoundError: If YAML file path doesn't exist + """ + if camera_info is None: + return None + + # Handle case where camera_info is already a matrix + if isinstance(camera_info, np.ndarray) and camera_info.shape == (3, 3): + return camera_info.astype(np.float32) + + # Handle case where camera_info is [fx, fy, cx, cy] format + if isinstance(camera_info, list) and len(camera_info) == 4: + fx, fy, cx, cy = camera_info + return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + # Handle case where camera_info is a dict + if isinstance(camera_info, dict): + return _extract_matrix_from_dict(camera_info) + + # Handle case where camera_info is a path to a YAML file + if isinstance(camera_info, str): + if not os.path.isfile(camera_info): + raise FileNotFoundError(f"Camera info file not found: {camera_info}") + + try: + with open(camera_info) as f: + data = yaml.safe_load(f) + return _extract_matrix_from_dict(data) + except Exception as e: + raise ValueError(f"Failed to read camera info from {camera_info}: {e}") + + raise ValueError( + f"Invalid camera_info format. Expected str, list, dict, or numpy array, got {type(camera_info)}" + ) + + +def _extract_matrix_from_dict(data: dict) -> np.ndarray: # type: ignore[type-arg] + """Extract camera matrix from dictionary with various formats.""" + # ROS format with 'K' field (most common) + if "K" in data: + k_data = data["K"] + if len(k_data) == 9: + return np.array(k_data, dtype=np.float32).reshape(3, 3) + + # Standard format with 'camera_matrix' + if "camera_matrix" in data: + if "data" in data["camera_matrix"]: + matrix_data = data["camera_matrix"]["data"] + if len(matrix_data) == 9: + return np.array(matrix_data, dtype=np.float32).reshape(3, 3) + + # Explicit intrinsics format + if all(k in data for k in ["fx", "fy", "cx", "cy"]): + fx, fy = float(data["fx"]), float(data["fy"]) + cx, cy = float(data["cx"]), float(data["cy"]) + return np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + # Error case - provide helpful debug info + available_keys = list(data.keys()) + if "K" in data: + k_info = f"K field length: {len(data['K']) if hasattr(data['K'], '__len__') else 'unknown'}" + else: + k_info = "K field not found" + + raise ValueError( + f"Cannot extract camera matrix from data. " + f"Available keys: {available_keys}. {k_info}. " + f"Expected formats: 'K' (9 elements), 'camera_matrix.data' (9 elements), " + f"or individual 'fx', 'fy', 'cx', 'cy' fields." + ) + + +def create_o3d_point_cloud_from_rgbd( + color_img: np.ndarray, # type: ignore[type-arg] + depth_img: np.ndarray, # type: ignore[type-arg] + intrinsic: np.ndarray, # type: ignore[type-arg] + depth_scale: float = 1.0, + depth_trunc: float = 3.0, +) -> o3d.geometry.PointCloud: + """ + Create an Open3D point cloud from RGB and depth images. + + Args: + color_img: RGB image as numpy array (H, W, 3) + depth_img: Depth image as numpy array (H, W) + intrinsic: Camera intrinsic matrix (3x3 numpy array) + depth_scale: Scale factor to convert depth to meters + depth_trunc: Maximum depth in meters + + Returns: + Open3D point cloud object + + Raises: + ValueError: If input dimensions are invalid + """ + # Validate inputs + if len(color_img.shape) != 3 or color_img.shape[2] != 3: + raise ValueError(f"color_img must be (H, W, 3), got {color_img.shape}") + if len(depth_img.shape) != 2: + raise ValueError(f"depth_img must be (H, W), got {depth_img.shape}") + if color_img.shape[:2] != depth_img.shape: + raise ValueError( + f"Color and depth image dimensions don't match: {color_img.shape[:2]} vs {depth_img.shape}" + ) + if intrinsic.shape != (3, 3): + raise ValueError(f"intrinsic must be (3, 3), got {intrinsic.shape}") + + # Convert to Open3D format + color_o3d = o3d.geometry.Image(color_img.astype(np.uint8)) + + # Filter out inf and nan values from depth image + depth_filtered = depth_img.copy() + + # Create mask for valid depth values (finite, positive, non-zero) + valid_mask = np.isfinite(depth_filtered) & (depth_filtered > 0) + + # Set invalid values to 0 (which Open3D treats as no depth) + depth_filtered[~valid_mask] = 0.0 + + depth_o3d = o3d.geometry.Image(depth_filtered.astype(np.float32)) + + # Create Open3D intrinsic object + height, width = color_img.shape[:2] + fx, fy = intrinsic[0, 0], intrinsic[1, 1] + cx, cy = intrinsic[0, 2], intrinsic[1, 2] + intrinsic_o3d = o3d.camera.PinholeCameraIntrinsic( + width, + height, + fx, + fy, # fx, fy + cx, + cy, # cx, cy + ) + + # Create RGBD image + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, + depth_o3d, + depth_scale=depth_scale, + depth_trunc=depth_trunc, + convert_rgb_to_intensity=False, + ) + + # Create point cloud + pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, intrinsic_o3d) + + return pcd + + +def create_point_cloud_and_extract_masks( + color_img: np.ndarray, # type: ignore[type-arg] + depth_img: np.ndarray, # type: ignore[type-arg] + masks: list[np.ndarray], # type: ignore[type-arg] + intrinsic: np.ndarray, # type: ignore[type-arg] + depth_scale: float = 1.0, + depth_trunc: float = 3.0, +) -> tuple[o3d.geometry.PointCloud, list[o3d.geometry.PointCloud]]: + """ + Efficiently create a point cloud once and extract multiple masked regions. + + Args: + color_img: RGB image (H, W, 3) + depth_img: Depth image (H, W) + masks: List of boolean masks, each of shape (H, W) + intrinsic: Camera intrinsic matrix (3x3 numpy array) + depth_scale: Scale factor to convert depth to meters + depth_trunc: Maximum depth in meters + + Returns: + Tuple of (full_point_cloud, list_of_masked_point_clouds) + """ + if not masks: + return o3d.geometry.PointCloud(), [] + + # Create the full point cloud + full_pcd = create_o3d_point_cloud_from_rgbd( + color_img, depth_img, intrinsic, depth_scale, depth_trunc + ) + + if len(np.asarray(full_pcd.points)) == 0: + return full_pcd, [o3d.geometry.PointCloud() for _ in masks] + + # Create pixel-to-point mapping + valid_depth_mask = np.isfinite(depth_img) & (depth_img > 0) & (depth_img <= depth_trunc) + + valid_depth = valid_depth_mask.flatten() + if not np.any(valid_depth): + return full_pcd, [o3d.geometry.PointCloud() for _ in masks] + + pixel_to_point = np.full(len(valid_depth), -1, dtype=np.int32) + pixel_to_point[valid_depth] = np.arange(np.sum(valid_depth)) + + # Extract point clouds for each mask + masked_pcds = [] + max_points = len(np.asarray(full_pcd.points)) + + for mask in masks: + if mask.shape != depth_img.shape: + masked_pcds.append(o3d.geometry.PointCloud()) + continue + + mask_flat = mask.flatten() + valid_mask_indices = mask_flat & valid_depth + point_indices = pixel_to_point[valid_mask_indices] + valid_point_indices = point_indices[point_indices >= 0] + + if len(valid_point_indices) > 0: + valid_point_indices = np.clip(valid_point_indices, 0, max_points - 1) + valid_point_indices = np.unique(valid_point_indices) + masked_pcd = full_pcd.select_by_index(valid_point_indices.tolist()) + else: + masked_pcd = o3d.geometry.PointCloud() + + masked_pcds.append(masked_pcd) + + return full_pcd, masked_pcds + + +def filter_point_cloud_statistical( + pcd: o3d.geometry.PointCloud, nb_neighbors: int = 20, std_ratio: float = 2.0 +) -> tuple[o3d.geometry.PointCloud, np.ndarray]: # type: ignore[type-arg] + """ + Apply statistical outlier filtering to point cloud. + + Args: + pcd: Input point cloud + nb_neighbors: Number of neighbors to analyze for each point + std_ratio: Threshold level based on standard deviation + + Returns: + Tuple of (filtered_point_cloud, outlier_indices) + """ + if len(np.asarray(pcd.points)) == 0: + return pcd, np.array([]) + + return pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio) # type: ignore[no-any-return] + + +def filter_point_cloud_radius( + pcd: o3d.geometry.PointCloud, nb_points: int = 16, radius: float = 0.05 +) -> tuple[o3d.geometry.PointCloud, np.ndarray]: # type: ignore[type-arg] + """ + Apply radius-based outlier filtering to point cloud. + + Args: + pcd: Input point cloud + nb_points: Minimum number of points within radius + radius: Search radius in meters + + Returns: + Tuple of (filtered_point_cloud, outlier_indices) + """ + if len(np.asarray(pcd.points)) == 0: + return pcd, np.array([]) + + return pcd.remove_radius_outlier(nb_points=nb_points, radius=radius) # type: ignore[no-any-return] + + +def overlay_point_clouds_on_image( + base_image: np.ndarray, # type: ignore[type-arg] + point_clouds: list[o3d.geometry.PointCloud], + camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] + colors: list[tuple[int, int, int]], + point_size: int = 2, + alpha: float = 0.7, +) -> np.ndarray: # type: ignore[type-arg] + """ + Overlay multiple colored point clouds onto an image. + + Args: + base_image: Base image to overlay onto (H, W, 3) - assumed to be RGB + point_clouds: List of Open3D point cloud objects + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] list or 3x3 matrix + colors: List of RGB color tuples for each point cloud. If None, generates distinct colors. + point_size: Size of points to draw (in pixels) + alpha: Blending factor for overlay (0.0 = fully transparent, 1.0 = fully opaque) + + Returns: + Image with overlaid point clouds (H, W, 3) + """ + if len(point_clouds) == 0: + return base_image.copy() + + # Create overlay image + overlay = base_image.copy() + height, width = base_image.shape[:2] + + # Process each point cloud + for i, pcd in enumerate(point_clouds): + if pcd is None: + continue + + points_3d = np.asarray(pcd.points) + if len(points_3d) == 0: + continue + + # Project 3D points to 2D + points_2d = project_3d_points_to_2d(points_3d, camera_intrinsics) + + if len(points_2d) == 0: + continue + + # Filter points within image bounds + valid_mask = ( + (points_2d[:, 0] >= 0) + & (points_2d[:, 0] < width) + & (points_2d[:, 1] >= 0) + & (points_2d[:, 1] < height) + ) + valid_points_2d = points_2d[valid_mask] + + if len(valid_points_2d) == 0: + continue + + # Get color for this point cloud + color = colors[i % len(colors)] + + # Ensure color is a tuple of integers for OpenCV + if isinstance(color, list | tuple | np.ndarray): + color = tuple(int(c) for c in color[:3]) # type: ignore[assignment] + else: + color = (255, 255, 255) + + # Draw points on overlay + for point in valid_points_2d: + u, v = point + # Draw a small filled circle for each point + cv2.circle(overlay, (u, v), point_size, color, -1) + + # Blend overlay with base image + result = cv2.addWeighted(base_image, 1 - alpha, overlay, alpha, 0) + + return result + + +def create_point_cloud_overlay_visualization( + base_image: np.ndarray, # type: ignore[type-arg] + objects: list[dict], # type: ignore[type-arg] + intrinsics: np.ndarray, # type: ignore[type-arg] +) -> np.ndarray: # type: ignore[type-arg] + """ + Create a visualization showing object point clouds and bounding boxes overlaid on a base image. + + Args: + base_image: Base image to overlay onto (H, W, 3) + objects: List of object dictionaries containing 'point_cloud', 'color', 'position', 'rotation', 'size' keys + intrinsics: Camera intrinsics as [fx, fy, cx, cy] or 3x3 matrix + + Returns: + Visualization image with overlaid point clouds and bounding boxes (H, W, 3) + """ + # Extract point clouds and colors from objects + point_clouds = [] + colors = [] + for obj in objects: + if "point_cloud" in obj and obj["point_cloud"] is not None: + point_clouds.append(obj["point_cloud"]) + + # Convert color to tuple + color = obj["color"] + if isinstance(color, np.ndarray): + color = tuple(int(c) for c in color) + elif isinstance(color, list | tuple): + color = tuple(int(c) for c in color[:3]) + colors.append(color) + + # Create visualization + if point_clouds: + result = overlay_point_clouds_on_image( + base_image=base_image, + point_clouds=point_clouds, + camera_intrinsics=intrinsics, + colors=colors, + point_size=3, + alpha=0.8, + ) + else: + result = base_image.copy() + + # Draw 3D bounding boxes + height_img, width_img = result.shape[:2] + for i, obj in enumerate(objects): + if all(key in obj and obj[key] is not None for key in ["position", "rotation", "size"]): + try: + # Create and project 3D bounding box + corners_3d = create_3d_bounding_box_corners( + obj["position"], obj["rotation"], obj["size"] + ) + corners_2d = project_3d_points_to_2d(corners_3d, intrinsics) + + # Check if any corners are visible + valid_mask = ( + (corners_2d[:, 0] >= 0) + & (corners_2d[:, 0] < width_img) + & (corners_2d[:, 1] >= 0) + & (corners_2d[:, 1] < height_img) + ) + + if np.any(valid_mask): + # Get color + bbox_color = colors[i] if i < len(colors) else (255, 255, 255) + draw_3d_bounding_box_on_image(result, corners_2d, bbox_color, thickness=2) + except: + continue + + return result + + +def create_3d_bounding_box_corners(position, rotation, size: int): # type: ignore[no-untyped-def] + """ + Create 8 corners of a 3D bounding box from position, rotation, and size. + + Args: + position: Vector or dict with x, y, z coordinates + rotation: Vector or dict with roll, pitch, yaw angles + size: Dict with width, height, depth + + Returns: + 8x3 numpy array of corner coordinates + """ + # Convert position to numpy array + if hasattr(position, "x"): # Vector object + center = np.array([position.x, position.y, position.z]) + else: # Dictionary + center = np.array([position["x"], position["y"], position["z"]]) + + # Convert rotation (euler angles) to rotation matrix + if hasattr(rotation, "x"): # Vector object (roll, pitch, yaw) + roll, pitch, yaw = rotation.x, rotation.y, rotation.z + else: # Dictionary + roll, pitch, yaw = rotation["roll"], rotation["pitch"], rotation["yaw"] + + # Create rotation matrix from euler angles (ZYX order) + cos_r, sin_r = np.cos(roll), np.sin(roll) + cos_p, sin_p = np.cos(pitch), np.sin(pitch) + cos_y, sin_y = np.cos(yaw), np.sin(yaw) + + # Rotation matrix for ZYX euler angles + R = np.array( + [ + [ + cos_y * cos_p, + cos_y * sin_p * sin_r - sin_y * cos_r, + cos_y * sin_p * cos_r + sin_y * sin_r, + ], + [ + sin_y * cos_p, + sin_y * sin_p * sin_r + cos_y * cos_r, + sin_y * sin_p * cos_r - cos_y * sin_r, + ], + [-sin_p, cos_p * sin_r, cos_p * cos_r], + ] + ) + + # Get dimensions + width = size.get("width", 0.1) # type: ignore[attr-defined] + height = size.get("height", 0.1) # type: ignore[attr-defined] + depth = size.get("depth", 0.1) # type: ignore[attr-defined] + + # Create 8 corners of the bounding box (before rotation) + corners = np.array( + [ + [-width / 2, -height / 2, -depth / 2], # 0 + [width / 2, -height / 2, -depth / 2], # 1 + [width / 2, height / 2, -depth / 2], # 2 + [-width / 2, height / 2, -depth / 2], # 3 + [-width / 2, -height / 2, depth / 2], # 4 + [width / 2, -height / 2, depth / 2], # 5 + [width / 2, height / 2, depth / 2], # 6 + [-width / 2, height / 2, depth / 2], # 7 + ] + ) + + # Apply rotation and translation + rotated_corners = corners @ R.T + center + + return rotated_corners + + +def draw_3d_bounding_box_on_image(image, corners_2d, color, thickness: int = 2) -> None: # type: ignore[no-untyped-def] + """ + Draw a 3D bounding box on an image using projected 2D corners. + + Args: + image: Image to draw on + corners_2d: 8x2 array of 2D corner coordinates + color: RGB color tuple + thickness: Line thickness + """ + # Define the 12 edges of a cube (connecting corner indices) + edges = [ + (0, 1), + (1, 2), + (2, 3), + (3, 0), # Bottom face + (4, 5), + (5, 6), + (6, 7), + (7, 4), # Top face + (0, 4), + (1, 5), + (2, 6), + (3, 7), # Vertical edges + ] + + # Draw each edge + for start_idx, end_idx in edges: + start_point = tuple(corners_2d[start_idx].astype(int)) + end_point = tuple(corners_2d[end_idx].astype(int)) + cv2.line(image, start_point, end_point, color, thickness) + + +def extract_and_cluster_misc_points( + full_pcd: o3d.geometry.PointCloud, + all_objects: list[dict], # type: ignore[type-arg] + eps: float = 0.03, + min_points: int = 100, + enable_filtering: bool = True, + voxel_size: float = 0.02, +) -> tuple[list[o3d.geometry.PointCloud], o3d.geometry.VoxelGrid]: + """ + Extract miscellaneous/background points and cluster them using DBSCAN. + + Args: + full_pcd: Complete scene point cloud + all_objects: List of objects with point clouds to subtract + eps: DBSCAN epsilon parameter (max distance between points in cluster) + min_points: DBSCAN min_samples parameter (min points to form cluster) + enable_filtering: Whether to apply statistical and radius filtering + voxel_size: Size of voxels for voxel grid generation + + Returns: + Tuple of (clustered_point_clouds, voxel_grid) + """ + if full_pcd is None or len(np.asarray(full_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + if not all_objects: + # If no objects detected, cluster the full point cloud + clusters = _cluster_point_cloud_dbscan(full_pcd, eps, min_points) + voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) + return clusters, voxel_grid + + try: + # Start with a copy of the full point cloud + misc_pcd = o3d.geometry.PointCloud(full_pcd) + + # Remove object points by combining all object point clouds + all_object_points = [] + for obj in all_objects: + if "point_cloud" in obj and obj["point_cloud"] is not None: + obj_points = np.asarray(obj["point_cloud"].points) + if len(obj_points) > 0: + all_object_points.append(obj_points) + + if not all_object_points: + # No object points to remove, cluster full point cloud + clusters = _cluster_point_cloud_dbscan(misc_pcd, eps, min_points) + voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) + return clusters, voxel_grid + + # Combine all object points + combined_obj_points = np.vstack(all_object_points) + + # For efficiency, downsample both point clouds + misc_downsampled = misc_pcd.voxel_down_sample(voxel_size=0.005) + + # Create object point cloud for efficient operations + obj_pcd = o3d.geometry.PointCloud() + obj_pcd.points = o3d.utility.Vector3dVector(combined_obj_points) + obj_downsampled = obj_pcd.voxel_down_sample(voxel_size=0.005) + + misc_points = np.asarray(misc_downsampled.points) + obj_points_down = np.asarray(obj_downsampled.points) + + if len(misc_points) == 0 or len(obj_points_down) == 0: + clusters = _cluster_point_cloud_dbscan(misc_downsampled, eps, min_points) + voxel_grid = _create_voxel_grid_from_clusters(clusters, voxel_size) + return clusters, voxel_grid + + # Build tree for object points + obj_tree = cKDTree(obj_points_down) + + # Find distances from misc points to nearest object points + distances, _ = obj_tree.query(misc_points, k=1) + + # Keep points that are far enough from any object point + threshold = 0.015 # 1.5cm threshold + keep_mask = distances > threshold + + if not np.any(keep_mask): + return [], o3d.geometry.VoxelGrid() + + # Filter misc points + misc_indices = np.where(keep_mask)[0] + final_misc_pcd = misc_downsampled.select_by_index(misc_indices) + + if len(np.asarray(final_misc_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + # Apply additional filtering if enabled + if enable_filtering: + # Apply statistical outlier filtering + filtered_misc_pcd, _ = filter_point_cloud_statistical( + final_misc_pcd, nb_neighbors=30, std_ratio=2.0 + ) + + if len(np.asarray(filtered_misc_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + # Apply radius outlier filtering + final_filtered_misc_pcd, _ = filter_point_cloud_radius( + filtered_misc_pcd, + nb_points=20, + radius=0.03, # 3cm radius + ) + + if len(np.asarray(final_filtered_misc_pcd.points)) == 0: + return [], o3d.geometry.VoxelGrid() + + final_misc_pcd = final_filtered_misc_pcd + + # Cluster the misc points using DBSCAN + clusters = _cluster_point_cloud_dbscan(final_misc_pcd, eps, min_points) + + # Create voxel grid from all misc points (before clustering) + voxel_grid = _create_voxel_grid_from_point_cloud(final_misc_pcd, voxel_size) + + return clusters, voxel_grid + + except Exception as e: + print(f"Error in misc point extraction and clustering: {e}") + # Fallback: return downsampled full point cloud as single cluster + try: + downsampled = full_pcd.voxel_down_sample(voxel_size=0.02) + if len(np.asarray(downsampled.points)) > 0: + voxel_grid = _create_voxel_grid_from_point_cloud(downsampled, voxel_size) + return [downsampled], voxel_grid + else: + return [], o3d.geometry.VoxelGrid() + except: + return [], o3d.geometry.VoxelGrid() + + +def _create_voxel_grid_from_point_cloud( + pcd: o3d.geometry.PointCloud, voxel_size: float = 0.02 +) -> o3d.geometry.VoxelGrid: + """ + Create a voxel grid from a point cloud. + + Args: + pcd: Input point cloud + voxel_size: Size of each voxel + + Returns: + Open3D VoxelGrid object + """ + if len(np.asarray(pcd.points)) == 0: + return o3d.geometry.VoxelGrid() + + try: + # Create voxel grid from point cloud + voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size) + + # Color the voxels with a semi-transparent gray + for voxel in voxel_grid.get_voxels(): + voxel.color = [0.5, 0.5, 0.5] # Gray color + + print( + f"Created voxel grid with {len(voxel_grid.get_voxels())} voxels (voxel_size={voxel_size})" + ) + return voxel_grid + + except Exception as e: + print(f"Error creating voxel grid: {e}") + return o3d.geometry.VoxelGrid() + + +def _create_voxel_grid_from_clusters( + clusters: list[o3d.geometry.PointCloud], voxel_size: float = 0.02 +) -> o3d.geometry.VoxelGrid: + """ + Create a voxel grid from multiple clustered point clouds. + + Args: + clusters: List of clustered point clouds + voxel_size: Size of each voxel + + Returns: + Open3D VoxelGrid object + """ + if not clusters: + return o3d.geometry.VoxelGrid() + + # Combine all clusters into one point cloud + combined_points = [] + for cluster in clusters: + points = np.asarray(cluster.points) + if len(points) > 0: + combined_points.append(points) + + if not combined_points: + return o3d.geometry.VoxelGrid() + + # Create combined point cloud + all_points = np.vstack(combined_points) + combined_pcd = o3d.geometry.PointCloud() + combined_pcd.points = o3d.utility.Vector3dVector(all_points) + + return _create_voxel_grid_from_point_cloud(combined_pcd, voxel_size) + + +def _cluster_point_cloud_dbscan( + pcd: o3d.geometry.PointCloud, eps: float = 0.05, min_points: int = 50 +) -> list[o3d.geometry.PointCloud]: + """ + Cluster a point cloud using DBSCAN and return list of clustered point clouds. + + Args: + pcd: Point cloud to cluster + eps: DBSCAN epsilon parameter + min_points: DBSCAN min_samples parameter + + Returns: + List of point clouds, one for each cluster + """ + if len(np.asarray(pcd.points)) == 0: + return [] + + try: + # Apply DBSCAN clustering + labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_points)) + + # Get unique cluster labels (excluding noise points labeled as -1) + unique_labels = np.unique(labels) + cluster_pcds = [] + + for label in unique_labels: + if label == -1: # Skip noise points + continue + + # Get indices for this cluster + cluster_indices = np.where(labels == label)[0] + + if len(cluster_indices) > 0: + # Create point cloud for this cluster + cluster_pcd = pcd.select_by_index(cluster_indices) + + # Assign a random color to this cluster + cluster_color = np.random.rand(3) # Random RGB color + cluster_pcd.paint_uniform_color(cluster_color) + + cluster_pcds.append(cluster_pcd) + + print( + f"DBSCAN clustering found {len(cluster_pcds)} clusters from {len(np.asarray(pcd.points))} points" + ) + return cluster_pcds + + except Exception as e: + print(f"Error in DBSCAN clustering: {e}") + return [pcd] # Return original point cloud as fallback + + +def get_standard_coordinate_transform(): # type: ignore[no-untyped-def] + """ + Get a standard coordinate transformation matrix for consistent visualization. + + This transformation ensures that: + - X (red) axis points right + - Y (green) axis points up + - Z (blue) axis points toward viewer + + Returns: + 4x4 transformation matrix + """ + # Standard transformation matrix to ensure consistent coordinate frame orientation + transform = np.array( + [ + [1, 0, 0, 0], # X points right + [0, -1, 0, 0], # Y points up (flip from OpenCV to standard) + [0, 0, -1, 0], # Z points toward viewer (flip depth) + [0, 0, 0, 1], + ] + ) + return transform + + +def visualize_clustered_point_clouds( + clustered_pcds: list[o3d.geometry.PointCloud], + window_name: str = "Clustered Point Clouds", + point_size: float = 2.0, + show_coordinate_frame: bool = True, + coordinate_frame_size: float = 0.1, +) -> None: + """ + Visualize multiple clustered point clouds with different colors. + + Args: + clustered_pcds: List of point clouds (already colored) + window_name: Name of the visualization window + point_size: Size of points in the visualization + show_coordinate_frame: Whether to show coordinate frame + coordinate_frame_size: Size of the coordinate frame + """ + if not clustered_pcds: + print("Warning: No clustered point clouds to visualize") + return + + # Apply standard coordinate transformation + transform = get_standard_coordinate_transform() # type: ignore[no-untyped-call] + geometries = [] + for pcd in clustered_pcds: + pcd_copy = o3d.geometry.PointCloud(pcd) + pcd_copy.transform(transform) + geometries.append(pcd_copy) + + # Add coordinate frame + if show_coordinate_frame: + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=coordinate_frame_size + ) + coordinate_frame.transform(transform) + geometries.append(coordinate_frame) + + total_points = sum(len(np.asarray(pcd.points)) for pcd in clustered_pcds) + print(f"Visualizing {len(clustered_pcds)} clusters with {total_points} total points") + + try: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=window_name, width=1280, height=720) + for geom in geometries: + vis.add_geometry(geom) + render_option = vis.get_render_option() + render_option.point_size = point_size + vis.run() + vis.destroy_window() + except Exception as e: + print(f"Failed to create interactive visualization: {e}") + o3d.visualization.draw_geometries( + geometries, window_name=window_name, width=1280, height=720 + ) + + +def visualize_pcd( + pcd: o3d.geometry.PointCloud, + window_name: str = "Point Cloud Visualization", + point_size: float = 1.0, + show_coordinate_frame: bool = True, + coordinate_frame_size: float = 0.1, +) -> None: + """ + Visualize an Open3D point cloud using Open3D's visualization window. + + Args: + pcd: Open3D point cloud to visualize + window_name: Name of the visualization window + point_size: Size of points in the visualization + show_coordinate_frame: Whether to show coordinate frame + coordinate_frame_size: Size of the coordinate frame + """ + if pcd is None: + print("Warning: Point cloud is None, nothing to visualize") + return + + if len(np.asarray(pcd.points)) == 0: + print("Warning: Point cloud is empty, nothing to visualize") + return + + # Apply standard coordinate transformation + transform = get_standard_coordinate_transform() # type: ignore[no-untyped-call] + pcd_copy = o3d.geometry.PointCloud(pcd) + pcd_copy.transform(transform) + geometries = [pcd_copy] + + # Add coordinate frame + if show_coordinate_frame: + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=coordinate_frame_size + ) + coordinate_frame.transform(transform) + geometries.append(coordinate_frame) + + print(f"Visualizing point cloud with {len(np.asarray(pcd.points))} points") + + try: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=window_name, width=1280, height=720) + for geom in geometries: + vis.add_geometry(geom) + render_option = vis.get_render_option() + render_option.point_size = point_size + vis.run() + vis.destroy_window() + except Exception as e: + print(f"Failed to create interactive visualization: {e}") + o3d.visualization.draw_geometries( + geometries, window_name=window_name, width=1280, height=720 + ) + + +def visualize_voxel_grid( + voxel_grid: o3d.geometry.VoxelGrid, + window_name: str = "Voxel Grid Visualization", + show_coordinate_frame: bool = True, + coordinate_frame_size: float = 0.1, +) -> None: + """ + Visualize an Open3D voxel grid using Open3D's visualization window. + + Args: + voxel_grid: Open3D voxel grid to visualize + window_name: Name of the visualization window + show_coordinate_frame: Whether to show coordinate frame + coordinate_frame_size: Size of the coordinate frame + """ + if voxel_grid is None: + print("Warning: Voxel grid is None, nothing to visualize") + return + + if len(voxel_grid.get_voxels()) == 0: + print("Warning: Voxel grid is empty, nothing to visualize") + return + + # VoxelGrid doesn't support transform, so we need to transform the source points instead + # For now, just visualize as-is with transformed coordinate frame + geometries = [voxel_grid] + + # Add coordinate frame + if show_coordinate_frame: + coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=coordinate_frame_size + ) + coordinate_frame.transform(get_standard_coordinate_transform()) # type: ignore[no-untyped-call] + geometries.append(coordinate_frame) + + print(f"Visualizing voxel grid with {len(voxel_grid.get_voxels())} voxels") + + try: + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=window_name, width=1280, height=720) + for geom in geometries: + vis.add_geometry(geom) + vis.run() + vis.destroy_window() + except Exception as e: + print(f"Failed to create interactive visualization: {e}") + o3d.visualization.draw_geometries( + geometries, window_name=window_name, width=1280, height=720 + ) + + +def combine_object_pointclouds( + point_clouds: list[np.ndarray] | list[o3d.geometry.PointCloud], # type: ignore[type-arg] + colors: list[np.ndarray] | None = None, # type: ignore[type-arg] +) -> o3d.geometry.PointCloud: + """ + Combine multiple point clouds into a single Open3D point cloud. + + Args: + point_clouds: List of point clouds as numpy arrays or Open3D point clouds + colors: List of colors as numpy arrays + Returns: + Combined Open3D point cloud + """ + all_points = [] + all_colors = [] + + for i, pcd in enumerate(point_clouds): + if isinstance(pcd, np.ndarray): + points = pcd[:, :3] + all_points.append(points) + if colors: + all_colors.append(colors[i]) + + elif isinstance(pcd, o3d.geometry.PointCloud): + points = np.asarray(pcd.points) + all_points.append(points) + if pcd.has_colors(): + colors = np.asarray(pcd.colors) # type: ignore[assignment] + all_colors.append(colors) # type: ignore[arg-type] + + if not all_points: + return o3d.geometry.PointCloud() + + combined_pcd = o3d.geometry.PointCloud() + combined_pcd.points = o3d.utility.Vector3dVector(np.vstack(all_points)) + + if all_colors: + combined_pcd.colors = o3d.utility.Vector3dVector(np.vstack(all_colors)) + + return combined_pcd + + +def extract_centroids_from_masks( + rgb_image: np.ndarray, # type: ignore[type-arg] + depth_image: np.ndarray, # type: ignore[type-arg] + masks: list[np.ndarray], # type: ignore[type-arg] + camera_intrinsics: list[float] | np.ndarray, # type: ignore[type-arg] +) -> list[dict[str, Any]]: + """ + Extract 3D centroids and orientations from segmentation masks. + + Args: + rgb_image: RGB image (H, W, 3) + depth_image: Depth image (H, W) in meters + masks: List of boolean masks (H, W) + camera_intrinsics: Camera parameters as [fx, fy, cx, cy] or 3x3 matrix + + Returns: + List of dictionaries containing: + - centroid: 3D centroid position [x, y, z] in camera frame + - orientation: Normalized direction vector from camera to centroid + - num_points: Number of valid 3D points + - mask_idx: Index of the mask in the input list + """ + # Extract camera parameters + if isinstance(camera_intrinsics, list) and len(camera_intrinsics) == 4: + fx, fy, cx, cy = camera_intrinsics + else: + fx = camera_intrinsics[0, 0] # type: ignore[call-overload] + fy = camera_intrinsics[1, 1] # type: ignore[call-overload] + cx = camera_intrinsics[0, 2] # type: ignore[call-overload] + cy = camera_intrinsics[1, 2] # type: ignore[call-overload] + + results = [] + + for mask_idx, mask in enumerate(masks): + if mask is None or mask.sum() == 0: + continue + + # Get pixel coordinates where mask is True + y_coords, x_coords = np.where(mask) + + # Get depth values at mask locations + depths = depth_image[y_coords, x_coords] + + # Convert to 3D points in camera frame + X = (x_coords - cx) * depths / fx + Y = (y_coords - cy) * depths / fy + Z = depths + + # Calculate centroid + centroid_x = np.mean(X) + centroid_y = np.mean(Y) + centroid_z = np.mean(Z) + centroid = np.array([centroid_x, centroid_y, centroid_z]) + + # Calculate orientation as normalized direction from camera origin to centroid + # Camera origin is at (0, 0, 0) + orientation = centroid / np.linalg.norm(centroid) + + results.append( + { + "centroid": centroid, + "orientation": orientation, + "num_points": int(mask.sum()), + "mask_idx": mask_idx, + } + ) + + return results diff --git a/dimos/perception/segmentation/__init__.py b/dimos/perception/segmentation/__init__.py new file mode 100644 index 0000000000..a48a76d6a4 --- /dev/null +++ b/dimos/perception/segmentation/__init__.py @@ -0,0 +1,2 @@ +from .sam_2d_seg import * +from .utils import * diff --git a/dimos/perception/segmentation/config/custom_tracker.yaml b/dimos/perception/segmentation/config/custom_tracker.yaml new file mode 100644 index 0000000000..7a6748ebf6 --- /dev/null +++ b/dimos/perception/segmentation/config/custom_tracker.yaml @@ -0,0 +1,21 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Default Ultralytics settings for BoT-SORT tracker when using mode="track" +# For documentation and examples see https://docs.ultralytics.com/modes/track/ +# For BoT-SORT source code see https://github.com/NirAharon/BoT-SORT + +tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.4 # threshold for the first association +track_low_thresh: 0.2 # threshold for the second association +new_track_thresh: 0.5 # threshold for init new track if the detection does not match any tracks +track_buffer: 100 # buffer to calculate the time when to remove tracks +match_thresh: 0.4 # threshold for matching tracks +fuse_score: False # Whether to fuse confidence scores with the iou distances before matching +# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) + +# BoT-SORT settings +gmc_method: sparseOptFlow # method of global motion compensation +# ReID model related thresh (not supported yet) +proximity_thresh: 0.6 +appearance_thresh: 0.35 +with_reid: False diff --git a/dimos/perception/segmentation/image_analyzer.py b/dimos/perception/segmentation/image_analyzer.py new file mode 100644 index 0000000000..06db712ac7 --- /dev/null +++ b/dimos/perception/segmentation/image_analyzer.py @@ -0,0 +1,162 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import os + +import cv2 +from openai import OpenAI + +NORMAL_PROMPT = "What are in these images? Give a short word answer with at most two words, \ + if not sure, give a description of its shape or color like 'small tube', 'blue item'. \" \ + if does not look like an object, say 'unknown'. Export objects as a list of strings \ + in this exact format '['object 1', 'object 2', '...']'." + +RICH_PROMPT = ( + "What are in these images? Give a detailed description of each item, the first n images will be \ + cropped patches of the original image detected by the object detection model. \ + The last image will be the original image. Use the last image only for context, \ + do not describe objects in the last image. \ + Export the objects as a list of strings in this exact format, '['description of object 1', '...', '...']', \ + don't include anything else. " +) + + +class ImageAnalyzer: + def __init__(self) -> None: + """ + Initializes the ImageAnalyzer with OpenAI API credentials. + """ + self.client = OpenAI() + + def encode_image(self, image): # type: ignore[no-untyped-def] + """ + Encodes an image to Base64. + + Parameters: + image (numpy array): Image array (BGR format). + + Returns: + str: Base64 encoded string of the image. + """ + _, buffer = cv2.imencode(".jpg", image) + return base64.b64encode(buffer).decode("utf-8") + + def analyze_images(self, images, detail: str = "auto", prompt_type: str = "normal"): # type: ignore[no-untyped-def] + """ + Takes a list of cropped images and returns descriptions from OpenAI's Vision model. + + Parameters: + images (list of numpy arrays): Cropped images from the original frame. + detail (str): "low", "high", or "auto" to set image processing detail. + prompt_type (str): "normal" or "rich" to set the prompt type. + + Returns: + list of str: Descriptions of objects in each image. + """ + image_data = [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{self.encode_image(img)}", # type: ignore[no-untyped-call] + "detail": detail, + }, + } + for img in images + ] + + if prompt_type == "normal": + prompt = NORMAL_PROMPT + elif prompt_type == "rich": + prompt = RICH_PROMPT + else: + raise ValueError(f"Invalid prompt type: {prompt_type}") + + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { # type: ignore[list-item, misc] + "role": "user", + "content": [{"type": "text", "text": prompt}, *image_data], + } + ], + max_tokens=300, + timeout=5, + ) + + # Accessing the content of the response using dot notation + return next(choice.message.content for choice in response.choices) + + +def main() -> None: + # Define the directory containing cropped images + cropped_images_dir = "cropped_images" + if not os.path.exists(cropped_images_dir): + print(f"Directory '{cropped_images_dir}' does not exist.") + return + + # Load all images from the directory + images = [] + for filename in os.listdir(cropped_images_dir): + if filename.endswith(".jpg") or filename.endswith(".png"): + image_path = os.path.join(cropped_images_dir, filename) + image = cv2.imread(image_path) + if image is not None: + images.append(image) + else: + print(f"Warning: Could not read image {image_path}") + + if not images: + print("No valid images found in the directory.") + return + + # Initialize ImageAnalyzer + analyzer = ImageAnalyzer() + + # Analyze images + results = analyzer.analyze_images(images) + + # Split results into a list of items + object_list = [item.strip()[2:] for item in results.split("\n")] + + # Overlay text on images and display them + for i, (img, obj) in enumerate(zip(images, object_list, strict=False)): + if obj: # Only process non-empty lines + # Add text to image + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + thickness = 2 + text = obj.strip() + + # Get text size + (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) + + # Position text at top of image + x = 10 + y = text_height + 10 + + # Add white background for text + cv2.rectangle( + img, (x - 5, y - text_height - 5), (x + text_width + 5, y + 5), (255, 255, 255), -1 + ) + # Add text + cv2.putText(img, text, (x, y), font, font_scale, (0, 0, 0), thickness) + + # Save or display the image + cv2.imwrite(f"annotated_image_{i}.jpg", img) + print(f"Detected object: {obj}") + + +if __name__ == "__main__": + main() diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py new file mode 100644 index 0000000000..171da88780 --- /dev/null +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -0,0 +1,360 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +import os +import time + +import cv2 +import onnxruntime # type: ignore[import-untyped] +from ultralytics import FastSAM + +from dimos.perception.common.detection2d_tracker import get_tracked_results, target2dTracker +from dimos.perception.segmentation.image_analyzer import ImageAnalyzer +from dimos.perception.segmentation.utils import ( + crop_images_from_bboxes, + extract_masks_bboxes_probs_names, + filter_segmentation_results, + plot_results, +) +from dimos.utils.data import get_data +from dimos.utils.gpu_utils import is_cuda_available +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Sam2DSegmenter: + def __init__( + self, + model_path: str = "models_fastsam", + model_name: str = "FastSAM-s.onnx", + min_analysis_interval: float = 5.0, + use_tracker: bool = True, + use_analyzer: bool = True, + use_rich_labeling: bool = False, + use_filtering: bool = True, + ) -> None: + if is_cuda_available(): # type: ignore[no-untyped-call] + logger.info("Using CUDA for SAM 2d segmenter") + if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 + onnxruntime.preload_dlls(cuda=True, cudnn=True) + self.device = "cuda" + else: + logger.info("Using CPU for SAM 2d segmenter") + self.device = "cpu" + # Core components + self.model = FastSAM(get_data(model_path) / model_name) + self.use_tracker = use_tracker + self.use_analyzer = use_analyzer + self.use_rich_labeling = use_rich_labeling + self.use_filtering = use_filtering + + module_dir = os.path.dirname(__file__) + self.tracker_config = os.path.join(module_dir, "config", "custom_tracker.yaml") + + # Initialize tracker if enabled + if self.use_tracker: + self.tracker = target2dTracker( + history_size=80, + score_threshold_start=0.7, + score_threshold_stop=0.05, + min_frame_count=10, + max_missed_frames=50, + min_area_ratio=0.05, + max_area_ratio=0.4, + texture_range=(0.0, 0.35), + border_safe_distance=100, + weights={"prob": 1.0, "temporal": 3.0, "texture": 2.0, "border": 3.0, "size": 1.0}, + ) + + # Initialize analyzer components if enabled + if self.use_analyzer: + self.image_analyzer = ImageAnalyzer() + self.min_analysis_interval = min_analysis_interval + self.last_analysis_time = 0 + self.to_be_analyzed = deque() # type: ignore[var-annotated] + self.object_names = {} # type: ignore[var-annotated] + self.analysis_executor = ThreadPoolExecutor(max_workers=1) + self.current_future = None + self.current_queue_ids = None + + def process_image(self, image): # type: ignore[no-untyped-def] + """Process an image and return segmentation results.""" + results = self.model.track( + source=image, + device=self.device, + retina_masks=True, + conf=0.3, + iou=0.5, + persist=True, + verbose=False, + ) + + if len(results) > 0: + # Get initial segmentation results + masks, bboxes, track_ids, probs, names, areas = extract_masks_bboxes_probs_names( + results[0] + ) + + # Filter results + if self.use_filtering: + ( + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + filtered_texture_values, + ) = filter_segmentation_results( + image, masks, bboxes, track_ids, probs, names, areas + ) + else: + # Use original results without filtering + filtered_masks = masks + filtered_bboxes = bboxes + filtered_track_ids = track_ids + filtered_probs = probs + filtered_names = names + filtered_texture_values = [] + + if self.use_tracker: + # Update tracker with filtered results + tracked_targets = self.tracker.update( + image, + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + filtered_texture_values, + ) + + # Get tracked results + tracked_masks, tracked_bboxes, tracked_target_ids, tracked_probs, tracked_names = ( + get_tracked_results(tracked_targets) # type: ignore[no-untyped-call] + ) + + if self.use_analyzer: + # Update analysis queue with tracked IDs + target_id_set = set(tracked_target_ids) + + # Remove untracked objects from object_names + all_target_ids = list(self.tracker.targets.keys()) + self.object_names = { + track_id: name + for track_id, name in self.object_names.items() + if track_id in all_target_ids + } + + # Remove untracked objects from queue and results + self.to_be_analyzed = deque( + [track_id for track_id in self.to_be_analyzed if track_id in target_id_set] + ) + + # Filter out any IDs being analyzed from the to_be_analyzed queue + if self.current_queue_ids: + self.to_be_analyzed = deque( + [ + tid + for tid in self.to_be_analyzed + if tid not in self.current_queue_ids + ] + ) + + # Add new track_ids to analysis queue + for track_id in tracked_target_ids: + if ( + track_id not in self.object_names + and track_id not in self.to_be_analyzed + ): + self.to_be_analyzed.append(track_id) + + return ( + tracked_masks, + tracked_bboxes, + tracked_target_ids, + tracked_probs, + tracked_names, + ) + else: + # When tracker disabled, just use the filtered results directly + if self.use_analyzer: + # Add unanalyzed IDs to the analysis queue + for track_id in filtered_track_ids: + if ( + track_id not in self.object_names + and track_id not in self.to_be_analyzed + ): + self.to_be_analyzed.append(track_id) + + # Simply return filtered results + return ( + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + ) + return [], [], [], [], [] + + def check_analysis_status(self, tracked_target_ids): # type: ignore[no-untyped-def] + """Check if analysis is complete and prepare new queue if needed.""" + if not self.use_analyzer: + return None, None + + current_time = time.time() + + # Check if current queue analysis is complete + if self.current_future and self.current_future.done(): + try: + results = self.current_future.result() + if results is not None: + # Map results to track IDs + object_list = eval(results) + for track_id, result in zip(self.current_queue_ids, object_list, strict=False): + self.object_names[track_id] = result + except Exception as e: + print(f"Queue analysis failed: {e}") + self.current_future = None + self.current_queue_ids = None + self.last_analysis_time = current_time + + # If enough time has passed and we have items to analyze, start new analysis + if ( + not self.current_future + and self.to_be_analyzed + and current_time - self.last_analysis_time >= self.min_analysis_interval + ): + queue_indices = [] + queue_ids = [] + + # Collect all valid track IDs from the queue + while self.to_be_analyzed: + track_id = self.to_be_analyzed[0] + if track_id in tracked_target_ids: + bbox_idx = tracked_target_ids.index(track_id) + queue_indices.append(bbox_idx) + queue_ids.append(track_id) + self.to_be_analyzed.popleft() + + if queue_indices: + return queue_indices, queue_ids + return None, None + + def run_analysis(self, frame, tracked_bboxes, tracked_target_ids) -> None: # type: ignore[no-untyped-def] + """Run queue image analysis in background.""" + if not self.use_analyzer: + return + + queue_indices, queue_ids = self.check_analysis_status(tracked_target_ids) # type: ignore[no-untyped-call] + if queue_indices: + selected_bboxes = [tracked_bboxes[i] for i in queue_indices] + cropped_images = crop_images_from_bboxes(frame, selected_bboxes) + if cropped_images: + self.current_queue_ids = queue_ids + print(f"Analyzing objects with track_ids: {queue_ids}") + + if self.use_rich_labeling: + prompt_type = "rich" + cropped_images.append(frame) + else: + prompt_type = "normal" + + self.current_future = self.analysis_executor.submit( # type: ignore[assignment] + self.image_analyzer.analyze_images, cropped_images, prompt_type=prompt_type + ) + + def get_object_names(self, track_ids, tracked_names: Sequence[str]): # type: ignore[no-untyped-def] + """Get object names for the given track IDs, falling back to tracked names.""" + if not self.use_analyzer: + return tracked_names + + return [ + self.object_names.get(track_id, tracked_name) + for track_id, tracked_name in zip(track_ids, tracked_names, strict=False) + ] + + def visualize_results( # type: ignore[no-untyped-def] + self, image, masks, bboxes, track_ids, probs: Sequence[float], names: Sequence[str] + ): + """Generate an overlay visualization with segmentation results and object names.""" + return plot_results(image, masks, bboxes, track_ids, probs, names) + + def cleanup(self) -> None: + """Cleanup resources.""" + if self.use_analyzer: + self.analysis_executor.shutdown() + + +def main() -> None: + # Example usage with different configurations + cap = cv2.VideoCapture(0) + + # Example 1: Full functionality with rich labeling + segmenter = Sam2DSegmenter( + min_analysis_interval=4.0, + use_tracker=True, + use_analyzer=True, + use_rich_labeling=True, # Enable rich labeling + ) + + # Example 2: Full functionality with normal labeling + # segmenter = Sam2DSegmenter(min_analysis_interval=4.0, use_tracker=True, use_analyzer=True) + + # Example 3: Tracker only (analyzer disabled) + # segmenter = Sam2DSegmenter(use_analyzer=False) + + # Example 4: Basic segmentation only (both tracker and analyzer disabled) + # segmenter = Sam2DSegmenter(use_tracker=False, use_analyzer=False) + + # Example 5: Analyzer without tracker (new capability) + # segmenter = Sam2DSegmenter(use_tracker=False, use_analyzer=True) + + try: + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + time.time() + + # Process image and get results + masks, bboxes, target_ids, probs, names = segmenter.process_image(frame) # type: ignore[no-untyped-call] + + # Run analysis if enabled + if segmenter.use_analyzer: + segmenter.run_analysis(frame, bboxes, target_ids) + names = segmenter.get_object_names(target_ids, names) + + # processing_time = time.time() - start_time + # print(f"Processing time: {processing_time:.2f}s") + + overlay = segmenter.visualize_results(frame, masks, bboxes, target_ids, probs, names) + + cv2.imshow("Segmentation", overlay) + key = cv2.waitKey(1) + if key & 0xFF == ord("q"): + break + + finally: + segmenter.cleanup() + cap.release() + cv2.destroyAllWindows() + + +if __name__ == "__main__": + main() diff --git a/dimos/perception/segmentation/test_sam_2d_seg.py b/dimos/perception/segmentation/test_sam_2d_seg.py new file mode 100644 index 0000000000..a9222ed2f2 --- /dev/null +++ b/dimos/perception/segmentation/test_sam_2d_seg.py @@ -0,0 +1,210 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import numpy as np +import pytest +from reactivex import operators as ops + +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.perception.segmentation.utils import extract_masks_bboxes_probs_names +from dimos.stream.video_provider import VideoProvider + + +@pytest.mark.heavy +class TestSam2DSegmenter: + def test_sam_segmenter_initialization(self) -> None: + """Test FastSAM segmenter initializes correctly with default model path.""" + try: + # Try to initialize with the default model path and existing device setting + segmenter = Sam2DSegmenter(use_analyzer=False) + assert segmenter is not None + assert segmenter.model is not None + except Exception as e: + # If the model file doesn't exist, the test should still pass with a warning + pytest.skip(f"Skipping test due to model initialization error: {e}") + + def test_sam_segmenter_process_image(self) -> None: + """Test FastSAM segmenter can process video frames and return segmentation masks.""" + # Import get data inside method to avoid pytest fixture confusion + from dimos.utils.data import get_data + + # Get test video path directly + video_path = get_data("assets") / "trimmed_video_office.mov" + try: + # Initialize segmenter without analyzer for faster testing + segmenter = Sam2DSegmenter(use_analyzer=False) + + # Note: conf and iou are parameters for process_image, not constructor + # We'll monkey patch the process_image method to use lower thresholds + + def patched_process_image(image): + results = segmenter.model.track( + source=image, + device=segmenter.device, + retina_masks=True, + conf=0.1, # Lower confidence threshold for testing + iou=0.5, # Lower IoU threshold + persist=True, + verbose=False, + tracker=segmenter.tracker_config + if hasattr(segmenter, "tracker_config") + else None, + ) + + if len(results) > 0: + masks, bboxes, track_ids, probs, names, _areas = ( + extract_masks_bboxes_probs_names(results[0]) + ) + return masks, bboxes, track_ids, probs, names + return [], [], [], [], [] + + # Replace the method + segmenter.process_image = patched_process_image + + # Create video provider and directly get a video stream observable + assert os.path.exists(video_path), f"Test video not found: {video_path}" + video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=1) + + # Use ReactiveX operators to process the stream + def process_frame(frame): + try: + # Process frame with FastSAM + masks, bboxes, track_ids, probs, names = segmenter.process_image(frame) + print( + f"SAM results - masks: {len(masks)}, bboxes: {len(bboxes)}, track_ids: {len(track_ids)}, names: {len(names)}" + ) + + return { + "frame": frame, + "masks": masks, + "bboxes": bboxes, + "track_ids": track_ids, + "probs": probs, + "names": names, + } + except Exception as e: + print(f"Error in process_frame: {e}") + return {} + + # Create the segmentation stream using pipe and map operator + segmentation_stream = video_stream.pipe(ops.map(process_frame)) + + # Collect results from the stream + results = [] + frames_processed = 0 + target_frames = 5 + + def on_next(result) -> None: + nonlocal frames_processed, results + if not result: + return + + results.append(result) + frames_processed += 1 + + # Stop processing after target frames + if frames_processed >= target_frames: + subscription.dispose() + + def on_error(error) -> None: + pytest.fail(f"Error in segmentation stream: {error}") + + def on_completed() -> None: + pass + + # Subscribe and wait for results + subscription = segmentation_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + # Wait for frames to be processed + timeout = 30.0 # seconds + start_time = time.time() + while frames_processed < target_frames and time.time() - start_time < timeout: + time.sleep(0.5) + + # Clean up subscription + subscription.dispose() + video_provider.dispose_all() + + # Check if we have results + if len(results) == 0: + pytest.skip( + "No segmentation results found, but test connection established correctly" + ) + return + + print(f"Processed {len(results)} frames with segmentation results") + + # Analyze the first result + result = results[0] + + # Check that we have a frame + assert "frame" in result, "Result doesn't contain a frame" + assert isinstance(result["frame"], np.ndarray), "Frame is not a numpy array" + + # Check that segmentation results are valid + assert isinstance(result["masks"], list) + assert isinstance(result["bboxes"], list) + assert isinstance(result["track_ids"], list) + assert isinstance(result["probs"], list) + assert isinstance(result["names"], list) + + # All result lists should be the same length + assert ( + len(result["masks"]) + == len(result["bboxes"]) + == len(result["track_ids"]) + == len(result["probs"]) + == len(result["names"]) + ) + + # If we have masks, check that they have valid shape + if result.get("masks") and len(result["masks"]) > 0: + assert result["masks"][0].shape == ( + result["frame"].shape[0], + result["frame"].shape[1], + ), "Mask shape should match image dimensions" + print(f"Found {len(result['masks'])} masks in first frame") + else: + print("No masks found in first frame, but test connection established correctly") + + # Test visualization function + if result["masks"]: + vis_frame = segmenter.visualize_results( + result["frame"], + result["masks"], + result["bboxes"], + result["track_ids"], + result["probs"], + result["names"], + ) + assert isinstance(vis_frame, np.ndarray), "Visualization output should be an image" + assert vis_frame.shape == result["frame"].shape, ( + "Visualization should have same dimensions as input frame" + ) + + # We've already tested visualization above, so no need for a duplicate test + + except Exception as e: + pytest.skip(f"Skipping test due to error: {e}") + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/dimos/perception/segmentation/utils.py b/dimos/perception/segmentation/utils.py new file mode 100644 index 0000000000..a23a256ca2 --- /dev/null +++ b/dimos/perception/segmentation/utils.py @@ -0,0 +1,343 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence + +import cv2 +import numpy as np +import torch + + +class SimpleTracker: + def __init__( + self, history_size: int = 100, min_count: int = 10, count_window: int = 20 + ) -> None: + """ + Simple temporal tracker that counts appearances in a fixed window. + :param history_size: Number of past frames to remember + :param min_count: Minimum number of appearances required + :param count_window: Number of latest frames to consider for counting + """ + self.history = [] # type: ignore[var-annotated] + self.history_size = history_size + self.min_count = min_count + self.count_window = count_window + self.total_counts = {} # type: ignore[var-annotated] + + def update(self, track_ids): # type: ignore[no-untyped-def] + # Add new frame's track IDs to history + self.history.append(track_ids) + if len(self.history) > self.history_size: + self.history.pop(0) + + # Consider only the latest `count_window` frames for counting + recent_history = self.history[-self.count_window :] + all_tracks = np.concatenate(recent_history) if recent_history else np.array([]) + + # Compute occurrences efficiently using numpy + unique_ids, counts = np.unique(all_tracks, return_counts=True) + id_counts = dict(zip(unique_ids, counts, strict=False)) + + # Update total counts but ensure it only contains IDs within the history size + total_tracked_ids = np.concatenate(self.history) if self.history else np.array([]) + unique_total_ids, total_counts = np.unique(total_tracked_ids, return_counts=True) + self.total_counts = dict(zip(unique_total_ids, total_counts, strict=False)) + + # Return IDs that appear often enough + return [track_id for track_id, count in id_counts.items() if count >= self.min_count] + + def get_total_counts(self): # type: ignore[no-untyped-def] + """Returns the total count of each tracking ID seen over time, limited to history size.""" + return self.total_counts + + +def extract_masks_bboxes_probs_names(result, max_size: float = 0.7): # type: ignore[no-untyped-def] + """ + Extracts masks, bounding boxes, probabilities, and class names from one Ultralytics result object. + + Parameters: + result: Ultralytics result object + max_size: float, maximum allowed size of object relative to image (0-1) + + Returns: + tuple: (masks, bboxes, track_ids, probs, names, areas) + """ + masks = [] # type: ignore[var-annotated] + bboxes = [] # type: ignore[var-annotated] + track_ids = [] # type: ignore[var-annotated] + probs = [] # type: ignore[var-annotated] + names = [] # type: ignore[var-annotated] + areas = [] # type: ignore[var-annotated] + + if result.masks is None: + return masks, bboxes, track_ids, probs, names, areas + + total_area = result.masks.orig_shape[0] * result.masks.orig_shape[1] + + for box, mask_data in zip(result.boxes, result.masks.data, strict=False): + mask_numpy = mask_data + + # Extract bounding box + x1, y1, x2, y2 = box.xyxy[0].tolist() + + # Extract track_id if available + track_id = -1 # default if no tracking + if hasattr(box, "id") and box.id is not None: + track_id = int(box.id[0].item()) + + # Extract probability and class index + conf = float(box.conf[0]) + cls_idx = int(box.cls[0]) + area = (x2 - x1) * (y2 - y1) + + if area / total_area > max_size: + continue + + masks.append(mask_numpy) + bboxes.append([x1, y1, x2, y2]) + track_ids.append(track_id) + probs.append(conf) + names.append(result.names[cls_idx]) + areas.append(area) + + return masks, bboxes, track_ids, probs, names, areas + + +def compute_texture_map(frame, blur_size: int = 3): # type: ignore[no-untyped-def] + """ + Compute texture map using gradient statistics. + Returns high values for textured regions and low values for smooth regions. + + Parameters: + frame: BGR image + blur_size: Size of Gaussian blur kernel for pre-processing + + Returns: + numpy array: Texture map with values normalized to [0,1] + """ + # Convert to grayscale + if len(frame.shape) == 3: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + else: + gray = frame + + # Pre-process with slight blur to reduce noise + if blur_size > 0: + gray = cv2.GaussianBlur(gray, (blur_size, blur_size), 0) + + # Compute gradients in x and y directions + grad_x = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) + + # Compute gradient magnitude and direction + magnitude = np.sqrt(grad_x**2 + grad_y**2) + + # Compute local standard deviation of gradient magnitude + texture_map = cv2.GaussianBlur(magnitude, (15, 15), 0) + + # Normalize to [0,1] + texture_map = (texture_map - texture_map.min()) / (texture_map.max() - texture_map.min() + 1e-8) + + return texture_map + + +def filter_segmentation_results( # type: ignore[no-untyped-def] + frame, + masks, + bboxes, + track_ids, + probs: Sequence[float], + names: Sequence[str], + areas, + texture_threshold: float = 0.07, + size_filter: int = 800, +): + """ + Filters segmentation results using both overlap and saliency detection. + Uses mask_sum tensor for efficient overlap detection. + + Parameters: + masks: list of torch.Tensor containing mask data + bboxes: list of bounding boxes [x1, y1, x2, y2] + track_ids: list of tracking IDs + probs: list of confidence scores + names: list of class names + areas: list of object areas + frame: BGR image for computing saliency + texture_threshold: Average texture value required for mask to be kept + size_filter: Minimum size of the object to be kept + + Returns: + tuple: (filtered_masks, filtered_bboxes, filtered_track_ids, filtered_probs, filtered_names, filtered_texture_values, texture_map) + """ + if len(masks) <= 1: + return masks, bboxes, track_ids, probs, names, [] + + # Compute texture map once and convert to tensor + texture_map = compute_texture_map(frame) + + # Sort by area (smallest to largest) + sorted_indices = torch.tensor(areas).argsort(descending=False) + + device = masks[0].device # Get the device of the first mask + + # Create mask_sum tensor where each pixel stores the index of the mask that claims it + mask_sum = torch.zeros_like(masks[0], dtype=torch.int32) + + texture_map = torch.from_numpy(texture_map).to( + device + ) # Convert texture_map to tensor and move to device + + filtered_texture_values = [] # List to store texture values of filtered masks + + for i, idx in enumerate(sorted_indices): + mask = masks[idx] + # Compute average texture value within mask + texture_value = torch.mean(texture_map[mask > 0]) if torch.any(mask > 0) else 0 + + # Only claim pixels if mask passes texture threshold + if texture_value >= texture_threshold: + mask_sum[mask > 0] = i + filtered_texture_values.append( + texture_value.item() # type: ignore[union-attr] + ) # Store the texture value as a Python float + + # Get indices that appear in mask_sum (these are the masks we want to keep) + keep_indices, counts = torch.unique(mask_sum[mask_sum > 0], return_counts=True) + size_indices = counts > size_filter + keep_indices = keep_indices[size_indices] + + sorted_indices = sorted_indices.cpu() + keep_indices = keep_indices.cpu() + + # Map back to original indices and filter + final_indices = sorted_indices[keep_indices].tolist() + + filtered_masks = [masks[i] for i in final_indices] + filtered_bboxes = [bboxes[i] for i in final_indices] + filtered_track_ids = [track_ids[i] for i in final_indices] + filtered_probs = [probs[i] for i in final_indices] + filtered_names = [names[i] for i in final_indices] + + return ( + filtered_masks, + filtered_bboxes, + filtered_track_ids, + filtered_probs, + filtered_names, + filtered_texture_values, + ) + + +def plot_results( # type: ignore[no-untyped-def] + image, + masks, + bboxes, + track_ids, + probs: Sequence[float], + names: Sequence[str], + alpha: float = 0.5, +): + """ + Draws bounding boxes, masks, and labels on the given image with enhanced visualization. + Includes object names in the overlay and improved text visibility. + """ + h, w = image.shape[:2] + overlay = image.copy() + + for mask, bbox, track_id, prob, name in zip( + masks, bboxes, track_ids, probs, names, strict=False + ): + # Convert mask tensor to numpy if needed + if isinstance(mask, torch.Tensor): + mask = mask.cpu().numpy() + + # Ensure mask is in proper format for OpenCV resize + if mask.dtype == bool: + mask = mask.astype(np.uint8) + elif mask.dtype != np.uint8 and mask.dtype != np.float32: + mask = mask.astype(np.float32) + + mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR) + + # Generate consistent color based on track_id + if track_id != -1: + np.random.seed(track_id) + color = np.random.randint(0, 255, (3,), dtype=np.uint8) + np.random.seed(None) + else: + color = np.random.randint(0, 255, (3,), dtype=np.uint8) + + # Apply mask color + overlay[mask_resized > 0.5] = color + + # Draw bounding box + x1, y1, x2, y2 = map(int, bbox) + cv2.rectangle(overlay, (x1, y1), (x2, y2), color.tolist(), 2) + + # Prepare label text + label = f"ID:{track_id} {prob:.2f}" + if name: # Add object name if available + label += f" {name}" + + # Calculate text size for background rectangle + (text_w, text_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Draw background rectangle for text + cv2.rectangle(overlay, (x1, y1 - text_h - 8), (x1 + text_w + 4, y1), color.tolist(), -1) + + # Draw text with white color for better visibility + cv2.putText( + overlay, + label, + (x1 + 2, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), # White text + 1, + ) + + # Blend overlay with original image + result = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) + return result + + +def crop_images_from_bboxes(image, bboxes, buffer: int = 0): # type: ignore[no-untyped-def] + """ + Crops regions from an image based on bounding boxes with an optional buffer. + + Parameters: + image (numpy array): Input image. + bboxes (list of lists): List of bounding boxes [x1, y1, x2, y2]. + buffer (int): Number of pixels to expand each bounding box. + + Returns: + list of numpy arrays: Cropped image regions. + """ + height, width, _ = image.shape + cropped_images = [] + + for bbox in bboxes: + x1, y1, x2, y2 = bbox + + # Apply buffer + x1 = max(0, x1 - buffer) + y1 = max(0, y1 - buffer) + x2 = min(width, x2 + buffer) + y2 = min(height, y2 + buffer) + + cropped_image = image[int(y1) : int(y2), int(x1) : int(x2)] + cropped_images.append(cropped_image) + + return cropped_images diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py new file mode 100644 index 0000000000..d1027172ee --- /dev/null +++ b/dimos/perception/spatial_perception.py @@ -0,0 +1,587 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Spatial Memory module for creating a semantic map of the environment. +""" + +from datetime import datetime +import os +import time +from typing import TYPE_CHECKING, Any, Optional +import uuid + +import cv2 +import numpy as np +from reactivex import Observable, interval, operators as ops +from reactivex.disposable import Disposable + +from dimos import spec +from dimos.agents.memory.image_embedding import ImageEmbeddingProvider +from dimos.agents.memory.spatial_vector_db import SpatialVectorDB +from dimos.agents.memory.visual_memory import VisualMemory +from dimos.constants import DIMOS_PROJECT_ROOT +from dimos.core import DimosCluster, In, Module, rpc +from dimos.msgs.sensor_msgs import Image +from dimos.types.robot_location import RobotLocation +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs import Vector3 + +_OUTPUT_DIR = DIMOS_PROJECT_ROOT / "assets" / "output" +_MEMORY_DIR = _OUTPUT_DIR / "memory" +_SPATIAL_MEMORY_DIR = _MEMORY_DIR / "spatial_memory" +_DB_PATH = _SPATIAL_MEMORY_DIR / "chromadb_data" +_VISUAL_MEMORY_PATH = _SPATIAL_MEMORY_DIR / "visual_memory.pkl" + + +logger = setup_logger() + + +class SpatialMemory(Module): + """ + A Dask module for building and querying Robot spatial memory. + + This module processes video frames and odometry data from LCM streams, + associates them with XY locations, and stores them in a vector database + for later retrieval via RPC calls. It also maintains a list of named + robot locations that can be queried by name. + """ + + # LCM inputs + color_image: In[Image] = None # type: ignore[assignment] + + def __init__( + self, + collection_name: str = "spatial_memory", + embedding_model: str = "clip", + embedding_dimensions: int = 512, + min_distance_threshold: float = 0.01, # Min distance in meters to store a new frame + min_time_threshold: float = 1.0, # Min time in seconds to record a new frame + db_path: str | None = str(_DB_PATH), # Path for ChromaDB persistence + visual_memory_path: str | None = str( + _VISUAL_MEMORY_PATH + ), # Path for saving/loading visual memory + new_memory: bool = True, # Whether to create a new memory from scratch + output_dir: str | None = str( + _SPATIAL_MEMORY_DIR + ), # Directory for storing visual memory data + chroma_client: Any = None, # Optional ChromaDB client for persistence + visual_memory: Optional[ + "VisualMemory" + ] = None, # Optional VisualMemory instance for storing images + ) -> None: + """ + Initialize the spatial perception system. + + Args: + collection_name: Name of the vector database collection + embedding_model: Model to use for image embeddings ("clip", "resnet", etc.) + embedding_dimensions: Dimensions of the embedding vectors + min_distance_threshold: Minimum distance in meters to record a new frame + min_time_threshold: Minimum time in seconds to record a new frame + chroma_client: Optional ChromaDB client for persistent storage + visual_memory: Optional VisualMemory instance for storing images + output_dir: Directory for storing visual memory data if visual_memory is not provided + """ + self.collection_name = collection_name + self.embedding_model = embedding_model + self.embedding_dimensions = embedding_dimensions + self.min_distance_threshold = min_distance_threshold + self.min_time_threshold = min_time_threshold + + # Set up paths for persistence + # Call parent Module init + super().__init__() + + self.db_path = db_path + self.visual_memory_path = visual_memory_path + + # Setup ChromaDB client if not provided + self._chroma_client = chroma_client + if chroma_client is None and db_path is not None: + # Create db directory if needed + os.makedirs(db_path, exist_ok=True) + + # Clean up existing DB if creating new memory + if new_memory and os.path.exists(db_path): + try: + logger.info("Creating new ChromaDB database (new_memory=True)") + # Try to delete any existing database files + import shutil + + for item in os.listdir(db_path): + item_path = os.path.join(db_path, item) + if os.path.isfile(item_path): + os.unlink(item_path) + elif os.path.isdir(item_path): + shutil.rmtree(item_path) + logger.info(f"Removed existing ChromaDB files from {db_path}") + except Exception as e: + logger.error(f"Error clearing ChromaDB directory: {e}") + + import chromadb + from chromadb.config import Settings + + self._chroma_client = chromadb.PersistentClient( + path=db_path, settings=Settings(anonymized_telemetry=False) + ) + + # Initialize or load visual memory + self._visual_memory = visual_memory + if visual_memory is None: + if new_memory or not os.path.exists(visual_memory_path or ""): + logger.info("Creating new visual memory") + self._visual_memory = VisualMemory(output_dir=output_dir) + else: + try: + logger.info(f"Loading existing visual memory from {visual_memory_path}...") + self._visual_memory = VisualMemory.load( + visual_memory_path, # type: ignore[arg-type] + output_dir=output_dir, + ) + logger.info(f"Loaded {self._visual_memory.count()} images from previous runs") + except Exception as e: + logger.error(f"Error loading visual memory: {e}") + self._visual_memory = VisualMemory(output_dir=output_dir) + + self.embedding_provider: ImageEmbeddingProvider = ImageEmbeddingProvider( + model_name=embedding_model, dimensions=embedding_dimensions + ) + + self.vector_db: SpatialVectorDB = SpatialVectorDB( + collection_name=collection_name, + chroma_client=self._chroma_client, + visual_memory=self._visual_memory, + embedding_provider=self.embedding_provider, + ) + + self.last_position: Vector3 | None = None + self.last_record_time: float | None = None + + self.frame_count: int = 0 + self.stored_frame_count: int = 0 + + # List to store robot locations + self.robot_locations: list[RobotLocation] = [] + + # Track latest data for processing + self._latest_video_frame: np.ndarray | None = None # type: ignore[type-arg] + self._process_interval = 1 + + logger.info(f"SpatialMemory initialized with model {embedding_model}") + + @rpc + def start(self) -> None: + super().start() + + # Subscribe to LCM streams + def set_video(image_msg: Image) -> None: + # Convert Image message to numpy array + if hasattr(image_msg, "data"): + frame = image_msg.data + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + self._latest_video_frame = frame + else: + logger.warning("Received image message without data attribute") + + unsub = self.color_image.subscribe(set_video) + self._disposables.add(Disposable(unsub)) + + # Start periodic processing using interval + unsub = interval(self._process_interval).subscribe(lambda _: self._process_frame()) # type: ignore[assignment] + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + # Save data before shutdown + self.save() + + if self._visual_memory: + self._visual_memory.clear() + + super().stop() + + def _process_frame(self) -> None: + """Process the latest frame with pose data if available.""" + tf = self.tf.get("map", "base_link") + if self._latest_video_frame is None or tf is None: + return + + # Create Pose object with position and orientation + current_pose = tf.to_pose() + + # Process the frame directly + try: + self.frame_count += 1 + + # Check distance constraint + if self.last_position is not None: + distance_moved = np.linalg.norm( + [ + current_pose.position.x - self.last_position.x, + current_pose.position.y - self.last_position.y, + current_pose.position.z - self.last_position.z, + ] + ) + if distance_moved < self.min_distance_threshold: + logger.debug( + f"Position has not moved enough: {distance_moved:.4f}m < {self.min_distance_threshold}m, skipping frame" + ) + return + + # Check time constraint + if self.last_record_time is not None: + time_elapsed = time.time() - self.last_record_time + if time_elapsed < self.min_time_threshold: + logger.debug( + f"Time since last record too short: {time_elapsed:.2f}s < {self.min_time_threshold}s, skipping frame" + ) + return + + current_time = time.time() + + # Get embedding for the frame + frame_embedding = self.embedding_provider.get_embedding(self._latest_video_frame) + + frame_id = f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + # Get euler angles from quaternion orientation for metadata + euler = tf.rotation.to_euler() + + # Create metadata dictionary with primitive types only + metadata = { + "pos_x": float(current_pose.position.x), + "pos_y": float(current_pose.position.y), + "pos_z": float(current_pose.position.z), + "rot_x": float(euler.x), + "rot_y": float(euler.y), + "rot_z": float(euler.z), + "timestamp": current_time, + "frame_id": frame_id, + } + + # Store in vector database + self.vector_db.add_image_vector( + vector_id=frame_id, + image=self._latest_video_frame, + embedding=frame_embedding, + metadata=metadata, + ) + + # Update tracking variables + self.last_position = current_pose.position + self.last_record_time = current_time + self.stored_frame_count += 1 + + logger.info( + f"Stored frame at position ({current_pose.position.x:.2f}, {current_pose.position.y:.2f}, {current_pose.position.z:.2f}), " + f"rotation ({euler.x:.2f}, {euler.y:.2f}, {euler.z:.2f}) " + f"stored {self.stored_frame_count}/{self.frame_count} frames" + ) + + # Periodically save visual memory to disk + if self._visual_memory is not None and self.visual_memory_path is not None: + if self.stored_frame_count % 100 == 0: + self.save() + + except Exception as e: + logger.error(f"Error processing frame: {e}") + + @rpc + def query_by_location( + self, x: float, y: float, radius: float = 2.0, limit: int = 5 + ) -> list[dict]: # type: ignore[type-arg] + """ + Query the vector database for images near the specified location. + + Args: + x: X coordinate + y: Y coordinate + radius: Search radius in meters + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + return self.vector_db.query_by_location(x, y, radius, limit) + + @rpc + def save(self) -> bool: + """ + Save the visual memory component to disk. + + Returns: + True if memory was saved successfully, False otherwise + """ + if self._visual_memory is not None and self.visual_memory_path is not None: + try: + saved_path = self._visual_memory.save(self.visual_memory_path) + logger.info(f"Saved {self._visual_memory.count()} images to {saved_path}") + return True + except Exception as e: + logger.error(f"Failed to save visual memory: {e}") + return False + + def process_stream(self, combined_stream: Observable) -> Observable: # type: ignore[type-arg] + """ + Process a combined stream of video frames and positions. + + This method handles a stream where each item already contains both the frame and position, + such as the stream created by combining video and transform streams with the + with_latest_from operator. + + Args: + combined_stream: Observable stream of dictionaries containing 'frame' and 'position' + + Returns: + Observable of processing results, including the stored frame and its metadata + """ + + def process_combined_data(data): # type: ignore[no-untyped-def] + self.frame_count += 1 + + frame = data.get("frame") + position_vec = data.get("position") # Use .get() for consistency + rotation_vec = data.get("rotation") # Get rotation data if available + + if position_vec is None or rotation_vec is None: + logger.info("No position or rotation data available, skipping frame") + return None + + # position_vec is already a Vector3, no need to recreate it + position_v3 = position_vec + + if self.last_position is not None: + distance_moved = np.linalg.norm( + [ + position_v3.x - self.last_position.x, + position_v3.y - self.last_position.y, + position_v3.z - self.last_position.z, + ] + ) + if distance_moved < self.min_distance_threshold: + logger.debug("Position has not moved, skipping frame") + return None + + if ( + self.last_record_time is not None + and (time.time() - self.last_record_time) < self.min_time_threshold + ): + logger.debug("Time since last record too short, skipping frame") + return None + + current_time = time.time() + + frame_embedding = self.embedding_provider.get_embedding(frame) + + frame_id = f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + + # Create metadata dictionary with primitive types only + metadata = { + "pos_x": float(position_v3.x), + "pos_y": float(position_v3.y), + "pos_z": float(position_v3.z), + "rot_x": float(rotation_vec.x), + "rot_y": float(rotation_vec.y), + "rot_z": float(rotation_vec.z), + "timestamp": current_time, + "frame_id": frame_id, + } + + self.vector_db.add_image_vector( + vector_id=frame_id, image=frame, embedding=frame_embedding, metadata=metadata + ) + + self.last_position = position_v3 + self.last_record_time = current_time + self.stored_frame_count += 1 + + logger.info( + f"Stored frame at position ({position_v3.x:.2f}, {position_v3.y:.2f}, {position_v3.z:.2f}), " + f"rotation ({rotation_vec.x:.2f}, {rotation_vec.y:.2f}, {rotation_vec.z:.2f}) " + f"stored {self.stored_frame_count}/{self.frame_count} frames" + ) + + # Create return dictionary with primitive-compatible values + return { + "frame": frame, + "position": (position_v3.x, position_v3.y, position_v3.z), + "rotation": (rotation_vec.x, rotation_vec.y, rotation_vec.z), + "frame_id": frame_id, + "timestamp": current_time, + } + + return combined_stream.pipe( + ops.map(process_combined_data), ops.filter(lambda result: result is not None) + ) + + @rpc + def query_by_image(self, image: np.ndarray, limit: int = 5) -> list[dict]: # type: ignore[type-arg] + """ + Query the vector database for images similar to the provided image. + + Args: + image: Query image + limit: Maximum number of results to return + + Returns: + List of results, each containing the image and its metadata + """ + embedding = self.embedding_provider.get_embedding(image) + return self.vector_db.query_by_embedding(embedding, limit) + + @rpc + def query_by_text(self, text: str, limit: int = 5) -> list[dict]: # type: ignore[type-arg] + """ + Query the vector database for images matching the provided text description. + + This method uses CLIP's text-to-image matching capability to find images + that semantically match the text query (e.g., "where is the kitchen"). + + Args: + text: Text query to search for + limit: Maximum number of results to return + + Returns: + List of results, each containing the image, its metadata, and similarity score + """ + logger.info(f"Querying spatial memory with text: '{text}'") + return self.vector_db.query_by_text(text, limit) + + @rpc + def add_robot_location(self, location: RobotLocation) -> bool: + """ + Add a named robot location to spatial memory. + + Args: + location: The RobotLocation object to add + + Returns: + True if successfully added, False otherwise + """ + try: + # Add to our list of robot locations + self.robot_locations.append(location) + logger.info(f"Added robot location '{location.name}' at position {location.position}") + return True + + except Exception as e: + logger.error(f"Error adding robot location: {e}") + return False + + @rpc + def add_named_location( + self, + name: str, + position: list[float] | None = None, + rotation: list[float] | None = None, + description: str | None = None, + ) -> bool: + """ + Add a named robot location to spatial memory using current or specified position. + + Args: + name: Name of the location + position: Optional position [x, y, z], uses current position if None + rotation: Optional rotation [roll, pitch, yaw], uses current rotation if None + description: Optional description of the location + + Returns: + True if successfully added, False otherwise + """ + tf = self.tf.get("map", "base_link") + if not tf: + logger.error("No position available for robot location") + return False + + # Create RobotLocation object + location = RobotLocation( # type: ignore[call-arg] + name=name, + position=tf.translation, + rotation=tf.rotation.to_euler(), + description=description or f"Location: {name}", + timestamp=time.time(), + ) + + return self.add_robot_location(location) # type: ignore[no-any-return] + + @rpc + def get_robot_locations(self) -> list[RobotLocation]: + """ + Get all stored robot locations. + + Returns: + List of RobotLocation objects + """ + return self.robot_locations + + @rpc + def find_robot_location(self, name: str) -> RobotLocation | None: + """ + Find a robot location by name. + + Args: + name: Name of the location to find + + Returns: + RobotLocation object if found, None otherwise + """ + # Simple search through our list of locations + for location in self.robot_locations: + if location.name.lower() == name.lower(): + return location + + return None + + @rpc + def get_stats(self) -> dict[str, int]: + """Get statistics about the spatial memory module. + + Returns: + Dictionary containing: + - frame_count: Total number of frames processed + - stored_frame_count: Number of frames actually stored + """ + return {"frame_count": self.frame_count, "stored_frame_count": self.stored_frame_count} + + @rpc + def tag_location(self, robot_location: RobotLocation) -> bool: + try: + self.vector_db.tag_location(robot_location) + except Exception: + return False + return True + + @rpc + def query_tagged_location(self, query: str) -> RobotLocation | None: + location, semantic_distance = self.vector_db.query_tagged_location(query) + if semantic_distance < 0.3: + return location + return None + + +def deploy( # type: ignore[no-untyped-def] + dimos: DimosCluster, + camera: spec.Camera, +): + spatial_memory = dimos.deploy(SpatialMemory, db_path="/tmp/spatial_memory_db") # type: ignore[attr-defined] + spatial_memory.color_image.connect(camera.color_image) + spatial_memory.start() + return spatial_memory + + +spatial_memory = SpatialMemory.blueprint + +__all__ = ["SpatialMemory", "deploy", "spatial_memory"] diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py new file mode 100644 index 0000000000..d4b188ced3 --- /dev/null +++ b/dimos/perception/test_spatial_memory.py @@ -0,0 +1,202 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import time + +import numpy as np +import pytest +from reactivex import operators as ops + +from dimos.msgs.geometry_msgs import Pose +from dimos.perception.spatial_perception import SpatialMemory +from dimos.stream.video_provider import VideoProvider + + +@pytest.mark.heavy +class TestSpatialMemory: + @pytest.fixture(scope="class") + def temp_dir(self): + # Create a temporary directory for storing spatial memory data + temp_dir = tempfile.mkdtemp() + yield temp_dir + # Clean up + shutil.rmtree(temp_dir) + + @pytest.fixture(scope="class") + def spatial_memory(self, temp_dir): + # Create a single SpatialMemory instance to be reused across all tests + memory = SpatialMemory( + collection_name="test_collection", + embedding_model="clip", + new_memory=True, + db_path=os.path.join(temp_dir, "chroma_db"), + visual_memory_path=os.path.join(temp_dir, "visual_memory.pkl"), + output_dir=os.path.join(temp_dir, "images"), + min_distance_threshold=0.01, + min_time_threshold=0.01, + ) + yield memory + # Clean up + memory.stop() + + def test_spatial_memory_initialization(self, spatial_memory) -> None: + """Test SpatialMemory initializes correctly with CLIP model.""" + # Use the shared spatial_memory fixture + assert spatial_memory is not None + assert spatial_memory.embedding_model == "clip" + assert spatial_memory.embedding_provider is not None + + def test_image_embedding(self, spatial_memory) -> None: + """Test generating image embeddings using CLIP.""" + # Use the shared spatial_memory fixture + # Create a test image - use a simple colored square + test_image = np.zeros((224, 224, 3), dtype=np.uint8) + test_image[50:150, 50:150] = [0, 0, 255] # Blue square + + # Generate embedding + embedding = spatial_memory.embedding_provider.get_embedding(test_image) + + # Check embedding shape and characteristics + assert embedding is not None + assert isinstance(embedding, np.ndarray) + assert embedding.shape[0] == spatial_memory.embedding_dimensions + + # Check that embedding is normalized (unit vector) + assert np.isclose(np.linalg.norm(embedding), 1.0, atol=1e-5) + + # Test text embedding + text_embedding = spatial_memory.embedding_provider.get_text_embedding("a blue square") + assert text_embedding is not None + assert isinstance(text_embedding, np.ndarray) + assert text_embedding.shape[0] == spatial_memory.embedding_dimensions + assert np.isclose(np.linalg.norm(text_embedding), 1.0, atol=1e-5) + + def test_spatial_memory_processing(self, spatial_memory, temp_dir) -> None: + """Test processing video frames and building spatial memory with CLIP embeddings.""" + try: + # Use the shared spatial_memory fixture + memory = spatial_memory + + from dimos.utils.data import get_data + + video_path = get_data("assets") / "trimmed_video_office.mov" + assert os.path.exists(video_path), f"Test video not found: {video_path}" + video_provider = VideoProvider(dev_name="test_video", video_source=video_path) + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=15) + + # Create a frame counter for position generation + frame_counter = 0 + + # Process each video frame directly + def process_frame(frame): + nonlocal frame_counter + + # Generate a unique position for this frame to ensure minimum distance threshold is met + pos = Pose(frame_counter * 0.5, frame_counter * 0.5, 0) + transform = {"position": pos, "timestamp": time.time()} + frame_counter += 1 + + # Create a dictionary with frame, position and rotation for SpatialMemory.process_stream + return { + "frame": frame, + "position": transform["position"], + "rotation": transform["position"], # Using position as rotation for testing + } + + # Create a stream that processes each frame + formatted_stream = video_stream.pipe(ops.map(process_frame)) + + # Process the stream using SpatialMemory's built-in processing + print("Creating spatial memory stream...") + spatial_stream = memory.process_stream(formatted_stream) + + # Stream is now created above using memory.process_stream() + + # Collect results from the stream + results = [] + + frames_processed = 0 + target_frames = 100 # Process more frames for thorough testing + + def on_next(result) -> None: + nonlocal results, frames_processed + if not result: # Skip None results + return + + results.append(result) + frames_processed += 1 + + # Stop processing after target frames + if frames_processed >= target_frames: + subscription.dispose() + + def on_error(error) -> None: + pytest.fail(f"Error in spatial stream: {error}") + + def on_completed() -> None: + pass + + # Subscribe and wait for results + subscription = spatial_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + # Wait for frames to be processed + timeout = 30.0 # seconds + start_time = time.time() + while frames_processed < target_frames and time.time() - start_time < timeout: + time.sleep(0.5) + + subscription.dispose() + + assert len(results) > 0, "Failed to process any frames with spatial memory" + + relevant_queries = ["office", "room with furniture"] + irrelevant_query = "star wars" + + for query in relevant_queries: + results = memory.query_by_text(query, limit=2) + print(f"\nResults for query: '{query}'") + + assert len(results) > 0, f"No results found for relevant query: {query}" + + similarities = [1 - r.get("distance") for r in results] + print(f"Similarities: {similarities}") + + assert any(d > 0.22 for d in similarities), ( + f"Expected at least one result with similarity > 0.22 for query '{query}'" + ) + + results = memory.query_by_text(irrelevant_query, limit=2) + print(f"\nResults for query: '{irrelevant_query}'") + + if results: + similarities = [1 - r.get("distance") for r in results] + print(f"Similarities: {similarities}") + + assert all(d < 0.25 for d in similarities), ( + f"Expected all results to have similarity < 0.25 for irrelevant query '{irrelevant_query}'" + ) + + except Exception as e: + pytest.fail(f"Error in test: {e}") + finally: + video_provider.dispose_all() + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py new file mode 100644 index 0000000000..e31b2a2d31 --- /dev/null +++ b/dimos/perception/test_spatial_memory_module.py @@ -0,0 +1,229 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import tempfile +import time + +import pytest +from reactivex import operators as ops + +from dimos import core +from dimos.core import Module, Out, rpc +from dimos.msgs.sensor_msgs import Image +from dimos.perception.spatial_perception import SpatialMemory +from dimos.protocol import pubsub +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +logger = setup_logger() + +pubsub.lcm.autoconf() + + +class VideoReplayModule(Module): + """Module that replays video data from TimedSensorReplay.""" + + video_out: Out[Image] = None + + def __init__(self, video_path: str) -> None: + super().__init__() + self.video_path = video_path + self._subscription = None + + @rpc + def start(self) -> None: + """Start replaying video data.""" + # Use TimedSensorReplay to replay video frames + video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) + + # Subscribe to the replay stream and publish to LCM + self._subscription = ( + video_replay.stream() + .pipe( + ops.sample(2), # Sample every 2 seconds for resource-constrained systems + ops.take(5), # Only take 5 frames total + ) + .subscribe(self.video_out.publish) + ) + + logger.info("VideoReplayModule started") + + @rpc + def stop(self) -> None: + """Stop replaying video data.""" + if self._subscription: + self._subscription.dispose() + self._subscription = None + logger.info("VideoReplayModule stopped") + + +class OdometryReplayModule(Module): + """Module that replays odometry data from TimedSensorReplay.""" + + odom_out: Out[Odometry] = None + + def __init__(self, odom_path: str) -> None: + super().__init__() + self.odom_path = odom_path + self._subscription = None + + @rpc + def start(self) -> None: + """Start replaying odometry data.""" + # Use TimedSensorReplay to replay odometry + odom_replay = TimedSensorReplay(self.odom_path, autocast=Odometry.from_msg) + + # Subscribe to the replay stream and publish to LCM + self._subscription = ( + odom_replay.stream() + .pipe( + ops.sample(0.5), # Sample every 500ms + ops.take(10), # Only take 10 odometry updates total + ) + .subscribe(self.odom_out.publish) + ) + + logger.info("OdometryReplayModule started") + + @rpc + def stop(self) -> None: + """Stop replaying odometry data.""" + if self._subscription: + self._subscription.dispose() + self._subscription = None + logger.info("OdometryReplayModule stopped") + + +@pytest.mark.gpu +class TestSpatialMemoryModule: + @pytest.fixture(scope="function") + def temp_dir(self): + """Create a temporary directory for test data.""" + # Use standard tempfile module to ensure proper permissions + temp_dir = tempfile.mkdtemp(prefix="spatial_memory_test_") + + yield temp_dir + + @pytest.mark.asyncio + async def test_spatial_memory_module_with_replay(self, temp_dir): + """Test SpatialMemory module with TimedSensorReplay inputs.""" + + # Start Dask + dimos = core.start(1) + + try: + # Get test data paths + data_path = get_data("unitree_office_walk") + video_path = os.path.join(data_path, "video") + odom_path = os.path.join(data_path, "odom") + + # Deploy modules + # Video replay module + video_module = dimos.deploy(VideoReplayModule, video_path) + video_module.video_out.transport = core.LCMTransport("/test_video", Image) + + # Odometry replay module + odom_module = dimos.deploy(OdometryReplayModule, odom_path) + odom_module.odom_out.transport = core.LCMTransport("/test_odom", Odometry) + + # Spatial memory module + spatial_memory = dimos.deploy( + SpatialMemory, + collection_name="test_spatial_memory", + embedding_model="clip", + embedding_dimensions=512, + min_distance_threshold=0.5, # 0.5m for test + min_time_threshold=1.0, # 1 second + db_path=os.path.join(temp_dir, "chroma_db"), + visual_memory_path=os.path.join(temp_dir, "visual_memory.pkl"), + new_memory=True, + output_dir=os.path.join(temp_dir, "images"), + ) + + # Connect streams + spatial_memory.video.connect(video_module.video_out) + spatial_memory.odom.connect(odom_module.odom_out) + + # Start all modules + video_module.start() + odom_module.start() + spatial_memory.start() + logger.info("All modules started, processing in background...") + + # Wait for frames to be processed with timeout + timeout = 10.0 # 10 second timeout + start_time = time.time() + + # Keep checking stats while modules are running + while (time.time() - start_time) < timeout: + stats = spatial_memory.get_stats() + if stats["frame_count"] > 0 and stats["stored_frame_count"] > 0: + logger.info( + f"Frames processing - Frame count: {stats['frame_count']}, Stored: {stats['stored_frame_count']}" + ) + break + await asyncio.sleep(0.5) + else: + # Timeout reached + stats = spatial_memory.get_stats() + logger.error( + f"Timeout after {timeout}s - Frame count: {stats['frame_count']}, Stored: {stats['stored_frame_count']}" + ) + raise AssertionError(f"No frames processed within {timeout} seconds") + + await asyncio.sleep(2) + + mid_stats = spatial_memory.get_stats() + logger.info( + f"Mid-test stats - Frame count: {mid_stats['frame_count']}, Stored: {mid_stats['stored_frame_count']}" + ) + assert mid_stats["frame_count"] >= stats["frame_count"], ( + "Frame count should increase or stay same" + ) + + # Test query while modules are still running + try: + text_results = spatial_memory.query_by_text("office") + logger.info(f"Query by text 'office' returned {len(text_results)} results") + assert len(text_results) > 0, "Should have at least one result" + except Exception as e: + logger.warning(f"Query by text failed: {e}") + + final_stats = spatial_memory.get_stats() + logger.info( + f"Final stats - Frame count: {final_stats['frame_count']}, Stored: {final_stats['stored_frame_count']}" + ) + + video_module.stop() + odom_module.stop() + logger.info("Stopped replay modules") + + logger.info("All spatial memory module tests passed!") + + finally: + # Cleanup + if "dimos" in locals(): + dimos.close() + + +if __name__ == "__main__": + pytest.main(["-v", "-s", __file__]) + # test = TestSpatialMemoryModule() + # asyncio.run( + # test.test_spatial_memory_module_with_replay(tempfile.mkdtemp(prefix="spatial_memory_test_")) + # ) diff --git a/dimos/manipulation/imitation/imitation_learning.py b/dimos/protocol/__init__.py similarity index 100% rename from dimos/manipulation/imitation/imitation_learning.py rename to dimos/protocol/__init__.py diff --git a/dimos/protocol/encode/__init__.py b/dimos/protocol/encode/__init__.py new file mode 100644 index 0000000000..87386a09e5 --- /dev/null +++ b/dimos/protocol/encode/__init__.py @@ -0,0 +1,89 @@ +from abc import ABC, abstractmethod +import json +from typing import Generic, Protocol, TypeVar + +MsgT = TypeVar("MsgT") +EncodingT = TypeVar("EncodingT") + + +class LCMMessage(Protocol): + """Protocol for LCM message types that have encode/decode methods.""" + + def encode(self) -> bytes: + """Encode the message to bytes.""" + ... + + @staticmethod + def decode(data: bytes) -> "LCMMessage": + """Decode bytes to a message instance.""" + ... + + +# TypeVar for LCM message types +LCMMsgT = TypeVar("LCMMsgT", bound=LCMMessage) + + +class Encoder(ABC, Generic[MsgT, EncodingT]): + """Base class for message encoders/decoders.""" + + @staticmethod + @abstractmethod + def encode(msg: MsgT) -> EncodingT: + raise NotImplementedError("Subclasses must implement this method.") + + @staticmethod + @abstractmethod + def decode(data: EncodingT) -> MsgT: + raise NotImplementedError("Subclasses must implement this method.") + + +class JSON(Encoder[MsgT, bytes]): + @staticmethod + def encode(msg: MsgT) -> bytes: + return json.dumps(msg).encode("utf-8") + + @staticmethod + def decode(data: bytes) -> MsgT: + return json.loads(data.decode("utf-8")) # type: ignore[no-any-return] + + +class LCM(Encoder[LCMMsgT, bytes]): + """Encoder for LCM message types.""" + + @staticmethod + def encode(msg: LCMMsgT) -> bytes: + return msg.encode() + + @staticmethod + def decode(data: bytes) -> LCMMsgT: + # Note: This is a generic implementation. In practice, you would need + # to pass the specific message type to decode with. This method would + # typically be overridden in subclasses for specific message types. + raise NotImplementedError( + "LCM.decode requires a specific message type. Use LCMTypedEncoder[MessageType] instead." + ) + + +class LCMTypedEncoder(LCM, Generic[LCMMsgT]): # type: ignore[type-arg] + """Typed LCM encoder for specific message types.""" + + def __init__(self, message_type: type[LCMMsgT]) -> None: + self.message_type = message_type + + @staticmethod + def decode(data: bytes) -> LCMMsgT: + # This is a generic implementation and should be overridden in specific instances + raise NotImplementedError( + "LCMTypedEncoder.decode must be overridden with a specific message type" + ) + + +def create_lcm_typed_encoder(message_type: type[LCMMsgT]) -> type[LCMTypedEncoder[LCMMsgT]]: + """Factory function to create a typed LCM encoder for a specific message type.""" + + class SpecificLCMEncoder(LCMTypedEncoder): # type: ignore[type-arg] + @staticmethod + def decode(data: bytes) -> LCMMsgT: + return message_type.decode(data) # type: ignore[return-value] + + return SpecificLCMEncoder diff --git a/dimos/protocol/pubsub/__init__.py b/dimos/protocol/pubsub/__init__.py new file mode 100644 index 0000000000..89bd292fda --- /dev/null +++ b/dimos/protocol/pubsub/__init__.py @@ -0,0 +1,3 @@ +import dimos.protocol.pubsub.lcmpubsub as lcm +from dimos.protocol.pubsub.memory import Memory +from dimos.protocol.pubsub.spec import PubSub diff --git a/dimos/protocol/pubsub/jpeg_shm.py b/dimos/protocol/pubsub/jpeg_shm.py new file mode 100644 index 0000000000..c61848c57a --- /dev/null +++ b/dimos/protocol/pubsub/jpeg_shm.py @@ -0,0 +1,20 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.pubsub.lcmpubsub import JpegSharedMemoryEncoderMixin +from dimos.protocol.pubsub.shmpubsub import SharedMemoryPubSubBase + + +class JpegSharedMemory(JpegSharedMemoryEncoderMixin, SharedMemoryPubSubBase): # type: ignore[misc] + pass diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py new file mode 100644 index 0000000000..9207e7dfc0 --- /dev/null +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -0,0 +1,171 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from turbojpeg import TurboJPEG # type: ignore[import-untyped] + +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ImageFormat +from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin +from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from collections.abc import Callable + import threading + +logger = setup_logger() + + +@runtime_checkable +class LCMMsg(Protocol): + msg_name: str + + @classmethod + def lcm_decode(cls, data: bytes) -> LCMMsg: + """Decode bytes into an LCM message instance.""" + ... + + def lcm_encode(self) -> bytes: + """Encode this message instance into bytes.""" + ... + + +@dataclass +class Topic: + topic: str = "" + lcm_type: type[LCMMsg] | None = None + + def __str__(self) -> str: + if self.lcm_type is None: + return self.topic + return f"{self.topic}#{self.lcm_type.msg_name}" + + +class LCMPubSubBase(LCMService, PubSub[Topic, Any]): + default_config = LCMConfig + _stop_event: threading.Event + _thread: threading.Thread | None + _callbacks: dict[str, list[Callable[[Any], None]]] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self._callbacks = {} + + def publish(self, topic: Topic, message: bytes) -> None: + """Publish a message to the specified channel.""" + if self.l is None: + logger.error("Tried to publish after LCM was closed") + return + + self.l.publish(str(topic), message) + + def subscribe( + self, topic: Topic, callback: Callable[[bytes, Topic], Any] + ) -> Callable[[], None]: + if self.l is None: + logger.error("Tried to subscribe after LCM was closed") + + def noop() -> None: + pass + + return noop + + lcm_subscription = self.l.subscribe(str(topic), lambda _, msg: callback(msg, topic)) + + # Set queue capacity to 10000 to handle high-volume bursts + lcm_subscription.set_queue_capacity(10000) + + def unsubscribe() -> None: + if self.l is None: + return + self.l.unsubscribe(lcm_subscription) + + return unsubscribe + + +class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]): + def encode(self, msg: LCMMsg, _: Topic) -> bytes: + return msg.lcm_encode() + + def decode(self, msg: bytes, topic: Topic) -> LCMMsg: + if topic.lcm_type is None: + raise ValueError( + f"Cannot decode message for topic '{topic.topic}': no lcm_type specified" + ) + return topic.lcm_type.lcm_decode(msg) + + +class JpegEncoderMixin(PubSubEncoderMixin[Topic, Any]): + def encode(self, msg: LCMMsg, _: Topic) -> bytes: + return msg.lcm_jpeg_encode() # type: ignore[attr-defined, no-any-return] + + def decode(self, msg: bytes, topic: Topic) -> LCMMsg: + if topic.lcm_type is None: + raise ValueError( + f"Cannot decode message for topic '{topic.topic}': no lcm_type specified" + ) + return topic.lcm_type.lcm_jpeg_decode(msg) # type: ignore[attr-defined, no-any-return] + + +class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image]): + def __init__(self, quality: int = 75, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + self.jpeg = TurboJPEG() + self.quality = quality + + def encode(self, msg: Any, _topic: str) -> bytes: + if not isinstance(msg, Image): + raise ValueError("Can only encode images.") + + bgr_image = msg.to_bgr().to_opencv() + return self.jpeg.encode(bgr_image, quality=self.quality) # type: ignore[no-any-return] + + def decode(self, msg: bytes, _topic: str) -> Image: + bgr_array = self.jpeg.decode(msg) + return Image(data=bgr_array, format=ImageFormat.BGR) + + +class LCM( + LCMEncoderMixin, + LCMPubSubBase, +): ... + + +class PickleLCM( + PickleEncoderMixin, # type: ignore[type-arg] + LCMPubSubBase, +): ... + + +class JpegLCM( + JpegEncoderMixin, + LCMPubSubBase, +): ... + + +__all__ = [ + "LCM", + "JpegLCM", + "LCMEncoderMixin", + "LCMMsg", + "LCMMsg", + "LCMPubSubBase", + "PickleLCM", + "autoconf", +] diff --git a/dimos/protocol/pubsub/memory.py b/dimos/protocol/pubsub/memory.py new file mode 100644 index 0000000000..e46fc10500 --- /dev/null +++ b/dimos/protocol/pubsub/memory.py @@ -0,0 +1,60 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +from dimos.protocol import encode +from dimos.protocol.pubsub.spec import PubSub, PubSubEncoderMixin + + +class Memory(PubSub[str, Any]): + def __init__(self) -> None: + self._map: defaultdict[str, list[Callable[[Any, str], None]]] = defaultdict(list) + + def publish(self, topic: str, message: Any) -> None: + for cb in self._map[topic]: + cb(message, topic) + + def subscribe(self, topic: str, callback: Callable[[Any, str], None]) -> Callable[[], None]: + self._map[topic].append(callback) + + def unsubscribe() -> None: + try: + self._map[topic].remove(callback) + if not self._map[topic]: + del self._map[topic] + except (KeyError, ValueError): + pass + + return unsubscribe + + def unsubscribe(self, topic: str, callback: Callable[[Any, str], None]) -> None: + try: + self._map[topic].remove(callback) + if not self._map[topic]: + del self._map[topic] + except (KeyError, ValueError): + pass + + +class MemoryWithJSONEncoder(PubSubEncoderMixin, Memory): # type: ignore[type-arg] + """Memory PubSub with JSON encoding/decoding.""" + + def encode(self, msg: Any, topic: str) -> bytes: + return encode.JSON.encode(msg) + + def decode(self, msg: bytes, topic: str) -> Any: + return encode.JSON.decode(msg) diff --git a/dimos/protocol/pubsub/redispubsub.py b/dimos/protocol/pubsub/redispubsub.py new file mode 100644 index 0000000000..6cc089e953 --- /dev/null +++ b/dimos/protocol/pubsub/redispubsub.py @@ -0,0 +1,198 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass, field +import json +import threading +import time +from types import TracebackType +from typing import Any + +import redis # type: ignore[import-not-found] + +from dimos.protocol.pubsub.spec import PubSub +from dimos.protocol.service.spec import Service + + +@dataclass +class RedisConfig: + host: str = "localhost" + port: int = 6379 + db: int = 0 + kwargs: dict[str, Any] = field(default_factory=dict) + + +class Redis(PubSub[str, Any], Service[RedisConfig]): + """Redis-based pub/sub implementation.""" + + default_config = RedisConfig + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + + # Redis connections + self._client = None + self._pubsub = None + + # Subscription management + self._callbacks: dict[str, list[Callable[[Any, str], None]]] = defaultdict(list) + self._listener_thread = None + self._running = False + + def start(self) -> None: + """Start the Redis pub/sub service.""" + if self._running: + return + self._connect() # type: ignore[no-untyped-call] + + def stop(self) -> None: + """Stop the Redis pub/sub service.""" + self.close() + + def _connect(self): # type: ignore[no-untyped-def] + """Connect to Redis and set up pub/sub.""" + try: + self._client = redis.Redis( + host=self.config.host, + port=self.config.port, + db=self.config.db, + decode_responses=True, + **self.config.kwargs, + ) + # Test connection + self._client.ping() # type: ignore[attr-defined] + + self._pubsub = self._client.pubsub() # type: ignore[attr-defined] + self._running = True + + # Start listener thread + self._listener_thread = threading.Thread(target=self._listen_loop, daemon=True) # type: ignore[assignment] + self._listener_thread.start() # type: ignore[attr-defined] + + except Exception as e: + raise ConnectionError( + f"Failed to connect to Redis at {self.config.host}:{self.config.port}: {e}" + ) + + def _listen_loop(self) -> None: + """Listen for messages from Redis and dispatch to callbacks.""" + while self._running: + try: + if not self._pubsub: + break + message = self._pubsub.get_message(timeout=0.1) + if message and message["type"] == "message": + topic = message["channel"] + data = message["data"] + + # Try to deserialize JSON, fall back to raw data + try: + data = json.loads(data) + except (json.JSONDecodeError, TypeError): + pass + + # Call all callbacks for this topic + for callback in self._callbacks.get(topic, []): + try: + callback(data, topic) + except Exception as e: + # Log error but continue processing other callbacks + print(f"Error in callback for topic {topic}: {e}") + + except Exception as e: + if self._running: # Only log if we're still supposed to be running + print(f"Error in Redis listener loop: {e}") + time.sleep(0.1) # Brief pause before retrying + + def publish(self, topic: str, message: Any) -> None: + """Publish a message to a topic.""" + if not self._client: + raise RuntimeError("Redis client not connected") + + # Serialize message as JSON if it's not a string + if isinstance(message, str): + data = message + else: + data = json.dumps(message) + + self._client.publish(topic, data) + + def subscribe(self, topic: str, callback: Callable[[Any, str], None]) -> Callable[[], None]: + """Subscribe to a topic with a callback.""" + if not self._pubsub: + raise RuntimeError("Redis pubsub not initialized") + + # If this is the first callback for this topic, subscribe to Redis channel + if topic not in self._callbacks or not self._callbacks[topic]: + self._pubsub.subscribe(topic) + + # Add callback to our list + self._callbacks[topic].append(callback) + + # Return unsubscribe function + def unsubscribe() -> None: + self.unsubscribe(topic, callback) + + return unsubscribe + + def unsubscribe(self, topic: str, callback: Callable[[Any, str], None]) -> None: + """Unsubscribe a callback from a topic.""" + if topic in self._callbacks: + try: + self._callbacks[topic].remove(callback) + + # If no more callbacks for this topic, unsubscribe from Redis channel + if not self._callbacks[topic]: + if self._pubsub: + self._pubsub.unsubscribe(topic) + del self._callbacks[topic] + + except ValueError: + pass # Callback wasn't in the list + + def close(self) -> None: + """Close Redis connections and stop listener thread.""" + self._running = False + + if self._listener_thread and self._listener_thread.is_alive(): + self._listener_thread.join(timeout=1.0) + + if self._pubsub: + try: + self._pubsub.close() + except Exception: + pass + self._pubsub = None + + if self._client: + try: + self._client.close() + except Exception: + pass + self._client = None + + self._callbacks.clear() + + def __enter__(self): # type: ignore[no-untyped-def] + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py new file mode 100644 index 0000000000..5f69c3dbd1 --- /dev/null +++ b/dimos/protocol/pubsub/shm/ipc_factory.py @@ -0,0 +1,309 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# frame_ipc.py +# Python 3.9+ +from abc import ABC, abstractmethod +from multiprocessing.shared_memory import SharedMemory +import os +import time + +import numpy as np + +_UNLINK_ON_GC = os.getenv("DIMOS_IPC_UNLINK_ON_GC", "0").lower() not in ("0", "false", "no") + + +def _open_shm_with_retry(name: str) -> SharedMemory: + tries = int(os.getenv("DIMOS_IPC_ATTACH_RETRIES", "40")) # ~40 tries + base_ms = float(os.getenv("DIMOS_IPC_ATTACH_BACKOFF_MS", "5")) # 5 ms + cap_ms = float(os.getenv("DIMOS_IPC_ATTACH_BACKOFF_CAP_MS", "200")) # 200 ms + last = None + for i in range(tries): + try: + return SharedMemory(name=name) + except FileNotFoundError as e: + last = e + # exponential backoff, capped + time.sleep(min((base_ms * (2**i)), cap_ms) / 1000.0) + raise FileNotFoundError(f"SHM not found after {tries} retries: {name}") from last + + +def _sanitize_shm_name(name: str) -> str: + # Python's SharedMemory expects names like 'psm_abc', without leading '/' + return name.lstrip("/") if isinstance(name, str) else name + + +# --------------------------- +# 1) Abstract interface +# --------------------------- + + +class FrameChannel(ABC): + """Single-slot 'freshest frame' IPC channel with a tiny control block. + - Double-buffered to avoid torn reads. + - Descriptor is JSON-safe; attach() reconstructs in another process. + """ + + @property + @abstractmethod + def device(self) -> str: # "cpu" or "cuda" + ... + + @property + @abstractmethod + def shape(self) -> tuple: ... # type: ignore[type-arg] + + @property + @abstractmethod + def dtype(self) -> np.dtype: ... # type: ignore[type-arg] + + @abstractmethod + def publish(self, frame) -> None: # type: ignore[no-untyped-def] + """Write into inactive buffer, then flip visible index (write control last).""" + ... + + @abstractmethod + def read(self, last_seq: int = -1, require_new: bool = True): # type: ignore[no-untyped-def] + """Return (seq:int, ts_ns:int, view-or-None).""" + ... + + @abstractmethod + def descriptor(self) -> dict: # type: ignore[type-arg] + """Tiny JSON-safe descriptor (names/handles/shape/dtype/device).""" + ... + + @classmethod + @abstractmethod + def attach(cls, desc: dict) -> "FrameChannel": # type: ignore[type-arg] + """Attach in another process.""" + ... + + @abstractmethod + def close(self) -> None: + """Detach resources (owner also unlinks manager if applicable).""" + ... + + +from multiprocessing.shared_memory import SharedMemory +import os +import weakref + + +def _safe_unlink(name: str) -> None: + try: + shm = SharedMemory(name=name) + shm.unlink() + except FileNotFoundError: + pass + except Exception: + pass + + +# --------------------------- +# 2) CPU shared-memory backend +# --------------------------- + + +class CpuShmChannel(FrameChannel): + def __init__( # type: ignore[no-untyped-def] + self, + shape, + dtype=np.uint8, + *, + data_name: str | None = None, + ctrl_name: str | None = None, + ) -> None: + self._shape = tuple(shape) + self._dtype = np.dtype(dtype) + self._nbytes = int(self._dtype.itemsize * np.prod(self._shape)) + + def _create_or_open(name: str, size: int): # type: ignore[no-untyped-def] + try: + shm = SharedMemory(create=True, size=size, name=name) + owner = True + except FileExistsError: + shm = SharedMemory(name=name) # attach existing + owner = False + return shm, owner + + if data_name is None or ctrl_name is None: + # fallback: random names (old behavior) + self._shm_data = SharedMemory(create=True, size=2 * self._nbytes) + self._shm_ctrl = SharedMemory(create=True, size=24) + self._is_owner = True + else: + self._shm_data, own_d = _create_or_open(data_name, 2 * self._nbytes) + self._shm_ctrl, own_c = _create_or_open(ctrl_name, 24) + self._is_owner = own_d and own_c + + self._ctrl = np.ndarray((3,), dtype=np.int64, buffer=self._shm_ctrl.buf) # type: ignore[var-annotated] + if self._is_owner: + self._ctrl[:] = 0 # initialize only once + + # only owners set unlink finalizers (beware cross-process timing) + self._finalizer_data = ( + weakref.finalize(self, _safe_unlink, self._shm_data.name) + if (_UNLINK_ON_GC and self._is_owner) + else None + ) + self._finalizer_ctrl = ( + weakref.finalize(self, _safe_unlink, self._shm_ctrl.name) + if (_UNLINK_ON_GC and self._is_owner) + else None + ) + + def descriptor(self): # type: ignore[no-untyped-def] + return { + "kind": "cpu", + "shape": self._shape, + "dtype": self._dtype.str, + "nbytes": self._nbytes, + "data_name": self._shm_data.name, + "ctrl_name": self._shm_ctrl.name, + } + + @property + def device(self) -> str: + return "cpu" + + @property + def shape(self): # type: ignore[no-untyped-def] + return self._shape + + @property + def dtype(self): # type: ignore[no-untyped-def] + return self._dtype + + def publish(self, frame) -> None: # type: ignore[no-untyped-def] + assert isinstance(frame, np.ndarray) + assert frame.shape == self._shape and frame.dtype == self._dtype + active = int(self._ctrl[2]) + inactive = 1 - active + view = np.ndarray( # type: ignore[var-annotated] + self._shape, + dtype=self._dtype, + buffer=self._shm_data.buf, + offset=inactive * self._nbytes, + ) + np.copyto(view, frame, casting="no") + ts = np.int64(time.time_ns()) + # Publish order: ts -> idx -> seq + self._ctrl[1] = ts + self._ctrl[2] = inactive + self._ctrl[0] += 1 + + def read(self, last_seq: int = -1, require_new: bool = True): # type: ignore[no-untyped-def] + for _ in range(3): + seq1 = int(self._ctrl[0]) + idx = int(self._ctrl[2]) + ts = int(self._ctrl[1]) + view = np.ndarray( # type: ignore[var-annotated] + self._shape, dtype=self._dtype, buffer=self._shm_data.buf, offset=idx * self._nbytes + ) + if seq1 == int(self._ctrl[0]): + if require_new and seq1 == last_seq: + return seq1, ts, None + return seq1, ts, view + return last_seq, 0, None + + def descriptor(self): # type: ignore[no-redef, no-untyped-def] + return { + "kind": "cpu", + "shape": self._shape, + "dtype": self._dtype.str, + "nbytes": self._nbytes, + "data_name": self._shm_data.name, + "ctrl_name": self._shm_ctrl.name, + } + + @classmethod + def attach(cls, desc: str): # type: ignore[no-untyped-def, override] + obj = object.__new__(cls) + obj._shape = tuple(desc["shape"]) # type: ignore[index] + obj._dtype = np.dtype(desc["dtype"]) # type: ignore[index] + obj._nbytes = int(desc["nbytes"]) # type: ignore[index] + data_name = desc["data_name"] # type: ignore[index] + ctrl_name = desc["ctrl_name"] # type: ignore[index] + try: + obj._shm_data = _open_shm_with_retry(data_name) + obj._shm_ctrl = _open_shm_with_retry(ctrl_name) + except FileNotFoundError as e: + raise FileNotFoundError( + f"CPU IPC attach failed: control/data SHM not found " + f"(ctrl='{ctrl_name}', data='{data_name}'). " + f"Ensure the writer is running on the same host and the channel is alive." + ) from e + obj._ctrl = np.ndarray((3,), dtype=np.int64, buffer=obj._shm_ctrl.buf) + # attachments don’t own/unlink + obj._finalizer_data = obj._finalizer_ctrl = None + return obj + + def close(self) -> None: + if getattr(self, "_is_owner", False): + try: + self._shm_ctrl.close() + finally: + try: + _safe_unlink(self._shm_ctrl.name) + except: + pass + if hasattr(self, "_shm_data"): + try: + self._shm_data.close() + finally: + try: + _safe_unlink(self._shm_data.name) + except: + pass + return + # readers: just close handles + try: + self._shm_ctrl.close() + except: + pass + try: + self._shm_data.close() + except: + pass + + +# --------------------------- +# 3) Factories +# --------------------------- + + +class CPU_IPC_Factory: + """Creates/attaches CPU shared-memory channels.""" + + @staticmethod + def create(shape, dtype=np.uint8) -> CpuShmChannel: # type: ignore[no-untyped-def] + return CpuShmChannel(shape, dtype=dtype) + + @staticmethod + def attach(desc: dict) -> CpuShmChannel: # type: ignore[type-arg] + assert desc.get("kind") == "cpu", "Descriptor kind mismatch" + return CpuShmChannel.attach(desc) # type: ignore[arg-type, no-any-return] + + +# --------------------------- +# 4) Runtime selector +# --------------------------- + + +def make_frame_channel( # type: ignore[no-untyped-def] + shape, dtype=np.uint8, prefer: str = "auto", device: int = 0 +) -> FrameChannel: + """Choose CUDA IPC if available (or requested), otherwise CPU SHM.""" + # TODO: Implement the CUDA version of creating this factory + return CPU_IPC_Factory.create(shape, dtype=dtype) diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py new file mode 100644 index 0000000000..b38bf5c5be --- /dev/null +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# --------------------------------------------------------------------------- +# SharedMemory Pub/Sub over unified IPC channels (CPU/CUDA) +# --------------------------------------------------------------------------- + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +import hashlib +import os +import struct +import threading +import time +from typing import TYPE_CHECKING, Any +import uuid + +import numpy as np + +from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel +from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from collections.abc import Callable + +logger = setup_logger() + + +# -------------------------------------------------------------------------------------- +# Configuration (kept local to PubSub now that Service is gone) +# -------------------------------------------------------------------------------------- + + +@dataclass +class SharedMemoryConfig: + prefer: str = "auto" # "auto" | "cpu" (DIMOS_IPC_BACKEND overrides), TODO: "cuda" + default_capacity: int = 3686400 # payload bytes (excludes 4-byte header) + close_channels_on_stop: bool = True + + +# -------------------------------------------------------------------------------------- +# Core PubSub with integrated SHM/IPC transport (previously the Service logic) +# -------------------------------------------------------------------------------------- + + +class SharedMemoryPubSubBase(PubSub[str, Any]): + """ + Pub/Sub over SharedMemory/CUDA-IPC, modeled after LCMPubSubBase but self-contained. + Wire format per topic/frame: [len:uint32_le] + payload bytes (padded to fixed capacity). + Features ported from Service: + - start()/stop() lifecycle + - one frame channel per topic + - per-topic fanout thread (reads from channel, invokes subscribers) + - CPU/CUDA backend selection (auto + env override) + - reconfigure(topic, capacity=...) + - drop initial empty frame; synchronous local delivery; echo suppression + """ + + # Per-topic state + # TODO: implement "is_cuda" below capacity, above cp + class _TopicState: + __slots__ = ( + "capacity", + "channel", + "cp", + "dtype", + "last_local_payload", + "last_seq", + "shape", + "stop", + "subs", + "suppress_counts", + "thread", + ) + + def __init__(self, channel, capacity: int, cp_mod) -> None: # type: ignore[no-untyped-def] + self.channel = channel + self.capacity = int(capacity) + self.shape = (self.capacity + 20,) # +20 for header: length(4) + uuid(16) + self.dtype = np.uint8 + self.subs: list[Callable[[bytes, str], None]] = [] + self.stop = threading.Event() + self.thread: threading.Thread | None = None + self.last_seq = 0 # start at 0 to avoid b"" on first poll + # TODO: implement an initializer variable for is_cuda once CUDA IPC is in + self.cp = cp_mod + self.last_local_payload: bytes | None = None + self.suppress_counts: dict[bytes, int] = defaultdict(int) # UUID bytes as key + + # ----- init / lifecycle ------------------------------------------------- + + def __init__( + self, + *, + prefer: str = "auto", + default_capacity: int = 3686400, + close_channels_on_stop: bool = True, + **_: Any, + ) -> None: + super().__init__() + self.config = SharedMemoryConfig( + prefer=prefer, + default_capacity=default_capacity, + close_channels_on_stop=close_channels_on_stop, + ) + self._topics: dict[str, SharedMemoryPubSubBase._TopicState] = {} + self._lock = threading.Lock() + + def start(self) -> None: + pref = (self.config.prefer or "auto").lower() + backend = os.getenv("DIMOS_IPC_BACKEND", pref).lower() + logger.info(f"SharedMemory PubSub starting (backend={backend})") + # No global thread needed; per-topic fanout starts on first subscribe. + + def stop(self) -> None: + with self._lock: + for _topic, st in list(self._topics.items()): + # stop fanout + try: + if st.thread: + st.stop.set() + st.thread.join(timeout=0.5) + st.thread = None + except Exception: + pass + # close/unlink channels if configured + if self.config.close_channels_on_stop: + try: + st.channel.close() + except Exception: + pass + self._topics.clear() + logger.info("SharedMemory PubSub stopped.") + + # ----- PubSub API (bytes on the wire) ---------------------------------- + + def publish(self, topic: str, message: bytes) -> None: + if not isinstance(message, bytes | bytearray | memoryview): + raise TypeError(f"publish expects bytes-like, got {type(message)!r}") + + st = self._ensure_topic(topic) + + # Normalize once + payload_bytes = bytes(message) + L = len(payload_bytes) + if L > st.capacity: + logger.error(f"Payload too large: {L} > capacity {st.capacity}") + raise ValueError(f"Payload too large: {L} > capacity {st.capacity}") + + # Create a unique identifier using UUID4 + message_id = uuid.uuid4().bytes # 16 bytes + + # Mark this message to suppress its echo + st.suppress_counts[message_id] += 1 + + # Synchronous local delivery first (zero extra copies) + for cb in list(st.subs): + try: + cb(payload_bytes, topic) + except Exception: + logger.warn(f"Payload couldn't be pushed to topic: {topic}") + pass + + # Build host frame [len:4] + [uuid:16] + payload and publish + # We embed the message UUID in the frame for echo suppression + host = np.zeros(st.shape, dtype=st.dtype) + # Pack: length(4) + uuid(16) + payload + header = struct.pack(" Callable[[], None]: + """Subscribe a callback(message: bytes, topic). Returns unsubscribe.""" + st = self._ensure_topic(topic) + st.subs.append(callback) + if st.thread is None: + st.thread = threading.Thread(target=self._fanout_loop, args=(topic, st), daemon=True) + st.thread.start() + + def _unsub() -> None: + try: + st.subs.remove(callback) + except ValueError: + pass + if not st.subs and st.thread: + st.stop.set() + st.thread.join(timeout=0.5) + st.thread = None + st.stop.clear() + + return _unsub + + # ----- Capacity mgmt ---------------------------------------------------- + + def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[type-arg] + """Change payload capacity (bytes) for a topic; returns new descriptor.""" + st = self._ensure_topic(topic) + new_cap = int(capacity) + new_shape = (new_cap + 20,) # +20 for header: length(4) + uuid(16) + desc = st.channel.reconfigure(new_shape, np.uint8) + st.capacity = new_cap + st.shape = new_shape + st.dtype = np.uint8 + st.last_seq = -1 + return desc # type: ignore[no-any-return] + + # ----- Internals -------------------------------------------------------- + + def _ensure_topic(self, topic: str) -> _TopicState: + with self._lock: + st = self._topics.get(topic) + if st is not None: + return st + cap = int(self.config.default_capacity) + + def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]: + # Python’s SharedMemory requires names without a leading '/' + h = hashlib.blake2b(f"{topic}:{capacity}".encode(), digest_size=12).hexdigest() + return f"psm_{h}_data", f"psm_{h}_ctrl" + + data_name, ctrl_name = _names_for_topic(topic, cap) + ch = CpuShmChannel((cap + 20,), np.uint8, data_name=data_name, ctrl_name=ctrl_name) + st = SharedMemoryPubSubBase._TopicState(ch, cap, None) + self._topics[topic] = st + return st + + def _fanout_loop(self, topic: str, st: _TopicState) -> None: + while not st.stop.is_set(): + seq, _ts_ns, view = st.channel.read(last_seq=st.last_seq, require_new=True) + if view is None: + time.sleep(0.001) + continue + st.last_seq = seq + + host = np.array(view, copy=True) + + try: + # Read header: length(4) + uuid(16) + L = struct.unpack(" st.capacity + 16: + continue + + # Extract UUID + message_id = host[4:20].tobytes() + + # Extract actual payload (after removing the 16 bytes for uuid) + payload_len = L - 16 + if payload_len > 0: + payload = host[20 : 20 + payload_len].tobytes() + else: + continue + + # Drop exactly the number of local echoes we created + cnt = st.suppress_counts.get(message_id, 0) + if cnt > 0: + if cnt == 1: + del st.suppress_counts[message_id] + else: + st.suppress_counts[message_id] = cnt - 1 + continue # suppressed + + except Exception: + continue + + for cb in list(st.subs): + try: + cb(payload, topic) + except Exception: + pass + + +# -------------------------------------------------------------------------------------- +# Encoders + concrete PubSub classes +# -------------------------------------------------------------------------------------- + + +class SharedMemoryBytesEncoderMixin(PubSubEncoderMixin[str, bytes]): + """Identity encoder for raw bytes.""" + + def encode(self, msg: bytes, _: str) -> bytes: + if isinstance(msg, bytes | bytearray | memoryview): + return bytes(msg) + raise TypeError(f"SharedMemory expects bytes-like, got {type(msg)!r}") + + def decode(self, msg: bytes, _: str) -> bytes: + return msg + + +class SharedMemory( + SharedMemoryBytesEncoderMixin, + SharedMemoryPubSubBase, +): + """SharedMemory pubsub that transports raw bytes.""" + + ... + + +class PickleSharedMemory( + PickleEncoderMixin[str, Any], + SharedMemoryPubSubBase, +): + """SharedMemory pubsub that transports arbitrary Python objects via pickle.""" + + ... diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py new file mode 100644 index 0000000000..28fce3faee --- /dev/null +++ b/dimos/protocol/pubsub/spec.py @@ -0,0 +1,154 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +import asyncio +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass +import pickle +from typing import Any, Generic, TypeVar + +from dimos.utils.logging_config import setup_logger + +MsgT = TypeVar("MsgT") +TopicT = TypeVar("TopicT") + + +logger = setup_logger() + + +class PubSub(Generic[TopicT, MsgT], ABC): + """Abstract base class for pub/sub implementations with sugar methods.""" + + @abstractmethod + def publish(self, topic: TopicT, message: MsgT) -> None: + """Publish a message to a topic.""" + ... + + @abstractmethod + def subscribe( + self, topic: TopicT, callback: Callable[[MsgT, TopicT], None] + ) -> Callable[[], None]: + """Subscribe to a topic with a callback. returns unsubscribe function""" + ... + + @dataclass(slots=True) + class _Subscription: + _bus: "PubSub[Any, Any]" + _topic: Any + _cb: Callable[[Any, Any], None] + _unsubscribe_fn: Callable[[], None] + + def unsubscribe(self) -> None: + self._unsubscribe_fn() + + # context-manager helper + def __enter__(self): # type: ignore[no-untyped-def] + return self + + def __exit__(self, *exc) -> None: # type: ignore[no-untyped-def] + self.unsubscribe() + + # public helper: returns disposable object + def sub(self, topic: TopicT, cb: Callable[[MsgT, TopicT], None]) -> "_Subscription": + unsubscribe_fn = self.subscribe(topic, cb) + return self._Subscription(self, topic, cb, unsubscribe_fn) + + # async iterator + async def aiter(self, topic: TopicT, *, max_pending: int | None = None) -> AsyncIterator[MsgT]: + q: asyncio.Queue[MsgT] = asyncio.Queue(maxsize=max_pending or 0) + + def _cb(msg: MsgT, topic: TopicT) -> None: + q.put_nowait(msg) + + unsubscribe_fn = self.subscribe(topic, _cb) + try: + while True: + yield await q.get() + finally: + unsubscribe_fn() + + # async context manager returning a queue + + @asynccontextmanager + async def queue(self, topic: TopicT, *, max_pending: int | None = None): # type: ignore[no-untyped-def] + q: asyncio.Queue[MsgT] = asyncio.Queue(maxsize=max_pending or 0) + + def _queue_cb(msg: MsgT, topic: TopicT) -> None: + q.put_nowait(msg) + + unsubscribe_fn = self.subscribe(topic, _queue_cb) + try: + yield q + finally: + unsubscribe_fn() + + +class PubSubEncoderMixin(Generic[TopicT, MsgT], ABC): + """Mixin that encodes messages before publishing and decodes them after receiving. + + Usage: Just specify encoder and decoder as a subclass: + + class MyPubSubWithJSON(PubSubEncoderMixin, MyPubSub): + def encoder(msg, topic): + json.dumps(msg).encode('utf-8') + def decoder(msg, topic): + data: json.loads(data.decode('utf-8')) + """ + + @abstractmethod + def encode(self, msg: MsgT, topic: TopicT) -> bytes: ... + + @abstractmethod + def decode(self, msg: bytes, topic: TopicT) -> MsgT: ... + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self._encode_callback_map: dict = {} # type: ignore[type-arg] + + def publish(self, topic: TopicT, message: MsgT) -> None: + """Encode the message and publish it.""" + if getattr(self, "_stop_event", None) is not None and self._stop_event.is_set(): # type: ignore[attr-defined] + return + encoded_message = self.encode(message, topic) + if encoded_message is None: + return + super().publish(topic, encoded_message) # type: ignore[misc] + + def subscribe( + self, topic: TopicT, callback: Callable[[MsgT, TopicT], None] + ) -> Callable[[], None]: + """Subscribe with automatic decoding.""" + + def wrapper_cb(encoded_data: bytes, topic: TopicT) -> None: + decoded_message = self.decode(encoded_data, topic) + callback(decoded_message, topic) + + return super().subscribe(topic, wrapper_cb) # type: ignore[misc, no-any-return] + + +class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT]): + def encode(self, msg: MsgT, *_: TopicT) -> bytes: # type: ignore[return] + try: + return pickle.dumps(msg) + except Exception as e: + print("Pickle encoding error:", e) + import traceback + + traceback.print_exc() + print("Tried to pickle:", msg) + + def decode(self, msg: bytes, _: TopicT) -> MsgT: + return pickle.loads(msg) # type: ignore[no-any-return] diff --git a/dimos/protocol/pubsub/test_encoder.py b/dimos/protocol/pubsub/test_encoder.py new file mode 100644 index 0000000000..f39bd170d5 --- /dev/null +++ b/dimos/protocol/pubsub/test_encoder.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from dimos.protocol.pubsub.memory import Memory, MemoryWithJSONEncoder + + +def test_json_encoded_pubsub() -> None: + """Test memory pubsub with JSON encoding.""" + pubsub = MemoryWithJSONEncoder() + received_messages = [] + + def callback(message, topic) -> None: + received_messages.append(message) + + # Subscribe to a topic + pubsub.subscribe("json_topic", callback) + + # Publish various types of messages + test_messages = [ + "hello world", + 42, + 3.14, + True, + None, + {"name": "Alice", "age": 30, "active": True}, + [1, 2, 3, "four", {"five": 5}], + {"nested": {"data": [1, 2, {"deep": True}]}}, + ] + + for msg in test_messages: + pubsub.publish("json_topic", msg) + + # Verify all messages were received and properly decoded + assert len(received_messages) == len(test_messages) + for original, received in zip(test_messages, received_messages, strict=False): + assert original == received + + +def test_json_encoding_edge_cases() -> None: + """Test edge cases for JSON encoding.""" + pubsub = MemoryWithJSONEncoder() + received_messages = [] + + def callback(message, topic) -> None: + received_messages.append(message) + + pubsub.subscribe("edge_cases", callback) + + # Test edge cases + edge_cases = [ + "", # empty string + [], # empty list + {}, # empty dict + 0, # zero + False, # False boolean + [None, None, None], # list with None values + {"": "empty_key", "null": None, "empty_list": [], "empty_dict": {}}, + ] + + for case in edge_cases: + pubsub.publish("edge_cases", case) + + assert received_messages == edge_cases + + +def test_multiple_subscribers_with_encoding() -> None: + """Test that multiple subscribers work with encoding.""" + pubsub = MemoryWithJSONEncoder() + received_messages_1 = [] + received_messages_2 = [] + + def callback_1(message, topic) -> None: + received_messages_1.append(message) + + def callback_2(message, topic) -> None: + received_messages_2.append(f"callback_2: {message}") + + pubsub.subscribe("json_topic", callback_1) + pubsub.subscribe("json_topic", callback_2) + pubsub.publish("json_topic", {"multi": "subscriber test"}) + + # Both callbacks should receive the message + assert received_messages_1[-1] == {"multi": "subscriber test"} + assert received_messages_2[-1] == "callback_2: {'multi': 'subscriber test'}" + + +# def test_unsubscribe_with_encoding(): +# """Test unsubscribe works correctly with encoded callbacks.""" +# pubsub = MemoryWithJSONEncoder() +# received_messages_1 = [] +# received_messages_2 = [] + +# def callback_1(message): +# received_messages_1.append(message) + +# def callback_2(message): +# received_messages_2.append(message) + +# pubsub.subscribe("json_topic", callback_1) +# pubsub.subscribe("json_topic", callback_2) + +# # Unsubscribe first callback +# pubsub.unsubscribe("json_topic", callback_1) +# pubsub.publish("json_topic", "only callback_2 should get this") + +# # Only callback_2 should receive the message +# assert len(received_messages_1) == 0 +# assert received_messages_2 == ["only callback_2 should get this"] + + +def test_data_actually_encoded_in_transit() -> None: + """Validate that data is actually encoded in transit by intercepting raw bytes.""" + + # Create a spy memory that captures what actually gets published + class SpyMemory(Memory): + def __init__(self) -> None: + super().__init__() + self.raw_messages_received = [] + + def publish(self, topic: str, message) -> None: + # Capture what actually gets published + self.raw_messages_received.append((topic, message, type(message))) + super().publish(topic, message) + + # Create encoder that uses our spy memory + class SpyMemoryWithJSON(MemoryWithJSONEncoder, SpyMemory): + pass + + pubsub = SpyMemoryWithJSON() + received_decoded = [] + + def callback(message, topic) -> None: + received_decoded.append(message) + + pubsub.subscribe("test_topic", callback) + + # Publish a complex object + original_message = {"name": "Alice", "age": 30, "items": [1, 2, 3]} + pubsub.publish("test_topic", original_message) + + # Verify the message was received and decoded correctly + assert len(received_decoded) == 1 + assert received_decoded[0] == original_message + + # Verify the underlying transport actually received JSON bytes, not the original object + assert len(pubsub.raw_messages_received) == 1 + topic, raw_message, raw_type = pubsub.raw_messages_received[0] + + assert topic == "test_topic" + assert raw_type == bytes # Should be bytes, not dict + assert isinstance(raw_message, bytes) + + # Verify it's actually JSON + decoded_raw = json.loads(raw_message.decode("utf-8")) + assert decoded_raw == original_message diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py new file mode 100644 index 0000000000..d06bf20716 --- /dev/null +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -0,0 +1,194 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.protocol.pubsub.lcmpubsub import ( + LCM, + LCMPubSubBase, + PickleLCM, + Topic, +) + + +@pytest.fixture +def lcm_pub_sub_base(): + lcm = LCMPubSubBase(autoconf=True) + lcm.start() + yield lcm + lcm.stop() + + +@pytest.fixture +def pickle_lcm(): + lcm = PickleLCM(autoconf=True) + lcm.start() + yield lcm + lcm.stop() + + +@pytest.fixture +def lcm(): + lcm = LCM(autoconf=True) + lcm.start() + yield lcm + lcm.stop() + + +class MockLCMMessage: + """Mock LCM message for testing""" + + msg_name = "geometry_msgs.Mock" + + def __init__(self, data) -> None: + self.data = data + + def lcm_encode(self) -> bytes: + return str(self.data).encode("utf-8") + + @classmethod + def lcm_decode(cls, data: bytes) -> "MockLCMMessage": + return cls(data.decode("utf-8")) + + def __eq__(self, other): + return isinstance(other, MockLCMMessage) and self.data == other.data + + +def test_LCMPubSubBase_pubsub(lcm_pub_sub_base) -> None: + lcm = lcm_pub_sub_base + + received_messages = [] + + topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) + test_message = MockLCMMessage("test_data") + + def callback(msg, topic) -> None: + received_messages.append((msg, topic)) + + lcm.subscribe(topic, callback) + lcm.publish(topic, test_message.lcm_encode()) + time.sleep(0.1) + + assert len(received_messages) == 1 + + received_data = received_messages[0][0] + received_topic = received_messages[0][1] + + print(f"Received data: {received_data}, Topic: {received_topic}") + + assert isinstance(received_data, bytes) + assert received_data.decode() == "test_data" + + assert isinstance(received_topic, Topic) + assert received_topic == topic + + +def test_lcm_autodecoder_pubsub(lcm) -> None: + received_messages = [] + + topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) + test_message = MockLCMMessage("test_data") + + def callback(msg, topic) -> None: + received_messages.append((msg, topic)) + + lcm.subscribe(topic, callback) + lcm.publish(topic, test_message) + time.sleep(0.1) + + assert len(received_messages) == 1 + + received_data = received_messages[0][0] + received_topic = received_messages[0][1] + + print(f"Received data: {received_data}, Topic: {received_topic}") + + assert isinstance(received_data, MockLCMMessage) + assert received_data == test_message + + assert isinstance(received_topic, Topic) + assert received_topic == topic + + +test_msgs = [ + (Vector3(1, 2, 3)), + (Quaternion(1, 2, 3, 4)), + (Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1))), +] + + +# passes some geometry types through LCM +@pytest.mark.parametrize("test_message", test_msgs) +def test_lcm_geometry_msgs_pubsub(test_message, lcm) -> None: + received_messages = [] + + topic = Topic(topic="/test_topic", lcm_type=test_message.__class__) + + def callback(msg, topic) -> None: + received_messages.append((msg, topic)) + + lcm.subscribe(topic, callback) + lcm.publish(topic, test_message) + + time.sleep(0.1) + + assert len(received_messages) == 1 + + received_data = received_messages[0][0] + received_topic = received_messages[0][1] + + print(f"Received data: {received_data}, Topic: {received_topic}") + + assert isinstance(received_data, test_message.__class__) + assert received_data == test_message + + assert isinstance(received_topic, Topic) + assert received_topic == topic + + print(test_message, topic) + + +# passes some geometry types through pickle LCM +@pytest.mark.parametrize("test_message", test_msgs) +def test_lcm_geometry_msgs_autopickle_pubsub(test_message, pickle_lcm) -> None: + lcm = pickle_lcm + received_messages = [] + + topic = Topic(topic="/test_topic") + + def callback(msg, topic) -> None: + received_messages.append((msg, topic)) + + lcm.subscribe(topic, callback) + lcm.publish(topic, test_message) + + time.sleep(0.1) + + assert len(received_messages) == 1 + + received_data = received_messages[0][0] + received_topic = received_messages[0][1] + + print(f"Received data: {received_data}, Topic: {received_topic}") + + assert isinstance(received_data, test_message.__class__) + assert received_data == test_message + + assert isinstance(received_topic, Topic) + assert received_topic == topic + + print(test_message, topic) diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py new file mode 100644 index 0000000000..91e8514b70 --- /dev/null +++ b/dimos/protocol/pubsub/test_spec.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections.abc import Callable +from contextlib import contextmanager +import time +from typing import Any + +import pytest + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.protocol.pubsub.memory import Memory + + +@contextmanager +def memory_context(): + """Context manager for Memory PubSub implementation.""" + memory = Memory() + try: + yield memory + finally: + # Cleanup logic can be added here if needed + pass + + +# Use Any for context manager type to accommodate both Memory and Redis +testdata: list[tuple[Callable[[], Any], Any, list[Any]]] = [ + (memory_context, "topic", ["value1", "value2", "value3"]), +] + +try: + from dimos.protocol.pubsub.redispubsub import Redis + + @contextmanager + def redis_context(): + redis_pubsub = Redis() + redis_pubsub.start() + yield redis_pubsub + redis_pubsub.stop() + + testdata.append( + (redis_context, "redis_topic", ["redis_value1", "redis_value2", "redis_value3"]) + ) + +except (ConnectionError, ImportError): + # either redis is not installed or the server is not running + print("Redis not available") + + +@contextmanager +def lcm_context(): + lcm_pubsub = LCM(autoconf=True) + lcm_pubsub.start() + yield lcm_pubsub + lcm_pubsub.stop() + + +testdata.append( + ( + lcm_context, + Topic(topic="/test_topic", lcm_type=Vector3), + [Vector3(1, 2, 3), Vector3(4, 5, 6), Vector3(7, 8, 9)], # Using Vector3 as mock data, + ) +) + + +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory + + +@contextmanager +def shared_memory_cpu_context(): + shared_mem_pubsub = PickleSharedMemory(prefer="cpu") + shared_mem_pubsub.start() + yield shared_mem_pubsub + shared_mem_pubsub.stop() + + +testdata.append( + ( + shared_memory_cpu_context, + "/shared_mem_topic_cpu", + [b"shared_mem_value1", b"shared_mem_value2", b"shared_mem_value3"], + ) +) + + +@pytest.mark.parametrize("pubsub_context, topic, values", testdata) +def test_store(pubsub_context, topic, values) -> None: + with pubsub_context() as x: + # Create a list to capture received messages + received_messages = [] + + # Define callback function that stores received messages + def callback(message, _) -> None: + received_messages.append(message) + + # Subscribe to the topic with our callback + x.subscribe(topic, callback) + + # Publish the first value to the topic + x.publish(topic, values[0]) + + # Give Redis time to process the message if needed + time.sleep(0.1) + + print("RECEIVED", received_messages) + # Verify the callback was called with the correct value + assert len(received_messages) == 1 + assert received_messages[0] == values[0] + + +@pytest.mark.parametrize("pubsub_context, topic, values", testdata) +def test_multiple_subscribers(pubsub_context, topic, values) -> None: + """Test that multiple subscribers receive the same message.""" + with pubsub_context() as x: + # Create lists to capture received messages for each subscriber + received_messages_1 = [] + received_messages_2 = [] + + # Define callback functions + def callback_1(message, topic) -> None: + received_messages_1.append(message) + + def callback_2(message, topic) -> None: + received_messages_2.append(message) + + # Subscribe both callbacks to the same topic + x.subscribe(topic, callback_1) + x.subscribe(topic, callback_2) + + # Publish the first value + x.publish(topic, values[0]) + + # Give Redis time to process the message if needed + time.sleep(0.1) + + # Verify both callbacks received the message + assert len(received_messages_1) == 1 + assert received_messages_1[0] == values[0] + assert len(received_messages_2) == 1 + assert received_messages_2[0] == values[0] + + +@pytest.mark.parametrize("pubsub_context, topic, values", testdata) +def test_unsubscribe(pubsub_context, topic, values) -> None: + """Test that unsubscribed callbacks don't receive messages.""" + with pubsub_context() as x: + # Create a list to capture received messages + received_messages = [] + + # Define callback function + def callback(message, topic) -> None: + received_messages.append(message) + + # Subscribe and get unsubscribe function + unsubscribe = x.subscribe(topic, callback) + + # Unsubscribe using the returned function + unsubscribe() + + # Publish the first value + x.publish(topic, values[0]) + + # Give time to process the message if needed + time.sleep(0.1) + + # Verify the callback was not called after unsubscribing + assert len(received_messages) == 0 + + +@pytest.mark.parametrize("pubsub_context, topic, values", testdata) +def test_multiple_messages(pubsub_context, topic, values) -> None: + """Test that subscribers receive multiple messages in order.""" + with pubsub_context() as x: + # Create a list to capture received messages + received_messages = [] + + # Define callback function + def callback(message, topic) -> None: + received_messages.append(message) + + # Subscribe to the topic + x.subscribe(topic, callback) + + # Publish the rest of the values (after the first one used in basic tests) + messages_to_send = values[1:] if len(values) > 1 else values + for msg in messages_to_send: + x.publish(topic, msg) + + # Give Redis time to process the messages if needed + time.sleep(0.2) + + # Verify all messages were received in order + assert len(received_messages) == len(messages_to_send) + assert received_messages == messages_to_send + + +@pytest.mark.parametrize("pubsub_context, topic, values", testdata) +@pytest.mark.asyncio +async def test_async_iterator(pubsub_context, topic, values) -> None: + """Test that async iterator receives messages correctly.""" + with pubsub_context() as x: + # Get the messages to send (using the rest of the values) + messages_to_send = values[1:] if len(values) > 1 else values + received_messages = [] + + # Create the async iterator + async_iter = x.aiter(topic) + + # Create a task to consume messages from the async iterator + async def consume_messages() -> None: + try: + async for message in async_iter: + received_messages.append(message) + # Stop after receiving all expected messages + if len(received_messages) >= len(messages_to_send): + break + except asyncio.CancelledError: + pass + + # Start the consumer task + consumer_task = asyncio.create_task(consume_messages()) + + # Give the consumer a moment to set up + await asyncio.sleep(0.1) + + # Publish messages + for msg in messages_to_send: + x.publish(topic, msg) + # Small delay to ensure message is processed + await asyncio.sleep(0.1) + + # Wait for the consumer to finish or timeout + try: + await asyncio.wait_for(consumer_task, timeout=1.0) # Longer timeout for Redis + except asyncio.TimeoutError: + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + + # Verify all messages were received in order + assert len(received_messages) == len(messages_to_send) + assert received_messages == messages_to_send + + +@pytest.mark.parametrize("pubsub_context, topic, values", testdata) +def test_high_volume_messages(pubsub_context, topic, values) -> None: + """Test that all 5000 messages are received correctly.""" + with pubsub_context() as x: + # Create a list to capture received messages + received_messages = [] + last_message_time = [time.time()] # Use list to allow modification in callback + + # Define callback function + def callback(message, topic) -> None: + received_messages.append(message) + last_message_time[0] = time.time() + + # Subscribe to the topic + x.subscribe(topic, callback) + + # Publish 10000 messages + num_messages = 10000 + for _ in range(num_messages): + x.publish(topic, values[0]) + + # Wait until no messages received for 0.5 seconds + timeout = 1.0 # Maximum time to wait + stable_duration = 0.1 # Time without new messages to consider done + start_time = time.time() + + while time.time() - start_time < timeout: + if time.time() - last_message_time[0] >= stable_duration: + break + time.sleep(0.1) + + # Capture count and clear list to avoid printing huge list on failure + received_len = len(received_messages) + received_messages.clear() + assert received_len == num_messages, f"Expected {num_messages} messages, got {received_len}" diff --git a/dimos/protocol/rpc/__init__.py b/dimos/protocol/rpc/__init__.py new file mode 100644 index 0000000000..5614074a4d --- /dev/null +++ b/dimos/protocol/rpc/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.rpc.pubsubrpc import LCMRPC, ShmRPC +from dimos.protocol.rpc.spec import RPCClient, RPCServer, RPCSpec + +__all__ = ["LCMRPC", "RPCClient", "RPCServer", "RPCSpec", "ShmRPC"] diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py new file mode 100644 index 0000000000..3a26d9f502 --- /dev/null +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -0,0 +1,318 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +import threading +import time +from typing import ( + TYPE_CHECKING, + Any, + Generic, + TypedDict, + TypeVar, +) + +from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH +from dimos.protocol.pubsub.lcmpubsub import PickleLCM, Topic +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory +from dimos.protocol.pubsub.spec import PubSub +from dimos.protocol.rpc.rpc_utils import deserialize_exception, serialize_exception +from dimos.protocol.rpc.spec import Args, RPCSpec +from dimos.utils.generic import short_id +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from types import FunctionType + +logger = setup_logger() + +MsgT = TypeVar("MsgT") +TopicT = TypeVar("TopicT") + +# (name, true_if_response_topic) -> TopicT +TopicGen = Callable[[str, bool], TopicT] +MsgGen = Callable[[str, list], MsgT] # type: ignore[type-arg] + + +class RPCReq(TypedDict): + id: float | None + name: str + args: Args + + +class RPCRes(TypedDict, total=False): + id: float + res: Any + exception: dict[str, Any] | None # Contains exception info: type, message, traceback + + +class PubSubRPCMixin(RPCSpec, PubSub[TopicT, MsgT], Generic[TopicT, MsgT]): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + # Thread pool for RPC handler execution (prevents deadlock in nested calls) + self._call_thread_pool: ThreadPoolExecutor | None = None + self._call_thread_pool_lock = threading.RLock() + self._call_thread_pool_max_workers = 50 + + # Shared response subscriptions: one per RPC name instead of one per call + # Maps str(topic_res) -> (subscription, {msg_id -> callback}) + self._response_subs: dict[str, tuple[Any, dict[float, Callable[..., Any]]]] = {} + self._response_subs_lock = threading.RLock() + + # Message ID counter for unique IDs even with concurrent calls + self._msg_id_counter = 0 + self._msg_id_lock = threading.Lock() + + def __getstate__(self) -> dict[str, Any]: + state: dict[str, Any] + if hasattr(super(), "__getstate__"): + state = super().__getstate__() # type: ignore[assignment] + else: + state = self.__dict__.copy() + + # Exclude unpicklable attributes when serializing. + state.pop("_call_thread_pool", None) + state.pop("_call_thread_pool_lock", None) + state.pop("_response_subs", None) + state.pop("_response_subs_lock", None) + state.pop("_msg_id_lock", None) + + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + if hasattr(super(), "__setstate__"): + super().__setstate__(state) # type: ignore[misc] + else: + self.__dict__.update(state) + + # Restore unserializable attributes. + self._call_thread_pool = None + self._call_thread_pool_lock = threading.RLock() + self._response_subs = {} + self._response_subs_lock = threading.RLock() + self._msg_id_lock = threading.Lock() + + @abstractmethod + def topicgen(self, name: str, req_or_res: bool) -> TopicT: ... + + def _encodeRPCReq(self, req: RPCReq) -> dict[str, Any]: + return dict(req) + + def _decodeRPCRes(self, msg: dict[Any, Any]) -> RPCRes: + return msg # type: ignore[return-value] + + def _encodeRPCRes(self, res: RPCRes) -> dict[str, Any]: + return dict(res) + + def _decodeRPCReq(self, msg: dict[Any, Any]) -> RPCReq: + return msg # type: ignore[return-value] + + def _get_call_thread_pool(self) -> ThreadPoolExecutor: + """Get or create the thread pool for RPC handler execution (lazy initialization).""" + with self._call_thread_pool_lock: + if self._call_thread_pool is None: + self._call_thread_pool = ThreadPoolExecutor( + max_workers=self._call_thread_pool_max_workers + ) + return self._call_thread_pool + + def _shutdown_thread_pool(self) -> None: + """Safely shutdown the thread pool with deadlock prevention.""" + with self._call_thread_pool_lock: + if self._call_thread_pool: + # Check if we're being called from within the thread pool + # to avoid "cannot join current thread" error + current_thread = threading.current_thread() + is_pool_thread = False + + # Check if current thread is one of the pool's threads + if hasattr(self._call_thread_pool, "_threads"): + is_pool_thread = current_thread in self._call_thread_pool._threads + elif "ThreadPoolExecutor" in current_thread.name: + # Fallback: check thread name pattern + is_pool_thread = True + + # Don't wait if we're in a pool thread to avoid deadlock + self._call_thread_pool.shutdown(wait=not is_pool_thread) + self._call_thread_pool = None + + def stop(self) -> None: + """Stop the RPC service and cleanup thread pool. + + Subclasses that override this method should call super().stop() + to ensure the thread pool is properly shutdown. + """ + self._shutdown_thread_pool() + + # Cleanup shared response subscriptions + with self._response_subs_lock: + for unsub, _ in self._response_subs.values(): + unsub() + self._response_subs.clear() + + # Call parent stop if it exists + if hasattr(super(), "stop"): + super().stop() # type: ignore[misc] + + def call(self, name: str, arguments: Args, cb: Callable | None): # type: ignore[no-untyped-def, type-arg] + if cb is None: + return self.call_nowait(name, arguments) + + return self.call_cb(name, arguments, cb) + + def call_cb(self, name: str, arguments: Args, cb: Callable[..., Any]) -> Any: + topic_req = self.topicgen(name, False) + topic_res = self.topicgen(name, True) + + # Generate unique msg_id: timestamp + counter for concurrent calls + with self._msg_id_lock: + self._msg_id_counter += 1 + msg_id = time.time() + (self._msg_id_counter / 1_000_000) + + req: RPCReq = {"name": name, "args": arguments, "id": msg_id} + + # Get or create shared subscription for this RPC's response topic + topic_res_key = str(topic_res) + with self._response_subs_lock: + if topic_res_key not in self._response_subs: + # Create shared handler that routes to callbacks by msg_id + callbacks_dict: dict[float, Callable[..., Any]] = {} + + def shared_response_handler(msg: MsgT, _: TopicT) -> None: + res = self._decodeRPCRes(msg) # type: ignore[arg-type] + res_id = res.get("id") + if res_id is None: + return + + # Look up callback for this msg_id + with self._response_subs_lock: + callback = callbacks_dict.pop(res_id, None) + + if callback is None: + return # No callback registered (already handled or timed out) + + # Check if response contains an exception + exc_data = res.get("exception") + if exc_data: + # Reconstruct the exception and pass it to the callback + from typing import cast + + from dimos.protocol.rpc.rpc_utils import SerializedException + + exc = deserialize_exception(cast("SerializedException", exc_data)) + callback(exc) + else: + # Normal response - pass the result + callback(res.get("res")) + + # Create single shared subscription + unsub = self.subscribe(topic_res, shared_response_handler) + self._response_subs[topic_res_key] = (unsub, callbacks_dict) + + # Register this call's callback + _, callbacks_dict = self._response_subs[topic_res_key] + callbacks_dict[msg_id] = cb + + # Publish request + self.publish(topic_req, self._encodeRPCReq(req)) # type: ignore[arg-type] + + # Return unsubscribe function that removes this callback from the dict + def unsubscribe_callback() -> None: + with self._response_subs_lock: + if topic_res_key in self._response_subs: + _, callbacks_dict = self._response_subs[topic_res_key] + callbacks_dict.pop(msg_id, None) + + return unsubscribe_callback + + def call_nowait(self, name: str, arguments: Args) -> None: + topic_req = self.topicgen(name, False) + req: RPCReq = {"name": name, "args": arguments, "id": None} + self.publish(topic_req, self._encodeRPCReq(req)) # type: ignore[arg-type] + + def serve_rpc(self, f: FunctionType, name: str | None = None): # type: ignore[no-untyped-def, override] + if not name: + name = f.__name__ + + topic_req = self.topicgen(name, False) + topic_res = self.topicgen(name, True) + + def receive_call(msg: MsgT, _: TopicT) -> None: + req = self._decodeRPCReq(msg) # type: ignore[arg-type] + + if req.get("name") != name: + return + + args = req.get("args") + if args is None: + return + + # Execute RPC handler in a separate thread to avoid deadlock when + # the handler makes nested RPC calls. + def execute_and_respond() -> None: + try: + response = f(*args[0], **args[1]) + req_id = req.get("id") + if req_id is not None: + self.publish(topic_res, self._encodeRPCRes({"id": req_id, "res": response})) # type: ignore[arg-type] + + except Exception as e: + logger.exception(f"Exception in RPC handler for {name}: {e}", exc_info=e) + # Send exception data to client if this was a request with an ID + req_id = req.get("id") + if req_id is not None: + exc_data = serialize_exception(e) + # Type ignore: SerializedException is compatible with dict[str, Any] + self.publish( + topic_res, + self._encodeRPCRes({"id": req_id, "exception": exc_data}), # type: ignore[typeddict-item, arg-type] + ) + + # Always use thread pool to execute RPC handlers (prevents deadlock) + self._get_call_thread_pool().submit(execute_and_respond) + + return self.subscribe(topic_req, receive_call) + + +class LCMRPC(PubSubRPCMixin[Topic, Any], PickleLCM): + def __init__(self, **kwargs: Any) -> None: + # Need to ensure PickleLCM gets initialized properly + # This is due to the diamond inheritance pattern with multiple base classes + PickleLCM.__init__(self, **kwargs) + # Initialize PubSubRPCMixin's thread pool + PubSubRPCMixin.__init__(self, **kwargs) + + def topicgen(self, name: str, req_or_res: bool) -> Topic: + suffix = "res" if req_or_res else "req" + topic = f"/rpc/{name}/{suffix}" + if len(topic) > LCM_MAX_CHANNEL_NAME_LENGTH: + topic = f"/rpc/{short_id(name)}/{suffix}" + return Topic(topic=topic) + + +class ShmRPC(PubSubRPCMixin[str, Any], PickleSharedMemory): + def __init__(self, prefer: str = "cpu", **kwargs: Any) -> None: + # Need to ensure SharedMemory gets initialized properly + # This is due to the diamond inheritance pattern with multiple base classes + PickleSharedMemory.__init__(self, prefer=prefer, **kwargs) + # Initialize PubSubRPCMixin's thread pool + PubSubRPCMixin.__init__(self, **kwargs) + + def topicgen(self, name: str, req_or_res: bool) -> str: + suffix = "res" if req_or_res else "req" + return f"/rpc/{name}/{suffix}" diff --git a/dimos/protocol/rpc/redisrpc.py b/dimos/protocol/rpc/redisrpc.py new file mode 100644 index 0000000000..aa8a5b87c5 --- /dev/null +++ b/dimos/protocol/rpc/redisrpc.py @@ -0,0 +1,21 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.pubsub.redispubsub import Redis +from dimos.protocol.rpc.pubsubrpc import PubSubRPCMixin + + +class RedisRPC(PubSubRPCMixin, Redis): # type: ignore[type-arg] + def topicgen(self, name: str, req_or_res: bool) -> str: + return f"/rpc/{name}/{'res' if req_or_res else 'req'}" diff --git a/dimos/protocol/rpc/rpc_utils.py b/dimos/protocol/rpc/rpc_utils.py new file mode 100644 index 0000000000..26ab281e45 --- /dev/null +++ b/dimos/protocol/rpc/rpc_utils.py @@ -0,0 +1,104 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for serializing and deserializing exceptions for RPC transport.""" + +from __future__ import annotations + +import traceback +from typing import Any, TypedDict + + +class SerializedException(TypedDict): + """Type for serialized exception data.""" + + type_name: str + type_module: str + args: tuple[Any, ...] + traceback: str + + +class RemoteError(Exception): + """Exception that was raised on a remote RPC server. + + Preserves the original exception type and full stack trace from the remote side. + """ + + def __init__( + self, type_name: str, type_module: str, args: tuple[Any, ...], traceback: str + ) -> None: + super().__init__(*args if args else (f"Remote exception: {type_name}",)) + self.remote_type = f"{type_module}.{type_name}" + self.remote_traceback = traceback + + def __str__(self) -> str: + base_msg = super().__str__() + return ( + f"[Remote {self.remote_type}] {base_msg}\n\nRemote traceback:\n{self.remote_traceback}" + ) + + +def serialize_exception(exc: Exception) -> SerializedException: + """Convert an exception to a transferable format. + + Args: + exc: The exception to serialize + + Returns: + A dictionary containing the exception data that can be transferred + """ + # Get the full traceback as a string + tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + + return SerializedException( + type_name=type(exc).__name__, + type_module=type(exc).__module__, + args=exc.args, + traceback=tb_str, + ) + + +def deserialize_exception(exc_data: SerializedException) -> Exception: + """Reconstruct an exception from serialized data. + + For builtin exceptions, instantiates the actual type. + For custom exceptions, returns a RemoteError. + + Args: + exc_data: The serialized exception data + + Returns: + An exception that can be raised with full type and traceback info + """ + type_name = exc_data.get("type_name", "Exception") + type_module = exc_data.get("type_module", "builtins") + args: tuple[Any, ...] = exc_data.get("args", ()) + tb_str = exc_data.get("traceback", "") + + # Only reconstruct builtin exceptions + if type_module == "builtins": + try: + import builtins + + exc_class = getattr(builtins, type_name, None) + if exc_class and issubclass(exc_class, BaseException): + exc = exc_class(*args) + # Add remote traceback as __cause__ for context + exc.__cause__ = RemoteError(type_name, type_module, args, tb_str) + return exc # type: ignore[no-any-return] + except (AttributeError, TypeError): + pass + + # Use RemoteError for non-builtin or if reconstruction failed + return RemoteError(type_name, type_module, args, tb_str) diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py new file mode 100644 index 0000000000..a09d4bfaab --- /dev/null +++ b/dimos/protocol/rpc/spec.py @@ -0,0 +1,100 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections.abc import Callable +import threading +from typing import Any, Protocol, overload + + +class Empty: ... + + +Args = tuple[list, dict[str, Any]] # type: ignore[type-arg] + + +# module that we can inspect for RPCs +class RPCInspectable(Protocol): + @property + def rpcs(self) -> dict[str, Callable]: ... # type: ignore[type-arg] + + +class RPCClient(Protocol): + # if we don't provide callback, we don't get a return unsub f + @overload + def call(self, name: str, arguments: Args, cb: None) -> None: ... + + # if we provide callback, we do get return unsub f + @overload + def call(self, name: str, arguments: Args, cb: Callable[[Any], None]) -> Callable[[], Any]: ... + + def call(self, name: str, arguments: Args, cb: Callable | None) -> Callable[[], Any] | None: ... # type: ignore[type-arg] + + # we expect to crash if we don't get a return value after 10 seconds + # but callers can override this timeout for extra long functions + def call_sync( + self, name: str, arguments: Args, rpc_timeout: float | None = 120.0 + ) -> tuple[Any, Callable[[], None]]: + event = threading.Event() + + def receive_value(val) -> None: # type: ignore[no-untyped-def] + event.result = val # type: ignore[attr-defined] # attach to event + event.set() + + unsub_fn = self.call(name, arguments, receive_value) + if not event.wait(rpc_timeout): + raise TimeoutError(f"RPC call to '{name}' timed out after {rpc_timeout} seconds") + + # Check if the result is an exception and raise it + result = event.result # type: ignore[attr-defined] + if isinstance(result, BaseException): + raise result + + return result, unsub_fn + + async def call_async(self, name: str, arguments: Args) -> Any: + loop = asyncio.get_event_loop() + future = loop.create_future() + + def receive_value(val) -> None: # type: ignore[no-untyped-def] + try: + # Check if the value is an exception + if isinstance(val, BaseException): + loop.call_soon_threadsafe(future.set_exception, val) + else: + loop.call_soon_threadsafe(future.set_result, val) + except Exception as e: + loop.call_soon_threadsafe(future.set_exception, e) + + self.call(name, arguments, receive_value) + + return await future + + +class RPCServer(Protocol): + def serve_rpc(self, f: Callable, name: str) -> Callable[[], None]: ... # type: ignore[type-arg] + + def serve_module_rpc(self, module: RPCInspectable, name: str | None = None) -> None: + for fname in module.rpcs.keys(): + if not name: + name = module.__class__.__name__ + + def override_f(*args, fname=fname, **kwargs): # type: ignore[no-untyped-def] + return getattr(module, fname)(*args, **kwargs) + + topic = name + "/" + fname + self.serve_rpc(override_f, topic) + + +class RPCSpec(RPCServer, RPCClient): ... diff --git a/dimos/protocol/rpc/test_lcmrpc.py b/dimos/protocol/rpc/test_lcmrpc.py new file mode 100644 index 0000000000..f31d20cf19 --- /dev/null +++ b/dimos/protocol/rpc/test_lcmrpc.py @@ -0,0 +1,45 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Generator + +import pytest + +from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH +from dimos.protocol.rpc import LCMRPC + + +@pytest.fixture +def lcmrpc() -> Generator[LCMRPC, None, None]: + ret = LCMRPC() + ret.start() + yield ret + ret.stop() + + +def test_short_name(lcmrpc) -> None: + actual = lcmrpc.topicgen("Hello/say", req_or_res=True) + assert actual.topic == "/rpc/Hello/say/res" + + +def test_long_name(lcmrpc) -> None: + long = "GreatyLongComplexExampleClassNameForTestingStuff/create" + long_topic = lcmrpc.topicgen(long, req_or_res=True).topic + assert long_topic == "/rpc/2cudPuFGMJdWxM5KZb/res" + + less_long = long[:-1] + less_long_topic = lcmrpc.topicgen(less_long, req_or_res=True).topic + assert less_long_topic == "/rpc/GreatyLongComplexExampleClassNameForTestingStuff/creat/res" + + assert len(less_long_topic) == LCM_MAX_CHANNEL_NAME_LENGTH diff --git a/dimos/protocol/rpc/test_rpc_utils.py b/dimos/protocol/rpc/test_rpc_utils.py new file mode 100644 index 0000000000..b5e6253aaf --- /dev/null +++ b/dimos/protocol/rpc/test_rpc_utils.py @@ -0,0 +1,70 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RPC exception serialization utilities.""" + +from dimos.protocol.rpc.rpc_utils import ( + RemoteError, + deserialize_exception, + serialize_exception, +) + + +def test_exception_builtin_serialization(): + """Test serialization and deserialization of exceptions.""" + + # Test with a builtin exception + try: + raise ValueError("test error", 42) + except ValueError as e: + serialized = serialize_exception(e) + + # Check serialized format + assert serialized["type_name"] == "ValueError" + assert serialized["type_module"] == "builtins" + assert serialized["args"] == ("test error", 42) + assert "Traceback" in serialized["traceback"] + assert "test error" in serialized["traceback"] + + # Deserialize and check we get a real ValueError back + deserialized = deserialize_exception(serialized) + assert isinstance(deserialized, ValueError) + assert deserialized.args == ("test error", 42) + # Check that remote traceback is attached as cause + assert isinstance(deserialized.__cause__, RemoteError) + assert "test error" in deserialized.__cause__.remote_traceback + + +def test_exception_custom_serialization(): + # Test with a custom exception + class CustomError(Exception): + pass + + try: + raise CustomError("custom message") + except CustomError as e: + serialized = serialize_exception(e) + + # Check serialized format + assert serialized["type_name"] == "CustomError" + # Module name varies when running under pytest vs directly + assert serialized["type_module"] in ("__main__", "dimos.protocol.rpc.test_rpc_utils") + assert serialized["args"] == ("custom message",) + + # Deserialize - should get RemoteError since it's not builtin + deserialized = deserialize_exception(serialized) + assert isinstance(deserialized, RemoteError) + assert "CustomError" in deserialized.remote_type + assert "custom message" in str(deserialized) + assert "custom message" in deserialized.remote_traceback diff --git a/dimos/protocol/rpc/test_spec.py b/dimos/protocol/rpc/test_spec.py new file mode 100644 index 0000000000..9fb8f65eb7 --- /dev/null +++ b/dimos/protocol/rpc/test_spec.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grid tests for RPC implementations to ensure spec compliance.""" + +import asyncio +from collections.abc import Callable +from contextlib import contextmanager +import threading +import time +from typing import Any + +import pytest + +from dimos.protocol.rpc.pubsubrpc import LCMRPC, ShmRPC +from dimos.protocol.rpc.rpc_utils import RemoteError + + +class CustomTestError(Exception): + """Custom exception for testing.""" + + pass + + +# Build testdata list with available implementations +testdata: list[tuple[Callable[[], Any], str]] = [] + + +# Context managers for different RPC implementations +@contextmanager +def lcm_rpc_context(): + """Context manager for LCMRPC implementation.""" + from dimos.protocol.service.lcmservice import autoconf + + autoconf() + server = LCMRPC() + client = LCMRPC() + server.start() + client.start() + + try: + yield server, client + finally: + server.stop() + client.stop() + + +testdata.append((lcm_rpc_context, "lcm")) + + +@contextmanager +def shm_rpc_context(): + """Context manager for Shared Memory RPC implementation.""" + # Create two separate instances that communicate through shared memory segments + server = ShmRPC(prefer="cpu") + client = ShmRPC(prefer="cpu") + server.start() + client.start() + + try: + yield server, client + finally: + server.stop() + client.stop() + + +testdata.append((shm_rpc_context, "shm")) + +# Try to add RedisRPC if available +try: + from dimos.protocol.rpc.redisrpc import RedisRPC + + @contextmanager + def redis_rpc_context(): + """Context manager for RedisRPC implementation.""" + server = RedisRPC() + client = RedisRPC() + server.start() + client.start() + + try: + yield server, client + finally: + server.stop() + client.stop() + + testdata.append((redis_rpc_context, "redis")) +except (ImportError, ConnectionError): + print("RedisRPC not available") + + +# Test functions that will be served +def add_function(a: int, b: int) -> int: + """Simple addition function for testing.""" + return a + b + + +def failing_function(msg: str) -> str: + """Function that raises exceptions for testing.""" + if msg == "fail": + raise ValueError("Test error message") + elif msg == "custom": + raise CustomTestError("Custom error") + return f"Success: {msg}" + + +def slow_function(delay: float) -> str: + """Function that takes time to execute.""" + time.sleep(delay) + return f"Completed after {delay} seconds" + + +# Grid tests + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_basic_sync_call(rpc_context, impl_name: str) -> None: + """Test basic synchronous RPC calls.""" + with rpc_context() as (server, client): + # Serve the function + unsub = server.serve_rpc(add_function, "add") + + try: + # Make sync call + result, _ = client.call_sync("add", ([5, 3], {}), rpc_timeout=2.0) + assert result == 8 + + # Test with different arguments + result, _ = client.call_sync("add", ([10, -2], {}), rpc_timeout=2.0) + assert result == 8 + + finally: + unsub() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +@pytest.mark.asyncio +@pytest.mark.skip( + reason="Async RPC calls have a deadlock issue when run in the full test suite (works in isolation)" +) +async def test_async_call(rpc_context, impl_name: str) -> None: + """Test asynchronous RPC calls.""" + with rpc_context() as (server, client): + # Serve the function + unsub = server.serve_rpc(add_function, "add_async") + + try: + # Make async call + result = await client.call_async("add_async", ([7, 4], {})) + assert result == 11 + + # Test multiple async calls + results = await asyncio.gather( + client.call_async("add_async", ([1, 2], {})), + client.call_async("add_async", ([3, 4], {})), + client.call_async("add_async", ([5, 6], {})), + ) + assert results == [3, 7, 11] + + finally: + unsub() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_callback_call(rpc_context, impl_name: str) -> None: + """Test callback-based RPC calls.""" + with rpc_context() as (server, client): + # Serve the function + unsub_server = server.serve_rpc(add_function, "add_callback") + + try: + # Test with callback + event = threading.Event() + received_value = None + + def callback(val) -> None: + nonlocal received_value + received_value = val + event.set() + + client.call("add_callback", ([20, 22], {}), callback) + assert event.wait(2.0) + assert received_value == 42 + + finally: + unsub_server() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_exception_handling_sync(rpc_context, impl_name: str) -> None: + """Test that exceptions are properly passed through sync RPC calls.""" + with rpc_context() as (server, client): + # Serve the function that can raise exceptions + unsub = server.serve_rpc(failing_function, "test_exc") + + try: + # Test successful call + result, _ = client.call_sync("test_exc", (["ok"], {}), rpc_timeout=2.0) + assert result == "Success: ok" + + # Test builtin exception - should raise actual ValueError + with pytest.raises(ValueError) as exc_info: + client.call_sync("test_exc", (["fail"], {}), rpc_timeout=2.0) + assert "Test error message" in str(exc_info.value) + # Check that the cause contains the remote traceback + assert isinstance(exc_info.value.__cause__, RemoteError) + assert "failing_function" in exc_info.value.__cause__.remote_traceback + + # Test custom exception - should raise RemoteError + with pytest.raises(RemoteError) as exc_info: + client.call_sync("test_exc", (["custom"], {}), rpc_timeout=2.0) + assert "Custom error" in str(exc_info.value) + assert "CustomTestError" in exc_info.value.remote_type + assert "failing_function" in exc_info.value.remote_traceback + + finally: + unsub() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +@pytest.mark.asyncio +async def test_exception_handling_async(rpc_context, impl_name: str) -> None: + """Test that exceptions are properly passed through async RPC calls.""" + with rpc_context() as (server, client): + # Serve the function that can raise exceptions + unsub = server.serve_rpc(failing_function, "test_exc_async") + + try: + # Test successful call + result = await client.call_async("test_exc_async", (["ok"], {})) + assert result == "Success: ok" + + # Test builtin exception + with pytest.raises(ValueError) as exc_info: + await client.call_async("test_exc_async", (["fail"], {})) + assert "Test error message" in str(exc_info.value) + assert isinstance(exc_info.value.__cause__, RemoteError) + + # Test custom exception + with pytest.raises(RemoteError) as exc_info: + await client.call_async("test_exc_async", (["custom"], {})) + assert "Custom error" in str(exc_info.value) + assert "CustomTestError" in exc_info.value.remote_type + + finally: + unsub() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_exception_handling_callback(rpc_context, impl_name: str) -> None: + """Test that exceptions are properly passed through callback-based RPC calls.""" + with rpc_context() as (server, client): + # Serve the function that can raise exceptions + unsub_server = server.serve_rpc(failing_function, "test_exc_cb") + + try: + # Test with callback - exception should be passed to callback + event = threading.Event() + received_value = None + + def callback(val) -> None: + nonlocal received_value + received_value = val + event.set() + + # Test successful call + client.call("test_exc_cb", (["ok"], {}), callback) + assert event.wait(2.0) + assert received_value == "Success: ok" + event.clear() + + # Test failed call - exception should be passed to callback + client.call("test_exc_cb", (["fail"], {}), callback) + assert event.wait(2.0) + assert isinstance(received_value, ValueError) + assert "Test error message" in str(received_value) + assert isinstance(received_value.__cause__, RemoteError) + + finally: + unsub_server() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_timeout(rpc_context, impl_name: str) -> None: + """Test that RPC calls properly timeout.""" + with rpc_context() as (server, client): + # Serve a slow function + unsub = server.serve_rpc(slow_function, "slow") + + try: + # Call with short timeout should fail + # Using 10 seconds sleep to ensure it would definitely timeout + with pytest.raises(TimeoutError) as exc_info: + client.call_sync("slow", ([2.0], {}), rpc_timeout=0.1) + assert "timed out" in str(exc_info.value) + + # Call with sufficient timeout should succeed + result, _ = client.call_sync("slow", ([0.01], {}), rpc_timeout=1.0) + assert "Completed after 0.01 seconds" in result + + finally: + unsub() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_nonexistent_service(rpc_context, impl_name: str) -> None: + """Test calling a service that doesn't exist.""" + with rpc_context() as (_server, client): + # Don't serve any function, just try to call + with pytest.raises(TimeoutError) as exc_info: + client.call_sync("nonexistent", ([1, 2], {}), rpc_timeout=0.1) + assert "nonexistent" in str(exc_info.value) + assert "timed out" in str(exc_info.value) + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_multiple_services(rpc_context, impl_name: str) -> None: + """Test serving multiple RPC functions simultaneously.""" + with rpc_context() as (server, client): + # Serve multiple functions + unsub1 = server.serve_rpc(add_function, "service1") + unsub2 = server.serve_rpc(lambda x: x * 2, "service2") + unsub3 = server.serve_rpc(lambda s: s.upper(), "service3") + + try: + # Call all services + result1, _ = client.call_sync("service1", ([3, 4], {}), rpc_timeout=1.0) + assert result1 == 7 + + result2, _ = client.call_sync("service2", ([21], {}), rpc_timeout=1.0) + assert result2 == 42 + + result3, _ = client.call_sync("service3", (["hello"], {}), rpc_timeout=1.0) + assert result3 == "HELLO" + + finally: + unsub1() + unsub2() + unsub3() + + +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +def test_concurrent_calls(rpc_context, impl_name: str) -> None: + """Test making multiple concurrent RPC calls.""" + # Skip for SharedMemory - double-buffered architecture can't handle concurrent bursts + # The channel only holds 2 frames, so 1000 rapid concurrent responses overwrite each other + if impl_name == "shm": + pytest.skip("SharedMemory uses double-buffering; can't handle 1000 concurrent responses") + + with rpc_context() as (server, client): + # Serve a function that we'll call concurrently + unsub = server.serve_rpc(add_function, "concurrent_add") + + try: + # Make multiple concurrent calls using threads + results = [] + threads = [] + + def make_call(a, b) -> None: + result, _ = client.call_sync("concurrent_add", ([a, b], {}), rpc_timeout=2.0) + results.append(result) + + # Start 1000 concurrent calls + for i in range(1000): + t = threading.Thread(target=make_call, args=(i, i + 1)) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join(timeout=10.0) + + # Verify all calls succeeded + assert len(results) == 1000 + # Results should be [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] but may be in any order + expected = [i + (i + 1) for i in range(1000)] + assert sorted(results) == sorted(expected) + + finally: + unsub() + + +if __name__ == "__main__": + # Run tests for debugging + pytest.main([__file__, "-v"]) diff --git a/dimos/protocol/service/__init__.py b/dimos/protocol/service/__init__.py new file mode 100644 index 0000000000..4726ad5f83 --- /dev/null +++ b/dimos/protocol/service/__init__.py @@ -0,0 +1,2 @@ +from dimos.protocol.service.lcmservice import LCMService +from dimos.protocol.service.spec import Configurable, Service diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py new file mode 100644 index 0000000000..5f2a5db7ca --- /dev/null +++ b/dimos/protocol/service/lcmservice.py @@ -0,0 +1,343 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import cache +import os +import subprocess +import sys +import threading +import traceback +from typing import Protocol, runtime_checkable + +import lcm + +from dimos.protocol.service.spec import Service +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +@cache +def check_root() -> bool: + """Return True if the current process is running as root (UID 0).""" + try: + return os.geteuid() == 0 + except AttributeError: + # Platforms without geteuid (e.g. Windows) – assume non-root. + return False + + +def check_multicast() -> list[str]: + """Check if multicast configuration is needed and return required commands.""" + commands_needed = [] + + sudo = "" if check_root() else "sudo " + + # Check if loopback interface has multicast enabled + try: + result = subprocess.run(["ip", "link", "show", "lo"], capture_output=True, text=True) + if "MULTICAST" not in result.stdout: + commands_needed.append(f"{sudo}ifconfig lo multicast") + except Exception: + commands_needed.append(f"{sudo}ifconfig lo multicast") + + # Check if multicast route exists + try: + result = subprocess.run( + ["ip", "route", "show", "224.0.0.0/4"], capture_output=True, text=True + ) + if not result.stdout.strip(): + commands_needed.append(f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + except Exception: + commands_needed.append(f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + + return commands_needed + + +def check_buffers() -> tuple[list[str], int | None]: + """Check if buffer configuration is needed and return required commands and current size. + + Returns: + Tuple of (commands_needed, current_max_buffer_size) + """ + commands_needed = [] + current_max = None + + sudo = "" if check_root() else "sudo " + + # Check current buffer settings + try: + result = subprocess.run(["sysctl", "net.core.rmem_max"], capture_output=True, text=True) + current_max = int(result.stdout.split("=")[1].strip()) if result.returncode == 0 else None + if not current_max or current_max < 67108864: + commands_needed.append(f"{sudo}sysctl -w net.core.rmem_max=67108864") + except: + commands_needed.append(f"{sudo}sysctl -w net.core.rmem_max=67108864") + + try: + result = subprocess.run(["sysctl", "net.core.rmem_default"], capture_output=True, text=True) + current_default = ( + int(result.stdout.split("=")[1].strip()) if result.returncode == 0 else None + ) + if not current_default or current_default < 16777216: + commands_needed.append(f"{sudo}sysctl -w net.core.rmem_default=16777216") + except: + commands_needed.append(f"{sudo}sysctl -w net.core.rmem_default=16777216") + + return commands_needed, current_max + + +def check_system() -> None: + """Check if system configuration is needed and exit only for critical issues. + + Multicast configuration is critical for LCM to work. + Buffer sizes are performance optimizations - warn but don't fail in containers. + """ + if os.environ.get("CI"): + logger.debug("CI environment detected: Skipping system configuration checks.") + return + + multicast_commands = check_multicast() + buffer_commands, current_buffer_size = check_buffers() + + # Check multicast first - this is critical + if multicast_commands: + logger.error( + "Critical: Multicast configuration required. Please run the following commands:" + ) + for cmd in multicast_commands: + logger.error(f" {cmd}") + logger.error("\nThen restart your application.") + sys.exit(1) + + # Buffer configuration is just for performance + elif buffer_commands: + if current_buffer_size: + logger.warning( + f"UDP buffer size limited to {current_buffer_size} bytes ({current_buffer_size // 1024}KB). Large LCM packets may fail." + ) + else: + logger.warning("UDP buffer sizes are limited. Large LCM packets may fail.") + logger.warning("For better performance, consider running:") + for cmd in buffer_commands: + logger.warning(f" {cmd}") + logger.warning("Note: This may not be possible in Docker containers.") + + +def autoconf() -> None: + """Auto-configure system by running checks and executing required commands if needed.""" + if os.environ.get("CI"): + logger.info("CI environment detected: Skipping automatic system configuration.") + return + + commands_needed = [] + + # Check multicast configuration + commands_needed.extend(check_multicast()) + + # Check buffer configuration + buffer_commands, _ = check_buffers() + commands_needed.extend(buffer_commands) + + if not commands_needed: + return + + logger.info("System configuration required. Executing commands...") + + for cmd in commands_needed: + logger.info(f" Running: {cmd}") + try: + # Split command into parts for subprocess + cmd_parts = cmd.split() + subprocess.run(cmd_parts, capture_output=True, text=True, check=True) + logger.info(" ✓ Success") + except subprocess.CalledProcessError as e: + # Check if this is a multicast/route command or a sysctl command + if "route" in cmd or "multicast" in cmd: + # Multicast/route failures should still fail + logger.error(f" ✗ Failed to configure multicast: {e}") + logger.error(f" stdout: {e.stdout}") + logger.error(f" stderr: {e.stderr}") + raise + elif "sysctl" in cmd: + # Sysctl failures are just warnings (likely docker/container) + logger.warning( + f" ✗ Not able to auto-configure UDP buffer sizes (likely docker image): {e}" + ) + except Exception as e: + logger.error(f" ✗ Error: {e}") + if "route" in cmd or "multicast" in cmd: + raise + + logger.info("System configuration completed.") + + +@dataclass +class LCMConfig: + ttl: int = 0 + url: str | None = None + autoconf: bool = True + lcm: lcm.LCM | None = None + + +@runtime_checkable +class LCMMsg(Protocol): + msg_name: str + + @classmethod + def lcm_decode(cls, data: bytes) -> LCMMsg: + """Decode bytes into an LCM message instance.""" + ... + + def lcm_encode(self) -> bytes: + """Encode this message instance into bytes.""" + ... + + +@dataclass +class Topic: + topic: str = "" + lcm_type: type[LCMMsg] | None = None + + def __str__(self) -> str: + if self.lcm_type is None: + return self.topic + return f"{self.topic}#{self.lcm_type.msg_name}" + + +class LCMService(Service[LCMConfig]): + default_config = LCMConfig + l: lcm.LCM | None + _stop_event: threading.Event + _l_lock: threading.Lock + _thread: threading.Thread | None + _call_thread_pool: ThreadPoolExecutor | None = None + _call_thread_pool_lock: threading.RLock = threading.RLock() + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + + # we support passing an existing LCM instance + if self.config.lcm: + # TODO: If we pass LCM in, it's unsafe to use in this thread and the _loop thread. + self.l = self.config.lcm + else: + self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + + self._l_lock = threading.Lock() + + self._stop_event = threading.Event() + self._thread = None + + def __getstate__(self): # type: ignore[no-untyped-def] + """Exclude unpicklable runtime attributes when serializing.""" + state = self.__dict__.copy() + # Remove unpicklable attributes + state.pop("l", None) + state.pop("_stop_event", None) + state.pop("_thread", None) + state.pop("_l_lock", None) + state.pop("_call_thread_pool", None) + state.pop("_call_thread_pool_lock", None) + return state + + def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] + """Restore object from pickled state.""" + self.__dict__.update(state) + # Reinitialize runtime attributes + self.l = None + self._stop_event = threading.Event() + self._thread = None + self._l_lock = threading.Lock() + self._call_thread_pool = None + self._call_thread_pool_lock = threading.RLock() + + def start(self) -> None: + # Reinitialize LCM if it's None (e.g., after unpickling) + if self.l is None: + if self.config.lcm: + self.l = self.config.lcm + else: + self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + + if self.config.autoconf: + autoconf() + else: + try: + check_system() + except Exception as e: + print(f"Error checking system configuration: {e}") + + self._stop_event.clear() + self._thread = threading.Thread(target=self._lcm_loop) + self._thread.daemon = True + self._thread.start() + + def _lcm_loop(self) -> None: + """LCM message handling loop.""" + while not self._stop_event.is_set(): + try: + with self._l_lock: + if self.l is None: + break + self.l.handle_timeout(50) + except Exception as e: + stack_trace = traceback.format_exc() + print(f"Error in LCM handling: {e}\n{stack_trace}") + + def stop(self) -> None: + """Stop the LCM loop.""" + self._stop_event.set() + if self._thread is not None: + # Only join if we're not the LCM thread (avoid "cannot join current thread") + if threading.current_thread() != self._thread: + self._thread.join(timeout=1.0) + if self._thread.is_alive(): + logger.warning("LCM thread did not stop cleanly within timeout") + + # Clean up LCM instance if we created it + if not self.config.lcm: + with self._l_lock: + if self.l is not None: + del self.l + self.l = None + + with self._call_thread_pool_lock: + if self._call_thread_pool: + # Check if we're being called from within the thread pool + # If so, we can't wait for shutdown (would cause "cannot join current thread") + current_thread = threading.current_thread() + is_pool_thread = False + + # Check if current thread is one of the pool's threads + # ThreadPoolExecutor threads have names like "ThreadPoolExecutor-N_M" + if hasattr(self._call_thread_pool, "_threads"): + is_pool_thread = current_thread in self._call_thread_pool._threads + elif "ThreadPoolExecutor" in current_thread.name: + # Fallback: check thread name pattern + is_pool_thread = True + + # Don't wait if we're in a pool thread to avoid deadlock + self._call_thread_pool.shutdown(wait=not is_pool_thread) + self._call_thread_pool = None + + def _get_call_thread_pool(self) -> ThreadPoolExecutor: + with self._call_thread_pool_lock: + if self._call_thread_pool is None: + self._call_thread_pool = ThreadPoolExecutor(max_workers=4) + return self._call_thread_pool diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py new file mode 100644 index 0000000000..c4e6758614 --- /dev/null +++ b/dimos/protocol/service/spec.py @@ -0,0 +1,38 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from typing import Generic, TypeVar + +# Generic type for service configuration +ConfigT = TypeVar("ConfigT") + + +class Configurable(Generic[ConfigT]): + default_config: type[ConfigT] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + self.config: ConfigT = self.default_config(**kwargs) + + +class Service(Configurable[ConfigT], ABC): + def start(self) -> None: + # Only call super().start() if it exists + if hasattr(super(), "start"): + super().start() # type: ignore[misc] + + def stop(self) -> None: + # Only call super().stop() if it exists + if hasattr(super(), "stop"): + super().stop() # type: ignore[misc] diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py new file mode 100644 index 0000000000..ed3f22650b --- /dev/null +++ b/dimos/protocol/service/test_lcmservice.py @@ -0,0 +1,426 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +from unittest.mock import patch + +import pytest + +from dimos.protocol.service.lcmservice import ( + autoconf, + check_buffers, + check_multicast, + check_root, +) + + +def get_sudo_prefix() -> str: + """Return 'sudo ' if not running as root, empty string if running as root.""" + return "" if check_root() else "sudo " + + +def test_check_multicast_all_configured() -> None: + """Test check_multicast when system is properly configured.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock successful checks with realistic output format + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + ] + + result = check_multicast() + assert result == [] + + +def test_check_multicast_missing_multicast_flag() -> None: + """Test check_multicast when loopback interface lacks multicast.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock interface without MULTICAST flag (realistic current system state) + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + ] + + result = check_multicast() + sudo = get_sudo_prefix() + assert result == [f"{sudo}ifconfig lo multicast"] + + +def test_check_multicast_missing_route() -> None: + """Test check_multicast when multicast route is missing.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock missing route - interface has multicast but no route + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), # Empty output - no route + ] + + result = check_multicast() + sudo = get_sudo_prefix() + assert result == [f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"] + + +def test_check_multicast_all_missing() -> None: + """Test check_multicast when both multicast flag and route are missing (current system state).""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock both missing - matches actual current system state + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), # Empty output - no route + ] + + result = check_multicast() + sudo = get_sudo_prefix() + expected = [ + f"{sudo}ifconfig lo multicast", + f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + +def test_check_multicast_subprocess_exception() -> None: + """Test check_multicast when subprocess calls fail.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + result = check_multicast() + sudo = get_sudo_prefix() + expected = [ + f"{sudo}ifconfig lo multicast", + f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + +def test_check_buffers_all_configured() -> None: + """Test check_buffers when system is properly configured.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock sufficient buffer sizes (64MB for max, 16MB for default) + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 67108864", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 16777216", "returncode": 0} + )(), + ] + + commands, buffer_size = check_buffers() + assert commands == [] + assert buffer_size == 67108864 + + +def test_check_buffers_low_max_buffer() -> None: + """Test check_buffers when rmem_max is too low.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock low rmem_max (below 64MB minimum) + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 16777216", "returncode": 0} + )(), + ] + + commands, buffer_size = check_buffers() + sudo = get_sudo_prefix() + assert commands == [f"{sudo}sysctl -w net.core.rmem_max=67108864"] + assert buffer_size == 1048576 + + +def test_check_buffers_low_default_buffer() -> None: + """Test check_buffers when rmem_default is too low.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock low rmem_default (below 16MB minimum) + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 67108864", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + commands, buffer_size = check_buffers() + sudo = get_sudo_prefix() + assert commands == [f"{sudo}sysctl -w net.core.rmem_default=16777216"] + assert buffer_size == 67108864 + + +def test_check_buffers_both_low() -> None: + """Test check_buffers when both buffer sizes are too low.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock both low (below minimums) + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + commands, buffer_size = check_buffers() + sudo = get_sudo_prefix() + expected = [ + f"{sudo}sysctl -w net.core.rmem_max=67108864", + f"{sudo}sysctl -w net.core.rmem_default=16777216", + ] + assert commands == expected + assert buffer_size == 1048576 + + +def test_check_buffers_subprocess_exception() -> None: + """Test check_buffers when subprocess calls fail.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + commands, buffer_size = check_buffers() + sudo = get_sudo_prefix() + expected = [ + f"{sudo}sysctl -w net.core.rmem_max=67108864", + f"{sudo}sysctl -w net.core.rmem_default=16777216", + ] + assert commands == expected + assert buffer_size is None + + +def test_check_buffers_parsing_error() -> None: + """Test check_buffers when output parsing fails.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock malformed output + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "invalid output", "returncode": 0})(), + type("MockResult", (), {"stdout": "also invalid", "returncode": 0})(), + ] + + commands, buffer_size = check_buffers() + sudo = get_sudo_prefix() + expected = [ + f"{sudo}sysctl -w net.core.rmem_max=67108864", + f"{sudo}sysctl -w net.core.rmem_default=16777216", + ] + assert commands == expected + assert buffer_size is None + + +def test_check_buffers_dev_container() -> None: + """Test check_buffers in dev container where sysctl fails.""" + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock dev container behavior - sysctl returns non-zero + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "sysctl: cannot stat /proc/sys/net/core/rmem_max: No such file or directory", + "returncode": 255, + }, + )(), + type( + "MockResult", + (), + { + "stdout": "sysctl: cannot stat /proc/sys/net/core/rmem_default: No such file or directory", + "returncode": 255, + }, + )(), + ] + + commands, buffer_size = check_buffers() + sudo = get_sudo_prefix() + expected = [ + f"{sudo}sysctl -w net.core.rmem_max=67108864", + f"{sudo}sysctl -w net.core.rmem_default=16777216", + ] + assert commands == expected + assert buffer_size is None + + +def test_autoconf_no_config_needed() -> None: + """Test autoconf when no configuration is needed.""" + # Clear CI environment variable for this test + with patch.dict(os.environ, {"CI": ""}, clear=False): + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock all checks passing with new buffer sizes (64MB and 16MB) + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536", + "returncode": 0, + }, + )(), + type( + "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0} + )(), + # check_buffers calls + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 67108864", "returncode": 0} + )(), + type( + "MockResult", + (), + {"stdout": "net.core.rmem_default = 16777216", "returncode": 0}, + )(), + ] + + with patch("dimos.protocol.service.lcmservice.logger") as mock_logger: + autoconf() + # Should not log anything when no config is needed + mock_logger.info.assert_not_called() + mock_logger.error.assert_not_called() + mock_logger.warning.assert_not_called() + + +def test_autoconf_with_config_needed_success() -> None: + """Test autoconf when configuration is needed and commands succeed.""" + # Clear CI environment variable for this test + with patch.dict(os.environ, {"CI": ""}, clear=False): + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock checks failing, then mock the execution succeeding + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls (low buffer sizes) + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + # Command execution calls + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # ifconfig lo multicast + type("MockResult", (), {"stdout": "success", "returncode": 0})(), # route add... + type("MockResult", (), {"stdout": "success", "returncode": 0})(), # sysctl rmem_max + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sysctl rmem_default + ] + + from unittest.mock import call + + with patch("dimos.protocol.service.lcmservice.logger") as mock_logger: + autoconf() + + sudo = get_sudo_prefix() + # Verify the expected log calls with new buffer sizes + expected_info_calls = [ + call("System configuration required. Executing commands..."), + call(f" Running: {sudo}ifconfig lo multicast"), + call(" ✓ Success"), + call(f" Running: {sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"), + call(" ✓ Success"), + call(f" Running: {sudo}sysctl -w net.core.rmem_max=67108864"), + call(" ✓ Success"), + call(f" Running: {sudo}sysctl -w net.core.rmem_default=16777216"), + call(" ✓ Success"), + call("System configuration completed."), + ] + + mock_logger.info.assert_has_calls(expected_info_calls) + + +def test_autoconf_with_command_failures() -> None: + """Test autoconf when some commands fail.""" + # Clear CI environment variable for this test + with patch.dict(os.environ, {"CI": ""}, clear=False): + with patch("dimos.protocol.service.lcmservice.subprocess.run") as mock_run: + # Mock checks failing, then mock some commands failing + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls (no buffer issues for simpler test, use new minimums) + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 67108864", "returncode": 0} + )(), + type( + "MockResult", + (), + {"stdout": "net.core.rmem_default = 16777216", "returncode": 0}, + )(), + # Command execution calls - first succeeds, second fails + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # ifconfig lo multicast + subprocess.CalledProcessError( + 1, + [ + *get_sudo_prefix().split(), + "route", + "add", + "-net", + "224.0.0.0", + "netmask", + "240.0.0.0", + "dev", + "lo", + ], + "Permission denied", + "Operation not permitted", + ), + ] + + with patch("dimos.protocol.service.lcmservice.logger") as mock_logger: + # The function should raise on multicast/route failures + with pytest.raises(subprocess.CalledProcessError): + autoconf() + + # Verify it logged the failure before raising + info_calls = [call[0][0] for call in mock_logger.info.call_args_list] + error_calls = [call[0][0] for call in mock_logger.error.call_args_list] + + assert "System configuration required. Executing commands..." in info_calls + assert " ✓ Success" in info_calls # First command succeeded + assert any( + "✗ Failed to configure multicast" in call for call in error_calls + ) # Second command failed diff --git a/dimos/protocol/service/test_spec.py b/dimos/protocol/service/test_spec.py new file mode 100644 index 0000000000..efb24d7e38 --- /dev/null +++ b/dimos/protocol/service/test_spec.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from dimos.protocol.service.spec import Service + + +@dataclass +class DatabaseConfig: + host: str = "localhost" + port: int = 5432 + database_name: str = "test_db" + timeout: float = 30.0 + max_connections: int = 10 + ssl_enabled: bool = False + + +class DatabaseService(Service[DatabaseConfig]): + default_config = DatabaseConfig + + def start(self) -> None: ... + def stop(self) -> None: ... + + +def test_default_configuration() -> None: + """Test that default configuration is applied correctly.""" + service = DatabaseService() + + # Check that all default values are set + assert service.config.host == "localhost" + assert service.config.port == 5432 + assert service.config.database_name == "test_db" + assert service.config.timeout == 30.0 + assert service.config.max_connections == 10 + assert service.config.ssl_enabled is False + + +def test_partial_configuration_override() -> None: + """Test that partial configuration correctly overrides defaults.""" + service = DatabaseService(host="production-db", port=3306, ssl_enabled=True) + + # Check overridden values + assert service.config.host == "production-db" + assert service.config.port == 3306 + assert service.config.ssl_enabled is True + + # Check that defaults are preserved for non-overridden values + assert service.config.database_name == "test_db" + assert service.config.timeout == 30.0 + assert service.config.max_connections == 10 + + +def test_complete_configuration_override() -> None: + """Test that all configuration values can be overridden.""" + service = DatabaseService( + host="custom-host", + port=9999, + database_name="custom_db", + timeout=60.0, + max_connections=50, + ssl_enabled=True, + ) + + # Check that all values match the custom config + assert service.config.host == "custom-host" + assert service.config.port == 9999 + assert service.config.database_name == "custom_db" + assert service.config.timeout == 60.0 + assert service.config.max_connections == 50 + assert service.config.ssl_enabled is True + + +def test_service_subclassing() -> None: + @dataclass + class ExtraConfig(DatabaseConfig): + extra_param: str = "default_value" + + class ExtraDatabaseService(DatabaseService): + default_config = ExtraConfig + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + bla = ExtraDatabaseService(host="custom-host2", extra_param="extra_value") + + assert bla.config.host == "custom-host2" + assert bla.config.extra_param == "extra_value" + assert bla.config.port == 5432 # Default value from DatabaseConfig diff --git a/dimos/protocol/skill/__init__.py b/dimos/protocol/skill/__init__.py new file mode 100644 index 0000000000..15ebf0b59c --- /dev/null +++ b/dimos/protocol/skill/__init__.py @@ -0,0 +1 @@ +from dimos.protocol.skill.skill import SkillContainer, skill diff --git a/dimos/protocol/skill/comms.py b/dimos/protocol/skill/comms.py new file mode 100644 index 0000000000..0720140b79 --- /dev/null +++ b/dimos/protocol/skill/comms.py @@ -0,0 +1,99 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, TypeVar + +from dimos.protocol.pubsub.lcmpubsub import PickleLCM +from dimos.protocol.service import Service # type: ignore[attr-defined] +from dimos.protocol.skill.type import SkillMsg + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.protocol.pubsub.spec import PubSub + +# defines a protocol for communication between skills and agents +# it has simple requirements of pub/sub semantics capable of sending and receiving SkillMsg objects + + +class SkillCommsSpec: + @abstractmethod + def publish(self, msg: SkillMsg) -> None: ... # type: ignore[type-arg] + + @abstractmethod + def subscribe(self, cb: Callable[[SkillMsg], None]) -> None: ... # type: ignore[type-arg] + + @abstractmethod + def start(self) -> None: ... + + @abstractmethod + def stop(self) -> None: ... + + +MsgT = TypeVar("MsgT") +TopicT = TypeVar("TopicT") + + +@dataclass +class PubSubCommsConfig(Generic[TopicT, MsgT]): + topic: TopicT | None = None + pubsub: type[PubSub[TopicT, MsgT]] | PubSub[TopicT, MsgT] | None = None + autostart: bool = True + + +# implementation of the SkillComms using any standard PubSub mechanism +class PubSubComms(Service[PubSubCommsConfig], SkillCommsSpec): # type: ignore[type-arg] + default_config: type[PubSubCommsConfig] = PubSubCommsConfig # type: ignore[type-arg] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + pubsub_config = getattr(self.config, "pubsub", None) + if pubsub_config is not None: + if callable(pubsub_config): + self.pubsub = pubsub_config() + else: + self.pubsub = pubsub_config + else: + raise ValueError("PubSub configuration is missing") + + if getattr(self.config, "autostart", True): + self.start() + + def start(self) -> None: + self.pubsub.start() + + def stop(self) -> None: + self.pubsub.stop() + + def publish(self, msg: SkillMsg) -> None: # type: ignore[type-arg] + self.pubsub.publish(self.config.topic, msg) + + def subscribe(self, cb: Callable[[SkillMsg], None]) -> None: # type: ignore[type-arg] + self.pubsub.subscribe(self.config.topic, lambda msg, topic: cb(msg)) + + +@dataclass +class LCMCommsConfig(PubSubCommsConfig[str, SkillMsg]): # type: ignore[type-arg] + topic: str = "/skill" + pubsub: type[PubSub] | PubSub | None = PickleLCM # type: ignore[type-arg] + # lcm needs to be started only if receiving + # skill comms are broadcast only in modules so we don't autostart + autostart: bool = False + + +class LCMSkillComms(PubSubComms): + default_config: type[LCMCommsConfig] = LCMCommsConfig diff --git a/dimos/protocol/skill/coordinator.py b/dimos/protocol/skill/coordinator.py new file mode 100644 index 0000000000..9ee76ac4d1 --- /dev/null +++ b/dimos/protocol/skill/coordinator.py @@ -0,0 +1,1003 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from copy import copy +from dataclasses import dataclass +from enum import Enum +import json +import threading +import time +from typing import Annotated, Any, Literal + +from annotated_doc import Doc +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool as langchain_tool +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from dimos.core import rpc +from dimos.core.module import Module, get_loop +from dimos.protocol.skill.comms import LCMSkillComms, SkillCommsSpec +from dimos.protocol.skill.skill import SkillConfig, SkillContainer # type: ignore[attr-defined] +from dimos.protocol.skill.type import MsgType, Output, Reducer, Return, SkillMsg, Stream +from dimos.protocol.skill.utils import interpret_tool_call_args +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +@dataclass +class SkillCoordinatorConfig: + """Configuration for the SkillCoordinator module. + + The SkillCoordinator is the central orchestration layer between agents and skills, + managing skill lifecycle, state tracking, and cross-event-loop message routing. This + configuration class controls how skills communicate with the coordinator. + """ + + skill_transport: Annotated[ + type[SkillCommsSpec], + Doc( + """Communication transport implementation for skill messages between skills (thread pools) + and coordinator. Must implement SkillCommsSpec (publish/subscribe semantics). + Defaults to LCMSkillComms (LCM over "/skill" channel). Custom transports can + implement the SkillCommsSpec interface.""" + ), + ] = LCMSkillComms + + +class SkillStateEnum(Enum): + """Lifecycle state of a skill invocation. + + State Transition Flow (unidirectional, message-driven): + pending → running → (completed | error) + + - pending: Scheduled but not yet executing + - running: Actively executing in thread pool + - completed: Finished successfully + - error: Terminated with an exception + + State transitions correspond directly to message types (MsgType.start, stream, ret, error). + Terminal states (completed, error) trigger automatic cleanup when clear=True in + generate_snapshot(), guaranteeing exactly one terminal state per invocation. + """ + + pending = 0 + running = 1 + completed = 2 + error = 3 + + def colored_name( + self, + ) -> Annotated[Text, Doc("The state name as a rich Text object with color styling.")]: + """Return the state name as a rich Text object with color.""" + colors = { + SkillStateEnum.pending: "yellow", + SkillStateEnum.running: "blue", + SkillStateEnum.completed: "green", + SkillStateEnum.error: "red", + } + return Text(self.name, style=colors.get(self, "white")) + + +# This object maintains the state of a skill run on a caller end +class SkillState: + """Tracks execution state of a single skill invocation. + + Manages the skill lifecycle (pending → running → completed/error), accumulates stream + messages via the configured reducer, and encodes state for agent consumption using the + dual-protocol pattern. + + Dual-Protocol Pattern: + - First agent_encode() call: Returns ToolMessage (LangChain protocol compatibility) + - Subsequent calls: Return JSON status updates (skill name, call_id, state, data, duration) + + State Transitions (via handle_msg()): + - MsgType.start: pending → running + - MsgType.stream: maintains running, applies reducer + - MsgType.ret: running → completed + - MsgType.error: any state → error + + Notification Logic (handle_msg() returns True when): + - Stream messages with Stream.call_agent + - Return messages with Return.call_agent + - Error messages (always) + - Start messages never trigger notification + + Example: + >>> from dimos.protocol.skill.type import Stream, Reducer, Output + >>> config = SkillConfig( + ... name="navigate_to", + ... ret=Return.call_agent, + ... stream=Stream.none, + ... reducer=Reducer.all, + ... output=Output.standard, + ... schema={} + ... ) + >>> skill_state = SkillState(call_id="abc123", name="navigate_to", skill_config=config) + >>> start_msg = SkillMsg(call_id="abc123", skill_name="navigate_to", content={}, type=MsgType.start) + >>> skill_state.handle_msg(start_msg) # Transitions to running + False + >>> tool_msg = skill_state.agent_encode() # First call returns ToolMessage + >>> isinstance(tool_msg, ToolMessage) + True + >>> json_update = skill_state.agent_encode() # Subsequent calls return JSON + >>> import json + >>> data = json.loads(json_update) + >>> data['name'] + 'navigate_to' + """ + + call_id: Annotated[ + str, + Doc( + "Unique identifier for this skill invocation, used for message routing and tool result correlation" + ), + ] + name: Annotated[str, Doc("Name of the skill being executed")] + state: Annotated[ + SkillStateEnum, + Doc( + "Current lifecycle state tracking execution progress (pending, running, completed, error)" + ), + ] + skill_config: Annotated[ + SkillConfig, + Doc( + "Configuration controlling skill behavior including streaming mode, return mode, reducer function, and output format" + ), + ] + + msg_count: Annotated[ + int, + Doc( + "Total number of SkillMsg messages received from the skill execution, used for progress tracking" + ), + ] = 0 + sent_tool_msg: Annotated[ + bool, + Doc( + "Flag tracking whether the initial ToolMessage has been sent, ensures correct protocol adherence" + ), + ] = False + + start_msg: Annotated[ + SkillMsg[Literal[MsgType.start]] | None, + Doc("The MsgType.start message marking execution begin, used for duration calculation"), + ] = None + end_msg: Annotated[ + SkillMsg[Literal[MsgType.ret]] | None, + Doc("Terminal message (either ret or error), marks completion timestamp"), + ] = None + error_msg: Annotated[ + SkillMsg[Literal[MsgType.error]] | None, + Doc( + "MsgType.error message if skill terminated with exception, contains error details for agent" + ), + ] = None + ret_msg: Annotated[ + SkillMsg[Literal[MsgType.ret]] | None, + Doc("MsgType.ret message with final return value, only present for successful completion"), + ] = None + reduced_stream_msg: Annotated[ + list[SkillMsg[Literal[MsgType.reduced_stream]]] | None, + Doc( + "Accumulated stream messages after applying the configured reducer function, provides incremental progress for streaming skills" + ), + ] = None + + def __init__( + self, + call_id: Annotated[str, Doc("Unique identifier for this skill invocation")], + name: Annotated[str, Doc("Name of the skill being executed")], + skill_config: Annotated[ + SkillConfig | None, + Doc( + "Optional configuration controlling skill behavior. If None, defaults to no streaming, no return, all reducer, standard output" + ), + ] = None, + ) -> None: + super().__init__() + + self.skill_config = skill_config or SkillConfig( + name=name, + stream=Stream.none, + ret=Return.none, + reducer=Reducer.all, + output=Output.standard, + schema={}, + ) + + self.state = SkillStateEnum.pending + self.call_id = call_id + self.name = name + + def duration( + self, + ) -> Annotated[ + float, + Doc( + """Duration in seconds. Returns elapsed time if completed, + time since start if running, or 0.0 if not started.""" + ), + ]: + """Calculate the duration of the skill run.""" + if self.start_msg and self.end_msg: + return self.end_msg.ts - self.start_msg.ts + elif self.start_msg: + return time.time() - self.start_msg.ts + else: + return 0.0 + + def content( + self, + ) -> Annotated[ + dict[str, Any] | str | int | float | None, + Doc(""" + The content from the skill's execution state. + + Returns the reduced stream message content when running, + the return message content (or reduced stream if streaming) when completed, + or the error message content (with optional stream context) when errored. + Returns None for pending state or when no content is available. + """), + ]: + """Get the content from the current skill execution state.""" + if self.state == SkillStateEnum.running: + if self.reduced_stream_msg: + return self.reduced_stream_msg.content # type: ignore[attr-defined, no-any-return] + + if self.state == SkillStateEnum.completed: + if self.reduced_stream_msg: # are we a streaming skill? + return self.reduced_stream_msg.content # type: ignore[attr-defined, no-any-return] + return self.ret_msg.content # type: ignore[return-value] + + if self.state == SkillStateEnum.error: + print("Error msg:", self.error_msg.content) + if self.reduced_stream_msg: + return self.reduced_stream_msg.content + "\n" + self.error_msg.content # type: ignore[attr-defined] + else: + return self.error_msg.content # type: ignore[return-value] + + return None + + def agent_encode( + self, + ) -> Annotated[ + ToolMessage | str, + Doc( + """ToolMessage on first call, JSON string on subsequent calls. + + This dual-protocol pattern bridges LangChain's tool call requirements + (one ToolMessage per tool_call_id) with the need for ongoing status updates + from long-running skills. + + First call returns a ToolMessage that completes the tool invocation protocol + and enters permanent conversation history. Subsequent calls return JSON-encoded + state snapshots that get aggregated into an AIMessage providing situational + awareness about active skills, without violating the one-ToolMessage constraint. + """ + ), + ]: + """Encode skill state for agent consumption using dual-protocol pattern.""" + if not self.sent_tool_msg: + self.sent_tool_msg = True + return ToolMessage( + self.content() or "Querying, please wait, you will receive a response soon.", # type: ignore[arg-type] + name=self.name, + tool_call_id=self.call_id, + ) + else: + return json.dumps( + { + "name": self.name, + "call_id": self.call_id, + "state": self.state.name, + "data": self.content(), + "ran_for": self.duration(), + } + ) + + # returns True if the agent should be called for this message + def handle_msg( + self, + msg: Annotated[SkillMsg, Doc("The skill message to process")], # type: ignore[type-arg] + ) -> Annotated[ + bool, + Doc( + """Whether the coordinator should notify the agent about this message. + True for errors (always), stream messages with Stream.call_agent, + and return messages with Return.call_agent. False otherwise.""" + ), + ]: + """Process an incoming skill message and update internal state. + + Updates the skill's execution state based on the message type. For stream + messages, applies the configured reducer to accumulate outputs. The return + value determines whether the coordinator should schedule an agent call to + process this message. + + Notification logic: + - Start messages: Never notify (skill is initializing) + - Stream messages: Notify only if configured with Stream.call_agent + - Return messages: Notify only if configured with Return.call_agent + - Error messages: Always notify (errors require agent attention) + """ + self.msg_count += 1 + if msg.type == MsgType.stream: + self.state = SkillStateEnum.running + self.reduced_stream_msg = self.skill_config.reducer(self.reduced_stream_msg, msg) # type: ignore[arg-type, assignment] + + if ( + self.skill_config.stream == Stream.none + or self.skill_config.stream == Stream.passive + ): + return False + + if self.skill_config.stream == Stream.call_agent: + return True + + if msg.type == MsgType.ret: + self.state = SkillStateEnum.completed + self.ret_msg = msg + if self.skill_config.ret == Return.call_agent: + return True + return False + + if msg.type == MsgType.error: + self.state = SkillStateEnum.error + self.error_msg = msg + return True + + if msg.type == MsgType.start: + self.state = SkillStateEnum.running + self.start_msg = msg + return False + + return False + + def __len__(self) -> int: + return self.msg_count + + def __str__(self) -> str: + # For standard string representation, we'll use rich's Console to render the colored text + console = Console(force_terminal=True, legacy_windows=False) + colored_state = self.state.colored_name() + + # Build the parts of the string + parts = [Text(f"SkillState({self.name} "), colored_state, Text(f", call_id={self.call_id}")] + + if self.state == SkillStateEnum.completed or self.state == SkillStateEnum.error: + parts.append(Text(", ran for=")) + else: + parts.append(Text(", running for=")) + + parts.append(Text(f"{self.duration():.2f}s")) + + if len(self): + parts.append(Text(f", msg_count={self.msg_count})")) + else: + parts.append(Text(", No Messages)")) + + # Combine all parts into a single Text object + combined = Text() + for part in parts: + combined.append(part) + + # Render to string with console + with console.capture() as capture: + console.print(combined, end="") + return capture.get() + + +# subclassed the dict just to have a better string representation +class SkillStateDict(dict[str, SkillState]): + """Dictionary mapping call_id to SkillState with Rich-formatted table display. + + Provides table() and __str__() methods for debugging and monitoring skill execution + in SkillCoordinator. + Table columns: Call ID, Skill, State (colored), Duration, Messages. + """ + + def table(self) -> Annotated[Table, Doc("Rich Table with formatted skill state columns")]: + # Add skill states section + states_table = Table(show_header=True) + states_table.add_column("Call ID", style="dim", width=12) + states_table.add_column("Skill", style="white") + states_table.add_column("State", style="white") + states_table.add_column("Duration", style="yellow") + states_table.add_column("Messages", style="dim") + + for call_id, skill_state in self.items(): + # Get colored state name + state_text = skill_state.state.colored_name() + + # Duration formatting + if ( + skill_state.state == SkillStateEnum.completed + or skill_state.state == SkillStateEnum.error + ): + duration = f"{skill_state.duration():.2f}s" + else: + duration = f"{skill_state.duration():.2f}s..." + + # Messages info + msg_count = str(len(skill_state)) + + states_table.add_row( + call_id[:8] + "...", skill_state.name, state_text, duration, msg_count + ) + + if not self: + states_table.add_row("", "[dim]No active skills[/dim]", "", "", "") + return states_table + + def __str__(self) -> str: + console = Console(force_terminal=True, legacy_windows=False) + + # Render to string with title above + with console.capture() as capture: + console.print(Text(" SkillState", style="bold blue")) + console.print(self.table()) + return capture.get().strip() + + +# This class is responsible for managing the lifecycle of skills, +# handling skill calls, and coordinating communication between the agent and skills. +# +# It aggregates skills from static and dynamic containers, manages skill states, +# and decides when to notify the agent about updates. +class SkillCoordinator(Module): + """Central orchestration layer between agents and skills. + + Manages skill lifecycle, state tracking, and message routing across event loops, + decoupling agents (asyncio) from skills (thread pools) using lazy event creation + and thread-safe cross-loop notification. + + Container Types: + - Static: Fixed skills cached at registration for O(1) lookup + - Dynamic: Runtime-generated skills queried on-demand for context-dependent generation + + Cross-Event-Loop Synchronization: + - asyncio.Event created lazily in agent's loop on first wait_for_updates() + - call_soon_threadsafe bridges transport loop and agent loop + - Message-driven state tracking via SkillState objects + + Examples: + Basic coordinator setup and skill invocation: + + >>> from dimos.core.module import Module + >>> from dimos.protocol.skill.skill import skill + >>> + >>> # Note that you'll need to do a bit more for the skill to be available to llm agents -- see the tutorial. + >>> class NavigationModule(Module): + ... @skill() + ... def navigate_to(self, location: str) -> str: + ... return f"Navigating to {location}" + >>> + >>> # Set up coordinator + >>> coordinator = SkillCoordinator() + >>> coordinator.register_skills(NavigationModule()) + >>> coordinator.start() + >>> coordinator.call_skill(call_id="123", skill_name="navigate_to", args={"args": ["kitchen"]}) + >>> + >>> # Verify skill state was created + >>> snapshot = coordinator.generate_snapshot(clear=False) + >>> "123" in snapshot + True + >>> coordinator.stop() + + Agent integration with update loop (async): + + >>> import asyncio + >>> # (In actual async context) + >>> # await coordinator.wait_for_updates(timeout=1.0) + >>> # snapshot = coordinator.generate_snapshot(clear=True) + >>> # for call_id, state in snapshot.items(): + >>> # message = state.agent_encode() # First: ToolMessage, then: JSON + + Notes: + - Not thread-safe for _skill_state (single coordinator loop assumed) + - generate_snapshot(clear=True) provides atomic read-and-clear, removing terminal states + - Completed/errored skills removed after snapshot(clear=True) + - Message flow pattern: Skills publish messages in a fixed sequence: + 1. One `start` message when execution begins + 2. Zero or more `stream` messages during execution (for incremental progress) + 3. Exactly one terminal message: either `ret` (success) or `error` (failure) + """ + + default_config = SkillCoordinatorConfig # type: ignore[assignment] + empty: bool = True + + _static_containers: Annotated[ + list[SkillContainer], + Doc( + "Containers with fixed skills known at registration time. Skills are cached immediately for performance." + ), + ] + _dynamic_containers: Annotated[ + list[SkillContainer], + Doc( + "Containers whose skills depend on runtime context. Queried on each skills() call; not cached." + ), + ] + _skill_state: Annotated[ + SkillStateDict, + Doc( + "Maps call_id to SkillState objects tracking each skill invocation. Key is call_id (unique per invocation), not skill_name (reusable)." + ), + ] + _skills: Annotated[ + dict[str, SkillConfig], Doc("Cached static skills for O(1) lookup performance.") + ] + _updates_available: Annotated[ + asyncio.Event | None, + Doc( + "Event signaling skill updates ready for agent processing. Created lazily in agent's event loop on first wait_for_updates() call." + ), + ] + _loop: Annotated[ + asyncio.AbstractEventLoop | None, Doc("Coordinator's own event loop for message handling.") + ] + _loop_thread: threading.Thread | None + _agent_loop: Annotated[ + asyncio.AbstractEventLoop | None, + Doc("Agent's event loop, captured when updates_available event is created."), + ] + + def __init__(self) -> None: + # TODO: Why isn't this super().__init__() ? + SkillContainer.__init__(self) + self._loop, self._loop_thread = get_loop() + self._static_containers = [] + self._dynamic_containers = [] + self._skills = {} + self._skill_state = SkillStateDict() + # Defer event creation until we're in the correct loop context + self._updates_available = None + self._agent_loop = None + self._pending_notifications = 0 # Count pending notifications + self._closed_coord = False + self._transport_unsub_fn = None + + def _ensure_updates_available(self) -> asyncio.Event: + """Lazily create the updates available event in the correct loop context.""" + if self._updates_available is None: + # Create the event in the current running loop, not the stored loop + try: + loop = asyncio.get_running_loop() + # print(f"[DEBUG] Creating _updates_available event in current loop {id(loop)}") + # Always use the current running loop for the event + # This ensures the event is created in the context where it will be used + self._updates_available = asyncio.Event() + # Store the loop where the event was created - this is the agent's loop + self._agent_loop = loop + # print( + # f"[DEBUG] Created _updates_available event {id(self._updates_available)} in agent loop {id(loop)}" + # ) + except RuntimeError: + # No running loop, defer event creation until we have the proper context + # print(f"[DEBUG] No running loop, deferring event creation") + # Don't create the event yet - wait for the proper loop context + pass + else: + ... + # print(f"[DEBUG] Reusing _updates_available event {id(self._updates_available)}") + return self._updates_available # type: ignore[return-value] + + @rpc + def start(self) -> None: + super().start() + self.skill_transport.start() + self._transport_unsub_fn = self.skill_transport.subscribe(self.handle_message) + + @rpc + def stop(self) -> None: + self._close_module() + self._closed_coord = True + self.skill_transport.stop() + if self._transport_unsub_fn: + self._transport_unsub_fn() + + # Stop all registered skill containers + for container in self._static_containers: + container.stop() + for container in self._dynamic_containers: + container.stop() + + super().stop() + + def len(self) -> int: + return len(self._skills) + + def __len__(self) -> int: + return self.len() + + # this can be converted to non-langchain json schema output + # and langchain takes this output as well + # just faster for now + def get_tools(self) -> list[dict]: # type: ignore[type-arg] + return [ + langchain_tool(skill_config.f) # type: ignore[arg-type, misc] + for skill_config in self.skills().values() + if not skill_config.hide_skill + ] + + # internal skill call + def call_skill( + self, + call_id: Annotated[ + str | Literal[False], + Doc("""Unique identifier for this skill invocation. If False, a + timestamp-based ID will be auto-generated. This ID is used to + track skill execution state and correlate responses."""), + ], + skill_name: Annotated[ + str, + Doc("""Name of the skill to invoke, as registered in the + coordinator's skill registry."""), + ], + args: Annotated[ + dict[str, Any], + Doc("""Dictionary containing skill invocation arguments. Expected to + contain an "args" key with either a list of positional arguments + or a dict of keyword arguments. Will be interpreted by + `interpret_tool_call_args` to extract positional and keyword args."""), + ], + ) -> None: + """Execute a skill invocation requested by an agent. + + Creates a SkillState to track execution and delegates to the skill's call method. + Auto-generates call_id from timestamp if not provided. Logs error and returns + early if skill not found (e.g., expired dynamic skill). + """ + if not call_id: + call_id = str(time.time()) + skill_config = self.get_skill_config(skill_name) + if not skill_config: + logger.error( + f"Skill {skill_name} not found in registered skills, but agent tried to call it (did a dynamic skill expire?)" + ) + return + + self._skill_state[call_id] = SkillState( + call_id=call_id, name=skill_name, skill_config=skill_config + ) + + # TODO agent often calls the skill again if previous response is still loading. + # maybe create a new skill_state linked to a previous one? not sure + + arg_keywords = args.get("args") or {} + arg_list = [] + + if isinstance(arg_keywords, list): + arg_list = arg_keywords + arg_keywords = {} + + arg_list, arg_keywords = interpret_tool_call_args(args) + + return skill_config.call( # type: ignore[no-any-return] + call_id, + *arg_list, + **arg_keywords, + ) + + # Receives a message from active skill + # Updates local skill state (appends to streamed data if needed etc) + # + # Checks if agent needs to be notified (if ToolConfig has Return=call_agent or Stream=call_agent) + def handle_message( + self, + msg: Annotated[ + SkillMsg, # type: ignore[type-arg] + Doc( + """The incoming skill message containing status updates, output, or errors. + Must contain a valid call_id and skill_name.""" + ), + ], + ) -> None: + """Process incoming skill messages and notify the agent when needed. + + Routes messages to the appropriate SkillState. If notification is required + (based on skill config), sets the agent's updates_available event using + call_soon_threadsafe for cross-loop communication. + + Handles orphan messages (no SkillState) by lazy initialization with warning. + Post-shutdown messages are silently dropped. + """ + if self._closed_coord: + import traceback + + traceback.print_stack() + return + # logger.info(f"SkillMsg from {msg.skill_name}, {msg.call_id} - {msg}") + + if self._skill_state.get(msg.call_id) is None: + logger.warn( + f"Skill state for {msg.skill_name} (call_id={msg.call_id}) not found, (skill not called by our agent?) initializing. (message received: {msg})" + ) + self._skill_state[msg.call_id] = SkillState(call_id=msg.call_id, name=msg.skill_name) + + should_notify = self._skill_state[msg.call_id].handle_msg(msg) + + if should_notify: + updates_available = self._ensure_updates_available() + if updates_available is None: + print("[DEBUG] Event not created yet, deferring notification") + return + + try: + current_loop = asyncio.get_running_loop() + agent_loop = getattr(self, "_agent_loop", self._loop) + # print( + # f"[DEBUG] handle_message: current_loop={id(current_loop)}, agent_loop={id(agent_loop) if agent_loop else 'None'}, event={id(updates_available)}" + # ) + if agent_loop and agent_loop != current_loop: + # print( + # f"[DEBUG] Calling set() via call_soon_threadsafe from loop {id(current_loop)} to agent loop {id(agent_loop)}" + # ) + agent_loop.call_soon_threadsafe(updates_available.set) + else: + # print(f"[DEBUG] Calling set() directly in current loop {id(current_loop)}") + updates_available.set() + except RuntimeError: + # No running loop, use call_soon_threadsafe if we have an agent loop + agent_loop = getattr(self, "_agent_loop", self._loop) + # print( + # f"[DEBUG] No current running loop, agent_loop={id(agent_loop) if agent_loop else 'None'}" + # ) + if agent_loop: + # print( + # f"[DEBUG] Calling set() via call_soon_threadsafe to agent loop {id(agent_loop)}" + # ) + agent_loop.call_soon_threadsafe(updates_available.set) + else: + # print(f"[DEBUG] Event creation was deferred, can't notify") + pass + + def has_active_skills(self) -> bool: + if not self.has_passive_skills(): + return False + for skill_run in self._skill_state.values(): + # check if this skill will notify agent + if skill_run.skill_config.ret == Return.call_agent: + return True + if skill_run.skill_config.stream == Stream.call_agent: + return True + return False + + def has_passive_skills(self) -> bool: + # check if dict is empty + if self._skill_state == {}: + return False + return True + + async def wait_for_updates( + self, + timeout: Annotated[float | None, Doc("Optional timeout in seconds")] = None, + ) -> Annotated[bool, Doc("True if updates are available, False on timeout")]: + """Wait for skill updates to become available. + + This method should be called by the agent when it's ready to receive updates. + It will block until updates are available or timeout is reached. + """ + updates_available = self._ensure_updates_available() + if updates_available is None: + # Force event creation now that we're in the agent's loop context + # print(f"[DEBUG] wait_for_updates: Creating event in current loop context") + current_loop = asyncio.get_running_loop() + self._updates_available = asyncio.Event() + self._agent_loop = current_loop + updates_available = self._updates_available + # print( + # f"[DEBUG] wait_for_updates: Created event {id(updates_available)} in loop {id(current_loop)}" + # ) + + try: + current_loop = asyncio.get_running_loop() + + # Double-check the loop context before waiting + if self._agent_loop != current_loop: + # print(f"[DEBUG] Loop context changed! Recreating event for loop {id(current_loop)}") + self._updates_available = asyncio.Event() + self._agent_loop = current_loop + updates_available = self._updates_available + + # print( + # f"[DEBUG] wait_for_updates: current_loop={id(current_loop)}, event={id(updates_available)}, is_set={updates_available.is_set()}" + # ) + if timeout: + # print(f"[DEBUG] Waiting for event with timeout {timeout}") + await asyncio.wait_for(updates_available.wait(), timeout=timeout) + else: + print("[DEBUG] Waiting for event without timeout") + await updates_available.wait() + print("[DEBUG] Event was set! Returning True") + return True + except asyncio.TimeoutError: + print("[DEBUG] Timeout occurred while waiting for event") + return False + except RuntimeError as e: + if "bound to a different event loop" in str(e): + print( + "[DEBUG] Event loop binding error detected, recreating event and returning False to retry" + ) + # Recreate the event in the current loop + current_loop = asyncio.get_running_loop() + self._updates_available = asyncio.Event() + self._agent_loop = current_loop + return False + else: + raise + + def generate_snapshot( + self, + clear: Annotated[ + bool, + Doc( + """Whether to perform cleanup after snapshot generation. If True, + removes completed/errored skills from tracking, resets stream accumulators + for running skills, and clears the updates_available event. If False, + returns a simple copy without side effects.""" + ), + ] = True, + ) -> Annotated[ + SkillStateDict, + Doc( + """Dictionary mapping call_id to SkillState objects. Each SkillState contains + the skill's execution state, accumulated outputs, and error information. + The returned dict is a copy independent of internal state.""" + ), + ]: + """Generate an atomic snapshot of skill states with optional cleanup. + + Returns a point-in-time copy of all tracked skill invocations. When clear=True, + performs atomic read-and-clear: removes terminal states (completed/error), resets + stream accumulators for running skills, and clears the updates_available event. + """ + ret = copy(self._skill_state) + + if clear: + updates_available = self._ensure_updates_available() + if updates_available is not None: + # print(f"[DEBUG] generate_snapshot: clearing event {id(updates_available)}") + updates_available.clear() + else: + ... + # rint(f"[DEBUG] generate_snapshot: event not created yet, nothing to clear") + to_delete = [] + # Since snapshot is being sent to agent, we can clear the finished skill runs + for call_id, skill_run in self._skill_state.items(): + if skill_run.state == SkillStateEnum.completed: + logger.info(f"Skill {skill_run.name} (call_id={call_id}) finished") + to_delete.append(call_id) + if skill_run.state == SkillStateEnum.error: + error_msg = skill_run.error_msg.content.get("msg", "Unknown error") # type: ignore[union-attr] + error_traceback = skill_run.error_msg.content.get( # type: ignore[union-attr] + "traceback", "No traceback available" + ) + + logger.error( + f"Skill error for {skill_run.name} (call_id={call_id}): {error_msg}" + ) + print(error_traceback) + to_delete.append(call_id) + + elif ( + skill_run.state == SkillStateEnum.running + and skill_run.reduced_stream_msg is not None + ): + # preserve ret as a copy + ret[call_id] = copy(skill_run) + logger.debug( + f"Resetting accumulator for skill {skill_run.name} (call_id={call_id})" + ) + skill_run.reduced_stream_msg = None # type: ignore[assignment] + + for call_id in to_delete: + logger.debug(f"Call {call_id} finished, removing from state") + del self._skill_state[call_id] + + return ret + + def __str__(self) -> str: + console = Console(force_terminal=True, legacy_windows=False) + + # Create main table without any header + table = Table(show_header=False) + + # Add containers section + containers_table = Table(show_header=True, show_edge=False, box=None) + containers_table.add_column("Type", style="cyan") + containers_table.add_column("Container", style="white") + + # Add static containers + for container in self._static_containers: + containers_table.add_row("Static", str(container)) + + # Add dynamic containers + for container in self._dynamic_containers: + containers_table.add_row("Dynamic", str(container)) + + if not self._static_containers and not self._dynamic_containers: + containers_table.add_row("", "[dim]No containers registered[/dim]") + + # Add skill states section + states_table = self._skill_state.table() + states_table.show_edge = False + states_table.box = None + + # Combine into main table + table.add_column("Section", style="bold") + table.add_column("Details", style="none") + table.add_row("Containers", containers_table) + table.add_row("Skills", states_table) + + # Render to string with title above + with console.capture() as capture: + console.print(Text(" SkillCoordinator", style="bold blue")) + console.print(table) + return capture.get().strip() + + # Given skillcontainers can run remotely, we are + # Caching available skills from static containers + # + # Dynamic containers will be queried at runtime via + # .skills() method + def register_skills( + self, + container: Annotated[ + SkillContainer, + Doc( + """The skill container to register. Must implement the SkillContainer + protocol with a dynamic_skills() method and a skills() method that returns + a mapping of skill names to SkillConfig objects.""" + ), + ], + ) -> None: + """Register a skill container with the coordinator, making its skills available to agents. + + Static containers (dynamic_skills() == False): Skills cached immediately for O(1) lookup. + Dynamic containers (dynamic_skills() == True): Skills queried at runtime for context-dependent generation. + + Skill resolution order: cached static skills first, then dynamic container query. + """ + self.empty = False + if not container.dynamic_skills(): + logger.info(f"Registering static skill container, {container}") + self._static_containers.append(container) + for name, skill_config in container.skills().items(): + self._skills[name] = skill_config.bind(getattr(container, name)) + else: + logger.info(f"Registering dynamic skill container, {container}") + self._dynamic_containers.append(container) + + def get_skill_config(self, skill_name: str) -> SkillConfig | None: + skill_config = self._skills.get(skill_name) + if not skill_config: + skill_config = self.skills().get(skill_name) + return skill_config + + def skills(self) -> dict[str, SkillConfig]: + # Static container skilling is already cached + all_skills: dict[str, SkillConfig] = {**self._skills} + + # Then aggregate skills from dynamic containers + for container in self._dynamic_containers: + for skill_name, skill_config in container.skills().items(): + all_skills[skill_name] = skill_config.bind(getattr(container, skill_name)) + + return all_skills diff --git a/dimos/protocol/skill/schema.py b/dimos/protocol/skill/schema.py new file mode 100644 index 0000000000..3b265f9c1b --- /dev/null +++ b/dimos/protocol/skill/schema.py @@ -0,0 +1,103 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Union, get_args, get_origin + + +def python_type_to_json_schema(python_type) -> dict: # type: ignore[no-untyped-def, type-arg] + """Convert Python type annotations to JSON Schema format.""" + # Handle None/NoneType + if python_type is type(None) or python_type is None: + return {"type": "null"} + + # Handle Union types (including Optional) + origin = get_origin(python_type) + if origin is Union: + args = get_args(python_type) + # Handle Optional[T] which is Union[T, None] + if len(args) == 2 and type(None) in args: + non_none_type = args[0] if args[1] is type(None) else args[1] + schema = python_type_to_json_schema(non_none_type) + # For OpenAI function calling, we don't use anyOf for optional params + return schema + else: + # For other Union types, use anyOf + return {"anyOf": [python_type_to_json_schema(arg) for arg in args]} + + # Handle List/list types + if origin in (list, list): + args = get_args(python_type) + if args: + return {"type": "array", "items": python_type_to_json_schema(args[0])} + return {"type": "array"} + + # Handle Dict/dict types + if origin in (dict, dict): + return {"type": "object"} + + # Handle basic types + type_map = { + str: {"type": "string"}, + int: {"type": "integer"}, + float: {"type": "number"}, + bool: {"type": "boolean"}, + list: {"type": "array"}, + dict: {"type": "object"}, + } + + return type_map.get(python_type, {"type": "string"}) + + +def function_to_schema(func) -> dict: # type: ignore[no-untyped-def, type-arg] + """Convert a function to OpenAI function schema format.""" + try: + signature = inspect.signature(func) + except ValueError as e: + raise ValueError(f"Failed to get signature for function {func.__name__}: {e!s}") + + properties = {} + required = [] + + for param_name, param in signature.parameters.items(): + # Skip 'self' parameter for methods + if param_name == "self": + continue + + # Get the type annotation + if param.annotation != inspect.Parameter.empty: + param_schema = python_type_to_json_schema(param.annotation) + else: + # Default to string if no type annotation + param_schema = {"type": "string"} + + # Add description from docstring if available (would need more sophisticated parsing) + properties[param_name] = param_schema + + # Add to required list if no default value + if param.default == inspect.Parameter.empty: + required.append(param_name) + + return { + "type": "function", + "function": { + "name": func.__name__, + "description": (func.__doc__ or "").strip(), + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } diff --git a/dimos/protocol/skill/skill.py b/dimos/protocol/skill/skill.py new file mode 100644 index 0000000000..859d579f78 --- /dev/null +++ b/dimos/protocol/skill/skill.py @@ -0,0 +1,667 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decorator and runtime for exposing robot capabilities as LLM tool calls. + +This module provides the core abstractions for defining and executing *skills*: wrappers around +robot capabilities that allow LLMs to invoke them as tool calls. +Skills transform high-level intentions (e.g., "navigate to the kitchen") into concrete actions. + +Core components +--------------- +`@skill` decorator + Transforms `Module` methods into agent-callable tools with configurable execution + semantics (streaming, passive/active behavior, thread pooling). Auto-generates JSON + schemas for LLM tool calling. + +`SkillContainer` class + Base infrastructure inherited by all Modules, providing skill execution, threading, + and transport. Available automatically via `Module` inheritance. + +`SkillContainerConfig` dataclass + Configuration for skill transport layer (defaults to LCM-based communication). + +Architecture +------------ +Skills execute in a distributed, asynchronous environment where every `Module` inherits +from `SkillContainer`. The `@skill` decorator wraps methods with execution routing and +schema generation. + +See also +-------- +`dimos.core.module.Module` : Base class that inherits `SkillContainer` +`dimos.protocol.skill.coordinator.SkillCoordinator` : Manages skill execution state for agents +`dimos.protocol.skill.type` : Enums and types for skill configuration +`dimos.agents2.agent` : LLM agents that discover and invoke skills + +Related docs +------------ +- Build your first skill tutorial: `docs/tutorials/skill_basics/tutorial.md` +- Wire a skill to an agent tutorial: `docs/tutorials/skill_with_agent/tutorial.md` +- Explainer on the Skill concept: `docs/concepts/skills.md` + +Examples +-------- +Basic skill returning a result: + +>>> from dimos.core.module import Module +>>> from dimos.protocol.skill.skill import skill +>>> +>>> class RobotSkills(Module): +... @skill() +... def greet(self, name: str) -> str: +... '''Greet someone by name.''' +... return f"Hello, {name}!" + +Notes +----- +**Thread pool execution:** +Skills execute in a `ThreadPoolExecutor` when called via coordinator (with `call_id`), +preventing blocking during long-running operations. + +**Passive skill requirements:** +Skills with `stream=Stream.passive` never wake the agent except on errors. Their data +is only delivered when an active skill keeps the agent loop running; by design, +passive skills are for auxiliary data like telemetry during primary tasks. +""" + +import asyncio +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Annotated, Any + +from annotated_doc import Doc + +# from dimos.core.core import rpc +from dimos.protocol.skill.comms import LCMSkillComms, SkillCommsSpec +from dimos.protocol.skill.schema import function_to_schema +from dimos.protocol.skill.type import ( + MsgType, + Output, + Reducer, + Return, + SkillConfig, + SkillMsg, + Stream, +) + +# skill is a decorator that allows us to specify a skill behaviour for a function. +# +# there are several parameters that can be specified: +# - ret: how to return the value from the skill, can be one of: +# +# Return.none: doesn't return anything to an agent +# Return.passive: doesn't schedule an agent call but +# returns the value to the agent when agent is called +# Return.call_agent: calls the agent with the value, scheduling an agent call +# +# - stream: if the skill streams values, it can behave in several ways: +# +# Stream.none: no streaming, skill doesn't emit any values +# Stream.passive: doesn't schedule an agent call upon emitting a value, +# returns the streamed value to the agent when agent is called +# Stream.call_agent: calls the agent with every value emitted, scheduling an agent call +# +# - reducer: defines an optional strategy for passive streams and how we collapse potential +# multiple values into something meaningful for the agent +# +# Reducer.none: no reduction, every emitted value is returned to the agent +# Reducer.latest: only the latest value is returned to the agent +# Reducer.average: assumes the skill emits a number, +# the average of all values is returned to the agent + + +def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: + fn.__rpc__ = True # type: ignore[attr-defined] + return fn + + +def skill( + reducer: Annotated[ + Reducer, + Doc( + """Aggregation strategy for streaming skills when multiple values are emitted. + + Applies to both `stream=Stream.passive` and `stream=Stream.call_agent`. + Has no effect when `stream=Stream.none`. + """ + ), + ] = Reducer.latest, + stream: Annotated[ + Stream, + Doc( + """ + Controls how generator/iterator returns are handled. + + Use `Stream.none` for non-streaming skills, `Stream.passive` for streaming without + triggering agent calls (values aggregated by reducer), or `Stream.call_agent` to + trigger an agent call for each yielded value. + """ + ), + ] = Stream.none, + ret: Annotated[ + Return, + Doc( + """ + Controls how the final return value is delivered to the agent. + + Use `Return.none` to suppress return value, `Return.passive` to make value available + when agent queries, or `Return.call_agent` to actively schedule an agent call with + the result. + + Note: forced to `Return.passive` when `stream=Stream.passive` to maintain + consistent passive behavior. + """ + ), + ] = Return.call_agent, + output: Annotated[ + Output, + Doc( + """Presentation hint for how the agent should interpret the output. + + Use `Output.standard` for normal text, `Output.human` for human-readable formatted + output, or `Output.image` for visual content. + """ + ), + ] = Output.standard, + hide_skill: Annotated[ + bool, + Doc( + """If True, prevents the skill from appearing in the agent's available skills list. + + Hidden skills can still be called programmatically but won't be offered to LLMs during + tool selection. Useful for internal or administrative skills. + """ + ), + ] = False, +) -> Callable: + """Decorator that transforms `Module` methods into agent-callable skills. + + The `@skill` decorator is what allows methods on a `Module` to be invoked as tool calls by LLM agents. + It does this by wrapping methods with execution routing, message protocol handling, + and automatic schema generation. + + When an agent calls a skill through the SkillCoordinator, the skill executes in + a background thread pool and publishes messages tracking its execution state + (start → [stream]* → ret/error). This enables non-blocking execution, progress + monitoring, and distributed deployment across machines. + + Examples: + + >>> from dimos.core.module import Module + >>> from dimos.protocol.skill.type import Stream, Reducer, Return + + Basic skill returning a string result: + + >>> class NavigationSkills(Module): + ... def __init__(self): + ... super().__init__() + ... self.goal = None + ... + ... def _set_goal(self, location: str) -> None: + ... self.goal = location + ... + ... @skill() + ... def navigate_to(self, location: str) -> str: + ... '''Navigate to a named location.''' + ... self._set_goal(location) + ... return f"Navigating to {location}" + + Streaming skill with progress updates: + + >>> class MonitorSkills(Module): + ... @skill(stream=Stream.call_agent, ret=Return.call_agent) + ... def monitor_task(self, count: int): + ... '''Monitor a long-running operation.''' + ... for i in range(count): + ... yield f"Progress: {i+1}/{count}" + ... yield "Task completed" + + Passive skill with reducer aggregation: + + >>> class RobotSkills(Module): + ... def _get_frames(self): + ... for i in range(5): + ... yield f"frame_{i}" + ... + ... @skill(stream=Stream.passive, reducer=Reducer.latest) + ... def stream_camera(self): + ... '''Stream camera frames in background.''' + ... for frame in self._get_frames(): + ... yield frame + ... yield "Camera stopped" + ... + ... @skill(ret=Return.call_agent) # Active companion keeps loop alive + ... def navigate_to(self, location: str) -> str: + ... '''Navigate while camera streams.''' + ... return f"Arrived at {location}" + + Hidden administrative skill: + + >>> class SystemSkills(Module): + ... def _calibrate(self) -> None: + ... pass # Internal calibration logic + ... + ... @skill(hide_skill=True) + ... def internal_calibration(self) -> str: + ... '''Internal calibration routine.''' + ... self._calibrate() + ... return "Calibration complete" + + See also the tutorials and other examples of skills in the library. + + Notes: + + **Key Contracts:** + + - Return strings for LLM compatibility (non-strings with `agent_encode()` + method are auto-encoded) + - Methods must be on subclasses of Module + - Parameters must be JSON-serializable for schema generation + + **Passive Skill Warning:** When using `stream=Stream.passive`: + + - If only passive skills are running, the loop exits and data from passive skills is lost + See `Stream.passive` docstring for full semantics. + + **Generator skills:** Use `yield` (not `return`) for your final message. + Only the last `yield` becomes `MsgType.ret`. + + **Best Practices:** + + - Write clear docstrings - they become the skill descriptions LLMs see + - Return meaningful strings that help agents understand outcomes + - Handle errors gracefully with contextual messages for agent recovery + """ + + def decorator(f: Callable[..., Any]) -> Any: + def wrapper(self, *args, **kwargs): # type: ignore[no-untyped-def] + skill = f"{f.__name__}" + + call_id = kwargs.get("call_id", None) + if call_id: + del kwargs["call_id"] + + return self.call_skill(call_id, skill, args, kwargs) + # def run_function(): + # return self.call_skill(call_id, skill, args, kwargs) + # + # thread = threading.Thread(target=run_function) + # thread.start() + # return None + + return f(self, *args, **kwargs) + + # sig = inspect.signature(f) + # params = list(sig.parameters.values()) + # if params and params[0].name == "self": + # params = params[1:] # Remove first parameter 'self' + # wrapper.__signature__ = sig.replace(parameters=params) + + skill_config = SkillConfig( + name=f.__name__, + reducer=reducer, # type: ignore[arg-type] + stream=stream, + # if stream is passive, ret must be passive too + ret=ret.passive if stream == Stream.passive else ret, + output=output, + schema=function_to_schema(f), + hide_skill=hide_skill, + ) + + wrapper.__rpc__ = True # type: ignore[attr-defined] + wrapper._skill_config = skill_config # type: ignore[attr-defined] + wrapper.__name__ = f.__name__ # Preserve original function name + wrapper.__doc__ = f.__doc__ # Preserve original docstring + return wrapper + + return decorator + + +@dataclass +class SkillContainerConfig: + skill_transport: type[SkillCommsSpec] = LCMSkillComms + + +def threaded(f: Callable[..., Any]) -> Callable[..., None]: + """Decorator to run a function in a thread pool.""" + + def wrapper(self, *args, **kwargs): # type: ignore[no-untyped-def] + if self._skill_thread_pool is None: + self._skill_thread_pool = ThreadPoolExecutor( + max_workers=50, thread_name_prefix="skill_worker" + ) + self._skill_thread_pool.submit(f, self, *args, **kwargs) + return None + + return wrapper + + +# Inherited by any class that wants to provide skills +# (This component works standalone but commonly used by DimOS modules) +# +# Hosts the function execution and handles correct publishing of skill messages +# according to the individual skill decorator configuration +# +# - It allows us to specify a communication layer for skills (LCM for now by default) +# - introspection of available skills via the `skills` RPC method +# - ability to provide dynamic context dependant skills with dynamic_skills flag +# for this you'll need to override the `skills` method to return a dynamic set of skills +# SkillCoordinator will call this method to get the skills available upon every request to +# the agent + + +class SkillContainer: + """Infrastructure for hosting and executing agent-callable skills. + + SkillContainer provides the foundational protocol layer inherited by all DimOS `Module`s, + enabling any `Module` to expose `@skill` decorated methods as LLM tool calls. This class + handles skill discovery, threaded execution, message publishing, and transport abstraction + for distributed skill communication. + + Key capabilities: + - **Skill introspection**: Discovers all `@skill` decorated methods via `skills()` RPC + - **Threaded execution**: Runs skills in background thread pool (max 50 workers) + - **Message protocol**: Publishes lifecycle events for streaming and error handling + - **Transport abstraction**: Configurable communication layer (default: LCM-based) + - **Lazy initialization**: Transport and thread pool created on first use + + Users typically don't interact with SkillContainer directly—it provides infrastructure + that makes the `@skill` decorator work seamlessly across distributed deployments. + + See also: + `dimos.core.module.Module` : Base class that inherits SkillContainer capabilities + `dimos.protocol.skill.coordinator.SkillCoordinator` : Orchestrates skill execution for agents + `@skill` decorator : Transforms Module methods into agent-callable tools (see for message protocol details) + `dimos.protocol.skill.type.Stream` : Stream behavior configuration (passive vs. active) + `dimos.protocol.skill.type.Return` : Return value handling modes + """ + + skill_transport_class: type[SkillCommsSpec] = LCMSkillComms + _skill_thread_pool: ThreadPoolExecutor | None = None + _skill_transport: SkillCommsSpec | None = None + + @rpc + def dynamic_skills(self) -> bool: + """Indicate whether this container generates skills dynamically at runtime. + + When False (the default), skills are cached at registration for performance. + Override to return True when skills depend on runtime context (e.g., attached + hardware, environment state)—this causes skills to be queried on each request. + + Note: If the skills depend only on constructor parameters (configuration at init time), + static skills work just fine. + + Examples: + Static skills (default behavior): + + >>> from dimos.core.module import Module + >>> from dimos.protocol.skill.skill import skill, rpc + >>> from dimos.protocol.skill.type import SkillConfig + >>> + >>> class StaticSkills(Module): + ... @skill() + ... def fixed_skill(self) -> str: + ... return "Always available" + ... # dynamic_skills() not overridden, returns False + + Dynamic skills based on runtime state: + + >>> class DynamicSkills(Module): + ... def __init__(self): + ... super().__init__() + ... self.gripper_attached = False + ... + ... @rpc + ... def dynamic_skills(self) -> bool: + ... return True # Skills change based on hardware state + ... + ... def skills(self) -> dict[str, SkillConfig]: + ... available = super().skills() + ... if not self.gripper_attached: + ... # Remove gripper-dependent skills + ... available.pop("pick_object", None) + ... return available + """ + return False + + def __str__(self) -> str: + return f"SkillContainer({self.__class__.__name__})" + + @rpc + def stop(self) -> None: + """Release skill execution resources and propagate cleanup to parent classes.""" + if self._skill_transport: + self._skill_transport.stop() + self._skill_transport = None + + if self._skill_thread_pool: + self._skill_thread_pool.shutdown(wait=True) + self._skill_thread_pool = None + + # Continue the MRO chain if there's a parent stop() method + if hasattr(super(), "stop"): + super().stop() # type: ignore[misc] + + # TODO: figure out standard args/kwargs passing format, + # use same interface as skill coordinator call_skill method + @threaded + def call_skill( + self, + call_id: Annotated[ + str, Doc("Unique identifier for this skill invocation, used for message correlation.") + ], + skill_name: Annotated[ + str, Doc("Name of the skill method to invoke (must match a `@skill` decorated method).") + ], + args: Annotated[tuple[Any, ...], Doc("Positional arguments to pass to the skill method.")], + kwargs: Annotated[dict[str, Any], Doc("Keyword arguments to pass to the skill method.")], + ) -> None: + """Execute a skill in the thread pool and publish lifecycle messages. + + Core execution method invoked by the `@skill` decorator when a skill is called with + a `call_id` parameter. Executes the skill in a background thread pool and publishes + status messages according to the skill's configuration. + + Message protocol: + 1. Publish `MsgType.start` immediately upon entry + 2. If skill returns an iterable (except strings): + - Publish `MsgType.stream` for each yielded/iterated value + - Publish `MsgType.ret` with the **last yielded value** after exhaustion + 3. If skill returns a non-iterable (or string): + - Publish `MsgType.ret` with the return value + 4. On exception: + - Publish `MsgType.error` with `{msg: str, traceback: str}` content + + Notes: + **Threading:** + The `@threaded` decorator submits execution to `_skill_thread_pool`, returning + immediately. + + See also: + `@skill` decorator : Wraps methods and routes calls with `call_id` to this method + `SkillCoordinator.call_skill` : Higher-level interface for skill invocation + """ + f = getattr(self, skill_name, None) + + if f is None: + raise ValueError(f"Function '{skill_name}' not found in {self.__class__.__name__}") + + config = getattr(f, "_skill_config", None) + if config is None: + raise ValueError(f"Function '{skill_name}' in {self.__class__.__name__} is not a skill") + + # we notify the skill transport about the start of the skill call + self.skill_transport.publish(SkillMsg(call_id, skill_name, None, type=MsgType.start)) + + try: + val = f(*args, **kwargs) + + # check if the skill returned a coroutine, if it is, block until it resolves + if isinstance(val, asyncio.Future): + val = asyncio.run(val) # type: ignore[arg-type] + + # check if the skill is a generator, if it is, we need to iterate over it + if hasattr(val, "__iter__") and not isinstance(val, str): + last_value = None + for v in val: + last_value = v + self.skill_transport.publish( + SkillMsg(call_id, skill_name, v, type=MsgType.stream) + ) + self.skill_transport.publish( + SkillMsg(call_id, skill_name, last_value, type=MsgType.ret) + ) + + else: + self.skill_transport.publish(SkillMsg(call_id, skill_name, val, type=MsgType.ret)) + + except Exception as e: + import traceback + + formatted_traceback = "".join(traceback.TracebackException.from_exception(e).format()) + + self.skill_transport.publish( + SkillMsg( + call_id, + skill_name, + {"msg": str(e), "traceback": formatted_traceback}, + type=MsgType.error, + ) + ) + + @rpc + def skills( + self, + ) -> Annotated[ + dict[str, SkillConfig], + Doc( + """Dictionary mapping skill name to SkillConfig. Each SkillConfig contains the skill's + JSON schema, execution settings (stream/ret/reducer), and metadata for LLM tool calling.""" + ), + ]: + """Discover all `@skill` decorated methods on this container. + + Introspects the container's methods to find those decorated with `@skill`, returning + their configurations for registration with the SkillCoordinator. This method enables + automatic skill discovery without explicit registration lists. + + Discovery algorithm: + 1. Iterate over all public attribute names via `dir(self)` + 2. Exclude: names starting with `_`, and names in exclusion list + 3. Include: attributes with a `_skill_config` attribute (set by `@skill` decorator) + + The exclusion list prevents recursion and avoids accessing problematic properties: + `{"skills", "tf", "rpc", "skill_transport"}`. + + Examples: + Discovering skills from a container: + + >>> from dimos.core.module import Module + >>> from dimos.protocol.skill.skill import skill + >>> + >>> class NavigationSkills(Module): + ... @skill() + ... def navigate_to(self, location: str) -> str: + ... '''Navigate to a named location.''' + ... return f"Navigating to {location}" + ... + ... @skill() + ... def cancel_navigation(self) -> str: + ... '''Cancel current navigation.''' + ... return "Navigation cancelled" + >>> skills = NavigationSkills() + >>> discovered = skills.skills() + >>> sorted(discovered.keys()) + ['cancel_navigation', 'navigate_to'] + >>> discovered['navigate_to'].schema['function']['description'] + 'Navigate to a named location.' + + Notes: + This method is marked `@rpc` for remote queryability by SkillCoordinator during + skill registration. When `dynamic_skills()` returns True, this method is called + on each coordinator query to refresh the skill set. + + See also: + `dynamic_skills()` : Controls whether skills are cached or queried dynamically + `@skill` decorator : Attaches `_skill_config` to methods for discovery + `SkillCoordinator.register_skills` : Uses this method during registration + """ + # Avoid recursion by excluding this property itself + # Also exclude known properties that shouldn't be accessed + excluded = {"skills", "tf", "rpc", "skill_transport"} + return { + name: getattr(self, name)._skill_config + for name in dir(self) + if not name.startswith("_") + and name not in excluded + and hasattr(getattr(self, name), "_skill_config") + } + + @property + def skill_transport( + self, + ) -> Annotated[ + SkillCommsSpec, + Doc( + """Transport instance for skill message publishing. Lazily initialized on first access + using `skill_transport_class` (default: `LCMSkillComms`).""" + ), + ]: + """Provide lazy access to the skill transport layer. + + Creates and caches a transport instance on first access, using the class specified by + `skill_transport_class`. The transport handles publishing skill messages (start, stream, + ret, error) to `SkillCoordinator` via the configured communication layer. + + Examples: + Custom transport class: + + >>> from dimos.core.module import Module + >>> from dimos.protocol.skill.skill import skill + >>> from dimos.protocol.skill.comms import SkillCommsSpec + >>> + >>> class CustomTransport(SkillCommsSpec): + ... def publish(self, msg): + ... pass # Custom implementation + ... def subscribe(self, cb): + ... pass + ... def start(self): + ... pass + ... def stop(self): + ... pass + >>> class CustomSkills(Module): + ... skill_transport_class = CustomTransport + ... @skill() + ... def example(self) -> str: + ... return "Done" + >>> skills = CustomSkills() + >>> skills.start() + >>> type(skills.skill_transport).__name__ + 'CustomTransport' + >>> skills.stop() + + Notes: + The transport is shared across all skills in this container, ensuring consistent + message delivery. + + See also: + `skill_transport_class` : Class attribute specifying which transport to instantiate + `SkillCommsSpec` : Interface defining transport contract (publish/subscribe/start/stop) + `LCMSkillComms` : Default transport implementation using LCM + """ + if self._skill_transport is None: + self._skill_transport = self.skill_transport_class() + return self._skill_transport diff --git a/dimos/protocol/skill/test_coordinator.py b/dimos/protocol/skill/test_coordinator.py new file mode 100644 index 0000000000..76e3e80697 --- /dev/null +++ b/dimos/protocol/skill/test_coordinator.py @@ -0,0 +1,159 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from collections.abc import Generator +import datetime +import time + +import pytest + +from dimos.core import Module, rpc +from dimos.msgs.sensor_msgs import Image +from dimos.protocol.skill.coordinator import SkillCoordinator +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Output, Reducer, Stream +from dimos.utils.data import get_data + + +class SkillContainerTest(Module): + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + @skill() + def add(self, x: int, y: int) -> int: + """adds x and y.""" + time.sleep(2) + return x + y + + @skill() + def delayadd(self, x: int, y: int) -> int: + """waits 0.3 seconds before adding x and y.""" + time.sleep(0.3) + return x + y + + @skill(stream=Stream.call_agent, reducer=Reducer.all) + def counter(self, count_to: int, delay: float | None = 0.05) -> Generator[int, None, None]: + """Counts from 1 to count_to, with an optional delay between counts.""" + for i in range(1, count_to + 1): + if delay > 0: + time.sleep(delay) + yield i + + @skill(stream=Stream.passive, reducer=Reducer.sum) + def counter_passive_sum( + self, count_to: int, delay: float | None = 0.05 + ) -> Generator[int, None, None]: + """Counts from 1 to count_to, with an optional delay between counts.""" + for i in range(1, count_to + 1): + if delay > 0: + time.sleep(delay) + yield i + + @skill(stream=Stream.passive, reducer=Reducer.latest) + def current_time(self, frequency: float | None = 10) -> Generator[str, None, None]: + """Provides current time.""" + while True: + yield str(datetime.datetime.now()) + time.sleep(1 / frequency) + + @skill(stream=Stream.passive, reducer=Reducer.latest) + def uptime_seconds(self, frequency: float | None = 10) -> Generator[float, None, None]: + """Provides current uptime.""" + start_time = datetime.datetime.now() + while True: + yield (datetime.datetime.now() - start_time).total_seconds() + time.sleep(1 / frequency) + + @skill() + def current_date(self, frequency: float | None = 10) -> str: + """Provides current date.""" + return datetime.datetime.now() + + @skill(output=Output.image) + def take_photo(self) -> str: + """Takes a camera photo""" + print("Taking photo...") + img = Image.from_file(get_data("cafe-smol.jpg")) + print("Photo taken.") + return img + + +@pytest.mark.asyncio +async def test_coordinator_parallel_calls() -> None: + container = SkillContainerTest() + skillCoordinator = SkillCoordinator() + skillCoordinator.register_skills(container) + + skillCoordinator.start() + skillCoordinator.call_skill("test-call-0", "add", {"args": [0, 2]}) + + time.sleep(0.1) + + cnt = 0 + while await skillCoordinator.wait_for_updates(1): + print(skillCoordinator) + + skillstates = skillCoordinator.generate_snapshot() + + skill_id = f"test-call-{cnt}" + tool_msg = skillstates[skill_id].agent_encode() + assert tool_msg.content == cnt + 2 + + cnt += 1 + if cnt < 5: + skillCoordinator.call_skill( + f"test-call-{cnt}-delay", + "delayadd", + {"args": [cnt, 2]}, + ) + skillCoordinator.call_skill( + f"test-call-{cnt}", + "add", + {"args": [cnt, 2]}, + ) + + await asyncio.sleep(0.1 * cnt) + + container.stop() + skillCoordinator.stop() + + +@pytest.mark.asyncio +async def test_coordinator_generator() -> None: + container = SkillContainerTest() + skillCoordinator = SkillCoordinator() + skillCoordinator.register_skills(container) + skillCoordinator.start() + + # here we call a skill that generates a sequence of messages + skillCoordinator.call_skill("test-gen-0", "counter", {"args": [10]}) + skillCoordinator.call_skill("test-gen-1", "counter_passive_sum", {"args": [5]}) + skillCoordinator.call_skill("test-gen-2", "take_photo", {"args": []}) + + # periodically agent is stopping it's thinking cycle and asks for updates + while await skillCoordinator.wait_for_updates(2): + print(skillCoordinator) + agent_update = skillCoordinator.generate_snapshot(clear=True) + print(agent_update) + await asyncio.sleep(0.125) + + print("coordinator loop finished") + print(skillCoordinator) + container.stop() + skillCoordinator.stop() diff --git a/dimos/protocol/skill/test_utils.py b/dimos/protocol/skill/test_utils.py new file mode 100644 index 0000000000..d9fe9f6f91 --- /dev/null +++ b/dimos/protocol/skill/test_utils.py @@ -0,0 +1,87 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.skill.utils import interpret_tool_call_args + + +def test_list() -> None: + args, kwargs = interpret_tool_call_args([1, 2, 3]) + assert args == [1, 2, 3] + assert kwargs == {} + + +def test_none() -> None: + args, kwargs = interpret_tool_call_args(None) + assert args == [] + assert kwargs == {} + + +def test_none_nested() -> None: + args, kwargs = interpret_tool_call_args({"args": None}) + assert args == [] + assert kwargs == {} + + +def test_non_dict() -> None: + args, kwargs = interpret_tool_call_args("test") + assert args == ["test"] + assert kwargs == {} + + +def test_dict_with_args_and_kwargs() -> None: + args, kwargs = interpret_tool_call_args({"args": [1, 2], "kwargs": {"key": "value"}}) + assert args == [1, 2] + assert kwargs == {"key": "value"} + + +def test_dict_with_only_kwargs() -> None: + args, kwargs = interpret_tool_call_args({"kwargs": {"a": 1, "b": 2}}) + assert args == [] + assert kwargs == {"a": 1, "b": 2} + + +def test_dict_as_kwargs() -> None: + args, kwargs = interpret_tool_call_args({"x": 10, "y": 20}) + assert args == [] + assert kwargs == {"x": 10, "y": 20} + + +def test_dict_with_only_args_first_pass() -> None: + args, kwargs = interpret_tool_call_args({"args": [5, 6, 7]}) + assert args == [5, 6, 7] + assert kwargs == {} + + +def test_dict_with_only_args_nested() -> None: + args, kwargs = interpret_tool_call_args({"args": {"inner": "value"}}) + assert args == [] + assert kwargs == {"inner": "value"} + + +def test_empty_list() -> None: + args, kwargs = interpret_tool_call_args([]) + assert args == [] + assert kwargs == {} + + +def test_empty_dict() -> None: + args, kwargs = interpret_tool_call_args({}) + assert args == [] + assert kwargs == {} + + +def test_integer() -> None: + args, kwargs = interpret_tool_call_args(42) + assert args == [42] + assert kwargs == {} diff --git a/dimos/protocol/skill/type.py b/dimos/protocol/skill/type.py new file mode 100644 index 0000000000..26ab7d8124 --- /dev/null +++ b/dimos/protocol/skill/type.py @@ -0,0 +1,436 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +import time +from typing import Annotated, Any, Generic, Literal, TypeVar + +from annotated_doc import Doc + +from dimos.types.timestamped import Timestamped +from dimos.utils.generic import truncate_display_string + +# This file defines protocol messages used for communication between skills and agents + + +class Output(Enum): + standard = 0 + human = 1 + image = 2 # this is same as separate_message, but maybe clearer for users + + +class Stream(Enum): + """Controls how streaming skill outputs are handled. + + Streaming skills (generators/iterators) emit multiple values during execution. + This enum determines whether each emitted value should wake the agent or just + accumulate silently in the background. + """ + + none = 0 + """No streaming. Skill returns a single value, not a generator.""" + + passive = 1 + """Passive streaming that accumulates values without waking the agent. + + Values accumulate via the configured reducer but do not trigger agent calls. + Passive skill data is only delivered when an active skill keeps the agent + loop running long enough for a snapshot to be generated. + + Behavior: + - Each `yield` applies the reducer to accumulate state + - Never wakes the agent (except on errors) + - Forces `ret=Return.passive` regardless of user setting + + How delivery works: + - The agent loop checks for active skills before generating snapshots + (snapshots include info about all skills) + - If active skills exist, the loop continues. + + Note: + - If *only* passive skills are running, the loop exits immediately at the + termination check without generating a snapshot. + Remaining passive data will not be delivered to the agent. + - That is, passive skills *require* an active companion skill (like `human_input` or + any `stream=Stream.call_agent` skill) to ensure data reaches the agent. + + Examples of use cases: + - Video streaming during navigation (with active navigation skill) + - Sensor telemetry alongside task execution + + Anti-patterns: + - Using passive skills without active skills (data never delivered) + - Starting passive skills with short-lived active skills (data may be lost) + """ + + call_agent = 2 + """Active streaming that wakes the agent on each yield. + + If yields happen faster than the agent can process them, the reducer combines + intermediate values. + + Use for progress updates or incremental results that should notify the agent + promptly while handling backpressure from fast producers. + """ + + +class Return(Enum): + """Controls how skill return values are delivered and whether they wake the agent. + + While Stream controls behavior during execution (for generators), Return controls + what happens when a skill completes. This determines whether the agent is directly notified + of completion and whether the return value is included in snapshots. + + Note: Errors always wake the agent regardless of Return setting. + + Constraint: `stream=Stream.passive` forces `ret=Return.passive` automatically. + """ + + none = 0 + """Return value discarded, agent not notified. + + Use for fire-and-forget operations where the agent doesn't need + to know about completion. + + Examples of use cases: + - Background logging or telemetry + - Fire-and-forget actuator commands + - Cleanup operations + """ + + passive = 1 + """Return value stored but agent not woken. + + The skill completes silently, but the return value is stored and appears in + snapshots when the agent wakes for other reasons. + + Critical: If no active skills are running when this skill completes, the + agent loop exits and this return value is never delivered. + + Note: When `stream=Stream.passive`, `ret` is forced to this value. + + Use cases: + - Status checks collected alongside active tasks + - Sensor readings that don't justify waking agent + """ + + call_agent = 2 + """Return value triggers immediate agent notification. + + Skill completion wakes the agent and delivers the return value immediately. + This is the default and most common behavior. + """ + + callback = 3 + """Not implemented. Reserved for future callback pattern.""" + + +@dataclass +class SkillConfig: + """Configuration for a skill, created by the @skill decorator. + + Attached to decorated methods as `_skill_config`. Used by SkillCoordinator + to control execution behavior. + """ + + name: Annotated[str, Doc("Skill name (from decorated function name).")] + reducer: Annotated[ReducerF, Doc("Aggregation function for streaming values.")] + stream: Annotated[Stream, Doc("Streaming behavior (none/passive/call_agent).")] + ret: Annotated[ + Return, + Doc( + "Return value delivery (none/passive/call_agent). " + "Note: Forced to `passive` when `stream=Stream.passive`." + ), + ] + output: Annotated[Output, Doc("Presentation hint for agent (standard/human/image).")] + schema: Annotated[dict[str, Any], Doc("OpenAI function-calling schema for LLM invocation.")] + f: Annotated[Callable | None, Doc("Bound method reference (set via `bind()`)")] = None # type: ignore[type-arg] + autostart: Annotated[bool, Doc("Reserved for future use (currently unused).")] = False + hide_skill: Annotated[bool, Doc("If True, skill hidden from LLM tool selection.")] = False + + def bind(self, f: Callable) -> SkillConfig: # type: ignore[type-arg] + self.f = f + return self + + def call(self, call_id, *args, **kwargs) -> Any: # type: ignore[no-untyped-def] + if self.f is None: + raise ValueError( + "Function is not bound to the SkillConfig. This should be called only within AgentListener." + ) + + return self.f(*args, **kwargs, call_id=call_id) + + def __str__(self) -> str: + parts = [f"name={self.name}"] + + # Only show reducer if stream is not none (streaming is happening) + if self.stream != Stream.none: + parts.append(f"stream={self.stream.name}") + + # Always show return mode + parts.append(f"ret={self.ret.name}") + return f"Skill({', '.join(parts)})" + + +class MsgType(Enum): + pending = 0 + start = 1 + stream = 2 + reduced_stream = 3 + ret = 4 + error = 5 + + +M = TypeVar("M", bound="MsgType") + + +def maybe_encode(something: Any) -> str: + if hasattr(something, "agent_encode"): + return something.agent_encode() # type: ignore[no-any-return] + return something # type: ignore[no-any-return] + + +class SkillMsg(Timestamped, Generic[M]): + ts: float + type: M + call_id: str + skill_name: str + content: str | int | float | dict | list # type: ignore[type-arg] + + def __init__( + self, + call_id: str, + skill_name: str, + content: Any, + type: M, + ) -> None: + self.ts = time.time() + self.call_id = call_id + self.skill_name = skill_name + # any tool output can be a custom type that knows how to encode itself + # like a costmap, path, transform etc could be translatable into strings + + self.content = maybe_encode(content) + self.type = type + + @property + def end(self) -> bool: + return self.type == MsgType.ret or self.type == MsgType.error + + @property + def start(self) -> bool: + return self.type == MsgType.start + + def __str__(self) -> str: # type: ignore[return] + time_ago = time.time() - self.ts + + if self.type == MsgType.start: + return f"Start({time_ago:.1f}s ago)" + if self.type == MsgType.ret: + return f"Ret({time_ago:.1f}s ago, val={truncate_display_string(self.content)})" + if self.type == MsgType.error: + return f"Error({time_ago:.1f}s ago, val={truncate_display_string(self.content)})" + if self.type == MsgType.pending: + return f"Pending({time_ago:.1f}s ago)" + if self.type == MsgType.stream: + return f"Stream({time_ago:.1f}s ago, val={truncate_display_string(self.content)})" + if self.type == MsgType.reduced_stream: + return f"Stream({time_ago:.1f}s ago, val={truncate_display_string(self.content)})" + + +# typing looks complex but it's a standard reducer function signature, using SkillMsgs +# (Optional[accumulator], msg) -> accumulator +ReducerF = Callable[ + [SkillMsg[Literal[MsgType.reduced_stream]] | None, SkillMsg[Literal[MsgType.stream]]], + SkillMsg[Literal[MsgType.reduced_stream]], +] + + +C = TypeVar("C") # content type +A = TypeVar("A") # accumulator type +# define a naive reducer function type that's generic in terms of the accumulator type +SimpleReducerF = Callable[[A | None, C], A] + + +def make_reducer(simple_reducer: SimpleReducerF) -> ReducerF: # type: ignore[type-arg] + """ + Converts a naive reducer function into a standard reducer function. + The naive reducer function should accept an accumulator and a message, + and return the updated accumulator. + """ + + def reducer( + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, + msg: SkillMsg[Literal[MsgType.stream]], + ) -> SkillMsg[Literal[MsgType.reduced_stream]]: + # Extract the content from the accumulator if it exists + acc_value = accumulator.content if accumulator else None + + # Apply the simple reducer to get the new accumulated value + new_value = simple_reducer(acc_value, msg.content) + + # Wrap the result in a SkillMsg with reduced_stream type + return SkillMsg( + call_id=msg.call_id, + skill_name=msg.skill_name, + content=new_value, + type=MsgType.reduced_stream, + ) + + return reducer + + +# just a convinience class to hold reducer functions +def _make_skill_msg( + msg: SkillMsg[Literal[MsgType.stream]], content: Any +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """Helper to create a reduced stream message with new content.""" + return SkillMsg( + call_id=msg.call_id, + skill_name=msg.skill_name, + content=content, + type=MsgType.reduced_stream, + ) + + +def sum_reducer( + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """Sum reducer that adds values together.""" + acc_value = accumulator.content if accumulator else None + new_value = acc_value + msg.content if acc_value else msg.content # type: ignore[operator] + return _make_skill_msg(msg, new_value) + + +def latest_reducer( + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """Latest reducer that keeps only the most recent value.""" + return _make_skill_msg(msg, msg.content) + + +def all_reducer( + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """All reducer that collects all values into a list.""" + acc_value = accumulator.content if accumulator else None + new_value = [*acc_value, msg.content] if acc_value else [msg.content] # type: ignore[misc] + return _make_skill_msg(msg, new_value) + + +def accumulate_list( + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """List concatenation reducer: extends accumulator list with message content list.""" + acc_value = accumulator.content if accumulator else [] + return _make_skill_msg(msg, acc_value + msg.content) # type: ignore[operator] + + +def accumulate_dict( + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """Dict merge reducer: merges message content dict into accumulator dict.""" + acc_value = accumulator.content if accumulator else {} + return _make_skill_msg(msg, {**acc_value, **msg.content}) # type: ignore[dict-item] + + +def accumulate_string( + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """String concatenation reducer: joins values with newlines. + + Examples: + >>> m = lambda s: SkillMsg('id', 'x', s, MsgType.stream) + >>> accumulate_string(None, m('A')).content # no leading newline + 'A' + >>> accumulate_string(accumulate_string(None, m('A')), m('B')).content + 'A\\nB' + >>> # Edge case: empty string as first yield doesn't cause leading newline + >>> accumulate_string(accumulate_string(None, m('')), m('X')).content + 'X' + """ + prefix = f"{accumulator.content}\n" if accumulator and accumulator.content else "" + new_value = prefix + msg.content + return _make_skill_msg(msg, new_value) + + +class Reducer: + """Namespace for reducer functions that buffer streaming skill outputs. + + Reducers act as **backpressure buffers**: when a skill yields values faster + than the agent can process them, the reducer combines or aggregates updates + between agent calls. + + With `Stream.passive`, values accumulate silently until an active skill wakes + the agent. With `Stream.call_agent`, whether updates accumulate depends on + whether yields happen faster than the agent processes them. + + Custom reducers can be created with `make_reducer()`. + + For examples, see `dimos/hardware/camera/module.py` and `dimos/navigation/rosnav.py`. + """ + + sum: Annotated[ + ReducerF, + Doc("""Adds numeric values together. O(1) memory."""), + ] = sum_reducer + + latest: Annotated[ + ReducerF, + Doc( + """Keeps only the most recent value, discarding previous state. O(1) memory. + + Ideal for high-frequency data where only the current value matters + (sensor readings, video frames, robot pose).""" + ), + ] = latest_reducer + + all: Annotated[ + ReducerF, + Doc("""Collects yielded values into a list. O(n) memory per snapshot interval."""), + ] = all_reducer + + accumulate_list: Annotated[ + ReducerF, + Doc( + """Concatenates yielded lists into one. O(n) memory per snapshot interval. + + Unlike `all` (which wraps each yield in a list), this expects yields + to already be lists and flattens them together.""" + ), + ] = accumulate_list + + accumulate_dict: Annotated[ + ReducerF, + Doc( + """Merges yielded dicts into one. O(n) memory in unique keys per snapshot interval. + + Later values overwrite earlier ones for duplicate keys.""" + ), + ] = accumulate_dict + + string: Annotated[ + ReducerF, + Doc("""Joins string values with newlines. O(n) memory per snapshot interval."""), + ] = accumulate_string diff --git a/dimos/protocol/skill/utils.py b/dimos/protocol/skill/utils.py new file mode 100644 index 0000000000..278134c525 --- /dev/null +++ b/dimos/protocol/skill/utils.py @@ -0,0 +1,41 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + + +def interpret_tool_call_args( + args: Any, first_pass: bool = True +) -> tuple[list[Any], dict[str, Any]]: + """ + Agents sometimes produce bizarre calls. This tries to interpret the args better. + """ + + if isinstance(args, list): + return args, {} + if args is None: + return [], {} + if not isinstance(args, dict): + return [args], {} + if args.keys() == {"args", "kwargs"}: + return args["args"], args["kwargs"] + if args.keys() == {"kwargs"}: + return [], args["kwargs"] + if args.keys() != {"args"}: + return [], args + + if first_pass: + return interpret_tool_call_args(args["args"], first_pass=False) + + return [], args diff --git a/dimos/protocol/tf/__init__.py b/dimos/protocol/tf/__init__.py new file mode 100644 index 0000000000..96cdbcf285 --- /dev/null +++ b/dimos/protocol/tf/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.tf.tf import LCMTF, TF, MultiTBuffer, PubSubTF, TBuffer, TFConfig, TFSpec + +__all__ = ["LCMTF", "TF", "MultiTBuffer", "PubSubTF", "TBuffer", "TFConfig", "TFSpec"] diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py new file mode 100644 index 0000000000..0b5b332c3d --- /dev/null +++ b/dimos/protocol/tf/test_tf.py @@ -0,0 +1,679 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time + +import pytest + +from dimos.core import TF +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.protocol.tf import MultiTBuffer, TBuffer + + +# from https://foxglove.dev/blog/understanding-ros-transforms +def test_tf_ros_example() -> None: + tf = TF() + + base_link_to_arm = Transform( + translation=Vector3(1.0, -1.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0, 0, math.pi / 6)), + frame_id="base_link", + child_frame_id="arm", + ts=time.time(), + ) + + arm_to_end = Transform( + translation=Vector3(1.0, 1.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity rotation + frame_id="arm", + child_frame_id="end_effector", + ts=time.time(), + ) + + tf.publish(base_link_to_arm, arm_to_end) + time.sleep(0.2) + + end_effector_global_pose = tf.get("base_link", "end_effector") + + assert end_effector_global_pose.translation.x == pytest.approx(1.366, abs=1e-3) + assert end_effector_global_pose.translation.y == pytest.approx(0.366, abs=1e-3) + + tf.stop() + + +def test_tf_main() -> None: + """Test TF broadcasting and querying between two TF instances. + If you run foxglove-bridge this will show up in the UI""" + + # here we create broadcasting and receiving TF instance. + # this is to verify that comms work multiprocess, normally + # you'd use only one instance in your module + broadcaster = TF() + querier = TF() + + # Create a transform from world to robot + current_time = time.time() + + world_to_charger = Transform( + translation=Vector3(2.0, -2.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0, 0, 2)), + frame_id="world", + child_frame_id="charger", + ts=current_time, + ) + + world_to_robot = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity rotation + frame_id="world", + child_frame_id="robot", + ts=current_time, + ) + + # Broadcast the transform + broadcaster.publish(world_to_robot) + broadcaster.publish(world_to_charger) + # Give time for the message to propagate + time.sleep(0.05) + + # Verify frames are available + frames = querier.get_frames() + assert "world" in frames + assert "robot" in frames + + # Add another transform in the chain + robot_to_sensor = Transform( + translation=Vector3(0.5, 0.0, 0.2), + rotation=Quaternion(0.0, 0.0, 0.707107, 0.707107), # 90 degrees around Z + frame_id="robot", + child_frame_id="sensor", + ts=current_time, + ) + + broadcaster.publish(robot_to_sensor) + + time.sleep(0.05) + + # we can now query (from a separate process given we use querier) the transform tree + chain_transform = querier.get("world", "sensor") + + # broadcaster will agree with us + assert broadcaster.get("world", "sensor") == chain_transform + + # The chain should compose: world->robot (1,2,3) + robot->sensor (0.5,0,0.2) + # Expected translation: (1.5, 2.0, 3.2) + assert abs(chain_transform.translation.x - 1.5) < 0.001 + assert abs(chain_transform.translation.y - 2.0) < 0.001 + assert abs(chain_transform.translation.z - 3.2) < 0.001 + + # we see something on camera + random_object_in_view = PoseStamped( + frame_id="random_object", + position=Vector3(1, 0, 0), + ) + + print("Random obj", random_object_in_view) + + # random_object is perceived by the sensor + # we create a transform pointing from sensor to object + random_t = random_object_in_view.new_transform_from("sensor") + + # we could have also done + assert random_t == random_object_in_view.new_transform_to("sensor").inverse() + + print("randm t", random_t) + + # we broadcast our object location + broadcaster.publish(random_t) + + ## we could also publish world -> random_object if we wanted to + # broadcaster.publish( + # broadcaster.get("world", "sensor") + random_object_in_view.new_transform("sensor").inverse() + # ) + ## (this would mess with the transform system because it expects trees not graphs) + ## and our random_object would get re-connected to world from sensor + + print(broadcaster) + + # Give time for the message to propagate + time.sleep(0.05) + + # we know where the object is in the world frame now + world_object = broadcaster.get("world", "random_object") + + # both instances agree + assert querier.get("world", "random_object") == world_object + + print("world object", world_object) + + # if you have "diagon" https://diagon.arthursonzogni.com/ installed you can draw a graph + print(broadcaster.graph()) + + assert abs(world_object.translation.x - 1.5) < 0.001 + assert abs(world_object.translation.y - 3.0) < 0.001 + assert abs(world_object.translation.z - 3.2) < 0.001 + + # this doesn't work atm + robot_to_charger = broadcaster.get("robot", "charger") + + # Expected: robot->world->charger + print(f"robot_to_charger translation: {robot_to_charger.translation}") + print(f"robot_to_charger rotation: {robot_to_charger.rotation}") + + assert abs(robot_to_charger.translation.x - 1.0) < 0.001 + assert abs(robot_to_charger.translation.y - (-4.0)) < 0.001 + assert abs(robot_to_charger.translation.z - (-3.0)) < 0.001 + + # Stop services (they were autostarted but don't know how to autostop) + broadcaster.stop() + querier.stop() + + +class TestTBuffer: + def test_add_transform(self) -> None: + buffer = TBuffer(buffer_size=10.0) + transform = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="robot", + ts=time.time(), + ) + + buffer.add(transform) + assert len(buffer) == 1 + assert buffer[0] == transform + + def test_get(self) -> None: + buffer = TBuffer() + base_time = time.time() + + # Add transforms at different times + for i in range(3): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time + i * 0.5, + ) + buffer.add(transform) + + # Test getting latest transform + latest = buffer.get() + assert latest is not None + assert latest.translation.x == 2.0 + + # Test getting transform at specific time + middle = buffer.get(time_point=base_time + 0.75) + assert middle is not None + assert middle.translation.x == 2.0 # Closest to i=1 + + # Test time tolerance + result = buffer.get(time_point=base_time + 10.0, time_tolerance=0.1) + assert result is None # Outside tolerance + + def test_buffer_pruning(self) -> None: + buffer = TBuffer(buffer_size=1.0) # 1 second buffer + + # Add old transform + old_time = time.time() - 2.0 + old_transform = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=old_time, + ) + buffer.add(old_transform) + + # Add recent transform + recent_transform = Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=time.time(), + ) + buffer.add(recent_transform) + + # Old transform should be pruned + assert len(buffer) == 1 + assert buffer[0].translation.x == 2.0 + + +class TestMultiTBuffer: + def test_multiple_frame_pairs(self) -> None: + ttbuffer = MultiTBuffer(buffer_size=10.0) + + # Add transforms for different frame pairs + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot1", + ts=time.time(), + ) + + transform2 = Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot2", + ts=time.time(), + ) + + ttbuffer.receive_transform(transform1, transform2) + + # Should have two separate buffers + assert len(ttbuffer.buffers) == 2 + assert ("world", "robot1") in ttbuffer.buffers + assert ("world", "robot2") in ttbuffer.buffers + + def test_graph(self) -> None: + ttbuffer = MultiTBuffer(buffer_size=10.0) + + # Add transforms for different frame pairs + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot1", + ts=time.time(), + ) + + transform2 = Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot2", + ts=time.time(), + ) + + ttbuffer.receive_transform(transform1, transform2) + + print(ttbuffer.graph()) + + def test_get_latest_transform(self) -> None: + ttbuffer = MultiTBuffer() + + # Add multiple transforms + for i in range(3): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=time.time() + i * 0.1, + ) + ttbuffer.receive_transform(transform) + time.sleep(0.01) + + # Get latest transform + latest = ttbuffer.get("world", "robot") + assert latest is not None + assert latest.translation.x == 2.0 + + def test_get_transform_at_time(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add transforms at known times + for i in range(5): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time + i * 0.5, + ) + ttbuffer.receive_transform(transform) + + # Get transform closest to middle time + middle_time = base_time + 1.25 # Should be closest to i=2 (t=1.0) or i=3 (t=1.5) + result = ttbuffer.get("world", "robot", time_point=middle_time) + assert result is not None + # At t=1.25, it's equidistant from i=2 (t=1.0) and i=3 (t=1.5) + # The implementation picks the later one when equidistant + assert result.translation.x == 3.0 + + def test_time_tolerance(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add single transform + transform = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ttbuffer.receive_transform(transform) + + # Within tolerance + result = ttbuffer.get("world", "robot", time_point=base_time + 0.1, time_tolerance=0.2) + assert result is not None + + # Outside tolerance + result = ttbuffer.get("world", "robot", time_point=base_time + 0.5, time_tolerance=0.1) + assert result is None + + def test_nonexistent_frame_pair(self) -> None: + ttbuffer = MultiTBuffer() + + # Try to get transform for non-existent frame pair + result = ttbuffer.get("foo", "bar") + assert result is None + + def test_get_transform_search_direct(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add direct transform + transform = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ttbuffer.receive_transform(transform) + + # Search should return single transform + result = ttbuffer.get_transform_search("world", "robot") + assert result is not None + assert len(result) == 1 + assert result[0].translation.x == 1.0 + + def test_get_transform_search_chain(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create transform chain: world -> robot -> sensor + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + transform2 = Transform( + translation=Vector3(0.0, 2.0, 0.0), + frame_id="robot", + child_frame_id="sensor", + ts=base_time, + ) + ttbuffer.receive_transform(transform1, transform2) + + # Search should find chain + result = ttbuffer.get_transform_search("world", "sensor") + assert result is not None + assert len(result) == 2 + assert result[0].translation.x == 1.0 # world -> robot + assert result[1].translation.y == 2.0 # robot -> sensor + + def test_get_transform_search_complex_chain(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create more complex graph: + # world -> base -> arm -> hand + # \-> robot -> sensor + transforms = [ + Transform( + frame_id="world", + child_frame_id="base", + translation=Vector3(1.0, 0.0, 0.0), + ts=base_time, + ), + Transform( + frame_id="base", + child_frame_id="arm", + translation=Vector3(0.0, 1.0, 0.0), + ts=base_time, + ), + Transform( + frame_id="arm", + child_frame_id="hand", + translation=Vector3(0.0, 0.0, 1.0), + ts=base_time, + ), + Transform( + frame_id="world", + child_frame_id="robot", + translation=Vector3(2.0, 0.0, 0.0), + ts=base_time, + ), + Transform( + frame_id="robot", + child_frame_id="sensor", + translation=Vector3(0.0, 2.0, 0.0), + ts=base_time, + ), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + # Find path world -> hand (should go through base -> arm) + result = ttbuffer.get_transform_search("world", "hand") + assert result is not None + assert len(result) == 3 + assert result[0].child_frame_id == "base" + assert result[1].child_frame_id == "arm" + assert result[2].child_frame_id == "hand" + + def test_get_transform_search_no_path(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create disconnected transforms + transform1 = Transform(frame_id="world", child_frame_id="robot", ts=base_time) + transform2 = Transform(frame_id="base", child_frame_id="sensor", ts=base_time) + ttbuffer.receive_transform(transform1, transform2) + + # No path exists + result = ttbuffer.get_transform_search("world", "sensor") + assert result is None + + def test_get_transform_search_with_time(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add transforms at different times + old_transform = Transform( + frame_id="world", + child_frame_id="robot", + translation=Vector3(1.0, 0.0, 0.0), + ts=base_time - 10.0, + ) + new_transform = Transform( + frame_id="world", + child_frame_id="robot", + translation=Vector3(2.0, 0.0, 0.0), + ts=base_time, + ) + ttbuffer.receive_transform(old_transform, new_transform) + + # Search at specific time + result = ttbuffer.get_transform_search("world", "robot", time_point=base_time) + assert result is not None + assert result[0].translation.x == 2.0 + + # Search with time tolerance + result = ttbuffer.get_transform_search( + "world", "robot", time_point=base_time + 1.0, time_tolerance=0.1 + ) + assert result is None # Outside tolerance + + def test_get_transform_search_shortest_path(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create graph with multiple paths: + # world -> A -> B -> target (3 hops) + # world -> target (direct, 1 hop) + transforms = [ + Transform(frame_id="world", child_frame_id="A", ts=base_time), + Transform(frame_id="A", child_frame_id="B", ts=base_time), + Transform(frame_id="B", child_frame_id="target", ts=base_time), + Transform(frame_id="world", child_frame_id="target", ts=base_time), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + # BFS should find the direct path (shortest) + result = ttbuffer.get_transform_search("world", "target") + assert result is not None + assert len(result) == 1 # Direct path, not the 3-hop path + assert result[0].child_frame_id == "target" + + def test_string_representations(self) -> None: + # Test empty buffers + empty_buffer = TBuffer() + assert str(empty_buffer) == "TBuffer(empty)" + + empty_ttbuffer = MultiTBuffer() + assert str(empty_ttbuffer) == "MultiTBuffer(empty)" + + # Test TBuffer with data + buffer = TBuffer() + base_time = time.time() + for i in range(3): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time + i * 0.1, + ) + buffer.add(transform) + + buffer_str = str(buffer) + assert "3 msgs" in buffer_str + assert "world -> robot" in buffer_str + assert "0.20s" in buffer_str # duration + + # Test MultiTBuffer with multiple frame pairs + ttbuffer = MultiTBuffer() + transforms = [ + Transform(frame_id="world", child_frame_id="robot1", ts=base_time), + Transform(frame_id="world", child_frame_id="robot2", ts=base_time + 0.5), + Transform(frame_id="robot1", child_frame_id="sensor", ts=base_time + 1.0), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + ttbuffer_str = str(ttbuffer) + print("\nMultiTBuffer string representation:") + print(ttbuffer_str) + + assert "MultiTBuffer(3 buffers):" in ttbuffer_str + assert "TBuffer(world -> robot1, 1 msgs" in ttbuffer_str + assert "TBuffer(world -> robot2, 1 msgs" in ttbuffer_str + assert "TBuffer(robot1 -> sensor, 1 msgs" in ttbuffer_str + + def test_get_with_transform_chain_composition(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create transform chain: world -> robot -> sensor + # world -> robot: translate by (1, 0, 0) + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + + # robot -> sensor: translate by (0, 2, 0) and rotate 90 degrees around Z + import math + + # 90 degrees around Z: quaternion (0, 0, sin(45°), cos(45°)) + transform2 = Transform( + translation=Vector3(0.0, 2.0, 0.0), + rotation=Quaternion(0.0, 0.0, math.sin(math.pi / 4), math.cos(math.pi / 4)), + frame_id="robot", + child_frame_id="sensor", + ts=base_time, + ) + + ttbuffer.receive_transform(transform1, transform2) + + # Get composed transform from world to sensor + result = ttbuffer.get("world", "sensor") + assert result is not None + + # The composed transform should: + # 1. Apply world->robot translation: (1, 0, 0) + # 2. Apply robot->sensor translation in robot frame: (0, 2, 0) + # Total translation: (1, 2, 0) + assert abs(result.translation.x - 1.0) < 1e-6 + assert abs(result.translation.y - 2.0) < 1e-6 + assert abs(result.translation.z - 0.0) < 1e-6 + + # Rotation should be 90 degrees around Z (same as transform2) + assert abs(result.rotation.x - 0.0) < 1e-6 + assert abs(result.rotation.y - 0.0) < 1e-6 + assert abs(result.rotation.z - math.sin(math.pi / 4)) < 1e-6 + assert abs(result.rotation.w - math.cos(math.pi / 4)) < 1e-6 + + # Frame IDs should be correct + assert result.frame_id == "world" + assert result.child_frame_id == "sensor" + + def test_get_with_longer_transform_chain(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create longer chain: world -> base -> arm -> hand + # Each adds a translation along different axes + transforms = [ + Transform( + translation=Vector3(1.0, 0.0, 0.0), # Move 1 along X + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="base", + ts=base_time, + ), + Transform( + translation=Vector3(0.0, 2.0, 0.0), # Move 2 along Y + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base", + child_frame_id="arm", + ts=base_time, + ), + Transform( + translation=Vector3(0.0, 0.0, 3.0), # Move 3 along Z + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="arm", + child_frame_id="hand", + ts=base_time, + ), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + # Get composed transform from world to hand + result = ttbuffer.get("world", "hand") + assert result is not None + + # Total translation should be sum of all: (1, 2, 3) + assert abs(result.translation.x - 1.0) < 1e-6 + assert abs(result.translation.y - 2.0) < 1e-6 + assert abs(result.translation.z - 3.0) < 1e-6 + + # Rotation should still be identity (all rotations were identity) + assert abs(result.rotation.x - 0.0) < 1e-6 + assert abs(result.rotation.y - 0.0) < 1e-6 + assert abs(result.rotation.z - 0.0) < 1e-6 + assert abs(result.rotation.w - 1.0) < 1e-6 + + assert result.frame_id == "world" + assert result.child_frame_id == "hand" diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py new file mode 100644 index 0000000000..3688b013cf --- /dev/null +++ b/dimos/protocol/tf/tf.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from collections import deque +from dataclasses import dataclass, field +from functools import reduce +from typing import TypeVar + +from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.tf2_msgs import TFMessage +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.protocol.pubsub.spec import PubSub +from dimos.protocol.service.lcmservice import Service # type: ignore[attr-defined] +from dimos.types.timestamped import TimestampedCollection + +CONFIG = TypeVar("CONFIG") + + +# generic configuration for transform service +@dataclass +class TFConfig: + buffer_size: float = 10.0 # seconds + rate_limit: float = 10.0 # Hz + + +# generic specification for transform service +class TFSpec(Service[TFConfig]): + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + + @abstractmethod + def publish(self, *args: Transform) -> None: ... + + @abstractmethod + def publish_static(self, *args: Transform) -> None: ... + + def get_frames(self) -> set[str]: + return set() + + @abstractmethod + def get( # type: ignore[no-untyped-def] + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + ): ... + + def receive_transform(self, *args: Transform) -> None: ... + + def receive_tfmessage(self, msg: TFMessage) -> None: + for transform in msg.transforms: + self.receive_transform(transform) + + +MsgT = TypeVar("MsgT") +TopicT = TypeVar("TopicT") + + +# stores a single transform +class TBuffer(TimestampedCollection[Transform]): + def __init__(self, buffer_size: float = 10.0) -> None: + super().__init__() + self.buffer_size = buffer_size + + def add(self, transform: Transform) -> None: + super().add(transform) + self._prune_old_transforms(transform.ts) + + def _prune_old_transforms(self, current_time) -> None: # type: ignore[no-untyped-def] + if not self._items: + return + + cutoff_time = current_time - self.buffer_size + + while self._items and self._items[0].ts < cutoff_time: + self._items.pop(0) + + def get(self, time_point: float | None = None, time_tolerance: float = 1.0) -> Transform | None: + """Get transform at specified time or latest if no time given.""" + if time_point is None: + # Return the latest transform + return self[-1] if len(self) > 0 else None + + return self.find_closest(time_point, time_tolerance) + + def __str__(self) -> str: + if not self._items: + return "TBuffer(empty)" + + # Get unique frame info from the transforms + frame_pairs = set() + if self._items: + frame_pairs.add((self._items[0].frame_id, self._items[0].child_frame_id)) + + time_range = self.time_range() + if time_range: + from dimos.types.timestamped import to_human_readable + + start_time = to_human_readable(time_range[0]) + end_time = to_human_readable(time_range[1]) + duration = time_range[1] - time_range[0] + + frame_str = ( + f"{self._items[0].frame_id} -> {self._items[0].child_frame_id}" + if self._items + else "unknown" + ) + + return ( + f"TBuffer(" + f"{frame_str}, " + f"{len(self._items)} msgs, " + f"{duration:.2f}s [{start_time} - {end_time}])" + ) + + return f"TBuffer({len(self._items)} msgs)" + + +# stores multiple transform buffers +# creates a new buffer on demand when new transform is detected +class MultiTBuffer: + def __init__(self, buffer_size: float = 10.0) -> None: + self.buffers: dict[tuple[str, str], TBuffer] = {} + self.buffer_size = buffer_size + + def receive_transform(self, *args: Transform) -> None: + for transform in args: + key = (transform.frame_id, transform.child_frame_id) + if key not in self.buffers: + self.buffers[key] = TBuffer(self.buffer_size) + self.buffers[key].add(transform) + + def get_frames(self) -> set[str]: + frames = set() + for parent, child in self.buffers: + frames.add(parent) + frames.add(child) + return frames + + def get_connections(self, frame_id: str) -> set[str]: + """Get all frames connected to the given frame (both as parent and child).""" + connections = set() + for parent, child in self.buffers: + if parent == frame_id: + connections.add(child) + if child == frame_id: + connections.add(parent) + return connections + + def get_transform( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + ) -> Transform | None: + # Check forward direction + key = (parent_frame, child_frame) + if key in self.buffers: + return self.buffers[key].get(time_point, time_tolerance) # type: ignore[arg-type] + + # Check reverse direction and return inverse + reverse_key = (child_frame, parent_frame) + if reverse_key in self.buffers: + transform = self.buffers[reverse_key].get(time_point, time_tolerance) # type: ignore[arg-type] + return transform.inverse() if transform else None + + return None + + def get(self, *args, **kwargs) -> Transform | None: # type: ignore[no-untyped-def] + simple = self.get_transform(*args, **kwargs) + if simple is not None: + return simple + + complex = self.get_transform_search(*args, **kwargs) + + if complex is None: + return None + + return reduce(lambda t1, t2: t1 + t2, complex) + + def get_transform_search( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + ) -> list[Transform] | None: + """Search for shortest transform chain between parent and child frames using BFS.""" + # Check if direct transform exists (already checked in get_transform, but for clarity) + direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + if direct is not None: + return [direct] + + # BFS to find shortest path + queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])]) + visited = {parent_frame} + + while queue: + current_frame, path = queue.popleft() + + if current_frame == child_frame: + return path + + # Get all connections for current frame + connections = self.get_connections(current_frame) + + for next_frame in connections: + if next_frame not in visited: + visited.add(next_frame) + + # Get the transform between current and next frame + transform = self.get_transform( + current_frame, next_frame, time_point, time_tolerance + ) + if transform: + queue.append((next_frame, [*path, transform])) + + return None + + def graph(self) -> str: + import subprocess + + def connection_str(connection: tuple[str, str]) -> str: + (frame_from, frame_to) = connection + return f"{frame_from} -> {frame_to}" + + graph_str = "\n".join(map(connection_str, self.buffers.keys())) + + try: + result = subprocess.run( + ["diagon", "GraphDAG", "-style=Unicode"], + input=graph_str, + capture_output=True, + text=True, + ) + return result.stdout if result.returncode == 0 else graph_str + except Exception: + return "no diagon installed" + + def __str__(self) -> str: + if not self.buffers: + return f"{self.__class__.__name__}(empty)" + + lines = [f"{self.__class__.__name__}({len(self.buffers)} buffers):"] + for buffer in self.buffers.values(): + lines.append(f" {buffer}") + + return "\n".join(lines) + + +@dataclass +class PubSubTFConfig(TFConfig): + topic: Topic | None = None # Required field but needs default for dataclass inheritance + pubsub: type[PubSub] | PubSub | None = None # type: ignore[type-arg] + autostart: bool = True + + +class PubSubTF(MultiTBuffer, TFSpec): + default_config: type[PubSubTFConfig] = PubSubTFConfig + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + TFSpec.__init__(self, **kwargs) + MultiTBuffer.__init__(self, self.config.buffer_size) + + pubsub_config = getattr(self.config, "pubsub", None) + if pubsub_config is not None: + if callable(pubsub_config): + self.pubsub = pubsub_config() + else: + self.pubsub = pubsub_config + else: + raise ValueError("PubSub configuration is missing") + + if self.config.autostart: # type: ignore[attr-defined] + self.start() + + def start(self, sub: bool = True) -> None: + self.pubsub.start() + if sub: + topic = getattr(self.config, "topic", None) + if topic: + self.pubsub.subscribe(topic, self.receive_msg) + + def stop(self) -> None: + self.pubsub.stop() + + def publish(self, *args: Transform) -> None: + """Send transforms using the configured PubSub.""" + if not self.pubsub: + raise ValueError("PubSub is not configured.") + + self.receive_transform(*args) + topic = getattr(self.config, "topic", None) + if topic: + self.pubsub.publish(topic, TFMessage(*args)) + + def publish_static(self, *args: Transform) -> None: + raise NotImplementedError("Static transforms not implemented in PubSubTF.") + + def publish_all(self) -> None: + """Publish all transforms currently stored in all buffers.""" + all_transforms = [] + for buffer in self.buffers.values(): + # Get the latest transform from each buffer + latest = buffer.get() # get() with no args returns latest + if latest: + all_transforms.append(latest) + + if all_transforms: + self.publish(*all_transforms) + + def get( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + ) -> Transform | None: + return super().get(parent_frame, child_frame, time_point, time_tolerance) + + def receive_msg(self, msg: TFMessage, topic: Topic) -> None: + self.receive_tfmessage(msg) + + +@dataclass +class LCMPubsubConfig(PubSubTFConfig): + topic: Topic = field(default_factory=lambda: Topic("/tf", TFMessage)) + pubsub: type[PubSub] | PubSub | None = LCM # type: ignore[type-arg] + autostart: bool = True + + +class LCMTF(PubSubTF): + default_config: type[LCMPubsubConfig] = LCMPubsubConfig + + +TF = LCMTF diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py new file mode 100644 index 0000000000..158a68d3d8 --- /dev/null +++ b/dimos/protocol/tf/tflcmcpp.py @@ -0,0 +1,93 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from typing import Union + +from dimos.msgs.geometry_msgs import Transform +from dimos.protocol.service.lcmservice import LCMConfig, LCMService +from dimos.protocol.tf.tf import TFConfig, TFSpec + + +# this doesn't work due to tf_lcm_py package +class TFLCM(TFSpec, LCMService): + """A service for managing and broadcasting transforms using LCM. + This is not a separete module, You can include this in your module + if you need to access transforms. + + Ideally we would have a generic pubsub for transforms so we are + transport agnostic (TODO) + + For now we are not doing this because we want to use cpp buffer/lcm + implementation. We also don't want to manually hook up tf stream + for each module. + """ + + default_config = Union[TFConfig, LCMConfig] + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + + import tf_lcm_py as tf # type: ignore[import-not-found] + + self.l = tf.LCM() + self.buffer = tf.Buffer(self.config.buffer_size) + self.listener = tf.TransformListener(self.l, self.buffer) + self.broadcaster = tf.TransformBroadcaster() + self.static_broadcaster = tf.StaticTransformBroadcaster() + + # will call the underlying LCMService.start + self.start() + + def send(self, *args: Transform) -> None: + for t in args: + self.broadcaster.send_transform(t.lcm_transform()) + + def send_static(self, *args: Transform) -> None: + for t in args: + self.static_broadcaster.send_static_transform(t) + + def lookup( # type: ignore[no-untyped-def] + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + ): + return self.buffer.lookup_transform( + parent_frame, + child_frame, + datetime.now(), + lcm_module=self.l, + ) + + def can_transform( + self, parent_frame: str, child_frame: str, time_point: float | datetime | None = None + ) -> bool: + if not time_point: + time_point = datetime.now() + + if isinstance(time_point, float): + time_point = datetime.fromtimestamp(time_point) + + return self.buffer.can_transform(parent_frame, child_frame, time_point) # type: ignore[no-any-return] + + def get_frames(self) -> set[str]: + return set(self.buffer.get_all_frame_names()) + + def start(self) -> None: + super().start() + ... + + def stop(self) -> None: ... diff --git a/dimos/manipulation/sensors_calibration_alignment.py b/dimos/robot/__init__.py similarity index 100% rename from dimos/manipulation/sensors_calibration_alignment.py rename to dimos/robot/__init__.py diff --git a/dimos/robot/agilex/README.md b/dimos/robot/agilex/README.md new file mode 100644 index 0000000000..5d43fa3c3f --- /dev/null +++ b/dimos/robot/agilex/README.md @@ -0,0 +1,371 @@ +# DIMOS Manipulator Robot Development Guide + +This guide explains how to create robot classes, integrate agents, and use the DIMOS module system with LCM transport. + +## Table of Contents +1. [Robot Class Architecture](#robot-class-architecture) +2. [Module System & LCM Transport](#module-system--lcm-transport) +3. [Agent Integration](#agent-integration) +4. [Complete Example](#complete-example) + +## Robot Class Architecture + +### Basic Robot Class Structure + +A DIMOS robot class should follow this pattern: + +```python +from typing import Optional, List +from dimos import core +from dimos.types.robot_capabilities import RobotCapability + +class YourRobot: + """Your robot implementation.""" + + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + # Core components + self.dimos = None + self.modules = {} + self.skill_library = SkillLibrary() + + # Define capabilities + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + """Start the robot modules.""" + # Initialize DIMOS with worker count + self.dimos = core.start(2) # Number of workers needed + + # Deploy modules + # ... (see Module System section) + + def stop(self): + """Stop all modules and clean up.""" + # Stop modules + # Close DIMOS + if self.dimos: + self.dimos.close() +``` + +### Key Components Explained + +1. **Initialization**: Store references to modules, skills, and capabilities +2. **Async Start**: Modules must be deployed asynchronously +3. **Proper Cleanup**: Always stop modules before closing DIMOS + +## Module System & LCM Transport + +### Understanding DIMOS Modules + +Modules are the building blocks of DIMOS robots. They: +- Process data streams (inputs) +- Produce outputs +- Can be connected together +- Communicate via LCM (Lightweight Communications and Marshalling) + +### Deploying a Module + +```python +# Deploy a camera module +self.camera = self.dimos.deploy( + ZEDModule, # Module class + camera_id=0, # Module parameters + resolution="HD720", + depth_mode="NEURAL", + fps=30, + publish_rate=30.0, + frame_id="camera_frame" +) +``` + +### Setting Up LCM Transport + +LCM transport enables inter-module communication: + +```python +# Enable LCM auto-configuration +from dimos.protocol import pubsub +pubsub.lcm.autoconf() + +# Configure output transport +self.camera.color_image.transport = core.LCMTransport( + "/camera/color_image", # Topic name + Image # Message type +) +self.camera.depth_image.transport = core.LCMTransport( + "/camera/depth_image", + Image +) +``` + +### Connecting Modules + +Connect module outputs to inputs: + +```python +# Connect manipulation module to camera outputs +self.manipulation.rgb_image.connect(self.camera.color_image) +self.manipulation.depth_image.connect(self.camera.depth_image) +self.manipulation.camera_info.connect(self.camera.camera_info) +``` + +### Module Communication Pattern + +``` +┌──────────────┐ LCM ┌────────────────┐ LCM ┌──────────────┐ +│ Camera │────────▶│ Manipulation │────────▶│ Visualization│ +│ Module │ Messages│ Module │ Messages│ Output │ +└──────────────┘ └────────────────┘ └──────────────┘ + ▲ ▲ + │ │ + └──────────────────────────┘ + Direct Connection via RPC call +``` + +## Agent Integration + +### Setting Up Agent with Robot + +The run file pattern for agent integration: + +```python +#!/usr/bin/env python3 +import asyncio +import reactivex as rx +from dimos.agents.claude_agent import ClaudeAgent +from dimos.web.robot_web_interface import RobotWebInterface + +def main(): + # 1. Create and start robot + robot = YourRobot() + asyncio.run(robot.start()) + + # 2. Set up skills + skills = robot.get_skills() + skills.add(YourSkill) + skills.create_instance("YourSkill", robot=robot) + + # 3. Set up reactive streams + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # 4. Create web interface + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream}, + audio_subject=rx.subject.Subject() + ) + + # 5. Create agent + agent = ClaudeAgent( + dev_name="your_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query="Your system prompt here", + model_name="claude-3-5-haiku-latest" + ) + + # 6. Connect agent responses + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + # 7. Run interface + web_interface.run() +``` + +### Key Integration Points + +1. **Reactive Streams**: Use RxPy for event-driven communication +2. **Web Interface**: Provides user input/output +3. **Agent**: Processes natural language and executes skills +4. **Skills**: Define robot capabilities as executable actions + +## Complete Example + +### Step 1: Create Robot Class (`my_robot.py`) + +```python +import asyncio +from typing import Optional, List +from dimos import core +from dimos.hardware.camera import CameraModule +from dimos.manipulation.module import ManipulationModule +from dimos.skills.skills import SkillLibrary +from dimos.types.robot_capabilities import RobotCapability +from dimos_lcm.sensor_msgs import Image, CameraInfo +from dimos.protocol import pubsub + +class MyRobot: + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + self.dimos = None + self.camera = None + self.manipulation = None + self.skill_library = SkillLibrary() + + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + # Start DIMOS + self.dimos = core.start(2) + + # Enable LCM + pubsub.lcm.autoconf() + + # Deploy camera + self.camera = self.dimos.deploy( + CameraModule, + camera_id=0, + fps=30 + ) + + # Configure camera LCM + self.camera.color_image.transport = core.LCMTransport("/camera/rgb", Image) + self.camera.depth_image.transport = core.LCMTransport("/camera/depth", Image) + self.camera.camera_info.transport = core.LCMTransport("/camera/info", CameraInfo) + + # Deploy manipulation + self.manipulation = self.dimos.deploy(ManipulationModule) + + # Connect modules + self.manipulation.rgb_image.connect(self.camera.color_image) + self.manipulation.depth_image.connect(self.camera.depth_image) + self.manipulation.camera_info.connect(self.camera.camera_info) + + # Configure manipulation output + self.manipulation.viz_image.transport = core.LCMTransport("/viz/output", Image) + + # Start modules + self.camera.start() + self.manipulation.start() + + await asyncio.sleep(2) # Allow initialization + + def get_skills(self): + return self.skill_library + + def stop(self): + if self.manipulation: + self.manipulation.stop() + if self.camera: + self.camera.stop() + if self.dimos: + self.dimos.close() +``` + +### Step 2: Create Run Script (`run.py`) + +```python +#!/usr/bin/env python3 +import asyncio +import os +from my_robot import MyRobot +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.basic import BasicSkill +from dimos.web.robot_web_interface import RobotWebInterface +import reactivex as rx +import reactivex.operators as ops + +SYSTEM_PROMPT = """You are a helpful robot assistant.""" + +def main(): + # Check API key + if not os.getenv("ANTHROPIC_API_KEY"): + print("Please set ANTHROPIC_API_KEY") + return + + # Create robot + robot = MyRobot() + + try: + # Start robot + asyncio.run(robot.start()) + + # Set up skills + skills = robot.get_skills() + skills.add(BasicSkill) + skills.create_instance("BasicSkill", robot=robot) + + # Set up streams + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # Create web interface + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream} + ) + + # Create agent + agent = ClaudeAgent( + dev_name="my_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query=SYSTEM_PROMPT + ) + + # Connect responses + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + print("Robot ready at http://localhost:5555") + + # Run + web_interface.run() + + finally: + robot.stop() + +if __name__ == "__main__": + main() +``` + +### Step 3: Define Skills (`skills.py`) + +```python +from dimos.skills import Skill, skill + +@skill( + description="Perform a basic action", + parameters={ + "action": "The action to perform" + } +) +class BasicSkill(Skill): + def __init__(self, robot): + self.robot = robot + + def run(self, action: str): + # Implement skill logic + return f"Performed: {action}" +``` + +## Best Practices + +1. **Module Lifecycle**: Always start DIMOS before deploying modules +2. **LCM Topics**: Use descriptive topic names with namespaces +3. **Error Handling**: Wrap module operations in try-except blocks +4. **Resource Cleanup**: Ensure proper cleanup in stop() methods +5. **Async Operations**: Use asyncio for non-blocking operations +6. **Stream Management**: Use RxPy for reactive programming patterns + +## Debugging Tips + +1. **Check Module Status**: Print module.io().result() to see connections +2. **Monitor LCM**: Use Foxglove to visualize LCM messages +3. **Log Everything**: Use dimos.utils.logging_config.setup_logger() +4. **Test Modules Independently**: Deploy and test one module at a time + +## Common Issues + +1. **"Module not started"**: Ensure start() is called after deployment +2. **"No data received"**: Check LCM transport configuration +3. **"Connection failed"**: Verify input/output types match +4. **"Cleanup errors"**: Stop modules before closing DIMOS diff --git a/dimos/robot/agilex/README_CN.md b/dimos/robot/agilex/README_CN.md new file mode 100644 index 0000000000..909a309ce9 --- /dev/null +++ b/dimos/robot/agilex/README_CN.md @@ -0,0 +1,465 @@ +# DIMOS 机械臂机器人开发指南 + +本指南介绍如何创建机器人类、集成智能体(Agent)以及使用 DIMOS 模块系统和 LCM 传输。 + +## 目录 +1. [机器人类架构](#机器人类架构) +2. [模块系统与 LCM 传输](#模块系统与-lcm-传输) +3. [智能体集成](#智能体集成) +4. [完整示例](#完整示例) + +## 机器人类架构 + +### 基本机器人类结构 + +DIMOS 机器人类应遵循以下模式: + +```python +from typing import Optional, List +from dimos import core +from dimos.types.robot_capabilities import RobotCapability + +class YourRobot: + """您的机器人实现。""" + + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + # 核心组件 + self.dimos = None + self.modules = {} + self.skill_library = SkillLibrary() + + # 定义能力 + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + """启动机器人模块。""" + # 初始化 DIMOS,指定工作线程数 + self.dimos = core.start(2) # 需要的工作线程数 + + # 部署模块 + # ... (参见模块系统章节) + + def stop(self): + """停止所有模块并清理资源。""" + # 停止模块 + # 关闭 DIMOS + if self.dimos: + self.dimos.close() +``` + +### 关键组件说明 + +1. **初始化**:存储模块、技能和能力的引用 +2. **异步启动**:模块必须异步部署 +3. **正确清理**:在关闭 DIMOS 之前始终停止模块 + +## 模块系统与 LCM 传输 + +### 理解 DIMOS 模块 + +模块是 DIMOS 机器人的构建块。它们: +- 处理数据流(输入) +- 产生输出 +- 可以相互连接 +- 通过 LCM(轻量级通信和编组)进行通信 + +### 部署模块 + +```python +# 部署相机模块 +self.camera = self.dimos.deploy( + ZEDModule, # 模块类 + camera_id=0, # 模块参数 + resolution="HD720", + depth_mode="NEURAL", + fps=30, + publish_rate=30.0, + frame_id="camera_frame" +) +``` + +### 设置 LCM 传输 + +LCM 传输实现模块间通信: + +```python +# 启用 LCM 自动配置 +from dimos.protocol import pubsub +pubsub.lcm.autoconf() + +# 配置输出传输 +self.camera.color_image.transport = core.LCMTransport( + "/camera/color_image", # 主题名称 + Image # 消息类型 +) +self.camera.depth_image.transport = core.LCMTransport( + "/camera/depth_image", + Image +) +``` + +### 连接模块 + +将模块输出连接到输入: + +```python +# 将操作模块连接到相机输出 +self.manipulation.rgb_image.connect(self.camera.color_image) # ROS set_callback +self.manipulation.depth_image.connect(self.camera.depth_image) +self.manipulation.camera_info.connect(self.camera.camera_info) +``` + +### 模块通信模式 + +``` +┌──────────────┐ LCM ┌────────────────┐ LCM ┌──────────────┐ +│ 相机模块 │────────▶│ 操作模块 │────────▶│ 可视化输出 │ +│ │ 消息 │ │ 消息 │ │ +└──────────────┘ └────────────────┘ └──────────────┘ + ▲ ▲ + │ │ + └──────────────────────────┘ + 直接连接(RPC指令) +``` + +## 智能体集成 + +### 设置智能体与机器人 + +运行文件的智能体集成模式: + +```python +#!/usr/bin/env python3 +import asyncio +import reactivex as rx +from dimos.agents.claude_agent import ClaudeAgent +from dimos.web.robot_web_interface import RobotWebInterface + +def main(): + # 1. 创建并启动机器人 + robot = YourRobot() + asyncio.run(robot.start()) + + # 2. 设置技能 + skills = robot.get_skills() + skills.add(YourSkill) + skills.create_instance("YourSkill", robot=robot) + + # 3. 设置响应式流 + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # 4. 创建 Web 界面 + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream}, + audio_subject=rx.subject.Subject() + ) + + # 5. 创建智能体 + agent = ClaudeAgent( + dev_name="your_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query="您的系统提示词", + model_name="claude-3-5-haiku-latest" + ) + + # 6. 连接智能体响应 + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + # 7. 运行界面 + web_interface.run() +``` + +### 关键集成点 + +1. **响应式流**:使用 RxPy 进行事件驱动通信 +2. **Web 界面**:提供用户输入/输出 +3. **智能体**:处理自然语言并执行技能 +4. **技能**:将机器人能力定义为可执行动作 + +## 完整示例 + +### 步骤 1:创建机器人类(`my_robot.py`) + +```python +import asyncio +from typing import Optional, List +from dimos import core +from dimos.hardware.camera import CameraModule +from dimos.manipulation.module import ManipulationModule +from dimos.skills.skills import SkillLibrary +from dimos.types.robot_capabilities import RobotCapability +from dimos_lcm.sensor_msgs import Image, CameraInfo +from dimos.protocol import pubsub + +class MyRobot: + def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + self.dimos = None + self.camera = None + self.manipulation = None + self.skill_library = SkillLibrary() + + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self): + # 启动 DIMOS + self.dimos = core.start(2) + + # 启用 LCM + pubsub.lcm.autoconf() + + # 部署相机 + self.camera = self.dimos.deploy( + CameraModule, + camera_id=0, + fps=30 + ) + + # 配置相机 LCM + self.camera.color_image.transport = core.LCMTransport("/camera/rgb", Image) + self.camera.depth_image.transport = core.LCMTransport("/camera/depth", Image) + self.camera.camera_info.transport = core.LCMTransport("/camera/info", CameraInfo) + + # 部署操作模块 + self.manipulation = self.dimos.deploy(ManipulationModule) + + # 连接模块 + self.manipulation.rgb_image.connect(self.camera.color_image) + self.manipulation.depth_image.connect(self.camera.depth_image) + self.manipulation.camera_info.connect(self.camera.camera_info) + + # 配置操作输出 + self.manipulation.viz_image.transport = core.LCMTransport("/viz/output", Image) + + # 启动模块 + self.camera.start() + self.manipulation.start() + + await asyncio.sleep(2) # 允许初始化 + + def get_skills(self): + return self.skill_library + + def stop(self): + if self.manipulation: + self.manipulation.stop() + if self.camera: + self.camera.stop() + if self.dimos: + self.dimos.close() +``` + +### 步骤 2:创建运行脚本(`run.py`) + +```python +#!/usr/bin/env python3 +import asyncio +import os +from my_robot import MyRobot +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.basic import BasicSkill +from dimos.web.robot_web_interface import RobotWebInterface +import reactivex as rx +import reactivex.operators as ops + +SYSTEM_PROMPT = """您是一个有用的机器人助手。""" + +def main(): + # 检查 API 密钥 + if not os.getenv("ANTHROPIC_API_KEY"): + print("请设置 ANTHROPIC_API_KEY") + return + + # 创建机器人 + robot = MyRobot() + + try: + # 启动机器人 + asyncio.run(robot.start()) + + # 设置技能 + skills = robot.get_skills() + skills.add(BasicSkill) + skills.create_instance("BasicSkill", robot=robot) + + # 设置流 + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + # 创建 Web 界面 + web_interface = RobotWebInterface( + port=5555, + text_streams={"agent_responses": agent_response_stream} + ) + + # 创建智能体 + agent = ClaudeAgent( + dev_name="my_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query=SYSTEM_PROMPT + ) + + # 连接响应 + agent.get_response_observable().subscribe( + lambda x: agent_response_subject.on_next(x) + ) + + print("机器人就绪,访问 http://localhost:5555") + + # 运行 + web_interface.run() + + finally: + robot.stop() + +if __name__ == "__main__": + main() +``` + +### 步骤 3:定义技能(`skills.py`) + +```python +from dimos.skills import Skill, skill + +@skill( + description="执行一个基本动作", + parameters={ + "action": "要执行的动作" + } +) +class BasicSkill(Skill): + def __init__(self, robot): + self.robot = robot + + def run(self, action: str): + # 实现技能逻辑 + return f"已执行:{action}" +``` + +## 最佳实践 + +1. **模块生命周期**:在部署模块之前始终先启动 DIMOS +2. **LCM 主题**:使用带命名空间的描述性主题名称 +3. **错误处理**:用 try-except 块包装模块操作 +4. **资源清理**:确保在 stop() 方法中正确清理 +5. **异步操作**:使用 asyncio 进行非阻塞操作 +6. **流管理**:使用 RxPy 进行响应式编程模式 + +## 调试技巧 + +1. **检查模块状态**:打印 module.io().result() 查看连接 +2. **监控 LCM**:使用 Foxglove 可视化 LCM 消息 +3. **记录一切**:使用 dimos.utils.logging_config.setup_logger() +4. **独立测试模块**:一次部署和测试一个模块 + +## 常见问题 + +1. **"模块未启动"**:确保在部署后调用 start() +2. **"未收到数据"**:检查 LCM 传输配置 +3. **"连接失败"**:验证输入/输出类型是否匹配 +4. **"清理错误"**:在关闭 DIMOS 之前停止模块 + +## 高级主题 + +### 自定义模块开发 + +创建自定义模块的基本结构: + +```python +from dimos.core import Module, In, Out, rpc + +class CustomModule(Module): + # 定义输入 + input_data: In[DataType] = None + + # 定义输出 + output_data: Out[DataType] = None + + def __init__(self, param1, param2, **kwargs): + super().__init__(**kwargs) + self.param1 = param1 + self.param2 = param2 + + @rpc + def start(self): + """启动模块处理。""" + self.input_data.subscribe(self._process_data) + + def _process_data(self, data): + """处理输入数据。""" + # 处理逻辑 + result = self.process(data) + # 发布输出 + self.output_data.publish(result) + + @rpc + def stop(self): + """停止模块。""" + # 清理资源 + pass +``` + +### 技能开发指南 + +技能是机器人可执行的高级动作: + +```python +from dimos.skills import Skill, skill +from typing import Optional + +@skill( + description="复杂操作技能", + parameters={ + "target": "目标对象", + "location": "目标位置" + } +) +class ComplexSkill(Skill): + def __init__(self, robot, **kwargs): + super().__init__(**kwargs) + self.robot = robot + + def run(self, target: str, location: Optional[str] = None): + """执行技能逻辑。""" + try: + # 1. 感知阶段 + object_info = self.robot.detect_object(target) + + # 2. 规划阶段 + if location: + plan = self.robot.plan_movement(object_info, location) + + # 3. 执行阶段 + result = self.robot.execute_plan(plan) + + return { + "success": True, + "message": f"成功移动 {target} 到 {location}" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } +``` + +### 性能优化 + +1. **并行处理**:使用多个工作线程处理不同模块 +2. **数据缓冲**:为高频数据流实现缓冲机制 +3. **延迟加载**:仅在需要时初始化重型模块 +4. **资源池化**:重用昂贵的资源(如神经网络模型) + +希望本指南能帮助您快速上手 DIMOS 机器人开发! diff --git a/dimos/robot/agilex/piper_arm.py b/dimos/robot/agilex/piper_arm.py new file mode 100644 index 0000000000..75187678b0 --- /dev/null +++ b/dimos/robot/agilex/piper_arm.py @@ -0,0 +1,181 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +# Import LCM message types +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + +from dimos import core +from dimos.hardware.camera.zed import ZEDModule +from dimos.manipulation.visual_servoing.manipulation_module import ManipulationModule +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.robot.robot import Robot +from dimos.skills.skills import SkillLibrary +from dimos.types.robot_capabilities import RobotCapability +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class PiperArmRobot(Robot): + """Piper Arm robot with ZED camera and manipulation capabilities.""" + + def __init__(self, robot_capabilities: list[RobotCapability] | None = None) -> None: + super().__init__() + self.dimos = None + self.stereo_camera = None + self.manipulation_interface = None + self.skill_library = SkillLibrary() # type: ignore[assignment] + + # Initialize capabilities + self.capabilities = robot_capabilities or [ + RobotCapability.VISION, + RobotCapability.MANIPULATION, + ] + + async def start(self) -> None: + """Start the robot modules.""" + # Start Dimos + self.dimos = core.start(2) # type: ignore[assignment] # Need 2 workers for ZED and manipulation modules + self.foxglove_bridge = FoxgloveBridge() + + # Enable LCM auto-configuration + pubsub.lcm.autoconf() # type: ignore[attr-defined] + + # Deploy ZED module + logger.info("Deploying ZED module...") + self.stereo_camera = self.dimos.deploy( # type: ignore[attr-defined] + ZEDModule, + camera_id=0, + resolution="HD720", + depth_mode="NEURAL", + fps=30, + enable_tracking=False, # We don't need tracking for manipulation + publish_rate=30.0, + frame_id="zed_camera", + ) + + # Configure ZED LCM transports + self.stereo_camera.color_image.transport = core.LCMTransport("/zed/color_image", Image) # type: ignore[attr-defined] + self.stereo_camera.depth_image.transport = core.LCMTransport("/zed/depth_image", Image) # type: ignore[attr-defined] + self.stereo_camera.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) # type: ignore[attr-defined] + + # Deploy manipulation module + logger.info("Deploying manipulation module...") + self.manipulation_interface = self.dimos.deploy(ManipulationModule) # type: ignore[attr-defined] + + # Connect manipulation inputs to ZED outputs + self.manipulation_interface.rgb_image.connect(self.stereo_camera.color_image) # type: ignore[attr-defined] + self.manipulation_interface.depth_image.connect(self.stereo_camera.depth_image) # type: ignore[attr-defined] + self.manipulation_interface.camera_info.connect(self.stereo_camera.camera_info) # type: ignore[attr-defined] + + # Configure manipulation output + self.manipulation_interface.viz_image.transport = core.LCMTransport( # type: ignore[attr-defined] + "/manipulation/viz", Image + ) + + # Print module info + logger.info("Modules configured:") + print("\nZED Module:") + print(self.stereo_camera.io()) # type: ignore[attr-defined] + print("\nManipulation Module:") + print(self.manipulation_interface.io()) # type: ignore[attr-defined] + + # Start modules + logger.info("Starting modules...") + self.foxglove_bridge.start() + self.stereo_camera.start() # type: ignore[attr-defined] + self.manipulation_interface.start() # type: ignore[attr-defined] + + # Give modules time to initialize + await asyncio.sleep(2) + + logger.info("PiperArmRobot initialized and started") + + def pick_and_place( # type: ignore[no-untyped-def] + self, pick_x: int, pick_y: int, place_x: int | None = None, place_y: int | None = None + ): + """Execute pick and place task. + + Args: + pick_x: X coordinate for pick location + pick_y: Y coordinate for pick location + place_x: X coordinate for place location (optional) + place_y: Y coordinate for place location (optional) + + Returns: + Result of the pick and place operation + """ + if self.manipulation_interface: + return self.manipulation_interface.pick_and_place(pick_x, pick_y, place_x, place_y) + else: + logger.error("Manipulation module not initialized") + return False + + def handle_keyboard_command(self, key: str): # type: ignore[no-untyped-def] + """Pass keyboard commands to manipulation module. + + Args: + key: Keyboard key pressed + + Returns: + Action taken or None + """ + if self.manipulation_interface: + return self.manipulation_interface.handle_keyboard_command(key) + else: + logger.error("Manipulation module not initialized") + return None + + def stop(self) -> None: + """Stop all modules and clean up.""" + logger.info("Stopping PiperArmRobot...") + + try: + if self.manipulation_interface: + self.manipulation_interface.stop() + + if self.stereo_camera: + self.stereo_camera.stop() + except Exception as e: + logger.warning(f"Error stopping modules: {e}") + + # Close dimos last to ensure workers are available for cleanup + if self.dimos: + self.dimos.close() + + logger.info("PiperArmRobot stopped") + + +async def run_piper_arm() -> None: + """Run the Piper Arm robot.""" + robot = PiperArmRobot() # type: ignore[abstract] + + await robot.start() + + # Keep the robot running + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + finally: + await robot.stop() # type: ignore[func-returns-value] + + +if __name__ == "__main__": + asyncio.run(run_piper_arm()) diff --git a/dimos/robot/agilex/run.py b/dimos/robot/agilex/run.py new file mode 100644 index 0000000000..f91f4ed440 --- /dev/null +++ b/dimos/robot/agilex/run.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Piper Arm robot with Claude agent integration. +Provides manipulation capabilities with natural language interface. +""" + +import asyncio +import os +import sys + +from dotenv import load_dotenv +import reactivex as rx +import reactivex.operators as ops + +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.agilex.piper_arm import PiperArmRobot +from dimos.skills.kill_skill import KillSkill +from dimos.skills.manipulation.pick_and_place import PickAndPlace +from dimos.stream.audio.pipelines import stt, tts +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + +logger = setup_logger() + +# Load environment variables +load_dotenv() + +# System prompt for the Piper Arm manipulation agent +SYSTEM_PROMPT = """You are an intelligent robotic assistant controlling a Piper Arm robot with advanced manipulation capabilities. Your primary role is to help users with pick and place tasks using natural language understanding. + +## Your Capabilities: +1. **Visual Perception**: You have access to a ZED stereo camera that provides RGB and depth information +2. **Object Manipulation**: You can pick up and place objects using a 6-DOF robotic arm with a gripper +3. **Language Understanding**: You use the Qwen vision-language model to identify objects and locations from natural language descriptions + +## Available Skills: +- **PickAndPlace**: Execute pick and place operations based on object and location descriptions + - Pick only: "Pick up the red mug" + - Pick and place: "Move the book to the shelf" +- **KillSkill**: Stop any currently running skill + +## Guidelines: +1. **Safety First**: Always ensure safe operation. If unsure about an object's graspability or a placement location's stability, ask for clarification +2. **Clear Communication**: Explain what you're doing and ask for confirmation when needed +3. **Error Handling**: If a task fails, explain why and suggest alternatives +4. **Precision**: When users give specific object descriptions, use them exactly as provided to the vision model + +## Interaction Examples: +- User: "Pick up the coffee mug" + You: "I'll pick up the coffee mug for you." [Execute PickAndPlace with object_query="coffee mug"] + +- User: "Put the toy on the table" + You: "I'll place the toy on the table." [Execute PickAndPlace with object_query="toy", target_query="on the table"] + +- User: "What do you see?" + +Remember: You're here to assist with manipulation tasks. Be helpful, precise, and always prioritize safe operation of the robot.""" + + +def main(): # type: ignore[no-untyped-def] + """Main entry point.""" + print("\n" + "=" * 60) + print("Piper Arm Robot with Claude Agent") + print("=" * 60) + print("\nThis system integrates:") + print(" - Piper Arm 6-DOF robot") + print(" - ZED stereo camera") + print(" - Claude AI for natural language understanding") + print(" - Qwen VLM for visual object detection") + print(" - Web interface with text and voice input") + print(" - Foxglove visualization via LCM") + print("\nStarting system...\n") + + # Check for API key + if not os.getenv("ANTHROPIC_API_KEY"): + print("WARNING: ANTHROPIC_API_KEY not found in environment") + print("Please set your API key in .env file or environment") + sys.exit(1) + + logger.info("Starting Piper Arm Robot with Agent") + + # Create robot instance + robot = PiperArmRobot() # type: ignore[abstract] + + try: + # Start the robot (this is async, so we need asyncio.run) + logger.info("Initializing robot...") + asyncio.run(robot.start()) + logger.info("Robot initialized successfully") + + # Set up skill library + skills = robot.get_skills() # type: ignore[no-untyped-call] + skills.add(PickAndPlace) + skills.add(KillSkill) + + # Create skill instances + skills.create_instance("PickAndPlace", robot=robot) + skills.create_instance("KillSkill", robot=robot, skill_library=skills) + + logger.info(f"Skills registered: {[skill.__name__ for skill in skills.get_class_skills()]}") + + # Set up streams for agent and web interface + agent_response_subject = rx.subject.Subject() # type: ignore[var-annotated] + agent_response_stream = agent_response_subject.pipe(ops.share()) + audio_subject = rx.subject.Subject() # type: ignore[var-annotated] + + # Set up streams for web interface + streams = {} # type: ignore[var-annotated] + + text_streams = { + "agent_responses": agent_response_stream, + } + + # Create web interface first (needed for agent) + try: + web_interface = RobotWebInterface( + port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams + ) + logger.info("Web interface created successfully") + except Exception as e: + logger.error(f"Failed to create web interface: {e}") + raise + + # Set up speech-to-text + stt_node = stt() # type: ignore[no-untyped-call] + stt_node.consume_audio(audio_subject.pipe(ops.share())) + + # Create Claude agent + agent = ClaudeAgent( + dev_name="piper_arm_agent", + input_query_stream=web_interface.query_stream, # Use text input from web interface + # input_query_stream=stt_node.emit_text(), # Uncomment to use voice input + skills=skills, + system_query=SYSTEM_PROMPT, + model_name="claude-3-5-haiku-latest", + thinking_budget_tokens=0, + max_output_tokens_per_request=4096, + ) + + # Subscribe to agent responses + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + + # Set up text-to-speech for agent responses + tts_node = tts() # type: ignore[no-untyped-call] + tts_node.consume_text(agent.get_response_observable()) + + logger.info("=" * 60) + logger.info("Piper Arm Agent Ready!") + logger.info("Web interface available at: http://localhost:5555") + logger.info("Foxglove visualization available at: ws://localhost:8765") + logger.info("You can:") + logger.info(" - Type commands in the web interface") + logger.info(" - Use voice commands") + logger.info(" - Ask the robot to pick up objects") + logger.info(" - Ask the robot to move objects to locations") + logger.info("=" * 60) + + # Run web interface (this blocks) + web_interface.run() + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + except Exception as e: + logger.error(f"Error running robot: {e}") + import traceback + + traceback.print_exc() + finally: + logger.info("Shutting down...") + # Stop the robot (this is also async) + robot.stop() + logger.info("Robot stopped") + + +if __name__ == "__main__": + main() # type: ignore[no-untyped-call] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py new file mode 100644 index 0000000000..8fe9978517 --- /dev/null +++ b/dimos/robot/all_blueprints.py @@ -0,0 +1,90 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.blueprints import ModuleBlueprintSet + +# The blueprints are defined as import strings so as not to trigger unnecessary imports. +all_blueprints = { + "unitree-go2": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard", + "unitree-go2-basic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:basic", + "unitree-go2-shm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_shm", + "unitree-go2-jpegshm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_jpegshm", + "unitree-go2-jpeglcm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_jpeglcm", + "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", + "unitree-go2-agentic-ollama": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_ollama", + "unitree-go2-agentic-huggingface": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_huggingface", + "unitree-g1": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard", + "unitree-g1-sim": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_sim", + "unitree-g1-basic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_ros", + "unitree-g1-basic-sim": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_sim", + "unitree-g1-shm": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_with_shm", + "unitree-g1-agentic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:agentic", + "unitree-g1-agentic-sim": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:agentic_sim", + "unitree-g1-joystick": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:with_joystick", + "unitree-g1-full": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:full_featured", + "unitree-g1-detection": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:detection", + "demo-osm": "dimos.mapping.osm.demo_osm:demo_osm", + "demo-skill": "dimos.agents2.skills.demo_skill:demo_skill", + "demo-gps-nav": "dimos.agents2.skills.demo_gps_nav:demo_gps_nav_skill", + "demo-google-maps-skill": "dimos.agents2.skills.demo_google_maps_skill:demo_google_maps_skill", + "demo-remapping": "dimos.robot.unitree_webrtc.demo_remapping:remapping", + "demo-remapping-transport": "dimos.robot.unitree_webrtc.demo_remapping:remapping_and_transport", + "demo-error-on-name-conflicts": "dimos.robot.unitree_webrtc.demo_error_on_name_conflicts:blueprint", +} + + +all_modules = { + "astar_planner": "dimos.navigation.global_planner.planner", + "behavior_tree_navigator": "dimos.navigation.bt_navigator.navigator", + "camera_module": "dimos.hardware.camera.module", + "depth_module": "dimos.robot.unitree_webrtc.depth_module", + "detection_2d": "dimos.perception.detection2d.module2D", + "foxglove_bridge": "dimos.robot.foxglove_bridge", + "g1_connection": "dimos.robot.unitree.connection.g1", + "g1_joystick": "dimos.robot.unitree_webrtc.g1_joystick_module", + "g1_skills": "dimos.robot.unitree_webrtc.unitree_g1_skill_container", + "google_maps_skill": "dimos.agents2.skills.google_maps_skill_container", + "gps_nav_skill": "dimos.agents2.skills.gps_nav_skill", + "holonomic_local_planner": "dimos.navigation.local_planner.holonomic_local_planner", + "human_input": "dimos.agents2.cli.human", + "keyboard_teleop": "dimos.robot.unitree_webrtc.keyboard_teleop", + "llm_agent": "dimos.agents2.agent", + "mapper": "dimos.robot.unitree_webrtc.type.map", + "navigation_skill": "dimos.agents2.skills.navigation", + "object_tracking": "dimos.perception.object_tracker", + "osm_skill": "dimos.agents2.skills.osm", + "ros_nav": "dimos.navigation.rosnav", + "spatial_memory": "dimos.perception.spatial_perception", + "speak_skill": "dimos.agents2.skills.speak_skill", + "unitree_skills": "dimos.robot.unitree_webrtc.unitree_skill_container", + "utilization": "dimos.utils.monitoring", + "wavefront_frontier_explorer": "dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector", + "websocket_vis": "dimos.web.websocket_vis.websocket_vis_module", + "web_input": "dimos.agents2.cli.web", +} + + +def get_blueprint_by_name(name: str) -> ModuleBlueprintSet: + if name not in all_blueprints: + raise ValueError(f"Unknown blueprint set name: {name}") + module_path, attr = all_blueprints[name].split(":") + module = __import__(module_path, fromlist=[attr]) + return getattr(module, attr) # type: ignore[no-any-return] + + +def get_module_by_name(name: str) -> ModuleBlueprintSet: + if name not in all_modules: + raise ValueError(f"Unknown module name: {name}") + python_module = __import__(all_modules[name], fromlist=[name]) + return getattr(python_module, name)() # type: ignore[no-any-return] diff --git a/dimos/robot/cli/README.md b/dimos/robot/cli/README.md new file mode 100644 index 0000000000..a8ceb37ba4 --- /dev/null +++ b/dimos/robot/cli/README.md @@ -0,0 +1,65 @@ +# Robot CLI + +To avoid having so many runfiles, I created a common script to run any blueprint. + +For example, to run the standard Unitree Go2 blueprint run: + +```bash +dimos run unitree-go2 +``` + +For the one with agents run: + +```bash +dimos run unitree-go2-agentic +``` + +You can dynamically connect additional modules. For example: + +```bash +dimos run unitree-go2 --extra-module llm_agent --extra-module human_input --extra-module navigation_skill +``` + +## Definitions + +Blueprints can be defined anywhere, but they're all linked together in `dimos/robot/all_blueprints.py`. E.g.: + +```python +all_blueprints = { + "unitree-go2": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard", + "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", + ... +} +``` + +(They are defined as imports to avoid triggering unrelated imports.) + +## `GlobalConfig` + +This tool also initializes the global config and passes it to the blueprint. + +`GlobalConfig` contains configuration options that are useful across many modules. For example: + +```python +class GlobalConfig(BaseSettings): + robot_ip: str | None = None + simulation: bool = False + replay: bool = False + n_dask_workers: int = 2 +``` + +Configuration values can be set from multiple places in order of precedence (later entries override earlier ones): + +- Default value defined on GlobalConfig. (`simulation = False`) +- Value defined in `.env` (`SIMULATION=true`) +- Value in the environment variable (`SIMULATION=true`) +- Value coming from the CLI (`--simulation` or `--no-simulation`) +- Value defined on the blueprint (`blueprint.global_config(simulation=True)`) + +For environment variables/`.env` values, you have to prefix the name with `DIMOS_`. + +For the command line, you call it like this: + +```bash +dimos --simulation run unitree-go2 +``` diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py new file mode 100644 index 0000000000..f78746dd37 --- /dev/null +++ b/dimos/robot/cli/dimos.py @@ -0,0 +1,180 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +import inspect +import sys +from typing import Optional, get_args, get_origin + +import typer + +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import GlobalConfig +from dimos.protocol import pubsub +from dimos.robot.all_blueprints import all_blueprints, get_blueprint_by_name, get_module_by_name +from dimos.utils.logging_config import setup_exception_handler + +RobotType = Enum("RobotType", {key.replace("-", "_").upper(): key for key in all_blueprints.keys()}) # type: ignore[misc] + +main = typer.Typer( + help="Dimensional CLI", + no_args_is_help=True, +) + + +def create_dynamic_callback(): # type: ignore[no-untyped-def] + fields = GlobalConfig.model_fields + + # Build the function signature dynamically + params = [ + inspect.Parameter("ctx", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typer.Context), + ] + + # Create parameters for each field in GlobalConfig + for field_name, field_info in fields.items(): + field_type = field_info.annotation + + # Handle Optional types + # Check for Optional/Union with None + if get_origin(field_type) is type(Optional[str]): # noqa: UP045 + inner_types = get_args(field_type) + if len(inner_types) == 2 and type(None) in inner_types: + # It's Optional[T], get the actual type T + actual_type = next(t for t in inner_types if t != type(None)) + else: + actual_type = field_type + else: + actual_type = field_type + + # Convert field name from snake_case to kebab-case for CLI + cli_option_name = field_name.replace("_", "-") + + # Special handling for boolean fields + if actual_type is bool: + # For boolean fields, create --flag/--no-flag pattern + param = inspect.Parameter( + field_name, + inspect.Parameter.KEYWORD_ONLY, + default=typer.Option( + None, # None means use the model's default if not provided + f"--{cli_option_name}/--no-{cli_option_name}", + help=f"Override {field_name} in GlobalConfig", + ), + annotation=Optional[bool], # noqa: UP045 + ) + else: + # For non-boolean fields, use regular option + param = inspect.Parameter( + field_name, + inspect.Parameter.KEYWORD_ONLY, + default=typer.Option( + None, # None means use the model's default if not provided + f"--{cli_option_name}", + help=f"Override {field_name} in GlobalConfig", + ), + annotation=Optional[actual_type], # noqa: UP045 + ) + params.append(param) + + def callback(**kwargs) -> None: # type: ignore[no-untyped-def] + ctx = kwargs.pop("ctx") + overrides = {k: v for k, v in kwargs.items() if v is not None} + ctx.obj = GlobalConfig().model_copy(update=overrides) + + callback.__signature__ = inspect.Signature(params) # type: ignore[attr-defined] + + return callback + + +main.callback()(create_dynamic_callback()) # type: ignore[no-untyped-call] + + +@main.command() +def run( + ctx: typer.Context, + robot_type: RobotType = typer.Argument(..., help="Type of robot to run"), + extra_modules: list[str] = typer.Option( # type: ignore[valid-type] + [], "--extra-module", help="Extra modules to add to the blueprint" + ), +) -> None: + """Start a robot blueprint""" + setup_exception_handler() + + config: GlobalConfig = ctx.obj + pubsub.lcm.autoconf() # type: ignore[attr-defined] + blueprint = get_blueprint_by_name(robot_type.value) + + if extra_modules: + loaded_modules = [get_module_by_name(mod_name) for mod_name in extra_modules] # type: ignore[attr-defined] + blueprint = autoconnect(blueprint, *loaded_modules) + + dimos = blueprint.build(global_config=config) + dimos.loop() + + +@main.command() +def show_config(ctx: typer.Context) -> None: + """Show current config settings and their values.""" + config: GlobalConfig = ctx.obj + + for field_name, value in config.model_dump().items(): + typer.echo(f"{field_name}: {value}") + + +@main.command() +def list() -> None: + """List all available blueprints.""" + blueprints = [name for name in all_blueprints.keys() if not name.startswith("demo-")] + for blueprint_name in sorted(blueprints): + typer.echo(blueprint_name) + + +@main.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) +def lcmspy(ctx: typer.Context) -> None: + """LCM spy tool for monitoring LCM messages.""" + from dimos.utils.cli.lcmspy.run_lcmspy import main as lcmspy_main + + sys.argv = ["lcmspy", *ctx.args] + lcmspy_main() + + +@main.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) +def skillspy(ctx: typer.Context) -> None: + """Skills spy tool for monitoring skills.""" + from dimos.utils.cli.skillspy.skillspy import main as skillspy_main + + sys.argv = ["skillspy", *ctx.args] + skillspy_main() + + +@main.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) +def agentspy(ctx: typer.Context) -> None: + """Agent spy tool for monitoring agents.""" + from dimos.utils.cli.agentspy.agentspy import main as agentspy_main + + sys.argv = ["agentspy", *ctx.args] + agentspy_main() + + +@main.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) +def humancli(ctx: typer.Context) -> None: + """Interface interacting with agents.""" + from dimos.utils.cli.human.humanclianim import main as humancli_main + + sys.argv = ["humancli", *ctx.args] + humancli_main() + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/cli/test_dimos_robot_e2e.py b/dimos/robot/cli/test_dimos_robot_e2e.py new file mode 100644 index 0000000000..7cb8dd1854 --- /dev/null +++ b/dimos/robot/cli/test_dimos_robot_e2e.py @@ -0,0 +1,156 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import signal +import subprocess +import time + +import lcm +import pytest + +from dimos.core.transport import pLCMTransport +from dimos.protocol.service.lcmservice import LCMService + + +class LCMSpy(LCMService): + messages: dict[str, list[bytes]] = {} + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.l = lcm.LCM() + + def start(self) -> None: + super().start() + if self.l: + self.l.subscribe(".*", self.msg) + + def wait_for_topic(self, topic: str, timeout: float = 30.0) -> list[bytes]: + start_time = time.time() + while time.time() - start_time < timeout: + if topic in self.messages: + return self.messages[topic] + time.sleep(0.1) + raise TimeoutError(f"Timeout waiting for topic {topic}") + + def wait_for_message_content( + self, topic: str, content_contains: bytes, timeout: float = 30.0 + ) -> None: + start_time = time.time() + while time.time() - start_time < timeout: + if topic in self.messages: + for msg in self.messages[topic]: + if content_contains in msg: + return + time.sleep(0.1) + raise TimeoutError(f"Timeout waiting for message content on topic {topic}") + + def stop(self) -> None: + super().stop() + + def msg(self, topic, data) -> None: + self.messages.setdefault(topic, []).append(data) + + +class DimosRobotCall: + process: subprocess.Popen | None + + def __init__(self) -> None: + self.process = None + + def start(self) -> None: + self.process = subprocess.Popen( + ["dimos", "run", "demo-skill"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + def stop(self) -> None: + if self.process is None: + return + + try: + # Send the kill signal (SIGTERM for graceful shutdown) + self.process.send_signal(signal.SIGTERM) + + # Record the time when we sent the kill signal + shutdown_start = time.time() + + # Wait for the process to terminate with a 30-second timeout + try: + self.process.wait(timeout=30) + shutdown_duration = time.time() - shutdown_start + + # Verify it shut down in time + assert shutdown_duration <= 30, ( + f"Process took {shutdown_duration:.2f} seconds to shut down, " + f"which exceeds the 30-second limit" + ) + except subprocess.TimeoutExpired: + # If we reach here, the process didn't terminate in 30 seconds + self.process.kill() # Force kill + self.process.wait() # Clean up + raise AssertionError( + "Process did not shut down within 30 seconds after receiving SIGTERM" + ) + + except Exception: + # Clean up if something goes wrong + if self.process.poll() is None: # Process still running + self.process.kill() + self.process.wait() + raise + + +@pytest.fixture +def lcm_spy(): + lcm_spy = LCMSpy() + lcm_spy.start() + yield lcm_spy + lcm_spy.stop() + + +@pytest.fixture +def dimos_robot_call(): + dimos_robot_call = DimosRobotCall() + dimos_robot_call.start() + yield dimos_robot_call + dimos_robot_call.stop() + + +@pytest.fixture +def human_input(): + transport = pLCMTransport("/human_input") + transport.lcm.start() + + def send_human_input(message: str) -> None: + transport.publish(message) + + yield send_human_input + + transport.lcm.stop() + + +@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.") +def test_dimos_robot_demo_e2e(lcm_spy, dimos_robot_call, human_input) -> None: + lcm_spy.wait_for_topic("/rpc/DemoCalculatorSkill/set_LlmAgent_register_skills/res") + lcm_spy.wait_for_topic("/rpc/HumanInput/start/res") + lcm_spy.wait_for_message_content("/agent", b"AIMessage") + + human_input("what is 52983 + 587237") + + lcm_spy.wait_for_message_content("/agent", b"640220") + + assert "/rpc/DemoCalculatorSkill/sum_numbers/req" in lcm_spy.messages + assert "/rpc/DemoCalculatorSkill/sum_numbers/res" in lcm_spy.messages diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py new file mode 100644 index 0000000000..30db8379b9 --- /dev/null +++ b/dimos/robot/foxglove_bridge.py @@ -0,0 +1,98 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import threading + +# this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm +from dimos_lcm.foxglove_bridge import ( # type: ignore[import-untyped] + FoxgloveBridge as LCMFoxgloveBridge, +) + +from dimos.core import DimosCluster, Module, rpc + +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) + + +class FoxgloveBridge(Module): + _thread: threading.Thread + _loop: asyncio.AbstractEventLoop + + def __init__(self, *args, shm_channels=None, jpeg_shm_channels=None, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.shm_channels = shm_channels or [] + self.jpeg_shm_channels = jpeg_shm_channels or [] + + @rpc + def start(self) -> None: + super().start() + + def run_bridge() -> None: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + try: + for logger in ["lcm_foxglove_bridge", "FoxgloveServer"]: + logger = logging.getLogger(logger) # type: ignore[assignment] + logger.setLevel(logging.ERROR) # type: ignore[attr-defined] + for handler in logger.handlers: # type: ignore[attr-defined] + handler.setLevel(logging.ERROR) + + bridge = LCMFoxgloveBridge( + host="0.0.0.0", + port=8765, + debug=False, + num_threads=4, + shm_channels=self.shm_channels, + jpeg_shm_channels=self.jpeg_shm_channels, + ) + self._loop.run_until_complete(bridge.run()) + except Exception as e: + print(f"Foxglove bridge error: {e}") + + self._thread = threading.Thread(target=run_bridge, daemon=True) + self._thread.start() + + @rpc + def stop(self) -> None: + if self._loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join(timeout=2) + + super().stop() + + +def deploy( + dimos: DimosCluster, + shm_channels: list[str] | None = None, +) -> FoxgloveBridge: + if shm_channels is None: + shm_channels = [ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + "/map#sensor_msgs.PointCloud2", + ] + foxglove_bridge = dimos.deploy( # type: ignore[attr-defined] + FoxgloveBridge, + shm_channels=shm_channels, + ) + foxglove_bridge.start() + return foxglove_bridge # type: ignore[no-any-return] + + +foxglove_bridge = FoxgloveBridge.blueprint + + +__all__ = ["FoxgloveBridge", "deploy", "foxglove_bridge"] diff --git a/dimos/robot/position_stream.py b/dimos/robot/position_stream.py new file mode 100644 index 0000000000..77a86bff4c --- /dev/null +++ b/dimos/robot/position_stream.py @@ -0,0 +1,161 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Position stream provider for ROS-based robots. + +This module creates a reactive stream of position updates from ROS odometry or pose topics. +""" + +import logging +import time + +from geometry_msgs.msg import PoseStamped # type: ignore[attr-defined] +from nav_msgs.msg import Odometry # type: ignore[attr-defined] +from rclpy.node import Node +from reactivex import Observable, Subject, operators as ops + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.INFO) + + +class PositionStreamProvider: + """ + A provider for streaming position updates from ROS. + + This class creates an Observable stream of position updates by subscribing + to ROS odometry or pose topics. + """ + + def __init__( + self, + ros_node: Node, + odometry_topic: str = "/odom", + pose_topic: str | None = None, + use_odometry: bool = True, + ) -> None: + """ + Initialize the position stream provider. + + Args: + ros_node: ROS node to use for subscriptions + odometry_topic: Name of the odometry topic (if use_odometry is True) + pose_topic: Name of the pose topic (if use_odometry is False) + use_odometry: Whether to use odometry (True) or pose (False) for position + """ + self.ros_node = ros_node + self.odometry_topic = odometry_topic + self.pose_topic = pose_topic + self.use_odometry = use_odometry + + self._subject = Subject() # type: ignore[var-annotated] + + self.last_position = None + self.last_update_time = None + + self._create_subscription() # type: ignore[no-untyped-call] + + logger.info( + f"PositionStreamProvider initialized with " + f"{'odometry topic' if use_odometry else 'pose topic'}: " + f"{odometry_topic if use_odometry else pose_topic}" + ) + + def _create_subscription(self): # type: ignore[no-untyped-def] + """Create the appropriate ROS subscription based on configuration.""" + if self.use_odometry: + self.subscription = self.ros_node.create_subscription( + Odometry, self.odometry_topic, self._odometry_callback, 10 + ) + logger.info(f"Subscribed to odometry topic: {self.odometry_topic}") + else: + if not self.pose_topic: + raise ValueError("Pose topic must be specified when use_odometry is False") + + self.subscription = self.ros_node.create_subscription( + PoseStamped, self.pose_topic, self._pose_callback, 10 + ) + logger.info(f"Subscribed to pose topic: {self.pose_topic}") + + def _odometry_callback(self, msg: Odometry) -> None: + """ + Process odometry messages and extract position. + + Args: + msg: Odometry message from ROS + """ + x = msg.pose.pose.position.x + y = msg.pose.pose.position.y + + self._update_position(x, y) + + def _pose_callback(self, msg: PoseStamped) -> None: + """ + Process pose messages and extract position. + + Args: + msg: PoseStamped message from ROS + """ + x = msg.pose.position.x + y = msg.pose.position.y + + self._update_position(x, y) + + def _update_position(self, x: float, y: float) -> None: + """ + Update the current position and emit to subscribers. + + Args: + x: X coordinate + y: Y coordinate + """ + current_time = time.time() + position = (x, y) + + if self.last_update_time: + update_rate = 1.0 / (current_time - self.last_update_time) + logger.debug(f"Position update rate: {update_rate:.1f} Hz") + + self.last_position = position # type: ignore[assignment] + self.last_update_time = current_time # type: ignore[assignment] + + self._subject.on_next(position) + logger.debug(f"Position updated: ({x:.2f}, {y:.2f})") + + def get_position_stream(self) -> Observable: # type: ignore[type-arg] + """ + Get an Observable stream of position updates. + + Returns: + Observable that emits (x, y) tuples + """ + return self._subject.pipe( + ops.share() # Share the stream among multiple subscribers + ) + + def get_current_position(self) -> tuple[float, float] | None: + """ + Get the most recent position. + + Returns: + Tuple of (x, y) coordinates, or None if no position has been received + """ + return self.last_position + + def cleanup(self) -> None: + """Clean up resources.""" + if hasattr(self, "subscription") and self.subscription: + self.ros_node.destroy_subscription(self.subscription) + logger.info("Position subscription destroyed") diff --git a/dimos/robot/recorder.py b/dimos/robot/recorder.py deleted file mode 100644 index 77dd5fab47..0000000000 --- a/dimos/robot/recorder.py +++ /dev/null @@ -1,141 +0,0 @@ -import threading -import time -from queue import Queue -from typing import Any, Callable, Literal - -from dimos.data.recording import Recorder - - -class RobotRecorder: - """A class for recording robot observation and actions. - - Recording at a specified frequency on the observation and action of a robot. It leverages a queue and a worker - thread to handle the recording asynchronously, ensuring that the main operations of the - robot are not blocked. - - Robot class must pass in the `get_state`, `get_observation`, `prepare_action` methods.` - get_state() gets the current state/pose of the robot. - get_observation() captures the observation/image of the robot. - prepare_action() calculates the action between the new and old states. - """ - - def __init__( - self, - get_state: Callable, - get_observation: Callable, - prepare_action: Callable, - frequency_hz: int = 5, - recorder_kwargs: dict = None, - on_static: Literal["record", "omit"] = "omit", - ) -> None: - """Initializes the RobotRecorder. - - This constructor sets up the recording mechanism on the given robot, including the recorder instance, - recording frequency, and the asynchronous processing queue and worker thread. It also - initializes attributes to track the last recorded pose and the current instruction. - - Args: - get_state: A function that returns the current state of the robot. - get_observation: A function that captures the observation/image of the robot. - prepare_action: A function that calculates the action between the new and old states. - frequency_hz: Frequency at which to record pose and image data (in Hz). - recorder_kwargs: Keyword arguments to pass to the Recorder constructor. - on_static: Whether to record on static poses or not. If "record", it will record when the robot is not moving. - """ - if recorder_kwargs is None: - recorder_kwargs = {} - self.recorder = Recorder(**recorder_kwargs) - self.task = None - - self.last_recorded_state = None - self.last_image = None - - self.recording = False - self.frequency_hz = frequency_hz - self.record_on_static = on_static == "record" - self.recording_queue = Queue() - - self.get_state = get_state - self.get_observation = get_observation - self.prepare_action = prepare_action - - self._worker_thread = threading.Thread(target=self._process_queue, daemon=True) - self._worker_thread.start() - - def __enter__(self): - """Enter the context manager, starting the recording.""" - self.start_recording(self.task) - - def __exit__(self, exc_type, exc_value, traceback) -> None: - """Exit the context manager, stopping the recording.""" - self.stop_recording() - - def record(self, task: str) -> "RobotRecorder": - """Set the task and return the context manager.""" - self.task = task - return self - - def reset_recorder(self) -> None: - """Reset the recorder.""" - while self.recording: - time.sleep(0.1) - self.recorder.reset() - - def record_from_robot(self) -> None: - """Records the current pose and captures an image at the specified frequency.""" - while self.recording: - start_time = time.perf_counter() - self.record_current_state() - elapsed_time = time.perf_counter() - start_time - # Sleep for the remaining time to maintain the desired frequency - sleep_time = max(0, (1.0 / self.frequency_hz) - elapsed_time) - time.sleep(sleep_time) - - def start_recording(self, task: str = "") -> None: - """Starts the recording of pose and image.""" - if not self.recording: - self.task = task - self.recording = True - self.recording_thread = threading.Thread(target=self.record_from_robot) - self.recording_thread.start() - - def stop_recording(self) -> None: - """Stops the recording of pose and image.""" - if self.recording: - self.recording = False - self.recording_thread.join() - - def _process_queue(self) -> None: - """Processes the recording queue asynchronously.""" - while True: - image, instruction, action, state = self.recording_queue.get() - self.recorder.record(observation={"image": image, "instruction": instruction}, action=action, state=state) - self.recording_queue.task_done() - - def record_current_state(self) -> None: - """Records the current pose and image if the pose has changed.""" - state = self.get_state() - image = self.get_observation() - - # This is the beginning of the episode - if self.last_recorded_state is None: - self.last_recorded_state = state - self.last_image = image - return - - if state != self.last_recorded_state or self.record_on_static: - action = self.prepare_action(self.last_recorded_state, state) - self.recording_queue.put( - ( - self.last_image, - self.task, - action, - self.last_recorded_state, - ), - ) - self.last_image = image - self.last_recorded_state = state - - def record_last_state(self) -> None: - """Records the final pose and image after the movement completes.""" - self.record_current_state() \ No newline at end of file diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py index d0f9843aff..b2b6feaf6d 100644 --- a/dimos/robot/robot.py +++ b/dimos/robot/robot.py @@ -1,32 +1,60 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal robot interface for DIMOS robots.""" + from abc import ABC, abstractmethod -from dimos.hardware.interface import HardwareInterface -from dimos.types.sample import Sample + +from dimos.types.robot_capabilities import RobotCapability -''' -Base class for all dimos robots, both physical and simulated. -''' +# TODO: Delete class Robot(ABC): - def __init__(self, hardware_interface: HardwareInterface): - self.hardware_interface = hardware_interface + """Minimal abstract base class for all DIMOS robots. - @abstractmethod - def perform_task(self): - """Abstract method to be implemented by subclasses to perform a specific task.""" - pass - @abstractmethod - def do(self, *args, **kwargs): - """Executes motion.""" - pass + This class provides the essential interface that all robot implementations + can share, with no required methods - just common properties and helpers. + """ + + def __init__(self) -> None: + """Initialize the robot with basic properties.""" + self.capabilities: list[RobotCapability] = [] + self.skill_library = None - def update_hardware_interface(self, new_hardware_interface: HardwareInterface): - """Update the hardware interface with a new configuration.""" - self.hardware_interface = new_hardware_interface + def has_capability(self, capability: RobotCapability) -> bool: + """Check if the robot has a specific capability. - def get_hardware_configuration(self): - """Retrieve the current hardware configuration.""" - return self.hardware_interface.get_configuration() + Args: + capability: The capability to check for + + Returns: + bool: True if the robot has the capability + """ + return capability in self.capabilities + + def get_skills(self): # type: ignore[no-untyped-def] + """Get the robot's skill library. + + Returns: + The robot's skill library for managing skills + """ + return self.skill_library + + @abstractmethod + def cleanup(self) -> None: + """Clean up robot resources. - def set_hardware_configuration(self, configuration): - """Set a new hardware configuration.""" - self.hardware_interface.set_configuration(configuration) + Override this method to provide cleanup logic. + """ + ... diff --git a/dimos/robot/ros_bridge.py b/dimos/robot/ros_bridge.py new file mode 100644 index 0000000000..999cafc370 --- /dev/null +++ b/dimos/robot/ros_bridge.py @@ -0,0 +1,205 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +import logging +import threading +from typing import Any + +try: + import rclpy + from rclpy.executors import SingleThreadedExecutor + from rclpy.node import Node + from rclpy.qos import QoSDurabilityPolicy, QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy +except ImportError: + rclpy = None # type: ignore[assignment] + SingleThreadedExecutor = None # type: ignore[assignment, misc] + Node = None # type: ignore[assignment, misc] + QoSProfile = None # type: ignore[assignment, misc] + QoSReliabilityPolicy = None # type: ignore[assignment, misc] + QoSHistoryPolicy = None # type: ignore[assignment, misc] + QoSDurabilityPolicy = None # type: ignore[assignment, misc] + +from dimos.core.resource import Resource +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.INFO) + + +class BridgeDirection(Enum): + """Direction of message bridging.""" + + ROS_TO_DIMOS = "ros_to_dimos" + DIMOS_TO_ROS = "dimos_to_ros" + + +class ROSBridge(Resource): + """Unidirectional bridge between ROS and DIMOS for message passing.""" + + def __init__(self, node_name: str = "dimos_ros_bridge") -> None: + """Initialize the ROS-DIMOS bridge. + + Args: + node_name: Name for the ROS node (default: "dimos_ros_bridge") + """ + if not rclpy.ok(): # type: ignore[attr-defined] + rclpy.init() + + self.node = Node(node_name) + self.lcm = LCM() + self.lcm.start() + + self._executor = SingleThreadedExecutor() + self._executor.add_node(self.node) + + self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) + self._spin_thread.start() # TODO: don't forget to shut it down + + self._bridges: dict[str, dict[str, Any]] = {} + + self._qos = QoSProfile( # type: ignore[no-untyped-call] + reliability=QoSReliabilityPolicy.RELIABLE, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=10, + ) + + logger.info(f"ROSBridge initialized with node name: {node_name}") + + def start(self) -> None: + pass + + def stop(self) -> None: + """Shutdown the bridge and clean up resources.""" + self._executor.shutdown() + self.node.destroy_node() # type: ignore[no-untyped-call] + + if rclpy.ok(): # type: ignore[attr-defined] + rclpy.shutdown() + + logger.info("ROSBridge shutdown complete") + + def _ros_spin(self) -> None: + """Background thread for spinning ROS executor.""" + try: + self._executor.spin() + finally: + self._executor.shutdown() + + def add_topic( + self, + topic_name: str, + dimos_type: type, + ros_type: type, + direction: BridgeDirection, + remap_topic: str | None = None, + ) -> None: + """Add unidirectional bridging for a topic. + + Args: + topic_name: Name of the topic (e.g., "/cmd_vel") + dimos_type: DIMOS message type (e.g., dimos.msgs.geometry_msgs.Twist) + ros_type: ROS message type (e.g., geometry_msgs.msg.Twist) + direction: Direction of bridging (ROS_TO_DIMOS or DIMOS_TO_ROS) + remap_topic: Optional remapped topic name for the other side + """ + if topic_name in self._bridges: + logger.warning(f"Topic {topic_name} already bridged") + return + + # Determine actual topic names for each side + ros_topic_name = topic_name + dimos_topic_name = topic_name + + if remap_topic: + if direction == BridgeDirection.ROS_TO_DIMOS: + dimos_topic_name = remap_topic + else: # DIMOS_TO_ROS + ros_topic_name = remap_topic + + # Create DIMOS/LCM topic + dimos_topic = Topic(dimos_topic_name, dimos_type) + + ros_subscription = None + ros_publisher = None + dimos_subscription = None + + if direction == BridgeDirection.ROS_TO_DIMOS: + + def ros_callback(msg) -> None: # type: ignore[no-untyped-def] + self._ros_to_dimos(msg, dimos_topic, dimos_type, topic_name) + + ros_subscription = self.node.create_subscription( + ros_type, ros_topic_name, ros_callback, self._qos + ) + logger.info(f" ROS → DIMOS: Subscribing to ROS topic {ros_topic_name}") + + elif direction == BridgeDirection.DIMOS_TO_ROS: + ros_publisher = self.node.create_publisher(ros_type, ros_topic_name, self._qos) + + def dimos_callback(msg, _topic) -> None: # type: ignore[no-untyped-def] + self._dimos_to_ros(msg, ros_publisher, topic_name) + + dimos_subscription = self.lcm.subscribe(dimos_topic, dimos_callback) + logger.info(f" DIMOS → ROS: Subscribing to DIMOS topic {dimos_topic_name}") + else: + raise ValueError(f"Invalid bridge direction: {direction}") + + self._bridges[topic_name] = { + "dimos_topic": dimos_topic, + "dimos_type": dimos_type, + "ros_type": ros_type, + "ros_subscription": ros_subscription, + "ros_publisher": ros_publisher, + "dimos_subscription": dimos_subscription, + "direction": direction, + "ros_topic_name": ros_topic_name, + "dimos_topic_name": dimos_topic_name, + } + + direction_str = { + BridgeDirection.ROS_TO_DIMOS: "ROS → DIMOS", + BridgeDirection.DIMOS_TO_ROS: "DIMOS → ROS", + }[direction] + + logger.info(f"Bridged topic: {topic_name} ({direction_str})") + if remap_topic: + logger.info(f" Remapped: ROS '{ros_topic_name}' ↔ DIMOS '{dimos_topic_name}'") + logger.info(f" DIMOS type: {dimos_type.__name__}, ROS type: {ros_type.__name__}") + + def _ros_to_dimos( + self, ros_msg: Any, dimos_topic: Topic, dimos_type: type, _topic_name: str + ) -> None: + """Convert ROS message to DIMOS and publish. + + Args: + ros_msg: ROS message + dimos_topic: DIMOS topic to publish to + dimos_type: DIMOS message type + topic_name: Name of the topic for tracking + """ + dimos_msg = dimos_type.from_ros_msg(ros_msg) # type: ignore[attr-defined] + self.lcm.publish(dimos_topic, dimos_msg) + + def _dimos_to_ros(self, dimos_msg: Any, ros_publisher, _topic_name: str) -> None: # type: ignore[no-untyped-def] + """Convert DIMOS message to ROS and publish. + + Args: + dimos_msg: DIMOS message + ros_publisher: ROS publisher to use + _topic_name: Name of the topic (unused, kept for consistency) + """ + ros_msg = dimos_msg.to_ros_msg() + ros_publisher.publish(ros_msg) diff --git a/dimos/robot/ros_command_queue.py b/dimos/robot/ros_command_queue.py new file mode 100644 index 0000000000..86115d7780 --- /dev/null +++ b/dimos/robot/ros_command_queue.py @@ -0,0 +1,473 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Queue-based command management system for robot commands. + +This module provides a unified approach to queueing and processing all robot commands, +including WebRTC requests and action client commands. +Commands are processed sequentially and only when the robot is in IDLE state. +""" + +from collections.abc import Callable +from enum import Enum, auto +from queue import Empty, PriorityQueue +import threading +import time +from typing import Any, NamedTuple +import uuid + +from dimos.utils.logging_config import setup_logger + +# Initialize logger for the ros command queue module +logger = setup_logger() + + +class CommandType(Enum): + """Types of commands that can be queued""" + + WEBRTC = auto() # WebRTC API requests + ACTION = auto() # Any action client or function call + + +class WebRTCRequest(NamedTuple): + """Class to represent a WebRTC request in the queue""" + + id: str # Unique ID for tracking + api_id: int # API ID for the command + topic: str # Topic to publish to + parameter: str # Optional parameter string + priority: int # Priority level + timeout: float # How long to wait for this request to complete + + +class ROSCommand(NamedTuple): + """Class to represent a command in the queue""" + + id: str # Unique ID for tracking + cmd_type: CommandType # Type of command + execute_func: Callable # type: ignore[type-arg] # Function to execute the command + params: dict[str, Any] # Parameters for the command (for debugging/logging) + priority: int # Priority level (lower is higher priority) + timeout: float # How long to wait for this command to complete + + +class ROSCommandQueue: + """ + Manages a queue of commands for the robot. + + Commands are executed sequentially, with only one command being processed at a time. + Commands are only executed when the robot is in the IDLE state. + """ + + def __init__( + self, + webrtc_func: Callable, # type: ignore[type-arg] + is_ready_func: Callable[[], bool] | None = None, + is_busy_func: Callable[[], bool] | None = None, + debug: bool = True, + ) -> None: + """ + Initialize the ROSCommandQueue. + + Args: + webrtc_func: Function to send WebRTC requests + is_ready_func: Function to check if the robot is ready for a command + is_busy_func: Function to check if the robot is busy + debug: Whether to enable debug logging + """ + self._webrtc_func = webrtc_func + self._is_ready_func = is_ready_func or (lambda: True) + self._is_busy_func = is_busy_func + self._debug = debug + + # Queue of commands to process + self._queue = PriorityQueue() # type: ignore[var-annotated] + self._current_command = None + self._last_command_time = 0 + + # Last known robot state + self._last_ready_state = None + self._last_busy_state = None + self._stuck_in_busy_since = None + + # Command execution status + self._should_stop = False + self._queue_thread = None + + # Stats + self._command_count = 0 + self._success_count = 0 + self._failure_count = 0 + self._command_history = [] # type: ignore[var-annotated] + + self._max_queue_wait_time = ( + 30.0 # Maximum time to wait for robot to be ready before forcing + ) + + logger.info("ROSCommandQueue initialized") + + def start(self) -> None: + """Start the queue processing thread""" + if self._queue_thread is not None and self._queue_thread.is_alive(): + logger.warning("Queue processing thread already running") + return + + self._should_stop = False + self._queue_thread = threading.Thread(target=self._process_queue, daemon=True) # type: ignore[assignment] + self._queue_thread.start() # type: ignore[attr-defined] + logger.info("Queue processing thread started") + + def stop(self, timeout: float = 2.0) -> None: + """ + Stop the queue processing thread + + Args: + timeout: Maximum time to wait for the thread to stop + """ + if self._queue_thread is None or not self._queue_thread.is_alive(): + logger.warning("Queue processing thread not running") + return + + self._should_stop = True + try: + self._queue_thread.join(timeout=timeout) + if self._queue_thread.is_alive(): + logger.warning(f"Queue processing thread did not stop within {timeout}s") + else: + logger.info("Queue processing thread stopped") + except Exception as e: + logger.error(f"Error stopping queue processing thread: {e}") + + def queue_webrtc_request( + self, + api_id: int, + topic: str | None = None, + parameter: str = "", + request_id: str | None = None, + data: dict[str, Any] | None = None, + priority: int = 0, + timeout: float = 30.0, + ) -> str: + """ + Queue a WebRTC request + + Args: + api_id: API ID for the command + topic: Topic to publish to + parameter: Optional parameter string + request_id: Unique ID for the request (will be generated if not provided) + data: Data to include in the request + priority: Priority level (lower is higher priority) + timeout: Maximum time to wait for the command to complete + + Returns: + str: Unique ID for the request + """ + request_id = request_id or str(uuid.uuid4()) + + # Create a function that will execute this WebRTC request + def execute_webrtc() -> bool: + try: + logger.info(f"Executing WebRTC request: {api_id} (ID: {request_id})") + if self._debug: + logger.debug(f"[WebRTC Queue] SENDING request: API ID {api_id}") + + result = self._webrtc_func( + api_id=api_id, + topic=topic, + parameter=parameter, + request_id=request_id, + data=data, + ) + if not result: + logger.warning(f"WebRTC request failed: {api_id} (ID: {request_id})") + if self._debug: + logger.debug(f"[WebRTC Queue] Request API ID {api_id} FAILED to send") + return False + + if self._debug: + logger.debug(f"[WebRTC Queue] Request API ID {api_id} sent SUCCESSFULLY") + + # Allow time for the robot to process the command + start_time = time.time() + stabilization_delay = 0.5 # Half-second delay for stabilization + time.sleep(stabilization_delay) + + # Wait for the robot to complete the command (timeout check) + while self._is_busy_func() and (time.time() - start_time) < timeout: # type: ignore[misc] + if ( + self._debug and (time.time() - start_time) % 5 < 0.1 + ): # Print every ~5 seconds + logger.debug( + f"[WebRTC Queue] Still waiting on API ID {api_id} - elapsed: {time.time() - start_time:.1f}s" + ) + time.sleep(0.1) + + # Check if we timed out + if self._is_busy_func() and (time.time() - start_time) >= timeout: # type: ignore[misc] + logger.warning(f"WebRTC request timed out: {api_id} (ID: {request_id})") + return False + + wait_time = time.time() - start_time + if self._debug: + logger.debug( + f"[WebRTC Queue] Request API ID {api_id} completed after {wait_time:.1f}s" + ) + + logger.info(f"WebRTC request completed: {api_id} (ID: {request_id})") + return True + except Exception as e: + logger.error(f"Error executing WebRTC request: {e}") + if self._debug: + logger.debug(f"[WebRTC Queue] ERROR processing request: {e}") + return False + + # Create the command and queue it + command = ROSCommand( + id=request_id, + cmd_type=CommandType.WEBRTC, + execute_func=execute_webrtc, + params={"api_id": api_id, "topic": topic, "request_id": request_id}, + priority=priority, + timeout=timeout, + ) + + # Queue the command + self._queue.put((priority, self._command_count, command)) + self._command_count += 1 + if self._debug: + logger.debug( + f"[WebRTC Queue] Added request ID {request_id} for API ID {api_id} - Queue size now: {self.queue_size}" + ) + logger.info(f"Queued WebRTC request: {api_id} (ID: {request_id}, Priority: {priority})") + + return request_id + + def queue_action_client_request( # type: ignore[no-untyped-def] + self, + action_name: str, + execute_func: Callable, # type: ignore[type-arg] + priority: int = 0, + timeout: float = 30.0, + **kwargs, + ) -> str: + """ + Queue any action client request or function + + Args: + action_name: Name of the action for logging/tracking + execute_func: Function to execute the command + priority: Priority level (lower is higher priority) + timeout: Maximum time to wait for the command to complete + **kwargs: Additional parameters to pass to the execute function + + Returns: + str: Unique ID for the request + """ + request_id = str(uuid.uuid4()) + + # Create the command + command = ROSCommand( + id=request_id, + cmd_type=CommandType.ACTION, + execute_func=execute_func, + params={"action_name": action_name, **kwargs}, + priority=priority, + timeout=timeout, + ) + + # Queue the command + self._queue.put((priority, self._command_count, command)) + self._command_count += 1 + + action_params = ", ".join([f"{k}={v}" for k, v in kwargs.items()]) + logger.info( + f"Queued action request: {action_name} (ID: {request_id}, Priority: {priority}, Params: {action_params})" + ) + + return request_id + + def _process_queue(self) -> None: + """Process commands in the queue""" + logger.info("Starting queue processing") + logger.info("[WebRTC Queue] Processing thread started") + + while not self._should_stop: + # Print queue status + self._print_queue_status() + + # Check if we're ready to process a command + if not self._queue.empty() and self._current_command is None: + current_time = time.time() + is_ready = self._is_ready_func() + is_busy = self._is_busy_func() if self._is_busy_func else False + + if self._debug: + logger.debug( + f"[WebRTC Queue] Status: {self.queue_size} requests waiting | Robot ready: {is_ready} | Robot busy: {is_busy}" + ) + + # Track robot state changes + if is_ready != self._last_ready_state: + logger.debug( + f"Robot ready state changed: {self._last_ready_state} -> {is_ready}" + ) + self._last_ready_state = is_ready # type: ignore[assignment] + + if is_busy != self._last_busy_state: + logger.debug(f"Robot busy state changed: {self._last_busy_state} -> {is_busy}") + self._last_busy_state = is_busy # type: ignore[assignment] + + # If the robot has transitioned to busy, record the time + if is_busy: + self._stuck_in_busy_since = current_time # type: ignore[assignment] + else: + self._stuck_in_busy_since = None + + # Check if we've been waiting too long for the robot to be ready + force_processing = False + if ( + not is_ready + and is_busy + and self._stuck_in_busy_since is not None + and current_time - self._stuck_in_busy_since > self._max_queue_wait_time + ): + logger.warning( + f"Robot has been busy for {current_time - self._stuck_in_busy_since:.1f}s, " + f"forcing queue to continue" + ) + force_processing = True + + # Process the next command if ready or forcing + if is_ready or force_processing: + if self._debug and is_ready: + logger.debug("[WebRTC Queue] Robot is READY for next command") + + try: + # Get the next command + _, _, command = self._queue.get(block=False) + self._current_command = command + self._last_command_time = current_time # type: ignore[assignment] + + # Log the command + cmd_info = f"ID: {command.id}, Type: {command.cmd_type.name}" + if command.cmd_type == CommandType.WEBRTC: + api_id = command.params.get("api_id") + cmd_info += f", API: {api_id}" + if self._debug: + logger.debug(f"[WebRTC Queue] DEQUEUED request: API ID {api_id}") + elif command.cmd_type == CommandType.ACTION: + action_name = command.params.get("action_name") + cmd_info += f", Action: {action_name}" + if self._debug: + logger.debug(f"[WebRTC Queue] DEQUEUED action: {action_name}") + + forcing_str = " (FORCED)" if force_processing else "" + logger.info(f"Processing command{forcing_str}: {cmd_info}") + + # Execute the command + try: + # Where command execution occurs + success = command.execute_func() + + if success: + self._success_count += 1 + logger.info(f"Command succeeded: {cmd_info}") + if self._debug: + logger.debug( + f"[WebRTC Queue] Command {command.id} marked as COMPLETED" + ) + else: + self._failure_count += 1 + logger.warning(f"Command failed: {cmd_info}") + if self._debug: + logger.debug(f"[WebRTC Queue] Command {command.id} FAILED") + + # Record command history + self._command_history.append( + { + "id": command.id, + "type": command.cmd_type.name, + "params": command.params, + "success": success, + "time": time.time() - self._last_command_time, + } + ) + + except Exception as e: + self._failure_count += 1 + logger.error(f"Error executing command: {e}") + if self._debug: + logger.debug(f"[WebRTC Queue] ERROR executing command: {e}") + + # Mark the command as complete + self._current_command = None + if self._debug: + logger.debug( + "[WebRTC Queue] Adding 0.5s stabilization delay before next command" + ) + time.sleep(0.5) + + except Empty: + pass + + # Sleep to avoid busy-waiting + time.sleep(0.1) + + logger.info("Queue processing stopped") + + def _print_queue_status(self) -> None: + """Print the current queue status""" + current_time = time.time() + + # Only print once per second to avoid spamming the log + if current_time - self._last_command_time < 1.0 and self._current_command is None: + return + + is_ready = self._is_ready_func() + self._is_busy_func() if self._is_busy_func else False + queue_size = self.queue_size + + # Get information about the current command + current_command_info = "None" + if self._current_command is not None: + current_command_info = f"{self._current_command.cmd_type.name}" + if self._current_command.cmd_type == CommandType.WEBRTC: + api_id = self._current_command.params.get("api_id") + current_command_info += f" (API: {api_id})" + elif self._current_command.cmd_type == CommandType.ACTION: + action_name = self._current_command.params.get("action_name") + current_command_info += f" (Action: {action_name})" + + # Print the status + status = ( + f"Queue: {queue_size} items | " + f"Robot: {'READY' if is_ready else 'BUSY'} | " + f"Current: {current_command_info} | " + f"Stats: {self._success_count} OK, {self._failure_count} FAIL" + ) + + logger.debug(status) + self._last_command_time = current_time # type: ignore[assignment] + + @property + def queue_size(self) -> int: + """Get the number of commands in the queue""" + return self._queue.qsize() + + @property + def current_command(self) -> ROSCommand | None: + """Get the current command being processed""" + return self._current_command diff --git a/dimos/robot/test_ros_bridge.py b/dimos/robot/test_ros_bridge.py new file mode 100644 index 0000000000..5b616dce3b --- /dev/null +++ b/dimos/robot/test_ros_bridge.py @@ -0,0 +1,434 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +import unittest + +import numpy as np +import pytest + +try: + from geometry_msgs.msg import TransformStamped, TwistStamped as ROSTwistStamped + import rclpy + from rclpy.node import Node + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2, PointField + from tf2_msgs.msg import TFMessage as ROSTFMessage +except ImportError: + rclpy = None + Node = None + ROSTwistStamped = None + ROSPointCloud2 = None + PointField = None + ROSTFMessage = None + TransformStamped = None + +from dimos.msgs.geometry_msgs import TwistStamped +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.tf2_msgs import TFMessage +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.ros_bridge import BridgeDirection, ROSBridge + + +@pytest.mark.ros +class TestROSBridge(unittest.TestCase): + """Test suite for ROS-DIMOS bridge.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + # Skip if ROS is not available + if rclpy is None: + self.skipTest("ROS not available") + + # Initialize ROS if not already done + if not rclpy.ok(): + rclpy.init() + + # Create test bridge + self.bridge = ROSBridge("test_ros_bridge") + + # Create test node for publishing/subscribing + self.test_node = Node("test_node") + + # Track received messages + self.ros_messages = [] + self.dimos_messages = [] + self.message_timestamps = {"ros": [], "dimos": []} + + def tearDown(self) -> None: + """Clean up test fixtures.""" + self.test_node.destroy_node() + self.bridge.stop() + if rclpy.ok(): + rclpy.try_shutdown() + + def test_ros_to_dimos_twist(self) -> None: + """Test ROS TwistStamped to DIMOS conversion and transmission.""" + # Set up bridge + self.bridge.add_topic( + "/test_twist", TwistStamped, ROSTwistStamped, BridgeDirection.ROS_TO_DIMOS + ) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_twist", TwistStamped) + + def dimos_callback(msg, _topic) -> None: + self.dimos_messages.append(msg) + self.message_timestamps["dimos"].append(time.time()) + + lcm.subscribe(topic, dimos_callback) + + # Publish from ROS side + ros_pub = self.test_node.create_publisher(ROSTwistStamped, "/test_twist", 10) + + # Send test messages + for i in range(10): + msg = ROSTwistStamped() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.header.frame_id = f"frame_{i}" + msg.twist.linear.x = float(i) + msg.twist.linear.y = float(i * 2) + msg.twist.angular.z = float(i * 0.1) + + ros_pub.publish(msg) + self.message_timestamps["ros"].append(time.time()) + time.sleep(0.01) # 100Hz + + # Allow time for processing + time.sleep(0.5) + + # Verify messages received + self.assertEqual(len(self.dimos_messages), 10, "Should receive all 10 messages") + + # Verify message content + for i, msg in enumerate(self.dimos_messages): + self.assertEqual(msg.frame_id, f"frame_{i}") + self.assertAlmostEqual(msg.linear.x, float(i), places=5) + self.assertAlmostEqual(msg.linear.y, float(i * 2), places=5) + self.assertAlmostEqual(msg.angular.z, float(i * 0.1), places=5) + + def test_dimos_to_ros_twist(self) -> None: + """Test DIMOS TwistStamped to ROS conversion and transmission.""" + # Set up bridge + self.bridge.add_topic( + "/test_twist_reverse", TwistStamped, ROSTwistStamped, BridgeDirection.DIMOS_TO_ROS + ) + + # Subscribe to ROS side + def ros_callback(msg) -> None: + self.ros_messages.append(msg) + self.message_timestamps["ros"].append(time.time()) + + self.test_node.create_subscription(ROSTwistStamped, "/test_twist_reverse", ros_callback, 10) + + # Use the bridge's LCM instance for publishing + topic = Topic("/test_twist_reverse", TwistStamped) + + # Send test messages + for i in range(10): + msg = TwistStamped(ts=time.time(), frame_id=f"dimos_frame_{i}") + msg.linear.x = float(i * 3) + msg.linear.y = float(i * 4) + msg.angular.z = float(i * 0.2) + + self.bridge.lcm.publish(topic, msg) + self.message_timestamps["dimos"].append(time.time()) + time.sleep(0.01) # 100Hz + + # Allow time for processing and spin the test node + for _ in range(50): # Spin for 0.5 seconds + rclpy.spin_once(self.test_node, timeout_sec=0.01) + + # Verify messages received + self.assertEqual(len(self.ros_messages), 10, "Should receive all 10 messages") + + # Verify message content + for i, msg in enumerate(self.ros_messages): + self.assertEqual(msg.header.frame_id, f"dimos_frame_{i}") + self.assertAlmostEqual(msg.twist.linear.x, float(i * 3), places=5) + self.assertAlmostEqual(msg.twist.linear.y, float(i * 4), places=5) + self.assertAlmostEqual(msg.twist.angular.z, float(i * 0.2), places=5) + + def test_frequency_preservation(self) -> None: + """Test that message frequencies are preserved through the bridge.""" + # Set up bridge + self.bridge.add_topic( + "/test_freq", TwistStamped, ROSTwistStamped, BridgeDirection.ROS_TO_DIMOS + ) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_freq", TwistStamped) + + receive_times = [] + + def dimos_callback(_msg, _topic) -> None: + receive_times.append(time.time()) + + lcm.subscribe(topic, dimos_callback) + + # Publish from ROS at specific frequencies + ros_pub = self.test_node.create_publisher(ROSTwistStamped, "/test_freq", 10) + + # Test different frequencies + test_frequencies = [10, 50, 100] # Hz + + for target_freq in test_frequencies: + receive_times.clear() + send_times = [] + period = 1.0 / target_freq + + # Send messages at target frequency + start_time = time.time() + while time.time() - start_time < 1.0: # Run for 1 second + msg = ROSTwistStamped() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.twist.linear.x = 1.0 + + ros_pub.publish(msg) + send_times.append(time.time()) + time.sleep(period) + + # Allow processing time + time.sleep(0.2) + + # Calculate actual frequencies + if len(send_times) > 1: + send_intervals = np.diff(send_times) + send_freq = 1.0 / np.mean(send_intervals) + else: + send_freq = 0 + + if len(receive_times) > 1: + receive_intervals = np.diff(receive_times) + receive_freq = 1.0 / np.mean(receive_intervals) + else: + receive_freq = 0 + + # Verify frequency preservation (within 10% tolerance) + self.assertAlmostEqual( + receive_freq, + send_freq, + delta=send_freq * 0.1, + msg=f"Frequency not preserved for {target_freq}Hz: sent={send_freq:.1f}Hz, received={receive_freq:.1f}Hz", + ) + + def test_pointcloud_conversion(self) -> None: + """Test PointCloud2 message conversion with numpy optimization.""" + # Set up bridge + self.bridge.add_topic( + "/test_cloud", PointCloud2, ROSPointCloud2, BridgeDirection.ROS_TO_DIMOS + ) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_cloud", PointCloud2) + + received_cloud = [] + + def dimos_callback(msg, _topic) -> None: + received_cloud.append(msg) + + lcm.subscribe(topic, dimos_callback) + + # Create test point cloud + ros_pub = self.test_node.create_publisher(ROSPointCloud2, "/test_cloud", 10) + + # Generate test points + num_points = 1000 + points = np.random.randn(num_points, 3).astype(np.float32) + + # Create ROS PointCloud2 message + msg = ROSPointCloud2() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.header.frame_id = "test_frame" + msg.height = 1 + msg.width = num_points + msg.fields = [ + PointField(name="x", offset=0, datatype=PointField.FLOAT32, count=1), + PointField(name="y", offset=4, datatype=PointField.FLOAT32, count=1), + PointField(name="z", offset=8, datatype=PointField.FLOAT32, count=1), + ] + msg.is_bigendian = False + msg.point_step = 12 + msg.row_step = msg.point_step * msg.width + msg.data = points.tobytes() + msg.is_dense = True + + # Send point cloud + ros_pub.publish(msg) + + # Allow processing time + time.sleep(0.5) + + # Verify reception + self.assertEqual(len(received_cloud), 1, "Should receive point cloud") + + # Verify point data + received_points = received_cloud[0].as_numpy() + self.assertEqual(received_points.shape, points.shape) + np.testing.assert_array_almost_equal(received_points, points, decimal=5) + + def test_tf_high_frequency(self) -> None: + """Test TF message handling at high frequency.""" + # Set up bridge + self.bridge.add_topic("/test_tf", TFMessage, ROSTFMessage, BridgeDirection.ROS_TO_DIMOS) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_tf", TFMessage) + + received_tfs = [] + receive_times = [] + + def dimos_callback(msg, _topic) -> None: + received_tfs.append(msg) + receive_times.append(time.time()) + + lcm.subscribe(topic, dimos_callback) + + # Publish TF at high frequency (100Hz) + ros_pub = self.test_node.create_publisher(ROSTFMessage, "/test_tf", 100) + + target_freq = 100 # Hz + period = 1.0 / target_freq + num_messages = 100 # 1 second worth + + send_times = [] + for i in range(num_messages): + msg = ROSTFMessage() + transform = TransformStamped() + transform.header.stamp = self.test_node.get_clock().now().to_msg() + transform.header.frame_id = "world" + transform.child_frame_id = f"link_{i}" + transform.transform.translation.x = float(i) + transform.transform.rotation.w = 1.0 + msg.transforms = [transform] + + ros_pub.publish(msg) + send_times.append(time.time()) + time.sleep(period) + + # Allow processing time + time.sleep(0.5) + + # Check message count (allow 5% loss tolerance) + min_expected = int(num_messages * 0.95) + self.assertGreaterEqual( + len(received_tfs), + min_expected, + f"Should receive at least {min_expected} of {num_messages} TF messages", + ) + + # Check frequency preservation + if len(receive_times) > 1: + receive_intervals = np.diff(receive_times) + receive_freq = 1.0 / np.mean(receive_intervals) + + # For high frequency, allow 20% tolerance + self.assertAlmostEqual( + receive_freq, + target_freq, + delta=target_freq * 0.2, + msg=f"High frequency TF not preserved: expected={target_freq}Hz, got={receive_freq:.1f}Hz", + ) + + def test_bidirectional_bridge(self) -> None: + """Test simultaneous bidirectional message flow.""" + # Set up bidirectional bridges for same topic type + self.bridge.add_topic( + "/ros_to_dimos", TwistStamped, ROSTwistStamped, BridgeDirection.ROS_TO_DIMOS + ) + + self.bridge.add_topic( + "/dimos_to_ros", TwistStamped, ROSTwistStamped, BridgeDirection.DIMOS_TO_ROS + ) + + dimos_received = [] + ros_received = [] + + # DIMOS subscriber - use bridge's LCM + topic_r2d = Topic("/ros_to_dimos", TwistStamped) + self.bridge.lcm.subscribe(topic_r2d, lambda msg, _: dimos_received.append(msg)) + + # ROS subscriber + self.test_node.create_subscription( + ROSTwistStamped, "/dimos_to_ros", lambda msg: ros_received.append(msg), 10 + ) + + # Set up publishers + ros_pub = self.test_node.create_publisher(ROSTwistStamped, "/ros_to_dimos", 10) + topic_d2r = Topic("/dimos_to_ros", TwistStamped) + + # Keep track of whether threads should continue + stop_spinning = threading.Event() + + # Spin the test node in background to receive messages + def spin_test_node() -> None: + while not stop_spinning.is_set(): + rclpy.spin_once(self.test_node, timeout_sec=0.01) + + spin_thread = threading.Thread(target=spin_test_node, daemon=True) + spin_thread.start() + + # Send messages in both directions simultaneously + def send_ros_messages() -> None: + for i in range(50): + msg = ROSTwistStamped() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.twist.linear.x = float(i) + ros_pub.publish(msg) + time.sleep(0.02) # 50Hz + + def send_dimos_messages() -> None: + for i in range(50): + msg = TwistStamped(ts=time.time()) + msg.linear.y = float(i * 2) + self.bridge.lcm.publish(topic_d2r, msg) + time.sleep(0.02) # 50Hz + + # Run both senders in parallel + ros_thread = threading.Thread(target=send_ros_messages) + dimos_thread = threading.Thread(target=send_dimos_messages) + + ros_thread.start() + dimos_thread.start() + + ros_thread.join() + dimos_thread.join() + + # Allow processing time + time.sleep(0.5) + stop_spinning.set() + spin_thread.join(timeout=1.0) + + # Verify both directions worked + self.assertGreaterEqual(len(dimos_received), 45, "Should receive most ROS->DIMOS messages") + self.assertGreaterEqual(len(ros_received), 45, "Should receive most DIMOS->ROS messages") + + # Verify message integrity + for i, msg in enumerate(dimos_received[:45]): + self.assertAlmostEqual(msg.linear.x, float(i), places=5) + + for i, msg in enumerate(ros_received[:45]): + self.assertAlmostEqual(msg.twist.linear.y, float(i * 2), places=5) + + +if __name__ == "__main__": + unittest.main() diff --git a/dimos/robot/unitree/connection/__init__.py b/dimos/robot/unitree/connection/__init__.py new file mode 100644 index 0000000000..5c1dff1922 --- /dev/null +++ b/dimos/robot/unitree/connection/__init__.py @@ -0,0 +1,4 @@ +import dimos.robot.unitree.connection.g1 as g1 +import dimos.robot.unitree.connection.go2 as go2 + +__all__ = ["g1", "go2"] diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py new file mode 100644 index 0000000000..eecf183451 --- /dev/null +++ b/dimos/robot/unitree/connection/connection.py @@ -0,0 +1,403 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass +import functools +import threading +import time +from typing import TypeAlias + +from aiortc import MediaStreamTrack # type: ignore[import-untyped] +from go2_webrtc_driver.constants import ( # type: ignore[import-untyped] + RTC_TOPIC, + SPORT_CMD, + VUI_COLOR, +) +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-untyped] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +import numpy as np +from numpy.typing import NDArray +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos.core import rpc +from dimos.core.resource import Resource +from dimos.msgs.geometry_msgs import Pose, Transform, Twist +from dimos.msgs.sensor_msgs import Image +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.reactive import backpressure, callback_to_observable + +VideoMessage: TypeAlias = NDArray[np.uint8] # Shape: (height, width, 3) + + +@dataclass +class SerializableVideoFrame: + """Pickleable wrapper for av.VideoFrame with all metadata""" + + data: np.ndarray # type: ignore[type-arg] + pts: int | None = None + time: float | None = None + dts: int | None = None + width: int | None = None + height: int | None = None + format: str | None = None + + @classmethod + def from_av_frame(cls, frame): # type: ignore[no-untyped-def] + return cls( + data=frame.to_ndarray(format="rgb24"), + pts=frame.pts, + time=frame.time, + dts=frame.dts, + width=frame.width, + height=frame.height, + format=frame.format.name if hasattr(frame, "format") and frame.format else None, + ) + + def to_ndarray(self, format=None): # type: ignore[no-untyped-def] + return self.data + + +class UnitreeWebRTCConnection(Resource): + def __init__(self, ip: str, mode: str = "ai") -> None: + self.ip = ip + self.mode = mode + self.stop_timer: threading.Timer | None = None + self.cmd_vel_timeout = 0.2 + self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) + self.connect() + + def connect(self) -> None: + self.loop = asyncio.new_event_loop() + self.task = None + self.connected_event = asyncio.Event() + self.connection_ready = threading.Event() + + async def async_connect() -> None: + await self.conn.connect() + await self.conn.datachannel.disableTrafficSaving(True) + + self.conn.datachannel.set_decoder(decoder_type="native") + + await self.conn.datachannel.pub_sub.publish_request_new( + RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} + ) + + self.connected_event.set() + self.connection_ready.set() + + while True: + await asyncio.sleep(1) + + def start_background_loop() -> None: + asyncio.set_event_loop(self.loop) + self.task = self.loop.create_task(async_connect()) + self.loop.run_forever() + + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=start_background_loop, daemon=True) + self.thread.start() + self.connection_ready.wait() + + def start(self) -> None: + pass + + def stop(self) -> None: + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if self.task: + self.task.cancel() + + async def async_disconnect() -> None: + try: + # Send stop command directly since we're already in the event loop. + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": 0, "ly": 0, "rx": 0, "ry": 0}, + ) + await self.conn.disconnect() + except Exception: + pass + + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + self.loop.call_soon_threadsafe(self.loop.stop) + + if self.thread.is_alive(): + self.thread.join(timeout=2.0) + + def move(self, twist: Twist, duration: float = 0.0) -> bool: + """Send movement command to the robot using Twist commands. + + Args: + twist: Twist message with linear and angular velocities + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z + + # WebRTC coordinate mapping: + # x - Positive right, negative left + # y - positive forward, negative backwards + # yaw - Positive rotate right, negative rotate left + async def async_move() -> None: + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, + ) + + async def async_move_duration() -> None: + """Send movement commands continuously for the specified duration.""" + start_time = time.time() + sleep_time = 0.01 + + while time.time() - start_time < duration: + await async_move() + await asyncio.sleep(sleep_time) + + # Cancel existing timer and start a new one + if self.stop_timer: + self.stop_timer.cancel() + + # Auto-stop after 0.5 seconds if no new commands + self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) + self.stop_timer.daemon = True + self.stop_timer.start() + + try: + if duration > 0: + # Send continuous move commands for the duration + future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) + future.result() + # Stop after duration + self.stop() + else: + # Single command for continuous movement + future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) + future.result() + return True + except Exception as e: + print(f"Failed to send movement command: {e}") + return False + + # Generic conversion of unitree subscription to Subject (used for all subs) + def unitree_sub_stream(self, topic_name: str): # type: ignore[no-untyped-def] + def subscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] + # Run the subscription in the background thread that has the event loop + def run_subscription() -> None: + self.conn.datachannel.pub_sub.subscribe(topic_name, cb) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_subscription) + + def unsubscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] + # Run the unsubscription in the background thread that has the event loop + def run_unsubscription() -> None: + self.conn.datachannel.pub_sub.unsubscribe(topic_name) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_unsubscription) + + return callback_to_observable( + start=subscribe_in_thread, + stop=unsubscribe_in_thread, + ) + + # Generic sync API call (we jump into the client thread) + def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] + future = asyncio.run_coroutine_threadsafe( + self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop + ) + return future.result() + + @simple_mcache + def raw_lidar_stream(self) -> Observable[LidarMessage]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) + + @simple_mcache + def raw_odom_stream(self) -> Observable[Pose]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) + + @simple_mcache + def lidar_stream(self) -> Observable[LidarMessage]: + return backpressure( + self.raw_lidar_stream().pipe( + ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) # type: ignore[arg-type] + ) + ) + + @simple_mcache + def tf_stream(self) -> Observable[Transform]: + base_link = functools.partial(Transform.from_pose, "base_link") + return backpressure(self.odom_stream().pipe(ops.map(base_link))) + + @simple_mcache + def odom_stream(self) -> Observable[Pose]: + return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) + + @simple_mcache + def video_stream(self) -> Observable[Image]: + return backpressure( + self.raw_video_stream().pipe( + ops.filter(lambda frame: frame is not None), + ops.map( + lambda frame: Image.from_numpy( + # np.ascontiguousarray(frame.to_ndarray("rgb24")), + frame.to_ndarray(format="rgb24"), # type: ignore[attr-defined] + frame_id="camera_optical", + ) + ), + ) + ) + + @simple_mcache + def lowstate_stream(self) -> Observable[LowStateMsg]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) + + def standup_ai(self) -> bool: + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) # type: ignore[no-any-return] + + def standup_normal(self) -> bool: + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + time.sleep(0.5) + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) + return True + + @rpc + def standup(self) -> bool: + if self.mode == "ai": + return self.standup_ai() + else: + return self.standup_normal() + + @rpc + def liedown(self) -> bool: + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) # type: ignore[no-any-return] + + async def handstand(self): # type: ignore[no-untyped-def] + return self.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, + ) + + @rpc + def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: + return self.publish_request( # type: ignore[no-any-return] + RTC_TOPIC["VUI"], + { + "api_id": 1001, + "parameter": { + "color": color, + "time": colortime, + }, + }, + ) + + @simple_mcache + def raw_video_stream(self) -> Observable[VideoMessage]: + subject: Subject[VideoMessage] = Subject() + stop_event = threading.Event() + + async def accept_track(track: MediaStreamTrack) -> None: + while True: + if stop_event.is_set(): + return + frame = await track.recv() + serializable_frame = SerializableVideoFrame.from_av_frame(frame) # type: ignore[no-untyped-call] + subject.on_next(serializable_frame) + + self.conn.video.add_track_callback(accept_track) + + # Run the video channel switching in the background thread + def switch_video_channel() -> None: + self.conn.video.switchVideoChannel(True) + + self.loop.call_soon_threadsafe(switch_video_channel) + + def stop() -> None: + stop_event.set() # Signal the loop to stop + self.conn.video.track_callbacks.remove(accept_track) + + # Run the video channel switching off in the background thread + def switch_video_channel_off() -> None: + self.conn.video.switchVideoChannel(False) + + self.loop.call_soon_threadsafe(switch_video_channel_off) + + return subject.pipe(ops.finally_action(stop)) + + def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: + """Get the video stream from the robot's camera. + + Implements the AbstractRobot interface method. + + Args: + fps: Frames per second. This parameter is included for API compatibility, + but doesn't affect the actual frame rate which is determined by the camera. + + Returns: + Observable: An observable stream of video frames or None if video is not available. + """ + return self.video_stream() # type: ignore[no-any-return] + + def stop(self) -> bool: # type: ignore[no-redef] + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + # Cancel timer since we're explicitly stopping + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + return True + + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if hasattr(self, "task") and self.task: + self.task.cancel() + if hasattr(self, "conn"): + + async def async_disconnect() -> None: + try: + await self.conn.disconnect() + except: + pass + + if hasattr(self, "loop") and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + if hasattr(self, "loop") and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + if hasattr(self, "thread") and self.thread.is_alive(): + self.thread.join(timeout=2.0) diff --git a/dimos/robot/unitree/connection/g1.py b/dimos/robot/unitree/connection/g1.py new file mode 100644 index 0000000000..189b9a25b4 --- /dev/null +++ b/dimos/robot/unitree/connection/g1.py @@ -0,0 +1,102 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +from reactivex.disposable import Disposable + +from dimos import spec +from dimos.core import DimosCluster, In, Module, rpc +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import Twist +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class G1Connection(Module): + cmd_vel: In[Twist] = None # type: ignore + ip: str | None + connection_type: str | None = None + _global_config: GlobalConfig + + connection: UnitreeWebRTCConnection | None + + def __init__( + self, + ip: str | None = None, + connection_type: str | None = None, + global_config: GlobalConfig | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + self._global_config = global_config or GlobalConfig() + self.ip = ip if ip is not None else self._global_config.robot_ip + self.connection_type = connection_type or self._global_config.unitree_connection_type + self.connection = None + super().__init__(*args, **kwargs) + + @rpc + def start(self) -> None: + super().start() + + match self.connection_type: + case "webrtc": + assert self.ip is not None, "IP address must be provided" + self.connection = UnitreeWebRTCConnection(self.ip) + case "replay": + raise ValueError("Replay connection not implemented for G1 robot") + case "mujoco": + raise ValueError( + "This module does not support simulation, use G1SimConnection instead" + ) + case _: + raise ValueError(f"Unknown connection type: {self.connection_type}") + + assert self.connection is not None + self.connection.start() + + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + + @rpc + def stop(self) -> None: + assert self.connection is not None + self.connection.stop() + super().stop() + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> None: + assert self.connection is not None + self.connection.move(twist, duration) + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + logger.info(f"Publishing request to topic: {topic} with data: {data}") + assert self.connection is not None + return self.connection.publish_request(topic, data) # type: ignore[no-any-return] + + +g1_connection = G1Connection.blueprint + + +def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: + connection = dimos.deploy(G1Connection, ip) # type: ignore[attr-defined] + connection.cmd_vel.connect(local_planner.cmd_vel) + connection.start() + return connection # type: ignore[no-any-return] + + +__all__ = ["G1Connection", "deploy", "g1_connection"] diff --git a/dimos/robot/unitree/connection/g1sim.py b/dimos/robot/unitree/connection/g1sim.py new file mode 100644 index 0000000000..1fc7437c91 --- /dev/null +++ b/dimos/robot/unitree/connection/g1sim.py @@ -0,0 +1,128 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from typing import TYPE_CHECKING, Any + +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + Vector3, +) +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry as SimOdometry +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + +logger = setup_logger() + + +class G1SimConnection(Module): + cmd_vel: In[Twist] = None # type: ignore + lidar: Out[LidarMessage] = None # type: ignore + odom: Out[PoseStamped] = None # type: ignore + ip: str | None + _global_config: GlobalConfig + + def __init__( + self, + ip: str | None = None, + global_config: GlobalConfig | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + self._global_config = global_config or GlobalConfig() + self.ip = ip if ip is not None else self._global_config.robot_ip + self.connection: MujocoConnection | None = None + super().__init__(*args, **kwargs) + + @rpc + def start(self) -> None: + super().start() + + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection(self._global_config) + assert self.connection is not None + self.connection.start() + + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + self._disposables.add(self.connection.odom_stream().subscribe(self._publish_sim_odom)) + self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) + + @rpc + def stop(self) -> None: + assert self.connection is not None + self.connection.stop() + super().stop() + + def _publish_tf(self, msg: PoseStamped) -> None: + self.odom.publish(msg) + + self.tf.publish(Transform.from_pose("base_link", msg)) + + # Publish camera_link transform + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=time.time(), + ) + + map_to_world = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="world", + ts=time.time(), + ) + + self.tf.publish(camera_link, map_to_world) + + def _publish_sim_odom(self, msg: SimOdometry) -> None: + self._publish_tf( + PoseStamped( + ts=msg.ts, + frame_id=msg.frame_id, + position=msg.position, + orientation=msg.orientation, + ) + ) + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> None: + assert self.connection is not None + self.connection.move(twist, duration) + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + logger.info(f"Publishing request to topic: {topic} with data: {data}") + assert self.connection is not None + return self.connection.publish_request(topic, data) + + +g1_sim_connection = G1SimConnection.blueprint + + +__all__ = ["G1SimConnection", "g1_sim_connection"] diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py new file mode 100644 index 0000000000..aee5122522 --- /dev/null +++ b/dimos/robot/unitree/connection/go2.py @@ -0,0 +1,314 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from threading import Thread +import time +from typing import Any, Protocol + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from reactivex.disposable import Disposable +from reactivex.observable import Observable + +from dimos import spec +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + Vector3, +) +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.std_msgs import Header +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.data import get_data +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +logger = setup_logger(level=logging.INFO) + + +class Go2ConnectionProtocol(Protocol): + """Protocol defining the interface for Go2 robot connections.""" + + def start(self) -> None: ... + def stop(self) -> None: ... + def lidar_stream(self) -> Observable: ... # type: ignore[type-arg] + def odom_stream(self) -> Observable: ... # type: ignore[type-arg] + def video_stream(self) -> Observable: ... # type: ignore[type-arg] + def move(self, twist: Twist, duration: float = 0.0) -> bool: ... + def standup(self) -> bool: ... + def liedown(self) -> bool: ... + def publish_request(self, topic: str, data: dict) -> dict: ... # type: ignore[type-arg] + + +def _camera_info_static() -> CameraInfo: + fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) + width, height = (1280, 720) + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + base_msg = { + "D_length": len(D), + "height": height, + "width": width, + "distortion_model": "plumb_bob", + "D": D, + "K": K, + "R": R, + "P": P, + "binning_x": 0, + "binning_y": 0, + } + + return CameraInfo(**base_msg, header=Header("camera_optical")) + + +class ReplayConnection(UnitreeWebRTCConnection): + dir_name = "unitree_go2_office_walk2" + + # we don't want UnitreeWebRTCConnection to init + def __init__( # type: ignore[no-untyped-def] + self, + **kwargs, + ) -> None: + get_data(self.dir_name) + self.replay_config = { + "loop": kwargs.get("loop"), + "seek": kwargs.get("seek"), + "duration": kwargs.get("duration"), + } + + def connect(self) -> None: + pass + + def start(self) -> None: + pass + + def standup(self) -> bool: + return True + + def liedown(self) -> bool: + return True + + @simple_mcache + def lidar_stream(self): # type: ignore[no-untyped-def] + lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") # type: ignore[var-annotated] + return lidar_store.stream(**self.replay_config) # type: ignore[arg-type] + + @simple_mcache + def odom_stream(self): # type: ignore[no-untyped-def] + odom_store = TimedSensorReplay(f"{self.dir_name}/odom") # type: ignore[var-annotated] + return odom_store.stream(**self.replay_config) # type: ignore[arg-type] + + # we don't have raw video stream in the data set + @simple_mcache + def video_stream(self): # type: ignore[no-untyped-def] + video_store = TimedSensorReplay(f"{self.dir_name}/video") # type: ignore[var-annotated] + return video_store.stream(**self.replay_config) # type: ignore[arg-type] + + def move(self, twist: Twist, duration: float = 0.0) -> bool: + return True + + def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] + """Fake publish request for testing.""" + return {"status": "ok", "message": "Fake publish"} + + +class GO2Connection(Module, spec.Camera, spec.Pointcloud): + cmd_vel: In[Twist] = None # type: ignore + pointcloud: Out[PointCloud2] = None # type: ignore + odom: Out[PoseStamped] = None # type: ignore + lidar: Out[LidarMessage] = None # type: ignore + color_image: Out[Image] = None # type: ignore + camera_info: Out[CameraInfo] = None # type: ignore + + connection: Go2ConnectionProtocol + camera_info_static: CameraInfo = _camera_info_static() + _global_config: GlobalConfig + _camera_info_thread: Thread | None = None + + def __init__( # type: ignore[no-untyped-def] + self, + ip: str | None = None, + global_config: GlobalConfig | None = None, + *args, + **kwargs, + ) -> None: + self._global_config = global_config or GlobalConfig() + + ip = ip if ip is not None else self._global_config.robot_ip + + connection_type = self._global_config.unitree_connection_type + + if ip in ["fake", "mock", "replay"] or connection_type == "replay": + self.connection = ReplayConnection() + elif ip == "mujoco" or connection_type == "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection(self._global_config) + else: + assert ip is not None, "IP address must be provided" + self.connection = UnitreeWebRTCConnection(ip) + + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self) -> None: + super().start() + + self.connection.start() + + self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) + self._disposables.add(self.connection.odom_stream().subscribe(self._publish_tf)) + self._disposables.add(self.connection.video_stream().subscribe(self.color_image.publish)) + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + + self._camera_info_thread = Thread( + target=self.publish_camera_info, + daemon=True, + ) + self._camera_info_thread.start() + + self.standup() + + @rpc + def stop(self) -> None: + self.liedown() + + if self.connection: + self.connection.stop() + + if self._camera_info_thread and self._camera_info_thread.is_alive(): + self._camera_info_thread.join(timeout=1.0) + + super().stop() + + @classmethod + def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom.ts, + ) + + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=odom.ts, + ) + + sensor = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="sensor", + ts=odom.ts, + ) + + map_to_world = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="world", + ts=time.time(), + ) + + return [ + Transform.from_pose("base_link", odom), + camera_link, + camera_optical, + sensor, + map_to_world, + ] + + def _publish_tf(self, msg: PoseStamped) -> None: + self.tf.publish(*self._odom_to_tf(msg)) + if self.odom.transport: + self.odom.publish(msg) + + def publish_camera_info(self) -> None: + while True: + self.camera_info.publish(_camera_info_static()) + time.sleep(1.0) + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> bool: + """Send movement command to robot.""" + return self.connection.move(twist, duration) + + @rpc + def standup(self) -> bool: + """Make the robot stand up.""" + return self.connection.standup() + + @rpc + def liedown(self) -> bool: + """Make the robot lie down.""" + return self.connection.liedown() + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + """Publish a request to the WebRTC connection. + Args: + topic: The RTC topic to publish to + data: The data dictionary to publish + Returns: + The result of the publish request + """ + return self.connection.publish_request(topic, data) + + +go2_connection = GO2Connection.blueprint + + +def deploy(dimos: DimosCluster, ip: str, prefix: str = "") -> GO2Connection: + from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE + + connection = dimos.deploy(GO2Connection, ip) # type: ignore[attr-defined] + + connection.pointcloud.transport = pSHMTransport( + f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + connection.color_image.transport = pSHMTransport( + f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + + connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", Twist) + + connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) + connection.start() + + return connection # type: ignore[no-any-return] + + +__all__ = ["GO2Connection", "deploy", "go2_connection"] diff --git a/dimos/robot/unitree/g1/g1agent.py b/dimos/robot/unitree/g1/g1agent.py new file mode 100644 index 0000000000..b545966f35 --- /dev/null +++ b/dimos/robot/unitree/g1/g1agent.py @@ -0,0 +1,48 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos import agents2 +from dimos.agents2.skills.navigation import NavigationSkillContainer +from dimos.core import DimosCluster +from dimos.perception import spatial_perception +from dimos.robot.unitree.g1 import g1detector + + +def deploy(dimos: DimosCluster, ip: str): # type: ignore[no-untyped-def] + g1 = g1detector.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + detector3d = g1.get("detector3d") + connection = g1.get("connection") + + spatialmem = spatial_perception.deploy(dimos, camera) + + navskills = dimos.deploy( # type: ignore[attr-defined] + NavigationSkillContainer, + spatialmem, + nav, + detector3d, + ) + navskills.start() + + agent = agents2.deploy( # type: ignore[attr-defined] + dimos, + "You are controling a humanoid robot", + skill_containers=[connection, nav, camera, spatialmem, navskills], + ) + agent.run_implicit_skill("current_position") + agent.run_implicit_skill("video_stream") + + return {"agent": agent, "spatialmem": spatialmem, **g1} diff --git a/dimos/robot/unitree/g1/g1detector.py b/dimos/robot/unitree/g1/g1detector.py new file mode 100644 index 0000000000..ca549025af --- /dev/null +++ b/dimos/robot/unitree/g1/g1detector.py @@ -0,0 +1,41 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core import DimosCluster +from dimos.perception.detection import module3D, moduleDB +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.robot.unitree.g1 import g1zed + + +def deploy(dimos: DimosCluster, ip: str): # type: ignore[no-untyped-def] + g1 = g1zed.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + + person_detector = module3D.deploy( + dimos, + camera=camera, + lidar=nav, + detector=YoloPersonDetector, + ) + + detector3d = moduleDB.deploy( + dimos, + camera=camera, + lidar=nav, + filter=lambda det: det.class_id != 0, + ) + + return {"person_detector": person_detector, "detector3d": detector3d, **g1} diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py new file mode 100644 index 0000000000..20034ecdba --- /dev/null +++ b/dimos/robot/unitree/g1/g1zed.py @@ -0,0 +1,90 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict, cast + +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core import DimosCluster, LCMTransport, pSHMTransport +from dimos.hardware.camera import zed +from dimos.hardware.camera.module import CameraModule +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import ( + Quaternion, + Transform, + Vector3, +) +from dimos.msgs.sensor_msgs import CameraInfo +from dimos.navigation import rosnav +from dimos.navigation.rosnav import ROSNav +from dimos.robot import foxglove_bridge +from dimos.robot.unitree.connection import g1 +from dimos.robot.unitree.connection.g1 import G1Connection +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class G1ZedDeployResult(TypedDict): + nav: ROSNav + connection: G1Connection + camera: CameraModule + camerainfo: CameraInfo + + +def deploy_g1_monozed(dimos: DimosCluster) -> CameraModule: + camera = cast( + "CameraModule", + dimos.deploy( # type: ignore[attr-defined] + CameraModule, + frequency=4.0, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + frequency=5, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ), + ) + + camera.color_image.transport = pSHMTransport( + "/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + camera.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + camera.start() + return camera + + +def deploy(dimos: DimosCluster, ip: str): # type: ignore[no-untyped-def] + nav = rosnav.deploy( # type: ignore[call-arg] + dimos, + sensor_to_base_link_transform=Transform( + frame_id="sensor", child_frame_id="base_link", translation=Vector3(0.0, 0.0, 1.5) + ), + ) + connection = g1.deploy(dimos, ip, nav) + zedcam = deploy_g1_monozed(dimos) + + foxglove_bridge.deploy(dimos) + + return { + "nav": nav, + "connection": connection, + "camera": zedcam, + } diff --git a/dimos/robot/unitree/go2/go2.py b/dimos/robot/unitree/go2/go2.py new file mode 100644 index 0000000000..9ee6379df9 --- /dev/null +++ b/dimos/robot/unitree/go2/go2.py @@ -0,0 +1,37 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from dimos.core import DimosCluster +from dimos.robot import foxglove_bridge +from dimos.robot.unitree.connection import go2 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.INFO) + + +def deploy(dimos: DimosCluster, ip: str): # type: ignore[no-untyped-def] + connection = go2.deploy(dimos, ip) + foxglove_bridge.deploy(dimos) + + # detector = moduleDB.deploy( + # dimos, + # camera=connection, + # lidar=connection, + # ) + + # agent = agents2.deploy(dimos) + # agent.register_skills(detector) + return connection diff --git a/dimos/robot/unitree/run.py b/dimos/robot/unitree/run.py new file mode 100644 index 0000000000..5b17ad7a9d --- /dev/null +++ b/dimos/robot/unitree/run.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Centralized runner for modular Unitree robot deployment scripts. + +Usage: + python run.py g1agent --ip 192.168.1.100 + python run.py g1/g1zed --ip $ROBOT_IP + python run.py go2/go2.py --ip $ROBOT_IP + python run.py connection/g1.py --ip $ROBOT_IP +""" + +import argparse +import importlib +import os +import sys + +from dotenv import load_dotenv + +from dimos.core import start, wait_exit + + +def main() -> None: + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree Robot Modular Deployment Runner") + parser.add_argument( + "module", + help="Module name/path to run (e.g., g1agent, g1/g1zed, go2/go2.py)", + ) + parser.add_argument( + "--ip", + default=os.getenv("ROBOT_IP"), + help="Robot IP address (default: ROBOT_IP from .env)", + ) + parser.add_argument( + "--workers", + type=int, + default=8, + help="Number of worker threads for DimosCluster (default: 8)", + ) + + args = parser.parse_args() + + # Validate IP address + if not args.ip: + print("ERROR: Robot IP address not provided") + print("Please provide --ip or set ROBOT_IP in .env") + sys.exit(1) + + # Parse the module path + module_path = args.module + + # Remove .py extension if present + if module_path.endswith(".py"): + module_path = module_path[:-3] + + # Convert path separators to dots for import + module_path = module_path.replace("/", ".") + + # Import the module + try: + # Build the full import path + full_module_path = f"dimos.robot.unitree.{module_path}" + print(f"Importing module: {full_module_path}") + module = importlib.import_module(full_module_path) + except ImportError: + # Try as a relative import from the unitree package + try: + module = importlib.import_module(f".{module_path}", package="dimos.robot.unitree") + except ImportError as e2: + import traceback + + traceback.print_exc() + + print(f"\nERROR: Could not import module '{args.module}'") + print("Tried importing as:") + print(f" 1. {full_module_path}") + print(" 2. Relative import from dimos.robot.unitree") + print("Make sure the module exists in dimos/robot/unitree/") + print(f"Import error: {e2}") + + sys.exit(1) + + # Verify deploy function exists + if not hasattr(module, "deploy"): + print(f"ERROR: Module '{args.module}' does not have a 'deploy' function") + sys.exit(1) + + print(f"Running {args.module}.deploy() with IP {args.ip}") + + # Run the standard deployment pattern + dimos = start(args.workers) + try: + module.deploy(dimos, args.ip) + wait_exit() + finally: + dimos.close_all() # type: ignore[attr-defined] + + +if __name__ == "__main__": + main() diff --git a/tests/data/database.db-wal b/dimos/robot/unitree_webrtc/__init__.py similarity index 100% rename from tests/data/database.db-wal rename to dimos/robot/unitree_webrtc/__init__.py diff --git a/dimos/robot/unitree_webrtc/demo_error_on_name_conflicts.py b/dimos/robot/unitree_webrtc/demo_error_on_name_conflicts.py new file mode 100644 index 0000000000..4fad0a8714 --- /dev/null +++ b/dimos/robot/unitree_webrtc/demo_error_on_name_conflicts.py @@ -0,0 +1,53 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.blueprints import autoconnect +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In, Out + + +class Data1: + pass + + +class Data2: + pass + + +class ModuleA(Module): + shared_data: Out[Data1] = None # type: ignore[assignment] + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + +class ModuleB(Module): + shared_data: In[Data2] = None # type: ignore[assignment] + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + +blueprint = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()) diff --git a/dimos/robot/unitree_webrtc/depth_module.py b/dimos/robot/unitree_webrtc/depth_module.py new file mode 100644 index 0000000000..6e9491b458 --- /dev/null +++ b/dimos/robot/unitree_webrtc/depth_module.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +import numpy as np + +from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class DepthModule(Module): + """ + Depth module for Unitree Go2 that processes RGB images to generate depth using Metric3D. + + Subscribes to: + - /go2/color_image: RGB camera images from Unitree + - /go2/camera_info: Camera calibration information + + Publishes: + - /go2/depth_image: Depth images generated by Metric3D + """ + + # LCM inputs + color_image: In[Image] = None # type: ignore[assignment] + camera_info: In[CameraInfo] = None # type: ignore[assignment] + + # LCM outputs + depth_image: Out[Image] = None # type: ignore[assignment] + + def __init__( # type: ignore[no-untyped-def] + self, + gt_depth_scale: float = 0.5, + global_config: GlobalConfig | None = None, + **kwargs, + ) -> None: + """ + Initialize Depth Module. + + Args: + gt_depth_scale: Ground truth depth scaling factor + """ + super().__init__(**kwargs) + + self.camera_intrinsics = None + self.gt_depth_scale = gt_depth_scale + self.metric3d = None + self._camera_info_received = False + + # Processing state + self._running = False + self._latest_frame = None + self._last_image = None + self._last_timestamp = None + self._last_depth = None + self._cannot_process_depth = False + + # Threading + self._processing_thread: threading.Thread | None = None + self._stop_processing = threading.Event() + + if global_config: + if global_config.simulation: + self.gt_depth_scale = 1.0 + + @rpc + def start(self) -> None: + super().start() + + if self._running: + logger.warning("Camera module already running") + return + + # Set running flag before starting + self._running = True + + # Subscribe to video and camera info inputs + self.color_image.subscribe(self._on_video) + self.camera_info.subscribe(self._on_camera_info) + + # Start processing thread + self._start_processing_thread() + + logger.info("Depth module started") + + @rpc + def stop(self) -> None: + if not self._running: + return + + self._running = False + self._stop_processing.set() + + # Wait for thread to finish + if self._processing_thread and self._processing_thread.is_alive(): + self._processing_thread.join(timeout=2.0) + + super().stop() + + def _on_camera_info(self, msg: CameraInfo) -> None: + """Process camera info to extract intrinsics.""" + if self.metric3d is not None: + return # Already initialized + + try: + # Extract intrinsics from camera matrix K + K = msg.K + fx = K[0] + fy = K[4] + cx = K[2] + cy = K[5] + + self.camera_intrinsics = [fx, fy, cx, cy] # type: ignore[assignment] + + # Initialize Metric3D with camera intrinsics + from dimos.models.depth.metric3d import Metric3D + + self.metric3d = Metric3D(camera_intrinsics=self.camera_intrinsics) # type: ignore[assignment] + self._camera_info_received = True + + logger.info( + f"Initialized Metric3D with intrinsics from camera_info: {self.camera_intrinsics}" + ) + + except Exception as e: + logger.error(f"Error processing camera info: {e}") + + def _on_video(self, msg: Image) -> None: + """Store latest video frame for processing.""" + if not self._running: + return + + # Simply store the latest frame - processing happens in main loop + self._latest_frame = msg # type: ignore[assignment] + logger.debug( + f"Received video frame: format={msg.format}, shape={msg.data.shape if hasattr(msg.data, 'shape') else 'unknown'}" + ) + + def _start_processing_thread(self) -> None: + """Start the processing thread.""" + self._stop_processing.clear() + self._processing_thread = threading.Thread(target=self._main_processing_loop, daemon=True) + self._processing_thread.start() + logger.info("Started depth processing thread") + + def _main_processing_loop(self) -> None: + """Main processing loop that continuously processes latest frames.""" + logger.info("Starting main processing loop") + + while not self._stop_processing.is_set(): + # Process latest frame if available + if self._latest_frame is not None: + try: + msg = self._latest_frame + self._latest_frame = None # Clear to avoid reprocessing + # Store for publishing + self._last_image = msg.data + self._last_timestamp = msg.ts if msg.ts else time.time() + # Process depth + self._process_depth(self._last_image) + + except Exception as e: + logger.error(f"Error in main processing loop: {e}", exc_info=True) + else: + # Small sleep to avoid busy waiting + time.sleep(0.001) + + logger.info("Main processing loop stopped") + + def _process_depth(self, img_array: np.ndarray) -> None: # type: ignore[type-arg] + """Process depth estimation using Metric3D.""" + if self._cannot_process_depth: + self._last_depth = None + return + + # Wait for camera info to initialize Metric3D + if self.metric3d is None: + logger.debug("Waiting for camera_info to initialize Metric3D") + return + + try: + logger.debug(f"Processing depth for image shape: {img_array.shape}") + + # Generate depth map + depth_array = self.metric3d.infer_depth(img_array) * self.gt_depth_scale + + self._last_depth = depth_array + logger.debug(f"Generated depth map shape: {depth_array.shape}") + + self._publish_depth() + + except Exception as e: + logger.error(f"Error processing depth: {e}") + self._cannot_process_depth = True + + def _publish_depth(self) -> None: + """Publish depth image.""" + if not self._running: + return + + try: + # Publish depth image + if self._last_depth is not None: + # Convert depth to uint16 (millimeters) for more efficient storage + # Clamp to valid range [0, 65.535] meters before converting + depth_clamped = np.clip(self._last_depth, 0, 65.535) + depth_uint16 = (depth_clamped * 1000).astype(np.uint16) + depth_msg = Image( + data=depth_uint16, + format=ImageFormat.DEPTH16, # Use DEPTH16 format for uint16 depth + frame_id="camera_link", + ts=self._last_timestamp, + ) + self.depth_image.publish(depth_msg) + logger.debug(f"Published depth image (uint16): shape={depth_uint16.shape}") + + except Exception as e: + logger.error(f"Error publishing depth data: {e}", exc_info=True) + + +depth_module = DepthModule.blueprint + + +__all__ = ["DepthModule", "depth_module"] diff --git a/dimos/robot/unitree_webrtc/keyboard_teleop.py b/dimos/robot/unitree_webrtc/keyboard_teleop.py new file mode 100644 index 0000000000..bbdc724821 --- /dev/null +++ b/dimos/robot/unitree_webrtc/keyboard_teleop.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading + +import pygame + +from dimos.core import Module, Out, rpc +from dimos.msgs.geometry_msgs import Twist, Vector3 + +# Force X11 driver to avoid OpenGL threading issues +os.environ["SDL_VIDEODRIVER"] = "x11" + + +class KeyboardTeleop(Module): + """Pygame-based keyboard control module. + + Outputs standard Twist messages on /cmd_vel for velocity control. + """ + + cmd_vel: Out[Twist] = None # type: ignore[assignment] # Standard velocity commands + + _stop_event: threading.Event + _keys_held: set[int] | None = None + _thread: threading.Thread | None = None + _screen: pygame.Surface | None = None + _clock: pygame.time.Clock | None = None + _font: pygame.font.Font | None = None + + def __init__(self) -> None: + super().__init__() + self._stop_event = threading.Event() + + @rpc + def start(self) -> bool: + super().start() + + self._keys_held = set() + self._stop_event.clear() + + self._thread = threading.Thread(target=self._pygame_loop, daemon=True) + self._thread.start() + + return True + + @rpc + def stop(self) -> None: + stop_twist = Twist() + stop_twist.linear = Vector3(0, 0, 0) + stop_twist.angular = Vector3(0, 0, 0) + self.cmd_vel.publish(stop_twist) + + self._stop_event.set() + + if self._thread is None: + raise RuntimeError("Cannot stop: thread was never started") + self._thread.join(2) + + super().stop() + + def _pygame_loop(self) -> None: + if self._keys_held is None: + raise RuntimeError("_keys_held not initialized") + + pygame.init() + self._screen = pygame.display.set_mode((500, 400), pygame.SWSURFACE) + pygame.display.set_caption("Keyboard Teleop") + self._clock = pygame.time.Clock() + self._font = pygame.font.Font(None, 24) + + while not self._stop_event.is_set(): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + self._stop_event.set() + elif event.type == pygame.KEYDOWN: + self._keys_held.add(event.key) + + if event.key == pygame.K_SPACE: + # Emergency stop - clear all keys and send zero twist + self._keys_held.clear() + stop_twist = Twist() + stop_twist.linear = Vector3(0, 0, 0) + stop_twist.angular = Vector3(0, 0, 0) + self.cmd_vel.publish(stop_twist) + print("EMERGENCY STOP!") + elif event.key == pygame.K_ESCAPE: + # ESC quits + self._stop_event.set() + + elif event.type == pygame.KEYUP: + self._keys_held.discard(event.key) + + # Generate Twist message from held keys + twist = Twist() + twist.linear = Vector3(0, 0, 0) + twist.angular = Vector3(0, 0, 0) + + # Forward/backward (W/S) + if pygame.K_w in self._keys_held: + twist.linear.x = 0.5 + if pygame.K_s in self._keys_held: + twist.linear.x = -0.5 + + # Strafe left/right (Q/E) + if pygame.K_q in self._keys_held: + twist.linear.y = 0.5 + if pygame.K_e in self._keys_held: + twist.linear.y = -0.5 + + # Turning (A/D) + if pygame.K_a in self._keys_held: + twist.angular.z = 0.8 + if pygame.K_d in self._keys_held: + twist.angular.z = -0.8 + + # Apply speed modifiers (Shift = 2x, Ctrl = 0.5x) + speed_multiplier = 1.0 + if pygame.K_LSHIFT in self._keys_held or pygame.K_RSHIFT in self._keys_held: + speed_multiplier = 2.0 + elif pygame.K_LCTRL in self._keys_held or pygame.K_RCTRL in self._keys_held: + speed_multiplier = 0.5 + + twist.linear.x *= speed_multiplier + twist.linear.y *= speed_multiplier + twist.angular.z *= speed_multiplier + + # Always publish twist at 50Hz + self.cmd_vel.publish(twist) + + self._update_display(twist) + + # Maintain 50Hz rate + if self._clock is None: + raise RuntimeError("_clock not initialized") + self._clock.tick(50) + + pygame.quit() + + def _update_display(self, twist: Twist) -> None: + if self._screen is None or self._font is None or self._keys_held is None: + raise RuntimeError("Not initialized correctly") + + self._screen.fill((30, 30, 30)) + + y_pos = 20 + + # Determine active speed multiplier + speed_mult_text = "" + if pygame.K_LSHIFT in self._keys_held or pygame.K_RSHIFT in self._keys_held: + speed_mult_text = " [BOOST 2x]" + elif pygame.K_LCTRL in self._keys_held or pygame.K_RCTRL in self._keys_held: + speed_mult_text = " [SLOW 0.5x]" + + texts = [ + "Keyboard Teleop" + speed_mult_text, + "", + f"Linear X (Forward/Back): {twist.linear.x:+.2f} m/s", + f"Linear Y (Strafe L/R): {twist.linear.y:+.2f} m/s", + f"Angular Z (Turn L/R): {twist.angular.z:+.2f} rad/s", + "", + "Keys: " + ", ".join([pygame.key.name(k).upper() for k in self._keys_held if k < 256]), + ] + + for text in texts: + if text: + color = (0, 255, 255) if text.startswith("Keyboard Teleop") else (255, 255, 255) + surf = self._font.render(text, True, color) + self._screen.blit(surf, (20, y_pos)) + y_pos += 30 + + if twist.linear.x != 0 or twist.linear.y != 0 or twist.angular.z != 0: + pygame.draw.circle(self._screen, (255, 0, 0), (450, 30), 15) # Red = moving + else: + pygame.draw.circle(self._screen, (0, 255, 0), (450, 30), 15) # Green = stopped + + y_pos = 280 + help_texts = [ + "WS: Move | AD: Turn | QE: Strafe", + "Shift: Boost | Ctrl: Slow", + "Space: E-Stop | ESC: Quit", + ] + for text in help_texts: + surf = self._font.render(text, True, (150, 150, 150)) + self._screen.blit(surf, (20, y_pos)) + y_pos += 25 + + pygame.display.flip() + + +keyboard_teleop = KeyboardTeleop.blueprint + +__all__ = ["KeyboardTeleop", "keyboard_teleop"] diff --git a/dimos/robot/unitree_webrtc/modular/__init__.py b/dimos/robot/unitree_webrtc/modular/__init__.py new file mode 100644 index 0000000000..5c2169cc9b --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/__init__.py @@ -0,0 +1 @@ +from dimos.robot.unitree_webrtc.modular.connection_module import deploy_connection diff --git a/dimos/robot/unitree_webrtc/modular/connection_module.py b/dimos/robot/unitree_webrtc/modular/connection_module.py new file mode 100644 index 0000000000..36cd2e7b51 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/connection_module.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 + +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +import functools +import logging +import os +import queue +import warnings + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +import reactivex as rx +from reactivex import operators as ops +from reactivex.observable import Observable + +from dimos.agents2 import Output, Reducer, Stream, skill # type: ignore[attr-defined] +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core import DimosCluster, In, LCMTransport, Module, ModuleConfig, Out, pSHMTransport, rpc +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.std_msgs import Header +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage + +logger = setup_logger(level=logging.INFO) + +# Suppress verbose loggers +logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("websockets.server").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("root").setLevel(logging.WARNING) + + +# Suppress warnings +warnings.filterwarnings("ignore", message="coroutine.*was never awaited") +warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") + +image_resize_factor = 1 +originalwidth, originalheight = (1280, 720) + + +class FakeRTC(UnitreeWebRTCConnection): + dir_name = "unitree_go2_office_walk2" + + # we don't want UnitreeWebRTCConnection to init + def __init__( # type: ignore[no-untyped-def] + self, + **kwargs, + ) -> None: + get_data(self.dir_name) + self.replay_config = { + "loop": kwargs.get("loop"), + "seek": kwargs.get("seek"), + "duration": kwargs.get("duration"), + } + + def connect(self) -> None: + pass + + def start(self) -> None: + pass + + def standup(self) -> None: + print("standup suppressed") + + def liedown(self) -> None: + print("liedown suppressed") + + @functools.cache + def lidar_stream(self): # type: ignore[no-untyped-def] + print("lidar stream start") + lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") # type: ignore[var-annotated] + return lidar_store.stream(**self.replay_config) # type: ignore[arg-type] + + @functools.cache + def odom_stream(self): # type: ignore[no-untyped-def] + print("odom stream start") + odom_store = TimedSensorReplay(f"{self.dir_name}/odom") # type: ignore[var-annotated] + return odom_store.stream(**self.replay_config) # type: ignore[arg-type] + + # we don't have raw video stream in the data set + @functools.cache + def video_stream(self): # type: ignore[no-untyped-def] + print("video stream start") + video_store = TimedSensorReplay(f"{self.dir_name}/video") # type: ignore[var-annotated] + + return video_store.stream(**self.replay_config) # type: ignore[arg-type] + + def move(self, vector: Twist, duration: float = 0.0) -> None: # type: ignore[override] + pass + + def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] + """Fake publish request for testing.""" + return {"status": "ok", "message": "Fake publish"} + + +@dataclass +class ConnectionModuleConfig(ModuleConfig): + ip: str | None = None + connection_type: str = "fake" # or "fake" or "mujoco" + loop: bool = False # For fake connection + speed: float = 1.0 # For fake connection + + +class ConnectionModule(Module): + camera_info: Out[CameraInfo] = None # type: ignore[assignment] + odom: Out[PoseStamped] = None # type: ignore[assignment] + lidar: Out[LidarMessage] = None # type: ignore[assignment] + video: Out[Image] = None # type: ignore[assignment] + movecmd: In[Twist] = None # type: ignore[assignment] + + connection = None + + default_config = ConnectionModuleConfig + + # mega temporary, skill should have a limit decorator for number of + # parallel calls + video_running: bool = False + + def __init__(self, connection_type: str = "webrtc", *args, **kwargs) -> None: # type: ignore[no-untyped-def] + self.connection_config = kwargs + self.connection_type = connection_type + Module.__init__(self, *args, **kwargs) + + @skill(stream=Stream.passive, output=Output.image, reducer=Reducer.latest) # type: ignore[arg-type] + def video_stream_tool(self) -> Image: # type: ignore[misc] + """implicit video stream skill, don't call this directly""" + if self.video_running: + return "video stream already running" + self.video_running = True + _queue = queue.Queue(maxsize=1) # type: ignore[var-annotated] + self.connection.video_stream().subscribe(_queue.put) # type: ignore[attr-defined] + + yield from iter(_queue.get, None) + + @rpc + def record(self, recording_name: str) -> None: + lidar_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/lidar") # type: ignore[type-arg] + lidar_store.save_stream(self.connection.lidar_stream()).subscribe(lambda x: x) # type: ignore[arg-type, attr-defined] + + odom_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/odom") # type: ignore[type-arg] + odom_store.save_stream(self.connection.odom_stream()).subscribe(lambda x: x) # type: ignore[arg-type, attr-defined] + + video_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/video") # type: ignore[type-arg] + video_store.save_stream(self.connection.video_stream()).subscribe(lambda x: x) # type: ignore[arg-type, attr-defined] + + @rpc + def start(self): # type: ignore[no-untyped-def] + """Start the connection and subscribe to sensor streams.""" + + super().start() + + match self.connection_type: + case "webrtc": + self.connection = UnitreeWebRTCConnection(**self.connection_config) + case "fake": + self.connection = FakeRTC(**self.connection_config, seek=12.0) + case "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection(GlobalConfig()) # type: ignore[assignment] + self.connection.start() # type: ignore[union-attr] + case _: + raise ValueError(f"Unknown connection type: {self.connection_type}") + + unsub = self.connection.odom_stream().subscribe( # type: ignore[union-attr] + lambda odom: self._publish_tf(odom) and self.odom.publish(odom) # type: ignore[func-returns-value] + ) + self._disposables.add(unsub) + + # Connect sensor streams to outputs + unsub = self.connection.lidar_stream().subscribe(self.lidar.publish) # type: ignore[union-attr] + self._disposables.add(unsub) + + # self.connection.lidar_stream().subscribe(lambda lidar: print("LIDAR", lidar.ts)) + # self.connection.video_stream().subscribe(lambda video: print("IMAGE", video.ts)) + # self.connection.odom_stream().subscribe(lambda odom: print("ODOM", odom.ts)) + + def resize(image: Image) -> Image: + return image.resize( + int(originalwidth / image_resize_factor), int(originalheight / image_resize_factor) + ) + + unsub = self.connection.video_stream().subscribe(self.video.publish) # type: ignore[union-attr] + self._disposables.add(unsub) + unsub = self.camera_info_stream().subscribe(self.camera_info.publish) + self._disposables.add(unsub) + unsub = self.movecmd.subscribe(self.connection.move) # type: ignore[union-attr] + self._disposables.add(unsub) # type: ignore[arg-type] + + @rpc + def stop(self) -> None: + if self.connection: + self.connection.stop() + + super().stop() + + @classmethod + def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom.ts, + ) + + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=odom.ts, + ) + + sensor = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="sensor", + ts=odom.ts, + ) + + return [ + Transform.from_pose("base_link", odom), + camera_link, + camera_optical, + sensor, + ] + + def _publish_tf(self, msg) -> None: # type: ignore[no-untyped-def] + self.odom.publish(msg) + self.tf.publish(*self._odom_to_tf(msg)) + + @rpc + def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] + """Publish a request to the WebRTC connection. + Args: + topic: The RTC topic to publish to + data: The data dictionary to publish + Returns: + The result of the publish request + """ + return self.connection.publish_request(topic, data) # type: ignore[union-attr] + + @classmethod + def _camera_info(cls) -> Out[CameraInfo]: + fx, fy, cx, cy = list( + map( + lambda x: int(x / image_resize_factor), + [819.553492, 820.646595, 625.284099, 336.808987], + ) + ) + width, height = tuple( + map( + lambda x: int(x / image_resize_factor), + [originalwidth, originalheight], + ) + ) + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + base_msg = { + "D_length": len(D), + "height": height, + "width": width, + "distortion_model": "plumb_bob", + "D": D, + "K": K, + "R": R, + "P": P, + "binning_x": 0, + "binning_y": 0, + } + + return CameraInfo(**base_msg, header=Header("camera_optical")) # type: ignore[no-any-return] + + @functools.cache + def camera_info_stream(self) -> Observable[CameraInfo]: + return rx.interval(1).pipe(ops.map(lambda _: self._camera_info())) + + +def deploy_connection(dimos: DimosCluster, **kwargs): # type: ignore[no-untyped-def] + foxglove_bridge = dimos.deploy(FoxgloveBridge) # type: ignore[attr-defined, name-defined] + foxglove_bridge.start() + + connection = dimos.deploy( # type: ignore[attr-defined] + ConnectionModule, + ip=os.getenv("ROBOT_IP"), + connection_type=os.getenv("CONNECTION_TYPE", "fake"), + **kwargs, + ) + + connection.odom.transport = LCMTransport("/odom", PoseStamped) + + connection.video.transport = pSHMTransport( + "/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + + connection.lidar.transport = pSHMTransport( + "/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + + connection.video.transport = LCMTransport("/image", Image) + connection.lidar.transport = LCMTransport("/lidar", LidarMessage) + connection.movecmd.transport = LCMTransport("/cmd_vel", Twist) + connection.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + + return connection diff --git a/dimos/robot/unitree_webrtc/modular/detect.py b/dimos/robot/unitree_webrtc/modular/detect.py new file mode 100644 index 0000000000..11c166fbe4 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/detect.py @@ -0,0 +1,185 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry + +image_resize_factor = 1 +originalwidth, originalheight = (1280, 720) + + +def camera_info() -> CameraInfo: + fx, fy, cx, cy = list( + map( + lambda x: int(x / image_resize_factor), + [819.553492, 820.646595, 625.284099, 336.808987], + ) + ) + width, height = tuple( + map( + lambda x: int(x / image_resize_factor), + [originalwidth, originalheight], + ) + ) + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + base_msg = { + "D_length": len(D), + "height": height, + "width": width, + "distortion_model": "plumb_bob", + "D": D, + "K": K, + "R": R, + "P": P, + "binning_x": 0, + "binning_y": 0, + } + + return CameraInfo( + **base_msg, + header=Header("camera_optical"), + ) + + +def transform_chain(odom_frame: Odometry) -> list: # type: ignore[type-arg] + from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 + from dimos.protocol.tf import TF + + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom_frame.ts, + ) + + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=camera_link.ts, + ) + + tf = TF() + tf.publish( + Transform.from_pose("base_link", odom_frame), + camera_link, + camera_optical, + ) + + return tf # type: ignore[return-value] + + +def broadcast( # type: ignore[no-untyped-def] + timestamp: float, + lidar_frame: LidarMessage, + video_frame: Image, + odom_frame: Odometry, + detections, + annotations, +) -> None: + from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + ImageAnnotations, + ) + + from dimos.core import LCMTransport + from dimos.msgs.geometry_msgs import PoseStamped + + lidar_transport = LCMTransport("/lidar", LidarMessage) # type: ignore[var-annotated] + odom_transport = LCMTransport("/odom", PoseStamped) # type: ignore[var-annotated] + video_transport = LCMTransport("/image", Image) # type: ignore[var-annotated] + camera_info_transport = LCMTransport("/camera_info", CameraInfo) # type: ignore[var-annotated] + + lidar_transport.broadcast(None, lidar_frame) + video_transport.broadcast(None, video_frame) + odom_transport.broadcast(None, odom_frame) + camera_info_transport.broadcast(None, camera_info()) + + transform_chain(odom_frame) + + print(lidar_frame) + print(video_frame) + print(odom_frame) + video_transport = LCMTransport("/image", Image) + annotations_transport = LCMTransport("/annotations", ImageAnnotations) # type: ignore[var-annotated] + annotations_transport.broadcast(None, annotations) + + +def process_data(): # type: ignore[no-untyped-def] + from dimos.msgs.sensor_msgs import Image + from dimos.perception.detection.module2D import ( # type: ignore[attr-defined] + Detection2DModule, + build_imageannotations, + ) + from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + from dimos.robot.unitree_webrtc.type.odometry import Odometry + from dimos.utils.data import get_data + from dimos.utils.testing import TimedSensorReplay + + get_data("unitree_office_walk") + target = 1751591272.9654856 + lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) + video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) + odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + def attach_frame_id(image: Image) -> Image: + image.frame_id = "camera_optical" + return image + + lidar_frame = lidar_store.find_closest(target, tolerance=1) + video_frame = attach_frame_id(video_store.find_closest(target, tolerance=1)) # type: ignore[arg-type] + odom_frame = odom_store.find_closest(target, tolerance=1) + + detector = Detection2DModule() + detections = detector.detect(video_frame) # type: ignore[attr-defined] + annotations = build_imageannotations(detections) + + data = (target, lidar_frame, video_frame, odom_frame, detections, annotations) + + with open("filename.pkl", "wb") as file: + pickle.dump(data, file) + + return data + + +def main() -> None: + try: + with open("filename.pkl", "rb") as file: + data = pickle.load(file) + except FileNotFoundError: + print("Processing data and creating pickle file...") + data = process_data() # type: ignore[no-untyped-call] + broadcast(*data) + + +main() diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py new file mode 100644 index 0000000000..7b4b6776ce --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -0,0 +1,98 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from dimos.agents2.spec import Model, Provider +from dimos.core import LCMTransport, start +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.module2D import Detection2DModule +from dimos.perception.detection.reid import ReidModule +from dimos.protocol.pubsub import lcm # type: ignore[attr-defined] +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.robot.unitree_webrtc.modular import deploy_connection # type: ignore[attr-defined] +from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.INFO) + + +def detection_unitree() -> None: + dimos = start(8) + connection = deploy_connection(dimos) + + def goto(pose) -> bool: # type: ignore[no-untyped-def] + print("NAVIGATION REQUESTED:", pose) + return True + + detector = dimos.deploy( # type: ignore[attr-defined] + Detection2DModule, + camera_info=ConnectionModule._camera_info(), + ) + + detector.image.connect(connection.video) + + detector.annotations.transport = LCMTransport("/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport("/detections", Detection2DArray) + + detector.detected_image_0.transport = LCMTransport("/detected/image/0", Image) + detector.detected_image_1.transport = LCMTransport("/detected/image/1", Image) + detector.detected_image_2.transport = LCMTransport("/detected/image/2", Image) + + reid = dimos.deploy(ReidModule) # type: ignore[attr-defined] + + reid.image.connect(connection.video) + reid.detections.connect(detector.detections) + reid.annotations.transport = LCMTransport("/reid/annotations", ImageAnnotations) + + detector.start() + connection.start() + reid.start() + + from dimos.agents2 import Agent # type: ignore[attr-defined] + from dimos.agents2.cli.human import HumanInput + + agent = Agent( + system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot.", + model=Model.GPT_4O, # Could add CLAUDE models to enum + provider=Provider.OPENAI, # type: ignore[attr-defined] # Would need ANTHROPIC provider + ) + + human_input = dimos.deploy(HumanInput) # type: ignore[attr-defined] + agent.register_skills(human_input) + agent.register_skills(detector) + + bridge = FoxgloveBridge( + shm_channels=[ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + ] + ) + time.sleep(1) + bridge.start() + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + connection.stop() + logger.info("Shutting down...") + + +if __name__ == "__main__": + lcm.autoconf() + detection_unitree() diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py new file mode 100644 index 0000000000..a80704fc95 --- /dev/null +++ b/dimos/robot/unitree_webrtc/mujoco_connection.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import base64 +from collections.abc import Callable +import functools +import json +import pickle +import subprocess +import sys +import threading +import time +from typing import Any, TypeVar + +import numpy as np +from numpy.typing import NDArray +from reactivex import Observable +from reactivex.abc import ObserverBase, SchedulerBase +from reactivex.disposable import Disposable + +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.simulation.mujoco.constants import LAUNCHER_PATH, LIDAR_FPS, VIDEO_FPS +from dimos.simulation.mujoco.shared_memory import ShmWriter +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger + +ODOM_FREQUENCY = 50 + +logger = setup_logger() + +T = TypeVar("T") + + +class MujocoConnection: + """MuJoCo simulator connection that runs in a separate subprocess.""" + + def __init__(self, global_config: GlobalConfig) -> None: + try: + import mujoco # type: ignore[import-untyped] + except ImportError: + raise ImportError("'mujoco' is not installed. Use `pip install -e .[sim]`") + + get_data("mujoco_sim") + + self.global_config = global_config + self.process: subprocess.Popen[bytes] | None = None + self.shm_data: ShmWriter | None = None + self._last_video_seq = 0 + self._last_odom_seq = 0 + self._last_lidar_seq = 0 + self._stop_timer: threading.Timer | None = None + + self._stream_threads: list[threading.Thread] = [] + self._stop_events: list[threading.Event] = [] + self._is_cleaned_up = False + + def start(self) -> None: + self.shm_data = ShmWriter() + + config_pickle = base64.b64encode(pickle.dumps(self.global_config)).decode("ascii") + shm_names_json = json.dumps(self.shm_data.shm.to_names()) + + # Launch the subprocess + try: + self.process = subprocess.Popen( + [sys.executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], + ) + + except Exception as e: + self.shm_data.cleanup() + raise RuntimeError(f"Failed to start MuJoCo subprocess: {e}") from e + + # Wait for process to be ready + ready_timeout = 10 + start_time = time.time() + assert self.process is not None + while time.time() - start_time < ready_timeout: + if self.process.poll() is not None: + exit_code = self.process.returncode + self.stop() + raise RuntimeError(f"MuJoCo process failed to start (exit code {exit_code})") + if self.shm_data.is_ready(): + logger.info("MuJoCo process started successfully") + return + time.sleep(0.1) + + # Timeout + self.stop() + raise RuntimeError("MuJoCo process failed to start (timeout)") + + def stop(self) -> None: + if self._is_cleaned_up: + return + + self._is_cleaned_up = True + + # Cancel any pending timers + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + + # Stop all stream threads + for stop_event in self._stop_events: + stop_event.set() + + # Wait for threads to finish + for thread in self._stream_threads: + if thread.is_alive(): + thread.join(timeout=2.0) + if thread.is_alive(): + logger.warning(f"Stream thread {thread.name} did not stop gracefully") + + # Signal subprocess to stop + if self.shm_data: + self.shm_data.signal_stop() + + # Wait for process to finish + if self.process: + try: + self.process.terminate() + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + logger.warning("MuJoCo process did not stop gracefully, killing") + self.process.kill() + self.process.wait(timeout=2) + except Exception as e: + logger.error(f"Error stopping MuJoCo process: {e}") + + self.process = None + + # Clean up shared memory + if self.shm_data: + self.shm_data.cleanup() + self.shm_data = None + + # Clear references + self._stream_threads.clear() + self._stop_events.clear() + + self.lidar_stream.cache_clear() + self.odom_stream.cache_clear() + self.video_stream.cache_clear() + + def standup(self) -> bool: + return True + + def liedown(self) -> bool: + return True + + def get_video_frame(self) -> NDArray[Any] | None: + if self.shm_data is None: + return None + + frame, seq = self.shm_data.read_video() + if seq > self._last_video_seq: + self._last_video_seq = seq + return frame + + return None + + def get_odom_message(self) -> Odometry | None: + if self.shm_data is None: + return None + + odom_data, seq = self.shm_data.read_odom() + if seq > self._last_odom_seq and odom_data is not None: + self._last_odom_seq = seq + pos, quat_wxyz, timestamp = odom_data + + # Convert quaternion from (w,x,y,z) to (x,y,z,w) for ROS/Dimos + orientation = Quaternion(quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]) + + return Odometry( + position=Vector3(pos[0], pos[1], pos[2]), + orientation=orientation, + ts=timestamp, + frame_id="world", + ) + + return None + + def get_lidar_message(self) -> LidarMessage | None: + if self.shm_data is None: + return None + + lidar_msg, seq = self.shm_data.read_lidar() + if seq > self._last_lidar_seq and lidar_msg is not None: + self._last_lidar_seq = seq + return lidar_msg + + return None + + def _create_stream( + self, + getter: Callable[[], T | None], + frequency: float, + stream_name: str, + ) -> Observable[T]: + def on_subscribe(observer: ObserverBase[T], _scheduler: SchedulerBase | None) -> Disposable: + if self._is_cleaned_up: + observer.on_completed() + return Disposable(lambda: None) + + stop_event = threading.Event() + self._stop_events.append(stop_event) + + def run() -> None: + try: + while not stop_event.is_set() and not self._is_cleaned_up: + data = getter() + if data is not None: + observer.on_next(data) + time.sleep(1 / frequency) + except Exception as e: + logger.error(f"{stream_name} stream error: {e}") + finally: + observer.on_completed() + + thread = threading.Thread(target=run, daemon=True) + self._stream_threads.append(thread) + thread.start() + + def dispose() -> None: + stop_event.set() + + return Disposable(dispose) + + return Observable(on_subscribe) + + @functools.cache + def lidar_stream(self) -> Observable[LidarMessage]: + return self._create_stream(self.get_lidar_message, LIDAR_FPS, "Lidar") + + @functools.cache + def odom_stream(self) -> Observable[Odometry]: + return self._create_stream(self.get_odom_message, ODOM_FREQUENCY, "Odom") + + @functools.cache + def video_stream(self) -> Observable[Image]: + def get_video_as_image() -> Image | None: + frame = self.get_video_frame() + return Image.from_numpy(frame) if frame is not None else None + + return self._create_stream(get_video_as_image, VIDEO_FPS, "Video") + + def move(self, twist: Twist, duration: float = 0.0) -> bool: + if self._is_cleaned_up or self.shm_data is None: + return True + + linear = np.array([twist.linear.x, twist.linear.y, twist.linear.z], dtype=np.float32) + angular = np.array([twist.angular.x, twist.angular.y, twist.angular.z], dtype=np.float32) + self.shm_data.write_command(linear, angular) + + if duration > 0: + if self._stop_timer: + self._stop_timer.cancel() + + def stop_movement() -> None: + if self.shm_data: + self.shm_data.write_command( + np.zeros(3, dtype=np.float32), np.zeros(3, dtype=np.float32) + ) + self._stop_timer = None + + self._stop_timer = threading.Timer(duration, stop_movement) + self._stop_timer.daemon = True + self._stop_timer.start() + return True + + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + print(f"publishing request, topic={topic}, data={data}") + return {} diff --git a/dimos/robot/unitree_webrtc/params/front_camera_720.yaml b/dimos/robot/unitree_webrtc/params/front_camera_720.yaml new file mode 100644 index 0000000000..0030d5fc6c --- /dev/null +++ b/dimos/robot/unitree_webrtc/params/front_camera_720.yaml @@ -0,0 +1,26 @@ +image_width: 1280 +image_height: 720 +camera_name: narrow_stereo +camera_matrix: + rows: 3 + cols: 3 + data: [864.39938, 0. , 639.19798, + 0. , 863.73849, 373.28118, + 0. , 0. , 1. ] +distortion_model: plumb_bob +distortion_coefficients: + rows: 1 + cols: 5 + data: [-0.354630, 0.102054, -0.001614, -0.001249, 0.000000] +rectification_matrix: + rows: 3 + cols: 3 + data: [1., 0., 0., + 0., 1., 0., + 0., 0., 1.] +projection_matrix: + rows: 3 + cols: 4 + data: [651.42609, 0. , 633.16224, 0. , + 0. , 804.93951, 373.8537 , 0. , + 0. , 0. , 1. , 0. ] diff --git a/dimos/robot/unitree_webrtc/params/sim_camera.yaml b/dimos/robot/unitree_webrtc/params/sim_camera.yaml new file mode 100644 index 0000000000..6a5ac3e6d8 --- /dev/null +++ b/dimos/robot/unitree_webrtc/params/sim_camera.yaml @@ -0,0 +1,26 @@ +image_width: 320 +image_height: 240 +camera_name: sim_camera +camera_matrix: + rows: 3 + cols: 3 + data: [277., 0. , 160. , + 0. , 277., 120. , + 0. , 0. , 1. ] +distortion_model: plumb_bob +distortion_coefficients: + rows: 1 + cols: 5 + data: [0.0, 0.0, 0.0, 0.0, 0.0] +rectification_matrix: + rows: 3 + cols: 3 + data: [1., 0., 0., + 0., 1., 0., + 0., 0., 1.] +projection_matrix: + rows: 3 + cols: 4 + data: [277., 0. , 160. , 0. , + 0. , 277., 120. , 0. , + 0. , 0. , 1. , 0. ] diff --git a/dimos/robot/unitree_webrtc/rosnav.py b/dimos/robot/unitree_webrtc/rosnav.py new file mode 100644 index 0000000000..79bd8d70bd --- /dev/null +++ b/dimos/robot/unitree_webrtc/rosnav.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Joy +from dimos.msgs.std_msgs.Bool import Bool +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.INFO) + + +# TODO: Remove, deprecated +class NavigationModule(Module): + goal_pose: Out[PoseStamped] = None # type: ignore[assignment] + goal_reached: In[Bool] = None # type: ignore[assignment] + cancel_goal: Out[Bool] = None # type: ignore[assignment] + joy: Out[Joy] = None # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + """Initialize NavigationModule.""" + Module.__init__(self, *args, **kwargs) + self.goal_reach = None + + @rpc + def start(self) -> None: + """Start the navigation module.""" + if self.goal_reached: + self.goal_reached.subscribe(self._on_goal_reached) + logger.info("NavigationModule started") + + def _on_goal_reached(self, msg: Bool) -> None: + """Handle goal reached status messages.""" + self.goal_reach = msg.data # type: ignore[assignment] + + def _set_autonomy_mode(self) -> None: + """ + Set autonomy mode by publishing Joy message. + """ + + joy_msg = Joy( + frame_id="dimos", + axes=[ + 0.0, # axis 0 + 0.0, # axis 1 + -1.0, # axis 2 + 0.0, # axis 3 + 1.0, # axis 4 + 1.0, # axis 5 + 0.0, # axis 6 + 0.0, # axis 7 + ], + buttons=[ + 0, # button 0 + 0, # button 1 + 0, # button 2 + 0, # button 3 + 0, # button 4 + 0, # button 5 + 0, # button 6 + 1, # button 7 - controls autonomy mode + 0, # button 8 + 0, # button 9 + 0, # button 10 + ], + ) + + if self.joy: + self.joy.publish(joy_msg) + logger.info("Setting autonomy mode via Joy message") + + @rpc + def go_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: + """ + Navigate to a target pose by publishing to LCM topics. + + Args: + pose: Target pose to navigate to + blocking: If True, block until goal is reached + timeout: Maximum time to wait for goal (seconds) + + Returns: + True if navigation was successful (or started if non-blocking) + """ + logger.info( + f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + ) + + self.goal_reach = None + self._set_autonomy_mode() + self.goal_pose.publish(pose) + time.sleep(0.2) + self.goal_pose.publish(pose) + + start_time = time.time() + while time.time() - start_time < timeout: + if self.goal_reach is not None: + return self.goal_reach + time.sleep(0.1) + + self.stop() + + logger.warning(f"Navigation timed out after {timeout} seconds") + return False + + @rpc + def stop(self) -> bool: + """ + Cancel current navigation by publishing to cancel_goal. + + Returns: + True if cancel command was sent successfully + """ + logger.info("Cancelling navigation") + + if self.cancel_goal: + cancel_msg = Bool(data=True) + self.cancel_goal.publish(cancel_msg) + return True + + return False diff --git a/dimos/robot/unitree_webrtc/testing/__init__.py b/dimos/robot/unitree_webrtc/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/robot/unitree_webrtc/testing/helpers.py b/dimos/robot/unitree_webrtc/testing/helpers.py new file mode 100644 index 0000000000..aaf188dbc3 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/helpers.py @@ -0,0 +1,170 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Iterable +import time +from typing import Any, Protocol + +import open3d as o3d # type: ignore[import-untyped] +from reactivex.observable import Observable + +color1 = [1, 0.706, 0] +color2 = [0, 0.651, 0.929] +color3 = [0.8, 0.196, 0.6] +color4 = [0.235, 0.702, 0.443] +color = [color1, color2, color3, color4] + + +# benchmarking function can return int, which will be applied to the time. +# +# (in case there is some preparation within the fuction and this time needs to be subtracted +# from the benchmark target) +def benchmark(calls: int, targetf: Callable[[], int | None]) -> float: + start = time.time() + timemod = 0 + for _ in range(calls): + res = targetf() + if res is not None: + timemod += res + end = time.time() + return (end - start + timemod) * 1000 / calls + + +O3dDrawable = ( + o3d.geometry.Geometry + | o3d.geometry.LineSet + | o3d.geometry.TriangleMesh + | o3d.geometry.PointCloud +) + + +class ReturnsDrawable(Protocol): + def o3d_geometry(self) -> O3dDrawable: ... # type: ignore[valid-type] + + +Drawable = O3dDrawable | ReturnsDrawable + + +def show3d(*components: Iterable[Drawable], title: str = "open3d") -> o3d.visualization.Visualizer: # type: ignore[valid-type] + vis = o3d.visualization.Visualizer() + vis.create_window(window_name=title) + for component in components: + # our custom drawable components should return an open3d geometry + if hasattr(component, "o3d_geometry"): + vis.add_geometry(component.o3d_geometry) + else: + vis.add_geometry(component) + + opt = vis.get_render_option() + opt.background_color = [0, 0, 0] + opt.point_size = 10 + vis.poll_events() + vis.update_renderer() + return vis + + +def multivis(*vis: o3d.visualization.Visualizer) -> None: + while True: + for v in vis: + v.poll_events() + v.update_renderer() + + +def show3d_stream( + geometry_observable: Observable[Any], + clearframe: bool = False, + title: str = "open3d", +) -> o3d.visualization.Visualizer: + """ + Visualize a stream of geometries using Open3D. The first geometry initializes the visualizer. + Subsequent geometries update the visualizer. If no new geometry, just poll events. + geometry_observable: Observable of objects with .o3d_geometry or Open3D geometry + """ + import queue + import threading + import time + from typing import Any + + q: queue.Queue[Any] = queue.Queue() + stop_flag = threading.Event() + + def on_next(geometry: O3dDrawable) -> None: # type: ignore[valid-type] + q.put(geometry) + + def on_error(e: Exception) -> None: + print(f"Visualization error: {e}") + stop_flag.set() + + def on_completed() -> None: + print("Geometry stream completed") + stop_flag.set() + + subscription = geometry_observable.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + ) + + def geom(geometry: Drawable) -> O3dDrawable: # type: ignore[valid-type] + """Extracts the Open3D geometry from the given object.""" + return geometry.o3d_geometry if hasattr(geometry, "o3d_geometry") else geometry # type: ignore[attr-defined, no-any-return] + + # Wait for the first geometry + first_geometry = None + while first_geometry is None and not stop_flag.is_set(): + try: + first_geometry = q.get(timeout=100) + except queue.Empty: + print("No geometry received to visualize.") + return + + scene_geometries = [] + first_geom_obj = geom(first_geometry) + + scene_geometries.append(first_geom_obj) + + vis = show3d(first_geom_obj, title=title) + + try: + while not stop_flag.is_set(): + try: + geometry = q.get_nowait() + geom_obj = geom(geometry) + if clearframe: + scene_geometries = [] + vis.clear_geometries() + + vis.add_geometry(geom_obj) + scene_geometries.append(geom_obj) + else: + if geom_obj in scene_geometries: + print("updating existing geometry") + vis.update_geometry(geom_obj) + else: + print("new geometry") + vis.add_geometry(geom_obj) + scene_geometries.append(geom_obj) + except queue.Empty: + pass + vis.poll_events() + vis.update_renderer() + time.sleep(0.1) + + except KeyboardInterrupt: + print("closing visualizer...") + stop_flag.set() + vis.destroy_window() + subscription.dispose() + + return vis diff --git a/dimos/robot/unitree_webrtc/testing/mock.py b/dimos/robot/unitree_webrtc/testing/mock.py new file mode 100644 index 0000000000..34ca390842 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/mock.py @@ -0,0 +1,92 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterator +import glob +import os +import pickle +from typing import cast, overload + +from reactivex import from_iterable, interval, operators as ops +from reactivex.observable import Observable + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage, RawLidarMsg + + +class Mock: + def __init__(self, root: str = "office", autocast: bool = True) -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + self.root = os.path.join(current_dir, f"mockdata/{root}") + self.autocast = autocast + self.cnt = 0 + + @overload + def load(self, name: int | str, /) -> LidarMessage: ... + @overload + def load(self, *names: int | str) -> list[LidarMessage]: ... + + def load(self, *names: int | str) -> LidarMessage | list[LidarMessage]: + if len(names) == 1: + return self.load_one(names[0]) + return list(map(lambda name: self.load_one(name), names)) + + def load_one(self, name: int | str) -> LidarMessage: + if isinstance(name, int): + file_name = f"/lidar_data_{name:03d}.pickle" + else: + file_name = f"/{name}.pickle" + + full_path = self.root + file_name + with open(full_path, "rb") as f: + return LidarMessage.from_msg(cast("RawLidarMsg", pickle.load(f))) + + def iterate(self) -> Iterator[LidarMessage]: + pattern = os.path.join(self.root, "lidar_data_*.pickle") + print("loading data", pattern) + for file_path in sorted(glob.glob(pattern)): + basename = os.path.basename(file_path) + filename = os.path.splitext(basename)[0] + yield self.load_one(filename) + + def stream(self, rate_hz: float = 10.0): # type: ignore[no-untyped-def] + sleep_time = 1.0 / rate_hz + + return from_iterable(self.iterate()).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda x: x[0] if isinstance(x, tuple) else x), + ) + + def save_stream(self, observable: Observable[LidarMessage]): # type: ignore[no-untyped-def] + return observable.pipe(ops.map(lambda frame: self.save_one(frame))) # type: ignore[no-untyped-call] + + def save(self, *frames): # type: ignore[no-untyped-def] + [self.save_one(frame) for frame in frames] # type: ignore[no-untyped-call] + return self.cnt + + def save_one(self, frame): # type: ignore[no-untyped-def] + file_name = f"/lidar_data_{self.cnt:03d}.pickle" + full_path = self.root + file_name + + self.cnt += 1 + + if os.path.isfile(full_path): + raise Exception(f"file {full_path} exists") + + if frame.__class__ == LidarMessage: + frame = frame.raw_msg + + with open(full_path, "wb") as f: + pickle.dump(frame, f) + + return self.cnt diff --git a/dimos/robot/unitree_webrtc/testing/test_actors.py b/dimos/robot/unitree_webrtc/testing/test_actors.py new file mode 100644 index 0000000000..7e79ca24cc --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/test_actors.py @@ -0,0 +1,111 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from collections.abc import Callable +import time + +import pytest + +from dimos import core +from dimos.core import Module, rpc +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map as Mapper + + +@pytest.fixture +def dimos(): + return core.start(2) + + +@pytest.fixture +def client(): + return core.start(2) + + +class Consumer: + testf: Callable[[int], int] + + def __init__(self, counter=None) -> None: + self.testf = counter + print("consumer init with", counter) + + async def waitcall(self, n: int): + async def task() -> None: + await asyncio.sleep(n) + + print("sleep finished, calling") + res = await self.testf(n) + print("res is", res) + + asyncio.create_task(task()) + return n + + +class Counter(Module): + @rpc + def addten(self, x: int): + print(f"counter adding to {x}") + return x + 10 + + +@pytest.mark.tool +def test_wait(client) -> None: + counter = client.submit(Counter, actor=True).result() + + async def addten(n): + return await counter.addten(n) + + consumer = client.submit(Consumer, counter=addten, actor=True).result() + + print("waitcall1", consumer.waitcall(2).result()) + print("waitcall2", consumer.waitcall(2).result()) + time.sleep(1) + + +@pytest.mark.tool +def test_basic(dimos) -> None: + counter = dimos.deploy(Counter) + consumer = dimos.deploy( + Consumer, + counter=lambda x: counter.addten(x).result(), + ) + + print(consumer) + print(counter) + print("starting consumer") + consumer.start().result() + + res = consumer.inc(10).result() + + print("result is", res) + assert res == 20 + + +@pytest.mark.tool +def test_mapper_start(dimos) -> None: + mapper = dimos.deploy(Mapper) + mapper.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + print("start res", mapper.start().result()) + + +if __name__ == "__main__": + dimos = core.start(2) + test_basic(dimos) + test_mapper_start(dimos) + + +@pytest.mark.tool +def test_counter(dimos) -> None: + counter = dimos.deploy(Counter) + assert counter.addten(10) == 20 diff --git a/dimos/robot/unitree_webrtc/testing/test_mock.py b/dimos/robot/unitree_webrtc/testing/test_mock.py new file mode 100644 index 0000000000..0765894409 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/test_mock.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.robot.unitree_webrtc.testing.mock import Mock +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + + +@pytest.mark.needsdata +def test_mock_load_cast() -> None: + mock = Mock("test") + + # Load a frame with type casting + frame = mock.load("a") + + # Verify it's a LidarMessage object + assert frame.__class__.__name__ == "LidarMessage" + assert hasattr(frame, "timestamp") + assert hasattr(frame, "origin") + assert hasattr(frame, "resolution") + assert hasattr(frame, "pointcloud") + + # Verify pointcloud has points + assert frame.pointcloud.has_points() + assert len(frame.pointcloud.points) > 0 + + +@pytest.mark.needsdata +def test_mock_iterate() -> None: + """Test the iterate method of the Mock class.""" + mock = Mock("office") + + # Test iterate method + frames = list(mock.iterate()) + assert len(frames) > 0 + for frame in frames: + assert isinstance(frame, LidarMessage) + assert frame.pointcloud.has_points() + + +@pytest.mark.needsdata +def test_mock_stream() -> None: + frames = [] + sub1 = Mock("office").stream(rate_hz=30.0).subscribe(on_next=frames.append) + time.sleep(0.1) + sub1.dispose() + + assert len(frames) >= 2 + assert isinstance(frames[0], LidarMessage) diff --git a/dimos/robot/unitree_webrtc/testing/test_tooling.py b/dimos/robot/unitree_webrtc/testing/test_tooling.py new file mode 100644 index 0000000000..456d600879 --- /dev/null +++ b/dimos/robot/unitree_webrtc/testing/test_tooling.py @@ -0,0 +1,37 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.reactive import backpressure +from dimos.utils.testing import TimedSensorReplay + + +@pytest.mark.tool +def test_replay_all() -> None: + lidar_store = TimedSensorReplay("unitree/lidar", autocast=LidarMessage.from_msg) + odom_store = TimedSensorReplay("unitree/odom", autocast=Odometry.from_msg) + video_store = TimedSensorReplay("unitree/video") + + backpressure(odom_store.stream()).subscribe(print) + backpressure(lidar_store.stream()).subscribe(print) + backpressure(video_store.stream()).subscribe(print) + + print("Replaying for 3 seconds...") + time.sleep(3) + print("Stopping replay after 3 seconds") diff --git a/dimos/robot/unitree_webrtc/type/__init__.py b/dimos/robot/unitree_webrtc/type/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py new file mode 100644 index 0000000000..b598373a09 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -0,0 +1,131 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import TypedDict + +import numpy as np +import open3d as o3d # type: ignore[import-untyped] + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.types.timestamped import to_human_readable + + +class RawLidarPoints(TypedDict): + points: np.ndarray # type: ignore[type-arg] # Shape (N, 3) array of 3D points [x, y, z] + + +class RawLidarData(TypedDict): + """Data portion of the LIDAR message""" + + frame_id: str + origin: list[float] + resolution: float + src_size: int + stamp: float + width: list[int] + data: RawLidarPoints + + +class RawLidarMsg(TypedDict): + """Static type definition for raw LIDAR message""" + + type: str + topic: str + data: RawLidarData + + +class LidarMessage(PointCloud2): + resolution: float # we lose resolution when encoding PointCloud2 + origin: Vector3 + raw_msg: RawLidarMsg | None + # _costmap: Optional[Costmap] = None # TODO: Fix after costmap migration + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__( + pointcloud=kwargs.get("pointcloud"), + ts=kwargs.get("ts"), + frame_id="world", + ) + + self.origin = kwargs.get("origin") # type: ignore[assignment] + self.resolution = kwargs.get("resolution", 0.05) + + @classmethod + def from_msg(cls: type["LidarMessage"], raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": # type: ignore[no-untyped-def] + data = raw_message["data"] + points = data["data"]["points"] + pointcloud = o3d.geometry.PointCloud() + pointcloud.points = o3d.utility.Vector3dVector(points) + + origin = Vector3(data["origin"]) + # webrtc decoding via native decompression doesn't require us + # to shift the pointcloud by it's origin + # + # pointcloud.translate((origin / 2).to_tuple()) + cls_data = { + "origin": origin, + "resolution": data["resolution"], + "pointcloud": pointcloud, + # - this is broken in unitree webrtc api "stamp":1.758148e+09 + "ts": time.time(), # data["stamp"], + "raw_msg": raw_message, + **kwargs, + } + return cls(**cls_data) + + def __repr__(self) -> str: + return f"LidarMessage(ts={to_human_readable(self.ts)}, origin={self.origin}, resolution={self.resolution}, {self.pointcloud})" + + def __iadd__(self, other: "LidarMessage") -> "LidarMessage": # type: ignore[override] + self.pointcloud += other.pointcloud + return self + + def __add__(self, other: "LidarMessage") -> "LidarMessage": # type: ignore[override] + # Determine which message is more recent + if self.ts >= other.ts: + ts = self.ts + origin = self.origin + resolution = self.resolution + else: + ts = other.ts + origin = other.origin + resolution = other.resolution + + # Return a new LidarMessage with combined data + return LidarMessage( # type: ignore[attr-defined, no-any-return] + ts=ts, + origin=origin, + resolution=resolution, + pointcloud=self.pointcloud + other.pointcloud, + ).estimate_normals() + + @property + def o3d_geometry(self): # type: ignore[no-untyped-def] + return self.pointcloud + + # TODO: Fix after costmap migration + # def costmap(self, voxel_size: float = 0.2) -> Costmap: + # if not self._costmap: + # down_sampled_pointcloud = self.pointcloud.voxel_down_sample(voxel_size=voxel_size) + # inflate_radius_m = 1.0 * voxel_size if voxel_size > self.resolution else 0.0 + # grid, origin_xy = pointcloud_to_costmap( + # down_sampled_pointcloud, + # resolution=self.resolution, + # inflate_radius_m=inflate_radius_m, + # ) + # self._costmap = Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.resolution) + # + # return self._costmap diff --git a/dimos/robot/unitree_webrtc/type/lowstate.py b/dimos/robot/unitree_webrtc/type/lowstate.py new file mode 100644 index 0000000000..3e7926424a --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/lowstate.py @@ -0,0 +1,93 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal, TypedDict + +raw_odom_msg_sample = { + "type": "msg", + "topic": "rt/lf/lowstate", + "data": { + "imu_state": {"rpy": [0.008086, -0.007515, 2.981771]}, + "motor_state": [ + {"q": 0.098092, "temperature": 40, "lost": 0, "reserve": [0, 674]}, + {"q": 0.757921, "temperature": 32, "lost": 0, "reserve": [0, 674]}, + {"q": -1.490911, "temperature": 38, "lost": 6, "reserve": [0, 674]}, + {"q": -0.072477, "temperature": 42, "lost": 0, "reserve": [0, 674]}, + {"q": 1.020276, "temperature": 32, "lost": 5, "reserve": [0, 674]}, + {"q": -2.007172, "temperature": 38, "lost": 5, "reserve": [0, 674]}, + {"q": 0.071382, "temperature": 50, "lost": 5, "reserve": [0, 674]}, + {"q": 0.963379, "temperature": 36, "lost": 6, "reserve": [0, 674]}, + {"q": -1.978311, "temperature": 40, "lost": 5, "reserve": [0, 674]}, + {"q": -0.051066, "temperature": 48, "lost": 0, "reserve": [0, 674]}, + {"q": 0.73103, "temperature": 34, "lost": 10, "reserve": [0, 674]}, + {"q": -1.466473, "temperature": 38, "lost": 6, "reserve": [0, 674]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + {"q": 0, "temperature": 0, "lost": 0, "reserve": [0, 0]}, + ], + "bms_state": { + "version_high": 1, + "version_low": 18, + "soc": 55, + "current": -2481, + "cycle": 56, + "bq_ntc": [30, 29], + "mcu_ntc": [33, 32], + }, + "foot_force": [97, 84, 81, 81], + "temperature_ntc1": 48, + "power_v": 28.331045, + }, +} + + +class MotorState(TypedDict): + q: float + temperature: int + lost: int + reserve: list[int] + + +class ImuState(TypedDict): + rpy: list[float] + + +class BmsState(TypedDict): + version_high: int + version_low: int + soc: int + current: int + cycle: int + bq_ntc: list[int] + mcu_ntc: list[int] + + +class LowStateData(TypedDict): + imu_state: ImuState + motor_state: list[MotorState] + bms_state: BmsState + foot_force: list[int] + temperature_ntc1: int + power_v: float + + +class LowStateMsg(TypedDict): + type: Literal["msg"] + topic: str + data: LowStateData diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py new file mode 100644 index 0000000000..51c8a65e2c --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -0,0 +1,184 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import open3d as o3d # type: ignore[import-untyped] +from reactivex import interval +from reactivex.disposable import Disposable + +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, rpc +from dimos.core.global_config import GlobalConfig +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.robot.unitree.connection.go2 import Go2ConnectionProtocol +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + + +class Map(Module): + lidar: In[LidarMessage] = None # type: ignore[assignment] + global_map: Out[LidarMessage] = None # type: ignore[assignment] + global_costmap: Out[OccupancyGrid] = None # type: ignore[assignment] + local_costmap: Out[OccupancyGrid] = None # type: ignore[assignment] + + pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() + + def __init__( # type: ignore[no-untyped-def] + self, + voxel_size: float = 0.05, + cost_resolution: float = 0.05, + global_publish_interval: float | None = None, + min_height: float = 0.15, + max_height: float = 0.6, + global_config: GlobalConfig | None = None, + **kwargs, + ) -> None: + self.voxel_size = voxel_size + self.cost_resolution = cost_resolution + self.global_publish_interval = global_publish_interval + self.min_height = min_height + self.max_height = max_height + + if global_config: + if global_config.simulation: + self.min_height = 0.3 + + super().__init__(**kwargs) + + @rpc + def start(self) -> None: + super().start() + + unsub = self.lidar.subscribe(self.add_frame) + self._disposables.add(Disposable(unsub)) + + def publish(_) -> None: # type: ignore[no-untyped-def] + self.global_map.publish(self.to_lidar_message()) + + # temporary, not sure if it belogs in mapper + # used only for visualizations, not for any algo + occupancygrid = OccupancyGrid.from_pointcloud( + self.to_lidar_message(), + resolution=self.cost_resolution, + min_height=self.min_height, + max_height=self.max_height, + ) + + self.global_costmap.publish(occupancygrid) + + if self.global_publish_interval is not None: + unsub = interval(self.global_publish_interval).subscribe(publish) # type: ignore[assignment] + self._disposables.add(unsub) # type: ignore[arg-type] + + @rpc + def stop(self) -> None: + super().stop() + + def to_PointCloud2(self) -> PointCloud2: + return PointCloud2( + pointcloud=self.pointcloud, + ts=time.time(), + ) + + def to_lidar_message(self) -> LidarMessage: + return LidarMessage( + pointcloud=self.pointcloud, + origin=[0.0, 0.0, 0.0], + resolution=self.voxel_size, + ts=time.time(), + ) + + @rpc + def add_frame(self, frame: LidarMessage) -> "Map": # type: ignore[return] + """Voxelise *frame* and splice it into the running map.""" + new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) + + # Skip for empty pointclouds. + if len(new_pct.points) == 0: + return self + + self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) + local_costmap = OccupancyGrid.from_pointcloud( + frame, + resolution=self.cost_resolution, + min_height=0.15, + max_height=0.6, + ).gradient(max_distance=0.25) + self.local_costmap.publish(local_costmap) + + @property + def o3d_geometry(self) -> o3d.geometry.PointCloud: + return self.pointcloud + + +def splice_sphere( + map_pcd: o3d.geometry.PointCloud, + patch_pcd: o3d.geometry.PointCloud, + shrink: float = 0.95, +) -> o3d.geometry.PointCloud: + center = patch_pcd.get_center() + radius = np.linalg.norm(np.asarray(patch_pcd.points) - center, axis=1).max() * shrink + dists = np.linalg.norm(np.asarray(map_pcd.points) - center, axis=1) + victims = np.nonzero(dists < radius)[0] + survivors = map_pcd.select_by_index(victims, invert=True) + return survivors + patch_pcd + + +def splice_cylinder( + map_pcd: o3d.geometry.PointCloud, + patch_pcd: o3d.geometry.PointCloud, + axis: int = 2, + shrink: float = 0.95, +) -> o3d.geometry.PointCloud: + center = patch_pcd.get_center() + patch_pts = np.asarray(patch_pcd.points) + + # Axes perpendicular to cylinder + axes = [0, 1, 2] + axes.remove(axis) + + planar_dists = np.linalg.norm(patch_pts[:, axes] - center[axes], axis=1) + radius = planar_dists.max() * shrink + + axis_min = (patch_pts[:, axis].min() - center[axis]) * shrink + center[axis] + axis_max = (patch_pts[:, axis].max() - center[axis]) * shrink + center[axis] + + map_pts = np.asarray(map_pcd.points) + planar_dists_map = np.linalg.norm(map_pts[:, axes] - center[axes], axis=1) + + victims = np.nonzero( + (planar_dists_map < radius) + & (map_pts[:, axis] >= axis_min) + & (map_pts[:, axis] <= axis_max) + )[0] + + survivors = map_pcd.select_by_index(victims, invert=True) + return survivors + patch_pcd + + +mapper = Map.blueprint + + +def deploy(dimos: DimosCluster, connection: Go2ConnectionProtocol): # type: ignore[no-untyped-def] + mapper = dimos.deploy(Map, global_publish_interval=1.0) # type: ignore[attr-defined] + mapper.global_map.transport = LCMTransport("/global_map", LidarMessage) + mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) + mapper.local_costmap.transport = LCMTransport("/local_costmap", OccupancyGrid) + mapper.lidar.connect(connection.pointcloud) # type: ignore[attr-defined] + mapper.start() + return mapper + + +__all__ = ["Map", "mapper"] diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py new file mode 100644 index 0000000000..59f8ed17f7 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -0,0 +1,105 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from typing import Literal, TypedDict + +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.robot.unitree_webrtc.type.timeseries import ( + Timestamped, +) + +raw_odometry_msg_sample = { + "type": "msg", + "topic": "rt/utlidar/robot_pose", + "data": { + "header": {"stamp": {"sec": 1746565669, "nanosec": 448350564}, "frame_id": "odom"}, + "pose": { + "position": {"x": 5.961965, "y": -2.916958, "z": 0.319509}, + "orientation": {"x": 0.002787, "y": -0.000902, "z": -0.970244, "w": -0.242112}, + }, + }, +} + + +class TimeStamp(TypedDict): + sec: int + nanosec: int + + +class Header(TypedDict): + stamp: TimeStamp + frame_id: str + + +class RawPosition(TypedDict): + x: float + y: float + z: float + + +class Orientation(TypedDict): + x: float + y: float + z: float + w: float + + +class PoseData(TypedDict): + position: RawPosition + orientation: Orientation + + +class OdometryData(TypedDict): + header: Header + pose: PoseData + + +class RawOdometryMessage(TypedDict): + type: Literal["msg"] + topic: str + data: OdometryData + + +class Odometry(PoseStamped, Timestamped): # type: ignore[misc] + name = "geometry_msgs.PoseStamped" + + def __init__(self, frame_id: str = "base_link", *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(frame_id=frame_id, *args, **kwargs) # type: ignore[misc] + + @classmethod + def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": + pose = msg["data"]["pose"] + + # Extract position + pos = Vector3( + pose["position"].get("x"), + pose["position"].get("y"), + pose["position"].get("z"), + ) + + rot = Quaternion( + pose["orientation"].get("x"), + pose["orientation"].get("y"), + pose["orientation"].get("z"), + pose["orientation"].get("w"), + ) + + # ts = to_timestamp(msg["data"]["header"]["stamp"]) + # lidar / video timestamps are not available from the robot + # so we are deferring to local time for everything + ts = time.time() + return Odometry(position=pos, orientation=rot, ts=ts, frame_id="world") + + def __repr__(self) -> str: + return f"Odom pos({self.position}), rot({self.orientation})" diff --git a/dimos/robot/unitree_webrtc/type/test_lidar.py b/dimos/robot/unitree_webrtc/type/test_lidar.py new file mode 100644 index 0000000000..0ad918409b --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_lidar.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.testing import SensorReplay + + +def test_init() -> None: + lidar = SensorReplay("office_lidar") + + for raw_frame in itertools.islice(lidar.iterate(), 5): + assert isinstance(raw_frame, dict) + frame = LidarMessage.from_msg(raw_frame) + assert isinstance(frame, LidarMessage) diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py new file mode 100644 index 0000000000..e3d2655266 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -0,0 +1,100 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.robot.unitree_webrtc.testing.helpers import show3d +from dimos.robot.unitree_webrtc.testing.mock import Mock +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map, splice_sphere +from dimos.utils.testing import SensorReplay + + +@pytest.mark.vis +def test_costmap_vis() -> None: + map = Map() + map.start() + mock = Mock("office") + frames = list(mock.iterate()) + + for frame in frames: + print(frame) + map.add_frame(frame) + + # Get global map and costmap + global_map = map.to_lidar_message() + print(f"Global map has {len(global_map.pointcloud.points)} points") + show3d(global_map.pointcloud, title="Global Map").run() + + +@pytest.mark.vis +def test_reconstruction_with_realtime_vis() -> None: + map = Map() + map.start() + mock = Mock("office") + + # Process frames and visualize final map + for frame in mock.iterate(): + map.add_frame(frame) + + show3d(map.pointcloud, title="Reconstructed Map").run() + + +@pytest.mark.vis +def test_splice_vis() -> None: + mock = Mock("test") + target = mock.load("a") + insert = mock.load("b") + show3d(splice_sphere(target.pointcloud, insert.pointcloud, shrink=0.7)).run() + + +@pytest.mark.vis +def test_robot_vis() -> None: + map = Map() + map.start() + mock = Mock("office") + + # Process all frames + for frame in mock.iterate(): + map.add_frame(frame) + + show3d(map.pointcloud, title="global dynamic map test").run() + + +def test_robot_mapping() -> None: + lidar_replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + map = Map(voxel_size=0.5) + + # Mock the output streams to avoid publishing errors + class MockStream: + def publish(self, msg) -> None: + pass # Do nothing + + map.local_costmap = MockStream() + map.global_costmap = MockStream() + map.global_map = MockStream() + + # Process all frames from replay + for frame in lidar_replay.iterate(): + map.add_frame(frame) + + # Check the built map + global_map = map.to_lidar_message() + pointcloud = global_map.pointcloud + + # Verify map has points + assert len(pointcloud.points) > 0 + print(f"Map contains {len(pointcloud.points)} points") + + map._close_module() diff --git a/dimos/robot/unitree_webrtc/type/test_odometry.py b/dimos/robot/unitree_webrtc/type/test_odometry.py new file mode 100644 index 0000000000..e277455cdd --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_odometry.py @@ -0,0 +1,81 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from operator import add, sub + +import pytest +import reactivex.operators as ops + +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.testing import SensorReplay + +_EXPECTED_TOTAL_RAD = -4.05212 + + +def test_dataset_size() -> None: + """Ensure the replay contains the expected number of messages.""" + assert sum(1 for _ in SensorReplay(name="raw_odometry_rotate_walk").iterate()) == 179 + + +def test_odometry_conversion_and_count() -> None: + """Each replay entry converts to :class:`Odometry` and count is correct.""" + for raw in SensorReplay(name="raw_odometry_rotate_walk").iterate(): + odom = Odometry.from_msg(raw) + assert isinstance(raw, dict) + assert isinstance(odom, Odometry) + + +def test_last_yaw_value() -> None: + """Verify yaw of the final message (regression guard).""" + last_msg = SensorReplay(name="raw_odometry_rotate_walk").stream().pipe(ops.last()).run() + + assert last_msg is not None, "Replay is empty" + assert last_msg["data"]["pose"]["orientation"] == { + "x": 0.01077, + "y": 0.008505, + "z": 0.499171, + "w": -0.866395, + } + + +def test_total_rotation_travel_iterate() -> None: + total_rad = 0.0 + prev_yaw: float | None = None + + for odom in SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg).iterate(): + yaw = odom.orientation.radians.z + if prev_yaw is not None: + diff = yaw - prev_yaw + total_rad += diff + prev_yaw = yaw + + assert total_rad == pytest.approx(_EXPECTED_TOTAL_RAD, abs=0.001) + + +def test_total_rotation_travel_rxpy() -> None: + total_rad = ( + SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg) + .stream() + .pipe( + ops.map(lambda odom: odom.orientation.radians.z), + ops.pairwise(), # [1,2,3,4] -> [[1,2], [2,3], [3,4]] + ops.starmap(sub), # [sub(1,2), sub(2,3), sub(3,4)] + ops.reduce(add), + ) + .run() + ) + + assert total_rad == pytest.approx(4.05, abs=0.01) diff --git a/dimos/robot/unitree_webrtc/type/test_timeseries.py b/dimos/robot/unitree_webrtc/type/test_timeseries.py new file mode 100644 index 0000000000..2c7606d9f2 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/test_timeseries.py @@ -0,0 +1,44 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta + +from dimos.robot.unitree_webrtc.type.timeseries import TEvent, TList + +fixed_date = datetime(2025, 5, 13, 15, 2, 5).astimezone() +start_event = TEvent(fixed_date, 1) +end_event = TEvent(fixed_date + timedelta(seconds=10), 9) + +sample_list = TList([start_event, TEvent(fixed_date + timedelta(seconds=2), 5), end_event]) + + +def test_repr() -> None: + assert ( + str(sample_list) + == "Timeseries(date=2025-05-13, start=15:02:05, end=15:02:15, duration=0:00:10, events=3, freq=0.30Hz)" + ) + + +def test_equals() -> None: + assert start_event == TEvent(start_event.ts, 1) + assert start_event != TEvent(start_event.ts, 2) + assert start_event != TEvent(start_event.ts + timedelta(seconds=1), 1) + + +def test_range() -> None: + assert sample_list.time_range() == (start_event.ts, end_event.ts) + + +def test_duration() -> None: + assert sample_list.duration() == timedelta(seconds=10) diff --git a/dimos/robot/unitree_webrtc/type/timeseries.py b/dimos/robot/unitree_webrtc/type/timeseries.py new file mode 100644 index 0000000000..b75a41b932 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/timeseries.py @@ -0,0 +1,149 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar, Union + +if TYPE_CHECKING: + from collections.abc import Iterable + +PAYLOAD = TypeVar("PAYLOAD") + + +class RosStamp(TypedDict): + sec: int + nanosec: int + + +EpochLike = Union[int, float, datetime, RosStamp] + + +def from_ros_stamp(stamp: dict[str, int], tz: timezone | None = None) -> datetime: + """Convert ROS-style timestamp {'sec': int, 'nanosec': int} to datetime.""" + return datetime.fromtimestamp(stamp["sec"] + stamp["nanosec"] / 1e9, tz=tz) + + +def to_human_readable(ts: EpochLike) -> str: + dt = to_datetime(ts) + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def to_datetime(ts: EpochLike, tz: timezone | None = None) -> datetime: + if isinstance(ts, datetime): + # if ts.tzinfo is None: + # ts = ts.astimezone(tz) + return ts + if isinstance(ts, int | float): + return datetime.fromtimestamp(ts, tz=tz) + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return datetime.fromtimestamp(ts["sec"] + ts["nanosec"] / 1e9, tz=tz) + raise TypeError("unsupported timestamp type") + + +class Timestamped(ABC): + """Abstract class for an event with a timestamp.""" + + ts: datetime + + def __init__(self, ts: EpochLike) -> None: + self.ts = to_datetime(ts) + + +class TEvent(Timestamped, Generic[PAYLOAD]): + """Concrete class for an event with a timestamp and data.""" + + def __init__(self, timestamp: EpochLike, data: PAYLOAD) -> None: + super().__init__(timestamp) + self.data = data + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TEvent): + return NotImplemented + return self.ts == other.ts and self.data == other.data + + def __repr__(self) -> str: + return f"TEvent(ts={self.ts}, data={self.data})" + + +EVENT = TypeVar("EVENT", bound=Timestamped) # any object that is a subclass of Timestamped + + +class Timeseries(ABC, Generic[EVENT]): + """Abstract class for an iterable of events with timestamps.""" + + @abstractmethod + def __iter__(self) -> Iterable[EVENT]: ... + + @property + def start_time(self) -> datetime: + """Return the timestamp of the earliest event, assuming the data is sorted.""" + return next(iter(self)).ts # type: ignore[call-overload, no-any-return, type-var] + + @property + def end_time(self) -> datetime: + """Return the timestamp of the latest event, assuming the data is sorted.""" + return next(reversed(list(self))).ts # type: ignore[call-overload, no-any-return] + + @property + def frequency(self) -> float: + """Calculate the frequency of events in Hz.""" + return len(list(self)) / (self.duration().total_seconds() or 1) # type: ignore[call-overload] + + def time_range(self) -> tuple[datetime, datetime]: + """Return (earliest_ts, latest_ts). Empty input ⇒ ValueError.""" + return self.start_time, self.end_time + + def duration(self) -> timedelta: + """Total time spanned by the iterable (Δ = last - first).""" + return self.end_time - self.start_time + + def closest_to(self, timestamp: EpochLike) -> EVENT: + """Return the event closest to the given timestamp. Assumes timeseries is sorted.""" + print("closest to", timestamp) + target = to_datetime(timestamp) + print("converted to", target) + target_ts = target.timestamp() + + closest = None + min_dist = float("inf") + + for event in self: # type: ignore[attr-defined] + dist = abs(event.ts - target_ts) + if dist > min_dist: + break + + min_dist = dist + closest = event + + print(f"closest: {closest}") + return closest # type: ignore[return-value] + + def __repr__(self) -> str: + """Return a string representation of the Timeseries.""" + return f"Timeseries(date={self.start_time.strftime('%Y-%m-%d')}, start={self.start_time.strftime('%H:%M:%S')}, end={self.end_time.strftime('%H:%M:%S')}, duration={self.duration()}, events={len(list(self))}, freq={self.frequency:.2f}Hz)" # type: ignore[call-overload] + + def __str__(self) -> str: + """Return a string representation of the Timeseries.""" + return self.__repr__() + + +class TList(list[EVENT], Timeseries[EVENT]): + """A test class that inherits from both list and Timeseries.""" + + def __repr__(self) -> str: + """Return a string representation of the TList using Timeseries repr method.""" + return Timeseries.__repr__(self) diff --git a/dimos/robot/unitree_webrtc/type/vector.py b/dimos/robot/unitree_webrtc/type/vector.py new file mode 100644 index 0000000000..58438c0a98 --- /dev/null +++ b/dimos/robot/unitree_webrtc/type/vector.py @@ -0,0 +1,442 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +from collections.abc import Iterable +from typing import ( + Any, + Protocol, + TypeVar, + Union, + runtime_checkable, +) + +import numpy as np +from numpy.typing import NDArray + +T = TypeVar("T", bound="Vector") + + +class Vector: + """A wrapper around numpy arrays for vector operations with intuitive syntax.""" + + def __init__(self, *args: Any) -> None: + """Initialize a vector from components or another iterable. + + Examples: + Vector(1, 2) # 2D vector + Vector(1, 2, 3) # 3D vector + Vector([1, 2, 3]) # From list + Vector(np.array([1, 2, 3])) # From numpy array + """ + if len(args) == 1 and hasattr(args[0], "__iter__"): + self._data = np.array(args[0], dtype=float) + elif len(args) == 1: + self._data = np.array([args[0].x, args[0].y, args[0].z], dtype=float) + + else: + self._data = np.array(args, dtype=float) + + @property + def yaw(self) -> float: + return self.x + + @property + def tuple(self) -> tuple[float, ...]: + """Tuple representation of the vector.""" + return tuple(self._data) + + @property + def x(self) -> float: + """X component of the vector.""" + return self._data[0] if len(self._data) > 0 else 0.0 + + @property + def y(self) -> float: + """Y component of the vector.""" + return self._data[1] if len(self._data) > 1 else 0.0 + + @property + def z(self) -> float: + """Z component of the vector.""" + return self._data[2] if len(self._data) > 2 else 0.0 + + @property + def dim(self) -> int: + """Dimensionality of the vector.""" + return len(self._data) + + @property + def data(self) -> NDArray[np.float64]: + """Get the underlying numpy array.""" + return self._data + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, idx: int) -> float: + return float(self._data[idx]) + + def __iter__(self) -> Iterable[float]: + return iter(self._data) # type: ignore[no-any-return] + + def __repr__(self) -> str: + components = ",".join(f"{x:.6g}" for x in self._data) + return f"({components})" + + def __str__(self) -> str: + if self.dim < 2: + return self.__repr__() + + def getArrow() -> str: + repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] + + if self.y == 0 and self.x == 0: + return "·" + + # Calculate angle in radians and convert to directional index + angle = np.arctan2(self.y, self.x) + # Map angle to 0-7 index (8 directions) with proper orientation + dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) + # Get directional arrow symbol + return repr[dir_index] + + return f"{getArrow()} Vector {self.__repr__()}" + + def serialize(self) -> dict: # type: ignore[type-arg] + """Serialize the vector to a dictionary.""" + return {"type": "vector", "c": self._data.tolist()} + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Vector): + return np.array_equal(self._data, other._data) + return np.array_equal(self._data, np.array(other, dtype=float)) + + def __add__(self: T, other: Union["Vector", Iterable[float]]) -> T: + if isinstance(other, Vector): + return self.__class__(self._data + other._data) + return self.__class__(self._data + np.array(other, dtype=float)) + + def __sub__(self: T, other: Union["Vector", Iterable[float]]) -> T: + if isinstance(other, Vector): + return self.__class__(self._data - other._data) + return self.__class__(self._data - np.array(other, dtype=float)) + + def __mul__(self: T, scalar: float) -> T: + return self.__class__(self._data * scalar) + + def __rmul__(self: T, scalar: float) -> T: + return self.__mul__(scalar) + + def __truediv__(self: T, scalar: float) -> T: + return self.__class__(self._data / scalar) + + def __neg__(self: T) -> T: + return self.__class__(-self._data) + + def dot(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute dot product.""" + if isinstance(other, Vector): + return float(np.dot(self._data, other._data)) + return float(np.dot(self._data, np.array(other, dtype=float))) + + def cross(self: T, other: Union["Vector", Iterable[float]]) -> T: + """Compute cross product (3D vectors only).""" + if self.dim != 3: + raise ValueError("Cross product is only defined for 3D vectors") + + if isinstance(other, Vector): + other_data = other._data + else: + other_data = np.array(other, dtype=float) + + if len(other_data) != 3: + raise ValueError("Cross product requires two 3D vectors") + + return self.__class__(np.cross(self._data, other_data)) + + def length(self) -> float: + """Compute the Euclidean length (magnitude) of the vector.""" + return float(np.linalg.norm(self._data)) + + def length_squared(self) -> float: + """Compute the squared length of the vector (faster than length()).""" + return float(np.sum(self._data * self._data)) + + def normalize(self: T) -> T: + """Return a normalized unit vector in the same direction.""" + length = self.length() + if length < 1e-10: # Avoid division by near-zero + return self.__class__(np.zeros_like(self._data)) + return self.__class__(self._data / length) + + def to_2d(self: T) -> T: + """Convert a vector to a 2D vector by taking only the x and y components.""" + return self.__class__(self._data[:2]) + + def distance(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute Euclidean distance to another vector.""" + if isinstance(other, Vector): + return float(np.linalg.norm(self._data - other._data)) + return float(np.linalg.norm(self._data - np.array(other, dtype=float))) + + def distance_squared(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute squared Euclidean distance to another vector (faster than distance()).""" + if isinstance(other, Vector): + diff = self._data - other._data + else: + diff = self._data - np.array(other, dtype=float) + return float(np.sum(diff * diff)) + + def angle(self, other: Union["Vector", Iterable[float]]) -> float: + """Compute the angle (in radians) between this vector and another.""" + if self.length() < 1e-10 or (isinstance(other, Vector) and other.length() < 1e-10): + return 0.0 + + if isinstance(other, Vector): + other_data = other._data + else: + other_data = np.array(other, dtype=float) + + cos_angle = np.clip( + np.dot(self._data, other_data) + / (np.linalg.norm(self._data) * np.linalg.norm(other_data)), + -1.0, + 1.0, + ) + return float(np.arccos(cos_angle)) + + def project(self: T, onto: Union["Vector", Iterable[float]]) -> T: + """Project this vector onto another vector.""" + if isinstance(onto, Vector): + onto_data = onto._data + else: + onto_data = np.array(onto, dtype=float) + + onto_length_sq = np.sum(onto_data * onto_data) + if onto_length_sq < 1e-10: + return self.__class__(np.zeros_like(self._data)) + + scalar_projection = np.dot(self._data, onto_data) / onto_length_sq + return self.__class__(scalar_projection * onto_data) + + @classmethod + def zeros(cls: type[T], dim: int) -> T: + """Create a zero vector of given dimension.""" + return cls(np.zeros(dim)) + + @classmethod + def ones(cls: type[T], dim: int) -> T: + """Create a vector of ones with given dimension.""" + return cls(np.ones(dim)) + + @classmethod + def unit_x(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the x direction.""" + v = np.zeros(dim) + v[0] = 1.0 + return cls(v) + + @classmethod + def unit_y(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the y direction.""" + v = np.zeros(dim) + v[1] = 1.0 + return cls(v) + + @classmethod + def unit_z(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the z direction.""" + v = np.zeros(dim) + if dim > 2: + v[2] = 1.0 + return cls(v) + + def to_list(self) -> list[float]: + """Convert the vector to a list.""" + return [float(x) for x in self._data] + + def to_tuple(self) -> builtins.tuple[float, ...]: + """Convert the vector to a tuple.""" + return tuple(self._data) + + def to_numpy(self) -> NDArray[np.float64]: + """Convert the vector to a numpy array.""" + return self._data + + +# Protocol approach for static type checking +@runtime_checkable +class VectorLike(Protocol): + """Protocol for types that can be treated as vectors.""" + + def __getitem__(self, key: int) -> float: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterable[float]: ... + + +def to_numpy(value: VectorLike) -> NDArray[np.float64]: + """Convert a vector-compatible value to a numpy array. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Numpy array representation + """ + if isinstance(value, Vector): + return value.data + elif isinstance(value, np.ndarray): + return value + else: + return np.array(value, dtype=float) + + +def to_vector(value: VectorLike) -> Vector: + """Convert a vector-compatible value to a Vector object. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Vector object + """ + if isinstance(value, Vector): + return value + else: + return Vector(value) + + +def to_tuple(value: VectorLike) -> tuple[float, ...]: + """Convert a vector-compatible value to a tuple. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Tuple of floats + """ + if isinstance(value, Vector): + return tuple(float(x) for x in value.data) + elif isinstance(value, np.ndarray): + return tuple(float(x) for x in value) + elif isinstance(value, tuple): + return tuple(float(x) for x in value) + else: + # Convert to list first to ensure we have an indexable sequence + data = [value[i] for i in range(len(value))] + return tuple(float(x) for x in data) + + +def to_list(value: VectorLike) -> list[float]: + """Convert a vector-compatible value to a list. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + List of floats + """ + if isinstance(value, Vector): + return [float(x) for x in value.data] + elif isinstance(value, np.ndarray): + return [float(x) for x in value] + elif isinstance(value, list): + return [float(x) for x in value] + else: + # Convert to list using indexing + return [float(value[i]) for i in range(len(value))] + + +# Helper functions to check dimensionality +def is_2d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 2D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 2D + """ + if isinstance(value, Vector): + return len(value) == 2 + elif isinstance(value, np.ndarray): + return value.shape[-1] == 2 or value.size == 2 + else: + return len(value) == 2 + + +def is_3d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 3D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 3D + """ + if isinstance(value, Vector): + return len(value) == 3 + elif isinstance(value, np.ndarray): + return value.shape[-1] == 3 or value.size == 3 + else: + return len(value) == 3 + + +# Extraction functions for XYZ components +def x(value: VectorLike) -> float: + """Get the X component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + X component as a float + """ + if isinstance(value, Vector): + return value.x + else: + return float(to_numpy(value)[0]) + + +def y(value: VectorLike) -> float: + """Get the Y component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Y component as a float + """ + if isinstance(value, Vector): + return value.y + else: + arr = to_numpy(value) + return float(arr[1]) if len(arr) > 1 else 0.0 + + +def z(value: VectorLike) -> float: + """Get the Z component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Z component as a float + """ + if isinstance(value, Vector): + return value.z + else: + arr = to_numpy(value) + return float(arr[2]) if len(arr) > 2 else 0.0 diff --git a/dimos/robot/unitree_webrtc/unitree_b1/README.md b/dimos/robot/unitree_webrtc/unitree_b1/README.md new file mode 100644 index 0000000000..f59e6a57ae --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/README.md @@ -0,0 +1,219 @@ +# Unitree B1 Dimensional Integration + +This module provides UDP-based control for the Unitree B1 quadruped robot with DimOS integration with ROS Twist cmd_vel interface. + +## Overview + +The system consists of two components: +1. **Server Side**: C++ UDP server running on the B1's internal computer +2. **Client Side**: Python control module running on external machine + +Key features: +- 50Hz continuous UDP streaming +- 100ms command timeout for automatic stop +- Standard Twist velocity interface +- Emergency stop (Space/Q keys) +- IDLE/STAND/WALK mode control +- Optional pygame joystick interface + +## Server Side Setup (B1 Internal Computer) + +### Prerequisites + +The B1 robot runs Ubuntu with the following requirements: +- Unitree Legged SDK v3.8.3 for B1 +- Boost (>= 1.71.0) +- CMake (>= 3.16.3) +- g++ (>= 9.4.0) + +### Step 1: Connect to B1 Robot + +1. **Connect to B1's WiFi Access Point**: + - SSID: `Unitree_B1_XXXXX` (where XXXXX is your robot's ID) + - Password: `00000000` (8 zeros) + +2. **SSH into the B1**: + ```bash + ssh unitree@192.168.12.1 + # Default password: 123 + ``` + +### Step 2: Build the UDP Server + +1. **Add joystick_server_udp.cpp to CMakeLists.txt**: + ```bash + # Edit the CMakeLists.txt in the unitree_legged_sdk_B1 directory + vim CMakeLists.txt + + # Add this line with the other add_executable statements: + add_executable(joystick_server example/joystick_server_udp.cpp) + target_link_libraries(joystick_server ${EXTRA_LIBS})``` + +2. **Build the server**: + ```bash + mkdir build + cd build + cmake ../ + make + ``` + +### Step 3: Run the UDP Server + +```bash +# Navigate to build directory +cd Unitree/sdk/unitree_legged_sdk_B1/build/ +./joystick_server + +# You should see: +# UDP Unitree B1 Joystick Control Server +# Communication level: HIGH-level +# Server port: 9090 +# WARNING: Make sure the robot is standing on the ground. +# Press Enter to continue... +``` + +The server will now listen for UDP packets on port 9090 and control the B1 robot. + +### Server Safety Features + +- **100ms timeout**: Robot stops if no packets received for 100ms +- **Packet validation**: Only accepts correctly formatted 19-byte packets +- **Mode restrictions**: Velocities only applied in WALK mode +- **Emergency stop**: Mode 0 (IDLE) stops all movement + +## Client Side Setup (External Machine) + +### Prerequisites + +- Python 3.10+ +- DimOS framework installed +- pygame (optional, for joystick control) + +### Step 1: Install Dependencies + +```bash +# Install Dimensional +pip install -e .[cpu,sim] +``` + +### Step 2: Connect to B1 Network + +1. **Connect your machine to B1's WiFi**: + - SSID: `Unitree_B1_XXXXX` + - Password: `00000000` + +2. **Verify connection**: + ```bash + ping 192.168.12.1 # Should get responses + ``` + +### Step 3: Run the Client + +#### With Joystick Control (Recommended for Testing) + +```bash +python -m dimos.robot.unitree_webrtc.unitree_b1.unitree_b1 \ + --ip 192.168.12.1 \ + --port 9090 \ + --joystick +``` + +**Joystick Controls**: +- `0/1/2` - Switch between IDLE/STAND/WALK modes +- `WASD` - Move forward/backward, turn left/right (only in WALK mode) +- `JL` - Strafe left/right (only in WALK mode) +- `Space/Q` - Emergency stop (switches to IDLE) +- `ESC` - Quit pygame window +- `Ctrl+C` - Exit program + +#### Test Mode (No Robot Required) + +```bash +python -m dimos.robot.unitree_webrtc.unitree_b1.unitree_b1 \ + --test \ + --joystick +``` + +This prints commands instead of sending UDP packets - useful for development. + +## Safety Features + +### Client Side +- **Command freshness tracking**: Stops sending if no new commands for 100ms +- **Emergency stop**: Q or Space immediately sets IDLE mode +- **Mode safety**: Movement only allowed in WALK mode +- **Graceful shutdown**: Sends stop commands on exit + +### Server Side +- **Packet timeout**: Robot stops if no packets for 100ms +- **Continuous monitoring**: Checks timeout before every control update +- **Safe defaults**: Starts in IDLE mode +- **Packet validation**: Rejects malformed packets + +## Architecture + +``` +External Machine (Client) B1 Robot (Server) +┌─────────────────────┐ ┌──────────────────┐ +│ Joystick Module │ │ │ +│ (pygame input) │ │ joystick_server │ +│ ↓ │ │ _udp.cpp │ +│ Twist msg │ │ │ +│ ↓ │ WiFi AP │ │ +│ B1ConnectionModule │◄─────────►│ UDP Port 9090 │ +│ (Twist → B1Command) │ 192.168. │ │ +│ ↓ │ 12.1 │ │ +│ UDP packets 50Hz │ │ Unitree SDK │ +└─────────────────────┘ └──────────────────┘ +``` + +## Setting up ROS Navigation stack with Unitree B1 + +### Setup external Wireless USB Adapter on onboard hardware +This is because the onboard hardware (mini PC, jetson, etc.) needs to connect to both the B1 wifi AP network to send cmd_vel messages over UDP, as well as the network running dimensional + + +Plug in wireless adapter +```bash +nmcli device status +nmcli device wifi list ifname *DEVICE_NAME* +# Connect to b1 network +nmcli device wifi connect "Unitree_B1-251" password "00000000" ifname *DEVICE_NAME* +# Verify connection +nmcli connection show --active +``` + +### *TODO: add more docs* + + +## Troubleshooting + +### Cannot connect to B1 +- Ensure WiFi connection to B1's AP +- Check IP: should be `192.168.12.1` +- Verify server is running: `ssh unitree@192.168.12.1` + +### Robot not responding +- Verify server shows "Client connected" message +- Check robot is in WALK mode (press '2') +- Ensure no timeout messages in server output + +### Timeout issues +- Check network latency: `ping 192.168.12.1` +- Ensure 50Hz sending rate is maintained +- Look for "Command timeout" messages + +### Emergency situations +- Press Space or Q for immediate stop +- Use Ctrl+C to exit cleanly +- Robot auto-stops after 100ms without commands + +## Development Notes + +- Packets are 19 bytes: 4 floats + uint16 + uint8 +- Coordinate system: B1 uses different conventions, hence negations in `b1_command.py` +- LCM topics: `/cmd_vel` for Twist, `/b1/mode` for Int32 mode changes + +## License + +Copyright 2025 Dimensional Inc. Licensed under Apache License 2.0. diff --git a/dimos/robot/unitree_webrtc/unitree_b1/__init__.py b/dimos/robot/unitree_webrtc/unitree_b1/__init__.py new file mode 100644 index 0000000000..e6e5a0f04a --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. + +"""Unitree B1 robot module.""" + +from .unitree_b1 import UnitreeB1 + +__all__ = ["UnitreeB1"] diff --git a/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py b/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py new file mode 100644 index 0000000000..2da5f03a55 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Internal B1 command structure for UDP communication.""" + +import struct + +from pydantic import BaseModel, Field + + +class B1Command(BaseModel): + """Internal B1 robot command matching UDP packet structure. + + This is an internal type - external interfaces use standard Twist messages. + """ + + # Direct joystick values matching C++ NetworkJoystickCmd struct + lx: float = Field(default=0.0, ge=-1.0, le=1.0) # Turn velocity (left stick X) + ly: float = Field(default=0.0, ge=-1.0, le=1.0) # Forward/back velocity (left stick Y) + rx: float = Field(default=0.0, ge=-1.0, le=1.0) # Strafe velocity (right stick X) + ry: float = Field(default=0.0, ge=-1.0, le=1.0) # Pitch/height adjustment (right stick Y) + buttons: int = Field(default=0, ge=0, le=65535) # Button states (uint16) + mode: int = Field( + default=0, ge=0, le=255 + ) # Control mode (uint8): 0=idle, 1=stand, 2=walk, 6=recovery + + @classmethod + def from_twist(cls, twist, mode: int = 2): # type: ignore[no-untyped-def] + """Create B1Command from standard ROS Twist message. + + This is the key integration point for navigation and planning. + + Args: + twist: ROS Twist message with linear and angular velocities + mode: Robot mode (default is walk mode for navigation) + + Returns: + B1Command configured for the given Twist + """ + # Max velocities from ROS needed to clamp to joystick ranges properly + MAX_LINEAR_VEL = 1.0 # m/s + MAX_ANGULAR_VEL = 2.0 # rad/s + + if mode == 2: # WALK mode - velocity control + return cls( + # Scale and clamp to joystick range [-1, 1] + lx=max(-1.0, min(1.0, -twist.angular.z / MAX_ANGULAR_VEL)), + ly=max(-1.0, min(1.0, twist.linear.x / MAX_LINEAR_VEL)), + rx=max(-1.0, min(1.0, -twist.linear.y / MAX_LINEAR_VEL)), + ry=0.0, # No pitch control in walk mode + mode=mode, + ) + elif mode == 1: # STAND mode - body pose control + # Map Twist pose controls to B1 joystick axes + # Already in normalized units, just clamp to [-1, 1] + return cls( + lx=max(-1.0, min(1.0, -twist.angular.z)), # ROS yaw → B1 yaw + ly=max(-1.0, min(1.0, twist.linear.z)), # ROS height → B1 bodyHeight + rx=max(-1.0, min(1.0, -twist.angular.x)), # ROS roll → B1 roll + ry=max(-1.0, min(1.0, twist.angular.y)), # ROS pitch → B1 pitch + mode=mode, + ) + else: + # IDLE mode - no controls + return cls(mode=mode) + + def to_bytes(self) -> bytes: + """Pack to 19-byte UDP packet matching C++ struct. + + Format: 4 floats + uint16 + uint8 = 19 bytes (little-endian) + """ + return struct.pack(" str: + """Human-readable representation.""" + mode_names = {0: "IDLE", 1: "STAND", 2: "WALK", 6: "RECOVERY"} + mode_str = mode_names.get(self.mode, f"MODE_{self.mode}") + + if self.lx != 0 or self.ly != 0 or self.rx != 0 or self.ry != 0: + return f"B1Cmd[{mode_str}] LX:{self.lx:+.2f} LY:{self.ly:+.2f} RX:{self.rx:+.2f} RY:{self.ry:+.2f}" + else: + return f"B1Cmd[{mode_str}] (idle)" diff --git a/dimos/robot/unitree_webrtc/unitree_b1/connection.py b/dimos/robot/unitree_webrtc/unitree_b1/connection.py new file mode 100644 index 0000000000..b3393e3de8 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/connection.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""B1 Connection Module that accepts standard Twist commands and converts to UDP packets.""" + +import logging +import socket +import threading +import time + +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.std_msgs import Int32 +from dimos.utils.logging_config import setup_logger + +from .b1_command import B1Command + +# Setup logger with DEBUG level for troubleshooting +logger = setup_logger(level=logging.DEBUG) + + +class RobotMode: + """Constants for B1 robot modes.""" + + IDLE = 0 + STAND = 1 + WALK = 2 + RECOVERY = 6 + + +class B1ConnectionModule(Module): + """UDP connection module for B1 robot with standard Twist interface. + + Accepts standard ROS Twist messages on /cmd_vel and mode changes on /b1/mode, + internally converts to B1Command format, and sends UDP packets at 50Hz. + """ + + cmd_vel: In[TwistStamped] = None # type: ignore[assignment] # Timestamped velocity commands from ROS + mode_cmd: In[Int32] = None # type: ignore[assignment] # Mode changes + odom_in: In[Odometry] = None # type: ignore[assignment] # External odometry from ROS SLAM/lidar + + odom_pose: Out[PoseStamped] = None # type: ignore[assignment] # Converted pose for internal use + + def __init__( # type: ignore[no-untyped-def] + self, ip: str = "192.168.12.1", port: int = 9090, test_mode: bool = False, *args, **kwargs + ) -> None: + """Initialize B1 connection module. + + Args: + ip: Robot IP address + port: UDP port for joystick server + test_mode: If True, print commands instead of sending UDP + """ + Module.__init__(self, *args, **kwargs) + + self.ip = ip + self.port = port + self.test_mode = test_mode + self.current_mode = RobotMode.IDLE # Start in IDLE mode + self._current_cmd = B1Command(mode=RobotMode.IDLE) + self.cmd_lock = threading.Lock() # Thread lock for _current_cmd access + # Thread control + self.running = False + self.send_thread = None + self.socket = None + self.packet_count = 0 + self.last_command_time = time.time() + self.command_timeout = 0.2 # 200ms safety timeout + self.watchdog_thread = None + self.watchdog_running = False + self.timeout_active = False + + @rpc + def start(self) -> None: + """Start the connection and subscribe to command streams.""" + + super().start() + + # Setup UDP socket (unless in test mode) + if not self.test_mode: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # type: ignore[assignment] + logger.info(f"B1 Connection started - UDP to {self.ip}:{self.port} at 50Hz") + else: + logger.info(f"[TEST MODE] B1 Connection started - would send to {self.ip}:{self.port}") + + # Subscribe to input streams + if self.cmd_vel: + unsub = self.cmd_vel.subscribe(self.handle_twist_stamped) + self._disposables.add(Disposable(unsub)) + if self.mode_cmd: + unsub = self.mode_cmd.subscribe(self.handle_mode) + self._disposables.add(Disposable(unsub)) + if self.odom_in: + unsub = self.odom_in.subscribe(self._publish_odom_pose) + self._disposables.add(Disposable(unsub)) + + # Start threads + self.running = True + self.watchdog_running = True + + # Start 50Hz sending thread + self.send_thread = threading.Thread(target=self._send_loop, daemon=True) # type: ignore[assignment] + self.send_thread.start() # type: ignore[attr-defined] + + # Start watchdog thread + self.watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=True) # type: ignore[assignment] + self.watchdog_thread.start() # type: ignore[attr-defined] + + @rpc + def stop(self) -> None: + """Stop the connection and send stop commands.""" + + self.set_mode(RobotMode.IDLE) # IDLE + with self.cmd_lock: + self._current_cmd = B1Command(mode=RobotMode.IDLE) # Zero all velocities + + # Send multiple stop packets + if not self.test_mode and self.socket: + stop_cmd = B1Command(mode=RobotMode.IDLE) + for _ in range(5): + data = stop_cmd.to_bytes() + self.socket.sendto(data, (self.ip, self.port)) + time.sleep(0.02) + + self.running = False + self.watchdog_running = False + + if self.send_thread: + self.send_thread.join(timeout=0.5) + if self.watchdog_thread: + self.watchdog_thread.join(timeout=0.5) + + if self.socket: + self.socket.close() + self.socket = None + + super().stop() + + def handle_twist_stamped(self, twist_stamped: TwistStamped) -> None: + """Handle timestamped Twist message and convert to B1Command. + + This is called automatically when messages arrive on cmd_vel input. + """ + # Extract Twist from TwistStamped + twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) + + logger.debug( + f"Received cmd_vel: linear=({twist.linear.x:.3f}, {twist.linear.y:.3f}, {twist.linear.z:.3f}), angular=({twist.angular.x:.3f}, {twist.angular.y:.3f}, {twist.angular.z:.3f})" + ) + + # In STAND mode, all twist values control body pose, not movement + # W/S: height (linear.z), A/D: yaw (angular.z), J/L: roll (angular.x), I/K: pitch (angular.y) + if self.current_mode == RobotMode.STAND: + # In STAND mode, don't auto-switch since all inputs are valid body pose controls + has_movement = False + else: + # In other modes, consider linear x/y and angular.z as movement + has_movement = ( + abs(twist.linear.x) > 0.01 + or abs(twist.linear.y) > 0.01 + or abs(twist.angular.z) > 0.01 + ) + + if has_movement and self.current_mode not in (RobotMode.STAND, RobotMode.WALK): + logger.info("Auto-switching to WALK mode for ROS control") + self.set_mode(RobotMode.WALK) + elif not has_movement and self.current_mode == RobotMode.WALK: + logger.info("Auto-switching to IDLE mode (zero velocities)") + self.set_mode(RobotMode.IDLE) + + if self.test_mode: + logger.info( + f"[TEST] Received TwistStamped: linear=({twist.linear.x:.2f}, {twist.linear.y:.2f}), angular.z={twist.angular.z:.2f}" + ) + + with self.cmd_lock: + self._current_cmd = B1Command.from_twist(twist, self.current_mode) + + logger.debug(f"Converted to B1Command: {self._current_cmd}") + + self.last_command_time = time.time() + self.timeout_active = False # Reset timeout state since we got a new command + + def handle_mode(self, mode_msg: Int32) -> None: + """Handle mode change message. + + This is called automatically when messages arrive on mode_cmd input. + """ + logger.debug(f"Received mode change: {mode_msg.data}") + if self.test_mode: + logger.info(f"[TEST] Received mode change: {mode_msg.data}") + self.set_mode(mode_msg.data) + + @rpc + def set_mode(self, mode: int) -> bool: + """Set robot mode (0=idle, 1=stand, 2=walk, 6=recovery).""" + self.current_mode = mode + with self.cmd_lock: + self._current_cmd.mode = mode + + # Clear velocities when not in walk mode + if mode != RobotMode.WALK: + self._current_cmd.lx = 0.0 + self._current_cmd.ly = 0.0 + self._current_cmd.rx = 0.0 + self._current_cmd.ry = 0.0 + + mode_names = { + RobotMode.IDLE: "IDLE", + RobotMode.STAND: "STAND", + RobotMode.WALK: "WALK", + RobotMode.RECOVERY: "RECOVERY", + } + logger.info(f"Mode changed to: {mode_names.get(mode, mode)}") + if self.test_mode: + logger.info(f"[TEST] Mode changed to: {mode_names.get(mode, mode)}") + + return True + + def _send_loop(self) -> None: + """Continuously send current command at 50Hz. + + The watchdog thread handles timeout and zeroing commands, so this loop + just sends whatever is in self._current_cmd at 50Hz. + """ + while self.running: + try: + # Watchdog handles timeout, we just send current command + with self.cmd_lock: + cmd_to_send = self._current_cmd + + # Log status every second (50 packets) + if self.packet_count % 50 == 0: + logger.info( + f"Sending B1 commands at 50Hz | Mode: {self.current_mode} | Count: {self.packet_count}" + ) + if not self.test_mode: + logger.debug(f"Current B1Command: {self._current_cmd}") + data = cmd_to_send.to_bytes() + hex_str = " ".join(f"{b:02x}" for b in data) + logger.debug(f"UDP packet ({len(data)} bytes): {hex_str}") + + if self.socket: + data = cmd_to_send.to_bytes() + self.socket.sendto(data, (self.ip, self.port)) + + self.packet_count += 1 + + # 50Hz rate (20ms between packets) + time.sleep(0.020) + + except Exception as e: + if self.running: + logger.error(f"Send error: {e}") + + def _publish_odom_pose(self, msg: Odometry) -> None: + """Convert and publish odometry as PoseStamped. + + This matches G1's approach of receiving external odometry. + """ + if self.odom_pose: + pose_stamped = PoseStamped( + ts=msg.ts, + frame_id=msg.frame_id, + position=msg.pose.pose.position, + orientation=msg.pose.pose.orientation, + ) + self.odom_pose.publish(pose_stamped) + + def _watchdog_loop(self) -> None: + """Single watchdog thread that monitors command freshness.""" + while self.watchdog_running: + try: + time_since_last_cmd = time.time() - self.last_command_time + + if time_since_last_cmd > self.command_timeout: + if not self.timeout_active: + # First time detecting timeout + logger.warning( + f"Watchdog timeout ({time_since_last_cmd:.1f}s) - zeroing commands" + ) + if self.test_mode: + logger.info("[TEST] Watchdog timeout - zeroing commands") + + with self.cmd_lock: + self._current_cmd.lx = 0.0 + self._current_cmd.ly = 0.0 + self._current_cmd.rx = 0.0 + self._current_cmd.ry = 0.0 + + self.timeout_active = True + else: + if self.timeout_active: + logger.info("Watchdog: Commands resumed - control restored") + if self.test_mode: + logger.info("[TEST] Watchdog: Commands resumed") + self.timeout_active = False + + # Check every 50ms + time.sleep(0.05) + + except Exception as e: + if self.watchdog_running: + logger.error(f"Watchdog error: {e}") + + @rpc + def idle(self) -> bool: + """Set robot to idle mode.""" + self.set_mode(RobotMode.IDLE) + return True + + @rpc + def pose(self) -> bool: + """Set robot to stand/pose mode for reaching ground objects with manipulator.""" + self.set_mode(RobotMode.STAND) + return True + + @rpc + def walk(self) -> bool: + """Set robot to walk mode.""" + self.set_mode(RobotMode.WALK) + return True + + @rpc + def recovery(self) -> bool: + """Set robot to recovery mode.""" + self.set_mode(RobotMode.RECOVERY) + return True + + @rpc + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> bool: + """Direct RPC method for sending TwistStamped commands. + + Args: + twist_stamped: Timestamped velocity command + duration: Not used, kept for compatibility + """ + self.handle_twist_stamped(twist_stamped) + return True + + +class MockB1ConnectionModule(B1ConnectionModule): + """Test connection module that prints commands instead of sending UDP.""" + + def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + """Initialize test connection without creating socket.""" + super().__init__(ip, port, test_mode=True, *args, **kwargs) # type: ignore[misc] + + def _send_loop(self) -> None: + """Override to provide better test output with timeout detection.""" + timeout_warned = False + + while self.running: + time_since_last_cmd = time.time() - self.last_command_time + is_timeout = time_since_last_cmd > self.command_timeout + + # Show timeout transitions + if is_timeout and not timeout_warned: + logger.info( + f"[TEST] Command timeout! Sending zeros after {time_since_last_cmd:.1f}s" + ) + timeout_warned = True + elif not is_timeout and timeout_warned: + logger.info("[TEST] Commands resumed - control restored") + timeout_warned = False + + # Print current state every 0.5 seconds + if self.packet_count % 25 == 0: + if is_timeout: + logger.info(f"[TEST] B1Cmd[ZEROS] (timeout) | Count: {self.packet_count}") + else: + logger.info(f"[TEST] {self._current_cmd} | Count: {self.packet_count}") + + self.packet_count += 1 + time.sleep(0.020) + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() diff --git a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py new file mode 100644 index 0000000000..235278a0b9 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Pygame Joystick Module for testing B1 control via LCM.""" + +import os +import threading + +# Force X11 driver to avoid OpenGL threading issues +os.environ["SDL_VIDEODRIVER"] = "x11" + +import time + +from dimos.core import Module, Out, rpc +from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 +from dimos.msgs.std_msgs import Int32 + + +class JoystickModule(Module): + """Pygame-based joystick control module for B1 testing. + + Outputs timestamped Twist messages on /cmd_vel and mode changes on /b1/mode. + This allows testing the same interface that navigation will use. + """ + + twist_out: Out[TwistStamped] = None # type: ignore[assignment] # Timestamped velocity commands + mode_out: Out[Int32] = None # type: ignore[assignment] # Mode changes + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + Module.__init__(self, *args, **kwargs) + self.pygame_ready = False + self.running = False + self.current_mode = 0 # Start in IDLE mode for safety + + @rpc + def start(self) -> bool: + """Initialize pygame and start control loop.""" + + super().start() + + try: + import pygame + except ImportError: + print("ERROR: pygame not installed. Install with: pip install pygame") + return False + + self.keys_held = set() # type: ignore[var-annotated] + self.pygame_ready = True + self.running = True + + # Start pygame loop in background thread - ALL pygame ops will happen there + self._thread = threading.Thread(target=self._pygame_loop, daemon=True) + self._thread.start() + + return True + + @rpc + def stop(self) -> None: + """Stop the joystick module.""" + + self.running = False + self.pygame_ready = False + + # Send stop command + stop_twist = Twist() + stop_twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=stop_twist.linear, + angular=stop_twist.angular, + ) + self.twist_out.publish(stop_twist_stamped) + + self._thread.join(2) + + super().stop() + + def _pygame_loop(self) -> None: + """Main pygame event loop - ALL pygame operations happen here.""" + import pygame + + # Initialize pygame and create display IN THIS THREAD + pygame.init() + self.screen = pygame.display.set_mode((500, 400), pygame.SWSURFACE) + pygame.display.set_caption("B1 Joystick Control (LCM)") + self.clock = pygame.time.Clock() + self.font = pygame.font.Font(None, 24) + + print("JoystickModule started - Focus pygame window to control") + print("Controls:") + print(" Walk Mode: WASD = Move/Turn, JL = Strafe") + print(" Stand Mode: WASD = Height/Yaw, JL = Roll, IK = Pitch") + print(" 1/2/0 = Stand/Walk/Idle modes") + print(" Space/Q = Emergency Stop") + print(" ESC = Quit (or use Ctrl+C)") + + while self.running and self.pygame_ready: + for event in pygame.event.get(): + if event.type == pygame.QUIT: + self.running = False + elif event.type == pygame.KEYDOWN: + self.keys_held.add(event.key) + + # Mode changes - publish to mode_out for connection module + if event.key == pygame.K_0: + self.current_mode = 0 + mode_msg = Int32() + mode_msg.data = 0 + self.mode_out.publish(mode_msg) + print("Mode: IDLE") + elif event.key == pygame.K_1: + self.current_mode = 1 + mode_msg = Int32() + mode_msg.data = 1 + self.mode_out.publish(mode_msg) + print("Mode: STAND") + elif event.key == pygame.K_2: + self.current_mode = 2 + mode_msg = Int32() + mode_msg.data = 2 + self.mode_out.publish(mode_msg) + print("Mode: WALK") + elif event.key == pygame.K_SPACE or event.key == pygame.K_q: + self.keys_held.clear() + # Send IDLE mode for emergency stop + self.current_mode = 0 + mode_msg = Int32() + mode_msg.data = 0 + self.mode_out.publish(mode_msg) + # Also send zero twist + stop_twist = Twist() + stop_twist.linear = Vector3(0, 0, 0) + stop_twist.angular = Vector3(0, 0, 0) + stop_twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=stop_twist.linear, + angular=stop_twist.angular, + ) + self.twist_out.publish(stop_twist_stamped) + print("EMERGENCY STOP!") + elif event.key == pygame.K_ESCAPE: + # ESC still quits for development convenience + self.running = False + + elif event.type == pygame.KEYUP: + self.keys_held.discard(event.key) + + # Generate Twist message from held keys + twist = Twist() + twist.linear = Vector3(0, 0, 0) + twist.angular = Vector3(0, 0, 0) + + # Apply controls based on mode + if self.current_mode == 2: # WALK mode - movement control + # Forward/backward (W/S) + if pygame.K_w in self.keys_held: + twist.linear.x = 1.0 # Forward + if pygame.K_s in self.keys_held: + twist.linear.x = -1.0 # Backward + + # Turning (A/D) + if pygame.K_a in self.keys_held: + twist.angular.z = 1.0 # Turn left + if pygame.K_d in self.keys_held: + twist.angular.z = -1.0 # Turn right + + # Strafing (J/L) + if pygame.K_j in self.keys_held: + twist.linear.y = 1.0 # Strafe left + if pygame.K_l in self.keys_held: + twist.linear.y = -1.0 # Strafe right + + elif self.current_mode == 1: # STAND mode - body pose control + # Height control (W/S) - use linear.z for body height + if pygame.K_w in self.keys_held: + twist.linear.z = 1.0 # Raise body + if pygame.K_s in self.keys_held: + twist.linear.z = -1.0 # Lower body + + # Yaw control (A/D) - use angular.z for body yaw + if pygame.K_a in self.keys_held: + twist.angular.z = 1.0 # Rotate body left + if pygame.K_d in self.keys_held: + twist.angular.z = -1.0 # Rotate body right + + # Roll control (J/L) - use angular.x for body roll + if pygame.K_j in self.keys_held: + twist.angular.x = 1.0 # Roll left + if pygame.K_l in self.keys_held: + twist.angular.x = -1.0 # Roll right + + # Pitch control (I/K) - use angular.y for body pitch + if pygame.K_i in self.keys_held: + twist.angular.y = 1.0 # Pitch forward + if pygame.K_k in self.keys_held: + twist.angular.y = -1.0 # Pitch backward + + twist_stamped = TwistStamped( + ts=time.time(), frame_id="base_link", linear=twist.linear, angular=twist.angular + ) + self.twist_out.publish(twist_stamped) + + # Update pygame display + self._update_display(twist) + + # Maintain 50Hz rate + self.clock.tick(50) + + pygame.quit() + print("JoystickModule stopped") + + def _update_display(self, twist) -> None: # type: ignore[no-untyped-def] + """Update pygame window with current status.""" + import pygame + + self.screen.fill((30, 30, 30)) + + # Mode display + y_pos = 20 + mode_text = ["IDLE", "STAND", "WALK"][self.current_mode if self.current_mode < 3 else 0] + mode_color = ( + (0, 255, 0) + if self.current_mode == 2 + else (255, 255, 0) + if self.current_mode == 1 + else (100, 100, 100) + ) + + texts = [ + f"Mode: {mode_text}", + "", + f"Linear X: {twist.linear.x:+.2f}", + f"Linear Y: {twist.linear.y:+.2f}", + f"Linear Z: {twist.linear.z:+.2f}", + f"Angular X: {twist.angular.x:+.2f}", + f"Angular Y: {twist.angular.y:+.2f}", + f"Angular Z: {twist.angular.z:+.2f}", + "Keys: " + ", ".join([pygame.key.name(k).upper() for k in self.keys_held if k < 256]), + ] + + for i, text in enumerate(texts): + if text: + color = mode_color if i == 0 else (255, 255, 255) + surf = self.font.render(text, True, color) + self.screen.blit(surf, (20, y_pos)) + y_pos += 30 + + if ( + twist.linear.x != 0 + or twist.linear.y != 0 + or twist.linear.z != 0 + or twist.angular.x != 0 + or twist.angular.y != 0 + or twist.angular.z != 0 + ): + pygame.draw.circle(self.screen, (255, 0, 0), (450, 30), 15) # Red = moving + else: + pygame.draw.circle(self.screen, (0, 255, 0), (450, 30), 15) # Green = stopped + + y_pos = 300 + help_texts = ["WASD: Move | JL: Strafe | 1/2/0: Modes", "Space/Q: E-Stop | ESC: Quit"] + for text in help_texts: + surf = self.font.render(text, True, (150, 150, 150)) + self.screen.blit(surf, (20, y_pos)) + y_pos += 25 + + pygame.display.flip() diff --git a/dimos/robot/unitree_webrtc/unitree_b1/joystick_server_udp.cpp b/dimos/robot/unitree_webrtc/unitree_b1/joystick_server_udp.cpp new file mode 100644 index 0000000000..e86e999b8d --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/joystick_server_udp.cpp @@ -0,0 +1,366 @@ +/***************************************************************** + UDP Joystick Control Server for Unitree B1 Robot + With timeout protection and guaranteed packet boundaries +******************************************************************/ + +#include "unitree_legged_sdk/unitree_legged_sdk.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace UNITREE_LEGGED_SDK; + +// Joystick command structure received over network +struct NetworkJoystickCmd { + float lx; // left stick x (-1 to 1) + float ly; // left stick y (-1 to 1) + float rx; // right stick x (-1 to 1) + float ry; // right stick y (-1 to 1) + uint16_t buttons; // button states + uint8_t mode; // control mode +}; + +class JoystickServer { +public: + JoystickServer(uint8_t level, int server_port) : + safe(LeggedType::B1), + udp(level, 8090, "192.168.123.220", 8082), + server_port_(server_port), + running_(false) { + udp.InitCmdData(cmd); + memset(&joystick_cmd_, 0, sizeof(joystick_cmd_)); + joystick_cmd_.mode = 0; // Start in idle mode + last_packet_time_ = std::chrono::steady_clock::now(); + } + + void Start(); + void Stop(); + +private: + void UDPRecv(); + void UDPSend(); + void RobotControl(); + void NetworkServerThread(); + void ParseJoystickCommand(const NetworkJoystickCmd& net_cmd); + void CheckTimeout(); + + Safety safe; + UDP udp; + HighCmd cmd = {0}; + HighState state = {0}; + + NetworkJoystickCmd joystick_cmd_; + std::mutex cmd_mutex_; + + int server_port_; + int server_socket_; + bool running_; + std::thread server_thread_; + + // Client tracking for debug + struct sockaddr_in last_client_addr_; + bool has_client_ = false; + + // SAFETY: Timeout tracking + std::chrono::steady_clock::time_point last_packet_time_; + const int PACKET_TIMEOUT_MS = 100; // Stop if no packet for 100ms + + float dt = 0.002; + + // Control parameters + const float MAX_FORWARD_SPEED = 0.2f; // m/s + const float MAX_SIDE_SPEED = 0.2f; // m/s + const float MAX_YAW_SPEED = 0.2f; // rad/s + const float MAX_BODY_HEIGHT = 0.1f; // m + const float MAX_EULER_ANGLE = 0.3f; // rad + const float DEADZONE = 0.0f; // joystick deadzone +}; + +void JoystickServer::Start() { + running_ = true; + + // Start network server thread + server_thread_ = std::thread(&JoystickServer::NetworkServerThread, this); + + // Initialize environment + InitEnvironment(); + + // Start control loops + LoopFunc loop_control("control_loop", dt, boost::bind(&JoystickServer::RobotControl, this)); + LoopFunc loop_udpSend("udp_send", dt, 3, boost::bind(&JoystickServer::UDPSend, this)); + LoopFunc loop_udpRecv("udp_recv", dt, 3, boost::bind(&JoystickServer::UDPRecv, this)); + + loop_udpSend.start(); + loop_udpRecv.start(); + loop_control.start(); + + std::cout << "UDP Joystick server started on port " << server_port_ << std::endl; + std::cout << "Timeout protection: " << PACKET_TIMEOUT_MS << "ms" << std::endl; + std::cout << "Expected packet size: 19 bytes" << std::endl; + std::cout << "Robot control loops started" << std::endl; + + // Keep running + while (running_) { + sleep(1); + } +} + +void JoystickServer::Stop() { + running_ = false; + close(server_socket_); + if (server_thread_.joinable()) { + server_thread_.join(); + } +} + +void JoystickServer::NetworkServerThread() { + // Create UDP socket + server_socket_ = socket(AF_INET, SOCK_DGRAM, 0); + if (server_socket_ < 0) { + std::cerr << "Failed to create UDP socket" << std::endl; + return; + } + + // Allow socket reuse + int opt = 1; + setsockopt(server_socket_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + // Bind socket + struct sockaddr_in server_addr; + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(server_port_); + + if (bind(server_socket_, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { + std::cerr << "Failed to bind UDP socket to port " << server_port_ << std::endl; + close(server_socket_); + return; + } + + std::cout << "UDP server listening on port " << server_port_ << std::endl; + std::cout << "Waiting for joystick packets..." << std::endl; + + NetworkJoystickCmd net_cmd; + struct sockaddr_in client_addr; + socklen_t client_len; + + while (running_) { + client_len = sizeof(client_addr); + + // Receive UDP datagram (blocks until packet arrives) + ssize_t bytes = recvfrom(server_socket_, &net_cmd, sizeof(net_cmd), + 0, (struct sockaddr*)&client_addr, &client_len); + + if (bytes == 19) { + // Perfect packet size from Python client + if (!has_client_) { + std::cout << "Client connected from " << inet_ntoa(client_addr.sin_addr) + << ":" << ntohs(client_addr.sin_port) << std::endl; + has_client_ = true; + last_client_addr_ = client_addr; + } + ParseJoystickCommand(net_cmd); + } else if (bytes == sizeof(NetworkJoystickCmd)) { + // C++ client with padding (20 bytes) + if (!has_client_) { + std::cout << "C++ Client connected from " << inet_ntoa(client_addr.sin_addr) + << ":" << ntohs(client_addr.sin_port) << std::endl; + has_client_ = true; + last_client_addr_ = client_addr; + } + ParseJoystickCommand(net_cmd); + } else if (bytes > 0) { + // Wrong packet size - ignore but log + static int error_count = 0; + if (error_count++ < 5) { // Only log first 5 errors + std::cerr << "Ignored packet with wrong size: " << bytes + << " bytes (expected 19)" << std::endl; + } + } + // Note: recvfrom returns -1 on error, which we ignore + } +} + +void JoystickServer::ParseJoystickCommand(const NetworkJoystickCmd& net_cmd) { + std::lock_guard lock(cmd_mutex_); + joystick_cmd_ = net_cmd; + + // SAFETY: Update timestamp for timeout tracking + last_packet_time_ = std::chrono::steady_clock::now(); + + // Apply deadzone to analog sticks + if (fabs(joystick_cmd_.lx) < DEADZONE) joystick_cmd_.lx = 0; + if (fabs(joystick_cmd_.ly) < DEADZONE) joystick_cmd_.ly = 0; + if (fabs(joystick_cmd_.rx) < DEADZONE) joystick_cmd_.rx = 0; + if (fabs(joystick_cmd_.ry) < DEADZONE) joystick_cmd_.ry = 0; +} + +void JoystickServer::CheckTimeout() { + auto now = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast( + now - last_packet_time_).count(); + + static bool timeout_printed = false; + + if (elapsed > PACKET_TIMEOUT_MS) { + joystick_cmd_.lx = 0; + joystick_cmd_.ly = 0; + joystick_cmd_.rx = 0; + joystick_cmd_.ry = 0; + joystick_cmd_.buttons = 0; + + if (!timeout_printed) { + std::cout << "SAFETY: Packet timeout - stopping movement!" << std::endl; + timeout_printed = true; + } + } else { + // Reset flag when packets resume + if (timeout_printed) { + std::cout << "Packets resumed - control restored" << std::endl; + timeout_printed = false; + } + } +} + +void JoystickServer::UDPRecv() { + udp.Recv(); +} + +void JoystickServer::UDPSend() { + udp.Send(); +} + +void JoystickServer::RobotControl() { + udp.GetRecv(state); + + // SAFETY: Check for packet timeout + NetworkJoystickCmd current_cmd; + { + std::lock_guard lock(cmd_mutex_); + CheckTimeout(); // This may zero movement if timeout + current_cmd = joystick_cmd_; + } + + cmd.mode = 0; + cmd.gaitType = 0; + cmd.speedLevel = 0; + cmd.footRaiseHeight = 0; + cmd.bodyHeight = 0; + cmd.euler[0] = 0; + cmd.euler[1] = 0; + cmd.euler[2] = 0; + cmd.velocity[0] = 0.0f; + cmd.velocity[1] = 0.0f; + cmd.yawSpeed = 0.0f; + cmd.reserve = 0; + + // Set mode from joystick + cmd.mode = current_cmd.mode; + + // Map joystick to robot control based on mode + switch (current_cmd.mode) { + case 0: // Idle + // Robot stops + break; + + case 1: // Force stand with body control + // Left stick controls body height and yaw + cmd.bodyHeight = current_cmd.ly * MAX_BODY_HEIGHT; + cmd.euler[2] = current_cmd.lx * MAX_EULER_ANGLE; + + // Right stick controls pitch and roll + cmd.euler[1] = current_cmd.ry * MAX_EULER_ANGLE; + cmd.euler[0] = current_cmd.rx * MAX_EULER_ANGLE; + break; + + case 2: // Walk mode + cmd.velocity[0] = std::clamp(current_cmd.ly * MAX_FORWARD_SPEED, -MAX_FORWARD_SPEED, MAX_FORWARD_SPEED); + cmd.yawSpeed = std::clamp(-current_cmd.lx * MAX_YAW_SPEED, -MAX_YAW_SPEED, MAX_YAW_SPEED); + cmd.velocity[1] = std::clamp(-current_cmd.rx * MAX_SIDE_SPEED, -MAX_SIDE_SPEED, MAX_SIDE_SPEED); + + // Check button states for gait type + if (current_cmd.buttons & 0x0001) { // Button A + cmd.gaitType = 0; // Trot + } else if (current_cmd.buttons & 0x0002) { // Button B + cmd.gaitType = 1; // Trot running + } else if (current_cmd.buttons & 0x0004) { // Button X + cmd.gaitType = 2; // Climb mode + } else if (current_cmd.buttons & 0x0008) { // Button Y + cmd.gaitType = 3; // Trot obstacle + } + break; + + case 5: // Damping mode + case 6: // Recovery stand up + break; + + default: + cmd.mode = 0; // Default to idle for safety + break; + } + + // Debug output + static int counter = 0; + if (counter++ % 500 == 0) { // Print every second + auto now = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast( + now - last_packet_time_).count(); + + std::cout << "Mode: " << (int)cmd.mode + << " Vel: [" << cmd.velocity[0] << ", " << cmd.velocity[1] << "]" + << " Yaw: " << cmd.yawSpeed + << " Last packet: " << elapsed << "ms ago" + << " IMU: " << state.imu.rpy[2] << std::endl; + } + + udp.SetSend(cmd); +} + +// Signal handler for clean shutdown +JoystickServer* g_server = nullptr; + +void signal_handler(int sig) { + if (g_server) { + std::cout << "\nShutting down server..." << std::endl; + g_server->Stop(); + } + exit(0); +} + +int main(int argc, char* argv[]) { + int port = 9090; // Default port + + if (argc > 1) { + port = atoi(argv[1]); + } + + std::cout << "UDP Unitree B1 Joystick Control Server" << std::endl; + std::cout << "Communication level: HIGH-level" << std::endl; + std::cout << "Protocol: UDP (datagram)" << std::endl; + std::cout << "Server port: " << port << std::endl; + std::cout << "Packet size: 19 bytes (Python) or 20 bytes (C++)" << std::endl; + std::cout << "Update rate: 50Hz expected" << std::endl; + std::cout << "WARNING: Make sure the robot is standing on the ground." << std::endl; + std::cout << "Press Enter to continue..." << std::endl; + std::cin.ignore(); + + JoystickServer server(HIGHLEVEL, port); + g_server = &server; + + // Set up signal handler + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + server.Start(); + + return 0; +} diff --git a/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py new file mode 100644 index 0000000000..9970b9c151 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Comprehensive tests for Unitree B1 connection module Timer implementation.""" + +# TODO: These tests are reaching too much into `conn` by setting and shutting +# down threads manually. That code is already in the connection module, and +# should be used and tested. Additionally, tests should always use `try-finally` +# to clean up even if the test fails. + +import threading +import time + +from dimos.msgs.geometry_msgs import TwistStamped, Vector3 +from dimos.msgs.std_msgs.Int32 import Int32 + +from .connection import MockB1ConnectionModule + + +class TestB1Connection: + """Test suite for B1 connection module with Timer implementation.""" + + def test_watchdog_actually_zeros_commands(self) -> None: + """Test that watchdog thread zeros commands after timeout.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send a forward command + twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist_stamped) + + # Verify command is set + assert conn._current_cmd.ly == 1.0 + assert conn._current_cmd.mode == 2 + assert not conn.timeout_active + + # Wait for watchdog timeout (200ms + buffer) + time.sleep(0.3) + + # Verify commands were zeroed by watchdog + assert conn._current_cmd.ly == 0.0 + assert conn._current_cmd.lx == 0.0 + assert conn._current_cmd.rx == 0.0 + assert conn._current_cmd.ry == 0.0 + assert conn._current_cmd.mode == 2 # Mode maintained + assert conn.timeout_active + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_watchdog_resets_on_new_command(self) -> None: + """Test that watchdog timeout resets when new command arrives.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send first command + twist1 = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist1) + assert conn._current_cmd.ly == 1.0 + + # Wait 150ms (not enough to trigger timeout) + time.sleep(0.15) + + # Send second command before timeout + twist2 = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist2) + + # Command should be updated and no timeout + assert conn._current_cmd.ly == 0.5 + assert not conn.timeout_active + + # Wait another 150ms (total 300ms from second command) + time.sleep(0.15) + # Should still not timeout since we reset the timer + assert not conn.timeout_active + assert conn._current_cmd.ly == 0.5 + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_watchdog_thread_efficiency(self) -> None: + """Test that watchdog uses only one thread regardless of command rate.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Count threads before sending commands + initial_thread_count = threading.active_count() + + # Send many commands rapidly (would create many Timer threads in old implementation) + for i in range(50): + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(i * 0.01, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + time.sleep(0.01) # 100Hz command rate + + # Thread count should be same (no new threads created) + final_thread_count = threading.active_count() + assert final_thread_count == initial_thread_count, "No new threads should be created" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_watchdog_with_send_loop_blocking(self) -> None: + """Test that watchdog still works if send loop blocks.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + + # Mock the send loop to simulate blocking + original_send_loop = conn._send_loop + block_event = threading.Event() + + def blocking_send_loop() -> None: + # Block immediately + block_event.wait() + # Then run normally + original_send_loop() + + conn._send_loop = blocking_send_loop + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + assert conn._current_cmd.ly == 1.0 + + # Wait for watchdog timeout + time.sleep(0.3) + + # Watchdog should have zeroed commands despite blocked send loop + assert conn._current_cmd.ly == 0.0 + assert conn.timeout_active + + # Unblock send loop + block_event.set() + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_continuous_commands_prevent_timeout(self) -> None: + """Test that continuous commands prevent watchdog timeout.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send commands continuously for 500ms (should prevent timeout) + start = time.time() + commands_sent = 0 + while time.time() - start < 0.5: + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + commands_sent += 1 + time.sleep(0.05) # 50ms between commands (well under 200ms timeout) + + # Should never timeout + assert not conn.timeout_active, "Should not timeout with continuous commands" + assert conn._current_cmd.ly == 0.5, "Commands should still be active" + assert commands_sent >= 9, f"Should send at least 9 commands in 500ms, sent {commands_sent}" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_watchdog_timing_accuracy(self) -> None: + """Test that watchdog zeros commands at approximately 200ms.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send command and record time + start_time = time.time() + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + + # Wait for timeout checking periodically + timeout_time = None + while time.time() - start_time < 0.5: + if conn.timeout_active: + timeout_time = time.time() + break + time.sleep(0.01) + + assert timeout_time is not None, "Watchdog should timeout within 500ms" + + # Check timing (should be close to 200ms + up to 50ms watchdog interval) + elapsed = timeout_time - start_time + print(f"\nWatchdog timeout occurred at exactly {elapsed:.3f} seconds") + assert 0.19 <= elapsed <= 0.3, f"Watchdog timed out at {elapsed:.3f}s, expected ~0.2-0.25s" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_mode_changes_with_watchdog(self) -> None: + """Test that mode changes work correctly with watchdog.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Give threads time to initialize + time.sleep(0.05) + + # Send walk command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + assert conn.current_mode == 2 + assert conn._current_cmd.ly == 1.0 + + # Wait for timeout first (0.2s timeout + 0.15s margin for reliability) + time.sleep(0.35) + assert conn.timeout_active + assert conn._current_cmd.ly == 0.0 # Watchdog zeroed it + + # Now change mode to STAND + mode_msg = Int32() + mode_msg.data = 1 # STAND + conn.handle_mode(mode_msg) + assert conn.current_mode == 1 + assert conn._current_cmd.mode == 1 + # timeout_active stays true since we didn't send new movement commands + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_watchdog_stops_movement_when_commands_stop(self) -> None: + """Verify watchdog zeros commands when packets stop being sent.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Simulate sending movement commands for a while + for _i in range(5): + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0.5), # Forward and turning + ) + conn.handle_twist_stamped(twist) + time.sleep(0.05) # Send at 20Hz + + # Verify robot is moving + assert conn._current_cmd.ly == 1.0 + assert conn._current_cmd.lx == -0.25 # angular.z * 0.5 -> lx (for turning) + assert conn.current_mode == 2 # WALK mode + assert not conn.timeout_active + + # Wait for watchdog to detect timeout (200ms + buffer) + time.sleep(0.3) + + assert conn.timeout_active, "Watchdog should have detected timeout" + assert conn._current_cmd.ly == 0.0, "Forward velocity should be zeroed" + assert conn._current_cmd.lx == 0.0, "Lateral velocity should be zeroed" + assert conn._current_cmd.rx == 0.0, "Rotation X should be zeroed" + assert conn._current_cmd.ry == 0.0, "Rotation Y should be zeroed" + assert conn.current_mode == 2, "Mode should stay as WALK" + + # Verify recovery works - send new command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + + # Give watchdog time to detect recovery + time.sleep(0.1) + + assert not conn.timeout_active, "Should recover from timeout" + assert conn._current_cmd.ly == 0.5, "Should accept new commands" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() + + def test_rapid_command_thread_safety(self) -> None: + """Test thread safety with rapid commands from multiple threads.""" + conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Count initial threads + initial_threads = threading.active_count() + + # Send commands from multiple threads rapidly + def send_commands(thread_id) -> None: + for _i in range(10): + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(thread_id * 0.1, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + time.sleep(0.01) + + threads = [] + for i in range(3): + t = threading.Thread(target=send_commands, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Thread count should only increase by the 3 sender threads we created + # No additional Timer threads should be created + final_threads = threading.active_count() + assert final_threads <= initial_threads, "No extra threads should be created by watchdog" + + # Commands should still work correctly + assert conn._current_cmd.ly >= 0, "Last command should be set" + assert not conn.timeout_active, "Should not be in timeout with recent commands" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + conn._close_module() diff --git a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py new file mode 100644 index 0000000000..192f3ea672 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +""" +Unitree B1 quadruped robot with simplified UDP control. +Uses standard Twist interface for velocity commands. +""" + +import logging +import os + +from dimos import core +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.resource import Resource +from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.std_msgs import Int32 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage +from dimos.robot.robot import Robot +from dimos.robot.ros_bridge import BridgeDirection, ROSBridge +from dimos.robot.unitree_webrtc.unitree_b1.connection import ( + B1ConnectionModule, + MockB1ConnectionModule, +) +from dimos.skills.skills import SkillLibrary +from dimos.types.robot_capabilities import RobotCapability +from dimos.utils.logging_config import setup_logger + +# Handle ROS imports for environments where ROS is not available like CI +try: + from geometry_msgs.msg import TwistStamped as ROSTwistStamped # type: ignore[attr-defined] + from nav_msgs.msg import Odometry as ROSOdometry # type: ignore[attr-defined] + from tf2_msgs.msg import TFMessage as ROSTFMessage # type: ignore[attr-defined] + + ROS_AVAILABLE = True +except ImportError: + ROSTwistStamped = None # type: ignore[assignment, misc] + ROSOdometry = None # type: ignore[assignment, misc] + ROSTFMessage = None # type: ignore[assignment, misc] + ROS_AVAILABLE = False + +logger = setup_logger(level=logging.INFO) + + +class UnitreeB1(Robot, Resource): + """Unitree B1 quadruped robot with UDP control. + + Simplified architecture: + - Connection module handles Twist → B1Command conversion + - Standard /cmd_vel interface for navigation compatibility + - Optional joystick module for testing + """ + + def __init__( + self, + ip: str = "192.168.123.14", + port: int = 9090, + output_dir: str | None = None, + skill_library: SkillLibrary | None = None, + enable_joystick: bool = False, + enable_ros_bridge: bool = True, + test_mode: bool = False, + ) -> None: + """Initialize the B1 robot. + + Args: + ip: Robot IP address (or server running joystick_server_udp) + port: UDP port for joystick server (default 9090) + output_dir: Directory for saving outputs + skill_library: Skill library instance (optional) + enable_joystick: Enable pygame joystick control module + enable_ros_bridge: Enable ROS bridge for external control + test_mode: Test mode - print commands instead of sending UDP + """ + super().__init__() + self.ip = ip + self.port = port + self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") + self.enable_joystick = enable_joystick + self.enable_ros_bridge = enable_ros_bridge + self.test_mode = test_mode + self.capabilities = [RobotCapability.LOCOMOTION] + self.connection = None + self.joystick = None + self.ros_bridge = None + self._dimos = ModuleCoordinator(n=2) + + os.makedirs(self.output_dir, exist_ok=True) + logger.info(f"Robot outputs will be saved to: {self.output_dir}") + + def start(self) -> None: + """Start the B1 robot - initialize DimOS, deploy modules, and start them.""" + + logger.info("Initializing DimOS...") + self._dimos.start() + + logger.info("Deploying connection module...") + if self.test_mode: + self.connection = self._dimos.deploy(MockB1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + else: + self.connection = self._dimos.deploy(B1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + + # Configure LCM transports for connection (matching G1 pattern) + self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", TwistStamped) # type: ignore[attr-defined] + self.connection.mode_cmd.transport = core.LCMTransport("/b1/mode", Int32) # type: ignore[attr-defined] + self.connection.odom_in.transport = core.LCMTransport("/state_estimation", Odometry) # type: ignore[attr-defined] + self.connection.odom_pose.transport = core.LCMTransport("/odom", PoseStamped) # type: ignore[attr-defined] + + # Deploy joystick move_vel control + if self.enable_joystick: + from dimos.robot.unitree_webrtc.unitree_b1.joystick_module import JoystickModule + + self.joystick = self._dimos.deploy(JoystickModule) # type: ignore[assignment] + self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", TwistStamped) # type: ignore[attr-defined] + self.joystick.mode_out.transport = core.LCMTransport("/b1/mode", Int32) # type: ignore[attr-defined] + logger.info("Joystick module deployed - pygame window will open") + + self._dimos.start_all_modules() + + self.connection.idle() # type: ignore[attr-defined] # Start in IDLE mode for safety + logger.info("B1 started in IDLE mode (safety)") + + # Deploy ROS bridge if enabled (matching G1 pattern) + if self.enable_ros_bridge: + self._deploy_ros_bridge() + + logger.info(f"UnitreeB1 initialized - UDP control to {self.ip}:{self.port}") + if self.enable_joystick: + logger.info("Pygame joystick module enabled for testing") + if self.enable_ros_bridge: + logger.info("ROS bridge enabled for external control") + + def stop(self) -> None: + self._dimos.stop() + if self.ros_bridge: + self.ros_bridge.stop() + + def _deploy_ros_bridge(self) -> None: + """Deploy and configure ROS bridge (matching G1 implementation).""" + self.ros_bridge = ROSBridge("b1_ros_bridge") # type: ignore[assignment] + + # Add /cmd_vel topic from ROS to DIMOS + self.ros_bridge.add_topic( # type: ignore[attr-defined] + "/cmd_vel", TwistStamped, ROSTwistStamped, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /state_estimation topic from ROS to DIMOS (external odometry) + self.ros_bridge.add_topic( # type: ignore[attr-defined] + "/state_estimation", Odometry, ROSOdometry, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /tf topic from ROS to DIMOS + self.ros_bridge.add_topic( # type: ignore[attr-defined] + "/tf", TFMessage, ROSTFMessage, direction=BridgeDirection.ROS_TO_DIMOS + ) + + self.ros_bridge.start() # type: ignore[attr-defined] + + logger.info("ROS bridge deployed: /cmd_vel, /state_estimation, /tf (ROS → DIMOS)") + + # Robot control methods (standard interface) + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: + """Send movement command to robot using timestamped Twist. + + Args: + twist_stamped: TwistStamped message with linear and angular velocities + duration: How long to move (not used for B1) + """ + if self.connection: + self.connection.move(twist_stamped, duration) + + def stand(self) -> None: + """Put robot in stand mode.""" + if self.connection: + self.connection.stand() + logger.info("B1 switched to STAND mode") + + def walk(self) -> None: + """Put robot in walk mode.""" + if self.connection: + self.connection.walk() + logger.info("B1 switched to WALK mode") + + def idle(self) -> None: + """Put robot in idle mode.""" + if self.connection: + self.connection.idle() + logger.info("B1 switched to IDLE mode") + + +def main() -> None: + """Main entry point for testing B1 robot.""" + import argparse + + parser = argparse.ArgumentParser(description="Unitree B1 Robot Control") + parser.add_argument("--ip", default="192.168.12.1", help="Robot IP address") + parser.add_argument("--port", type=int, default=9090, help="UDP port") + parser.add_argument("--joystick", action="store_true", help="Enable pygame joystick control") + parser.add_argument("--ros-bridge", action="store_true", default=True, help="Enable ROS bridge") + parser.add_argument( + "--no-ros-bridge", dest="ros_bridge", action="store_false", help="Disable ROS bridge" + ) + parser.add_argument("--output-dir", help="Output directory for logs/data") + parser.add_argument( + "--test", action="store_true", help="Test mode - print commands instead of UDP" + ) + + args = parser.parse_args() + + robot = UnitreeB1( # type: ignore[abstract] + ip=args.ip, + port=args.port, + output_dir=args.output_dir, + enable_joystick=args.joystick, + enable_ros_bridge=args.ros_bridge, + test_mode=args.test, + ) + + robot.start() + + try: + if args.joystick: + print("\n" + "=" * 50) + print("B1 JOYSTICK CONTROL") + print("=" * 50) + print("Focus the pygame window to control") + print("Press keys in pygame window:") + print(" 0/1/2 = Idle/Stand/Walk modes") + print(" WASD = Move/Turn") + print(" JL = Strafe") + print(" Space/Q = Emergency Stop") + print(" ESC = Quit pygame (then Ctrl+C to exit)") + print("=" * 50 + "\n") + + import time + + while True: + time.sleep(1) + else: + # Manual control example + print("\nB1 Robot ready for commands") + print("Use robot.idle(), robot.stand(), robot.walk() to change modes") + if args.ros_bridge: + print("ROS bridge active - listening for /cmd_vel and /state_estimation") + else: + print("Use robot.move(TwistStamped(...)) to send velocity commands") + print("Press Ctrl+C to exit\n") + + import time + + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("\nShutting down...") + finally: + robot.stop() + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py new file mode 100644 index 0000000000..ce5e0ed266 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blueprint configurations for Unitree G1 humanoid robot. + +This module provides pre-configured blueprints for various G1 robot setups, +from basic teleoperation to full autonomous agent configurations. +""" + +from dimos_lcm.foxglove_msgs import SceneUpdate # type: ignore[import-untyped] +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( # type: ignore[import-untyped] + ImageAnnotations, +) +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.navigation import navigation_skill +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core.blueprints import autoconnect +from dimos.core.transport import LCMTransport, pSHMTransport +from dimos.hardware.camera import zed +from dimos.hardware.camera.module import camera_module +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + Vector3, +) +from dimos.msgs.nav_msgs import Odometry, Path +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.std_msgs import Bool +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.navigation.bt_navigator.navigator import ( + behavior_tree_navigator, +) +from dimos.navigation.frontier_exploration import ( + wavefront_frontier_explorer, +) +from dimos.navigation.global_planner import astar_planner +from dimos.navigation.local_planner.holonomic_local_planner import ( + holonomic_local_planner, +) +from dimos.navigation.rosnav import ros_nav +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.perception.detection.module3D import Detection3DModule, detection3d_module +from dimos.perception.detection.moduleDB import ObjectDBModule, detectionDB_module +from dimos.perception.detection.person_tracker import PersonTracker, person_tracker_module +from dimos.perception.object_tracker import object_tracking +from dimos.perception.spatial_perception import spatial_memory +from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.unitree.connection.g1 import g1_connection +from dimos.robot.unitree.connection.g1sim import g1_sim_connection +from dimos.robot.unitree_webrtc.keyboard_teleop import keyboard_teleop +from dimos.robot.unitree_webrtc.type.map import mapper +from dimos.robot.unitree_webrtc.unitree_g1_skill_container import g1_skills +from dimos.utils.monitoring import utilization +from dimos.web.websocket_vis.websocket_vis_module import websocket_vis + +_basic_no_nav = ( + autoconnect( + camera_module( + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0.0, 0.2, 0.0)), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + frequency=15, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ), + # SLAM and mapping + mapper(voxel_size=0.5, global_publish_interval=2.5), + # Navigation stack + astar_planner(), + holonomic_local_planner(), + wavefront_frontier_explorer(), + # Visualization + websocket_vis(), + foxglove_bridge(), + ) + .global_config(n_dask_workers=4, robot_model="unitree_g1") + .transports( + { + # G1 uses Twist for movement commands + ("cmd_vel", Twist): LCMTransport("/cmd_vel", Twist), + # State estimation from ROS + ("state_estimation", Odometry): LCMTransport("/state_estimation", Odometry), + # Odometry output from ROSNavigationModule + ("odom", PoseStamped): LCMTransport("/odom", PoseStamped), + # Navigation module topics from nav_bot + ("goal_req", PoseStamped): LCMTransport("/goal_req", PoseStamped), + ("goal_active", PoseStamped): LCMTransport("/goal_active", PoseStamped), + ("path_active", Path): LCMTransport("/path_active", Path), + ("pointcloud", PointCloud2): LCMTransport("/lidar", PointCloud2), + ("global_pointcloud", PointCloud2): LCMTransport("/map", PointCloud2), + # Original navigation topics for backwards compatibility + ("goal_pose", PoseStamped): LCMTransport("/goal_pose", PoseStamped), + ("goal_reached", Bool): LCMTransport("/goal_reached", Bool), + ("cancel_goal", Bool): LCMTransport("/cancel_goal", Bool), + # Camera topics (if camera module is added) + ("color_image", Image): LCMTransport("/g1/color_image", Image), + ("camera_info", CameraInfo): LCMTransport("/g1/camera_info", CameraInfo), + } + ) +) + +basic_ros = autoconnect( + _basic_no_nav, + g1_connection(), + ros_nav(), +) + +basic_sim = autoconnect( + _basic_no_nav, + g1_sim_connection(), + behavior_tree_navigator(), +) + +_perception_and_memory = autoconnect( + spatial_memory(), + object_tracking(frame_id="camera_link"), + utilization(), +) + +standard = autoconnect( + basic_ros, + _perception_and_memory, +).global_config(n_dask_workers=8) + +standard_sim = autoconnect( + basic_sim, + _perception_and_memory, +).global_config(n_dask_workers=8) + +# Optimized configuration using shared memory for images +standard_with_shm = autoconnect( + standard.transports( + { + ("color_image", Image): pSHMTransport( + "/g1/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + } + ), + foxglove_bridge( + shm_channels=[ + "/g1/color_image#sensor_msgs.Image", + ] + ), +) + +_agentic_skills = autoconnect( + llm_agent(), + human_input(), + navigation_skill(), + g1_skills(), +) + +# Full agentic configuration with LLM and skills +agentic = autoconnect( + standard, + _agentic_skills, +) + +agentic_sim = autoconnect( + standard_sim, + _agentic_skills, +) + +# Configuration with joystick control for teleoperation +with_joystick = autoconnect( + basic_ros, + keyboard_teleop(), # Pygame-based joystick control +) + +# Detection configuration with person tracking and 3D detection +detection = ( + autoconnect( + basic_ros, + # Person detection modules with YOLO + detection3d_module( + camera_info=zed.CameraInfo.SingleWebcam, + detector=YoloPersonDetector, + ), + detectionDB_module( + camera_info=zed.CameraInfo.SingleWebcam, + filter=lambda det: det.class_id == 0, # Filter for person class only + ), + person_tracker_module( + cameraInfo=zed.CameraInfo.SingleWebcam, + ), + ) + .global_config(n_dask_workers=8) + .remappings( + [ + # Connect detection modules to camera and lidar + (Detection3DModule, "image", "color_image"), + (Detection3DModule, "pointcloud", "pointcloud"), + (ObjectDBModule, "image", "color_image"), + (ObjectDBModule, "pointcloud", "pointcloud"), + (PersonTracker, "image", "color_image"), + (PersonTracker, "detections", "detections_2d"), + ] + ) + .transports( + { + # Detection 3D module outputs + ("detections", Detection3DModule): LCMTransport( + "/detector3d/detections", Detection2DArray + ), + ("annotations", Detection3DModule): LCMTransport( + "/detector3d/annotations", ImageAnnotations + ), + ("scene_update", Detection3DModule): LCMTransport( + "/detector3d/scene_update", SceneUpdate + ), + ("detected_pointcloud_0", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/0", PointCloud2 + ), + ("detected_pointcloud_1", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/1", PointCloud2 + ), + ("detected_pointcloud_2", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/2", PointCloud2 + ), + ("detected_image_0", Detection3DModule): LCMTransport("/detector3d/image/0", Image), + ("detected_image_1", Detection3DModule): LCMTransport("/detector3d/image/1", Image), + ("detected_image_2", Detection3DModule): LCMTransport("/detector3d/image/2", Image), + # Detection DB module outputs + ("detections", ObjectDBModule): LCMTransport( + "/detectorDB/detections", Detection2DArray + ), + ("annotations", ObjectDBModule): LCMTransport( + "/detectorDB/annotations", ImageAnnotations + ), + ("scene_update", ObjectDBModule): LCMTransport("/detectorDB/scene_update", SceneUpdate), + ("detected_pointcloud_0", ObjectDBModule): LCMTransport( + "/detectorDB/pointcloud/0", PointCloud2 + ), + ("detected_pointcloud_1", ObjectDBModule): LCMTransport( + "/detectorDB/pointcloud/1", PointCloud2 + ), + ("detected_pointcloud_2", ObjectDBModule): LCMTransport( + "/detectorDB/pointcloud/2", PointCloud2 + ), + ("detected_image_0", ObjectDBModule): LCMTransport("/detectorDB/image/0", Image), + ("detected_image_1", ObjectDBModule): LCMTransport("/detectorDB/image/1", Image), + ("detected_image_2", ObjectDBModule): LCMTransport("/detectorDB/image/2", Image), + # Person tracker outputs + ("target", PersonTracker): LCMTransport("/person_tracker/target", PoseStamped), + } + ) +) + +# Full featured configuration with everything +full_featured = autoconnect( + standard_with_shm, + _agentic_skills, + keyboard_teleop(), +) diff --git a/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py b/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py new file mode 100644 index 0000000000..cd8ad93841 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py @@ -0,0 +1,161 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unitree G1 skill container for the new agents2 framework. +Dynamically generates skills for G1 humanoid robot including arm controls and movement modes. +""" + +import difflib + +from dimos.core.core import rpc +from dimos.core.skill_module import SkillModule +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.protocol.skill.skill import skill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# G1 Arm Actions - all use api_id 7106 on topic "rt/api/arm/request" +G1_ARM_CONTROLS = [ + ("Handshake", 27, "Perform a handshake gesture with the right hand."), + ("HighFive", 18, "Give a high five with the right hand."), + ("Hug", 19, "Perform a hugging gesture with both arms."), + ("HighWave", 26, "Wave with the hand raised high."), + ("Clap", 17, "Clap hands together."), + ("FaceWave", 25, "Wave near the face level."), + ("LeftKiss", 12, "Blow a kiss with the left hand."), + ("ArmHeart", 20, "Make a heart shape with both arms overhead."), + ("RightHeart", 21, "Make a heart gesture with the right hand."), + ("HandsUp", 15, "Raise both hands up in the air."), + ("XRay", 24, "Hold arms in an X-ray pose position."), + ("RightHandUp", 23, "Raise only the right hand up."), + ("Reject", 22, "Make a rejection or 'no' gesture."), + ("CancelAction", 99, "Cancel any current arm action and return hands to neutral position."), +] + +# G1 Movement Modes - all use api_id 7101 on topic "rt/api/sport/request" +G1_MODE_CONTROLS = [ + ("WalkMode", 500, "Switch to normal walking mode."), + ("WalkControlWaist", 501, "Switch to walking mode with waist control."), + ("RunMode", 801, "Switch to running mode."), +] + +_ARM_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_ARM_CONTROLS +} + +_MODE_COMMANDS: dict[str, tuple[int, str]] = { + name: (id_, description) for name, id_, description in G1_MODE_CONTROLS +} + + +class UnitreeG1SkillContainer(SkillModule): + rpc_calls: list[str] = [ + "G1ConnectionModule.move", + "G1ConnectionModule.publish_request", + ] + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + @skill() + def move(self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0) -> str: + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions. + + Example call: + args = { "x": 0.5, "y": 0.0, "yaw": 0.0, "duration": 2.0 } + move(**args) + + Args: + x: Forward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds) + """ + + move_rpc = self.get_rpc_calls("G1ConnectionModule.move") + twist = Twist(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) + move_rpc(twist, duration=duration) + return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" + + @skill() + def execute_arm_command(self, command_name: str) -> str: + return self._execute_g1_command(_ARM_COMMANDS, 7106, command_name) + + @skill() + def execute_mode_command(self, command_name: str) -> str: + return self._execute_g1_command(_MODE_COMMANDS, 7101, command_name) + + def _execute_g1_command( + self, command_dict: dict[str, tuple[int, str]], api_id: int, command_name: str + ) -> str: + publish_request_rpc = self.get_rpc_calls("G1ConnectionModule.publish_request") + + if command_name not in command_dict: + suggestions = difflib.get_close_matches( + command_name, command_dict.keys(), n=3, cutoff=0.6 + ) + return f"There's no '{command_name}' command. Did you mean: {suggestions}" + + id_, _ = command_dict[command_name] + + try: + publish_request_rpc( + "rt/api/sport/request", {"api_id": api_id, "parameter": {"data": id_}} + ) + return f"'{command_name}' command executed successfully." + except Exception as e: + logger.error(f"Failed to execute {command_name}: {e}") + return "Failed to execute the command." + + +_arm_commands = "\n".join( + [f'- "{name}": {description}' for name, (_, description) in _ARM_COMMANDS.items()] +) + +UnitreeG1SkillContainer.execute_arm_command.__doc__ = f"""Execute a Unitree G1 arm command. + +Example usage: + + execute_arm_command("ArmHeart") + +Here are all the command names and what they do. + +{_arm_commands} +""" + +_mode_commands = "\n".join( + [f'- "{name}": {description}' for name, (_, description) in _MODE_COMMANDS.items()] +) + +UnitreeG1SkillContainer.execute_mode_command.__doc__ = f"""Execute a Unitree G1 mode command. + +Example usage: + + execute_mode_command("RunMode") + +Here are all the command names and what they do. + +{_mode_commands} +""" + +g1_skills = UnitreeG1SkillContainer.blueprint + +__all__ = ["UnitreeG1SkillContainer", "g1_skills"] diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py new file mode 100644 index 0000000000..24068d86dd --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.cli.web import web_input +from dimos.agents2.ollama_agent import ollama_installed +from dimos.agents2.skills.navigation import navigation_skill +from dimos.agents2.skills.speak_skill import speak_skill +from dimos.agents2.spec import Provider +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core.blueprints import autoconnect +from dimos.core.transport import JpegLcmTransport, JpegShmTransport, LCMTransport, pSHMTransport +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image +from dimos.navigation.bt_navigator.navigator import ( + behavior_tree_navigator, +) +from dimos.navigation.frontier_exploration import ( + wavefront_frontier_explorer, +) +from dimos.navigation.global_planner import astar_planner +from dimos.navigation.local_planner.holonomic_local_planner import ( + holonomic_local_planner, +) +from dimos.perception.object_tracker import object_tracking +from dimos.perception.spatial_perception import spatial_memory +from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.unitree.connection.go2 import go2_connection +from dimos.robot.unitree_webrtc.type.map import mapper +from dimos.robot.unitree_webrtc.unitree_skill_container import unitree_skills +from dimos.utils.monitoring import utilization +from dimos.web.websocket_vis.websocket_vis_module import websocket_vis + +basic = ( + autoconnect( + go2_connection(), + mapper(voxel_size=0.5, global_publish_interval=2.5), + astar_planner(), + holonomic_local_planner(), + behavior_tree_navigator(), + wavefront_frontier_explorer(), + websocket_vis(), + foxglove_bridge(), + ) + .global_config(n_dask_workers=4, robot_model="unitree_go2") + .transports( + # These are kept the same so that we don't have to change foxglove configs. + # Although we probably should. + { + ("color_image", Image): LCMTransport("/go2/color_image", Image), + ("camera_pose", PoseStamped): LCMTransport("/go2/camera_pose", PoseStamped), + ("camera_info", CameraInfo): LCMTransport("/go2/camera_info", CameraInfo), + } + ) +) + +standard = autoconnect( + basic, + spatial_memory(), + object_tracking(frame_id="camera_link"), + utilization(), +).global_config(n_dask_workers=8) + +standard_with_shm = autoconnect( + standard.transports( + { + ("color_image", Image): pSHMTransport( + "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + } + ), + foxglove_bridge( + shm_channels=[ + "/go2/color_image#sensor_msgs.Image", + ] + ), +) + +standard_with_jpeglcm = standard.transports( + { + ("color_image", Image): JpegLcmTransport("/go2/color_image", Image), + } +) + +standard_with_jpegshm = autoconnect( + standard.transports( + { + ("color_image", Image): JpegShmTransport("/go2/color_image", quality=75), + } + ), + foxglove_bridge( + jpeg_shm_channels=[ + "/go2/color_image#sensor_msgs.Image", + ] + ), +) + +_common_agentic = autoconnect( + human_input(), + navigation_skill(), + unitree_skills(), + web_input(), + speak_skill(), +) + +agentic = autoconnect( + standard, + llm_agent(), + _common_agentic, +) + +agentic_ollama = autoconnect( + standard, + llm_agent( + model="qwen3:8b", + provider=Provider.OLLAMA, # type: ignore[attr-defined] + ), + _common_agentic, +).requirements( + ollama_installed, +) + +agentic_huggingface = autoconnect( + standard, + llm_agent( + model="Qwen/Qwen2.5-1.5B-Instruct", + provider=Provider.HUGGINGFACE, # type: ignore[attr-defined] + ), + _common_agentic, +) diff --git a/dimos/robot/unitree_webrtc/unitree_skill_container.py b/dimos/robot/unitree_webrtc/unitree_skill_container.py new file mode 100644 index 0000000000..2584691109 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_skill_container.py @@ -0,0 +1,147 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import datetime +import difflib +import time +from typing import TYPE_CHECKING + +from go2_webrtc_driver.constants import RTC_TOPIC # type: ignore[import-untyped] + +from dimos.core.core import rpc +from dimos.core.skill_module import SkillModule +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Reducer, Stream +from dimos.robot.unitree_webrtc.unitree_skills import UNITREE_WEBRTC_CONTROLS +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.core.rpc_client import RpcCall + +logger = setup_logger() + + +_UNITREE_COMMANDS = { + name: (id_, description) + for name, id_, description in UNITREE_WEBRTC_CONTROLS + if name not in ["Reverse", "Spin"] +} + + +class UnitreeSkillContainer(SkillModule): + """Container for Unitree Go2 robot skills using the new framework.""" + + _move: RpcCall | None = None + _publish_request: RpcCall | None = None + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + @rpc + def set_ConnectionModule_move(self, callable: RpcCall) -> None: + self._move = callable + self._move.set_rpc(self.rpc) # type: ignore[arg-type] + + @rpc + def set_ConnectionModule_publish_request(self, callable: RpcCall) -> None: + self._publish_request = callable + self._publish_request.set_rpc(self.rpc) # type: ignore[arg-type] + + @skill() + def move(self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0) -> str: + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions. + + Example call: + args = { "x": 0.5, "y": 0.0, "yaw": 0.0, "duration": 2.0 } + move(**args) + + Args: + x: Forward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds) + """ + if self._move is None: + return "Error: Robot not connected" + + twist = Twist(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) + self._move(twist, duration=duration) + return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" + + @skill() + def wait(self, seconds: float) -> str: + """Wait for a specified amount of time. + + Args: + seconds: Seconds to wait + """ + time.sleep(seconds) + return f"Wait completed with length={seconds}s" + + @skill(stream=Stream.passive, reducer=Reducer.latest, hide_skill=True) # type: ignore[arg-type] + def current_time(self): # type: ignore[no-untyped-def] + """Provides current time implicitly, don't call this skill directly.""" + print("Starting current_time skill") + while True: + yield str(datetime.datetime.now()) + time.sleep(1) + + @skill() + def execute_sport_command(self, command_name: str) -> str: + if self._publish_request is None: + return f"Error: Robot not connected (cannot execute {command_name})" + + if command_name not in _UNITREE_COMMANDS: + suggestions = difflib.get_close_matches( + command_name, _UNITREE_COMMANDS.keys(), n=3, cutoff=0.6 + ) + return f"There's no '{command_name}' command. Did you mean: {suggestions}" + + id_, _ = _UNITREE_COMMANDS[command_name] + + try: + self._publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": id_}) + return f"'{command_name}' command executed successfully." + except Exception as e: + logger.error(f"Failed to execute {command_name}: {e}") + return "Failed to execute the command." + + +_commands = "\n".join( + [f'- "{name}": {description}' for name, (_, description) in _UNITREE_COMMANDS.items()] +) + +UnitreeSkillContainer.execute_sport_command.__doc__ = f"""Execute a Unitree sport command. + +Example usage: + + execute_sport_command("FrontPounce") + +Here are all the command names and what they do. + +{_commands} +""" + + +unitree_skills = UnitreeSkillContainer.blueprint + +__all__ = ["UnitreeSkillContainer", "unitree_skills"] diff --git a/dimos/robot/unitree_webrtc/unitree_skills.py b/dimos/robot/unitree_webrtc/unitree_skills.py new file mode 100644 index 0000000000..36403629ce --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_skills.py @@ -0,0 +1,357 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +from pydantic import Field + +if TYPE_CHECKING: + from dimos.robot.robot import MockRobot, Robot # type: ignore[attr-defined] +else: + Robot = "Robot" + MockRobot = "MockRobot" + +from go2_webrtc_driver.constants import RTC_TOPIC # type: ignore[import-untyped] + +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary +from dimos.types.constants import Colors + +# Module-level constant for Unitree Go2 WebRTC control definitions +UNITREE_WEBRTC_CONTROLS: list[tuple[str, int, str]] = [ + ("Damp", 1001, "Lowers the robot to the ground fully."), + ( + "BalanceStand", + 1002, + "Activates a mode that maintains the robot in a balanced standing position.", + ), + ( + "StandUp", + 1004, + "Commands the robot to transition from a sitting or prone position to a standing posture.", + ), + ( + "StandDown", + 1005, + "Instructs the robot to move from a standing position to a sitting or prone posture.", + ), + ( + "RecoveryStand", + 1006, + "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips, Must run after skills like sit and jump and standup.", + ), + ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), + ( + "RiseSit", + 1010, + "Commands the robot to rise back to a standing position from a sitting posture.", + ), + ( + "SwitchGait", + 1011, + "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", + ), + ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), + ( + "BodyHeight", + 1013, + "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", + ), + ( + "FootRaiseHeight", + 1014, + "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", + ), + ( + "SpeedLevel", + 1015, + "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", + ), + ( + "Hello", + 1016, + "Performs a greeting action, which could involve a wave or other friendly gesture.", + ), + ("Stretch", 1017, "Engages the robot in a stretching routine."), + ( + "TrajectoryFollow", + 1018, + "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", + ), + ( + "ContinuousGait", + 1019, + "Enables a mode for continuous walking or running, ideal for long-distance travel.", + ), + ("Content", 1020, "To display or trigger when the robot is happy."), + ("Wallow", 1021, "The robot falls onto its back and rolls around."), + ( + "Dance1", + 1022, + "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", + ), + ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), + ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), + ( + "GetFootRaiseHeight", + 1025, + "Retrieves the current height at which the robot's feet are being raised during movement.", + ), + ( + "GetSpeedLevel", + 1026, + "Retrieves the current speed level setting of the robot.", + ), + ( + "SwitchJoystick", + 1027, + "Switches the robot's control mode to respond to joystick input for manual operation.", + ), + ( + "Pose", + 1028, + "Commands the robot to assume a specific pose or posture as predefined in its programming.", + ), + ("Scrape", 1029, "The robot performs a scraping motion."), + ( + "FrontFlip", + 1030, + "Commands the robot to perform a front flip, showcasing its agility and dynamic movement capabilities.", + ), + ( + "FrontJump", + 1031, + "Instructs the robot to jump forward, demonstrating its explosive movement capabilities.", + ), + ( + "FrontPounce", + 1032, + "Commands the robot to perform a pouncing motion forward.", + ), + ( + "WiggleHips", + 1033, + "The robot performs a hip wiggling motion, often used for entertainment or demonstration purposes.", + ), + ( + "GetState", + 1034, + "Retrieves the current operational state of the robot, including its mode, position, and status.", + ), + ( + "EconomicGait", + 1035, + "Engages a more energy-efficient walking or running mode to conserve battery life.", + ), + ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), + ( + "Handstand", + 1301, + "Commands the robot to perform a handstand, demonstrating balance and control.", + ), + ( + "CrossStep", + 1302, + "Commands the robot to perform cross-step movements.", + ), + ( + "OnesidedStep", + 1303, + "Commands the robot to perform one-sided step movements.", + ), + ("Bound", 1304, "Commands the robot to perform bounding movements."), + ("MoonWalk", 1305, "Commands the robot to perform a moonwalk motion."), + ("LeftFlip", 1042, "Executes a flip towards the left side."), + ("RightFlip", 1043, "Performs a flip towards the right side."), + ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), +] + +# Module-level constants for Unitree G1 WebRTC control definitions +# G1 Arm Actions - all use api_id 7106 on topic "rt/api/arm/request" +G1_ARM_CONTROLS: list[tuple[str, int, str]] = [ + ("Handshake", 27, "Perform a handshake gesture with the right hand."), + ("HighFive", 18, "Give a high five with the right hand."), + ("Hug", 19, "Perform a hugging gesture with both arms."), + ("HighWave", 26, "Wave with the hand raised high."), + ("Clap", 17, "Clap hands together."), + ("FaceWave", 25, "Wave near the face level."), + ("LeftKiss", 12, "Blow a kiss with the left hand."), + ("ArmHeart", 20, "Make a heart shape with both arms overhead."), + ("RightHeart", 21, "Make a heart gesture with the right hand."), + ("HandsUp", 15, "Raise both hands up in the air."), + ("XRay", 24, "Hold arms in an X-ray pose position."), + ("RightHandUp", 23, "Raise only the right hand up."), + ("Reject", 22, "Make a rejection or 'no' gesture."), + ("CancelAction", 99, "Cancel any current arm action and return hands to neutral position."), +] + +# G1 Movement Modes - all use api_id 7101 on topic "rt/api/sport/request" +G1_MODE_CONTROLS: list[tuple[str, int, str]] = [ + ("WalkMode", 500, "Switch to normal walking mode."), + ("WalkControlWaist", 501, "Switch to walking mode with waist control."), + ("RunMode", 801, "Switch to running mode."), +] + +# region MyUnitreeSkills + + +class MyUnitreeSkills(SkillLibrary): + """My Unitree Skills for WebRTC interface.""" + + def __init__(self, robot: Robot | None = None, robot_type: str = "go2") -> None: + """Initialize Unitree skills library. + + Args: + robot: Optional robot instance + robot_type: Type of robot ("go2" or "g1"), defaults to "go2" + """ + super().__init__() + self._robot: Robot = None # type: ignore[assignment] + self.robot_type = robot_type.lower() + + if self.robot_type not in ["go2", "g1"]: + raise ValueError(f"Unsupported robot type: {robot_type}. Must be 'go2' or 'g1'") + + # Add dynamic skills to this class based on robot type + dynamic_skills = self.create_skills_live() + self.register_skills(dynamic_skills) # type: ignore[arg-type] + + @classmethod + def register_skills(cls, skill_classes: AbstractSkill | list[AbstractSkill]) -> None: + """Add multiple skill classes as class attributes. + + Args: + skill_classes: List of skill classes to add + """ + if not isinstance(skill_classes, list): + skill_classes = [skill_classes] + + for skill_class in skill_classes: + # Add to the class as a skill + setattr(cls, skill_class.__name__, skill_class) # type: ignore[attr-defined] + + def initialize_skills(self) -> None: + for skill_class in self.get_class_skills(): + self.create_instance(skill_class.__name__, robot=self._robot) # type: ignore[attr-defined] + + # Refresh the class skills + self.refresh_class_skills() + + def create_skills_live(self) -> list[AbstractRobotSkill]: + # ================================================ + # Procedurally created skills + # ================================================ + class BaseUnitreeSkill(AbstractRobotSkill): + """Base skill for dynamic skill creation.""" + + def __call__(self) -> str: + super().__call__() # type: ignore[no-untyped-call] + + # For Go2: Simple api_id based call + if hasattr(self, "_app_id"): + string = f"{Colors.GREEN_PRINT_COLOR}Executing Go2 skill: {self.__class__.__name__} with api_id={self._app_id}{Colors.RESET_COLOR}" + print(string) + self._robot.connection.publish_request( # type: ignore[attr-defined] + RTC_TOPIC["SPORT_MOD"], {"api_id": self._app_id} + ) + return f"{self.__class__.__name__} executed successfully" + + # For G1: Fixed api_id with parameter data + elif hasattr(self, "_data_value"): + string = f"{Colors.GREEN_PRINT_COLOR}Executing G1 skill: {self.__class__.__name__} with data={self._data_value}{Colors.RESET_COLOR}" + print(string) + self._robot.connection.publish_request( # type: ignore[attr-defined] + self._topic, # type: ignore[attr-defined] + {"api_id": self._api_id, "parameter": {"data": self._data_value}}, # type: ignore[attr-defined] + ) + return f"{self.__class__.__name__} executed successfully" + else: + raise RuntimeError( + f"Skill {self.__class__.__name__} missing required attributes" + ) + + skills_classes = [] + + if self.robot_type == "g1": + # Create G1 arm skills + for name, data_value, description in G1_ARM_CONTROLS: + skill_class = type( + name, + (BaseUnitreeSkill,), + { + "__doc__": description, + "_topic": "rt/api/arm/request", + "_api_id": 7106, + "_data_value": data_value, + }, + ) + skills_classes.append(skill_class) + + # Create G1 mode skills + for name, data_value, description in G1_MODE_CONTROLS: + skill_class = type( + name, + (BaseUnitreeSkill,), + { + "__doc__": description, + "_topic": "rt/api/sport/request", + "_api_id": 7101, + "_data_value": data_value, + }, + ) + skills_classes.append(skill_class) + else: + # Go2 skills (existing code) + for name, app_id, description in UNITREE_WEBRTC_CONTROLS: + if name not in ["Reverse", "Spin"]: # Exclude reverse and spin skills + skill_class = type( + name, (BaseUnitreeSkill,), {"__doc__": description, "_app_id": app_id} + ) + skills_classes.append(skill_class) + + return skills_classes # type: ignore[return-value] + + # region Class-based Skills + + class Move(AbstractRobotSkill): + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" + + x: float = Field(..., description="Forward velocity (m/s).") + y: float = Field(default=0.0, description="Left/right velocity (m/s)") + yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") + duration: float = Field(default=0.0, description="How long to move (seconds).") + + def __call__(self) -> str: + self._robot.move( # type: ignore[attr-defined] + Twist(linear=Vector3(self.x, self.y, 0.0), angular=Vector3(0.0, 0.0, self.yaw)), + duration=self.duration, + ) + return f"started moving with velocity={self.x}, {self.y}, {self.yaw} for {self.duration} seconds" + + class Wait(AbstractSkill): + """Wait for a specified amount of time.""" + + seconds: float = Field(..., description="Seconds to wait") + + def __call__(self) -> str: + time.sleep(self.seconds) + return f"Wait completed with length={self.seconds}s" + + # endregion + + +# endregion diff --git a/dimos/robot/utils/README.md b/dimos/robot/utils/README.md new file mode 100644 index 0000000000..5a84b20c4a --- /dev/null +++ b/dimos/robot/utils/README.md @@ -0,0 +1,38 @@ +# Robot Utils + +## RobotDebugger + +The `RobotDebugger` provides a way to debug a running robot through the python shell. + +Requirements: + +```bash +pip install rpyc +``` + +### Usage + +1. **Add to your robot application:** + ```python + from dimos.robot.utils.robot_debugger import RobotDebugger + + # In your robot application's context manager or main loop: + with RobotDebugger(robot): + # Your robot code here + pass + + # Or better, with an exit stack. + exit_stack.enter_context(RobotDebugger(robot)) + ``` + +2. **Start your robot with debugging enabled:** + ```bash + ROBOT_DEBUGGER=true python your_robot_script.py + ``` + +3. **Open the python shell:** + ```bash + ./bin/robot-debugger + >>> robot.explore() + True + ``` diff --git a/dimos/robot/utils/robot_debugger.py b/dimos/robot/utils/robot_debugger.py new file mode 100644 index 0000000000..c7f3cd7291 --- /dev/null +++ b/dimos/robot/utils/robot_debugger.py @@ -0,0 +1,59 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dimos.core.resource import Resource +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class RobotDebugger(Resource): + def __init__(self, robot) -> None: # type: ignore[no-untyped-def] + self._robot = robot + self._threaded_server = None + + def start(self) -> None: + if not os.getenv("ROBOT_DEBUGGER"): + return + + try: + import rpyc # type: ignore[import-not-found] + from rpyc.utils.server import ThreadedServer # type: ignore[import-not-found] + except ImportError: + return + + logger.info( + "Starting the robot debugger. You can open a python shell with `./bin/robot-debugger`" + ) + + robot = self._robot + + class RobotService(rpyc.Service): # type: ignore[misc] + def exposed_robot(self): # type: ignore[no-untyped-def] + return robot + + self._threaded_server = ThreadedServer( + RobotService, + port=18861, + protocol_config={ + "allow_all_attrs": True, + }, + ) + self._threaded_server.start() # type: ignore[attr-defined] + + def stop(self) -> None: + if self._threaded_server: + self._threaded_server.close() diff --git a/dimos/simulation/README.md b/dimos/simulation/README.md new file mode 100644 index 0000000000..95d8b4cda1 --- /dev/null +++ b/dimos/simulation/README.md @@ -0,0 +1,98 @@ +# Dimensional Streaming Setup + +This guide explains how to set up and run the Isaac Sim and Genesis streaming functionality via Docker. The setup is tested on Ubuntu 22.04 (recommended). + +## Prerequisites + +1. **NVIDIA Driver** + - NVIDIA Driver 535 must be installed + - Check your driver: `nvidia-smi` + - If not installed: + ```bash + sudo apt-get update + sudo apt install build-essential -y + sudo apt-get install -y nvidia-driver-535 + sudo reboot + ``` + +2. **CUDA Toolkit** + ```bash + sudo apt install -y nvidia-cuda-toolkit + ``` + +3. **Docker** + ```bash + # Install Docker + curl -fsSL https://get.docker.com -o get-docker.sh + sudo sh get-docker.sh + + # Post-install steps + sudo groupadd docker + sudo usermod -aG docker $USER + newgrp docker + ``` + +4. **NVIDIA Container Toolkit** + ```bash + # Add NVIDIA Container Toolkit repository + curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg + curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ + sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list + sudo apt-get update + + # Install the toolkit + sudo apt-get install -y nvidia-container-toolkit + sudo systemctl restart docker + + # Configure runtime + sudo nvidia-ctk runtime configure --runtime=docker + sudo systemctl restart docker + + # Verify installation + sudo docker run --rm --runtime=nvidia --gpus all ubuntu nvidia-smi + ``` + +5. **Pull Isaac Sim Image** + ```bash + sudo docker pull nvcr.io/nvidia/isaac-sim:4.2.0 + ``` + +6. **TO DO: Add ROS2 websocket server for client-side streaming** + +## Running the Streaming Example + +1. **Navigate to the docker/simulation directory** + ```bash + cd docker/simulation + ``` + +2. **Build and run with docker-compose** + For Isaac Sim: + ```bash + docker compose -f isaac/docker-compose.yml build + docker compose -f isaac/docker-compose.yml up + + ``` + + For Genesis: + ```bash + docker compose -f genesis/docker-compose.yml build + docker compose -f genesis/docker-compose.yml up + + ``` + +This will: +- Build the dimos_simulator image with ROS2 and required dependencies +- Start the MediaMTX RTSP server +- Run the test streaming example from either: + - `/tests/isaacsim/stream_camera.py` for Isaac Sim + - `/tests/genesissim/stream_camera.py` for Genesis + +## Viewing the Stream + +The camera stream will be available at: + +- RTSP: `rtsp://localhost:8554/stream` or `rtsp://:8554/stream` + +You can view it using VLC or any RTSP-capable player. diff --git a/dimos/simulation/__init__.py b/dimos/simulation/__init__.py new file mode 100644 index 0000000000..2b77f47097 --- /dev/null +++ b/dimos/simulation/__init__.py @@ -0,0 +1,15 @@ +# Try to import Isaac Sim components +try: + from .isaac import IsaacSimulator, IsaacStream +except ImportError: + IsaacSimulator = None # type: ignore + IsaacStream = None # type: ignore + +# Try to import Genesis components +try: + from .genesis import GenesisSimulator, GenesisStream +except ImportError: + GenesisSimulator = None # type: ignore + GenesisStream = None # type: ignore + +__all__ = ["GenesisSimulator", "GenesisStream", "IsaacSimulator", "IsaacStream"] diff --git a/dimos/simulation/base/__init__.py b/dimos/simulation/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/simulation/base/simulator_base.py b/dimos/simulation/base/simulator_base.py new file mode 100644 index 0000000000..59e366a1d3 --- /dev/null +++ b/dimos/simulation/base/simulator_base.py @@ -0,0 +1,47 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class SimulatorBase(ABC): + """Base class for simulators.""" + + @abstractmethod + def __init__( + self, + headless: bool = True, + open_usd: str | None = None, # Keep for Isaac compatibility + entities: list[dict[str, str | dict]] | None = None, # type: ignore[type-arg] # Add for Genesis + ) -> None: + """Initialize the simulator. + + Args: + headless: Whether to run without visualization + open_usd: Path to USD file (for Isaac) + entities: List of entity configurations (for Genesis) + """ + self.headless = headless + self.open_usd = open_usd + self.stage = None + + @abstractmethod + def get_stage(self): # type: ignore[no-untyped-def] + """Get the current stage/scene.""" + pass + + @abstractmethod + def close(self): # type: ignore[no-untyped-def] + """Close the simulation.""" + pass diff --git a/dimos/simulation/base/stream_base.py b/dimos/simulation/base/stream_base.py new file mode 100644 index 0000000000..9f8898439e --- /dev/null +++ b/dimos/simulation/base/stream_base.py @@ -0,0 +1,116 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from pathlib import Path +import subprocess +from typing import Literal + +AnnotatorType = Literal["rgb", "normals", "bounding_box_3d", "motion_vectors"] +TransportType = Literal["tcp", "udp"] + + +class StreamBase(ABC): + """Base class for simulation streaming.""" + + @abstractmethod + def __init__( # type: ignore[no-untyped-def] + self, + simulator, + width: int = 1920, + height: int = 1080, + fps: int = 60, + camera_path: str = "/World/camera", + annotator_type: AnnotatorType = "rgb", + transport: TransportType = "tcp", + rtsp_url: str = "rtsp://mediamtx:8554/stream", + usd_path: str | Path | None = None, + ) -> None: + """Initialize the stream. + + Args: + simulator: Simulator instance + width: Stream width in pixels + height: Stream height in pixels + fps: Frames per second + camera_path: Camera path in scene + annotator: Type of annotator to use + transport: Transport protocol + rtsp_url: RTSP stream URL + usd_path: Optional USD file path to load + """ + self.simulator = simulator + self.width = width + self.height = height + self.fps = fps + self.camera_path = camera_path + self.annotator_type = annotator_type + self.transport = transport + self.rtsp_url = rtsp_url + self.proc = None + + @abstractmethod + def _load_stage(self, usd_path: str | Path): # type: ignore[no-untyped-def] + """Load stage from file.""" + pass + + @abstractmethod + def _setup_camera(self): # type: ignore[no-untyped-def] + """Setup and validate camera.""" + pass + + def _setup_ffmpeg(self) -> None: + """Setup FFmpeg process for streaming.""" + command = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-pix_fmt", + "bgr24", + "-s", + f"{self.width}x{self.height}", + "-r", + str(self.fps), + "-i", + "-", + "-an", + "-c:v", + "h264_nvenc", + "-preset", + "fast", + "-f", + "rtsp", + "-rtsp_transport", + self.transport, + self.rtsp_url, + ] + self.proc = subprocess.Popen(command, stdin=subprocess.PIPE) # type: ignore[assignment] + + @abstractmethod + def _setup_annotator(self): # type: ignore[no-untyped-def] + """Setup annotator.""" + pass + + @abstractmethod + def stream(self): # type: ignore[no-untyped-def] + """Start streaming.""" + pass + + @abstractmethod + def cleanup(self): # type: ignore[no-untyped-def] + """Cleanup resources.""" + pass diff --git a/dimos/simulation/genesis/__init__.py b/dimos/simulation/genesis/__init__.py new file mode 100644 index 0000000000..5657d9167b --- /dev/null +++ b/dimos/simulation/genesis/__init__.py @@ -0,0 +1,4 @@ +from .simulator import GenesisSimulator +from .stream import GenesisStream + +__all__ = ["GenesisSimulator", "GenesisStream"] diff --git a/dimos/simulation/genesis/simulator.py b/dimos/simulation/genesis/simulator.py new file mode 100644 index 0000000000..3d045d0e24 --- /dev/null +++ b/dimos/simulation/genesis/simulator.py @@ -0,0 +1,159 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import genesis as gs # type: ignore + +from ..base.simulator_base import SimulatorBase + + +class GenesisSimulator(SimulatorBase): + """Genesis simulator implementation.""" + + def __init__( + self, + headless: bool = True, + open_usd: str | None = None, # Keep for compatibility + entities: list[dict[str, str | dict]] | None = None, # type: ignore[type-arg] + ) -> None: + """Initialize the Genesis simulation. + + Args: + headless: Whether to run without visualization + open_usd: Path to USD file (for Isaac) + entities: List of entity configurations to load. Each entity is a dict with: + - type: str ('mesh', 'urdf', 'mjcf', 'primitive') + - path: str (file path for mesh/urdf/mjcf) + - params: dict (parameters for primitives or loading options) + """ + super().__init__(headless, open_usd, entities) + + # Initialize Genesis + gs.init() + + # Create scene with viewer options + self.scene = gs.Scene( + show_viewer=not headless, + viewer_options=gs.options.ViewerOptions( + res=(1280, 960), + camera_pos=(3.5, 0.0, 2.5), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=40, + max_FPS=60, + ), + vis_options=gs.options.VisOptions( + show_world_frame=True, + world_frame_size=1.0, + show_link_frame=False, + show_cameras=False, + plane_reflection=True, + ambient_light=(0.1, 0.1, 0.1), + ), + renderer=gs.renderers.Rasterizer(), + ) + + # Handle USD parameter for compatibility + if open_usd: + print(f"[Warning] USD files not supported in Genesis. Ignoring: {open_usd}") + + # Load entities if provided + if entities: + self._load_entities(entities) + + # Don't build scene yet - let stream add camera first + self.is_built = False + + def _load_entities(self, entities: list[dict[str, str | dict]]): # type: ignore[no-untyped-def, type-arg] + """Load multiple entities into the scene.""" + for entity in entities: + entity_type = entity.get("type", "").lower() # type: ignore[union-attr] + path = entity.get("path", "") + params = entity.get("params", {}) + + try: + if entity_type == "mesh": + mesh = gs.morphs.Mesh( + file=path, # Explicit file argument + **params, + ) + self.scene.add_entity(mesh) + print(f"[Genesis] Added mesh from {path}") + + elif entity_type == "urdf": + robot = gs.morphs.URDF( + file=path, # Explicit file argument + **params, + ) + self.scene.add_entity(robot) + print(f"[Genesis] Added URDF robot from {path}") + + elif entity_type == "mjcf": + mujoco = gs.morphs.MJCF( + file=path, # Explicit file argument + **params, + ) + self.scene.add_entity(mujoco) + print(f"[Genesis] Added MJCF model from {path}") + + elif entity_type == "primitive": + shape_type = params.pop("shape", "plane") # type: ignore[union-attr] + if shape_type == "plane": + morph = gs.morphs.Plane(**params) + elif shape_type == "box": + morph = gs.morphs.Box(**params) + elif shape_type == "sphere": + morph = gs.morphs.Sphere(**params) + else: + raise ValueError(f"Unsupported primitive shape: {shape_type}") + + # Add position if not specified + if "pos" not in params: + if shape_type == "plane": + morph.pos = [0, 0, 0] + else: + morph.pos = [0, 0, 1] # Lift objects above ground + + self.scene.add_entity(morph) + print(f"[Genesis] Added {shape_type} at position {morph.pos}") + + else: + raise ValueError(f"Unsupported entity type: {entity_type}") + + except Exception as e: + print(f"[Warning] Failed to load entity {entity}: {e!s}") + + def add_entity(self, entity_type: str, path: str = "", **params) -> None: # type: ignore[no-untyped-def] + """Add a single entity to the scene. + + Args: + entity_type: Type of entity ('mesh', 'urdf', 'mjcf', 'primitive') + path: File path for mesh/urdf/mjcf entities + **params: Additional parameters for entity creation + """ + self._load_entities([{"type": entity_type, "path": path, "params": params}]) + + def get_stage(self): # type: ignore[no-untyped-def] + """Get the current stage/scene.""" + return self.scene + + def build(self) -> None: + """Build the scene if not already built.""" + if not self.is_built: + self.scene.build() + self.is_built = True + + def close(self) -> None: + """Close the simulation.""" + # Genesis handles cleanup automatically + pass diff --git a/dimos/simulation/genesis/stream.py b/dimos/simulation/genesis/stream.py new file mode 100644 index 0000000000..0d3bcc6832 --- /dev/null +++ b/dimos/simulation/genesis/stream.py @@ -0,0 +1,144 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import time + +import cv2 +import numpy as np + +from ..base.stream_base import AnnotatorType, StreamBase, TransportType + + +class GenesisStream(StreamBase): + """Genesis stream implementation.""" + + def __init__( # type: ignore[no-untyped-def] + self, + simulator, + width: int = 1920, + height: int = 1080, + fps: int = 60, + camera_path: str = "/camera", + annotator_type: AnnotatorType = "rgb", + transport: TransportType = "tcp", + rtsp_url: str = "rtsp://mediamtx:8554/stream", + usd_path: str | Path | None = None, + ) -> None: + """Initialize the Genesis stream.""" + super().__init__( + simulator=simulator, + width=width, + height=height, + fps=fps, + camera_path=camera_path, + annotator_type=annotator_type, + transport=transport, + rtsp_url=rtsp_url, + usd_path=usd_path, + ) + + self.scene = simulator.get_stage() + + # Initialize components + if usd_path: + self._load_stage(usd_path) + self._setup_camera() + self._setup_ffmpeg() + self._setup_annotator() + + # Build scene after camera is set up + simulator.build() + + def _load_stage(self, usd_path: str | Path) -> None: + """Load stage from file.""" + # Genesis handles stage loading through simulator + pass + + def _setup_camera(self) -> None: + """Setup and validate camera.""" + self.camera = self.scene.add_camera( + res=(self.width, self.height), + pos=(3.5, 0.0, 2.5), + lookat=(0, 0, 0.5), + fov=30, + GUI=False, + ) + + def _setup_annotator(self) -> None: + """Setup the specified annotator.""" + # Genesis handles different render types through camera.render() + pass + + def stream(self) -> None: + """Start the streaming loop.""" + try: + print("[Stream] Starting Genesis camera stream...") + frame_count = 0 + start_time = time.time() + + while True: + frame_start = time.time() + + # Step simulation and get frame + step_start = time.time() + self.scene.step() + step_time = time.time() - step_start + print(f"[Stream] Simulation step took {step_time * 1000:.2f}ms") + + # Get frame based on annotator type + if self.annotator_type == "rgb": + frame, _, _, _ = self.camera.render(rgb=True) + elif self.annotator_type == "normals": + _, _, _, frame = self.camera.render(normal=True) + else: + frame, _, _, _ = self.camera.render(rgb=True) # Default to RGB + + # Convert frame format if needed + if isinstance(frame, np.ndarray): + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + # Write to FFmpeg + self.proc.stdin.write(frame.tobytes()) # type: ignore[attr-defined] + self.proc.stdin.flush() # type: ignore[attr-defined] + + # Log metrics + frame_time = time.time() - frame_start + print(f"[Stream] Total frame processing took {frame_time * 1000:.2f}ms") + frame_count += 1 + + if frame_count % 100 == 0: + elapsed_time = time.time() - start_time + current_fps = frame_count / elapsed_time + print( + f"[Stream] Processed {frame_count} frames | Current FPS: {current_fps:.2f}" + ) + + except KeyboardInterrupt: + print("\n[Stream] Received keyboard interrupt, stopping stream...") + finally: + self.cleanup() + + def cleanup(self) -> None: + """Cleanup resources.""" + print("[Cleanup] Stopping FFmpeg process...") + if hasattr(self, "proc"): + self.proc.stdin.close() # type: ignore[attr-defined] + self.proc.wait() # type: ignore[attr-defined] + print("[Cleanup] Closing simulation...") + try: + self.simulator.close() + except AttributeError: + print("[Cleanup] Warning: Could not close simulator properly") + print("[Cleanup] Successfully cleaned up resources") diff --git a/dimos/simulation/isaac/__init__.py b/dimos/simulation/isaac/__init__.py new file mode 100644 index 0000000000..2b9bdc082d --- /dev/null +++ b/dimos/simulation/isaac/__init__.py @@ -0,0 +1,4 @@ +from .simulator import IsaacSimulator +from .stream import IsaacStream + +__all__ = ["IsaacSimulator", "IsaacStream"] diff --git a/dimos/simulation/isaac/simulator.py b/dimos/simulation/isaac/simulator.py new file mode 100644 index 0000000000..1b524e1cb5 --- /dev/null +++ b/dimos/simulation/isaac/simulator.py @@ -0,0 +1,44 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from isaacsim import SimulationApp # type: ignore[import-not-found] + +from ..base.simulator_base import SimulatorBase + + +class IsaacSimulator(SimulatorBase): + """Isaac Sim simulator implementation.""" + + def __init__( + self, + headless: bool = True, + open_usd: str | None = None, + entities: list[dict[str, str | dict]] | None = None, # type: ignore[type-arg] # Add but ignore + ) -> None: + """Initialize the Isaac Sim simulation.""" + super().__init__(headless, open_usd) + self.app = SimulationApp({"headless": headless, "open_usd": open_usd}) + + def get_stage(self): # type: ignore[no-untyped-def] + """Get the current USD stage.""" + import omni.usd # type: ignore[import-not-found] + + self.stage = omni.usd.get_context().get_stage() + return self.stage + + def close(self) -> None: + """Close the simulation.""" + if hasattr(self, "app"): + self.app.close() diff --git a/dimos/simulation/isaac/stream.py b/dimos/simulation/isaac/stream.py new file mode 100644 index 0000000000..e927c4bad4 --- /dev/null +++ b/dimos/simulation/isaac/stream.py @@ -0,0 +1,137 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import time + +import cv2 + +from ..base.stream_base import AnnotatorType, StreamBase, TransportType + + +class IsaacStream(StreamBase): + """Isaac Sim stream implementation.""" + + def __init__( # type: ignore[no-untyped-def] + self, + simulator, + width: int = 1920, + height: int = 1080, + fps: int = 60, + camera_path: str = "/World/alfred_parent_prim/alfred_base_descr/chest_cam_rgb_camera_frame/chest_cam", + annotator_type: AnnotatorType = "rgb", + transport: TransportType = "tcp", + rtsp_url: str = "rtsp://mediamtx:8554/stream", + usd_path: str | Path | None = None, + ) -> None: + """Initialize the Isaac Sim stream.""" + super().__init__( + simulator=simulator, + width=width, + height=height, + fps=fps, + camera_path=camera_path, + annotator_type=annotator_type, + transport=transport, + rtsp_url=rtsp_url, + usd_path=usd_path, + ) + + # Import omni.replicator after SimulationApp initialization + import omni.replicator.core as rep # type: ignore[import-not-found] + + self.rep = rep + + # Initialize components + if usd_path: + self._load_stage(usd_path) + self._setup_camera() # type: ignore[no-untyped-call] + self._setup_ffmpeg() + self._setup_annotator() + + def _load_stage(self, usd_path: str | Path): # type: ignore[no-untyped-def] + """Load USD stage from file.""" + import omni.usd # type: ignore[import-not-found] + + abs_path = str(Path(usd_path).resolve()) + omni.usd.get_context().open_stage(abs_path) + self.stage = self.simulator.get_stage() + if not self.stage: + raise RuntimeError(f"Failed to load stage: {abs_path}") + + def _setup_camera(self): # type: ignore[no-untyped-def] + """Setup and validate camera.""" + self.stage = self.simulator.get_stage() + camera_prim = self.stage.GetPrimAtPath(self.camera_path) + if not camera_prim: + raise RuntimeError(f"Failed to find camera at path: {self.camera_path}") + + self.render_product = self.rep.create.render_product( + self.camera_path, resolution=(self.width, self.height) + ) + + def _setup_annotator(self) -> None: + """Setup the specified annotator.""" + self.annotator = self.rep.AnnotatorRegistry.get_annotator(self.annotator_type) + self.annotator.attach(self.render_product) + + def stream(self) -> None: + """Start the streaming loop.""" + try: + print("[Stream] Starting camera stream loop...") + frame_count = 0 + start_time = time.time() + + while True: + frame_start = time.time() + + # Step simulation and get frame + step_start = time.time() + self.rep.orchestrator.step() + step_time = time.time() - step_start + print(f"[Stream] Simulation step took {step_time * 1000:.2f}ms") + + frame = self.annotator.get_data() + frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR) + + # Write to FFmpeg + self.proc.stdin.write(frame.tobytes()) # type: ignore[attr-defined] + self.proc.stdin.flush() # type: ignore[attr-defined] + + # Log metrics + frame_time = time.time() - frame_start + print(f"[Stream] Total frame processing took {frame_time * 1000:.2f}ms") + frame_count += 1 + + if frame_count % 100 == 0: + elapsed_time = time.time() - start_time + current_fps = frame_count / elapsed_time + print( + f"[Stream] Processed {frame_count} frames | Current FPS: {current_fps:.2f}" + ) + + except KeyboardInterrupt: + print("\n[Stream] Received keyboard interrupt, stopping stream...") + finally: + self.cleanup() + + def cleanup(self) -> None: + """Cleanup resources.""" + print("[Cleanup] Stopping FFmpeg process...") + if hasattr(self, "proc"): + self.proc.stdin.close() # type: ignore[attr-defined] + self.proc.wait() # type: ignore[attr-defined] + print("[Cleanup] Closing simulation...") + self.simulator.close() + print("[Cleanup] Successfully cleaned up resources") diff --git a/dimos/simulation/mujoco/constants.py b/dimos/simulation/mujoco/constants.py new file mode 100644 index 0000000000..22bd409ac8 --- /dev/null +++ b/dimos/simulation/mujoco/constants.py @@ -0,0 +1,35 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +# Video/Camera constants +VIDEO_WIDTH = 320 +VIDEO_HEIGHT = 240 +DEPTH_CAMERA_FOV = 160 + +# Depth camera range/filtering constants +MAX_RANGE = 3 +MIN_RANGE = 0.2 +MAX_HEIGHT = 1.2 + +# Lidar constants +LIDAR_RESOLUTION = 0.05 + +# Simulation timing constants +STEPS_PER_FRAME = 7 +VIDEO_FPS = 20 +LIDAR_FPS = 2 + +LAUNCHER_PATH = Path(__file__).parent / "mujoco_process.py" diff --git a/dimos/simulation/mujoco/depth_camera.py b/dimos/simulation/mujoco/depth_camera.py new file mode 100644 index 0000000000..486b740ffd --- /dev/null +++ b/dimos/simulation/mujoco/depth_camera.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any + +import numpy as np +from numpy.typing import NDArray +import open3d as o3d # type: ignore[import-untyped] + +from dimos.simulation.mujoco.constants import MAX_HEIGHT, MAX_RANGE, MIN_RANGE + + +def depth_image_to_point_cloud( + depth_image: NDArray[Any], + camera_pos: NDArray[Any], + camera_mat: NDArray[Any], + fov_degrees: float = 120, +) -> NDArray[Any]: + """ + Convert a depth image from a camera to a 3D point cloud using perspective projection. + + Args: + depth_image: 2D numpy array of depth values in meters + camera_pos: 3D position of camera in world coordinates + camera_mat: 3x3 camera rotation matrix in world coordinates + fov_degrees: Vertical field of view of the camera in degrees + min_range: Minimum distance from camera to include points (meters) + + Returns: + numpy array of 3D points in world coordinates, shape (N, 3) + """ + height, width = depth_image.shape + + # Calculate camera intrinsics similar to StackOverflow approach + fovy = math.radians(fov_degrees) + f = height / (2 * math.tan(fovy / 2)) # focal length in pixels + cx = width / 2 # principal point x + cy = height / 2 # principal point y + + # Create Open3D camera intrinsics + cam_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, f, f, cx, cy) + + # Convert numpy depth array to Open3D Image + o3d_depth = o3d.geometry.Image(depth_image.astype(np.float32)) + + # Create point cloud from depth image using Open3D + o3d_cloud = o3d.geometry.PointCloud.create_from_depth_image(o3d_depth, cam_intrinsics) + + # Convert Open3D point cloud to numpy array + camera_points: NDArray[Any] = np.asarray(o3d_cloud.points) + + if camera_points.size == 0: + return np.array([]).reshape(0, 3) + + # Flip y and z axes + camera_points[:, 1] = -camera_points[:, 1] + camera_points[:, 2] = -camera_points[:, 2] + + # y (index 1) is up here + valid_mask = ( + (np.abs(camera_points[:, 0]) <= MAX_RANGE) + & (np.abs(camera_points[:, 1]) <= MAX_HEIGHT) + & (np.abs(camera_points[:, 2]) >= MIN_RANGE) + & (np.abs(camera_points[:, 2]) <= MAX_RANGE) + ) + camera_points = camera_points[valid_mask] + + if camera_points.size == 0: + return np.array([]).reshape(0, 3) + + # Transform to world coordinates + world_points: NDArray[Any] = (camera_mat @ camera_points.T).T + camera_pos + + return world_points diff --git a/dimos/simulation/mujoco/input_controller.py b/dimos/simulation/mujoco/input_controller.py new file mode 100644 index 0000000000..e12e8a71c7 --- /dev/null +++ b/dimos/simulation/mujoco/input_controller.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Protocol + +from numpy.typing import NDArray + + +class InputController(Protocol): + """A protocol for input devices to control the robot.""" + + def get_command(self) -> NDArray[Any]: ... + def stop(self) -> None: ... diff --git a/dimos/simulation/mujoco/model.py b/dimos/simulation/mujoco/model.py new file mode 100644 index 0000000000..36cd3898ee --- /dev/null +++ b/dimos/simulation/mujoco/model.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import xml.etree.ElementTree as ET + +from etils import epath # type: ignore[import-untyped] +import mujoco # type: ignore[import-untyped] +from mujoco_playground._src import mjx_env # type: ignore[import-untyped] +import numpy as np + +from dimos.simulation.mujoco.input_controller import InputController +from dimos.simulation.mujoco.policy import G1OnnxController, Go1OnnxController, OnnxController + +DATA_DIR = epath.Path(__file__).parent / "../../../data/mujoco_sim" + + +def get_assets() -> dict[str, bytes]: + # Assets used from https://sketchfab.com/3d-models/mersus-office-8714be387bcd406898b2615f7dae3a47 + # Created by Ryan Cassidy and Coleman Costello + assets: dict[str, bytes] = {} + mjx_env.update_assets(assets, DATA_DIR, "*.xml") + mjx_env.update_assets(assets, DATA_DIR / "scene_office1/textures", "*.png") + mjx_env.update_assets(assets, DATA_DIR / "scene_office1/office_split", "*.obj") + mjx_env.update_assets(assets, mjx_env.MENAGERIE_PATH / "unitree_go1" / "assets") + mjx_env.update_assets(assets, mjx_env.MENAGERIE_PATH / "unitree_g1" / "assets") + return assets + + +def load_model( + input_device: InputController, robot: str, scene: str +) -> tuple[mujoco.MjModel, mujoco.MjData]: + mujoco.set_mjcb_control(None) + + xml_string = get_model_xml(robot, scene) + model = mujoco.MjModel.from_xml_string(xml_string, assets=get_assets()) + data = mujoco.MjData(model) + + mujoco.mj_resetDataKeyframe(model, data, 0) + + match robot: + case "unitree_g1": + sim_dt = 0.002 + case _: + sim_dt = 0.005 + + ctrl_dt = 0.02 + n_substeps = round(ctrl_dt / sim_dt) + model.opt.timestep = sim_dt + + params = { + "policy_path": (DATA_DIR / f"{robot}_policy.onnx").as_posix(), + "default_angles": np.array(model.keyframe("home").qpos[7:]), + "n_substeps": n_substeps, + "action_scale": 0.5, + "input_controller": input_device, + "ctrl_dt": ctrl_dt, + } + + match robot: + case "unitree_go1": + policy: OnnxController = Go1OnnxController(**params) + case "unitree_g1": + policy = G1OnnxController(**params, drift_compensation=[-0.18, 0.0, -0.09]) + case _: + raise ValueError(f"Unknown robot policy: {robot}") + + mujoco.set_mjcb_control(policy.get_control) + + return model, data + + +def get_model_xml(robot: str, scene: str) -> str: + xml_file = (DATA_DIR / f"scene_{scene}.xml").as_posix() + + tree = ET.parse(xml_file) + root = tree.getroot() + root.set("model", f"{robot}_{scene}") + root.insert(0, ET.Element("include", file=f"{robot}.xml")) + return ET.tostring(root, encoding="unicode") diff --git a/dimos/simulation/mujoco/mujoco_process.py b/dimos/simulation/mujoco/mujoco_process.py new file mode 100755 index 0000000000..b75ce74cfd --- /dev/null +++ b/dimos/simulation/mujoco/mujoco_process.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import pickle +import signal +import sys +import time +from typing import Any + +import mujoco # type: ignore[import-untyped] +from mujoco import viewer +import numpy as np +from numpy.typing import NDArray +import open3d as o3d # type: ignore[import-untyped] + +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.simulation.mujoco.constants import ( + DEPTH_CAMERA_FOV, + LIDAR_FPS, + LIDAR_RESOLUTION, + STEPS_PER_FRAME, + VIDEO_FPS, + VIDEO_HEIGHT, + VIDEO_WIDTH, +) +from dimos.simulation.mujoco.depth_camera import depth_image_to_point_cloud +from dimos.simulation.mujoco.model import load_model +from dimos.simulation.mujoco.shared_memory import ShmReader +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class MockController: + """Controller that reads commands from shared memory.""" + + def __init__(self, shm_interface: ShmReader) -> None: + self.shm = shm_interface + self._command = np.zeros(3, dtype=np.float32) + + def get_command(self) -> NDArray[Any]: + """Get the current movement command.""" + cmd_data = self.shm.read_command() + if cmd_data is not None: + linear, angular = cmd_data + # MuJoCo expects [forward, lateral, rotational] + self._command[0] = linear[0] # forward/backward + self._command[1] = linear[1] # left/right + self._command[2] = angular[2] # rotation + return self._command.copy() + + def stop(self) -> None: + """Stop method to satisfy InputController protocol.""" + pass + + +def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: + robot_name = config.robot_model or "unitree_go1" + if robot_name == "unitree_go2": + robot_name = "unitree_go1" + + mujoco_room = getattr(config, "mujoco_room", "office1") + if mujoco_room is None: + mujoco_room = "office1" + + controller = MockController(shm) + model, data = load_model(controller, robot=robot_name, scene=mujoco_room) + + if model is None or data is None: + raise ValueError("Failed to load MuJoCo model: model or data is None") + + match robot_name: + case "unitree_go1": + z = 0.3 + case "unitree_g1": + z = 0.8 + case _: + z = 0 + + data.qpos[0:3] = [-1, 1, z] + + mujoco.mj_forward(model, data) + + camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera") + lidar_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_front_camera") + lidar_left_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_left_camera") + lidar_right_camera_id = mujoco.mj_name2id( + model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_right_camera" + ) + + shm.signal_ready() + + with viewer.launch_passive(model, data, show_left_ui=False, show_right_ui=False) as m_viewer: + camera_size = (VIDEO_WIDTH, VIDEO_HEIGHT) + + # Create renderers + rgb_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_renderer.enable_depth_rendering() + + depth_left_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_left_renderer.enable_depth_rendering() + + depth_right_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_right_renderer.enable_depth_rendering() + + scene_option = mujoco.MjvOption() + + # Timing control + last_video_time = 0.0 + last_lidar_time = 0.0 + video_interval = 1.0 / VIDEO_FPS + lidar_interval = 1.0 / LIDAR_FPS + + while m_viewer.is_running() and not shm.should_stop(): + step_start = time.time() + + # Step simulation + for _ in range(STEPS_PER_FRAME): + mujoco.mj_step(model, data) + + m_viewer.sync() + + # Always update odometry + pos = data.qpos[0:3].copy() + quat = data.qpos[3:7].copy() # (w, x, y, z) + shm.write_odom(pos, quat, time.time()) + + current_time = time.time() + + # Video rendering + if current_time - last_video_time >= video_interval: + rgb_renderer.update_scene(data, camera=camera_id, scene_option=scene_option) + pixels = rgb_renderer.render() + shm.write_video(pixels) + last_video_time = current_time + + # Lidar/depth rendering + if current_time - last_lidar_time >= lidar_interval: + # Render all depth cameras + depth_renderer.update_scene(data, camera=lidar_camera_id, scene_option=scene_option) + depth_front = depth_renderer.render() + + depth_left_renderer.update_scene( + data, camera=lidar_left_camera_id, scene_option=scene_option + ) + depth_left = depth_left_renderer.render() + + depth_right_renderer.update_scene( + data, camera=lidar_right_camera_id, scene_option=scene_option + ) + depth_right = depth_right_renderer.render() + + shm.write_depth(depth_front, depth_left, depth_right) + + # Process depth images into lidar message + all_points = [] + cameras_data = [ + ( + depth_front, + data.cam_xpos[lidar_camera_id], + data.cam_xmat[lidar_camera_id].reshape(3, 3), + ), + ( + depth_left, + data.cam_xpos[lidar_left_camera_id], + data.cam_xmat[lidar_left_camera_id].reshape(3, 3), + ), + ( + depth_right, + data.cam_xpos[lidar_right_camera_id], + data.cam_xmat[lidar_right_camera_id].reshape(3, 3), + ), + ] + + for depth_image, camera_pos, camera_mat in cameras_data: + points = depth_image_to_point_cloud( + depth_image, camera_pos, camera_mat, fov_degrees=DEPTH_CAMERA_FOV + ) + if points.size > 0: + all_points.append(points) + + if all_points: + combined_points = np.vstack(all_points) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(combined_points) + pcd = pcd.voxel_down_sample(voxel_size=LIDAR_RESOLUTION) + + lidar_msg = LidarMessage( + pointcloud=pcd, + ts=time.time(), + origin=Vector3(pos[0], pos[1], pos[2]), + resolution=LIDAR_RESOLUTION, + ) + shm.write_lidar(lidar_msg) + + last_lidar_time = current_time + + # Control simulation speed + time_until_next_step = model.opt.timestep - (time.time() - step_start) + if time_until_next_step > 0: + time.sleep(time_until_next_step) + + +if __name__ == "__main__": + + def signal_handler(_signum: int, _frame: Any) -> None: + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + global_config = pickle.loads(base64.b64decode(sys.argv[1])) + shm_names = json.loads(sys.argv[2]) + + shm = ShmReader(shm_names) + try: + _run_simulation(global_config, shm) + finally: + shm.cleanup() diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py new file mode 100644 index 0000000000..2a5113b216 --- /dev/null +++ b/dimos/simulation/mujoco/policy.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Any + +import mujoco # type: ignore[import-untyped] +import numpy as np +import onnxruntime as rt # type: ignore[import-untyped] + +from dimos.simulation.mujoco.input_controller import InputController + + +class OnnxController(ABC): + def __init__( + self, + policy_path: str, + default_angles: np.ndarray[Any, Any], + n_substeps: int, + action_scale: float, + input_controller: InputController, + ctrl_dt: float | None = None, + drift_compensation: list[float] | None = None, + ) -> None: + self._output_names = ["continuous_actions"] + self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) + + self._action_scale = action_scale + self._default_angles = default_angles + self._last_action = np.zeros_like(default_angles, dtype=np.float32) + + self._counter = 0 + self._n_substeps = n_substeps + self._input_controller = input_controller + + self._drift_compensation = np.array(drift_compensation or [0.0, 0.0, 0.0], dtype=np.float32) + + @abstractmethod + def get_obs(self, model: mujoco.MjModel, data: mujoco.MjData) -> np.ndarray[Any, Any]: + pass + + def get_control(self, model: mujoco.MjModel, data: mujoco.MjData) -> None: + self._counter += 1 + if self._counter % self._n_substeps == 0: + obs = self.get_obs(model, data) + onnx_input = {"obs": obs.reshape(1, -1)} + onnx_pred = self._policy.run(self._output_names, onnx_input)[0][0] + self._last_action = onnx_pred.copy() + data.ctrl[:] = onnx_pred * self._action_scale + self._default_angles + self._post_control_update() + + def _post_control_update(self) -> None: # noqa: B027 + pass + + +class Go1OnnxController(OnnxController): + def get_obs(self, model: mujoco.MjModel, data: mujoco.MjData) -> np.ndarray[Any, Any]: + linvel = data.sensor("local_linvel").data + gyro = data.sensor("gyro").data + imu_xmat = data.site_xmat[model.site("imu").id].reshape(3, 3) + gravity = imu_xmat.T @ np.array([0, 0, -1]) + joint_angles = data.qpos[7:] - self._default_angles + joint_velocities = data.qvel[6:] + obs = np.hstack( + [ + linvel, + gyro, + gravity, + joint_angles, + joint_velocities, + self._last_action, + self._input_controller.get_command(), + ] + ) + return obs.astype(np.float32) + + +class G1OnnxController(OnnxController): + def __init__( + self, + policy_path: str, + default_angles: np.ndarray[Any, Any], + ctrl_dt: float, + n_substeps: int, + action_scale: float, + input_controller: InputController, + drift_compensation: list[float] | None = None, + ) -> None: + super().__init__( + policy_path, + default_angles, + n_substeps, + action_scale, + input_controller, + ctrl_dt, + drift_compensation, + ) + + self._phase = np.array([0.0, np.pi]) + self._gait_freq = 1.5 + self._phase_dt = 2 * np.pi * self._gait_freq * ctrl_dt + + def get_obs(self, model: mujoco.MjModel, data: mujoco.MjData) -> np.ndarray[Any, Any]: + linvel = data.sensor("local_linvel_pelvis").data + gyro = data.sensor("gyro_pelvis").data + imu_xmat = data.site_xmat[model.site("imu_in_pelvis").id].reshape(3, 3) + gravity = imu_xmat.T @ np.array([0, 0, -1]) + joint_angles = data.qpos[7:] - self._default_angles + joint_velocities = data.qvel[6:] + phase = np.concatenate([np.cos(self._phase), np.sin(self._phase)]) + command = self._input_controller.get_command() + command[0] = command[0] * 2 + command[1] = command[1] * 2 + command[0] += self._drift_compensation[0] + command[1] += self._drift_compensation[1] + command[2] += self._drift_compensation[2] + obs = np.hstack( + [ + linvel, + gyro, + gravity, + command, + joint_angles, + joint_velocities, + self._last_action, + phase, + ] + ) + return obs.astype(np.float32) + + def _post_control_update(self) -> None: + phase_tp1 = self._phase + self._phase_dt + self._phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi diff --git a/dimos/simulation/mujoco/shared_memory.py b/dimos/simulation/mujoco/shared_memory.py new file mode 100644 index 0000000000..4c22062233 --- /dev/null +++ b/dimos/simulation/mujoco/shared_memory.py @@ -0,0 +1,286 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from multiprocessing import resource_tracker +from multiprocessing.shared_memory import SharedMemory +import pickle +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.simulation.mujoco.constants import VIDEO_HEIGHT, VIDEO_WIDTH +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Video buffer: VIDEO_WIDTH x VIDEO_HEIGHT x 3 RGB +_video_size = VIDEO_WIDTH * VIDEO_HEIGHT * 3 +# Depth buffers: 3 cameras x VIDEO_WIDTH x VIDEO_HEIGHT float32 +_depth_size = VIDEO_WIDTH * VIDEO_HEIGHT * 4 # float32 = 4 bytes +# Odometry buffer: position(3) + quaternion(4) + timestamp(1) = 8 floats +_odom_size = 8 * 8 # 8 float64 values +# Command buffer: linear(3) + angular(3) = 6 floats +_cmd_size = 6 * 4 # 6 float32 values +# Lidar message buffer: for serialized lidar data +_lidar_size = 1024 * 1024 * 4 # 4MB should be enough for point cloud +# Sequence/version numbers for detecting updates +_seq_size = 8 * 8 # 8 int64 values for different data types +# Control buffer: ready flag + stop flag +_control_size = 2 * 4 # 2 int32 values + +_shm_sizes = { + "video": _video_size, + "depth_front": _depth_size, + "depth_left": _depth_size, + "depth_right": _depth_size, + "odom": _odom_size, + "cmd": _cmd_size, + "lidar": _lidar_size, + "lidar_len": 4, + "seq": _seq_size, + "control": _control_size, +} + + +def _unregister(shm: SharedMemory) -> SharedMemory: + try: + resource_tracker.unregister(shm._name, "shared_memory") # type: ignore[attr-defined] + except Exception: + pass + return shm + + +@dataclass(frozen=True) +class ShmSet: + video: SharedMemory + depth_front: SharedMemory + depth_left: SharedMemory + depth_right: SharedMemory + odom: SharedMemory + cmd: SharedMemory + lidar: SharedMemory + lidar_len: SharedMemory + seq: SharedMemory + control: SharedMemory + + @classmethod + def from_names(cls, shm_names: dict[str, str]) -> "ShmSet": + return cls(**{k: _unregister(SharedMemory(name=shm_names[k])) for k in _shm_sizes.keys()}) + + @classmethod + def from_sizes(cls) -> "ShmSet": + return cls( + **{ + k: _unregister(SharedMemory(create=True, size=_shm_sizes[k])) + for k in _shm_sizes.keys() + } + ) + + def to_names(self) -> dict[str, str]: + return {k: getattr(self, k).name for k in _shm_sizes.keys()} + + def as_list(self) -> list[SharedMemory]: + return [getattr(self, k) for k in _shm_sizes.keys()] + + +class ShmReader: + shm: ShmSet + _last_cmd_seq: int + + def __init__(self, shm_names: dict[str, str]) -> None: + self.shm = ShmSet.from_names(shm_names) + self._last_cmd_seq = 0 + + def signal_ready(self) -> None: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[0] = 1 # ready flag + + def should_stop(self) -> bool: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + return bool(control_array[1] == 1) # stop flag + + def write_video(self, pixels: NDArray[Any]) -> None: + video_array: NDArray[Any] = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH, 3), dtype=np.uint8, buffer=self.shm.video.buf + ) + video_array[:] = pixels + self._increment_seq(0) + + def write_depth(self, front: NDArray[Any], left: NDArray[Any], right: NDArray[Any]) -> None: + # Front camera + depth_array: NDArray[Any] = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH), dtype=np.float32, buffer=self.shm.depth_front.buf + ) + depth_array[:] = front + + # Left camera + depth_array = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH), dtype=np.float32, buffer=self.shm.depth_left.buf + ) + depth_array[:] = left + + # Right camera + depth_array = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH), dtype=np.float32, buffer=self.shm.depth_right.buf + ) + depth_array[:] = right + + self._increment_seq(1) + + def write_odom(self, pos: NDArray[Any], quat: NDArray[Any], timestamp: float) -> None: + odom_array: NDArray[Any] = np.ndarray((8,), dtype=np.float64, buffer=self.shm.odom.buf) + odom_array[0:3] = pos + odom_array[3:7] = quat + odom_array[7] = timestamp + self._increment_seq(2) + + def write_lidar(self, lidar_msg: LidarMessage) -> None: + data = pickle.dumps(lidar_msg) + data_len = len(data) + + if data_len > self.shm.lidar.size: + logger.error(f"Lidar data too large: {data_len} > {self.shm.lidar.size}") + return + + # Write length + len_array: NDArray[Any] = np.ndarray((1,), dtype=np.uint32, buffer=self.shm.lidar_len.buf) + len_array[0] = data_len + + # Write data + lidar_array: NDArray[Any] = np.ndarray( + (data_len,), dtype=np.uint8, buffer=self.shm.lidar.buf + ) + lidar_array[:] = np.frombuffer(data, dtype=np.uint8) + + self._increment_seq(4) + + def read_command(self) -> tuple[NDArray[Any], NDArray[Any]] | None: + seq = self._get_seq(3) + if seq > self._last_cmd_seq: + self._last_cmd_seq = seq + cmd_array: NDArray[Any] = np.ndarray((6,), dtype=np.float32, buffer=self.shm.cmd.buf) + linear = cmd_array[0:3].copy() + angular = cmd_array[3:6].copy() + return linear, angular + return None + + def _increment_seq(self, index: int) -> None: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[index] += 1 + + def _get_seq(self, index: int) -> int: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + return int(seq_array[index]) + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.close() + except Exception: + pass + + +class ShmWriter: + shm: ShmSet + + def __init__(self) -> None: + self.shm = ShmSet.from_sizes() + + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[:] = 0 + + cmd_array: NDArray[Any] = np.ndarray((6,), dtype=np.float32, buffer=self.shm.cmd.buf) + cmd_array[:] = 0 + + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[:] = 0 # [ready_flag, stop_flag] + + def is_ready(self) -> bool: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + return bool(control_array[0] == 1) + + def signal_stop(self) -> None: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[1] = 1 # Set stop flag + + def read_video(self) -> tuple[NDArray[Any] | None, int]: + seq = self._get_seq(0) + if seq > 0: + video_array: NDArray[Any] = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH, 3), dtype=np.uint8, buffer=self.shm.video.buf + ) + return video_array.copy(), seq + return None, 0 + + def read_odom(self) -> tuple[tuple[NDArray[Any], NDArray[Any], float] | None, int]: + seq = self._get_seq(2) + if seq > 0: + odom_array: NDArray[Any] = np.ndarray((8,), dtype=np.float64, buffer=self.shm.odom.buf) + pos = odom_array[0:3].copy() + quat = odom_array[3:7].copy() + timestamp = odom_array[7] + return (pos, quat, timestamp), seq + return None, 0 + + def write_command(self, linear: NDArray[Any], angular: NDArray[Any]) -> None: + cmd_array: NDArray[Any] = np.ndarray((6,), dtype=np.float32, buffer=self.shm.cmd.buf) + cmd_array[0:3] = linear + cmd_array[3:6] = angular + self._increment_seq(3) + + def read_lidar(self) -> tuple[LidarMessage | None, int]: + seq = self._get_seq(4) + if seq > 0: + # Read length + len_array: NDArray[Any] = np.ndarray( + (1,), dtype=np.uint32, buffer=self.shm.lidar_len.buf + ) + data_len = int(len_array[0]) + + if data_len > 0 and data_len <= self.shm.lidar.size: + # Read data + lidar_array: NDArray[Any] = np.ndarray( + (data_len,), dtype=np.uint8, buffer=self.shm.lidar.buf + ) + data = bytes(lidar_array) + + try: + lidar_msg = pickle.loads(data) + return lidar_msg, seq + except Exception as e: + logger.error(f"Failed to deserialize lidar message: {e}") + return None, 0 + + def _increment_seq(self, index: int) -> None: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[index] += 1 + + def _get_seq(self, index: int) -> int: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + return int(seq_array[index]) + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.unlink() + except Exception: + pass + + try: + shm.close() + except Exception: + pass diff --git a/dimos/skills/__init__.py b/dimos/skills/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/skills/kill_skill.py b/dimos/skills/kill_skill.py new file mode 100644 index 0000000000..f0ca805e6f --- /dev/null +++ b/dimos/skills/kill_skill.py @@ -0,0 +1,61 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Kill skill for terminating running skills. + +This module provides a skill that can terminate other running skills, +particularly those running in separate threads like the monitor skill. +""" + +from pydantic import Field + +from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class KillSkill(AbstractSkill): + """ + A skill that terminates other running skills. + + This skill can be used to stop long-running or background skills + like the monitor skill. It uses the centralized process management + in the SkillLibrary to track and terminate skills. + """ + + skill_name: str = Field(..., description="Name of the skill to terminate") + + def __init__(self, skill_library: SkillLibrary | None = None, **data) -> None: # type: ignore[no-untyped-def] + """ + Initialize the kill skill. + + Args: + skill_library: The skill library instance + **data: Additional data for configuration + """ + super().__init__(**data) + self._skill_library = skill_library + + def __call__(self): # type: ignore[no-untyped-def] + """ + Terminate the specified skill. + + Returns: + A message indicating whether the skill was successfully terminated + """ + print("running skills", self._skill_library.get_running_skills()) # type: ignore[union-attr] + # Terminate the skill using the skill library + return self._skill_library.terminate_skill(self.skill_name) # type: ignore[union-attr] diff --git a/dimos/skills/manipulation/abstract_manipulation_skill.py b/dimos/skills/manipulation/abstract_manipulation_skill.py new file mode 100644 index 0000000000..e767ad8c8f --- /dev/null +++ b/dimos/skills/manipulation/abstract_manipulation_skill.py @@ -0,0 +1,58 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract base class for manipulation skills.""" + +from dimos.manipulation.manipulation_interface import ManipulationInterface +from dimos.robot.robot import Robot +from dimos.skills.skills import AbstractRobotSkill +from dimos.types.robot_capabilities import RobotCapability + + +class AbstractManipulationSkill(AbstractRobotSkill): + """Base class for all manipulation-related skills. + + This abstract class provides access to the robot's manipulation memory system. + """ + + def __init__(self, *args, robot: Robot | None = None, **kwargs) -> None: # type: ignore[no-untyped-def] + """Initialize the manipulation skill. + + Args: + robot: The robot instance to associate with this skill + """ + super().__init__(*args, robot=robot, **kwargs) + + if self._robot and not self._robot.manipulation_interface: # type: ignore[attr-defined] + raise NotImplementedError( + "This robot does not have a manipulation interface implemented" + ) + + @property + def manipulation_interface(self) -> ManipulationInterface | None: + """Get the robot's manipulation interface. + + Returns: + ManipulationInterface: The robot's manipulation interface or None if not available + + Raises: + RuntimeError: If the robot doesn't have the MANIPULATION capability + """ + if self._robot is None: + return None + + if not self._robot.has_capability(RobotCapability.MANIPULATION): + raise RuntimeError("This robot does not have manipulation capabilities") + + return self._robot.manipulation_interface # type: ignore[attr-defined, no-any-return] diff --git a/dimos/skills/manipulation/force_constraint_skill.py b/dimos/skills/manipulation/force_constraint_skill.py new file mode 100644 index 0000000000..edeac0844e --- /dev/null +++ b/dimos/skills/manipulation/force_constraint_skill.py @@ -0,0 +1,72 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.types.manipulation import ForceConstraint, Vector # type: ignore[attr-defined] +from dimos.utils.logging_config import setup_logger + +# Initialize logger +logger = setup_logger() + + +class ForceConstraintSkill(AbstractManipulationSkill): + """ + Skill for generating force constraints for robot manipulation. + + This skill generates force constraints and adds them to the ManipulationInterface's + agent_constraints list for tracking constraints created by the Agent. + """ + + # Constraint parameters + min_force: float = Field(0.0, description="Minimum force magnitude in Newtons") + max_force: float = Field(100.0, description="Maximum force magnitude in Newtons to apply") + + # Force direction as (x,y) tuple + force_direction: tuple[float, float] | None = Field( + None, description="Force direction vector (x,y)" + ) + + # Description + description: str = Field("", description="Description of the force constraint") + + def __call__(self) -> ForceConstraint: + """ + Generate a force constraint based on the parameters. + + Returns: + ForceConstraint: The generated constraint + """ + # Create force direction vector if provided (convert 2D point to 3D vector with z=0) + force_direction_vector = None + if self.force_direction: + force_direction_vector = Vector(self.force_direction[0], self.force_direction[1], 0.0) # type: ignore[arg-type] + + # Create and return the constraint + constraint = ForceConstraint( + max_force=self.max_force, + min_force=self.min_force, + force_direction=force_direction_vector, + description=self.description, + ) + + # Add constraint to manipulation interface for Agent recall + self.manipulation_interface.add_constraint(constraint) # type: ignore[union-attr] + + # Log the constraint creation + logger.info(f"Generated force constraint: {self.description}") + + return constraint diff --git a/dimos/skills/manipulation/manipulate_skill.py b/dimos/skills/manipulation/manipulate_skill.py new file mode 100644 index 0000000000..830ddc33e0 --- /dev/null +++ b/dimos/skills/manipulation/manipulate_skill.py @@ -0,0 +1,173 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Any +import uuid + +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.types.manipulation import ( + AbstractConstraint, + ManipulationMetadata, + ManipulationTask, + ManipulationTaskConstraint, +) +from dimos.utils.logging_config import setup_logger + +# Initialize logger +logger = setup_logger() + + +class Manipulate(AbstractManipulationSkill): + """ + Skill for executing manipulation tasks with constraints. + Can be called by an LLM with a list of manipulation constraints. + """ + + description: str = Field("", description="Description of the manipulation task") + + # Target object information + target_object: str = Field( + "", description="Semantic label of the target object (e.g., 'cup', 'box')" + ) + + target_point: str = Field( + "", description="(X,Y) point in pixel-space of the point to manipulate on target object" + ) + + # Constraints - can be set directly + constraints: list[str] = Field( + [], + description="List of AbstractConstraint constraint IDs from AgentMemory to apply to the manipulation task", + ) + + # Object movement tolerances + object_tolerances: dict[str, float] = Field( + {}, # Empty dict as default + description="Dictionary mapping object IDs to movement tolerances (0.0 = immovable, 1.0 = freely movable)", + ) + + def __call__(self) -> dict[str, Any]: + """ + Execute a manipulation task with the given constraints. + + Returns: + Dict[str, Any]: Result of the manipulation operation + """ + # Get the manipulation constraint + constraint = self._build_manipulation_constraint() + + # Create task with unique ID + task_id = f"{str(uuid.uuid4())[:4]}" + timestamp = time.time() + + # Build metadata with environment state + metadata = self._build_manipulation_metadata() + + task = ManipulationTask( + description=self.description, + target_object=self.target_object, + target_point=tuple(map(int, self.target_point.strip("()").split(","))), # type: ignore[arg-type] + constraints=constraint, + metadata=metadata, + timestamp=timestamp, + task_id=task_id, + result=None, + ) + + # Add task to manipulation interface + self.manipulation_interface.add_manipulation_task(task) # type: ignore[union-attr] + + # Execute the manipulation + result = self._execute_manipulation(task) + + # Log the execution + logger.info( + f"Executed manipulation '{self.description}' with constraints: {self.constraints}" + ) + + return result + + def _build_manipulation_metadata(self) -> ManipulationMetadata: + """ + Build metadata for the current environment state, including object data and movement tolerances. + """ + # Get detected objects from the manipulation interface + detected_objects = [] # type: ignore[var-annotated] + try: + detected_objects = self.manipulation_interface.get_latest_objects() or [] # type: ignore[union-attr] + except Exception as e: + logger.warning(f"Failed to get detected objects: {e}") + + # Create dictionary of objects keyed by ID for easier lookup + objects_by_id = {} + for obj in detected_objects: + obj_id = str(obj.get("object_id", -1)) + objects_by_id[obj_id] = dict(obj) # Make a copy to avoid modifying original + + # Create objects_data dictionary with tolerances applied + objects_data: dict[str, Any] = {} + + # First, apply all specified tolerances + for object_id, tolerance in self.object_tolerances.items(): + if object_id in objects_by_id: + # Object exists in detected objects, update its tolerance + obj_data = objects_by_id[object_id] + obj_data["movement_tolerance"] = tolerance + objects_data[object_id] = obj_data + + # Add any detected objects not explicitly given tolerances + for obj_id, obj in objects_by_id.items(): + if obj_id not in self.object_tolerances: + obj["movement_tolerance"] = 0.0 # Default to immovable + objects_data[obj_id] = obj + + # Create properly typed ManipulationMetadata + metadata: ManipulationMetadata = {"timestamp": time.time(), "objects": objects_data} + + return metadata + + def _build_manipulation_constraint(self) -> ManipulationTaskConstraint: + """ + Build a ManipulationTaskConstraint object from the provided parameters. + """ + + constraint = ManipulationTaskConstraint() + + # Add constraints directly or resolve from IDs + for c in self.constraints: + if isinstance(c, AbstractConstraint): + constraint.add_constraint(c) + elif isinstance(c, str) and self.manipulation_interface: + # Try to load constraint from ID + saved_constraint = self.manipulation_interface.get_constraint(c) + if saved_constraint: + constraint.add_constraint(saved_constraint) + + return constraint + + # TODO: Implement + def _execute_manipulation(self, task: ManipulationTask) -> dict[str, Any]: + """ + Execute the manipulation with the given constraint. + + Args: + task: The manipulation task to execute + + Returns: + Dict[str, Any]: Result of the manipulation operation + """ + return {"success": True} diff --git a/dimos/skills/manipulation/pick_and_place.py b/dimos/skills/manipulation/pick_and_place.py new file mode 100644 index 0000000000..1d1063edad --- /dev/null +++ b/dimos/skills/manipulation/pick_and_place.py @@ -0,0 +1,444 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Pick and place skill for Piper Arm robot. + +This module provides a skill that uses Qwen VLM to identify pick and place +locations based on natural language queries, then executes the manipulation. +""" + +import json +import os +from typing import Any + +import cv2 +import numpy as np +from pydantic import Field + +from dimos.models.qwen.video_query import query_single_frame +from dimos.skills.skills import AbstractRobotSkill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def parse_qwen_points_response(response: str) -> tuple[tuple[int, int], tuple[int, int]] | None: + """ + Parse Qwen's response containing two points. + + Args: + response: Qwen's response containing JSON with two points + + Returns: + Tuple of (pick_point, place_point) where each point is (x, y), or None if parsing fails + """ + try: + # Try to extract JSON from the response + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + + if start_idx >= 0 and end_idx > start_idx: + json_str = response[start_idx:end_idx] + result = json.loads(json_str) + + # Extract pick and place points + if "pick_point" in result and "place_point" in result: + pick = result["pick_point"] + place = result["place_point"] + + # Validate points have x,y coordinates + if ( + isinstance(pick, list | tuple) + and len(pick) >= 2 + and isinstance(place, list | tuple) + and len(place) >= 2 + ): + return (int(pick[0]), int(pick[1])), (int(place[0]), int(place[1])) + + except Exception as e: + logger.error(f"Error parsing Qwen points response: {e}") + logger.debug(f"Raw response: {response}") + + return None + + +def save_debug_image_with_points( + image: np.ndarray, # type: ignore[type-arg] + pick_point: tuple[int, int] | None = None, + place_point: tuple[int, int] | None = None, + filename_prefix: str = "qwen_debug", +) -> str: + """ + Save debug image with crosshairs marking pick and/or place points. + + Args: + image: RGB image array + pick_point: (x, y) coordinates for pick location + place_point: (x, y) coordinates for place location + filename_prefix: Prefix for the saved filename + + Returns: + Path to the saved image + """ + # Create a copy to avoid modifying original + debug_image = image.copy() + + # Draw pick point crosshair (green) + if pick_point: + x, y = pick_point + # Draw crosshair + cv2.drawMarker(debug_image, (x, y), (0, 255, 0), cv2.MARKER_CROSS, 30, 2) + # Draw circle + cv2.circle(debug_image, (x, y), 5, (0, 255, 0), -1) + # Add label + cv2.putText( + debug_image, "PICK", (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 + ) + + # Draw place point crosshair (cyan) + if place_point: + x, y = place_point + # Draw crosshair + cv2.drawMarker(debug_image, (x, y), (255, 255, 0), cv2.MARKER_CROSS, 30, 2) + # Draw circle + cv2.circle(debug_image, (x, y), 5, (255, 255, 0), -1) + # Add label + cv2.putText( + debug_image, "PLACE", (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2 + ) + + # Draw arrow from pick to place if both exist + if pick_point and place_point: + cv2.arrowedLine(debug_image, pick_point, place_point, (255, 0, 255), 2, tipLength=0.03) + + # Generate filename with timestamp + filename = f"{filename_prefix}.png" + filepath = os.path.join(os.getcwd(), filename) + + # Save image + cv2.imwrite(filepath, debug_image) + logger.info(f"Debug image saved to: {filepath}") + + return filepath + + +def parse_qwen_single_point_response(response: str) -> tuple[int, int] | None: + """ + Parse Qwen's response containing a single point. + + Args: + response: Qwen's response containing JSON with a point + + Returns: + Tuple of (x, y) or None if parsing fails + """ + try: + # Try to extract JSON from the response + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + + if start_idx >= 0 and end_idx > start_idx: + json_str = response[start_idx:end_idx] + result = json.loads(json_str) + + # Try different possible keys + point = None + for key in ["point", "location", "position", "coordinates"]: + if key in result: + point = result[key] + break + + # Validate point has x,y coordinates + if point and isinstance(point, list | tuple) and len(point) >= 2: + return int(point[0]), int(point[1]) + + except Exception as e: + logger.error(f"Error parsing Qwen single point response: {e}") + logger.debug(f"Raw response: {response}") + + return None + + +class PickAndPlace(AbstractRobotSkill): + """ + A skill that performs pick and place operations using vision-language guidance. + + This skill uses Qwen VLM to identify objects and locations based on natural + language queries, then executes pick and place operations using the robot's + manipulation interface. + + Example usage: + # Just pick the object + skill = PickAndPlace(robot=robot, object_query="red mug") + + # Pick and place the object + skill = PickAndPlace(robot=robot, object_query="red mug", target_query="on the coaster") + + The skill uses the robot's stereo camera to capture RGB images and its manipulation + interface to execute the pick and place operation. It automatically handles coordinate + transformation from 2D pixel coordinates to 3D world coordinates. + """ + + object_query: str = Field( + "mug", + description="Natural language description of the object to pick (e.g., 'red mug', 'small box')", + ) + + target_query: str | None = Field( + None, + description="Natural language description of where to place the object (e.g., 'on the table', 'in the basket'). If not provided, only pick operation will be performed.", + ) + + model_name: str = Field( + "qwen2.5-vl-72b-instruct", description="Qwen model to use for visual queries" + ) + + def __init__(self, robot=None, **data) -> None: # type: ignore[no-untyped-def] + """ + Initialize the PickAndPlace skill. + + Args: + robot: The PiperArmRobot instance + **data: Additional configuration data + """ + super().__init__(robot=robot, **data) + + def _get_camera_frame(self) -> np.ndarray | None: # type: ignore[type-arg] + """ + Get a single RGB frame from the robot's camera. + + Returns: + RGB image as numpy array or None if capture fails + """ + if not self._robot or not self._robot.manipulation_interface: # type: ignore[attr-defined] + logger.error("Robot or stereo camera not available") + return None + + try: + # Use the RPC call to get a single RGB frame + rgb_frame = self._robot.manipulation_interface.get_single_rgb_frame() # type: ignore[attr-defined] + if rgb_frame is None: + logger.error("Failed to capture RGB frame from camera") + return rgb_frame # type: ignore[no-any-return] + except Exception as e: + logger.error(f"Error getting camera frame: {e}") + return None + + def _query_pick_and_place_points( + self, + frame: np.ndarray, # type: ignore[type-arg] + ) -> tuple[tuple[int, int], tuple[int, int]] | None: + """ + Query Qwen to get both pick and place points in a single query. + + Args: + frame: RGB image array + + Returns: + Tuple of (pick_point, place_point) or None if query fails + """ + # This method is only called when both object and target are specified + prompt = ( + f"Look at this image carefully. I need you to identify two specific locations:\n" + f"1. Find the {self.object_query} - this is the object I want to pick up\n" + f"2. Identify where to place it {self.target_query}\n\n" + "Instructions:\n" + "- The pick_point should be at the center or graspable part of the object\n" + "- The place_point should be a stable, flat surface at the target location\n" + "- Consider the object's size when choosing the placement point\n\n" + "Return ONLY a JSON object with this exact format:\n" + "{'pick_point': [x, y], 'place_point': [x, y]}\n" + "where [x, y] are pixel coordinates in the image." + ) + + try: + response = query_single_frame(frame, prompt, model_name=self.model_name) + return parse_qwen_points_response(response) + except Exception as e: + logger.error(f"Error querying Qwen for pick and place points: {e}") + return None + + def _query_single_point( + self, + frame: np.ndarray, # type: ignore[type-arg] + query: str, + point_type: str, + ) -> tuple[int, int] | None: + """ + Query Qwen to get a single point location. + + Args: + frame: RGB image array + query: Natural language description of what to find + point_type: Type of point ('pick' or 'place') for context + + Returns: + Tuple of (x, y) pixel coordinates or None if query fails + """ + if point_type == "pick": + prompt = ( + f"Look at this image carefully and find the {query}.\n\n" + "Instructions:\n" + "- Identify the exact object matching the description\n" + "- Choose the center point or the most graspable location on the object\n" + "- If multiple matching objects exist, choose the most prominent or accessible one\n" + "- Consider the object's shape and material when selecting the grasp point\n\n" + "Return ONLY a JSON object with this exact format:\n" + "{'point': [x, y]}\n" + "where [x, y] are the pixel coordinates of the optimal grasping point on the object." + ) + else: # place + prompt = ( + f"Look at this image and identify where to place an object {query}.\n\n" + "Instructions:\n" + "- Find a stable, flat surface at the specified location\n" + "- Ensure the placement spot is clear of obstacles\n" + "- Consider the size of the object being placed\n" + "- If the query specifies a container or specific spot, center the placement there\n" + "- Otherwise, find the most appropriate nearby surface\n\n" + "Return ONLY a JSON object with this exact format:\n" + "{'point': [x, y]}\n" + "where [x, y] are the pixel coordinates of the optimal placement location." + ) + + try: + response = query_single_frame(frame, prompt, model_name=self.model_name) + return parse_qwen_single_point_response(response) + except Exception as e: + logger.error(f"Error querying Qwen for {point_type} point: {e}") + return None + + def __call__(self) -> dict[str, Any]: + """ + Execute the pick and place operation. + + Returns: + Dictionary with operation results + """ + super().__call__() # type: ignore[no-untyped-call] + + if not self._robot: + error_msg = "No robot instance provided to PickAndPlace skill" + logger.error(error_msg) + return {"success": False, "error": error_msg} + + # Register skill as running + skill_library = self._robot.get_skills() # type: ignore[no-untyped-call] + self.register_as_running("PickAndPlace", skill_library) + + # Get camera frame + frame = self._get_camera_frame() + if frame is None: + return {"success": False, "error": "Failed to capture camera frame"} + + # Convert RGB to BGR for OpenCV if needed + if len(frame.shape) == 3 and frame.shape[2] == 3: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + + # Get pick and place points from Qwen + pick_point = None + place_point = None + + # Determine mode based on whether target_query is provided + if self.target_query is None: + # Pick only mode + logger.info("Pick-only mode (no target specified)") + + # Query for pick point + pick_point = self._query_single_point(frame, self.object_query, "pick") + if not pick_point: + return {"success": False, "error": f"Failed to find {self.object_query}"} + + # No place point needed for pick-only + place_point = None + else: + # Pick and place mode - can use either single or dual query + logger.info("Pick and place mode (target specified)") + + # Try single query first for efficiency + points = self._query_pick_and_place_points(frame) + pick_point, place_point = points # type: ignore[misc] + + logger.info(f"Pick point: {pick_point}, Place point: {place_point}") + + # Save debug image with marked points + if pick_point or place_point: + save_debug_image_with_points(frame, pick_point, place_point) + + # Execute pick (and optionally place) using the robot's interface + try: + if place_point: + # Pick and place + result = self._robot.pick_and_place( # type: ignore[attr-defined] + pick_x=pick_point[0], + pick_y=pick_point[1], + place_x=place_point[0], + place_y=place_point[1], + ) + else: + # Pick only + result = self._robot.pick_and_place( # type: ignore[attr-defined] + pick_x=pick_point[0], pick_y=pick_point[1], place_x=None, place_y=None + ) + + if result: + if self.target_query: + message = ( + f"Successfully picked {self.object_query} and placed it {self.target_query}" + ) + else: + message = f"Successfully picked {self.object_query}" + + return { + "success": True, + "pick_point": pick_point, + "place_point": place_point, + "object": self.object_query, + "target": self.target_query, + "message": message, + } + else: + operation = "Pick and place" if self.target_query else "Pick" + return { + "success": False, + "pick_point": pick_point, + "place_point": place_point, + "error": f"{operation} operation failed", + } + + except Exception as e: + logger.error(f"Error executing pick and place: {e}") + return { + "success": False, + "error": f"Execution error: {e!s}", + "pick_point": pick_point, + "place_point": place_point, + } + finally: + # Always unregister skill when done + self.stop() + + def stop(self) -> None: + """ + Stop the pick and place operation and perform cleanup. + """ + logger.info("Stopping PickAndPlace skill") + + # Unregister skill from skill library + if self._robot: + skill_library = self._robot.get_skills() # type: ignore[no-untyped-call] + self.unregister_as_running("PickAndPlace", skill_library) + + logger.info("PickAndPlace skill stopped successfully") diff --git a/dimos/skills/manipulation/rotation_constraint_skill.py b/dimos/skills/manipulation/rotation_constraint_skill.py new file mode 100644 index 0000000000..72e6a53716 --- /dev/null +++ b/dimos/skills/manipulation/rotation_constraint_skill.py @@ -0,0 +1,111 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.types.manipulation import RotationConstraint +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger + +# Initialize logger +logger = setup_logger() + + +class RotationConstraintSkill(AbstractManipulationSkill): + """ + Skill for generating rotation constraints for robot manipulation. + + This skill generates rotation constraints and adds them to the ManipulationInterface's + agent_constraints list for tracking constraints created by the Agent. + """ + + # Rotation axis parameter + rotation_axis: Literal["roll", "pitch", "yaw"] = Field( + "roll", + description="Axis to rotate around: 'roll' (x-axis), 'pitch' (y-axis), or 'yaw' (z-axis)", + ) + + # Simple angle values for rotation (in degrees) + start_angle: float | None = Field(None, description="Starting angle in degrees") + end_angle: float | None = Field(None, description="Ending angle in degrees") + + # Pivot points as (x,y) tuples + pivot_point: tuple[float, float] | None = Field( + None, description="Pivot point (x,y) for rotation" + ) + + # TODO: Secondary pivot point for more complex rotations + secondary_pivot_point: tuple[float, float] | None = Field( + None, description="Secondary pivot point (x,y) for double-pivot rotation" + ) + + def __call__(self) -> RotationConstraint: + """ + Generate a rotation constraint based on the parameters. + + This implementation supports rotation around a single axis (roll, pitch, or yaw). + + Returns: + RotationConstraint: The generated constraint + """ + # rotation_axis is guaranteed to be one of "roll", "pitch", or "yaw" due to Literal type constraint + + # Create angle vectors more efficiently + start_angle_vector = None + if self.start_angle is not None: + # Build rotation vector on correct axis + values = [0.0, 0.0, 0.0] + axis_index = {"roll": 0, "pitch": 1, "yaw": 2}[self.rotation_axis] + values[axis_index] = self.start_angle + start_angle_vector = Vector(*values) # type: ignore[arg-type] + + end_angle_vector = None + if self.end_angle is not None: + values = [0.0, 0.0, 0.0] + axis_index = {"roll": 0, "pitch": 1, "yaw": 2}[self.rotation_axis] + values[axis_index] = self.end_angle + end_angle_vector = Vector(*values) # type: ignore[arg-type] + + # Create pivot point vector if provided (convert 2D point to 3D vector with z=0) + pivot_point_vector = None + if self.pivot_point: + pivot_point_vector = Vector(self.pivot_point[0], self.pivot_point[1], 0.0) # type: ignore[arg-type] + + # Create secondary pivot point vector if provided + secondary_pivot_vector = None + if self.secondary_pivot_point: + secondary_pivot_vector = Vector( + self.secondary_pivot_point[0], # type: ignore[arg-type] + self.secondary_pivot_point[1], # type: ignore[arg-type] + 0.0, # type: ignore[arg-type] + ) + + constraint = RotationConstraint( + rotation_axis=self.rotation_axis, + start_angle=start_angle_vector, + end_angle=end_angle_vector, + pivot_point=pivot_point_vector, + secondary_pivot_point=secondary_pivot_vector, + ) + + # Add constraint to manipulation interface + self.manipulation_interface.add_constraint(constraint) # type: ignore[union-attr] + + # Log the constraint creation + logger.info(f"Generated rotation constraint around {self.rotation_axis} axis") + + return constraint diff --git a/dimos/skills/manipulation/translation_constraint_skill.py b/dimos/skills/manipulation/translation_constraint_skill.py new file mode 100644 index 0000000000..78ea38cfe4 --- /dev/null +++ b/dimos/skills/manipulation/translation_constraint_skill.py @@ -0,0 +1,100 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +from pydantic import Field + +from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill +from dimos.types.manipulation import TranslationConstraint, Vector # type: ignore[attr-defined] +from dimos.utils.logging_config import setup_logger + +# Initialize logger +logger = setup_logger() + + +class TranslationConstraintSkill(AbstractManipulationSkill): + """ + Skill for generating translation constraints for robot manipulation. + + This skill generates translation constraints and adds them to the ManipulationInterface's + agent_constraints list for tracking constraints created by the Agent. + """ + + # Constraint parameters + translation_axis: Literal["x", "y", "z"] = Field( + "x", description="Axis to translate along: 'x', 'y', or 'z'" + ) + + reference_point: tuple[float, float] | None = Field( + None, description="Reference point (x,y) on the target object for translation constraining" + ) + + bounds_min: tuple[float, float] | None = Field( + None, description="Minimum bounds (x,y) for bounded translation" + ) + + bounds_max: tuple[float, float] | None = Field( + None, description="Maximum bounds (x,y) for bounded translation" + ) + + target_point: tuple[float, float] | None = Field( + None, description="Final target position (x,y) for translation constraining" + ) + + # Description + description: str = Field("", description="Description of the translation constraint") + + def __call__(self) -> TranslationConstraint: + """ + Generate a translation constraint based on the parameters. + + Returns: + TranslationConstraint: The generated constraint + """ + # Create reference point vector if provided (convert 2D point to 3D vector with z=0) + reference_point = None + if self.reference_point: + reference_point = Vector(self.reference_point[0], self.reference_point[1], 0.0) # type: ignore[arg-type] + + # Create bounds minimum vector if provided + bounds_min = None + if self.bounds_min: + bounds_min = Vector(self.bounds_min[0], self.bounds_min[1], 0.0) # type: ignore[arg-type] + + # Create bounds maximum vector if provided + bounds_max = None + if self.bounds_max: + bounds_max = Vector(self.bounds_max[0], self.bounds_max[1], 0.0) # type: ignore[arg-type] + + # Create relative target vector if provided + target_point = None + if self.target_point: + target_point = Vector(self.target_point[0], self.target_point[1], 0.0) # type: ignore[arg-type] + + constraint = TranslationConstraint( + translation_axis=self.translation_axis, + reference_point=reference_point, + bounds_min=bounds_min, + bounds_max=bounds_max, + target_point=target_point, + ) + + # Add constraint to manipulation interface + self.manipulation_interface.add_constraint(constraint) # type: ignore[union-attr] + + # Log the constraint creation + logger.info(f"Generated translation constraint along {self.translation_axis} axis") + + return {"success": True} # type: ignore[return-value] diff --git a/dimos/skills/rest/__init__.py b/dimos/skills/rest/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/skills/rest/rest.py b/dimos/skills/rest/rest.py new file mode 100644 index 0000000000..23369faf23 --- /dev/null +++ b/dimos/skills/rest/rest.py @@ -0,0 +1,101 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from pydantic import Field +import requests + +from dimos.skills.skills import AbstractSkill + +logger = logging.getLogger(__name__) + + +class GenericRestSkill(AbstractSkill): + """Performs a configurable REST API call. + + This skill executes an HTTP request based on the provided parameters. It + supports various HTTP methods and allows specifying URL, timeout. + + Attributes: + url: The target URL for the API call. + method: The HTTP method (e.g., 'GET', 'POST'). Case-insensitive. + timeout: Request timeout in seconds. + """ + + # TODO: Add query parameters, request body data (form-encoded or JSON), and headers. + # , query + # parameters, request body data (form-encoded or JSON), and headers. + # params: Optional dictionary of URL query parameters. + # data: Optional dictionary for form-encoded request body data. + # json_payload: Optional dictionary for JSON request body data. Use the + # alias 'json' when initializing. + # headers: Optional dictionary of HTTP headers. + url: str = Field(..., description="The target URL for the API call.") + method: str = Field(..., description="HTTP method (e.g., 'GET', 'POST').") + timeout: int = Field(..., description="Request timeout in seconds.") + # params: Optional[Dict[str, Any]] = Field(default=None, description="URL query parameters.") + # data: Optional[Dict[str, Any]] = Field(default=None, description="Form-encoded request body.") + # json_payload: Optional[Dict[str, Any]] = Field(default=None, alias="json", description="JSON request body.") + # headers: Optional[Dict[str, str]] = Field(default=None, description="HTTP headers.") + + def __call__(self) -> str: + """Executes the configured REST API call. + + Returns: + The text content of the response on success (HTTP 2xx). + + Raises: + requests.exceptions.RequestException: If a connection error, timeout, + or other request-related issue occurs. + requests.exceptions.HTTPError: If the server returns an HTTP 4xx or + 5xx status code. + Exception: For any other unexpected errors during execution. + + Returns: + A string representing the success or failure outcome. If successful, + returns the response body text. If an error occurs, returns a + descriptive error message. + """ + try: + logger.debug( + f"Executing {self.method.upper()} request to {self.url} " + f"with timeout={self.timeout}" # , params={self.params}, " + # f"data={self.data}, json={self.json_payload}, headers={self.headers}" + ) + response = requests.request( + method=self.method.upper(), # Normalize method to uppercase + url=self.url, + # params=self.params, + # data=self.data, + # json=self.json_payload, # Use the attribute name defined in Pydantic + # headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) + logger.debug( + f"Request successful. Status: {response.status_code}, Response: {response.text[:100]}..." + ) + return response.text # Return text content directly + except requests.exceptions.HTTPError as http_err: + logger.error( + f"HTTP error occurred: {http_err} - Status Code: {http_err.response.status_code}" + ) + return f"HTTP error making {self.method.upper()} request to {self.url}: {http_err.response.status_code} {http_err.response.reason}" + except requests.exceptions.RequestException as req_err: + logger.error(f"Request exception occurred: {req_err}") + return f"Error making {self.method.upper()} request to {self.url}: {req_err}" + except Exception as e: + logger.exception(f"An unexpected error occurred: {e}") # Log the full traceback + return f"An unexpected error occurred: {type(e).__name__}: {e}" diff --git a/dimos/skills/skills.py b/dimos/skills/skills.py new file mode 100644 index 0000000000..94f8b3726f --- /dev/null +++ b/dimos/skills/skills.py @@ -0,0 +1,343 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from openai import pydantic_function_tool +from pydantic import BaseModel + +from dimos.types.constants import Colors + +if TYPE_CHECKING: + from collections.abc import Iterator + +# Configure logging for the module +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# region SkillLibrary + + +class SkillLibrary: + # ==== Flat Skill Library ==== + + def __init__(self) -> None: + self.registered_skills: list[AbstractSkill] = [] + self.class_skills: list[AbstractSkill] = [] + self._running_skills = {} # type: ignore[var-annotated] # {skill_name: (instance, subscription)} + + self.init() + + def init(self) -> None: + # Collect all skills from the parent class and update self.skills + self.refresh_class_skills() + + # Temporary + self.registered_skills = self.class_skills.copy() + + def get_class_skills(self) -> list[AbstractSkill]: + """Extract all AbstractSkill subclasses from a class. + + Returns: + List of skill classes found within the class + """ + skills = [] + + # Loop through all attributes of the class + for attr_name in dir(self.__class__): + # Skip special/dunder attributes + if attr_name.startswith("__"): + continue + + try: + attr = getattr(self.__class__, attr_name) + + # Check if it's a class and inherits from AbstractSkill + if ( + isinstance(attr, type) + and issubclass(attr, AbstractSkill) + and attr is not AbstractSkill + ): + skills.append(attr) + except (AttributeError, TypeError): + # Skip attributes that can't be accessed or aren't classes + continue + + return skills # type: ignore[return-value] + + def refresh_class_skills(self) -> None: + self.class_skills = self.get_class_skills() + + def add(self, skill: AbstractSkill) -> None: + if skill not in self.registered_skills: + self.registered_skills.append(skill) + + def get(self) -> list[AbstractSkill]: + return self.registered_skills.copy() + + def remove(self, skill: AbstractSkill) -> None: + try: + self.registered_skills.remove(skill) + except ValueError: + logger.warning(f"Attempted to remove non-existent skill: {skill}") + + def clear(self) -> None: + self.registered_skills.clear() + + def __iter__(self) -> Iterator: # type: ignore[type-arg] + return iter(self.registered_skills) + + def __len__(self) -> int: + return len(self.registered_skills) + + def __contains__(self, skill: AbstractSkill) -> bool: + return skill in self.registered_skills + + def __getitem__(self, index): # type: ignore[no-untyped-def] + return self.registered_skills[index] + + # ==== Calling a Function ==== + + _instances: dict[str, dict] = {} # type: ignore[type-arg] + + def create_instance(self, name: str, **kwargs) -> None: # type: ignore[no-untyped-def] + # Key based only on the name + key = name + + if key not in self._instances: + # Instead of creating an instance, store the args for later use + self._instances[key] = kwargs + + def call(self, name: str, **args): # type: ignore[no-untyped-def] + try: + # Get the stored args if available; otherwise, use an empty dict + stored_args = self._instances.get(name, {}) + + # Merge the arguments with priority given to stored arguments + complete_args = {**args, **stored_args} + + # Dynamically get the class from the module or current script + skill_class = getattr(self, name, None) + if skill_class is None: + for skill in self.get(): + if name == skill.__name__: # type: ignore[attr-defined] + skill_class = skill + break + if skill_class is None: + error_msg = f"Skill '{name}' is not available. Please check if it's properly registered." + logger.error(f"Skill class not found: {name}") + return error_msg + + # Initialize the instance with the merged arguments + instance = skill_class(**complete_args) # type: ignore[operator] + print(f"Instance created and function called for: {name} with args: {complete_args}") + + # Call the instance directly + return instance() + except Exception as e: + error_msg = f"Error executing skill '{name}': {e!s}" + logger.error(error_msg) + return error_msg + + # ==== Tools ==== + + def get_tools(self) -> Any: + tools_json = self.get_list_of_skills_as_json(list_of_skills=self.registered_skills) + # print(f"{Colors.YELLOW_PRINT_COLOR}Tools JSON: {tools_json}{Colors.RESET_COLOR}") + return tools_json + + def get_list_of_skills_as_json(self, list_of_skills: list[AbstractSkill]) -> list[str]: + return list(map(pydantic_function_tool, list_of_skills)) # type: ignore[arg-type] + + def register_running_skill(self, name: str, instance: Any, subscription=None) -> None: # type: ignore[no-untyped-def] + """ + Register a running skill with its subscription. + + Args: + name: Name of the skill (will be converted to lowercase) + instance: Instance of the running skill + subscription: Optional subscription associated with the skill + """ + name = name.lower() + self._running_skills[name] = (instance, subscription) + logger.info(f"Registered running skill: {name}") + + def unregister_running_skill(self, name: str) -> bool: + """ + Unregister a running skill. + + Args: + name: Name of the skill to remove (will be converted to lowercase) + + Returns: + True if the skill was found and removed, False otherwise + """ + name = name.lower() + if name in self._running_skills: + del self._running_skills[name] + logger.info(f"Unregistered running skill: {name}") + return True + return False + + def get_running_skills(self): # type: ignore[no-untyped-def] + """ + Get all running skills. + + Returns: + A dictionary of running skill names and their (instance, subscription) tuples + """ + return self._running_skills.copy() + + def terminate_skill(self, name: str): # type: ignore[no-untyped-def] + """ + Terminate a running skill. + + Args: + name: Name of the skill to terminate (will be converted to lowercase) + + Returns: + A message indicating whether the skill was successfully terminated + """ + name = name.lower() + if name in self._running_skills: + instance, subscription = self._running_skills[name] + + try: + # Call the stop method if it exists + if hasattr(instance, "stop") and callable(instance.stop): + instance.stop() + logger.info(f"Stopped skill: {name}") + else: + logger.warning(f"Skill {name} does not have a stop method") + + # Also dispose the subscription if it exists + if ( + subscription is not None + and hasattr(subscription, "dispose") + and callable(subscription.dispose) + ): + subscription.dispose() + logger.info(f"Disposed subscription for skill: {name}") + elif subscription is not None: + logger.warning(f"Skill {name} has a subscription but it's not disposable") + + # unregister the skill + self.unregister_running_skill(name) + return f"Successfully terminated skill: {name}" + + except Exception as e: + error_msg = f"Error terminating skill {name}: {e}" + logger.error(error_msg) + # Even on error, try to unregister the skill + self.unregister_running_skill(name) + return error_msg + else: + return f"No running skill found with name: {name}" + + +# endregion SkillLibrary + +# region AbstractSkill + + +class AbstractSkill(BaseModel): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + print("Initializing AbstractSkill Class") + super().__init__(*args, **kwargs) + self._instances = {} # type: ignore[var-annotated] + self._list_of_skills = [] # type: ignore[var-annotated] # Initialize the list of skills + print(f"Instances: {self._instances}") + + def clone(self) -> AbstractSkill: + return AbstractSkill() + + def register_as_running( # type: ignore[no-untyped-def] + self, name: str, skill_library: SkillLibrary, subscription=None + ) -> None: + """ + Register this skill as running in the skill library. + + Args: + name: Name of the skill (will be converted to lowercase) + skill_library: The skill library to register with + subscription: Optional subscription associated with the skill + """ + skill_library.register_running_skill(name, self, subscription) + + def unregister_as_running(self, name: str, skill_library: SkillLibrary) -> None: + """ + Unregister this skill from the skill library. + + Args: + name: Name of the skill to remove (will be converted to lowercase) + skill_library: The skill library to unregister from + """ + skill_library.unregister_running_skill(name) + + # ==== Tools ==== + def get_tools(self) -> Any: + tools_json = self.get_list_of_skills_as_json(list_of_skills=self._list_of_skills) + # print(f"Tools JSON: {tools_json}") + return tools_json + + def get_list_of_skills_as_json(self, list_of_skills: list[AbstractSkill]) -> list[str]: + return list(map(pydantic_function_tool, list_of_skills)) # type: ignore[arg-type] + + +# endregion AbstractSkill + +# region Abstract Robot Skill + +if TYPE_CHECKING: + from dimos.robot.robot import Robot +else: + Robot = "Robot" + + +class AbstractRobotSkill(AbstractSkill): + _robot: Robot = None # type: ignore[assignment] + + def __init__(self, *args, robot: Robot | None = None, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self._robot = robot # type: ignore[assignment] + print( + f"{Colors.BLUE_PRINT_COLOR}Robot Skill Initialized with Robot: {robot}{Colors.RESET_COLOR}" + ) + + def set_robot(self, robot: Robot) -> None: + """Set the robot reference for this skills instance. + + Args: + robot: The robot instance to associate with these skills. + """ + self._robot = robot + + def __call__(self): # type: ignore[no-untyped-def] + if self._robot is None: + raise RuntimeError( + f"{Colors.RED_PRINT_COLOR}" + f"No Robot instance provided to Robot Skill: {self.__class__.__name__}" + f"{Colors.RESET_COLOR}" + ) + else: + print( + f"{Colors.BLUE_PRINT_COLOR}Robot Instance provided to Robot Skill: {self.__class__.__name__}{Colors.RESET_COLOR}" + ) + + +# endregion Abstract Robot Skill diff --git a/dimos/skills/speak.py b/dimos/skills/speak.py new file mode 100644 index 0000000000..fc26fd2cd0 --- /dev/null +++ b/dimos/skills/speak.py @@ -0,0 +1,168 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import threading +import time +from typing import Any + +from pydantic import Field +from reactivex import Subject + +from dimos.skills.skills import AbstractSkill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Global lock to prevent multiple simultaneous audio playbacks +_audio_device_lock = threading.RLock() + +# Global queue for sequential audio processing +_audio_queue = queue.Queue() # type: ignore[var-annotated] +_queue_processor_thread = None +_queue_running = False + + +def _process_audio_queue() -> None: + """Background thread to process audio requests sequentially""" + global _queue_running + + while _queue_running: + try: + # Get the next queued audio task with a timeout + task = _audio_queue.get(timeout=1.0) + if task is None: # Sentinel value to stop the thread + break + + # Execute the task (which is a function to be called) + task() + _audio_queue.task_done() + + except queue.Empty: + # No tasks in queue, just continue waiting + continue + except Exception as e: + logger.error(f"Error in audio queue processor: {e}") + # Continue processing other tasks + + +def start_audio_queue_processor() -> None: + """Start the background thread for processing audio requests""" + global _queue_processor_thread, _queue_running + + if _queue_processor_thread is None or not _queue_processor_thread.is_alive(): + _queue_running = True + _queue_processor_thread = threading.Thread( + target=_process_audio_queue, daemon=True, name="AudioQueueProcessor" + ) + _queue_processor_thread.start() + logger.info("Started audio queue processor thread") + + +# Start the queue processor when module is imported +start_audio_queue_processor() + + +class Speak(AbstractSkill): + """Speak text out loud to humans nearby or to other robots.""" + + text: str = Field(..., description="Text to speak") + + def __init__(self, tts_node: Any | None = None, **data) -> None: # type: ignore[no-untyped-def] + super().__init__(**data) + self._tts_node = tts_node + self._audio_complete = threading.Event() + self._subscription = None + self._subscriptions: list = [] # type: ignore[type-arg] # Track all subscriptions + + def __call__(self): # type: ignore[no-untyped-def] + if not self._tts_node: + logger.error("No TTS node provided to Speak skill") + return "Error: No TTS node available" + + # Create a result queue to get the result back from the audio thread + result_queue = queue.Queue(1) # type: ignore[var-annotated] + + # Define the speech task to run in the audio queue + def speak_task() -> None: + try: + # Using a lock to ensure exclusive access to audio device + with _audio_device_lock: + text_subject = Subject() # type: ignore[var-annotated] + self._audio_complete.clear() + self._subscriptions = [] + + # This function will be called when audio processing is complete + def on_complete() -> None: + logger.info(f"TTS audio playback completed for: {self.text}") + self._audio_complete.set() + + # This function will be called if there's an error + def on_error(error) -> None: # type: ignore[no-untyped-def] + logger.error(f"Error in TTS processing: {error}") + self._audio_complete.set() + + # Connect the Subject to the TTS node and keep the subscription + self._tts_node.consume_text(text_subject) # type: ignore[union-attr] + + # Subscribe to the audio output to know when it's done + self._subscription = self._tts_node.emit_text().subscribe( # type: ignore[union-attr] + on_next=lambda text: logger.debug(f"TTS processing: {text}"), + on_completed=on_complete, + on_error=on_error, + ) + self._subscriptions.append(self._subscription) + + # Emit the text to the Subject + text_subject.on_next(self.text) + text_subject.on_completed() # Signal that we're done sending text + + # Wait for audio playback to complete with a timeout + # Using a dynamic timeout based on text length + timeout = max(5, len(self.text) * 0.1) + logger.debug(f"Waiting for TTS completion with timeout {timeout:.1f}s") + + if not self._audio_complete.wait(timeout=timeout): + logger.warning(f"TTS timeout reached for: {self.text}") + else: + # Add a small delay after audio completes to ensure buffers are fully flushed + time.sleep(0.3) + + # Clean up all subscriptions + for sub in self._subscriptions: + if sub: + sub.dispose() + self._subscriptions = [] + + # Successfully completed + result_queue.put(f"Spoke: {self.text} successfully") + except Exception as e: + logger.error(f"Error in speak task: {e}") + result_queue.put(f"Error speaking text: {e!s}") + + # Add our speech task to the global queue for sequential processing + display_text = self.text[:50] + "..." if len(self.text) > 50 else self.text + logger.info(f"Queueing speech task: '{display_text}'") + _audio_queue.put(speak_task) + + # Wait for the result with a timeout + try: + # Use a longer timeout than the audio playback itself + text_len_timeout = len(self.text) * 0.15 # 150ms per character + max_timeout = max(10, text_len_timeout) # At least 10 seconds + + return result_queue.get(timeout=max_timeout) + except queue.Empty: + logger.error("Timed out waiting for speech task to complete") + return f"Error: Timed out while speaking: {self.text}" diff --git a/dimos/skills/unitree/__init__.py b/dimos/skills/unitree/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/skills/unitree/unitree_speak.py b/dimos/skills/unitree/unitree_speak.py new file mode 100644 index 0000000000..9d6d973b64 --- /dev/null +++ b/dimos/skills/unitree/unitree_speak.py @@ -0,0 +1,280 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import hashlib +import json +import os +import tempfile +import time + +from go2_webrtc_driver.constants import RTC_TOPIC # type: ignore[import-untyped] +import numpy as np +from openai import OpenAI +from pydantic import Field +import soundfile as sf # type: ignore[import-untyped] + +from dimos.skills.skills import AbstractRobotSkill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Audio API constants (from go2_webrtc_driver) +AUDIO_API = { + "GET_AUDIO_LIST": 1001, + "SELECT_START_PLAY": 1002, + "PAUSE": 1003, + "UNSUSPEND": 1004, + "SET_PLAY_MODE": 1007, + "UPLOAD_AUDIO_FILE": 2001, + "ENTER_MEGAPHONE": 4001, + "EXIT_MEGAPHONE": 4002, + "UPLOAD_MEGAPHONE": 4003, +} + +PLAY_MODES = {"NO_CYCLE": "no_cycle", "SINGLE_CYCLE": "single_cycle", "LIST_LOOP": "list_loop"} + + +class UnitreeSpeak(AbstractRobotSkill): + """Speak text out loud through the robot's speakers using WebRTC audio upload.""" + + text: str = Field(..., description="Text to speak") + voice: str = Field( + default="echo", description="Voice to use (alloy, echo, fable, onyx, nova, shimmer)" + ) + speed: float = Field(default=1.2, description="Speech speed (0.25 to 4.0)") + use_megaphone: bool = Field( + default=False, description="Use megaphone mode for lower latency (experimental)" + ) + + def __init__(self, **data) -> None: # type: ignore[no-untyped-def] + super().__init__(**data) + self._openai_client = None + + def _get_openai_client(self): # type: ignore[no-untyped-def] + if self._openai_client is None: + self._openai_client = OpenAI() # type: ignore[assignment] + return self._openai_client + + def _generate_audio(self, text: str) -> bytes: + try: + client = self._get_openai_client() # type: ignore[no-untyped-call] + response = client.audio.speech.create( + model="tts-1", voice=self.voice, input=text, speed=self.speed, response_format="mp3" + ) + return response.content # type: ignore[no-any-return] + except Exception as e: + logger.error(f"Error generating audio: {e}") + raise + + def _webrtc_request(self, api_id: int, parameter: dict | None = None): # type: ignore[no-untyped-def, type-arg] + if parameter is None: + parameter = {} + + request_data = {"api_id": api_id, "parameter": json.dumps(parameter) if parameter else "{}"} + + return self._robot.connection.publish_request(RTC_TOPIC["AUDIO_HUB_REQ"], request_data) # type: ignore[attr-defined] + + def _upload_audio_to_robot(self, audio_data: bytes, filename: str) -> str: + try: + file_md5 = hashlib.md5(audio_data).hexdigest() + b64_data = base64.b64encode(audio_data).decode("utf-8") + + chunk_size = 61440 + chunks = [b64_data[i : i + chunk_size] for i in range(0, len(b64_data), chunk_size)] + total_chunks = len(chunks) + + logger.info(f"Uploading audio '{filename}' in {total_chunks} chunks (optimized)") + + for i, chunk in enumerate(chunks, 1): + parameter = { + "file_name": filename, + "file_type": "wav", + "file_size": len(audio_data), + "current_block_index": i, + "total_block_number": total_chunks, + "block_content": chunk, + "current_block_size": len(chunk), + "file_md5": file_md5, + "create_time": int(time.time() * 1000), + } + + logger.debug(f"Sending chunk {i}/{total_chunks}") + self._webrtc_request(AUDIO_API["UPLOAD_AUDIO_FILE"], parameter) + + logger.info(f"Audio upload completed for '{filename}'") + + list_response = self._webrtc_request(AUDIO_API["GET_AUDIO_LIST"], {}) + + if list_response and "data" in list_response: + data_str = list_response.get("data", {}).get("data", "{}") + audio_list = json.loads(data_str).get("audio_list", []) + + for audio in audio_list: + if audio.get("CUSTOM_NAME") == filename: + return audio.get("UNIQUE_ID") # type: ignore[no-any-return] + + logger.warning( + f"Could not find uploaded audio '{filename}' in list, using filename as UUID" + ) + return filename + + except Exception as e: + logger.error(f"Error uploading audio to robot: {e}") + raise + + def _play_audio_on_robot(self, uuid: str): # type: ignore[no-untyped-def] + try: + self._webrtc_request(AUDIO_API["SET_PLAY_MODE"], {"play_mode": PLAY_MODES["NO_CYCLE"]}) + time.sleep(0.1) + + parameter = {"unique_id": uuid} + + logger.info(f"Playing audio with UUID: {uuid}") + self._webrtc_request(AUDIO_API["SELECT_START_PLAY"], parameter) + + except Exception as e: + logger.error(f"Error playing audio on robot: {e}") + raise + + def _stop_audio_playback(self) -> None: + try: + logger.debug("Stopping audio playback") + self._webrtc_request(AUDIO_API["PAUSE"], {}) + except Exception as e: + logger.warning(f"Error stopping audio playback: {e}") + + def _upload_and_play_megaphone(self, audio_data: bytes, duration: float): # type: ignore[no-untyped-def] + try: + logger.debug("Entering megaphone mode") + self._webrtc_request(AUDIO_API["ENTER_MEGAPHONE"], {}) + + time.sleep(0.2) + + b64_data = base64.b64encode(audio_data).decode("utf-8") + + chunk_size = 4096 + chunks = [b64_data[i : i + chunk_size] for i in range(0, len(b64_data), chunk_size)] + total_chunks = len(chunks) + + logger.info(f"Uploading megaphone audio in {total_chunks} chunks") + + for i, chunk in enumerate(chunks, 1): + parameter = { + "current_block_size": len(chunk), + "block_content": chunk, + "current_block_index": i, + "total_block_number": total_chunks, + } + + logger.debug(f"Sending megaphone chunk {i}/{total_chunks}") + self._webrtc_request(AUDIO_API["UPLOAD_MEGAPHONE"], parameter) + + if i < total_chunks: + time.sleep(0.05) + + logger.info("Megaphone audio upload completed, waiting for playback") + + time.sleep(duration + 1.0) + + except Exception as e: + logger.error(f"Error in megaphone mode: {e}") + try: + self._webrtc_request(AUDIO_API["EXIT_MEGAPHONE"], {}) + except: + pass + raise + finally: + try: + logger.debug("Exiting megaphone mode") + self._webrtc_request(AUDIO_API["EXIT_MEGAPHONE"], {}) + time.sleep(0.1) + except Exception as e: + logger.warning(f"Error exiting megaphone mode: {e}") + + def __call__(self) -> str: + super().__call__() # type: ignore[no-untyped-call] + + if not self._robot: + logger.error("No robot instance provided to UnitreeSpeak skill") + return "Error: No robot instance available" + + try: + display_text = self.text[:50] + "..." if len(self.text) > 50 else self.text + logger.info(f"Speaking: '{display_text}'") + + logger.debug("Generating audio with OpenAI TTS") + audio_data = self._generate_audio(self.text) + + with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_mp3: + tmp_mp3.write(audio_data) + tmp_mp3_path = tmp_mp3.name + + try: + audio_array, sample_rate = sf.read(tmp_mp3_path) + + if audio_array.ndim > 1: + audio_array = np.mean(audio_array, axis=1) + + target_sample_rate = 22050 + if sample_rate != target_sample_rate: + logger.debug(f"Resampling from {sample_rate}Hz to {target_sample_rate}Hz") + old_length = len(audio_array) + new_length = int(old_length * target_sample_rate / sample_rate) + old_indices = np.arange(old_length) + new_indices = np.linspace(0, old_length - 1, new_length) + audio_array = np.interp(new_indices, old_indices, audio_array) + sample_rate = target_sample_rate + + audio_array = audio_array / np.max(np.abs(audio_array)) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav: + sf.write(tmp_wav.name, audio_array, sample_rate, format="WAV", subtype="PCM_16") + tmp_wav.seek(0) + wav_data = open(tmp_wav.name, "rb").read() + os.unlink(tmp_wav.name) + + logger.info( + f"Audio size: {len(wav_data) / 1024:.1f}KB, duration: {len(audio_array) / sample_rate:.1f}s" + ) + + finally: + os.unlink(tmp_mp3_path) + + if self.use_megaphone: + logger.debug("Using megaphone mode for lower latency") + duration = len(audio_array) / sample_rate + self._upload_and_play_megaphone(wav_data, duration) + + return f"Spoke: '{display_text}' on robot successfully (megaphone mode)" + else: + filename = f"speak_{int(time.time() * 1000)}" + + logger.debug("Uploading audio to robot") + uuid = self._upload_audio_to_robot(wav_data, filename) + + logger.debug("Playing audio on robot") + self._play_audio_on_robot(uuid) + + duration = len(audio_array) / sample_rate + logger.debug(f"Waiting {duration:.1f}s for playback to complete") + # time.sleep(duration + 0.2) + + # self._stop_audio_playback() + + return f"Spoke: '{display_text}' on robot successfully" + + except Exception as e: + logger.error(f"Error in speak skill: {e}") + return f"Error speaking text: {e!s}" diff --git a/dimos/skills/visual_navigation_skills.py b/dimos/skills/visual_navigation_skills.py new file mode 100644 index 0000000000..acd658ee83 --- /dev/null +++ b/dimos/skills/visual_navigation_skills.py @@ -0,0 +1,148 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Visual navigation skills for robot interaction. + +This module provides skills for visual navigation, including following humans +and navigating to specific objects using computer vision. +""" + +import logging +import threading +import time + +from pydantic import Field + +from dimos.perception.visual_servoing import VisualServoing # type: ignore[import-untyped] +from dimos.skills.skills import AbstractRobotSkill +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(level=logging.DEBUG) + + +class FollowHuman(AbstractRobotSkill): + """ + A skill that makes the robot follow a human using visual servoing continuously. + + This skill uses the robot's person tracking stream to follow a human + while maintaining a specified distance. It will keep following the human + until the timeout is reached or the skill is stopped. Don't use this skill + if you want to navigate to a specific person, use NavigateTo instead. + """ + + distance: float = Field( + 1.5, description="Desired distance to maintain from the person in meters" + ) + timeout: float = Field(20.0, description="Maximum time to follow the person in seconds") + point: tuple[int, int] | None = Field( + None, description="Optional point to start tracking (x,y pixel coordinates)" + ) + + def __init__(self, robot=None, **data) -> None: # type: ignore[no-untyped-def] + super().__init__(robot=robot, **data) + self._stop_event = threading.Event() + self._visual_servoing = None + + def __call__(self): # type: ignore[no-untyped-def] + """ + Start following a human using visual servoing. + + Returns: + bool: True if successful, False otherwise + """ + super().__call__() # type: ignore[no-untyped-call] + + if ( + not hasattr(self._robot, "person_tracking_stream") + or self._robot.person_tracking_stream is None + ): + logger.error("Robot does not have a person tracking stream") + return False + + # Stop any existing operation + self.stop() + self._stop_event.clear() + + success = False + + try: + # Initialize visual servoing + self._visual_servoing = VisualServoing( + tracking_stream=self._robot.person_tracking_stream + ) + + logger.warning(f"Following human for {self.timeout} seconds...") + start_time = time.time() + + # Start tracking + track_success = self._visual_servoing.start_tracking( # type: ignore[attr-defined] + point=self.point, desired_distance=self.distance + ) + + if not track_success: + logger.error("Failed to start tracking") + return False + + # Main follow loop + while ( + self._visual_servoing.running # type: ignore[attr-defined] + and time.time() - start_time < self.timeout + and not self._stop_event.is_set() + ): + output = self._visual_servoing.updateTracking() # type: ignore[attr-defined] + x_vel = output.get("linear_vel") + z_vel = output.get("angular_vel") + logger.debug(f"Following human: x_vel: {x_vel}, z_vel: {z_vel}") + self._robot.move(Vector(x_vel, 0, z_vel)) # type: ignore[arg-type, attr-defined] + time.sleep(0.05) + + # If we completed the full timeout duration, consider it success + if time.time() - start_time >= self.timeout: + success = True + logger.info("Human following completed successfully") + elif self._stop_event.is_set(): + logger.info("Human following stopped externally") + else: + logger.info("Human following stopped due to tracking loss") + + return success + + except Exception as e: + logger.error(f"Error in follow human: {e}") + return False + finally: + # Clean up + if self._visual_servoing: + self._visual_servoing.stop_tracking() + self._visual_servoing = None + + def stop(self) -> bool: + """ + Stop the human following process. + + Returns: + bool: True if stopped, False if it wasn't running + """ + if self._visual_servoing is not None: + logger.info("Stopping FollowHuman skill") + self._stop_event.set() + + # Clean up visual servoing if it exists + self._visual_servoing.stop_tracking() + self._visual_servoing = None + + return True + return False diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py new file mode 100644 index 0000000000..03c1024d12 --- /dev/null +++ b/dimos/spec/__init__.py @@ -0,0 +1,15 @@ +from dimos.spec.control import LocalPlanner +from dimos.spec.map import Global3DMap, GlobalCostmap, GlobalMap +from dimos.spec.nav import Nav +from dimos.spec.perception import Camera, Image, Pointcloud + +__all__ = [ + "Camera", + "Global3DMap", + "GlobalCostmap", + "GlobalMap", + "Image", + "LocalPlanner", + "Nav", + "Pointcloud", +] diff --git a/dimos/spec/control.py b/dimos/spec/control.py new file mode 100644 index 0000000000..e2024c5a09 --- /dev/null +++ b/dimos/spec/control.py @@ -0,0 +1,22 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import Out +from dimos.msgs.geometry_msgs import Twist + + +class LocalPlanner(Protocol): + cmd_vel: Out[Twist] diff --git a/dimos/spec/map.py b/dimos/spec/map.py new file mode 100644 index 0000000000..217f6db619 --- /dev/null +++ b/dimos/spec/map.py @@ -0,0 +1,145 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Structural protocols decoupling map producers from consumers. + +Robotic applications require different map representations for different tasks: +3D geometry for manipulation, 2D occupancy for obstacle detection, and costmaps +for path planning. This module provides structural protocols that allow you to +swap mapping backends without changing downstream code. + +A SLAM system, map accumulator, or simulator can satisfy these protocols by declaring +the appropriate output stream. + +Choose +- `Global3DMap` for detailed 3D geometry (manipulation, dense reconstruction), +- `GlobalMap` for basic obstacle detection, +- and `GlobalCostmap` for path planning +""" + +from typing import Annotated, Protocol + +from annotated_doc import Doc + +from dimos.core import Out +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 + + +class Global3DMap(Protocol): + """Protocol for modules providing a global 3D point cloud map. + + This protocol defines the interface for modules that accumulate and publish + a 3D point cloud representation of the environment. Unlike instantaneous + sensor scans, the global point cloud represents accumulated spatial knowledge + built over time (e.g., from SLAM or incremental mapping). + + The protocol enables decoupling between map producers and consumers, allowing + different mapping backends (ROS2 SLAM, custom mappers, simulation) to be + substituted without changing downstream code. + + Example: + Implementing the protocol:: + + class CustomMapper(Module): + global_pointcloud: Out[PointCloud2] = None # Satisfies Global3DMap + + def publish_map(self, points: np.ndarray): + pc = PointCloud2.from_numpy(points, frame_id="map", timestamp=time.time()) + self.global_pointcloud.publish(pc) + + Notes: + - Implementations in `dimos/navigation/rosnav.py` (ROSNav module) + - The protocol specifies the interface but not publish frequency, point + density, or how the map is accumulated + - Coordinate frame should be a fixed world/map frame, not a moving robot frame + """ + + global_pointcloud: Annotated[ + Out[PointCloud2], + Doc( + """Output stream publishing accumulated 3D point cloud data. The `PointCloud2` + messages contain Open3D point clouds with spatial coordinates in a fixed + world/map frame (typically `frame_id="map"`).""" + ), + ] + + +class GlobalMap(Protocol): + """Protocol for modules providing a global 2D occupancy grid map. + + This protocol defines the interface for modules that publish a 2D occupancy + grid representing obstacles and free space in the environment. The occupancy + grid uses the ROS `nav_msgs/OccupancyGrid` convention for cell values. + + Abstracts over different mapping implementations (SLAM, static maps, simulation). + + Example: + Implementing the protocol:: + + class MapProducer(Module): + global_map: Out[OccupancyGrid] = None # Satisfies GlobalMap + + def publish_occupancy(self, pointcloud: LidarMessage): + grid = OccupancyGrid.from_pointcloud( + pointcloud, resolution=0.05, min_height=0.0, max_height=2.0 + ) + self.global_map.publish(grid) + + Notes: + - For 2D occupancy grids used in navigation, `GlobalCostmap` is more commonly + consumed as it supports cost gradients for path planning + """ + + global_map: Annotated[ + Out[OccupancyGrid], + Doc( + """Output stream publishing 2D occupancy grids. Cell values follow ROS + conventions: `-1` (unknown/unexplored), `0` (free space), `100` (occupied), + and `1-99` (intermediate occupancy/cost values, implementation-dependent).""" + ), + ] + + +class GlobalCostmap(Protocol): + """Protocol for modules providing a global 2D costmap for navigation. + + This protocol defines the interface for modules that publish a 2D costmap + representation designed specifically for path planning. Unlike `GlobalMap` + (which represents raw occupancy), costmaps typically include safety margins + around obstacles and cost gradients to encourage paths that maintain clearance. + + Example: + Implementing the protocol:: + + class CostmapBuilder(Module): + global_costmap: Out[OccupancyGrid] = None # Satisfies GlobalCostmap + + def publish_costmap(self, occupancy: OccupancyGrid): + # Apply inflation and gradient for navigation + costmap = occupancy.inflate(radius=0.2).gradient(max_distance=1.5) + self.global_costmap.publish(costmap) + + Notes: + - Example of an implementation: `dimos/robot/unitree_webrtc/type/map.py` + """ + + global_costmap: Annotated[ + Out[OccupancyGrid], + Doc( + """Output stream publishing 2D costmaps. Cell values follow ROS conventions: + `-1` (unknown/unexplored), `0` (free space with no cost), `1-99` (increasing + traversal cost from gradient/inflation), and `100` (lethal obstacle, impassable).""" + ), + ] diff --git a/dimos/spec/nav.py b/dimos/spec/nav.py new file mode 100644 index 0000000000..d1f62c0846 --- /dev/null +++ b/dimos/spec/nav.py @@ -0,0 +1,31 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import In, Out +from dimos.msgs.geometry_msgs import PoseStamped, Twist +from dimos.msgs.nav_msgs import Path + + +class Nav(Protocol): + goal_req: In[PoseStamped] + goal_active: Out[PoseStamped] + path_active: Out[Path] + ctrl: Out[Twist] + + # identity quaternion (Quaternion(0,0,0,1)) represents "no rotation requested" + def navigate_to_target(self, target: PoseStamped) -> None: ... + + def stop_navigating(self) -> None: ... diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py new file mode 100644 index 0000000000..f2d43e1363 --- /dev/null +++ b/dimos/spec/perception.py @@ -0,0 +1,31 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import Out +from dimos.msgs.sensor_msgs import CameraInfo, Image as ImageMsg, PointCloud2 + + +class Image(Protocol): + color_image: Out[ImageMsg] + + +class Camera(Image): + camera_info: Out[CameraInfo] + _camera_info: CameraInfo + + +class Pointcloud(Protocol): + pointcloud: Out[PointCloud2] diff --git a/dimos/stream/audio/__init__.py b/dimos/stream/audio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/stream/audio/base.py b/dimos/stream/audio/base.py new file mode 100644 index 0000000000..54bd1705a3 --- /dev/null +++ b/dimos/stream/audio/base.py @@ -0,0 +1,121 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +import numpy as np +from reactivex import Observable + + +class AbstractAudioEmitter(ABC): + """Base class for components that emit audio.""" + + @abstractmethod + def emit_audio(self) -> Observable: # type: ignore[type-arg] + """Create an observable that emits audio frames. + + Returns: + Observable emitting audio frames + """ + pass + + +class AbstractAudioConsumer(ABC): + """Base class for components that consume audio.""" + + @abstractmethod + def consume_audio(self, audio_observable: Observable) -> "AbstractAudioConsumer": # type: ignore[type-arg] + """Set the audio observable to consume. + + Args: + audio_observable: Observable emitting audio frames + + Returns: + Self for method chaining + """ + pass + + +class AbstractAudioTransform(AbstractAudioConsumer, AbstractAudioEmitter): + """Base class for components that both consume and emit audio. + + This represents a transform in an audio processing pipeline. + """ + + pass + + +class AudioEvent: + """Class to represent an audio frame event with metadata.""" + + def __init__( + self, + data: np.ndarray, # type: ignore[type-arg] + sample_rate: int, + timestamp: float, + channels: int = 1, + ) -> None: + """ + Initialize an AudioEvent. + + Args: + data: Audio data as numpy array + sample_rate: Audio sample rate in Hz + timestamp: Unix timestamp when the audio was captured + channels: Number of audio channels + """ + self.data = data + self.sample_rate = sample_rate + self.timestamp = timestamp + self.channels = channels + self.dtype = data.dtype + self.shape = data.shape + + def to_float32(self) -> "AudioEvent": + """Convert audio data to float32 format normalized to [-1.0, 1.0].""" + if self.data.dtype == np.float32: + return self + + new_data = self.data.astype(np.float32) + if self.data.dtype == np.int16: + new_data /= 32768.0 + + return AudioEvent( + data=new_data, + sample_rate=self.sample_rate, + timestamp=self.timestamp, + channels=self.channels, + ) + + def to_int16(self) -> "AudioEvent": + """Convert audio data to int16 format.""" + if self.data.dtype == np.int16: + return self + + new_data = self.data + if self.data.dtype == np.float32: + new_data = (new_data * 32767).astype(np.int16) + + return AudioEvent( + data=new_data, + sample_rate=self.sample_rate, + timestamp=self.timestamp, + channels=self.channels, + ) + + def __repr__(self) -> str: + return ( + f"AudioEvent(shape={self.shape}, dtype={self.dtype}, " + f"sample_rate={self.sample_rate}, channels={self.channels})" + ) diff --git a/dimos/stream/audio/node_key_recorder.py b/dimos/stream/audio/node_key_recorder.py new file mode 100644 index 0000000000..a6489d0e5a --- /dev/null +++ b/dimos/stream/audio/node_key_recorder.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import select +import sys +import threading +import time + +import numpy as np +from reactivex import Observable +from reactivex.subject import ReplaySubject, Subject + +from dimos.stream.audio.base import AbstractAudioTransform, AudioEvent +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class KeyRecorder(AbstractAudioTransform): + """ + Audio recorder that captures audio events and combines them. + Press a key to toggle recording on/off. + """ + + def __init__( + self, + max_recording_time: float = 120.0, + always_subscribe: bool = False, + ) -> None: + """ + Initialize KeyRecorder. + + Args: + max_recording_time: Maximum recording time in seconds + always_subscribe: If True, subscribe to audio source continuously, + If False, only subscribe when recording (more efficient + but some audio devices may need time to initialize) + """ + self.max_recording_time = max_recording_time + self.always_subscribe = always_subscribe + + self._audio_buffer = [] # type: ignore[var-annotated] + self._is_recording = False + self._recording_start_time = 0 + self._sample_rate = None # Will be updated from incoming audio + self._channels = None # Will be set from first event + + self._audio_observable = None + self._subscription = None + self._output_subject = Subject() # type: ignore[var-annotated] # For record-time passthrough + self._recording_subject = ReplaySubject(1) # type: ignore[var-annotated] # For full completed recordings + + # Start a thread to monitor for input + self._running = True + self._input_thread = threading.Thread(target=self._input_monitor, daemon=True) + self._input_thread.start() + + logger.info("Started audio recorder (press any key to start/stop recording)") + + def consume_audio(self, audio_observable: Observable) -> "KeyRecorder": # type: ignore[type-arg] + """ + Set the audio observable to use when recording. + If always_subscribe is True, subscribes immediately. + Otherwise, subscribes only when recording starts. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self._audio_observable = audio_observable # type: ignore[assignment] + + # If configured to always subscribe, do it now + if self.always_subscribe and not self._subscription: + self._subscription = audio_observable.subscribe( # type: ignore[assignment] + on_next=self._process_audio_event, + on_error=self._handle_error, + on_completed=self._handle_completion, + ) + logger.debug("Subscribed to audio source (always_subscribe=True)") + + return self + + def emit_audio(self) -> Observable: # type: ignore[type-arg] + """ + Create an observable that emits audio events in real-time (pass-through). + + Returns: + Observable emitting AudioEvent objects in real-time + """ + return self._output_subject + + def emit_recording(self) -> Observable: # type: ignore[type-arg] + """ + Create an observable that emits combined audio recordings when recording stops. + + Returns: + Observable emitting AudioEvent objects with complete recordings + """ + return self._recording_subject + + def stop(self) -> None: + """Stop recording and clean up resources.""" + logger.info("Stopping audio recorder") + + # If recording is in progress, stop it first + if self._is_recording: + self._stop_recording() + + # Always clean up subscription on full stop + if self._subscription: + self._subscription.dispose() + self._subscription = None + + # Stop input monitoring thread + self._running = False + if self._input_thread.is_alive(): + self._input_thread.join(1.0) + + def _input_monitor(self) -> None: + """Monitor for key presses to toggle recording.""" + logger.info("Press Enter to start/stop recording...") + + while self._running: + # Check if there's input available + if select.select([sys.stdin], [], [], 0.1)[0]: + sys.stdin.readline() + + if self._is_recording: + self._stop_recording() + else: + self._start_recording() + + # Sleep a bit to reduce CPU usage + time.sleep(0.1) + + def _start_recording(self) -> None: + """Start recording audio and subscribe to the audio source if not always subscribed.""" + if not self._audio_observable: + logger.error("Cannot start recording: No audio source has been set") + return + + # Subscribe to the observable if not using always_subscribe + if not self._subscription: + self._subscription = self._audio_observable.subscribe( + on_next=self._process_audio_event, + on_error=self._handle_error, + on_completed=self._handle_completion, + ) + logger.debug("Subscribed to audio source for recording") + + self._is_recording = True + self._recording_start_time = time.time() + self._audio_buffer = [] + logger.info("Recording... (press Enter to stop)") + + def _stop_recording(self) -> None: + """Stop recording, unsubscribe from audio source if not always subscribed, and emit the combined audio event.""" + self._is_recording = False + recording_duration = time.time() - self._recording_start_time + + # Unsubscribe from the audio source if not using always_subscribe + if not self.always_subscribe and self._subscription: + self._subscription.dispose() + self._subscription = None + logger.debug("Unsubscribed from audio source after recording") + + logger.info(f"Recording stopped after {recording_duration:.2f} seconds") + + # Combine all audio events into one + if len(self._audio_buffer) > 0: + combined_audio = self._combine_audio_events(self._audio_buffer) + self._recording_subject.on_next(combined_audio) + else: + logger.warning("No audio was recorded") + + def _process_audio_event(self, audio_event) -> None: # type: ignore[no-untyped-def] + """Process incoming audio events.""" + + # Only buffer if recording + if not self._is_recording: + return + + # Pass through audio events in real-time + self._output_subject.on_next(audio_event) + + # First audio event - determine channel count/sample rate + if self._channels is None: + self._channels = audio_event.channels + self._sample_rate = audio_event.sample_rate + logger.info(f"Setting channel count to {self._channels}") + + # Add to buffer + self._audio_buffer.append(audio_event) + + # Check if we've exceeded max recording time + if time.time() - self._recording_start_time > self.max_recording_time: + logger.warning(f"Max recording time ({self.max_recording_time}s) reached") + self._stop_recording() + + def _combine_audio_events(self, audio_events: list[AudioEvent]) -> AudioEvent: + """Combine multiple audio events into a single event.""" + if not audio_events: + logger.warning("Attempted to combine empty audio events list") + return None # type: ignore[return-value] + + # Filter out any empty events that might cause broadcasting errors + valid_events = [ + event + for event in audio_events + if event is not None + and (hasattr(event, "data") and event.data is not None and event.data.size > 0) + ] + + if not valid_events: + logger.warning("No valid audio events to combine") + return None # type: ignore[return-value] + + first_event = valid_events[0] + channels = first_event.channels + dtype = first_event.data.dtype + + # Calculate total samples only from valid events + total_samples = sum(event.data.shape[0] for event in valid_events) + + # Safety check - if somehow we got no samples + if total_samples <= 0: + logger.warning(f"Combined audio would have {total_samples} samples - aborting") + return None # type: ignore[return-value] + + # For multichannel audio, data shape could be (samples,) or (samples, channels) + if len(first_event.data.shape) == 1: + # 1D audio data (mono) + combined_data = np.zeros(total_samples, dtype=dtype) + + # Copy data + offset = 0 + for event in valid_events: + samples = event.data.shape[0] + if samples > 0: # Extra safety check + combined_data[offset : offset + samples] = event.data + offset += samples + else: + # Multichannel audio data (stereo or more) + combined_data = np.zeros((total_samples, channels), dtype=dtype) + + # Copy data + offset = 0 + for event in valid_events: + samples = event.data.shape[0] + if samples > 0 and offset + samples <= total_samples: # Safety check + try: + combined_data[offset : offset + samples] = event.data + offset += samples + except ValueError as e: + logger.error( + f"Error combining audio events: {e}. " + f"Event shape: {event.data.shape}, " + f"Combined shape: {combined_data.shape}, " + f"Offset: {offset}, Samples: {samples}" + ) + # Continue with next event instead of failing completely + + # Create new audio event with the combined data + if combined_data.size > 0: + return AudioEvent( + data=combined_data, + sample_rate=self._sample_rate, # type: ignore[arg-type] + timestamp=valid_events[0].timestamp, + channels=channels, + ) + else: + logger.warning("Failed to create valid combined audio event") + return None # type: ignore[return-value] + + def _handle_error(self, error) -> None: # type: ignore[no-untyped-def] + """Handle errors from the observable.""" + logger.error(f"Error in audio observable: {error}") + + def _handle_completion(self) -> None: + """Handle completion of the observable.""" + logger.info("Audio observable completed") + self.stop() + + +if __name__ == "__main__": + from dimos.stream.audio.node_microphone import ( + SounddeviceAudioSource, + ) + from dimos.stream.audio.node_normalizer import AudioNormalizer + from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.utils import keepalive + + # Create microphone source, recorder, and audio output + mic = SounddeviceAudioSource() + + # my audio device needs time to init, so for smoother ux we constantly listen + recorder = KeyRecorder(always_subscribe=True) + + normalizer = AudioNormalizer() + speaker = SounddeviceAudioOutput() + + # Connect the components + normalizer.consume_audio(mic.emit_audio()) + recorder.consume_audio(normalizer.emit_audio()) + # recorder.consume_audio(mic.emit_audio()) + + # Monitor microphone input levels (real-time pass-through) + monitor(recorder.emit_audio()) + + # Connect the recorder output to the speakers to hear recordings when completed + playback_speaker = SounddeviceAudioOutput() + playback_speaker.consume_audio(recorder.emit_recording()) + + # TODO: we should be able to run normalizer post hoc on the recording as well, + # it's not working, this needs a review + # + # normalizer.consume_audio(recorder.emit_recording()) + # playback_speaker.consume_audio(normalizer.emit_audio()) + + keepalive() diff --git a/dimos/stream/audio/node_microphone.py b/dimos/stream/audio/node_microphone.py new file mode 100644 index 0000000000..5d6e28dc74 --- /dev/null +++ b/dimos/stream/audio/node_microphone.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Any + +import numpy as np +from reactivex import Observable, create, disposable +import sounddevice as sd # type: ignore[import-untyped] + +from dimos.stream.audio.base import ( + AbstractAudioEmitter, + AudioEvent, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class SounddeviceAudioSource(AbstractAudioEmitter): + """Audio source implementation using the sounddevice library.""" + + def __init__( + self, + device_index: int | None = None, + sample_rate: int = 16000, + channels: int = 1, + block_size: int = 1024, + dtype: np.dtype = np.float32, # type: ignore[assignment, type-arg] + ) -> None: + """ + Initialize SounddeviceAudioSource. + + Args: + device_index: Audio device index (None for default) + sample_rate: Audio sample rate in Hz + channels: Number of audio channels (1=mono, 2=stereo) + block_size: Number of samples per audio frame + dtype: Data type for audio samples (np.float32 or np.int16) + """ + self.device_index = device_index + self.sample_rate = sample_rate + self.channels = channels + self.block_size = block_size + self.dtype = dtype + + self._stream = None + self._running = False + + def emit_audio(self) -> Observable: # type: ignore[type-arg] + """ + Create an observable that emits audio frames. + + Returns: + Observable emitting AudioEvent objects + """ + + def on_subscribe(observer, scheduler): # type: ignore[no-untyped-def] + # Callback function to process audio data + def audio_callback(indata, frames, time_info, status) -> None: # type: ignore[no-untyped-def] + if status: + logger.warning(f"Audio callback status: {status}") + + # Create audio event + audio_event = AudioEvent( + data=indata.copy(), + sample_rate=self.sample_rate, + timestamp=time.time(), + channels=self.channels, + ) + + observer.on_next(audio_event) + + # Start the audio stream + try: + self._stream = sd.InputStream( + device=self.device_index, + samplerate=self.sample_rate, + channels=self.channels, + blocksize=self.block_size, + dtype=self.dtype, + callback=audio_callback, + ) + self._stream.start() # type: ignore[attr-defined] + self._running = True + + logger.info( + f"Started audio capture: {self.sample_rate}Hz, " + f"{self.channels} channels, {self.block_size} samples per frame" + ) + + except Exception as e: + logger.error(f"Error starting audio stream: {e}") + observer.on_error(e) + + # Return a disposable to clean up resources + def dispose() -> None: + logger.info("Stopping audio capture") + self._running = False + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + def get_available_devices(self) -> list[dict[str, Any]]: + """Get a list of available audio input devices.""" + return sd.query_devices() # type: ignore[no-any-return] + + +if __name__ == "__main__": + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.utils import keepalive + + monitor(SounddeviceAudioSource().emit_audio()) + keepalive() diff --git a/dimos/stream/audio/node_normalizer.py b/dimos/stream/audio/node_normalizer.py new file mode 100644 index 0000000000..60a25a0404 --- /dev/null +++ b/dimos/stream/audio/node_normalizer.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import numpy as np +from reactivex import Observable, create, disposable + +from dimos.stream.audio.base import ( + AbstractAudioTransform, + AudioEvent, +) +from dimos.stream.audio.volume import ( + calculate_peak_volume, + calculate_rms_volume, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class AudioNormalizer(AbstractAudioTransform): + """ + Audio normalizer that remembers max volume and rescales audio to normalize it. + + This class applies dynamic normalization to audio frames. It keeps track of + the max volume encountered and uses that to normalize the audio to a target level. + """ + + def __init__( + self, + target_level: float = 1.0, + min_volume_threshold: float = 0.01, + max_gain: float = 10.0, + decay_factor: float = 0.999, + adapt_speed: float = 0.05, + volume_func: Callable[[np.ndarray], float] = calculate_peak_volume, # type: ignore[type-arg] + ) -> None: + """ + Initialize AudioNormalizer. + + Args: + target_level: Target normalization level (0.0 to 1.0) + min_volume_threshold: Minimum volume to apply normalization + max_gain: Maximum allowed gain to prevent excessive amplification + decay_factor: Decay factor for max volume (0.0-1.0, higher = slower decay) + adapt_speed: How quickly to adapt to new volume levels (0.0-1.0) + volume_func: Function to calculate volume (default: peak volume) + """ + self.target_level = target_level + self.min_volume_threshold = min_volume_threshold + self.max_gain = max_gain + self.decay_factor = decay_factor + self.adapt_speed = adapt_speed + self.volume_func = volume_func + + # Internal state + self.max_volume = 0.0 + self.current_gain = 1.0 + self.audio_observable = None + + def _normalize_audio(self, audio_event: AudioEvent) -> AudioEvent: + """ + Normalize audio data based on tracked max volume. + + Args: + audio_event: Input audio event + + Returns: + Normalized audio event + """ + # Convert to float32 for processing if needed + if audio_event.data.dtype != np.float32: + audio_event = audio_event.to_float32() + + # Calculate current volume using provided function + current_volume = self.volume_func(audio_event.data) + + # Update max volume with decay + self.max_volume = max(current_volume, self.max_volume * self.decay_factor) + + # Calculate ideal gain + if self.max_volume > self.min_volume_threshold: + ideal_gain = self.target_level / self.max_volume + else: + ideal_gain = 1.0 # No normalization needed for very quiet audio + + # Limit gain to max_gain + ideal_gain = min(ideal_gain, self.max_gain) + + # Smoothly adapt current gain towards ideal gain + self.current_gain = ( + 1 - self.adapt_speed + ) * self.current_gain + self.adapt_speed * ideal_gain + + # Apply gain to audio data + normalized_data = audio_event.data * self.current_gain + + # Clip to prevent distortion (values should stay within -1.0 to 1.0) + normalized_data = np.clip(normalized_data, -1.0, 1.0) + + # Create new audio event with normalized data + return AudioEvent( + data=normalized_data, + sample_rate=audio_event.sample_rate, + timestamp=audio_event.timestamp, + channels=audio_event.channels, + ) + + def consume_audio(self, audio_observable: Observable) -> "AudioNormalizer": # type: ignore[type-arg] + """ + Set the audio source observable to consume. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self.audio_observable = audio_observable # type: ignore[assignment] + return self + + def emit_audio(self) -> Observable: # type: ignore[type-arg] + """ + Create an observable that emits normalized audio frames. + + Returns: + Observable emitting normalized AudioEvent objects + """ + if self.audio_observable is None: + raise ValueError("No audio source provided. Call consume_audio() first.") + + def on_subscribe(observer, scheduler): + # Subscribe to the audio observable + audio_subscription = self.audio_observable.subscribe( + on_next=lambda event: observer.on_next(self._normalize_audio(event)), + on_error=lambda error: observer.on_error(error), + on_completed=lambda: observer.on_completed(), + ) + + logger.info( + f"Started audio normalizer with target level: {self.target_level}, max gain: {self.max_gain}" + ) + + # Return a disposable to clean up resources + def dispose() -> None: + logger.info("Stopping audio normalizer") + audio_subscription.dispose() + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + +if __name__ == "__main__": + import sys + + from dimos.stream.audio.node_microphone import ( + SounddeviceAudioSource, + ) + from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.node_simulated import SimulatedAudioSource + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.utils import keepalive + + # Parse command line arguments + volume_method = "peak" # Default to peak + use_mic = False # Default to microphone input + target_level = 1 # Default target level + + # Process arguments + for arg in sys.argv[1:]: + if arg == "rms": + volume_method = "rms" + elif arg == "peak": + volume_method = "peak" + elif arg == "mic": + use_mic = True + elif arg.startswith("level="): + try: + target_level = float(arg.split("=")[1]) # type: ignore[assignment] + except ValueError: + print(f"Invalid target level: {arg}") + sys.exit(1) + + # Create appropriate audio source + if use_mic: + audio_source = SounddeviceAudioSource() + print("Using microphone input") + else: + audio_source = SimulatedAudioSource(volume_oscillation=True) + print("Using simulated audio source") + + # Select volume function + volume_func = calculate_rms_volume if volume_method == "rms" else calculate_peak_volume + + # Create normalizer + normalizer = AudioNormalizer(target_level=target_level, volume_func=volume_func) + + # Connect the audio source to the normalizer + normalizer.consume_audio(audio_source.emit_audio()) + + print(f"Using {volume_method} volume method with target level {target_level}") + SounddeviceAudioOutput().consume_audio(normalizer.emit_audio()) + + # Monitor the normalized audio + monitor(normalizer.emit_audio()) + keepalive() diff --git a/dimos/stream/audio/node_output.py b/dimos/stream/audio/node_output.py new file mode 100644 index 0000000000..4b4d407329 --- /dev/null +++ b/dimos/stream/audio/node_output.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +from reactivex import Observable +import sounddevice as sd # type: ignore[import-untyped] + +from dimos.stream.audio.base import ( + AbstractAudioTransform, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class SounddeviceAudioOutput(AbstractAudioTransform): + """ + Audio output implementation using the sounddevice library. + + This class implements AbstractAudioTransform so it can both play audio and + optionally pass audio events through to other components (for example, to + record audio while playing it, or to visualize the waveform while playing). + """ + + def __init__( + self, + device_index: int | None = None, + sample_rate: int = 16000, + channels: int = 1, + block_size: int = 1024, + dtype: np.dtype = np.float32, # type: ignore[assignment, type-arg] + ) -> None: + """ + Initialize SounddeviceAudioOutput. + + Args: + device_index: Audio device index (None for default) + sample_rate: Audio sample rate in Hz + channels: Number of audio channels (1=mono, 2=stereo) + block_size: Number of samples per audio frame + dtype: Data type for audio samples (np.float32 or np.int16) + """ + self.device_index = device_index + self.sample_rate = sample_rate + self.channels = channels + self.block_size = block_size + self.dtype = dtype + + self._stream = None + self._running = False + self._subscription = None + self.audio_observable = None + + def consume_audio(self, audio_observable: Observable) -> "SounddeviceAudioOutput": # type: ignore[type-arg] + """ + Subscribe to an audio observable and play the audio through the speakers. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self.audio_observable = audio_observable # type: ignore[assignment] + + # Create and start the output stream + try: + self._stream = sd.OutputStream( + device=self.device_index, + samplerate=self.sample_rate, + channels=self.channels, + blocksize=self.block_size, + dtype=self.dtype, + ) + self._stream.start() # type: ignore[attr-defined] + self._running = True + + logger.info( + f"Started audio output: {self.sample_rate}Hz, " + f"{self.channels} channels, {self.block_size} samples per frame" + ) + + except Exception as e: + logger.error(f"Error starting audio output stream: {e}") + raise e + + # Subscribe to the observable + self._subscription = audio_observable.subscribe( # type: ignore[assignment] + on_next=self._play_audio_event, + on_error=self._handle_error, + on_completed=self._handle_completion, + ) + + return self + + def emit_audio(self) -> Observable: # type: ignore[type-arg] + """ + Pass through the audio observable to allow chaining with other components. + + Returns: + The same Observable that was provided to consume_audio + """ + if self.audio_observable is None: + raise ValueError("No audio source provided. Call consume_audio() first.") + + return self.audio_observable + + def stop(self) -> None: + """Stop audio output and clean up resources.""" + logger.info("Stopping audio output") + self._running = False + + if self._subscription: + self._subscription.dispose() + self._subscription = None + + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + + def _play_audio_event(self, audio_event) -> None: # type: ignore[no-untyped-def] + """Play audio from an AudioEvent.""" + if not self._running or not self._stream: + return + + try: + # Ensure data type matches our stream + if audio_event.dtype != self.dtype: + if self.dtype == np.float32: + audio_event = audio_event.to_float32() + elif self.dtype == np.int16: + audio_event = audio_event.to_int16() + + # Write audio data to the stream + self._stream.write(audio_event.data) + except Exception as e: + logger.error(f"Error playing audio: {e}") + + def _handle_error(self, error) -> None: # type: ignore[no-untyped-def] + """Handle errors from the observable.""" + logger.error(f"Error in audio observable: {error}") + + def _handle_completion(self) -> None: + """Handle completion of the observable.""" + logger.info("Audio observable completed") + self._running = False + if self._stream: + self._stream.stop() + self._stream.close() + self._stream = None + + def get_available_devices(self) -> list[dict[str, Any]]: + """Get a list of available audio output devices.""" + return sd.query_devices() # type: ignore[no-any-return] + + +if __name__ == "__main__": + from dimos.stream.audio.node_microphone import ( + SounddeviceAudioSource, + ) + from dimos.stream.audio.node_normalizer import AudioNormalizer + from dimos.stream.audio.utils import keepalive + + # Create microphone source, normalizer and audio output + mic = SounddeviceAudioSource() + normalizer = AudioNormalizer() + speaker = SounddeviceAudioOutput() + + # Connect the components in a pipeline + normalizer.consume_audio(mic.emit_audio()) + speaker.consume_audio(normalizer.emit_audio()) + + keepalive() diff --git a/dimos/stream/audio/node_simulated.py b/dimos/stream/audio/node_simulated.py new file mode 100644 index 0000000000..1d4cf2d063 --- /dev/null +++ b/dimos/stream/audio/node_simulated.py @@ -0,0 +1,222 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +import numpy as np +from reactivex import Observable, create, disposable + +from dimos.stream.audio.abstract import ( # type: ignore[import-untyped] + AbstractAudioEmitter, + AudioEvent, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class SimulatedAudioSource(AbstractAudioEmitter): # type: ignore[misc] + """Audio source that generates simulated audio for testing.""" + + def __init__( + self, + sample_rate: int = 16000, + frame_length: int = 1024, + channels: int = 1, + dtype: np.dtype = np.float32, # type: ignore[assignment, type-arg] + frequency: float = 440.0, # A4 note + waveform: str = "sine", # Type of waveform + modulation_rate: float = 0.5, # Modulation rate in Hz + volume_oscillation: bool = True, # Enable sinusoidal volume changes + volume_oscillation_rate: float = 0.2, # Volume oscillation rate in Hz + ) -> None: + """ + Initialize SimulatedAudioSource. + + Args: + sample_rate: Audio sample rate in Hz + frame_length: Number of samples per frame + channels: Number of audio channels + dtype: Data type for audio samples + frequency: Frequency of the sine wave in Hz + waveform: Type of waveform ("sine", "square", "triangle", "sawtooth") + modulation_rate: Frequency modulation rate in Hz + volume_oscillation: Whether to oscillate volume sinusoidally + volume_oscillation_rate: Rate of volume oscillation in Hz + """ + self.sample_rate = sample_rate + self.frame_length = frame_length + self.channels = channels + self.dtype = dtype + self.frequency = frequency + self.waveform = waveform.lower() + self.modulation_rate = modulation_rate + self.volume_oscillation = volume_oscillation + self.volume_oscillation_rate = volume_oscillation_rate + self.phase = 0.0 + self.volume_phase = 0.0 + + self._running = False + self._thread = None + + def _generate_sine_wave(self, time_points: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """Generate a waveform based on selected type.""" + # Generate base time points with phase + t = time_points + self.phase + + # Add frequency modulation for more interesting sounds + if self.modulation_rate > 0: + # Modulate frequency between 0.5x and 1.5x the base frequency + freq_mod = self.frequency * (1.0 + 0.5 * np.sin(2 * np.pi * self.modulation_rate * t)) + else: + freq_mod = np.ones_like(t) * self.frequency + + # Create phase argument for oscillators + phase_arg = 2 * np.pi * np.cumsum(freq_mod / self.sample_rate) + + # Generate waveform based on selection + if self.waveform == "sine": + wave = np.sin(phase_arg) + elif self.waveform == "square": + wave = np.sign(np.sin(phase_arg)) + elif self.waveform == "triangle": + wave = ( + 2 * np.abs(2 * (phase_arg / (2 * np.pi) - np.floor(phase_arg / (2 * np.pi) + 0.5))) + - 1 + ) + elif self.waveform == "sawtooth": + wave = 2 * (phase_arg / (2 * np.pi) - np.floor(0.5 + phase_arg / (2 * np.pi))) + else: + # Default to sine wave + wave = np.sin(phase_arg) + + # Apply sinusoidal volume oscillation if enabled + if self.volume_oscillation: + # Current time points for volume calculation + vol_t = t + self.volume_phase + + # Volume oscillates between 0.0 and 1.0 using a sine wave (complete silence to full volume) + volume_factor = 0.5 + 0.5 * np.sin(2 * np.pi * self.volume_oscillation_rate * vol_t) + + # Apply the volume factor + wave *= volume_factor * 0.7 + + # Update volume phase for next frame + self.volume_phase += ( + time_points[-1] - time_points[0] + (time_points[1] - time_points[0]) + ) + + # Update phase for next frame + self.phase += time_points[-1] - time_points[0] + (time_points[1] - time_points[0]) + + # Add a second channel if needed + if self.channels == 2: + wave = np.column_stack((wave, wave)) + elif self.channels > 2: + wave = np.tile(wave.reshape(-1, 1), (1, self.channels)) + + # Convert to int16 if needed + if self.dtype == np.int16: + wave = (wave * 32767).astype(np.int16) + + return wave # type: ignore[no-any-return] + + def _audio_thread(self, observer, interval: float) -> None: # type: ignore[no-untyped-def] + """Thread function for simulated audio generation.""" + try: + sample_index = 0 + self._running = True + + while self._running: + # Calculate time points for this frame + time_points = ( + np.arange(sample_index, sample_index + self.frame_length) / self.sample_rate + ) + + # Generate audio data + audio_data = self._generate_sine_wave(time_points) + + # Create audio event + audio_event = AudioEvent( + data=audio_data, + sample_rate=self.sample_rate, + timestamp=time.time(), + channels=self.channels, + ) + + observer.on_next(audio_event) + + # Update sample index for next frame + sample_index += self.frame_length + + # Sleep to simulate real-time audio + time.sleep(interval) + + except Exception as e: + logger.error(f"Error in simulated audio thread: {e}") + observer.on_error(e) + finally: + self._running = False + observer.on_completed() + + def emit_audio(self, fps: int = 30) -> Observable: # type: ignore[type-arg] + """ + Create an observable that emits simulated audio frames. + + Args: + fps: Frames per second to emit + + Returns: + Observable emitting AudioEvent objects + """ + + def on_subscribe(observer, scheduler): # type: ignore[no-untyped-def] + # Calculate interval based on fps + interval = 1.0 / fps + + # Start the audio generation thread + self._thread = threading.Thread( # type: ignore[assignment] + target=self._audio_thread, args=(observer, interval), daemon=True + ) + self._thread.start() # type: ignore[attr-defined] + + logger.info( + f"Started simulated audio source: {self.sample_rate}Hz, " + f"{self.channels} channels, {self.frame_length} samples per frame" + ) + + # Return a disposable to clean up + def dispose() -> None: + logger.info("Stopping simulated audio") + self._running = False + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + +if __name__ == "__main__": + from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.utils import keepalive + + source = SimulatedAudioSource() + speaker = SounddeviceAudioOutput() + speaker.consume_audio(source.emit_audio()) + monitor(speaker.emit_audio()) + + keepalive() diff --git a/dimos/stream/audio/node_volume_monitor.py b/dimos/stream/audio/node_volume_monitor.py new file mode 100644 index 0000000000..a1dec90460 --- /dev/null +++ b/dimos/stream/audio/node_volume_monitor.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +from reactivex import Observable, create, disposable + +from dimos.stream.audio.base import AbstractAudioConsumer, AudioEvent +from dimos.stream.audio.text.base import AbstractTextEmitter +from dimos.stream.audio.text.node_stdout import TextPrinterNode +from dimos.stream.audio.volume import calculate_peak_volume +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class VolumeMonitorNode(AbstractAudioConsumer, AbstractTextEmitter): + """ + A node that monitors audio volume and emits text descriptions. + """ + + def __init__( + self, + threshold: float = 0.01, + bar_length: int = 50, + volume_func: Callable = calculate_peak_volume, # type: ignore[type-arg] + ) -> None: + """ + Initialize VolumeMonitorNode. + + Args: + threshold: Threshold for considering audio as active + bar_length: Length of the volume bar in characters + volume_func: Function to calculate volume (defaults to peak volume) + """ + self.threshold = threshold + self.bar_length = bar_length + self.volume_func = volume_func + self.func_name = volume_func.__name__.replace("calculate_", "") + self.audio_observable = None + + def create_volume_text(self, volume: float) -> str: + """ + Create a text representation of the volume level. + + Args: + volume: Volume level between 0.0 and 1.0 + + Returns: + String representation of the volume + """ + # Calculate number of filled segments + filled = int(volume * self.bar_length) + + # Create the bar + bar = "█" * filled + "░" * (self.bar_length - filled) + + # Determine if we're above threshold + active = volume >= self.threshold + + # Format the text + percentage = int(volume * 100) + activity = "active" if active else "silent" + return f"{bar} {percentage:3d}% {activity}" + + def consume_audio(self, audio_observable: Observable) -> "VolumeMonitorNode": # type: ignore[type-arg] + """ + Set the audio source observable to consume. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self.audio_observable = audio_observable # type: ignore[assignment] + return self + + def emit_text(self) -> Observable: # type: ignore[type-arg] + """ + Create an observable that emits volume text descriptions. + + Returns: + Observable emitting text descriptions of audio volume + """ + if self.audio_observable is None: + raise ValueError("No audio source provided. Call consume_audio() first.") + + def on_subscribe(observer, scheduler): + logger.info(f"Starting volume monitor (method: {self.func_name})") + + # Subscribe to the audio source + def on_audio_event(event: AudioEvent) -> None: + try: + # Calculate volume + volume = self.volume_func(event.data) + + # Create text representation + text = self.create_volume_text(volume) + + # Emit the text + observer.on_next(text) + except Exception as e: + logger.error(f"Error processing audio event: {e}") + observer.on_error(e) + + # Set up subscription to audio source + subscription = self.audio_observable.subscribe( + on_next=on_audio_event, + on_error=lambda e: observer.on_error(e), + on_completed=lambda: observer.on_completed(), + ) + + # Return a disposable to clean up resources + def dispose() -> None: + logger.info("Stopping volume monitor") + subscription.dispose() + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + +def monitor( + audio_source: Observable, # type: ignore[type-arg] + threshold: float = 0.01, + bar_length: int = 50, + volume_func: Callable = calculate_peak_volume, # type: ignore[type-arg] +) -> VolumeMonitorNode: + """ + Create a volume monitor node connected to a text output node. + + Args: + audio_source: The audio source to monitor + threshold: Threshold for considering audio as active + bar_length: Length of the volume bar in characters + volume_func: Function to calculate volume + + Returns: + The configured volume monitor node + """ + # Create the volume monitor node with specified parameters + volume_monitor = VolumeMonitorNode( + threshold=threshold, bar_length=bar_length, volume_func=volume_func + ) + + # Connect the volume monitor to the audio source + volume_monitor.consume_audio(audio_source) + + # Create and connect the text printer node + text_printer = TextPrinterNode() + text_printer.consume_text(volume_monitor.emit_text()) + + # Return the volume monitor node + return volume_monitor + + +if __name__ == "__main__": + from audio.node_simulated import SimulatedAudioSource # type: ignore[import-not-found] + from utils import keepalive # type: ignore[import-untyped] + + # Use the monitor function to create and connect the nodes + volume_monitor = monitor(SimulatedAudioSource().emit_audio()) + + keepalive() diff --git a/dimos/stream/audio/pipelines.py b/dimos/stream/audio/pipelines.py new file mode 100644 index 0000000000..5685b47bcf --- /dev/null +++ b/dimos/stream/audio/pipelines.py @@ -0,0 +1,52 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.stream.audio.node_key_recorder import KeyRecorder +from dimos.stream.audio.node_microphone import SounddeviceAudioSource +from dimos.stream.audio.node_normalizer import AudioNormalizer +from dimos.stream.audio.node_output import SounddeviceAudioOutput +from dimos.stream.audio.node_volume_monitor import monitor +from dimos.stream.audio.stt.node_whisper import WhisperNode +from dimos.stream.audio.text.node_stdout import TextPrinterNode +from dimos.stream.audio.tts.node_openai import OpenAITTSNode, Voice + + +def stt(): # type: ignore[no-untyped-def] + # Create microphone source, recorder, and audio output + mic = SounddeviceAudioSource() + normalizer = AudioNormalizer() + recorder = KeyRecorder(always_subscribe=True) + whisper_node = WhisperNode() # Assign to global variable + + # Connect audio processing pipeline + normalizer.consume_audio(mic.emit_audio()) + recorder.consume_audio(normalizer.emit_audio()) + monitor(recorder.emit_audio()) + whisper_node.consume_audio(recorder.emit_recording()) + + user_text_printer = TextPrinterNode(prefix="USER: ") + user_text_printer.consume_text(whisper_node.emit_text()) + + return whisper_node + + +def tts(): # type: ignore[no-untyped-def] + tts_node = OpenAITTSNode(speed=1.2, voice=Voice.ONYX) + agent_text_printer = TextPrinterNode(prefix="AGENT: ") + agent_text_printer.consume_text(tts_node.emit_text()) + + response_output = SounddeviceAudioOutput(sample_rate=24000) + response_output.consume_audio(tts_node.emit_audio()) + + return tts_node diff --git a/dimos/stream/audio/stt/node_whisper.py b/dimos/stream/audio/stt/node_whisper.py new file mode 100644 index 0000000000..e162d150a1 --- /dev/null +++ b/dimos/stream/audio/stt/node_whisper.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from reactivex import Observable, create, disposable +import whisper # type: ignore[import-untyped] + +from dimos.stream.audio.base import ( + AbstractAudioConsumer, + AudioEvent, +) +from dimos.stream.audio.text.base import AbstractTextEmitter +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class WhisperNode(AbstractAudioConsumer, AbstractTextEmitter): + """ + A node that transcribes audio using OpenAI's Whisper model and emits the transcribed text. + """ + + def __init__( + self, + model: str = "base", + modelopts: dict[str, Any] | None = None, + ) -> None: + if modelopts is None: + modelopts = {"language": "en", "fp16": False} + self.audio_observable = None + self.modelopts = modelopts + self.model = whisper.load_model(model) + + def consume_audio(self, audio_observable: Observable) -> "WhisperNode": # type: ignore[type-arg] + """ + Set the audio source observable to consume. + + Args: + audio_observable: Observable emitting AudioEvent objects + + Returns: + Self for method chaining + """ + self.audio_observable = audio_observable # type: ignore[assignment] + return self + + def emit_text(self) -> Observable: # type: ignore[type-arg] + """ + Create an observable that emits transcribed text from audio. + + Returns: + Observable emitting transcribed text from audio recordings + """ + if self.audio_observable is None: + raise ValueError("No audio source provided. Call consume_audio() first.") + + def on_subscribe(observer, scheduler): + logger.info("Starting Whisper transcription service") + + # Subscribe to the audio source + def on_audio_event(event: AudioEvent) -> None: + try: + result = self.model.transcribe(event.data.flatten(), **self.modelopts) + observer.on_next(result["text"].strip()) + except Exception as e: + logger.error(f"Error processing audio event: {e}") + observer.on_error(e) + + # Set up subscription to audio source + subscription = self.audio_observable.subscribe( + on_next=on_audio_event, + on_error=lambda e: observer.on_error(e), + on_completed=lambda: observer.on_completed(), + ) + + # Return a disposable to clean up resources + def dispose() -> None: + subscription.dispose() + + return disposable.Disposable(dispose) + + return create(on_subscribe) + + +if __name__ == "__main__": + from dimos.stream.audio.node_key_recorder import KeyRecorder + from dimos.stream.audio.node_microphone import ( + SounddeviceAudioSource, + ) + from dimos.stream.audio.node_normalizer import AudioNormalizer + from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.text.node_stdout import TextPrinterNode + from dimos.stream.audio.tts.node_openai import OpenAITTSNode + from dimos.stream.audio.utils import keepalive + + # Create microphone source, recorder, and audio output + mic = SounddeviceAudioSource() + normalizer = AudioNormalizer() + recorder = KeyRecorder() + whisper_node = WhisperNode() + output = SounddeviceAudioOutput(sample_rate=24000) + + normalizer.consume_audio(mic.emit_audio()) + recorder.consume_audio(normalizer.emit_audio()) + monitor(recorder.emit_audio()) + whisper_node.consume_audio(recorder.emit_recording()) + + # Create and connect the text printer node + text_printer = TextPrinterNode(prefix="USER: ") + text_printer.consume_text(whisper_node.emit_text()) + + tts_node = OpenAITTSNode() + tts_node.consume_text(whisper_node.emit_text()) + + output.consume_audio(tts_node.emit_audio()) + + keepalive() diff --git a/dimos/stream/audio/text/base.py b/dimos/stream/audio/text/base.py new file mode 100644 index 0000000000..b101121357 --- /dev/null +++ b/dimos/stream/audio/text/base.py @@ -0,0 +1,55 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from reactivex import Observable + + +class AbstractTextEmitter(ABC): + """Base class for components that emit audio.""" + + @abstractmethod + def emit_text(self) -> Observable: # type: ignore[type-arg] + """Create an observable that emits audio frames. + + Returns: + Observable emitting audio frames + """ + pass + + +class AbstractTextConsumer(ABC): + """Base class for components that consume audio.""" + + @abstractmethod + def consume_text(self, text_observable: Observable) -> "AbstractTextConsumer": # type: ignore[type-arg] + """Set the audio observable to consume. + + Args: + audio_observable: Observable emitting audio frames + + Returns: + Self for method chaining + """ + pass + + +class AbstractTextTransform(AbstractTextConsumer, AbstractTextEmitter): + """Base class for components that both consume and emit audio. + + This represents a transform in an audio processing pipeline. + """ + + pass diff --git a/dimos/stream/audio/text/node_stdout.py b/dimos/stream/audio/text/node_stdout.py new file mode 100644 index 0000000000..4a25b7b1fa --- /dev/null +++ b/dimos/stream/audio/text/node_stdout.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from reactivex import Observable + +from dimos.stream.audio.text.base import AbstractTextConsumer +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class TextPrinterNode(AbstractTextConsumer): + """ + A node that subscribes to a text observable and prints the text. + """ + + def __init__(self, prefix: str = "", suffix: str = "", end: str = "\n") -> None: + """ + Initialize TextPrinterNode. + + Args: + prefix: Text to print before each line + suffix: Text to print after each line + end: String to append at the end of each line + """ + self.prefix = prefix + self.suffix = suffix + self.end = end + self.subscription = None + + def print_text(self, text: str) -> None: + """ + Print the text with prefix and suffix. + + Args: + text: The text to print + """ + print(f"{self.prefix}{text}{self.suffix}", end=self.end, flush=True) + + def consume_text(self, text_observable: Observable) -> "AbstractTextConsumer": # type: ignore[type-arg] + """ + Start processing text from the observable source. + + Args: + text_observable: Observable source of text strings + + Returns: + Self for method chaining + """ + logger.info("Starting text printer") + + # Subscribe to the text observable + self.subscription = text_observable.subscribe( # type: ignore[assignment] + on_next=self.print_text, + on_error=lambda e: logger.error(f"Error: {e}"), + on_completed=lambda: logger.info("Text printer completed"), + ) + + return self + + +if __name__ == "__main__": + import time + + from reactivex import Subject + + # Create a simple text subject that we can push values to + text_subject = Subject() # type: ignore[var-annotated] + + # Create and connect the text printer + text_printer = TextPrinterNode(prefix="Text: ") + text_printer.consume_text(text_subject) + + # Emit some test messages + test_messages = [ + "Hello, world!", + "This is a test of the text printer", + "Using the new AbstractTextConsumer interface", + "Press Ctrl+C to exit", + ] + + print("Starting test...") + print("-" * 60) + + # Emit each message with a delay + try: + for message in test_messages: + text_subject.on_next(message) + time.sleep(0.1) + + # Keep the program running + while True: + text_subject.on_next(f"Current time: {time.strftime('%H:%M:%S')}") + time.sleep(0.2) + except KeyboardInterrupt: + print("\nStopping text printer") + finally: + # Clean up + if text_printer.subscription: + text_printer.subscription.dispose() + text_subject.on_completed() diff --git a/dimos/stream/audio/tts/node_openai.py b/dimos/stream/audio/tts/node_openai.py new file mode 100644 index 0000000000..bed1f35682 --- /dev/null +++ b/dimos/stream/audio/tts/node_openai.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +import io +import threading +import time + +from openai import OpenAI +from reactivex import Observable, Subject +import soundfile as sf # type: ignore[import-untyped] + +from dimos.stream.audio.base import ( + AbstractAudioEmitter, + AudioEvent, +) +from dimos.stream.audio.text.base import AbstractTextConsumer, AbstractTextEmitter +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class Voice(str, Enum): + """Available voices in OpenAI TTS API.""" + + ALLOY = "alloy" + ECHO = "echo" + FABLE = "fable" + ONYX = "onyx" + NOVA = "nova" + SHIMMER = "shimmer" + + +class OpenAITTSNode(AbstractTextConsumer, AbstractAudioEmitter, AbstractTextEmitter): + """ + A text-to-speech node that consumes text, emits audio using OpenAI's TTS API, and passes through text. + + This node implements AbstractTextConsumer to receive text input, AbstractAudioEmitter + to provide audio output, and AbstractTextEmitter to pass through the text being spoken, + allowing it to be inserted into a text-to-audio pipeline with text passthrough capabilities. + """ + + def __init__( + self, + api_key: str | None = None, + voice: Voice = Voice.ECHO, + model: str = "tts-1", + buffer_size: int = 1024, + speed: float = 1.0, + ) -> None: + """ + Initialize OpenAITTSNode. + + Args: + api_key: OpenAI API key (if None, will try to use environment variable) + voice: TTS voice to use + model: TTS model to use + buffer_size: Audio buffer size in samples + """ + self.voice = voice + self.model = model + self.speed = speed + self.buffer_size = buffer_size + + # Initialize OpenAI client + self.client = OpenAI(api_key=api_key) + + # Initialize state + self.audio_subject = Subject() # type: ignore[var-annotated] + self.text_subject = Subject() # type: ignore[var-annotated] + self.subscription = None + self.processing_thread = None + self.is_running = True + self.text_queue = [] # type: ignore[var-annotated] + self.queue_lock = threading.Lock() + + def emit_audio(self) -> Observable: # type: ignore[type-arg] + """ + Returns an observable that emits audio frames. + + Returns: + Observable emitting AudioEvent objects + """ + return self.audio_subject + + def emit_text(self) -> Observable: # type: ignore[type-arg] + """ + Returns an observable that emits the text being spoken. + + Returns: + Observable emitting text strings + """ + return self.text_subject + + def consume_text(self, text_observable: Observable) -> "AbstractTextConsumer": # type: ignore[type-arg] + """ + Start consuming text from the observable source. + + Args: + text_observable: Observable source of text strings + + Returns: + Self for method chaining + """ + logger.info("Starting OpenAITTSNode") + + # Start the processing thread + self.processing_thread = threading.Thread(target=self._process_queue, daemon=True) # type: ignore[assignment] + self.processing_thread.start() # type: ignore[attr-defined] + + # Subscribe to the text observable + self.subscription = text_observable.subscribe( # type: ignore[assignment] + on_next=self._queue_text, + on_error=lambda e: logger.error(f"Error in OpenAITTSNode: {e}"), + ) + + return self + + def _queue_text(self, text: str) -> None: + """ + Add text to the processing queue and pass it through to text_subject. + + Args: + text: The text to synthesize + """ + if not text.strip(): + return + + with self.queue_lock: + self.text_queue.append(text) + + def _process_queue(self) -> None: + """Background thread to process the text queue.""" + while self.is_running: + # Check if there's text to process + text_to_process = None + with self.queue_lock: + if self.text_queue: + text_to_process = self.text_queue.pop(0) + + if text_to_process: + self._synthesize_speech(text_to_process) + else: + # Sleep a bit to avoid busy-waiting + time.sleep(0.1) + + def _synthesize_speech(self, text: str) -> None: + """ + Convert text to speech using OpenAI API. + + Args: + text: The text to synthesize + """ + try: + # Call OpenAI TTS API + response = self.client.audio.speech.create( + model=self.model, voice=self.voice.value, input=text, speed=self.speed + ) + self.text_subject.on_next(text) + + # Convert the response to audio data + audio_data = io.BytesIO(response.content) + + # Read with soundfile + with sf.SoundFile(audio_data, "r") as sound_file: + # Get the sample rate from the file + actual_sample_rate = sound_file.samplerate + # Read the entire file + audio_array = sound_file.read() + + # Debug log the sample rate from the OpenAI file + logger.debug(f"OpenAI audio sample rate: {actual_sample_rate}Hz") + + timestamp = time.time() + + # Create AudioEvent and emit it + audio_event = AudioEvent( + data=audio_array, + sample_rate=24000, + timestamp=timestamp, + channels=1 if audio_array.ndim == 1 else audio_array.shape[1], + ) + + self.audio_subject.on_next(audio_event) + + except Exception as e: + logger.error(f"Error synthesizing speech: {e}") + + def dispose(self) -> None: + """Clean up resources.""" + logger.info("Disposing OpenAITTSNode") + + self.is_running = False + + if self.processing_thread and self.processing_thread.is_alive(): + self.processing_thread.join(timeout=5.0) + + if self.subscription: + self.subscription.dispose() + self.subscription = None + + # Complete the subjects + self.audio_subject.on_completed() + self.text_subject.on_completed() + + +if __name__ == "__main__": + import time + + from reactivex import Subject + + from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.text.node_stdout import TextPrinterNode + from dimos.stream.audio.utils import keepalive + + # Create a simple text subject that we can push values to + text_subject = Subject() # type: ignore[var-annotated] + + tts_node = OpenAITTSNode(voice=Voice.ALLOY) + tts_node.consume_text(text_subject) + + # Create and connect an audio output node - explicitly set sample rate + audio_output = SounddeviceAudioOutput(sample_rate=24000) + audio_output.consume_audio(tts_node.emit_audio()) + + stdout = TextPrinterNode(prefix="[Spoken Text] ") + + stdout.consume_text(tts_node.emit_text()) + + # Emit some test messages + test_messages = [ + "Hello!", + "This is a test of the OpenAI text to speech system.", + ] + + print("Starting OpenAI TTS test...") + print("-" * 60) + + for _i, message in enumerate(test_messages): + text_subject.on_next(message) + + keepalive() diff --git a/dimos/stream/audio/tts/node_pytts.py b/dimos/stream/audio/tts/node_pytts.py new file mode 100644 index 0000000000..1304c33c9a --- /dev/null +++ b/dimos/stream/audio/tts/node_pytts.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyttsx3 # type: ignore[import-not-found] +from reactivex import Observable, Subject + +from dimos.stream.audio.text.abstract import AbstractTextTransform # type: ignore[import-untyped] +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class PyTTSNode(AbstractTextTransform): # type: ignore[misc] + """ + A transform node that passes through text but also speaks it using pyttsx3. + + This node implements AbstractTextTransform, so it both consumes and emits + text observables, allowing it to be inserted into a text processing pipeline. + """ + + def __init__(self, rate: int = 200, volume: float = 1.0) -> None: + """ + Initialize PyTTSNode. + + Args: + rate: Speech rate (words per minute) + volume: Volume level (0.0 to 1.0) + """ + self.engine = pyttsx3.init() + self.engine.setProperty("rate", rate) + self.engine.setProperty("volume", volume) + + self.text_subject = Subject() # type: ignore[var-annotated] + self.subscription = None + + def emit_text(self) -> Observable: # type: ignore[type-arg] + """ + Returns an observable that emits text strings passed through this node. + + Returns: + Observable emitting text strings + """ + return self.text_subject + + def consume_text(self, text_observable: Observable) -> "AbstractTextTransform": # type: ignore[type-arg] + """ + Start processing text from the observable source. + + Args: + text_observable: Observable source of text strings + + Returns: + Self for method chaining + """ + logger.info("Starting PyTTSNode") + + # Subscribe to the text observable + self.subscription = text_observable.subscribe( # type: ignore[assignment] + on_next=self.process_text, + on_error=lambda e: logger.error(f"Error in PyTTSNode: {e}"), + on_completed=lambda: self.on_text_completed(), + ) + + return self + + def process_text(self, text: str) -> None: + """ + Process the input text: speak it and pass it through. + + Args: + text: The text to process + """ + # Speak the text + logger.debug(f"Speaking: {text}") + self.engine.say(text) + self.engine.runAndWait() + + # Pass the text through to any subscribers + self.text_subject.on_next(text) + + def on_text_completed(self) -> None: + """Handle completion of the input observable.""" + logger.info("Input text stream completed") + # Signal completion to subscribers + self.text_subject.on_completed() + + def dispose(self) -> None: + """Clean up resources.""" + logger.info("Disposing PyTTSNode") + if self.subscription: + self.subscription.dispose() + self.subscription = None + + +if __name__ == "__main__": + import time + + # Create a simple text subject that we can push values to + text_subject = Subject() # type: ignore[var-annotated] + + # Create and connect the TTS node + tts_node = PyTTSNode(rate=150) + tts_node.consume_text(text_subject) + + # Optional: Connect to the output to demonstrate it's a transform + from dimos.stream.audio.text.node_stdout import TextPrinterNode + + printer = TextPrinterNode(prefix="[Spoken Text] ") + printer.consume_text(tts_node.emit_text()) + + # Emit some test messages + test_messages = [ + "Hello, world!", + "This is a test of the text-to-speech node", + "Using the AbstractTextTransform interface", + "It passes text through while also speaking it", + ] + + print("Starting test...") + print("-" * 60) + + try: + # Emit each message with a delay + for message in test_messages: + text_subject.on_next(message) + time.sleep(2) # Longer delay to let speech finish + + except KeyboardInterrupt: + print("\nStopping TTS node") + finally: + # Clean up + tts_node.dispose() + text_subject.on_completed() diff --git a/dimos/stream/audio/utils.py b/dimos/stream/audio/utils.py new file mode 100644 index 0000000000..c0c3b866d0 --- /dev/null +++ b/dimos/stream/audio/utils.py @@ -0,0 +1,26 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + + +def keepalive() -> None: + try: + # Keep the program running + print("Press Ctrl+C to exit") + print("-" * 60) + while True: + time.sleep(0.1) + except KeyboardInterrupt: + print("\nStopping pipeline") diff --git a/dimos/stream/audio/volume.py b/dimos/stream/audio/volume.py new file mode 100644 index 0000000000..eafb61690b --- /dev/null +++ b/dimos/stream/audio/volume.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def calculate_rms_volume(audio_data: np.ndarray) -> float: # type: ignore[type-arg] + """ + Calculate RMS (Root Mean Square) volume of audio data. + + Args: + audio_data: Audio data as numpy array + + Returns: + RMS volume as a float between 0.0 and 1.0 + """ + # For multi-channel audio, calculate RMS across all channels + if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: + # Flatten all channels + audio_data = audio_data.flatten() + + # Calculate RMS + rms = np.sqrt(np.mean(np.square(audio_data))) + + # For int16 data, normalize to [0, 1] + if audio_data.dtype == np.int16: + rms = rms / 32768.0 + + return rms # type: ignore[no-any-return] + + +def calculate_peak_volume(audio_data: np.ndarray) -> float: # type: ignore[type-arg] + """ + Calculate peak volume of audio data. + + Args: + audio_data: Audio data as numpy array + + Returns: + Peak volume as a float between 0.0 and 1.0 + """ + # For multi-channel audio, find max across all channels + if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: + # Flatten all channels + audio_data = audio_data.flatten() + + # Find absolute peak value + peak = np.max(np.abs(audio_data)) + + # For int16 data, normalize to [0, 1] + if audio_data.dtype == np.int16: + peak = peak / 32768.0 + + return peak # type: ignore[no-any-return] + + +if __name__ == "__main__": + # Example usage + import time + + from .node_simulated import SimulatedAudioSource + + # Create a simulated audio source + audio_source = SimulatedAudioSource() + + # Create observable and subscribe to get a single frame + audio_observable = audio_source.capture_audio_as_observable() + + def process_frame(frame) -> None: # type: ignore[no-untyped-def] + # Calculate and print both RMS and peak volumes + rms_vol = calculate_rms_volume(frame.data) + peak_vol = calculate_peak_volume(frame.data) + + print(f"RMS Volume: {rms_vol:.4f}") + print(f"Peak Volume: {peak_vol:.4f}") + print(f"Ratio (Peak/RMS): {peak_vol / rms_vol:.2f}") + + # Set a flag to track when processing is complete + processed = {"done": False} + + def process_frame_wrapper(frame) -> None: # type: ignore[no-untyped-def] + # Process the frame + process_frame(frame) + # Mark as processed + processed["done"] = True + + # Subscribe to get a single frame and process it + subscription = audio_observable.subscribe( + on_next=process_frame_wrapper, on_completed=lambda: print("Completed") + ) + + # Wait for frame processing to complete + while not processed["done"]: + time.sleep(0.01) + + # Now dispose the subscription from the main thread, not from within the callback + subscription.dispose() diff --git a/dimos/stream/data_provider.py b/dimos/stream/data_provider.py new file mode 100644 index 0000000000..2a2d18d857 --- /dev/null +++ b/dimos/stream/data_provider.py @@ -0,0 +1,182 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +import logging +import multiprocessing + +import reactivex as rx +from reactivex import Observable, Subject, operators as ops +from reactivex.scheduler import ThreadPoolScheduler +from reactivex.subject import Subject + +logging.basicConfig(level=logging.INFO) + +# Create a thread pool scheduler for concurrent processing +pool_scheduler = ThreadPoolScheduler(multiprocessing.cpu_count()) + + +class AbstractDataProvider(ABC): + """Abstract base class for data providers using ReactiveX.""" + + def __init__(self, dev_name: str = "NA") -> None: + self.dev_name = dev_name + self._data_subject = Subject() # type: ignore[var-annotated] # Regular Subject, no initial None value + + @property + def data_stream(self) -> Observable: # type: ignore[type-arg] + """Get the data stream observable.""" + return self._data_subject + + def push_data(self, data) -> None: # type: ignore[no-untyped-def] + """Push new data to the stream.""" + self._data_subject.on_next(data) + + def dispose(self) -> None: + """Cleanup resources.""" + self._data_subject.dispose() + + +class ROSDataProvider(AbstractDataProvider): + """ReactiveX data provider for ROS topics.""" + + def __init__(self, dev_name: str = "ros_provider") -> None: + super().__init__(dev_name) + self.logger = logging.getLogger(dev_name) + + def push_data(self, data) -> None: # type: ignore[no-untyped-def] + """Push new data to the stream.""" + print(f"ROSDataProvider pushing data of type: {type(data)}") + super().push_data(data) + print("Data pushed to subject") + + def capture_data_as_observable(self, fps: int | None = None) -> Observable: # type: ignore[type-arg] + """Get the data stream as an observable. + + Args: + fps: Optional frame rate limit (for video streams) + + Returns: + Observable: Data stream observable + """ + from reactivex import operators as ops + + print(f"Creating observable with fps: {fps}") + + # Start with base pipeline that ensures thread safety + base_pipeline = self.data_stream.pipe( + # Ensure emissions are handled on thread pool + ops.observe_on(pool_scheduler), + # Add debug logging to track data flow + ops.do_action( + on_next=lambda x: print(f"Got frame in pipeline: {type(x)}"), + on_error=lambda e: print(f"Pipeline error: {e}"), + on_completed=lambda: print("Pipeline completed"), + ), + ) + + # If fps is specified, add rate limiting + if fps and fps > 0: + print(f"Adding rate limiting at {fps} FPS") + return base_pipeline.pipe( + # Use scheduler for time-based operations + ops.sample(1.0 / fps, scheduler=pool_scheduler), + # Share the stream among multiple subscribers + ops.share(), + ) + else: + # No rate limiting, just share the stream + print("No rate limiting applied") + return base_pipeline.pipe(ops.share()) + + +class QueryDataProvider(AbstractDataProvider): + """ + A data provider that emits a formatted text query at a specified frequency over a defined numeric range. + + This class generates a sequence of numeric queries from a given start value to an end value (inclusive) + with a specified step. Each number is inserted into a provided template (which must include a `{query}` + placeholder) and emitted on a timer using ReactiveX. + + Attributes: + dev_name (str): The name of the data provider. + logger (logging.Logger): Logger instance for logging messages. + """ + + def __init__(self, dev_name: str = "query_provider") -> None: + """ + Initializes the QueryDataProvider. + + Args: + dev_name (str): The name of the data provider. Defaults to "query_provider". + """ + super().__init__(dev_name) + self.logger = logging.getLogger(dev_name) + + def start_query_stream( + self, + query_template: str | None = None, + frequency: float = 3.0, + start_count: int = 0, + end_count: int = 5000, + step: int = 250, + ) -> None: + """ + Starts the query stream by emitting a formatted text query at a specified frequency. + + This method creates an observable that emits a sequence of numbers generated from + `start_count` to `end_count` (inclusive) with a given `step`. Each number is then formatted + using the `query_template`. The formatted query is pushed to the internal data stream. + + Args: + query_template (str): The template string for formatting queries. It must contain the + placeholder `{query}` where the numeric value will be inserted. If None, a default + template is used. + frequency (float): The frequency (in seconds) at which queries are emitted. Defaults to 3.0. + start_count (int): The starting number for query generation. Defaults to 0. + end_count (int): The ending number for query generation (inclusive). Defaults to 5000. + step (int): The increment between consecutive query numbers. Defaults to 250. + """ + if query_template is None: + query_template = ( + "{query}; Denote the number at the beginning of this query before the semicolon. " + "Only provide the number, without any other text in your response. " + "If the number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. " + "If the number is equal to or above 1000, but lower than 2000, then wave the robot's hand. " + "If the number is equal to or above 2000, then clear debris. " + "IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!" + ) + + # Generate the sequence of numeric queries. + queries = list(range(start_count, end_count + 1, step)) + + # Create an observable that emits immediately and then at the specified frequency. + timer = rx.timer(0, frequency) + query_source = rx.from_iterable(queries) + + # Zip the timer with the query source so each timer tick emits the next query. + query_stream = timer.pipe( + ops.zip(query_source), + ops.map(lambda pair: query_template.format(query=pair[1])), # type: ignore[index] + ops.observe_on(pool_scheduler), + # ops.do_action( + # on_next=lambda q: self.logger.info(f"Emitting query: {q}"), + # on_error=lambda e: self.logger.error(f"Query stream error: {e}"), + # on_completed=lambda: self.logger.info("Query stream completed") + # ), + ops.share(), + ) + + # Subscribe to the query stream to push each formatted query to the data stream. + query_stream.subscribe(lambda q: self.push_data(q)) diff --git a/dimos/stream/frame_processor.py b/dimos/stream/frame_processor.py new file mode 100644 index 0000000000..ab18400c88 --- /dev/null +++ b/dimos/stream/frame_processor.py @@ -0,0 +1,304 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import cv2 +import numpy as np +from reactivex import Observable, operators as ops + + +# TODO: Reorganize, filenaming - Consider merger with VideoOperators class +class FrameProcessor: + def __init__( + self, output_dir: str = f"{os.getcwd()}/assets/output/frames", delete_on_init: bool = False + ) -> None: + """Initializes the FrameProcessor. + + Sets up the output directory for frame storage and optionally cleans up + existing JPG files. + + Args: + output_dir: Directory path for storing processed frames. + Defaults to '{os.getcwd()}/assets/output/frames'. + delete_on_init: If True, deletes all existing JPG files in output_dir. + Defaults to False. + + Raises: + OSError: If directory creation fails or if file deletion fails. + PermissionError: If lacking permissions for directory/file operations. + """ + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + if delete_on_init: + try: + jpg_files = [f for f in os.listdir(self.output_dir) if f.lower().endswith(".jpg")] + for file in jpg_files: + file_path = os.path.join(self.output_dir, file) + os.remove(file_path) + print(f"Cleaned up {len(jpg_files)} existing JPG files from {self.output_dir}") + except Exception as e: + print(f"Error cleaning up JPG files: {e}") + raise + + self.image_count = 1 + # TODO: Add randomness to jpg folder storage naming. + # Will overwrite between sessions. + + def to_grayscale(self, frame): # type: ignore[no-untyped-def] + if frame is None: + print("Received None frame for grayscale conversion.") + return None + return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + def edge_detection(self, frame): # type: ignore[no-untyped-def] + return cv2.Canny(frame, 100, 200) + + def resize(self, frame, scale: float = 0.5): # type: ignore[no-untyped-def] + return cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + + def export_to_jpeg(self, frame, save_limit: int = 100, loop: bool = False, suffix: str = ""): # type: ignore[no-untyped-def] + if frame is None: + print("Error: Attempted to save a None image.") + return None + + # Check if the image has an acceptable number of channels + if len(frame.shape) == 3 and frame.shape[2] not in [1, 3, 4]: + print(f"Error: Frame with shape {frame.shape} has unsupported number of channels.") + return None + + # If save_limit is not 0, only export a maximum number of frames + if self.image_count > save_limit and save_limit != 0: + if loop: + self.image_count = 1 + else: + return frame + + filepath = os.path.join(self.output_dir, f"{self.image_count}_{suffix}.jpg") + cv2.imwrite(filepath, frame) + self.image_count += 1 + return frame + + def compute_optical_flow( + self, + acc: tuple[np.ndarray, np.ndarray, float | None], # type: ignore[type-arg] + current_frame: np.ndarray, # type: ignore[type-arg] + compute_relevancy: bool = True, + ) -> tuple[np.ndarray, np.ndarray, float | None]: # type: ignore[type-arg] + """Computes optical flow between consecutive frames. + + Uses the Farneback algorithm to compute dense optical flow between the + previous and current frame. Optionally calculates a relevancy score + based on the mean magnitude of motion vectors. + + Args: + acc: Accumulator tuple containing: + prev_frame: Previous video frame (np.ndarray) + prev_flow: Previous optical flow (np.ndarray) + prev_relevancy: Previous relevancy score (float or None) + current_frame: Current video frame as BGR image (np.ndarray) + compute_relevancy: If True, calculates mean magnitude of flow vectors. + Defaults to True. + + Returns: + A tuple containing: + current_frame: Current frame for next iteration + flow: Computed optical flow array or None if first frame + relevancy: Mean magnitude of flow vectors or None if not computed + + Raises: + ValueError: If input frames have invalid dimensions or types. + TypeError: If acc is not a tuple of correct types. + """ + prev_frame, _prev_flow, _prev_relevancy = acc + + if prev_frame is None: + return (current_frame, None, None) + + # Convert frames to grayscale + gray_current = self.to_grayscale(current_frame) # type: ignore[no-untyped-call] + gray_prev = self.to_grayscale(prev_frame) # type: ignore[no-untyped-call] + + # Compute optical flow + flow = cv2.calcOpticalFlowFarneback(gray_prev, gray_current, None, 0.5, 3, 15, 3, 5, 1.2, 0) # type: ignore[call-overload] + + # Relevancy calulation (average magnitude of flow vectors) + relevancy = None + if compute_relevancy: + mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + relevancy = np.mean(mag) + + # Return the current frame as the new previous frame and the processed optical flow, with relevancy score + return (current_frame, flow, relevancy) # type: ignore[return-value] + + def visualize_flow(self, flow): # type: ignore[no-untyped-def] + if flow is None: + return None + hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) + hsv[..., 1] = 255 + mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) # type: ignore[call-overload] + rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + return rgb + + # ============================== + + def process_stream_edge_detection(self, frame_stream): # type: ignore[no-untyped-def] + return frame_stream.pipe( + ops.map(self.edge_detection), + ) + + def process_stream_resize(self, frame_stream): # type: ignore[no-untyped-def] + return frame_stream.pipe( + ops.map(self.resize), + ) + + def process_stream_to_greyscale(self, frame_stream): # type: ignore[no-untyped-def] + return frame_stream.pipe( + ops.map(self.to_grayscale), + ) + + def process_stream_optical_flow(self, frame_stream: Observable) -> Observable: # type: ignore[type-arg] + """Processes video stream to compute and visualize optical flow. + + Computes optical flow between consecutive frames and generates a color-coded + visualization where hue represents flow direction and intensity represents + flow magnitude. This method optimizes performance by disabling relevancy + computation. + + Args: + frame_stream: An Observable emitting video frames as numpy arrays. + Each frame should be in BGR format with shape (height, width, 3). + + Returns: + An Observable emitting visualized optical flow frames as BGR images + (np.ndarray). Hue indicates flow direction, intensity shows magnitude. + + Raises: + TypeError: If frame_stream is not an Observable. + ValueError: If frames have invalid dimensions or format. + + Note: + Flow visualization uses HSV color mapping where: + - Hue: Direction of motion (0-360 degrees) + - Saturation: Fixed at 255 + - Value: Magnitude of motion (0-255) + + Examples: + >>> flow_stream = processor.process_stream_optical_flow(frame_stream) + >>> flow_stream.subscribe(lambda flow: cv2.imshow('Flow', flow)) + """ + return frame_stream.pipe( + ops.scan( + lambda acc, frame: self.compute_optical_flow(acc, frame, compute_relevancy=False), # type: ignore[arg-type, return-value] + (None, None, None), + ), + ops.map(lambda result: result[1]), # type: ignore[index] # Extract flow component + ops.filter(lambda flow: flow is not None), + ops.map(self.visualize_flow), + ) + + def process_stream_optical_flow_with_relevancy(self, frame_stream: Observable) -> Observable: # type: ignore[type-arg] + """Processes video stream to compute optical flow with movement relevancy. + + Applies optical flow computation to each frame and returns both the + visualized flow and a relevancy score indicating the amount of movement. + The relevancy score is calculated as the mean magnitude of flow vectors. + This method includes relevancy computation for motion detection. + + Args: + frame_stream: An Observable emitting video frames as numpy arrays. + Each frame should be in BGR format with shape (height, width, 3). + + Returns: + An Observable emitting tuples of (visualized_flow, relevancy_score): + visualized_flow: np.ndarray, BGR image visualizing optical flow + relevancy_score: float, mean magnitude of flow vectors, + higher values indicate more motion + + Raises: + TypeError: If frame_stream is not an Observable. + ValueError: If frames have invalid dimensions or format. + + Examples: + >>> flow_stream = processor.process_stream_optical_flow_with_relevancy( + ... frame_stream + ... ) + >>> flow_stream.subscribe( + ... lambda result: print(f"Motion score: {result[1]}") + ... ) + + Note: + Relevancy scores are computed using mean magnitude of flow vectors. + Higher scores indicate more movement in the frame. + """ + return frame_stream.pipe( + ops.scan( + lambda acc, frame: self.compute_optical_flow(acc, frame, compute_relevancy=True), # type: ignore[arg-type, return-value] + (None, None, None), + ), + # Result is (current_frame, flow, relevancy) + ops.filter(lambda result: result[1] is not None), # type: ignore[index] # Filter out None flows + ops.map( + lambda result: ( + self.visualize_flow(result[1]), # type: ignore[index, no-untyped-call] # Visualized flow + result[2], # type: ignore[index] # Relevancy score + ) + ), + ops.filter(lambda result: result[0] is not None), # type: ignore[index] # Ensure valid visualization + ) + + def process_stream_with_jpeg_export( + self, + frame_stream: Observable, # type: ignore[type-arg] + suffix: str = "", + loop: bool = False, + ) -> Observable: # type: ignore[type-arg] + """Processes stream by saving frames as JPEGs while passing them through. + + Saves each frame from the stream as a JPEG file and passes the frame + downstream unmodified. Files are saved sequentially with optional suffix + in the configured output directory (self.output_dir). If loop is True, + it will cycle back and overwrite images starting from the first one + after reaching the save_limit. + + Args: + frame_stream: An Observable emitting video frames as numpy arrays. + Each frame should be in BGR format with shape (height, width, 3). + suffix: Optional string to append to filename before index. + Defaults to empty string. Example: "optical" -> "optical_1.jpg" + loop: If True, reset the image counter to 1 after reaching + save_limit, effectively looping the saves. Defaults to False. + + Returns: + An Observable emitting the same frames that were saved. Returns None + for frames that could not be saved due to format issues or save_limit + (unless loop is True). + + Raises: + TypeError: If frame_stream is not an Observable. + ValueError: If frames have invalid format or output directory + is not writable. + OSError: If there are file system permission issues. + + Note: + Frames are saved as '{suffix}_{index}.jpg' where index + increments for each saved frame. Saving stops after reaching + the configured save_limit (default: 100) unless loop is True. + """ + return frame_stream.pipe( + ops.map(lambda frame: self.export_to_jpeg(frame, suffix=suffix, loop=loop)), + ) diff --git a/dimos/stream/media_provider.py b/dimos/stream/media_provider.py deleted file mode 100644 index 8dfa07e55c..0000000000 --- a/dimos/stream/media_provider.py +++ /dev/null @@ -1,149 +0,0 @@ -from time import sleep -import cv2 -import reactivex as rx -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler - - -class MediaProvider: - def __init__(self, dev_name:str="NA"): - self.dev_name = dev_name - self.disposables = CompositeDisposable() - - def dispose_all(self): - """Disposes of all active subscriptions managed by this agent.""" - if self.disposables: - self.disposables.dispose() - else: - print("No disposables to dispose.") - - -# TODO: Test threading concurrency and instanciation more fully -class VideoProviderExample(MediaProvider): - def __init__(self, dev_name: str, video_source:str="/app/assets/video-f30-480p.mp4"): - super().__init__(dev_name) - self.video_source = video_source - # self.scheduler = ThreadPoolScheduler(1) # CurrentThreadScheduler - self.cap = None - - def get_capture(self): - """Ensure that the capture device is correctly initialized and open.""" - if self.cap is None or not self.cap.isOpened(): - if self.cap: - self.cap.release() - print("Released Capture") - self.cap = cv2.VideoCapture(self.video_source) - print("Opened Capture") - if not self.cap.isOpened(): - raise Exception("Failed to open video source") - return self.cap - - def video_capture_to_observable(self): - cap = self.get_capture() - - def emit_frames(observer, scheduler): - try: - while cap.isOpened(): - ret, frame = cap.read() - if ret: - observer.on_next(frame) - else: - cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # If loading from a video, loop it - continue - # observer.on_completed() - # break - except Exception as e: - observer.on_error(e) - finally: - cap.release() - - return rx.create(emit_frames).pipe( - # ops.observe_on(self.scheduler), # - # ops.subscribe_on(self.scheduler), # - ops.share() - ) - - def dispose_all(self): - """Disposes of all resources.""" - if self.cap and self.cap.isOpened(): - self.cap.release() - super().dispose_all() - - def __del__(self): - """Destructor to ensure resources are cleaned up if not explicitly disposed.""" - self.dispose_all() - - - - - - -# class VideoProviderExample(MediaProvider): -# def __init__(self, dev_name: str, provider_type:str="Video", video_source:str="/app/assets/video-f30-480p.mp4"): -# super().__init__(dev_name) -# self.provider_type = provider_type -# self.video_source = video_source - -# def video_capture_to_observable(self, cap): -# """Creates an observable from a video capture source.""" -# def on_subscribe(observer, scheduler=None): - -# def read_frame(): # scheduler, state): -# while True: -# try: -# ret, frame = cap.read() -# if ret: -# observer.on_next(frame) -# # cv2.waitKey(1) -# # Reschedule reading the next frame -# #if scheduler: -# #scheduler.schedule(read_frame) -# else: -# cap.set(cv2.CAP_PROP_POS_FRAMES, 0) -# continue -# # observer.on_completed() -# # cap.release() -# except Exception as e: -# observer.on_error(e) -# cap.release() - -# # Schedule the first frame read -# #if scheduler: -# #scheduler.schedule(read_frame) -# #else: -# read_frame() # Direct call on the same thread -# return rx.create(on_subscribe).pipe( -# ops.publish(), # Convert the observable from cold to hot -# ops.ref_count() # Start emitting when the first subscriber subscribes and stop when the last unsubscribes -# ) - -# def get_capture(self): # , video_source="/app/assets/video-f30-480p.mp4"): -# # video_source = root_dir + '' # "udp://0.0.0.0:23000" # "/dev/video0" -# cap = cv2.VideoCapture(self.video_source) -# print("Opening video source") -# print(f"Source: {self.video_source}") -# if not cap.isOpened(): -# print("Failed to open video source") -# exit() -# print("Opened video source") -# return cap - -# def video_capture_to_observable(self): # , video_source="/app/assets/video-f30-480p.mp4"): -# cap = self.get_capture() -# return self.video_capture_to_observable(cap) - -# # def dispose(): -# # self.disposeables.dispose() -# # from time import sleep -# # while True: -# # sleep(1) -# # if cv2.waitKey(1) & 0xFF == ord('q'): -# # # disposable.dispose() -# # disposable_flask.dispose() -# # disposable_oai.dispose() -# # for _ in disposablables: -# # disposablables.dispose() - -# # cv2.destroyAllWindows() -# # break diff --git a/dimos/stream/ros_video_provider.py b/dimos/stream/ros_video_provider.py new file mode 100644 index 0000000000..cf842aa257 --- /dev/null +++ b/dimos/stream/ros_video_provider.py @@ -0,0 +1,111 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROS-based video provider module. + +This module provides a video frame provider that receives frames from ROS (Robot Operating System) +and makes them available as an Observable stream. +""" + +import logging +import time + +import numpy as np +from reactivex import Observable, Subject, operators as ops +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.stream.video_provider import AbstractVideoProvider + +logging.basicConfig(level=logging.INFO) + + +class ROSVideoProvider(AbstractVideoProvider): + """Video provider that uses a Subject to broadcast frames pushed by ROS. + + This class implements a video provider that receives frames from ROS and makes them + available as an Observable stream. It uses ReactiveX's Subject to broadcast frames. + + Attributes: + logger: Logger instance for this provider. + _subject: ReactiveX Subject that broadcasts frames. + _last_frame_time: Timestamp of the last received frame. + """ + + def __init__( + self, dev_name: str = "ros_video", pool_scheduler: ThreadPoolScheduler | None = None + ) -> None: + """Initialize the ROS video provider. + + Args: + dev_name: A string identifying this provider. + pool_scheduler: Optional ThreadPoolScheduler for multithreading. + """ + super().__init__(dev_name, pool_scheduler) + self.logger = logging.getLogger(dev_name) + self._subject = Subject() # type: ignore[var-annotated] + self._last_frame_time = None + self.logger.info("ROSVideoProvider initialized") + + def push_data(self, frame: np.ndarray) -> None: # type: ignore[type-arg] + """Push a new frame into the provider. + + Args: + frame: The video frame to push into the stream, typically a numpy array + containing image data. + + Raises: + Exception: If there's an error pushing the frame. + """ + try: + current_time = time.time() + if self._last_frame_time: + frame_interval = current_time - self._last_frame_time + self.logger.debug( + f"Frame interval: {frame_interval:.3f}s ({1 / frame_interval:.1f} FPS)" + ) + self._last_frame_time = current_time # type: ignore[assignment] + + self.logger.debug(f"Pushing frame type: {type(frame)}") + self._subject.on_next(frame) + self.logger.debug("Frame pushed") + except Exception as e: + self.logger.error(f"Push error: {e}") + raise + + def capture_video_as_observable(self, fps: int = 30) -> Observable: # type: ignore[type-arg] + """Return an observable of video frames. + + Args: + fps: Frames per second rate limit (default: 30; ignored for now). + + Returns: + Observable: An observable stream of video frames (numpy.ndarray objects), + with each emission containing a single video frame. The frames are + multicast to all subscribers. + + Note: + The fps parameter is currently not enforced. See implementation note below. + """ + self.logger.info(f"Creating observable with {fps} FPS rate limiting") + # TODO: Implement rate limiting using ops.throttle_with_timeout() or + # ops.sample() to restrict emissions to one frame per (1/fps) seconds. + # Example: ops.sample(1.0/fps) + return self._subject.pipe( + # Ensure subscription work happens on the thread pool + ops.subscribe_on(self.pool_scheduler), + # Ensure observer callbacks execute on the thread pool + ops.observe_on(self.pool_scheduler), + # Make the stream hot/multicast so multiple subscribers get the same frames + ops.share(), + ) diff --git a/dimos/stream/rtsp_video_provider.py b/dimos/stream/rtsp_video_provider.py new file mode 100644 index 0000000000..fb53e80dd8 --- /dev/null +++ b/dimos/stream/rtsp_video_provider.py @@ -0,0 +1,379 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RTSP video provider using ffmpeg for robust stream handling.""" + +import subprocess +import threading +import time + +import ffmpeg # type: ignore[import-untyped] # ffmpeg-python wrapper +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.observable import Observable +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.utils.logging_config import setup_logger + +# Assuming AbstractVideoProvider and exceptions are in the sibling file +from .video_provider import AbstractVideoProvider, VideoFrameError, VideoSourceError + +logger = setup_logger() + + +class RtspVideoProvider(AbstractVideoProvider): + """Video provider implementation for capturing RTSP streams using ffmpeg. + + This provider uses the ffmpeg-python library to interact with ffmpeg, + providing more robust handling of various RTSP streams compared to OpenCV's + built-in VideoCapture for RTSP. + """ + + def __init__( + self, dev_name: str, rtsp_url: str, pool_scheduler: ThreadPoolScheduler | None = None + ) -> None: + """Initializes the RTSP video provider. + + Args: + dev_name: The name of the device or stream (for identification). + rtsp_url: The URL of the RTSP stream (e.g., "rtsp://user:pass@ip:port/path"). + pool_scheduler: The scheduler for thread pool operations. Defaults to global scheduler. + """ + super().__init__(dev_name, pool_scheduler) + self.rtsp_url = rtsp_url + # Holds the currently active ffmpeg process Popen object + self._ffmpeg_process: subprocess.Popen | None = None # type: ignore[type-arg] + # Lock to protect access to the ffmpeg process object + self._lock = threading.Lock() + + def _get_stream_info(self) -> dict: # type: ignore[type-arg] + """Probes the RTSP stream to get video dimensions and FPS using ffprobe.""" + logger.info(f"({self.dev_name}) Probing RTSP stream.") + try: + # Probe the stream without the problematic timeout argument + probe = ffmpeg.probe(self.rtsp_url) + except ffmpeg.Error as e: + stderr = e.stderr.decode("utf8", errors="ignore") if e.stderr else "No stderr" + msg = f"({self.dev_name}) Failed to probe RTSP stream {self.rtsp_url}: {stderr}" + logger.error(msg) + raise VideoSourceError(msg) from e + except Exception as e: + msg = f"({self.dev_name}) Unexpected error during probing {self.rtsp_url}: {e}" + logger.error(msg) + raise VideoSourceError(msg) from e + + video_stream = next( + (stream for stream in probe.get("streams", []) if stream.get("codec_type") == "video"), + None, + ) + + if video_stream is None: + msg = f"({self.dev_name}) No video stream found in {self.rtsp_url}" + logger.error(msg) + raise VideoSourceError(msg) + + width = video_stream.get("width") + height = video_stream.get("height") + fps_str = video_stream.get("avg_frame_rate", "0/1") + + if not width or not height: + msg = f"({self.dev_name}) Could not determine resolution for {self.rtsp_url}. Stream info: {video_stream}" + logger.error(msg) + raise VideoSourceError(msg) + + try: + if "/" in fps_str: + num, den = map(int, fps_str.split("/")) + fps = float(num) / den if den != 0 else 30.0 + else: + fps = float(fps_str) + if fps <= 0: + logger.warning( + f"({self.dev_name}) Invalid avg_frame_rate '{fps_str}', defaulting FPS to 30." + ) + fps = 30.0 + except ValueError: + logger.warning( + f"({self.dev_name}) Could not parse FPS '{fps_str}', defaulting FPS to 30." + ) + fps = 30.0 + + logger.info(f"({self.dev_name}) Stream info: {width}x{height} @ {fps:.2f} FPS") + return {"width": width, "height": height, "fps": fps} + + def _start_ffmpeg_process(self, width: int, height: int) -> subprocess.Popen: # type: ignore[type-arg] + """Starts the ffmpeg process to capture and decode the stream.""" + logger.info(f"({self.dev_name}) Starting ffmpeg process for rtsp stream.") + try: + # Configure ffmpeg input: prefer TCP, set timeout, reduce buffering/delay + input_options = { + "rtsp_transport": "tcp", + "stimeout": "5000000", # 5 seconds timeout for RTSP server responses + "fflags": "nobuffer", # Reduce input buffering + "flags": "low_delay", # Reduce decoding delay + # 'timeout': '10000000' # Removed: This was misinterpreted as listen timeout + } + process = ( + ffmpeg.input(self.rtsp_url, **input_options) + .output("pipe:", format="rawvideo", pix_fmt="bgr24") # Output raw BGR frames + .global_args("-loglevel", "warning") # Reduce ffmpeg log spam, use 'error' for less + .run_async(pipe_stdout=True, pipe_stderr=True) # Capture stdout and stderr + ) + logger.info(f"({self.dev_name}) ffmpeg process started (PID: {process.pid})") + return process # type: ignore[no-any-return] + except ffmpeg.Error as e: + stderr = e.stderr.decode("utf8", errors="ignore") if e.stderr else "No stderr" + msg = f"({self.dev_name}) Failed to start ffmpeg for {self.rtsp_url}: {stderr}" + logger.error(msg) + raise VideoSourceError(msg) from e + except Exception as e: # Catch other errors like ffmpeg executable not found + msg = f"({self.dev_name}) An unexpected error occurred starting ffmpeg: {e}" + logger.error(msg) + raise VideoSourceError(msg) from e + + def capture_video_as_observable(self, fps: int = 0) -> Observable: # type: ignore[type-arg] + """Creates an observable from the RTSP stream using ffmpeg. + + The observable attempts to reconnect if the stream drops. + + Args: + fps: This argument is currently ignored. The provider attempts + to use the stream's native frame rate. Future versions might + allow specifying an output FPS via ffmpeg filters. + + Returns: + Observable: An observable emitting video frames as NumPy arrays (BGR format). + + Raises: + VideoSourceError: If the stream cannot be initially probed or the + ffmpeg process fails to start. + VideoFrameError: If there's an error reading or processing frames. + """ + if fps != 0: + logger.warning( + f"({self.dev_name}) The 'fps' argument ({fps}) is currently ignored. Using stream native FPS." + ) + + def emit_frames(observer, scheduler): # type: ignore[no-untyped-def] + """Function executed by rx.create to emit frames.""" + process: subprocess.Popen | None = None # type: ignore[type-arg] + # Event to signal the processing loop should stop (e.g., on dispose) + should_stop = threading.Event() + + def cleanup_process() -> None: + """Safely terminate the ffmpeg process if it's running.""" + nonlocal process + logger.debug(f"({self.dev_name}) Cleanup requested.") + # Use lock to ensure thread safety when accessing/modifying process + with self._lock: + # Check if the process exists and is still running + if process and process.poll() is None: + logger.info( + f"({self.dev_name}) Terminating ffmpeg process (PID: {process.pid})." + ) + try: + process.terminate() # Ask ffmpeg to exit gracefully + process.wait(timeout=1.0) # Wait up to 1 second + except subprocess.TimeoutExpired: + logger.warning( + f"({self.dev_name}) ffmpeg (PID: {process.pid}) did not terminate gracefully, killing." + ) + process.kill() # Force kill if it didn't exit + except Exception as e: + logger.error(f"({self.dev_name}) Error during ffmpeg termination: {e}") + finally: + # Ensure we clear the process variable even if wait/kill fails + process = None + # Also clear the shared class attribute if this was the active process + if self._ffmpeg_process and self._ffmpeg_process.pid == process.pid: # type: ignore[attr-defined] + self._ffmpeg_process = None + elif process and process.poll() is not None: + # Process exists but already terminated + logger.debug( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) already terminated (exit code: {process.poll()})." + ) + process = None # Clear the variable + # Clear shared attribute if it matches + if self._ffmpeg_process and self._ffmpeg_process.pid == process.pid: # type: ignore[attr-defined] + self._ffmpeg_process = None + else: + # Process variable is already None or doesn't match _ffmpeg_process + logger.debug( + f"({self.dev_name}) No active ffmpeg process found needing termination in cleanup." + ) + + try: + # 1. Probe the stream to get essential info (width, height) + stream_info = self._get_stream_info() + width = stream_info["width"] + height = stream_info["height"] + # Calculate expected bytes per frame (BGR format = 3 bytes per pixel) + frame_size = width * height * 3 + + # 2. Main loop: Start ffmpeg and read frames. Retry on failure. + while not should_stop.is_set(): + try: + # Start or reuse the ffmpeg process + with self._lock: + # Check if another thread/subscription already started the process + if self._ffmpeg_process and self._ffmpeg_process.poll() is None: + logger.warning( + f"({self.dev_name}) ffmpeg process (PID: {self._ffmpeg_process.pid}) seems to be already running. Reusing." + ) + process = self._ffmpeg_process + else: + # Start a new ffmpeg process + process = self._start_ffmpeg_process(width, height) + # Store the new process handle in the shared class attribute + self._ffmpeg_process = process + + # 3. Frame reading loop + while not should_stop.is_set(): + # Read exactly one frame's worth of bytes + in_bytes = process.stdout.read(frame_size) # type: ignore[union-attr] + + if len(in_bytes) == 0: + # End of stream or process terminated unexpectedly + logger.warning( + f"({self.dev_name}) ffmpeg stdout returned 0 bytes. EOF or process terminated." + ) + process.wait(timeout=0.5) # Allow stderr to flush + stderr_data = process.stderr.read().decode("utf8", errors="ignore") # type: ignore[union-attr] + exit_code = process.poll() + logger.warning( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) exited with code {exit_code}. Stderr: {stderr_data}" + ) + # Break inner loop to trigger cleanup and potential restart + with self._lock: + # Clear the shared process handle if it matches the one that just exited + if ( + self._ffmpeg_process + and self._ffmpeg_process.pid == process.pid + ): + self._ffmpeg_process = None + process = None # Clear local process variable + break # Exit frame reading loop + + elif len(in_bytes) != frame_size: + # Received incomplete frame data - indicates a problem + msg = f"({self.dev_name}) Incomplete frame read. Expected {frame_size}, got {len(in_bytes)}. Stopping." + logger.error(msg) + observer.on_error(VideoFrameError(msg)) + should_stop.set() # Signal outer loop to stop + break # Exit frame reading loop + + # Convert the raw bytes to a NumPy array (height, width, channels) + frame = np.frombuffer(in_bytes, np.uint8).reshape((height, width, 3)) + # Emit the frame to subscribers + observer.on_next(frame) + + # 4. Handle ffmpeg process exit/crash (if not stopping deliberately) + if not should_stop.is_set() and process is None: + logger.info( + f"({self.dev_name}) ffmpeg process ended. Attempting reconnection in 5 seconds..." + ) + # Wait for a few seconds before trying to restart + time.sleep(5) + # Continue to the next iteration of the outer loop to restart + + except (VideoSourceError, ffmpeg.Error) as e: + # Errors during ffmpeg process start or severe runtime errors + logger.error( + f"({self.dev_name}) Unrecoverable ffmpeg error: {e}. Stopping emission." + ) + observer.on_error(e) + should_stop.set() # Stop retrying + except Exception as e: + # Catch other unexpected errors during frame reading/processing + logger.error( + f"({self.dev_name}) Unexpected error processing stream: {e}", + exc_info=True, + ) + observer.on_error(VideoFrameError(f"Frame processing failed: {e}")) + should_stop.set() # Stop retrying + + # 5. Loop finished (likely due to should_stop being set) + logger.info(f"({self.dev_name}) Frame emission loop stopped.") + observer.on_completed() + + except VideoSourceError as e: + # Handle errors during the initial probing phase + logger.error(f"({self.dev_name}) Failed initial setup: {e}") + observer.on_error(e) + except Exception as e: + # Catch-all for unexpected errors before the main loop starts + logger.error(f"({self.dev_name}) Unexpected setup error: {e}", exc_info=True) + observer.on_error(VideoSourceError(f"Setup failed: {e}")) + finally: + # Crucial: Ensure the ffmpeg process is terminated when the observable + # is completed, errored, or disposed. + logger.debug(f"({self.dev_name}) Entering finally block in emit_frames.") + cleanup_process() + + # Return a Disposable that, when called (by unsubscribe/dispose), + # signals the loop to stop. The finally block handles the actual cleanup. + return Disposable(should_stop.set) + + # Create the observable using rx.create, applying scheduling and sharing + return rx.create(emit_frames).pipe( + ops.subscribe_on(self.pool_scheduler), # Run the emit_frames logic on the pool + # ops.observe_on(self.pool_scheduler), # Optional: Switch thread for downstream operators + ops.share(), # Ensure multiple subscribers share the same ffmpeg process + ) + + def dispose_all(self) -> None: + """Disposes of all managed resources, including terminating the ffmpeg process.""" + logger.info(f"({self.dev_name}) dispose_all called.") + # Terminate the ffmpeg process using the same locked logic as cleanup + with self._lock: + process = self._ffmpeg_process # Get the current process handle + if process and process.poll() is None: + logger.info( + f"({self.dev_name}) Terminating ffmpeg process (PID: {process.pid}) via dispose_all." + ) + try: + process.terminate() + process.wait(timeout=1.0) + except subprocess.TimeoutExpired: + logger.warning( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) kill required in dispose_all." + ) + process.kill() + except Exception as e: + logger.error( + f"({self.dev_name}) Error during ffmpeg termination in dispose_all: {e}" + ) + finally: + self._ffmpeg_process = None # Clear handle after attempting termination + elif process: # Process exists but already terminated + logger.debug( + f"({self.dev_name}) ffmpeg process (PID: {process.pid}) already terminated in dispose_all." + ) + self._ffmpeg_process = None + else: + logger.debug( + f"({self.dev_name}) No active ffmpeg process found during dispose_all." + ) + + # Call the parent class's dispose_all to handle Rx Disposables + super().dispose_all() + + def __del__(self) -> None: + """Destructor attempts to clean up resources if not explicitly disposed.""" + # Logging in __del__ is generally discouraged due to interpreter state issues, + # but can be helpful for debugging resource leaks. Use print for robustness here if needed. + # print(f"DEBUG: ({self.dev_name}) __del__ called.") + self.dispose_all() diff --git a/dimos/stream/stream_merger.py b/dimos/stream/stream_merger.py new file mode 100644 index 0000000000..645fb86030 --- /dev/null +++ b/dimos/stream/stream_merger.py @@ -0,0 +1,45 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypeVar + +from reactivex import Observable, operators as ops + +T = TypeVar("T") +Q = TypeVar("Q") + + +def create_stream_merger( + data_input_stream: Observable[T], text_query_stream: Observable[Q] +) -> Observable[tuple[Q, list[T]]]: + """ + Creates a merged stream that combines the latest value from data_input_stream + with each value from text_query_stream. + + Args: + data_input_stream: Observable stream of data values + text_query_stream: Observable stream of query values + + Returns: + Observable that emits tuples of (query, latest_data) + """ + # Encompass any data items as a list for safe evaluation + safe_data_stream = data_input_stream.pipe( + # We don't modify the data, just pass it through in a list + # This avoids any boolean evaluation of arrays + ops.map(lambda x: [x]) + ) + + # Use safe_data_stream instead of raw data_input_stream + return text_query_stream.pipe(ops.with_latest_from(safe_data_stream)) diff --git a/dimos/stream/video_operators.py b/dimos/stream/video_operators.py new file mode 100644 index 0000000000..0ba99b71e1 --- /dev/null +++ b/dimos/stream/video_operators.py @@ -0,0 +1,627 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from collections.abc import Callable +from datetime import datetime, timedelta +from enum import Enum +from typing import TYPE_CHECKING, Any + +import cv2 +import numpy as np +from reactivex import Observable, Observer, create, operators as ops +import zmq + +if TYPE_CHECKING: + from dimos.stream.frame_processor import FrameProcessor + + +class VideoOperators: + """Collection of video processing operators for reactive video streams.""" + + @staticmethod + def with_fps_sampling( + fps: int = 25, *, sample_interval: timedelta | None = None, use_latest: bool = True + ) -> Callable[[Observable], Observable]: # type: ignore[type-arg] + """Creates an operator that samples frames at a specified rate. + + Creates a transformation operator that samples frames either by taking + the latest frame or the first frame in each interval. Provides frame + rate control through time-based selection. + + Args: + fps: Desired frames per second, defaults to 25 FPS. + Ignored if sample_interval is provided. + sample_interval: Optional explicit interval between samples. + If provided, overrides the fps parameter. + use_latest: If True, uses the latest frame in interval. + If False, uses the first frame. Defaults to True. + + Returns: + A function that transforms an Observable[np.ndarray] stream to a sampled + Observable[np.ndarray] stream with controlled frame rate. + + Raises: + ValueError: If fps is not positive or sample_interval is negative. + TypeError: If sample_interval is provided but not a timedelta object. + + Examples: + Sample latest frame at 30 FPS (good for real-time): + >>> video_stream.pipe( + ... VideoOperators.with_fps_sampling(fps=30) + ... ) + + Sample first frame with custom interval (good for consistent timing): + >>> video_stream.pipe( + ... VideoOperators.with_fps_sampling( + ... sample_interval=timedelta(milliseconds=40), + ... use_latest=False + ... ) + ... ) + + Note: + This operator helps manage high-speed video streams through time-based + frame selection. It reduces the frame rate by selecting frames at + specified intervals. + + When use_latest=True: + - Uses sampling to select the most recent frame at fixed intervals + - Discards intermediate frames, keeping only the latest + - Best for real-time video where latest frame is most relevant + - Uses ops.sample internally + + When use_latest=False: + - Uses throttling to select the first frame in each interval + - Ignores subsequent frames until next interval + - Best for scenarios where you want consistent frame timing + - Uses ops.throttle_first internally + + This is an approropriate solution for managing video frame rates and + memory usage in many scenarios. + """ + if sample_interval is None: + if fps <= 0: + raise ValueError("FPS must be positive") + sample_interval = timedelta(microseconds=int(1_000_000 / fps)) + + def _operator(source: Observable) -> Observable: # type: ignore[type-arg] + return source.pipe( + ops.sample(sample_interval) if use_latest else ops.throttle_first(sample_interval) + ) + + return _operator + + @staticmethod + def with_jpeg_export( + frame_processor: "FrameProcessor", + save_limit: int = 100, + suffix: str = "", + loop: bool = False, + ) -> Callable[[Observable], Observable]: # type: ignore[type-arg] + """Creates an operator that saves video frames as JPEG files. + + Creates a transformation operator that saves each frame from the video + stream as a JPEG file while passing the frame through unchanged. + + Args: + frame_processor: FrameProcessor instance that handles the JPEG export + operations and maintains file count. + save_limit: Maximum number of frames to save before stopping. + Defaults to 100. Set to 0 for unlimited saves. + suffix: Optional string to append to filename before index. + Example: "raw" creates "1_raw.jpg". + Defaults to empty string. + loop: If True, when save_limit is reached, the files saved are + loopbacked and overwritten with the most recent frame. + Defaults to False. + Returns: + A function that transforms an Observable of frames into another + Observable of the same frames, with side effect of saving JPEGs. + + Raises: + ValueError: If save_limit is negative. + TypeError: If frame_processor is not a FrameProcessor instance. + + Example: + >>> video_stream.pipe( + ... VideoOperators.with_jpeg_export(processor, suffix="raw") + ... ) + """ + + def _operator(source: Observable) -> Observable: # type: ignore[type-arg] + return source.pipe( + ops.map( + lambda frame: frame_processor.export_to_jpeg(frame, save_limit, loop, suffix) + ) + ) + + return _operator + + @staticmethod + def with_optical_flow_filtering(threshold: float = 1.0) -> Callable[[Observable], Observable]: # type: ignore[type-arg] + """Creates an operator that filters optical flow frames by relevancy score. + + Filters a stream of optical flow results (frame, relevancy_score) tuples, + passing through only frames that meet the relevancy threshold. + + Args: + threshold: Minimum relevancy score required for frames to pass through. + Defaults to 1.0. Higher values mean more motion required. + + Returns: + A function that transforms an Observable of (frame, score) tuples + into an Observable of frames that meet the threshold. + + Raises: + ValueError: If threshold is negative. + TypeError: If input stream items are not (frame, float) tuples. + + Examples: + Basic filtering: + >>> optical_flow_stream.pipe( + ... VideoOperators.with_optical_flow_filtering(threshold=1.0) + ... ) + + With custom threshold: + >>> optical_flow_stream.pipe( + ... VideoOperators.with_optical_flow_filtering(threshold=2.5) + ... ) + + Note: + Input stream should contain tuples of (frame, relevancy_score) where + frame is a numpy array and relevancy_score is a float or None. + None scores are filtered out. + """ + return lambda source: source.pipe( + ops.filter(lambda result: result[1] is not None), # type: ignore[index] + ops.filter(lambda result: result[1] > threshold), # type: ignore[index] + ops.map(lambda result: result[0]), # type: ignore[index] + ) + + @staticmethod + def with_edge_detection( + frame_processor: "FrameProcessor", + ) -> Callable[[Observable], Observable]: # type: ignore[type-arg] + return lambda source: source.pipe( + ops.map(lambda frame: frame_processor.edge_detection(frame)) # type: ignore[no-untyped-call] + ) + + @staticmethod + def with_optical_flow( + frame_processor: "FrameProcessor", + ) -> Callable[[Observable], Observable]: # type: ignore[type-arg] + return lambda source: source.pipe( + ops.scan( + lambda acc, frame: frame_processor.compute_optical_flow( # type: ignore[arg-type, return-value] + acc, # type: ignore[arg-type] + frame, # type: ignore[arg-type] + compute_relevancy=False, + ), + (None, None, None), + ), + ops.map(lambda result: result[1]), # type: ignore[index] # Extract flow component + ops.filter(lambda flow: flow is not None), + ops.map(frame_processor.visualize_flow), + ) + + @staticmethod + def with_zmq_socket( + socket: zmq.Socket, # type: ignore[type-arg] + scheduler: Any | None = None, + ) -> Callable[[Observable], Observable]: # type: ignore[type-arg] + def send_frame(frame, socket) -> None: # type: ignore[no-untyped-def] + _, img_encoded = cv2.imencode(".jpg", frame) + socket.send(img_encoded.tobytes()) + # print(f"Frame received: {frame.shape}") + + # Use a default scheduler if none is provided + if scheduler is None: + from reactivex.scheduler import ThreadPoolScheduler + + scheduler = ThreadPoolScheduler(1) # Single-threaded pool for isolation + + return lambda source: source.pipe( + ops.observe_on(scheduler), # Ensure this part runs on its own thread + ops.do_action(lambda frame: send_frame(frame, socket)), + ) + + @staticmethod + def encode_image() -> Callable[[Observable], Observable]: # type: ignore[type-arg] + """ + Operator to encode an image to JPEG format and convert it to a Base64 string. + + Returns: + A function that transforms an Observable of images into an Observable + of tuples containing the Base64 string of the encoded image and its dimensions. + """ + + def _operator(source: Observable) -> Observable: # type: ignore[type-arg] + def _encode_image(image: np.ndarray) -> tuple[str, tuple[int, int]]: # type: ignore[type-arg] + try: + width, height = image.shape[:2] + _, buffer = cv2.imencode(".jpg", image) + if buffer is None: + raise ValueError("Failed to encode image") + base64_image = base64.b64encode(buffer).decode("utf-8") + return base64_image, (width, height) + except Exception as e: + raise e + + return source.pipe(ops.map(_encode_image)) + + return _operator + + +from threading import Lock + +from reactivex import Observable +from reactivex.disposable import Disposable + + +class Operators: + @staticmethod + def exhaust_lock(process_item): # type: ignore[no-untyped-def] + """ + For each incoming item, call `process_item(item)` to get an Observable. + - If we're busy processing the previous one, skip new items. + - Use a lock to ensure concurrency safety across threads. + """ + + def _exhaust_lock(source: Observable) -> Observable: # type: ignore[type-arg] + def _subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + in_flight = False + lock = Lock() + upstream_done = False + + upstream_disp = None + active_inner_disp = None + + def dispose_all() -> None: + if upstream_disp: + upstream_disp.dispose() + if active_inner_disp: + active_inner_disp.dispose() + + def on_next(value) -> None: # type: ignore[no-untyped-def] + nonlocal in_flight, active_inner_disp + lock.acquire() + try: + if not in_flight: + in_flight = True + print("Processing new item.") + else: + print("Skipping item, already processing.") + return + finally: + lock.release() + + # We only get here if we grabbed the in_flight slot + try: + inner_source = process_item(value) + except Exception as ex: + observer.on_error(ex) + return + + def inner_on_next(ivalue) -> None: # type: ignore[no-untyped-def] + observer.on_next(ivalue) + + def inner_on_error(err) -> None: # type: ignore[no-untyped-def] + nonlocal in_flight + with lock: + in_flight = False + observer.on_error(err) + + def inner_on_completed() -> None: + nonlocal in_flight + with lock: + in_flight = False + if upstream_done: + observer.on_completed() + + # Subscribe to the inner observable + nonlocal active_inner_disp + active_inner_disp = inner_source.subscribe( + on_next=inner_on_next, + on_error=inner_on_error, + on_completed=inner_on_completed, + scheduler=scheduler, + ) + + def on_error(err) -> None: # type: ignore[no-untyped-def] + dispose_all() + observer.on_error(err) + + def on_completed() -> None: + nonlocal upstream_done + with lock: + upstream_done = True + # If we're not busy, we can end now + if not in_flight: + observer.on_completed() + + upstream_disp = source.subscribe( + on_next, on_error, on_completed, scheduler=scheduler + ) + return dispose_all + + return create(_subscribe) + + return _exhaust_lock + + @staticmethod + def exhaust_lock_per_instance(process_item, lock: Lock): # type: ignore[no-untyped-def] + """ + - For each item from upstream, call process_item(item) -> Observable. + - If a frame arrives while one is "in flight", discard it. + - 'lock' ensures we safely check/modify the 'in_flight' state in a multithreaded environment. + """ + + def _exhaust_lock(source: Observable) -> Observable: # type: ignore[type-arg] + def _subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + in_flight = False + upstream_done = False + + upstream_disp = None + active_inner_disp = None + + def dispose_all() -> None: + if upstream_disp: + upstream_disp.dispose() + if active_inner_disp: + active_inner_disp.dispose() + + def on_next(value) -> None: # type: ignore[no-untyped-def] + nonlocal in_flight, active_inner_disp + with lock: + # If not busy, claim the slot + if not in_flight: + in_flight = True + print("\033[34mProcessing new item.\033[0m") + else: + # Already processing => drop + print("\033[34mSkipping item, already processing.\033[0m") + return + + # We only get here if we acquired the slot + try: + inner_source = process_item(value) + except Exception as ex: + observer.on_error(ex) + return + + def inner_on_next(ivalue) -> None: # type: ignore[no-untyped-def] + observer.on_next(ivalue) + + def inner_on_error(err) -> None: # type: ignore[no-untyped-def] + nonlocal in_flight + with lock: + in_flight = False + print("\033[34mError in inner on error.\033[0m") + observer.on_error(err) + + def inner_on_completed() -> None: + nonlocal in_flight + with lock: + in_flight = False + print("\033[34mInner on completed.\033[0m") + if upstream_done: + observer.on_completed() + + # Subscribe to the inner Observable + nonlocal active_inner_disp + active_inner_disp = inner_source.subscribe( + on_next=inner_on_next, + on_error=inner_on_error, + on_completed=inner_on_completed, + scheduler=scheduler, + ) + + def on_error(e) -> None: # type: ignore[no-untyped-def] + dispose_all() + observer.on_error(e) + + def on_completed() -> None: + nonlocal upstream_done + with lock: + upstream_done = True + print("\033[34mOn completed.\033[0m") + if not in_flight: + observer.on_completed() + + upstream_disp = source.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + scheduler=scheduler, + ) + + return Disposable(dispose_all) + + return create(_subscribe) + + return _exhaust_lock + + @staticmethod + def exhaust_map(project): # type: ignore[no-untyped-def] + def _exhaust_map(source: Observable): # type: ignore[no-untyped-def, type-arg] + def subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + is_processing = False + + def on_next(item) -> None: # type: ignore[no-untyped-def] + nonlocal is_processing + if not is_processing: + is_processing = True + print("\033[35mProcessing item.\033[0m") + try: + inner_observable = project(item) # Create the inner observable + inner_observable.subscribe( + on_next=observer.on_next, + on_error=observer.on_error, + on_completed=lambda: set_not_processing(), + scheduler=scheduler, + ) + except Exception as e: + observer.on_error(e) + else: + print("\033[35mSkipping item, already processing.\033[0m") + + def set_not_processing() -> None: + nonlocal is_processing + is_processing = False + print("\033[35mItem processed.\033[0m") + + return source.subscribe( + on_next=on_next, + on_error=observer.on_error, + on_completed=observer.on_completed, + scheduler=scheduler, + ) + + return create(subscribe) + + return _exhaust_map + + @staticmethod + def with_lock(lock: Lock): # type: ignore[no-untyped-def] + def operator(source: Observable): # type: ignore[no-untyped-def, type-arg] + def subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + def on_next(item) -> None: # type: ignore[no-untyped-def] + if not lock.locked(): # Check if the lock is free + if lock.acquire(blocking=False): # Non-blocking acquire + try: + print("\033[32mAcquired lock, processing item.\033[0m") + observer.on_next(item) + finally: # Ensure lock release even if observer.on_next throws + lock.release() + else: + print("\033[34mLock busy, skipping item.\033[0m") + else: + print("\033[34mLock busy, skipping item.\033[0m") + + def on_error(error) -> None: # type: ignore[no-untyped-def] + observer.on_error(error) + + def on_completed() -> None: + observer.on_completed() + + return source.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + scheduler=scheduler, + ) + + return Observable(subscribe) + + return operator + + @staticmethod + def with_lock_check(lock: Lock): # type: ignore[no-untyped-def] # Renamed for clarity + def operator(source: Observable): # type: ignore[no-untyped-def, type-arg] + def subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + def on_next(item) -> None: # type: ignore[no-untyped-def] + if not lock.locked(): # Check if the lock is held WITHOUT acquiring + print(f"\033[32mLock is free, processing item: {item}\033[0m") + observer.on_next(item) + else: + print(f"\033[34mLock is busy, skipping item: {item}\033[0m") + # observer.on_completed() + + def on_error(error) -> None: # type: ignore[no-untyped-def] + observer.on_error(error) + + def on_completed() -> None: + observer.on_completed() + + return source.subscribe( + on_next=on_next, + on_error=on_error, + on_completed=on_completed, + scheduler=scheduler, + ) + + return Observable(subscribe) + + return operator + + # PrintColor enum for standardized color formatting + class PrintColor(Enum): + RED = "\033[31m" + GREEN = "\033[32m" + BLUE = "\033[34m" + YELLOW = "\033[33m" + MAGENTA = "\033[35m" + CYAN = "\033[36m" + WHITE = "\033[37m" + RESET = "\033[0m" + + @staticmethod + def print_emission( # type: ignore[no-untyped-def] + id: str, + dev_name: str = "NA", + counts: dict | None = None, # type: ignore[type-arg] + color: "Operators.PrintColor" = None, # type: ignore[assignment] + enabled: bool = True, + ): + """ + Creates an operator that prints the emission with optional counts for debugging. + + Args: + id: Identifier for the emission point (e.g., 'A', 'B') + dev_name: Device or component name for context + counts: External dictionary to track emission count across operators. If None, will not print emission count. + color: Color for the printed output from PrintColor enum (default is RED) + enabled: Whether to print the emission count (default is True) + Returns: + An operator that counts and prints emissions without modifying the stream + """ + # If enabled is false, return the source unchanged + if not enabled: + return lambda source: source + + # Use RED as default if no color provided + if color is None: + color = Operators.PrintColor.RED + + def _operator(source: Observable) -> Observable: # type: ignore[type-arg] + def _subscribe(observer: Observer, scheduler=None): # type: ignore[no-untyped-def, type-arg] + def on_next(value) -> None: # type: ignore[no-untyped-def] + if counts is not None: + # Initialize count if necessary + if id not in counts: + counts[id] = 0 + + # Increment and print + counts[id] += 1 + print( + f"{color.value}({dev_name} - {id}) Emission Count - {counts[id]} {datetime.now()}{Operators.PrintColor.RESET.value}" + ) + else: + print( + f"{color.value}({dev_name} - {id}) Emitted - {datetime.now()}{Operators.PrintColor.RESET.value}" + ) + + # Pass value through unchanged + observer.on_next(value) + + return source.subscribe( + on_next=on_next, + on_error=observer.on_error, + on_completed=observer.on_completed, + scheduler=scheduler, + ) + + return create(_subscribe) # type: ignore[arg-type] + + return _operator diff --git a/dimos/stream/video_provider.py b/dimos/stream/video_provider.py new file mode 100644 index 0000000000..38406fd5a5 --- /dev/null +++ b/dimos/stream/video_provider.py @@ -0,0 +1,234 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Video provider module for capturing and streaming video frames. + +This module provides classes for capturing video from various sources and +exposing them as reactive observables. It handles resource management, +frame rate control, and thread safety. +""" + +# Standard library imports +from abc import ABC, abstractmethod +import logging +import os +from threading import Lock +import time + +# Third-party imports +import cv2 +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import CompositeDisposable +from reactivex.observable import Observable +from reactivex.scheduler import ThreadPoolScheduler + +# Local imports +from dimos.utils.threadpool import get_scheduler + +# Note: Logging configuration should ideally be in the application initialization, +# not in a module. Keeping it for now but with a more restricted scope. +logger = logging.getLogger(__name__) + + +# Specific exception classes +class VideoSourceError(Exception): + """Raised when there's an issue with the video source.""" + + pass + + +class VideoFrameError(Exception): + """Raised when there's an issue with frame acquisition.""" + + pass + + +class AbstractVideoProvider(ABC): + """Abstract base class for video providers managing video capture resources.""" + + def __init__( + self, dev_name: str = "NA", pool_scheduler: ThreadPoolScheduler | None = None + ) -> None: + """Initializes the video provider with a device name. + + Args: + dev_name: The name of the device. Defaults to "NA". + pool_scheduler: The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + """ + self.dev_name = dev_name + self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() + self.disposables = CompositeDisposable() + + @abstractmethod + def capture_video_as_observable(self, fps: int = 30) -> Observable: # type: ignore[type-arg] + """Create an observable from video capture. + + Args: + fps: Frames per second to emit. Defaults to 30fps. + + Returns: + Observable: An observable emitting frames at the specified rate. + + Raises: + VideoSourceError: If the video source cannot be opened. + VideoFrameError: If frames cannot be read properly. + """ + pass + + def dispose_all(self) -> None: + """Disposes of all active subscriptions managed by this provider.""" + if self.disposables: + self.disposables.dispose() + else: + logger.info("No disposables to dispose.") + + def __del__(self) -> None: + """Destructor to ensure resources are cleaned up if not explicitly disposed.""" + self.dispose_all() + + +class VideoProvider(AbstractVideoProvider): + """Video provider implementation for capturing video as an observable.""" + + def __init__( + self, + dev_name: str, + video_source: str = f"{os.getcwd()}/assets/video-f30-480p.mp4", + pool_scheduler: ThreadPoolScheduler | None = None, + ) -> None: + """Initializes the video provider with a device name and video source. + + Args: + dev_name: The name of the device. + video_source: The path to the video source. Defaults to a sample video. + pool_scheduler: The scheduler to use for thread pool operations. + If None, the global scheduler from get_scheduler() will be used. + """ + super().__init__(dev_name, pool_scheduler) + self.video_source = video_source + self.cap = None + self.lock = Lock() + + def _initialize_capture(self) -> None: + """Initializes the video capture object if not already initialized. + + Raises: + VideoSourceError: If the video source cannot be opened. + """ + if self.cap is None or not self.cap.isOpened(): + # Release previous capture if it exists but is closed + if self.cap: + self.cap.release() + logger.info("Released previous capture") + + # Attempt to open new capture + self.cap = cv2.VideoCapture(self.video_source) # type: ignore[assignment] + if self.cap is None or not self.cap.isOpened(): + error_msg = f"Failed to open video source: {self.video_source}" + logger.error(error_msg) + raise VideoSourceError(error_msg) + + logger.info(f"Opened new capture: {self.video_source}") + + def capture_video_as_observable(self, realtime: bool = True, fps: int = 30) -> Observable: # type: ignore[override, type-arg] + """Creates an observable from video capture. + + Creates an observable that emits frames at specified FPS or the video's + native FPS, with proper resource management and error handling. + + Args: + realtime: If True, use the video's native FPS. Defaults to True. + fps: Frames per second to emit. Defaults to 30fps. Only used if + realtime is False or the video's native FPS is not available. + + Returns: + Observable: An observable emitting frames at the configured rate. + + Raises: + VideoSourceError: If the video source cannot be opened. + VideoFrameError: If frames cannot be read properly. + """ + + def emit_frames(observer, scheduler) -> None: # type: ignore[no-untyped-def] + try: + self._initialize_capture() + + # Determine the FPS to use based on configuration and availability + local_fps: float = fps + if realtime: + native_fps: float = self.cap.get(cv2.CAP_PROP_FPS) # type: ignore[attr-defined] + if native_fps > 0: + local_fps = native_fps + else: + logger.warning("Native FPS not available, defaulting to specified FPS") + + frame_interval: float = 1.0 / local_fps + frame_time: float = time.monotonic() + + while self.cap.isOpened(): # type: ignore[attr-defined] + # Thread-safe access to video capture + with self.lock: + ret, frame = self.cap.read() # type: ignore[attr-defined] + + if not ret: + # Loop video when we reach the end + logger.warning("End of video reached, restarting playback") + with self.lock: + self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # type: ignore[attr-defined] + continue + + # Control frame rate to match target FPS + now: float = time.monotonic() + next_frame_time: float = frame_time + frame_interval + sleep_time: float = next_frame_time - now + + if sleep_time > 0: + time.sleep(sleep_time) + + observer.on_next(frame) + frame_time = next_frame_time + + except VideoSourceError as e: + logger.error(f"Video source error: {e}") + observer.on_error(e) + except Exception as e: + logger.error(f"Unexpected error during frame emission: {e}") + observer.on_error(VideoFrameError(f"Frame acquisition failed: {e}")) + finally: + # Clean up resources regardless of success or failure + with self.lock: + if self.cap and self.cap.isOpened(): + self.cap.release() + logger.info("Capture released") + observer.on_completed() + + return rx.create(emit_frames).pipe( # type: ignore[arg-type] + ops.subscribe_on(self.pool_scheduler), + ops.observe_on(self.pool_scheduler), + ops.share(), # Share the stream among multiple subscribers + ) + + def dispose_all(self) -> None: + """Disposes of all resources including video capture.""" + with self.lock: + if self.cap and self.cap.isOpened(): + self.cap.release() + logger.info("Capture released in dispose_all") + super().dispose_all() + + def __del__(self) -> None: + """Destructor to ensure resources are cleaned up if not explicitly disposed.""" + self.dispose_all() diff --git a/dimos/stream/video_providers/__init__.py b/dimos/stream/video_providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/stream/videostream.py b/dimos/stream/videostream.py deleted file mode 100644 index f501846c82..0000000000 --- a/dimos/stream/videostream.py +++ /dev/null @@ -1,141 +0,0 @@ -from datetime import timedelta -import cv2 -import numpy as np -import os -from reactivex import Observable -from reactivex import operators as ops - -class StreamUtils: - def limit_emission_rate(frame_stream, time_delta=timedelta(milliseconds=40)): - return frame_stream.pipe( - ops.throttle_first(time_delta) - ) - - -# TODO: Reorganize, filenaming -class FrameProcessor: - def __init__(self, output_dir='/app/assets/frames'): - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - self.image_count = 0 - # TODO: Add randomness to jpg folder storage naming. - # Will overwrite between sessions. - - def to_grayscale(self, frame): - if frame is None: - print("Received None frame for grayscale conversion.") - return None - return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - - def edge_detection(self, frame): - return cv2.Canny(frame, 100, 200) - - def resize(self, frame, scale=0.5): - return cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) - - def export_to_jpeg(self, frame, save_limit=100, suffix=""): - if frame is None: - print("Error: Attempted to save a None image.") - return None - - # Check if the image has an acceptable number of channels - if len(frame.shape) == 3 and frame.shape[2] not in [1, 3, 4]: - print(f"Error: Frame with shape {frame.shape} has unsupported number of channels.") - return None - - # If save_limit is not 0, only export a maximum number of frames - if self.image_count > save_limit: - return frame - - filepath = os.path.join(self.output_dir, f'{suffix}_image_{self.image_count}.jpg') - cv2.imwrite(filepath, frame) - self.image_count += 1 - return frame - - def compute_optical_flow(self, acc, current_frame): - prev_frame, _ = acc # acc (accumulator) contains the previous frame and its flow (which is ignored here) - - if prev_frame is None: - # Skip processing for the first frame as there's no previous frame to compare against. - return (current_frame, None) - - # Convert frames to grayscale (if not already done) - gray_current = self.to_grayscale(current_frame) - gray_prev = self.to_grayscale(prev_frame) - - # Compute optical flow - flow = cv2.calcOpticalFlowFarneback(gray_prev, gray_current, None, 0.5, 3, 15, 3, 5, 1.2, 0) - - # Relevancy calulation (average magnitude of flow vectors) - mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) - relevancy = np.mean(mag) - - # Return the current frame as the new previous frame and the processed optical flow, with relevancy score - return (current_frame, flow, relevancy) - - def visualize_flow(self, flow): - if flow is None: - return None - hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) - hsv[..., 1] = 255 - mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) - hsv[..., 0] = ang * 180 / np.pi / 2 - hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) - rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) - return rgb - - # ============================== - - def process_stream_edge_detection(self, frame_stream): - return frame_stream.pipe( - ops.map(self.edge_detection), - ) - - def process_stream_resize(self, frame_stream): - return frame_stream.pipe( - ops.map(self.resize), - ) - - def process_stream_to_greyscale(self, frame_stream): - return frame_stream.pipe( - ops.map(self.to_grayscale), - ) - - # TODO: Propogate up relevancy score from compute_optical_flow - def process_stream_optical_flow(self, frame_stream): - return frame_stream.pipe( - ops.scan(self.compute_optical_flow, (None, None)), # Initial value for scan is (None, None) - ops.map(lambda result: result[1]), # Extract only the flow part from the tuple - ops.filter(lambda flow: flow is not None), - ops.map(self.visualize_flow), - ) - - def process_stream_export_to_jpeg(self, frame_stream, suffix=""): - return frame_stream.pipe( - ops.map(lambda frame: self.export_to_jpeg(frame, suffix=suffix)), - ) - -class VideoStream: - def __init__(self, source=0): - """ - Initialize the video stream from a camera source. - - Args: - source (int or str): Camera index or video file path. - """ - self.capture = cv2.VideoCapture(source) - if not self.capture.isOpened(): - raise ValueError(f"Unable to open video source {source}") - - def __iter__(self): - return self - - def __next__(self): - ret, frame = self.capture.read() - if not ret: - self.capture.release() - raise StopIteration - return frame - - def release(self): - self.capture.release() \ No newline at end of file diff --git a/dimos/types/constants.py b/dimos/types/constants.py new file mode 100644 index 0000000000..b02726cb0b --- /dev/null +++ b/dimos/types/constants.py @@ -0,0 +1,24 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Colors: + GREEN_PRINT_COLOR: str = "\033[32m" + YELLOW_PRINT_COLOR: str = "\033[33m" + RED_PRINT_COLOR: str = "\033[31m" + BLUE_PRINT_COLOR: str = "\033[34m" + MAGENTA_PRINT_COLOR: str = "\033[35m" + CYAN_PRINT_COLOR: str = "\033[36m" + WHITE_PRINT_COLOR: str = "\033[37m" + RESET_COLOR: str = "\033[0m" diff --git a/dimos/types/manipulation.py b/dimos/types/manipulation.py new file mode 100644 index 0000000000..507b9e9b85 --- /dev/null +++ b/dimos/types/manipulation.py @@ -0,0 +1,168 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from dataclasses import dataclass, field +from enum import Enum +import time +from typing import TYPE_CHECKING, Any, Literal, TypedDict +import uuid + +import numpy as np + +from dimos.types.vector import Vector + +if TYPE_CHECKING: + import open3d as o3d # type: ignore[import-untyped] + + +class ConstraintType(Enum): + """Types of manipulation constraints.""" + + TRANSLATION = "translation" + ROTATION = "rotation" + FORCE = "force" + + +@dataclass +class AbstractConstraint(ABC): + """Base class for all manipulation constraints.""" + + description: str = "" + id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) + + +@dataclass +class TranslationConstraint(AbstractConstraint): + """Constraint parameters for translational movement along a single axis.""" + + translation_axis: Literal["x", "y", "z"] = None # type: ignore[assignment] # Axis to translate along + reference_point: Vector | None = None + bounds_min: Vector | None = None # For bounded translation + bounds_max: Vector | None = None # For bounded translation + target_point: Vector | None = None # For relative positioning + + +@dataclass +class RotationConstraint(AbstractConstraint): + """Constraint parameters for rotational movement around a single axis.""" + + rotation_axis: Literal["roll", "pitch", "yaw"] = None # type: ignore[assignment] # Axis to rotate around + start_angle: Vector | None = None # Angle values applied to the specified rotation axis + end_angle: Vector | None = None # Angle values applied to the specified rotation axis + pivot_point: Vector | None = None # Point of rotation + secondary_pivot_point: Vector | None = None # For double point rotations + + +@dataclass +class ForceConstraint(AbstractConstraint): + """Constraint parameters for force application.""" + + max_force: float = 0.0 # Maximum force in newtons + min_force: float = 0.0 # Minimum force in newtons + force_direction: Vector | None = None # Direction of force application + + +class ObjectData(TypedDict, total=False): + """Data about an object in the manipulation scene.""" + + # Basic detection information + object_id: int # Unique ID for the object + bbox: list[float] # Bounding box [x1, y1, x2, y2] + depth: float # Depth in meters from Metric3d + confidence: float # Detection confidence + class_id: int # Class ID from the detector + label: str # Semantic label (e.g., 'cup', 'table') + movement_tolerance: float # (0.0 = immovable, 1.0 = freely movable) + segmentation_mask: np.ndarray # type: ignore[type-arg] # Binary mask of the object's pixels + + # 3D pose and dimensions + position: dict[str, float] | Vector # 3D position {x, y, z} or Vector + rotation: dict[str, float] | Vector # 3D rotation {roll, pitch, yaw} or Vector + size: dict[str, float] # Object dimensions {width, height, depth} + + # Point cloud data + point_cloud: "o3d.geometry.PointCloud" # Open3D point cloud object + point_cloud_numpy: np.ndarray # type: ignore[type-arg] # Nx6 array of XYZRGB points + color: np.ndarray # type: ignore[type-arg] # RGB color for visualization [R, G, B] + + +class ManipulationMetadata(TypedDict, total=False): + """Typed metadata for manipulation constraints.""" + + timestamp: float + objects: dict[str, ObjectData] + + +@dataclass +class ManipulationTaskConstraint: + """Set of constraints for a specific manipulation action.""" + + constraints: list[AbstractConstraint] = field(default_factory=list) + + def add_constraint(self, constraint: AbstractConstraint) -> None: + """Add a constraint to this set.""" + if constraint not in self.constraints: + self.constraints.append(constraint) + + def get_constraints(self) -> list[AbstractConstraint]: + """Get all constraints in this set.""" + return self.constraints + + +@dataclass +class ManipulationTask: + """Complete definition of a manipulation task.""" + + description: str + target_object: str # Semantic label of target object + target_point: tuple[float, float] | None = ( + None # (X,Y) point in pixel-space of the point to manipulate on target object + ) + metadata: ManipulationMetadata = field(default_factory=dict) # type: ignore[assignment] + timestamp: float = field(default_factory=time.time) + task_id: str = "" + result: dict[str, Any] | None = None # Any result data from the task execution + constraints: list[AbstractConstraint] | ManipulationTaskConstraint | AbstractConstraint = field( + default_factory=list + ) + + def add_constraint(self, constraint: AbstractConstraint) -> None: + """Add a constraint to this manipulation task.""" + # If constraints is a ManipulationTaskConstraint object + if isinstance(self.constraints, ManipulationTaskConstraint): + self.constraints.add_constraint(constraint) + return + + # If constraints is a single AbstractConstraint, convert to list + if isinstance(self.constraints, AbstractConstraint): + self.constraints = [self.constraints, constraint] + return + + # If constraints is a list, append to it + # This will also handle empty lists (the default case) + self.constraints.append(constraint) + + def get_constraints(self) -> list[AbstractConstraint]: + """Get all constraints in this manipulation task.""" + # If constraints is a ManipulationTaskConstraint object + if isinstance(self.constraints, ManipulationTaskConstraint): + return self.constraints.get_constraints() + + # If constraints is a single AbstractConstraint, return as list + if isinstance(self.constraints, AbstractConstraint): + return [self.constraints] + + # If constraints is a list (including empty list), return it + return self.constraints diff --git a/dimos/types/media_provider.py b/dimos/types/media_provider.py deleted file mode 100644 index 8dfa07e55c..0000000000 --- a/dimos/types/media_provider.py +++ /dev/null @@ -1,149 +0,0 @@ -from time import sleep -import cv2 -import reactivex as rx -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler - - -class MediaProvider: - def __init__(self, dev_name:str="NA"): - self.dev_name = dev_name - self.disposables = CompositeDisposable() - - def dispose_all(self): - """Disposes of all active subscriptions managed by this agent.""" - if self.disposables: - self.disposables.dispose() - else: - print("No disposables to dispose.") - - -# TODO: Test threading concurrency and instanciation more fully -class VideoProviderExample(MediaProvider): - def __init__(self, dev_name: str, video_source:str="/app/assets/video-f30-480p.mp4"): - super().__init__(dev_name) - self.video_source = video_source - # self.scheduler = ThreadPoolScheduler(1) # CurrentThreadScheduler - self.cap = None - - def get_capture(self): - """Ensure that the capture device is correctly initialized and open.""" - if self.cap is None or not self.cap.isOpened(): - if self.cap: - self.cap.release() - print("Released Capture") - self.cap = cv2.VideoCapture(self.video_source) - print("Opened Capture") - if not self.cap.isOpened(): - raise Exception("Failed to open video source") - return self.cap - - def video_capture_to_observable(self): - cap = self.get_capture() - - def emit_frames(observer, scheduler): - try: - while cap.isOpened(): - ret, frame = cap.read() - if ret: - observer.on_next(frame) - else: - cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # If loading from a video, loop it - continue - # observer.on_completed() - # break - except Exception as e: - observer.on_error(e) - finally: - cap.release() - - return rx.create(emit_frames).pipe( - # ops.observe_on(self.scheduler), # - # ops.subscribe_on(self.scheduler), # - ops.share() - ) - - def dispose_all(self): - """Disposes of all resources.""" - if self.cap and self.cap.isOpened(): - self.cap.release() - super().dispose_all() - - def __del__(self): - """Destructor to ensure resources are cleaned up if not explicitly disposed.""" - self.dispose_all() - - - - - - -# class VideoProviderExample(MediaProvider): -# def __init__(self, dev_name: str, provider_type:str="Video", video_source:str="/app/assets/video-f30-480p.mp4"): -# super().__init__(dev_name) -# self.provider_type = provider_type -# self.video_source = video_source - -# def video_capture_to_observable(self, cap): -# """Creates an observable from a video capture source.""" -# def on_subscribe(observer, scheduler=None): - -# def read_frame(): # scheduler, state): -# while True: -# try: -# ret, frame = cap.read() -# if ret: -# observer.on_next(frame) -# # cv2.waitKey(1) -# # Reschedule reading the next frame -# #if scheduler: -# #scheduler.schedule(read_frame) -# else: -# cap.set(cv2.CAP_PROP_POS_FRAMES, 0) -# continue -# # observer.on_completed() -# # cap.release() -# except Exception as e: -# observer.on_error(e) -# cap.release() - -# # Schedule the first frame read -# #if scheduler: -# #scheduler.schedule(read_frame) -# #else: -# read_frame() # Direct call on the same thread -# return rx.create(on_subscribe).pipe( -# ops.publish(), # Convert the observable from cold to hot -# ops.ref_count() # Start emitting when the first subscriber subscribes and stop when the last unsubscribes -# ) - -# def get_capture(self): # , video_source="/app/assets/video-f30-480p.mp4"): -# # video_source = root_dir + '' # "udp://0.0.0.0:23000" # "/dev/video0" -# cap = cv2.VideoCapture(self.video_source) -# print("Opening video source") -# print(f"Source: {self.video_source}") -# if not cap.isOpened(): -# print("Failed to open video source") -# exit() -# print("Opened video source") -# return cap - -# def video_capture_to_observable(self): # , video_source="/app/assets/video-f30-480p.mp4"): -# cap = self.get_capture() -# return self.video_capture_to_observable(cap) - -# # def dispose(): -# # self.disposeables.dispose() -# # from time import sleep -# # while True: -# # sleep(1) -# # if cv2.waitKey(1) & 0xFF == ord('q'): -# # # disposable.dispose() -# # disposable_flask.dispose() -# # disposable_oai.dispose() -# # for _ in disposablables: -# # disposablables.dispose() - -# # cv2.destroyAllWindows() -# # break diff --git a/dimos/types/robot_capabilities.py b/dimos/types/robot_capabilities.py new file mode 100644 index 0000000000..9a3f5da14e --- /dev/null +++ b/dimos/types/robot_capabilities.py @@ -0,0 +1,27 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Robot capabilities module for defining robot functionality.""" + +from enum import Enum, auto + + +class RobotCapability(Enum): + """Enum defining possible robot capabilities.""" + + MANIPULATION = auto() + VISION = auto() + AUDIO = auto() + SPEECH = auto() + LOCOMOTION = auto() diff --git a/dimos/types/robot_location.py b/dimos/types/robot_location.py new file mode 100644 index 0000000000..78077092f8 --- /dev/null +++ b/dimos/types/robot_location.py @@ -0,0 +1,138 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RobotLocation type definition for storing and managing robot location data. +""" + +from dataclasses import dataclass, field +import time +from typing import Any +import uuid + + +@dataclass +class RobotLocation: + """ + Represents a named location in the robot's spatial memory. + + This class stores the position, rotation, and descriptive metadata for + locations that the robot can remember and navigate to. + + Attributes: + name: Human-readable name of the location (e.g., "kitchen", "office") + position: 3D position coordinates (x, y, z) + rotation: 3D rotation angles in radians (roll, pitch, yaw) + frame_id: ID of the associated video frame if available + timestamp: Time when the location was recorded + location_id: Unique identifier for this location + metadata: Additional metadata for the location + """ + + name: str + position: tuple[float, float, float] + rotation: tuple[float, float, float] + frame_id: str | None = None + timestamp: float = field(default_factory=time.time) + location_id: str = field(default_factory=lambda: f"loc_{uuid.uuid4().hex[:8]}") + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Validate and normalize the position and rotation tuples.""" + # Ensure position is a tuple of 3 floats + if len(self.position) == 2: + self.position = (self.position[0], self.position[1], 0.0) + else: + self.position = tuple(float(x) for x in self.position) # type: ignore[assignment] + + # Ensure rotation is a tuple of 3 floats + if len(self.rotation) == 1: + self.rotation = (0.0, 0.0, self.rotation[0]) + else: + self.rotation = tuple(float(x) for x in self.rotation) # type: ignore[assignment] + + def to_vector_metadata(self) -> dict[str, Any]: + """ + Convert the location to metadata format for storing in a vector database. + + Returns: + Dictionary with metadata fields compatible with vector DB storage + """ + metadata = { + "pos_x": float(self.position[0]), + "pos_y": float(self.position[1]), + "pos_z": float(self.position[2]), + "rot_x": float(self.rotation[0]), + "rot_y": float(self.rotation[1]), + "rot_z": float(self.rotation[2]), + "timestamp": self.timestamp, + "location_id": self.location_id, + "location_name": self.name, + "description": self.name, # Makes it searchable by text + } + + # Only add frame_id if it's not None + if self.frame_id is not None: + metadata["frame_id"] = self.frame_id + + return metadata + + @classmethod + def from_vector_metadata(cls, metadata: dict[str, Any]) -> "RobotLocation": + """ + Create a RobotLocation object from vector database metadata. + + Args: + metadata: Dictionary with metadata from vector database + + Returns: + RobotLocation object + """ + return cls( + name=metadata.get("location_name", "unknown"), + position=( + metadata.get("pos_x", 0.0), + metadata.get("pos_y", 0.0), + metadata.get("pos_z", 0.0), + ), + rotation=( + metadata.get("rot_x", 0.0), + metadata.get("rot_y", 0.0), + metadata.get("rot_z", 0.0), + ), + frame_id=metadata.get("frame_id"), + timestamp=metadata.get("timestamp", time.time()), + location_id=metadata.get("location_id", f"loc_{uuid.uuid4().hex[:8]}"), + metadata={ + k: v + for k, v in metadata.items() + if k + not in [ + "pos_x", + "pos_y", + "pos_z", + "rot_x", + "rot_y", + "rot_z", + "timestamp", + "location_id", + "frame_id", + "location_name", + "description", + ] + }, + ) + + def __str__(self) -> str: + return f"[RobotPosition name:{self.name} pos:{self.position} rot:{self.rotation})]" diff --git a/dimos/types/ros_polyfill.py b/dimos/types/ros_polyfill.py new file mode 100644 index 0000000000..211c94dd49 --- /dev/null +++ b/dimos/types/ros_polyfill.py @@ -0,0 +1,43 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + from geometry_msgs.msg import Vector3 # type: ignore[attr-defined] +except ImportError: + from dimos.msgs.geometry_msgs import Vector3 + +try: + from geometry_msgs.msg import Point, Pose, Quaternion, Twist # type: ignore[attr-defined] + from nav_msgs.msg import OccupancyGrid, Odometry # type: ignore[attr-defined] + from std_msgs.msg import Header # type: ignore[attr-defined] +except ImportError: + from dimos_lcm.geometry_msgs import ( # type: ignore[import-untyped, no-redef] + Point, + Pose, + Quaternion, + Twist, + ) + from dimos_lcm.nav_msgs import OccupancyGrid, Odometry # type: ignore[import-untyped, no-redef] + from dimos_lcm.std_msgs import Header # type: ignore[import-untyped, no-redef] + +__all__ = [ + "Header", + "OccupancyGrid", + "Odometry", + "Point", + "Pose", + "Quaternion", + "Twist", + "Vector3", +] diff --git a/dimos/types/sample.py b/dimos/types/sample.py index eab963cde8..50c51040fe 100644 --- a/dimos/types/sample.py +++ b/dimos/types/sample.py @@ -1,21 +1,36 @@ -import json -import logging +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins from collections import OrderedDict +from collections.abc import Sequence from enum import Enum +import json +import logging from pathlib import Path -from typing import Any, Dict, List, Literal, Sequence, Union, get_origin +from typing import Annotated, Any, Literal, Union, get_origin +from datasets import Dataset # type: ignore[import-not-found] +from gymnasium import spaces # type: ignore[import-not-found] +from jsonref import replace_refs # type: ignore[import-not-found] +from mbodied.data.utils import to_features # type: ignore[import-not-found] +from mbodied.utils.import_utils import smart_import # type: ignore[import-not-found] import numpy as np -from datasets import Dataset -from gymnasium import spaces -from jsonref import replace_refs from pydantic import BaseModel, ConfigDict, ValidationError from pydantic.fields import FieldInfo from pydantic_core import from_json -from typing_extensions import Annotated - -from mbodied.data.utils import to_features -from mbodied.utils.import_utils import smart_import +import torch Flattenable = Annotated[Literal["dict", "np", "pt", "list"], "Numpy, PyTorch, list, or dict"] @@ -59,7 +74,7 @@ class Sample(BaseModel): __doc__ = "A base model class for serializing, recording, and manipulating arbitray data." - model_config: ConfigDict = ConfigDict( + model_config: ConfigDict = ConfigDict( # type: ignore[misc] use_enum_values=False, from_attributes=True, validate_assignment=False, @@ -67,7 +82,7 @@ class Sample(BaseModel): arbitrary_types_allowed=True, ) - def __init__(self, datum=None, **data): + def __init__(self, datum=None, **data) -> None: # type: ignore[no-untyped-def] """Accepts an arbitrary datum as well as keyword arguments.""" if datum is not None: if isinstance(datum, Sample): @@ -86,7 +101,7 @@ def __str__(self) -> str: """Return a string representation of the Sample instance.""" return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.dict().items() if v is not None])})" - def dict(self, exclude_none=True, exclude: set[str] = None) -> Dict[str, Any]: + def dict(self, exclude_none: bool = True, exclude: set[str] | None = None) -> dict[str, Any]: # type: ignore[override] """Return the Sample object as a dictionary with None values excluded. Args: @@ -99,7 +114,7 @@ def dict(self, exclude_none=True, exclude: set[str] = None) -> Dict[str, Any]: return self.model_dump(exclude_none=exclude_none, exclude=exclude) @classmethod - def unflatten(cls, one_d_array_or_dict, schema=None) -> "Sample": + def unflatten(cls, one_d_array_or_dict, schema=None) -> "Sample": # type: ignore[no-untyped-def] """Unflatten a one-dimensional array or dictionary into a Sample instance. If a dictionary is provided, its keys are ignored. @@ -128,7 +143,7 @@ def unflatten(cls, one_d_array_or_dict, schema=None) -> "Sample": else: flat_data = list(one_d_array_or_dict) - def unflatten_recursive(schema_part, index=0): + def unflatten_recursive(schema_part, index: int = 0): # type: ignore[no-untyped-def] if schema_part["type"] == "object": result = {} for prop, prop_schema in schema_part["properties"].items(): @@ -151,10 +166,10 @@ def flatten( self, output_type: Flattenable = "dict", non_numerical: Literal["ignore", "forbid", "allow"] = "allow", - ) -> Dict[str, Any] | np.ndarray | "torch.Tensor" | List: - accumulator = {} if output_type == "dict" else [] + ) -> builtins.dict[str, Any] | np.ndarray | torch.Tensor | list: # type: ignore[type-arg] + accumulator = {} if output_type == "dict" else [] # type: ignore[var-annotated] - def flatten_recursive(obj, path=""): + def flatten_recursive(obj, path: str = "") -> None: # type: ignore[no-untyped-def] if isinstance(obj, Sample): for k, v in obj.dict().items(): flatten_recursive(v, path + k + "/") @@ -168,31 +183,33 @@ def flatten_recursive(obj, path=""): flat_list = obj.flatten().tolist() if output_type == "dict": # Convert to list for dict storage - accumulator[path[:-1]] = flat_list + accumulator[path[:-1]] = flat_list # type: ignore[index] else: - accumulator.extend(flat_list) + accumulator.extend(flat_list) # type: ignore[attr-defined] else: if non_numerical == "ignore" and not isinstance(obj, int | float | bool): return final_key = path[:-1] # Remove trailing slash if output_type == "dict": - accumulator[final_key] = obj + accumulator[final_key] = obj # type: ignore[index] else: - accumulator.append(obj) + accumulator.append(obj) # type: ignore[attr-defined] flatten_recursive(self) - accumulator = accumulator.values() if output_type == "dict" else accumulator - if non_numerical == "forbid" and any(not isinstance(v, int | float | bool) for v in accumulator): + accumulator = accumulator.values() if output_type == "dict" else accumulator # type: ignore[attr-defined] + if non_numerical == "forbid" and any( + not isinstance(v, int | float | bool) for v in accumulator + ): raise ValueError("Non-numerical values found in flattened data.") if output_type == "np": return np.array(accumulator) if output_type == "pt": torch = smart_import("torch") - return torch.tensor(accumulator) - return accumulator + return torch.tensor(accumulator) # type: ignore[no-any-return] + return accumulator # type: ignore[return-value] @staticmethod - def obj_to_schema(value: Any) -> Dict: + def obj_to_schema(value: Any) -> builtins.dict: # type: ignore[type-arg] """Generates a simplified JSON schema from a dictionary. Args: @@ -202,7 +219,10 @@ def obj_to_schema(value: Any) -> Dict: dict: A simplified JSON schema representing the structure of the dictionary. """ if isinstance(value, dict): - return {"type": "object", "properties": {k: Sample.obj_to_schema(v) for k, v in value.items()}} + return { + "type": "object", + "properties": {k: Sample.obj_to_schema(v) for k, v in value.items()}, + } if isinstance(value, list | tuple | np.ndarray): if len(value) > 0: return {"type": "array", "items": Sample.obj_to_schema(value[0])} @@ -217,7 +237,11 @@ def obj_to_schema(value: Any) -> Dict: return {"type": "boolean"} return {} - def schema(self, resolve_refs: bool = True, include_descriptions=False) -> Dict: + def schema( + self, + resolve_refs: bool = True, + include_descriptions: bool = False, # type: ignore[override] + ) -> builtins.dict: # type: ignore[type-arg] """Returns a simplified json schema. Removing additionalProperties, @@ -246,7 +270,9 @@ def schema(self, resolve_refs: bool = True, include_descriptions=False) -> Dict: if key not in properties: properties[key] = Sample.obj_to_schema(value) if isinstance(value, Sample): - properties[key] = value.schema(resolve_refs=resolve_refs, include_descriptions=include_descriptions) + properties[key] = value.schema( + resolve_refs=resolve_refs, include_descriptions=include_descriptions + ) else: properties[key] = Sample.obj_to_schema(value) return schema @@ -291,8 +317,8 @@ def to(self, container: Any) -> Any: Returns: Any: The converted container. """ - if isinstance(container, Sample) and not issubclass(container, Sample): - return container(**self.dict()) + if isinstance(container, Sample) and not issubclass(container, Sample): # type: ignore[arg-type] + return container(**self.dict()) # type: ignore[operator] if isinstance(container, type) and issubclass(container, Sample): return container.unflatten(self.flatten()) @@ -330,7 +356,7 @@ def space_for( cls, value: Any, max_text_length: int = 1000, - info: Annotated = None, + info: Annotated = None, # type: ignore[valid-type] ) -> spaces.Space: """Default Gym space generation for a given value. @@ -385,10 +411,10 @@ def space_for( raise ValueError(f"Unsupported object {value} of type: {type(value)} for space generation") @classmethod - def init_from(cls, d: Any, pack=False) -> "Sample": + def init_from(cls, d: Any, pack: bool = False) -> "Sample": if isinstance(d, spaces.Space): return cls.from_space(d) - if isinstance(d, Union[Sequence, np.ndarray]): # noqa: UP007 + if isinstance(d, Union[Sequence, np.ndarray]): # type: ignore[arg-type] if pack: return cls.pack_from(d) return cls.unflatten(d) @@ -406,7 +432,11 @@ def init_from(cls, d: Any, pack=False) -> "Sample": return cls(d) @classmethod - def from_flat_dict(cls, flat_dict: Dict[str, Any], schema: Dict = None) -> "Sample": + def from_flat_dict( + cls, + flat_dict: builtins.dict[str, Any], + schema: builtins.dict | None = None, # type: ignore[type-arg] + ) -> "Sample": """Initialize a Sample instance from a flattened dictionary.""" """ Reconstructs the original JSON object from a flattened dictionary using the provided schema. @@ -419,7 +449,7 @@ def from_flat_dict(cls, flat_dict: Dict[str, Any], schema: Dict = None) -> "Samp dict: The reconstructed JSON object. """ schema = schema or replace_refs(cls.model_json_schema()) - reconstructed = {} + reconstructed = {} # type: ignore[var-annotated] for flat_key, value in flat_dict.items(): keys = flat_key.split(".") @@ -430,7 +460,7 @@ def from_flat_dict(cls, flat_dict: Dict[str, Any], schema: Dict = None) -> "Samp current = current[key] current[keys[-1]] = value - return reconstructed + return reconstructed # type: ignore[return-value] @classmethod def from_space(cls, space: spaces.Space) -> "Sample": @@ -441,11 +471,11 @@ def from_space(cls, space: spaces.Space) -> "Sample": if hasattr(sampled, "__len__") and not isinstance(sampled, str): sampled = np.asarray(sampled) if len(sampled.shape) > 0 and isinstance(sampled[0], dict | Sample): - return cls.pack_from(sampled) + return cls.pack_from(sampled) # type: ignore[arg-type] return cls(sampled) @classmethod - def pack_from(cls, samples: List[Union["Sample", Dict]]) -> "Sample": + def pack_from(cls, samples: list[Union["Sample", builtins.dict]]) -> "Sample": # type: ignore[type-arg] """Pack a list of samples into a single sample with lists for attributes. Args: @@ -465,7 +495,7 @@ def pack_from(cls, samples: List[Union["Sample", Dict]]) -> "Sample": else: attributes = ["item" + str(i) for i in range(len(samples))] - aggregated = {attr: [] for attr in attributes} + aggregated = {attr: [] for attr in attributes} # type: ignore[var-annotated] for sample in samples: for attr in attributes: # Handle both Sample instances and dictionaries @@ -475,15 +505,17 @@ def pack_from(cls, samples: List[Union["Sample", Dict]]) -> "Sample": aggregated[attr].append(getattr(sample, attr, None)) return cls(**aggregated) - def unpack(self, to_dicts=False) -> List[Union["Sample", Dict]]: + def unpack(self, to_dicts: bool = False) -> list[Union["Sample", builtins.dict]]: # type: ignore[type-arg] """Unpack the packed Sample object into a list of Sample objects or dictionaries.""" - attributes = list(self.model_extra.keys()) + list(self.model_fields.keys()) + attributes = list(self.model_extra.keys()) + list(self.model_fields.keys()) # type: ignore[union-attr] attributes = [attr for attr in attributes if getattr(self, attr) is not None] if not attributes or getattr(self, attributes[0]) is None: return [] # Ensure all attributes are lists and have the same length - list_sizes = {len(getattr(self, attr)) for attr in attributes if isinstance(getattr(self, attr), list)} + list_sizes = { + len(getattr(self, attr)) for attr in attributes if isinstance(getattr(self, attr), list) + } if len(list_sizes) != 1: raise ValueError("Not all attribute lists have the same length.") list_size = list_sizes.pop() @@ -491,7 +523,10 @@ def unpack(self, to_dicts=False) -> List[Union["Sample", Dict]]: if to_dicts: return [{key: getattr(self, key)[i] for key in attributes} for i in range(list_size)] - return [self.__class__(**{key: getattr(self, key)[i] for key in attributes}) for i in range(list_size)] + return [ + self.__class__(**{key: getattr(self, key)[i] for key in attributes}) + for i in range(list_size) + ] @classmethod def default_space(cls) -> spaces.Dict: @@ -499,7 +534,9 @@ def default_space(cls) -> spaces.Dict: return cls().space() @classmethod - def default_sample(cls, output_type="Sample") -> Union["Sample", Dict[str, Any]]: + def default_sample( + cls, output_type: str = "Sample" + ) -> Union["Sample", builtins.dict[str, Any]]: """Generate a default Sample instance from its class attributes. Useful for padding. This is the "no-op" instance and should be overriden as needed. @@ -511,13 +548,13 @@ def default_sample(cls, output_type="Sample") -> Union["Sample", Dict[str, Any]] def model_field_info(self, key: str) -> FieldInfo: """Get the FieldInfo for a given attribute key.""" if self.model_extra and self.model_extra.get(key) is not None: - info = FieldInfo(metadata=self.model_extra[key]) + info = FieldInfo(metadata=self.model_extra[key]) # type: ignore[call-arg] if self.model_fields.get(key) is not None: - info = FieldInfo(metadata=self.model_fields[key]) + info = FieldInfo(metadata=self.model_fields[key]) # type: ignore[call-arg] if info and hasattr(info, "annotation"): - return info.annotation - return None + return info.annotation # type: ignore[return-value] + return None # type: ignore[return-value] def space(self) -> spaces.Dict: """Return the corresponding Gym space for the Sample instance based on its instance attributes. Omits None values. @@ -528,8 +565,10 @@ def space(self) -> spaces.Dict: for key, value in self.dict().items(): logging.debug("Generating space for key: '%s', value: %s", key, value) info = self.model_field_info(key) - value = getattr(self, key) if hasattr(self, key) else value # noqa: PLW2901 - space_dict[key] = value.space() if isinstance(value, Sample) else self.space_for(value, info=info) + value = getattr(self, key) if hasattr(self, key) else value + space_dict[key] = ( + value.space() if isinstance(value, Sample) else self.space_for(value, info=info) + ) return spaces.Dict(space_dict) def random_sample(self) -> "Sample": @@ -541,4 +580,4 @@ def random_sample(self) -> "Sample": if __name__ == "__main__": - sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) \ No newline at end of file + sample = Sample(x=1, y=2, z={"a": 3, "b": 4}, extra_field=5) diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py new file mode 100644 index 0000000000..88a8d65102 --- /dev/null +++ b/dimos/types/test_timestamped.py @@ -0,0 +1,578 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timezone +import time + +import pytest +from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.msgs.sensor_msgs import Image +from dimos.types.timestamped import ( + Timestamped, + TimestampedBufferCollection, + TimestampedCollection, + align_timestamped, + to_datetime, + to_ros_stamp, +) +from dimos.utils import testing +from dimos.utils.data import get_data +from dimos.utils.reactive import backpressure + + +def test_timestamped_dt_method() -> None: + ts = 1751075203.4120464 + timestamped = Timestamped(ts) + dt = timestamped.dt() + assert isinstance(dt, datetime) + assert abs(dt.timestamp() - ts) < 1e-6 + assert dt.tzinfo is not None, "datetime should be timezone-aware" + + +def test_to_ros_stamp() -> None: + """Test the to_ros_stamp function with different input types.""" + + # Test with float timestamp + ts_float = 1234567890.123456789 + result = to_ros_stamp(ts_float) + assert result.sec == 1234567890 + # Float precision limitation - check within reasonable range + assert abs(result.nanosec - 123456789) < 1000 + + # Test with integer timestamp + ts_int = 1234567890 + result = to_ros_stamp(ts_int) + assert result.sec == 1234567890 + assert result.nanosec == 0 + + # Test with datetime object + dt = datetime(2009, 2, 13, 23, 31, 30, 123456, tzinfo=timezone.utc) + result = to_ros_stamp(dt) + assert result.sec == 1234567890 + assert abs(result.nanosec - 123456000) < 1000 # Allow small rounding error + + +def test_to_datetime() -> None: + """Test the to_datetime function with different input types.""" + + # Test with float timestamp + ts_float = 1234567890.123456 + dt = to_datetime(ts_float) + assert isinstance(dt, datetime) + assert dt.tzinfo is not None # Should have timezone + assert abs(dt.timestamp() - ts_float) < 1e-6 + + # Test with integer timestamp + ts_int = 1234567890 + dt = to_datetime(ts_int) + assert isinstance(dt, datetime) + assert dt.tzinfo is not None + assert dt.timestamp() == ts_int + + # Test with RosStamp + ros_stamp = {"sec": 1234567890, "nanosec": 123456000} + dt = to_datetime(ros_stamp) + assert isinstance(dt, datetime) + assert dt.tzinfo is not None + expected_ts = 1234567890.123456 + assert abs(dt.timestamp() - expected_ts) < 1e-6 + + # Test with datetime (already has timezone) + dt_input = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) + dt_result = to_datetime(dt_input) + assert dt_result.tzinfo is not None + # Should convert to local timezone by default + + # Test with naive datetime (no timezone) + dt_naive = datetime(2009, 2, 13, 23, 31, 30) + dt_result = to_datetime(dt_naive) + assert dt_result.tzinfo is not None + + # Test with specific timezone + dt_utc = to_datetime(ts_float, tz=timezone.utc) + assert dt_utc.tzinfo == timezone.utc + assert abs(dt_utc.timestamp() - ts_float) < 1e-6 + + +class SimpleTimestamped(Timestamped): + def __init__(self, ts: float, data: str) -> None: + super().__init__(ts) + self.data = data + + +@pytest.fixture +def test_scheduler(): + """Fixture that provides a ThreadPoolScheduler and cleans it up after the test.""" + scheduler = ThreadPoolScheduler(max_workers=6) + yield scheduler + # Cleanup after test + scheduler.executor.shutdown(wait=True) + time.sleep(0.2) # Give threads time to finish cleanup + + +@pytest.fixture +def sample_items(): + return [ + SimpleTimestamped(1.0, "first"), + SimpleTimestamped(3.0, "third"), + SimpleTimestamped(5.0, "fifth"), + SimpleTimestamped(7.0, "seventh"), + ] + + +@pytest.fixture +def collection(sample_items): + return TimestampedCollection(sample_items) + + +def test_empty_collection() -> None: + collection = TimestampedCollection() + assert len(collection) == 0 + assert collection.duration() == 0.0 + assert collection.time_range() is None + assert collection.find_closest(1.0) is None + + +def test_add_items() -> None: + collection = TimestampedCollection() + item1 = SimpleTimestamped(2.0, "two") + item2 = SimpleTimestamped(1.0, "one") + + collection.add(item1) + collection.add(item2) + + assert len(collection) == 2 + assert collection[0].data == "one" # Should be sorted by timestamp + assert collection[1].data == "two" + + +def test_find_closest(collection) -> None: + # Exact match + assert collection.find_closest(3.0).data == "third" + + # Between items (closer to left) + assert collection.find_closest(1.5, tolerance=1.0).data == "first" + + # Between items (closer to right) + assert collection.find_closest(3.5, tolerance=1.0).data == "third" + + # Exactly in the middle (should pick the later one due to >= comparison) + assert ( + collection.find_closest(4.0, tolerance=1.0).data == "fifth" + ) # 4.0 is equidistant from 3.0 and 5.0 + + # Before all items + assert collection.find_closest(0.0, tolerance=1.0).data == "first" + + # After all items + assert collection.find_closest(10.0, tolerance=4.0).data == "seventh" + + # low tolerance, should return None + assert collection.find_closest(10.0, tolerance=2.0) is None + + +def test_find_before_after(collection) -> None: + # Find before + assert collection.find_before(2.0).data == "first" + assert collection.find_before(5.5).data == "fifth" + assert collection.find_before(1.0) is None # Nothing before first item + + # Find after + assert collection.find_after(2.0).data == "third" + assert collection.find_after(5.0).data == "seventh" + assert collection.find_after(7.0) is None # Nothing after last item + + +def test_merge_collections() -> None: + collection1 = TimestampedCollection( + [ + SimpleTimestamped(1.0, "a"), + SimpleTimestamped(3.0, "c"), + ] + ) + collection2 = TimestampedCollection( + [ + SimpleTimestamped(2.0, "b"), + SimpleTimestamped(4.0, "d"), + ] + ) + + merged = collection1.merge(collection2) + + assert len(merged) == 4 + assert [item.data for item in merged] == ["a", "b", "c", "d"] + + +def test_duration_and_range(collection) -> None: + assert collection.duration() == 6.0 # 7.0 - 1.0 + assert collection.time_range() == (1.0, 7.0) + + +def test_slice_by_time(collection) -> None: + # Slice inclusive of boundaries + sliced = collection.slice_by_time(2.0, 6.0) + assert len(sliced) == 2 + assert sliced[0].data == "third" + assert sliced[1].data == "fifth" + + # Empty slice + empty_slice = collection.slice_by_time(8.0, 10.0) + assert len(empty_slice) == 0 + + # Slice all + all_slice = collection.slice_by_time(0.0, 10.0) + assert len(all_slice) == 4 + + +def test_iteration(collection) -> None: + items = list(collection) + assert len(items) == 4 + assert [item.ts for item in items] == [1.0, 3.0, 5.0, 7.0] + + +def test_single_item_collection() -> None: + single = TimestampedCollection([SimpleTimestamped(5.0, "only")]) + assert single.duration() == 0.0 + assert single.time_range() == (5.0, 5.0) + + +def test_time_window_collection() -> None: + # Create a collection with a 2-second window + window = TimestampedBufferCollection[SimpleTimestamped](window_duration=2.0) + + # Add messages at different timestamps + window.add(SimpleTimestamped(1.0, "msg1")) + window.add(SimpleTimestamped(2.0, "msg2")) + window.add(SimpleTimestamped(3.0, "msg3")) + + # At this point, all messages should be present (within 2s window) + assert len(window) == 3 + + # Add a message at t=4.0, should keep messages from t=2.0 onwards + window.add(SimpleTimestamped(4.0, "msg4")) + assert len(window) == 3 # msg1 should be dropped + assert window[0].data == "msg2" # oldest is now msg2 + assert window[-1].data == "msg4" # newest is msg4 + + # Add a message at t=5.5, should drop msg2 and msg3 + window.add(SimpleTimestamped(5.5, "msg5")) + assert len(window) == 2 # only msg4 and msg5 remain + assert window[0].data == "msg4" + assert window[1].data == "msg5" + + # Verify time range + assert window.start_ts == 4.0 + assert window.end_ts == 5.5 + + +def test_timestamp_alignment(test_scheduler) -> None: + speed = 5.0 + + # ensure that lfs package is downloaded + get_data("unitree_office_walk") + + raw_frames = [] + + def spy(image): + raw_frames.append(image.ts) + print(image.ts) + return image + + # sensor reply of raw video frames + video_raw = ( + testing.TimedSensorReplay( + "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() + ) + .stream(speed) + .pipe(ops.take(30)) + ) + + processed_frames = [] + + def process_video_frame(frame): + processed_frames.append(frame.ts) + time.sleep(0.5 / speed) + return frame + + # fake reply of some 0.5s processor of video frames that drops messages + # Pass the scheduler to backpressure to manage threads properly + fake_video_processor = backpressure( + video_raw.pipe(ops.map(spy)), scheduler=test_scheduler + ).pipe(ops.map(process_video_frame)) + + aligned_frames = align_timestamped(fake_video_processor, video_raw).pipe(ops.to_list()).run() + + assert len(raw_frames) == 30 + assert len(processed_frames) > 2 + assert len(aligned_frames) > 2 + + # Due to async processing, the last frame might not be aligned before completion + assert len(aligned_frames) >= len(processed_frames) - 1 + + for value in aligned_frames: + [primary, secondary] = value + diff = abs(primary.ts - secondary.ts) + print( + f"Aligned pair: primary={primary.ts:.6f}, secondary={secondary.ts:.6f}, diff={diff:.6f}s" + ) + assert diff <= 0.05 + + assert len(aligned_frames) > 2 + + +def test_timestamp_alignment_primary_first() -> None: + """Test alignment when primary messages arrive before secondary messages.""" + from reactivex import Subject + + primary_subject = Subject() + secondary_subject = Subject() + + results = [] + + # Set up alignment with a 2-second buffer + aligned = align_timestamped( + primary_subject, secondary_subject, buffer_size=2.0, match_tolerance=0.1 + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Send primary messages first + primary1 = SimpleTimestamped(1.0, "primary1") + primary2 = SimpleTimestamped(2.0, "primary2") + primary3 = SimpleTimestamped(3.0, "primary3") + + primary_subject.on_next(primary1) + primary_subject.on_next(primary2) + primary_subject.on_next(primary3) + + # At this point, no results should be emitted (no secondaries yet) + assert len(results) == 0 + + # Send secondary messages that match primary1 and primary2 + secondary1 = SimpleTimestamped(1.05, "secondary1") # Matches primary1 + secondary2 = SimpleTimestamped(2.02, "secondary2") # Matches primary2 + + secondary_subject.on_next(secondary1) + assert len(results) == 1 # primary1 should now be matched + assert results[0][0].data == "primary1" + assert results[0][1].data == "secondary1" + + secondary_subject.on_next(secondary2) + assert len(results) == 2 # primary2 should now be matched + assert results[1][0].data == "primary2" + assert results[1][1].data == "secondary2" + + # Send a secondary that's too far from primary3 + secondary_far = SimpleTimestamped(3.5, "secondary_far") # Too far from primary3 + secondary_subject.on_next(secondary_far) + # At this point primary3 is removed as unmatchable since secondary progressed past it + assert len(results) == 2 # primary3 should not match (outside tolerance) + + # Send a new primary that can match with the future secondary + primary4 = SimpleTimestamped(3.45, "primary4") + primary_subject.on_next(primary4) + assert len(results) == 3 # Should match with secondary_far + assert results[2][0].data == "primary4" + assert results[2][1].data == "secondary_far" + + # Complete the streams + primary_subject.on_completed() + secondary_subject.on_completed() + + +def test_timestamp_alignment_multiple_secondaries() -> None: + """Test alignment with multiple secondary observables.""" + from reactivex import Subject + + primary_subject = Subject() + secondary1_subject = Subject() + secondary2_subject = Subject() + + results = [] + + # Set up alignment with two secondary streams + aligned = align_timestamped( + primary_subject, + secondary1_subject, + secondary2_subject, + buffer_size=1.0, + match_tolerance=0.05, + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Send a primary message + primary1 = SimpleTimestamped(1.0, "primary1") + primary_subject.on_next(primary1) + + # No results yet (waiting for both secondaries) + assert len(results) == 0 + + # Send first secondary + sec1_msg1 = SimpleTimestamped(1.01, "sec1_msg1") + secondary1_subject.on_next(sec1_msg1) + + # Still no results (waiting for secondary2) + assert len(results) == 0 + + # Send second secondary + sec2_msg1 = SimpleTimestamped(1.02, "sec2_msg1") + secondary2_subject.on_next(sec2_msg1) + + # Now we should have a result + assert len(results) == 1 + assert results[0][0].data == "primary1" + assert results[0][1].data == "sec1_msg1" + assert results[0][2].data == "sec2_msg1" + + # Test partial match (one secondary missing) + primary2 = SimpleTimestamped(2.0, "primary2") + primary_subject.on_next(primary2) + + # Send only one secondary + sec1_msg2 = SimpleTimestamped(2.01, "sec1_msg2") + secondary1_subject.on_next(sec1_msg2) + + # No result yet + assert len(results) == 1 + + # Send a secondary2 that's too far + sec2_far = SimpleTimestamped(2.1, "sec2_far") # Outside tolerance + secondary2_subject.on_next(sec2_far) + + # Still no result (secondary2 is outside tolerance) + assert len(results) == 1 + + # Complete the streams + primary_subject.on_completed() + secondary1_subject.on_completed() + secondary2_subject.on_completed() + + +def test_timestamp_alignment_delayed_secondary() -> None: + """Test alignment when secondary messages arrive late but still within tolerance.""" + from reactivex import Subject + + primary_subject = Subject() + secondary_subject = Subject() + + results = [] + + # Set up alignment with a 2-second buffer + aligned = align_timestamped( + primary_subject, secondary_subject, buffer_size=2.0, match_tolerance=0.1 + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Send primary messages + primary1 = SimpleTimestamped(1.0, "primary1") + primary2 = SimpleTimestamped(2.0, "primary2") + primary3 = SimpleTimestamped(3.0, "primary3") + + primary_subject.on_next(primary1) + primary_subject.on_next(primary2) + primary_subject.on_next(primary3) + + # No results yet + assert len(results) == 0 + + # Send delayed secondaries (in timestamp order) + secondary1 = SimpleTimestamped(1.05, "secondary1") # Matches primary1 + secondary_subject.on_next(secondary1) + assert len(results) == 1 # primary1 matched + assert results[0][0].data == "primary1" + assert results[0][1].data == "secondary1" + + secondary2 = SimpleTimestamped(2.02, "secondary2") # Matches primary2 + secondary_subject.on_next(secondary2) + assert len(results) == 2 # primary2 matched + assert results[1][0].data == "primary2" + assert results[1][1].data == "secondary2" + + # Now send a secondary that's past primary3's match window + secondary_future = SimpleTimestamped(3.2, "secondary_future") # Too far from primary3 + secondary_subject.on_next(secondary_future) + # At this point, primary3 should be removed as unmatchable + assert len(results) == 2 # No new matches + + # Send a new primary that can match with secondary_future + primary4 = SimpleTimestamped(3.15, "primary4") + primary_subject.on_next(primary4) + assert len(results) == 3 # Should match immediately + assert results[2][0].data == "primary4" + assert results[2][1].data == "secondary_future" + + # Complete the streams + primary_subject.on_completed() + secondary_subject.on_completed() + + +def test_timestamp_alignment_buffer_cleanup() -> None: + """Test that old buffered primaries are cleaned up.""" + import time as time_module + + from reactivex import Subject + + primary_subject = Subject() + secondary_subject = Subject() + + results = [] + + # Set up alignment with a 0.5-second buffer + aligned = align_timestamped( + primary_subject, secondary_subject, buffer_size=0.5, match_tolerance=0.05 + ) + + # Subscribe to collect results + aligned.subscribe(lambda x: results.append(x)) + + # Use real timestamps for this test + now = time_module.time() + + # Send an old primary + old_primary = Timestamped(now - 1.0) # 1 second ago + old_primary.data = "old" + primary_subject.on_next(old_primary) + + # Send a recent secondary to trigger cleanup + recent_secondary = Timestamped(now) + recent_secondary.data = "recent" + secondary_subject.on_next(recent_secondary) + + # Old primary should not match (outside buffer window) + assert len(results) == 0 + + # Send a matching pair within buffer + new_primary = Timestamped(now + 0.1) + new_primary.data = "new_primary" + new_secondary = Timestamped(now + 0.11) + new_secondary.data = "new_secondary" + + primary_subject.on_next(new_primary) + secondary_subject.on_next(new_secondary) + + # Should have one match + assert len(results) == 1 + assert results[0][0].data == "new_primary" + assert results[0][1].data == "new_secondary" + + # Complete the streams + primary_subject.on_completed() + secondary_subject.on_completed() diff --git a/dimos/types/test_vector.py b/dimos/types/test_vector.py new file mode 100644 index 0000000000..285d021bea --- /dev/null +++ b/dimos/types/test_vector.py @@ -0,0 +1,384 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from dimos.types.vector import Vector + + +def test_vector_default_init() -> None: + """Test that default initialization of Vector() has x,y,z components all zero.""" + v = Vector() + assert v.x == 0.0 + assert v.y == 0.0 + assert v.z == 0.0 + assert v.dim == 0 + assert len(v.data) == 0 + assert v.to_list() == [] + assert v.is_zero() # Empty vector should be considered zero + + +def test_vector_specific_init() -> None: + """Test initialization with specific values.""" + # 2D vector + v1 = Vector(1.0, 2.0) + assert v1.x == 1.0 + assert v1.y == 2.0 + assert v1.z == 0.0 + assert v1.dim == 2 + + # 3D vector + v2 = Vector(3.0, 4.0, 5.0) + assert v2.x == 3.0 + assert v2.y == 4.0 + assert v2.z == 5.0 + assert v2.dim == 3 + + # From list + v3 = Vector([6.0, 7.0, 8.0]) + assert v3.x == 6.0 + assert v3.y == 7.0 + assert v3.z == 8.0 + assert v3.dim == 3 + + # From numpy array + v4 = Vector(np.array([9.0, 10.0, 11.0])) + assert v4.x == 9.0 + assert v4.y == 10.0 + assert v4.z == 11.0 + assert v4.dim == 3 + + +def test_vector_addition() -> None: + """Test vector addition.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + v_add = v1 + v2 + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + +def test_vector_subtraction() -> None: + """Test vector subtraction.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + v_sub = v2 - v1 + assert v_sub.x == 3.0 + assert v_sub.y == 3.0 + assert v_sub.z == 3.0 + + +def test_vector_scalar_multiplication() -> None: + """Test vector multiplication by a scalar.""" + v1 = Vector(1.0, 2.0, 3.0) + + v_mul = v1 * 2.0 + assert v_mul.x == 2.0 + assert v_mul.y == 4.0 + assert v_mul.z == 6.0 + + # Test right multiplication + v_rmul = 2.0 * v1 + assert v_rmul.x == 2.0 + assert v_rmul.y == 4.0 + assert v_rmul.z == 6.0 + + +def test_vector_scalar_division() -> None: + """Test vector division by a scalar.""" + v2 = Vector(4.0, 5.0, 6.0) + + v_div = v2 / 2.0 + assert v_div.x == 2.0 + assert v_div.y == 2.5 + assert v_div.z == 3.0 + + +def test_vector_dot_product() -> None: + """Test vector dot product.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + dot = v1.dot(v2) + assert dot == 32.0 + + +def test_vector_length() -> None: + """Test vector length calculation.""" + # 2D vector with length 5 + v1 = Vector(3.0, 4.0) + assert v1.length() == 5.0 + + # 3D vector + v2 = Vector(2.0, 3.0, 6.0) + assert v2.length() == pytest.approx(7.0, 0.001) + + # Test length_squared + assert v1.length_squared() == 25.0 + assert v2.length_squared() == 49.0 + + +def test_vector_normalize() -> None: + """Test vector normalization.""" + v = Vector(2.0, 3.0, 6.0) + assert not v.is_zero() + + v_norm = v.normalize() + length = v.length() + expected_x = 2.0 / length + expected_y = 3.0 / length + expected_z = 6.0 / length + + assert np.isclose(v_norm.x, expected_x) + assert np.isclose(v_norm.y, expected_y) + assert np.isclose(v_norm.z, expected_z) + assert np.isclose(v_norm.length(), 1.0) + assert not v_norm.is_zero() + + # Test normalizing a zero vector + v_zero = Vector(0.0, 0.0, 0.0) + assert v_zero.is_zero() + v_zero_norm = v_zero.normalize() + assert v_zero_norm.x == 0.0 + assert v_zero_norm.y == 0.0 + assert v_zero_norm.z == 0.0 + assert v_zero_norm.is_zero() + + +def test_vector_to_2d() -> None: + """Test conversion to 2D vector.""" + v = Vector(2.0, 3.0, 6.0) + + v_2d = v.to_2d() + assert v_2d.x == 2.0 + assert v_2d.y == 3.0 + assert v_2d.z == 0.0 + assert v_2d.dim == 2 + + # Already 2D vector + v2 = Vector(4.0, 5.0) + v2_2d = v2.to_2d() + assert v2_2d.x == 4.0 + assert v2_2d.y == 5.0 + assert v2_2d.dim == 2 + + +def test_vector_distance() -> None: + """Test distance calculations between vectors.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 6.0, 8.0) + + # Distance + dist = v1.distance(v2) + expected_dist = np.sqrt(9.0 + 16.0 + 25.0) # sqrt((4-1)² + (6-2)² + (8-3)²) + assert dist == pytest.approx(expected_dist) + + # Distance squared + dist_sq = v1.distance_squared(v2) + assert dist_sq == 50.0 # 9 + 16 + 25 + + +def test_vector_cross_product() -> None: + """Test vector cross product.""" + v1 = Vector(1.0, 0.0, 0.0) # Unit x vector + v2 = Vector(0.0, 1.0, 0.0) # Unit y vector + + # v1 × v2 should be unit z vector + cross = v1.cross(v2) + assert cross.x == 0.0 + assert cross.y == 0.0 + assert cross.z == 1.0 + + # Test with more complex vectors + a = Vector(2.0, 3.0, 4.0) + b = Vector(5.0, 6.0, 7.0) + c = a.cross(b) + + # Cross product manually calculated: + # (3*7-4*6, 4*5-2*7, 2*6-3*5) + assert c.x == -3.0 + assert c.y == 6.0 + assert c.z == -3.0 + + # Test with 2D vectors (should raise error) + v_2d = Vector(1.0, 2.0) + with pytest.raises(ValueError): + v_2d.cross(v2) + + +def test_vector_zeros() -> None: + """Test Vector.zeros class method.""" + # 3D zero vector + v_zeros = Vector.zeros(3) + assert v_zeros.x == 0.0 + assert v_zeros.y == 0.0 + assert v_zeros.z == 0.0 + assert v_zeros.dim == 3 + assert v_zeros.is_zero() + + # 2D zero vector + v_zeros_2d = Vector.zeros(2) + assert v_zeros_2d.x == 0.0 + assert v_zeros_2d.y == 0.0 + assert v_zeros_2d.z == 0.0 + assert v_zeros_2d.dim == 2 + assert v_zeros_2d.is_zero() + + +def test_vector_ones() -> None: + """Test Vector.ones class method.""" + # 3D ones vector + v_ones = Vector.ones(3) + assert v_ones.x == 1.0 + assert v_ones.y == 1.0 + assert v_ones.z == 1.0 + assert v_ones.dim == 3 + + # 2D ones vector + v_ones_2d = Vector.ones(2) + assert v_ones_2d.x == 1.0 + assert v_ones_2d.y == 1.0 + assert v_ones_2d.z == 0.0 + assert v_ones_2d.dim == 2 + + +def test_vector_conversion_methods() -> None: + """Test vector conversion methods (to_list, to_tuple, to_numpy).""" + v = Vector(1.0, 2.0, 3.0) + + # to_list + assert v.to_list() == [1.0, 2.0, 3.0] + + # to_tuple + assert v.to_tuple() == (1.0, 2.0, 3.0) + + # to_numpy + np_array = v.to_numpy() + assert isinstance(np_array, np.ndarray) + assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) + + +def test_vector_equality() -> None: + """Test vector equality.""" + v1 = Vector(1, 2, 3) + v2 = Vector(1, 2, 3) + v3 = Vector(4, 5, 6) + + assert v1 == v2 + assert v1 != v3 + assert v1 != Vector(1, 2) # Different dimensions + assert v1 != Vector(1.1, 2, 3) # Different values + assert v1 != [1, 2, 3] + + +def test_vector_is_zero() -> None: + """Test is_zero method for vectors.""" + # Default empty vector + v0 = Vector() + assert v0.is_zero() + + # Explicit zero vector + v1 = Vector(0.0, 0.0, 0.0) + assert v1.is_zero() + + # Zero vector with different dimensions + v2 = Vector(0.0, 0.0) + assert v2.is_zero() + + # Non-zero vectors + v3 = Vector(1.0, 0.0, 0.0) + assert not v3.is_zero() + + v4 = Vector(0.0, 2.0, 0.0) + assert not v4.is_zero() + + v5 = Vector(0.0, 0.0, 3.0) + assert not v5.is_zero() + + # Almost zero (within tolerance) + v6 = Vector(1e-10, 1e-10, 1e-10) + assert v6.is_zero() + + # Almost zero (outside tolerance) + v7 = Vector(1e-6, 1e-6, 1e-6) + assert not v7.is_zero() + + +def test_vector_bool_conversion(): + """Test boolean conversion of vectors.""" + # Zero vectors should be False + v0 = Vector() + assert not bool(v0) + + v1 = Vector(0.0, 0.0, 0.0) + assert not bool(v1) + + # Almost zero vectors should be False + v2 = Vector(1e-10, 1e-10, 1e-10) + assert not bool(v2) + + # Non-zero vectors should be True + v3 = Vector(1.0, 0.0, 0.0) + assert bool(v3) + + v4 = Vector(0.0, 2.0, 0.0) + assert bool(v4) + + v5 = Vector(0.0, 0.0, 3.0) + assert bool(v5) + + # Direct use in if statements + if v0: + raise AssertionError("Zero vector should be False in boolean context") + else: + pass # Expected path + + if v3: + pass # Expected path + else: + raise AssertionError("Non-zero vector should be True in boolean context") + + +def test_vector_add() -> None: + """Test vector addition operator.""" + v1 = Vector(1.0, 2.0, 3.0) + v2 = Vector(4.0, 5.0, 6.0) + + # Using __add__ method + v_add = v1.__add__(v2) + assert v_add.x == 5.0 + assert v_add.y == 7.0 + assert v_add.z == 9.0 + + # Using + operator + v_add_op = v1 + v2 + assert v_add_op.x == 5.0 + assert v_add_op.y == 7.0 + assert v_add_op.z == 9.0 + + # Adding zero vector should return original vector + v_zero = Vector.zeros(3) + assert (v1 + v_zero) == v1 + + +def test_vector_add_dim_mismatch() -> None: + """Test vector addition operator.""" + v1 = Vector(1.0, 2.0) + v2 = Vector(4.0, 5.0, 6.0) + + # Using + operator + v1 + v2 diff --git a/dimos/types/test_weaklist.py b/dimos/types/test_weaklist.py new file mode 100644 index 0000000000..990cc0d164 --- /dev/null +++ b/dimos/types/test_weaklist.py @@ -0,0 +1,165 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for WeakList implementation.""" + +import gc + +import pytest + +from dimos.types.weaklist import WeakList + + +class SampleObject: + """Simple test object.""" + + def __init__(self, value) -> None: + self.value = value + + def __repr__(self) -> str: + return f"SampleObject({self.value})" + + +def test_weaklist_basic_operations() -> None: + """Test basic append, iterate, and length operations.""" + wl = WeakList() + + # Add objects + obj1 = SampleObject(1) + obj2 = SampleObject(2) + obj3 = SampleObject(3) + + wl.append(obj1) + wl.append(obj2) + wl.append(obj3) + + # Check length and iteration + assert len(wl) == 3 + assert list(wl) == [obj1, obj2, obj3] + + # Check contains + assert obj1 in wl + assert obj2 in wl + assert SampleObject(4) not in wl + + +def test_weaklist_auto_removal() -> None: + """Test that objects are automatically removed when garbage collected.""" + wl = WeakList() + + obj1 = SampleObject(1) + obj2 = SampleObject(2) + obj3 = SampleObject(3) + + wl.append(obj1) + wl.append(obj2) + wl.append(obj3) + + assert len(wl) == 3 + + # Delete one object and force garbage collection + del obj2 + gc.collect() + + # Should only have 2 objects now + assert len(wl) == 2 + assert list(wl) == [obj1, obj3] + + +def test_weaklist_explicit_remove() -> None: + """Test explicit removal of objects.""" + wl = WeakList() + + obj1 = SampleObject(1) + obj2 = SampleObject(2) + + wl.append(obj1) + wl.append(obj2) + + # Remove obj1 + wl.remove(obj1) + assert len(wl) == 1 + assert obj1 not in wl + assert obj2 in wl + + # Try to remove non-existent object + with pytest.raises(ValueError): + wl.remove(SampleObject(3)) + + +def test_weaklist_indexing() -> None: + """Test index access.""" + wl = WeakList() + + obj1 = SampleObject(1) + obj2 = SampleObject(2) + obj3 = SampleObject(3) + + wl.append(obj1) + wl.append(obj2) + wl.append(obj3) + + assert wl[0] is obj1 + assert wl[1] is obj2 + assert wl[2] is obj3 + + # Test index out of range + with pytest.raises(IndexError): + _ = wl[3] + + +def test_weaklist_clear() -> None: + """Test clearing the list.""" + wl = WeakList() + + obj1 = SampleObject(1) + obj2 = SampleObject(2) + + wl.append(obj1) + wl.append(obj2) + + assert len(wl) == 2 + + wl.clear() + assert len(wl) == 0 + assert obj1 not in wl + + +def test_weaklist_iteration_during_modification() -> None: + """Test that iteration works even if objects are deleted during iteration.""" + wl = WeakList() + + objects = [SampleObject(i) for i in range(5)] + for obj in objects: + wl.append(obj) + + # Verify initial state + assert len(wl) == 5 + + # Iterate and check that we can safely delete objects + seen_values = [] + for obj in wl: + seen_values.append(obj.value) + if obj.value == 2: + # Delete another object (not the current one) + del objects[3] # Delete SampleObject(3) + gc.collect() + + # The object with value 3 gets garbage collected during iteration + # so we might not see it (depends on timing) + assert len(seen_values) in [4, 5] + assert all(v in [0, 1, 2, 3, 4] for v in seen_values) + + # After iteration, the list should have 4 objects (one was deleted) + assert len(wl) == 4 diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py new file mode 100644 index 0000000000..99be484f0a --- /dev/null +++ b/dimos/types/timestamped.py @@ -0,0 +1,411 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from collections.abc import Iterable, Iterator +from datetime import datetime, timezone +from typing import Generic, TypeVar, Union + +from dimos_lcm.builtin_interfaces import Time as ROSTime # type: ignore[import-untyped] +from reactivex import create +from reactivex.disposable import CompositeDisposable + +# from dimos_lcm.std_msgs import Time as ROSTime +from reactivex.observable import Observable +from sortedcontainers import SortedKeyList # type: ignore[import-untyped] + +from dimos.types.weaklist import WeakList +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# any class that carries a timestamp should inherit from this +# this allows us to work with timeseries in consistent way, allign messages, replay etc +# aditional functionality will come to this class soon + + +# class RosStamp(TypedDict): +# sec: int +# nanosec: int + + +TimeLike = Union[int, float, datetime, ROSTime] + + +def to_timestamp(ts: TimeLike) -> float: + """Convert TimeLike to a timestamp in seconds.""" + if isinstance(ts, datetime): + return ts.timestamp() + if isinstance(ts, int | float): + return float(ts) + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return ts["sec"] + ts["nanosec"] / 1e9 # type: ignore[no-any-return] + # Check for ROS Time-like objects by attributes + if hasattr(ts, "sec") and (hasattr(ts, "nanosec") or hasattr(ts, "nsec")): + # Handle both std_msgs.Time (nsec) and builtin_interfaces.Time (nanosec) + if hasattr(ts, "nanosec"): + return ts.sec + ts.nanosec / 1e9 # type: ignore[no-any-return] + else: # has nsec + return ts.sec + ts.nsec / 1e9 # type: ignore[no-any-return] + raise TypeError("unsupported timestamp type") + + +def to_ros_stamp(ts: TimeLike) -> ROSTime: + """Convert TimeLike to a ROS-style timestamp dictionary.""" + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return ts + + timestamp = to_timestamp(ts) + sec = int(timestamp) + nanosec = int((timestamp - sec) * 1_000_000_000) + return ROSTime(sec=sec, nanosec=nanosec) + + +def to_human_readable(ts: float) -> str: + """Convert timestamp to human-readable format with date and time.""" + import time + + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(ts)) + + +def to_datetime(ts: TimeLike, tz=None) -> datetime: # type: ignore[no-untyped-def] + if isinstance(ts, datetime): + if ts.tzinfo is None: + # Assume UTC for naive datetime + ts = ts.replace(tzinfo=timezone.utc) + if tz is not None: + return ts.astimezone(tz) + return ts.astimezone() # Convert to local tz + + # Convert to timestamp first + timestamp = to_timestamp(ts) + + # Create datetime from timestamp + if tz is not None: + return datetime.fromtimestamp(timestamp, tz=tz) + else: + # Use local timezone by default + return datetime.fromtimestamp(timestamp).astimezone() + + +class Timestamped: + ts: float + + def __init__(self, ts: float) -> None: + self.ts = ts + + def dt(self) -> datetime: + return datetime.fromtimestamp(self.ts, tz=timezone.utc).astimezone() + + def ros_timestamp(self) -> list[int]: + """Convert timestamp to ROS-style list [sec, nanosec].""" + sec = int(self.ts) + nanosec = int((self.ts - sec) * 1_000_000_000) + return [sec, nanosec] + + +T = TypeVar("T", bound=Timestamped) + + +class TimestampedCollection(Generic[T]): + """A collection of timestamped objects with efficient time-based operations.""" + + def __init__(self, items: Iterable[T] | None = None) -> None: + self._items = SortedKeyList(items or [], key=lambda x: x.ts) + + def add(self, item: T) -> None: + """Add a timestamped item to the collection.""" + self._items.add(item) + + def find_closest(self, timestamp: float, tolerance: float | None = None) -> T | None: + """Find the timestamped object closest to the given timestamp.""" + if not self._items: + return None + + # Use binary search to find insertion point + idx = self._items.bisect_key_left(timestamp) + + # Check exact match + if idx < len(self._items) and self._items[idx].ts == timestamp: + return self._items[idx] # type: ignore[no-any-return] + + # Find candidates: item before and after + candidates = [] + + # Item before + if idx > 0: + candidates.append((idx - 1, abs(self._items[idx - 1].ts - timestamp))) + + # Item after + if idx < len(self._items): + candidates.append((idx, abs(self._items[idx].ts - timestamp))) + + if not candidates: + return None + + # Find closest + # When distances are equal, prefer the later item (higher index) + closest_idx, closest_distance = min(candidates, key=lambda x: (x[1], -x[0])) + + # Check tolerance if provided + if tolerance is not None and closest_distance > tolerance: + return None + + return self._items[closest_idx] # type: ignore[no-any-return] + + def find_before(self, timestamp: float) -> T | None: + """Find the last item before the given timestamp.""" + idx = self._items.bisect_key_left(timestamp) + return self._items[idx - 1] if idx > 0 else None + + def find_after(self, timestamp: float) -> T | None: + """Find the first item after the given timestamp.""" + idx = self._items.bisect_key_right(timestamp) + return self._items[idx] if idx < len(self._items) else None + + def merge(self, other: "TimestampedCollection[T]") -> "TimestampedCollection[T]": + """Merge two timestamped collections into a new one.""" + result = TimestampedCollection[T]() + result._items = SortedKeyList(self._items + other._items, key=lambda x: x.ts) + return result + + def duration(self) -> float: + """Get the duration of the collection in seconds.""" + if len(self._items) < 2: + return 0.0 + return self._items[-1].ts - self._items[0].ts # type: ignore[no-any-return] + + def time_range(self) -> tuple[float, float] | None: + """Get the time range (start, end) of the collection.""" + if not self._items: + return None + return (self._items[0].ts, self._items[-1].ts) + + def slice_by_time(self, start: float, end: float) -> "TimestampedCollection[T]": + """Get a subset of items within the given time range.""" + start_idx = self._items.bisect_key_left(start) + end_idx = self._items.bisect_key_right(end) + return TimestampedCollection(self._items[start_idx:end_idx]) + + @property + def start_ts(self) -> float | None: + """Get the start timestamp of the collection.""" + return self._items[0].ts if self._items else None + + @property + def end_ts(self) -> float | None: + """Get the end timestamp of the collection.""" + return self._items[-1].ts if self._items else None + + def __len__(self) -> int: + return len(self._items) + + def __iter__(self) -> Iterator: # type: ignore[type-arg] + return iter(self._items) + + def __getitem__(self, idx: int) -> T: + return self._items[idx] # type: ignore[no-any-return] + + +PRIMARY = TypeVar("PRIMARY", bound=Timestamped) +SECONDARY = TypeVar("SECONDARY", bound=Timestamped) + + +class TimestampedBufferCollection(TimestampedCollection[T]): + """A timestamped collection that maintains a sliding time window, dropping old messages.""" + + def __init__(self, window_duration: float, items: Iterable[T] | None = None) -> None: + """ + Initialize with a time window duration in seconds. + + Args: + window_duration: Maximum age of messages to keep in seconds + items: Optional initial items + """ + super().__init__(items) + self.window_duration = window_duration + + def add(self, item: T) -> None: + """Add a timestamped item and remove any items outside the time window.""" + super().add(item) + self._prune_old_messages(item.ts) + + def _prune_old_messages(self, current_ts: float) -> None: + """Remove messages older than window_duration from the given timestamp.""" + cutoff_ts = current_ts - self.window_duration + + # Find the index of the first item that should be kept + keep_idx = self._items.bisect_key_left(cutoff_ts) + + # Remove old items + if keep_idx > 0: + del self._items[:keep_idx] + + def remove_by_timestamp(self, timestamp: float) -> bool: + """Remove an item with the given timestamp. Returns True if item was found and removed.""" + idx = self._items.bisect_key_left(timestamp) + + if idx < len(self._items) and self._items[idx].ts == timestamp: + del self._items[idx] + return True + return False + + def remove(self, item: T) -> bool: + """Remove a timestamped item from the collection. Returns True if item was found and removed.""" + return self.remove_by_timestamp(item.ts) + + +class MatchContainer(Timestamped, Generic[PRIMARY, SECONDARY]): + """ + This class stores a primary item along with its partial matches to secondary items, + tracking which secondaries are still missing to avoid redundant searches. + """ + + def __init__(self, primary: PRIMARY, matches: list[SECONDARY | None]) -> None: + super().__init__(primary.ts) + self.primary = primary + self.matches = matches # Direct list with None for missing matches + + def message_received(self, secondary_idx: int, secondary_item: SECONDARY) -> None: + """Process a secondary message and check if it matches this primary.""" + if self.matches[secondary_idx] is None: + self.matches[secondary_idx] = secondary_item + + def is_complete(self) -> bool: + """Check if all secondary matches have been found.""" + return all(match is not None for match in self.matches) + + def get_tuple(self) -> tuple[PRIMARY, ...]: + """Get the result tuple for emission.""" + return (self.primary, *self.matches) # type: ignore[arg-type] + + +def align_timestamped( + primary_observable: Observable[PRIMARY], + *secondary_observables: Observable[SECONDARY], + buffer_size: float = 1.0, # seconds + match_tolerance: float = 0.1, # seconds +) -> Observable[tuple[PRIMARY, ...]]: + """Align a primary observable with one or more secondary observables. + + Args: + primary_observable: The primary stream to align against + *secondary_observables: One or more secondary streams to align + buffer_size: Time window to keep messages in seconds + match_tolerance: Maximum time difference for matching in seconds + + Returns: + If single secondary observable: Observable that emits tuples of (primary_item, secondary_item) + If multiple secondary observables: Observable that emits tuples of (primary_item, secondary1, secondary2, ...) + Each secondary item is the closest match from the corresponding + secondary observable, or None if no match within tolerance. + """ + + def subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + # Create a timed buffer collection for each secondary observable + secondary_collections: list[TimestampedBufferCollection[SECONDARY]] = [ + TimestampedBufferCollection(buffer_size) for _ in secondary_observables + ] + + # WeakLists to track subscribers to each secondary observable + secondary_stakeholders = defaultdict(WeakList) # type: ignore[var-annotated] + + # Buffer for unmatched MatchContainers - automatically expires old items + primary_buffer: TimestampedBufferCollection[MatchContainer[PRIMARY, SECONDARY]] = ( + TimestampedBufferCollection(buffer_size) + ) + + # Subscribe to all secondary observables + secondary_subs = [] + + def has_secondary_progressed_past(secondary_ts: float, primary_ts: float) -> bool: + """Check if secondary stream has progressed past the primary + tolerance.""" + return secondary_ts > primary_ts + match_tolerance + + def remove_stakeholder(stakeholder: MatchContainer) -> None: # type: ignore[type-arg] + """Remove a stakeholder from all tracking structures.""" + primary_buffer.remove(stakeholder) + for weak_list in secondary_stakeholders.values(): + weak_list.discard(stakeholder) + + def on_secondary(i: int, secondary_item: SECONDARY) -> None: + # Add the secondary item to its collection + secondary_collections[i].add(secondary_item) + + # Check all stakeholders for this secondary stream + for stakeholder in secondary_stakeholders[i]: + # If the secondary stream has progressed past this primary, + # we won't be able to match it anymore + if has_secondary_progressed_past(secondary_item.ts, stakeholder.ts): + logger.debug(f"secondary progressed, giving up {stakeholder.ts}") + + remove_stakeholder(stakeholder) + continue + + # Check if this secondary is within tolerance of the primary + if abs(stakeholder.ts - secondary_item.ts) <= match_tolerance: + stakeholder.message_received(i, secondary_item) + + # If all secondaries matched, emit result + if stakeholder.is_complete(): + logger.debug(f"Emitting deferred match {stakeholder.ts}") + observer.on_next(stakeholder.get_tuple()) + remove_stakeholder(stakeholder) + + for i, secondary_obs in enumerate(secondary_observables): + secondary_subs.append( + secondary_obs.subscribe( + lambda x, idx=i: on_secondary(idx, x), # type: ignore[misc] + on_error=observer.on_error, + ) + ) + + def on_primary(primary_item: PRIMARY) -> None: + # Try to find matches in existing secondary collections + matches = [None] * len(secondary_observables) + + for i, collection in enumerate(secondary_collections): + closest = collection.find_closest(primary_item.ts, tolerance=match_tolerance) + if closest is not None: + matches[i] = closest # type: ignore[call-overload] + else: + # Check if this secondary stream has already progressed past this primary + if collection.end_ts is not None and has_secondary_progressed_past( + collection.end_ts, primary_item.ts + ): + # This secondary won't match, so don't buffer this primary + return + + # If all matched, emit immediately without creating MatchContainer + if all(match is not None for match in matches): + logger.debug(f"Immadiate match {primary_item.ts}") + result = (primary_item, *matches) + observer.on_next(result) + else: + logger.debug(f"Deferred match attempt {primary_item.ts}") + match_container = MatchContainer(primary_item, matches) # type: ignore[type-var] + primary_buffer.add(match_container) # type: ignore[arg-type] + + for i, match in enumerate(matches): + if match is None: + secondary_stakeholders[i].append(match_container) + + # Subscribe to primary observable + primary_sub = primary_observable.subscribe( + on_primary, on_error=observer.on_error, on_completed=observer.on_completed + ) + + # Return a CompositeDisposable for proper cleanup + return CompositeDisposable(primary_sub, *secondary_subs) + + return create(subscribe) diff --git a/dimos/types/vector.py b/dimos/types/vector.py new file mode 100644 index 0000000000..654dc1f378 --- /dev/null +++ b/dimos/types/vector.py @@ -0,0 +1,457 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +from collections.abc import Sequence +from typing import TypeVar, Union + +import numpy as np + +from dimos.types.ros_polyfill import Vector3 + +T = TypeVar("T", bound="Vector") + +# Vector-like types that can be converted to/from Vector +VectorLike = Union[Sequence[int | float], Vector3, "Vector", np.ndarray] # type: ignore[type-arg] + + +class Vector: + """A wrapper around numpy arrays for vector operations with intuitive syntax.""" + + def __init__(self, *args: VectorLike) -> None: + """Initialize a vector from components or another iterable. + + Examples: + Vector(1, 2) # 2D vector + Vector(1, 2, 3) # 3D vector + Vector([1, 2, 3]) # From list + Vector(np.array([1, 2, 3])) # From numpy array + """ + if len(args) == 1 and hasattr(args[0], "__iter__"): + self._data = np.array(args[0], dtype=float) + + elif len(args) == 1: + self._data = np.array([args[0].x, args[0].y, args[0].z], dtype=float) # type: ignore[union-attr] + + else: + self._data = np.array(args, dtype=float) + + @property + def yaw(self) -> float: + return self.x + + @property + def tuple(self) -> tuple[float, ...]: + """Tuple representation of the vector.""" + return tuple(self._data) + + @property + def x(self) -> float: + """X component of the vector.""" + return self._data[0] if len(self._data) > 0 else 0.0 + + @property + def y(self) -> float: + """Y component of the vector.""" + return self._data[1] if len(self._data) > 1 else 0.0 + + @property + def z(self) -> float: + """Z component of the vector.""" + return self._data[2] if len(self._data) > 2 else 0.0 + + @property + def dim(self) -> int: + """Dimensionality of the vector.""" + return len(self._data) + + @property + def data(self) -> np.ndarray: # type: ignore[type-arg] + """Get the underlying numpy array.""" + return self._data + + def __getitem__(self, idx: int): # type: ignore[no-untyped-def] + return self._data[idx] + + def __repr__(self) -> str: + return f"Vector({self.data})" + + def __str__(self) -> str: + if self.dim < 2: + return self.__repr__() + + def getArrow(): # type: ignore[no-untyped-def] + repr = ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"] + + if self.x == 0 and self.y == 0: + return "·" + + # Calculate angle in radians and convert to directional index + angle = np.arctan2(self.y, self.x) + # Map angle to 0-7 index (8 directions) with proper orientation + dir_index = int(((angle + np.pi) * 4 / np.pi) % 8) + # Get directional arrow symbol + return repr[dir_index] + + return f"{getArrow()} Vector {self.__repr__()}" # type: ignore[no-untyped-call] + + def serialize(self) -> builtins.tuple: # type: ignore[type-arg] + """Serialize the vector to a tuple.""" + return {"type": "vector", "c": self._data.tolist()} # type: ignore[return-value] + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + """Check if two vectors are equal using numpy's allclose for floating point comparison.""" + if not isinstance(other, Vector): + return False + if len(self._data) != len(other._data): + return False + return np.allclose(self._data, other._data) + + def __add__(self: T, other: VectorLike) -> T: + other = to_vector(other) + if self.dim != other.dim: + max_dim = max(self.dim, other.dim) + return self.pad(max_dim) + other.pad(max_dim) + return self.__class__(self._data + other._data) + + def __sub__(self: T, other: VectorLike) -> T: + other = to_vector(other) + if self.dim != other.dim: + max_dim = max(self.dim, other.dim) + return self.pad(max_dim) - other.pad(max_dim) + return self.__class__(self._data - other._data) + + def __mul__(self: T, scalar: float) -> T: + return self.__class__(self._data * scalar) + + def __rmul__(self: T, scalar: float) -> T: + return self.__mul__(scalar) + + def __truediv__(self: T, scalar: float) -> T: + return self.__class__(self._data / scalar) + + def __neg__(self: T) -> T: + return self.__class__(-self._data) + + def dot(self, other: VectorLike) -> float: + """Compute dot product.""" + other = to_vector(other) + return float(np.dot(self._data, other._data)) + + def cross(self: T, other: VectorLike) -> T: + """Compute cross product (3D vectors only).""" + if self.dim != 3: + raise ValueError("Cross product is only defined for 3D vectors") + + other = to_vector(other) + if other.dim != 3: + raise ValueError("Cross product requires two 3D vectors") + + return self.__class__(np.cross(self._data, other._data)) + + def length(self) -> float: + """Compute the Euclidean length (magnitude) of the vector.""" + return float(np.linalg.norm(self._data)) + + def length_squared(self) -> float: + """Compute the squared length of the vector (faster than length()).""" + return float(np.sum(self._data * self._data)) + + def normalize(self: T) -> T: + """Return a normalized unit vector in the same direction.""" + length = self.length() + if length < 1e-10: # Avoid division by near-zero + return self.__class__(np.zeros_like(self._data)) + return self.__class__(self._data / length) + + def to_2d(self: T) -> T: + """Convert a vector to a 2D vector by taking only the x and y components.""" + return self.__class__(self._data[:2]) + + def pad(self: T, dim: int) -> T: + """Pad a vector with zeros to reach the specified dimension. + + If vector already has dimension >= dim, it is returned unchanged. + """ + if self.dim >= dim: + return self + + padded = np.zeros(dim, dtype=float) + padded[: len(self._data)] = self._data + return self.__class__(padded) + + def distance(self, other: VectorLike) -> float: + """Compute Euclidean distance to another vector.""" + other = to_vector(other) + return float(np.linalg.norm(self._data - other._data)) + + def distance_squared(self, other: VectorLike) -> float: + """Compute squared Euclidean distance to another vector (faster than distance()).""" + other = to_vector(other) + diff = self._data - other._data + return float(np.sum(diff * diff)) + + def angle(self, other: VectorLike) -> float: + """Compute the angle (in radians) between this vector and another.""" + other = to_vector(other) + if self.length() < 1e-10 or other.length() < 1e-10: + return 0.0 + + cos_angle = np.clip( + np.dot(self._data, other._data) + / (np.linalg.norm(self._data) * np.linalg.norm(other._data)), + -1.0, + 1.0, + ) + return float(np.arccos(cos_angle)) + + def project(self: T, onto: VectorLike) -> T: + """Project this vector onto another vector.""" + onto = to_vector(onto) + onto_length_sq = np.sum(onto._data * onto._data) + if onto_length_sq < 1e-10: + return self.__class__(np.zeros_like(self._data)) + + scalar_projection = np.dot(self._data, onto._data) / onto_length_sq + return self.__class__(scalar_projection * onto._data) + + @classmethod + def zeros(cls: type[T], dim: int) -> T: + """Create a zero vector of given dimension.""" + return cls(np.zeros(dim)) + + @classmethod + def ones(cls: type[T], dim: int) -> T: + """Create a vector of ones with given dimension.""" + return cls(np.ones(dim)) + + @classmethod + def unit_x(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the x direction.""" + v = np.zeros(dim) + v[0] = 1.0 + return cls(v) + + @classmethod + def unit_y(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the y direction.""" + v = np.zeros(dim) + v[1] = 1.0 + return cls(v) + + @classmethod + def unit_z(cls: type[T], dim: int = 3) -> T: + """Create a unit vector in the z direction.""" + v = np.zeros(dim) + if dim > 2: + v[2] = 1.0 + return cls(v) + + def to_list(self) -> list[float]: + """Convert the vector to a list.""" + return self._data.tolist() # type: ignore[no-any-return] + + def to_tuple(self) -> builtins.tuple[float, ...]: + """Convert the vector to a tuple.""" + return tuple(self._data) + + def to_numpy(self) -> np.ndarray: # type: ignore[type-arg] + """Convert the vector to a numpy array.""" + return self._data + + def is_zero(self) -> bool: + """Check if this is a zero vector (all components are zero). + + Returns: + True if all components are zero, False otherwise + """ + return np.allclose(self._data, 0.0) + + def __bool__(self) -> bool: + """Boolean conversion for Vector. + + A Vector is considered False if it's a zero vector (all components are zero), + and True otherwise. + + Returns: + False if vector is zero, True otherwise + """ + return not self.is_zero() + + +def to_numpy(value: VectorLike) -> np.ndarray: # type: ignore[type-arg] + """Convert a vector-compatible value to a numpy array. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Numpy array representation + """ + if isinstance(value, Vector3): + return np.array([value.x, value.y, value.z], dtype=float) + if isinstance(value, Vector): + return value.data + elif isinstance(value, np.ndarray): + return value + else: + return np.array(value, dtype=float) + + +def to_vector(value: VectorLike) -> Vector: + """Convert a vector-compatible value to a Vector object. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Vector object + """ + if isinstance(value, Vector): + return value + else: + return Vector(value) + + +def to_tuple(value: VectorLike) -> tuple[float, ...]: + """Convert a vector-compatible value to a tuple. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Tuple of floats + """ + if isinstance(value, Vector3): + return tuple([value.x, value.y, value.z]) + if isinstance(value, Vector): + return tuple(value.data) + elif isinstance(value, np.ndarray): + return tuple(value.tolist()) + elif isinstance(value, tuple): + return value + else: + return tuple(value) + + +def to_list(value: VectorLike) -> list[float]: + """Convert a vector-compatible value to a list. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + List of floats + """ + if isinstance(value, Vector): + return value.data.tolist() # type: ignore[no-any-return] + elif isinstance(value, np.ndarray): + return value.tolist() # type: ignore[no-any-return] + elif isinstance(value, list): + return value + else: + return list(value) # type: ignore[arg-type] + + +# Helper functions to check dimensionality +def is_2d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 2D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 2D + """ + if isinstance(value, Vector3): + return False + elif isinstance(value, Vector): + return len(value) == 2 # type: ignore[arg-type] + elif isinstance(value, np.ndarray): + return value.shape[-1] == 2 or value.size == 2 + else: + return len(value) == 2 + + +def is_3d(value: VectorLike) -> bool: + """Check if a vector-compatible value is 3D. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + True if the value is 3D + """ + if isinstance(value, Vector): + return len(value) == 3 # type: ignore[arg-type] + elif isinstance(value, Vector3): + return True + elif isinstance(value, np.ndarray): + return value.shape[-1] == 3 or value.size == 3 + else: + return len(value) == 3 + + +# Extraction functions for XYZ components +def x(value: VectorLike) -> float: + """Get the X component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + X component as a float + """ + if isinstance(value, Vector): + return value.x + elif isinstance(value, Vector3): + return value.x # type: ignore[no-any-return] + else: + return float(to_numpy(value)[0]) + + +def y(value: VectorLike) -> float: + """Get the Y component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Y component as a float + """ + if isinstance(value, Vector): + return value.y + elif isinstance(value, Vector3): + return value.y # type: ignore[no-any-return] + else: + arr = to_numpy(value) + return float(arr[1]) if len(arr) > 1 else 0.0 + + +def z(value: VectorLike) -> float: + """Get the Z component of a vector-compatible value. + + Args: + value: Any vector-like object (Vector, numpy array, tuple, list) + + Returns: + Z component as a float + """ + if isinstance(value, Vector): + return value.z + elif isinstance(value, Vector3): + return value.z # type: ignore[no-any-return] + else: + arr = to_numpy(value) + return float(arr[2]) if len(arr) > 2 else 0.0 diff --git a/dimos/types/videostream.py b/dimos/types/videostream.py deleted file mode 100644 index 820f24efe2..0000000000 --- a/dimos/types/videostream.py +++ /dev/null @@ -1,116 +0,0 @@ -from datetime import timedelta -import cv2 -import numpy as np -import os -from reactivex import Observable -from reactivex import operators as ops - -class StreamUtils: - def limit_emission_rate(frame_stream, time_delta=timedelta(milliseconds=40)): - return frame_stream.pipe( - ops.throttle_first(time_delta) - ) - - -# TODO: Reorganize, filenaming -class FrameProcessor: - def __init__(self, output_dir='/app/assets/frames'): - self.output_dir = output_dir - os.makedirs(self.output_dir, exist_ok=True) - self.image_count = 0 - # TODO: Add randomness to jpg folder storage naming. - # Will overwrite between sessions. - - def to_grayscale(self, frame): - if frame is None: - print("Received None frame for grayscale conversion.") - return None - return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - - def edge_detection(self, frame): - return cv2.Canny(frame, 100, 200) - - def resize(self, frame, scale=0.5): - return cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) - - def export_to_jpeg(self, frame, save_limit=100, suffix=""): - if frame is None: - print("Error: Attempted to save a None image.") - return None - - # Check if the image has an acceptable number of channels - if len(frame.shape) == 3 and frame.shape[2] not in [1, 3, 4]: - print(f"Error: Frame with shape {frame.shape} has unsupported number of channels.") - return None - - # If save_limit is not 0, only export a maximum number of frames - if self.image_count > save_limit: - return frame - - filepath = os.path.join(self.output_dir, f'{suffix}_image_{self.image_count}.jpg') - cv2.imwrite(filepath, frame) - self.image_count += 1 - return frame - - def compute_optical_flow(self, acc, current_frame): - prev_frame, _ = acc # acc (accumulator) contains the previous frame and its flow (which is ignored here) - - if prev_frame is None: - # Skip processing for the first frame as there's no previous frame to compare against. - return (current_frame, None) - - # Convert frames to grayscale (if not already done) - gray_current = self.to_grayscale(current_frame) - gray_prev = self.to_grayscale(prev_frame) - - # Compute optical flow - flow = cv2.calcOpticalFlowFarneback(gray_prev, gray_current, None, 0.5, 3, 15, 3, 5, 1.2, 0) - - # Relevancy calulation (average magnitude of flow vectors) - mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) - relevancy = np.mean(mag) - - # Return the current frame as the new previous frame and the processed optical flow, with relevancy score - return (current_frame, flow, relevancy) - - def visualize_flow(self, flow): - if flow is None: - return None - hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8) - hsv[..., 1] = 255 - mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) - hsv[..., 0] = ang * 180 / np.pi / 2 - hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) - rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) - return rgb - - # ============================== - - def process_stream_edge_detection(self, frame_stream): - return frame_stream.pipe( - ops.map(self.edge_detection), - ) - - def process_stream_resize(self, frame_stream): - return frame_stream.pipe( - ops.map(self.resize), - ) - - def process_stream_to_greyscale(self, frame_stream): - return frame_stream.pipe( - ops.map(self.to_grayscale), - ) - - # TODO: Propogate up relevancy score from compute_optical_flow - def process_stream_optical_flow(self, frame_stream): - return frame_stream.pipe( - ops.scan(self.compute_optical_flow, (None, None)), # Initial value for scan is (None, None) - ops.map(lambda result: result[1]), # Extract only the flow part from the tuple - ops.filter(lambda flow: flow is not None), - ops.map(self.visualize_flow), - ) - - def process_stream_export_to_jpeg(self, frame_stream, suffix=""): - return frame_stream.pipe( - ops.map(lambda frame: self.export_to_jpeg(frame, suffix=suffix)), - ) \ No newline at end of file diff --git a/dimos/types/weaklist.py b/dimos/types/weaklist.py new file mode 100644 index 0000000000..a720d54e2d --- /dev/null +++ b/dimos/types/weaklist.py @@ -0,0 +1,86 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Weak reference list implementation that automatically removes dead references.""" + +from collections.abc import Iterator +from typing import Any +import weakref + + +class WeakList: + """A list that holds weak references to objects. + + Objects are automatically removed when garbage collected. + Supports iteration, append, remove, and length operations. + """ + + def __init__(self) -> None: + self._refs = [] # type: ignore[var-annotated] + + def append(self, obj: Any) -> None: + """Add an object to the list (stored as weak reference).""" + + def _cleanup(ref) -> None: # type: ignore[no-untyped-def] + try: + self._refs.remove(ref) + except ValueError: + pass + + self._refs.append(weakref.ref(obj, _cleanup)) + + def remove(self, obj: Any) -> None: + """Remove an object from the list.""" + for i, ref in enumerate(self._refs): + if ref() is obj: + del self._refs[i] + return + raise ValueError(f"{obj} not in WeakList") + + def discard(self, obj: Any) -> None: + """Remove an object from the list if present, otherwise do nothing.""" + try: + self.remove(obj) + except ValueError: + pass + + def __iter__(self) -> Iterator[Any]: + """Iterate over live objects, skipping dead references.""" + # Create a copy to avoid modification during iteration + for ref in self._refs[:]: + obj = ref() + if obj is not None: + yield obj + + def __len__(self) -> int: + """Return count of live objects.""" + return sum(1 for _ in self) + + def __contains__(self, obj: Any) -> bool: + """Check if object is in the list.""" + return any(ref() is obj for ref in self._refs) + + def clear(self) -> None: + """Remove all references.""" + self._refs.clear() + + def __getitem__(self, index: int) -> Any: + """Get object at index (only counting live objects).""" + for i, obj in enumerate(self): + if i == index: + return obj + raise IndexError("WeakList index out of range") + + def __repr__(self) -> str: + return f"WeakList({list(self)})" diff --git a/dimos/utils/actor_registry.py b/dimos/utils/actor_registry.py new file mode 100644 index 0000000000..6f6d219594 --- /dev/null +++ b/dimos/utils/actor_registry.py @@ -0,0 +1,84 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared memory registry for tracking actor deployments across processes.""" + +import json +from multiprocessing import shared_memory + + +class ActorRegistry: + """Shared memory registry of actor deployments.""" + + SHM_NAME = "dimos_actor_registry" + SHM_SIZE = 65536 # 64KB should be enough for most deployments + + @staticmethod + def update(actor_name: str, worker_id: str) -> None: + """Update registry with new actor deployment.""" + try: + shm = shared_memory.SharedMemory(name=ActorRegistry.SHM_NAME) + except FileNotFoundError: + shm = shared_memory.SharedMemory( + name=ActorRegistry.SHM_NAME, create=True, size=ActorRegistry.SHM_SIZE + ) + + # Read existing data + data = ActorRegistry._read_from_shm(shm) + + # Update with new actor + data[actor_name] = worker_id + + # Write back + ActorRegistry._write_to_shm(shm, data) + shm.close() + + @staticmethod + def get_all() -> dict[str, str]: + """Get all actor->worker mappings.""" + try: + shm = shared_memory.SharedMemory(name=ActorRegistry.SHM_NAME) + data = ActorRegistry._read_from_shm(shm) + shm.close() + return data + except FileNotFoundError: + return {} + + @staticmethod + def clear() -> None: + """Clear the registry and free shared memory.""" + try: + shm = shared_memory.SharedMemory(name=ActorRegistry.SHM_NAME) + ActorRegistry._write_to_shm(shm, {}) + shm.close() + shm.unlink() + except FileNotFoundError: + pass + + @staticmethod + def _read_from_shm(shm) -> dict[str, str]: # type: ignore[no-untyped-def] + """Read JSON data from shared memory.""" + raw = bytes(shm.buf[:]).rstrip(b"\x00") + if not raw: + return {} + return json.loads(raw.decode("utf-8")) # type: ignore[no-any-return] + + @staticmethod + def _write_to_shm(shm, data: dict[str, str]): # type: ignore[no-untyped-def] + """Write JSON data to shared memory.""" + json_bytes = json.dumps(data).encode("utf-8") + if len(json_bytes) > ActorRegistry.SHM_SIZE: + raise ValueError("Registry data too large for shared memory") + shm.buf[: len(json_bytes)] = json_bytes + shm.buf[len(json_bytes) :] = b"\x00" * (ActorRegistry.SHM_SIZE - len(json_bytes)) diff --git a/dimos/utils/cli/__init__.py b/dimos/utils/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/utils/cli/agentspy/agentspy.py b/dimos/utils/cli/agentspy/agentspy.py new file mode 100644 index 0000000000..52760cb2da --- /dev/null +++ b/dimos/utils/cli/agentspy/agentspy.py @@ -0,0 +1,238 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +import time +from typing import Any, Union + +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.widgets import Footer, RichLog + +from dimos.protocol.pubsub.lcmpubsub import PickleLCM +from dimos.utils.cli import theme + +# Type alias for all message types we might receive +AnyMessage = Union[SystemMessage, ToolMessage, AIMessage, HumanMessage] + + +@dataclass +class MessageEntry: + """Store a single message with metadata.""" + + timestamp: float + message: AnyMessage + + def __post_init__(self) -> None: + """Initialize timestamp if not provided.""" + if self.timestamp is None: + self.timestamp = time.time() + + +class AgentMessageMonitor: + """Monitor agent messages published via LCM.""" + + def __init__(self, topic: str = "/agent", max_messages: int = 1000) -> None: + self.topic = topic + self.max_messages = max_messages + self.messages: deque[MessageEntry] = deque(maxlen=max_messages) + self.transport = PickleLCM() + self.transport.start() + self.callbacks: list[callable] = [] # type: ignore[valid-type] + pass + + def start(self) -> None: + """Start monitoring messages.""" + self.transport.subscribe(self.topic, self._handle_message) + + def stop(self) -> None: + """Stop monitoring.""" + # PickleLCM doesn't have explicit stop method + pass + + def _handle_message(self, msg: Any, topic: str) -> None: + """Handle incoming messages.""" + # Check if it's one of the message types we care about + if isinstance(msg, SystemMessage | ToolMessage | AIMessage | HumanMessage): + entry = MessageEntry(timestamp=time.time(), message=msg) + self.messages.append(entry) + + # Notify callbacks + for callback in self.callbacks: + callback(entry) # type: ignore[misc] + else: + pass + + def subscribe(self, callback: callable) -> None: # type: ignore[valid-type] + """Subscribe to new messages.""" + self.callbacks.append(callback) + + def get_messages(self) -> list[MessageEntry]: + """Get all stored messages.""" + return list(self.messages) + + +def format_timestamp(timestamp: float) -> str: + """Format timestamp as HH:MM:SS.mmm.""" + return ( + time.strftime("%H:%M:%S", time.localtime(timestamp)) + f".{int((timestamp % 1) * 1000):03d}" + ) + + +def get_message_type_and_style(msg: AnyMessage) -> tuple[str, str]: + """Get message type name and style color.""" + if isinstance(msg, HumanMessage): + return "Human ", "green" + elif isinstance(msg, AIMessage): + if hasattr(msg, "metadata") and msg.metadata.get("state"): + return "State ", "blue" + return "Agent ", "yellow" + elif isinstance(msg, ToolMessage): + return "Tool ", "red" + elif isinstance(msg, SystemMessage): + return "System", "red" + else: + return "Unkn ", "white" + + +def format_message_content(msg: AnyMessage) -> str: + """Format message content for display.""" + if isinstance(msg, ToolMessage): + return f"{msg.name}() -> {msg.content}" + elif isinstance(msg, AIMessage) and msg.tool_calls: + # Include tool calls in content + tool_info = [] + for tc in msg.tool_calls: + args_str = str(tc.get("args", {})) + tool_info.append(f"{tc.get('name')}({args_str})") + content = msg.content or "" + if content and tool_info: + return f"{content}\n[Tool Calls: {', '.join(tool_info)}]" + elif tool_info: + return f"[Tool Calls: {', '.join(tool_info)}]" + return content # type: ignore[return-value] + else: + return str(msg.content) if hasattr(msg, "content") else str(msg) + + +class AgentSpyApp(App): # type: ignore[type-arg] + """TUI application for monitoring agent messages.""" + + CSS_PATH = theme.CSS_PATH + + CSS = f""" + Screen {{ + layout: vertical; + background: {theme.BACKGROUND}; + }} + + RichLog {{ + height: 1fr; + border: none; + background: {theme.BACKGROUND}; + padding: 0 1; + }} + + Footer {{ + dock: bottom; + height: 1; + }} + """ + + BINDINGS = [ + Binding("q", "quit", "Quit"), + Binding("c", "clear", "Clear"), + Binding("ctrl+c", "quit", show=False), + ] + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.monitor = AgentMessageMonitor() + self.message_log: RichLog | None = None + + def compose(self) -> ComposeResult: + """Compose the UI.""" + self.message_log = RichLog(wrap=True, highlight=True, markup=True) + yield self.message_log + yield Footer() + + def on_mount(self) -> None: + """Start monitoring when app mounts.""" + self.theme = "flexoki" + + # Subscribe to new messages + self.monitor.subscribe(self.on_new_message) + self.monitor.start() + + # Write existing messages to the log + for entry in self.monitor.get_messages(): + self.on_new_message(entry) + + def on_unmount(self) -> None: + """Stop monitoring when app unmounts.""" + self.monitor.stop() + + def on_new_message(self, entry: MessageEntry) -> None: + """Handle new messages.""" + if self.message_log: + msg = entry.message + msg_type, style = get_message_type_and_style(msg) + content = format_message_content(msg) + + # Format the message for the log + timestamp = format_timestamp(entry.timestamp) + self.message_log.write( + f"[dim white]{timestamp}[/dim white] | " + f"[bold {style}]{msg_type}[/bold {style}] | " + f"[{style}]{content}[/{style}]" + ) + + def refresh_display(self) -> None: + """Refresh the message display.""" + # Not needed anymore as messages are written directly to the log + + def action_clear(self) -> None: + """Clear message history.""" + self.monitor.messages.clear() + if self.message_log: + self.message_log.clear() + + +def main() -> None: + """Main entry point for agentspy.""" + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "web": + import os + + from textual_serve.server import Server # type: ignore[import-not-found] + + server = Server(f"python {os.path.abspath(__file__)}") + server.serve() + else: + app = AgentSpyApp() + app.run() + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/cli/agentspy/demo_agentspy.py b/dimos/utils/cli/agentspy/demo_agentspy.py new file mode 100755 index 0000000000..c747ab65f6 --- /dev/null +++ b/dimos/utils/cli/agentspy/demo_agentspy.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demo script to test agent message publishing and agentspy reception.""" + +import time + +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + +from dimos.protocol.pubsub import lcm # type: ignore[attr-defined] +from dimos.protocol.pubsub.lcmpubsub import PickleLCM + + +def test_publish_messages() -> None: + """Publish test messages to verify agentspy is working.""" + print("Starting agent message publisher demo...") + + # Create transport + transport = PickleLCM() + topic = lcm.Topic("/agent") + + print(f"Publishing to topic: {topic}") + + # Test messages + messages = [ + SystemMessage("System initialized for testing"), + HumanMessage("Hello agent, can you help me?"), + AIMessage( + "Of course! I'm here to help.", + tool_calls=[{"name": "get_info", "args": {"query": "test"}, "id": "1"}], + ), + ToolMessage(name="get_info", content="Test result: success", tool_call_id="1"), + AIMessage("The test was successful!", metadata={"state": True}), + ] + + # Publish messages with delays + for i, msg in enumerate(messages): + print(f"\nPublishing message {i + 1}: {type(msg).__name__}") + print(f"Content: {msg.content if hasattr(msg, 'content') else msg}") + + transport.publish(topic, msg) + time.sleep(1) # Wait 1 second between messages + + print("\nAll messages published! Check agentspy to see if they were received.") + print("Keeping publisher alive for 10 more seconds...") + time.sleep(10) + + +if __name__ == "__main__": + test_publish_messages() diff --git a/dimos/utils/cli/boxglove/boxglove.py b/dimos/utils/cli/boxglove/boxglove.py new file mode 100644 index 0000000000..3ace1c1aaa --- /dev/null +++ b/dimos/utils/cli/boxglove/boxglove.py @@ -0,0 +1,292 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import reactivex.operators as ops +from rich.text import Text +from textual.app import App, ComposeResult +from textual.containers import Container +from textual.reactive import reactive +from textual.widgets import Footer, Static + +from dimos import core +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + +if TYPE_CHECKING: + from reactivex.disposable import Disposable + + from dimos.msgs.nav_msgs import OccupancyGrid + from dimos.utils.cli.boxglove.connection import Connection + + +blocks = "█▗▖▝▘" +shades = "█░░░░" +crosses = "┼┌┐└┘" +quadrant = "█▟▙▜▛" +triangles = "◼◢◣◥◤" # 45-degree triangular blocks + + +alphabet = crosses + +# Box drawing characters for smooth edges +top_left = alphabet[1] # Quadrant lower right +top_right = alphabet[2] # Quadrant lower left +bottom_left = alphabet[3] # Quadrant upper right +bottom_right = alphabet[4] # Quadrant upper left +full = alphabet[0] # Full block + + +class OccupancyGridApp(App): # type: ignore[type-arg] + """A Textual app for visualizing OccupancyGrid data in real-time.""" + + CSS = """ + Screen { + layout: vertical; + overflow: hidden; + } + + #grid-container { + width: 100%; + height: 1fr; + overflow: hidden; + margin: 0; + padding: 0; + } + + #grid-display { + width: 100%; + height: 100%; + margin: 0; + padding: 0; + } + + Footer { + dock: bottom; + height: 1; + } + """ + + # Reactive properties + grid_data: reactive[OccupancyGrid | None] = reactive(None) + + BINDINGS = [ + ("q", "quit", "Quit"), + ("ctrl+c", "quit", "Quit"), + ] + + def __init__(self, connection: Connection, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.connection = connection + self.subscription: Disposable | None = None + self.grid_display: Static | None = None + self.cached_grid: OccupancyGrid | None = None + + def compose(self) -> ComposeResult: + """Create the app layout.""" + # Container for the grid (no scrolling since we scale to fit) + with Container(id="grid-container"): + self.grid_display = Static("", id="grid-display") + yield self.grid_display + + yield Footer() + + def on_mount(self) -> None: + """Subscribe to the connection when the app starts.""" + self.theme = "flexoki" + + # Subscribe to the OccupancyGrid stream + def on_grid(grid: OccupancyGrid) -> None: + self.grid_data = grid + + def on_error(error: Exception) -> None: + self.notify(f"Error: {error}", severity="error") + + self.subscription = self.connection().subscribe(on_next=on_grid, on_error=on_error) # type: ignore[assignment] + + async def on_unmount(self) -> None: + """Clean up subscription when app closes.""" + if self.subscription: + self.subscription.dispose() + + def watch_grid_data(self, grid: OccupancyGrid | None) -> None: + """Update display when new grid data arrives.""" + if grid is None: + return + + # Cache the grid for rerendering on terminal resize + self.cached_grid = grid + + # Render the grid as ASCII art + grid_text = self.render_grid(grid) + self.grid_display.update(grid_text) # type: ignore[union-attr] + + def on_resize(self, event) -> None: # type: ignore[no-untyped-def] + """Handle terminal resize events.""" + if self.cached_grid: + # Re-render with new terminal dimensions + grid_text = self.render_grid(self.cached_grid) + self.grid_display.update(grid_text) # type: ignore[union-attr] + + def render_grid(self, grid: OccupancyGrid) -> Text: + """Render the OccupancyGrid as colored ASCII art, scaled to fit terminal.""" + text = Text() + + # Get the actual container dimensions + container = self.query_one("#grid-container") + content_width = container.content_size.width + content_height = container.content_size.height + + # Each cell will be 2 chars wide to make square pixels + terminal_width = max(1, content_width // 2) + terminal_height = max(1, content_height) + + # Handle edge cases + if grid.width == 0 or grid.height == 0: + return text # Return empty text for empty grid + + # Calculate scaling factors (as floats for smoother scaling) + scale_x = grid.width / terminal_width + scale_y = grid.height / terminal_height + + # Use the larger scale to ensure the grid fits + scale_float = max(1.0, max(scale_x, scale_y)) + + # For smoother resizing, we'll use fractional scaling + # This means we might sample between grid cells + render_width = min(int(grid.width / scale_float), terminal_width) + render_height = min(int(grid.height / scale_float), terminal_height) + + # Store both integer and float scale for different uses + int(np.ceil(scale_float)) # For legacy compatibility + + # Adjust render dimensions to use all available space + # This reduces jumping by allowing fractional cell sizes + actual_scale_x = grid.width / render_width if render_width > 0 else 1 + actual_scale_y = grid.height / render_height if render_height > 0 else 1 + + # Function to get value with fractional scaling + def get_cell_value(grid_data: np.ndarray, x: int, y: int) -> int: # type: ignore[type-arg] + # Use fractional coordinates for smoother scaling + y_center = int((y + 0.5) * actual_scale_y) + x_center = int((x + 0.5) * actual_scale_x) + + # Clamp to grid bounds + y_center = max(0, min(y_center, grid.height - 1)) + x_center = max(0, min(x_center, grid.width - 1)) + + # For now, just sample the center point + # Could do area averaging for smoother results + return grid_data[y_center, x_center] # type: ignore[no-any-return] + + # Helper function to check if a cell is an obstacle + def is_obstacle(grid_data: np.ndarray, x: int, y: int) -> bool: # type: ignore[type-arg] + if x < 0 or x >= render_width or y < 0 or y >= render_height: + return False + value = get_cell_value(grid_data, x, y) + return value > 90 # Consider cells with >90% probability as obstacles + + # Character and color mapping with intelligent obstacle rendering + def get_cell_char_and_style(grid_data: np.ndarray, x: int, y: int) -> tuple[str, str]: # type: ignore[type-arg] + value = get_cell_value(grid_data, x, y) + norm_value = min(value, 100) / 100.0 + + if norm_value > 0.9: + # Check neighbors for intelligent character selection + top = is_obstacle(grid_data, x, y + 1) + bottom = is_obstacle(grid_data, x, y - 1) + left = is_obstacle(grid_data, x - 1, y) + right = is_obstacle(grid_data, x + 1, y) + + # Count neighbors + neighbor_count = sum([top, bottom, left, right]) + + # Select character based on neighbor configuration + if neighbor_count == 4: + # All neighbors are obstacles - use full block + symbol = full + full + elif neighbor_count == 3: + # Three neighbors - use full block (interior edge) + symbol = full + full + elif neighbor_count == 2: + # Two neighbors - check configuration + if top and bottom: + symbol = full + full # Vertical corridor + elif left and right: + symbol = full + full # Horizontal corridor + elif top and left: + symbol = bottom_right + " " + elif top and right: + symbol = " " + bottom_left + elif bottom and left: + symbol = top_right + " " + elif bottom and right: + symbol = " " + top_left + else: + symbol = full + full + elif neighbor_count == 1: + # One neighbor - point towards it + if top: + symbol = bottom_left + bottom_right + elif bottom: + symbol = top_left + top_right + elif left: + symbol = top_right + bottom_right + elif right: + symbol = top_left + bottom_left + else: + symbol = full + full + else: + # No neighbors - isolated obstacle + symbol = full + full + + return symbol, None # type: ignore[return-value] + else: + return " ", None # type: ignore[return-value] + + # Render the scaled grid row by row (flip Y axis for proper display) + for y in range(render_height - 1, -1, -1): + for x in range(render_width): + char, style = get_cell_char_and_style(grid.grid, x, y) + text.append(char, style=style) + if y > 0: # Add newline except for last row + text.append("\n") + + # Could show scale info in footer status if needed + + return text + + +def main() -> None: + """Run the OccupancyGrid visualizer with a connection.""" + # app = OccupancyGridApp(core.LCMTransport("/global_costmap", OccupancyGrid).observable) + + app = OccupancyGridApp( + lambda: core.LCMTransport("/lidar", LidarMessage) # type: ignore[no-untyped-call] + .observable() + .pipe(ops.map(lambda msg: msg.costmap())) # type: ignore[attr-defined] + ) + app.run() + import time + + while True: + time.sleep(1) + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/cli/boxglove/connection.py b/dimos/utils/cli/boxglove/connection.py new file mode 100644 index 0000000000..1743684626 --- /dev/null +++ b/dimos/utils/cli/boxglove/connection.py @@ -0,0 +1,71 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +import pickle + +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.observable import Observable + +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.protocol.pubsub import lcm # type: ignore[attr-defined] +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.utils.data import get_data +from dimos.utils.reactive import backpressure +from dimos.utils.testing import TimedSensorReplay + +Connection = Callable[[], Observable[OccupancyGrid]] + + +def live_connection() -> Observable[OccupancyGrid]: + def subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + lcm.autoconf() + l = lcm.LCM() + + def on_message(grid: OccupancyGrid, _) -> None: # type: ignore[no-untyped-def] + observer.on_next(grid) + + l.subscribe(lcm.Topic("/global_costmap", OccupancyGrid), on_message) + l.start() + + def dispose() -> None: + l.stop() + + return Disposable(dispose) + + return rx.create(subscribe) + + +def recorded_connection() -> Observable[OccupancyGrid]: + lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) + mapper = Map() + return backpressure( + lidar_store.stream(speed=1).pipe( + ops.map(mapper.add_frame), + ops.map(lambda _: mapper.costmap().inflate(0.1).gradient()), # type: ignore[attr-defined] + ) + ) + + +def single_message() -> Observable[OccupancyGrid]: + pointcloud_pickle = get_data("lcm_msgs") / "sensor_msgs/PointCloud2.pickle" + with open(pointcloud_pickle, "rb") as f: + pointcloud = PointCloud2.lcm_decode(pickle.load(f)) + mapper = Map() + mapper.add_frame(pointcloud) + return rx.just(mapper.costmap()) # type: ignore[attr-defined] diff --git a/dimos/utils/cli/dimos.tcss b/dimos/utils/cli/dimos.tcss new file mode 100644 index 0000000000..3ccbde957d --- /dev/null +++ b/dimos/utils/cli/dimos.tcss @@ -0,0 +1,91 @@ +/* DimOS Base Theme for Textual CLI Applications + * Based on colors.json - Official DimOS color palette + */ + +/* Base Color Palette (from colors.json) */ +$black: #0b0f0f; +$red: #ff0000; +$green: #00eeee; +$yellow: #ffcc00; +$blue: #5c9ff0; +$purple: #00eeee; +$cyan: #00eeee; +$white: #b5e4f4; + +/* Bright Colors */ +$bright-black: #404040; +$bright-red: #ff0000; +$bright-green: #00eeee; +$bright-yellow: #f2ea8c; +$bright-blue: #8cbdf2; +$bright-purple: #00eeee; +$bright-cyan: #00eeee; +$bright-white: #ffffff; + +/* Core Theme Colors */ +$background: #0b0f0f; +$foreground: #b5e4f4; +$cursor: #00eeee; + +/* Semantic Aliases */ +$bg: $black; +$border: $cyan; +$accent: $white; +$dim: $bright-black; +$timestamp: $bright-white; + +/* Message Type Colors */ +$system: $red; +$agent: #88ff88; +$tool: $cyan; +$tool-result: $yellow; +$human: $bright-white; + +/* Status Colors */ +$success: $green; +$error: $red; +$warning: $yellow; +$info: $cyan; + +/* Base Screen */ +Screen { + background: $bg; +} + +/* Default Container */ +Container { + background: $bg; +} + +/* Input Widget */ +Input { + background: $bg; + border: solid $border; + color: $accent; +} + +Input:focus { + border: solid $border; +} + +/* RichLog Widget */ +RichLog { + background: $bg; + border: solid $border; +} + +/* Button Widget */ +Button { + background: $bg; + border: solid $border; + color: $accent; +} + +Button:hover { + background: $dim; + border: solid $accent; +} + +Button:focus { + border: double $accent; +} diff --git a/dimos/utils/cli/foxglove_bridge/run_foxglove_bridge.py b/dimos/utils/cli/foxglove_bridge/run_foxglove_bridge.py new file mode 100644 index 0000000000..782e93b5e0 --- /dev/null +++ b/dimos/utils/cli/foxglove_bridge/run_foxglove_bridge.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +use lcm_foxglove_bridge as a module from dimos_lcm +""" + +import asyncio +import os +import threading + +import dimos_lcm # type: ignore[import-untyped] +from dimos_lcm.foxglove_bridge import FoxgloveBridge # type: ignore[import-untyped] + +dimos_lcm_path = os.path.dirname(os.path.abspath(dimos_lcm.__file__)) +print(f"Using dimos_lcm from: {dimos_lcm_path}") + + +def run_bridge_example() -> None: + """Example of running the bridge in a separate thread""" + + def bridge_thread() -> None: + """Thread function to run the bridge""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + bridge_instance = FoxgloveBridge(host="0.0.0.0", port=8765, debug=True, num_threads=4) + + loop.run_until_complete(bridge_instance.run()) + except Exception as e: + print(f"Bridge error: {e}") + finally: + loop.close() + + thread = threading.Thread(target=bridge_thread, daemon=True) + thread.start() + + print("Bridge started in background thread") + print("Open Foxglove Studio and connect to ws://localhost:8765") + print("Press Ctrl+C to exit") + + try: + while True: + threading.Event().wait(1) + except KeyboardInterrupt: + print("Shutting down...") + + +def main() -> None: + run_bridge_example() + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/cli/human/humancli.py b/dimos/utils/cli/human/humancli.py new file mode 100644 index 0000000000..a0ce0afff4 --- /dev/null +++ b/dimos/utils/cli/human/humancli.py @@ -0,0 +1,306 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +import textwrap +import threading +from typing import TYPE_CHECKING + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolCall, ToolMessage +from rich.highlighter import JSONHighlighter +from rich.theme import Theme +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import Container +from textual.widgets import Input, RichLog + +from dimos.core import pLCMTransport +from dimos.utils.cli import theme +from dimos.utils.generic import truncate_display_string + +if TYPE_CHECKING: + from textual.events import Key + +# Custom theme for JSON highlighting +JSON_THEME = Theme( + { + "json.key": theme.CYAN, + "json.str": theme.ACCENT, + "json.number": theme.ACCENT, + "json.bool_true": theme.ACCENT, + "json.bool_false": theme.ACCENT, + "json.null": theme.DIM, + "json.brace": theme.BRIGHT_WHITE, + } +) + + +class HumanCLIApp(App): # type: ignore[type-arg] + """IRC-like interface for interacting with DimOS agents.""" + + CSS_PATH = theme.CSS_PATH + + CSS = f""" + Screen {{ + background: {theme.BACKGROUND}; + }} + + #chat-container {{ + height: 1fr; + }} + + RichLog {{ + scrollbar-size: 0 0; + }} + + Input {{ + dock: bottom; + }} + """ + + BINDINGS = [ + Binding("q", "quit", "Quit", show=False), + Binding("ctrl+c", "quit", "Quit"), + Binding("ctrl+l", "clear", "Clear chat"), + ] + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.human_transport = pLCMTransport("/human_input") # type: ignore[var-annotated] + self.agent_transport = pLCMTransport("/agent") # type: ignore[var-annotated] + self.chat_log: RichLog | None = None + self.input_widget: Input | None = None + self._subscription_thread: threading.Thread | None = None + self._running = False + + def compose(self) -> ComposeResult: + """Compose the IRC-like interface.""" + with Container(id="chat-container"): + self.chat_log = RichLog(highlight=True, markup=True, wrap=False) + yield self.chat_log + + self.input_widget = Input(placeholder="Type a message...") + yield self.input_widget + + def on_mount(self) -> None: + """Initialize the app when mounted.""" + self._running = True + + # Apply custom JSON theme to app console + self.console.push_theme(JSON_THEME) + + # Set custom highlighter for RichLog + self.chat_log.highlighter = JSONHighlighter() # type: ignore[union-attr] + + # Start subscription thread + self._subscription_thread = threading.Thread(target=self._subscribe_to_agent, daemon=True) + self._subscription_thread.start() + + # Focus on input + self.input_widget.focus() # type: ignore[union-attr] + + self.chat_log.write(f"[{theme.ACCENT}]{theme.ascii_logo}[/{theme.ACCENT}]") # type: ignore[union-attr] + + # Welcome message + self._add_system_message("Connected to DimOS Agent Interface") + + def on_unmount(self) -> None: + """Clean up when unmounting.""" + self._running = False + + def _subscribe_to_agent(self) -> None: + """Subscribe to agent messages in a separate thread.""" + + def receive_msg(msg) -> None: # type: ignore[no-untyped-def] + if not self._running: + return + + timestamp = datetime.now().strftime("%H:%M:%S") + + if isinstance(msg, SystemMessage): + self.call_from_thread( + self._add_message, + timestamp, + "system", + truncate_display_string(msg.content, 1000), + theme.YELLOW, + ) + elif isinstance(msg, AIMessage): + content = msg.content or "" + tool_calls = msg.additional_kwargs.get("tool_calls", []) + + # Display the main content first + if content: + self.call_from_thread( + self._add_message, timestamp, "agent", content, theme.AGENT + ) + + # Display tool calls separately with different formatting + if tool_calls: + for tc in tool_calls: + tool_info = self._format_tool_call(tc) + self.call_from_thread( + self._add_message, timestamp, "tool", tool_info, theme.TOOL + ) + + # If neither content nor tool calls, show a placeholder + if not content and not tool_calls: + self.call_from_thread( + self._add_message, timestamp, "agent", "", theme.DIM + ) + elif isinstance(msg, ToolMessage): + self.call_from_thread( + self._add_message, timestamp, "tool", msg.content, theme.TOOL_RESULT + ) + elif isinstance(msg, HumanMessage): + self.call_from_thread( + self._add_message, timestamp, "human", msg.content, theme.HUMAN + ) + + self.agent_transport.subscribe(receive_msg) + + def _format_tool_call(self, tool_call: ToolCall) -> str: + """Format a tool call for display.""" + f = tool_call.get("function", {}) + name = f.get("name", "unknown") # type: ignore[attr-defined] + return f"▶ {name}({f.get('arguments', '')})" # type: ignore[attr-defined] + + def _add_message(self, timestamp: str, sender: str, content: str, color: str) -> None: + """Add a message to the chat log.""" + # Strip leading/trailing whitespace from content + content = content.strip() if content else "" + + # Format timestamp with nicer colors - split into hours, minutes, seconds + time_parts = timestamp.split(":") + if len(time_parts) == 3: + # Format as HH:MM:SS with colored colons + timestamp_formatted = f" [{theme.TIMESTAMP}]{time_parts[0]}:{time_parts[1]}:{time_parts[2]}[/{theme.TIMESTAMP}]" + else: + timestamp_formatted = f" [{theme.TIMESTAMP}]{timestamp}[/{theme.TIMESTAMP}]" + + # Format sender with consistent width + sender_formatted = f"[{color}]{sender:>8}[/{color}]" + + # Calculate the prefix length for proper indentation + # space (1) + timestamp (8) + space (1) + sender (8) + space (1) + separator (1) + space (1) = 21 + prefix = f"{timestamp_formatted} {sender_formatted} │ " + indent = " " * 19 # Spaces to align with the content after the separator + + # Get the width of the chat area (accounting for borders and padding) + width = self.chat_log.size.width - 4 if self.chat_log.size else 76 # type: ignore[union-attr] + + # Calculate the available width for text (subtract prefix length) + text_width = max(width - 20, 40) # Minimum 40 chars for text + + # Split content into lines first (respecting explicit newlines) + lines = content.split("\n") + + for line_idx, line in enumerate(lines): + # Wrap each line to fit the available width + if line_idx == 0: + # First line includes the full prefix + wrapped = textwrap.wrap( + line, width=text_width, initial_indent="", subsequent_indent="" + ) + if wrapped: + self.chat_log.write(prefix + f"[{color}]{wrapped[0]}[/{color}]") # type: ignore[union-attr] + for wrapped_line in wrapped[1:]: + self.chat_log.write(indent + f"│ [{color}]{wrapped_line}[/{color}]") # type: ignore[union-attr] + else: + # Empty line + self.chat_log.write(prefix) # type: ignore[union-attr] + else: + # Subsequent lines from explicit newlines + wrapped = textwrap.wrap( + line, width=text_width, initial_indent="", subsequent_indent="" + ) + if wrapped: + for wrapped_line in wrapped: + self.chat_log.write(indent + f"│ [{color}]{wrapped_line}[/{color}]") # type: ignore[union-attr] + else: + # Empty line + self.chat_log.write(indent + "│") # type: ignore[union-attr] + + def _add_system_message(self, content: str) -> None: + """Add a system message to the chat.""" + timestamp = datetime.now().strftime("%H:%M:%S") + self._add_message(timestamp, "system", content, theme.YELLOW) + + def on_key(self, event: Key) -> None: + """Handle key events.""" + if event.key == "ctrl+c": + self.exit() + event.prevent_default() + + def on_input_submitted(self, event: Input.Submitted) -> None: + """Handle input submission.""" + message = event.value.strip() + if not message: + return + + # Clear input + self.input_widget.value = "" # type: ignore[union-attr] + + # Check for commands + if message.lower() in ["/exit", "/quit"]: + self.exit() + return + elif message.lower() == "/clear": + self.action_clear() + return + elif message.lower() == "/help": + help_text = """Commands: + /clear - Clear the chat log + /help - Show this help message + /exit - Exit the application + /quit - Exit the application + +Tool calls are displayed in cyan with ▶ prefix""" + self._add_system_message(help_text) + return + + # Send to agent (message will be displayed when received back) + self.human_transport.publish(message) + + def action_clear(self) -> None: + """Clear the chat log.""" + self.chat_log.clear() # type: ignore[union-attr] + + def action_quit(self) -> None: # type: ignore[override] + """Quit the application.""" + self._running = False + self.exit() + + +def main() -> None: + """Main entry point for the human CLI.""" + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "web": + # Support for textual-serve web mode + import os + + from textual_serve.server import Server # type: ignore[import-not-found] + + server = Server(f"python {os.path.abspath(__file__)}") + server.serve() + else: + app = HumanCLIApp() + app.run() + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/cli/human/humanclianim.py b/dimos/utils/cli/human/humanclianim.py new file mode 100644 index 0000000000..0d13013baf --- /dev/null +++ b/dimos/utils/cli/human/humanclianim.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import sys +import threading +import time + +from terminaltexteffects import Color # type: ignore[attr-defined] + +from dimos.utils.cli import theme + +# Global to store the imported main function +_humancli_main = None +_import_complete = threading.Event() + +print(theme.ACCENT) + + +def import_cli_in_background() -> None: + """Import the heavy CLI modules in the background""" + global _humancli_main + try: + from dimos.utils.cli.human.humancli import main as humancli_main + + _humancli_main = humancli_main + except Exception as e: + print(f"Failed to import CLI: {e}") + finally: + _import_complete.set() + + +def get_effect_config(effect_name: str): # type: ignore[no-untyped-def] + """Get hardcoded configuration for a specific effect""" + # Hardcoded configs for each effect + global_config = { + "final_gradient_stops": [Color(theme.ACCENT)], + } + + configs = { + "randomsequence": { + "speed": 0.075, + }, + "slide": {"direction": "left", "movement_speed": 1.5}, + "sweep": {"direction": "left"}, + "print": { + "print_speed": 10, + "print_head_return_speed": 10, + "final_gradient_stops": [Color(theme.ACCENT)], + }, + "pour": {"pour_speed": 9}, + "matrix": {"rain_symbols": "01", "rain_fall_speed_range": (4, 7)}, + "decrypt": {"typing_speed": 5, "decryption_speed": 3}, + "burn": {"fire_chars": "█", "flame_color": "ffffff"}, + "expand": {"expand_direction": "center"}, + "scattered": {"movement_speed": 0.5}, + "beams": {"movement_speed": 0.5, "beam_delay": 0}, + "middleout": {"center_movement_speed": 3, "full_movement_speed": 0.5}, + "rain": { + "rain_symbols": "░▒▓█", + "rain_fall_speed_range": (5, 10), + }, + "highlight": {"highlight_brightness": 3}, + } + + return {**configs.get(effect_name, {}), **global_config} # type: ignore[dict-item] + + +def run_banner_animation() -> None: + """Run the ASCII banner animation before launching Textual""" + + # Check if we should animate + random_anim = ["scattered", "print", "expand", "slide", "rain"] + animation_style = os.environ.get("DIMOS_BANNER_ANIMATION", random.choice(random_anim)).lower() + + if animation_style == "none": + return # Skip animation + from terminaltexteffects.effects.effect_beams import Beams + from terminaltexteffects.effects.effect_burn import Burn + from terminaltexteffects.effects.effect_decrypt import Decrypt + from terminaltexteffects.effects.effect_expand import Expand + from terminaltexteffects.effects.effect_highlight import Highlight + from terminaltexteffects.effects.effect_matrix import Matrix + from terminaltexteffects.effects.effect_middleout import MiddleOut + from terminaltexteffects.effects.effect_overflow import Overflow + from terminaltexteffects.effects.effect_pour import Pour + from terminaltexteffects.effects.effect_print import Print + from terminaltexteffects.effects.effect_rain import Rain + from terminaltexteffects.effects.effect_random_sequence import RandomSequence + from terminaltexteffects.effects.effect_scattered import Scattered + from terminaltexteffects.effects.effect_slide import Slide + from terminaltexteffects.effects.effect_sweep import Sweep + + # The DIMENSIONAL ASCII art + ascii_art = "\n" + theme.ascii_logo.replace("\n", "\n ") + # Choose effect based on style + effect_map = { + "slide": Slide, + "sweep": Sweep, + "print": Print, + "pour": Pour, + "burn": Burn, + "matrix": Matrix, + "rain": Rain, + "scattered": Scattered, + "expand": Expand, + "decrypt": Decrypt, + "overflow": Overflow, + "randomsequence": RandomSequence, + "beams": Beams, + "middleout": MiddleOut, + "highlight": Highlight, + } + + EffectClass = effect_map.get(animation_style, Slide) + + # Clear screen before starting animation + print("\033[2J\033[H", end="", flush=True) + + # Get effect configuration + effect_config = get_effect_config(animation_style) + + # Create and run the effect with config + effect = EffectClass(ascii_art) + for key, value in effect_config.items(): + setattr(effect.effect_config, key, value) # type: ignore[attr-defined] + + # Run the animation - terminal.print() handles all screen management + with effect.terminal_output() as terminal: # type: ignore[attr-defined] + for frame in effect: # type: ignore[attr-defined] + terminal.print(frame) + + # Brief pause to see the final frame + time.sleep(0.5) + + # Clear screen for Textual to take over + print("\033[2J\033[H", end="") + + +def main() -> None: + """Main entry point - run animation then launch the real CLI""" + + # Start importing CLI in background (this is slow) + import_thread = threading.Thread(target=import_cli_in_background, daemon=True) + import_thread.start() + + # Run the animation while imports happen (if not in web mode) + if not (len(sys.argv) > 1 and sys.argv[1] == "web"): + run_banner_animation() + + # Wait for import to complete + _import_complete.wait(timeout=10) # Max 10 seconds wait + + # Launch the real CLI + if _humancli_main: + _humancli_main() + else: + # Fallback if threaded import failed + from dimos.utils.cli.human.humancli import main as humancli_main + + humancli_main() + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/cli/lcmspy/lcmspy.py b/dimos/utils/cli/lcmspy/lcmspy.py new file mode 100755 index 0000000000..cf02216986 --- /dev/null +++ b/dimos/utils/cli/lcmspy/lcmspy.py @@ -0,0 +1,212 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from dataclasses import dataclass +from enum import Enum +import threading +import time + +from dimos.protocol.service.lcmservice import LCMConfig, LCMService + + +class BandwidthUnit(Enum): + BP = "B" + KBP = "kB" + MBP = "MB" + GBP = "GB" + + +def human_readable_bytes(bytes_value: float, round_to: int = 2) -> tuple[float, BandwidthUnit]: + """Convert bytes to human-readable format with appropriate units""" + if bytes_value >= 1024**3: # GB + return round(bytes_value / (1024**3), round_to), BandwidthUnit.GBP + elif bytes_value >= 1024**2: # MB + return round(bytes_value / (1024**2), round_to), BandwidthUnit.MBP + elif bytes_value >= 1024: # KB + return round(bytes_value / 1024, round_to), BandwidthUnit.KBP + else: + return round(bytes_value, round_to), BandwidthUnit.BP + + +class Topic: + history_window: float = 60.0 + + def __init__(self, name: str, history_window: float = 60.0) -> None: + self.name = name + # Store (timestamp, data_size) tuples for statistics + self.message_history = deque() # type: ignore[var-annotated] + self.history_window = history_window + # Total traffic accumulator (doesn't get cleaned up) + self.total_traffic_bytes = 0 + + def msg(self, data: bytes) -> None: + # print(f"> msg {self.__str__()} {len(data)} bytes") + datalen = len(data) + self.message_history.append((time.time(), datalen)) + self.total_traffic_bytes += datalen + self._cleanup_old_messages() + + def _cleanup_old_messages(self, max_age: float | None = None) -> None: + """Remove messages older than max_age seconds""" + current_time = time.time() + while self.message_history and current_time - self.message_history[0][0] > ( + max_age or self.history_window + ): + self.message_history.popleft() + + def _get_messages_in_window(self, time_window: float): # type: ignore[no-untyped-def] + """Get messages within the specified time window""" + current_time = time.time() + cutoff_time = current_time - time_window + return [(ts, size) for ts, size in self.message_history if ts >= cutoff_time] + + # avg msg freq in the last n seconds + def freq(self, time_window: float) -> float: + messages = self._get_messages_in_window(time_window) + if not messages: + return 0.0 + return len(messages) / time_window + + # avg bandwidth in kB/s in the last n seconds + def kbps(self, time_window: float) -> float: + messages = self._get_messages_in_window(time_window) + if not messages: + return 0.0 + total_bytes = sum(size for _, size in messages) + total_kbytes = total_bytes / 1000 # Convert bytes to kB + return total_kbytes / time_window # type: ignore[no-any-return] + + def kbps_hr(self, time_window: float, round_to: int = 2) -> tuple[float, BandwidthUnit]: + """Return human-readable bandwidth with appropriate units""" + kbps_val = self.kbps(time_window) + # Convert kB/s to B/s for human_readable_bytes + bps = kbps_val * 1000 + return human_readable_bytes(bps, round_to) + + # avg msg size in the last n seconds + def size(self, time_window: float) -> float: + messages = self._get_messages_in_window(time_window) + if not messages: + return 0.0 + total_size = sum(size for _, size in messages) + return total_size / len(messages) # type: ignore[no-any-return] + + def total_traffic(self) -> int: + """Return total traffic passed in bytes since the beginning""" + return self.total_traffic_bytes + + def total_traffic_hr(self) -> tuple[float, BandwidthUnit]: + """Return human-readable total traffic with appropriate units""" + total_bytes = self.total_traffic() + return human_readable_bytes(total_bytes) + + def __str__(self) -> str: + return f"topic({self.name})" + + +@dataclass +class LCMSpyConfig(LCMConfig): + topic_history_window: float = 60.0 + + +class LCMSpy(LCMService, Topic): + default_config = LCMSpyConfig + topic = dict[str, Topic] + graph_log_window: float = 1.0 + topic_class: type[Topic] = Topic + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + Topic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + self.topic = {} # type: ignore[assignment] + + def start(self) -> None: + super().start() + self.l.subscribe(".*", self.msg) # type: ignore[union-attr] + + def stop(self) -> None: + """Stop the LCM spy and clean up resources""" + super().stop() + + def msg(self, topic, data) -> None: # type: ignore[no-untyped-def, override] + Topic.msg(self, data) + + if topic not in self.topic: # type: ignore[operator] + print(self.config) + self.topic[topic] = self.topic_class( # type: ignore[assignment, call-arg] + topic, + history_window=self.config.topic_history_window, # type: ignore[attr-defined] + ) + self.topic[topic].msg(data) # type: ignore[attr-defined, type-arg] + + +class GraphTopic(Topic): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.freq_history = deque(maxlen=20) # type: ignore[var-annotated] + self.bandwidth_history = deque(maxlen=20) # type: ignore[var-annotated] + + def update_graphs(self, step_window: float = 1.0) -> None: + """Update historical data for graphing""" + freq = self.freq(step_window) + kbps = self.kbps(step_window) + self.freq_history.append(freq) + self.bandwidth_history.append(kbps) + + +@dataclass +class GraphLCMSpyConfig(LCMSpyConfig): + graph_log_window: float = 1.0 + + +class GraphLCMSpy(LCMSpy, GraphTopic): + default_config = GraphLCMSpyConfig + + graph_log_thread: threading.Thread | None = None + graph_log_stop_event: threading.Event = threading.Event() + topic_class: type[Topic] = GraphTopic + + def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(**kwargs) + GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + + def start(self) -> None: + super().start() + self.graph_log_thread = threading.Thread(target=self.graph_log, daemon=True) + self.graph_log_thread.start() + + def graph_log(self) -> None: + while not self.graph_log_stop_event.is_set(): + self.update_graphs(self.config.graph_log_window) # type: ignore[attr-defined] # Update global history + for topic in self.topic.values(): # type: ignore[call-arg] + topic.update_graphs(self.config.graph_log_window) # type: ignore[attr-defined] + time.sleep(self.config.graph_log_window) # type: ignore[attr-defined] + + def stop(self) -> None: + """Stop the graph logging and LCM spy""" + self.graph_log_stop_event.set() + if self.graph_log_thread and self.graph_log_thread.is_alive(): + self.graph_log_thread.join(timeout=1.0) + super().stop() + + +if __name__ == "__main__": + lcm_spy = LCMSpy() + lcm_spy.start() + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("LCM Spy stopped.") diff --git a/dimos/utils/cli/lcmspy/run_lcmspy.py b/dimos/utils/cli/lcmspy/run_lcmspy.py new file mode 100644 index 0000000000..f3d31b48ba --- /dev/null +++ b/dimos/utils/cli/lcmspy/run_lcmspy.py @@ -0,0 +1,135 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from rich.text import Text +from textual.app import App, ComposeResult +from textual.color import Color +from textual.widgets import DataTable + +from dimos.utils.cli import theme +from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy, GraphTopic as SpyTopic + + +def gradient(max_value: float, value: float) -> str: + """Gradient from cyan (low) to yellow (high) using DimOS theme colors""" + ratio = min(value / max_value, 1.0) + # Parse hex colors from theme + cyan = Color.parse(theme.CYAN) + yellow = Color.parse(theme.YELLOW) + color = cyan.blend(yellow, ratio) + + return color.hex + + +def topic_text(topic_name: str) -> Text: + """Format topic name with DimOS theme colors""" + if "#" in topic_name: + parts = topic_name.split("#", 1) + return Text(parts[0], style=theme.BRIGHT_WHITE) + Text("#" + parts[1], style=theme.BLUE) + + if topic_name[:4] == "/rpc": + return Text(topic_name[:4], style=theme.BLUE) + Text( + topic_name[4:], style=theme.BRIGHT_WHITE + ) + + return Text(topic_name, style=theme.BRIGHT_WHITE) + + +class LCMSpyApp(App): # type: ignore[type-arg] + """A real-time CLI dashboard for LCM traffic statistics using Textual.""" + + CSS_PATH = "../dimos.tcss" + + CSS = f""" + Screen {{ + layout: vertical; + background: {theme.BACKGROUND}; + }} + DataTable {{ + height: 2fr; + width: 1fr; + border: solid {theme.BORDER}; + background: {theme.BG}; + scrollbar-size: 0 0; + }} + DataTable > .datatable--header {{ + color: {theme.ACCENT}; + background: transparent; + }} + """ + + refresh_interval: float = 0.5 # seconds + + BINDINGS = [ + ("q", "quit"), + ("ctrl+c", "quit"), + ] + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.spy = GraphLCMSpy(autoconf=True, graph_log_window=0.5) + self.table: DataTable | None = None # type: ignore[type-arg] + + def compose(self) -> ComposeResult: + self.table = DataTable(zebra_stripes=False, cursor_type=None) # type: ignore[arg-type] + self.table.add_column("Topic") + self.table.add_column("Freq (Hz)") + self.table.add_column("Bandwidth") + self.table.add_column("Total Traffic") + yield self.table + + def on_mount(self) -> None: + self.spy.start() + self.set_interval(self.refresh_interval, self.refresh_table) + + async def on_unmount(self) -> None: + self.spy.stop() + + def refresh_table(self) -> None: + topics: list[SpyTopic] = list(self.spy.topic.values()) # type: ignore[arg-type, call-arg] + topics.sort(key=lambda t: t.total_traffic(), reverse=True) + self.table.clear(columns=False) # type: ignore[union-attr] + + for t in topics: + freq = t.freq(5.0) + kbps = t.kbps(5.0) + bw_val, bw_unit = t.kbps_hr(5.0) + total_val, total_unit = t.total_traffic_hr() + + self.table.add_row( # type: ignore[union-attr] + topic_text(t.name), + Text(f"{freq:.1f}", style=gradient(10, freq)), + Text(f"{bw_val} {bw_unit.value}/s", style=gradient(1024 * 3, kbps)), + Text(f"{total_val} {total_unit.value}"), + ) + + +def main() -> None: + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "web": + import os + + from textual_serve.server import Server # type: ignore[import-not-found] + + server = Server(f"python {os.path.abspath(__file__)}") + server.serve() + else: + LCMSpyApp().run() + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/cli/lcmspy/test_lcmspy.py b/dimos/utils/cli/lcmspy/test_lcmspy.py new file mode 100644 index 0000000000..3016a723fe --- /dev/null +++ b/dimos/utils/cli/lcmspy/test_lcmspy.py @@ -0,0 +1,221 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.protocol.pubsub.lcmpubsub import PickleLCM, Topic +from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy, GraphTopic, LCMSpy, Topic as TopicSpy + + +@pytest.mark.lcm +def test_spy_basic() -> None: + lcm = PickleLCM(autoconf=True) + lcm.start() + + lcmspy = LCMSpy(autoconf=True) + lcmspy.start() + + video_topic = Topic(topic="/video") + odom_topic = Topic(topic="/odom") + + for i in range(5): + lcm.publish(video_topic, f"video frame {i}") + time.sleep(0.1) + if i % 2 == 0: + lcm.publish(odom_topic, f"odometry data {i / 2}") + + # Wait a bit for messages to be processed + time.sleep(0.5) + + # Test statistics for video topic + video_topic_spy = lcmspy.topic["/video"] + assert video_topic_spy is not None + + # Test frequency (should be around 10 Hz for 5 messages in ~0.5 seconds) + freq = video_topic_spy.freq(1.0) + assert freq > 0 + print(f"Video topic frequency: {freq:.2f} Hz") + + # Test bandwidth + kbps = video_topic_spy.kbps(1.0) + assert kbps > 0 + print(f"Video topic bandwidth: {kbps:.2f} kbps") + + # Test average message size + avg_size = video_topic_spy.size(1.0) + assert avg_size > 0 + print(f"Video topic average message size: {avg_size:.2f} bytes") + + # Test statistics for odom topic + odom_topic_spy = lcmspy.topic["/odom"] + assert odom_topic_spy is not None + + freq = odom_topic_spy.freq(1.0) + assert freq > 0 + print(f"Odom topic frequency: {freq:.2f} Hz") + + kbps = odom_topic_spy.kbps(1.0) + assert kbps > 0 + print(f"Odom topic bandwidth: {kbps:.2f} kbps") + + avg_size = odom_topic_spy.size(1.0) + assert avg_size > 0 + print(f"Odom topic average message size: {avg_size:.2f} bytes") + + print(f"Video topic: {video_topic_spy}") + print(f"Odom topic: {odom_topic_spy}") + + +@pytest.mark.lcm +def test_topic_statistics_direct() -> None: + """Test Topic statistics directly without LCM""" + + topic = TopicSpy("/test") + + # Add some test messages + test_data = [b"small", b"medium sized message", b"very long message for testing purposes"] + + for _i, data in enumerate(test_data): + topic.msg(data) + time.sleep(0.1) # Simulate time passing + + # Test statistics over 1 second window + freq = topic.freq(1.0) + kbps = topic.kbps(1.0) + avg_size = topic.size(1.0) + + assert freq > 0 + assert kbps > 0 + assert avg_size > 0 + + print(f"Direct test - Frequency: {freq:.2f} Hz") + print(f"Direct test - Bandwidth: {kbps:.2f} kbps") + print(f"Direct test - Avg size: {avg_size:.2f} bytes") + + +def test_topic_cleanup() -> None: + """Test that old messages are properly cleaned up""" + + topic = TopicSpy("/test") + + # Add a message + topic.msg(b"test message") + initial_count = len(topic.message_history) + assert initial_count == 1 + + # Simulate time passing by manually adding old timestamps + old_time = time.time() - 70 # 70 seconds ago + topic.message_history.appendleft((old_time, 10)) + + # Trigger cleanup + topic._cleanup_old_messages(max_age=60.0) + + # Should only have the recent message + assert len(topic.message_history) == 1 + assert topic.message_history[0][0] > time.time() - 10 # Recent message + + +@pytest.mark.lcm +def test_graph_topic_basic() -> None: + """Test GraphTopic basic functionality""" + topic = GraphTopic("/test_graph") + + # Add some messages and update graphs + topic.msg(b"test message") + topic.update_graphs(1.0) + + # Should have history data + assert len(topic.freq_history) == 1 + assert len(topic.bandwidth_history) == 1 + assert topic.freq_history[0] > 0 + assert topic.bandwidth_history[0] > 0 + + +@pytest.mark.lcm +def test_graph_lcmspy_basic() -> None: + """Test GraphLCMSpy basic functionality""" + spy = GraphLCMSpy(autoconf=True, graph_log_window=0.1) + spy.start() + time.sleep(0.2) # Wait for thread to start + + # Simulate a message + spy.msg("/test", b"test data") + time.sleep(0.2) # Wait for graph update + + # Should create GraphTopic with history + topic = spy.topic["/test"] + assert isinstance(topic, GraphTopic) + assert len(topic.freq_history) > 0 + assert len(topic.bandwidth_history) > 0 + + spy.stop() + + +@pytest.mark.lcm +def test_lcmspy_global_totals() -> None: + """Test that LCMSpy tracks global totals as a Topic itself""" + spy = LCMSpy(autoconf=True) + spy.start() + + # Send messages to different topics + spy.msg("/video", b"video frame data") + spy.msg("/odom", b"odometry data") + spy.msg("/imu", b"imu data") + + # The spy itself should have accumulated all messages + assert len(spy.message_history) == 3 + + # Check global statistics + global_freq = spy.freq(1.0) + global_kbps = spy.kbps(1.0) + global_size = spy.size(1.0) + + assert global_freq > 0 + assert global_kbps > 0 + assert global_size > 0 + + print(f"Global frequency: {global_freq:.2f} Hz") + print(f"Global bandwidth: {spy.kbps_hr(1.0)}") + print(f"Global avg message size: {global_size:.0f} bytes") + + spy.stop() + + +@pytest.mark.lcm +def test_graph_lcmspy_global_totals() -> None: + """Test that GraphLCMSpy tracks global totals with history""" + spy = GraphLCMSpy(autoconf=True, graph_log_window=0.1) + spy.start() + time.sleep(0.2) + + # Send messages + spy.msg("/video", b"video frame data") + spy.msg("/odom", b"odometry data") + time.sleep(0.2) # Wait for graph update + + # Update global graphs + spy.update_graphs(1.0) + + # Should have global history + assert len(spy.freq_history) == 1 + assert len(spy.bandwidth_history) == 1 + assert spy.freq_history[0] > 0 + assert spy.bandwidth_history[0] > 0 + + print(f"Global frequency history: {spy.freq_history[0]:.2f} Hz") + print(f"Global bandwidth history: {spy.bandwidth_history[0]:.2f} kB/s") + + spy.stop() diff --git a/dimos/utils/cli/skillspy/demo_skillspy.py b/dimos/utils/cli/skillspy/demo_skillspy.py new file mode 100644 index 0000000000..602381020a --- /dev/null +++ b/dimos/utils/cli/skillspy/demo_skillspy.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demo script that runs skills in the background while agentspy monitors them.""" + +import threading +import time + +from dimos.protocol.skill.coordinator import SkillCoordinator +from dimos.protocol.skill.skill import SkillContainer, skill + + +class DemoSkills(SkillContainer): + @skill() + def count_to(self, n: int) -> str: + """Count to n with delays.""" + for _i in range(n): + time.sleep(0.5) + return f"Counted to {n}" + + @skill() + def compute_fibonacci(self, n: int) -> int: + """Compute nth fibonacci number.""" + if n <= 1: + return n + a, b = 0, 1 + for _ in range(2, n + 1): + time.sleep(0.1) # Simulate computation + a, b = b, a + b + return b + + @skill() + def simulate_error(self) -> None: + """Skill that always errors.""" + time.sleep(0.3) + raise RuntimeError("Simulated error for testing") + + @skill() + def quick_task(self, name: str) -> str: + """Quick task that completes fast.""" + time.sleep(0.1) + return f"Quick task '{name}' done!" + + +def run_demo_skills() -> None: + """Run demo skills in background.""" + # Create and start agent interface + agent_interface = SkillCoordinator() + agent_interface.start() + + # Register skills + demo_skills = DemoSkills() + agent_interface.register_skills(demo_skills) + + # Run various skills periodically + def skill_runner() -> None: + counter = 0 + while True: + time.sleep(2) + + # Generate unique call_id for each invocation + call_id = f"demo-{counter}" + + # Run different skills based on counter + if counter % 4 == 0: + # Run multiple count_to in parallel to show parallel execution + agent_interface.call_skill(f"{call_id}-count-1", "count_to", {"args": [3]}) + agent_interface.call_skill(f"{call_id}-count-2", "count_to", {"args": [5]}) + agent_interface.call_skill(f"{call_id}-count-3", "count_to", {"args": [2]}) + elif counter % 4 == 1: + agent_interface.call_skill(f"{call_id}-fib", "compute_fibonacci", {"args": [10]}) + elif counter % 4 == 2: + agent_interface.call_skill( + f"{call_id}-quick", "quick_task", {"args": [f"task-{counter}"]} + ) + else: + agent_interface.call_skill(f"{call_id}-error", "simulate_error", {}) + + counter += 1 + + # Start skill runner in background + thread = threading.Thread(target=skill_runner, daemon=True) + thread.start() + + print("Demo skills running in background. Start agentspy in another terminal to monitor.") + print("Run: agentspy") + + # Keep running + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nDemo stopped.") + + agent_interface.stop() + + +if __name__ == "__main__": + run_demo_skills() diff --git a/dimos/utils/cli/skillspy/skillspy.py b/dimos/utils/cli/skillspy/skillspy.py new file mode 100644 index 0000000000..beb2421eec --- /dev/null +++ b/dimos/utils/cli/skillspy/skillspy.py @@ -0,0 +1,281 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING + +from rich.text import Text +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.widgets import DataTable, Footer + +from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateEnum +from dimos.utils.cli import theme + +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.protocol.skill.comms import SkillMsg # type: ignore[attr-defined] + + +class AgentSpy: + """Spy on agent skill executions via LCM messages.""" + + def __init__(self) -> None: + self.agent_interface = SkillCoordinator() + self.message_callbacks: list[Callable[[dict[str, SkillState]], None]] = [] + self._lock = threading.Lock() + self._latest_state: dict[str, SkillState] = {} + self._running = False + + def start(self) -> None: + """Start spying on agent messages.""" + self._running = True + # Start the agent interface + self.agent_interface.start() + + # Subscribe to the agent interface's comms + self.agent_interface.skill_transport.subscribe(self._handle_message) + + def stop(self) -> None: + """Stop spying.""" + self._running = False + # Give threads a moment to finish processing + time.sleep(0.2) + self.agent_interface.stop() + + def _handle_message(self, msg: SkillMsg) -> None: # type: ignore[type-arg] + """Handle incoming skill messages.""" + if not self._running: + return + + # Small delay to ensure agent_interface has processed the message + def delayed_update() -> None: + time.sleep(0.1) + if not self._running: + return + with self._lock: + self._latest_state = self.agent_interface.generate_snapshot(clear=False) + for callback in self.message_callbacks: + callback(self._latest_state) + + # Run in separate thread to not block LCM + threading.Thread(target=delayed_update, daemon=True).start() + + def subscribe(self, callback: Callable[[dict[str, SkillState]], None]) -> None: + """Subscribe to state updates.""" + self.message_callbacks.append(callback) + + def get_state(self) -> dict[str, SkillState]: + """Get current state snapshot.""" + with self._lock: + return self._latest_state.copy() + + +def state_color(state: SkillStateEnum) -> str: + """Get color for skill state.""" + if state == SkillStateEnum.pending: + return theme.WARNING + elif state == SkillStateEnum.running: + return theme.AGENT + elif state == SkillStateEnum.completed: + return theme.SUCCESS + elif state == SkillStateEnum.error: + return theme.ERROR + return theme.FOREGROUND + + +def format_duration(duration: float) -> str: + """Format duration in human readable format.""" + if duration < 1: + return f"{duration * 1000:.0f}ms" + elif duration < 60: + return f"{duration:.1f}s" + elif duration < 3600: + return f"{duration / 60:.1f}m" + else: + return f"{duration / 3600:.1f}h" + + +class AgentSpyApp(App): # type: ignore[type-arg] + """A real-time CLI dashboard for agent skill monitoring using Textual.""" + + CSS_PATH = theme.CSS_PATH + + CSS = f""" + Screen {{ + layout: vertical; + background: {theme.BACKGROUND}; + }} + DataTable {{ + height: 100%; + border: solid $border; + background: {theme.BACKGROUND}; + }} + DataTable > .datatable--header {{ + background: transparent; + }} + Footer {{ + background: transparent; + }} + """ + + BINDINGS = [ + Binding("q", "quit", "Quit"), + Binding("c", "clear", "Clear History"), + Binding("ctrl+c", "quit", "Quit", show=False), + ] + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.spy = AgentSpy() + self.table: DataTable | None = None # type: ignore[type-arg] + self.skill_history: list[tuple[str, SkillState, float]] = [] # (call_id, state, start_time) + + def compose(self) -> ComposeResult: + self.table = DataTable(zebra_stripes=False, cursor_type=None) # type: ignore[arg-type] + self.table.add_column("Call ID") + self.table.add_column("Skill Name") + self.table.add_column("State") + self.table.add_column("Duration") + self.table.add_column("Messages") + self.table.add_column("Details") + + yield self.table + yield Footer() + + def on_mount(self) -> None: + """Start the spy when app mounts.""" + self.spy.subscribe(self.update_state) + self.spy.start() + + # Set up periodic refresh to update durations + self.set_interval(1.0, self.refresh_table) + + def on_unmount(self) -> None: + """Stop the spy when app unmounts.""" + self.spy.stop() + + def update_state(self, state: dict[str, SkillState]) -> None: + """Update state from spy callback. State dict is keyed by call_id.""" + # Update history with current state + current_time = time.time() + + # Add new skills or update existing ones + for call_id, skill_state in state.items(): + # Find if this call_id already in history + found = False + for i, (existing_call_id, _old_state, start_time) in enumerate(self.skill_history): + if existing_call_id == call_id: + # Update existing entry + self.skill_history[i] = (call_id, skill_state, start_time) + found = True + break + + if not found: + # Add new entry with current time as start + start_time = current_time + if skill_state.start_msg: + # Use start message timestamp if available + start_time = skill_state.start_msg.ts + self.skill_history.append((call_id, skill_state, start_time)) + + # Schedule UI update + self.call_from_thread(self.refresh_table) + + def refresh_table(self) -> None: + """Refresh the table display.""" + if not self.table: + return + + # Clear table + self.table.clear(columns=False) + + # Sort by start time (newest first) + sorted_history = sorted(self.skill_history, key=lambda x: x[2], reverse=True) + + # Get terminal height and calculate how many rows we can show + height = self.size.height - 6 # Account for header, footer, column headers + max_rows = max(1, height) + + # Show only top N entries + for call_id, skill_state, start_time in sorted_history[:max_rows]: + # Calculate how long ago it started (for progress indicator) + time_ago = time.time() - start_time + + # Duration + duration_str = format_duration(skill_state.duration()) + + # Message count + msg_count = len(skill_state) + + # Details based on state and last message + details = "" + if skill_state.state == SkillStateEnum.error and skill_state.error_msg: + # Show error message + error_content = skill_state.error_msg.content + if isinstance(error_content, dict): + details = error_content.get("msg", "Error")[:40] + else: + details = str(error_content)[:40] + elif skill_state.state == SkillStateEnum.completed and skill_state.ret_msg: + # Show return value + details = f"→ {str(skill_state.ret_msg.content)[:37]}" + elif skill_state.state == SkillStateEnum.running: + # Show progress indicator + details = "⋯ " + "▸" * min(int(time_ago), 20) + + # Format call_id for display (truncate if too long) + display_call_id = call_id + if len(call_id) > 16: + display_call_id = call_id[:13] + "..." + + # Add row with colored state + self.table.add_row( + Text(display_call_id, style=theme.BRIGHT_BLUE), + Text(skill_state.name, style=theme.YELLOW), + Text(skill_state.state.name, style=state_color(skill_state.state)), + Text(duration_str, style=theme.WHITE), + Text(str(msg_count), style=theme.YELLOW), + Text(details, style=theme.FOREGROUND), + ) + + def action_clear(self) -> None: + """Clear the skill history.""" + self.skill_history.clear() + self.refresh_table() + + +def main() -> None: + """Main entry point for agentspy CLI.""" + import sys + + # Check if running in web mode + if len(sys.argv) > 1 and sys.argv[1] == "web": + import os + + from textual_serve.server import Server # type: ignore[import-not-found] + + server = Server(f"python {os.path.abspath(__file__)}") + server.serve() + else: + app = AgentSpyApp() + app.run() + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/cli/theme.py b/dimos/utils/cli/theme.py new file mode 100644 index 0000000000..b6b6b9ccae --- /dev/null +++ b/dimos/utils/cli/theme.py @@ -0,0 +1,108 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parse DimOS theme from tcss file.""" + +from __future__ import annotations + +from pathlib import Path +import re + + +def parse_tcss_colors(tcss_path: str | Path) -> dict[str, str]: + """Parse color variables from a tcss file. + + Args: + tcss_path: Path to the tcss file + + Returns: + Dictionary mapping variable names to color values + """ + tcss_path = Path(tcss_path) + content = tcss_path.read_text() + + # Match $variable: value; patterns + pattern = r"\$([a-zA-Z0-9_-]+)\s*:\s*(#[0-9a-fA-F]{6}|#[0-9a-fA-F]{3});" + matches = re.findall(pattern, content) + + return {name: value for name, value in matches} + + +# Load DimOS theme colors +_THEME_PATH = Path(__file__).parent / "dimos.tcss" +COLORS = parse_tcss_colors(_THEME_PATH) + +# Export CSS path for Textual apps +CSS_PATH = str(_THEME_PATH) + + +# Convenience accessors for common colors +def get(name: str, default: str = "#ffffff") -> str: + """Get a color by variable name.""" + return COLORS.get(name, default) + + +# Base color palette +BLACK = COLORS.get("black", "#0b0f0f") +RED = COLORS.get("red", "#ff0000") +GREEN = COLORS.get("green", "#00eeee") +YELLOW = COLORS.get("yellow", "#ffcc00") +BLUE = COLORS.get("blue", "#5c9ff0") +PURPLE = COLORS.get("purple", "#00eeee") +CYAN = COLORS.get("cyan", "#00eeee") +WHITE = COLORS.get("white", "#b5e4f4") + +# Bright colors +BRIGHT_BLACK = COLORS.get("bright-black", "#404040") +BRIGHT_RED = COLORS.get("bright-red", "#ff0000") +BRIGHT_GREEN = COLORS.get("bright-green", "#00eeee") +BRIGHT_YELLOW = COLORS.get("bright-yellow", "#f2ea8c") +BRIGHT_BLUE = COLORS.get("bright-blue", "#8cbdf2") +BRIGHT_PURPLE = COLORS.get("bright-purple", "#00eeee") +BRIGHT_CYAN = COLORS.get("bright-cyan", "#00eeee") +BRIGHT_WHITE = COLORS.get("bright-white", "#ffffff") + +# Core theme colors +BACKGROUND = COLORS.get("background", "#0b0f0f") +FOREGROUND = COLORS.get("foreground", "#b5e4f4") +CURSOR = COLORS.get("cursor", "#00eeee") + +# Semantic aliases +BG = COLORS.get("bg", "#0b0f0f") +BORDER = COLORS.get("border", "#00eeee") +ACCENT = COLORS.get("accent", "#b5e4f4") +DIM = COLORS.get("dim", "#404040") +TIMESTAMP = COLORS.get("timestamp", "#ffffff") + +# Message type colors +SYSTEM = COLORS.get("system", "#ff0000") +AGENT = COLORS.get("agent", "#88ff88") +TOOL = COLORS.get("tool", "#00eeee") +TOOL_RESULT = COLORS.get("tool-result", "#ffff00") +HUMAN = COLORS.get("human", "#ffffff") + +# Status colors +SUCCESS = COLORS.get("success", "#00eeee") +ERROR = COLORS.get("error", "#ff0000") +WARNING = COLORS.get("warning", "#ffcc00") +INFO = COLORS.get("info", "#00eeee") + +ascii_logo = """ + ▇▇▇▇▇▇╗ ▇▇╗▇▇▇╗ ▇▇▇╗▇▇▇▇▇▇▇╗▇▇▇╗ ▇▇╗▇▇▇▇▇▇▇╗▇▇╗ ▇▇▇▇▇▇╗ ▇▇▇╗ ▇▇╗ ▇▇▇▇▇╗ ▇▇╗ + ▇▇╔══▇▇╗▇▇║▇▇▇▇╗ ▇▇▇▇║▇▇╔════╝▇▇▇▇╗ ▇▇║▇▇╔════╝▇▇║▇▇╔═══▇▇╗▇▇▇▇╗ ▇▇║▇▇╔══▇▇╗▇▇║ + ▇▇║ ▇▇║▇▇║▇▇╔▇▇▇▇╔▇▇║▇▇▇▇▇╗ ▇▇╔▇▇╗ ▇▇║▇▇▇▇▇▇▇╗▇▇║▇▇║ ▇▇║▇▇╔▇▇╗ ▇▇║▇▇▇▇▇▇▇║▇▇║ + ▇▇║ ▇▇║▇▇║▇▇║╚▇▇╔╝▇▇║▇▇╔══╝ ▇▇║╚▇▇╗▇▇║╚════▇▇║▇▇║▇▇║ ▇▇║▇▇║╚▇▇╗▇▇║▇▇╔══▇▇║▇▇║ + ▇▇▇▇▇▇╔╝▇▇║▇▇║ ╚═╝ ▇▇║▇▇▇▇▇▇▇╗▇▇║ ╚▇▇▇▇║▇▇▇▇▇▇▇║▇▇║╚▇▇▇▇▇▇╔╝▇▇║ ╚▇▇▇▇║▇▇║ ▇▇║▇▇▇▇▇▇▇╗ + ╚═════╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝ +""" diff --git a/dimos/utils/data.py b/dimos/utils/data.py new file mode 100644 index 0000000000..e424e0af91 --- /dev/null +++ b/dimos/utils/data.py @@ -0,0 +1,159 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cache +from pathlib import Path +import subprocess +import tarfile + + +@cache +def _get_repo_root() -> Path: + try: + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], capture_output=True, check=True, text=True + ) + return Path(result.stdout.strip()) + except subprocess.CalledProcessError: + raise RuntimeError("Not in a Git repository") + + +@cache +def _get_data_dir(extra_path: str | None = None) -> Path: + if extra_path: + return _get_repo_root() / "data" / extra_path + return _get_repo_root() / "data" + + +@cache +def _get_lfs_dir() -> Path: + return _get_data_dir() / ".lfs" + + +def _check_git_lfs_available() -> bool: + try: + subprocess.run(["git", "lfs", "version"], capture_output=True, check=True, text=True) + except (subprocess.CalledProcessError, FileNotFoundError): + raise RuntimeError( + "Git LFS is not installed. Please install git-lfs to use test data utilities.\n" + "Installation instructions: https://git-lfs.github.io/" + ) + return True + + +def _is_lfs_pointer_file(file_path: Path) -> bool: + try: + # LFS pointer files are small (typically < 200 bytes) and start with specific text + if file_path.stat().st_size > 1024: # LFS pointers are much smaller + return False + + with open(file_path, encoding="utf-8") as f: + first_line = f.readline().strip() + return first_line.startswith("version https://git-lfs.github.com/spec/") + + except (UnicodeDecodeError, OSError): + return False + + +def _lfs_pull(file_path: Path, repo_root: Path) -> None: + try: + relative_path = file_path.relative_to(repo_root) + + subprocess.run( + ["git", "lfs", "pull", "--include", str(relative_path)], + cwd=repo_root, + check=True, + capture_output=True, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to pull LFS file {file_path}: {e}") + + return None + + +def _decompress_archive(filename: str | Path) -> Path: + target_dir = _get_data_dir() + filename_path = Path(filename) + with tarfile.open(filename_path, "r:gz") as tar: + tar.extractall(target_dir) + return target_dir / filename_path.name.replace(".tar.gz", "") + + +def _pull_lfs_archive(filename: str | Path) -> Path: + # Check Git LFS availability first + _check_git_lfs_available() + + # Find repository root + repo_root = _get_repo_root() + + # Construct path to test data file + file_path = _get_lfs_dir() / (str(filename) + ".tar.gz") + + # Check if file exists + if not file_path.exists(): + raise FileNotFoundError( + f"Test file '{filename}' not found at {file_path}. " + f"Make sure the file is committed to Git LFS in the tests/data directory." + ) + + # If it's an LFS pointer file, ensure LFS is set up and pull the file + if _is_lfs_pointer_file(file_path): + _lfs_pull(file_path, repo_root) + + # Verify the file was actually downloaded + if _is_lfs_pointer_file(file_path): + raise RuntimeError( + f"Failed to download LFS file '{filename}'. The file is still a pointer after attempting to pull." + ) + + return file_path + + +def get_data(filename: str | Path) -> Path: + """ + Get the path to a test data, downloading from LFS if needed. + + This function will: + 1. Check that Git LFS is available + 2. Locate the file in the tests/data directory + 3. Initialize Git LFS if needed + 4. Download the file from LFS if it's a pointer file + 5. Return the Path object to the actual file or dir + + Args: + filename: Name of the test file (e.g., "lidar_sample.bin") + + Returns: + Path: Path object to the test file + + Raises: + RuntimeError: If Git LFS is not available or LFS operations fail + FileNotFoundError: If the test file doesn't exist + + Usage: + # As string path + file_path = str(testFile("sample.bin")) + + # As context manager for file operations + with testFile("sample.bin").open('rb') as f: + data = f.read() + """ + data_dir = _get_data_dir() + file_path = data_dir / filename + + # already pulled and decompressed, return it directly + if file_path.exists(): + return file_path + + return _decompress_archive(_pull_lfs_archive(filename)) diff --git a/dimos/utils/decorators/__init__.py b/dimos/utils/decorators/__init__.py new file mode 100644 index 0000000000..ee17260c20 --- /dev/null +++ b/dimos/utils/decorators/__init__.py @@ -0,0 +1,12 @@ +"""Decorators and accumulators for rate limiting and other utilities.""" + +from .accumulators import Accumulator, LatestAccumulator, RollingAverageAccumulator +from .decorators import limit, retry + +__all__ = [ + "Accumulator", + "LatestAccumulator", + "RollingAverageAccumulator", + "limit", + "retry", +] diff --git a/dimos/utils/decorators/accumulators.py b/dimos/utils/decorators/accumulators.py new file mode 100644 index 0000000000..75cb25661d --- /dev/null +++ b/dimos/utils/decorators/accumulators.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +import threading +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class Accumulator(ABC, Generic[T]): + """Base class for accumulating messages between rate-limited calls.""" + + @abstractmethod + def add(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + """Add args and kwargs to the accumulator.""" + pass + + @abstractmethod + def get(self) -> tuple[tuple, dict] | None: # type: ignore[type-arg] + """Get the accumulated args and kwargs and reset the accumulator.""" + pass + + @abstractmethod + def __len__(self) -> int: + """Return the number of accumulated items.""" + pass + + +class LatestAccumulator(Accumulator[T]): + """Simple accumulator that remembers only the latest args and kwargs.""" + + def __init__(self) -> None: + self._latest: tuple[tuple, dict] | None = None # type: ignore[type-arg] + self._lock = threading.Lock() + + def add(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + with self._lock: + self._latest = (args, kwargs) + + def get(self) -> tuple[tuple, dict] | None: # type: ignore[type-arg] + with self._lock: + result = self._latest + self._latest = None + return result + + def __len__(self) -> int: + with self._lock: + return 1 if self._latest is not None else 0 + + +class RollingAverageAccumulator(Accumulator[T]): + """Accumulator that maintains a rolling average of the first argument. + + This accumulator expects the first argument to be numeric and maintains + a rolling average without storing individual values. + """ + + def __init__(self) -> None: + self._sum: float = 0.0 + self._count: int = 0 + self._latest_kwargs: dict = {} # type: ignore[type-arg] + self._lock = threading.Lock() + + def add(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + if not args: + raise ValueError("RollingAverageAccumulator requires at least one argument") + + with self._lock: + try: + value = float(args[0]) + self._sum += value + self._count += 1 + self._latest_kwargs = kwargs + except (TypeError, ValueError): + raise TypeError(f"First argument must be numeric, got {type(args[0])}") + + def get(self) -> tuple[tuple, dict] | None: # type: ignore[type-arg] + with self._lock: + if self._count == 0: + return None + + average = self._sum / self._count + result = ((average,), self._latest_kwargs) + + # Reset accumulator + self._sum = 0.0 + self._count = 0 + self._latest_kwargs = {} + + return result + + def __len__(self) -> int: + with self._lock: + return self._count diff --git a/dimos/utils/decorators/decorators.py b/dimos/utils/decorators/decorators.py new file mode 100644 index 0000000000..6c31979d16 --- /dev/null +++ b/dimos/utils/decorators/decorators.py @@ -0,0 +1,201 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from functools import wraps +import threading +import time + +from .accumulators import Accumulator, LatestAccumulator + + +def limit(max_freq: float, accumulator: Accumulator | None = None): # type: ignore[no-untyped-def, type-arg] + """ + Decorator that limits function call frequency. + + If calls come faster than max_freq, they are skipped. + If calls come slower than max_freq, they pass through immediately. + + Args: + max_freq: Maximum frequency in Hz (calls per second) + accumulator: Optional accumulator to collect skipped calls (defaults to LatestAccumulator) + + Returns: + Decorated function that respects the frequency limit + """ + if max_freq <= 0: + raise ValueError("Frequency must be positive") + + min_interval = 1.0 / max_freq + + # Create default accumulator if none provided + if accumulator is None: + accumulator = LatestAccumulator() + + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] + last_call_time = 0.0 + lock = threading.Lock() + timer: threading.Timer | None = None + + def execute_accumulated() -> None: + nonlocal last_call_time, timer + with lock: + if len(accumulator): + acc_args, acc_kwargs = accumulator.get() # type: ignore[misc] + last_call_time = time.time() + timer = None + func(*acc_args, **acc_kwargs) + + @wraps(func) + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] + nonlocal last_call_time, timer + current_time = time.time() + + with lock: + time_since_last = current_time - last_call_time + + if time_since_last >= min_interval: + # Cancel any pending timer + if timer is not None: + timer.cancel() + timer = None + + # Enough time has passed, execute the function + last_call_time = current_time + + # if we have accumulated data, we get a compound value + if len(accumulator): + accumulator.add(*args, **kwargs) + acc_args, acc_kwargs = accumulator.get() # type: ignore[misc] # accumulator resets here + return func(*acc_args, **acc_kwargs) + + # No accumulated data, normal call + return func(*args, **kwargs) + + else: + # Too soon, skip this call + accumulator.add(*args, **kwargs) + + # Schedule execution for when the interval expires + if timer is not None: + timer.cancel() + + time_to_wait = min_interval - time_since_last + timer = threading.Timer(time_to_wait, execute_accumulated) + timer.start() + + return None + + return wrapper + + return decorator + + +def simple_mcache(method: Callable) -> Callable: # type: ignore[type-arg] + """ + Decorator to cache the result of a method call on the instance. + + The cached value is stored as an attribute on the instance with the name + `_cached_`. Subsequent calls to the method will return the + cached value instead of recomputing it. + + Thread-safe: Uses a lock per instance to ensure the cached value is + computed only once even in multi-threaded environments. + + Args: + method: The method to be decorated. + + Returns: + The decorated method with caching behavior. + """ + + attr_name = f"_cached_{method.__name__}" + lock_name = f"_lock_{method.__name__}" + + @wraps(method) + def getter(self): # type: ignore[no-untyped-def] + # Get or create the lock for this instance + if not hasattr(self, lock_name): + # This is a one-time operation, race condition here is acceptable + # as worst case we create multiple locks but only one gets stored + setattr(self, lock_name, threading.Lock()) + + lock = getattr(self, lock_name) + + if hasattr(self, attr_name): + return getattr(self, attr_name) + + with lock: + # Check again inside the lock + if not hasattr(self, attr_name): + setattr(self, attr_name, method(self)) + return getattr(self, attr_name) + + return getter + + +def retry(max_retries: int = 3, on_exception: type[Exception] = Exception, delay: float = 0.0): # type: ignore[no-untyped-def] + """ + Decorator that retries a function call if it raises an exception. + + Args: + max_retries: Maximum number of retry attempts (default: 3) + on_exception: Exception type to catch and retry on (default: Exception) + delay: Fixed delay in seconds between retries (default: 0.0) + + Returns: + Decorated function that will retry on failure + + Example: + @retry(max_retries=5, on_exception=ConnectionError, delay=0.5) + def connect_to_server(): + # connection logic that might fail + pass + + @retry() # Use defaults: 3 retries on any Exception, no delay + def risky_operation(): + # might fail occasionally + pass + """ + if max_retries < 0: + raise ValueError("max_retries must be non-negative") + if delay < 0: + raise ValueError("delay must be non-negative") + + def decorator(func: Callable) -> Callable: # type: ignore[type-arg] + @wraps(func) + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] + last_exception = None + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except on_exception as e: + last_exception = e + if attempt < max_retries: + # Still have retries left + if delay > 0: + time.sleep(delay) + continue + else: + # Out of retries, re-raise the last exception + raise + + # This should never be reached, but just in case + if last_exception: + raise last_exception + + return wrapper + + return decorator diff --git a/dimos/utils/decorators/test_decorators.py b/dimos/utils/decorators/test_decorators.py new file mode 100644 index 0000000000..b7bc048631 --- /dev/null +++ b/dimos/utils/decorators/test_decorators.py @@ -0,0 +1,262 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.utils.decorators import RollingAverageAccumulator, limit, retry + + +def test_limit() -> None: + """Test limit decorator with keyword arguments.""" + calls = [] + + @limit(20) # 20 Hz + def process(msg: str, keyword: int = 0) -> str: + calls.append((msg, keyword)) + return f"{msg}:{keyword}" + + # First call goes through + result1 = process("first", keyword=1) + assert result1 == "first:1" + assert calls == [("first", 1)] + + # Quick calls get accumulated + result2 = process("second", keyword=2) + assert result2 is None + + result3 = process("third", keyword=3) + assert result3 is None + + # Wait for interval, expect to be called after it passes + time.sleep(0.6) + + result4 = process("fourth") + assert result4 == "fourth:0" + + assert calls == [("first", 1), ("third", 3), ("fourth", 0)] + + +def test_latest_rolling_average() -> None: + """Test RollingAverageAccumulator with limit decorator.""" + calls = [] + + accumulator = RollingAverageAccumulator() + + @limit(20, accumulator=accumulator) # 20 Hz + def process(value: float, label: str = "") -> str: + calls.append((value, label)) + return f"{value}:{label}" + + # First call goes through + result1 = process(10.0, label="first") + assert result1 == "10.0:first" + assert calls == [(10.0, "first")] + + # Quick calls get accumulated + result2 = process(20.0, label="second") + assert result2 is None + + result3 = process(30.0, label="third") + assert result3 is None + + # Wait for interval + time.sleep(0.6) + + # Should see the average of accumulated values + assert calls == [(10.0, "first"), (25.0, "third")] # (20+30)/2 = 25 + + +def test_retry_success_after_failures() -> None: + """Test that retry decorator retries on failure and eventually succeeds.""" + attempts = [] + + @retry(max_retries=3) + def flaky_function(fail_times: int = 2) -> str: + attempts.append(len(attempts)) + if len(attempts) <= fail_times: + raise ValueError(f"Attempt {len(attempts)} failed") + return "success" + + result = flaky_function() + assert result == "success" + assert len(attempts) == 3 # Failed twice, succeeded on third attempt + + +def test_retry_exhausted() -> None: + """Test that retry decorator raises exception when retries are exhausted.""" + attempts = [] + + @retry(max_retries=2) + def always_fails(): + attempts.append(len(attempts)) + raise RuntimeError(f"Attempt {len(attempts)} failed") + + with pytest.raises(RuntimeError) as exc_info: + always_fails() + + assert "Attempt 3 failed" in str(exc_info.value) + assert len(attempts) == 3 # Initial attempt + 2 retries + + +def test_retry_specific_exception() -> None: + """Test that retry only catches specified exception types.""" + attempts = [] + + @retry(max_retries=3, on_exception=ValueError) + def raises_different_exceptions() -> str: + attempts.append(len(attempts)) + if len(attempts) == 1: + raise ValueError("First attempt") + elif len(attempts) == 2: + raise TypeError("Second attempt - should not be retried") + return "success" + + # Should fail on TypeError (not retried) + with pytest.raises(TypeError) as exc_info: + raises_different_exceptions() + + assert "Second attempt" in str(exc_info.value) + assert len(attempts) == 2 # First attempt with ValueError, second with TypeError + + +def test_retry_no_failures() -> None: + """Test that retry decorator works when function succeeds immediately.""" + attempts = [] + + @retry(max_retries=5) + def always_succeeds() -> str: + attempts.append(len(attempts)) + return "immediate success" + + result = always_succeeds() + assert result == "immediate success" + assert len(attempts) == 1 # Only one attempt needed + + +def test_retry_with_delay() -> None: + """Test that retry decorator applies delay between attempts.""" + attempts = [] + times = [] + + @retry(max_retries=2, delay=0.1) + def delayed_failures() -> str: + times.append(time.time()) + attempts.append(len(attempts)) + if len(attempts) < 2: + raise ValueError(f"Attempt {len(attempts)}") + return "success" + + start = time.time() + result = delayed_failures() + duration = time.time() - start + + assert result == "success" + assert len(attempts) == 2 + assert duration >= 0.1 # At least one delay occurred + + # Check that delays were applied + if len(times) >= 2: + assert times[1] - times[0] >= 0.1 + + +def test_retry_zero_retries() -> None: + """Test retry with max_retries=0 (no retries, just one attempt).""" + attempts = [] + + @retry(max_retries=0) + def single_attempt(): + attempts.append(len(attempts)) + raise ValueError("Failed") + + with pytest.raises(ValueError): + single_attempt() + + assert len(attempts) == 1 # Only the initial attempt + + +def test_retry_invalid_parameters() -> None: + """Test that retry decorator validates parameters.""" + with pytest.raises(ValueError): + + @retry(max_retries=-1) + def invalid_retries() -> None: + pass + + with pytest.raises(ValueError): + + @retry(delay=-0.5) + def invalid_delay() -> None: + pass + + +def test_retry_with_methods() -> None: + """Test that retry decorator works with class methods, instance methods, and static methods.""" + + class TestClass: + def __init__(self) -> None: + self.instance_attempts = [] + self.instance_value = 42 + + @retry(max_retries=3) + def instance_method(self, fail_times: int = 2) -> str: + """Test retry on instance method.""" + self.instance_attempts.append(len(self.instance_attempts)) + if len(self.instance_attempts) <= fail_times: + raise ValueError(f"Instance attempt {len(self.instance_attempts)} failed") + return f"instance success with value {self.instance_value}" + + @classmethod + @retry(max_retries=2) + def class_method(cls, attempts_list, fail_times: int = 1) -> str: + """Test retry on class method.""" + attempts_list.append(len(attempts_list)) + if len(attempts_list) <= fail_times: + raise ValueError(f"Class attempt {len(attempts_list)} failed") + return f"class success from {cls.__name__}" + + @staticmethod + @retry(max_retries=2) + def static_method(attempts_list, fail_times: int = 1) -> str: + """Test retry on static method.""" + attempts_list.append(len(attempts_list)) + if len(attempts_list) <= fail_times: + raise ValueError(f"Static attempt {len(attempts_list)} failed") + return "static success" + + # Test instance method + obj = TestClass() + result = obj.instance_method() + assert result == "instance success with value 42" + assert len(obj.instance_attempts) == 3 # Failed twice, succeeded on third + + # Test class method + class_attempts = [] + result = TestClass.class_method(class_attempts) + assert result == "class success from TestClass" + assert len(class_attempts) == 2 # Failed once, succeeded on second + + # Test static method + static_attempts = [] + result = TestClass.static_method(static_attempts) + assert result == "static success" + assert len(static_attempts) == 2 # Failed once, succeeded on second + + # Test that self is properly maintained across retries + obj2 = TestClass() + obj2.instance_value = 100 + result = obj2.instance_method() + assert result == "instance success with value 100" + assert len(obj2.instance_attempts) == 3 diff --git a/dimos/utils/demo_image_encoding.py b/dimos/utils/demo_image_encoding.py new file mode 100644 index 0000000000..cfd06e172a --- /dev/null +++ b/dimos/utils/demo_image_encoding.py @@ -0,0 +1,127 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Usage + +Run it with uncompressed LCM: + + python dimos/utils/demo_image_encoding.py + +Run it with JPEG LCM: + + python dimos/utils/demo_image_encoding.py --use-jpeg +""" + +import argparse +import threading +import time + +from reactivex.disposable import Disposable + +from dimos.core.module import Module +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.stream import In, Out +from dimos.core.transport import JpegLcmTransport, LCMTransport +from dimos.msgs.sensor_msgs import Image +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.utils.fast_image_generator import random_image + + +class EmitterModule(Module): + image: Out[Image] = None # type: ignore[assignment] + + _thread: threading.Thread | None = None + _stop_event: threading.Event | None = None + + def start(self) -> None: + super().start() + self._stop_event = threading.Event() + self._thread = threading.Thread(target=self._publish_image, daemon=True) + self._thread.start() + + def stop(self) -> None: + if self._thread: + self._stop_event.set() # type: ignore[union-attr] + self._thread.join(timeout=2) + super().stop() + + def _publish_image(self) -> None: + open_file = open("/tmp/emitter-times", "w") + while not self._stop_event.is_set(): # type: ignore[union-attr] + start = time.time() + data = random_image(1280, 720) + total = time.time() - start + print("took", total) + open_file.write(str(time.time()) + "\n") + self.image.publish(Image(data=data)) + open_file.close() + + +class ReceiverModule(Module): + image: In[Image] = None # type: ignore[assignment] + + _open_file = None + + def start(self) -> None: + super().start() + self._disposables.add(Disposable(self.image.subscribe(self._on_image))) + self._open_file = open("/tmp/receiver-times", "w") + + def stop(self) -> None: + self._open_file.close() # type: ignore[union-attr] + super().stop() + + def _on_image(self, image: Image) -> None: + self._open_file.write(str(time.time()) + "\n") # type: ignore[union-attr] + print("image") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Demo image encoding with transport options") + parser.add_argument( + "--use-jpeg", + action="store_true", + help="Use JPEG LCM transport instead of regular LCM transport", + ) + args = parser.parse_args() + + dimos = ModuleCoordinator(n=2) + dimos.start() + emitter = dimos.deploy(EmitterModule) + receiver = dimos.deploy(ReceiverModule) + + if args.use_jpeg: + emitter.image.transport = JpegLcmTransport("/go2/color_image", Image) + else: + emitter.image.transport = LCMTransport("/go2/color_image", Image) + receiver.image.connect(emitter.image) + + foxglove_bridge = FoxgloveBridge() + foxglove_bridge.start() + + dimos.start_all_modules() + + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + pass + finally: + foxglove_bridge.stop() + dimos.close() # type: ignore[attr-defined] + + +if __name__ == "__main__": + main() diff --git a/dimos/utils/extract_frames.py b/dimos/utils/extract_frames.py index 3e84e1e838..1719c77620 100644 --- a/dimos/utils/extract_frames.py +++ b/dimos/utils/extract_frames.py @@ -1,9 +1,24 @@ -import cv2 -import os +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse from pathlib import Path -def extract_frames(video_path, output_dir, frame_rate): +import cv2 + + +def extract_frames(video_path, output_dir, frame_rate) -> None: # type: ignore[no-untyped-def] """ Extract frames from a video file at a specified frame rate. @@ -26,7 +41,7 @@ def extract_frames(video_path, output_dir, frame_rate): return # Calculate the interval between frames to capture - frame_interval = int(round(original_frame_rate / frame_rate)) + frame_interval = round(original_frame_rate / frame_rate) if frame_interval == 0: frame_interval = 1 @@ -49,11 +64,19 @@ def extract_frames(video_path, output_dir, frame_rate): cap.release() print(f"Extracted {saved_frame_count} frames to {output_dir}") + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract frames from a video file.") parser.add_argument("video_path", type=str, help="Path to the input .mov or .mp4 video file.") - parser.add_argument("--output_dir", type=str, default="frames", help="Directory to save extracted frames.") - parser.add_argument("--frame_rate", type=float, default=1.0, help="Frame rate at which to extract frames (frames per second).") + parser.add_argument( + "--output_dir", type=str, default="frames", help="Directory to save extracted frames." + ) + parser.add_argument( + "--frame_rate", + type=float, + default=1.0, + help="Frame rate at which to extract frames (frames per second).", + ) args = parser.parse_args() diff --git a/dimos/utils/fast_image_generator.py b/dimos/utils/fast_image_generator.py new file mode 100644 index 0000000000..66c4fcf951 --- /dev/null +++ b/dimos/utils/fast_image_generator.py @@ -0,0 +1,305 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fast stateful image generator with visual features for encoding tests.""" + +from typing import Literal, TypedDict, Union + +import numpy as np +from numpy.typing import NDArray + + +class CircleObject(TypedDict): + """Type definition for circle objects.""" + + type: Literal["circle"] + x: float + y: float + vx: float + vy: float + radius: int + color: NDArray[np.float32] + + +class RectObject(TypedDict): + """Type definition for rectangle objects.""" + + type: Literal["rect"] + x: float + y: float + vx: float + vy: float + width: int + height: int + color: NDArray[np.float32] + + +Object = Union[CircleObject, RectObject] + + +class FastImageGenerator: + """ + Stateful image generator that creates images with visual features + suitable for testing image/video encoding at 30+ FPS. + + Features generated: + - Moving geometric shapes (tests motion vectors) + - Color gradients (tests gradient compression) + - Sharp edges and corners (tests edge preservation) + - Textured regions (tests detail retention) + - Smooth regions (tests flat area compression) + - High contrast boundaries (tests blocking artifacts) + """ + + def __init__(self, width: int = 1280, height: int = 720) -> None: + """Initialize the generator with pre-computed elements.""" + self.width = width + self.height = height + self.frame_count = 0 + self.objects: list[Object] = [] + + # Pre-allocate the main canvas + self.canvas = np.zeros((height, width, 3), dtype=np.float32) + + # Pre-compute coordinate grids for fast gradient generation + self.x_grid, self.y_grid = np.meshgrid( + np.linspace(0, 1, width, dtype=np.float32), np.linspace(0, 1, height, dtype=np.float32) + ) + + # Pre-compute base gradient patterns + self._init_gradients() + + # Initialize moving objects with their properties + self._init_moving_objects() + + # Pre-compute static texture pattern + self._init_texture() + + # Pre-allocate shape masks for reuse + self._init_shape_masks() + + def _init_gradients(self) -> None: + """Pre-compute gradient patterns.""" + # Diagonal gradient + self.diag_gradient = (self.x_grid + self.y_grid) * 0.5 + + # Radial gradient from center + cx, cy = 0.5, 0.5 + self.radial_gradient = np.sqrt((self.x_grid - cx) ** 2 + (self.y_grid - cy) ** 2) + self.radial_gradient = np.clip(1.0 - self.radial_gradient * 1.5, 0, 1) + + # Horizontal and vertical gradients + self.h_gradient = self.x_grid + self.v_gradient = self.y_grid + + def _init_moving_objects(self) -> None: + """Initialize properties of moving objects.""" + self.objects = [ + { + "type": "circle", + "x": 0.2, + "y": 0.3, + "vx": 0.002, + "vy": 0.003, + "radius": 60, + "color": np.array([255, 100, 100], dtype=np.float32), + }, + { + "type": "rect", + "x": 0.7, + "y": 0.6, + "vx": -0.003, + "vy": 0.002, + "width": 100, + "height": 80, + "color": np.array([100, 255, 100], dtype=np.float32), + }, + { + "type": "circle", + "x": 0.5, + "y": 0.5, + "vx": 0.004, + "vy": -0.002, + "radius": 40, + "color": np.array([100, 100, 255], dtype=np.float32), + }, + ] + + def _init_texture(self) -> None: + """Pre-compute a texture pattern.""" + # Create a simple checkerboard pattern at lower resolution + checker_size = 20 + checker_h = self.height // checker_size + checker_w = self.width // checker_size + + # Create small checkerboard + checker = np.indices((checker_h, checker_w)).sum(axis=0) % 2 + + # Upscale using repeat (fast) + self.texture = np.repeat(np.repeat(checker, checker_size, axis=0), checker_size, axis=1) + self.texture = self.texture[: self.height, : self.width].astype(np.float32) * 30 + + def _init_shape_masks(self) -> None: + """Pre-allocate reusable masks for shapes.""" + # Pre-allocate a mask array + self.temp_mask = np.zeros((self.height, self.width), dtype=np.float32) + + # Pre-compute indices for the entire image + self.y_indices, self.x_indices = np.indices((self.height, self.width)) + + def _draw_circle_fast(self, cx: int, cy: int, radius: int, color: NDArray[np.float32]) -> None: + """Draw a circle using vectorized operations - optimized version without anti-aliasing.""" + # Compute bounding box to minimize calculations + y1 = max(0, cy - radius - 1) + y2 = min(self.height, cy + radius + 2) + x1 = max(0, cx - radius - 1) + x2 = min(self.width, cx + radius + 2) + + # Work only on the bounding box region + if y1 < y2 and x1 < x2: + y_local, x_local = np.ogrid[y1:y2, x1:x2] + dist_sq = (x_local - cx) ** 2 + (y_local - cy) ** 2 + mask = dist_sq <= radius**2 + self.canvas[y1:y2, x1:x2][mask] = color + + def _draw_rect_fast(self, x: int, y: int, w: int, h: int, color: NDArray[np.float32]) -> None: + """Draw a rectangle using slicing.""" + # Clip to canvas boundaries + x1 = max(0, x) + y1 = max(0, y) + x2 = min(self.width, x + w) + y2 = min(self.height, y + h) + + if x1 < x2 and y1 < y2: + self.canvas[y1:y2, x1:x2] = color + + def _update_objects(self) -> None: + """Update positions of moving objects.""" + for obj in self.objects: + # Update position + obj["x"] += obj["vx"] + obj["y"] += obj["vy"] + + # Bounce off edges + if obj["type"] == "circle": + r = obj["radius"] / self.width + if obj["x"] - r <= 0 or obj["x"] + r >= 1: + obj["vx"] *= -1 + obj["x"] = np.clip(obj["x"], r, 1 - r) + + r = obj["radius"] / self.height + if obj["y"] - r <= 0 or obj["y"] + r >= 1: + obj["vy"] *= -1 + obj["y"] = np.clip(obj["y"], r, 1 - r) + + elif obj["type"] == "rect": + w = obj["width"] / self.width + h = obj["height"] / self.height + if obj["x"] <= 0 or obj["x"] + w >= 1: + obj["vx"] *= -1 + obj["x"] = np.clip(obj["x"], 0, 1 - w) + + if obj["y"] <= 0 or obj["y"] + h >= 1: + obj["vy"] *= -1 + obj["y"] = np.clip(obj["y"], 0, 1 - h) + + def generate_frame(self) -> NDArray[np.uint8]: + """ + Generate a single frame with visual features - optimized for 30+ FPS. + + Returns: + numpy array of shape (height, width, 3) with uint8 values + """ + # Fast gradient background - use only one gradient per frame + if self.frame_count % 2 == 0: + base_gradient = self.h_gradient + else: + base_gradient = self.v_gradient + + # Simple color mapping + self.canvas[:, :, 0] = base_gradient * 150 + 50 + self.canvas[:, :, 1] = base_gradient * 120 + 70 + self.canvas[:, :, 2] = (1 - base_gradient) * 140 + 60 + + # Add texture in corner - simplified without per-channel scaling + tex_size = self.height // 3 + self.canvas[:tex_size, :tex_size] += self.texture[:tex_size, :tex_size, np.newaxis] + + # Add test pattern bars - vectorized + bar_width = 50 + bar_start = self.width // 3 + for i in range(3): # Reduced from 5 to 3 bars + x1 = bar_start + i * bar_width * 2 + x2 = min(x1 + bar_width, self.width) + if x1 < self.width: + color_val = 180 + i * 30 + self.canvas[self.height // 2 :, x1:x2] = color_val + + # Update and draw only 2 moving objects (reduced from 3) + self._update_objects() + + # Draw only first 2 objects for speed + for obj in self.objects[:2]: + if obj["type"] == "circle": + cx = int(obj["x"] * self.width) + cy = int(obj["y"] * self.height) + self._draw_circle_fast(cx, cy, obj["radius"], obj["color"]) + elif obj["type"] == "rect": + x = int(obj["x"] * self.width) + y = int(obj["y"] * self.height) + self._draw_rect_fast(x, y, obj["width"], obj["height"], obj["color"]) + + # Simple horizontal lines pattern (faster than sine wave) + line_y = int(self.height * 0.8) + line_spacing = 10 + for i in range(0, 5): + y = line_y + i * line_spacing + if y < self.height: + self.canvas[y : y + 2, :] = [255, 200, 100] + + # Increment frame counter + self.frame_count += 1 + + # Direct conversion to uint8 (already in valid range) + return self.canvas.astype(np.uint8) + + def reset(self) -> None: + """Reset the generator to initial state.""" + self.frame_count = 0 + self._init_moving_objects() + + +# Convenience function for backward compatibility +_generator: FastImageGenerator | None = None + + +def random_image(width: int, height: int) -> NDArray[np.uint8]: + """ + Generate an image with visual features suitable for encoding tests. + Maintains state for efficient stream generation. + + Args: + width: Image width in pixels + height: Image height in pixels + + Returns: + numpy array of shape (height, width, 3) with uint8 values + """ + global _generator + + # Initialize or reinitialize if dimensions changed + if _generator is None or _generator.width != width or _generator.height != height: + _generator = FastImageGenerator(width, height) + + return _generator.generate_frame() diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py new file mode 100644 index 0000000000..e53292f1b1 --- /dev/null +++ b/dimos/utils/generic.py @@ -0,0 +1,78 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import json +import os +import string +from typing import Any +import uuid + + +def truncate_display_string(arg: Any, max: int | None = None) -> str: + """ + If we print strings that are too long that potentially obscures more important logs. + + Use this function to truncate it to a reasonable length (configurable from the env). + """ + string = str(arg) + + if max is not None: + max_chars = max + else: + max_chars = int(os.getenv("TRUNCATE_MAX", "2000")) + + if max_chars == 0 or len(string) <= max_chars: + return string + + return string[:max_chars] + "...(truncated)..." + + +def extract_json_from_llm_response(response: str) -> Any: + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 + + if start_idx >= 0 and end_idx > start_idx: + json_str = response[start_idx:end_idx] + try: + return json.loads(json_str) + except Exception: + pass + + return None + + +def short_id(from_string: str | None = None) -> str: + alphabet = string.digits + string.ascii_letters + base = len(alphabet) + + if from_string is None: + num = uuid.uuid4().int + else: + hash_bytes = hashlib.sha1(from_string.encode()).digest()[:16] + num = int.from_bytes(hash_bytes, "big") + + min_chars = 18 + + chars: list[str] = [] + while num > 0 or len(chars) < min_chars: + num, rem = divmod(num, base) + chars.append(alphabet[rem]) + + return "".join(reversed(chars))[:min_chars] + + +class classproperty(property): + def __get__(self, obj, cls): # type: ignore[no-untyped-def, override] + return self.fget(cls) # type: ignore[misc] diff --git a/dimos/utils/gpu_utils.py b/dimos/utils/gpu_utils.py new file mode 100644 index 0000000000..36cebd7d82 --- /dev/null +++ b/dimos/utils/gpu_utils.py @@ -0,0 +1,23 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def is_cuda_available(): # type: ignore[no-untyped-def] + try: + import pycuda.driver as cuda # type: ignore[import-not-found] + + cuda.init() + return cuda.Device.count() > 0 + except Exception: + return False diff --git a/dimos/utils/llm_utils.py b/dimos/utils/llm_utils.py new file mode 100644 index 0000000000..47d848807c --- /dev/null +++ b/dimos/utils/llm_utils.py @@ -0,0 +1,74 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re + + +def extract_json(response: str) -> dict | list: # type: ignore[type-arg] + """Extract JSON from potentially messy LLM response. + + Tries multiple strategies: + 1. Parse the entire response as JSON + 2. Find and parse JSON arrays in the response + 3. Find and parse JSON objects in the response + + Args: + response: Raw text response that may contain JSON + + Returns: + Parsed JSON object (dict or list) + + Raises: + json.JSONDecodeError: If no valid JSON can be extracted + """ + # First try to parse the whole response as JSON + try: + return json.loads(response) # type: ignore[no-any-return] + except json.JSONDecodeError: + pass + + # If that fails, try to extract JSON from the messy response + # Look for JSON arrays or objects in the text + + # Pattern to match JSON arrays (including nested arrays/objects) + # This finds the outermost [...] structure + array_pattern = r"\[(?:[^\[\]]*|\[(?:[^\[\]]*|\[[^\[\]]*\])*\])*\]" + + # Pattern to match JSON objects + object_pattern = r"\{(?:[^{}]*|\{(?:[^{}]*|\{[^{}]*\})*\})*\}" + + # Try to find JSON arrays first (most common for detections) + matches = re.findall(array_pattern, response, re.DOTALL) + for match in matches: + try: + parsed = json.loads(match) + # For detection arrays, we expect a list + if isinstance(parsed, list): + return parsed + except json.JSONDecodeError: + continue + + # Try JSON objects if no arrays found + matches = re.findall(object_pattern, response, re.DOTALL) + for match in matches: + try: + return json.loads(match) # type: ignore[no-any-return] + except json.JSONDecodeError: + continue + + # If nothing worked, raise an error with the original response + raise json.JSONDecodeError( + f"Could not extract valid JSON from response: {response[:200]}...", response, 0 + ) diff --git a/dimos/utils/logging_config.py b/dimos/utils/logging_config.py new file mode 100644 index 0000000000..ce1494025c --- /dev/null +++ b/dimos/utils/logging_config.py @@ -0,0 +1,234 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Mapping +from datetime import datetime +import inspect +import logging +import logging.handlers +import os +from pathlib import Path +import sys +import tempfile +import traceback +from types import TracebackType +from typing import Any + +import structlog +from structlog.processors import CallsiteParameter, CallsiteParameterAdder + +from dimos.constants import DIMOS_LOG_DIR, DIMOS_PROJECT_ROOT + +# Suppress noisy loggers +logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("websockets.server").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) + +_LOG_FILE_PATH = None + + +def _get_log_directory() -> Path: + # Check if running from a git repository + if (DIMOS_PROJECT_ROOT / ".git").exists(): + log_dir = DIMOS_LOG_DIR + else: + # Running from an installed package - use XDG_STATE_HOME + xdg_state_home = os.getenv("XDG_STATE_HOME") + if xdg_state_home: + log_dir = Path(xdg_state_home) / "dimos" / "logs" + else: + log_dir = Path.home() / ".local" / "state" / "dimos" / "logs" + + try: + log_dir.mkdir(parents=True, exist_ok=True) + except (PermissionError, OSError): + log_dir = Path(tempfile.gettempdir()) / "dimos" / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + + return log_dir + + +def _get_log_file_path() -> Path: + log_dir = _get_log_directory() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pid = os.getpid() + return log_dir / f"dimos_{timestamp}_{pid}.jsonl" + + +def _configure_structlog() -> Path: + global _LOG_FILE_PATH + + if _LOG_FILE_PATH: + return _LOG_FILE_PATH + + _LOG_FILE_PATH = _get_log_file_path() + + shared_processors: list[Any] = [ + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.UnicodeDecoder(), + CallsiteParameterAdder( + parameters=[ + CallsiteParameter.FUNC_NAME, + CallsiteParameter.LINENO, + ] + ), + structlog.processors.format_exc_info, # Add this to format exception info + ] + + structlog.configure( + processors=[ + structlog.stdlib.filter_by_level, + *shared_processors, + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True, + ) + + return _LOG_FILE_PATH + + +def setup_logger(*, level: int | None = None) -> Any: + """Set up a structured logger using structlog. + + Args: + level: The logging level. + + Returns: + A configured structlog logger instance. + """ + + caller_frame = inspect.stack()[1] + name = caller_frame.filename + + # Convert absolute path to relative path + try: + name = str(Path(name).relative_to(DIMOS_PROJECT_ROOT)) + except (ValueError, TypeError): + pass + + log_file_path = _configure_structlog() + + if level is None: + level_name = os.getenv("DIMOS_LOG_LEVEL", "INFO") + level = getattr(logging, level_name) + + stdlib_logger = logging.getLogger(name) + + # Remove any existing handlers. + if stdlib_logger.hasHandlers(): + stdlib_logger.handlers.clear() + + stdlib_logger.setLevel(level) + stdlib_logger.propagate = False + + # Create console handler with pretty formatting. + # We use exception_formatter=None because we handle exceptions + # separately with Rich in the global exception handler + + console_renderer = structlog.dev.ConsoleRenderer( + colors=True, + pad_event=60, + force_colors=False, + sort_keys=True, + # Don't format exceptions in console logs + exception_formatter=None, # type: ignore[arg-type] + ) + + # Wrapper to remove callsite info and exception details before rendering to console. + def console_processor_without_callsite( + logger: Any, method_name: str, event_dict: Mapping[str, Any] + ) -> str: + event_dict = dict(event_dict) + # Remove callsite info + event_dict.pop("func_name", None) + event_dict.pop("lineno", None) + # Remove exception fields since we handle them with Rich + event_dict.pop("exception", None) + event_dict.pop("exc_info", None) + event_dict.pop("exception_type", None) + event_dict.pop("exception_message", None) + event_dict.pop("traceback_lines", None) + return console_renderer(logger, method_name, event_dict) + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + console_formatter = structlog.stdlib.ProcessorFormatter( + processor=console_processor_without_callsite, + ) + console_handler.setFormatter(console_formatter) + stdlib_logger.addHandler(console_handler) + + # Create rotating file handler with JSON formatting. + file_handler = logging.handlers.RotatingFileHandler( + log_file_path, + mode="a", + maxBytes=10 * 1024 * 1024, # 10MiB + backupCount=20, + encoding="utf-8", + ) + file_handler.setLevel(level) + file_formatter = structlog.stdlib.ProcessorFormatter( + processor=structlog.processors.JSONRenderer(), + ) + file_handler.setFormatter(file_formatter) + stdlib_logger.addHandler(file_handler) + + return structlog.get_logger(name) + + +def setup_exception_handler() -> None: + def handle_exception( + exc_type: type[BaseException], + exc_value: BaseException, + exc_traceback: TracebackType | None, + ) -> None: + # Don't log KeyboardInterrupt + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + + # Get a logger for uncaught exceptions + logger = setup_logger() + + # Log the exception with full traceback to JSON + logger.error( + "Uncaught exception occurred", + exc_info=(exc_type, exc_value, exc_traceback), + exception_type=exc_type.__name__, + exception_message=str(exc_value), + traceback_lines=traceback.format_exception(exc_type, exc_value, exc_traceback), + ) + + # Still display the exception nicely on console using Rich if available + try: + from rich.console import Console + from rich.traceback import Traceback + + console = Console() + tb = Traceback.from_exception(exc_type, exc_value, exc_traceback) + console.print(tb) + except ImportError: + # Fall back to standard exception display if Rich is not available + sys.__excepthook__(exc_type, exc_value, exc_traceback) + + # Set our custom exception handler + sys.excepthook = handle_exception diff --git a/dimos/utils/monitoring.py b/dimos/utils/monitoring.py new file mode 100644 index 0000000000..ca3e03c55e --- /dev/null +++ b/dimos/utils/monitoring.py @@ -0,0 +1,307 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Note, to enable ps-spy to run without sudo you need: + + echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope +""" + +from functools import cache +import os +import re +import shutil +import subprocess +import threading + +from distributed import get_client +from distributed.client import Client + +from dimos.core import Module, rpc +from dimos.utils.actor_registry import ActorRegistry +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def print_data_table(data) -> None: # type: ignore[no-untyped-def] + headers = [ + "cpu_percent", + "active_percent", + "gil_percent", + "n_threads", + "pid", + "worker_id", + "modules", + ] + numeric_headers = {"cpu_percent", "active_percent", "gil_percent", "n_threads", "pid"} + + # Add registered modules. + modules = ActorRegistry.get_all() + for worker in data: + worker["modules"] = ", ".join( + module_name.split("-", 1)[0] + for module_name, worker_id_str in modules.items() + if worker_id_str == str(worker["worker_id"]) + ) + + # Determine column widths + col_widths = [] + for h in headers: + max_len = max(len(str(d[h])) for d in data) + col_widths.append(max(len(h), max_len)) + + # Print header with DOS box characters + header_row = " │ ".join(h.ljust(col_widths[i]) for i, h in enumerate(headers)) + border_parts = ["─" * w for w in col_widths] + border_line = "─┼─".join(border_parts) + print(border_line) + print(header_row) + print(border_line) + + # Print rows + for row in data: + formatted_cells = [] + for i, h in enumerate(headers): + value = str(row[h]) + if h in numeric_headers: + formatted_cells.append(value.rjust(col_widths[i])) + else: + formatted_cells.append(value.ljust(col_widths[i])) + print(" │ ".join(formatted_cells)) + + +class UtilizationThread(threading.Thread): + _module: "UtilizationModule" + _stop_event: threading.Event + _monitors: dict # type: ignore[type-arg] + + def __init__(self, module) -> None: # type: ignore[no-untyped-def] + super().__init__(daemon=True) + self._module = module + self._stop_event = threading.Event() + self._monitors = {} + + def run(self) -> None: + while not self._stop_event.is_set(): + workers = self._module.client.scheduler_info()["workers"] # type: ignore[union-attr] + pids = {pid: None for pid in get_worker_pids()} # type: ignore[no-untyped-call] + for worker, info in workers.items(): + pid = get_pid_by_port(worker.rsplit(":", 1)[-1]) + if pid is None: + continue + pids[pid] = info["id"] + data = [] + for pid, worker_id in pids.items(): + if pid not in self._monitors: + self._monitors[pid] = GilMonitorThread(pid) + self._monitors[pid].start() + cpu, gil, active, n_threads = self._monitors[pid].get_values() + data.append( + { + "cpu_percent": cpu, + "worker_id": worker_id, + "pid": pid, + "gil_percent": gil, + "active_percent": active, + "n_threads": n_threads, + } + ) + data.sort(key=lambda x: x["pid"]) + self._fix_missing_ids(data) + print_data_table(data) + self._stop_event.wait(1) + + def stop(self) -> None: + self._stop_event.set() + for monitor in self._monitors.values(): + monitor.stop() + monitor.join(timeout=2) + + def _fix_missing_ids(self, data) -> None: # type: ignore[no-untyped-def] + """ + Some worker IDs are None. But if we order the workers by PID and all + non-None ids are in order, then we can deduce that the None ones are the + missing indices. + """ + if all(x["worker_id"] in (i, None) for i, x in enumerate(data)): + for i, worker in enumerate(data): + worker["worker_id"] = i + + +class UtilizationModule(Module): + client: Client | None + _utilization_thread: UtilizationThread | None + + def __init__(self) -> None: + super().__init__() + self.client = None + self._utilization_thread = None + + if not os.getenv("MEASURE_GIL_UTILIZATION"): + logger.info("Set `MEASURE_GIL_UTILIZATION=true` to print GIL utilization.") + return + + if not _can_use_py_spy(): # type: ignore[no-untyped-call] + logger.warning( + "Cannot start UtilizationModule because in order to run py-spy without " + "being root you need to enable this:\n" + "\n" + " echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope" + ) + return + + if not shutil.which("py-spy"): + logger.warning("Cannot start UtilizationModule because `py-spy` is not installed.") + return + + self.client = get_client() + self._utilization_thread = UtilizationThread(self) + + @rpc + def start(self) -> None: + super().start() + + if self._utilization_thread: + self._utilization_thread.start() + + @rpc + def stop(self) -> None: + if self._utilization_thread: + self._utilization_thread.stop() + self._utilization_thread.join(timeout=2) + super().stop() + + +utilization = UtilizationModule.blueprint + + +__all__ = ["UtilizationModule", "utilization"] + + +def _can_use_py_spy(): # type: ignore[no-untyped-def] + try: + with open("/proc/sys/kernel/yama/ptrace_scope") as f: + value = f.read().strip() + return value == "0" + except Exception: + pass + return False + + +@cache +def get_pid_by_port(port: int) -> int | None: + try: + result = subprocess.run( + ["lsof", "-ti", f":{port}"], capture_output=True, text=True, check=True + ) + pid_str = result.stdout.strip() + return int(pid_str) if pid_str else None + except subprocess.CalledProcessError: + return None + + +def get_worker_pids(): # type: ignore[no-untyped-def] + pids = [] + for pid in os.listdir("/proc"): + if not pid.isdigit(): + continue + try: + with open(f"/proc/{pid}/cmdline") as f: + cmdline = f.read().replace("\x00", " ") + if "spawn_main" in cmdline: + pids.append(int(pid)) + except (FileNotFoundError, PermissionError): + continue + return pids + + +class GilMonitorThread(threading.Thread): + pid: int + _latest_values: tuple[float, float, float, int] + _stop_event: threading.Event + _lock: threading.Lock + + def __init__(self, pid: int) -> None: + super().__init__(daemon=True) + self.pid = pid + self._latest_values = (-1.0, -1.0, -1.0, -1) + self._stop_event = threading.Event() + self._lock = threading.Lock() + + def run(self): # type: ignore[no-untyped-def] + command = ["py-spy", "top", "--pid", str(self.pid), "--rate", "100"] + process = None + try: + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, # Line-buffered output + ) + + for line in iter(process.stdout.readline, ""): # type: ignore[union-attr] + if self._stop_event.is_set(): + break + + if "GIL:" not in line: + continue + + match = re.search( + r"GIL:\s*([\d.]+?)%,\s*Active:\s*([\d.]+?)%,\s*Threads:\s*(\d+)", line + ) + if not match: + continue + + try: + cpu_percent = _get_cpu_percent(self.pid) + gil_percent = float(match.group(1)) + active_percent = float(match.group(2)) + num_threads = int(match.group(3)) + + with self._lock: + self._latest_values = ( + cpu_percent, + gil_percent, + active_percent, + num_threads, + ) + except (ValueError, IndexError): + pass + except Exception as e: + logger.error(f"An error occurred in GilMonitorThread for PID {self.pid}: {e}") + raise + finally: + if process: + process.terminate() + process.wait(timeout=1) + self._stop_event.set() + + def get_values(self): # type: ignore[no-untyped-def] + with self._lock: + return self._latest_values + + def stop(self) -> None: + self._stop_event.set() + + +def _get_cpu_percent(pid: int) -> float: + try: + result = subprocess.run( + ["ps", "-p", str(pid), "-o", "%cpu="], capture_output=True, text=True, check=True + ) + return float(result.stdout.strip()) + except Exception: + return -1.0 diff --git a/dimos/utils/path_utils.py b/dimos/utils/path_utils.py new file mode 100644 index 0000000000..794d36e34d --- /dev/null +++ b/dimos/utils/path_utils.py @@ -0,0 +1,22 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + + +def get_project_root() -> Path: + """ + Returns the absolute path to the project root directory. + """ + return Path(__file__).resolve().parent.parent.parent diff --git a/dimos/utils/reactive.py b/dimos/utils/reactive.py new file mode 100644 index 0000000000..5eed9908a2 --- /dev/null +++ b/dimos/utils/reactive.py @@ -0,0 +1,230 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +import threading +from typing import Any, Generic, TypeVar + +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.observable import Observable +from reactivex.scheduler import ThreadPoolScheduler +from rxpy_backpressure import BackPressure # type: ignore[import-untyped] + +from dimos.utils.threadpool import get_scheduler + +T = TypeVar("T") + + +# Observable ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) +# ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) +# └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) +def backpressure( + observable: Observable[T], + scheduler: ThreadPoolScheduler | None = None, + drop_unprocessed: bool = True, +) -> Observable[T]: + if scheduler is None: + scheduler = get_scheduler() + + # hot, latest-cached core (similar to replay subject) + core = observable.pipe( + ops.replay(buffer_size=1), + ops.ref_count(), # Shared but still synchronous! + ) + + # per-subscriber factory + def per_sub(): # type: ignore[no-untyped-def] + # Move processing to thread pool + base = core.pipe(ops.observe_on(scheduler)) + + # optional back-pressure handling + if not drop_unprocessed: + return base + + def _subscribe(observer, sch=None): # type: ignore[no-untyped-def] + return base.subscribe(BackPressure.LATEST(observer), scheduler=sch) + + return rx.create(_subscribe) + + # each `.subscribe()` call gets its own async backpressure chain + return rx.defer(lambda *_: per_sub()) # type: ignore[no-untyped-call] + + +class LatestReader(Generic[T]): + """A callable object that returns the latest value from an observable.""" + + def __init__(self, initial_value: T, subscription, connection=None) -> None: # type: ignore[no-untyped-def] + self._value = initial_value + self._subscription = subscription + self._connection = connection + + def __call__(self) -> T: + """Return the latest value from the observable.""" + return self._value + + def dispose(self) -> None: + """Dispose of the subscription to the observable.""" + self._subscription.dispose() + if self._connection: + self._connection.dispose() + + +def getter_ondemand(observable: Observable[T], timeout: float | None = 30.0) -> T: + def getter(): # type: ignore[no-untyped-def] + result = [] + error = [] + event = threading.Event() + + def on_next(value) -> None: # type: ignore[no-untyped-def] + result.append(value) + event.set() + + def on_error(e) -> None: # type: ignore[no-untyped-def] + error.append(e) + event.set() + + def on_completed() -> None: + event.set() + + # Subscribe and wait for first value + subscription = observable.pipe(ops.first()).subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + try: + if timeout is not None: + if not event.wait(timeout): + raise TimeoutError(f"No value received after {timeout} seconds") + else: + event.wait() + + if error: + raise error[0] + + if not result: + raise Exception("Observable completed without emitting a value") + + return result[0] + finally: + subscription.dispose() + + return getter # type: ignore[return-value] + + +T = TypeVar("T") # type: ignore[misc] + + +def getter_streaming( + source: Observable[T], + timeout: float | None = 30.0, + *, + nonblocking: bool = False, +) -> LatestReader[T]: + shared = source.pipe( + ops.replay(buffer_size=1), + ops.ref_count(), # auto-connect & auto-disconnect + ) + + _val_lock = threading.Lock() + _val: T | None = None + _ready = threading.Event() + + def _update(v: T) -> None: + nonlocal _val + with _val_lock: + _val = v + _ready.set() + + sub = shared.subscribe(_update) + + # If we’re in blocking mode, wait right now + if not nonblocking: + if timeout is not None and not _ready.wait(timeout): + sub.dispose() + raise TimeoutError(f"No value received after {timeout} s") + else: + _ready.wait() # wait indefinitely if timeout is None + + def reader() -> T: + if not _ready.is_set(): # first call in non-blocking mode + if timeout is not None and not _ready.wait(timeout): + raise TimeoutError(f"No value received after {timeout} s") + else: + _ready.wait() + with _val_lock: + return _val # type: ignore[return-value] + + def _dispose() -> None: + sub.dispose() + + reader.dispose = _dispose # type: ignore[attr-defined] + return reader # type: ignore[return-value] + + +T = TypeVar("T") # type: ignore[misc] +CB = Callable[[T], Any] + + +def callback_to_observable( + start: Callable[[CB[T]], Any], + stop: Callable[[CB[T]], Any], +) -> Observable[T]: + def _subscribe(observer, _scheduler=None): # type: ignore[no-untyped-def] + def _on_msg(value: T) -> None: + observer.on_next(value) + + start(_on_msg) + return Disposable(lambda: stop(_on_msg)) + + return rx.create(_subscribe) + + +def spy(name: str): # type: ignore[no-untyped-def] + def spyfun(x): # type: ignore[no-untyped-def] + print(f"SPY {name}:", x) + return x + + return ops.map(spyfun) + + +def quality_barrier(quality_func: Callable[[T], float], target_frequency: float): # type: ignore[no-untyped-def] + """ + RxPY pipe operator that selects the highest quality item within each time window. + + Args: + quality_func: Function to compute quality score for each item + target_frequency: Output frequency in Hz (e.g., 1.0 for 1 item per second) + + Returns: + A pipe operator that can be used with .pipe() + """ + window_duration = 1.0 / target_frequency # Duration of each window in seconds + + def _quality_barrier(source: Observable[T]) -> Observable[T]: + return source.pipe( + # Create non-overlapping time-based windows + ops.window_with_time(window_duration, window_duration), + # For each window, find the highest quality item + ops.flat_map( + lambda window: window.pipe( # type: ignore[attr-defined] + ops.to_list(), + ops.map(lambda items: max(items, key=quality_func) if items else None), # type: ignore[call-overload] + ops.filter(lambda x: x is not None), # type: ignore[arg-type] + ) + ), + ) + + return _quality_barrier diff --git a/dimos/utils/s3_utils.py b/dimos/utils/s3_utils.py deleted file mode 100644 index 02e7df580c..0000000000 --- a/dimos/utils/s3_utils.py +++ /dev/null @@ -1,79 +0,0 @@ -import boto3 -import os -from io import BytesIO -try: - import open3d as o3d -except Exception as e: - print(f"Open3D not importing, assuming to be running outside of docker. {e}") - -class S3Utils: - def __init__(self, bucket_name): - self.s3 = boto3.client('s3') - self.bucket_name = bucket_name - - def download_file(self, s3_key, local_path): - try: - self.s3.download_file(self.bucket_name, s3_key, local_path) - print(f"Downloaded {s3_key} to {local_path}") - except Exception as e: - print(f"Error downloading {s3_key}: {e}") - - def upload_file(self, local_path, s3_key): - try: - self.s3.upload_file(local_path, self.bucket_name, s3_key) - print(f"Uploaded {local_path} to {s3_key}") - except Exception as e: - print(f"Error uploading {local_path}: {e}") - - def save_pointcloud_to_s3(self, inlier_cloud, s3_key): - - try: - temp_pcd_file = "/tmp/temp_pointcloud.pcd" - o3d.io.write_point_cloud(temp_pcd_file, inlier_cloud) - with open(temp_pcd_file, 'rb') as pcd_file: - self.s3.put_object(Bucket=self.bucket_name, Key=s3_key, Body=pcd_file.read()) - os.remove(temp_pcd_file) - print(f"Saved pointcloud to {s3_key}") - except Exception as e: - print(f"error downloading {s3_key}: {e}") - - def restore_pointcloud_from_s3(self, pointcloud_paths): - restored_pointclouds = [] - - for path in pointcloud_paths: - # Download the point cloud file from S3 to memory - pcd_obj = self.s3.get_object(Bucket=self.bucket_name, Key=path) - pcd_data = pcd_obj['Body'].read() - - # Save the point cloud data to a temporary file - temp_pcd_file = "/tmp/temp_pointcloud.pcd" - with open(temp_pcd_file, 'wb') as f: - f.write(pcd_data) - - # Read the point cloud from the temporary file - pcd = o3d.io.read_point_cloud(temp_pcd_file) - restored_pointclouds.append(pcd) - - # Remove the temporary file - os.remove(temp_pcd_file) - - return restored_pointclouds - @staticmethod - def upload_text_file(bucket_name, local_path, s3_key): - s3 = boto3.client('s3') - try: - with open(local_path, 'r') as file: - content = file.read() - - # Ensure the s3_key includes the file name - if not s3_key.endswith('/'): - s3_key = s3_key + '/' - - # Extract the file name from the local_path - file_name = local_path.split('/')[-1] - full_s3_key = s3_key + file_name - - s3.put_object(Bucket=bucket_name, Key=full_s3_key, Body=content) - print(f"Uploaded text file {local_path} to {full_s3_key}") - except Exception as e: - print(f"Error uploading text file {local_path}: {e}") \ No newline at end of file diff --git a/dimos/utils/simple_controller.py b/dimos/utils/simple_controller.py new file mode 100644 index 0000000000..f95350552c --- /dev/null +++ b/dimos/utils/simple_controller.py @@ -0,0 +1,172 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + + +def normalize_angle(angle: float): # type: ignore[no-untyped-def] + """Normalize angle to the range [-pi, pi].""" + return math.atan2(math.sin(angle), math.cos(angle)) + + +# ---------------------------- +# PID Controller Class +# ---------------------------- +class PIDController: + def __init__( # type: ignore[no-untyped-def] + self, + kp, + ki: float = 0.0, + kd: float = 0.0, + output_limits=(None, None), + integral_limit=None, + deadband: float = 0.0, + output_deadband: float = 0.0, + inverse_output: bool = False, + ) -> None: + """ + Initialize the PID controller. + + Args: + kp (float): Proportional gain. + ki (float): Integral gain. + kd (float): Derivative gain. + output_limits (tuple): (min_output, max_output). Use None for no limit. + integral_limit (float): Maximum absolute value for the integral term (anti-windup). + deadband (float): Size of the deadband region. Error smaller than this will be compensated. + output_deadband (float): Deadband applied to the output to overcome physical system deadband. + inverse_output (bool): When True, the output will be multiplied by -1. + """ + self.kp = kp + self.ki = ki + self.kd = kd + self.min_output, self.max_output = output_limits + self.integral_limit = integral_limit + self.output_deadband = output_deadband + self.deadband = deadband + self.integral = 0.0 + self.prev_error = 0.0 + self.inverse_output = inverse_output + + def update(self, error, dt): # type: ignore[no-untyped-def] + """Compute the PID output with anti-windup, output deadband compensation and output saturation.""" + # Update integral term with windup protection. + self.integral += error * dt + if self.integral_limit is not None: + self.integral = max(-self.integral_limit, min(self.integral, self.integral_limit)) + + # Compute derivative term. + derivative = (error - self.prev_error) / dt if dt > 0 else 0.0 + + if abs(error) < self.deadband: + # Prevent integral windup by not increasing integral term when error is small. + self.integral = 0.0 + derivative = 0.0 + + # Compute raw output. + output = self.kp * error + self.ki * self.integral + self.kd * derivative + + # Apply deadband compensation to the output + output = self._apply_output_deadband_compensation(output) # type: ignore[no-untyped-call] + + # Apply output limits if specified. + if self.max_output is not None: + output = min(self.max_output, output) + if self.min_output is not None: + output = max(self.min_output, output) + + self.prev_error = error + if self.inverse_output: + return -output + return output + + def _apply_output_deadband_compensation(self, output): # type: ignore[no-untyped-def] + """ + Apply deadband compensation to the output. + + This simply adds the deadband value to the magnitude of the output + while preserving the sign, ensuring we overcome the physical deadband. + """ + if self.output_deadband == 0.0 or output == 0.0: + return output + + if output > self.max_output * 0.05: + # For positive output, add the deadband + return output + self.output_deadband + elif output < self.min_output * 0.05: + # For negative output, subtract the deadband + return output - self.output_deadband + else: + return output + + def _apply_deadband_compensation(self, error): # type: ignore[no-untyped-def] + """ + Apply deadband compensation to the error. + + This maintains the original error value, as the deadband compensation + will be applied to the output, not the error. + """ + return error + + +# ---------------------------- +# Visual Servoing Controller Class +# ---------------------------- +class VisualServoingController: + def __init__(self, distance_pid_params, angle_pid_params) -> None: # type: ignore[no-untyped-def] + """ + Initialize the visual servoing controller using enhanced PID controllers. + + Args: + distance_pid_params (tuple): (kp, ki, kd, output_limits, integral_limit, deadband) for distance. + angle_pid_params (tuple): (kp, ki, kd, output_limits, integral_limit, deadband) for angle. + """ + self.distance_pid = PIDController(*distance_pid_params) + self.angle_pid = PIDController(*angle_pid_params) + self.prev_measured_angle = 0.0 # Used for angular feed-forward damping + + def compute_control( # type: ignore[no-untyped-def] + self, measured_distance, measured_angle, desired_distance, desired_angle, dt + ): + """ + Compute the forward (x) and angular (z) commands. + + Args: + measured_distance (float): Current distance to target (from camera). + measured_angle (float): Current angular offset to target (radians). + desired_distance (float): Desired distance to target. + desired_angle (float): Desired angular offset (e.g., 0 for centered). + dt (float): Timestep. + + Returns: + tuple: (forward_command, angular_command) + """ + # Compute the errors. + error_distance = measured_distance - desired_distance + error_angle = normalize_angle(measured_angle - desired_angle) + + # Get raw PID outputs. + forward_command_raw = self.distance_pid.update(error_distance, dt) # type: ignore[no-untyped-call] + angular_command_raw = self.angle_pid.update(error_angle, dt) # type: ignore[no-untyped-call] + + # print("forward: {} angular: {}".format(forward_command_raw, angular_command_raw)) + + angular_command = angular_command_raw + + # Couple forward command to angular error: + # scale the forward command smoothly. + scaling_factor = max(0.0, min(1.0, math.exp(-2.0 * abs(error_angle)))) + forward_command = forward_command_raw * scaling_factor + + return forward_command, angular_command diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py new file mode 100644 index 0000000000..01f145f60c --- /dev/null +++ b/dimos/utils/test_data.py @@ -0,0 +1,130 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +import subprocess + +import pytest + +from dimos.utils import data + + +@pytest.mark.heavy +def test_pull_file() -> None: + repo_root = data._get_repo_root() + test_file_name = "cafe.jpg" + test_file_compressed = data._get_lfs_dir() / (test_file_name + ".tar.gz") + test_file_decompressed = data._get_data_dir() / test_file_name + + # delete decompressed test file if it exists + if test_file_decompressed.exists(): + test_file_decompressed.unlink() + + # delete lfs archive file if it exists + if test_file_compressed.exists(): + test_file_compressed.unlink() + + assert not test_file_compressed.exists() + assert not test_file_decompressed.exists() + + # pull the lfs file reference from git + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" + subprocess.run( + ["git", "checkout", "HEAD", "--", test_file_compressed], + cwd=repo_root, + env=env, + check=True, + capture_output=True, + ) + + # ensure we have a pointer file from git (small ASCII text file) + assert test_file_compressed.exists() + assert test_file_compressed.stat().st_size < 200 + + # trigger a data file pull + assert data.get_data(test_file_name) == test_file_decompressed + + # validate data is received + assert test_file_compressed.exists() + assert test_file_decompressed.exists() + + # validate hashes + with test_file_compressed.open("rb") as f: + assert test_file_compressed.stat().st_size > 200 + compressed_sha256 = hashlib.sha256(f.read()).hexdigest() + assert ( + compressed_sha256 == "b8cf30439b41033ccb04b09b9fc8388d18fb544d55b85c155dbf85700b9e7603" + ) + + with test_file_decompressed.open("rb") as f: + decompressed_sha256 = hashlib.sha256(f.read()).hexdigest() + assert ( + decompressed_sha256 + == "55d451dde49b05e3ad386fdd4ae9e9378884b8905bff1ca8aaea7d039ff42ddd" + ) + + +@pytest.mark.heavy +def test_pull_dir() -> None: + repo_root = data._get_repo_root() + test_dir_name = "ab_lidar_frames" + test_dir_compressed = data._get_lfs_dir() / (test_dir_name + ".tar.gz") + test_dir_decompressed = data._get_data_dir() / test_dir_name + + # delete decompressed test directory if it exists + if test_dir_decompressed.exists(): + for item in test_dir_decompressed.iterdir(): + item.unlink() + test_dir_decompressed.rmdir() + + # delete lfs archive file if it exists + if test_dir_compressed.exists(): + test_dir_compressed.unlink() + + # pull the lfs file reference from git + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" + subprocess.run( + ["git", "checkout", "HEAD", "--", test_dir_compressed], + cwd=repo_root, + env=env, + check=True, + capture_output=True, + ) + + # ensure we have a pointer file from git (small ASCII text file) + assert test_dir_compressed.exists() + assert test_dir_compressed.stat().st_size < 200 + + # trigger a data file pull + assert data.get_data(test_dir_name) == test_dir_decompressed + assert test_dir_compressed.stat().st_size > 200 + + # validate data is received + assert test_dir_compressed.exists() + assert test_dir_decompressed.exists() + + for [file, expected_hash] in zip( + sorted(test_dir_decompressed.iterdir()), + [ + "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", + "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", + ], + strict=False, + ): + with file.open("rb") as f: + sha256 = hashlib.sha256(f.read()).hexdigest() + assert sha256 == expected_hash diff --git a/dimos/utils/test_foxglove_bridge.py b/dimos/utils/test_foxglove_bridge.py new file mode 100644 index 0000000000..c45dcde660 --- /dev/null +++ b/dimos/utils/test_foxglove_bridge.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test for foxglove bridge import and basic functionality +""" + +import warnings + +import pytest + +warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets.server") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets.legacy") + + +def test_foxglove_bridge_import() -> None: + """Test that the foxglove bridge can be imported successfully.""" + try: + from dimos_lcm.foxglove_bridge import FoxgloveBridge + except ImportError as e: + pytest.fail(f"Failed to import foxglove bridge: {e}") + + +def test_foxglove_bridge_runner_init() -> None: + """Test that LcmFoxgloveBridge can be initialized with default parameters.""" + try: + from dimos_lcm.foxglove_bridge import FoxgloveBridge + + runner = FoxgloveBridge(host="localhost", port=8765, debug=False, num_threads=2) + + # Check that the runner was created successfully + assert runner is not None + + except Exception as e: + pytest.fail(f"Failed to initialize LcmFoxgloveBridge: {e}") + + +def test_foxglove_bridge_runner_params() -> None: + """Test that LcmFoxgloveBridge accepts various parameter configurations.""" + try: + from dimos_lcm.foxglove_bridge import FoxgloveBridge + + configs = [ + {"host": "0.0.0.0", "port": 8765, "debug": True, "num_threads": 1}, + {"host": "127.0.0.1", "port": 9090, "debug": False, "num_threads": 4}, + {"host": "localhost", "port": 8080, "debug": True, "num_threads": 2}, + ] + + for config in configs: + runner = FoxgloveBridge(**config) + assert runner is not None + + except Exception as e: + pytest.fail(f"Failed to create runner with different configs: {e}") + + +def test_bridge_runner_has_run_method() -> None: + """Test that the bridge runner has a run method that can be called.""" + try: + from dimos_lcm.foxglove_bridge import FoxgloveBridge + + runner = FoxgloveBridge(host="localhost", port=8765, debug=False, num_threads=1) + + # Check that the run method exists + assert hasattr(runner, "run") + assert callable(runner.run) + + except Exception as e: + pytest.fail(f"Failed to verify run method: {e}") diff --git a/dimos/utils/test_generic.py b/dimos/utils/test_generic.py new file mode 100644 index 0000000000..0f691bc23c --- /dev/null +++ b/dimos/utils/test_generic.py @@ -0,0 +1,31 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from dimos.utils.generic import short_id + + +def test_short_id_hello_world() -> None: + assert short_id("HelloWorld") == "6GgJmzi1KYf4iaHVxk" + + +def test_short_id_uuid_one(mocker) -> None: + mocker.patch("uuid.uuid4", return_value=UUID("11111111-1111-1111-1111-111111111111")) + assert short_id() == "wcFtOGNXQnQFZ8QRh1" + + +def test_short_id_uuid_zero(mocker) -> None: + mocker.patch("uuid.uuid4", return_value=UUID("00000000-0000-0000-0000-000000000000")) + assert short_id() == "000000000000000000" diff --git a/dimos/utils/test_llm_utils.py b/dimos/utils/test_llm_utils.py new file mode 100644 index 0000000000..0a3812aeaf --- /dev/null +++ b/dimos/utils/test_llm_utils.py @@ -0,0 +1,123 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LLM utility functions.""" + +import json + +import pytest + +from dimos.utils.llm_utils import extract_json + + +def test_extract_json_clean_response() -> None: + """Test extract_json with clean JSON response.""" + clean_json = '[["object", 1, 2, 3, 4]]' + result = extract_json(clean_json) + assert result == [["object", 1, 2, 3, 4]] + + +def test_extract_json_with_text_before_after() -> None: + """Test extract_json with text before and after JSON.""" + messy = """Here's what I found: + [ + ["person", 10, 20, 30, 40], + ["car", 50, 60, 70, 80] + ] + Hope this helps!""" + result = extract_json(messy) + assert result == [["person", 10, 20, 30, 40], ["car", 50, 60, 70, 80]] + + +def test_extract_json_with_emojis() -> None: + """Test extract_json with emojis and markdown code blocks.""" + messy = """Sure! 😊 Here are the detections: + + ```json + [["human", 100, 200, 300, 400]] + ``` + + Let me know if you need anything else! 👍""" + result = extract_json(messy) + assert result == [["human", 100, 200, 300, 400]] + + +def test_extract_json_multiple_json_blocks() -> None: + """Test extract_json when there are multiple JSON blocks.""" + messy = """First attempt (wrong format): + {"error": "not what we want"} + + Correct format: + [ + ["cat", 10, 10, 50, 50], + ["dog", 60, 60, 100, 100] + ] + + Another block: {"also": "not needed"}""" + result = extract_json(messy) + # Should return the first valid array + assert result == [["cat", 10, 10, 50, 50], ["dog", 60, 60, 100, 100]] + + +def test_extract_json_object() -> None: + """Test extract_json with JSON object instead of array.""" + response = 'The result is: {"status": "success", "count": 5}' + result = extract_json(response) + assert result == {"status": "success", "count": 5} + + +def test_extract_json_nested_structures() -> None: + """Test extract_json with nested arrays and objects.""" + response = """Processing complete: + [ + ["label1", 1, 2, 3, 4], + {"nested": {"value": 10}}, + ["label2", 5, 6, 7, 8] + ]""" + result = extract_json(response) + assert result[0] == ["label1", 1, 2, 3, 4] + assert result[1] == {"nested": {"value": 10}} + assert result[2] == ["label2", 5, 6, 7, 8] + + +def test_extract_json_invalid() -> None: + """Test extract_json raises error when no valid JSON found.""" + response = "This response has no valid JSON at all!" + with pytest.raises(json.JSONDecodeError) as exc_info: + extract_json(response) + assert "Could not extract valid JSON" in str(exc_info.value) + + +# Test with actual LLM response format +MOCK_LLM_RESPONSE = """ + Yes :) + + [ + ["humans", 76, 368, 219, 580], + ["humans", 354, 372, 512, 525], + ["humans", 409, 370, 615, 748], + ["humans", 628, 350, 762, 528], + ["humans", 785, 323, 960, 650] + ] + + Hope this helps!😀😊 :)""" + + +def test_extract_json_with_real_llm_response() -> None: + """Test extract_json with actual messy LLM response.""" + result = extract_json(MOCK_LLM_RESPONSE) + assert isinstance(result, list) + assert len(result) == 5 + assert result[0] == ["humans", 76, 368, 219, 580] + assert result[-1] == ["humans", 785, 323, 960, 650] diff --git a/dimos/utils/test_reactive.py b/dimos/utils/test_reactive.py new file mode 100644 index 0000000000..17b69ba0aa --- /dev/null +++ b/dimos/utils/test_reactive.py @@ -0,0 +1,285 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +import time +from typing import Any, TypeVar + +import numpy as np +import pytest +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.scheduler import ThreadPoolScheduler + +from dimos.utils.reactive import ( + backpressure, + callback_to_observable, + getter_ondemand, + getter_streaming, +) + + +def measure_time(func: Callable[[], Any], iterations: int = 1) -> float: + start_time = time.time() + result = func() + end_time = time.time() + total_time = end_time - start_time + return result, total_time + + +def assert_time( + func: Callable[[], Any], assertion: Callable[[int], bool], assert_fail_msg=None +) -> None: + [result, total_time] = measure_time(func) + assert assertion(total_time), assert_fail_msg + f", took {round(total_time, 2)}s" + return result + + +def min_time( + func: Callable[[], Any], min_t: int, assert_fail_msg: str = "Function returned too fast" +): + return assert_time( + func, (lambda t: t >= min_t * 0.98), assert_fail_msg + f", min: {min_t} seconds" + ) + + +def max_time(func: Callable[[], Any], max_t: int, assert_fail_msg: str = "Function took too long"): + return assert_time(func, (lambda t: t < max_t), assert_fail_msg + f", max: {max_t} seconds") + + +T = TypeVar("T") + + +def dispose_spy(source: rx.Observable[T]) -> rx.Observable[T]: + state = {"active": 0} + + def factory(observer, scheduler=None): + state["active"] += 1 + upstream = source.subscribe(observer, scheduler=scheduler) + + def _dispose() -> None: + upstream.dispose() + state["active"] -= 1 + + return Disposable(_dispose) + + proxy = rx.create(factory) + proxy.subs_number = lambda: state["active"] + proxy.is_disposed = lambda: state["active"] == 0 + return proxy + + +def test_backpressure_handling() -> None: + # Create a dedicated scheduler for this test to avoid thread leaks + test_scheduler = ThreadPoolScheduler(max_workers=8) + try: + received_fast = [] + received_slow = [] + # Create an observable that emits numpy arrays instead of integers + source = dispose_spy( + rx.interval(0.1).pipe(ops.map(lambda i: np.array([i, i + 1, i + 2])), ops.take(50)) + ) + + # Wrap with backpressure handling + safe_source = backpressure(source, scheduler=test_scheduler) + + # Fast sub + subscription1 = safe_source.subscribe(lambda x: received_fast.append(x)) + + # Slow sub (shouldn't block above) + subscription2 = safe_source.subscribe(lambda x: (time.sleep(0.25), received_slow.append(x))) + + time.sleep(2.5) + + subscription1.dispose() + assert not source.is_disposed(), "Observable should not be disposed yet" + subscription2.dispose() + # Wait longer to ensure background threads finish processing + # (the slow subscriber sleeps for 0.25s, so we need to wait at least that long) + time.sleep(0.5) + assert source.is_disposed(), "Observable should be disposed" + + # Check results + print("Fast observer received:", len(received_fast), [arr[0] for arr in received_fast]) + print("Slow observer received:", len(received_slow), [arr[0] for arr in received_slow]) + + # Fast observer should get all or nearly all items + assert len(received_fast) > 15, ( + f"Expected fast observer to receive most items, got {len(received_fast)}" + ) + + # Slow observer should get fewer items due to backpressure handling + assert len(received_slow) < len(received_fast), ( + "Slow observer should receive fewer items than fast observer" + ) + # Specifically, processing at 0.25s means ~4 items per second, so expect 8-10 items + assert 7 <= len(received_slow) <= 11, f"Expected 7-11 items, got {len(received_slow)}" + + # The slow observer should skip items (not process them in sequence) + # We test this by checking that the difference between consecutive arrays is sometimes > 1 + has_skips = False + for i in range(1, len(received_slow)): + if received_slow[i][0] - received_slow[i - 1][0] > 1: + has_skips = True + break + assert has_skips, "Slow observer should skip items due to backpressure" + finally: + # Always shutdown the scheduler to clean up threads + test_scheduler.executor.shutdown(wait=True) + + +def test_getter_streaming_blocking() -> None: + source = dispose_spy( + rx.interval(0.2).pipe(ops.map(lambda i: np.array([i, i + 1, i + 2])), ops.take(50)) + ) + assert source.is_disposed() + + getter = min_time( + lambda: getter_streaming(source), + 0.2, + "Latest getter needs to block until first msg is ready", + ) + assert np.array_equal(getter(), np.array([0, 1, 2])), ( + f"Expected to get the first array [0,1,2], got {getter()}" + ) + + time.sleep(0.5) + assert getter()[0] >= 2, f"Expected array with first value >= 2, got {getter()}" + time.sleep(0.5) + assert getter()[0] >= 4, f"Expected array with first value >= 4, got {getter()}" + + getter.dispose() + time.sleep(0.3) # Wait for background interval timer threads to finish + assert source.is_disposed(), "Observable should be disposed" + + +def test_getter_streaming_blocking_timeout() -> None: + source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) + with pytest.raises(Exception): + getter = getter_streaming(source, timeout=0.1) + getter.dispose() + time.sleep(0.3) # Wait for background interval timer threads to finish + assert source.is_disposed() + + +def test_getter_streaming_nonblocking() -> None: + source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) + + getter = max_time( + lambda: getter_streaming(source, nonblocking=True), + 0.1, + "nonblocking getter init shouldn't block", + ) + min_time(getter, 0.1, "Expected for first value call to block if cache is empty") + assert getter() == 0 + + time.sleep(0.5) + assert getter() >= 2, f"Expected value >= 2, got {getter()}" + + # sub is active + assert not source.is_disposed() + + time.sleep(0.5) + assert getter() >= 4, f"Expected value >= 4, got {getter()}" + + getter.dispose() + time.sleep(0.3) # Wait for background interval timer threads to finish + assert source.is_disposed(), "Observable should be disposed" + + +def test_getter_streaming_nonblocking_timeout() -> None: + source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) + getter = getter_streaming(source, timeout=0.1, nonblocking=True) + with pytest.raises(Exception): + getter() + + assert not source.is_disposed(), "is not disposed, this is a job of the caller" + + # Clean up the subscription to avoid thread leak + getter.dispose() + time.sleep(0.3) # Wait for background threads to finish + assert source.is_disposed(), "Observable should be disposed after cleanup" + + +def test_getter_ondemand() -> None: + # Create a controlled scheduler to avoid thread leaks from rx.interval + test_scheduler = ThreadPoolScheduler(max_workers=4) + try: + source = dispose_spy(rx.interval(0.1, scheduler=test_scheduler).pipe(ops.take(50))) + getter = getter_ondemand(source) + assert source.is_disposed(), "Observable should be disposed" + result = min_time(getter, 0.05) + assert result == 0, f"Expected to get the first value of 0, got {result}" + # Wait for background threads to clean up + time.sleep(0.3) + assert source.is_disposed(), "Observable should be disposed" + result2 = getter() + assert result2 == 0, f"Expected to get the first value of 0, got {result2}" + assert source.is_disposed(), "Observable should be disposed" + # Wait for threads to finish + time.sleep(0.3) + finally: + # Explicitly shutdown the scheduler to clean up threads + test_scheduler.executor.shutdown(wait=True) + + +def test_getter_ondemand_timeout() -> None: + source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) + getter = getter_ondemand(source, timeout=0.1) + with pytest.raises(Exception): + getter() + assert source.is_disposed(), "Observable should be disposed" + # Wait for background interval timer threads to finish + time.sleep(0.3) + + +def test_callback_to_observable() -> None: + # Test converting a callback-based API to an Observable + received = [] + callback = None + + # Mock start function that captures the callback + def start_fn(cb) -> str: + nonlocal callback + callback = cb + return "start_result" + + # Mock stop function + stop_called = False + + def stop_fn(cb) -> None: + nonlocal stop_called + stop_called = True + + # Create observable from callback + observable = callback_to_observable(start_fn, stop_fn) + + # Subscribe to the observable + subscription = observable.subscribe(lambda x: received.append(x)) + + # Verify start was called and we have access to the callback + assert callback is not None + + # Simulate callback being triggered with different messages + callback("message1") + callback(42) + callback({"key": "value"}) + + # Check that all messages were received + assert received == ["message1", 42, {"key": "value"}] + + # Dispose subscription and check that stop was called + subscription.dispose() + assert stop_called, "Stop function should be called on dispose" diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py new file mode 100644 index 0000000000..0aee51d133 --- /dev/null +++ b/dimos/utils/test_testing.py @@ -0,0 +1,279 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from reactivex import operators as ops + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils import testing +from dimos.utils.data import get_data + + +def test_sensor_replay() -> None: + counter = 0 + for message in testing.SensorReplay(name="office_lidar").iterate(): + counter += 1 + assert isinstance(message, dict) + assert counter == 500 + + +def test_sensor_replay_cast() -> None: + counter = 0 + for message in testing.SensorReplay( + name="office_lidar", autocast=LidarMessage.from_msg + ).iterate(): + counter += 1 + assert isinstance(message, LidarMessage) + assert counter == 500 + + +def test_timed_sensor_replay() -> None: + get_data("unitree_office_walk") + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + itermsgs = [] + for msg in odom_store.iterate(): + itermsgs.append(msg) + if len(itermsgs) > 9: + break + + assert len(itermsgs) == 10 + + print("\n") + + timed_msgs = [] + + for msg in odom_store.stream().pipe(ops.take(10), ops.to_list()).run(): + timed_msgs.append(msg) + + assert len(timed_msgs) == 10 + + for i in range(10): + print(itermsgs[i], timed_msgs[i]) + assert itermsgs[i] == timed_msgs[i] + + +def test_iterate_ts_no_seek() -> None: + """Test iterate_ts without seek (start_timestamp=None)""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Test without seek + ts_msgs = [] + for ts, msg in odom_store.iterate_ts(): + ts_msgs.append((ts, msg)) + if len(ts_msgs) >= 5: + break + + assert len(ts_msgs) == 5 + # Check that we get tuples of (timestamp, data) + for ts, msg in ts_msgs: + assert isinstance(ts, float) + assert isinstance(msg, Odometry) + + +def test_iterate_ts_with_from_timestamp() -> None: + """Test iterate_ts with from_timestamp (absolute timestamp)""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # First get all messages to find a good seek point + all_msgs = [] + for ts, msg in odom_store.iterate_ts(): + all_msgs.append((ts, msg)) + if len(all_msgs) >= 10: + break + + # Seek to timestamp of 5th message + seek_timestamp = all_msgs[4][0] + + # Test with from_timestamp + seeked_msgs = [] + for ts, msg in odom_store.iterate_ts(from_timestamp=seek_timestamp): + seeked_msgs.append((ts, msg)) + if len(seeked_msgs) >= 5: + break + + assert len(seeked_msgs) == 5 + # First message should be at or after seek timestamp + assert seeked_msgs[0][0] >= seek_timestamp + # Should match the data from position 5 onward + assert seeked_msgs[0][1] == all_msgs[4][1] + + +def test_iterate_ts_with_relative_seek() -> None: + """Test iterate_ts with seek (relative seconds after first timestamp)""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Get first few messages to understand timing + all_msgs = [] + for ts, msg in odom_store.iterate_ts(): + all_msgs.append((ts, msg)) + if len(all_msgs) >= 10: + break + + # Calculate relative seek time (e.g., 0.5 seconds after start) + first_ts = all_msgs[0][0] + seek_seconds = 0.5 + expected_start_ts = first_ts + seek_seconds + + # Test with relative seek + seeked_msgs = [] + for ts, msg in odom_store.iterate_ts(seek=seek_seconds): + seeked_msgs.append((ts, msg)) + if len(seeked_msgs) >= 5: + break + + # First message should be at or after expected timestamp + assert seeked_msgs[0][0] >= expected_start_ts + # Make sure we're actually skipping some messages + assert seeked_msgs[0][0] > first_ts + + +def test_stream_with_seek() -> None: + """Test stream method with seek parameters""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Test stream with relative seek + msgs_with_seek = [] + for msg in odom_store.stream(seek=0.2).pipe(ops.take(5), ops.to_list()).run(): + msgs_with_seek.append(msg) + + assert len(msgs_with_seek) == 5 + + # Test stream with from_timestamp + # First get a reference timestamp + first_msgs = [] + for msg in odom_store.stream().pipe(ops.take(3), ops.to_list()).run(): + first_msgs.append(msg) + + # Now test from_timestamp (would need actual timestamps from iterate_ts to properly test) + # This is a basic test to ensure the parameter is accepted + msgs_with_timestamp = [] + for msg in ( + odom_store.stream(from_timestamp=1000000000.0).pipe(ops.take(3), ops.to_list()).run() + ): + msgs_with_timestamp.append(msg) + + +def test_duration_with_loop() -> None: + """Test duration parameter with looping in TimedSensorReplay""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Collect timestamps from a small duration window + collected_ts = [] + duration = 0.3 # 300ms window + + # First pass: collect timestamps in the duration window + for ts, _msg in odom_store.iterate_ts(duration=duration): + collected_ts.append(ts) + if len(collected_ts) >= 100: # Safety limit + break + + # Should have some messages but not too many + assert len(collected_ts) > 0 + assert len(collected_ts) < 20 # Assuming ~30Hz data + + # Test looping with duration - should repeat the same window + loop_count = 0 + prev_ts = None + + for ts, _msg in odom_store.iterate_ts(duration=duration, loop=True): + if prev_ts is not None and ts < prev_ts: + # We've looped back to the beginning + loop_count += 1 + if loop_count >= 2: # Stop after 2 full loops + break + prev_ts = ts + + assert loop_count >= 2 # Verify we actually looped + + +def test_first_methods() -> None: + """Test first() and first_timestamp() methods""" + + # Test SensorReplay.first() + lidar_replay = testing.SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + print("first file", lidar_replay.files[0]) + # Verify the first file ends with 000.pickle using regex + assert re.search(r"000\.pickle$", str(lidar_replay.files[0])), ( + f"Expected first file to end with 000.pickle, got {lidar_replay.files[0]}" + ) + + first_msg = lidar_replay.first() + assert first_msg is not None + assert isinstance(first_msg, LidarMessage) + + # Verify it's the same type as first item from iterate() + first_from_iterate = next(lidar_replay.iterate()) + print("DONE") + assert type(first_msg) is type(first_from_iterate) + # Since LidarMessage.from_msg uses time.time(), timestamps will be slightly different + assert abs(first_msg.ts - first_from_iterate.ts) < 1.0 # Within 1 second tolerance + + # Test TimedSensorReplay.first_timestamp() + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + first_ts = odom_store.first_timestamp() + assert first_ts is not None + assert isinstance(first_ts, float) + + # Verify it matches the timestamp from iterate_ts + ts_from_iterate, _ = next(odom_store.iterate_ts()) + assert first_ts == ts_from_iterate + + # Test that first() returns just the data + first_data = odom_store.first() + assert first_data is not None + assert isinstance(first_data, Odometry) + + +def test_find_closest() -> None: + """Test find_closest method in TimedSensorReplay""" + odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + + # Get some reference timestamps + timestamps = [] + for ts, _msg in odom_store.iterate_ts(): + timestamps.append(ts) + if len(timestamps) >= 10: + break + + # Test exact match + target_ts = timestamps[5] + result = odom_store.find_closest(target_ts) + assert result is not None + assert isinstance(result, Odometry) + + # Test between timestamps + mid_ts = (timestamps[3] + timestamps[4]) / 2 + result = odom_store.find_closest(mid_ts) + assert result is not None + + # Test with tolerance + far_future = timestamps[-1] + 100.0 + result = odom_store.find_closest(far_future, tolerance=1.0) + assert result is None # Too far away + + result = odom_store.find_closest(timestamps[0] - 0.001, tolerance=0.01) + assert result is not None # Within tolerance + + # Test find_closest_seek + result = odom_store.find_closest_seek(0.5) # 0.5 seconds from start + assert result is not None + assert isinstance(result, Odometry) + + # Test with negative seek (before start) + result = odom_store.find_closest_seek(-1.0) + assert result is not None # Should still return closest (first frame) diff --git a/dimos/utils/test_transform_utils.py b/dimos/utils/test_transform_utils.py new file mode 100644 index 0000000000..b404579598 --- /dev/null +++ b/dimos/utils/test_transform_utils.py @@ -0,0 +1,678 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from scipy.spatial.transform import Rotation as R + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.utils import transform_utils + + +class TestNormalizeAngle: + def test_normalize_angle_zero(self) -> None: + assert transform_utils.normalize_angle(0) == 0 + + def test_normalize_angle_pi(self) -> None: + assert np.isclose(transform_utils.normalize_angle(np.pi), np.pi) + + def test_normalize_angle_negative_pi(self) -> None: + assert np.isclose(transform_utils.normalize_angle(-np.pi), -np.pi) + + def test_normalize_angle_two_pi(self) -> None: + # 2*pi should normalize to 0 + assert np.isclose(transform_utils.normalize_angle(2 * np.pi), 0, atol=1e-10) + + def test_normalize_angle_large_positive(self) -> None: + # Large positive angle should wrap to [-pi, pi] + angle = 5 * np.pi + normalized = transform_utils.normalize_angle(angle) + assert -np.pi <= normalized <= np.pi + assert np.isclose(normalized, np.pi) + + def test_normalize_angle_large_negative(self) -> None: + # Large negative angle should wrap to [-pi, pi] + angle = -5 * np.pi + normalized = transform_utils.normalize_angle(angle) + assert -np.pi <= normalized <= np.pi + # -5*pi = -pi (odd multiple of pi wraps to -pi) + assert np.isclose(normalized, -np.pi) or np.isclose(normalized, np.pi) + + +# Tests for distance_angle_to_goal_xy removed as function doesn't exist in the module + + +class TestPoseToMatrix: + def test_identity_pose(self) -> None: + pose = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + T = transform_utils.pose_to_matrix(pose) + assert np.allclose(T, np.eye(4)) + + def test_translation_only(self) -> None: + pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + T = transform_utils.pose_to_matrix(pose) + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected) + + def test_rotation_only_90_degrees_z(self) -> None: + # 90 degree rotation around z-axis + quat = R.from_euler("z", np.pi / 2).as_quat() + pose = Pose(Vector3(0, 0, 0), Quaternion(quat[0], quat[1], quat[2], quat[3])) + T = transform_utils.pose_to_matrix(pose) + + # Check rotation part + expected_rot = R.from_euler("z", np.pi / 2).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + + # Check translation is zero + assert np.allclose(T[:3, 3], [0, 0, 0]) + + def test_translation_and_rotation(self) -> None: + quat = R.from_euler("xyz", [np.pi / 4, np.pi / 6, np.pi / 3]).as_quat() + pose = Pose(Vector3(5, -3, 2), Quaternion(quat[0], quat[1], quat[2], quat[3])) + T = transform_utils.pose_to_matrix(pose) + + # Check translation + assert np.allclose(T[:3, 3], [5, -3, 2]) + + # Check rotation + expected_rot = R.from_euler("xyz", [np.pi / 4, np.pi / 6, np.pi / 3]).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + + # Check bottom row + assert np.allclose(T[3, :], [0, 0, 0, 1]) + + def test_zero_norm_quaternion(self) -> None: + # Test handling of zero norm quaternion + pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 0)) + T = transform_utils.pose_to_matrix(pose) + + # Should use identity rotation + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected) + + +class TestMatrixToPose: + def test_identity_matrix(self) -> None: + T = np.eye(4) + pose = transform_utils.matrix_to_pose(T) + assert pose.position.x == 0 + assert pose.position.y == 0 + assert pose.position.z == 0 + assert np.isclose(pose.orientation.w, 1) + assert np.isclose(pose.orientation.x, 0) + assert np.isclose(pose.orientation.y, 0) + assert np.isclose(pose.orientation.z, 0) + + def test_translation_only(self) -> None: + T = np.eye(4) + T[:3, 3] = [1, 2, 3] + pose = transform_utils.matrix_to_pose(T) + assert pose.position.x == 1 + assert pose.position.y == 2 + assert pose.position.z == 3 + assert np.isclose(pose.orientation.w, 1) + + def test_rotation_only(self) -> None: + T = np.eye(4) + T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() + pose = transform_utils.matrix_to_pose(T) + + # Check position is zero + assert pose.position.x == 0 + assert pose.position.y == 0 + assert pose.position.z == 0 + + # Check rotation + quat = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + recovered_rot = R.from_quat(quat).as_matrix() + assert np.allclose(recovered_rot, T[:3, :3]) + + def test_round_trip_conversion(self) -> None: + # Test that pose -> matrix -> pose gives same result + # Use a properly normalized quaternion + quat = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_quat() + original_pose = Pose( + Vector3(1.5, -2.3, 0.7), Quaternion(quat[0], quat[1], quat[2], quat[3]) + ) + T = transform_utils.pose_to_matrix(original_pose) + recovered_pose = transform_utils.matrix_to_pose(T) + + assert np.isclose(recovered_pose.position.x, original_pose.position.x) + assert np.isclose(recovered_pose.position.y, original_pose.position.y) + assert np.isclose(recovered_pose.position.z, original_pose.position.z) + assert np.isclose(recovered_pose.orientation.x, original_pose.orientation.x, atol=1e-6) + assert np.isclose(recovered_pose.orientation.y, original_pose.orientation.y, atol=1e-6) + assert np.isclose(recovered_pose.orientation.z, original_pose.orientation.z, atol=1e-6) + assert np.isclose(recovered_pose.orientation.w, original_pose.orientation.w, atol=1e-6) + + +class TestApplyTransform: + def test_identity_transform(self) -> None: + pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + T_identity = np.eye(4) + result = transform_utils.apply_transform(pose, T_identity) + + assert np.isclose(result.position.x, pose.position.x) + assert np.isclose(result.position.y, pose.position.y) + assert np.isclose(result.position.z, pose.position.z) + + def test_translation_transform(self) -> None: + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + T = np.eye(4) + T[:3, 3] = [2, 3, 4] + result = transform_utils.apply_transform(pose, T) + + assert np.isclose(result.position.x, 3) # 2 + 1 + assert np.isclose(result.position.y, 3) # 3 + 0 + assert np.isclose(result.position.z, 4) # 4 + 0 + + def test_rotation_transform(self) -> None: + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + T = np.eye(4) + T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() # 90 degree rotation + result = transform_utils.apply_transform(pose, T) + + # After 90 degree rotation around z, point (1,0,0) becomes (0,1,0) + assert np.isclose(result.position.x, 0, atol=1e-10) + assert np.isclose(result.position.y, 1) + assert np.isclose(result.position.z, 0) + + def test_transform_with_transform_object(self) -> None: + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + pose.frame_id = "base" + + transform = Transform() + transform.frame_id = "world" + transform.child_frame_id = "base" + transform.translation = Vector3(2, 3, 4) + transform.rotation = Quaternion(0, 0, 0, 1) + + result = transform_utils.apply_transform(pose, transform) + assert np.isclose(result.position.x, 3) + assert np.isclose(result.position.y, 3) + assert np.isclose(result.position.z, 4) + + def test_transform_frame_mismatch_raises(self) -> None: + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + pose.frame_id = "base" + + transform = Transform() + transform.frame_id = "world" + transform.child_frame_id = "different_frame" + transform.translation = Vector3(2, 3, 4) + transform.rotation = Quaternion(0, 0, 0, 1) + + with pytest.raises(ValueError, match="does not match"): + transform_utils.apply_transform(pose, transform) + + +class TestOpticalToRobotFrame: + def test_identity_at_origin(self) -> None: + pose = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + assert result.position.x == 0 + assert result.position.y == 0 + assert result.position.z == 0 + + def test_position_transformation(self) -> None: + # Optical: X=right(1), Y=down(0), Z=forward(0) + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + + # Robot: X=forward(0), Y=left(-1), Z=up(0) + assert np.isclose(result.position.x, 0) # Forward = Camera Z + assert np.isclose(result.position.y, -1) # Left = -Camera X + assert np.isclose(result.position.z, 0) # Up = -Camera Y + + def test_forward_position(self) -> None: + # Optical: X=right(0), Y=down(0), Z=forward(2) + pose = Pose(Vector3(0, 0, 2), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + + # Robot: X=forward(2), Y=left(0), Z=up(0) + assert np.isclose(result.position.x, 2) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, 0) + + def test_down_position(self) -> None: + # Optical: X=right(0), Y=down(3), Z=forward(0) + pose = Pose(Vector3(0, 3, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.optical_to_robot_frame(pose) + + # Robot: X=forward(0), Y=left(0), Z=up(-3) + assert np.isclose(result.position.x, 0) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, -3) + + def test_round_trip_optical_robot(self) -> None: + original_pose = Pose(Vector3(1, 2, 3), Quaternion(0.1, 0.2, 0.3, 0.9165151389911680)) + robot_pose = transform_utils.optical_to_robot_frame(original_pose) + recovered_pose = transform_utils.robot_to_optical_frame(robot_pose) + + assert np.isclose(recovered_pose.position.x, original_pose.position.x, atol=1e-10) + assert np.isclose(recovered_pose.position.y, original_pose.position.y, atol=1e-10) + assert np.isclose(recovered_pose.position.z, original_pose.position.z, atol=1e-10) + + +class TestRobotToOpticalFrame: + def test_position_transformation(self) -> None: + # Robot: X=forward(1), Y=left(0), Z=up(0) + pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.robot_to_optical_frame(pose) + + # Optical: X=right(0), Y=down(0), Z=forward(1) + assert np.isclose(result.position.x, 0) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, 1) + + def test_left_position(self) -> None: + # Robot: X=forward(0), Y=left(2), Z=up(0) + pose = Pose(Vector3(0, 2, 0), Quaternion(0, 0, 0, 1)) + result = transform_utils.robot_to_optical_frame(pose) + + # Optical: X=right(-2), Y=down(0), Z=forward(0) + assert np.isclose(result.position.x, -2) + assert np.isclose(result.position.y, 0) + assert np.isclose(result.position.z, 0) + + def test_up_position(self) -> None: + # Robot: X=forward(0), Y=left(0), Z=up(3) + pose = Pose(Vector3(0, 0, 3), Quaternion(0, 0, 0, 1)) + result = transform_utils.robot_to_optical_frame(pose) + + # Optical: X=right(0), Y=down(-3), Z=forward(0) + assert np.isclose(result.position.x, 0) + assert np.isclose(result.position.y, -3) + assert np.isclose(result.position.z, 0) + + +class TestYawTowardsPoint: + def test_yaw_from_origin(self) -> None: + # Point at (1, 0) from origin should have yaw = 0 + position = Vector3(1, 0, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, 0) + + def test_yaw_ninety_degrees(self) -> None: + # Point at (0, 1) from origin should have yaw = pi/2 + position = Vector3(0, 1, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, np.pi / 2) + + def test_yaw_negative_ninety_degrees(self) -> None: + # Point at (0, -1) from origin should have yaw = -pi/2 + position = Vector3(0, -1, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, -np.pi / 2) + + def test_yaw_forty_five_degrees(self) -> None: + # Point at (1, 1) from origin should have yaw = pi/4 + position = Vector3(1, 1, 0) + yaw = transform_utils.yaw_towards_point(position) + assert np.isclose(yaw, np.pi / 4) + + def test_yaw_with_custom_target(self) -> None: + # Point at (3, 2) from target (1, 1) + position = Vector3(3, 2, 0) + target = Vector3(1, 1, 0) + yaw = transform_utils.yaw_towards_point(position, target) + # Direction is (2, 1), so yaw = atan2(1, 2) + expected = np.arctan2(1, 2) + assert np.isclose(yaw, expected) + + +# Tests for transform_robot_to_map removed as function doesn't exist in the module + + +class TestCreateTransformFrom6DOF: + def test_identity_transform(self) -> None: + trans = Vector3(0, 0, 0) + euler = Vector3(0, 0, 0) + T = transform_utils.create_transform_from_6dof(trans, euler) + assert np.allclose(T, np.eye(4)) + + def test_translation_only(self) -> None: + trans = Vector3(1, 2, 3) + euler = Vector3(0, 0, 0) + T = transform_utils.create_transform_from_6dof(trans, euler) + + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected) + + def test_rotation_only(self) -> None: + trans = Vector3(0, 0, 0) + euler = Vector3(np.pi / 4, np.pi / 6, np.pi / 3) + T = transform_utils.create_transform_from_6dof(trans, euler) + + expected_rot = R.from_euler("xyz", [np.pi / 4, np.pi / 6, np.pi / 3]).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + assert np.allclose(T[:3, 3], [0, 0, 0]) + assert np.allclose(T[3, :], [0, 0, 0, 1]) + + def test_translation_and_rotation(self) -> None: + trans = Vector3(5, -3, 2) + euler = Vector3(0.1, 0.2, 0.3) + T = transform_utils.create_transform_from_6dof(trans, euler) + + expected_rot = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_matrix() + assert np.allclose(T[:3, :3], expected_rot) + assert np.allclose(T[:3, 3], [5, -3, 2]) + + def test_small_angles_threshold(self) -> None: + trans = Vector3(1, 2, 3) + euler = Vector3(1e-7, 1e-8, 1e-9) # Very small angles + T = transform_utils.create_transform_from_6dof(trans, euler) + + # Should be effectively identity rotation + expected = np.eye(4) + expected[:3, 3] = [1, 2, 3] + assert np.allclose(T, expected, atol=1e-6) + + +class TestInvertTransform: + def test_identity_inverse(self) -> None: + T = np.eye(4) + T_inv = transform_utils.invert_transform(T) + assert np.allclose(T_inv, np.eye(4)) + + def test_translation_inverse(self) -> None: + T = np.eye(4) + T[:3, 3] = [1, 2, 3] + T_inv = transform_utils.invert_transform(T) + + # Inverse should negate translation + expected = np.eye(4) + expected[:3, 3] = [-1, -2, -3] + assert np.allclose(T_inv, expected) + + def test_rotation_inverse(self) -> None: + T = np.eye(4) + T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() + T_inv = transform_utils.invert_transform(T) + + # Inverse rotation is transpose + expected = np.eye(4) + expected[:3, :3] = R.from_euler("z", -np.pi / 2).as_matrix() + assert np.allclose(T_inv, expected) + + def test_general_transform_inverse(self) -> None: + T = np.eye(4) + T[:3, :3] = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_matrix() + T[:3, 3] = [1, 2, 3] + + T_inv = transform_utils.invert_transform(T) + + # T @ T_inv should be identity + result = T @ T_inv + assert np.allclose(result, np.eye(4)) + + # T_inv @ T should also be identity + result2 = T_inv @ T + assert np.allclose(result2, np.eye(4)) + + +class TestComposeTransforms: + def test_no_transforms(self) -> None: + result = transform_utils.compose_transforms() + assert np.allclose(result, np.eye(4)) + + def test_single_transform(self) -> None: + T = np.eye(4) + T[:3, 3] = [1, 2, 3] + result = transform_utils.compose_transforms(T) + assert np.allclose(result, T) + + def test_two_translations(self) -> None: + T1 = np.eye(4) + T1[:3, 3] = [1, 0, 0] + + T2 = np.eye(4) + T2[:3, 3] = [0, 2, 0] + + result = transform_utils.compose_transforms(T1, T2) + + expected = np.eye(4) + expected[:3, 3] = [1, 2, 0] + assert np.allclose(result, expected) + + def test_three_transforms(self) -> None: + T1 = np.eye(4) + T1[:3, 3] = [1, 0, 0] + + T2 = np.eye(4) + T2[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() + + T3 = np.eye(4) + T3[:3, 3] = [1, 0, 0] + + result = transform_utils.compose_transforms(T1, T2, T3) + expected = T1 @ T2 @ T3 + assert np.allclose(result, expected) + + +class TestEulerToQuaternion: + def test_zero_euler(self) -> None: + euler = Vector3(0, 0, 0) + quat = transform_utils.euler_to_quaternion(euler) + assert np.isclose(quat.w, 1) + assert np.isclose(quat.x, 0) + assert np.isclose(quat.y, 0) + assert np.isclose(quat.z, 0) + + def test_roll_only(self) -> None: + euler = Vector3(np.pi / 2, 0, 0) + quat = transform_utils.euler_to_quaternion(euler) + + # Verify by converting back + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz") + assert np.isclose(recovered[0], np.pi / 2) + assert np.isclose(recovered[1], 0) + assert np.isclose(recovered[2], 0) + + def test_pitch_only(self) -> None: + euler = Vector3(0, np.pi / 3, 0) + quat = transform_utils.euler_to_quaternion(euler) + + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz") + assert np.isclose(recovered[0], 0) + assert np.isclose(recovered[1], np.pi / 3) + assert np.isclose(recovered[2], 0) + + def test_yaw_only(self) -> None: + euler = Vector3(0, 0, np.pi / 4) + quat = transform_utils.euler_to_quaternion(euler) + + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz") + assert np.isclose(recovered[0], 0) + assert np.isclose(recovered[1], 0) + assert np.isclose(recovered[2], np.pi / 4) + + def test_degrees_mode(self) -> None: + euler = Vector3(45, 30, 60) # degrees + quat = transform_utils.euler_to_quaternion(euler, degrees=True) + + recovered = R.from_quat([quat.x, quat.y, quat.z, quat.w]).as_euler("xyz", degrees=True) + assert np.isclose(recovered[0], 45) + assert np.isclose(recovered[1], 30) + assert np.isclose(recovered[2], 60) + + +class TestQuaternionToEuler: + def test_identity_quaternion(self) -> None: + quat = Quaternion(0, 0, 0, 1) + euler = transform_utils.quaternion_to_euler(quat) + assert np.isclose(euler.x, 0) + assert np.isclose(euler.y, 0) + assert np.isclose(euler.z, 0) + + def test_90_degree_yaw(self) -> None: + # Create quaternion for 90 degree yaw rotation + r = R.from_euler("z", np.pi / 2) + q = r.as_quat() + quat = Quaternion(q[0], q[1], q[2], q[3]) + + euler = transform_utils.quaternion_to_euler(quat) + assert np.isclose(euler.x, 0) + assert np.isclose(euler.y, 0) + assert np.isclose(euler.z, np.pi / 2) + + def test_round_trip_euler_quaternion(self) -> None: + original_euler = Vector3(0.3, 0.5, 0.7) + quat = transform_utils.euler_to_quaternion(original_euler) + recovered_euler = transform_utils.quaternion_to_euler(quat) + + assert np.isclose(recovered_euler.x, original_euler.x, atol=1e-10) + assert np.isclose(recovered_euler.y, original_euler.y, atol=1e-10) + assert np.isclose(recovered_euler.z, original_euler.z, atol=1e-10) + + def test_degrees_mode(self) -> None: + # Create quaternion for 45 degree yaw rotation + r = R.from_euler("z", 45, degrees=True) + q = r.as_quat() + quat = Quaternion(q[0], q[1], q[2], q[3]) + + euler = transform_utils.quaternion_to_euler(quat, degrees=True) + assert np.isclose(euler.x, 0) + assert np.isclose(euler.y, 0) + assert np.isclose(euler.z, 45) + + def test_angle_normalization(self) -> None: + # Test that angles are normalized to [-pi, pi] + r = R.from_euler("xyz", [3 * np.pi, -3 * np.pi, 2 * np.pi]) + q = r.as_quat() + quat = Quaternion(q[0], q[1], q[2], q[3]) + + euler = transform_utils.quaternion_to_euler(quat) + assert -np.pi <= euler.x <= np.pi + assert -np.pi <= euler.y <= np.pi + assert -np.pi <= euler.z <= np.pi + + +class TestGetDistance: + def test_same_pose(self) -> None: + pose1 = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(1, 2, 3), Quaternion(0.1, 0.2, 0.3, 0.9)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 0) + + def test_vector_distance(self) -> None: + pose1 = Vector3(1, 2, 3) + pose2 = Vector3(4, 5, 6) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, np.sqrt(3**2 + 3**2 + 3**2)) + + def test_distance_x_axis(self) -> None: + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(5, 0, 0), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 5) + + def test_distance_y_axis(self) -> None: + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(0, 3, 0), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 3) + + def test_distance_z_axis(self) -> None: + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(0, 0, 4), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 4) + + def test_3d_distance(self) -> None: + pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(3, 4, 0), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + assert np.isclose(distance, 5) # 3-4-5 triangle + + def test_negative_coordinates(self) -> None: + pose1 = Pose(Vector3(-1, -2, -3), Quaternion(0, 0, 0, 1)) + pose2 = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) + distance = transform_utils.get_distance(pose1, pose2) + expected = np.sqrt(4 + 16 + 36) # sqrt(56) + assert np.isclose(distance, expected) + + +class TestRetractDistance: + def test_retract_along_negative_z(self) -> None: + # Default case: gripper approaches along -z axis + # Positive distance moves away from the surface (opposite to approach direction) + target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) + retracted = transform_utils.offset_distance(target_pose, 0.5) + + # Moving along -z approach vector with positive distance = retracting upward + # Since approach is -z and we retract (positive distance), we move in +z + assert np.isclose(retracted.position.x, 0) + assert np.isclose(retracted.position.y, 0) + assert np.isclose(retracted.position.z, 0.5) # 1 + 0.5 * (-1) = 0.5 + + # Orientation should remain unchanged + assert retracted.orientation.x == target_pose.orientation.x + assert retracted.orientation.y == target_pose.orientation.y + assert retracted.orientation.z == target_pose.orientation.z + assert retracted.orientation.w == target_pose.orientation.w + + def test_retract_with_rotation(self) -> None: + # Test with a rotated pose (90 degrees around x-axis) + r = R.from_euler("x", np.pi / 2) + q = r.as_quat() + target_pose = Pose(Vector3(0, 0, 1), Quaternion(q[0], q[1], q[2], q[3])) + + retracted = transform_utils.offset_distance(target_pose, 0.5) + + # After 90 degree rotation around x, -z becomes +y + assert np.isclose(retracted.position.x, 0) + assert np.isclose(retracted.position.y, 0.5) # Move along +y + assert np.isclose(retracted.position.z, 1) + + def test_retract_negative_distance(self) -> None: + # Negative distance should move forward (toward the approach direction) + target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) + retracted = transform_utils.offset_distance(target_pose, -0.3) + + # Moving along -z approach vector with negative distance = moving downward + assert np.isclose(retracted.position.x, 0) + assert np.isclose(retracted.position.y, 0) + assert np.isclose(retracted.position.z, 1.3) # 1 + (-0.3) * (-1) = 1.3 + + def test_retract_arbitrary_pose(self) -> None: + # Test with arbitrary position and rotation + r = R.from_euler("xyz", [0.1, 0.2, 0.3]) + q = r.as_quat() + target_pose = Pose(Vector3(5, 3, 2), Quaternion(q[0], q[1], q[2], q[3])) + + distance = 1.0 + retracted = transform_utils.offset_distance(target_pose, distance) + + # Verify the distance between original and retracted is as expected + # (approximately, due to the approach vector direction) + T_target = transform_utils.pose_to_matrix(target_pose) + rotation_matrix = T_target[:3, :3] + approach_vector = rotation_matrix @ np.array([0, 0, -1]) + + expected_x = target_pose.position.x + distance * approach_vector[0] + expected_y = target_pose.position.y + distance * approach_vector[1] + expected_z = target_pose.position.z + distance * approach_vector[2] + + assert np.isclose(retracted.position.x, expected_x) + assert np.isclose(retracted.position.y, expected_y) + assert np.isclose(retracted.position.z, expected_z) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py new file mode 100644 index 0000000000..9b34436d4e --- /dev/null +++ b/dimos/utils/testing.py @@ -0,0 +1,375 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Iterator +import functools +import glob +import os +from pathlib import Path +import pickle +import re +import time +from typing import Any, Generic, TypeVar + +from reactivex import ( + from_iterable, + interval, + operators as ops, +) +from reactivex.observable import Observable +from reactivex.scheduler import TimeoutScheduler + +from dimos.utils.data import _get_data_dir, get_data + +T = TypeVar("T") + + +class SensorReplay(Generic[T]): + """Generic sensor data replay utility. + + Args: + name: The name of the test dataset + autocast: Optional function that takes unpickled data and returns a processed result. + For example: lambda data: LidarMessage.from_msg(data) + """ + + def __init__(self, name: str, autocast: Callable[[Any], T] | None = None) -> None: + self.root_dir = get_data(name) + self.autocast = autocast + + def load(self, *names: int | str) -> T | Any | list[T] | list[Any]: + if len(names) == 1: + return self.load_one(names[0]) + return list(map(lambda name: self.load_one(name), names)) + + def load_one(self, name: int | str | Path) -> T | Any: + if isinstance(name, int): + full_path = self.root_dir / f"/{name:03d}.pickle" + elif isinstance(name, Path): + full_path = name + else: + full_path = self.root_dir / Path(f"{name}.pickle") + + with open(full_path, "rb") as f: + data = pickle.load(f) + if self.autocast: + return self.autocast(data) + return data + + def first(self) -> T | Any | None: + try: + return next(self.iterate()) + except StopIteration: + return None + + @functools.cached_property + def files(self) -> list[Path]: + def extract_number(filepath): # type: ignore[no-untyped-def] + """Extract last digits before .pickle extension""" + basename = os.path.basename(filepath) + match = re.search(r"(\d+)\.pickle$", basename) + return int(match.group(1)) if match else 0 + + return sorted( + glob.glob(os.path.join(self.root_dir, "*")), # type: ignore[arg-type] + key=extract_number, + ) + + def iterate(self, loop: bool = False) -> Iterator[T | Any]: + while True: + for file_path in self.files: + yield self.load_one(Path(file_path)) + if not loop: + break + + def stream(self, rate_hz: float | None = None, loop: bool = False) -> Observable[T | Any]: + if rate_hz is None: + return from_iterable(self.iterate(loop=loop)) + + sleep_time = 1.0 / rate_hz + + return from_iterable(self.iterate(loop=loop)).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda x: x[0] if isinstance(x, tuple) else x), + ) + + +class SensorStorage(Generic[T]): + """Generic sensor data storage utility + . + Creates a directory in the test data directory and stores pickled sensor data. + + Args: + name: The name of the storage directory + autocast: Optional function that takes data and returns a processed result before storage. + """ + + def __init__(self, name: str, autocast: Callable[[T], Any] | None = None) -> None: + self.name = name + self.autocast = autocast + self.cnt = 0 + + # Create storage directory in the data dir + self.root_dir = _get_data_dir() / name + + # Check if directory exists and is not empty + if self.root_dir.exists(): + existing_files = list(self.root_dir.glob("*.pickle")) + if existing_files: + raise RuntimeError( + f"Storage directory '{name}' already exists and contains {len(existing_files)} files. " + f"Please use a different name or clean the directory first." + ) + else: + # Create the directory + self.root_dir.mkdir(parents=True, exist_ok=True) + + def consume_stream(self, observable: Observable[T | Any]) -> None: + """Consume an observable stream of sensor data without saving.""" + return observable.subscribe(self.save_one) # type: ignore[arg-type, return-value] + + def save_stream(self, observable: Observable[T | Any]) -> Observable[int]: + """Save an observable stream of sensor data to pickle files.""" + return observable.pipe(ops.map(lambda frame: self.save_one(frame))) + + def save(self, *frames) -> int: # type: ignore[no-untyped-def] + """Save one or more frames to pickle files.""" + for frame in frames: + self.save_one(frame) + return self.cnt + + def save_one(self, frame) -> int: # type: ignore[no-untyped-def] + """Save a single frame to a pickle file.""" + file_name = f"{self.cnt:03d}.pickle" + full_path = self.root_dir / file_name + + if full_path.exists(): + raise RuntimeError(f"File {full_path} already exists") + + # Apply autocast if provided + data_to_save = frame + if self.autocast: + data_to_save = self.autocast(frame) + # Convert to raw message if frame has a raw_msg attribute + elif hasattr(frame, "raw_msg"): + data_to_save = frame.raw_msg + + with open(full_path, "wb") as f: + pickle.dump(data_to_save, f) + + self.cnt += 1 + return self.cnt + + +class TimedSensorStorage(SensorStorage[T]): + def save_one(self, frame: T) -> int: + return super().save_one((time.time(), frame)) + + +class TimedSensorReplay(SensorReplay[T]): + def load_one(self, name: int | str | Path) -> T | Any: + if isinstance(name, int): + full_path = self.root_dir / f"/{name:03d}.pickle" + elif isinstance(name, Path): + full_path = name + else: + full_path = self.root_dir / Path(f"{name}.pickle") + + with open(full_path, "rb") as f: + data = pickle.load(f) + if self.autocast: + return (data[0], self.autocast(data[1])) + return data + + def find_closest(self, timestamp: float, tolerance: float | None = None) -> T | Any | None: + """Find the frame closest to the given timestamp. + + Args: + timestamp: The target timestamp to search for + tolerance: Optional maximum time difference allowed + + Returns: + The data frame closest to the timestamp, or None if no match within tolerance + """ + closest_data = None + closest_diff = float("inf") + + # Check frames before and after the timestamp + for ts, data in self.iterate_ts(): + diff = abs(ts - timestamp) + + if diff < closest_diff: + closest_diff = diff + closest_data = data + elif diff > closest_diff: + # We're moving away from the target, can stop + break + + if tolerance is not None and closest_diff > tolerance: + return None + + return closest_data + + def find_closest_seek( + self, relative_seconds: float, tolerance: float | None = None + ) -> T | Any | None: + """Find the frame closest to a time relative to the start. + + Args: + relative_seconds: Seconds from the start of the dataset + tolerance: Optional maximum time difference allowed + + Returns: + The data frame closest to the relative timestamp, or None if no match within tolerance + """ + # Get the first timestamp + first_ts = self.first_timestamp() + if first_ts is None: + return None + + # Calculate absolute timestamp and use find_closest + target_timestamp = first_ts + relative_seconds + return self.find_closest(target_timestamp, tolerance) + + def first_timestamp(self) -> float | None: + """Get the timestamp of the first item in the dataset. + + Returns: + The first timestamp, or None if dataset is empty + """ + try: + ts, _ = next(self.iterate_ts()) + return ts + except StopIteration: + return None + + def iterate(self, loop: bool = False) -> Iterator[T | Any]: + return (x[1] for x in super().iterate(loop=loop)) # type: ignore[index] + + def iterate_ts( + self, + seek: float | None = None, + duration: float | None = None, + from_timestamp: float | None = None, + loop: bool = False, + ) -> Iterator[tuple[float, T] | Any]: + first_ts = None + if (seek is not None) or (duration is not None): + first_ts = self.first_timestamp() + if first_ts is None: + return + + if seek is not None: + from_timestamp = first_ts + seek # type: ignore[operator] + + end_timestamp = None + if duration is not None: + end_timestamp = (from_timestamp if from_timestamp else first_ts) + duration # type: ignore[operator] + + while True: + for ts, data in super().iterate(): # type: ignore[misc] + if from_timestamp is None or ts >= from_timestamp: + if end_timestamp is not None and ts >= end_timestamp: + break + yield (ts, data) + if not loop: + break + + def stream( # type: ignore[override] + self, + speed: float = 1.0, + seek: float | None = None, + duration: float | None = None, + from_timestamp: float | None = None, + loop: bool = False, + ) -> Observable[T | Any]: + def _subscribe(observer, scheduler=None): # type: ignore[no-untyped-def] + from reactivex.disposable import CompositeDisposable, Disposable + + scheduler = scheduler or TimeoutScheduler() + disp = CompositeDisposable() + is_disposed = False + + iterator = self.iterate_ts( + seek=seek, duration=duration, from_timestamp=from_timestamp, loop=loop + ) + + # Get first message + try: + first_ts, first_data = next(iterator) + except StopIteration: + observer.on_completed() + return Disposable() + + # Establish timing reference + start_local_time = time.time() + start_replay_time = first_ts + + # Emit first sample immediately + observer.on_next(first_data) + + # Pre-load next message + try: + next_message = next(iterator) + except StopIteration: + observer.on_completed() + return disp + + def schedule_emission(message) -> None: # type: ignore[no-untyped-def] + nonlocal next_message, is_disposed + + if is_disposed: + return + + ts, data = message + + # Pre-load the following message while we have time + try: + next_message = next(iterator) + except StopIteration: + next_message = None + + # Calculate absolute emission time + target_time = start_local_time + (ts - start_replay_time) / speed + delay = max(0.0, target_time - time.time()) + + def emit() -> None: + if is_disposed: + return + observer.on_next(data) + if next_message is not None: + schedule_emission(next_message) + else: + observer.on_completed() + # Dispose of the scheduler to clean up threads + if hasattr(scheduler, "dispose"): + scheduler.dispose() + + disp.add(scheduler.schedule_relative(delay, lambda sc, _: emit())) + + schedule_emission(next_message) + + # Create a custom disposable that properly cleans up + def dispose() -> None: + nonlocal is_disposed + is_disposed = True + disp.dispose() + # Ensure scheduler is disposed to clean up any threads + if hasattr(scheduler, "dispose"): + scheduler.dispose() + + return Disposable(dispose) + + from reactivex import create + + return create(_subscribe) diff --git a/dimos/utils/threadpool.py b/dimos/utils/threadpool.py new file mode 100644 index 0000000000..a2adc90725 --- /dev/null +++ b/dimos/utils/threadpool.py @@ -0,0 +1,79 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread pool functionality for parallel execution in the Dimos framework. + +This module provides a shared ThreadPoolExecutor exposed through a +ReactiveX scheduler, ensuring consistent thread management across the application. +""" + +import multiprocessing +import os + +from reactivex.scheduler import ThreadPoolScheduler + +from .logging_config import setup_logger + +logger = setup_logger() + + +def get_max_workers() -> int: + """Determine the number of workers for the thread pool. + + Returns: + int: The number of workers, configurable via the DIMOS_MAX_WORKERS + environment variable, defaulting to 4 times the CPU count. + """ + env_value = os.getenv("DIMOS_MAX_WORKERS", "") + return int(env_value) if env_value.strip() else multiprocessing.cpu_count() + + +# Create a ThreadPoolScheduler with a configurable number of workers. +try: + max_workers = get_max_workers() + scheduler = ThreadPoolScheduler(max_workers=max_workers) + # logger.info(f"Using {max_workers} workers") +except Exception as e: + logger.error(f"Failed to initialize ThreadPoolScheduler: {e}") + raise + + +def get_scheduler() -> ThreadPoolScheduler: + """Return the global ThreadPoolScheduler instance. + + The thread pool is configured with a fixed number of workers and is shared + across the application to manage system resources efficiently. + + Returns: + ThreadPoolScheduler: The global scheduler instance for scheduling + operations on the thread pool. + """ + return scheduler + + +def make_single_thread_scheduler() -> ThreadPoolScheduler: + """Create a new ThreadPoolScheduler with a single worker. + + This provides a dedicated scheduler for tasks that should run serially + on their own thread rather than using the shared thread pool. + + Returns: + ThreadPoolScheduler: A scheduler instance with a single worker thread. + """ + return ThreadPoolScheduler(max_workers=1) + + +# Example usage: +# scheduler = get_scheduler() +# # Use the scheduler for parallel tasks diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py new file mode 100644 index 0000000000..f3577dbaa6 --- /dev/null +++ b/dimos/utils/transform_utils.py @@ -0,0 +1,386 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from scipy.spatial.transform import Rotation as R + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 + + +def normalize_angle(angle: float) -> float: + """Normalize angle to [-pi, pi] range""" + return np.arctan2(np.sin(angle), np.cos(angle)) # type: ignore[no-any-return] + + +def pose_to_matrix(pose: Pose) -> np.ndarray: # type: ignore[type-arg] + """ + Convert pose to 4x4 homogeneous transform matrix. + + Args: + pose: Pose object with position and orientation (quaternion) + + Returns: + 4x4 transformation matrix + """ + # Extract position + tx, ty, tz = pose.position.x, pose.position.y, pose.position.z + + # Create rotation matrix from quaternion using scipy + quat = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + + # Check for zero norm quaternion and use identity if invalid + quat_norm = np.linalg.norm(quat) + if quat_norm == 0.0: + # Use identity quaternion [0, 0, 0, 1] if zero norm detected + quat = [0.0, 0.0, 0.0, 1.0] + + rotation = R.from_quat(quat) + Rot = rotation.as_matrix() + + # Create 4x4 transform + T = np.eye(4) + T[:3, :3] = Rot + T[:3, 3] = [tx, ty, tz] + + return T + + +def matrix_to_pose(T: np.ndarray) -> Pose: # type: ignore[type-arg] + """ + Convert 4x4 transformation matrix to Pose object. + + Args: + T: 4x4 transformation matrix + + Returns: + Pose object with position and orientation (quaternion) + """ + # Extract position + pos = Vector3(T[0, 3], T[1, 3], T[2, 3]) + + # Extract rotation matrix and convert to quaternion + Rot = T[:3, :3] + rotation = R.from_matrix(Rot) + quat = rotation.as_quat() # Returns [x, y, z, w] + + orientation = Quaternion(quat[0], quat[1], quat[2], quat[3]) + + return Pose(pos, orientation) + + +def apply_transform(pose: Pose, transform: np.ndarray | Transform) -> Pose: # type: ignore[type-arg] + """ + Apply a transformation matrix to a pose. + + Args: + pose: Input pose + transform_matrix: 4x4 transformation matrix to apply + + Returns: + Transformed pose + """ + if isinstance(transform, Transform): + if transform.child_frame_id != pose.frame_id: + raise ValueError( + f"Transform frame_id {transform.frame_id} does not match pose frame_id {pose.frame_id}" + ) + transform = pose_to_matrix(transform.to_pose()) + + # Convert pose to matrix + T_pose = pose_to_matrix(pose) + + # Apply transform + T_result = transform @ T_pose + + # Convert back to pose + return matrix_to_pose(T_result) + + +def optical_to_robot_frame(pose: Pose) -> Pose: + """ + Convert pose from optical camera frame to robot frame convention. + + Optical Camera Frame (e.g., ZED): + - X: Right + - Y: Down + - Z: Forward (away from camera) + + Robot Frame (ROS/REP-103): + - X: Forward + - Y: Left + - Z: Up + + Args: + pose: Pose in optical camera frame + + Returns: + Pose in robot frame + """ + # Position transformation + robot_x = pose.position.z # Forward = Camera Z + robot_y = -pose.position.x # Left = -Camera X + robot_z = -pose.position.y # Up = -Camera Y + + # Rotation transformation using quaternions + # First convert quaternion to rotation matrix + quat_optical = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + R_optical = R.from_quat(quat_optical).as_matrix() + + # Coordinate frame transformation matrix from optical to robot + # X_robot = Z_optical, Y_robot = -X_optical, Z_robot = -Y_optical + T_frame = np.array( + [ + [0, 0, 1], # X_robot = Z_optical + [-1, 0, 0], # Y_robot = -X_optical + [0, -1, 0], # Z_robot = -Y_optical + ] + ) + + # Transform the rotation matrix + R_robot = T_frame @ R_optical @ T_frame.T + + # Convert back to quaternion + quat_robot = R.from_matrix(R_robot).as_quat() # [x, y, z, w] + + return Pose( + Vector3(robot_x, robot_y, robot_z), + Quaternion(quat_robot[0], quat_robot[1], quat_robot[2], quat_robot[3]), + ) + + +def robot_to_optical_frame(pose: Pose) -> Pose: + """ + Convert pose from robot frame to optical camera frame convention. + This is the inverse of optical_to_robot_frame. + + Args: + pose: Pose in robot frame + + Returns: + Pose in optical camera frame + """ + # Position transformation (inverse) + optical_x = -pose.position.y # Right = -Left + optical_y = -pose.position.z # Down = -Up + optical_z = pose.position.x # Forward = Forward + + # Rotation transformation using quaternions + quat_robot = [pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w] + R_robot = R.from_quat(quat_robot).as_matrix() + + # Coordinate frame transformation matrix from Robot to optical (inverse of optical to Robot) + # This is the transpose of the forward transformation + T_frame_inv = np.array( + [ + [0, -1, 0], # X_optical = -Y_robot + [0, 0, -1], # Y_optical = -Z_robot + [1, 0, 0], # Z_optical = X_robot + ] + ) + + # Transform the rotation matrix + R_optical = T_frame_inv @ R_robot @ T_frame_inv.T + + # Convert back to quaternion + quat_optical = R.from_matrix(R_optical).as_quat() # [x, y, z, w] + + return Pose( + Vector3(optical_x, optical_y, optical_z), + Quaternion(quat_optical[0], quat_optical[1], quat_optical[2], quat_optical[3]), + ) + + +def yaw_towards_point(position: Vector3, target_point: Vector3 = None) -> float: # type: ignore[assignment] + """ + Calculate yaw angle from target point to position (away from target). + This is commonly used for object orientation in grasping applications. + Assumes robot frame where X is forward and Y is left. + + Args: + position: Current position in robot frame + target_point: Reference point (default: origin) + + Returns: + Yaw angle in radians pointing from target_point to position + """ + if target_point is None: + target_point = Vector3(0.0, 0.0, 0.0) + direction_x = position.x - target_point.x + direction_y = position.y - target_point.y + return np.arctan2(direction_y, direction_x) # type: ignore[no-any-return] + + +def create_transform_from_6dof(translation: Vector3, euler_angles: Vector3) -> np.ndarray: # type: ignore[type-arg] + """ + Create a 4x4 transformation matrix from 6DOF parameters. + + Args: + translation: Translation vector [x, y, z] in meters + euler_angles: Euler angles [rx, ry, rz] in radians (XYZ convention) + + Returns: + 4x4 transformation matrix + """ + # Create transformation matrix + T = np.eye(4) + + # Set translation + T[0:3, 3] = [translation.x, translation.y, translation.z] + + # Set rotation using scipy + if np.linalg.norm([euler_angles.x, euler_angles.y, euler_angles.z]) > 1e-6: + rotation = R.from_euler("xyz", [euler_angles.x, euler_angles.y, euler_angles.z]) + T[0:3, 0:3] = rotation.as_matrix() + + return T + + +def invert_transform(T: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """ + Invert a 4x4 transformation matrix efficiently. + + Args: + T: 4x4 transformation matrix + + Returns: + Inverted 4x4 transformation matrix + """ + # For homogeneous transform matrices, we can use the special structure: + # [R t]^-1 = [R^T -R^T*t] + # [0 1] [0 1 ] + + Rot = T[:3, :3] + t = T[:3, 3] + + T_inv = np.eye(4) + T_inv[:3, :3] = Rot.T + T_inv[:3, 3] = -Rot.T @ t + + return T_inv + + +def compose_transforms(*transforms: np.ndarray) -> np.ndarray: # type: ignore[type-arg] + """ + Compose multiple transformation matrices. + + Args: + *transforms: Variable number of 4x4 transformation matrices + + Returns: + Composed 4x4 transformation matrix (T1 @ T2 @ ... @ Tn) + """ + result = np.eye(4) + for T in transforms: + result = result @ T + return result + + +def euler_to_quaternion(euler_angles: Vector3, degrees: bool = False) -> Quaternion: + """ + Convert euler angles to quaternion. + + Args: + euler_angles: Euler angles as Vector3 [roll, pitch, yaw] in radians (XYZ convention) + + Returns: + Quaternion object [x, y, z, w] + """ + rotation = R.from_euler( + "xyz", [euler_angles.x, euler_angles.y, euler_angles.z], degrees=degrees + ) + quat = rotation.as_quat() # Returns [x, y, z, w] + return Quaternion(quat[0], quat[1], quat[2], quat[3]) + + +def quaternion_to_euler(quaternion: Quaternion, degrees: bool = False) -> Vector3: + """ + Convert quaternion to euler angles. + + Args: + quaternion: Quaternion object [x, y, z, w] + + Returns: + Euler angles as Vector3 [roll, pitch, yaw] in radians (XYZ convention) + """ + quat = [quaternion.x, quaternion.y, quaternion.z, quaternion.w] + rotation = R.from_quat(quat) + euler = rotation.as_euler("xyz", degrees=degrees) # Returns [roll, pitch, yaw] + if not degrees: + return Vector3( + normalize_angle(euler[0]), normalize_angle(euler[1]), normalize_angle(euler[2]) + ) + else: + return Vector3(euler[0], euler[1], euler[2]) + + +def get_distance(pose1: Pose | Vector3, pose2: Pose | Vector3) -> float: + """ + Calculate Euclidean distance between two poses. + + Args: + pose1: First pose + pose2: Second pose + + Returns: + Euclidean distance between the two poses in meters + """ + if hasattr(pose1, "position"): + pose1 = pose1.position + if hasattr(pose2, "position"): + pose2 = pose2.position + + dx = pose1.x - pose2.x + dy = pose1.y - pose2.y + dz = pose1.z - pose2.z + + return np.linalg.norm(np.array([dx, dy, dz])) # type: ignore[return-value] + + +def offset_distance( + target_pose: Pose, distance: float, approach_vector: Vector3 = Vector3(0, 0, -1) +) -> Pose: + """ + Apply distance offset to target pose along its approach direction. + + This is commonly used in grasping to offset the gripper by a certain distance + along the approach vector before or after grasping. + + Args: + target_pose: Target pose (e.g., grasp pose) + distance: Distance to offset along the approach direction (meters) + + Returns: + Target pose offset by the specified distance along its approach direction + """ + # Convert pose to transformation matrix to extract rotation + T_target = pose_to_matrix(target_pose) + rotation_matrix = T_target[:3, :3] + + # Define the approach vector based on the target pose orientation + # Assuming the gripper approaches along its local -z axis (common for downward grasps) + # You can change this to [1, 0, 0] for x-axis or [0, 1, 0] for y-axis based on your gripper + approach_vector_local = np.array([approach_vector.x, approach_vector.y, approach_vector.z]) + + # Transform approach vector to world coordinates + approach_vector_world = rotation_matrix @ approach_vector_local + + # Apply offset along the approach direction + offset_position = Vector3( + target_pose.position.x + distance * approach_vector_world[0], + target_pose.position.y + distance * approach_vector_world[1], + target_pose.position.z + distance * approach_vector_world[2], + ) + + return Pose(position=offset_position, orientation=target_pose.orientation) diff --git a/dimos/web/README.md b/dimos/web/README.md new file mode 100644 index 0000000000..c7bcd5df20 --- /dev/null +++ b/dimos/web/README.md @@ -0,0 +1,126 @@ +# DimOS Robot Web Interface + +A streamlined interface for controlling and interacting with robots through DimOS. + +## Setup + +First, create an `.env` file in the root dimos directory with your configuration: + +```bash +# Example .env file +OPENAI_API_KEY=sk-your-openai-api-key +ROBOT_IP=192.168.x.x +CONN_TYPE=webrtc +WEBRTC_SERVER_HOST=0.0.0.0 +WEBRTC_SERVER_PORT=9991 +DISPLAY=:0 +``` + +## Unitree Go2 Example + +Running a full stack for Unitree Go2 requires three components: + +### 1. Start ROS2 Robot Driver + +```bash +# Source ROS environment +source /opt/ros/humble/setup.bash +source ~/your_ros_workspace/install/setup.bash + +# Launch robot driver +ros2 launch go2_robot_sdk robot.launch.py +``` + +### 2. Start DimOS Backend + +```bash +# In a new terminal, source your Python environment +source venv/bin/activate # Or your environment + +# Install requirements +pip install -r requirements.txt + +# Source ROS workspace (needed for robot communication) +source /opt/ros/humble/setup.bash +source ~/your_ros_workspace/install/setup.bash + +# Run the server with Robot() and Agent() initialization +python tests/test_unitree_agent_queries_fastapi.py +``` + +### 3. Start Frontend + +**Install yarn if not already installed** + +```bash +npm install -g yarn +``` + +**Then install dependencies and start the development server** + +```bash +# In a new terminal +cd dimos/web/dimos-interface + +# Install dependencies (first time only) +yarn install + +# Start development server +yarn dev +``` + +The frontend will be available at http://localhost:3000 + +## Using the Interface + +1. Access the web terminal at http://localhost:3000 +2. Type commands to control your robot: + - `unitree command ` - Send a command to the robot + - `unitree status` - Check connection status + - `unitree start_stream` - Start the video stream + - `unitree stop_stream` - Stop the video stream + +## Integrating DimOS with the DimOS-interface + +### Unitree Go2 Example + +```python +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface + +robot_ip = os.getenv("ROBOT_IP") + +# Initialize robot +logger.info("Initializing Unitree Robot") +robot = UnitreeGo2(ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir) + +# Set up video stream +logger.info("Starting video stream") +video_stream = robot.get_ros_video_stream() + +# Create FastAPI server with video stream +logger.info("Initializing FastAPI server") +streams = {"unitree_video": video_stream} +web_interface = RobotWebInterface(port=5555, **streams) + +# Initialize agent with robot skills +skills_instance = MyUnitreeSkills(robot=robot) + +agent = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgent", + input_query_stream=web_interface.query_stream, + output_dir=output_dir, + skills=skills_instance, +) + +web_interface.run() +``` + +## Architecture + +- **Backend**: FastAPI server runs on port 5555 +- **Frontend**: Web application runs on port 3000 diff --git a/dimos/web/command-center-extension/.gitignore b/dimos/web/command-center-extension/.gitignore new file mode 100644 index 0000000000..3f7224ed26 --- /dev/null +++ b/dimos/web/command-center-extension/.gitignore @@ -0,0 +1,5 @@ +*.foxe +/dist +/node_modules +!/package.json +!/package-lock.json diff --git a/dimos/web/command-center-extension/.prettierrc.yaml b/dimos/web/command-center-extension/.prettierrc.yaml new file mode 100644 index 0000000000..e57cc20758 --- /dev/null +++ b/dimos/web/command-center-extension/.prettierrc.yaml @@ -0,0 +1,5 @@ +arrowParens: always +printWidth: 100 +trailingComma: "all" +tabWidth: 2 +semi: true diff --git a/dimos/web/command-center-extension/CHANGELOG.md b/dimos/web/command-center-extension/CHANGELOG.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/web/command-center-extension/README.md b/dimos/web/command-center-extension/README.md new file mode 100644 index 0000000000..efee4ec11d --- /dev/null +++ b/dimos/web/command-center-extension/README.md @@ -0,0 +1,17 @@ +# command-center-extension + +This is a Foxglove extension for visualizing robot data and controlling the robot. See `dimos/web/websocket_vis/README.md` for how to use the module in your robot. + +## Build and use + +Install the Foxglove Studio desktop application. + +Install the Node dependencies: + + npm install + +Build the package and install it into Foxglove: + + npm run build && npm run local-install + +To add the panel, go to Foxglove Studio, click on the "Add panel" icon on the top right and select "command-center [local]". diff --git a/dimos/web/command-center-extension/eslint.config.js b/dimos/web/command-center-extension/eslint.config.js new file mode 100644 index 0000000000..63cc3a243a --- /dev/null +++ b/dimos/web/command-center-extension/eslint.config.js @@ -0,0 +1,23 @@ +// @ts-check + +const foxglove = require("@foxglove/eslint-plugin"); +const globals = require("globals"); +const tseslint = require("typescript-eslint"); + +module.exports = tseslint.config({ + files: ["src/**/*.ts", "src/**/*.tsx"], + extends: [foxglove.configs.base, foxglove.configs.react, foxglove.configs.typescript], + languageOptions: { + globals: { + ...globals.es2020, + ...globals.browser, + }, + parserOptions: { + project: "tsconfig.json", + tsconfigRootDir: __dirname, + }, + }, + rules: { + "react-hooks/exhaustive-deps": "error", + }, +}); diff --git a/dimos/web/command-center-extension/package-lock.json b/dimos/web/command-center-extension/package-lock.json new file mode 100644 index 0000000000..6446666ebc --- /dev/null +++ b/dimos/web/command-center-extension/package-lock.json @@ -0,0 +1,7178 @@ +{ + "name": "command-center-extension", + "version": "0.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "command-center-extension", + "version": "0.0.0", + "license": "UNLICENSED", + "dependencies": { + "@types/pako": "^2.0.4", + "d3": "^7.9.0", + "leaflet": "^1.9.4", + "pako": "^2.1.0", + "react-leaflet": "^4.2.1", + "socket.io-client": "^4.8.1" + }, + "devDependencies": { + "@foxglove/eslint-plugin": "2.1.0", + "@foxglove/extension": "2.34.0", + "@types/d3": "^7.4.3", + "@types/leaflet": "^1.9.21", + "@types/react": "18.3.24", + "@types/react-dom": "18.3.7", + "create-foxglove-extension": "1.0.6", + "eslint": "9.34.0", + "prettier": "3.6.2", + "react": "18.3.1", + "react-dom": "^18.3.1", + "typescript": "5.9.2" + } + }, + "node_modules/@eslint-community/eslint-utils": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz", + "integrity": "sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^3.4.3" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" + } + }, + "node_modules/@eslint-community/regexpp": { + "version": "4.12.1", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.12.1.tgz", + "integrity": "sha512-CCZCDJuduB9OUkFkY2IgppNZMi2lBQgD2qzwXkEia16cge2pijY/aXi96CJMquDMn3nJdlPV1A5KrJEXwfLNzQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.0.0 || ^14.0.0 || >=16.0.0" + } + }, + "node_modules/@eslint/compat": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@eslint/compat/-/compat-1.3.2.tgz", + "integrity": "sha512-jRNwzTbd6p2Rw4sZ1CgWRS8YMtqG15YyZf7zvb6gY2rB2u6n+2Z+ELW0GtL0fQgyl0pr4Y/BzBfng/BdsereRA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "peerDependencies": { + "eslint": "^8.40 || 9" + }, + "peerDependenciesMeta": { + "eslint": { + "optional": true + } + } + }, + "node_modules/@eslint/config-array": { + "version": "0.21.0", + "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.21.0.tgz", + "integrity": "sha512-ENIdc4iLu0d93HeYirvKmrzshzofPw6VkZRKQGe9Nv46ZnWUzcF1xV01dcvEg/1wXUR61OmmlSfyeyO7EvjLxQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/object-schema": "^2.1.6", + "debug": "^4.3.1", + "minimatch": "^3.1.2" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/config-array/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/@eslint/config-array/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/@eslint/config-helpers": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.3.1.tgz", + "integrity": "sha512-xR93k9WhrDYpXHORXpxVL5oHj3Era7wo6k/Wd8/IsQNnZUTzkGS29lyn3nAT05v6ltUuTFVCCYDEGfy2Or/sPA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/core": { + "version": "0.15.2", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.15.2.tgz", + "integrity": "sha512-78Md3/Rrxh83gCxoUc0EiciuOHsIITzLy53m3d9UyiW8y9Dj2D29FeETqyKA+BRK76tnTp6RXWb3pCay8Oyomg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@types/json-schema": "^7.0.15" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/eslintrc": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.3.1.tgz", + "integrity": "sha512-gtF186CXhIl1p4pJNGZw8Yc6RlshoePRvE0X91oPGb3vZ8pM3qOS9W9NGPat9LziaBV7XrJWGylNQXkGcnM3IQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ajv": "^6.12.4", + "debug": "^4.3.2", + "espree": "^10.0.1", + "globals": "^14.0.0", + "ignore": "^5.2.0", + "import-fresh": "^3.2.1", + "js-yaml": "^4.1.0", + "minimatch": "^3.1.2", + "strip-json-comments": "^3.1.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@eslint/eslintrc/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/@eslint/eslintrc/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/@eslint/js": { + "version": "9.34.0", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.34.0.tgz", + "integrity": "sha512-EoyvqQnBNsV1CWaEJ559rxXL4c8V92gxirbawSmVUOWXlsRxxQXl6LmCpdUblgxgSkDIqKnhzba2SjRTI/A5Rw==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + } + }, + "node_modules/@eslint/object-schema": { + "version": "2.1.6", + "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.6.tgz", + "integrity": "sha512-RBMg5FRL0I0gs51M/guSAj5/e14VQ4tpZnQNWwuDT66P14I43ItmPfIZRhO9fUVIPOAQXU47atlywZ/czoqFPA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/plugin-kit": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.3.5.tgz", + "integrity": "sha512-Z5kJ+wU3oA7MMIqVR9tyZRtjYPr4OC004Q4Rw7pgOKUOKkJfZ3O24nz3WYfGRpMDNmcOi3TwQOmgm7B7Tpii0w==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.15.2", + "levn": "^0.4.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@foxglove/eslint-plugin": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/@foxglove/eslint-plugin/-/eslint-plugin-2.1.0.tgz", + "integrity": "sha512-EQrEns2BneSY7ODsOnJ6YIvn6iOVhwypHT4OwrzuPX2jqncghF7BXypkdDP3KlFtyDGC1+ff3+VXZMmyc8vpfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint/compat": "^1", + "@eslint/js": "^9", + "@typescript-eslint/utils": "^8", + "eslint-config-prettier": "^9", + "eslint-plugin-es": "^4", + "eslint-plugin-filenames": "^1", + "eslint-plugin-import": "^2", + "eslint-plugin-jest": "^28", + "eslint-plugin-prettier": "^5", + "eslint-plugin-react": "^7", + "eslint-plugin-react-hooks": "^5", + "tsutils": "^3", + "typescript-eslint": "^8" + }, + "peerDependencies": { + "eslint": "^9.27.0" + } + }, + "node_modules/@foxglove/extension": { + "version": "2.34.0", + "resolved": "https://registry.npmjs.org/@foxglove/extension/-/extension-2.34.0.tgz", + "integrity": "sha512-muZGa//A4gsNVRjwZevwvnSqQdabCJePdh75VFm5LhEb0fkP7VXjU3Rzh84EHRJvkUctiV7IbiI9OAPJmENGeQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@humanfs/core": { + "version": "0.19.1", + "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", + "integrity": "sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/node": { + "version": "0.16.6", + "resolved": "https://registry.npmjs.org/@humanfs/node/-/node-0.16.6.tgz", + "integrity": "sha512-YuI2ZHQL78Q5HbhDiBA1X4LmYdXCKCMQIfw0pw7piHJwyREFebJUvrQN4cMssyES6x+vfUbx1CIpaQUKYdQZOw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@humanfs/core": "^0.19.1", + "@humanwhocodes/retry": "^0.3.0" + }, + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/node/node_modules/@humanwhocodes/retry": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.3.1.tgz", + "integrity": "sha512-JBxkERygn7Bv/GbN5Rv8Ul6LVknS+5Bp6RgDC/O8gEBU/yeH5Ui5C/OlWrTb6qct7LjjfT6Re2NxB0ln0yYybA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/module-importer": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz", + "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.22" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/retry": { + "version": "0.4.3", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.3.tgz", + "integrity": "sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@isaacs/balanced-match": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz", + "integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@isaacs/brace-expansion": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.0.tgz", + "integrity": "sha512-ZT55BDLV0yv0RBm2czMiZ+SqCGO7AvmOM3G/w2xhVPH+te0aKgFjmBvGlL1dH+ql2tgGO3MVrbb3jCKyvpgnxA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@isaacs/balanced-match": "^4.0.1" + }, + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@isaacs/cliui": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", + "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", + "dev": true, + "license": "ISC", + "dependencies": { + "string-width": "^5.1.2", + "string-width-cjs": "npm:string-width@^4.2.0", + "strip-ansi": "^7.0.1", + "strip-ansi-cjs": "npm:strip-ansi@^6.0.1", + "wrap-ansi": "^8.1.0", + "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0", + "@jridgewell/trace-mapping": "^0.3.24" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/source-map": { + "version": "0.3.11", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.11.tgz", + "integrity": "sha512-ZMp1V8ZFcPG5dIWnQLr3NSI1MiCU7UETdS/A0G8V/XWHvJv3ZsFqutJn1Y5RPmAPX6F3BiE397OqveU/9NCuIA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", + "dev": true, + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.30", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.30.tgz", + "integrity": "sha512-GQ7Nw5G2lTu/BtHTKfXhKHok2WGetd4XYcVKGx00SjAk8GMwgJM3zr6zORiPGuOE+/vkc90KtTosSSvaCjKb2Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@pkgr/core": { + "version": "0.2.9", + "resolved": "https://registry.npmjs.org/@pkgr/core/-/core-0.2.9.tgz", + "integrity": "sha512-QNqXyfVS2wm9hweSYD2O7F0G06uurj9kZ96TRQE5Y9hU7+tgdZwIkbAKc5Ocy1HxEY2kuDQa6cQ1WRs/O5LFKA==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.20.0 || ^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/pkgr" + } + }, + "node_modules/@react-leaflet/core": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/@react-leaflet/core/-/core-2.1.0.tgz", + "integrity": "sha512-Qk7Pfu8BSarKGqILj4x7bCSZ1pjuAPZ+qmRwH5S7mDS91VSbVVsJSrW4qA+GPrro8t69gFYVMWb1Zc4yFmPiVg==", + "license": "Hippocratic-2.1", + "peerDependencies": { + "leaflet": "^1.9.0", + "react": "^18.0.0", + "react-dom": "^18.0.0" + } + }, + "node_modules/@rtsao/scc": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@rtsao/scc/-/scc-1.1.0.tgz", + "integrity": "sha512-zt6OdqaDoOnJ1ZYsCYGt9YmWzDXl4vQdKTyJev62gFhRGKdx7mcT54V9KIjg+d2wi9EXsPvAPKe7i7WjfVWB8g==", + "dev": true, + "license": "MIT" + }, + "node_modules/@socket.io/component-emitter": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", + "license": "MIT" + }, + "node_modules/@types/d3": { + "version": "7.4.3", + "resolved": "https://registry.npmjs.org/@types/d3/-/d3-7.4.3.tgz", + "integrity": "sha512-lZXZ9ckh5R8uiFVt8ogUNf+pIrK4EsWrx2Np75WvF/eTpJ0FMHNhjXk8CKEx/+gpHbNQyJWehbFaTvqmHWB3ww==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-array": "*", + "@types/d3-axis": "*", + "@types/d3-brush": "*", + "@types/d3-chord": "*", + "@types/d3-color": "*", + "@types/d3-contour": "*", + "@types/d3-delaunay": "*", + "@types/d3-dispatch": "*", + "@types/d3-drag": "*", + "@types/d3-dsv": "*", + "@types/d3-ease": "*", + "@types/d3-fetch": "*", + "@types/d3-force": "*", + "@types/d3-format": "*", + "@types/d3-geo": "*", + "@types/d3-hierarchy": "*", + "@types/d3-interpolate": "*", + "@types/d3-path": "*", + "@types/d3-polygon": "*", + "@types/d3-quadtree": "*", + "@types/d3-random": "*", + "@types/d3-scale": "*", + "@types/d3-scale-chromatic": "*", + "@types/d3-selection": "*", + "@types/d3-shape": "*", + "@types/d3-time": "*", + "@types/d3-time-format": "*", + "@types/d3-timer": "*", + "@types/d3-transition": "*", + "@types/d3-zoom": "*" + } + }, + "node_modules/@types/d3-array": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.1.tgz", + "integrity": "sha512-Y2Jn2idRrLzUfAKV2LyRImR+y4oa2AntrgID95SHJxuMUrkNXmanDSed71sRNZysveJVt1hLLemQZIady0FpEg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-axis": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@types/d3-axis/-/d3-axis-3.0.6.tgz", + "integrity": "sha512-pYeijfZuBd87T0hGn0FO1vQ/cgLk6E1ALJjfkC0oJ8cbwkZl3TpgS8bVBLZN+2jjGgg38epgxb2zmoGtSfvgMw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-brush": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@types/d3-brush/-/d3-brush-3.0.6.tgz", + "integrity": "sha512-nH60IZNNxEcrh6L1ZSMNA28rj27ut/2ZmI3r96Zd+1jrZD++zD3LsMIjWlvg4AYrHn/Pqz4CF3veCxGjtbqt7A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-chord": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@types/d3-chord/-/d3-chord-3.0.6.tgz", + "integrity": "sha512-LFYWWd8nwfwEmTZG9PfQxd17HbNPksHBiJHaKuY1XeqscXacsS2tyoo6OdRsjf+NQYeB6XrNL3a25E3gH69lcg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-color": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/@types/d3-color/-/d3-color-3.1.3.tgz", + "integrity": "sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-contour": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@types/d3-contour/-/d3-contour-3.0.6.tgz", + "integrity": "sha512-BjzLgXGnCWjUSYGfH1cpdo41/hgdWETu4YxpezoztawmqsvCeep+8QGfiY6YbDvfgHz/DkjeIkkZVJavB4a3rg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-array": "*", + "@types/geojson": "*" + } + }, + "node_modules/@types/d3-delaunay": { + "version": "6.0.4", + "resolved": "https://registry.npmjs.org/@types/d3-delaunay/-/d3-delaunay-6.0.4.tgz", + "integrity": "sha512-ZMaSKu4THYCU6sV64Lhg6qjf1orxBthaC161plr5KuPHo3CNm8DTHiLw/5Eq2b6TsNP0W0iJrUOFscY6Q450Hw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-dispatch": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/@types/d3-dispatch/-/d3-dispatch-3.0.7.tgz", + "integrity": "sha512-5o9OIAdKkhN1QItV2oqaE5KMIiXAvDWBDPrD85e58Qlz1c1kI/J0NcqbEG88CoTwJrYe7ntUCVfeUl2UJKbWgA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-drag": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/@types/d3-drag/-/d3-drag-3.0.7.tgz", + "integrity": "sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-dsv": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/@types/d3-dsv/-/d3-dsv-3.0.7.tgz", + "integrity": "sha512-n6QBF9/+XASqcKK6waudgL0pf/S5XHPPI8APyMLLUHd8NqouBGLsU8MgtO7NINGtPBtk9Kko/W4ea0oAspwh9g==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-ease": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/d3-ease/-/d3-ease-3.0.2.tgz", + "integrity": "sha512-NcV1JjO5oDzoK26oMzbILE6HW7uVXOHLQvHshBUW4UMdZGfiY6v5BeQwh9a9tCzv+CeefZQHJt5SRgK154RtiA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-fetch": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/@types/d3-fetch/-/d3-fetch-3.0.7.tgz", + "integrity": "sha512-fTAfNmxSb9SOWNB9IoG5c8Hg6R+AzUHDRlsXsDZsNp6sxAEOP0tkP3gKkNSO/qmHPoBFTxNrjDprVHDQDvo5aA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-dsv": "*" + } + }, + "node_modules/@types/d3-force": { + "version": "3.0.10", + "resolved": "https://registry.npmjs.org/@types/d3-force/-/d3-force-3.0.10.tgz", + "integrity": "sha512-ZYeSaCF3p73RdOKcjj+swRlZfnYpK1EbaDiYICEEp5Q6sUiqFaFQ9qgoshp5CzIyyb/yD09kD9o2zEltCexlgw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-format": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/d3-format/-/d3-format-3.0.4.tgz", + "integrity": "sha512-fALi2aI6shfg7vM5KiR1wNJnZ7r6UuggVqtDA+xiEdPZQwy/trcQaHnwShLuLdta2rTymCNpxYTiMZX/e09F4g==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-geo": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@types/d3-geo/-/d3-geo-3.1.0.tgz", + "integrity": "sha512-856sckF0oP/diXtS4jNsiQw/UuK5fQG8l/a9VVLeSouf1/PPbBE1i1W852zVwKwYCBkFJJB7nCFTbk6UMEXBOQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/geojson": "*" + } + }, + "node_modules/@types/d3-hierarchy": { + "version": "3.1.7", + "resolved": "https://registry.npmjs.org/@types/d3-hierarchy/-/d3-hierarchy-3.1.7.tgz", + "integrity": "sha512-tJFtNoYBtRtkNysX1Xq4sxtjK8YgoWUNpIiUee0/jHGRwqvzYxkq0hGVbbOGSz+JgFxxRu4K8nb3YpG3CMARtg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-interpolate": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz", + "integrity": "sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-color": "*" + } + }, + "node_modules/@types/d3-path": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@types/d3-path/-/d3-path-3.1.1.tgz", + "integrity": "sha512-VMZBYyQvbGmWyWVea0EHs/BwLgxc+MKi1zLDCONksozI4YJMcTt8ZEuIR4Sb1MMTE8MMW49v0IwI5+b7RmfWlg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-polygon": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/d3-polygon/-/d3-polygon-3.0.2.tgz", + "integrity": "sha512-ZuWOtMaHCkN9xoeEMr1ubW2nGWsp4nIql+OPQRstu4ypeZ+zk3YKqQT0CXVe/PYqrKpZAi+J9mTs05TKwjXSRA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-quadtree": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@types/d3-quadtree/-/d3-quadtree-3.0.6.tgz", + "integrity": "sha512-oUzyO1/Zm6rsxKRHA1vH0NEDG58HrT5icx/azi9MF1TWdtttWl0UIUsjEQBBh+SIkrpd21ZjEv7ptxWys1ncsg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-random": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/@types/d3-random/-/d3-random-3.0.3.tgz", + "integrity": "sha512-Imagg1vJ3y76Y2ea0871wpabqp613+8/r0mCLEBfdtqC7xMSfj9idOnmBYyMoULfHePJyxMAw3nWhJxzc+LFwQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-scale": { + "version": "4.0.9", + "resolved": "https://registry.npmjs.org/@types/d3-scale/-/d3-scale-4.0.9.tgz", + "integrity": "sha512-dLmtwB8zkAeO/juAMfnV+sItKjlsw2lKdZVVy6LRr0cBmegxSABiLEpGVmSJJ8O08i4+sGR6qQtb6WtuwJdvVw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-time": "*" + } + }, + "node_modules/@types/d3-scale-chromatic": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@types/d3-scale-chromatic/-/d3-scale-chromatic-3.1.0.tgz", + "integrity": "sha512-iWMJgwkK7yTRmWqRB5plb1kadXyQ5Sj8V/zYlFGMUBbIPKQScw+Dku9cAAMgJG+z5GYDoMjWGLVOvjghDEFnKQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-selection": { + "version": "3.0.11", + "resolved": "https://registry.npmjs.org/@types/d3-selection/-/d3-selection-3.0.11.tgz", + "integrity": "sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-shape": { + "version": "3.1.7", + "resolved": "https://registry.npmjs.org/@types/d3-shape/-/d3-shape-3.1.7.tgz", + "integrity": "sha512-VLvUQ33C+3J+8p+Daf+nYSOsjB4GXp19/S/aGo60m9h1v6XaxjiT82lKVWJCfzhtuZ3yD7i/TPeC/fuKLLOSmg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-path": "*" + } + }, + "node_modules/@types/d3-time": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/d3-time/-/d3-time-3.0.4.tgz", + "integrity": "sha512-yuzZug1nkAAaBlBBikKZTgzCeA+k1uy4ZFwWANOfKw5z5LRhV0gNA7gNkKm7HoK+HRN0wX3EkxGk0fpbWhmB7g==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-time-format": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/@types/d3-time-format/-/d3-time-format-4.0.3.tgz", + "integrity": "sha512-5xg9rC+wWL8kdDj153qZcsJ0FWiFt0J5RB6LYUNZjwSnesfblqrI/bJ1wBdJ8OQfncgbJG5+2F+qfqnqyzYxyg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-timer": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/d3-timer/-/d3-timer-3.0.2.tgz", + "integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/d3-transition": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@types/d3-transition/-/d3-transition-3.0.9.tgz", + "integrity": "sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-selection": "*" + } + }, + "node_modules/@types/d3-zoom": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@types/d3-zoom/-/d3-zoom-3.0.8.tgz", + "integrity": "sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/d3-interpolate": "*", + "@types/d3-selection": "*" + } + }, + "node_modules/@types/eslint": { + "version": "9.6.1", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-9.6.1.tgz", + "integrity": "sha512-FXx2pKgId/WyYo2jXw63kk7/+TY7u7AziEJxJAnSFzHlqTAS3Ync6SvgYAN/k4/PQpnnVuzoMuVnByKK2qp0ag==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "node_modules/@types/eslint-scope": { + "version": "3.7.7", + "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.7.tgz", + "integrity": "sha512-MzMFlSLBqNF2gcHWO0G1vP/YQyfvrxZ0bF+u7mzUdZ1/xK4A4sru+nraZz5i3iEIk1l1uyicaDVTB4QbbEkAYg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/eslint": "*", + "@types/estree": "*" + } + }, + "node_modules/@types/estree": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/geojson": { + "version": "7946.0.16", + "resolved": "https://registry.npmjs.org/@types/geojson/-/geojson-7946.0.16.tgz", + "integrity": "sha512-6C8nqWur3j98U6+lXDfTUWIfgvZU+EumvpHKcYjujKH7woYyLj2sUmff0tRhrqM7BohUw7Pz3ZB1jj2gW9Fvmg==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/glob": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/@types/glob/-/glob-7.2.0.tgz", + "integrity": "sha512-ZUxbzKl0IfJILTS6t7ip5fQQM/J3TJYubDm3nMbgubNNYS62eXeUpoLUC8/7fJNiFYHTrGPQn7hspDUzIHX3UA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/minimatch": "*", + "@types/node": "*" + } + }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/json5": { + "version": "0.0.29", + "resolved": "https://registry.npmjs.org/@types/json5/-/json5-0.0.29.tgz", + "integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/leaflet": { + "version": "1.9.21", + "resolved": "https://registry.npmjs.org/@types/leaflet/-/leaflet-1.9.21.tgz", + "integrity": "sha512-TbAd9DaPGSnzp6QvtYngntMZgcRk+igFELwR2N99XZn7RXUdKgsXMR+28bUO0rPsWp8MIu/f47luLIQuSLYv/w==", + "dev": true, + "dependencies": { + "@types/geojson": "*" + } + }, + "node_modules/@types/minimatch": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/@types/minimatch/-/minimatch-5.1.2.tgz", + "integrity": "sha512-K0VQKziLUWkVKiRVrx4a40iPaxTUefQmjtkQofBkYRcoaaL/8rhwDWww9qWbrgicNOgnpIsMxyNIUM4+n6dUIA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "24.3.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-24.3.0.tgz", + "integrity": "sha512-aPTXCrfwnDLj4VvXrm+UUCQjNEvJgNA8s5F1cvwQU+3KNltTOkBm1j30uNLyqqPNe7gE3KFzImYoZEfLhp4Yow==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~7.10.0" + } + }, + "node_modules/@types/pako": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@types/pako/-/pako-2.0.4.tgz", + "integrity": "sha512-VWDCbrLeVXJM9fihYodcLiIv0ku+AlOa/TQ1SvYOaBuyrSKgEcro95LJyIsJ4vSo6BXIxOKxiJAat04CmST9Fw==", + "license": "MIT" + }, + "node_modules/@types/prop-types": { + "version": "15.7.15", + "resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.15.tgz", + "integrity": "sha512-F6bEyamV9jKGAFBEmlQnesRPGOQqS2+Uwi0Em15xenOxHaf2hv6L8YCVn3rPdPJOiJfPiCnLIRyvwVaqMY3MIw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/react": { + "version": "18.3.24", + "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.24.tgz", + "integrity": "sha512-0dLEBsA1kI3OezMBF8nSsb7Nk19ZnsyE1LLhB8r27KbgU5H4pvuqZLdtE+aUkJVoXgTVuA+iLIwmZ0TuK4tx6A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/prop-types": "*", + "csstype": "^3.0.2" + } + }, + "node_modules/@types/react-dom": { + "version": "18.3.7", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.7.tgz", + "integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "@types/react": "^18.0.0" + } + }, + "node_modules/@typescript-eslint/eslint-plugin": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.40.0.tgz", + "integrity": "sha512-w/EboPlBwnmOBtRbiOvzjD+wdiZdgFeo17lkltrtn7X37vagKKWJABvyfsJXTlHe6XBzugmYgd4A4nW+k8Mixw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/regexpp": "^4.10.0", + "@typescript-eslint/scope-manager": "8.40.0", + "@typescript-eslint/type-utils": "8.40.0", + "@typescript-eslint/utils": "8.40.0", + "@typescript-eslint/visitor-keys": "8.40.0", + "graphemer": "^1.4.0", + "ignore": "^7.0.0", + "natural-compare": "^1.4.0", + "ts-api-utils": "^2.1.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "@typescript-eslint/parser": "^8.40.0", + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/eslint-plugin/node_modules/ignore": { + "version": "7.0.5", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-7.0.5.tgz", + "integrity": "sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/@typescript-eslint/parser": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.40.0.tgz", + "integrity": "sha512-jCNyAuXx8dr5KJMkecGmZ8KI61KBUhkCob+SD+C+I5+Y1FWI2Y3QmY4/cxMCC5WAsZqoEtEETVhUiUMIGCf6Bw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/scope-manager": "8.40.0", + "@typescript-eslint/types": "8.40.0", + "@typescript-eslint/typescript-estree": "8.40.0", + "@typescript-eslint/visitor-keys": "8.40.0", + "debug": "^4.3.4" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/project-service": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.40.0.tgz", + "integrity": "sha512-/A89vz7Wf5DEXsGVvcGdYKbVM9F7DyFXj52lNYUDS1L9yJfqjW/fIp5PgMuEJL/KeqVTe2QSbXAGUZljDUpArw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/tsconfig-utils": "^8.40.0", + "@typescript-eslint/types": "^8.40.0", + "debug": "^4.3.4" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/scope-manager": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.40.0.tgz", + "integrity": "sha512-y9ObStCcdCiZKzwqsE8CcpyuVMwRouJbbSrNuThDpv16dFAj429IkM6LNb1dZ2m7hK5fHyzNcErZf7CEeKXR4w==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.40.0", + "@typescript-eslint/visitor-keys": "8.40.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/tsconfig-utils": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.40.0.tgz", + "integrity": "sha512-jtMytmUaG9d/9kqSl/W3E3xaWESo4hFDxAIHGVW/WKKtQhesnRIJSAJO6XckluuJ6KDB5woD1EiqknriCtAmcw==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/type-utils": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.40.0.tgz", + "integrity": "sha512-eE60cK4KzAc6ZrzlJnflXdrMqOBaugeukWICO2rB0KNvwdIMaEaYiywwHMzA1qFpTxrLhN9Lp4E/00EgWcD3Ow==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.40.0", + "@typescript-eslint/typescript-estree": "8.40.0", + "@typescript-eslint/utils": "8.40.0", + "debug": "^4.3.4", + "ts-api-utils": "^2.1.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/types": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.40.0.tgz", + "integrity": "sha512-ETdbFlgbAmXHyFPwqUIYrfc12ArvpBhEVgGAxVYSwli26dn8Ko+lIo4Su9vI9ykTZdJn+vJprs/0eZU0YMAEQg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/typescript-estree": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.40.0.tgz", + "integrity": "sha512-k1z9+GJReVVOkc1WfVKs1vBrR5MIKKbdAjDTPvIK3L8De6KbFfPFt6BKpdkdk7rZS2GtC/m6yI5MYX+UsuvVYQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/project-service": "8.40.0", + "@typescript-eslint/tsconfig-utils": "8.40.0", + "@typescript-eslint/types": "8.40.0", + "@typescript-eslint/visitor-keys": "8.40.0", + "debug": "^4.3.4", + "fast-glob": "^3.3.2", + "is-glob": "^4.0.3", + "minimatch": "^9.0.4", + "semver": "^7.6.0", + "ts-api-utils": "^2.1.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/utils": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.40.0.tgz", + "integrity": "sha512-Cgzi2MXSZyAUOY+BFwGs17s7ad/7L+gKt6Y8rAVVWS+7o6wrjeFN4nVfTpbE25MNcxyJ+iYUXflbs2xR9h4UBg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.7.0", + "@typescript-eslint/scope-manager": "8.40.0", + "@typescript-eslint/types": "8.40.0", + "@typescript-eslint/typescript-estree": "8.40.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/visitor-keys": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.40.0.tgz", + "integrity": "sha512-8CZ47QwalyRjsypfwnbI3hKy5gJDPmrkLjkgMxhi0+DZZ2QNx2naS6/hWoVYUHU7LU2zleF68V9miaVZvhFfTA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.40.0", + "eslint-visitor-keys": "^4.2.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/visitor-keys/node_modules/eslint-visitor-keys": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz", + "integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@webassemblyjs/ast": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.14.1.tgz", + "integrity": "sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/helper-numbers": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2" + } + }, + "node_modules/@webassemblyjs/floating-point-hex-parser": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.13.2.tgz", + "integrity": "sha512-6oXyTOzbKxGH4steLbLNOu71Oj+C8Lg34n6CqRvqfS2O71BxY6ByfMDRhBytzknj9yGUPVJ1qIKhRlAwO1AovA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-api-error": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-api-error/-/helper-api-error-1.13.2.tgz", + "integrity": "sha512-U56GMYxy4ZQCbDZd6JuvvNV/WFildOjsaWD3Tzzvmw/mas3cXzRJPMjP83JqEsgSbyrmaGjBfDtV7KDXV9UzFQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-buffer": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.14.1.tgz", + "integrity": "sha512-jyH7wtcHiKssDtFPRB+iQdxlDf96m0E39yb0k5uJVhFGleZFoNw1c4aeIcVUPPbXUVJ94wwnMOAqUHyzoEPVMA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-numbers": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-numbers/-/helper-numbers-1.13.2.tgz", + "integrity": "sha512-FE8aCmS5Q6eQYcV3gI35O4J789wlQA+7JrqTTpJqn5emA4U2hvwJmvFRC0HODS+3Ye6WioDklgd6scJ3+PLnEA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/floating-point-hex-parser": "1.13.2", + "@webassemblyjs/helper-api-error": "1.13.2", + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webassemblyjs/helper-wasm-bytecode": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.13.2.tgz", + "integrity": "sha512-3QbLKy93F0EAIXLh0ogEVR6rOubA9AoZ+WRYhNbFyuB70j3dRdwH9g+qXhLAO0kiYGlg3TxDV+I4rQTr/YNXkA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@webassemblyjs/helper-wasm-section": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.14.1.tgz", + "integrity": "sha512-ds5mXEqTJ6oxRoqjhWDU83OgzAYjwsCV8Lo/N+oRsNDmx/ZDpqalmrtgOMkHwxsG0iI//3BwWAErYRHtgn0dZw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/wasm-gen": "1.14.1" + } + }, + "node_modules/@webassemblyjs/ieee754": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ieee754/-/ieee754-1.13.2.tgz", + "integrity": "sha512-4LtOzh58S/5lX4ITKxnAK2USuNEvpdVV9AlgGQb8rJDHaLeHciwG4zlGr0j/SNWlr7x3vO1lDEsuePvtcDNCkw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@xtuc/ieee754": "^1.2.0" + } + }, + "node_modules/@webassemblyjs/leb128": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/leb128/-/leb128-1.13.2.tgz", + "integrity": "sha512-Lde1oNoIdzVzdkNEAWZ1dZ5orIbff80YPdHx20mrHwHrVNNTjNr8E3xz9BdpcGqRQbAEa+fkrCb+fRFTl/6sQw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@webassemblyjs/utf8": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/@webassemblyjs/utf8/-/utf8-1.13.2.tgz", + "integrity": "sha512-3NQWGjKTASY1xV5m7Hr0iPeXD9+RDobLll3T9d2AO+g3my8xy5peVyjSag4I50mR1bBSN/Ct12lo+R9tJk0NZQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@webassemblyjs/wasm-edit": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.14.1.tgz", + "integrity": "sha512-RNJUIQH/J8iA/1NzlE4N7KtyZNHi3w7at7hDjvRNm5rcUXa00z1vRz3glZoULfJ5mpvYhLybmVcwcjGrC1pRrQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/helper-wasm-section": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-opt": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1", + "@webassemblyjs/wast-printer": "1.14.1" + } + }, + "node_modules/@webassemblyjs/wasm-gen": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.14.1.tgz", + "integrity": "sha512-AmomSIjP8ZbfGQhumkNvgC33AY7qtMCXnN6bL2u2Js4gVCg8fp735aEiMSBbDR7UQIj90n4wKAFUSEd0QN2Ukg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" + } + }, + "node_modules/@webassemblyjs/wasm-opt": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.14.1.tgz", + "integrity": "sha512-PTcKLUNvBqnY2U6E5bdOQcSM+oVP/PmrDY9NzowJjislEjwP/C4an2303MCVS2Mg9d3AJpIGdUFIQQWbPds0Sw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-buffer": "1.14.1", + "@webassemblyjs/wasm-gen": "1.14.1", + "@webassemblyjs/wasm-parser": "1.14.1" + } + }, + "node_modules/@webassemblyjs/wasm-parser": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.14.1.tgz", + "integrity": "sha512-JLBl+KZ0R5qB7mCnud/yyX08jWFw5MsoalJ1pQ4EdFlgj9VdXKGuENGsiCIjegI1W7p91rUlcB/LB5yRJKNTcQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@webassemblyjs/helper-api-error": "1.13.2", + "@webassemblyjs/helper-wasm-bytecode": "1.13.2", + "@webassemblyjs/ieee754": "1.13.2", + "@webassemblyjs/leb128": "1.13.2", + "@webassemblyjs/utf8": "1.13.2" + } + }, + "node_modules/@webassemblyjs/wast-printer": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.14.1.tgz", + "integrity": "sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@webassemblyjs/ast": "1.14.1", + "@xtuc/long": "4.2.2" + } + }, + "node_modules/@xtuc/ieee754": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@xtuc/ieee754/-/ieee754-1.2.0.tgz", + "integrity": "sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@xtuc/long": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/@xtuc/long/-/long-4.2.2.tgz", + "integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==", + "dev": true, + "license": "Apache-2.0" + }, + "node_modules/acorn": { + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "dev": true, + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ajv-formats": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/ajv-formats/-/ajv-formats-2.1.1.tgz", + "integrity": "sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ajv": "^8.0.0" + }, + "peerDependencies": { + "ajv": "^8.0.0" + }, + "peerDependenciesMeta": { + "ajv": { + "optional": true + } + } + }, + "node_modules/ajv-formats/node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ajv-formats/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true, + "license": "MIT" + }, + "node_modules/ajv-keywords": { + "version": "3.5.2", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-3.5.2.tgz", + "integrity": "sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "ajv": "^6.9.1" + } + }, + "node_modules/ansi-regex": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.0.tgz", + "integrity": "sha512-TKY5pyBkHyADOPYlRT9Lx6F544mPl0vS5Ew7BJ45hA08Q+t3GjbueLliBWN3sMICk6+y7HdyxSzC4bWS8baBdg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-regex?sponsor=1" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "dev": true, + "license": "Python-2.0" + }, + "node_modules/array-buffer-byte-length": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/array-buffer-byte-length/-/array-buffer-byte-length-1.0.2.tgz", + "integrity": "sha512-LHE+8BuR7RYGDKvnrmcuSq3tDcKv9OFEXQt/HpbZhY7V6h0zlUXutnAD82GiFx9rdieCMjkvtcsPqBwgUl1Iiw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "is-array-buffer": "^3.0.5" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array-includes": { + "version": "3.1.9", + "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.9.tgz", + "integrity": "sha512-FmeCCAenzH0KH381SPT5FZmiA/TmpndpcaShhfgEN9eCVjnFBqq3l1xrI42y8+PPLI6hypzou4GXw00WHmPBLQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "define-properties": "^1.2.1", + "es-abstract": "^1.24.0", + "es-object-atoms": "^1.1.1", + "get-intrinsic": "^1.3.0", + "is-string": "^1.1.1", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array-union": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/array-union/-/array-union-1.0.2.tgz", + "integrity": "sha512-Dxr6QJj/RdU/hCaBjOfxW+q6lyuVE6JFWIrAUpuOOhoJJoQ99cUn3igRaHVB5P9WrgFVN0FfArM3x0cueOU8ng==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-uniq": "^1.0.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/array-uniq": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/array-uniq/-/array-uniq-1.0.3.tgz", + "integrity": "sha512-MNha4BWQ6JbwhFhj03YK552f7cb3AzoE8SzeljgChvL1dl3IcvggXVz1DilzySZkCja+CXuZbdW7yATchWn8/Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/array.prototype.findlast": { + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/array.prototype.findlast/-/array.prototype.findlast-1.2.5.tgz", + "integrity": "sha512-CVvd6FHg1Z3POpBLxO6E6zr+rSKEQ9L6rZHAaY7lLfhKsWYUBBOuMs0e9o24oopj6H+geRCX0YJ+TJLBK2eHyQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.2", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.findlastindex": { + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/array.prototype.findlastindex/-/array.prototype.findlastindex-1.2.6.tgz", + "integrity": "sha512-F/TKATkzseUExPlfvmwQKGITM3DGTK+vkAsCZoDc5daVygbJBnjEUCbgkAvVFsgfXfX4YIqZ/27G3k3tdXrTxQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.9", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "es-shim-unscopables": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.flat": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/array.prototype.flat/-/array.prototype.flat-1.3.3.tgz", + "integrity": "sha512-rwG/ja1neyLqCuGZ5YYrznA62D4mZXg0i1cIskIUKSiqF3Cje9/wXAls9B9s1Wa2fomMsIv8czB8jZcPmxCXFg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.flatmap": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/array.prototype.flatmap/-/array.prototype.flatmap-1.3.3.tgz", + "integrity": "sha512-Y7Wt51eKJSyi80hFrJCePGGNo5ktJCslFuboqJsbf57CCPcm5zztluPlc4/aD8sWsKvlwatezpV4U1efk8kpjg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/array.prototype.tosorted": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/array.prototype.tosorted/-/array.prototype.tosorted-1.1.4.tgz", + "integrity": "sha512-p6Fx8B7b7ZhL/gmUsAy0D15WhvDccw3mnGNbZpi3pmeJdxtWsj2jEaI4Y6oo3XiHfzuSgPwKc04MYt6KgvC/wA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.3", + "es-errors": "^1.3.0", + "es-shim-unscopables": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/arraybuffer.prototype.slice": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/arraybuffer.prototype.slice/-/arraybuffer.prototype.slice-1.0.4.tgz", + "integrity": "sha512-BNoCY6SXXPQ7gF2opIP4GBE+Xw7U+pHMYKuzjgCN3GwiaIR09UUeKfheyIry77QtrCBlC0KK0q5/TER/tYh3PQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-buffer-byte-length": "^1.0.1", + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "is-array-buffer": "^3.0.4" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/async-function": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/async-function/-/async-function-1.0.0.tgz", + "integrity": "sha512-hsU18Ae8CDTR6Kgu9DYf0EbCr/a5iGL0rytQDobUcdpYOKokk8LEjVphnXkDkgpi0wYVsqrXuP0bZxJaTqdgoA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/available-typed-arrays": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/available-typed-arrays/-/available-typed-arrays-1.0.7.tgz", + "integrity": "sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "possible-typed-array-names": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "dev": true, + "license": "MIT", + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/browserslist": { + "version": "4.25.3", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.25.3.tgz", + "integrity": "sha512-cDGv1kkDI4/0e5yON9yM5G/0A5u8sf5TnmdX5C9qHzI9PPu++sQ9zjm1k9NiOrf3riY4OkK0zSGqfvJyJsgCBQ==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "caniuse-lite": "^1.0.30001735", + "electron-to-chromium": "^1.5.204", + "node-releases": "^2.0.19", + "update-browserslist-db": "^1.1.3" + }, + "bin": { + "browserslist": "cli.js" + }, + "engines": { + "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" + } + }, + "node_modules/buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/call-bind": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.8.tgz", + "integrity": "sha512-oKlSFMcMwpUg2ednkhQ454wfWiU/ul3CkJe/PEHcTKuiX6RpbehUiFMXu13HalGZxfUwCQzZG747YXBn1im9ww==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.0", + "es-define-property": "^1.0.0", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/caniuse-lite": { + "version": "1.0.30001737", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001737.tgz", + "integrity": "sha512-BiloLiXtQNrY5UyF0+1nSJLXUENuhka2pzy2Fx5pGxqavdrxSCW4U6Pn/PoG3Efspi2frRbHpBV2XsrPE6EDlw==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "CC-BY-4.0" + }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/chrome-trace-event": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/chrome-trace-event/-/chrome-trace-event-1.0.4.tgz", + "integrity": "sha512-rNjApaLzuwaOTjCiT8lSDdGN1APCiqkChLMJxJPWLunPAt5fy8xgU9/jNOchV84wfIxrA0lRQB7oCT8jrn/wrQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0" + } + }, + "node_modules/clean-webpack-plugin": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/clean-webpack-plugin/-/clean-webpack-plugin-4.0.0.tgz", + "integrity": "sha512-WuWE1nyTNAyW5T7oNyys2EN0cfP2fdRxhxnIQWiAp0bMabPdHhoGxM8A6YL2GhqwgrPnnaemVE7nv5XJ2Fhh2w==", + "dev": true, + "license": "MIT", + "dependencies": { + "del": "^4.1.1" + }, + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "webpack": ">=4.0.0 <6.0.0" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true, + "license": "MIT" + }, + "node_modules/commander": { + "version": "12.1.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-12.1.0.tgz", + "integrity": "sha512-Vw8qHK3bZM9y/P10u3Vib8o/DdkvA2OtPtZvD871QKjy74Wj1WSKFILMPRPSdUSx5RFK1arlJzEtA4PkFgnbuA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "dev": true, + "license": "MIT" + }, + "node_modules/core-util-is": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/create-foxglove-extension": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/create-foxglove-extension/-/create-foxglove-extension-1.0.6.tgz", + "integrity": "sha512-Gp0qOQ+nU6dkqgpQlEdqdYVL4PJtdG+HXnfw09npEJCGT9M+5KFLj9V6Xt07oV3rSO/vthoTKPLR6xAD/+nPZQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "clean-webpack-plugin": "4.0.0", + "commander": "12.1.0", + "jszip": "3.10.1", + "mkdirp": "3.0.1", + "ncp": "2.0.0", + "node-fetch": "2.7.0", + "path-browserify": "1.0.1", + "rimraf": "6.0.1", + "sanitize-filename": "1.6.3", + "ts-loader": "9.5.1", + "webpack": "5.96.1" + }, + "bin": { + "create-foxglove-extension": "dist/bin/create-foxglove-extension.js", + "foxglove-extension": "dist/bin/foxglove-extension.js" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "dev": true, + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/csstype": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", + "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", + "dev": true, + "license": "MIT" + }, + "node_modules/d3": { + "version": "7.9.0", + "resolved": "https://registry.npmjs.org/d3/-/d3-7.9.0.tgz", + "integrity": "sha512-e1U46jVP+w7Iut8Jt8ri1YsPOvFpg46k+K8TpCb0P+zjCkjkPnV7WzfDJzMHy1LnA+wj5pLT1wjO901gLXeEhA==", + "license": "ISC", + "dependencies": { + "d3-array": "3", + "d3-axis": "3", + "d3-brush": "3", + "d3-chord": "3", + "d3-color": "3", + "d3-contour": "4", + "d3-delaunay": "6", + "d3-dispatch": "3", + "d3-drag": "3", + "d3-dsv": "3", + "d3-ease": "3", + "d3-fetch": "3", + "d3-force": "3", + "d3-format": "3", + "d3-geo": "3", + "d3-hierarchy": "3", + "d3-interpolate": "3", + "d3-path": "3", + "d3-polygon": "3", + "d3-quadtree": "3", + "d3-random": "3", + "d3-scale": "4", + "d3-scale-chromatic": "3", + "d3-selection": "3", + "d3-shape": "3", + "d3-time": "3", + "d3-time-format": "4", + "d3-timer": "3", + "d3-transition": "3", + "d3-zoom": "3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-array": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz", + "integrity": "sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==", + "license": "ISC", + "dependencies": { + "internmap": "1 - 2" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-axis": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-axis/-/d3-axis-3.0.0.tgz", + "integrity": "sha512-IH5tgjV4jE/GhHkRV0HiVYPDtvfjHQlQfJHs0usq7M30XcSBvOotpmH1IgkcXsO/5gEQZD43B//fc7SRT5S+xw==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-brush": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-brush/-/d3-brush-3.0.0.tgz", + "integrity": "sha512-ALnjWlVYkXsVIGlOsuWH1+3udkYFI48Ljihfnh8FZPF2QS9o+PzGLBslO0PjzVoHLZ2KCVgAM8NVkXPJB2aNnQ==", + "license": "ISC", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-drag": "2 - 3", + "d3-interpolate": "1 - 3", + "d3-selection": "3", + "d3-transition": "3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-chord": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-chord/-/d3-chord-3.0.1.tgz", + "integrity": "sha512-VE5S6TNa+j8msksl7HwjxMHDM2yNK3XCkusIlpX5kwauBfXuyLAtNg9jCp/iHH61tgI4sb6R/EIMWCqEIdjT/g==", + "license": "ISC", + "dependencies": { + "d3-path": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-color": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-color/-/d3-color-3.1.0.tgz", + "integrity": "sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-contour": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/d3-contour/-/d3-contour-4.0.2.tgz", + "integrity": "sha512-4EzFTRIikzs47RGmdxbeUvLWtGedDUNkTcmzoeyg4sP/dvCexO47AaQL7VKy/gul85TOxw+IBgA8US2xwbToNA==", + "license": "ISC", + "dependencies": { + "d3-array": "^3.2.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-delaunay": { + "version": "6.0.4", + "resolved": "https://registry.npmjs.org/d3-delaunay/-/d3-delaunay-6.0.4.tgz", + "integrity": "sha512-mdjtIZ1XLAM8bm/hx3WwjfHt6Sggek7qH043O8KEjDXN40xi3vx/6pYSVTwLjEgiXQTbvaouWKynLBiUZ6SK6A==", + "license": "ISC", + "dependencies": { + "delaunator": "5" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-dispatch": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-dispatch/-/d3-dispatch-3.0.1.tgz", + "integrity": "sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-drag": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-drag/-/d3-drag-3.0.0.tgz", + "integrity": "sha512-pWbUJLdETVA8lQNJecMxoXfH6x+mO2UQo8rSmZ+QqxcbyA3hfeprFgIT//HW2nlHChWeIIMwS2Fq+gEARkhTkg==", + "license": "ISC", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-selection": "3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-dsv": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-dsv/-/d3-dsv-3.0.1.tgz", + "integrity": "sha512-UG6OvdI5afDIFP9w4G0mNq50dSOsXHJaRE8arAS5o9ApWnIElp8GZw1Dun8vP8OyHOZ/QJUKUJwxiiCCnUwm+Q==", + "license": "ISC", + "dependencies": { + "commander": "7", + "iconv-lite": "0.6", + "rw": "1" + }, + "bin": { + "csv2json": "bin/dsv2json.js", + "csv2tsv": "bin/dsv2dsv.js", + "dsv2dsv": "bin/dsv2dsv.js", + "dsv2json": "bin/dsv2json.js", + "json2csv": "bin/json2dsv.js", + "json2dsv": "bin/json2dsv.js", + "json2tsv": "bin/json2dsv.js", + "tsv2csv": "bin/dsv2dsv.js", + "tsv2json": "bin/dsv2json.js" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-dsv/node_modules/commander": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-7.2.0.tgz", + "integrity": "sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==", + "license": "MIT", + "engines": { + "node": ">= 10" + } + }, + "node_modules/d3-ease": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz", + "integrity": "sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-fetch": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-fetch/-/d3-fetch-3.0.1.tgz", + "integrity": "sha512-kpkQIM20n3oLVBKGg6oHrUchHM3xODkTzjMoj7aWQFq5QEM+R6E4WkzT5+tojDY7yjez8KgCBRoj4aEr99Fdqw==", + "license": "ISC", + "dependencies": { + "d3-dsv": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-force": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-force/-/d3-force-3.0.0.tgz", + "integrity": "sha512-zxV/SsA+U4yte8051P4ECydjD/S+qeYtnaIyAs9tgHCqfguma/aAQDjo85A9Z6EKhBirHRJHXIgJUlffT4wdLg==", + "license": "ISC", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-quadtree": "1 - 3", + "d3-timer": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-format": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-format/-/d3-format-3.1.0.tgz", + "integrity": "sha512-YyUI6AEuY/Wpt8KWLgZHsIU86atmikuoOmCfommt0LYHiQSPjvX2AcFc38PX0CBpr2RCyZhjex+NS/LPOv6YqA==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-geo": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/d3-geo/-/d3-geo-3.1.1.tgz", + "integrity": "sha512-637ln3gXKXOwhalDzinUgY83KzNWZRKbYubaG+fGVuc/dxO64RRljtCTnf5ecMyE1RIdtqpkVcq0IbtU2S8j2Q==", + "license": "ISC", + "dependencies": { + "d3-array": "2.5.0 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-hierarchy": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/d3-hierarchy/-/d3-hierarchy-3.1.2.tgz", + "integrity": "sha512-FX/9frcub54beBdugHjDCdikxThEqjnR93Qt7PvQTOHxyiNCAlvMrHhclk3cD5VeAaq9fxmfRp+CnWw9rEMBuA==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-interpolate": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-interpolate/-/d3-interpolate-3.0.1.tgz", + "integrity": "sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==", + "license": "ISC", + "dependencies": { + "d3-color": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-path": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-path/-/d3-path-3.1.0.tgz", + "integrity": "sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-polygon": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-polygon/-/d3-polygon-3.0.1.tgz", + "integrity": "sha512-3vbA7vXYwfe1SYhED++fPUQlWSYTTGmFmQiany/gdbiWgU/iEyQzyymwL9SkJjFFuCS4902BSzewVGsHHmHtXg==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-quadtree": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-quadtree/-/d3-quadtree-3.0.1.tgz", + "integrity": "sha512-04xDrxQTDTCFwP5H6hRhsRcb9xxv2RzkcsygFzmkSIOJy3PeRJP7sNk3VRIbKXcog561P9oU0/rVH6vDROAgUw==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-random": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-random/-/d3-random-3.0.1.tgz", + "integrity": "sha512-FXMe9GfxTxqd5D6jFsQ+DJ8BJS4E/fT5mqqdjovykEB2oFbTMDVdg1MGFxfQW+FBOGoB++k8swBrgwSHT1cUXQ==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-scale": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/d3-scale/-/d3-scale-4.0.2.tgz", + "integrity": "sha512-GZW464g1SH7ag3Y7hXjf8RoUuAFIqklOAq3MRl4OaWabTFJY9PN/E1YklhXLh+OQ3fM9yS2nOkCoS+WLZ6kvxQ==", + "license": "ISC", + "dependencies": { + "d3-array": "2.10.0 - 3", + "d3-format": "1 - 3", + "d3-interpolate": "1.2.0 - 3", + "d3-time": "2.1.1 - 3", + "d3-time-format": "2 - 4" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-scale-chromatic": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-scale-chromatic/-/d3-scale-chromatic-3.1.0.tgz", + "integrity": "sha512-A3s5PWiZ9YCXFye1o246KoscMWqf8BsD9eRiJ3He7C9OBaxKhAd5TFCdEx/7VbKtxxTsu//1mMJFrEt572cEyQ==", + "license": "ISC", + "dependencies": { + "d3-color": "1 - 3", + "d3-interpolate": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-selection": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", + "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-shape": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/d3-shape/-/d3-shape-3.2.0.tgz", + "integrity": "sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==", + "license": "ISC", + "dependencies": { + "d3-path": "^3.1.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-time": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-time/-/d3-time-3.1.0.tgz", + "integrity": "sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q==", + "license": "ISC", + "dependencies": { + "d3-array": "2 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-time-format": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/d3-time-format/-/d3-time-format-4.1.0.tgz", + "integrity": "sha512-dJxPBlzC7NugB2PDLwo9Q8JiTR3M3e4/XANkreKSUxF8vvXKqm1Yfq4Q5dl8budlunRVlUUaDUgFt7eA8D6NLg==", + "license": "ISC", + "dependencies": { + "d3-time": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-timer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-timer/-/d3-timer-3.0.1.tgz", + "integrity": "sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-transition": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-transition/-/d3-transition-3.0.1.tgz", + "integrity": "sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==", + "license": "ISC", + "dependencies": { + "d3-color": "1 - 3", + "d3-dispatch": "1 - 3", + "d3-ease": "1 - 3", + "d3-interpolate": "1 - 3", + "d3-timer": "1 - 3" + }, + "engines": { + "node": ">=12" + }, + "peerDependencies": { + "d3-selection": "2 - 3" + } + }, + "node_modules/d3-zoom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-zoom/-/d3-zoom-3.0.0.tgz", + "integrity": "sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==", + "license": "ISC", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-drag": "2 - 3", + "d3-interpolate": "1 - 3", + "d3-selection": "2 - 3", + "d3-transition": "2 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/data-view-buffer": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/data-view-buffer/-/data-view-buffer-1.0.2.tgz", + "integrity": "sha512-EmKO5V3OLXh1rtK2wgXRansaK1/mtVdTUEiEI0W8RkvgT05kfxaH29PliLnpLP73yYO6142Q72QNa8Wx/A5CqQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "is-data-view": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/data-view-byte-length": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/data-view-byte-length/-/data-view-byte-length-1.0.2.tgz", + "integrity": "sha512-tuhGbE6CfTM9+5ANGf+oQb72Ky/0+s3xKUpHvShfiz2RxMFgFPjsXuRLBVMtvMs15awe45SRb83D6wH4ew6wlQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "is-data-view": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/inspect-js" + } + }, + "node_modules/data-view-byte-offset": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/data-view-byte-offset/-/data-view-byte-offset-1.0.1.tgz", + "integrity": "sha512-BS8PfmtDGnrgYdOonGZQdLZslWIeCGFP9tpan0hi1Co2Zr2NKADsvGYA8XxuG/4UWgJ6Cjtv+YJnB6MM69QGlQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "is-data-view": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/debug": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz", + "integrity": "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/deep-is": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", + "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/define-properties": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.1.tgz", + "integrity": "sha512-8QmQKqEASLd5nx0U1B1okLElbUuuttJ/AnYmRXbbbGDWh6uS208EjD4Xqq/I9wK7u0v6O08XhTWnt5XtEbR6Dg==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.0.1", + "has-property-descriptors": "^1.0.0", + "object-keys": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/del": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/del/-/del-4.1.1.tgz", + "integrity": "sha512-QwGuEUouP2kVwQenAsOof5Fv8K9t3D8Ca8NxcXKrIpEHjTXK5J2nXLdP+ALI1cgv8wj7KuwBhTwBkOZSJKM5XQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/glob": "^7.1.1", + "globby": "^6.1.0", + "is-path-cwd": "^2.0.0", + "is-path-in-cwd": "^2.0.0", + "p-map": "^2.0.0", + "pify": "^4.0.1", + "rimraf": "^2.6.3" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/del/node_modules/rimraf": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", + "integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==", + "deprecated": "Rimraf versions prior to v4 are no longer supported", + "dev": true, + "license": "ISC", + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + } + }, + "node_modules/delaunator": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/delaunator/-/delaunator-5.0.1.tgz", + "integrity": "sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw==", + "license": "ISC", + "dependencies": { + "robust-predicates": "^3.0.2" + } + }, + "node_modules/doctrine": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-2.1.0.tgz", + "integrity": "sha512-35mSku4ZXK0vfCuHEDAwt55dg2jNajHZ1odvF+8SSr82EsZY4QmXfuWso8oEd8zRhVObSN18aM0CjSdoBX7zIw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "esutils": "^2.0.2" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/eastasianwidth": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", + "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==", + "dev": true, + "license": "MIT" + }, + "node_modules/electron-to-chromium": { + "version": "1.5.208", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.208.tgz", + "integrity": "sha512-ozZyibehoe7tOhNaf16lKmljVf+3npZcJIEbJRVftVsmAg5TeA1mGS9dVCZzOwr2xT7xK15V0p7+GZqSPgkuPg==", + "dev": true, + "license": "ISC" + }, + "node_modules/emoji-regex": { + "version": "9.2.2", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", + "dev": true, + "license": "MIT" + }, + "node_modules/engine.io-client": { + "version": "6.6.3", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.6.3.tgz", + "integrity": "sha512-T0iLjnyNWahNyv/lcjS2y4oE358tVS/SYQNxYXGAJ9/GLgH4VCvOQ/mhTjqU88mLZCQgiG8RIegFHYCdVC+j5w==", + "license": "MIT", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1", + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1", + "xmlhttprequest-ssl": "~2.1.1" + } + }, + "node_modules/engine.io-client/node_modules/debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/engine.io-parser": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/enhanced-resolve": { + "version": "5.18.3", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.3.tgz", + "integrity": "sha512-d4lC8xfavMeBjzGr2vECC3fsGXziXZQyJxD868h2M/mBI3PwAuODxAkLkq5HYuvrPYcUtiLzsTo8U3PgX3Ocww==", + "dev": true, + "license": "MIT", + "dependencies": { + "graceful-fs": "^4.2.4", + "tapable": "^2.2.0" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/es-abstract": { + "version": "1.24.0", + "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.24.0.tgz", + "integrity": "sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-buffer-byte-length": "^1.0.2", + "arraybuffer.prototype.slice": "^1.0.4", + "available-typed-arrays": "^1.0.7", + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "data-view-buffer": "^1.0.2", + "data-view-byte-length": "^1.0.2", + "data-view-byte-offset": "^1.0.1", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "es-set-tostringtag": "^2.1.0", + "es-to-primitive": "^1.3.0", + "function.prototype.name": "^1.1.8", + "get-intrinsic": "^1.3.0", + "get-proto": "^1.0.1", + "get-symbol-description": "^1.1.0", + "globalthis": "^1.0.4", + "gopd": "^1.2.0", + "has-property-descriptors": "^1.0.2", + "has-proto": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "internal-slot": "^1.1.0", + "is-array-buffer": "^3.0.5", + "is-callable": "^1.2.7", + "is-data-view": "^1.0.2", + "is-negative-zero": "^2.0.3", + "is-regex": "^1.2.1", + "is-set": "^2.0.3", + "is-shared-array-buffer": "^1.0.4", + "is-string": "^1.1.1", + "is-typed-array": "^1.1.15", + "is-weakref": "^1.1.1", + "math-intrinsics": "^1.1.0", + "object-inspect": "^1.13.4", + "object-keys": "^1.1.1", + "object.assign": "^4.1.7", + "own-keys": "^1.0.1", + "regexp.prototype.flags": "^1.5.4", + "safe-array-concat": "^1.1.3", + "safe-push-apply": "^1.0.0", + "safe-regex-test": "^1.1.0", + "set-proto": "^1.0.0", + "stop-iteration-iterator": "^1.1.0", + "string.prototype.trim": "^1.2.10", + "string.prototype.trimend": "^1.0.9", + "string.prototype.trimstart": "^1.0.8", + "typed-array-buffer": "^1.0.3", + "typed-array-byte-length": "^1.0.3", + "typed-array-byte-offset": "^1.0.4", + "typed-array-length": "^1.0.7", + "unbox-primitive": "^1.1.0", + "which-typed-array": "^1.1.19" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-iterator-helpers": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/es-iterator-helpers/-/es-iterator-helpers-1.2.1.tgz", + "integrity": "sha512-uDn+FE1yrDzyC0pCo961B2IHbdM8y/ACZsKD4dG6WqrjV53BADjwa7D+1aom2rsNVfLyDgU/eigvlJGJ08OQ4w==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.6", + "es-errors": "^1.3.0", + "es-set-tostringtag": "^2.0.3", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.6", + "globalthis": "^1.0.4", + "gopd": "^1.2.0", + "has-property-descriptors": "^1.0.2", + "has-proto": "^1.2.0", + "has-symbols": "^1.1.0", + "internal-slot": "^1.1.0", + "iterator.prototype": "^1.1.4", + "safe-array-concat": "^1.1.3" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-module-lexer": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz", + "integrity": "sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==", + "dev": true, + "license": "MIT" + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-shim-unscopables": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/es-shim-unscopables/-/es-shim-unscopables-1.1.0.tgz", + "integrity": "sha512-d9T8ucsEhh8Bi1woXCf+TIKDIROLG5WCkxg8geBCbvk22kzwC5G2OnXVMO6FUsvQlgUUXQ2itephWDLqDzbeCw==", + "dev": true, + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-to-primitive": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-to-primitive/-/es-to-primitive-1.3.0.tgz", + "integrity": "sha512-w+5mJ3GuFL+NjVtJlvydShqE1eN3h3PbI7/5LAsYJP/2qtuMXjfL2LpHSRqo4b4eSF5K/DH1JXKUAHSB2UW50g==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-callable": "^1.2.7", + "is-date-object": "^1.0.5", + "is-symbol": "^1.0.4" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/eslint": { + "version": "9.34.0", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.34.0.tgz", + "integrity": "sha512-RNCHRX5EwdrESy3Jc9o8ie8Bog+PeYvvSR8sDGoZxNFTvZ4dlxUB3WzQ3bQMztFrSRODGrLLj8g6OFuGY/aiQg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.2.0", + "@eslint-community/regexpp": "^4.12.1", + "@eslint/config-array": "^0.21.0", + "@eslint/config-helpers": "^0.3.1", + "@eslint/core": "^0.15.2", + "@eslint/eslintrc": "^3.3.1", + "@eslint/js": "9.34.0", + "@eslint/plugin-kit": "^0.3.5", + "@humanfs/node": "^0.16.6", + "@humanwhocodes/module-importer": "^1.0.1", + "@humanwhocodes/retry": "^0.4.2", + "@types/estree": "^1.0.6", + "@types/json-schema": "^7.0.15", + "ajv": "^6.12.4", + "chalk": "^4.0.0", + "cross-spawn": "^7.0.6", + "debug": "^4.3.2", + "escape-string-regexp": "^4.0.0", + "eslint-scope": "^8.4.0", + "eslint-visitor-keys": "^4.2.1", + "espree": "^10.4.0", + "esquery": "^1.5.0", + "esutils": "^2.0.2", + "fast-deep-equal": "^3.1.3", + "file-entry-cache": "^8.0.0", + "find-up": "^5.0.0", + "glob-parent": "^6.0.2", + "ignore": "^5.2.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "json-stable-stringify-without-jsonify": "^1.0.1", + "lodash.merge": "^4.6.2", + "minimatch": "^3.1.2", + "natural-compare": "^1.4.0", + "optionator": "^0.9.3" + }, + "bin": { + "eslint": "bin/eslint.js" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + }, + "peerDependencies": { + "jiti": "*" + }, + "peerDependenciesMeta": { + "jiti": { + "optional": true + } + } + }, + "node_modules/eslint-config-prettier": { + "version": "9.1.2", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-9.1.2.tgz", + "integrity": "sha512-iI1f+D2ViGn+uvv5HuHVUamg8ll4tN+JRHGc6IJi4TP9Kl976C57fzPXgseXNs8v0iA8aSJpHsTWjDb9QJamGQ==", + "dev": true, + "license": "MIT", + "bin": { + "eslint-config-prettier": "bin/cli.js" + }, + "peerDependencies": { + "eslint": ">=7.0.0" + } + }, + "node_modules/eslint-import-resolver-node": { + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/eslint-import-resolver-node/-/eslint-import-resolver-node-0.3.9.tgz", + "integrity": "sha512-WFj2isz22JahUv+B788TlO3N6zL3nNJGU8CcZbPZvVEkBPaJdCV4vy5wyghty5ROFbCRnm132v8BScu5/1BQ8g==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "^3.2.7", + "is-core-module": "^2.13.0", + "resolve": "^1.22.4" + } + }, + "node_modules/eslint-import-resolver-node/node_modules/debug": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-3.2.7.tgz", + "integrity": "sha512-CFjzYYAi4ThfiQvizrFQevTTXHtnCqWfe7x1AhgEscTz6ZbLbfoLRLPugTQyBth6f8ZERVUSyWHFD/7Wu4t1XQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.1" + } + }, + "node_modules/eslint-module-utils": { + "version": "2.12.1", + "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.12.1.tgz", + "integrity": "sha512-L8jSWTze7K2mTg0vos/RuLRS5soomksDPoJLXIslC7c8Wmut3bx7CPpJijDcBZtxQ5lrbUdM+s0OlNbz0DCDNw==", + "dev": true, + "license": "MIT", + "dependencies": { + "debug": "^3.2.7" + }, + "engines": { + "node": ">=4" + }, + "peerDependenciesMeta": { + "eslint": { + "optional": true + } + } + }, + "node_modules/eslint-module-utils/node_modules/debug": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-3.2.7.tgz", + "integrity": "sha512-CFjzYYAi4ThfiQvizrFQevTTXHtnCqWfe7x1AhgEscTz6ZbLbfoLRLPugTQyBth6f8ZERVUSyWHFD/7Wu4t1XQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.1" + } + }, + "node_modules/eslint-plugin-es": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-es/-/eslint-plugin-es-4.1.0.tgz", + "integrity": "sha512-GILhQTnjYE2WorX5Jyi5i4dz5ALWxBIdQECVQavL6s7cI76IZTDWleTHkxz/QT3kvcs2QlGHvKLYsSlPOlPXnQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-utils": "^2.0.0", + "regexpp": "^3.0.0" + }, + "engines": { + "node": ">=8.10.0" + }, + "funding": { + "url": "https://github.com/sponsors/mysticatea" + }, + "peerDependencies": { + "eslint": ">=4.19.1" + } + }, + "node_modules/eslint-plugin-filenames": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/eslint-plugin-filenames/-/eslint-plugin-filenames-1.3.2.tgz", + "integrity": "sha512-tqxJTiEM5a0JmRCUYQmxw23vtTxrb2+a3Q2mMOPhFxvt7ZQQJmdiuMby9B/vUAuVMghyP7oET+nIf6EO6CBd/w==", + "dev": true, + "license": "MIT", + "dependencies": { + "lodash.camelcase": "4.3.0", + "lodash.kebabcase": "4.1.1", + "lodash.snakecase": "4.1.1", + "lodash.upperfirst": "4.3.1" + }, + "peerDependencies": { + "eslint": "*" + } + }, + "node_modules/eslint-plugin-import": { + "version": "2.32.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.32.0.tgz", + "integrity": "sha512-whOE1HFo/qJDyX4SnXzP4N6zOWn79WhnCUY/iDR0mPfQZO8wcYE4JClzI2oZrhBnnMUCBCHZhO6VQyoBU95mZA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@rtsao/scc": "^1.1.0", + "array-includes": "^3.1.9", + "array.prototype.findlastindex": "^1.2.6", + "array.prototype.flat": "^1.3.3", + "array.prototype.flatmap": "^1.3.3", + "debug": "^3.2.7", + "doctrine": "^2.1.0", + "eslint-import-resolver-node": "^0.3.9", + "eslint-module-utils": "^2.12.1", + "hasown": "^2.0.2", + "is-core-module": "^2.16.1", + "is-glob": "^4.0.3", + "minimatch": "^3.1.2", + "object.fromentries": "^2.0.8", + "object.groupby": "^1.0.3", + "object.values": "^1.2.1", + "semver": "^6.3.1", + "string.prototype.trimend": "^1.0.9", + "tsconfig-paths": "^3.15.0" + }, + "engines": { + "node": ">=4" + }, + "peerDependencies": { + "eslint": "^2 || ^3 || ^4 || ^5 || ^6 || ^7.2.0 || ^8 || ^9" + } + }, + "node_modules/eslint-plugin-import/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/eslint-plugin-import/node_modules/debug": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-3.2.7.tgz", + "integrity": "sha512-CFjzYYAi4ThfiQvizrFQevTTXHtnCqWfe7x1AhgEscTz6ZbLbfoLRLPugTQyBth6f8ZERVUSyWHFD/7Wu4t1XQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.1" + } + }, + "node_modules/eslint-plugin-import/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/eslint-plugin-import/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + } + }, + "node_modules/eslint-plugin-jest": { + "version": "28.14.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-jest/-/eslint-plugin-jest-28.14.0.tgz", + "integrity": "sha512-P9s/qXSMTpRTerE2FQ0qJet2gKbcGyFTPAJipoKxmWqR6uuFqIqk8FuEfg5yBieOezVrEfAMZrEwJ6yEp+1MFQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/utils": "^6.0.0 || ^7.0.0 || ^8.0.0" + }, + "engines": { + "node": "^16.10.0 || ^18.12.0 || >=20.0.0" + }, + "peerDependencies": { + "@typescript-eslint/eslint-plugin": "^6.0.0 || ^7.0.0 || ^8.0.0", + "eslint": "^7.0.0 || ^8.0.0 || ^9.0.0", + "jest": "*" + }, + "peerDependenciesMeta": { + "@typescript-eslint/eslint-plugin": { + "optional": true + }, + "jest": { + "optional": true + } + } + }, + "node_modules/eslint-plugin-prettier": { + "version": "5.5.4", + "resolved": "https://registry.npmjs.org/eslint-plugin-prettier/-/eslint-plugin-prettier-5.5.4.tgz", + "integrity": "sha512-swNtI95SToIz05YINMA6Ox5R057IMAmWZ26GqPxusAp1TZzj+IdY9tXNWWD3vkF/wEqydCONcwjTFpxybBqZsg==", + "dev": true, + "license": "MIT", + "dependencies": { + "prettier-linter-helpers": "^1.0.0", + "synckit": "^0.11.7" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint-plugin-prettier" + }, + "peerDependencies": { + "@types/eslint": ">=8.0.0", + "eslint": ">=8.0.0", + "eslint-config-prettier": ">= 7.0.0 <10.0.0 || >=10.1.0", + "prettier": ">=3.0.0" + }, + "peerDependenciesMeta": { + "@types/eslint": { + "optional": true + }, + "eslint-config-prettier": { + "optional": true + } + } + }, + "node_modules/eslint-plugin-react": { + "version": "7.37.5", + "resolved": "https://registry.npmjs.org/eslint-plugin-react/-/eslint-plugin-react-7.37.5.tgz", + "integrity": "sha512-Qteup0SqU15kdocexFNAJMvCJEfa2xUKNV4CC1xsVMrIIqEy3SQ/rqyxCWNzfrd3/ldy6HMlD2e0JDVpDg2qIA==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-includes": "^3.1.8", + "array.prototype.findlast": "^1.2.5", + "array.prototype.flatmap": "^1.3.3", + "array.prototype.tosorted": "^1.1.4", + "doctrine": "^2.1.0", + "es-iterator-helpers": "^1.2.1", + "estraverse": "^5.3.0", + "hasown": "^2.0.2", + "jsx-ast-utils": "^2.4.1 || ^3.0.0", + "minimatch": "^3.1.2", + "object.entries": "^1.1.9", + "object.fromentries": "^2.0.8", + "object.values": "^1.2.1", + "prop-types": "^15.8.1", + "resolve": "^2.0.0-next.5", + "semver": "^6.3.1", + "string.prototype.matchall": "^4.0.12", + "string.prototype.repeat": "^1.0.0" + }, + "engines": { + "node": ">=4" + }, + "peerDependencies": { + "eslint": "^3 || ^4 || ^5 || ^6 || ^7 || ^8 || ^9.7" + } + }, + "node_modules/eslint-plugin-react-hooks": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-react-hooks/-/eslint-plugin-react-hooks-5.2.0.tgz", + "integrity": "sha512-+f15FfK64YQwZdJNELETdn5ibXEUQmW1DZL6KXhNnc2heoy/sg9VJJeT7n8TlMWouzWqSWavFkIhHyIbIAEapg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "eslint": "^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0 || ^9.0.0" + } + }, + "node_modules/eslint-plugin-react/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/eslint-plugin-react/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/eslint-plugin-react/node_modules/resolve": { + "version": "2.0.0-next.5", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-2.0.0-next.5.tgz", + "integrity": "sha512-U7WjGVG9sH8tvjW5SmGbQuui75FiyjAX72HX15DwBBwF9dNiQZRQAg9nnPhYy+TUnE0+VcrttuvNI8oSxZcocA==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-core-module": "^2.13.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/eslint-plugin-react/node_modules/semver": { + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + } + }, + "node_modules/eslint-scope": { + "version": "8.4.0", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.4.0.tgz", + "integrity": "sha512-sNXOfKCn74rt8RICKMvJS7XKV/Xk9kA7DyJr8mJik3S7Cwgy3qlkkmyS2uQB3jiJg6VNdZd/pDBJu0nvG2NlTg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^5.2.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-utils": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/eslint-utils/-/eslint-utils-2.1.0.tgz", + "integrity": "sha512-w94dQYoauyvlDc43XnGB8lU3Zt713vNChgt4EWwhXAP2XkBvndfxF0AgIqKOOasjPIPzj9JqgwkwbCYD0/V3Zg==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^1.1.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/mysticatea" + } + }, + "node_modules/eslint-utils/node_modules/eslint-visitor-keys": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-1.3.0.tgz", + "integrity": "sha512-6J72N8UNa462wa/KFODt/PJ3IU60SDpC3QXC1Hjc1BXXpfL2C9R5+AU7jhe0F6GREqVMh4Juu+NY7xn+6dipUQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=4" + } + }, + "node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/eslint/node_modules/eslint-visitor-keys": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz", + "integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/espree": { + "version": "10.4.0", + "resolved": "https://registry.npmjs.org/espree/-/espree-10.4.0.tgz", + "integrity": "sha512-j6PAQ2uUr79PZhBjP5C5fhl8e39FmRnOjsD5lGnWrFU8i2G776tBK7+nP8KuQUTTyAZUwfQqXAgrVH5MbH9CYQ==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "acorn": "^8.15.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^4.2.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/espree/node_modules/eslint-visitor-keys": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz", + "integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/esquery": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.6.0.tgz", + "integrity": "sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "estraverse": "^5.1.0" + }, + "engines": { + "node": ">=0.10" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/events": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/events/-/events-3.3.0.tgz", + "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.8.x" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-diff": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/fast-diff/-/fast-diff-1.3.0.tgz", + "integrity": "sha512-VxPP4NqbUjj6MaAOafWeUn2cXWLcCtljklUtZf0Ind4XQ+QPtmA0b18zZy0jIQx+ExRVCR/ZQpBmik5lXshNsw==", + "dev": true, + "license": "Apache-2.0" + }, + "node_modules/fast-glob": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", + "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.8" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-uri": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz", + "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fastify" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fastify" + } + ], + "license": "BSD-3-Clause" + }, + "node_modules/fastq": { + "version": "1.19.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.1.tgz", + "integrity": "sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ==", + "dev": true, + "license": "ISC", + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/file-entry-cache": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", + "integrity": "sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "flat-cache": "^4.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "dev": true, + "license": "MIT", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/find-up": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", + "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", + "dev": true, + "license": "MIT", + "dependencies": { + "locate-path": "^6.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/flat-cache": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-4.0.1.tgz", + "integrity": "sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==", + "dev": true, + "license": "MIT", + "dependencies": { + "flatted": "^3.2.9", + "keyv": "^4.5.4" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/flatted": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", + "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "dev": true, + "license": "ISC" + }, + "node_modules/for-each": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/for-each/-/for-each-0.3.5.tgz", + "integrity": "sha512-dKx12eRCVIzqCxFGplyFKJMPvLEWgmNtUrpTiJIR5u97zEhRG8ySrtboPHZXx7daLxQVrl643cTzbab2tkQjxg==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-callable": "^1.2.7" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/foreground-child": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.1.tgz", + "integrity": "sha512-gIXjKqtFuWEgzFRJA9WCQeSJLZDjgJUOMCMzxtvFq/37KojM1BFGufqsCy0r4qSQmYLsZYMeyRqzIWOMup03sw==", + "dev": true, + "license": "ISC", + "dependencies": { + "cross-spawn": "^7.0.6", + "signal-exit": "^4.0.1" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/fs.realpath": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", + "dev": true, + "license": "ISC" + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/function.prototype.name": { + "version": "1.1.8", + "resolved": "https://registry.npmjs.org/function.prototype.name/-/function.prototype.name-1.1.8.tgz", + "integrity": "sha512-e5iwyodOHhbMr/yNrc7fDYG4qlbIvI5gajyzPnb5TCwyhjApznQh1BMFou9b30SevY43gCJKXycoCBjMbsuW0Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "functions-have-names": "^1.2.3", + "hasown": "^2.0.2", + "is-callable": "^1.2.7" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/functions-have-names": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/functions-have-names/-/functions-have-names-1.2.3.tgz", + "integrity": "sha512-xckBUXyTIqT97tq2x2AMb+g163b5JFysYk0x4qxNFwbfQkmNZoiRHb6sPzI9/QV33WeuvVYBUIiD4NzNIyqaRQ==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "dev": true, + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/get-symbol-description": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/get-symbol-description/-/get-symbol-description-1.1.0.tgz", + "integrity": "sha512-w9UMqWwJxHNOvoNzSJ2oPF5wvYcvP7jUvYzhp67yEhTi17ZDBBC1z9pTdGuzjD+EFIqLSYRweZjqfiPzQ06Ebg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/glob": { + "version": "7.2.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "deprecated": "Glob versions prior to v9 are no longer supported", + "dev": true, + "license": "ISC", + "dependencies": { + "fs.realpath": "^1.0.0", + "inflight": "^1.0.4", + "inherits": "2", + "minimatch": "^3.1.1", + "once": "^1.3.0", + "path-is-absolute": "^1.0.0" + }, + "engines": { + "node": "*" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/glob-to-regexp": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz", + "integrity": "sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==", + "dev": true, + "license": "BSD-2-Clause" + }, + "node_modules/glob/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/glob/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/globals": { + "version": "14.0.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-14.0.0.tgz", + "integrity": "sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/globalthis": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/globalthis/-/globalthis-1.0.4.tgz", + "integrity": "sha512-DpLKbNU4WylpxJykQujfCcwYWiV/Jhm50Goo0wrVILAv5jOr9d+H+UR3PhSCD2rCCEIg0uc+G+muBTwD54JhDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-properties": "^1.2.1", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/globby": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-6.1.0.tgz", + "integrity": "sha512-KVbFv2TQtbzCoxAnfD6JcHZTYCzyliEaaeM/gH8qQdkKr5s0OP9scEgvdcngyk7AVdY6YVW/TJHd+lQ/Df3Daw==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-union": "^1.0.1", + "glob": "^7.0.3", + "object-assign": "^4.0.1", + "pify": "^2.0.0", + "pinkie-promise": "^2.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/globby/node_modules/pify": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", + "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", + "dev": true, + "license": "ISC" + }, + "node_modules/graphemer": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", + "dev": true, + "license": "MIT" + }, + "node_modules/has-bigints": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-bigints/-/has-bigints-1.1.0.tgz", + "integrity": "sha512-R3pbpkcIqv2Pm3dUwgjclDRVmWpTJW2DcMzcIhEXEx1oh/CEMObMm3KLmRJOdvhM7o4uQBnwr8pzRK2sJWIqfg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/has-property-descriptors": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-define-property": "^1.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-proto": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.2.0.tgz", + "integrity": "sha512-KIL7eQPfHQRC8+XluaIw7BHUwwqL19bQn4hzNgdr+1wXoU0KKj6rufu47lhY7KbJR2C6T6+PfyN0Ea7wkSS+qQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/iconv-lite": { + "version": "0.6.3", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz", + "integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==", + "license": "MIT", + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/ignore": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/immediate": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/immediate/-/immediate-3.0.6.tgz", + "integrity": "sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/import-fresh": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/inflight": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.", + "dev": true, + "license": "ISC", + "dependencies": { + "once": "^1.3.0", + "wrappy": "1" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "dev": true, + "license": "ISC" + }, + "node_modules/internal-slot": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.1.0.tgz", + "integrity": "sha512-4gd7VpWNQNB4UKKCFFVcp1AVv+FMOgs9NKzjHKusc8jTMhd5eL1NqQqOpE0KzMds804/yHlglp3uxgluOqAPLw==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "hasown": "^2.0.2", + "side-channel": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/internmap": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/internmap/-/internmap-2.0.3.tgz", + "integrity": "sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/is-array-buffer": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/is-array-buffer/-/is-array-buffer-3.0.5.tgz", + "integrity": "sha512-DDfANUiiG2wC1qawP66qlTugJeL5HyzMpfr8lLK+jMQirGzNod0B12cFB/9q838Ru27sBwfw78/rdoU7RERz6A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "get-intrinsic": "^1.2.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-async-function": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-async-function/-/is-async-function-2.1.1.tgz", + "integrity": "sha512-9dgM/cZBnNvjzaMYHVoxxfPj2QXt22Ev7SuuPrs+xav0ukGB0S6d4ydZdEiM48kLx5kDV+QBPrpVnFyefL8kkQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "async-function": "^1.0.0", + "call-bound": "^1.0.3", + "get-proto": "^1.0.1", + "has-tostringtag": "^1.0.2", + "safe-regex-test": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-bigint": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-bigint/-/is-bigint-1.1.0.tgz", + "integrity": "sha512-n4ZT37wG78iz03xPRKJrHTdZbe3IicyucEtdRsV5yglwc3GyUfbAfpSeD0FJ41NbUNSt5wbhqfp1fS+BgnvDFQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-bigints": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-boolean-object": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/is-boolean-object/-/is-boolean-object-1.2.2.tgz", + "integrity": "sha512-wa56o2/ElJMYqjCjGkXri7it5FbebW5usLw/nPmCMs5DeZ7eziSYZhSmPRn0txqeW4LnAmQQU7FgqLpsEFKM4A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-callable": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/is-callable/-/is-callable-1.2.7.tgz", + "integrity": "sha512-1BC0BVFhS/p0qtw6enp8e+8OD0UrK0oFLztSjNzhcKA3WDuJxxAPXzPuPtKkjEY9UUoEWlX/8fgKeu2S8i9JTA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-core-module": { + "version": "2.16.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", + "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", + "dev": true, + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-data-view": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/is-data-view/-/is-data-view-1.0.2.tgz", + "integrity": "sha512-RKtWF8pGmS87i2D6gqQu/l7EYRlVdfzemCJN/P3UOs//x1QE7mfhvzHIApBTRf7axvT6DMGwSwBXYCT0nfB9xw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "get-intrinsic": "^1.2.6", + "is-typed-array": "^1.1.13" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-date-object": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-date-object/-/is-date-object-1.1.0.tgz", + "integrity": "sha512-PwwhEakHVKTdRNVOw+/Gyh0+MzlCl4R6qKvkhuvLtPMggI1WAHt9sOwZxQLSGpUaDnrdyDsomoRgNnCfKNSXXg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-finalizationregistry": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-finalizationregistry/-/is-finalizationregistry-1.1.1.tgz", + "integrity": "sha512-1pC6N8qWJbWoPtEjgcL2xyhQOP491EQjeUo3qTKcmV8YSDDJrOepfG8pcC7h/QgnQHYSv0mJ3Z/ZWxmatVrysg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/is-generator-function": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-generator-function/-/is-generator-function-1.1.0.tgz", + "integrity": "sha512-nPUB5km40q9e8UfN/Zc24eLlzdSf9OfKByBw9CIdw4H1giPMeA0OIJvbchsCu4npfI2QcMVBsGEBHKZ7wLTWmQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "get-proto": "^1.0.0", + "has-tostringtag": "^1.0.2", + "safe-regex-test": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-map": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-map/-/is-map-2.0.3.tgz", + "integrity": "sha512-1Qed0/Hr2m+YqxnM09CjA2d/i6YZNfF6R2oRAOj36eUdS6qIV/huPJNSEpKbupewFs+ZsJlxsjjPbc0/afW6Lw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-negative-zero": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-negative-zero/-/is-negative-zero-2.0.3.tgz", + "integrity": "sha512-5KoIu2Ngpyek75jXodFvnafB6DJgr3u8uuK0LEZJjrU19DrMD3EVERaR8sjz8CCGgpZvxPl9SuE1GMVPFHx1mw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-number-object": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-number-object/-/is-number-object-1.1.1.tgz", + "integrity": "sha512-lZhclumE1G6VYD8VHe35wFaIif+CTy5SJIi5+3y4psDgWu4wPDoBhF8NxUOinEc7pHgiTsT6MaBb92rKhhD+Xw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-path-cwd": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/is-path-cwd/-/is-path-cwd-2.2.0.tgz", + "integrity": "sha512-w942bTcih8fdJPJmQHFzkS76NEP8Kzzvmw92cXsazb8intwLqPibPPdXf4ANdKV3rYMuuQYGIWtvz9JilB3NFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/is-path-in-cwd": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-path-in-cwd/-/is-path-in-cwd-2.1.0.tgz", + "integrity": "sha512-rNocXHgipO+rvnP6dk3zI20RpOtrAM/kzbB258Uw5BWr3TpXi861yzjo16Dn4hUox07iw5AyeMLHWsujkjzvRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-path-inside": "^2.1.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/is-path-inside": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-2.1.0.tgz", + "integrity": "sha512-wiyhTzfDWsvwAW53OBWF5zuvaOGlZ6PwYxAbPVDhpm+gM09xKQGjBq/8uYN12aDvMxnAnq3dxTyoSoRNmg5YFg==", + "dev": true, + "license": "MIT", + "dependencies": { + "path-is-inside": "^1.0.2" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/is-regex": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/is-regex/-/is-regex-1.2.1.tgz", + "integrity": "sha512-MjYsKHO5O7mCsmRGxWcLWheFqN9DJ/2TmngvjKXihe6efViPqc274+Fx/4fYj/r03+ESvBdTXK0V6tA3rgez1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "gopd": "^1.2.0", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-set": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-set/-/is-set-2.0.3.tgz", + "integrity": "sha512-iPAjerrse27/ygGLxw+EBR9agv9Y6uLeYVJMu+QNCoouJ1/1ri0mGrcWpfCqFZuzzx3WjtwxG098X+n4OuRkPg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-shared-array-buffer": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/is-shared-array-buffer/-/is-shared-array-buffer-1.0.4.tgz", + "integrity": "sha512-ISWac8drv4ZGfwKl5slpHG9OwPNty4jOWPRIhBpxOoD+hqITiwuipOQ2bNthAzwA3B4fIjO4Nln74N0S9byq8A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-string": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-string/-/is-string-1.1.1.tgz", + "integrity": "sha512-BtEeSsoaQjlSPBemMQIrY1MY0uM6vnS1g5fmufYOtnxLGUZM2178PKbhsk7Ffv58IX+ZtcvoGwccYsh0PglkAA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-symbol": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-symbol/-/is-symbol-1.1.1.tgz", + "integrity": "sha512-9gGx6GTtCQM73BgmHQXfDmLtfjjTUDSyoxTCbp5WtoixAhfgsDirWIcVQ/IHpvI5Vgd5i/J5F7B9cN/WlVbC/w==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "has-symbols": "^1.1.0", + "safe-regex-test": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-typed-array": { + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/is-typed-array/-/is-typed-array-1.1.15.tgz", + "integrity": "sha512-p3EcsicXjit7SaskXHs1hA91QxgTw46Fv6EFKKGS5DRFLD8yKnohjF3hxoju94b/OcMZoQukzpPpBE9uLVKzgQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "which-typed-array": "^1.1.16" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-weakmap": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/is-weakmap/-/is-weakmap-2.0.2.tgz", + "integrity": "sha512-K5pXYOm9wqY1RgjpL3YTkF39tni1XajUIkawTLUo9EZEVUFga5gSQJF8nNS7ZwJQ02y+1YCNYcMh+HIf1ZqE+w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-weakref": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/is-weakref/-/is-weakref-1.1.1.tgz", + "integrity": "sha512-6i9mGWSlqzNMEqpCp93KwRS1uUOodk2OJ6b+sq7ZPDSy2WuI5NFIxp/254TytR8ftefexkWn5xNiHUNpPOfSew==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-weakset": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/is-weakset/-/is-weakset-2.0.4.tgz", + "integrity": "sha512-mfcwb6IzQyOKTs84CQMrOwW4gQcaTOAWJ0zzJCl2WSPDrWk/OzDaImWFH3djXhb24g4eudZfLRozAvPGw4d9hQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "get-intrinsic": "^1.2.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true, + "license": "ISC" + }, + "node_modules/iterator.prototype": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/iterator.prototype/-/iterator.prototype-1.1.5.tgz", + "integrity": "sha512-H0dkQoCa3b2VEeKQBOxFph+JAbcrQdE7KC0UkqwpLmv2EC4P41QXP+rqo9wYodACiG5/WM5s9oDApTU8utwj9g==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.1.4", + "es-object-atoms": "^1.0.0", + "get-intrinsic": "^1.2.6", + "get-proto": "^1.0.0", + "has-symbols": "^1.1.0", + "set-function-name": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/jackspeak": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-4.1.1.tgz", + "integrity": "sha512-zptv57P3GpL+O0I7VdMJNBZCu+BPHVQUk55Ft8/QCJjTVxrnJHuVuX/0Bl2A6/+2oyR/ZMEuFKwmzqqZ/U5nPQ==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "@isaacs/cliui": "^8.0.2" + }, + "engines": { + "node": "20 || >=22" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/jest-worker": { + "version": "27.5.1", + "resolved": "https://registry.npmjs.org/jest-worker/-/jest-worker-27.5.1.tgz", + "integrity": "sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*", + "merge-stream": "^2.0.0", + "supports-color": "^8.0.0" + }, + "engines": { + "node": ">= 10.13.0" + } + }, + "node_modules/jest-worker/node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, + "node_modules/js-tokens": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", + "license": "MIT" + }, + "node_modules/js-yaml": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "argparse": "^2.0.1" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + }, + "node_modules/json-buffer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", + "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-parse-even-better-errors": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", + "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/json5": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/json5/-/json5-1.0.2.tgz", + "integrity": "sha512-g1MWMLBiz8FKi1e4w0UyVL3w+iJceWAFBAaBnnGKOpNa5f8TLktkbre1+s6oICydWAm+HRUGTmI+//xv2hvXYA==", + "dev": true, + "license": "MIT", + "dependencies": { + "minimist": "^1.2.0" + }, + "bin": { + "json5": "lib/cli.js" + } + }, + "node_modules/jsx-ast-utils": { + "version": "3.3.5", + "resolved": "https://registry.npmjs.org/jsx-ast-utils/-/jsx-ast-utils-3.3.5.tgz", + "integrity": "sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "array-includes": "^3.1.6", + "array.prototype.flat": "^1.3.1", + "object.assign": "^4.1.4", + "object.values": "^1.1.6" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/jszip": { + "version": "3.10.1", + "resolved": "https://registry.npmjs.org/jszip/-/jszip-3.10.1.tgz", + "integrity": "sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==", + "dev": true, + "license": "(MIT OR GPL-3.0-or-later)", + "dependencies": { + "lie": "~3.3.0", + "pako": "~1.0.2", + "readable-stream": "~2.3.6", + "setimmediate": "^1.0.5" + } + }, + "node_modules/jszip/node_modules/pako": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/pako/-/pako-1.0.11.tgz", + "integrity": "sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==", + "dev": true, + "license": "(MIT AND Zlib)" + }, + "node_modules/keyv": { + "version": "4.5.4", + "resolved": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz", + "integrity": "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "json-buffer": "3.0.1" + } + }, + "node_modules/leaflet": { + "version": "1.9.4", + "resolved": "https://registry.npmjs.org/leaflet/-/leaflet-1.9.4.tgz", + "integrity": "sha512-nxS1ynzJOmOlHp+iL3FyWqK89GtNL8U8rvlMOsQdTTssxZwCXh8N2NB3GDQOL+YR3XnWyZAxwQixURb+FA74PA==" + }, + "node_modules/levn": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/lie": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/lie/-/lie-3.3.0.tgz", + "integrity": "sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "immediate": "~3.0.5" + } + }, + "node_modules/loader-runner": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/loader-runner/-/loader-runner-4.3.0.tgz", + "integrity": "sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.11.5" + } + }, + "node_modules/locate-path": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-locate": "^5.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/lodash.camelcase": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/lodash.camelcase/-/lodash.camelcase-4.3.0.tgz", + "integrity": "sha512-TwuEnCnxbc3rAvhf/LbG7tJUDzhqXyFnv3dtzLOPgCG/hODL7WFnsbwktkD7yUV0RrreP/l1PALq/YSg6VvjlA==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.kebabcase": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/lodash.kebabcase/-/lodash.kebabcase-4.1.1.tgz", + "integrity": "sha512-N8XRTIMMqqDgSy4VLKPnJ/+hpGZN+PHQiJnSenYqPaVV/NCqEogTnAdZLQiGKhxX+JCs8waWq2t1XHWKOmlY8g==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.merge": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.snakecase": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/lodash.snakecase/-/lodash.snakecase-4.1.1.tgz", + "integrity": "sha512-QZ1d4xoBHYUeuouhEq3lk3Uq7ldgyFXGBhg04+oRLnIz8o9T65Eh+8YdroUwn846zchkA9yDsDl5CVVaV2nqYw==", + "dev": true, + "license": "MIT" + }, + "node_modules/lodash.upperfirst": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/lodash.upperfirst/-/lodash.upperfirst-4.3.1.tgz", + "integrity": "sha512-sReKOYJIJf74dhJONhU4e0/shzi1trVbSWDOhKYE5XV2O+H7Sb2Dihwuc7xWxVl+DgFPyTqIN3zMfT9cq5iWDg==", + "dev": true, + "license": "MIT" + }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "license": "MIT", + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, + "node_modules/lru-cache": { + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-11.1.0.tgz", + "integrity": "sha512-QIXZUBJUx+2zHUdQujWejBkcD9+cs94tLn0+YL8UrCh+D5sCXZ4c7LaEH48pNwRY3MLDgqUFyhlCyjJPf1WP0A==", + "dev": true, + "license": "ISC", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/merge-stream": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", + "integrity": "sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==", + "dev": true, + "license": "MIT" + }, + "node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/micromatch": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", + "dev": true, + "license": "MIT", + "dependencies": { + "braces": "^3.0.3", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "dev": true, + "license": "MIT", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/minimist": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", + "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/minipass": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", + "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/mkdirp": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-3.0.1.tgz", + "integrity": "sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==", + "dev": true, + "license": "MIT", + "bin": { + "mkdirp": "dist/cjs/src/bin.js" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", + "dev": true, + "license": "MIT" + }, + "node_modules/ncp": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ncp/-/ncp-2.0.0.tgz", + "integrity": "sha512-zIdGUrPRFTUELUvr3Gmc7KZ2Sw/h1PiVM0Af/oHB6zgnV1ikqSfRk+TOufi79aHYCW3NiOXmr1BP5nWbzojLaA==", + "dev": true, + "license": "MIT", + "bin": { + "ncp": "bin/ncp" + } + }, + "node_modules/neo-async": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz", + "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==", + "dev": true, + "license": "MIT" + }, + "node_modules/node-fetch": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz", + "integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==", + "dev": true, + "license": "MIT", + "dependencies": { + "whatwg-url": "^5.0.0" + }, + "engines": { + "node": "4.x || >=6.0.0" + }, + "peerDependencies": { + "encoding": "^0.1.0" + }, + "peerDependenciesMeta": { + "encoding": { + "optional": true + } + } + }, + "node_modules/node-releases": { + "version": "2.0.19", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.19.tgz", + "integrity": "sha512-xxOWJsBKtzAq7DY0J+DTzuz58K8e7sJbdgwkbMWQe8UYB6ekmsQ45q0M/tJDsGaZmbC+l7n57UV8Hl5tHxO9uw==", + "dev": true, + "license": "MIT" + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-inspect": { + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/object-keys": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", + "integrity": "sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/object.assign": { + "version": "4.1.7", + "resolved": "https://registry.npmjs.org/object.assign/-/object.assign-4.1.7.tgz", + "integrity": "sha512-nK28WOo+QIjBkDduTINE4JkF/UJJKyf2EJxvJKfblDpyg0Q+pkOHNTL0Qwy6NP6FhE/EnzV73BxxqcJaXY9anw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0", + "has-symbols": "^1.1.0", + "object-keys": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/object.entries": { + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/object.entries/-/object.entries-1.1.9.tgz", + "integrity": "sha512-8u/hfXFRBD1O0hPUjioLhoWFHRmt6tKA4/vZPyckBr18l1KE9uHrFaFaUi8MDRTpi4uak2goyPTSNJLXX2k2Hw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/object.fromentries": { + "version": "2.0.8", + "resolved": "https://registry.npmjs.org/object.fromentries/-/object.fromentries-2.0.8.tgz", + "integrity": "sha512-k6E21FzySsSK5a21KRADBd/NGneRegFO5pLHfdQLpRDETUNJueLXs3WCzyQ3tFRDYgbq3KHGXfTbi2bs8WQ6rQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.2", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/object.groupby": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/object.groupby/-/object.groupby-1.0.3.tgz", + "integrity": "sha512-+Lhy3TQTuzXI5hevh8sBGqbmurHbbIjAi0Z4S63nthVLmLxfbj4T54a4CfZrXIrt9iP4mVAPYMo/v99taj3wjQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/object.values": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/object.values/-/object.values-1.2.1.tgz", + "integrity": "sha512-gXah6aZrcUxjWg2zR2MwouP2eHlCBzdV4pygudehaKXSGW4v2AsRQUK+lwwXhii6KFZcunEnmSUoYp5CXibxtA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "dev": true, + "license": "ISC", + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/optionator": { + "version": "0.9.4", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", + "integrity": "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==", + "dev": true, + "license": "MIT", + "dependencies": { + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.5" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/own-keys": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/own-keys/-/own-keys-1.0.1.tgz", + "integrity": "sha512-qFOyK5PjiWZd+QQIh+1jhdb9LpxTF0qs7Pm8o5QHYZ0M3vKqSqzsZaEB6oWlxZ+q2sJBMI/Ktgd2N5ZwQoRHfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "get-intrinsic": "^1.2.6", + "object-keys": "^1.1.1", + "safe-push-apply": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", + "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-limit": "^3.0.2" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-map": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/p-map/-/p-map-2.1.0.tgz", + "integrity": "sha512-y3b8Kpd8OAN444hxfBbFfj1FY/RjtTd8tzYwhUqNYXx0fXx2iX4maP4Qr6qhIKbQXI02wTLAda4fYUbDagTUFw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/package-json-from-dist": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", + "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", + "dev": true, + "license": "BlueOak-1.0.0" + }, + "node_modules/pako": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/pako/-/pako-2.1.0.tgz", + "integrity": "sha512-w+eufiZ1WuJYgPXbV/PO3NCMEc3xqylkKHzp8bxp1uW4qaSNQUkwmLLEc3kKsfz8lpV1F8Ht3U1Cm+9Srog2ug==", + "license": "(MIT AND Zlib)" + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dev": true, + "license": "MIT", + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/path-browserify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-browserify/-/path-browserify-1.0.1.tgz", + "integrity": "sha512-b7uo2UCUOYZcnF/3ID0lulOJi/bafxa1xPe7ZPsammBSpjSWQkjNxlt635YGS2MiR9GjvuXCtz2emr3jbsz98g==", + "dev": true, + "license": "MIT" + }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-is-absolute": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/path-is-inside": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/path-is-inside/-/path-is-inside-1.0.2.tgz", + "integrity": "sha512-DUWJr3+ULp4zXmol/SZkFf3JGsS9/SIv+Y3Rt93/UjPpDpklB5f1er4O3POIbUuUJ3FXgqte2Q7SrU6zAqwk8w==", + "dev": true, + "license": "(WTFPL OR MIT)" + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "dev": true, + "license": "MIT" + }, + "node_modules/path-scurry": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-2.0.0.tgz", + "integrity": "sha512-ypGJsmGtdXUOeM5u93TyeIEfEhM6s+ljAhrk5vAvSx8uyY/02OvrZnA0YNGUrPXfpJMgI1ODd3nwz8Npx4O4cg==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "lru-cache": "^11.0.0", + "minipass": "^7.1.2" + }, + "engines": { + "node": "20 || >=22" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/picocolors": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "dev": true, + "license": "ISC" + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pify": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/pify/-/pify-4.0.1.tgz", + "integrity": "sha512-uB80kBFb/tfd68bVleG9T5GGsGPjJrLAUpR5PZIrhBnIaRTQRjqdJSsIKkOP6OAIFbj7GOrcudc5pNjZ+geV2g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/pinkie": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/pinkie/-/pinkie-2.0.4.tgz", + "integrity": "sha512-MnUuEycAemtSaeFSjXKW/aroV7akBbY+Sv+RkyqFjgAe73F+MR0TBWKBRDkmfWq/HiFmdavfZ1G7h4SPZXaCSg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pinkie-promise": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/pinkie-promise/-/pinkie-promise-2.0.1.tgz", + "integrity": "sha512-0Gni6D4UcLTbv9c57DfxDGdr41XfgUjqWZu492f0cIGr16zDU06BWP/RAEvOuo7CQ0CNjHaLlM59YJJFm3NWlw==", + "dev": true, + "license": "MIT", + "dependencies": { + "pinkie": "^2.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/possible-typed-array-names": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/possible-typed-array-names/-/possible-typed-array-names-1.1.0.tgz", + "integrity": "sha512-/+5VFTchJDoVj3bhoqi6UeymcD00DAwb1nJwamzPvHEszJ4FpF6SNNbUbOS8yI56qHzdV8eK0qEfOSiodkTdxg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/prelude-ls": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/prettier": { + "version": "3.6.2", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.6.2.tgz", + "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", + "dev": true, + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/prettier-linter-helpers": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/prettier-linter-helpers/-/prettier-linter-helpers-1.0.0.tgz", + "integrity": "sha512-GbK2cP9nraSSUF9N2XwUwqfzlAFlMNYYl+ShE/V+H8a9uNl/oUqB1w2EL54Jh0OlyRSd8RfWYJ3coVS4TROP2w==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-diff": "^1.1.2" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/process-nextick-args": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==", + "dev": true, + "license": "MIT" + }, + "node_modules/prop-types": { + "version": "15.8.1", + "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", + "integrity": "sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==", + "dev": true, + "license": "MIT", + "dependencies": { + "loose-envify": "^1.4.0", + "object-assign": "^4.1.1", + "react-is": "^16.13.1" + } + }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/randombytes": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/randombytes/-/randombytes-2.1.0.tgz", + "integrity": "sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "safe-buffer": "^5.1.0" + } + }, + "node_modules/react": { + "version": "18.3.1", + "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", + "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.1.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-dom": { + "version": "18.3.1", + "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", + "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.1.0", + "scheduler": "^0.23.2" + }, + "peerDependencies": { + "react": "^18.3.1" + } + }, + "node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/react-leaflet": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/react-leaflet/-/react-leaflet-4.2.1.tgz", + "integrity": "sha512-p9chkvhcKrWn/H/1FFeVSqLdReGwn2qmiobOQGO3BifX+/vV/39qhY8dGqbdcPh1e6jxh/QHriLXr7a4eLFK4Q==", + "dependencies": { + "@react-leaflet/core": "^2.1.0" + }, + "peerDependencies": { + "leaflet": "^1.9.0", + "react": "^18.0.0", + "react-dom": "^18.0.0" + } + }, + "node_modules/readable-stream": { + "version": "2.3.8", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.8.tgz", + "integrity": "sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==", + "dev": true, + "license": "MIT", + "dependencies": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "node_modules/reflect.getprototypeof": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/reflect.getprototypeof/-/reflect.getprototypeof-1.0.10.tgz", + "integrity": "sha512-00o4I+DVrefhv+nX0ulyi3biSHCPDe+yLv5o/p6d/UVlirijB8E16FtfwSAi4g3tcqrQ4lRAqQSoFEZJehYEcw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.9", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0", + "get-intrinsic": "^1.2.7", + "get-proto": "^1.0.1", + "which-builtin-type": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/regexp.prototype.flags": { + "version": "1.5.4", + "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.5.4.tgz", + "integrity": "sha512-dYqgNSZbDwkaJ2ceRd9ojCGjBq+mOm9LmtXnAnEGyHhN/5R7iDW2TRw3h+o/jCFxus3P2LfWIIiwowAjANm7IA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "define-properties": "^1.2.1", + "es-errors": "^1.3.0", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "set-function-name": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/regexpp": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.2.0.tgz", + "integrity": "sha512-pq2bWo9mVD43nbts2wGv17XLiNLya+GklZ8kaDLV2Z08gDCsGpnKn9BFMepvWuHCbyVvY7J5o5+BVvoQbmlJLg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/mysticatea" + } + }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/resolve": { + "version": "1.22.10", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz", + "integrity": "sha512-NPRy+/ncIMeDlTAsuqwKIiferiawhefFJtkNSW0qZJEqMEb+qBt/77B/jGeeek+F0uOeN05CDa6HXbbIgtVX4w==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-core-module": "^2.16.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/reusify": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.1.0.tgz", + "integrity": "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==", + "dev": true, + "license": "MIT", + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/rimraf": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-6.0.1.tgz", + "integrity": "sha512-9dkvaxAsk/xNXSJzMgFqqMCuFgt2+KsOFek3TMLfo8NCPfWpBmqwyNn5Y+NX56QUYfCtsyhF3ayiboEoUmJk/A==", + "dev": true, + "license": "ISC", + "dependencies": { + "glob": "^11.0.0", + "package-json-from-dist": "^1.0.0" + }, + "bin": { + "rimraf": "dist/esm/bin.mjs" + }, + "engines": { + "node": "20 || >=22" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rimraf/node_modules/glob": { + "version": "11.0.3", + "resolved": "https://registry.npmjs.org/glob/-/glob-11.0.3.tgz", + "integrity": "sha512-2Nim7dha1KVkaiF4q6Dj+ngPPMdfvLJEOpZk/jKiUAkqKebpGAWQXAq9z1xu9HKu5lWfqw/FASuccEjyznjPaA==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.3.1", + "jackspeak": "^4.1.1", + "minimatch": "^10.0.3", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^2.0.0" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "engines": { + "node": "20 || >=22" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rimraf/node_modules/minimatch": { + "version": "10.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.0.3.tgz", + "integrity": "sha512-IPZ167aShDZZUMdRk66cyQAW3qr0WzbHkPdMYa8bzZhlHhO3jALbKdxcaak7W9FfT2rZNpQuUu4Od7ILEpXSaw==", + "dev": true, + "license": "ISC", + "dependencies": { + "@isaacs/brace-expansion": "^5.0.0" + }, + "engines": { + "node": "20 || >=22" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/robust-predicates": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/robust-predicates/-/robust-predicates-3.0.2.tgz", + "integrity": "sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==", + "license": "Unlicense" + }, + "node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/rw": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/rw/-/rw-1.3.3.tgz", + "integrity": "sha512-PdhdWy89SiZogBLaw42zdeqtRJ//zFd2PgQavcICDUgJT5oW10QCRKbJ6bg4r0/UY2M6BWd5tkxuGFRvCkgfHQ==", + "license": "BSD-3-Clause" + }, + "node_modules/safe-array-concat": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/safe-array-concat/-/safe-array-concat-1.1.3.tgz", + "integrity": "sha512-AURm5f0jYEOydBj7VQlVvDrjeFgthDdEF5H1dP+6mNpoXOMo1quQqJ4wvJDyRZ9+pO3kGWoOdmV08cSv2aJV6Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.2", + "get-intrinsic": "^1.2.6", + "has-symbols": "^1.1.0", + "isarray": "^2.0.5" + }, + "engines": { + "node": ">=0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/safe-array-concat/node_modules/isarray": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-2.0.5.tgz", + "integrity": "sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==", + "dev": true, + "license": "MIT" + }, + "node_modules/safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==", + "dev": true, + "license": "MIT" + }, + "node_modules/safe-push-apply": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/safe-push-apply/-/safe-push-apply-1.0.0.tgz", + "integrity": "sha512-iKE9w/Z7xCzUMIZqdBsp6pEQvwuEebH4vdpjcDWnyzaI6yl6O9FHvVpmGelvEHNsoY6wGblkxR6Zty/h00WiSA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "isarray": "^2.0.5" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/safe-push-apply/node_modules/isarray": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-2.0.5.tgz", + "integrity": "sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==", + "dev": true, + "license": "MIT" + }, + "node_modules/safe-regex-test": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/safe-regex-test/-/safe-regex-test-1.1.0.tgz", + "integrity": "sha512-x/+Cz4YrimQxQccJf5mKEbIa1NzeCRNI5Ecl/ekmlYaampdNLPalVyIcCZNNH3MvmqBugV5TMYZXv0ljslUlaw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "is-regex": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", + "license": "MIT" + }, + "node_modules/sanitize-filename": { + "version": "1.6.3", + "resolved": "https://registry.npmjs.org/sanitize-filename/-/sanitize-filename-1.6.3.tgz", + "integrity": "sha512-y/52Mcy7aw3gRm7IrcGDFx/bCk4AhRh2eI9luHOQM86nZsqwiRkkq2GekHXBBD+SmPidc8i2PqtYZl+pWJ8Oeg==", + "dev": true, + "license": "WTFPL OR ISC", + "dependencies": { + "truncate-utf8-bytes": "^1.0.0" + } + }, + "node_modules/scheduler": { + "version": "0.23.2", + "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.23.2.tgz", + "integrity": "sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.1.0" + } + }, + "node_modules/schema-utils": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-3.3.0.tgz", + "integrity": "sha512-pN/yOAvcC+5rQ5nERGuwrjLlYvLTbCibnZ1I7B1LaiAz9BRBlE9GMgE/eqV30P7aJQUf7Ddimy/RsbYO/GrVGg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/json-schema": "^7.0.8", + "ajv": "^6.12.5", + "ajv-keywords": "^3.5.2" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/semver": { + "version": "7.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", + "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/serialize-javascript": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz", + "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "randombytes": "^2.1.0" + } + }, + "node_modules/set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/set-function-name": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/set-function-name/-/set-function-name-2.0.2.tgz", + "integrity": "sha512-7PGFlmtwsEADb0WYyvCMa1t+yke6daIG4Wirafur5kcf+MhUnPms1UeR0CKQdTZD81yESwMHbtn+TR+dMviakQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "functions-have-names": "^1.2.3", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/set-proto": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/set-proto/-/set-proto-1.0.0.tgz", + "integrity": "sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==", + "dev": true, + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/setimmediate": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/setimmediate/-/setimmediate-1.0.5.tgz", + "integrity": "sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==", + "dev": true, + "license": "MIT" + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/side-channel": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/signal-exit": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", + "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/socket.io-client": { + "version": "4.8.1", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.8.1.tgz", + "integrity": "sha512-hJVXfu3E28NmzGk8o1sHhN3om52tRvwYeidbj7xKy2eIIse5IoKX3USlS6Tqt3BHAtflLIkCQBkzVrEEfWUyYQ==", + "license": "MIT", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.2", + "engine.io-client": "~6.6.1", + "socket.io-parser": "~4.2.4" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-client/node_modules/debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/socket.io-parser": { + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", + "license": "MIT", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-parser/node_modules/debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/source-map": { + "version": "0.7.6", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.6.tgz", + "integrity": "sha512-i5uvt8C3ikiWeNZSVZNWcfZPItFQOsYTUAOkcUPGd8DqDy1uOUikjt5dG+uRlwyvR108Fb9DOd4GvXfT0N2/uQ==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">= 12" + } + }, + "node_modules/source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "dev": true, + "license": "MIT", + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/source-map-support/node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/stop-iteration-iterator": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/stop-iteration-iterator/-/stop-iteration-iterator-1.1.0.tgz", + "integrity": "sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "internal-slot": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "dev": true, + "license": "MIT", + "dependencies": { + "safe-buffer": "~5.1.0" + } + }, + "node_modules/string-width": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", + "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", + "dev": true, + "license": "MIT", + "dependencies": { + "eastasianwidth": "^0.2.0", + "emoji-regex": "^9.2.2", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/string-width-cjs": { + "name": "string-width", + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true, + "license": "MIT" + }, + "node_modules/string-width-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/string.prototype.matchall": { + "version": "4.0.12", + "resolved": "https://registry.npmjs.org/string.prototype.matchall/-/string.prototype.matchall-4.0.12.tgz", + "integrity": "sha512-6CC9uyBL+/48dYizRf7H7VAYCMCNTBeM78x/VTUe9bFEaxBepPJDa1Ow99LqI/1yF7kuy7Q3cQsYMrcjGUcskA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.3", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.6", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.0.0", + "get-intrinsic": "^1.2.6", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "internal-slot": "^1.1.0", + "regexp.prototype.flags": "^1.5.3", + "set-function-name": "^2.0.2", + "side-channel": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/string.prototype.repeat": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/string.prototype.repeat/-/string.prototype.repeat-1.0.0.tgz", + "integrity": "sha512-0u/TldDbKD8bFCQ/4f5+mNRrXwZ8hg2w7ZR8wa16e8z9XpePWl3eGEcUD0OXpEH/VJH/2G3gjUtR3ZOiBe2S/w==", + "dev": true, + "license": "MIT", + "dependencies": { + "define-properties": "^1.1.3", + "es-abstract": "^1.17.5" + } + }, + "node_modules/string.prototype.trim": { + "version": "1.2.10", + "resolved": "https://registry.npmjs.org/string.prototype.trim/-/string.prototype.trim-1.2.10.tgz", + "integrity": "sha512-Rs66F0P/1kedk5lyYyH9uBzuiI/kNRmwJAR9quK6VOtIpZ2G+hMZd+HQbbv25MgCA6gEffoMZYxlTod4WcdrKA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.2", + "define-data-property": "^1.1.4", + "define-properties": "^1.2.1", + "es-abstract": "^1.23.5", + "es-object-atoms": "^1.0.0", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/string.prototype.trimend": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/string.prototype.trimend/-/string.prototype.trimend-1.0.9.tgz", + "integrity": "sha512-G7Ok5C6E/j4SGfyLCloXTrngQIQU3PWtXGst3yM7Bea9FRURf1S42ZHlZZtsNque2FN2PoUhfZXYLNWwEr4dLQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "call-bound": "^1.0.2", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/string.prototype.trimstart": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/string.prototype.trimstart/-/string.prototype.trimstart-1.0.8.tgz", + "integrity": "sha512-UXSH262CSZY1tfu3G3Secr6uGLCFVPMhIqHjlgCUtCCcgihYc/xKs9djMTMUOb2j1mVSeU8EU6NWc/iQKU6Gfg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "define-properties": "^1.2.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/strip-ansi": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz", + "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^6.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/strip-ansi?sponsor=1" + } + }, + "node_modules/strip-ansi-cjs": { + "name": "strip-ansi", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-bom": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz", + "integrity": "sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/strip-json-comments": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/synckit": { + "version": "0.11.11", + "resolved": "https://registry.npmjs.org/synckit/-/synckit-0.11.11.tgz", + "integrity": "sha512-MeQTA1r0litLUf0Rp/iisCaL8761lKAZHaimlbGK4j0HysC4PLfqygQj9srcs0m2RdtDYnF8UuYyKpbjHYp7Jw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@pkgr/core": "^0.2.9" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/synckit" + } + }, + "node_modules/tapable": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.3.tgz", + "integrity": "sha512-ZL6DDuAlRlLGghwcfmSn9sK3Hr6ArtyudlSAiCqQ6IfE+b+HHbydbYDIG15IfS5do+7XQQBdBiubF/cV2dnDzg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/terser": { + "version": "5.43.1", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.43.1.tgz", + "integrity": "sha512-+6erLbBm0+LROX2sPXlUYx/ux5PyE9K/a92Wrt6oA+WDAoFTdpHE5tCYCI5PNzq2y8df4rA+QgHLJuR4jNymsg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "@jridgewell/source-map": "^0.3.3", + "acorn": "^8.14.0", + "commander": "^2.20.0", + "source-map-support": "~0.5.20" + }, + "bin": { + "terser": "bin/terser" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/terser-webpack-plugin": { + "version": "5.3.14", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.14.tgz", + "integrity": "sha512-vkZjpUjb6OMS7dhV+tILUW6BhpDR7P2L/aQSAv+Uwk+m8KATX9EccViHTJR2qDtACKPIYndLGCyl3FMo+r2LMw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.25", + "jest-worker": "^27.4.5", + "schema-utils": "^4.3.0", + "serialize-javascript": "^6.0.2", + "terser": "^5.31.1" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependencies": { + "webpack": "^5.1.0" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "esbuild": { + "optional": true + }, + "uglify-js": { + "optional": true + } + } + }, + "node_modules/terser-webpack-plugin/node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/terser-webpack-plugin/node_modules/ajv-keywords": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", + "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3" + }, + "peerDependencies": { + "ajv": "^8.8.2" + } + }, + "node_modules/terser-webpack-plugin/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "dev": true, + "license": "MIT" + }, + "node_modules/terser-webpack-plugin/node_modules/schema-utils": { + "version": "4.3.2", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", + "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/json-schema": "^7.0.9", + "ajv": "^8.9.0", + "ajv-formats": "^2.1.1", + "ajv-keywords": "^5.1.0" + }, + "engines": { + "node": ">= 10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + } + }, + "node_modules/terser/node_modules/commander": { + "version": "2.20.3", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", + "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/tr46": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", + "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==", + "dev": true, + "license": "MIT" + }, + "node_modules/truncate-utf8-bytes": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/truncate-utf8-bytes/-/truncate-utf8-bytes-1.0.2.tgz", + "integrity": "sha512-95Pu1QXQvruGEhv62XCMO3Mm90GscOCClvrIUwCM0PYOXK3kaF3l3sIHxx71ThJfcbM2O5Au6SO3AWCSEfW4mQ==", + "dev": true, + "license": "WTFPL", + "dependencies": { + "utf8-byte-length": "^1.0.1" + } + }, + "node_modules/ts-api-utils": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.1.0.tgz", + "integrity": "sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18.12" + }, + "peerDependencies": { + "typescript": ">=4.8.4" + } + }, + "node_modules/ts-loader": { + "version": "9.5.1", + "resolved": "https://registry.npmjs.org/ts-loader/-/ts-loader-9.5.1.tgz", + "integrity": "sha512-rNH3sK9kGZcH9dYzC7CewQm4NtxJTjSEVRJ2DyBZR7f8/wcta+iV44UPCXc5+nzDzivKtlzV6c9P4e+oFhDLYg==", + "dev": true, + "license": "MIT", + "dependencies": { + "chalk": "^4.1.0", + "enhanced-resolve": "^5.0.0", + "micromatch": "^4.0.0", + "semver": "^7.3.4", + "source-map": "^0.7.4" + }, + "engines": { + "node": ">=12.0.0" + }, + "peerDependencies": { + "typescript": "*", + "webpack": "^5.0.0" + } + }, + "node_modules/tsconfig-paths": { + "version": "3.15.0", + "resolved": "https://registry.npmjs.org/tsconfig-paths/-/tsconfig-paths-3.15.0.tgz", + "integrity": "sha512-2Ac2RgzDe/cn48GvOe3M+o82pEFewD3UPbyoUHHdKasHwJKjds4fLXWf/Ux5kATBKN20oaFGu+jbElp1pos0mg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/json5": "^0.0.29", + "json5": "^1.0.2", + "minimist": "^1.2.6", + "strip-bom": "^3.0.0" + } + }, + "node_modules/tslib": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", + "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", + "dev": true, + "license": "0BSD" + }, + "node_modules/tsutils": { + "version": "3.21.0", + "resolved": "https://registry.npmjs.org/tsutils/-/tsutils-3.21.0.tgz", + "integrity": "sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==", + "dev": true, + "license": "MIT", + "dependencies": { + "tslib": "^1.8.1" + }, + "engines": { + "node": ">= 6" + }, + "peerDependencies": { + "typescript": ">=2.8.0 || >= 3.2.0-dev || >= 3.3.0-dev || >= 3.4.0-dev || >= 3.5.0-dev || >= 3.6.0-dev || >= 3.6.0-beta || >= 3.7.0-dev || >= 3.7.0-beta" + } + }, + "node_modules/type-check": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/typed-array-buffer": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/typed-array-buffer/-/typed-array-buffer-1.0.3.tgz", + "integrity": "sha512-nAYYwfY3qnzX30IkA6AQZjVbtK6duGontcQm1WSG1MD94YLqK0515GNApXkoxKOWMusVssAHWLh9SeaoefYFGw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "is-typed-array": "^1.1.14" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/typed-array-byte-length": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/typed-array-byte-length/-/typed-array-byte-length-1.0.3.tgz", + "integrity": "sha512-BaXgOuIxz8n8pIq3e7Atg/7s+DpiYrxn4vdot3w9KbnBhcRQq6o3xemQdIfynqSeXeDrF32x+WvfzmOjPiY9lg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.8", + "for-each": "^0.3.3", + "gopd": "^1.2.0", + "has-proto": "^1.2.0", + "is-typed-array": "^1.1.14" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typed-array-byte-offset": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/typed-array-byte-offset/-/typed-array-byte-offset-1.0.4.tgz", + "integrity": "sha512-bTlAFB/FBYMcuX81gbL4OcpH5PmlFHqlCCpAl8AlEzMz5k53oNDvN8p1PNOWLEmI2x4orp3raOFB51tv9X+MFQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "available-typed-arrays": "^1.0.7", + "call-bind": "^1.0.8", + "for-each": "^0.3.3", + "gopd": "^1.2.0", + "has-proto": "^1.2.0", + "is-typed-array": "^1.1.15", + "reflect.getprototypeof": "^1.0.9" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typed-array-length": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/typed-array-length/-/typed-array-length-1.0.7.tgz", + "integrity": "sha512-3KS2b+kL7fsuk/eJZ7EQdnEmQoaho/r6KUef7hxvltNA5DR8NAUM+8wJMbJyZ4G9/7i3v5zPBIMN5aybAh2/Jg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind": "^1.0.7", + "for-each": "^0.3.3", + "gopd": "^1.0.1", + "is-typed-array": "^1.1.13", + "possible-typed-array-names": "^1.0.0", + "reflect.getprototypeof": "^1.0.6" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/typescript": { + "version": "5.9.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.2.tgz", + "integrity": "sha512-CWBzXQrc/qOkhidw1OzBTQuYRbfyxDXJMVJ1XNwUHGROVmuaeiEm3OslpZ1RV96d7SKKjZKrSJu3+t/xlw3R9A==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/typescript-eslint": { + "version": "8.40.0", + "resolved": "https://registry.npmjs.org/typescript-eslint/-/typescript-eslint-8.40.0.tgz", + "integrity": "sha512-Xvd2l+ZmFDPEt4oj1QEXzA4A2uUK6opvKu3eGN9aGjB8au02lIVcLyi375w94hHyejTOmzIU77L8ol2sRg9n7Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/eslint-plugin": "8.40.0", + "@typescript-eslint/parser": "8.40.0", + "@typescript-eslint/typescript-estree": "8.40.0", + "@typescript-eslint/utils": "8.40.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/unbox-primitive": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/unbox-primitive/-/unbox-primitive-1.1.0.tgz", + "integrity": "sha512-nWJ91DjeOkej/TA8pXQ3myruKpKEYgqvpw9lz4OPHj/NWFNluYrjbz9j01CJ8yKQd2g4jFoOkINCTW2I5LEEyw==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "has-bigints": "^1.0.2", + "has-symbols": "^1.1.0", + "which-boxed-primitive": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/undici-types": { + "version": "7.10.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.10.0.tgz", + "integrity": "sha512-t5Fy/nfn+14LuOc2KNYg75vZqClpAiqscVvMygNnlsHBFpSXdJaYtXMcdNLpl/Qvc3P2cB3s6lOV51nqsFq4ag==", + "dev": true, + "license": "MIT" + }, + "node_modules/update-browserslist-db": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.1.3.tgz", + "integrity": "sha512-UxhIZQ+QInVdunkDAaiazvvT/+fXL5Osr0JZlJulepYu6Jd7qJtDZjlur0emRlT71EN3ScPoE7gvsuIKKNavKw==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "escalade": "^3.2.0", + "picocolors": "^1.1.1" + }, + "bin": { + "update-browserslist-db": "cli.js" + }, + "peerDependencies": { + "browserslist": ">= 4.21.0" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/utf8-byte-length": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/utf8-byte-length/-/utf8-byte-length-1.0.5.tgz", + "integrity": "sha512-Xn0w3MtiQ6zoz2vFyUVruaCL53O/DwUvkEeOvj+uulMm0BkUGYWmBYVyElqZaSLhY6ZD0ulfU3aBra2aVT4xfA==", + "dev": true, + "license": "(WTFPL OR MIT)" + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", + "dev": true, + "license": "MIT" + }, + "node_modules/watchpack": { + "version": "2.4.4", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.4.tgz", + "integrity": "sha512-c5EGNOiyxxV5qmTtAB7rbiXxi1ooX1pQKMLX/MIabJjRA0SJBQOjKF+KSVfHkr9U1cADPon0mRiVe/riyaiDUA==", + "dev": true, + "license": "MIT", + "dependencies": { + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.1.2" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/webidl-conversions": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", + "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==", + "dev": true, + "license": "BSD-2-Clause" + }, + "node_modules/webpack": { + "version": "5.96.1", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.96.1.tgz", + "integrity": "sha512-l2LlBSvVZGhL4ZrPwyr8+37AunkcYj5qh8o6u2/2rzoPc8gxFJkLj1WxNgooi9pnoc06jh0BjuXnamM4qlujZA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/eslint-scope": "^3.7.7", + "@types/estree": "^1.0.6", + "@webassemblyjs/ast": "^1.12.1", + "@webassemblyjs/wasm-edit": "^1.12.1", + "@webassemblyjs/wasm-parser": "^1.12.1", + "acorn": "^8.14.0", + "browserslist": "^4.24.0", + "chrome-trace-event": "^1.0.2", + "enhanced-resolve": "^5.17.1", + "es-module-lexer": "^1.2.1", + "eslint-scope": "5.1.1", + "events": "^3.2.0", + "glob-to-regexp": "^0.4.1", + "graceful-fs": "^4.2.11", + "json-parse-even-better-errors": "^2.3.1", + "loader-runner": "^4.2.0", + "mime-types": "^2.1.27", + "neo-async": "^2.6.2", + "schema-utils": "^3.2.0", + "tapable": "^2.1.1", + "terser-webpack-plugin": "^5.3.10", + "watchpack": "^2.4.1", + "webpack-sources": "^3.2.3" + }, + "bin": { + "webpack": "bin/webpack.js" + }, + "engines": { + "node": ">=10.13.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/webpack" + }, + "peerDependenciesMeta": { + "webpack-cli": { + "optional": true + } + } + }, + "node_modules/webpack-sources": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/webpack-sources/-/webpack-sources-3.3.3.tgz", + "integrity": "sha512-yd1RBzSGanHkitROoPFd6qsrxt+oFhg/129YzheDGqeustzX0vTZJZsSsQjVQC4yzBQ56K55XU8gaNCtIzOnTg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/webpack/node_modules/eslint-scope": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-5.1.1.tgz", + "integrity": "sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^4.1.1" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/webpack/node_modules/estraverse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-4.3.0.tgz", + "integrity": "sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/whatwg-url": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", + "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", + "dev": true, + "license": "MIT", + "dependencies": { + "tr46": "~0.0.3", + "webidl-conversions": "^3.0.0" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/which-boxed-primitive": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/which-boxed-primitive/-/which-boxed-primitive-1.1.1.tgz", + "integrity": "sha512-TbX3mj8n0odCBFVlY8AxkqcHASw3L60jIuF8jFP78az3C2YhmGvqbHBpAjTRH2/xqYunrJ9g1jSyjCjpoWzIAA==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-bigint": "^1.1.0", + "is-boolean-object": "^1.2.1", + "is-number-object": "^1.1.1", + "is-string": "^1.1.1", + "is-symbol": "^1.1.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/which-builtin-type": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/which-builtin-type/-/which-builtin-type-1.2.1.tgz", + "integrity": "sha512-6iBczoX+kDQ7a3+YJBnh3T+KZRxM/iYNPXicqk66/Qfm1b93iu+yOImkg0zHbj5LNOcNv1TEADiZ0xa34B4q6Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "function.prototype.name": "^1.1.6", + "has-tostringtag": "^1.0.2", + "is-async-function": "^2.0.0", + "is-date-object": "^1.1.0", + "is-finalizationregistry": "^1.1.0", + "is-generator-function": "^1.0.10", + "is-regex": "^1.2.1", + "is-weakref": "^1.0.2", + "isarray": "^2.0.5", + "which-boxed-primitive": "^1.1.0", + "which-collection": "^1.0.2", + "which-typed-array": "^1.1.16" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/which-builtin-type/node_modules/isarray": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-2.0.5.tgz", + "integrity": "sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==", + "dev": true, + "license": "MIT" + }, + "node_modules/which-collection": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/which-collection/-/which-collection-1.0.2.tgz", + "integrity": "sha512-K4jVyjnBdgvc86Y6BkaLZEN933SwYOuBFkdmBu9ZfkcAbdVbpITnDmjvZ/aQjRXQrv5EPkTnD1s39GiiqbngCw==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-map": "^2.0.3", + "is-set": "^2.0.3", + "is-weakmap": "^2.0.2", + "is-weakset": "^2.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/which-typed-array": { + "version": "1.1.19", + "resolved": "https://registry.npmjs.org/which-typed-array/-/which-typed-array-1.1.19.tgz", + "integrity": "sha512-rEvr90Bck4WZt9HHFC4DJMsjvu7x+r6bImz0/BrbWb7A2djJ8hnZMrWnHo9F8ssv0OMErasDhftrfROTyqSDrw==", + "dev": true, + "license": "MIT", + "dependencies": { + "available-typed-arrays": "^1.0.7", + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "for-each": "^0.3.5", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/word-wrap": { + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", + "integrity": "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/wrap-ansi": { + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", + "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.1.0", + "string-width": "^5.0.1", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs": { + "name": "wrap-ansi", + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true, + "license": "MIT" + }, + "node_modules/wrap-ansi-cjs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi/node_modules/ansi-styles": { + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz", + "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "dev": true, + "license": "ISC" + }, + "node_modules/ws": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xmlhttprequest-ssl": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.1.2.tgz", + "integrity": "sha512-TEU+nJVUUnA4CYJFLvK5X9AOeH4KvDvhIfm0vV1GaQRtchnG0hgK5p8hw/xjv8cunWYCsiPCSDzObPyhEwq3KQ==", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + } + } +} diff --git a/dimos/web/command-center-extension/package.json b/dimos/web/command-center-extension/package.json new file mode 100644 index 0000000000..9ee8d823a9 --- /dev/null +++ b/dimos/web/command-center-extension/package.json @@ -0,0 +1,43 @@ +{ + "name": "command-center-extension", + "displayName": "command-center-extension", + "description": "", + "publisher": "dimensional", + "homepage": "", + "version": "0.0.0", + "license": "UNLICENSED", + "main": "./dist/extension.js", + "keywords": [], + "scripts": { + "build": "foxglove-extension build", + "foxglove:prepublish": "foxglove-extension build --mode production", + "lint": "eslint .", + "lint:ci": "eslint .", + "lint:fix": "eslint --fix .", + "local-install": "foxglove-extension install", + "package": "foxglove-extension package", + "pretest": "foxglove-extension pretest" + }, + "devDependencies": { + "@foxglove/eslint-plugin": "2.1.0", + "@foxglove/extension": "2.34.0", + "@types/d3": "^7.4.3", + "@types/leaflet": "^1.9.21", + "@types/react": "18.3.24", + "@types/react-dom": "18.3.7", + "create-foxglove-extension": "1.0.6", + "eslint": "9.34.0", + "prettier": "3.6.2", + "react": "18.3.1", + "react-dom": "^18.3.1", + "typescript": "5.9.2" + }, + "dependencies": { + "@types/pako": "^2.0.4", + "d3": "^7.9.0", + "leaflet": "^1.9.4", + "pako": "^2.1.0", + "react-leaflet": "^4.2.1", + "socket.io-client": "^4.8.1" + } +} diff --git a/dimos/web/command-center-extension/src/App.tsx b/dimos/web/command-center-extension/src/App.tsx new file mode 100644 index 0000000000..838f15df59 --- /dev/null +++ b/dimos/web/command-center-extension/src/App.tsx @@ -0,0 +1,115 @@ +import * as React from "react"; + +import Connection from "./Connection"; +import ExplorePanel from "./ExplorePanel"; +import GpsButton from "./GpsButton"; +import KeyboardControlPanel from "./KeyboardControlPanel"; +import VisualizerWrapper from "./components/VisualizerWrapper"; +import LeafletMap from "./components/LeafletMap"; +import { AppAction, AppState, LatLon } from "./types"; + +function appReducer(state: AppState, action: AppAction): AppState { + switch (action.type) { + case "SET_COSTMAP": + return { ...state, costmap: action.payload }; + case "SET_ROBOT_POSE": + return { ...state, robotPose: action.payload }; + case "SET_GPS_LOCATION": + return { ...state, gpsLocation: action.payload }; + case "SET_GPS_TRAVEL_GOAL_POINTS": + return { ...state, gpsTravelGoalPoints: action.payload }; + case "SET_PATH": + return { ...state, path: action.payload }; + case "SET_FULL_STATE": + return { ...state, ...action.payload }; + default: + return state; + } +} + +const initialState: AppState = { + costmap: null, + robotPose: null, + gpsLocation: null, + gpsTravelGoalPoints: null, + path: null, +}; + +export default function App(): React.ReactElement { + const [state, dispatch] = React.useReducer(appReducer, initialState); + const [isGpsMode, setIsGpsMode] = React.useState(false); + const connectionRef = React.useRef(null); + + React.useEffect(() => { + connectionRef.current = new Connection(dispatch); + + return () => { + if (connectionRef.current) { + connectionRef.current.disconnect(); + } + }; + }, []); + + const handleWorldClick = React.useCallback((worldX: number, worldY: number) => { + connectionRef.current?.worldClick(worldX, worldY); + }, []); + + const handleStartExplore = React.useCallback(() => { + connectionRef.current?.startExplore(); + }, []); + + const handleStopExplore = React.useCallback(() => { + connectionRef.current?.stopExplore(); + }, []); + + const handleGpsGoal = React.useCallback((goal: LatLon) => { + connectionRef.current?.sendGpsGoal(goal); + }, []); + + const handleSendMoveCommand = React.useCallback( + (linear: [number, number, number], angular: [number, number, number]) => { + connectionRef.current?.sendMoveCommand(linear, angular); + }, + [], + ); + + const handleStopMoveCommand = React.useCallback(() => { + connectionRef.current?.stopMoveCommand(); + }, []); + + return ( +
+ {isGpsMode ? ( + + ) : ( + + )} +
+ setIsGpsMode(true)} + onUseCostmap={() => setIsGpsMode(false)} + > + + +
+
+ ); +} diff --git a/dimos/web/command-center-extension/src/Button.tsx b/dimos/web/command-center-extension/src/Button.tsx new file mode 100644 index 0000000000..8714bb8611 --- /dev/null +++ b/dimos/web/command-center-extension/src/Button.tsx @@ -0,0 +1,24 @@ +interface ButtonProps { + onClick: () => void; + isActive: boolean; + children: React.ReactNode; +} + +export default function Button({ onClick, isActive, children }: ButtonProps): React.ReactElement { + return ( + + ); +} diff --git a/dimos/web/command-center-extension/src/Connection.ts b/dimos/web/command-center-extension/src/Connection.ts new file mode 100644 index 0000000000..7a23c6b98c --- /dev/null +++ b/dimos/web/command-center-extension/src/Connection.ts @@ -0,0 +1,110 @@ +import { io, Socket } from "socket.io-client"; + +import { + AppAction, + Costmap, + EncodedCostmap, + EncodedPath, + EncodedVector, + FullStateData, + LatLon, + Path, + TwistCommand, + Vector, +} from "./types"; + +export default class Connection { + socket: Socket; + dispatch: React.Dispatch; + + constructor(dispatch: React.Dispatch) { + this.dispatch = dispatch; + this.socket = io("ws://localhost:7779"); + + this.socket.on("costmap", (data: EncodedCostmap) => { + const costmap = Costmap.decode(data); + this.dispatch({ type: "SET_COSTMAP", payload: costmap }); + }); + + this.socket.on("robot_pose", (data: EncodedVector) => { + const robotPose = Vector.decode(data); + this.dispatch({ type: "SET_ROBOT_POSE", payload: robotPose }); + }); + + this.socket.on("gps_location", (data: LatLon) => { + this.dispatch({ type: "SET_GPS_LOCATION", payload: data }); + }); + + this.socket.on("gps_travel_goal_points", (data: LatLon[]) => { + this.dispatch({ type: "SET_GPS_TRAVEL_GOAL_POINTS", payload: data }); + }); + + this.socket.on("path", (data: EncodedPath) => { + const path = Path.decode(data); + this.dispatch({ type: "SET_PATH", payload: path }); + }); + + this.socket.on("full_state", (data: FullStateData) => { + const state: Partial<{ costmap: Costmap; robotPose: Vector; gpsLocation: LatLon; gpsTravelGoalPoints: LatLon[]; path: Path }> = {}; + + if (data.costmap != undefined) { + state.costmap = Costmap.decode(data.costmap); + } + if (data.robot_pose != undefined) { + state.robotPose = Vector.decode(data.robot_pose); + } + if (data.gps_location != undefined) { + state.gpsLocation = data.gps_location; + } + if (data.path != undefined) { + state.path = Path.decode(data.path); + } + + this.dispatch({ type: "SET_FULL_STATE", payload: state }); + }); + } + + worldClick(worldX: number, worldY: number): void { + this.socket.emit("click", [worldX, worldY]); + } + + startExplore(): void { + this.socket.emit("start_explore"); + } + + stopExplore(): void { + this.socket.emit("stop_explore"); + } + + sendMoveCommand(linear: [number, number, number], angular: [number, number, number]): void { + const twist: TwistCommand = { + linear: { + x: linear[0], + y: linear[1], + z: linear[2], + }, + angular: { + x: angular[0], + y: angular[1], + z: angular[2], + }, + }; + this.socket.emit("move_command", twist); + } + + sendGpsGoal(goal: LatLon): void { + this.socket.emit("gps_goal", goal); + } + + stopMoveCommand(): void { + const twist: TwistCommand = { + linear: { x: 0, y: 0, z: 0 }, + angular: { x: 0, y: 0, z: 0 }, + }; + this.socket.emit("move_command", twist); + } + + disconnect(): void { + this.socket.disconnect(); + } +} diff --git a/dimos/web/command-center-extension/src/ExplorePanel.tsx b/dimos/web/command-center-extension/src/ExplorePanel.tsx new file mode 100644 index 0000000000..6210664591 --- /dev/null +++ b/dimos/web/command-center-extension/src/ExplorePanel.tsx @@ -0,0 +1,41 @@ +import * as React from "react"; + +import Button from "./Button"; + +interface ExplorePanelProps { + onStartExplore: () => void; + onStopExplore: () => void; +} + +export default function ExplorePanel({ + onStartExplore, + onStopExplore, +}: ExplorePanelProps): React.ReactElement { + const [exploring, setExploring] = React.useState(false); + + return ( +
+ {exploring ? ( + + ) : ( + + )} +
+ ); +} diff --git a/dimos/web/command-center-extension/src/GpsButton.tsx b/dimos/web/command-center-extension/src/GpsButton.tsx new file mode 100644 index 0000000000..74f0d73dfd --- /dev/null +++ b/dimos/web/command-center-extension/src/GpsButton.tsx @@ -0,0 +1,41 @@ +import * as React from "react"; + +import Button from "./Button"; + +interface GpsButtonProps { + onUseGps: () => void; + onUseCostmap: () => void; +} + +export default function GpsButton({ + onUseGps, + onUseCostmap, +}: GpsButtonProps): React.ReactElement { + const [gps, setGps] = React.useState(false); + + return ( +
+ {gps ? ( + + ) : ( + + )} +
+ ); +} diff --git a/dimos/web/command-center-extension/src/KeyboardControlPanel.tsx b/dimos/web/command-center-extension/src/KeyboardControlPanel.tsx new file mode 100644 index 0000000000..d4f5402557 --- /dev/null +++ b/dimos/web/command-center-extension/src/KeyboardControlPanel.tsx @@ -0,0 +1,167 @@ +import * as React from "react"; + +import Button from "./Button"; + +interface KeyboardControlPanelProps { + onSendMoveCommand: (linear: [number, number, number], angular: [number, number, number]) => void; + onStopMoveCommand: () => void; +} + +const linearSpeed = 0.5; +const angularSpeed = 0.8; +const publishRate = 10.0; // Hz + +function calculateVelocities(keys: Set) { + let linearX = 0.0; + let linearY = 0.0; + let angularY = 0.0; + let angularZ = 0.0; + + let speedMultiplier = 1.0; + if (keys.has("Shift")) { + speedMultiplier = 2.0; // Boost mode + } else if (keys.has("Control")) { + speedMultiplier = 0.5; // Slow mode + } + + // Check for stop command (space) + if (keys.has(" ")) { + return { linearX: 0, linearY: 0, angularY: 0, angularZ: 0 }; + } + + // Linear X (forward/backward) - W/S + if (keys.has("w")) { + linearX = linearSpeed * speedMultiplier; + } else if (keys.has("s")) { + linearX = -linearSpeed * speedMultiplier; + } + + // Angular Z (yaw/turn) - A/D + if (keys.has("a")) { + angularZ = angularSpeed * speedMultiplier; + } else if (keys.has("d")) { + angularZ = -angularSpeed * speedMultiplier; + } + + // Linear Y (strafe) - Left/Right arrows + if (keys.has("ArrowLeft")) { + linearY = linearSpeed * speedMultiplier; + } else if (keys.has("ArrowRight")) { + linearY = -linearSpeed * speedMultiplier; + } + + // Angular Y (pitch) - Up/Down arrows + if (keys.has("ArrowUp")) { + angularY = angularSpeed * speedMultiplier; + } else if (keys.has("ArrowDown")) { + angularY = -angularSpeed * speedMultiplier; + } + + return { linearX, linearY, angularY, angularZ }; +} + +export default function KeyboardControlPanel({ + onSendMoveCommand, + onStopMoveCommand, +}: KeyboardControlPanelProps): React.ReactElement { + const [isActive, setIsActive] = React.useState(false); + const keysPressed = React.useRef>(new Set()); + const intervalRef = React.useRef(null); + + const handleKeyDown = React.useCallback((event: KeyboardEvent) => { + // Prevent default for arrow keys and space to avoid scrolling + if (["ArrowUp", "ArrowDown", "ArrowLeft", "ArrowRight", " "].includes(event.key)) { + event.preventDefault(); + } + + const normalizedKey = event.key.length === 1 ? event.key.toLowerCase() : event.key; + keysPressed.current.add(normalizedKey); + }, []); + + const handleKeyUp = React.useCallback((event: KeyboardEvent) => { + const normalizedKey = event.key.length === 1 ? event.key.toLowerCase() : event.key; + keysPressed.current.delete(normalizedKey); + }, []); + + // Start/stop keyboard control + React.useEffect(() => { + keysPressed.current.clear(); + + if (!isActive) { + return undefined; + } + + document.addEventListener("keydown", handleKeyDown); + document.addEventListener("keyup", handleKeyUp); + + // Start publishing loop + intervalRef.current = setInterval(() => { + const velocities = calculateVelocities(keysPressed.current); + + onSendMoveCommand( + [velocities.linearX, velocities.linearY, 0], + [0, velocities.angularY, velocities.angularZ], + ); + }, 1000 / publishRate); + + return () => { + document.removeEventListener("keydown", handleKeyDown); + document.removeEventListener("keyup", handleKeyUp); + + if (intervalRef.current) { + clearInterval(intervalRef.current); + intervalRef.current = null; + } + + keysPressed.current.clear(); + onStopMoveCommand(); + }; + }, [isActive, handleKeyDown, handleKeyUp, onSendMoveCommand, onStopMoveCommand]); + + const toggleKeyboardControl = () => { + if (isActive) { + keysPressed.current.clear(); + setIsActive(false); + } else { + setIsActive(true); + } + }; + + React.useEffect(() => { + const handleBlur = () => { + if (isActive) { + keysPressed.current.clear(); + setIsActive(false); + } + }; + + const handleFocus = () => { + // Clear keys when window regains focus to avoid stuck keys + keysPressed.current.clear(); + }; + + window.addEventListener("blur", handleBlur); + window.addEventListener("focus", handleFocus); + + return () => { + window.removeEventListener("blur", handleBlur); + window.removeEventListener("focus", handleFocus); + }; + }, [isActive]); + + return ( +
+ {isActive && ( +
+
Controls:
+
W/S: Forward/Backward | A/D: Turn
+
Arrows: Strafe/Pitch | Space: Stop
+
Shift: Boost | Ctrl: Slow
+
+ )} + +
+ ); +} diff --git a/dimos/web/command-center-extension/src/components/CostmapLayer.tsx b/dimos/web/command-center-extension/src/components/CostmapLayer.tsx new file mode 100644 index 0000000000..3881f6f0d5 --- /dev/null +++ b/dimos/web/command-center-extension/src/components/CostmapLayer.tsx @@ -0,0 +1,165 @@ +import * as d3 from "d3"; +import * as React from "react"; + +import { Costmap } from "../types"; +import GridLayer from "./GridLayer"; + +interface CostmapLayerProps { + costmap: Costmap; + width: number; + height: number; +} + +const CostmapLayer = React.memo(({ costmap, width, height }) => { + const canvasRef = React.useRef(null); + const { grid, origin, resolution } = costmap; + const rows = Math.max(1, grid.shape[0] || 1); + const cols = Math.max(1, grid.shape[1] || 1); + + const axisMargin = { left: 60, bottom: 40 }; + const availableWidth = Math.max(1, width - axisMargin.left); + const availableHeight = Math.max(1, height - axisMargin.bottom); + + const cell = Math.max(0, Math.min(availableWidth / cols, availableHeight / rows)); + const gridW = Math.max(0, cols * cell); + const gridH = Math.max(0, rows * cell); + const offsetX = axisMargin.left + (availableWidth - gridW) / 2; + const offsetY = (availableHeight - gridH) / 2; + + // Pre-compute color lookup table using exact D3 colors (computed once on mount) + const colorLookup = React.useMemo(() => { + const lookup = new Uint8ClampedArray(256 * 3); // RGB values for -1 to 254 (255 total values) + + const customColorScale = (t: number) => { + if (t === 0) { + return "black"; + } + if (t < 0) { + return "#2d2136"; + } + if (t > 0.95) { + return "#000000"; + } + + const color = d3.interpolateTurbo(t * 2 - 1); + const hsl = d3.hsl(color); + hsl.s *= 0.75; + return hsl.toString(); + }; + + const colour = d3.scaleSequential(customColorScale).domain([-1, 100]); + + // Pre-compute all 256 possible color values + for (let i = 0; i < 256; i++) { + const value = i === 255 ? -1 : i; + const colorStr = colour(value); + const c = d3.color(colorStr); + + if (c) { + const rgb = c as d3.RGBColor; + lookup[i * 3] = rgb.r; + lookup[i * 3 + 1] = rgb.g; + lookup[i * 3 + 2] = rgb.b; + } else { + lookup[i * 3] = 0; + lookup[i * 3 + 1] = 0; + lookup[i * 3 + 2] = 0; + } + } + + return lookup; + }, []); + + React.useEffect(() => { + const canvas = canvasRef.current; + if (!canvas) { + return; + } + + // Validate grid data length matches dimensions + const expectedLength = rows * cols; + if (grid.data.length !== expectedLength) { + console.warn( + `Grid data length mismatch: expected ${expectedLength}, got ${grid.data.length} (rows=${rows}, cols=${cols})` + ); + } + + canvas.width = cols; + canvas.height = rows; + const ctx = canvas.getContext("2d"); + if (!ctx) { + return; + } + + const img = ctx.createImageData(cols, rows); + const data = grid.data; + const imgData = img.data; + + for (let i = 0; i < data.length && i < rows * cols; i++) { + const row = Math.floor(i / cols); + const col = i % cols; + const invertedRow = rows - 1 - row; + const srcIdx = invertedRow * cols + col; + + if (srcIdx < 0 || srcIdx >= data.length) { + continue; + } + + const value = data[i]!; + // Map value to lookup index (handle -1 -> 255 mapping) + const lookupIdx = value === -1 ? 255 : Math.min(254, Math.max(0, value)); + + const o = srcIdx * 4; + if (o < 0 || o + 3 >= imgData.length) { + continue; + } + + // Use pre-computed colors from lookup table + const colorOffset = lookupIdx * 3; + imgData[o] = colorLookup[colorOffset]!; + imgData[o + 1] = colorLookup[colorOffset + 1]!; + imgData[o + 2] = colorLookup[colorOffset + 2]!; + imgData[o + 3] = 255; + } + + ctx.putImageData(img, 0, 0); + }, [grid.data, cols, rows, colorLookup]); + + return ( + + +
+ +
+
+ +
+ ); +}); + +CostmapLayer.displayName = "CostmapLayer"; + +export default CostmapLayer; diff --git a/dimos/web/command-center-extension/src/components/GridLayer.tsx b/dimos/web/command-center-extension/src/components/GridLayer.tsx new file mode 100644 index 0000000000..87018cd3af --- /dev/null +++ b/dimos/web/command-center-extension/src/components/GridLayer.tsx @@ -0,0 +1,105 @@ +import * as d3 from "d3"; +import * as React from "react"; + +import { Vector } from "../types"; + +interface GridLayerProps { + width: number; + height: number; + origin: Vector; + resolution: number; + rows: number; + cols: number; +} + +const GridLayer = React.memo( + ({ width, height, origin, resolution, rows, cols }) => { + const minX = origin.coords[0]!; + const minY = origin.coords[1]!; + const maxX = minX + cols * resolution; + const maxY = minY + rows * resolution; + + const xScale = d3.scaleLinear().domain([minX, maxX]).range([0, width]); + const yScale = d3.scaleLinear().domain([minY, maxY]).range([height, 0]); + + const gridSize = 1 / resolution; + const gridLines = React.useMemo(() => { + const lines = []; + for (const x of d3.range(Math.ceil(minX / gridSize) * gridSize, maxX, gridSize)) { + lines.push( + , + ); + } + for (const y of d3.range(Math.ceil(minY / gridSize) * gridSize, maxY, gridSize)) { + lines.push( + , + ); + } + return lines; + }, [minX, minY, maxX, maxY, gridSize, xScale, yScale, width, height]); + + const xAxisRef = React.useRef(null); + const yAxisRef = React.useRef(null); + + React.useEffect(() => { + if (xAxisRef.current) { + const xAxis = d3.axisBottom(xScale).ticks(7); + d3.select(xAxisRef.current).call(xAxis); + d3.select(xAxisRef.current) + .selectAll("line,path") + .attr("stroke", "#ffffff") + .attr("stroke-width", 1); + d3.select(xAxisRef.current).selectAll("text").attr("fill", "#ffffff"); + } + if (yAxisRef.current) { + const yAxis = d3.axisLeft(yScale).ticks(7); + d3.select(yAxisRef.current).call(yAxis); + d3.select(yAxisRef.current) + .selectAll("line,path") + .attr("stroke", "#ffffff") + .attr("stroke-width", 1); + d3.select(yAxisRef.current).selectAll("text").attr("fill", "#ffffff"); + } + }, [xScale, yScale]); + + const showOrigin = minX <= 0 && 0 <= maxX && minY <= 0 && 0 <= maxY; + + return ( + <> + {gridLines} + + + {showOrigin && ( + + + + World Origin (0,0) + + + )} + + ); + }, +); + +GridLayer.displayName = "GridLayer"; + +export default GridLayer; diff --git a/dimos/web/command-center-extension/src/components/LeafletMap.tsx b/dimos/web/command-center-extension/src/components/LeafletMap.tsx new file mode 100644 index 0000000000..d0ad2380c4 --- /dev/null +++ b/dimos/web/command-center-extension/src/components/LeafletMap.tsx @@ -0,0 +1,150 @@ +import * as React from "react"; +import { MapContainer, TileLayer, Marker, Popup, useMapEvents } from "react-leaflet"; +import L, { LatLngExpression } from "leaflet"; +import { LatLon } from "../types"; + +// Fix for default marker icons in react-leaflet +// Using CDN URLs since webpack can't handle the image imports directly +const DefaultIcon = L.icon({ + iconUrl: "https://unpkg.com/leaflet@1.9.4/dist/images/marker-icon.png", + shadowUrl: "https://unpkg.com/leaflet@1.9.4/dist/images/marker-shadow.png", + iconSize: [25, 41], + iconAnchor: [12, 41], +}); + +L.Marker.prototype.options.icon = DefaultIcon; + +// Component to handle map click events +function MapClickHandler({ onMapClick }: { onMapClick: (lat: number, lng: number) => void }) { + useMapEvents({ + click: (e) => { + onMapClick(e.latlng.lat, e.latlng.lng); + }, + }); + return null; +} + +interface LeafletMapProps { + gpsLocation: LatLon | null; + gpsTravelGoalPoints: LatLon[] | null; + onGpsGoal: (goal: LatLon) => void; +} + +const LeafletMap: React.FC = ({ gpsLocation, gpsTravelGoalPoints, onGpsGoal }) => { + if (!gpsLocation) { + return ( +
+ GPS location not received yet. +
+ ); + } + + const center: LatLngExpression = [gpsLocation.lat, gpsLocation.lon]; + + return ( +
+ + + + { + onGpsGoal({ lat, lon: lng }); + }} /> + + Current GPS Location + + {gpsTravelGoalPoints !== null && ( + gpsTravelGoalPoints.map(p => ( + + )) + )} + +
+ ); +}; + +const leafletCss = ` +.leaflet-control-container { + display: none; +} +.leaflet-container { + width: 100%; + height: 100%; + position: relative; +} +.leaflet-pane, +.leaflet-tile, +.leaflet-marker-icon, +.leaflet-marker-shadow, +.leaflet-tile-container, +.leaflet-pane > svg, +.leaflet-pane > canvas, +.leaflet-zoom-box, +.leaflet-image-layer, +.leaflet-layer { + position: absolute; + left: 0; + top: 0; +} +.leaflet-container { + overflow: hidden; + -webkit-tap-highlight-color: transparent; + background: #ddd; + outline: 0; + font: 12px/1.5 "Helvetica Neue", Arial, Helvetica, sans-serif; +} +.leaflet-tile { + filter: inherit; + visibility: hidden; +} +.leaflet-tile-loaded { + visibility: inherit; +} +.leaflet-zoom-box { + width: 0; + height: 0; + -moz-box-sizing: border-box; + box-sizing: border-box; + z-index: 800; +} +.leaflet-control { + position: relative; + z-index: 800; + pointer-events: visiblePainted; + pointer-events: auto; +} +.leaflet-top, +.leaflet-bottom { + position: absolute; + z-index: 1000; + pointer-events: none; +} +.leaflet-top { + top: 0; +} +.leaflet-right { + right: 0; +} +.leaflet-bottom { + bottom: 0; +} +.leaflet-left { + left: 0; +} +`; + +export default LeafletMap; diff --git a/dimos/web/command-center-extension/src/components/PathLayer.tsx b/dimos/web/command-center-extension/src/components/PathLayer.tsx new file mode 100644 index 0000000000..969c9cf7dc --- /dev/null +++ b/dimos/web/command-center-extension/src/components/PathLayer.tsx @@ -0,0 +1,57 @@ +import * as d3 from "d3"; +import * as React from "react"; + +import { Path } from "../types"; + +interface PathLayerProps { + path: Path; + worldToPx: (x: number, y: number) => [number, number]; +} + +const PathLayer = React.memo(({ path, worldToPx }) => { + const points = React.useMemo( + () => path.coords.map(([x, y]) => worldToPx(x, y)), + [path.coords, worldToPx], + ); + + const pathData = React.useMemo(() => { + const line = d3.line(); + return line(points); + }, [points]); + + const gradientId = React.useMemo(() => `path-gradient-${Date.now()}`, []); + + if (path.coords.length < 2) { + return null; + } + + return ( + <> + + + + + + + + + ); +}); + +PathLayer.displayName = "PathLayer"; + +export default PathLayer; diff --git a/dimos/web/command-center-extension/src/components/VectorLayer.tsx b/dimos/web/command-center-extension/src/components/VectorLayer.tsx new file mode 100644 index 0000000000..87b932d0a4 --- /dev/null +++ b/dimos/web/command-center-extension/src/components/VectorLayer.tsx @@ -0,0 +1,41 @@ +import * as React from "react"; + +import { Vector } from "../types"; + +interface VectorLayerProps { + vector: Vector; + label: string; + worldToPx: (x: number, y: number) => [number, number]; +} + +const VectorLayer = React.memo(({ vector, label, worldToPx }) => { + const [cx, cy] = worldToPx(vector.coords[0]!, vector.coords[1]!); + const text = `${label} (${vector.coords[0]!.toFixed(2)}, ${vector.coords[1]!.toFixed(2)})`; + + return ( + <> + + + + + + + + {text} + + + + ); +}); + +VectorLayer.displayName = "VectorLayer"; + +export default VectorLayer; diff --git a/dimos/web/command-center-extension/src/components/VisualizerComponent.tsx b/dimos/web/command-center-extension/src/components/VisualizerComponent.tsx new file mode 100644 index 0000000000..e5bdb7f58e --- /dev/null +++ b/dimos/web/command-center-extension/src/components/VisualizerComponent.tsx @@ -0,0 +1,102 @@ +import * as d3 from "d3"; +import * as React from "react"; + +import { Costmap, Path, Vector } from "../types"; +import CostmapLayer from "./CostmapLayer"; +import PathLayer from "./PathLayer"; +import VectorLayer from "./VectorLayer"; + +interface VisualizerComponentProps { + costmap: Costmap | null; + robotPose: Vector | null; + path: Path | null; +} + +const VisualizerComponent: React.FC = ({ costmap, robotPose, path }) => { + const svgRef = React.useRef(null); + const [dimensions, setDimensions] = React.useState({ width: 800, height: 600 }); + const { width, height } = dimensions; + + React.useEffect(() => { + if (!svgRef.current?.parentElement) { + return; + } + + const updateDimensions = () => { + const rect = svgRef.current?.parentElement?.getBoundingClientRect(); + if (rect) { + setDimensions({ width: rect.width, height: rect.height }); + } + }; + + updateDimensions(); + const observer = new ResizeObserver(updateDimensions); + observer.observe(svgRef.current.parentElement); + + return () => { + observer.disconnect(); + }; + }, []); + + const { worldToPx } = React.useMemo(() => { + if (!costmap) { + return { worldToPx: undefined }; + } + + const { + grid: { shape }, + origin, + resolution, + } = costmap; + const rows = shape[0]!; + const cols = shape[1]!; + + const axisMargin = { left: 60, bottom: 40 }; + const availableWidth = width - axisMargin.left; + const availableHeight = height - axisMargin.bottom; + + const cell = Math.min(availableWidth / cols, availableHeight / rows); + const gridW = cols * cell; + const gridH = rows * cell; + const offsetX = axisMargin.left + (availableWidth - gridW) / 2; + const offsetY = (availableHeight - gridH) / 2; + + const xScale = d3 + .scaleLinear() + .domain([origin.coords[0]!, origin.coords[0]! + cols * resolution]) + .range([offsetX, offsetX + gridW]); + + const yScale = d3 + .scaleLinear() + .domain([origin.coords[1]!, origin.coords[1]! + rows * resolution]) + .range([offsetY + gridH, offsetY]); + + const worldToPxFn = (x: number, y: number): [number, number] => [xScale(x), yScale(y)]; + + return { worldToPx: worldToPxFn }; + }, [costmap, width, height]); + + return ( +
+ + {costmap && } + {path && worldToPx && } + {robotPose && worldToPx && ( + + )} + +
+ ); +}; + +export default React.memo(VisualizerComponent); diff --git a/dimos/web/command-center-extension/src/components/VisualizerWrapper.tsx b/dimos/web/command-center-extension/src/components/VisualizerWrapper.tsx new file mode 100644 index 0000000000..e137019ae1 --- /dev/null +++ b/dimos/web/command-center-extension/src/components/VisualizerWrapper.tsx @@ -0,0 +1,86 @@ +import * as d3 from "d3"; +import * as React from "react"; + +import { AppState } from "../types"; +import VisualizerComponent from "./VisualizerComponent"; + +interface VisualizerWrapperProps { + data: AppState; + onWorldClick: (worldX: number, worldY: number) => void; +} + +const VisualizerWrapper: React.FC = ({ data, onWorldClick }) => { + const containerRef = React.useRef(null); + const lastClickTime = React.useRef(0); + const clickThrottleMs = 150; + + const handleClick = React.useCallback( + (event: React.MouseEvent) => { + if (!data.costmap || !containerRef.current) { + return; + } + + event.stopPropagation(); + + const now = Date.now(); + if (now - lastClickTime.current < clickThrottleMs) { + console.log("Click throttled"); + return; + } + lastClickTime.current = now; + + const svgElement = containerRef.current.querySelector("svg"); + if (!svgElement) { + return; + } + + const svgRect = svgElement.getBoundingClientRect(); + const clickX = event.clientX - svgRect.left; + const clickY = event.clientY - svgRect.top; + + const costmap = data.costmap; + const { + grid: { shape }, + origin, + resolution, + } = costmap; + const rows = shape[0]!; + const cols = shape[1]!; + const width = svgRect.width; + const height = svgRect.height; + + const axisMargin = { left: 60, bottom: 40 }; + const availableWidth = width - axisMargin.left; + const availableHeight = height - axisMargin.bottom; + + const cell = Math.min(availableWidth / cols, availableHeight / rows); + const gridW = cols * cell; + const gridH = rows * cell; + const offsetX = axisMargin.left + (availableWidth - gridW) / 2; + const offsetY = (availableHeight - gridH) / 2; + + const xScale = d3 + .scaleLinear() + .domain([origin.coords[0]!, origin.coords[0]! + cols * resolution]) + .range([offsetX, offsetX + gridW]); + const yScale = d3 + .scaleLinear() + .domain([origin.coords[1]!, origin.coords[1]! + rows * resolution]) + .range([offsetY + gridH, offsetY]); + + const worldX = xScale.invert(clickX); + const worldY = yScale.invert(clickY); + + onWorldClick(worldX, worldY); + }, + [data.costmap, onWorldClick], + ); + + return ( +
+ +
+ ); +}; + +export default VisualizerWrapper; diff --git a/dimos/web/command-center-extension/src/index.ts b/dimos/web/command-center-extension/src/index.ts new file mode 100644 index 0000000000..052f967e37 --- /dev/null +++ b/dimos/web/command-center-extension/src/index.ts @@ -0,0 +1,14 @@ +import { PanelExtensionContext, ExtensionContext } from "@foxglove/extension"; + +import { initializeApp } from "./init"; + +export function activate(extensionContext: ExtensionContext): void { + extensionContext.registerPanel({ name: "command-center", initPanel }); +} + +export function initPanel(context: PanelExtensionContext): () => void { + initializeApp(context.panelElement); + return () => { + // Cleanup function + }; +} diff --git a/dimos/web/command-center-extension/src/init.ts b/dimos/web/command-center-extension/src/init.ts new file mode 100644 index 0000000000..f57f3aa582 --- /dev/null +++ b/dimos/web/command-center-extension/src/init.ts @@ -0,0 +1,9 @@ +import * as React from "react"; +import * as ReactDOMClient from "react-dom/client"; + +import App from "./App"; + +export function initializeApp(element: HTMLElement): void { + const root = ReactDOMClient.createRoot(element); + root.render(React.createElement(App)); +} diff --git a/dimos/web/command-center-extension/src/optimizedCostmap.ts b/dimos/web/command-center-extension/src/optimizedCostmap.ts new file mode 100644 index 0000000000..2244437eab --- /dev/null +++ b/dimos/web/command-center-extension/src/optimizedCostmap.ts @@ -0,0 +1,120 @@ +import * as pako from 'pako'; + +export interface EncodedOptimizedGrid { + update_type: "full" | "delta"; + shape: [number, number]; + dtype: string; + compressed: boolean; + compression?: "zlib" | "none"; + data?: string; + chunks?: Array<{ + pos: [number, number]; + size: [number, number]; + data: string; + }>; +} + +export class OptimizedGrid { + private fullGrid: Uint8Array | null = null; + private shape: [number, number] = [0, 0]; + + decode(msg: EncodedOptimizedGrid): Float32Array { + if (msg.update_type === "full") { + return this.decodeFull(msg); + } else { + return this.decodeDelta(msg); + } + } + + private decodeFull(msg: EncodedOptimizedGrid): Float32Array { + if (!msg.data) { + throw new Error("Missing data for full update"); + } + + const binaryString = atob(msg.data); + const compressed = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + compressed[i] = binaryString.charCodeAt(i); + } + + // Decompress if needed + let decompressed: Uint8Array; + if (msg.compressed && msg.compression === "zlib") { + decompressed = pako.inflate(compressed); + } else { + decompressed = compressed; + } + + // Store for delta updates + this.fullGrid = decompressed; + this.shape = msg.shape; + + // Convert uint8 back to float32 costmap values + const float32Data = new Float32Array(decompressed.length); + for (let i = 0; i < decompressed.length; i++) { + // Map 255 back to -1 for unknown cells + const val = decompressed[i]!; + float32Data[i] = val === 255 ? -1 : val; + } + + return float32Data; + } + + private decodeDelta(msg: EncodedOptimizedGrid): Float32Array { + if (!this.fullGrid) { + console.warn("No full grid available for delta update - skipping until full update arrives"); + const size = msg.shape[0] * msg.shape[1]; + return new Float32Array(size).fill(-1); + } + + if (!msg.chunks) { + throw new Error("Missing chunks for delta update"); + } + + // Apply delta updates to the full grid + for (const chunk of msg.chunks) { + const [y, x] = chunk.pos; + const [h, w] = chunk.size; + + // Decode chunk data + const binaryString = atob(chunk.data); + const compressed = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + compressed[i] = binaryString.charCodeAt(i); + } + + let decompressed: Uint8Array; + if (msg.compressed && msg.compression === "zlib") { + decompressed = pako.inflate(compressed); + } else { + decompressed = compressed; + } + + // Update the full grid with chunk data + const width = this.shape[1]; + let chunkIdx = 0; + for (let cy = 0; cy < h; cy++) { + for (let cx = 0; cx < w; cx++) { + const gridIdx = (y + cy) * width + (x + cx); + const val = decompressed[chunkIdx++]; + if (val !== undefined) { + this.fullGrid[gridIdx] = val; + } + } + } + } + + // Convert to float32 + const float32Data = new Float32Array(this.fullGrid.length); + for (let i = 0; i < this.fullGrid.length; i++) { + const val = this.fullGrid[i]!; + float32Data[i] = val === 255 ? -1 : val; + } + + return float32Data; + } + + getShape(): [number, number] { + return this.shape; + } +} diff --git a/dimos/web/command-center-extension/src/types.ts b/dimos/web/command-center-extension/src/types.ts new file mode 100644 index 0000000000..5f3a804a9c --- /dev/null +++ b/dimos/web/command-center-extension/src/types.ts @@ -0,0 +1,127 @@ +import { EncodedOptimizedGrid, OptimizedGrid } from './optimizedCostmap'; + +export type EncodedVector = Encoded<"vector"> & { + c: number[]; +}; + +export class Vector { + coords: number[]; + constructor(...coords: number[]) { + this.coords = coords; + } + + static decode(data: EncodedVector): Vector { + return new Vector(...data.c); + } +} + +export interface LatLon { + lat: number; + lon: number; + alt?: number; +} + +export type EncodedPath = Encoded<"path"> & { + points: Array<[number, number]>; +}; + +export class Path { + constructor(public coords: Array<[number, number]>) {} + + static decode(data: EncodedPath): Path { + return new Path(data.points); + } +} + +export type EncodedCostmap = Encoded<"costmap"> & { + grid: EncodedOptimizedGrid; + origin: EncodedVector; + resolution: number; + origin_theta: number; +}; + +export class Costmap { + constructor( + public grid: Grid, + public origin: Vector, + public resolution: number, + public origin_theta: number, + ) { + this.grid = grid; + this.origin = origin; + this.resolution = resolution; + this.origin_theta = origin_theta; + } + + private static decoder: OptimizedGrid | null = null; + + static decode(data: EncodedCostmap): Costmap { + // Use a singleton decoder to maintain state for delta updates + if (!Costmap.decoder) { + Costmap.decoder = new OptimizedGrid(); + } + + const float32Data = Costmap.decoder.decode(data.grid); + const shape = data.grid.shape; + + // Create a Grid object from the decoded data + const grid = new Grid(float32Data, shape); + + return new Costmap( + grid, + Vector.decode(data.origin), + data.resolution, + data.origin_theta, + ); + } +} + +export class Grid { + constructor( + public data: Float32Array | Float64Array | Int32Array | Int8Array, + public shape: number[], + ) {} +} + +export type Drawable = Costmap | Vector | Path; + +export type Encoded = { + type: T; +}; + +export interface FullStateData { + costmap?: EncodedCostmap; + robot_pose?: EncodedVector; + gps_location?: LatLon; + gps_travel_goal_points?: LatLon[]; + path?: EncodedPath; +} + +export interface TwistCommand { + linear: { + x: number; + y: number; + z: number; + }; + angular: { + x: number; + y: number; + z: number; + }; +} + +export interface AppState { + costmap: Costmap | null; + robotPose: Vector | null; + gpsLocation: LatLon | null; + gpsTravelGoalPoints: LatLon[] | null; + path: Path | null; +} + +export type AppAction = + | { type: "SET_COSTMAP"; payload: Costmap } + | { type: "SET_ROBOT_POSE"; payload: Vector } + | { type: "SET_GPS_LOCATION"; payload: LatLon } + | { type: "SET_GPS_TRAVEL_GOAL_POINTS"; payload: LatLon[] } + | { type: "SET_PATH"; payload: Path } + | { type: "SET_FULL_STATE"; payload: Partial }; diff --git a/dimos/web/command-center-extension/tsconfig.json b/dimos/web/command-center-extension/tsconfig.json new file mode 100644 index 0000000000..b4ead7c4a8 --- /dev/null +++ b/dimos/web/command-center-extension/tsconfig.json @@ -0,0 +1,22 @@ +{ + "extends": "create-foxglove-extension/tsconfig/tsconfig.json", + "include": [ + "./src/**/*" + ], + "compilerOptions": { + "rootDir": "./src", + "outDir": "./dist", + "lib": [ + "dom" + ], + "composite": false, + "declaration": false, + "noFallthroughCasesInSwitch": true, + "noImplicitAny": true, + "noImplicitReturns": true, + "noUncheckedIndexedAccess": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "forceConsistentCasingInFileNames": true + } +} diff --git a/dimos/web/dimos_interface/.gitignore b/dimos/web/dimos_interface/.gitignore new file mode 100644 index 0000000000..8f2a0d7c82 --- /dev/null +++ b/dimos/web/dimos_interface/.gitignore @@ -0,0 +1,41 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +# Dependencies and builds +node_modules +dist +dist-ssr +.vite/ +*.local +dist.zip +yarn.lock +package-lock.json + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? + +# Environment variables +.env +.env.* +!.env.example + +# GitHub directory from original repo +.github/ + +docs/ +vite.config.ts.timestamp-* +httpd.conf diff --git a/dimos/web/dimos_interface/__init__.py b/dimos/web/dimos_interface/__init__.py new file mode 100644 index 0000000000..5ca28b30e5 --- /dev/null +++ b/dimos/web/dimos_interface/__init__.py @@ -0,0 +1,7 @@ +""" +Dimensional Interface package +""" + +from .api.server import FastAPIServer + +__all__ = ["FastAPIServer"] diff --git a/dimos/web/dimos_interface/api/README.md b/dimos/web/dimos_interface/api/README.md new file mode 100644 index 0000000000..37cafd6e52 --- /dev/null +++ b/dimos/web/dimos_interface/api/README.md @@ -0,0 +1,86 @@ +# Unitree API Server + +This is a minimal FastAPI server implementation that provides API endpoints for the terminal frontend. + +## Quick Start + +```bash +# Navigate to the api directory +cd api + +# Install minimal requirements +pip install -r requirements.txt + +# Run the server +python unitree_server.py +``` + +The server will start on `http://0.0.0.0:5555`. + +## Integration with Frontend + +1. Start the API server as described above +2. In another terminal, run the frontend from the root directory: + ```bash + cd .. # Navigate to root directory (if you're in api/) + yarn dev + ``` +3. Use the `unitree` command in the terminal interface: + - `unitree status` - Check the API status + - `unitree command ` - Send a command to the API + +## Integration with DIMOS Agents + +See DimOS Documentation for more info. + +```python +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.web.robot_web_interface import RobotWebInterface + +robot_ip = os.getenv("ROBOT_IP") + +# Initialize robot +logger.info("Initializing Unitree Robot") +robot = UnitreeGo2(ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir) + +# Set up video stream +logger.info("Starting video stream") +video_stream = robot.get_ros_video_stream() + +# Create FastAPI server with video stream +logger.info("Initializing FastAPI server") +streams = {"unitree_video": video_stream} +web_interface = RobotWebInterface(port=5555, **streams) + +# Initialize agent with robot skills +skills_instance = MyUnitreeSkills(robot=robot) + +agent = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgent", + input_query_stream=web_interface.query_stream, + output_dir=output_dir, + skills=skills_instance, +) + +web_interface.run() +``` + +## API Endpoints + +- **GET /unitree/status**: Check the status of the Unitree API +- **POST /unitree/command**: Send a command to the Unitree API + +## How It Works + +The frontend and backend are separate applications: + +1. The Svelte frontend runs on port 3000 via Vite +2. The FastAPI backend runs on port 5555 +3. Vite's development server proxies requests from `/unitree/*` to the FastAPI server +4. The `unitree` command in the terminal interface sends requests to these endpoints + +This architecture allows the frontend and backend to be developed and run independently. diff --git a/dimos/web/dimos_interface/api/__init__.py b/dimos/web/dimos_interface/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dimos/web/dimos_interface/api/requirements.txt b/dimos/web/dimos_interface/api/requirements.txt new file mode 100644 index 0000000000..a1ab33e428 --- /dev/null +++ b/dimos/web/dimos_interface/api/requirements.txt @@ -0,0 +1,7 @@ +fastapi==0.104.1 +uvicorn==0.24.0 +reactivex==4.0.4 +numpy<2.0.0 # Specify older NumPy version for cv2 compatibility +opencv-python==4.8.1.78 +python-multipart==0.0.6 +jinja2==3.1.2 diff --git a/dimos/web/dimos_interface/api/server.py b/dimos/web/dimos_interface/api/server.py new file mode 100644 index 0000000000..6692e90f46 --- /dev/null +++ b/dimos/web/dimos_interface/api/server.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Working FastAPI/Uvicorn Impl. + +# Notes: Do not use simultaneously with Flask, this includes imports. +# Workers are not yet setup, as this requires a much more intricate +# reorganization. There appears to be possible signalling issues when +# opening up streams on multiple windows/reloading which will need to +# be fixed. Also note, Chrome only supports 6 simultaneous web streams, +# and its advised to test threading/worker performance with another +# browser like Safari. + +# Fast Api & Uvicorn +import asyncio + +# For audio processing +import io +from pathlib import Path +from queue import Empty, Queue +from threading import Lock +import time + +import cv2 +from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse +from fastapi.templating import Jinja2Templates +import ffmpeg # type: ignore[import-untyped] +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import SingleAssignmentDisposable +import soundfile as sf # type: ignore[import-untyped] +from sse_starlette.sse import EventSourceResponse +import uvicorn + +from dimos.stream.audio.base import AudioEvent +from dimos.web.edge_io import EdgeIO + +# TODO: Resolve threading, start/stop stream functionality. + + +class FastAPIServer(EdgeIO): + def __init__( # type: ignore[no-untyped-def] + self, + dev_name: str = "FastAPI Server", + edge_type: str = "Bidirectional", + host: str = "0.0.0.0", + port: int = 5555, + text_streams=None, + audio_subject=None, + **streams, + ) -> None: + print("Starting FastAPIServer initialization...") # Debug print + super().__init__(dev_name, edge_type) + self.app = FastAPI() + self._server: uvicorn.Server | None = None + + # Add CORS middleware with more permissive settings for development + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # More permissive for development + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], + ) + + self.port = port + self.host = host + BASE_DIR = Path(__file__).resolve().parent + self.templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) + self.streams = streams + self.active_streams = {} + self.stream_locks = {key: Lock() for key in self.streams} + self.stream_queues = {} # type: ignore[var-annotated] + self.stream_disposables = {} # type: ignore[var-annotated] + + # Initialize text streams + self.text_streams = text_streams or {} + self.text_queues = {} # type: ignore[var-annotated] + self.text_disposables = {} + self.text_clients = set() # type: ignore[var-annotated] + + # Create a Subject for text queries + self.query_subject = rx.subject.Subject() # type: ignore[var-annotated] + self.query_stream = self.query_subject.pipe(ops.share()) + self.audio_subject = audio_subject + + for key in self.streams: + if self.streams[key] is not None: + self.active_streams[key] = self.streams[key].pipe( + ops.map(self.process_frame_fastapi), ops.share() + ) + + # Set up text stream subscriptions + for key, stream in self.text_streams.items(): + if stream is not None: + self.text_queues[key] = Queue(maxsize=100) + disposable = stream.subscribe( + lambda text, k=key: self.text_queues[k].put(text) if text is not None else None, + lambda e, k=key: self.text_queues[k].put(None), + lambda k=key: self.text_queues[k].put(None), + ) + self.text_disposables[key] = disposable + self.disposables.add(disposable) + + print("Setting up routes...") # Debug print + self.setup_routes() + print("FastAPIServer initialization complete") # Debug print + + def process_frame_fastapi(self, frame): # type: ignore[no-untyped-def] + """Convert frame to JPEG format for streaming.""" + _, buffer = cv2.imencode(".jpg", frame) + return buffer.tobytes() + + def stream_generator(self, key): # type: ignore[no-untyped-def] + """Generate frames for a given video stream.""" + + def generate(): # type: ignore[no-untyped-def] + if key not in self.stream_queues: + self.stream_queues[key] = Queue(maxsize=10) + + frame_queue = self.stream_queues[key] + + # Clear any existing disposable for this stream + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + disposable = SingleAssignmentDisposable() + self.stream_disposables[key] = disposable + self.disposables.add(disposable) + + if key in self.active_streams: + with self.stream_locks[key]: + # Clear the queue before starting new subscription + while not frame_queue.empty(): + try: + frame_queue.get_nowait() + except Empty: + break + + disposable.disposable = self.active_streams[key].subscribe( + lambda frame: frame_queue.put(frame) if frame is not None else None, + lambda e: frame_queue.put(None), + lambda: frame_queue.put(None), + ) + + try: + while True: + try: + frame = frame_queue.get(timeout=1) + if frame is None: + break + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") + except Empty: + # Instead of breaking, continue waiting for new frames + continue + finally: + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + return generate + + def create_video_feed_route(self, key): # type: ignore[no-untyped-def] + """Create a video feed route for a specific stream.""" + + async def video_feed(): # type: ignore[no-untyped-def] + return StreamingResponse( + self.stream_generator(key)(), # type: ignore[no-untyped-call] + media_type="multipart/x-mixed-replace; boundary=frame", + ) + + return video_feed + + async def text_stream_generator(self, key): # type: ignore[no-untyped-def] + """Generate SSE events for text stream.""" + client_id = id(object()) + self.text_clients.add(client_id) + + try: + while True: + if key not in self.text_queues: + yield {"event": "ping", "data": ""} + await asyncio.sleep(0.1) + continue + + try: + text = self.text_queues[key].get_nowait() + if text is not None: + yield {"event": "message", "id": key, "data": text} + else: + break + except Empty: + yield {"event": "ping", "data": ""} + await asyncio.sleep(0.1) + finally: + self.text_clients.remove(client_id) + + @staticmethod + def _decode_audio(raw: bytes) -> tuple[np.ndarray, int]: # type: ignore[type-arg] + """Convert the webm/opus blob sent by the browser into mono 16-kHz PCM.""" + try: + # Use ffmpeg to convert to 16-kHz mono 16-bit PCM WAV in memory + out, _ = ( + ffmpeg.input("pipe:0") + .output( + "pipe:1", + format="wav", + acodec="pcm_s16le", + ac=1, + ar="16000", + loglevel="quiet", + ) + .run(input=raw, capture_stdout=True, capture_stderr=True) + ) + # Load with soundfile (returns float32 by default) + audio, sr = sf.read(io.BytesIO(out), dtype="float32") + # Ensure 1-D array (mono) + if audio.ndim > 1: + audio = audio[:, 0] + return np.array(audio), sr + except Exception as exc: + print(f"ffmpeg decoding failed: {exc}") + return None, None # type: ignore[return-value] + + def setup_routes(self) -> None: + """Set up FastAPI routes.""" + + @self.app.get("/streams") + async def get_streams(): # type: ignore[no-untyped-def] + """Get list of available video streams""" + return {"streams": list(self.streams.keys())} + + @self.app.get("/text_streams") + async def get_text_streams(): # type: ignore[no-untyped-def] + """Get list of available text streams""" + return {"streams": list(self.text_streams.keys())} + + @self.app.get("/", response_class=HTMLResponse) + async def index(request: Request): # type: ignore[no-untyped-def] + stream_keys = list(self.streams.keys()) + text_stream_keys = list(self.text_streams.keys()) + return self.templates.TemplateResponse( + "index_fastapi.html", + { + "request": request, + "stream_keys": stream_keys, + "text_stream_keys": text_stream_keys, + "has_voice": self.audio_subject is not None, + }, + ) + + @self.app.post("/submit_query") + async def submit_query(query: str = Form(...)): # type: ignore[no-untyped-def] + # Using Form directly as a dependency ensures proper form handling + try: + if query: + # Emit the query through our Subject + self.query_subject.on_next(query) + return JSONResponse({"success": True, "message": "Query received"}) + return JSONResponse({"success": False, "message": "No query provided"}) + except Exception as e: + # Ensure we always return valid JSON even on error + return JSONResponse( + status_code=500, + content={"success": False, "message": f"Server error: {e!s}"}, + ) + + @self.app.post("/upload_audio") + async def upload_audio(file: UploadFile = File(...)): # type: ignore[no-untyped-def] + """Handle audio upload from the browser.""" + if self.audio_subject is None: + return JSONResponse( + status_code=400, + content={"success": False, "message": "Voice input not configured"}, + ) + + try: + data = await file.read() + audio_np, sr = self._decode_audio(data) + if audio_np is None: + return JSONResponse( + status_code=400, + content={"success": False, "message": "Unable to decode audio"}, + ) + + event = AudioEvent( + data=audio_np, + sample_rate=sr, + timestamp=time.time(), + channels=1 if audio_np.ndim == 1 else audio_np.shape[1], + ) + + # Push to reactive stream + self.audio_subject.on_next(event) + print(f"Received audio - {event.data.shape[0] / sr:.2f} s, {sr} Hz") + return {"success": True} + except Exception as e: + print(f"Failed to process uploaded audio: {e}") + return JSONResponse(status_code=500, content={"success": False, "message": str(e)}) + + # Unitree API endpoints + @self.app.get("/unitree/status") + async def unitree_status(): # type: ignore[no-untyped-def] + """Check the status of the Unitree API server""" + return JSONResponse({"status": "online", "service": "unitree"}) + + @self.app.post("/unitree/command") + async def unitree_command(request: Request): # type: ignore[no-untyped-def] + """Process commands sent from the terminal frontend""" + try: + data = await request.json() + command_text = data.get("command", "") + + # Emit the command through the query_subject + self.query_subject.on_next(command_text) + + response = { + "success": True, + "command": command_text, + "result": f"Processed command: {command_text}", + } + + return JSONResponse(response) + except Exception as e: + print(f"Error processing command: {e!s}") + return JSONResponse( + status_code=500, + content={"success": False, "message": f"Error processing command: {e!s}"}, + ) + + @self.app.get("/text_stream/{key}") + async def text_stream(key: str): # type: ignore[no-untyped-def] + if key not in self.text_streams: + raise HTTPException(status_code=404, detail=f"Text stream '{key}' not found") + return EventSourceResponse(self.text_stream_generator(key)) # type: ignore[no-untyped-call] + + for key in self.streams: + self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) # type: ignore[no-untyped-call] + + def run(self) -> None: + config = uvicorn.Config( + self.app, + host=self.host, + port=self.port, + log_level="error", # Reduce verbosity + ) + self._server = uvicorn.Server(config) + self._server.run() + + def shutdown(self) -> None: + if self._server is not None: + self._server.should_exit = True + + +if __name__ == "__main__": + server = FastAPIServer() + server.run() diff --git a/dimos/web/dimos_interface/api/templates/index_fastapi.html b/dimos/web/dimos_interface/api/templates/index_fastapi.html new file mode 100644 index 0000000000..4cfe943fc7 --- /dev/null +++ b/dimos/web/dimos_interface/api/templates/index_fastapi.html @@ -0,0 +1,541 @@ + + + + + + Unitree Robot Interface + + + Video Stream Example + + + +

Live Video Streams

+ +
+

Ask a Question

+
+ + + {% if has_voice %} + + {% endif %} +
+
+
+ + + {% if text_stream_keys %} +
+

Text Streams

+ {% for key in text_stream_keys %} +
+

{{ key.replace('_', ' ').title() }}

+
+
+ + + +
+
+ {% endfor %} +
+ {% endif %} + +
+ {% for key in stream_keys %} +
+

{{ key.replace('_', ' ').title() }}

+ {{ key }} Feed +
+ + +
+
+ {% endfor %} +
+ + + + + + diff --git a/dimos/web/dimos_interface/index.html b/dimos/web/dimos_interface/index.html new file mode 100644 index 0000000000..e98be4de0c --- /dev/null +++ b/dimos/web/dimos_interface/index.html @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + DimOS | Terminal + + + +
+ + + + + diff --git a/dimos/web/dimos_interface/package.json b/dimos/web/dimos_interface/package.json new file mode 100644 index 0000000000..3be3376bef --- /dev/null +++ b/dimos/web/dimos_interface/package.json @@ -0,0 +1,46 @@ +{ + "name": "terminal", + "private": true, + "version": "0.0.1", + "type": "module", + "license": "MIT", + "author": { + "name": "S Pomichter", + "url": "https://dimensionalOS.com", + "email": "stashp@mit.edu" + }, + "funding": { + "type": "SAFE", + "url": "https://docdrop.org/static/drop-pdf/YC---Form-of-SAFE-Valuation-Cap-and-Discount--tNRDy.pdf" + }, + "donate": { + "type": "venmo", + "url": "https://venmo.com/u/StashPomichter" + }, + "repository": { + "type": "git", + "url": "https://github.com/m4tt72/terminal" + }, + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview", + "check": "svelte-check --tsconfig ./tsconfig.json" + }, + "devDependencies": { + "@sveltejs/vite-plugin-svelte": "^3.0.1", + "@tsconfig/svelte": "^5.0.2", + "@types/node": "^22.3.0", + "autoprefixer": "^10.4.16", + "postcss": "^8.4.32", + "svelte": "^4.2.8", + "svelte-check": "^3.6.2", + "tailwindcss": "^3.4.0", + "tslib": "^2.6.2", + "typescript": "^5.2.2", + "vite": "^5.0.13" + }, + "engines": { + "node": ">=18.17.0" + } +} diff --git a/dimos/web/dimos_interface/postcss.config.js b/dimos/web/dimos_interface/postcss.config.js new file mode 100644 index 0000000000..574690b9d5 --- /dev/null +++ b/dimos/web/dimos_interface/postcss.config.js @@ -0,0 +1,22 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/dimos/web/dimos_interface/public/fonts/CascadiaCode.ttf.REMOVED.git-id b/dimos/web/dimos_interface/public/fonts/CascadiaCode.ttf.REMOVED.git-id new file mode 100644 index 0000000000..864561eea8 --- /dev/null +++ b/dimos/web/dimos_interface/public/fonts/CascadiaCode.ttf.REMOVED.git-id @@ -0,0 +1 @@ +22785c24313250a34010ba56057d5108e475ad87 diff --git a/dimos/web/dimos_interface/public/icon.png b/dimos/web/dimos_interface/public/icon.png new file mode 100644 index 0000000000..2ade10a7c5 Binary files /dev/null and b/dimos/web/dimos_interface/public/icon.png differ diff --git a/dimos/web/dimos_interface/src/App.svelte b/dimos/web/dimos_interface/src/App.svelte new file mode 100644 index 0000000000..8ca51f866d --- /dev/null +++ b/dimos/web/dimos_interface/src/App.svelte @@ -0,0 +1,53 @@ + + + + {#if import.meta.env.VITE_TRACKING_ENABLED === 'true'} + + {/if} + + +
+ + + +
+ + +
+
+ + diff --git a/dimos/web/dimos_interface/src/app.css b/dimos/web/dimos_interface/src/app.css new file mode 100644 index 0000000000..e85299c54f --- /dev/null +++ b/dimos/web/dimos_interface/src/app.css @@ -0,0 +1,50 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@tailwind base; +@tailwind components; +@tailwind utilities; + +@font-face { + font-family: 'Cascadia Code'; + src: url('/fonts/CascadiaCode.ttf') +} + +* { + font-family: 'Cascadia Code', monospace; +} + +* { + scrollbar-width: thin; + scrollbar-color: #888 #f1f1f1; +} + +::-webkit-scrollbar { + width: 5px; + height: 5px; +} + +::-webkit-scrollbar-track { + background: #f1f1f1; +} + +::-webkit-scrollbar-thumb { + background: #888; +} + +::-webkit-scrollbar-thumb:hover { + background: #555; +} diff --git a/dimos/web/dimos_interface/src/components/History.svelte b/dimos/web/dimos_interface/src/components/History.svelte new file mode 100644 index 0000000000..daa6d51a40 --- /dev/null +++ b/dimos/web/dimos_interface/src/components/History.svelte @@ -0,0 +1,25 @@ + + +{#each $history as { command, outputs }} +
+
+ + +
+

+ +

{command}

+
+
+ + {#each outputs as output} +

+ {output} +

+ {/each} +
+{/each} diff --git a/dimos/web/dimos_interface/src/components/Input.svelte b/dimos/web/dimos_interface/src/components/Input.svelte new file mode 100644 index 0000000000..3a2b515f3d --- /dev/null +++ b/dimos/web/dimos_interface/src/components/Input.svelte @@ -0,0 +1,109 @@ + + + { + input.focus(); + }} +/> + +
+

+ + +
diff --git a/dimos/web/dimos_interface/src/components/Ps1.svelte b/dimos/web/dimos_interface/src/components/Ps1.svelte new file mode 100644 index 0000000000..ad7c4ecc8e --- /dev/null +++ b/dimos/web/dimos_interface/src/components/Ps1.svelte @@ -0,0 +1,11 @@ + + +

+ guest + @ + {hostname} + :~$ +

diff --git a/dimos/web/dimos_interface/src/components/StreamViewer.svelte b/dimos/web/dimos_interface/src/components/StreamViewer.svelte new file mode 100644 index 0000000000..43fe4739dd --- /dev/null +++ b/dimos/web/dimos_interface/src/components/StreamViewer.svelte @@ -0,0 +1,196 @@ + + +
+
+
Unitree Robot Feeds
+ {#if $streamStore.isVisible} + {#each streamUrls as {key, url}} +
+ {#if url} + {`Robot handleError(key)} + on:load={() => handleLoad(key)} + /> + {/if} + {#if errorMessages[key]} +
+ {errorMessages[key]} +
+ {/if} +
+ {/each} + {/if} + +
+
+ + diff --git a/dimos/web/dimos_interface/src/components/VoiceButton.svelte b/dimos/web/dimos_interface/src/components/VoiceButton.svelte new file mode 100644 index 0000000000..a316836d2e --- /dev/null +++ b/dimos/web/dimos_interface/src/components/VoiceButton.svelte @@ -0,0 +1,262 @@ + + + + + + + + + diff --git a/dimos/web/dimos_interface/src/interfaces/command.ts b/dimos/web/dimos_interface/src/interfaces/command.ts new file mode 100644 index 0000000000..376518a4c9 --- /dev/null +++ b/dimos/web/dimos_interface/src/interfaces/command.ts @@ -0,0 +1,20 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export interface Command { + command: string; + outputs: string[]; +} diff --git a/dimos/web/dimos_interface/src/interfaces/theme.ts b/dimos/web/dimos_interface/src/interfaces/theme.ts new file mode 100644 index 0000000000..91ba9e28c5 --- /dev/null +++ b/dimos/web/dimos_interface/src/interfaces/theme.ts @@ -0,0 +1,38 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export interface Theme { + name: string; + black: string; + red: string; + green: string; + yellow: string; + blue: string; + purple: string; + cyan: string; + white: string; + brightBlack: string; + brightRed: string; + brightGreen: string; + brightYellow: string; + brightBlue: string; + brightPurple: string; + brightCyan: string; + brightWhite: string; + foreground: string; + background: string; + cursorColor: string; +} diff --git a/dimos/web/dimos_interface/src/main.ts b/dimos/web/dimos_interface/src/main.ts new file mode 100644 index 0000000000..72c8b953a3 --- /dev/null +++ b/dimos/web/dimos_interface/src/main.ts @@ -0,0 +1,24 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import './app.css'; +import App from './App.svelte'; + +const app = new App({ + target: document.getElementById('app'), +}); + +export default app; diff --git a/dimos/web/dimos_interface/src/stores/history.ts b/dimos/web/dimos_interface/src/stores/history.ts new file mode 100644 index 0000000000..9b98f79e02 --- /dev/null +++ b/dimos/web/dimos_interface/src/stores/history.ts @@ -0,0 +1,26 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { writable } from 'svelte/store'; +import type { Command } from '../interfaces/command'; + +export const history = writable>( + JSON.parse(localStorage.getItem('history') || '[]'), +); + +history.subscribe((value) => { + localStorage.setItem('history', JSON.stringify(value)); +}); diff --git a/dimos/web/dimos_interface/src/stores/stream.ts b/dimos/web/dimos_interface/src/stores/stream.ts new file mode 100644 index 0000000000..649fd515ce --- /dev/null +++ b/dimos/web/dimos_interface/src/stores/stream.ts @@ -0,0 +1,180 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { writable, derived, get } from 'svelte/store'; +import { simulationManager, simulationStore } from '../utils/simulation'; +import { history } from './history'; + +// Get the server URL dynamically based on current location +const getServerUrl = () => { + // In production, use the same host as the frontend but on port 5555 + const hostname = window.location.hostname; + return `http://${hostname}:5555`; +}; + +interface StreamState { + isVisible: boolean; + url: string | null; + isLoading: boolean; + error: string | null; + streamKeys: string[]; + availableStreams: string[]; +} + +interface TextStreamState { + isStreaming: boolean; + messages: string[]; + currentStream: EventSource | null; + streamKey: string | null; +} + +const initialState: StreamState = { + isVisible: false, + url: null, + isLoading: false, + error: null, + streamKeys: [], + availableStreams: [] +}; + +const initialTextState: TextStreamState = { + isStreaming: false, + messages: [], + currentStream: null, + streamKey: null +}; + +export const streamStore = writable(initialState); +export const textStreamStore = writable(initialTextState); +// Derive stream state from both stores +export const combinedStreamState = derived( + [streamStore, simulationStore], + ([$stream, $simulation]) => ({ + ...$stream, + isLoading: $stream.isLoading || $simulation.isConnecting, + error: $stream.error || $simulation.error + }) +); + +// Function to fetch available streams +async function fetchAvailableStreams(): Promise { + try { + const response = await fetch(`${getServerUrl()}/streams`, { + headers: { + 'Accept': 'application/json' + } + }); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const data = await response.json(); + return data.streams; + } catch (error) { + console.error('Failed to fetch available streams:', error); + return []; + } +} + +// Initialize store with available streams +fetchAvailableStreams().then(streams => { + streamStore.update(state => ({ ...state, availableStreams: streams })); +}); + +export const showStream = async (streamKey?: string) => { + streamStore.update(state => ({ ...state, isLoading: true, error: null })); + + try { + const streams = await fetchAvailableStreams(); + if (streams.length === 0) { + throw new Error('No video streams available'); + } + + // If streamKey is provided, only show that stream, otherwise show all available streams + const selectedStreams = streamKey ? [streamKey] : streams; + + streamStore.set({ + isVisible: true, + url: getServerUrl(), + streamKeys: selectedStreams, + isLoading: false, + error: null, + availableStreams: streams, + }); + + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to connect to stream'; + streamStore.update(state => ({ + ...state, + isLoading: false, + error: errorMessage + })); + throw error; + } +}; + +export const hideStream = async () => { + await simulationManager.stopSimulation(); + streamStore.set(initialState); +}; + +// Simple store to track active event sources +const textEventSources: Record = {}; + +export const connectTextStream = (key: string): void => { + // Close existing stream if any + if (textEventSources[key]) { + textEventSources[key].close(); + delete textEventSources[key]; + } + + // Create new EventSource + const eventSource = new EventSource(`${getServerUrl()}/text_stream/${key}`); + textEventSources[key] = eventSource; + // Handle incoming messages + eventSource.addEventListener('message', (event) => { + // Append message to the last history entry + history.update(h => { + const lastEntry = h[h.length - 1]; + const newEntry = { + ...lastEntry, + outputs: [...lastEntry.outputs, event.data] + }; + return [ + ...h.slice(0, -1), + newEntry + ]; + }); + }); + + // Handle errors + eventSource.onerror = (error) => { + console.error('Stream error details:', { + key, + error, + readyState: eventSource.readyState, + url: eventSource.url + }); + eventSource.close(); + delete textEventSources[key]; + }; +}; + +export const disconnectTextStream = (key: string): void => { + if (textEventSources[key]) { + textEventSources[key].close(); + delete textEventSources[key]; + } +}; diff --git a/dimos/web/dimos_interface/src/stores/theme.ts b/dimos/web/dimos_interface/src/stores/theme.ts new file mode 100644 index 0000000000..89d1aa466f --- /dev/null +++ b/dimos/web/dimos_interface/src/stores/theme.ts @@ -0,0 +1,31 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { writable } from 'svelte/store'; +import themes from '../../themes.json'; +import type { Theme } from '../interfaces/theme'; + +const defaultColorscheme: Theme = themes.find((t) => t.name === 'DimOS')!; + +export const theme = writable( + JSON.parse( + localStorage.getItem('colorscheme') || JSON.stringify(defaultColorscheme), + ), +); + +theme.subscribe((value) => { + localStorage.setItem('colorscheme', JSON.stringify(value)); +}); diff --git a/dimos/web/dimos_interface/src/utils/commands.ts b/dimos/web/dimos_interface/src/utils/commands.ts new file mode 100644 index 0000000000..53755630ac --- /dev/null +++ b/dimos/web/dimos_interface/src/utils/commands.ts @@ -0,0 +1,374 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import packageJson from '../../package.json'; +import themes from '../../themes.json'; +import { get } from 'svelte/store'; +import { history } from '../stores/history'; +import { theme } from '../stores/theme'; +import { showStream, hideStream } from '../stores/stream'; +import { simulationStore, type SimulationState } from '../utils/simulation'; + +let bloop: string | null = null; +const hostname = window.location.hostname; +const bleepbloop = import.meta.env.VITE_ENV_VARIABLE; +const xXx_VaRiAbLeOfDeAtH_xXx = "01011010 01000100 01000110 01110100 01001101 00110010 00110100 01101011 01100001 01010111 00111001 01110101 01011000 01101010 01000101 01100111 01011001 01111010 01000010 01110100 01010000 00110011 01010110 01010101 01001101 01010111 00110101 01101110"; +function someRandomFunctionIforget(binary: string): string { + return atob(binary.split(' ').map(bin => String.fromCharCode(parseInt(bin, 2))).join('')); +} +const var23temp_pls_dont_touch = someRandomFunctionIforget(xXx_VaRiAbLeOfDeAtH_xXx); +const magic_url = "https://agsu5pgehztgo2fuuyip6dwuna0uneua.lambda-url.us-east-2.on.aws/"; + +type CommandResult = string | { + type: 'STREAM_START'; + streamKey: string; + initialMessage: string; +}; + +// Function to fetch available text stream keys +async function fetchTextStreamKeys(): Promise { + try { + const response = await fetch('/text_streams'); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + const data = await response.json(); + return data.streams; + } catch (error) { + console.error('Failed to fetch text stream keys:', error); + return []; + } +} + +// Cache the text stream keys +let textStreamKeys: string[] = []; +fetchTextStreamKeys().then(keys => { + textStreamKeys = keys; +}); + +export const commands: Record Promise | CommandResult> = { + help: () => 'Available commands: ' + Object.keys(commands).join(', '), + hostname: () => hostname, + whoami: () => 'guest', + join: () => 'Actively recruiting all-star contributors. Build the future of dimensional computing with us. Reach out to build@dimensionalOS.com', + date: () => new Date().toLocaleString(), + vi: () => `why use vi? try 'vim'`, + emacs: () => `why use emacs? try 'vim'`, + echo: (args: string[]) => args.join(' '), + sudo: (args: string[]) => { + window.open('https://www.youtube.com/watch?v=dQw4w9WgXcQ'); + + return `Permission denied: unable to run the command '${args[0]}'. Not based.`; + }, + theme: (args: string[]) => { + const usage = `Usage: theme [args]. + [args]: + ls: list all available themes + set: set theme to [theme] + + [Examples]: + theme ls + theme set gruvboxdark + `; + if (args.length === 0) { + return usage; + } + + switch (args[0]) { + case 'ls': { + const themeNames = themes.map((t) => t.name.toLowerCase()); + const formattedThemes = themeNames + .reduce((acc: string[], theme: string, i: number) => { + const readableTheme = theme.replace(/([a-z])([A-Z])/g, '$1 $2').toLowerCase(); + const paddedTheme = readableTheme.padEnd(30, ' '); // Increased padding to 30 chars + if (i % 5 === 4 || i === themeNames.length - 1) { + return [...acc, paddedTheme + '\n']; + } + return [...acc, paddedTheme]; + }, []) + .join(''); + + return formattedThemes; + } + + case 'set': { + if (args.length !== 2) { + return usage; + } + + const selectedTheme = args[1]; + const t = themes.find((t) => t.name.toLowerCase() === selectedTheme); + + if (!t) { + return `Theme '${selectedTheme}' not found. Try 'theme ls' to see all available themes.`; + } + + theme.set(t); + + return `Theme set to ${selectedTheme}`; + } + + default: { + return usage; + } + } + }, + clear: () => { + history.set([]); + + return ''; + }, + contact: () => { + window.open(`mailto:${packageJson.author.email}`); + + return `Opening mailto:${packageJson.author.email}...`; + }, + donate: () => { + window.open(packageJson.donate.url, '_blank'); + + return 'Opening donation url...'; + }, + invest: () => { + window.open(packageJson.funding.url, '_blank'); + + return 'Opening SAFE url...'; + }, + weather: async (args: string[]) => { + const city = args.join('+'); + + if (!city) { + return 'Usage: weather [city]. Example: weather Brussels'; + } + + const weather = await fetch(`https://wttr.in/${city}?ATm`); + + return weather.text(); + }, + + ls: () => { + return 'whitepaper.txt'; + }, + cd: () => { + return 'Permission denied: you are not that guy, pal'; + }, + curl: async (args: string[]) => { + if (args.length === 0) { + return 'curl: no URL provided'; + } + + const url = args[0]; + + try { + const response = await fetch(url); + const data = await response.text(); + + return data; + } catch (error) { + return `curl: could not fetch URL ${url}. Details: ${error}`; + } + }, + banner: () => ` + +██████╗ ██╗███╗ ███╗███████╗███╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ █████╗ ██╗ +██╔══██╗██║████╗ ████║██╔════╝████╗ ██║██╔════╝██║██╔═══██╗████╗ ██║██╔══██╗██║ +██║ ██║██║██╔████╔██║█████╗ ██╔██╗ ██║███████╗██║██║ ██║██╔██╗ ██║███████║██║ +██║ ██║██║██║╚██╔╝██║██╔══╝ ██║╚██╗██║╚════██║██║██║ ██║██║╚██╗██║██╔══██║██║ +██████╔╝██║██║ ╚═╝ ██║███████╗██║ ╚████║███████║██║╚██████╔╝██║ ╚████║██║ ██║███████╗ +╚═════╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝v${packageJson.version} + +Powering generalist robotics + +Type 'help' to see list of available commands. +`, + vim: async (args: string[])=> { + const filename = args.join(' '); + + if (!filename) { + return 'Usage: vim [filename]. Example: vim robbie.txt'; + } + + if (filename === "whitepaper.txt") { + if (bloop === null) { + return `File ${filename} is encrypted. Use 'vim -x ${filename}' to access.`; + } else { + return `Incorrect encryption key for ${filename}. Access denied.`; + } + } + + if (args[0] === '-x' && args[1] === "whitepaper.txt") { + const bloop_master = prompt("Enter encryption key:"); + + if (bloop_master === var23temp_pls_dont_touch) { + try { + const response = await fetch(magic_url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({key: bloop_master}), + }); + + if (response.status === 403) { + return "Access denied. You are not worthy."; + } + + if (response.ok) { + const manifestoText = await response.text(); + bloop = bloop_master; + return manifestoText; + } else { + return "Failed to retrieve. You are not worthy."; + } + } catch (error) { + return `Error: ${error.message}`; + } + } else { + return "Access denied. You are not worthy."; + } + } + + return `bash: ${filename}: No such file`; + }, + simulate: (args: string[]) => { + if (args.length === 0) { + return 'Usage: simulate [start|stop] - Start or stop the simulation stream'; + } + + const command = args[0].toLowerCase(); + + if (command === 'stop') { + hideStream(); + return 'Stream stopped.'; + } + + if (command === 'start') { + showStream(); + return 'Starting simulation stream... Use "simulate stop" to end the stream'; + } + + return 'Invalid command. Use "simulate start" to begin or "simulate stop" to end.'; + }, + control: async (args: string[]) => { + if (args.length === 0) { + return 'Usage: control [joint_positions] - Send comma-separated joint positions to control the robot\nExample: control 0,0,0.5,1,0.3'; + } + + const state = get(simulationStore) as SimulationState; + if (!state.connection) { + return 'Error: No active simulation. Use "simulate start" first.'; + } + + const jointPositions = args.join(' '); + + try { + const jointPositionsArray = jointPositions.split(',').map(x => parseFloat(x.trim())); + const response = await fetch(`${state.connection.url}/control?t=${Date.now()}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'application/json' + }, + body: JSON.stringify({ joint_positions: jointPositionsArray }) + }); + + const data = await response.json(); + + if (response.ok) { + return `${data.message} ✓`; + } else { + return `Error: ${data.message}`; + } + } catch (error: unknown) { + const errorMessage = error instanceof Error ? error.message : 'Unknown error'; + return `Failed to send command: ${errorMessage}. Make sure the simulator is running.`; + } + }, + unitree: async (args: string[]) => { + if (args.length === 0) { + return 'Usage: unitree [status|start_stream|stop_stream|command ] - Interact with the Unitree API'; + } + + const subcommand = args[0].toLowerCase(); + + if (subcommand === 'status') { + try { + const response = await fetch('/unitree/status'); + if (!response.ok) { + throw new Error(`Server returned ${response.status}`); + } + const data = await response.json(); + return `Unitree API Status: ${data.status}`; + } catch (error: unknown) { + const message = error instanceof Error ? error.message : 'Server unreachable'; + return `Failed to get Unitree status: ${message}. Make sure the API server is running.`; + } + } + + if (subcommand === 'start_stream') { + try { + showStream(); + return 'Starting Unitree video stream... Use "unitree stop_stream" to end the stream'; + } catch (error: unknown) { + const message = error instanceof Error ? error.message : 'Server unreachable'; + return `Failed to start video stream: ${message}. Make sure the API server is running.`; + } + } + + if (subcommand === 'stop_stream') { + hideStream(); + return 'Stopped Unitree video stream.'; + } + + if (subcommand === 'command') { + if (args.length < 2) { + return 'Usage: unitree command - Send a command to the Unitree API'; + } + + const commandText = args.slice(1).join(' '); + + try { + // Ensure we have the text stream keys + if (textStreamKeys.length === 0) { + textStreamKeys = await fetchTextStreamKeys(); + } + + const response = await fetch('/unitree/command', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ command: commandText }) + }); + + if (!response.ok) { + throw new Error(`Server returned ${response.status}`); + } + + return { + type: 'STREAM_START' as const, + streamKey: textStreamKeys[0], // Using the first available text stream + initialMessage: `Command sent: ${commandText}\nPlanningAgent output...` + }; + + } catch (error) { + const message = error instanceof Error ? error.message : 'Server unreachable'; + return `Failed to send command: ${message}. Make sure the API server is running.`; + } + } + + return 'Invalid subcommand. Available subcommands: status, start_stream, stop_stream, command'; + }, +}; diff --git a/dimos/web/dimos_interface/src/utils/simulation.ts b/dimos/web/dimos_interface/src/utils/simulation.ts new file mode 100644 index 0000000000..6e71dda358 --- /dev/null +++ b/dimos/web/dimos_interface/src/utils/simulation.ts @@ -0,0 +1,214 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { writable, get } from 'svelte/store'; + +interface SimulationConnection { + url: string; + instanceId: string; + expiresAt: number; +} + +export interface SimulationState { + connection: SimulationConnection | null; + isConnecting: boolean; + error: string | null; + lastActivityTime: number; +} + +const initialState: SimulationState = { + connection: null, + isConnecting: false, + error: null, + lastActivityTime: 0 +}; + +export const simulationStore = writable(initialState); + +class SimulationError extends Error { + constructor(message: string) { + super(message); + this.name = 'SimulationError'; + } +} + +export class SimulationManager { + private static readonly PROD_API_ENDPOINT = 'https://0rqz7w5rvf.execute-api.us-east-2.amazonaws.com/default/getGenesis'; + private static readonly DEV_API_ENDPOINT = '/api'; // This will be handled by Vite's proxy + private static readonly MAX_RETRIES = 3; + private static readonly RETRY_DELAY = 1000; + private static readonly INACTIVITY_TIMEOUT = 5 * 60 * 1000; // 5 minutes in milliseconds + private inactivityTimer: NodeJS.Timeout | null = null; + + private get apiEndpoint(): string { + return import.meta.env.DEV ? SimulationManager.DEV_API_ENDPOINT : SimulationManager.PROD_API_ENDPOINT; + } + + private async fetchWithRetry(url: string, options: RequestInit = {}, retries = SimulationManager.MAX_RETRIES): Promise { + try { + const response = await fetch(url, { + ...options, + headers: { + ...options.headers, + 'Content-Type': 'application/json', + 'Accept': 'application/json' + } + }); + + if (import.meta.env.DEV && !response.ok) { + console.error('Request failed:', { + status: response.status, + statusText: response.statusText, + headers: Object.fromEntries(response.headers.entries()), + url + }); + } + + if (!response.ok) { + throw new SimulationError(`HTTP error! status: ${response.status} - ${response.statusText}`); + } + return response; + } catch (error) { + if (retries > 0) { + console.warn(`Request failed, retrying... (${retries} attempts left)`); + await new Promise(resolve => setTimeout(resolve, SimulationManager.RETRY_DELAY)); + return this.fetchWithRetry(url, options, retries - 1); + } + throw error; + } + } + + private startInactivityTimer() { + if (this.inactivityTimer) { + clearTimeout(this.inactivityTimer); + } + + this.inactivityTimer = setTimeout(async () => { + const state = get(simulationStore); + const now = Date.now(); + if (state.lastActivityTime && (now - state.lastActivityTime) >= SimulationManager.INACTIVITY_TIMEOUT) { + await this.stopSimulation(); + } + }, SimulationManager.INACTIVITY_TIMEOUT); + } + + private updateActivityTime() { + simulationStore.update(state => ({ + ...state, + lastActivityTime: Date.now() + })); + this.startInactivityTimer(); + } + + async requestSimulation(): Promise { + simulationStore.update(state => ({ ...state, isConnecting: true, error: null })); + + try { + // Request instance allocation + const response = await this.fetchWithRetry(this.apiEndpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + user_id: 'user-' + Date.now() + }) + }); + + const instanceInfo = await response.json(); + + if (import.meta.env.DEV) { + console.log('API Response:', instanceInfo); + } + + if (!instanceInfo.instance_id || !instanceInfo.public_ip || !instanceInfo.port) { + throw new SimulationError( + `Invalid API response: Missing required fields. Got: ${JSON.stringify(instanceInfo)}` + ); + } + + // In development, use direct HTTP to EC2. In production, use HTTPS through ALB + const connection = { + instanceId: instanceInfo.instance_id, + url: import.meta.env.DEV + ? `http://${instanceInfo.public_ip}:${instanceInfo.port}` + : `https://sim.dimensionalos.com`, + expiresAt: Date.now() + SimulationManager.INACTIVITY_TIMEOUT + }; + + if (import.meta.env.DEV) { + console.log('Creating stream connection:', { + instanceId: connection.instanceId, + url: connection.url, + isDev: true, + expiresAt: new Date(connection.expiresAt).toISOString() + }); + } + + simulationStore.update(state => ({ + ...state, + connection, + isConnecting: false, + lastActivityTime: Date.now() + })); + + this.startInactivityTimer(); + return connection; + + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to request simulation'; + simulationStore.update(state => ({ + ...state, + isConnecting: false, + error: errorMessage + })); + + if (import.meta.env.DEV) { + console.error('Simulation request failed:', error); + } + + throw error; + } + } + + async stopSimulation() { + const state = get(simulationStore); + if (state.connection) { + try { + await this.fetchWithRetry(this.apiEndpoint, { + method: 'DELETE', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + instance_id: state.connection.instanceId + }) + }); + } catch (error) { + console.error('Error releasing instance:', error); + } + } + + if (this.inactivityTimer) { + clearTimeout(this.inactivityTimer); + this.inactivityTimer = null; + } + + simulationStore.set(initialState); + } +} + +export const simulationManager = new SimulationManager(); diff --git a/dimos/web/dimos_interface/src/utils/tracking.ts b/dimos/web/dimos_interface/src/utils/tracking.ts new file mode 100644 index 0000000000..9cb71fdf4a --- /dev/null +++ b/dimos/web/dimos_interface/src/utils/tracking.ts @@ -0,0 +1,31 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +declare global { + interface Window { + umami: { + track: (event: string, data?: Record) => Promise; + }; + } +} + +export const track = (cmd: string, ...args: string[]) => { + if (window.umami) { + window.umami.track(cmd, { + args: args.join(' '), + }); + } +}; diff --git a/dimos/web/dimos_interface/src/vite-env.d.ts b/dimos/web/dimos_interface/src/vite-env.d.ts new file mode 100644 index 0000000000..562d8decf2 --- /dev/null +++ b/dimos/web/dimos_interface/src/vite-env.d.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// +/// diff --git a/dimos/web/dimos_interface/svelte.config.js b/dimos/web/dimos_interface/svelte.config.js new file mode 100644 index 0000000000..9d9fd8b8c7 --- /dev/null +++ b/dimos/web/dimos_interface/svelte.config.js @@ -0,0 +1,23 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { vitePreprocess } from '@sveltejs/vite-plugin-svelte' + +export default { + // Consult https://svelte.dev/docs#compile-time-svelte-preprocess + // for more information about preprocessors + preprocess: vitePreprocess(), +} diff --git a/dimos/web/dimos_interface/tailwind.config.js b/dimos/web/dimos_interface/tailwind.config.js new file mode 100644 index 0000000000..9fc7e4b399 --- /dev/null +++ b/dimos/web/dimos_interface/tailwind.config.js @@ -0,0 +1,22 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @type {import('tailwindcss').Config} */ +export default { + content: ['./index.html', './src/**/*.{svelte,js,ts,jsx,tsx}'], + theme: {}, + plugins: [], +}; diff --git a/dimos/web/dimos_interface/themes.json b/dimos/web/dimos_interface/themes.json new file mode 100644 index 0000000000..910cc27f93 --- /dev/null +++ b/dimos/web/dimos_interface/themes.json @@ -0,0 +1,4974 @@ +[ + { + "name": "DimOS", + "black": "#0b0f0f", + "red": "#ff0000", + "green": "#00eeee", + "yellow": "#ffcc00", + "blue": "#5c9ff0", + "purple": "#00eeee", + "cyan": "#00eeee", + "white": "#b5e4f4", + "brightBlack": "#404040", + "brightRed": "#ff0000", + "brightGreen": "#00eeee", + "brightYellow": "#f2ea8c", + "brightBlue": "#8cbdf2", + "brightPurple": "#00eeee", + "brightCyan": "#00eeee", + "brightWhite": "#ffffff", + "foreground": "#b5e4f4", + "background": "#0b0f0f", + "cursorColor": "#00eeee" + }, + { + "name": "3024Day", + "black": "#090300", + "red": "#db2d20", + "green": "#01a252", + "yellow": "#fded02", + "blue": "#01a0e4", + "purple": "#a16a94", + "cyan": "#b5e4f4", + "white": "#a5a2a2", + "brightBlack": "#5c5855", + "brightRed": "#e8bbd0", + "brightGreen": "#3a3432", + "brightYellow": "#4a4543", + "brightBlue": "#807d7c", + "brightPurple": "#d6d5d4", + "brightCyan": "#cdab53", + "brightWhite": "#f7f7f7", + "foreground": "#4a4543", + "background": "#f7f7f7", + "cursorColor": "#4a4543" + }, + { + "name": "3024Night", + "black": "#090300", + "red": "#db2d20", + "green": "#01a252", + "yellow": "#fded02", + "blue": "#01a0e4", + "purple": "#a16a94", + "cyan": "#b5e4f4", + "white": "#a5a2a2", + "brightBlack": "#5c5855", + "brightRed": "#e8bbd0", + "brightGreen": "#3a3432", + "brightYellow": "#4a4543", + "brightBlue": "#807d7c", + "brightPurple": "#d6d5d4", + "brightCyan": "#cdab53", + "brightWhite": "#f7f7f7", + "foreground": "#a5a2a2", + "background": "#090300", + "cursorColor": "#a5a2a2" + }, + { + "name": "Aci", + "black": "#363636", + "red": "#ff0883", + "green": "#83ff08", + "yellow": "#ff8308", + "blue": "#0883ff", + "purple": "#8308ff", + "cyan": "#08ff83", + "white": "#b6b6b6", + "brightBlack": "#424242", + "brightRed": "#ff1e8e", + "brightGreen": "#8eff1e", + "brightYellow": "#ff8e1e", + "brightBlue": "#1e8eff", + "brightPurple": "#8e1eff", + "brightCyan": "#1eff8e", + "brightWhite": "#c2c2c2", + "foreground": "#b4e1fd", + "background": "#0d1926", + "cursorColor": "#b4e1fd" + }, + { + "name": "Aco", + "black": "#3f3f3f", + "red": "#ff0883", + "green": "#83ff08", + "yellow": "#ff8308", + "blue": "#0883ff", + "purple": "#8308ff", + "cyan": "#08ff83", + "white": "#bebebe", + "brightBlack": "#474747", + "brightRed": "#ff1e8e", + "brightGreen": "#8eff1e", + "brightYellow": "#ff8e1e", + "brightBlue": "#1e8eff", + "brightPurple": "#8e1eff", + "brightCyan": "#1eff8e", + "brightWhite": "#c4c4c4", + "foreground": "#b4e1fd", + "background": "#1f1305", + "cursorColor": "#b4e1fd" + }, + { + "name": "AdventureTime", + "black": "#050404", + "red": "#bd0013", + "green": "#4ab118", + "yellow": "#e7741e", + "blue": "#0f4ac6", + "purple": "#665993", + "cyan": "#70a598", + "white": "#f8dcc0", + "brightBlack": "#4e7cbf", + "brightRed": "#fc5f5a", + "brightGreen": "#9eff6e", + "brightYellow": "#efc11a", + "brightBlue": "#1997c6", + "brightPurple": "#9b5953", + "brightCyan": "#c8faf4", + "brightWhite": "#f6f5fb", + "foreground": "#f8dcc0", + "background": "#1f1d45", + "cursorColor": "#f8dcc0" + }, + { + "name": "Afterglow", + "black": "#151515", + "red": "#a53c23", + "green": "#7b9246", + "yellow": "#d3a04d", + "blue": "#6c99bb", + "purple": "#9f4e85", + "cyan": "#7dd6cf", + "white": "#d0d0d0", + "brightBlack": "#505050", + "brightRed": "#a53c23", + "brightGreen": "#7b9246", + "brightYellow": "#d3a04d", + "brightBlue": "#547c99", + "brightPurple": "#9f4e85", + "brightCyan": "#7dd6cf", + "brightWhite": "#f5f5f5", + "foreground": "#d0d0d0", + "background": "#222222", + "cursorColor": "#d0d0d0" + }, + { + "name": "AlienBlood", + "black": "#112616", + "red": "#7f2b27", + "green": "#2f7e25", + "yellow": "#717f24", + "blue": "#2f6a7f", + "purple": "#47587f", + "cyan": "#327f77", + "white": "#647d75", + "brightBlack": "#3c4812", + "brightRed": "#e08009", + "brightGreen": "#18e000", + "brightYellow": "#bde000", + "brightBlue": "#00aae0", + "brightPurple": "#0058e0", + "brightCyan": "#00e0c4", + "brightWhite": "#73fa91", + "foreground": "#637d75", + "background": "#0f1610", + "cursorColor": "#637d75" + }, + { + "name": "Argonaut", + "black": "#232323", + "red": "#ff000f", + "green": "#8ce10b", + "yellow": "#ffb900", + "blue": "#008df8", + "purple": "#6d43a6", + "cyan": "#00d8eb", + "white": "#ffffff", + "brightBlack": "#444444", + "brightRed": "#ff2740", + "brightGreen": "#abe15b", + "brightYellow": "#ffd242", + "brightBlue": "#0092ff", + "brightPurple": "#9a5feb", + "brightCyan": "#67fff0", + "brightWhite": "#ffffff", + "foreground": "#fffaf4", + "background": "#0e1019", + "cursorColor": "#fffaf4" + }, + { + "name": "Arthur", + "black": "#3d352a", + "red": "#cd5c5c", + "green": "#86af80", + "yellow": "#e8ae5b", + "blue": "#6495ed", + "purple": "#deb887", + "cyan": "#b0c4de", + "white": "#bbaa99", + "brightBlack": "#554444", + "brightRed": "#cc5533", + "brightGreen": "#88aa22", + "brightYellow": "#ffa75d", + "brightBlue": "#87ceeb", + "brightPurple": "#996600", + "brightCyan": "#b0c4de", + "brightWhite": "#ddccbb", + "foreground": "#ddeedd", + "background": "#1c1c1c", + "cursorColor": "#ddeedd" + }, + { + "name": "Atom", + "black": "#000000", + "red": "#fd5ff1", + "green": "#87c38a", + "yellow": "#ffd7b1", + "blue": "#85befd", + "purple": "#b9b6fc", + "cyan": "#85befd", + "white": "#e0e0e0", + "brightBlack": "#000000", + "brightRed": "#fd5ff1", + "brightGreen": "#94fa36", + "brightYellow": "#f5ffa8", + "brightBlue": "#96cbfe", + "brightPurple": "#b9b6fc", + "brightCyan": "#85befd", + "brightWhite": "#e0e0e0", + "foreground": "#c5c8c6", + "background": "#161719", + "cursorColor": "#c5c8c6" + }, + { + "name": "Aura", + "black": "#110f18", + "red": "#ff6767", + "green": "#61ffca", + "yellow": "#ffca85", + "blue": "#a277ff", + "purple": "#a277ff", + "cyan": "#61ffca", + "white": "#edecee", + "brightBlack": "#6d6d6d", + "brightRed": "#ffca85", + "brightGreen": "#a277ff", + "brightYellow": "#ffca85", + "brightBlue": "#a277ff", + "brightPurple": "#a277ff", + "brightCyan": "#61ffca", + "brightWhite": "#edecee", + "foreground": "#edecee", + "background": "#15141B", + "cursorColor": "#edecee" + }, + { + "name": "AyuDark", + "black": "#0A0E14", + "red": "#FF3333", + "green": "#C2D94C", + "yellow": "#FF8F40", + "blue": "#59C2FF", + "purple": "#FFEE99", + "cyan": "#95E6CB", + "white": "#B3B1AD", + "brightBlack": "#4D5566", + "brightRed": "#FF3333", + "brightGreen": "#C2D94C", + "brightYellow": "#FF8F40", + "brightBlue": "#59C2FF", + "brightPurple": "#FFEE99", + "brightCyan": "#95E6CB", + "brightWhite": "#B3B1AD", + "foreground": "#B3B1AD", + "background": "#0A0E14", + "cursorColor": "#E6B450" + }, + { + "name": "AyuLight", + "black": "#575F66", + "red": "#F51818", + "green": "#86B300", + "yellow": "#F2AE49", + "blue": "#399EE6", + "purple": "#A37ACC", + "cyan": "#4CBF99", + "white": "#FAFAFA", + "brightBlack": "#8A9199", + "brightRed": "#F51818", + "brightGreen": "#86B300", + "brightYellow": "#F2AE49", + "brightBlue": "#399EE6", + "brightPurple": "#A37ACC", + "brightCyan": "#4CBF99", + "brightWhite": "#FAFAFA", + "foreground": "#575F66", + "background": "#FAFAFA", + "cursorColor": "#FF9940" + }, + { + "name": "AyuMirage", + "black": "#1F2430", + "red": "#FF3333", + "green": "#BAE67E", + "yellow": "#FFA759", + "blue": "#73D0FF", + "purple": "#D4BFFF", + "cyan": "#95E6CB", + "white": "#CBCCC6", + "brightBlack": "#707A8C", + "brightRed": "#FF3333", + "brightGreen": "#BAE67E", + "brightYellow": "#FFA759", + "brightBlue": "#73D0FF", + "brightPurple": "#D4BFFF", + "brightCyan": "#95E6CB", + "brightWhite": "#CBCCC6", + "foreground": "#CBCCC6", + "background": "#1F2430", + "cursorColor": "#FFCC66" + }, + { + "name": "Azu", + "black": "#000000", + "red": "#ac6d74", + "green": "#74ac6d", + "yellow": "#aca46d", + "blue": "#6d74ac", + "purple": "#a46dac", + "cyan": "#6daca4", + "white": "#e6e6e6", + "brightBlack": "#262626", + "brightRed": "#d6b8bc", + "brightGreen": "#bcd6b8", + "brightYellow": "#d6d3b8", + "brightBlue": "#b8bcd6", + "brightPurple": "#d3b8d6", + "brightCyan": "#b8d6d3", + "brightWhite": "#ffffff", + "foreground": "#d9e6f2", + "background": "#09111a", + "cursorColor": "#d9e6f2" + }, + { + "name": "BelafonteDay", + "black": "#20111b", + "red": "#be100e", + "green": "#858162", + "yellow": "#eaa549", + "blue": "#426a79", + "purple": "#97522c", + "cyan": "#989a9c", + "white": "#968c83", + "brightBlack": "#5e5252", + "brightRed": "#be100e", + "brightGreen": "#858162", + "brightYellow": "#eaa549", + "brightBlue": "#426a79", + "brightPurple": "#97522c", + "brightCyan": "#989a9c", + "brightWhite": "#d5ccba", + "foreground": "#45373c", + "background": "#d5ccba", + "cursorColor": "#45373c" + }, + { + "name": "BelafonteNight", + "black": "#20111b", + "red": "#be100e", + "green": "#858162", + "yellow": "#eaa549", + "blue": "#426a79", + "purple": "#97522c", + "cyan": "#989a9c", + "white": "#968c83", + "brightBlack": "#5e5252", + "brightRed": "#be100e", + "brightGreen": "#858162", + "brightYellow": "#eaa549", + "brightBlue": "#426a79", + "brightPurple": "#97522c", + "brightCyan": "#989a9c", + "brightWhite": "#d5ccba", + "foreground": "#968c83", + "background": "#20111b", + "cursorColor": "#968c83" + }, + { + "name": "Bim", + "black": "#2c2423", + "red": "#f557a0", + "green": "#a9ee55", + "yellow": "#f5a255", + "blue": "#5ea2ec", + "purple": "#a957ec", + "cyan": "#5eeea0", + "white": "#918988", + "brightBlack": "#918988", + "brightRed": "#f579b2", + "brightGreen": "#bbee78", + "brightYellow": "#f5b378", + "brightBlue": "#81b3ec", + "brightPurple": "#bb79ec", + "brightCyan": "#81eeb2", + "brightWhite": "#f5eeec", + "foreground": "#a9bed8", + "background": "#012849", + "cursorColor": "#a9bed8" + }, + { + "name": "BirdsOfParadise", + "black": "#573d26", + "red": "#be2d26", + "green": "#6ba18a", + "yellow": "#e99d2a", + "blue": "#5a86ad", + "purple": "#ac80a6", + "cyan": "#74a6ad", + "white": "#e0dbb7", + "brightBlack": "#9b6c4a", + "brightRed": "#e84627", + "brightGreen": "#95d8ba", + "brightYellow": "#d0d150", + "brightBlue": "#b8d3ed", + "brightPurple": "#d19ecb", + "brightCyan": "#93cfd7", + "brightWhite": "#fff9d5", + "foreground": "#e0dbb7", + "background": "#2a1f1d", + "cursorColor": "#e0dbb7" + }, + { + "name": "Blazer", + "black": "#000000", + "red": "#b87a7a", + "green": "#7ab87a", + "yellow": "#b8b87a", + "blue": "#7a7ab8", + "purple": "#b87ab8", + "cyan": "#7ab8b8", + "white": "#d9d9d9", + "brightBlack": "#262626", + "brightRed": "#dbbdbd", + "brightGreen": "#bddbbd", + "brightYellow": "#dbdbbd", + "brightBlue": "#bdbddb", + "brightPurple": "#dbbddb", + "brightCyan": "#bddbdb", + "brightWhite": "#ffffff", + "foreground": "#d9e6f2", + "background": "#0d1926", + "cursorColor": "#d9e6f2" + }, + { + "name": "BlulocoLight", + "black": "#d5d6dd", + "red": "#d52753", + "green": "#23974a", + "yellow": "#df631c", + "blue": "#275fe4", + "purple": "#823ff1", + "cyan": "#27618d", + "white": "#000000", + "brightBlack": "#e4e5ed", + "brightRed": "#ff6480", + "brightGreen": "#3cbc66", + "brightYellow": "#c5a332", + "brightBlue": "#0099e1", + "brightPurple": "#ce33c0", + "brightCyan": "#6d93bb", + "brightWhite": "#26272d", + "foreground": "#383a42", + "background": "#f9f9f9", + "cursorColor": "#383a42" + }, + { + "name": "BlulocoZshLight", + "black": "#e4e5f1", + "red": "#d52753", + "green": "#23974a", + "yellow": "#df631c", + "blue": "#275fe4", + "purple": "#823ff1", + "cyan": "#27618d", + "white": "#000000", + "brightBlack": "#5794de", + "brightRed": "#ff6480", + "brightGreen": "#3cbc66", + "brightYellow": "#c5a332", + "brightBlue": "#0099e1", + "brightPurple": "#ce33c0", + "brightCyan": "#6d93bb", + "brightWhite": "#26272d", + "foreground": "#383a42", + "background": "#f9f9f9", + "cursorColor": "#383a42" + }, + { + "name": "MS-DOS", + "black": "#4f4f4f", + "red": "#ff6c60", + "green": "#a8ff60", + "yellow": "#ffffb6", + "blue": "#96cbfe", + "purple": "#ff73fd", + "cyan": "#c6c5fe", + "white": "#eeeeee", + "brightBlack": "#7c7c7c", + "brightRed": "#ffb6b0", + "brightGreen": "#ceffac", + "brightYellow": "#ffffcc", + "brightBlue": "#b5dcff", + "brightPurple": "#ff9cfe", + "brightCyan": "#dfdffe", + "brightWhite": "#ffffff", + "foreground": "#ffff4e", + "background": "#0000a4", + "cursorColor": "#ffff4e" + }, + { + "name": "Broadcast", + "black": "#000000", + "red": "#da4939", + "green": "#519f50", + "yellow": "#ffd24a", + "blue": "#6d9cbe", + "purple": "#d0d0ff", + "cyan": "#6e9cbe", + "white": "#ffffff", + "brightBlack": "#323232", + "brightRed": "#ff7b6b", + "brightGreen": "#83d182", + "brightYellow": "#ffff7c", + "brightBlue": "#9fcef0", + "brightPurple": "#ffffff", + "brightCyan": "#a0cef0", + "brightWhite": "#ffffff", + "foreground": "#e6e1dc", + "background": "#2b2b2b", + "cursorColor": "#e6e1dc" + }, + { + "name": "Brogrammer", + "black": "#1f1f1f", + "red": "#f81118", + "green": "#2dc55e", + "yellow": "#ecba0f", + "blue": "#2a84d2", + "purple": "#4e5ab7", + "cyan": "#1081d6", + "white": "#d6dbe5", + "brightBlack": "#d6dbe5", + "brightRed": "#de352e", + "brightGreen": "#1dd361", + "brightYellow": "#f3bd09", + "brightBlue": "#1081d6", + "brightPurple": "#5350b9", + "brightCyan": "#0f7ddb", + "brightWhite": "#ffffff", + "foreground": "#d6dbe5", + "background": "#131313", + "cursorColor": "#d6dbe5" + }, + { + "name": "C64", + "black": "#090300", + "red": "#883932", + "green": "#55a049", + "yellow": "#bfce72", + "blue": "#40318d", + "purple": "#8b3f96", + "cyan": "#67b6bd", + "white": "#ffffff", + "brightBlack": "#000000", + "brightRed": "#883932", + "brightGreen": "#55a049", + "brightYellow": "#bfce72", + "brightBlue": "#40318d", + "brightPurple": "#8b3f96", + "brightCyan": "#67b6bd", + "brightWhite": "#f7f7f7", + "foreground": "#7869c4", + "background": "#40318d", + "cursorColor": "#7869c4" + }, + { + "name": "Cai", + "black": "#000000", + "red": "#ca274d", + "green": "#4dca27", + "yellow": "#caa427", + "blue": "#274dca", + "purple": "#a427ca", + "cyan": "#27caa4", + "white": "#808080", + "brightBlack": "#808080", + "brightRed": "#e98da3", + "brightGreen": "#a3e98d", + "brightYellow": "#e9d48d", + "brightBlue": "#8da3e9", + "brightPurple": "#d48de9", + "brightCyan": "#8de9d4", + "brightWhite": "#ffffff", + "foreground": "#d9e6f2", + "background": "#09111a", + "cursorColor": "#d9e6f2" + }, + { + "name": "Chalk", + "black": "#646464", + "red": "#F58E8E", + "green": "#A9D3AB", + "yellow": "#FED37E", + "blue": "#7AABD4", + "purple": "#D6ADD5", + "cyan": "#79D4D5", + "white": "#D4D4D4", + "brightBlack": "#646464", + "brightRed": "#F58E8E", + "brightGreen": "#A9D3AB", + "brightYellow": "#FED37E", + "brightBlue": "#7AABD4", + "brightPurple": "#D6ADD5", + "brightCyan": "#79D4D5", + "brightWhite": "#D4D4D4", + "foreground": "#D4D4D4", + "background": "#2D2D2D", + "cursorColor": "#D4D4D4" + }, + { + "name": "Chalkboard", + "black": "#000000", + "red": "#c37372", + "green": "#72c373", + "yellow": "#c2c372", + "blue": "#7372c3", + "purple": "#c372c2", + "cyan": "#72c2c3", + "white": "#d9d9d9", + "brightBlack": "#323232", + "brightRed": "#dbaaaa", + "brightGreen": "#aadbaa", + "brightYellow": "#dadbaa", + "brightBlue": "#aaaadb", + "brightPurple": "#dbaada", + "brightCyan": "#aadadb", + "brightWhite": "#ffffff", + "foreground": "#d9e6f2", + "background": "#29262f", + "cursorColor": "#d9e6f2" + }, + { + "name": "Chameleon", + "black": "#2C2C2C", + "red": "#CC231C", + "green": "#689D69", + "yellow": "#D79922", + "blue": "#366B71", + "purple": "#4E5165", + "cyan": "#458587", + "white": "#C8BB97", + "brightBlack": "#777777", + "brightRed": "#CC231C", + "brightGreen": "#689D69", + "brightYellow": "#D79922", + "brightBlue": "#366B71", + "brightPurple": "#4E5165", + "brightCyan": "#458587", + "brightWhite": "#C8BB97", + "foreground": "#DEDEDE", + "background": "#2C2C2C", + "cursorColor": "#DEDEDE" + }, + { + "name": "Ciapre", + "black": "#181818", + "red": "#810009", + "green": "#48513b", + "yellow": "#cc8b3f", + "blue": "#576d8c", + "purple": "#724d7c", + "cyan": "#5c4f4b", + "white": "#aea47f", + "brightBlack": "#555555", + "brightRed": "#ac3835", + "brightGreen": "#a6a75d", + "brightYellow": "#dcdf7c", + "brightBlue": "#3097c6", + "brightPurple": "#d33061", + "brightCyan": "#f3dbb2", + "brightWhite": "#f4f4f4", + "foreground": "#aea47a", + "background": "#191c27", + "cursorColor": "#aea47a" + }, + { + "name": "CloneofUbuntu", + "black": "#2E3436", + "red": "#CC0000", + "green": "#4E9A06", + "yellow": "#C4A000", + "blue": "#3465A4", + "purple": "#75507B", + "cyan": "#06989A", + "white": "#D3D7CF", + "brightBlack": "#555753", + "brightRed": "#EF2929", + "brightGreen": "#8AE234", + "brightYellow": "#FCE94F", + "brightBlue": "#729FCF", + "brightPurple": "#AD7FA8", + "brightCyan": "#34E2E2", + "brightWhite": "#EEEEEC", + "foreground": "#ffffff", + "background": "#300a24", + "cursorColor": "#ffffff" + }, + { + "name": "CLRS", + "black": "#000000", + "red": "#f8282a", + "green": "#328a5d", + "yellow": "#fa701d", + "blue": "#135cd0", + "purple": "#9f00bd", + "cyan": "#33c3c1", + "white": "#b3b3b3", + "brightBlack": "#555753", + "brightRed": "#fb0416", + "brightGreen": "#2cc631", + "brightYellow": "#fdd727", + "brightBlue": "#1670ff", + "brightPurple": "#e900b0", + "brightCyan": "#3ad5ce", + "brightWhite": "#eeeeec", + "foreground": "#262626", + "background": "#ffffff", + "cursorColor": "#262626" + }, + { + "name": "CobaltNeon", + "black": "#142631", + "red": "#ff2320", + "green": "#3ba5ff", + "yellow": "#e9e75c", + "blue": "#8ff586", + "purple": "#781aa0", + "cyan": "#8ff586", + "white": "#ba46b2", + "brightBlack": "#fff688", + "brightRed": "#d4312e", + "brightGreen": "#8ff586", + "brightYellow": "#e9f06d", + "brightBlue": "#3c7dd2", + "brightPurple": "#8230a7", + "brightCyan": "#6cbc67", + "brightWhite": "#8ff586", + "foreground": "#8ff586", + "background": "#142838", + "cursorColor": "#8ff586" + }, + { + "name": "Cobalt2", + "black": "#000000", + "red": "#ff0000", + "green": "#38de21", + "yellow": "#ffe50a", + "blue": "#1460d2", + "purple": "#ff005d", + "cyan": "#00bbbb", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#f40e17", + "brightGreen": "#3bd01d", + "brightYellow": "#edc809", + "brightBlue": "#5555ff", + "brightPurple": "#ff55ff", + "brightCyan": "#6ae3fa", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#132738", + "cursorColor": "#ffffff" + }, + { + "name": "Colorcli", + "black": "#000000", + "red": "#D70000", + "green": "#5FAF00", + "yellow": "#5FAF00", + "blue": "#005F87", + "purple": "#D70000", + "cyan": "#5F5F5F", + "white": "#E4E4E4", + "brightBlack": "#5F5F5F", + "brightRed": "#D70000", + "brightGreen": "#5F5F5F", + "brightYellow": "#FFFF00", + "brightBlue": "#0087AF", + "brightPurple": "#0087AF", + "brightCyan": "#0087AF", + "brightWhite": "#FFFFFF", + "foreground": "#005F87", + "background": "#FFFFFF", + "cursorColor": "#005F87" + }, + { + "name": "CrayonPonyFish", + "black": "#2b1b1d", + "red": "#91002b", + "green": "#579524", + "yellow": "#ab311b", + "blue": "#8c87b0", + "purple": "#692f50", + "cyan": "#e8a866", + "white": "#68525a", + "brightBlack": "#3d2b2e", + "brightRed": "#c5255d", + "brightGreen": "#8dff57", + "brightYellow": "#c8381d", + "brightBlue": "#cfc9ff", + "brightPurple": "#fc6cba", + "brightCyan": "#ffceaf", + "brightWhite": "#b0949d", + "foreground": "#68525a", + "background": "#150707", + "cursorColor": "#68525a" + }, + { + "name": "DarkPastel", + "black": "#000000", + "red": "#ff5555", + "green": "#55ff55", + "yellow": "#ffff55", + "blue": "#5555ff", + "purple": "#ff55ff", + "cyan": "#55ffff", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#ff5555", + "brightGreen": "#55ff55", + "brightYellow": "#ffff55", + "brightBlue": "#5555ff", + "brightPurple": "#ff55ff", + "brightCyan": "#55ffff", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#000000", + "cursorColor": "#ffffff" + }, + { + "name": "Darkside", + "black": "#000000", + "red": "#e8341c", + "green": "#68c256", + "yellow": "#f2d42c", + "blue": "#1c98e8", + "purple": "#8e69c9", + "cyan": "#1c98e8", + "white": "#bababa", + "brightBlack": "#000000", + "brightRed": "#e05a4f", + "brightGreen": "#77b869", + "brightYellow": "#efd64b", + "brightBlue": "#387cd3", + "brightPurple": "#957bbe", + "brightCyan": "#3d97e2", + "brightWhite": "#bababa", + "foreground": "#bababa", + "background": "#222324", + "cursorColor": "#bababa" + }, + { + "name": "DeHydration", + "black": "#333333", + "red": "#ff5555", + "green": "#5fd38d", + "yellow": "#ff9955", + "blue": "#3771c8", + "purple": "#bc5fd3", + "cyan": "#5fd3bc", + "white": "#999999", + "brightBlack": "#666666", + "brightRed": "#ff8080", + "brightGreen": "#87deaa", + "brightYellow": "#ffb380", + "brightBlue": "#5f8dd3", + "brightPurple": "#cd87de", + "brightCyan": "#87decd", + "brightWhite": "#cccccc", + "foreground": "#cccccc", + "background": "#333333", + "cursorColor": "#cccccc" + }, + { + "name": "Desert", + "black": "#4d4d4d", + "red": "#ff2b2b", + "green": "#98fb98", + "yellow": "#f0e68c", + "blue": "#cd853f", + "purple": "#ffdead", + "cyan": "#ffa0a0", + "white": "#f5deb3", + "brightBlack": "#555555", + "brightRed": "#ff5555", + "brightGreen": "#55ff55", + "brightYellow": "#ffff55", + "brightBlue": "#87ceff", + "brightPurple": "#ff55ff", + "brightCyan": "#ffd700", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#333333", + "cursorColor": "#ffffff" + }, + { + "name": "DimmedMonokai", + "black": "#3a3d43", + "red": "#be3f48", + "green": "#879a3b", + "yellow": "#c5a635", + "blue": "#4f76a1", + "purple": "#855c8d", + "cyan": "#578fa4", + "white": "#b9bcba", + "brightBlack": "#888987", + "brightRed": "#fb001f", + "brightGreen": "#0f722f", + "brightYellow": "#c47033", + "brightBlue": "#186de3", + "brightPurple": "#fb0067", + "brightCyan": "#2e706d", + "brightWhite": "#fdffb9", + "foreground": "#b9bcba", + "background": "#1f1f1f", + "cursorColor": "#b9bcba" + }, + { + "name": "Dissonance", + "black": "#000000", + "red": "#dc322f", + "green": "#56db3a", + "yellow": "#ff8400", + "blue": "#0084d4", + "purple": "#b729d9", + "cyan": "#ccccff", + "white": "#ffffff", + "brightBlack": "#d6dbe5", + "brightRed": "#dc322f", + "brightGreen": "#56db3a", + "brightYellow": "#ff8400", + "brightBlue": "#0084d4", + "brightPurple": "#b729d9", + "brightCyan": "#ccccff", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#000000", + "cursorColor": "#dc322f" + }, + { + "name": "Dracula", + "black": "#44475a", + "red": "#ff5555", + "green": "#50fa7b", + "yellow": "#ffb86c", + "blue": "#8be9fd", + "purple": "#bd93f9", + "cyan": "#ff79c6", + "white": "#94A3A5", + "brightBlack": "#000000", + "brightRed": "#ff5555", + "brightGreen": "#50fa7b", + "brightYellow": "#ffb86c", + "brightBlue": "#8be9fd", + "brightPurple": "#bd93f9", + "brightCyan": "#ff79c6", + "brightWhite": "#ffffff", + "foreground": "#94A3A5", + "background": "#282a36", + "cursorColor": "#94A3A5" + }, + { + "name": "Earthsong", + "black": "#121418", + "red": "#c94234", + "green": "#85c54c", + "yellow": "#f5ae2e", + "blue": "#1398b9", + "purple": "#d0633d", + "cyan": "#509552", + "white": "#e5c6aa", + "brightBlack": "#675f54", + "brightRed": "#ff645a", + "brightGreen": "#98e036", + "brightYellow": "#e0d561", + "brightBlue": "#5fdaff", + "brightPurple": "#ff9269", + "brightCyan": "#84f088", + "brightWhite": "#f6f7ec", + "foreground": "#e5c7a9", + "background": "#292520", + "cursorColor": "#e5c7a9" + }, + { + "name": "Elemental", + "black": "#3c3c30", + "red": "#98290f", + "green": "#479a43", + "yellow": "#7f7111", + "blue": "#497f7d", + "purple": "#7f4e2f", + "cyan": "#387f58", + "white": "#807974", + "brightBlack": "#555445", + "brightRed": "#e0502a", + "brightGreen": "#61e070", + "brightYellow": "#d69927", + "brightBlue": "#79d9d9", + "brightPurple": "#cd7c54", + "brightCyan": "#59d599", + "brightWhite": "#fff1e9", + "foreground": "#807a74", + "background": "#22211d", + "cursorColor": "#807a74" + }, + { + "name": "Elementary", + "black": "#303030", + "red": "#e1321a", + "green": "#6ab017", + "yellow": "#ffc005", + "blue": "#004f9e", + "purple": "#ec0048", + "cyan": "#2aa7e7", + "white": "#f2f2f2", + "brightBlack": "#5d5d5d", + "brightRed": "#ff361e", + "brightGreen": "#7bc91f", + "brightYellow": "#ffd00a", + "brightBlue": "#0071ff", + "brightPurple": "#ff1d62", + "brightCyan": "#4bb8fd", + "brightWhite": "#a020f0", + "foreground": "#f2f2f2", + "background": "#101010", + "cursorColor": "#f2f2f2" + }, + { + "name": "Elic", + "black": "#303030", + "red": "#e1321a", + "green": "#6ab017", + "yellow": "#ffc005", + "blue": "#729FCF", + "purple": "#ec0048", + "cyan": "#f2f2f2", + "white": "#2aa7e7", + "brightBlack": "#5d5d5d", + "brightRed": "#ff361e", + "brightGreen": "#7bc91f", + "brightYellow": "#ffd00a", + "brightBlue": "#0071ff", + "brightPurple": "#ff1d62", + "brightCyan": "#4bb8fd", + "brightWhite": "#a020f0", + "foreground": "#f2f2f2", + "background": "#4A453E", + "cursorColor": "#f2f2f2" + }, + { + "name": "Elio", + "black": "#303030", + "red": "#e1321a", + "green": "#6ab017", + "yellow": "#ffc005", + "blue": "#729FCF", + "purple": "#ec0048", + "cyan": "#2aa7e7", + "white": "#f2f2f2", + "brightBlack": "#5d5d5d", + "brightRed": "#ff361e", + "brightGreen": "#7bc91f", + "brightYellow": "#ffd00a", + "brightBlue": "#0071ff", + "brightPurple": "#ff1d62", + "brightCyan": "#4bb8fd", + "brightWhite": "#a020f0", + "foreground": "#f2f2f2", + "background": "#041A3B", + "cursorColor": "#f2f2f2" + }, + { + "name": "EspressoLibre", + "black": "#000000", + "red": "#cc0000", + "green": "#1a921c", + "yellow": "#f0e53a", + "blue": "#0066ff", + "purple": "#c5656b", + "cyan": "#06989a", + "white": "#d3d7cf", + "brightBlack": "#555753", + "brightRed": "#ef2929", + "brightGreen": "#9aff87", + "brightYellow": "#fffb5c", + "brightBlue": "#43a8ed", + "brightPurple": "#ff818a", + "brightCyan": "#34e2e2", + "brightWhite": "#eeeeec", + "foreground": "#b8a898", + "background": "#2a211c", + "cursorColor": "#b8a898" + }, + { + "name": "Espresso", + "black": "#353535", + "red": "#d25252", + "green": "#a5c261", + "yellow": "#ffc66d", + "blue": "#6c99bb", + "purple": "#d197d9", + "cyan": "#bed6ff", + "white": "#eeeeec", + "brightBlack": "#535353", + "brightRed": "#f00c0c", + "brightGreen": "#c2e075", + "brightYellow": "#e1e48b", + "brightBlue": "#8ab7d9", + "brightPurple": "#efb5f7", + "brightCyan": "#dcf4ff", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#323232", + "cursorColor": "#ffffff" + }, + { + "name": "FairyFloss", + "black": "#42395D", + "red": "#A8757B", + "green": "#FF857F", + "yellow": "#E6C000", + "blue": "#AE81FF", + "purple": "#716799", + "cyan": "#C2FFDF", + "white": "#F8F8F2", + "brightBlack": "#75507B", + "brightRed": "#FFB8D1", + "brightGreen": "#F1568E", + "brightYellow": "#D5A425", + "brightBlue": "#C5A3FF", + "brightPurple": "#8077A8", + "brightCyan": "#C2FFFF", + "brightWhite": "#F8F8F0", + "foreground": "#C2FFDF", + "background": "#5A5475", + "cursorColor": "#FFB8D1" + }, + { + "name": "FairyFlossDark", + "black": "#42395D", + "red": "#A8757B", + "green": "#FF857F", + "yellow": "#E6C000", + "blue": "#AE81FF", + "purple": "#716799", + "cyan": "#C2FFDF", + "white": "#F8F8F2", + "brightBlack": "#75507B", + "brightRed": "#FFB8D1", + "brightGreen": "#F1568E", + "brightYellow": "#D5A425", + "brightBlue": "#C5A3FF", + "brightPurple": "#8077A8", + "brightCyan": "#C2FFFF", + "brightWhite": "#F8F8F0", + "foreground": "#C2FFDF", + "background": "#42395D", + "cursorColor": "#FFB8D1" + }, + { + "name": "Fishtank", + "black": "#03073c", + "red": "#c6004a", + "green": "#acf157", + "yellow": "#fecd5e", + "blue": "#525fb8", + "purple": "#986f82", + "cyan": "#968763", + "white": "#ecf0fc", + "brightBlack": "#6c5b30", + "brightRed": "#da4b8a", + "brightGreen": "#dbffa9", + "brightYellow": "#fee6a9", + "brightBlue": "#b2befa", + "brightPurple": "#fda5cd", + "brightCyan": "#a5bd86", + "brightWhite": "#f6ffec", + "foreground": "#ecf0fe", + "background": "#232537", + "cursorColor": "#ecf0fe" + }, + { + "name": "FlatRemix", + "black": "#1F2229", + "red": "#D41919", + "green": "#5EBDAB", + "yellow": "#FEA44C", + "blue": "#367bf0", + "purple": "#BF2E5D", + "cyan": "#49AEE6", + "white": "#E6E6E6", + "brightBlack": "#8C42AB", + "brightRed": "#EC0101", + "brightGreen": "#47D4B9", + "brightYellow": "#FF8A18", + "brightBlue": "#277FFF", + "brightPurple": "#D71655", + "brightCyan": "#05A1F7", + "brightWhite": "#FFFFFF", + "foreground": "#FFFFFF", + "background": "#272a34", + "cursorColor": "#FFFFFF" + }, + { + "name": "Flat", + "black": "#2c3e50", + "red": "#c0392b", + "green": "#27ae60", + "yellow": "#f39c12", + "blue": "#2980b9", + "purple": "#8e44ad", + "cyan": "#16a085", + "white": "#bdc3c7", + "brightBlack": "#34495e", + "brightRed": "#e74c3c", + "brightGreen": "#2ecc71", + "brightYellow": "#f1c40f", + "brightBlue": "#3498db", + "brightPurple": "#9b59b6", + "brightCyan": "#2AA198", + "brightWhite": "#ecf0f1", + "foreground": "#1abc9c", + "background": "#1F2D3A", + "cursorColor": "#1abc9c" + }, + { + "name": "Flatland", + "black": "#1d1d19", + "red": "#f18339", + "green": "#9fd364", + "yellow": "#f4ef6d", + "blue": "#5096be", + "purple": "#695abc", + "cyan": "#d63865", + "white": "#ffffff", + "brightBlack": "#1d1d19", + "brightRed": "#d22a24", + "brightGreen": "#a7d42c", + "brightYellow": "#ff8949", + "brightBlue": "#61b9d0", + "brightPurple": "#695abc", + "brightCyan": "#d63865", + "brightWhite": "#ffffff", + "foreground": "#b8dbef", + "background": "#1d1f21", + "cursorColor": "#b8dbef" + }, + { + "name": "Foxnightly", + "black": "#2A2A2E", + "red": "#B98EFF", + "green": "#FF7DE9", + "yellow": "#729FCF", + "blue": "#66A05B", + "purple": "#75507B", + "cyan": "#ACACAE", + "white": "#FFFFFF", + "brightBlack": "#A40000", + "brightRed": "#BF4040", + "brightGreen": "#66A05B", + "brightYellow": "#FFB86C", + "brightBlue": "#729FCF", + "brightPurple": "#8F5902", + "brightCyan": "#C4A000", + "brightWhite": "#5C3566", + "foreground": "#D7D7DB", + "background": "#2A2A2E", + "cursorColor": "#D7D7DB" + }, + { + "name": "Freya", + "black": "#073642", + "red": "#dc322f", + "green": "#859900", + "yellow": "#b58900", + "blue": "#268bd2", + "purple": "#ec0048", + "cyan": "#2aa198", + "white": "#94a3a5", + "brightBlack": "#586e75", + "brightRed": "#cb4b16", + "brightGreen": "#859900", + "brightYellow": "#b58900", + "brightBlue": "#268bd2", + "brightPurple": "#d33682", + "brightCyan": "#2aa198", + "brightWhite": "#6c71c4", + "foreground": "#94a3a5", + "background": "#252e32", + "cursorColor": "#839496" + }, + { + "name": "FrontendDelight", + "black": "#242526", + "red": "#f8511b", + "green": "#565747", + "yellow": "#fa771d", + "blue": "#2c70b7", + "purple": "#f02e4f", + "cyan": "#3ca1a6", + "white": "#adadad", + "brightBlack": "#5fac6d", + "brightRed": "#f74319", + "brightGreen": "#74ec4c", + "brightYellow": "#fdc325", + "brightBlue": "#3393ca", + "brightPurple": "#e75e4f", + "brightCyan": "#4fbce6", + "brightWhite": "#8c735b", + "foreground": "#adadad", + "background": "#1b1c1d", + "cursorColor": "#adadad" + }, + { + "name": "FrontendFunForrest", + "black": "#000000", + "red": "#d6262b", + "green": "#919c00", + "yellow": "#be8a13", + "blue": "#4699a3", + "purple": "#8d4331", + "cyan": "#da8213", + "white": "#ddc265", + "brightBlack": "#7f6a55", + "brightRed": "#e55a1c", + "brightGreen": "#bfc65a", + "brightYellow": "#ffcb1b", + "brightBlue": "#7cc9cf", + "brightPurple": "#d26349", + "brightCyan": "#e6a96b", + "brightWhite": "#ffeaa3", + "foreground": "#dec165", + "background": "#251200", + "cursorColor": "#dec165" + }, + { + "name": "FrontendGalaxy", + "black": "#000000", + "red": "#f9555f", + "green": "#21b089", + "yellow": "#fef02a", + "blue": "#589df6", + "purple": "#944d95", + "cyan": "#1f9ee7", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#fa8c8f", + "brightGreen": "#35bb9a", + "brightYellow": "#ffff55", + "brightBlue": "#589df6", + "brightPurple": "#e75699", + "brightCyan": "#3979bc", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#1d2837", + "cursorColor": "#ffffff" + }, + { + "name": "GeoHot", + "black": "#F9F5F5", + "red": "#CC0000", + "green": "#1F1E1F", + "yellow": "#ADA110", + "blue": "#FF004E", + "purple": "#75507B", + "cyan": "#06919A", + "white": "#FFFFFF", + "brightBlack": "#555753", + "brightRed": "#EF2929", + "brightGreen": "#FF0000", + "brightYellow": "#ADA110", + "brightBlue": "#5F4AA6", + "brightPurple": "#B74438", + "brightCyan": "#408F0C", + "brightWhite": "#FFFFFF", + "foreground": "#FFFFFF", + "background": "#1F1E1F", + "cursorColor": "#FFFFFF" + }, + { + "name": "Github", + "black": "#3e3e3e", + "red": "#970b16", + "green": "#07962a", + "yellow": "#f8eec7", + "blue": "#003e8a", + "purple": "#e94691", + "cyan": "#89d1ec", + "white": "#ffffff", + "brightBlack": "#666666", + "brightRed": "#de0000", + "brightGreen": "#87d5a2", + "brightYellow": "#f1d007", + "brightBlue": "#2e6cba", + "brightPurple": "#ffa29f", + "brightCyan": "#1cfafe", + "brightWhite": "#ffffff", + "foreground": "#3e3e3e", + "background": "#f4f4f4", + "cursorColor": "#3e3e3e" + }, + { + "name": "Gogh", + "black": "#292D3E", + "red": "#F07178", + "green": "#62DE84", + "yellow": "#FFCB6B", + "blue": "#75A1FF", + "purple": "#F580FF", + "cyan": "#60BAEC", + "white": "#ABB2BF", + "brightBlack": "#959DCB", + "brightRed": "#F07178", + "brightGreen": "#C3E88D", + "brightYellow": "#FF5572", + "brightBlue": "#82AAFF", + "brightPurple": "#FFCB6B", + "brightCyan": "#676E95", + "brightWhite": "#FFFEFE", + "foreground": "#BFC7D5", + "background": "#292D3E", + "cursorColor": "#BFC7D5" + }, + { + "name": "gooey", + "black": "#000009", + "red": "#BB4F6C", + "green": "#72CCAE", + "yellow": "#C65E3D", + "blue": "#58B6CA", + "purple": "#6488C4", + "cyan": "#8D84C6", + "white": "#858893", + "brightBlack": "#1f222d", + "brightRed": "#ee829f", + "brightGreen": "#a5ffe1", + "brightYellow": "#f99170", + "brightBlue": "#8be9fd", + "brightPurple": "#97bbf7", + "brightCyan": "#c0b7f9", + "brightWhite": "#ffffff", + "foreground": "#EBEEF9", + "background": "#0D101B", + "cursorColor": "#EBEEF9" + }, + { + "name": "GoogleDark", + "black": "#202124", + "red": "#EA4335", + "green": "#34A853", + "yellow": "#FBBC04", + "blue": "#4285F4", + "purple": "#A142F4", + "cyan": "#24C1E0", + "white": "#E8EAED", + "brightBlack": "#5F6368", + "brightRed": "#EA4335", + "brightGreen": "#34A853", + "brightYellow": "#FBBC05", + "brightBlue": "#4285F4", + "brightPurple": "#A142F4", + "brightCyan": "#24C1E0", + "brightWhite": "#FFFFFF", + "foreground": "#E8EAED", + "background": "#202124", + "cursorColor": "#E8EAED" + }, + { + "name": "GoogleLight", + "black": "#202124", + "red": "#EA4335", + "green": "#34A853", + "yellow": "#FBBC04", + "blue": "#4285F4", + "purple": "#A142F4", + "cyan": "#24C1E0", + "white": "#E8EAED", + "brightBlack": "#5F6368", + "brightRed": "#EA4335", + "brightGreen": "#34A853", + "brightYellow": "#FBBC05", + "brightBlue": "#4285F4", + "brightPurple": "#A142F4", + "brightCyan": "#24C1E0", + "brightWhite": "#FFFFFF", + "foreground": "#5F6368", + "background": "#FFFFFF", + "cursorColor": "#5F6368" + }, + { + "name": "gotham", + "black": "#0a0f14", + "red": "#c33027", + "green": "#26a98b", + "yellow": "#edb54b", + "blue": "#195465", + "purple": "#4e5165", + "cyan": "#33859d", + "white": "#98d1ce", + "brightBlack": "#10151b", + "brightRed": "#d26939", + "brightGreen": "#081f2d", + "brightYellow": "#245361", + "brightBlue": "#093748", + "brightPurple": "#888ba5", + "brightCyan": "#599caa", + "brightWhite": "#d3ebe9", + "foreground": "#98d1ce", + "background": "#0a0f14", + "cursorColor": "#98d1ce" + }, + { + "name": "Grape", + "black": "#2d283f", + "red": "#ed2261", + "green": "#1fa91b", + "yellow": "#8ddc20", + "blue": "#487df4", + "purple": "#8d35c9", + "cyan": "#3bdeed", + "white": "#9e9ea0", + "brightBlack": "#59516a", + "brightRed": "#f0729a", + "brightGreen": "#53aa5e", + "brightYellow": "#b2dc87", + "brightBlue": "#a9bcec", + "brightPurple": "#ad81c2", + "brightCyan": "#9de3eb", + "brightWhite": "#a288f7", + "foreground": "#9f9fa1", + "background": "#171423", + "cursorColor": "#9f9fa1" + }, + { + "name": "Grass", + "black": "#000000", + "red": "#bb0000", + "green": "#00bb00", + "yellow": "#e7b000", + "blue": "#0000a3", + "purple": "#950062", + "cyan": "#00bbbb", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#bb0000", + "brightGreen": "#00bb00", + "brightYellow": "#e7b000", + "brightBlue": "#0000bb", + "brightPurple": "#ff55ff", + "brightCyan": "#55ffff", + "brightWhite": "#ffffff", + "foreground": "#fff0a5", + "background": "#13773d", + "cursorColor": "#fff0a5" + }, + { + "name": "GruvboxDark", + "black": "#282828", + "red": "#cc241d", + "green": "#98971a", + "yellow": "#d79921", + "blue": "#458588", + "purple": "#b16286", + "cyan": "#689d6a", + "white": "#a89984", + "brightBlack": "#928374", + "brightRed": "#fb4934", + "brightGreen": "#b8bb26", + "brightYellow": "#fabd2f", + "brightBlue": "#83a598", + "brightPurple": "#d3869b", + "brightCyan": "#8ec07c", + "brightWhite": "#ebdbb2", + "foreground": "#ebdbb2", + "background": "#282828", + "cursorColor": "#ebdbb2" + }, + { + "name": "Gruvbox", + "black": "#fbf1c7", + "red": "#cc241d", + "green": "#98971a", + "yellow": "#d79921", + "blue": "#458588", + "purple": "#b16286", + "cyan": "#689d6a", + "white": "#7c6f64", + "brightBlack": "#928374", + "brightRed": "#9d0006", + "brightGreen": "#79740e", + "brightYellow": "#b57614", + "brightBlue": "#076678", + "brightPurple": "#8f3f71", + "brightCyan": "#427b58", + "brightWhite": "#3c3836", + "foreground": "#3c3836", + "background": "#fbf1c7", + "cursorColor": "#3c3836" + }, + { + "name": "Hardcore", + "black": "#1b1d1e", + "red": "#f92672", + "green": "#a6e22e", + "yellow": "#fd971f", + "blue": "#66d9ef", + "purple": "#9e6ffe", + "cyan": "#5e7175", + "white": "#ccccc6", + "brightBlack": "#505354", + "brightRed": "#ff669d", + "brightGreen": "#beed5f", + "brightYellow": "#e6db74", + "brightBlue": "#66d9ef", + "brightPurple": "#9e6ffe", + "brightCyan": "#a3babf", + "brightWhite": "#f8f8f2", + "foreground": "#a0a0a0", + "background": "#121212", + "cursorColor": "#a0a0a0" + }, + { + "name": "Harper", + "black": "#010101", + "red": "#f8b63f", + "green": "#7fb5e1", + "yellow": "#d6da25", + "blue": "#489e48", + "purple": "#b296c6", + "cyan": "#f5bfd7", + "white": "#a8a49d", + "brightBlack": "#726e6a", + "brightRed": "#f8b63f", + "brightGreen": "#7fb5e1", + "brightYellow": "#d6da25", + "brightBlue": "#489e48", + "brightPurple": "#b296c6", + "brightCyan": "#f5bfd7", + "brightWhite": "#fefbea", + "foreground": "#a8a49d", + "background": "#010101", + "cursorColor": "#a8a49d" + }, + { + "name": "HemisuDark", + "black": "#444444", + "red": "#FF0054", + "green": "#B1D630", + "yellow": "#9D895E", + "blue": "#67BEE3", + "purple": "#B576BC", + "cyan": "#569A9F", + "white": "#EDEDED", + "brightBlack": "#777777", + "brightRed": "#D65E75", + "brightGreen": "#BAFFAA", + "brightYellow": "#ECE1C8", + "brightBlue": "#9FD3E5", + "brightPurple": "#DEB3DF", + "brightCyan": "#B6E0E5", + "brightWhite": "#FFFFFF", + "foreground": "#FFFFFF", + "background": "#000000", + "cursorColor": "#BAFFAA" + }, + { + "name": "HemisuLight", + "black": "#777777", + "red": "#FF0055", + "green": "#739100", + "yellow": "#503D15", + "blue": "#538091", + "purple": "#5B345E", + "cyan": "#538091", + "white": "#999999", + "brightBlack": "#999999", + "brightRed": "#D65E76", + "brightGreen": "#9CC700", + "brightYellow": "#947555", + "brightBlue": "#9DB3CD", + "brightPurple": "#A184A4", + "brightCyan": "#85B2AA", + "brightWhite": "#BABABA", + "foreground": "#444444", + "background": "#EFEFEF", + "cursorColor": "#FF0054" + }, + { + "name": "Highway", + "black": "#000000", + "red": "#d00e18", + "green": "#138034", + "yellow": "#ffcb3e", + "blue": "#006bb3", + "purple": "#6b2775", + "cyan": "#384564", + "white": "#ededed", + "brightBlack": "#5d504a", + "brightRed": "#f07e18", + "brightGreen": "#b1d130", + "brightYellow": "#fff120", + "brightBlue": "#4fc2fd", + "brightPurple": "#de0071", + "brightCyan": "#5d504a", + "brightWhite": "#ffffff", + "foreground": "#ededed", + "background": "#222225", + "cursorColor": "#ededed" + }, + { + "name": "HipsterGreen", + "black": "#000000", + "red": "#b6214a", + "green": "#00a600", + "yellow": "#bfbf00", + "blue": "#246eb2", + "purple": "#b200b2", + "cyan": "#00a6b2", + "white": "#bfbfbf", + "brightBlack": "#666666", + "brightRed": "#e50000", + "brightGreen": "#86a93e", + "brightYellow": "#e5e500", + "brightBlue": "#0000ff", + "brightPurple": "#e500e5", + "brightCyan": "#00e5e5", + "brightWhite": "#e5e5e5", + "foreground": "#84c138", + "background": "#100b05", + "cursorColor": "#84c138" + }, + { + "name": "Homebrew", + "black": "#000000", + "red": "#990000", + "green": "#00a600", + "yellow": "#999900", + "blue": "#0000b2", + "purple": "#b200b2", + "cyan": "#00a6b2", + "white": "#bfbfbf", + "brightBlack": "#666666", + "brightRed": "#e50000", + "brightGreen": "#00d900", + "brightYellow": "#e5e500", + "brightBlue": "#0000ff", + "brightPurple": "#e500e5", + "brightCyan": "#00e5e5", + "brightWhite": "#e5e5e5", + "foreground": "#00ff00", + "background": "#000000", + "cursorColor": "#00ff00" + }, + { + "name": "HorizonBright", + "black": "#16161C", + "red": "#DA103F", + "green": "#1EB980", + "yellow": "#F6661E", + "blue": "#26BBD9", + "purple": "#EE64AE", + "cyan": "#1D8991", + "white": "#FADAD1", + "brightBlack": "#1A1C23", + "brightRed": "#F43E5C", + "brightGreen": "#07DA8C", + "brightYellow": "#F77D26", + "brightBlue": "#3FC6DE", + "brightPurple": "#F075B7", + "brightCyan": "#1EAEAE", + "brightWhite": "#FDF0ED", + "foreground": "#1C1E26", + "background": "#FDF0ED", + "cursorColor": "#1C1E26" + }, + { + "name": "HorizonDark", + "black": "#16161C", + "red": "#E95678", + "green": "#29D398", + "yellow": "#FAB795", + "blue": "#26BBD9", + "purple": "#EE64AE", + "cyan": "#59E3E3", + "white": "#FADAD1", + "brightBlack": "#232530", + "brightRed": "#EC6A88", + "brightGreen": "#3FDAA4", + "brightYellow": "#FBC3A7", + "brightBlue": "#3FC6DE", + "brightPurple": "#F075B7", + "brightCyan": "#6BE6E6", + "brightWhite": "#FDF0ED", + "foreground": "#FDF0ED", + "background": "#1C1E26", + "cursorColor": "#FDF0ED" + }, + { + "name": "Hurtado", + "black": "#575757", + "red": "#ff1b00", + "green": "#a5e055", + "yellow": "#fbe74a", + "blue": "#496487", + "purple": "#fd5ff1", + "cyan": "#86e9fe", + "white": "#cbcccb", + "brightBlack": "#262626", + "brightRed": "#d51d00", + "brightGreen": "#a5df55", + "brightYellow": "#fbe84a", + "brightBlue": "#89beff", + "brightPurple": "#c001c1", + "brightCyan": "#86eafe", + "brightWhite": "#dbdbdb", + "foreground": "#dbdbdb", + "background": "#000000", + "cursorColor": "#dbdbdb" + }, + { + "name": "Hybrid", + "black": "#282a2e", + "red": "#A54242", + "green": "#8C9440", + "yellow": "#de935f", + "blue": "#5F819D", + "purple": "#85678F", + "cyan": "#5E8D87", + "white": "#969896", + "brightBlack": "#373b41", + "brightRed": "#cc6666", + "brightGreen": "#b5bd68", + "brightYellow": "#f0c674", + "brightBlue": "#81a2be", + "brightPurple": "#b294bb", + "brightCyan": "#8abeb7", + "brightWhite": "#c5c8c6", + "foreground": "#94a3a5", + "background": "#141414", + "cursorColor": "#94a3a5" + }, + { + "name": "IBM3270(HighContrast)", + "black": "#000000", + "red": "#FF0000", + "green": "#00FF00", + "yellow": "#FFFF00", + "blue": "#00BFFF", + "purple": "#FFC0CB", + "cyan": "#40E0D0", + "white": "#BEBEBE", + "brightBlack": "#414141", + "brightRed": "#FFA500", + "brightGreen": "#98FB98", + "brightYellow": "#FFFF00", + "brightBlue": "#0000CD", + "brightPurple": "#A020F0", + "brightCyan": "#AEEEEE", + "brightWhite": "#FFFFFF", + "foreground": "#FDFDFD", + "background": "#000000", + "cursorColor": "#FDFDFD" + }, + { + "name": "ibm3270", + "black": "#222222", + "red": "#F01818", + "green": "#24D830", + "yellow": "#F0D824", + "blue": "#7890F0", + "purple": "#F078D8", + "cyan": "#54E4E4", + "white": "#A5A5A5", + "brightBlack": "#888888", + "brightRed": "#EF8383", + "brightGreen": "#7ED684", + "brightYellow": "#EFE28B", + "brightBlue": "#B3BFEF", + "brightPurple": "#EFB3E3", + "brightCyan": "#9CE2E2", + "brightWhite": "#FFFFFF", + "foreground": "#FDFDFD", + "background": "#000000", + "cursorColor": "#FDFDFD" + }, + { + "name": "ICGreenPPL", + "black": "#1f1f1f", + "red": "#fb002a", + "green": "#339c24", + "yellow": "#659b25", + "blue": "#149b45", + "purple": "#53b82c", + "cyan": "#2cb868", + "white": "#e0ffef", + "brightBlack": "#032710", + "brightRed": "#a7ff3f", + "brightGreen": "#9fff6d", + "brightYellow": "#d2ff6d", + "brightBlue": "#72ffb5", + "brightPurple": "#50ff3e", + "brightCyan": "#22ff71", + "brightWhite": "#daefd0", + "foreground": "#d9efd3", + "background": "#3a3d3f", + "cursorColor": "#d9efd3" + }, + { + "name": "ICOrangePPL", + "black": "#000000", + "red": "#c13900", + "green": "#a4a900", + "yellow": "#caaf00", + "blue": "#bd6d00", + "purple": "#fc5e00", + "cyan": "#f79500", + "white": "#ffc88a", + "brightBlack": "#6a4f2a", + "brightRed": "#ff8c68", + "brightGreen": "#f6ff40", + "brightYellow": "#ffe36e", + "brightBlue": "#ffbe55", + "brightPurple": "#fc874f", + "brightCyan": "#c69752", + "brightWhite": "#fafaff", + "foreground": "#ffcb83", + "background": "#262626", + "cursorColor": "#ffcb83" + }, + { + "name": "IdleToes", + "black": "#323232", + "red": "#d25252", + "green": "#7fe173", + "yellow": "#ffc66d", + "blue": "#4099ff", + "purple": "#f680ff", + "cyan": "#bed6ff", + "white": "#eeeeec", + "brightBlack": "#535353", + "brightRed": "#f07070", + "brightGreen": "#9dff91", + "brightYellow": "#ffe48b", + "brightBlue": "#5eb7f7", + "brightPurple": "#ff9dff", + "brightCyan": "#dcf4ff", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#323232", + "cursorColor": "#ffffff" + }, + { + "name": "IrBlack", + "black": "#4e4e4e", + "red": "#ff6c60", + "green": "#a8ff60", + "yellow": "#ffffb6", + "blue": "#69cbfe", + "purple": "#ff73Fd", + "cyan": "#c6c5fe", + "white": "#eeeeee", + "brightBlack": "#7c7c7c", + "brightRed": "#ffb6b0", + "brightGreen": "#ceffac", + "brightYellow": "#ffffcb", + "brightBlue": "#b5dcfe", + "brightPurple": "#ff9cfe", + "brightCyan": "#dfdffe", + "brightWhite": "#ffffff", + "foreground": "#eeeeee", + "background": "#000000", + "cursorColor": "#ffa560" + }, + { + "name": "JackieBrown", + "black": "#2c1d16", + "red": "#ef5734", + "green": "#2baf2b", + "yellow": "#bebf00", + "blue": "#246eb2", + "purple": "#d05ec1", + "cyan": "#00acee", + "white": "#bfbfbf", + "brightBlack": "#666666", + "brightRed": "#e50000", + "brightGreen": "#86a93e", + "brightYellow": "#e5e500", + "brightBlue": "#0000ff", + "brightPurple": "#e500e5", + "brightCyan": "#00e5e5", + "brightWhite": "#e5e5e5", + "foreground": "#ffcc2f", + "background": "#2c1d16", + "cursorColor": "#ffcc2f" + }, + { + "name": "Japanesque", + "black": "#343935", + "red": "#cf3f61", + "green": "#7bb75b", + "yellow": "#e9b32a", + "blue": "#4c9ad4", + "purple": "#a57fc4", + "cyan": "#389aad", + "white": "#fafaf6", + "brightBlack": "#595b59", + "brightRed": "#d18fa6", + "brightGreen": "#767f2c", + "brightYellow": "#78592f", + "brightBlue": "#135979", + "brightPurple": "#604291", + "brightCyan": "#76bbca", + "brightWhite": "#b2b5ae", + "foreground": "#f7f6ec", + "background": "#1e1e1e", + "cursorColor": "#f7f6ec" + }, + { + "name": "Jellybeans", + "black": "#929292", + "red": "#e27373", + "green": "#94b979", + "yellow": "#ffba7b", + "blue": "#97bedc", + "purple": "#e1c0fa", + "cyan": "#00988e", + "white": "#dedede", + "brightBlack": "#bdbdbd", + "brightRed": "#ffa1a1", + "brightGreen": "#bddeab", + "brightYellow": "#ffdca0", + "brightBlue": "#b1d8f6", + "brightPurple": "#fbdaff", + "brightCyan": "#1ab2a8", + "brightWhite": "#ffffff", + "foreground": "#dedede", + "background": "#121212", + "cursorColor": "#dedede" + }, + { + "name": "Jup", + "black": "#000000", + "red": "#dd006f", + "green": "#6fdd00", + "yellow": "#dd6f00", + "blue": "#006fdd", + "purple": "#6f00dd", + "cyan": "#00dd6f", + "white": "#f2f2f2", + "brightBlack": "#7d7d7d", + "brightRed": "#ff74b9", + "brightGreen": "#b9ff74", + "brightYellow": "#ffb974", + "brightBlue": "#74b9ff", + "brightPurple": "#b974ff", + "brightCyan": "#74ffb9", + "brightWhite": "#ffffff", + "foreground": "#23476a", + "background": "#758480", + "cursorColor": "#23476a" + }, + { + "name": "Kibble", + "black": "#4d4d4d", + "red": "#c70031", + "green": "#29cf13", + "yellow": "#d8e30e", + "blue": "#3449d1", + "purple": "#8400ff", + "cyan": "#0798ab", + "white": "#e2d1e3", + "brightBlack": "#5a5a5a", + "brightRed": "#f01578", + "brightGreen": "#6ce05c", + "brightYellow": "#f3f79e", + "brightBlue": "#97a4f7", + "brightPurple": "#c495f0", + "brightCyan": "#68f2e0", + "brightWhite": "#ffffff", + "foreground": "#f7f7f7", + "background": "#0e100a", + "cursorColor": "#f7f7f7" + }, + { + "name": "kokuban", + "black": "#2E8744", + "red": "#D84E4C", + "green": "#95DA5A", + "yellow": "#D6E264", + "blue": "#4B9ED7", + "purple": "#945FC5", + "cyan": "#D89B25", + "white": "#D8E2D7", + "brightBlack": "#34934F", + "brightRed": "#FF4F59", + "brightGreen": "#AFF56A", + "brightYellow": "#FCFF75", + "brightBlue": "#57AEFF", + "brightPurple": "#AE63E9", + "brightCyan": "#FFAA2B", + "brightWhite": "#FFFEFE", + "foreground": "#D8E2D7", + "background": "#0D4A08", + "cursorColor": "#D8E2D7" + }, + { + "name": "laserwave", + "black": "#39243A", + "red": "#EB64B9", + "green": "#AFD686", + "yellow": "#FEAE87", + "blue": "#40B4C4", + "purple": "#B381C5", + "cyan": "#215969", + "white": "#91889b", + "brightBlack": "#716485", + "brightRed": "#FC2377", + "brightGreen": "#50FA7B", + "brightYellow": "#FFE261", + "brightBlue": "#74DFC4", + "brightPurple": "#6D75E0", + "brightCyan": "#B4DCE7", + "brightWhite": "#FFFFFF", + "foreground": "#E0E0E0", + "background": "#1F1926", + "cursorColor": "#C7C7C7" + }, + { + "name": "LaterThisEvening", + "black": "#2b2b2b", + "red": "#d45a60", + "green": "#afba67", + "yellow": "#e5d289", + "blue": "#a0bad6", + "purple": "#c092d6", + "cyan": "#91bfb7", + "white": "#3c3d3d", + "brightBlack": "#454747", + "brightRed": "#d3232f", + "brightGreen": "#aabb39", + "brightYellow": "#e5be39", + "brightBlue": "#6699d6", + "brightPurple": "#ab53d6", + "brightCyan": "#5fc0ae", + "brightWhite": "#c1c2c2", + "foreground": "#959595", + "background": "#222222", + "cursorColor": "#959595" + }, + { + "name": "Lavandula", + "black": "#230046", + "red": "#7d1625", + "green": "#337e6f", + "yellow": "#7f6f49", + "blue": "#4f4a7f", + "purple": "#5a3f7f", + "cyan": "#58777f", + "white": "#736e7d", + "brightBlack": "#372d46", + "brightRed": "#e05167", + "brightGreen": "#52e0c4", + "brightYellow": "#e0c386", + "brightBlue": "#8e87e0", + "brightPurple": "#a776e0", + "brightCyan": "#9ad4e0", + "brightWhite": "#8c91fa", + "foreground": "#736e7d", + "background": "#050014", + "cursorColor": "#736e7d" + }, + { + "name": "LiquidCarbonTransparent", + "black": "#000000", + "red": "#ff3030", + "green": "#559a70", + "yellow": "#ccac00", + "blue": "#0099cc", + "purple": "#cc69c8", + "cyan": "#7ac4cc", + "white": "#bccccc", + "brightBlack": "#000000", + "brightRed": "#ff3030", + "brightGreen": "#559a70", + "brightYellow": "#ccac00", + "brightBlue": "#0099cc", + "brightPurple": "#cc69c8", + "brightCyan": "#7ac4cc", + "brightWhite": "#bccccc", + "foreground": "#afc2c2", + "background": "#000000", + "cursorColor": "#afc2c2" + }, + { + "name": "LiquidCarbon", + "black": "#000000", + "red": "#ff3030", + "green": "#559a70", + "yellow": "#ccac00", + "blue": "#0099cc", + "purple": "#cc69c8", + "cyan": "#7ac4cc", + "white": "#bccccc", + "brightBlack": "#000000", + "brightRed": "#ff3030", + "brightGreen": "#559a70", + "brightYellow": "#ccac00", + "brightBlue": "#0099cc", + "brightPurple": "#cc69c8", + "brightCyan": "#7ac4cc", + "brightWhite": "#bccccc", + "foreground": "#afc2c2", + "background": "#303030", + "cursorColor": "#afc2c2" + }, + { + "name": "LunariaDark", + "black": "#36464E", + "red": "#846560", + "green": "#809984", + "yellow": "#A79A79", + "blue": "#555673", + "purple": "#866C83", + "cyan": "#7E98B4", + "white": "#CACED8", + "brightBlack": "#404F56", + "brightRed": "#BB928B", + "brightGreen": "#BFDCC2", + "brightYellow": "#F1DFB6", + "brightBlue": "#777798", + "brightPurple": "#BF9DB9", + "brightCyan": "#BDDCFF", + "brightWhite": "#DFE2ED", + "foreground": "#CACED8", + "background": "#36464E", + "cursorColor": "#CACED8" + }, + { + "name": "LunariaEclipse", + "black": "#323F46", + "red": "#83615B", + "green": "#7F9781", + "yellow": "#A69875", + "blue": "#53516F", + "purple": "#856880", + "cyan": "#7D96B2", + "white": "#C9CDD7", + "brightBlack": "#3D4950", + "brightRed": "#BA9088", + "brightGreen": "#BEDBC1", + "brightYellow": "#F1DFB4", + "brightBlue": "#767495", + "brightPurple": "#BE9CB8", + "brightCyan": "#BCDBFF", + "brightWhite": "#DFE2ED", + "foreground": "#C9CDD7", + "background": "#323F46", + "cursorColor": "#C9CDD7" + }, + { + "name": "LunariaLight", + "black": "#3E3C3D", + "red": "#783C1F", + "green": "#497D46", + "yellow": "#8F750B", + "blue": "#3F3566", + "purple": "#793F62", + "cyan": "#3778A9", + "white": "#D5CFCC", + "brightBlack": "#484646", + "brightRed": "#B06240", + "brightGreen": "#7BC175", + "brightYellow": "#DCB735", + "brightBlue": "#5C4F89", + "brightPurple": "#B56895", + "brightCyan": "#64BAFF", + "brightWhite": "#EBE4E1", + "foreground": "#484646", + "background": "#EBE4E1", + "cursorColor": "#484646" + }, + { + "name": "Maia", + "black": "#232423", + "red": "#BA2922", + "green": "#7E807E", + "yellow": "#4C4F4D", + "blue": "#16A085", + "purple": "#43746A", + "cyan": "#00CCCC", + "white": "#E0E0E0", + "brightBlack": "#282928", + "brightRed": "#CC372C", + "brightGreen": "#8D8F8D", + "brightYellow": "#4E524F", + "brightBlue": "#13BF9D", + "brightPurple": "#487D72", + "brightCyan": "#00D1D1", + "brightWhite": "#E8E8E8", + "foreground": "#BDC3C7", + "background": "#31363B", + "cursorColor": "#BDC3C7" + }, + { + "name": "ManPage", + "black": "#000000", + "red": "#cc0000", + "green": "#00a600", + "yellow": "#999900", + "blue": "#0000b2", + "purple": "#b200b2", + "cyan": "#00a6b2", + "white": "#cccccc", + "brightBlack": "#666666", + "brightRed": "#e50000", + "brightGreen": "#00d900", + "brightYellow": "#e5e500", + "brightBlue": "#0000ff", + "brightPurple": "#e500e5", + "brightCyan": "#00e5e5", + "brightWhite": "#e5e5e5", + "foreground": "#000000", + "background": "#fef49c", + "cursorColor": "#000000" + }, + { + "name": "Mar", + "black": "#000000", + "red": "#b5407b", + "green": "#7bb540", + "yellow": "#b57b40", + "blue": "#407bb5", + "purple": "#7b40b5", + "cyan": "#40b57b", + "white": "#f8f8f8", + "brightBlack": "#737373", + "brightRed": "#cd73a0", + "brightGreen": "#a0cd73", + "brightYellow": "#cda073", + "brightBlue": "#73a0cd", + "brightPurple": "#a073cd", + "brightCyan": "#73cda0", + "brightWhite": "#ffffff", + "foreground": "#23476a", + "background": "#ffffff", + "cursorColor": "#23476a" + }, + { + "name": "Material", + "black": "#073641", + "red": "#EB606B", + "green": "#C3E88D", + "yellow": "#F7EB95", + "blue": "#80CBC3", + "purple": "#FF2490", + "cyan": "#AEDDFF", + "white": "#FFFFFF", + "brightBlack": "#002B36", + "brightRed": "#EB606B", + "brightGreen": "#C3E88D", + "brightYellow": "#F7EB95", + "brightBlue": "#7DC6BF", + "brightPurple": "#6C71C3", + "brightCyan": "#34434D", + "brightWhite": "#FFFFFF", + "foreground": "#C3C7D1", + "background": "#1E282C", + "cursorColor": "#657B83" + }, + { + "name": "Mathias", + "black": "#000000", + "red": "#e52222", + "green": "#a6e32d", + "yellow": "#fc951e", + "blue": "#c48dff", + "purple": "#fa2573", + "cyan": "#67d9f0", + "white": "#f2f2f2", + "brightBlack": "#555555", + "brightRed": "#ff5555", + "brightGreen": "#55ff55", + "brightYellow": "#ffff55", + "brightBlue": "#5555ff", + "brightPurple": "#ff55ff", + "brightCyan": "#55ffff", + "brightWhite": "#ffffff", + "foreground": "#bbbbbb", + "background": "#000000", + "cursorColor": "#bbbbbb" + }, + { + "name": "Medallion", + "black": "#000000", + "red": "#b64c00", + "green": "#7c8b16", + "yellow": "#d3bd26", + "blue": "#616bb0", + "purple": "#8c5a90", + "cyan": "#916c25", + "white": "#cac29a", + "brightBlack": "#5e5219", + "brightRed": "#ff9149", + "brightGreen": "#b2ca3b", + "brightYellow": "#ffe54a", + "brightBlue": "#acb8ff", + "brightPurple": "#ffa0ff", + "brightCyan": "#ffbc51", + "brightWhite": "#fed698", + "foreground": "#cac296", + "background": "#1d1908", + "cursorColor": "#cac296" + }, + { + "name": "Misterioso", + "black": "#000000", + "red": "#ff4242", + "green": "#74af68", + "yellow": "#ffad29", + "blue": "#338f86", + "purple": "#9414e6", + "cyan": "#23d7d7", + "white": "#e1e1e0", + "brightBlack": "#555555", + "brightRed": "#ff3242", + "brightGreen": "#74cd68", + "brightYellow": "#ffb929", + "brightBlue": "#23d7d7", + "brightPurple": "#ff37ff", + "brightCyan": "#00ede1", + "brightWhite": "#ffffff", + "foreground": "#e1e1e0", + "background": "#2d3743", + "cursorColor": "#e1e1e0" + }, + { + "name": "Miu", + "black": "#000000", + "red": "#b87a7a", + "green": "#7ab87a", + "yellow": "#b8b87a", + "blue": "#7a7ab8", + "purple": "#b87ab8", + "cyan": "#7ab8b8", + "white": "#d9d9d9", + "brightBlack": "#262626", + "brightRed": "#dbbdbd", + "brightGreen": "#bddbbd", + "brightYellow": "#dbdbbd", + "brightBlue": "#bdbddb", + "brightPurple": "#dbbddb", + "brightCyan": "#bddbdb", + "brightWhite": "#ffffff", + "foreground": "#d9e6f2", + "background": "#0d1926", + "cursorColor": "#d9e6f2" + }, + { + "name": "Molokai", + "black": "#1b1d1e", + "red": "#7325FA", + "green": "#23E298", + "yellow": "#60D4DF", + "blue": "#D08010", + "purple": "#FF0087", + "cyan": "#D0A843", + "white": "#BBBBBB", + "brightBlack": "#555555", + "brightRed": "#9D66F6", + "brightGreen": "#5FE0B1", + "brightYellow": "#6DF2FF", + "brightBlue": "#FFAF00", + "brightPurple": "#FF87AF", + "brightCyan": "#FFCE51", + "brightWhite": "#FFFFFF", + "foreground": "#BBBBBB", + "background": "#1b1d1e", + "cursorColor": "#BBBBBB" + }, + { + "name": "MonaLisa", + "black": "#351b0e", + "red": "#9b291c", + "green": "#636232", + "yellow": "#c36e28", + "blue": "#515c5d", + "purple": "#9b1d29", + "cyan": "#588056", + "white": "#f7d75c", + "brightBlack": "#874228", + "brightRed": "#ff4331", + "brightGreen": "#b4b264", + "brightYellow": "#ff9566", + "brightBlue": "#9eb2b4", + "brightPurple": "#ff5b6a", + "brightCyan": "#8acd8f", + "brightWhite": "#ffe598", + "foreground": "#f7d66a", + "background": "#120b0d", + "cursorColor": "#f7d66a" + }, + { + "name": "mono-amber", + "black": "#402500", + "red": "#FF9400", + "green": "#FF9400", + "yellow": "#FF9400", + "blue": "#FF9400", + "purple": "#FF9400", + "cyan": "#FF9400", + "white": "#FF9400", + "brightBlack": "#FF9400", + "brightRed": "#FF9400", + "brightGreen": "#FF9400", + "brightYellow": "#FF9400", + "brightBlue": "#FF9400", + "brightPurple": "#FF9400", + "brightCyan": "#FF9400", + "brightWhite": "#FF9400", + "foreground": "#FF9400", + "background": "#2B1900", + "cursorColor": "#FF9400" + }, + { + "name": "mono-cyan", + "black": "#003340", + "red": "#00CCFF", + "green": "#00CCFF", + "yellow": "#00CCFF", + "blue": "#00CCFF", + "purple": "#00CCFF", + "cyan": "#00CCFF", + "white": "#00CCFF", + "brightBlack": "#00CCFF", + "brightRed": "#00CCFF", + "brightGreen": "#00CCFF", + "brightYellow": "#00CCFF", + "brightBlue": "#00CCFF", + "brightPurple": "#00CCFF", + "brightCyan": "#00CCFF", + "brightWhite": "#00CCFF", + "foreground": "#00CCFF", + "background": "#00222B", + "cursorColor": "#00CCFF" + }, + { + "name": "mono-green", + "black": "#034000", + "red": "#0BFF00", + "green": "#0BFF00", + "yellow": "#0BFF00", + "blue": "#0BFF00", + "purple": "#0BFF00", + "cyan": "#0BFF00", + "white": "#0BFF00", + "brightBlack": "#0BFF00", + "brightRed": "#0BFF00", + "brightGreen": "#0BFF00", + "brightYellow": "#0BFF00", + "brightBlue": "#0BFF00", + "brightPurple": "#0BFF00", + "brightCyan": "#0BFF00", + "brightWhite": "#0BFF00", + "foreground": "#0BFF00", + "background": "#022B00", + "cursorColor": "#0BFF00" + }, + { + "name": "mono-red", + "black": "#401200", + "red": "#FF3600", + "green": "#FF3600", + "yellow": "#FF3600", + "blue": "#FF3600", + "purple": "#FF3600", + "cyan": "#FF3600", + "white": "#FF3600", + "brightBlack": "#FF3600", + "brightRed": "#FF3600", + "brightGreen": "#FF3600", + "brightYellow": "#FF3600", + "brightBlue": "#FF3600", + "brightPurple": "#FF3600", + "brightCyan": "#FF3600", + "brightWhite": "#FF3600", + "foreground": "#FF3600", + "background": "#2B0C00", + "cursorColor": "#FF3600" + }, + { + "name": "mono-white", + "black": "#3B3B3B", + "red": "#FAFAFA", + "green": "#FAFAFA", + "yellow": "#FAFAFA", + "blue": "#FAFAFA", + "purple": "#FAFAFA", + "cyan": "#FAFAFA", + "white": "#FAFAFA", + "brightBlack": "#FAFAFA", + "brightRed": "#FAFAFA", + "brightGreen": "#FAFAFA", + "brightYellow": "#FAFAFA", + "brightBlue": "#FAFAFA", + "brightPurple": "#FAFAFA", + "brightCyan": "#FAFAFA", + "brightWhite": "#FAFAFA", + "foreground": "#FAFAFA", + "background": "#262626", + "cursorColor": "#FAFAFA" + }, + { + "name": "mono-yellow", + "black": "#403500", + "red": "#FFD300", + "green": "#FFD300", + "yellow": "#FFD300", + "blue": "#FFD300", + "purple": "#FFD300", + "cyan": "#FFD300", + "white": "#FFD300", + "brightBlack": "#FFD300", + "brightRed": "#FFD300", + "brightGreen": "#FFD300", + "brightYellow": "#FFD300", + "brightBlue": "#FFD300", + "brightPurple": "#FFD300", + "brightCyan": "#FFD300", + "brightWhite": "#FFD300", + "foreground": "#FFD300", + "background": "#2B2400", + "cursorColor": "#FFD300" + }, + { + "name": "MonokaiDark", + "black": "#75715e", + "red": "#f92672", + "green": "#a6e22e", + "yellow": "#f4bf75", + "blue": "#66d9ef", + "purple": "#ae81ff", + "cyan": "#2AA198", + "white": "#f9f8f5", + "brightBlack": "#272822", + "brightRed": "#f92672", + "brightGreen": "#a6e22e", + "brightYellow": "#f4bf75", + "brightBlue": "#66d9ef", + "brightPurple": "#ae81ff", + "brightCyan": "#2AA198", + "brightWhite": "#f8f8f2", + "foreground": "#f8f8f2", + "background": "#272822", + "cursorColor": "#f8f8f2" + }, + { + "name": "MonokaiProRistretto", + "black": "#3E3838", + "red": "#DF7484", + "green": "#BBD87E", + "yellow": "#EDCE73", + "blue": "#DC9373", + "purple": "#A9AAE9", + "cyan": "#A4D7CC", + "white": "#FBF2F3", + "brightBlack": "#70696A", + "brightRed": "#DF7484", + "brightGreen": "#BBD87E", + "brightYellow": "#EDCE73", + "brightBlue": "#DC9373", + "brightPurple": "#A9AAE9", + "brightCyan": "#A4D7CC", + "brightWhite": "#FBF2F3", + "foreground": "#FBF2F3", + "background": "#3E3838", + "cursorColor": "#FBF2F3" + }, + { + "name": "MonokaiPro", + "black": "#363537", + "red": "#FF6188", + "green": "#A9DC76", + "yellow": "#FFD866", + "blue": "#FC9867", + "purple": "#AB9DF2", + "cyan": "#78DCE8", + "white": "#FDF9F3", + "brightBlack": "#908E8F", + "brightRed": "#FF6188", + "brightGreen": "#A9DC76", + "brightYellow": "#FFD866", + "brightBlue": "#FC9867", + "brightPurple": "#AB9DF2", + "brightCyan": "#78DCE8", + "brightWhite": "#FDF9F3", + "foreground": "#FDF9F3", + "background": "#363537", + "cursorColor": "#FDF9F3" + }, + { + "name": "MonokaiSoda", + "black": "#1a1a1a", + "red": "#f4005f", + "green": "#98e024", + "yellow": "#fa8419", + "blue": "#9d65ff", + "purple": "#f4005f", + "cyan": "#58d1eb", + "white": "#c4c5b5", + "brightBlack": "#625e4c", + "brightRed": "#f4005f", + "brightGreen": "#98e024", + "brightYellow": "#e0d561", + "brightBlue": "#9d65ff", + "brightPurple": "#f4005f", + "brightCyan": "#58d1eb", + "brightWhite": "#f6f6ef", + "foreground": "#c4c5b5", + "background": "#1a1a1a", + "cursorColor": "#c4c5b5" + }, + { + "name": "Morada", + "black": "#040404", + "red": "#0f49c4", + "green": "#48b117", + "yellow": "#e87324", + "blue": "#bc0116", + "purple": "#665b93", + "cyan": "#70a699", + "white": "#f5dcbe", + "brightBlack": "#4f7cbf", + "brightRed": "#1c96c7", + "brightGreen": "#3bff6f", + "brightYellow": "#efc31c", + "brightBlue": "#fb605b", + "brightPurple": "#975b5a", + "brightCyan": "#1eff8e", + "brightWhite": "#f6f5fb", + "foreground": "#ffffff", + "background": "#211f46", + "cursorColor": "#ffffff" + }, + { + "name": "N0tch2k", + "black": "#383838", + "red": "#a95551", + "green": "#666666", + "yellow": "#a98051", + "blue": "#657d3e", + "purple": "#767676", + "cyan": "#c9c9c9", + "white": "#d0b8a3", + "brightBlack": "#474747", + "brightRed": "#a97775", + "brightGreen": "#8c8c8c", + "brightYellow": "#a99175", + "brightBlue": "#98bd5e", + "brightPurple": "#a3a3a3", + "brightCyan": "#dcdcdc", + "brightWhite": "#d8c8bb", + "foreground": "#a0a0a0", + "background": "#222222", + "cursorColor": "#a0a0a0" + }, + { + "name": "neon-night", + "black": "#20242d", + "red": "#FF8E8E", + "green": "#7EFDD0", + "yellow": "#FCAD3F", + "blue": "#69B4F9", + "purple": "#DD92F6", + "cyan": "#8CE8ff", + "white": "#C9CCCD", + "brightBlack": "#20242d", + "brightRed": "#FF8E8E", + "brightGreen": "#7EFDD0", + "brightYellow": "#FCAD3F", + "brightBlue": "#69B4F9", + "brightPurple": "#DD92F6", + "brightCyan": "#8CE8ff", + "brightWhite": "#C9CCCD", + "foreground": "#C7C8FF", + "background": "#20242d", + "cursorColor": "#C7C8FF" + }, + { + "name": "Neopolitan", + "black": "#000000", + "red": "#800000", + "green": "#61ce3c", + "yellow": "#fbde2d", + "blue": "#253b76", + "purple": "#ff0080", + "cyan": "#8da6ce", + "white": "#f8f8f8", + "brightBlack": "#000000", + "brightRed": "#800000", + "brightGreen": "#61ce3c", + "brightYellow": "#fbde2d", + "brightBlue": "#253b76", + "brightPurple": "#ff0080", + "brightCyan": "#8da6ce", + "brightWhite": "#f8f8f8", + "foreground": "#ffffff", + "background": "#271f19", + "cursorColor": "#ffffff" + }, + { + "name": "Nep", + "black": "#000000", + "red": "#dd6f00", + "green": "#00dd6f", + "yellow": "#6fdd00", + "blue": "#6f00dd", + "purple": "#dd006f", + "cyan": "#006fdd", + "white": "#f2f2f2", + "brightBlack": "#7d7d7d", + "brightRed": "#ffb974", + "brightGreen": "#74ffb9", + "brightYellow": "#b9ff74", + "brightBlue": "#b974ff", + "brightPurple": "#ff74b9", + "brightCyan": "#74b9ff", + "brightWhite": "#ffffff", + "foreground": "#23476a", + "background": "#758480", + "cursorColor": "#23476a" + }, + { + "name": "Neutron", + "black": "#23252b", + "red": "#b54036", + "green": "#5ab977", + "yellow": "#deb566", + "blue": "#6a7c93", + "purple": "#a4799d", + "cyan": "#3f94a8", + "white": "#e6e8ef", + "brightBlack": "#23252b", + "brightRed": "#b54036", + "brightGreen": "#5ab977", + "brightYellow": "#deb566", + "brightBlue": "#6a7c93", + "brightPurple": "#a4799d", + "brightCyan": "#3f94a8", + "brightWhite": "#ebedf2", + "foreground": "#e6e8ef", + "background": "#1c1e22", + "cursorColor": "#e6e8ef" + }, + { + "name": "NightOwl", + "black": "#011627", + "red": "#EF5350", + "green": "#22da6e", + "yellow": "#addb67", + "blue": "#82aaff", + "purple": "#c792ea", + "cyan": "#21c7a8", + "white": "#ffffff", + "brightBlack": "#575656", + "brightRed": "#ef5350", + "brightGreen": "#22da6e", + "brightYellow": "#ffeb95", + "brightBlue": "#82aaff", + "brightPurple": "#c792ea", + "brightCyan": "#7fdbca", + "brightWhite": "#ffffff", + "foreground": "#d6deeb", + "background": "#011627", + "cursorColor": "#d6deeb" + }, + { + "name": "NightlionV1", + "black": "#4c4c4c", + "red": "#bb0000", + "green": "#5fde8f", + "yellow": "#f3f167", + "blue": "#276bd8", + "purple": "#bb00bb", + "cyan": "#00dadf", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#ff5555", + "brightGreen": "#55ff55", + "brightYellow": "#ffff55", + "brightBlue": "#5555ff", + "brightPurple": "#ff55ff", + "brightCyan": "#55ffff", + "brightWhite": "#ffffff", + "foreground": "#bbbbbb", + "background": "#000000", + "cursorColor": "#bbbbbb" + }, + { + "name": "NightlionV2", + "black": "#4c4c4c", + "red": "#bb0000", + "green": "#04f623", + "yellow": "#f3f167", + "blue": "#64d0f0", + "purple": "#ce6fdb", + "cyan": "#00dadf", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#ff5555", + "brightGreen": "#7df71d", + "brightYellow": "#ffff55", + "brightBlue": "#62cbe8", + "brightPurple": "#ff9bf5", + "brightCyan": "#00ccd8", + "brightWhite": "#ffffff", + "foreground": "#bbbbbb", + "background": "#171717", + "cursorColor": "#bbbbbb" + }, + { + "name": "nighty", + "black": "#373D48", + "red": "#9B3E46", + "green": "#095B32", + "yellow": "#808020", + "blue": "#1D3E6F", + "purple": "#823065", + "cyan": "#3A7458", + "white": "#828282", + "brightBlack": "#5C6370", + "brightRed": "#D0555F", + "brightGreen": "#119955", + "brightYellow": "#DFE048", + "brightBlue": "#4674B8", + "brightPurple": "#ED86C9", + "brightCyan": "#70D2A4", + "brightWhite": "#DFDFDF", + "foreground": "#DFDFDF", + "background": "#2F2F2F", + "cursorColor": "#DFDFDF" + }, + { + "name": "NordLight", + "black": "#003B4E", + "red": "#E64569", + "green": "#069F5F", + "yellow": "#DAB752", + "blue": "#439ECF", + "purple": "#D961DC", + "cyan": "#00B1BE", + "white": "#B3B3B3", + "brightBlack": "#3E89A1", + "brightRed": "#E4859A", + "brightGreen": "#A2CCA1", + "brightYellow": "#E1E387", + "brightBlue": "#6FBBE2", + "brightPurple": "#E586E7", + "brightCyan": "#96DCDA", + "brightWhite": "#DEDEDE", + "foreground": "#004f7c", + "background": "#ebeaf2", + "cursorColor": "#439ECF" + }, + { + "name": "Nord", + "black": "#3B4252", + "red": "#BF616A", + "green": "#A3BE8C", + "yellow": "#EBCB8B", + "blue": "#81A1C1", + "purple": "#B48EAD", + "cyan": "#88C0D0", + "white": "#E5E9F0", + "brightBlack": "#4C566A", + "brightRed": "#BF616A", + "brightGreen": "#A3BE8C", + "brightYellow": "#EBCB8B", + "brightBlue": "#81A1C1", + "brightPurple": "#B48EAD", + "brightCyan": "#8FBCBB", + "brightWhite": "#ECEFF4", + "foreground": "#D8DEE9", + "background": "#2E3440", + "cursorColor": "#D8DEE9" + }, + { + "name": "Novel", + "black": "#000000", + "red": "#cc0000", + "green": "#009600", + "yellow": "#d06b00", + "blue": "#0000cc", + "purple": "#cc00cc", + "cyan": "#0087cc", + "white": "#cccccc", + "brightBlack": "#808080", + "brightRed": "#cc0000", + "brightGreen": "#009600", + "brightYellow": "#d06b00", + "brightBlue": "#0000cc", + "brightPurple": "#cc00cc", + "brightCyan": "#0087cc", + "brightWhite": "#ffffff", + "foreground": "#3b2322", + "background": "#dfdbc3", + "cursorColor": "#3b2322" + }, + { + "name": "Obsidian", + "black": "#000000", + "red": "#a60001", + "green": "#00bb00", + "yellow": "#fecd22", + "blue": "#3a9bdb", + "purple": "#bb00bb", + "cyan": "#00bbbb", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#ff0003", + "brightGreen": "#93c863", + "brightYellow": "#fef874", + "brightBlue": "#a1d7ff", + "brightPurple": "#ff55ff", + "brightCyan": "#55ffff", + "brightWhite": "#ffffff", + "foreground": "#cdcdcd", + "background": "#283033", + "cursorColor": "#cdcdcd" + }, + { + "name": "OceanDark", + "black": "#4F4F4F", + "red": "#AF4B57", + "green": "#AFD383", + "yellow": "#E5C079", + "blue": "#7D90A4", + "purple": "#A4799D", + "cyan": "#85A6A5", + "white": "#EEEDEE", + "brightBlack": "#7B7B7B", + "brightRed": "#AF4B57", + "brightGreen": "#CEFFAB", + "brightYellow": "#FFFECC", + "brightBlue": "#B5DCFE", + "brightPurple": "#FB9BFE", + "brightCyan": "#DFDFFD", + "brightWhite": "#FEFFFE", + "foreground": "#979CAC", + "background": "#1C1F27", + "cursorColor": "#979CAC" + }, + { + "name": "Ocean", + "black": "#000000", + "red": "#990000", + "green": "#00a600", + "yellow": "#999900", + "blue": "#0000b2", + "purple": "#b200b2", + "cyan": "#00a6b2", + "white": "#bfbfbf", + "brightBlack": "#666666", + "brightRed": "#e50000", + "brightGreen": "#00d900", + "brightYellow": "#e5e500", + "brightBlue": "#0000ff", + "brightPurple": "#e500e5", + "brightCyan": "#00e5e5", + "brightWhite": "#e5e5e5", + "foreground": "#ffffff", + "background": "#224fbc", + "cursorColor": "#ffffff" + }, + { + "name": "OceanicNext", + "black": "#121C21", + "red": "#E44754", + "green": "#89BD82", + "yellow": "#F7BD51", + "blue": "#5486C0", + "purple": "#B77EB8", + "cyan": "#50A5A4", + "white": "#FFFFFF", + "brightBlack": "#52606B", + "brightRed": "#E44754", + "brightGreen": "#89BD82", + "brightYellow": "#F7BD51", + "brightBlue": "#5486C0", + "brightPurple": "#B77EB8", + "brightCyan": "#50A5A4", + "brightWhite": "#FFFFFF", + "foreground": "#b3b8c3", + "background": "#121b21", + "cursorColor": "#b3b8c3" + }, + { + "name": "Ollie", + "black": "#000000", + "red": "#ac2e31", + "green": "#31ac61", + "yellow": "#ac4300", + "blue": "#2d57ac", + "purple": "#b08528", + "cyan": "#1fa6ac", + "white": "#8a8eac", + "brightBlack": "#5b3725", + "brightRed": "#ff3d48", + "brightGreen": "#3bff99", + "brightYellow": "#ff5e1e", + "brightBlue": "#4488ff", + "brightPurple": "#ffc21d", + "brightCyan": "#1ffaff", + "brightWhite": "#5b6ea7", + "foreground": "#8a8dae", + "background": "#222125", + "cursorColor": "#8a8dae" + }, + { + "name": "Omni", + "black": "#191622", + "red": "#E96379", + "green": "#67e480", + "yellow": "#E89E64", + "blue": "#78D1E1", + "purple": "#988BC7", + "cyan": "#FF79C6", + "white": "#ABB2BF", + "brightBlack": "#000000", + "brightRed": "#E96379", + "brightGreen": "#67e480", + "brightYellow": "#E89E64", + "brightBlue": "#78D1E1", + "brightPurple": "#988BC7", + "brightCyan": "#FF79C6", + "brightWhite": "#ffffff", + "foreground": "#ABB2BF", + "background": "#191622", + "cursorColor": "#ABB2BF" + }, + { + "name": "OneDark", + "black": "#000000", + "red": "#E06C75", + "green": "#98C379", + "yellow": "#D19A66", + "blue": "#61AFEF", + "purple": "#C678DD", + "cyan": "#56B6C2", + "white": "#ABB2BF", + "brightBlack": "#5C6370", + "brightRed": "#E06C75", + "brightGreen": "#98C379", + "brightYellow": "#D19A66", + "brightBlue": "#61AFEF", + "brightPurple": "#C678DD", + "brightCyan": "#56B6C2", + "brightWhite": "#FFFEFE", + "foreground": "#5C6370", + "background": "#1E2127", + "cursorColor": "#5C6370" + }, + { + "name": "OneHalfBlack", + "black": "#282c34", + "red": "#e06c75", + "green": "#98c379", + "yellow": "#e5c07b", + "blue": "#61afef", + "purple": "#c678dd", + "cyan": "#56b6c2", + "white": "#dcdfe4", + "brightBlack": "#282c34", + "brightRed": "#e06c75", + "brightGreen": "#98c379", + "brightYellow": "#e5c07b", + "brightBlue": "#61afef", + "brightPurple": "#c678dd", + "brightCyan": "#56b6c2", + "brightWhite": "#dcdfe4", + "foreground": "#dcdfe4", + "background": "#000000", + "cursorColor": "#dcdfe4" + }, + { + "name": "OneLight", + "black": "#000000", + "red": "#DA3E39", + "green": "#41933E", + "yellow": "#855504", + "blue": "#315EEE", + "purple": "#930092", + "cyan": "#0E6FAD", + "white": "#8E8F96", + "brightBlack": "#2A2B32", + "brightRed": "#DA3E39", + "brightGreen": "#41933E", + "brightYellow": "#855504", + "brightBlue": "#315EEE", + "brightPurple": "#930092", + "brightCyan": "#0E6FAD", + "brightWhite": "#FFFEFE", + "foreground": "#2A2B32", + "background": "#F8F8F8", + "cursorColor": "#2A2B32" + }, + { + "name": "palenight", + "black": "#292D3E", + "red": "#F07178", + "green": "#C3E88D", + "yellow": "#FFCB6B", + "blue": "#82AAFF", + "purple": "#C792EA", + "cyan": "#60ADEC", + "white": "#ABB2BF", + "brightBlack": "#959DCB", + "brightRed": "#F07178", + "brightGreen": "#C3E88D", + "brightYellow": "#FF5572", + "brightBlue": "#82AAFF", + "brightPurple": "#FFCB6B", + "brightCyan": "#676E95", + "brightWhite": "#FFFEFE", + "foreground": "#BFC7D5", + "background": "#292D3E", + "cursorColor": "#BFC7D5" + }, + { + "name": "Pali", + "black": "#0a0a0a", + "red": "#ab8f74", + "green": "#74ab8f", + "yellow": "#8fab74", + "blue": "#8f74ab", + "purple": "#ab748f", + "cyan": "#748fab", + "white": "#F2F2F2", + "brightBlack": "#5D5D5D", + "brightRed": "#FF1D62", + "brightGreen": "#9cc3af", + "brightYellow": "#FFD00A", + "brightBlue": "#af9cc3", + "brightPurple": "#FF1D62", + "brightCyan": "#4BB8FD", + "brightWhite": "#A020F0", + "foreground": "#d9e6f2", + "background": "#232E37", + "cursorColor": "#d9e6f2" + }, + { + "name": "Panda", + "black": "#1F1F20", + "red": "#FB055A", + "green": "#26FFD4", + "yellow": "#FDAA5A", + "blue": "#5C9FFF", + "purple": "#FC59A6", + "cyan": "#26FFD4", + "white": "#F0F0F0", + "brightBlack": "#5C6370", + "brightRed": "#FB055A", + "brightGreen": "#26FFD4", + "brightYellow": "#FEBE7E", + "brightBlue": "#55ADFF", + "brightPurple": "#FD95D0", + "brightCyan": "#26FFD4", + "brightWhite": "#F0F0F0", + "foreground": "#F0F0F0", + "background": "#1D1E20", + "cursorColor": "#F0F0F0" + }, + { + "name": "PaperColorDark", + "black": "#1C1C1C", + "red": "#AF005F", + "green": "#5FAF00", + "yellow": "#D7AF5F", + "blue": "#5FAFD7", + "purple": "#808080", + "cyan": "#D7875F", + "white": "#D0D0D0", + "brightBlack": "#585858", + "brightRed": "#5FAF5F", + "brightGreen": "#AFD700", + "brightYellow": "#AF87D7", + "brightBlue": "#FFAF00", + "brightPurple": "#FF5FAF", + "brightCyan": "#00AFAF", + "brightWhite": "#5F8787", + "foreground": "#D0D0D0", + "background": "#1C1C1C", + "cursorColor": "#D0D0D0" + }, + { + "name": "PaperColorLight", + "black": "#EEEEEE", + "red": "#AF0000", + "green": "#008700", + "yellow": "#5F8700", + "blue": "#0087AF", + "purple": "#878787", + "cyan": "#005F87", + "white": "#444444", + "brightBlack": "#BCBCBC", + "brightRed": "#D70000", + "brightGreen": "#D70087", + "brightYellow": "#8700AF", + "brightBlue": "#D75F00", + "brightPurple": "#D75F00", + "brightCyan": "#005FAF", + "brightWhite": "#005F87", + "foreground": "#444444", + "background": "#EEEEEE", + "cursorColor": "#444444" + }, + { + "name": "ParaisoDark", + "black": "#2f1e2e", + "red": "#ef6155", + "green": "#48b685", + "yellow": "#fec418", + "blue": "#06b6ef", + "purple": "#815ba4", + "cyan": "#5bc4bf", + "white": "#a39e9b", + "brightBlack": "#776e71", + "brightRed": "#ef6155", + "brightGreen": "#48b685", + "brightYellow": "#fec418", + "brightBlue": "#06b6ef", + "brightPurple": "#815ba4", + "brightCyan": "#5bc4bf", + "brightWhite": "#e7e9db", + "foreground": "#a39e9b", + "background": "#2f1e2e", + "cursorColor": "#a39e9b" + }, + { + "name": "PaulMillr", + "black": "#2a2a2a", + "red": "#ff0000", + "green": "#79ff0f", + "yellow": "#d3bf00", + "blue": "#396bd7", + "purple": "#b449be", + "cyan": "#66ccff", + "white": "#bbbbbb", + "brightBlack": "#666666", + "brightRed": "#ff0080", + "brightGreen": "#66ff66", + "brightYellow": "#f3d64e", + "brightBlue": "#709aed", + "brightPurple": "#db67e6", + "brightCyan": "#7adff2", + "brightWhite": "#ffffff", + "foreground": "#f2f2f2", + "background": "#000000", + "cursorColor": "#f2f2f2" + }, + { + "name": "PencilDark", + "black": "#212121", + "red": "#c30771", + "green": "#10a778", + "yellow": "#a89c14", + "blue": "#008ec4", + "purple": "#523c79", + "cyan": "#20a5ba", + "white": "#d9d9d9", + "brightBlack": "#424242", + "brightRed": "#fb007a", + "brightGreen": "#5fd7af", + "brightYellow": "#f3e430", + "brightBlue": "#20bbfc", + "brightPurple": "#6855de", + "brightCyan": "#4fb8cc", + "brightWhite": "#f1f1f1", + "foreground": "#f1f1f1", + "background": "#212121", + "cursorColor": "#f1f1f1" + }, + { + "name": "PencilLight", + "black": "#212121", + "red": "#c30771", + "green": "#10a778", + "yellow": "#a89c14", + "blue": "#008ec4", + "purple": "#523c79", + "cyan": "#20a5ba", + "white": "#d9d9d9", + "brightBlack": "#424242", + "brightRed": "#fb007a", + "brightGreen": "#5fd7af", + "brightYellow": "#f3e430", + "brightBlue": "#20bbfc", + "brightPurple": "#6855de", + "brightCyan": "#4fb8cc", + "brightWhite": "#f1f1f1", + "foreground": "#424242", + "background": "#f1f1f1", + "cursorColor": "#424242" + }, + { + "name": "Peppermint", + "black": "#353535", + "red": "#E64569", + "green": "#89D287", + "yellow": "#DAB752", + "blue": "#439ECF", + "purple": "#D961DC", + "cyan": "#64AAAF", + "white": "#B3B3B3", + "brightBlack": "#535353", + "brightRed": "#E4859A", + "brightGreen": "#A2CCA1", + "brightYellow": "#E1E387", + "brightBlue": "#6FBBE2", + "brightPurple": "#E586E7", + "brightCyan": "#96DCDA", + "brightWhite": "#DEDEDE", + "foreground": "#C7C7C7", + "background": "#000000", + "cursorColor": "#BBBBBB" + }, + { + "name": "Pixiefloss", + "black": "#2f2942", + "red": "#ff857f", + "green": "#48b685", + "yellow": "#e6c000", + "blue": "#ae81ff", + "purple": "#ef6155", + "cyan": "#c2ffdf", + "white": "#f8f8f2", + "brightBlack": "#75507b", + "brightRed": "#f1568e", + "brightGreen": "#5adba2", + "brightYellow": "#d5a425", + "brightBlue": "#c5a3ff", + "brightPurple": "#ef6155", + "brightCyan": "#c2ffff", + "brightWhite": "#f8f8f0", + "foreground": "#d1cae8", + "background": "#241f33", + "cursorColor": "#d1cae8" + }, + { + "name": "Pnevma", + "black": "#2f2e2d", + "red": "#a36666", + "green": "#90a57d", + "yellow": "#d7af87", + "blue": "#7fa5bd", + "purple": "#c79ec4", + "cyan": "#8adbb4", + "white": "#d0d0d0", + "brightBlack": "#4a4845", + "brightRed": "#d78787", + "brightGreen": "#afbea2", + "brightYellow": "#e4c9af", + "brightBlue": "#a1bdce", + "brightPurple": "#d7beda", + "brightCyan": "#b1e7dd", + "brightWhite": "#efefef", + "foreground": "#d0d0d0", + "background": "#1c1c1c", + "cursorColor": "#d0d0d0" + }, + { + "name": "PowerShell", + "black": "#000000", + "red": "#7E0008", + "green": "#098003", + "yellow": "#C4A000", + "blue": "#010083", + "purple": "#D33682", + "cyan": "#0E807F", + "white": "#7F7C7F", + "brightBlack": "#808080", + "brightRed": "#EF2929", + "brightGreen": "#1CFE3C", + "brightYellow": "#FEFE45", + "brightBlue": "#268AD2", + "brightPurple": "#FE13FA", + "brightCyan": "#29FFFE", + "brightWhite": "#C2C1C3", + "foreground": "#F6F6F7", + "background": "#052454", + "cursorColor": "#F6F6F7" + }, + { + "name": "Pro", + "black": "#000000", + "red": "#990000", + "green": "#00a600", + "yellow": "#999900", + "blue": "#2009db", + "purple": "#b200b2", + "cyan": "#00a6b2", + "white": "#bfbfbf", + "brightBlack": "#666666", + "brightRed": "#e50000", + "brightGreen": "#00d900", + "brightYellow": "#e5e500", + "brightBlue": "#0000ff", + "brightPurple": "#e500e5", + "brightCyan": "#00e5e5", + "brightWhite": "#e5e5e5", + "foreground": "#f2f2f2", + "background": "#000000", + "cursorColor": "#f2f2f2" + }, + { + "name": "PurplePeopleEater", + "black": "#0d1117", + "red": "#e34c26", + "green": "#238636", + "yellow": "#ed9a51", + "blue": "#a5d6ff", + "purple": "#6eb0e8", + "cyan": "#c09aeb", + "white": "#c9d1d9", + "brightBlack": "#0d1117", + "brightRed": "#ff7b72", + "brightGreen": "#3bab4a", + "brightYellow": "#ffa657", + "brightBlue": "#a5d6ff", + "brightPurple": "#79c0ff", + "brightCyan": "#b694df", + "brightWhite": "#c9d1d9", + "foreground": "#c9d1d9", + "background": "#161b22", + "cursorColor": "#c9d1d9" + }, + { + "name": "RedAlert", + "black": "#000000", + "red": "#d62e4e", + "green": "#71be6b", + "yellow": "#beb86b", + "blue": "#489bee", + "purple": "#e979d7", + "cyan": "#6bbeb8", + "white": "#d6d6d6", + "brightBlack": "#262626", + "brightRed": "#e02553", + "brightGreen": "#aff08c", + "brightYellow": "#dfddb7", + "brightBlue": "#65aaf1", + "brightPurple": "#ddb7df", + "brightCyan": "#b7dfdd", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#762423", + "cursorColor": "#ffffff" + }, + { + "name": "RedSands", + "black": "#000000", + "red": "#ff3f00", + "green": "#00bb00", + "yellow": "#e7b000", + "blue": "#0072ff", + "purple": "#bb00bb", + "cyan": "#00bbbb", + "white": "#bbbbbb", + "brightBlack": "#555555", + "brightRed": "#bb0000", + "brightGreen": "#00bb00", + "brightYellow": "#e7b000", + "brightBlue": "#0072ae", + "brightPurple": "#ff55ff", + "brightCyan": "#55ffff", + "brightWhite": "#ffffff", + "foreground": "#d7c9a7", + "background": "#7a251e", + "cursorColor": "#d7c9a7" + }, + { + "name": "Relaxed", + "black": "#151515", + "red": "#BC5653", + "green": "#909D63", + "yellow": "#EBC17A", + "blue": "#6A8799", + "purple": "#B06698", + "cyan": "#C9DFFF", + "white": "#D9D9D9", + "brightBlack": "#636363", + "brightRed": "#BC5653", + "brightGreen": "#A0AC77", + "brightYellow": "#EBC17A", + "brightBlue": "#7EAAC7", + "brightPurple": "#B06698", + "brightCyan": "#ACBBD0", + "brightWhite": "#F7F7F7", + "foreground": "#D9D9D9", + "background": "#353A44", + "cursorColor": "#D9D9D9" + }, + { + "name": "Rippedcasts", + "black": "#000000", + "red": "#cdaf95", + "green": "#a8ff60", + "yellow": "#bfbb1f", + "blue": "#75a5b0", + "purple": "#ff73fd", + "cyan": "#5a647e", + "white": "#bfbfbf", + "brightBlack": "#666666", + "brightRed": "#eecbad", + "brightGreen": "#bcee68", + "brightYellow": "#e5e500", + "brightBlue": "#86bdc9", + "brightPurple": "#e500e5", + "brightCyan": "#8c9bc4", + "brightWhite": "#e5e5e5", + "foreground": "#ffffff", + "background": "#2b2b2b", + "cursorColor": "#ffffff" + }, + { + "name": "Royal", + "black": "#241f2b", + "red": "#91284c", + "green": "#23801c", + "yellow": "#b49d27", + "blue": "#6580b0", + "purple": "#674d96", + "cyan": "#8aaabe", + "white": "#524966", + "brightBlack": "#312d3d", + "brightRed": "#d5356c", + "brightGreen": "#2cd946", + "brightYellow": "#fde83b", + "brightBlue": "#90baf9", + "brightPurple": "#a479e3", + "brightCyan": "#acd4eb", + "brightWhite": "#9e8cbd", + "foreground": "#514968", + "background": "#100815", + "cursorColor": "#514968" + }, + { + "name": "Sat", + "black": "#000000", + "red": "#dd0007", + "green": "#07dd00", + "yellow": "#ddd600", + "blue": "#0007dd", + "purple": "#d600dd", + "cyan": "#00ddd6", + "white": "#f2f2f2", + "brightBlack": "#7d7d7d", + "brightRed": "#ff7478", + "brightGreen": "#78ff74", + "brightYellow": "#fffa74", + "brightBlue": "#7478ff", + "brightPurple": "#fa74ff", + "brightCyan": "#74fffa", + "brightWhite": "#ffffff", + "foreground": "#23476a", + "background": "#758480", + "cursorColor": "#23476a" + }, + { + "name": "SeaShells", + "black": "#17384c", + "red": "#d15123", + "green": "#027c9b", + "yellow": "#fca02f", + "blue": "#1e4950", + "purple": "#68d4f1", + "cyan": "#50a3b5", + "white": "#deb88d", + "brightBlack": "#434b53", + "brightRed": "#d48678", + "brightGreen": "#628d98", + "brightYellow": "#fdd39f", + "brightBlue": "#1bbcdd", + "brightPurple": "#bbe3ee", + "brightCyan": "#87acb4", + "brightWhite": "#fee4ce", + "foreground": "#deb88d", + "background": "#09141b", + "cursorColor": "#deb88d" + }, + { + "name": "SeafoamPastel", + "black": "#757575", + "red": "#825d4d", + "green": "#728c62", + "yellow": "#ada16d", + "blue": "#4d7b82", + "purple": "#8a7267", + "cyan": "#729494", + "white": "#e0e0e0", + "brightBlack": "#8a8a8a", + "brightRed": "#cf937a", + "brightGreen": "#98d9aa", + "brightYellow": "#fae79d", + "brightBlue": "#7ac3cf", + "brightPurple": "#d6b2a1", + "brightCyan": "#ade0e0", + "brightWhite": "#e0e0e0", + "foreground": "#d4e7d4", + "background": "#243435", + "cursorColor": "#d4e7d4" + }, + { + "name": "Seti", + "black": "#323232", + "red": "#c22832", + "green": "#8ec43d", + "yellow": "#e0c64f", + "blue": "#43a5d5", + "purple": "#8b57b5", + "cyan": "#8ec43d", + "white": "#eeeeee", + "brightBlack": "#323232", + "brightRed": "#c22832", + "brightGreen": "#8ec43d", + "brightYellow": "#e0c64f", + "brightBlue": "#43a5d5", + "brightPurple": "#8b57b5", + "brightCyan": "#8ec43d", + "brightWhite": "#ffffff", + "foreground": "#cacecd", + "background": "#111213", + "cursorColor": "#cacecd" + }, + { + "name": "Shaman", + "black": "#012026", + "red": "#b2302d", + "green": "#00a941", + "yellow": "#5e8baa", + "blue": "#449a86", + "purple": "#00599d", + "cyan": "#5d7e19", + "white": "#405555", + "brightBlack": "#384451", + "brightRed": "#ff4242", + "brightGreen": "#2aea5e", + "brightYellow": "#8ed4fd", + "brightBlue": "#61d5ba", + "brightPurple": "#1298ff", + "brightCyan": "#98d028", + "brightWhite": "#58fbd6", + "foreground": "#405555", + "background": "#001015", + "cursorColor": "#405555" + }, + { + "name": "Shel", + "black": "#2c2423", + "red": "#ab2463", + "green": "#6ca323", + "yellow": "#ab6423", + "blue": "#2c64a2", + "purple": "#6c24a2", + "cyan": "#2ca363", + "white": "#918988", + "brightBlack": "#918988", + "brightRed": "#f588b9", + "brightGreen": "#c2ee86", + "brightYellow": "#f5ba86", + "brightBlue": "#8fbaec", + "brightPurple": "#c288ec", + "brightCyan": "#8feeb9", + "brightWhite": "#f5eeec", + "foreground": "#4882cd", + "background": "#2a201f", + "cursorColor": "#4882cd" + }, + { + "name": "Slate", + "black": "#222222", + "red": "#e2a8bf", + "green": "#81d778", + "yellow": "#c4c9c0", + "blue": "#264b49", + "purple": "#a481d3", + "cyan": "#15ab9c", + "white": "#02c5e0", + "brightBlack": "#ffffff", + "brightRed": "#ffcdd9", + "brightGreen": "#beffa8", + "brightYellow": "#d0ccca", + "brightBlue": "#7ab0d2", + "brightPurple": "#c5a7d9", + "brightCyan": "#8cdfe0", + "brightWhite": "#e0e0e0", + "foreground": "#35b1d2", + "background": "#222222", + "cursorColor": "#35b1d2" + }, + { + "name": "Smyck", + "black": "#000000", + "red": "#C75646", + "green": "#8EB33B", + "yellow": "#D0B03C", + "blue": "#72B3CC", + "purple": "#C8A0D1", + "cyan": "#218693", + "white": "#B0B0B0", + "brightBlack": "#5D5D5D", + "brightRed": "#E09690", + "brightGreen": "#CDEE69", + "brightYellow": "#FFE377", + "brightBlue": "#9CD9F0", + "brightPurple": "#FBB1F9", + "brightCyan": "#77DFD8", + "brightWhite": "#F7F7F7", + "foreground": "#F7F7F7", + "background": "#242424", + "cursorColor": "#F7F7F7" + }, + { + "name": "Snazzy", + "black": "#282A36", + "red": "#FF5C57", + "green": "#5AF78E", + "yellow": "#F3F99D", + "blue": "#57C7FF", + "purple": "#FF6AC1", + "cyan": "#9AEDFE", + "white": "#F1F1F0", + "brightBlack": "#686868", + "brightRed": "#FF5C57", + "brightGreen": "#5AF78E", + "brightYellow": "#F3F99D", + "brightBlue": "#57C7FF", + "brightPurple": "#FF6AC1", + "brightCyan": "#9AEDFE", + "brightWhite": "#EFF0EB", + "foreground": "#EFF0EB", + "background": "#282A36", + "cursorColor": "#97979B" + }, + { + "name": "SoftServer", + "black": "#000000", + "red": "#a2686a", + "green": "#9aa56a", + "yellow": "#a3906a", + "blue": "#6b8fa3", + "purple": "#6a71a3", + "cyan": "#6ba58f", + "white": "#99a3a2", + "brightBlack": "#666c6c", + "brightRed": "#dd5c60", + "brightGreen": "#bfdf55", + "brightYellow": "#deb360", + "brightBlue": "#62b1df", + "brightPurple": "#606edf", + "brightCyan": "#64e39c", + "brightWhite": "#d2e0de", + "foreground": "#99a3a2", + "background": "#242626", + "cursorColor": "#99a3a2" + }, + { + "name": "SolarizedDarcula", + "black": "#25292a", + "red": "#f24840", + "green": "#629655", + "yellow": "#b68800", + "blue": "#2075c7", + "purple": "#797fd4", + "cyan": "#15968d", + "white": "#d2d8d9", + "brightBlack": "#25292a", + "brightRed": "#f24840", + "brightGreen": "#629655", + "brightYellow": "#b68800", + "brightBlue": "#2075c7", + "brightPurple": "#797fd4", + "brightCyan": "#15968d", + "brightWhite": "#d2d8d9", + "foreground": "#d2d8d9", + "background": "#3d3f41", + "cursorColor": "#d2d8d9" + }, + { + "name": "SolarizedDarkHigherContrast", + "black": "#002831", + "red": "#d11c24", + "green": "#6cbe6c", + "yellow": "#a57706", + "blue": "#2176c7", + "purple": "#c61c6f", + "cyan": "#259286", + "white": "#eae3cb", + "brightBlack": "#006488", + "brightRed": "#f5163b", + "brightGreen": "#51ef84", + "brightYellow": "#b27e28", + "brightBlue": "#178ec8", + "brightPurple": "#e24d8e", + "brightCyan": "#00b39e", + "brightWhite": "#fcf4dc", + "foreground": "#9cc2c3", + "background": "#001e27", + "cursorColor": "#9cc2c3" + }, + { + "name": "SolarizedDark", + "black": "#073642", + "red": "#DC322F", + "green": "#859900", + "yellow": "#CF9A6B", + "blue": "#268BD2", + "purple": "#D33682", + "cyan": "#2AA198", + "white": "#EEE8D5", + "brightBlack": "#657B83", + "brightRed": "#D87979", + "brightGreen": "#88CF76", + "brightYellow": "#657B83", + "brightBlue": "#2699FF", + "brightPurple": "#D33682", + "brightCyan": "#43B8C3", + "brightWhite": "#FDF6E3", + "foreground": "#839496", + "background": "#002B36", + "cursorColor": "#839496" + }, + { + "name": "SolarizedLight", + "black": "#073642", + "red": "#DC322F", + "green": "#859900", + "yellow": "#B58900", + "blue": "#268BD2", + "purple": "#D33682", + "cyan": "#2AA198", + "white": "#EEE8D5", + "brightBlack": "#002B36", + "brightRed": "#CB4B16", + "brightGreen": "#586E75", + "brightYellow": "#657B83", + "brightBlue": "#839496", + "brightPurple": "#6C71C4", + "brightCyan": "#93A1A1", + "brightWhite": "#FDF6E3", + "foreground": "#657B83", + "background": "#FDF6E3", + "cursorColor": "#657B83" + }, + { + "name": "Sonokai", + "black": "#2C2E34", + "red": "#FC5D7C", + "green": "#9ED072", + "yellow": "#E7C664", + "blue": "#F39660", + "purple": "#B39DF3", + "cyan": "#76CCE0", + "white": "#E2E2E3", + "brightBlack": "#2C2E34", + "brightRed": "#FC5D7C", + "brightGreen": "#9ED072", + "brightYellow": "#E7C664", + "brightBlue": "#F39660", + "brightPurple": "#B39DF3", + "brightCyan": "#76CCE0", + "brightWhite": "#E2E2E3", + "foreground": "#E2E2E3", + "background": "#2C2E34", + "cursorColor": "#E2E2E3" + }, + { + "name": "Spacedust", + "black": "#6e5346", + "red": "#e35b00", + "green": "#5cab96", + "yellow": "#e3cd7b", + "blue": "#0f548b", + "purple": "#e35b00", + "cyan": "#06afc7", + "white": "#f0f1ce", + "brightBlack": "#684c31", + "brightRed": "#ff8a3a", + "brightGreen": "#aecab8", + "brightYellow": "#ffc878", + "brightBlue": "#67a0ce", + "brightPurple": "#ff8a3a", + "brightCyan": "#83a7b4", + "brightWhite": "#fefff1", + "foreground": "#ecf0c1", + "background": "#0a1e24", + "cursorColor": "#ecf0c1" + }, + { + "name": "SpaceGrayEightiesDull", + "black": "#15171c", + "red": "#b24a56", + "green": "#92b477", + "yellow": "#c6735a", + "blue": "#7c8fa5", + "purple": "#a5789e", + "cyan": "#80cdcb", + "white": "#b3b8c3", + "brightBlack": "#555555", + "brightRed": "#ec5f67", + "brightGreen": "#89e986", + "brightYellow": "#fec254", + "brightBlue": "#5486c0", + "brightPurple": "#bf83c1", + "brightCyan": "#58c2c1", + "brightWhite": "#ffffff", + "foreground": "#c9c6bc", + "background": "#222222", + "cursorColor": "#c9c6bc" + }, + { + "name": "SpaceGrayEighties", + "black": "#15171c", + "red": "#ec5f67", + "green": "#81a764", + "yellow": "#fec254", + "blue": "#5486c0", + "purple": "#bf83c1", + "cyan": "#57c2c1", + "white": "#efece7", + "brightBlack": "#555555", + "brightRed": "#ff6973", + "brightGreen": "#93d493", + "brightYellow": "#ffd256", + "brightBlue": "#4d84d1", + "brightPurple": "#ff55ff", + "brightCyan": "#83e9e4", + "brightWhite": "#ffffff", + "foreground": "#bdbaae", + "background": "#222222", + "cursorColor": "#bdbaae" + }, + { + "name": "SpaceGray", + "black": "#000000", + "red": "#b04b57", + "green": "#87b379", + "yellow": "#e5c179", + "blue": "#7d8fa4", + "purple": "#a47996", + "cyan": "#85a7a5", + "white": "#b3b8c3", + "brightBlack": "#000000", + "brightRed": "#b04b57", + "brightGreen": "#87b379", + "brightYellow": "#e5c179", + "brightBlue": "#7d8fa4", + "brightPurple": "#a47996", + "brightCyan": "#85a7a5", + "brightWhite": "#ffffff", + "foreground": "#b3b8c3", + "background": "#20242d", + "cursorColor": "#b3b8c3" + }, + { + "name": "Spring", + "black": "#000000", + "red": "#ff4d83", + "green": "#1f8c3b", + "yellow": "#1fc95b", + "blue": "#1dd3ee", + "purple": "#8959a8", + "cyan": "#3e999f", + "white": "#ffffff", + "brightBlack": "#000000", + "brightRed": "#ff0021", + "brightGreen": "#1fc231", + "brightYellow": "#d5b807", + "brightBlue": "#15a9fd", + "brightPurple": "#8959a8", + "brightCyan": "#3e999f", + "brightWhite": "#ffffff", + "foreground": "#ecf0c1", + "background": "#0a1e24", + "cursorColor": "#ecf0c1" + }, + { + "name": "Square", + "black": "#050505", + "red": "#e9897c", + "green": "#b6377d", + "yellow": "#ecebbe", + "blue": "#a9cdeb", + "purple": "#75507b", + "cyan": "#c9caec", + "white": "#f2f2f2", + "brightBlack": "#141414", + "brightRed": "#f99286", + "brightGreen": "#c3f786", + "brightYellow": "#fcfbcc", + "brightBlue": "#b6defb", + "brightPurple": "#ad7fa8", + "brightCyan": "#d7d9fc", + "brightWhite": "#e2e2e2", + "foreground": "#a1a1a1", + "background": "#0a1e24", + "cursorColor": "#a1a1a1" + }, + { + "name": "Srcery", + "black": "#1C1B19", + "red": "#FF3128", + "green": "#519F50", + "yellow": "#FBB829", + "blue": "#5573A3", + "purple": "#E02C6D", + "cyan": "#0AAEB3", + "white": "#918175", + "brightBlack": "#2D2B28", + "brightRed": "#F75341", + "brightGreen": "#98BC37", + "brightYellow": "#FED06E", + "brightBlue": "#8EB2F7", + "brightPurple": "#E35682", + "brightCyan": "#53FDE9", + "brightWhite": "#FCE8C3", + "foreground": "#ebdbb2", + "background": "#282828", + "cursorColor": "#ebdbb2" + }, + { + "name": "summer-pop", + "black": "#666666", + "red": "#FF1E8E", + "green": "#8EFF1E", + "yellow": "#FFFB00", + "blue": "#1E8EFF", + "purple": "#E500E5", + "cyan": "#00E5E5", + "white": "#E5E5E5", + "brightBlack": "#666666", + "brightRed": "#FF1E8E", + "brightGreen": "#8EFF1E", + "brightYellow": "#FFFB00", + "brightBlue": "#1E8EFF", + "brightPurple": "#E500E5", + "brightCyan": "#00E5E5", + "brightWhite": "#E5E5E5", + "foreground": "#FFFFFF", + "background": "#272822", + "cursorColor": "#FFFFFF" + }, + { + "name": "Sundried", + "black": "#302b2a", + "red": "#a7463d", + "green": "#587744", + "yellow": "#9d602a", + "blue": "#485b98", + "purple": "#864651", + "cyan": "#9c814f", + "white": "#c9c9c9", + "brightBlack": "#4d4e48", + "brightRed": "#aa000c", + "brightGreen": "#128c21", + "brightYellow": "#fc6a21", + "brightBlue": "#7999f7", + "brightPurple": "#fd8aa1", + "brightCyan": "#fad484", + "brightWhite": "#ffffff", + "foreground": "#c9c9c9", + "background": "#1a1818", + "cursorColor": "#c9c9c9" + }, + { + "name": "sweet-eliverlara", + "black": "#282C34", + "red": "#ED254E", + "green": "#71F79F", + "yellow": "#F9DC5C", + "blue": "#7CB7FF", + "purple": "#C74DED", + "cyan": "#00C1E4", + "white": "#DCDFE4", + "brightBlack": "#282C34", + "brightRed": "#ED254E", + "brightGreen": "#71F79F", + "brightYellow": "#F9DC5C", + "brightBlue": "#7CB7FF", + "brightPurple": "#C74DED", + "brightCyan": "#00C1E4", + "brightWhite": "#DCDFE4", + "foreground": "#C3C7D1", + "background": "#282C34", + "cursorColor": "#C3C7D1" + }, + { + "name": "SweetTerminal", + "black": "#3F3F54", + "red": "#f60055", + "green": "#06c993", + "yellow": "#9700be", + "blue": "#f69154", + "purple": "#ec89cb", + "cyan": "#60ADEC", + "white": "#ABB2BF", + "brightBlack": "#959DCB", + "brightRed": "#f60055", + "brightGreen": "#06c993", + "brightYellow": "#9700be", + "brightBlue": "#f69154", + "brightPurple": "#ec89cb", + "brightCyan": "#00dded", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#222235", + "cursorColor": "#ffffff" + }, + { + "name": "Symphonic", + "black": "#000000", + "red": "#dc322f", + "green": "#56db3a", + "yellow": "#ff8400", + "blue": "#0084d4", + "purple": "#b729d9", + "cyan": "#ccccff", + "white": "#ffffff", + "brightBlack": "#1b1d21", + "brightRed": "#dc322f", + "brightGreen": "#56db3a", + "brightYellow": "#ff8400", + "brightBlue": "#0084d4", + "brightPurple": "#b729d9", + "brightCyan": "#ccccff", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#000000", + "cursorColor": "#ffffff" + }, + { + "name": "SynthWave", + "black": "#011627", + "red": "#fe4450", + "green": "#72f1b8", + "yellow": "#fede5d", + "blue": "#03edf9", + "purple": "#ff7edb", + "cyan": "#03edf9", + "white": "#ffffff", + "brightBlack": "#575656", + "brightRed": "#fe4450", + "brightGreen": "#72f1b8", + "brightYellow": "#fede5d", + "brightBlue": "#03edf9", + "brightPurple": "#ff7edb", + "brightCyan": "#03edf9", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#262335", + "cursorColor": "#03edf9" + }, + { + "name": "Teerb", + "black": "#1c1c1c", + "red": "#d68686", + "green": "#aed686", + "yellow": "#d7af87", + "blue": "#86aed6", + "purple": "#d6aed6", + "cyan": "#8adbb4", + "white": "#d0d0d0", + "brightBlack": "#1c1c1c", + "brightRed": "#d68686", + "brightGreen": "#aed686", + "brightYellow": "#e4c9af", + "brightBlue": "#86aed6", + "brightPurple": "#d6aed6", + "brightCyan": "#b1e7dd", + "brightWhite": "#efefef", + "foreground": "#d0d0d0", + "background": "#262626", + "cursorColor": "#d0d0d0" + }, + { + "name": "Tender", + "black": "#1d1d1d", + "red": "#c5152f", + "green": "#c9d05c", + "yellow": "#ffc24b", + "blue": "#b3deef", + "purple": "#d3b987", + "cyan": "#73cef4", + "white": "#eeeeee", + "brightBlack": "#323232", + "brightRed": "#f43753", + "brightGreen": "#d9e066", + "brightYellow": "#facc72", + "brightBlue": "#c0eafb", + "brightPurple": "#efd093", + "brightCyan": "#a1d6ec", + "brightWhite": "#ffffff", + "foreground": "#EEEEEE", + "background": "#282828", + "cursorColor": "#EEEEEE" + }, + { + "name": "TerminalBasic", + "black": "#000000", + "red": "#990000", + "green": "#00a600", + "yellow": "#999900", + "blue": "#0000b2", + "purple": "#b200b2", + "cyan": "#00a6b2", + "white": "#bfbfbf", + "brightBlack": "#666666", + "brightRed": "#e50000", + "brightGreen": "#00d900", + "brightYellow": "#e5e500", + "brightBlue": "#0000ff", + "brightPurple": "#e500e5", + "brightCyan": "#00e5e5", + "brightWhite": "#e5e5e5", + "foreground": "#000000", + "background": "#ffffff", + "cursorColor": "#000000" + }, + { + "name": "TerminixDark", + "black": "#282a2e", + "red": "#a54242", + "green": "#a1b56c", + "yellow": "#de935f", + "blue": "#225555", + "purple": "#85678f", + "cyan": "#5e8d87", + "white": "#777777", + "brightBlack": "#373b41", + "brightRed": "#c63535", + "brightGreen": "#608360", + "brightYellow": "#fa805a", + "brightBlue": "#449da1", + "brightPurple": "#ba8baf", + "brightCyan": "#86c1b9", + "brightWhite": "#c5c8c6", + "foreground": "#868A8C", + "background": "#091116", + "cursorColor": "#868A8C" + }, + { + "name": "ThayerBright", + "black": "#1b1d1e", + "red": "#f92672", + "green": "#4df840", + "yellow": "#f4fd22", + "blue": "#2757d6", + "purple": "#8c54fe", + "cyan": "#38c8b5", + "white": "#ccccc6", + "brightBlack": "#505354", + "brightRed": "#ff5995", + "brightGreen": "#b6e354", + "brightYellow": "#feed6c", + "brightBlue": "#3f78ff", + "brightPurple": "#9e6ffe", + "brightCyan": "#23cfd5", + "brightWhite": "#f8f8f2", + "foreground": "#f8f8f8", + "background": "#1b1d1e", + "cursorColor": "#f8f8f8" + }, + { + "name": "Tin", + "black": "#000000", + "red": "#8d534e", + "green": "#4e8d53", + "yellow": "#888d4e", + "blue": "#534e8d", + "purple": "#8d4e88", + "cyan": "#4e888d", + "white": "#ffffff", + "brightBlack": "#000000", + "brightRed": "#b57d78", + "brightGreen": "#78b57d", + "brightYellow": "#b0b578", + "brightBlue": "#7d78b5", + "brightPurple": "#b578b0", + "brightCyan": "#78b0b5", + "brightWhite": "#ffffff", + "foreground": "#ffffff", + "background": "#2e2e35", + "cursorColor": "#ffffff" + }, + { + "name": "TokyoNightLight", + "black": "#0f0f14", + "red": "#8c4351", + "green": "#485e30", + "yellow": "#8f5e15", + "blue": "#34548a", + "purple": "#5a4a78", + "cyan": "#0f4b6e", + "white": "#343b58", + "brightBlack": "#9699a3", + "brightRed": "#8c4351", + "brightGreen": "#485e30", + "brightYellow": "#8f5e15", + "brightBlue": "#34548a", + "brightPurple": "#5a4a78", + "brightCyan": "#0f4b6e", + "brightWhite": "#343b58", + "foreground": "#565a6e", + "background": "#d5d6db", + "cursorColor": "#565a6e" + }, + { + "name": "TokyoNightStorm", + "black": "#414868", + "red": "#f7768e", + "green": "#9ece6a", + "yellow": "#e0af68", + "blue": "#7aa2f7", + "purple": "#bb9af7", + "cyan": "#7dcfff", + "white": "#c0caf5", + "brightBlack": "#414868", + "brightRed": "#f7768e", + "brightGreen": "#9ece6a", + "brightYellow": "#e0af68", + "brightBlue": "#7aa2f7", + "brightPurple": "#bb9af7", + "brightCyan": "#7dcfff", + "brightWhite": "#c0caf5", + "foreground": "#c0caf5", + "background": "#24283b", + "cursorColor": "#c0caf5" + }, + { + "name": "TokyoNight", + "black": "#414868", + "red": "#f7768e", + "green": "#9ece6a", + "yellow": "#e0af68", + "blue": "#7aa2f7", + "purple": "#bb9af7", + "cyan": "#7dcfff", + "white": "#a9b1d6", + "brightBlack": "#414868", + "brightRed": "#f7768e", + "brightGreen": "#9ece6a", + "brightYellow": "#e0af68", + "brightBlue": "#7aa2f7", + "brightPurple": "#bb9af7", + "brightCyan": "#7dcfff", + "brightWhite": "#c0caf5", + "foreground": "#c0caf5", + "background": "#1a1b26", + "cursorColor": "#c0caf5" + }, + { + "name": "TomorrowNightBlue", + "black": "#000000", + "red": "#FF9DA3", + "green": "#D1F1A9", + "yellow": "#FFEEAD", + "blue": "#BBDAFF", + "purple": "#EBBBFF", + "cyan": "#99FFFF", + "white": "#FFFEFE", + "brightBlack": "#000000", + "brightRed": "#FF9CA3", + "brightGreen": "#D0F0A8", + "brightYellow": "#FFEDAC", + "brightBlue": "#BADAFF", + "brightPurple": "#EBBAFF", + "brightCyan": "#99FFFF", + "brightWhite": "#FFFEFE", + "foreground": "#FFFEFE", + "background": "#002451", + "cursorColor": "#FFFEFE" + }, + { + "name": "TomorrowNightBright", + "black": "#000000", + "red": "#D54E53", + "green": "#B9CA49", + "yellow": "#E7C547", + "blue": "#79A6DA", + "purple": "#C397D8", + "cyan": "#70C0B1", + "white": "#FFFEFE", + "brightBlack": "#000000", + "brightRed": "#D44D53", + "brightGreen": "#B9C949", + "brightYellow": "#E6C446", + "brightBlue": "#79A6DA", + "brightPurple": "#C396D7", + "brightCyan": "#70C0B1", + "brightWhite": "#FFFEFE", + "foreground": "#E9E9E9", + "background": "#000000", + "cursorColor": "#E9E9E9" + }, + { + "name": "TomorrowNightEighties", + "black": "#000000", + "red": "#F27779", + "green": "#99CC99", + "yellow": "#FFCC66", + "blue": "#6699CC", + "purple": "#CC99CC", + "cyan": "#66CCCC", + "white": "#FFFEFE", + "brightBlack": "#000000", + "brightRed": "#F17779", + "brightGreen": "#99CC99", + "brightYellow": "#FFCC66", + "brightBlue": "#6699CC", + "brightPurple": "#CC99CC", + "brightCyan": "#66CCCC", + "brightWhite": "#FFFEFE", + "foreground": "#CCCCCC", + "background": "#2C2C2C", + "cursorColor": "#CCCCCC" + }, + { + "name": "TomorrowNight", + "black": "#000000", + "red": "#CC6666", + "green": "#B5BD68", + "yellow": "#F0C674", + "blue": "#81A2BE", + "purple": "#B293BB", + "cyan": "#8ABEB7", + "white": "#FFFEFE", + "brightBlack": "#000000", + "brightRed": "#CC6666", + "brightGreen": "#B5BD68", + "brightYellow": "#F0C574", + "brightBlue": "#80A1BD", + "brightPurple": "#B294BA", + "brightCyan": "#8ABDB6", + "brightWhite": "#FFFEFE", + "foreground": "#C5C8C6", + "background": "#1D1F21", + "cursorColor": "#C4C8C5" + }, + { + "name": "Tomorrow", + "black": "#000000", + "red": "#C82828", + "green": "#718C00", + "yellow": "#EAB700", + "blue": "#4171AE", + "purple": "#8959A8", + "cyan": "#3E999F", + "white": "#FFFEFE", + "brightBlack": "#000000", + "brightRed": "#C82828", + "brightGreen": "#708B00", + "brightYellow": "#E9B600", + "brightBlue": "#4170AE", + "brightPurple": "#8958A7", + "brightCyan": "#3D999F", + "brightWhite": "#FFFEFE", + "foreground": "#4D4D4C", + "background": "#FFFFFF", + "cursorColor": "#4C4C4C" + }, + { + "name": "ToyChest", + "black": "#2c3f58", + "red": "#be2d26", + "green": "#1a9172", + "yellow": "#db8e27", + "blue": "#325d96", + "purple": "#8a5edc", + "cyan": "#35a08f", + "white": "#23d183", + "brightBlack": "#336889", + "brightRed": "#dd5944", + "brightGreen": "#31d07b", + "brightYellow": "#e7d84b", + "brightBlue": "#34a6da", + "brightPurple": "#ae6bdc", + "brightCyan": "#42c3ae", + "brightWhite": "#d5d5d5", + "foreground": "#31d07b", + "background": "#24364b", + "cursorColor": "#31d07b" + }, + { + "name": "Treehouse", + "black": "#321300", + "red": "#b2270e", + "green": "#44a900", + "yellow": "#aa820c", + "blue": "#58859a", + "purple": "#97363d", + "cyan": "#b25a1e", + "white": "#786b53", + "brightBlack": "#433626", + "brightRed": "#ed5d20", + "brightGreen": "#55f238", + "brightYellow": "#f2b732", + "brightBlue": "#85cfed", + "brightPurple": "#e14c5a", + "brightCyan": "#f07d14", + "brightWhite": "#ffc800", + "foreground": "#786b53", + "background": "#191919", + "cursorColor": "#786b53" + }, + { + "name": "Twilight", + "black": "#141414", + "red": "#c06d44", + "green": "#afb97a", + "yellow": "#c2a86c", + "blue": "#44474a", + "purple": "#b4be7c", + "cyan": "#778385", + "white": "#ffffd4", + "brightBlack": "#262626", + "brightRed": "#de7c4c", + "brightGreen": "#ccd88c", + "brightYellow": "#e2c47e", + "brightBlue": "#5a5e62", + "brightPurple": "#d0dc8e", + "brightCyan": "#8a989b", + "brightWhite": "#ffffd4", + "foreground": "#ffffd4", + "background": "#141414", + "cursorColor": "#ffffd4" + }, + { + "name": "Ura", + "black": "#000000", + "red": "#c21b6f", + "green": "#6fc21b", + "yellow": "#c26f1b", + "blue": "#1b6fc2", + "purple": "#6f1bc2", + "cyan": "#1bc26f", + "white": "#808080", + "brightBlack": "#808080", + "brightRed": "#ee84b9", + "brightGreen": "#b9ee84", + "brightYellow": "#eeb984", + "brightBlue": "#84b9ee", + "brightPurple": "#b984ee", + "brightCyan": "#84eeb9", + "brightWhite": "#e5e5e5", + "foreground": "#23476a", + "background": "#feffee", + "cursorColor": "#23476a" + }, + { + "name": "Urple", + "black": "#000000", + "red": "#b0425b", + "green": "#37a415", + "yellow": "#ad5c42", + "blue": "#564d9b", + "purple": "#6c3ca1", + "cyan": "#808080", + "white": "#87799c", + "brightBlack": "#5d3225", + "brightRed": "#ff6388", + "brightGreen": "#29e620", + "brightYellow": "#f08161", + "brightBlue": "#867aed", + "brightPurple": "#a05eee", + "brightCyan": "#eaeaea", + "brightWhite": "#bfa3ff", + "foreground": "#877a9b", + "background": "#1b1b23", + "cursorColor": "#877a9b" + }, + { + "name": "Vag", + "black": "#303030", + "red": "#a87139", + "green": "#39a871", + "yellow": "#71a839", + "blue": "#7139a8", + "purple": "#a83971", + "cyan": "#3971a8", + "white": "#8a8a8a", + "brightBlack": "#494949", + "brightRed": "#b0763b", + "brightGreen": "#3bb076", + "brightYellow": "#76b03b", + "brightBlue": "#763bb0", + "brightPurple": "#b03b76", + "brightCyan": "#3b76b0", + "brightWhite": "#cfcfcf", + "foreground": "#d9e6f2", + "background": "#191f1d", + "cursorColor": "#d9e6f2" + }, + { + "name": "Vaughn", + "black": "#25234f", + "red": "#705050", + "green": "#60b48a", + "yellow": "#dfaf8f", + "blue": "#5555ff", + "purple": "#f08cc3", + "cyan": "#8cd0d3", + "white": "#709080", + "brightBlack": "#709080", + "brightRed": "#dca3a3", + "brightGreen": "#60b48a", + "brightYellow": "#f0dfaf", + "brightBlue": "#5555ff", + "brightPurple": "#ec93d3", + "brightCyan": "#93e0e3", + "brightWhite": "#ffffff", + "foreground": "#dcdccc", + "background": "#25234f", + "cursorColor": "#dcdccc" + }, + { + "name": "VibrantInk", + "black": "#878787", + "red": "#ff6600", + "green": "#ccff04", + "yellow": "#ffcc00", + "blue": "#44b4cc", + "purple": "#9933cc", + "cyan": "#44b4cc", + "white": "#f5f5f5", + "brightBlack": "#555555", + "brightRed": "#ff0000", + "brightGreen": "#00ff00", + "brightYellow": "#ffff00", + "brightBlue": "#0000ff", + "brightPurple": "#ff00ff", + "brightCyan": "#00ffff", + "brightWhite": "#e5e5e5", + "foreground": "#ffffff", + "background": "#000000", + "cursorColor": "#ffffff" + }, + { + "name": "VSCodeDark+", + "black": "#6A787A", + "red": "#E9653B", + "green": "#39E9A8", + "yellow": "#E5B684", + "blue": "#44AAE6", + "purple": "#E17599", + "cyan": "#3DD5E7", + "white": "#C3DDE1", + "brightBlack": "#598489", + "brightRed": "#E65029", + "brightGreen": "#00FF9A", + "brightYellow": "#E89440", + "brightBlue": "#009AFB", + "brightPurple": "#FF578F", + "brightCyan": "#5FFFFF", + "brightWhite": "#D9FBFF", + "foreground": "#CCCCCC", + "background": "#1E1E1E", + "cursorColor": "#CCCCCC" + }, + { + "name": "VSCodeLight+", + "black": "#020202", + "red": "#CD3232", + "green": "#00BC00", + "yellow": "#A5A900", + "blue": "#0752A8", + "purple": "#BC05BC", + "cyan": "#0598BC", + "white": "#343434", + "brightBlack": "#5E5E5E", + "brightRed": "#cd3333", + "brightGreen": "#1BCE1A", + "brightYellow": "#ADBB5B", + "brightBlue": "#0752A8", + "brightPurple": "#C451CE", + "brightCyan": "#52A8C7", + "brightWhite": "#A6A3A6", + "foreground": "#020202", + "background": "#f9f9f9", + "cursorColor": "#020202" + }, + { + "name": "WarmNeon", + "black": "#000000", + "red": "#e24346", + "green": "#39b13a", + "yellow": "#dae145", + "blue": "#4261c5", + "purple": "#f920fb", + "cyan": "#2abbd4", + "white": "#d0b8a3", + "brightBlack": "#fefcfc", + "brightRed": "#e97071", + "brightGreen": "#9cc090", + "brightYellow": "#ddda7a", + "brightBlue": "#7b91d6", + "brightPurple": "#f674ba", + "brightCyan": "#5ed1e5", + "brightWhite": "#d8c8bb", + "foreground": "#afdab6", + "background": "#404040", + "cursorColor": "#afdab6" + }, + { + "name": "Wez", + "black": "#000000", + "red": "#cc5555", + "green": "#55cc55", + "yellow": "#cdcd55", + "blue": "#5555cc", + "purple": "#cc55cc", + "cyan": "#7acaca", + "white": "#cccccc", + "brightBlack": "#555555", + "brightRed": "#ff5555", + "brightGreen": "#55ff55", + "brightYellow": "#ffff55", + "brightBlue": "#5555ff", + "brightPurple": "#ff55ff", + "brightCyan": "#55ffff", + "brightWhite": "#ffffff", + "foreground": "#b3b3b3", + "background": "#000000", + "cursorColor": "#b3b3b3" + }, + { + "name": "WildCherry", + "black": "#000507", + "red": "#d94085", + "green": "#2ab250", + "yellow": "#ffd16f", + "blue": "#883cdc", + "purple": "#ececec", + "cyan": "#c1b8b7", + "white": "#fff8de", + "brightBlack": "#009cc9", + "brightRed": "#da6bac", + "brightGreen": "#f4dca5", + "brightYellow": "#eac066", + "brightBlue": "#308cba", + "brightPurple": "#ae636b", + "brightCyan": "#ff919d", + "brightWhite": "#e4838d", + "foreground": "#dafaff", + "background": "#1f1726", + "cursorColor": "#dafaff" + }, + { + "name": "Wombat", + "black": "#000000", + "red": "#ff615a", + "green": "#b1e969", + "yellow": "#ebd99c", + "blue": "#5da9f6", + "purple": "#e86aff", + "cyan": "#82fff7", + "white": "#dedacf", + "brightBlack": "#313131", + "brightRed": "#f58c80", + "brightGreen": "#ddf88f", + "brightYellow": "#eee5b2", + "brightBlue": "#a5c7ff", + "brightPurple": "#ddaaff", + "brightCyan": "#b7fff9", + "brightWhite": "#ffffff", + "foreground": "#dedacf", + "background": "#171717", + "cursorColor": "#dedacf" + }, + { + "name": "Wryan", + "black": "#333333", + "red": "#8c4665", + "green": "#287373", + "yellow": "#7c7c99", + "blue": "#395573", + "purple": "#5e468c", + "cyan": "#31658c", + "white": "#899ca1", + "brightBlack": "#3d3d3d", + "brightRed": "#bf4d80", + "brightGreen": "#53a6a6", + "brightYellow": "#9e9ecb", + "brightBlue": "#477ab3", + "brightPurple": "#7e62b3", + "brightCyan": "#6096bf", + "brightWhite": "#c0c0c0", + "foreground": "#999993", + "background": "#101010", + "cursorColor": "#999993" + }, + { + "name": "Wzoreck", + "black": "#2E3436", + "red": "#FC6386", + "green": "#424043", + "yellow": "#FCE94F", + "blue": "#FB976B", + "purple": "#75507B", + "cyan": "#34E2E2", + "white": "#FFFFFF", + "brightBlack": "#989595", + "brightRed": "#FC6386", + "brightGreen": "#A9DC76", + "brightYellow": "#FCE94F", + "brightBlue": "#FB976B", + "brightPurple": "#AB9DF2", + "brightCyan": "#34E2E2", + "brightWhite": "#D1D1C0", + "foreground": "#FCFCFA", + "background": "#424043", + "cursorColor": "#FCFCFA" + }, + { + "name": "Zenburn", + "black": "#4d4d4d", + "red": "#705050", + "green": "#60b48a", + "yellow": "#f0dfaf", + "blue": "#506070", + "purple": "#dc8cc3", + "cyan": "#8cd0d3", + "white": "#dcdccc", + "brightBlack": "#709080", + "brightRed": "#dca3a3", + "brightGreen": "#c3bf9f", + "brightYellow": "#e0cf9f", + "brightBlue": "#94bff3", + "brightPurple": "#ec93d3", + "brightCyan": "#93e0e3", + "brightWhite": "#ffffff", + "foreground": "#dcdccc", + "background": "#3f3f3f", + "cursorColor": "#dcdccc" + } +] diff --git a/dimos/web/dimos_interface/tsconfig.json b/dimos/web/dimos_interface/tsconfig.json new file mode 100644 index 0000000000..4bf29f39d2 --- /dev/null +++ b/dimos/web/dimos_interface/tsconfig.json @@ -0,0 +1,25 @@ +{ + "extends": "@tsconfig/svelte/tsconfig.json", + "compilerOptions": { + "target": "ESNext", + "useDefineForClassFields": true, + "module": "ESNext", + "resolveJsonModule": true, + "allowJs": true, + "checkJs": true, + "isolatedModules": true, + "types": [ + "node" + ] + }, + "include": [ + "src/**/*.ts", + "src/**/*.js", + "src/**/*.svelte" + ], + "references": [ + { + "path": "./tsconfig.node.json" + } + ] +} diff --git a/dimos/web/dimos_interface/tsconfig.node.json b/dimos/web/dimos_interface/tsconfig.node.json new file mode 100644 index 0000000000..ad883d0eb4 --- /dev/null +++ b/dimos/web/dimos_interface/tsconfig.node.json @@ -0,0 +1,11 @@ +{ + "compilerOptions": { + "composite": true, + "skipLibCheck": true, + "module": "ESNext", + "moduleResolution": "bundler" + }, + "include": [ + "vite.config.ts" + ] +} diff --git a/dimos/web/dimos_interface/vite.config.ts b/dimos/web/dimos_interface/vite.config.ts new file mode 100644 index 0000000000..29be79dd4a --- /dev/null +++ b/dimos/web/dimos_interface/vite.config.ts @@ -0,0 +1,97 @@ +/** + * Copyright 2025 Dimensional Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { defineConfig } from 'vite'; +import { svelte } from '@sveltejs/vite-plugin-svelte'; + +// https://vitejs.dev/config/ +export default defineConfig({ + plugins: [svelte()], + server: { + port: 3000, + host: '0.0.0.0', + watch: { + // Exclude node_modules, .git and other large directories + ignored: ['**/node_modules/**', '**/.git/**', '**/dist/**', 'lambda/**'], + // Use polling instead of filesystem events (less efficient but uses fewer watchers) + usePolling: true, + }, + proxy: { + '/api': { + target: 'https://0rqz7w5rvf.execute-api.us-east-2.amazonaws.com', + changeOrigin: true, + rewrite: (path) => path.replace(/^\/api/, '/default/getGenesis'), + configure: (proxy, _options) => { + proxy.on('error', (err, _req, _res) => { + console.log('proxy error', err); + }); + proxy.on('proxyReq', (proxyReq, req, _res) => { + console.log('Sending Request to the Target:', req.method, req.url); + }); + proxy.on('proxyRes', (proxyRes, req, _res) => { + console.log('Received Response from the Target:', proxyRes.statusCode, req.url); + }); + }, + }, + '/unitree': { + target: 'http://0.0.0.0:5555', + changeOrigin: true, + configure: (proxy, _options) => { + proxy.on('error', (err, _req, _res) => { + console.log('unitree proxy error', err); + }); + proxy.on('proxyReq', (proxyReq, req, _res) => { + console.log('Sending Unitree Request:', req.method, req.url); + }); + proxy.on('proxyRes', (proxyRes, req, _res) => { + console.log('Received Unitree Response:', proxyRes.statusCode, req.url); + }); + }, + }, + '/text_streams': { + target: 'http://0.0.0.0:5555', + changeOrigin: true, + configure: (proxy, _options) => { + proxy.on('error', (err, _req, _res) => { + console.log('text streams proxy error', err); + }); + proxy.on('proxyReq', (proxyReq, req, _res) => { + console.log('Sending Text Streams Request:', req.method, req.url); + }); + proxy.on('proxyRes', (proxyRes, req, _res) => { + console.log('Received Text Streams Response:', proxyRes.statusCode, req.url); + }); + }, + }, + '/simulation': { + target: '', // Will be set dynamically + changeOrigin: true, + configure: (proxy, _options) => { + proxy.on('error', (err, _req, _res) => { + console.log('proxy error', err); + }); + proxy.on('proxyReq', (proxyReq, req, _res) => { + console.log('Sending Simulation Request:', req.method, req.url); + }); + }, + } + }, + cors: true + }, + define: { + 'process.env': process.env + } +}); diff --git a/dimos/web/edge_io.py b/dimos/web/edge_io.py index 5bef95c39d..28ccae8733 100644 --- a/dimos/web/edge_io.py +++ b/dimos/web/edge_io.py @@ -1,87 +1,26 @@ -from flask import Flask, jsonify, request, Response, render_template -import cv2 -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable, SingleAssignmentDisposable -from reactivex.subject import BehaviorSubject, Subject -from queue import Queue - -class EdgeIO(): - def __init__(self, dev_name:str="NA", edge_type:str="Base"): +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from reactivex.disposable import CompositeDisposable + + +class EdgeIO: + def __init__(self, dev_name: str = "NA", edge_type: str = "Base") -> None: self.dev_name = dev_name self.edge_type = edge_type self.disposables = CompositeDisposable() - def dispose_all(self): + def dispose_all(self) -> None: """Disposes of all active subscriptions managed by this agent.""" self.disposables.dispose() - -class FlaskServer(EdgeIO): - def __init__(self, dev_name="Flask Server", edge_type="Bidirectional", port=5555, **streams): - super().__init__(dev_name, edge_type) - self.app = Flask(__name__) - self.port = port - self.streams = streams - self.active_streams = {} - - # Initialize shared stream references with ref_count - for key in self.streams: - if self.streams[key] is not None: - # Apply share and ref_count to manage subscriptions - self.active_streams[key] = self.streams[key].pipe( - ops.map(self.process_frame_flask), - ops.share() - ) - - self.setup_routes() - - def process_frame_flask(self, frame): - """Convert frame to JPEG format for streaming.""" - _, buffer = cv2.imencode('.jpg', frame) - return buffer.tobytes() - - def setup_routes(self): - @self.app.route('/') - def index(): - stream_keys = list(self.streams.keys()) # Get the keys from the streams dictionary - return render_template('index.html', stream_keys=stream_keys) - - # Function to create a streaming response - def stream_generator(key): - def generate(): - frame_queue = Queue() - disposable = SingleAssignmentDisposable() - - # Subscribe to the shared, ref-counted stream - if key in self.active_streams: - disposable.disposable = self.active_streams[key].subscribe( - lambda frame: frame_queue.put(frame) if frame is not None else None, - lambda e: frame_queue.put(None), - lambda: frame_queue.put(None) - ) - - try: - while True: - frame = frame_queue.get() - if frame is None: - break - yield (b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') - finally: - disposable.dispose() - - return generate - - def make_response_generator(key): - def response_generator(): - return Response(stream_generator(key)(), mimetype='multipart/x-mixed-replace; boundary=frame') - return response_generator - - # Dynamically adding routes using add_url_rule - for key in self.streams: - endpoint = f'video_feed_{key}' - self.app.add_url_rule( - f'/video_feed/{key}', endpoint, view_func=make_response_generator(key)) - - def run(self, host='0.0.0.0', port=5555): - self.port = port - self.app.run(host=host, port=self.port, debug=False) diff --git a/dimos/web/fastapi_server.py b/dimos/web/fastapi_server.py new file mode 100644 index 0000000000..606e081fb3 --- /dev/null +++ b/dimos/web/fastapi_server.py @@ -0,0 +1,226 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Working FastAPI/Uvicorn Impl. + +# Notes: Do not use simultaneously with Flask, this includes imports. +# Workers are not yet setup, as this requires a much more intricate +# reorganization. There appears to be possible signalling issues when +# opening up streams on multiple windows/reloading which will need to +# be fixed. Also note, Chrome only supports 6 simultaneous web streams, +# and its advised to test threading/worker performance with another +# browser like Safari. + +# Fast Api & Uvicorn +import asyncio +from pathlib import Path +from queue import Empty, Queue +from threading import Lock + +import cv2 +from fastapi import FastAPI, Form, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse +from fastapi.templating import Jinja2Templates +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import SingleAssignmentDisposable +from sse_starlette.sse import EventSourceResponse +import uvicorn + +from dimos.web.edge_io import EdgeIO + +# TODO: Resolve threading, start/stop stream functionality. + + +class FastAPIServer(EdgeIO): + def __init__( # type: ignore[no-untyped-def] + self, + dev_name: str = "FastAPI Server", + edge_type: str = "Bidirectional", + host: str = "0.0.0.0", + port: int = 5555, + text_streams=None, + **streams, + ) -> None: + super().__init__(dev_name, edge_type) + self.app = FastAPI() + self.port = port + self.host = host + BASE_DIR = Path(__file__).resolve().parent + self.templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) + self.streams = streams + self.active_streams = {} + self.stream_locks = {key: Lock() for key in self.streams} + self.stream_queues = {} # type: ignore[var-annotated] + self.stream_disposables = {} # type: ignore[var-annotated] + + # Initialize text streams + self.text_streams = text_streams or {} + self.text_queues = {} # type: ignore[var-annotated] + self.text_disposables = {} + self.text_clients = set() # type: ignore[var-annotated] + + # Create a Subject for text queries + self.query_subject = rx.subject.Subject() # type: ignore[var-annotated] + self.query_stream = self.query_subject.pipe(ops.share()) + + for key in self.streams: + if self.streams[key] is not None: + self.active_streams[key] = self.streams[key].pipe( + ops.map(self.process_frame_fastapi), ops.share() + ) + + # Set up text stream subscriptions + for key, stream in self.text_streams.items(): + if stream is not None: + self.text_queues[key] = Queue(maxsize=100) + disposable = stream.subscribe( + lambda text, k=key: self.text_queues[k].put(text) if text is not None else None, + lambda e, k=key: self.text_queues[k].put(None), + lambda k=key: self.text_queues[k].put(None), + ) + self.text_disposables[key] = disposable + self.disposables.add(disposable) + + self.setup_routes() + + def process_frame_fastapi(self, frame): # type: ignore[no-untyped-def] + """Convert frame to JPEG format for streaming.""" + _, buffer = cv2.imencode(".jpg", frame) + return buffer.tobytes() + + def stream_generator(self, key): # type: ignore[no-untyped-def] + """Generate frames for a given video stream.""" + + def generate(): # type: ignore[no-untyped-def] + if key not in self.stream_queues: + self.stream_queues[key] = Queue(maxsize=10) + + frame_queue = self.stream_queues[key] + + # Clear any existing disposable for this stream + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + disposable = SingleAssignmentDisposable() + self.stream_disposables[key] = disposable + self.disposables.add(disposable) + + if key in self.active_streams: + with self.stream_locks[key]: + # Clear the queue before starting new subscription + while not frame_queue.empty(): + try: + frame_queue.get_nowait() + except Empty: + break + + disposable.disposable = self.active_streams[key].subscribe( + lambda frame: frame_queue.put(frame) if frame is not None else None, + lambda e: frame_queue.put(None), + lambda: frame_queue.put(None), + ) + + try: + while True: + try: + frame = frame_queue.get(timeout=1) + if frame is None: + break + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") + except Empty: + # Instead of breaking, continue waiting for new frames + continue + finally: + if key in self.stream_disposables: + self.stream_disposables[key].dispose() + + return generate + + def create_video_feed_route(self, key): # type: ignore[no-untyped-def] + """Create a video feed route for a specific stream.""" + + async def video_feed(): # type: ignore[no-untyped-def] + return StreamingResponse( + self.stream_generator(key)(), # type: ignore[no-untyped-call] + media_type="multipart/x-mixed-replace; boundary=frame", + ) + + return video_feed + + async def text_stream_generator(self, key): # type: ignore[no-untyped-def] + """Generate SSE events for text stream.""" + client_id = id(object()) + self.text_clients.add(client_id) + + try: + while True: + if key in self.text_queues: + try: + text = self.text_queues[key].get(timeout=1) + if text is not None: + yield {"event": "message", "id": key, "data": text} + except Empty: + # Send a keep-alive comment + yield {"event": "ping", "data": ""} + await asyncio.sleep(0.1) + finally: + self.text_clients.remove(client_id) + + def setup_routes(self) -> None: + """Set up FastAPI routes.""" + + @self.app.get("/", response_class=HTMLResponse) + async def index(request: Request): # type: ignore[no-untyped-def] + stream_keys = list(self.streams.keys()) + text_stream_keys = list(self.text_streams.keys()) + return self.templates.TemplateResponse( + "index_fastapi.html", + { + "request": request, + "stream_keys": stream_keys, + "text_stream_keys": text_stream_keys, + }, + ) + + @self.app.post("/submit_query") + async def submit_query(query: str = Form(...)): # type: ignore[no-untyped-def] + # Using Form directly as a dependency ensures proper form handling + try: + if query: + # Emit the query through our Subject + self.query_subject.on_next(query) + return JSONResponse({"success": True, "message": "Query received"}) + return JSONResponse({"success": False, "message": "No query provided"}) + except Exception as e: + # Ensure we always return valid JSON even on error + return JSONResponse( + status_code=500, + content={"success": False, "message": f"Server error: {e!s}"}, + ) + + @self.app.get("/text_stream/{key}") + async def text_stream(key: str): # type: ignore[no-untyped-def] + if key not in self.text_streams: + raise HTTPException(status_code=404, detail=f"Text stream '{key}' not found") + return EventSourceResponse(self.text_stream_generator(key)) # type: ignore[no-untyped-call] + + for key in self.streams: + self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) # type: ignore[no-untyped-call] + + def run(self) -> None: + """Run the FastAPI server.""" + uvicorn.run( + self.app, host=self.host, port=self.port + ) # TODO: Translate structure to enable in-built workers' diff --git a/dimos/web/flask_server.py b/dimos/web/flask_server.py new file mode 100644 index 0000000000..4cd6d0a5e0 --- /dev/null +++ b/dimos/web/flask_server.py @@ -0,0 +1,105 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Queue + +import cv2 +from flask import Flask, Response, render_template +from reactivex import operators as ops +from reactivex.disposable import SingleAssignmentDisposable + +from dimos.web.edge_io import EdgeIO + + +class FlaskServer(EdgeIO): + def __init__( # type: ignore[no-untyped-def] + self, + dev_name: str = "Flask Server", + edge_type: str = "Bidirectional", + port: int = 5555, + **streams, + ) -> None: + super().__init__(dev_name, edge_type) + self.app = Flask(__name__) + self.port = port + self.streams = streams + self.active_streams = {} + + # Initialize shared stream references with ref_count + for key in self.streams: + if self.streams[key] is not None: + # Apply share and ref_count to manage subscriptions + self.active_streams[key] = self.streams[key].pipe( + ops.map(self.process_frame_flask), ops.share() + ) + + self.setup_routes() + + def process_frame_flask(self, frame): # type: ignore[no-untyped-def] + """Convert frame to JPEG format for streaming.""" + _, buffer = cv2.imencode(".jpg", frame) + return buffer.tobytes() + + def setup_routes(self) -> None: + @self.app.route("/") + def index(): # type: ignore[no-untyped-def] + stream_keys = list(self.streams.keys()) # Get the keys from the streams dictionary + return render_template("index_flask.html", stream_keys=stream_keys) + + # Function to create a streaming response + def stream_generator(key): # type: ignore[no-untyped-def] + def generate(): # type: ignore[no-untyped-def] + frame_queue = Queue() # type: ignore[var-annotated] + disposable = SingleAssignmentDisposable() + + # Subscribe to the shared, ref-counted stream + if key in self.active_streams: + disposable.disposable = self.active_streams[key].subscribe( + lambda frame: frame_queue.put(frame) if frame is not None else None, + lambda e: frame_queue.put(None), + lambda: frame_queue.put(None), + ) + + try: + while True: + frame = frame_queue.get() + if frame is None: + break + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n") + finally: + disposable.dispose() + + return generate + + def make_response_generator(key): # type: ignore[no-untyped-def] + def response_generator(): # type: ignore[no-untyped-def] + return Response( + stream_generator(key)(), # type: ignore[no-untyped-call] + mimetype="multipart/x-mixed-replace; boundary=frame", + ) + + return response_generator + + # Dynamically adding routes using add_url_rule + for key in self.streams: + endpoint = f"video_feed_{key}" + self.app.add_url_rule( + f"/video_feed/{key}", + endpoint, + view_func=make_response_generator(key), # type: ignore[no-untyped-call] + ) + + def run(self, host: str = "0.0.0.0", port: int = 5555, threaded: bool = True) -> None: + self.port = port + self.app.run(host=host, port=self.port, debug=False, threaded=threaded) diff --git a/dimos/web/robot_web_interface.py b/dimos/web/robot_web_interface.py new file mode 100644 index 0000000000..f45319f1d2 --- /dev/null +++ b/dimos/web/robot_web_interface.py @@ -0,0 +1,35 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Robot Web Interface wrapper for DIMOS. +Provides a clean interface to the dimensional-interface FastAPI server. +""" + +from dimos.web.dimos_interface.api.server import FastAPIServer + + +class RobotWebInterface(FastAPIServer): + """Wrapper class for the dimos-interface FastAPI server.""" + + def __init__(self, port: int = 5555, text_streams=None, audio_subject=None, **streams) -> None: # type: ignore[no-untyped-def] + super().__init__( + dev_name="Robot Web Interface", + edge_type="Bidirectional", + host="0.0.0.0", + port=port, + text_streams=text_streams, + audio_subject=audio_subject, + **streams, + ) diff --git a/dimos/web/templates/index.html b/dimos/web/templates/index.html deleted file mode 100644 index b2897b93f4..0000000000 --- a/dimos/web/templates/index.html +++ /dev/null @@ -1,54 +0,0 @@ - - - - - - Video Stream Example - - - -

Live Video Streams

- - - {% for key in stream_keys %} -

Live {{ key.replace('_', ' ').title() }} Feed

- {{ key }} Feed - {% endfor %} - - - - - \ No newline at end of file diff --git a/dimos/web/templates/index_fastapi.html b/dimos/web/templates/index_fastapi.html new file mode 100644 index 0000000000..75b0c1c179 --- /dev/null +++ b/dimos/web/templates/index_fastapi.html @@ -0,0 +1,389 @@ + + + + + + + + Video Stream Example + + + +

Live Video Streams

+ +
+

Ask a Question

+
+ + +
+
+
+ + + {% if text_stream_keys %} +
+

Text Streams

+ {% for key in text_stream_keys %} +
+

{{ key.replace('_', ' ').title() }}

+
+
+ + + +
+
+ {% endfor %} +
+ {% endif %} + +
+ {% for key in stream_keys %} +
+

{{ key.replace('_', ' ').title() }}

+ {{ key }} Feed +
+ + +
+
+ {% endfor %} +
+ + + + + + diff --git a/dimos/web/templates/index_flask.html b/dimos/web/templates/index_flask.html new file mode 100644 index 0000000000..e41665e588 --- /dev/null +++ b/dimos/web/templates/index_flask.html @@ -0,0 +1,118 @@ + + + + + + + + Video Stream Example + + + +

Live Video Streams

+ +
+ {% for key in stream_keys %} +
+

{{ key.replace('_', ' ').title() }}

+ {{ key }} Feed +
+ {% endfor %} +
+ + + + + diff --git a/dimos/web/websocket_vis/README.md b/dimos/web/websocket_vis/README.md new file mode 100644 index 0000000000..c04235958e --- /dev/null +++ b/dimos/web/websocket_vis/README.md @@ -0,0 +1,66 @@ +# WebSocket Visualization Module + +The `WebsocketVisModule` provides a real-time data for visualization and control of the robot in Foxglove (see `dimos/web/command-center-extension/README.md`). + +## Overview + +Visualization: + +- Robot position and orientation +- Navigation paths +- Costmaps + +Control: + +- Set navigation goal +- Set GPS location goal +- Keyboard teleop (WASD) +- Trigger exploration + +## What it Provides + +### Inputs (Subscribed Topics) +- `robot_pose` (PoseStamped): Current robot position and orientation +- `gps_location` (LatLon): GPS coordinates of the robot +- `path` (Path): Planned navigation path +- `global_costmap` (OccupancyGrid): Global costmap for visualization + +### Outputs (Published Topics) +- `click_goal` (PoseStamped): Goal positions set by user clicks in the web interface +- `gps_goal` (LatLon): GPS goal coordinates set through the interface +- `explore_cmd` (Bool): Command to start autonomous exploration +- `stop_explore_cmd` (Bool): Command to stop exploration +- `movecmd` (Twist): Direct movement commands from the interface +- `movecmd_stamped` (TwistStamped): Timestamped movement commands + +## How to Use + +### Basic Usage + +```python +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos import core + +# Deploy the WebSocket visualization module +websocket_vis = dimos.deploy(WebsocketVisModule, port=7779) + +# Receive control from the Foxglove plugin. +websocket_vis.click_goal.transport = core.LCMTransport("/goal_request", PoseStamped) +websocket_vis.explore_cmd.transport = core.LCMTransport("/explore_cmd", Bool) +websocket_vis.stop_explore_cmd.transport = core.LCMTransport("/stop_explore_cmd", Bool) +websocket_vis.movecmd.transport = core.LCMTransport("/cmd_vel", Twist) +websocket_vis.gps_goal.transport = core.pLCMTransport("/gps_goal") + +# Send visualization data to the Foxglove plugin. +websocket_vis.robot_pose.connect(connection.odom) +websocket_vis.path.connect(global_planner.path) +websocket_vis.global_costmap.connect(mapper.global_costmap) +websocket_vis.gps_location.connect(connection.gps_location) + +# Start the module +websocket_vis.start() +``` + +### Accessing the Interface + +See `dimos/web/command-center-extension/README.md` for how to add the command-center plugin in Foxglove. diff --git a/dimos/web/websocket_vis/costmap_viz.py b/dimos/web/websocket_vis/costmap_viz.py new file mode 100644 index 0000000000..21309c94bc --- /dev/null +++ b/dimos/web/websocket_vis/costmap_viz.py @@ -0,0 +1,65 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple costmap wrapper for visualization purposes. +This is a minimal implementation to support websocket visualization. +""" + +import numpy as np + +from dimos.msgs.nav_msgs import OccupancyGrid + + +class CostmapViz: + """A wrapper around OccupancyGrid for visualization compatibility.""" + + def __init__(self, occupancy_grid: OccupancyGrid | None = None) -> None: + """Initialize from an OccupancyGrid.""" + self.occupancy_grid = occupancy_grid + + @property + def data(self) -> np.ndarray | None: # type: ignore[type-arg] + """Get the costmap data as a numpy array.""" + if self.occupancy_grid: + return self.occupancy_grid.grid + return None + + @property + def width(self) -> int: + """Get the width of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.width + return 0 + + @property + def height(self) -> int: + """Get the height of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.height + return 0 + + @property + def resolution(self) -> float: + """Get the resolution of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.resolution + return 1.0 + + @property + def origin(self): # type: ignore[no-untyped-def] + """Get the origin pose of the costmap.""" + if self.occupancy_grid: + return self.occupancy_grid.origin + return None diff --git a/dimos/web/websocket_vis/optimized_costmap.py b/dimos/web/websocket_vis/optimized_costmap.py new file mode 100644 index 0000000000..dfe5822a7b --- /dev/null +++ b/dimos/web/websocket_vis/optimized_costmap.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +import base64 +import hashlib +import time +from typing import Any +import zlib + +import numpy as np + + +class OptimizedCostmapEncoder: + """Handles optimized encoding of costmaps with delta compression.""" + + def __init__(self, chunk_size: int = 64) -> None: + self.chunk_size = chunk_size + self.last_full_grid: np.ndarray | None = None # type: ignore[type-arg] + self.last_full_sent_time: float = 0 # Track when last full update was sent + self.chunk_hashes: dict[tuple[int, int], str] = {} + self.full_update_interval = 3.0 # Send full update every 3 seconds + + def encode_costmap(self, grid: np.ndarray, force_full: bool = False) -> dict[str, Any]: # type: ignore[type-arg] + """Encode a costmap grid with optimizations. + + Args: + grid: The costmap grid as numpy array + force_full: Force sending a full update + + Returns: + Encoded costmap data + """ + current_time = time.time() + + # Determine if we need a full update + send_full = ( + force_full + or self.last_full_grid is None + or self.last_full_grid.shape != grid.shape + or (current_time - self.last_full_sent_time) > self.full_update_interval + ) + + if send_full: + return self._encode_full(grid, current_time) + else: + return self._encode_delta(grid, current_time) + + def _encode_full(self, grid: np.ndarray, current_time: float) -> dict[str, Any]: # type: ignore[type-arg] + height, width = grid.shape + + # Convert to uint8 for better compression (costmap values are -1 to 100) + # Map -1 to 255 for unknown cells + grid_uint8 = grid.astype(np.int16) + grid_uint8[grid_uint8 == -1] = 255 + grid_uint8 = grid_uint8.astype(np.uint8) + + # Compress the data + compressed = zlib.compress(grid_uint8.tobytes(), level=6) + + # Base64 encode + encoded = base64.b64encode(compressed).decode("ascii") + + # Update state + self.last_full_grid = grid.copy() + self.last_full_sent_time = current_time + self._update_chunk_hashes(grid) + + return { + "update_type": "full", + "shape": [height, width], + "dtype": "u8", # uint8 + "compressed": True, + "compression": "zlib", + "data": encoded, + } + + def _encode_delta(self, grid: np.ndarray, current_time: float) -> dict[str, Any]: # type: ignore[type-arg] + height, width = grid.shape + changed_chunks = [] + + # Divide grid into chunks and check for changes + for y in range(0, height, self.chunk_size): + for x in range(0, width, self.chunk_size): + # Get chunk bounds + y_end = min(y + self.chunk_size, height) + x_end = min(x + self.chunk_size, width) + + # Extract chunk + chunk = grid[y:y_end, x:x_end] + + # Compute hash of chunk + chunk_hash = hashlib.md5(chunk.tobytes()).hexdigest() + chunk_key = (y, x) + + # Check if chunk has changed + if chunk_key not in self.chunk_hashes or self.chunk_hashes[chunk_key] != chunk_hash: + # Chunk has changed, encode it + chunk_uint8 = chunk.astype(np.int16) + chunk_uint8[chunk_uint8 == -1] = 255 + chunk_uint8 = chunk_uint8.astype(np.uint8) + + # Compress chunk + compressed = zlib.compress(chunk_uint8.tobytes(), level=6) + encoded = base64.b64encode(compressed).decode("ascii") + + changed_chunks.append( + {"pos": [y, x], "size": [y_end - y, x_end - x], "data": encoded} + ) + + # Update hash + self.chunk_hashes[chunk_key] = chunk_hash + + # Update state - only update the grid, not the timer + self.last_full_grid = grid.copy() + + # If too many chunks changed, send full update instead + total_chunks = ((height + self.chunk_size - 1) // self.chunk_size) * ( + (width + self.chunk_size - 1) // self.chunk_size + ) + + if len(changed_chunks) > total_chunks * 0.5: + # More than 50% changed, send full update + return self._encode_full(grid, current_time) + + return { + "update_type": "delta", + "shape": [height, width], + "dtype": "u8", + "compressed": True, + "compression": "zlib", + "chunks": changed_chunks, + } + + def _update_chunk_hashes(self, grid: np.ndarray) -> None: # type: ignore[type-arg] + """Update all chunk hashes for the grid.""" + self.chunk_hashes.clear() + height, width = grid.shape + + for y in range(0, height, self.chunk_size): + for x in range(0, width, self.chunk_size): + y_end = min(y + self.chunk_size, height) + x_end = min(x + self.chunk_size, width) + chunk = grid[y:y_end, x:x_end] + chunk_hash = hashlib.md5(chunk.tobytes()).hexdigest() + self.chunk_hashes[(y, x)] = chunk_hash diff --git a/dimos/web/websocket_vis/path_history.py b/dimos/web/websocket_vis/path_history.py new file mode 100644 index 0000000000..39b6be08a3 --- /dev/null +++ b/dimos/web/websocket_vis/path_history.py @@ -0,0 +1,75 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple path history class for visualization purposes. +This is a minimal implementation to support websocket visualization. +""" + +from dimos.msgs.geometry_msgs import Vector3 + + +class PathHistory: + """A simple container for storing a history of positions for visualization.""" + + def __init__(self, points: list[Vector3 | tuple | list] | None = None) -> None: # type: ignore[type-arg] + """Initialize with optional list of points.""" + self.points: list[Vector3] = [] + if points: + for p in points: + if isinstance(p, Vector3): + self.points.append(p) + else: + self.points.append(Vector3(*p)) + + def ipush(self, point: Vector3 | tuple | list) -> "PathHistory": # type: ignore[type-arg] + """Add a point to the history (in-place) and return self.""" + if isinstance(point, Vector3): + self.points.append(point) + else: + self.points.append(Vector3(*point)) + return self + + def iclip_tail(self, max_length: int) -> "PathHistory": + """Keep only the last max_length points (in-place) and return self.""" + if max_length > 0 and len(self.points) > max_length: + self.points = self.points[-max_length:] + return self + + def last(self) -> Vector3 | None: + """Return the last point in the history, or None if empty.""" + return self.points[-1] if self.points else None + + def length(self) -> float: + """Calculate the total length of the path.""" + if len(self.points) < 2: + return 0.0 + + total = 0.0 + for i in range(1, len(self.points)): + p1 = self.points[i - 1] + p2 = self.points[i] + dx = p2.x - p1.x + dy = p2.y - p1.y + dz = p2.z - p1.z + total += (dx * dx + dy * dy + dz * dz) ** 0.5 + return total + + def __len__(self) -> int: + """Return the number of points in the history.""" + return len(self.points) + + def __getitem__(self, index: int) -> Vector3: + """Get a point by index.""" + return self.points[index] diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py new file mode 100644 index 0000000000..6990bf825a --- /dev/null +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +WebSocket Visualization Module for Dimos navigation and mapping. +""" + +import asyncio +import threading +import time +from typing import Any + +from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] +from reactivex.disposable import Disposable +import socketio # type: ignore[import-untyped] +from starlette.applications import Starlette +from starlette.responses import HTMLResponse +from starlette.routing import Route +import uvicorn + +from dimos.core import In, Module, Out, rpc +from dimos.mapping.types import LatLon +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.utils.logging_config import setup_logger + +from .optimized_costmap import OptimizedCostmapEncoder + +logger = setup_logger() + + +class WebsocketVisModule(Module): + """ + WebSocket-based visualization module for real-time navigation data. + + This module provides a web interface for visualizing: + - Robot position and orientation + - Navigation paths + - Costmaps + - Interactive goal setting via mouse clicks + + Inputs: + - robot_pose: Current robot position + - path: Navigation path + - global_costmap: Global costmap for visualization + + Outputs: + - click_goal: Goal position from user clicks + """ + + # LCM inputs + odom: In[PoseStamped] = None # type: ignore[assignment] + gps_location: In[LatLon] = None # type: ignore[assignment] + path: In[Path] = None # type: ignore[assignment] + global_costmap: In[OccupancyGrid] = None # type: ignore[assignment] + + # LCM outputs + goal_request: Out[PoseStamped] = None # type: ignore[assignment] + gps_goal: Out[LatLon] = None # type: ignore[assignment] + explore_cmd: Out[Bool] = None # type: ignore[assignment] + stop_explore_cmd: Out[Bool] = None # type: ignore[assignment] + cmd_vel: Out[Twist] = None # type: ignore[assignment] + movecmd_stamped: Out[TwistStamped] = None # type: ignore[assignment] + + def __init__(self, port: int = 7779, **kwargs) -> None: # type: ignore[no-untyped-def] + """Initialize the WebSocket visualization module. + + Args: + port: Port to run the web server on + """ + super().__init__(**kwargs) + + self.port = port + self._uvicorn_server_thread: threading.Thread | None = None + self.sio: socketio.AsyncServer | None = None + self.app = None + self._broadcast_loop = None + self._broadcast_thread = None + self._uvicorn_server: uvicorn.Server | None = None + + self.vis_state = {} # type: ignore[var-annotated] + self.state_lock = threading.Lock() + + self.costmap_encoder = OptimizedCostmapEncoder(chunk_size=64) + + logger.info(f"WebSocket visualization module initialized on port {port}") + + def _start_broadcast_loop(self) -> None: + def websocket_vis_loop() -> None: + self._broadcast_loop = asyncio.new_event_loop() # type: ignore[assignment] + asyncio.set_event_loop(self._broadcast_loop) + try: + self._broadcast_loop.run_forever() # type: ignore[attr-defined] + except Exception as e: + logger.error(f"Broadcast loop error: {e}") + finally: + self._broadcast_loop.close() # type: ignore[attr-defined] + + self._broadcast_thread = threading.Thread(target=websocket_vis_loop, daemon=True) # type: ignore[assignment] + self._broadcast_thread.start() # type: ignore[attr-defined] + + @rpc + def start(self) -> None: + super().start() + + self._create_server() + + self._start_broadcast_loop() + + self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) + self._uvicorn_server_thread.start() + + try: + unsub = self.odom.subscribe(self._on_robot_pose) + self._disposables.add(Disposable(unsub)) + except Exception: + ... + + try: + unsub = self.gps_location.subscribe(self._on_gps_location) + self._disposables.add(Disposable(unsub)) + except Exception: + ... + + try: + unsub = self.path.subscribe(self._on_path) + self._disposables.add(Disposable(unsub)) + except Exception: + ... + + unsub = self.global_costmap.subscribe(self._on_global_costmap) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + if self._uvicorn_server: + self._uvicorn_server.should_exit = True + + if self.sio and self._broadcast_loop and not self._broadcast_loop.is_closed(): + + async def _disconnect_all() -> None: + await self.sio.disconnect() + + asyncio.run_coroutine_threadsafe(_disconnect_all(), self._broadcast_loop) + + if self._broadcast_loop and not self._broadcast_loop.is_closed(): + self._broadcast_loop.call_soon_threadsafe(self._broadcast_loop.stop) + + if self._broadcast_thread and self._broadcast_thread.is_alive(): + self._broadcast_thread.join(timeout=1.0) + + if self._uvicorn_server_thread and self._uvicorn_server_thread.is_alive(): + self._uvicorn_server_thread.join(timeout=2.0) + + super().stop() + + @rpc + def set_gps_travel_goal_points(self, points: list[LatLon]) -> None: + json_points = [{"lat": x.lat, "lon": x.lon} for x in points] + self.vis_state["gps_travel_goal_points"] = json_points + self._emit("gps_travel_goal_points", json_points) + + def _create_server(self) -> None: + # Create SocketIO server + self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + + async def serve_index(request): # type: ignore[no-untyped-def] + return HTMLResponse("Use the extension.") + + routes = [Route("/", serve_index)] + starlette_app = Starlette(routes=routes) + + self.app = socketio.ASGIApp(self.sio, starlette_app) + + # Register SocketIO event handlers + @self.sio.event # type: ignore[misc] + async def connect(sid, environ) -> None: # type: ignore[no-untyped-def] + with self.state_lock: + current_state = dict(self.vis_state) + + # Force full costmap update on new connection + self.costmap_encoder.last_full_grid = None + + await self.sio.emit("full_state", current_state, room=sid) # type: ignore[union-attr] + + @self.sio.event # type: ignore[misc] + async def click(sid, position) -> None: # type: ignore[no-untyped-def] + goal = PoseStamped( + position=(position[0], position[1], 0), + orientation=(0, 0, 0, 1), # Default orientation + frame_id="world", + ) + self.goal_request.publish(goal) + logger.info(f"Click goal published: ({goal.position.x:.2f}, {goal.position.y:.2f})") + + @self.sio.event # type: ignore[misc] + async def gps_goal(sid, goal) -> None: # type: ignore[no-untyped-def] + logger.info(f"Set GPS goal: {goal}") + self.gps_goal.publish(LatLon(lat=goal["lat"], lon=goal["lon"])) + + @self.sio.event # type: ignore[misc] + async def start_explore(sid) -> None: # type: ignore[no-untyped-def] + logger.info("Starting exploration") + self.explore_cmd.publish(Bool(data=True)) + + @self.sio.event # type: ignore[misc] + async def stop_explore(sid) -> None: # type: ignore[no-untyped-def] + logger.info("Stopping exploration") + self.stop_explore_cmd.publish(Bool(data=True)) + + @self.sio.event # type: ignore[misc] + async def move_command(sid, data) -> None: # type: ignore[no-untyped-def] + # Publish Twist if transport is configured + if self.cmd_vel and self.cmd_vel.transport: + twist = Twist( + linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), + angular=Vector3( + data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] + ), + ) + self.cmd_vel.publish(twist) + + # Publish TwistStamped if transport is configured + if self.movecmd_stamped and self.movecmd_stamped.transport: + twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), + angular=Vector3( + data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] + ), + ) + self.movecmd_stamped.publish(twist_stamped) + + def _run_uvicorn_server(self) -> None: + config = uvicorn.Config( + self.app, # type: ignore[arg-type] + host="0.0.0.0", + port=self.port, + log_level="error", # Reduce verbosity + ) + self._uvicorn_server = uvicorn.Server(config) + self._uvicorn_server.run() + + def _on_robot_pose(self, msg: PoseStamped) -> None: + pose_data = {"type": "vector", "c": [msg.position.x, msg.position.y, msg.position.z]} + self.vis_state["robot_pose"] = pose_data + self._emit("robot_pose", pose_data) + + def _on_gps_location(self, msg: LatLon) -> None: + pose_data = {"lat": msg.lat, "lon": msg.lon} + self.vis_state["gps_location"] = pose_data + self._emit("gps_location", pose_data) + + def _on_path(self, msg: Path) -> None: + points = [[pose.position.x, pose.position.y] for pose in msg.poses] + path_data = {"type": "path", "points": points} + self.vis_state["path"] = path_data + self._emit("path", path_data) + + def _on_global_costmap(self, msg: OccupancyGrid) -> None: + costmap_data = self._process_costmap(msg) + self.vis_state["costmap"] = costmap_data + self._emit("costmap", costmap_data) + + def _process_costmap(self, costmap: OccupancyGrid) -> dict[str, Any]: + """Convert OccupancyGrid to visualization format.""" + costmap = costmap.inflate(0.1).gradient(max_distance=1.0) + grid_data = self.costmap_encoder.encode_costmap(costmap.grid) + + return { + "type": "costmap", + "grid": grid_data, + "origin": { + "type": "vector", + "c": [costmap.origin.position.x, costmap.origin.position.y, 0], + }, + "resolution": costmap.resolution, + "origin_theta": 0, # Assuming no rotation for now + } + + def _emit(self, event: str, data: Any) -> None: + if self._broadcast_loop and not self._broadcast_loop.is_closed(): + asyncio.run_coroutine_threadsafe(self.sio.emit(event, data), self._broadcast_loop) + + +websocket_vis = WebsocketVisModule.blueprint + +__all__ = ["WebsocketVisModule", "websocket_vis"] diff --git a/dist/dimos-0.0.0-py3-none-any.whl b/dist/dimos-0.0.0-py3-none-any.whl deleted file mode 100644 index 9d6535daee..0000000000 Binary files a/dist/dimos-0.0.0-py3-none-any.whl and /dev/null differ diff --git a/dist/dimos-0.0.0.tar.gz b/dist/dimos-0.0.0.tar.gz deleted file mode 100644 index ad6e61e525..0000000000 Binary files a/dist/dimos-0.0.0.tar.gz and /dev/null differ diff --git a/docker/agent/Dockerfile b/docker/agent/Dockerfile deleted file mode 100644 index f91e458a7c..0000000000 --- a/docker/agent/Dockerfile +++ /dev/null @@ -1,22 +0,0 @@ -FROM python:3 - -RUN apt-get update && apt-get install -y \ - libgl1-mesa-glx - -WORKDIR /app - -COPY requirements.txt ./ - -RUN pip install --no-cache-dir -r requirements.txt - -COPY ./dimos ./dimos - -COPY ./tests ./tests - -COPY ./dimos/__init__.py ./ - -# CMD [ "python", "-m", "tests.test_environment" ] - -# CMD [ "python", "-m", "tests.test_openai_agent_v3" ] - -CMD [ "python", "-m", "tests.test_agent" ] diff --git a/docker/agent/docker-compose.yml b/docker/agent/docker-compose.yml deleted file mode 100644 index da79d5a453..0000000000 --- a/docker/agent/docker-compose.yml +++ /dev/null @@ -1,48 +0,0 @@ ---- -services: - dimos: - image: dimos:latest - build: ./../../ - env_file: - - ./../../.env - mem_limit: 8048m - volumes: - - ./../../assets:/app/assets - ports: - - "5555:5555" - # command: [ "python", "-m", "tests.test_agent" ] - # ^^ Working Sanity Test Cases - Expand to Agent Class - # - # command: [ "python", "-m", "tests.types.videostream" ] - # ^^ Working Skeleton - Needs Impl. - # - # command: [ "python", "-m", "tests.types.media_provider" ] - # ^^ Working Instance - Needs Tests. - # - # command: [ "python", "-m", "tests.web.edge_io" ] - # ^^ Working Instance - Needs Tests. - # - command: [ "python", "-m", "tests.agent_manip_flow_test" ] - # ^^ Working Instance - Needs Optical Flow Fix. - - # command: [ "python", "-m", "tests.agent_memory_test" ] - # ^^ WIP - Agent Memory Testing - - # command: ["tail", "-f", "/dev/null"] - stdin_open: true - tty: true - -# ---- -# TO RUN: -# docker build -f ./Dockerfile -t dimos ../../ && docker compose up -# GO TO: -# 127.0.0.1:5555 (when flask server fixed) -# ---- - -# video-service: -# build: ./video-service -# image: video-service:latest -# volumes: -# - ./../../assets:/app/dimos-env/assets -# ports: -# - "23001:23001" diff --git a/docker/deprecated/agent/Dockerfile b/docker/deprecated/agent/Dockerfile new file mode 100644 index 0000000000..a760bc3a6a --- /dev/null +++ b/docker/deprecated/agent/Dockerfile @@ -0,0 +1,40 @@ +FROM python:3 + +# General +# RUN apt-get update && apt-get install -y \ +# libgl1-mesa-glx + +# Unitree Specific +RUN apt-get update && apt-get install -y \ + libgl1-mesa-glx \ + build-essential \ + libavformat-dev \ + libavcodec-dev \ + libavdevice-dev \ + libavutil-dev \ + libswscale-dev \ + libpostproc-dev \ + gcc \ + make \ + portaudio19-dev \ + python3-pyaudio \ + python3-all-dev + +# Change working directory to /app for proper relative pathing +WORKDIR /app + +COPY requirements.txt ./ + +RUN pip install --no-cache-dir -r requirements.txt + +COPY ./dimos ./dimos + +COPY ./tests ./tests + +COPY ./dimos/__init__.py ./ + +# CMD [ "python", "-m", "tests.test_environment" ] + +# CMD [ "python", "-m", "tests.test_openai_agent_v3" ] + +CMD [ "python", "-m", "tests.test_agent" ] diff --git a/docker/deprecated/agent/docker-compose.yml b/docker/deprecated/agent/docker-compose.yml new file mode 100644 index 0000000000..37b24f6abf --- /dev/null +++ b/docker/deprecated/agent/docker-compose.yml @@ -0,0 +1,85 @@ +--- +services: + dimos: + image: dimos:latest + build: + context: ../../ + dockerfile: docker/agent/Dockerfile + env_file: + - ../../.env + mem_limit: 8048m + volumes: + - ../../assets:/app/assets + ports: + - "5555:5555" + environment: + - PYTHONUNBUFFERED=1 + # command: [ "python", "-m", "tests.test_agent" ] + # ^^ Working Sanity Test Cases - Expand to Agent Class + # + # command: [ "python", "-m", "tests.stream.video_operators" ] + # ^^ Working Skeleton - Needs Impl. + # + # command: [ "python", "-m", "tests.stream.video_provider" ] + # ^^ Working Instance - Needs Tests. + # + # command: [ "python", "-m", "tests.web.edge_io" ] + # ^^ Working Instance - Needs Tests. + # + # command: [ "python", "-m", "tests.agent_manip_flow_flask_test" ] + # ^^ Working Instance + + # command: [ "python", "-m", "tests.agent_manip_flow_fastapi_test" ] + # ^^ Working Instance - Needs threading / start / stop functionality bugfix. + + # command: [ "python", "-m", "tests.test_standalone_project_out" ] + # ^^ WIP - Output Function Headers + Descriptions + + # command: [ "python", "-m", "tests.agent_memory_test" ] + # ^^ WIP - Agent Memory Testing + + # command: [ "python", "-m", "tests.test_standalone_fastapi" ] + # ^^ Working, FastAPI Multithreader Standalone + + # command: [ "python", "-m", "tests.test_standalone_rxpy_01" ] + # ^^ Working Instance + + # command: [ "python", "-m", "tests.test_standalone_openai_json" ] + # ^^ Working Instance + + # command: [ "python", "-m", "tests.test_standalone_openai_json_struct" ] + # ^^ Working Instance + + # command: [ "python", "-m", "tests.test_standalone_openai_json_struct_func" ] + # ^^ WIP + + # command: [ "python", "-m", "tests.test_standalone_openai_json_struct_func_playground" ] + # ^^ WIP + + # command: [ "python", "-m", "tests.test_skill_library" ] + # ^^ Working Instance + + # command: [ "python", "-m", "tests.test_video_rtsp" ] + # ^^ WIP + + command: [ "python", "-m", "tests.test_video_agent_threading" ] + # ^^ WIP + + # command: ["tail", "-f", "/dev/null"] + stdin_open: true + tty: true + +# ---- +# TO RUN: +# docker build -f ./Dockerfile -t dimos ../../ && docker compose up +# GO TO: +# 127.0.0.1:5555 (when flask server fixed) +# ---- + +# video-service: +# build: ./video-service +# image: video-service:latest +# volumes: +# - ./../../assets:/app/dimos-env/assets +# ports: +# - "23001:23001" diff --git a/docker/deprecated/interface/Dockerfile b/docker/deprecated/interface/Dockerfile new file mode 100644 index 0000000000..9064f882e9 --- /dev/null +++ b/docker/deprecated/interface/Dockerfile @@ -0,0 +1,6 @@ +FROM node:18-alpine + +WORKDIR /app + +# Start development server with host 0.0.0.0 to allow external connections +CMD ["sh", "-c", "yarn install && yarn dev --host 0.0.0.0"] \ No newline at end of file diff --git a/docker/deprecated/interface/docker-compose.yml b/docker/deprecated/interface/docker-compose.yml new file mode 100644 index 0000000000..6571e92e16 --- /dev/null +++ b/docker/deprecated/interface/docker-compose.yml @@ -0,0 +1,18 @@ +--- +services: + dimos-web-interface: + build: + context: ../../ # Root of the project + dockerfile: docker/interface/Dockerfile + image: dimos-web-interface:latest + container_name: dimos-web-interface + network_mode: "host" + ports: + - "3000:3000" + volumes: + - ../../dimos/web/dimos_interface:/app + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:3000"] + interval: 30s + timeout: 10s + retries: 3 diff --git a/docker/deprecated/simulation/entrypoint.sh b/docker/deprecated/simulation/entrypoint.sh new file mode 100644 index 0000000000..373fa6f05c --- /dev/null +++ b/docker/deprecated/simulation/entrypoint.sh @@ -0,0 +1,5 @@ +#!/bin/bash +export PYTHONPATH="${PYTHONPATH}:/app" +source /opt/ros/humble/setup.bash +#source /home/ros/dev_ws/install/setup.bash +exec "$@" \ No newline at end of file diff --git a/docker/deprecated/simulation/genesis/10_nvidia.json b/docker/deprecated/simulation/genesis/10_nvidia.json new file mode 100644 index 0000000000..2bfcca059e --- /dev/null +++ b/docker/deprecated/simulation/genesis/10_nvidia.json @@ -0,0 +1,6 @@ +{ + "file_format_version" : "1.0.0", + "ICD" : { + "library_path" : "libEGL_nvidia.so.0" + } +} diff --git a/docker/deprecated/simulation/genesis/Dockerfile b/docker/deprecated/simulation/genesis/Dockerfile new file mode 100644 index 0000000000..d22473b7cd --- /dev/null +++ b/docker/deprecated/simulation/genesis/Dockerfile @@ -0,0 +1,131 @@ +# From https://github.com/Genesis-Embodied-AI/Genesis/blob/main/docker/Dockerfile +ARG CUDA_VERSION=12.1 + +# =============================================================== +# Stage 1: Build LuisaRender +# =============================================================== +FROM pytorch/pytorch:2.5.1-cuda${CUDA_VERSION}-cudnn9-devel AS builder + +ENV DEBIAN_FRONTEND=noninteractive +ARG PYTHON_VERSION=3.11 + +# Install necessary packages +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + manpages-dev \ + libvulkan-dev \ + zlib1g-dev \ + xorg-dev libglu1-mesa-dev \ + libsnappy-dev \ + software-properties-common \ + git \ + curl \ + wget +RUN add-apt-repository ppa:ubuntu-toolchain-r/test && \ + apt update && \ + apt install -y --no-install-recommends \ + gcc-11 \ + g++-11 \ + gcc-11 g++-11 patchelf && \ + rm -rf /var/lib/apt/lists/* + +# Set GCC-11 and G++-11 as the default +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 110 && \ + update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 110 + +# Install Rust for build requirements +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +RUN pip install "pybind11[global]" + +# Install CMake +RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.0-rc2/cmake-3.31.0-rc2-linux-x86_64.sh && \ + chmod +x cmake-3.31.0-rc2-linux-x86_64.sh && \ + ./cmake-3.31.0-rc2-linux-x86_64.sh --skip-license --prefix=/usr/local && \ + rm cmake-3.31.0-rc2-linux-x86_64.sh + +# Build LuisaRender +WORKDIR /workspace +RUN git clone https://github.com/Genesis-Embodied-AI/Genesis.git && \ + cd Genesis && \ + git submodule update --init --recursive +COPY ./docker/simulation/genesis/build_luisa.sh /workspace/build_luisa.sh +RUN chmod +x ./build_luisa.sh && ./build_luisa.sh ${PYTHON_VERSION} + +# =============================================================== +# Stage 2: Runtime Environment +# =============================================================== +FROM pytorch/pytorch:2.5.1-cuda${CUDA_VERSION}-cudnn9-devel + +ARG PYTHON_VERSION=3.11 +ENV DEBIAN_FRONTEND=noninteractive +ENV NVIDIA_DRIVER_CAPABILITIES=all + +# Install runtime dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + tmux \ + git \ + curl \ + wget \ + bash-completion \ + libgl1 \ + libgl1-mesa-glx \ + libegl-dev \ + libegl1 \ + libxrender1 \ + libglib2.0-0 \ + ffmpeg \ + libgtk2.0-dev \ + pkg-config \ + libvulkan-dev \ + libgles2 \ + libglvnd0 \ + libglx0 \ + && apt clean \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +# --------------------------- Genesis ---------------------------- +RUN pip install --no-cache-dir open3d +RUN git clone https://github.com/Genesis-Embodied-AI/Genesis.git && \ + cd Genesis && \ + pip install . && \ + pip install --no-cache-dir PyOpenGL==3.1.5 + +# ------------------------ Motion planning ----------------------- +RUN PYTHON_MAJOR_MINOR=$(echo ${PYTHON_VERSION} | tr -d '.') && \ + wget https://github.com/ompl/ompl/releases/download/prerelease/ompl-1.6.0-cp${PYTHON_MAJOR_MINOR}-cp${PYTHON_MAJOR_MINOR}-manylinux_2_28_x86_64.whl && \ + pip install ompl-1.6.0-cp${PYTHON_MAJOR_MINOR}-cp${PYTHON_MAJOR_MINOR}-manylinux_2_28_x86_64.whl && \ + rm ompl-1.6.0-cp${PYTHON_MAJOR_MINOR}-cp${PYTHON_MAJOR_MINOR}-manylinux_2_28_x86_64.whl + +# -------------------- Surface Reconstruction -------------------- +# Set the LD_LIBRARY_PATH directly in the environment +COPY --from=builder /workspace/Genesis/genesis/ext/ParticleMesher/ParticleMesherPy /opt/conda/lib/python3.1/site-packages/genesis/ext/ParticleMesher/ParticleMesherPy +ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.1/site-packages/genesis/ext/ParticleMesher/ParticleMesherPy:$LD_LIBRARY_PATH + +# --------------------- Ray Tracing Renderer --------------------- +# Copy LuisaRender build artifacts from the builder stage +COPY --from=builder /workspace/Genesis/genesis/ext/LuisaRender/build/bin /opt/conda/lib/python3.1/site-packages/genesis/ext/LuisaRender/build/bin +# fix GLIBCXX_3.4.30 not found +RUN cd /opt/conda/lib && \ + mv libstdc++.so.6 libstdc++.so.6.old && \ + ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6 libstdc++.so.6 + +COPY ./docker/simulation/genesis/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json +COPY ./docker/simulation/genesis/nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json +COPY ./docker/simulation/genesis/nvidia_layers.json /etc/vulkan/implicit_layer.d/nvidia_layers.json + +# Change working directory to /app for proper relative pathing +WORKDIR /app + +# Copy application code +COPY ./dimos ./dimos +COPY ./tests ./tests +COPY ./assets ./assets +COPY ./dimos/__init__.py ./ +COPY ./docker/simulation/entrypoint.sh / +RUN chmod +x /entrypoint.sh + +ENTRYPOINT ["/entrypoint.sh"] +CMD [ "python3", "/app/tests/genesissim/stream_camera.py" ] diff --git a/docker/deprecated/simulation/genesis/build_luisa.sh b/docker/deprecated/simulation/genesis/build_luisa.sh new file mode 100644 index 0000000000..95d861c57f --- /dev/null +++ b/docker/deprecated/simulation/genesis/build_luisa.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Check if Python version is provided +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +PYTHON_VERSION=$1 + +cd Genesis/genesis/ext/LuisaRender && \ +git submodule update --init --recursive && \ +mkdir -p build && \ +cmake -S . -B build \ + -D CMAKE_BUILD_TYPE=Release \ + -D PYTHON_VERSIONS=$PYTHON_VERSION \ + -D LUISA_COMPUTE_DOWNLOAD_NVCOMP=ON \ + -D LUISA_COMPUTE_DOWNLOAD_OIDN=ON \ + -D LUISA_COMPUTE_ENABLE_GUI=OFF \ + -D LUISA_COMPUTE_ENABLE_CUDA=ON \ + -Dpybind11_DIR=$(python3 -c "import pybind11; print(pybind11.get_cmake_dir())") && \ +cmake --build build -j $(nproc) \ No newline at end of file diff --git a/docker/deprecated/simulation/genesis/docker-compose.yml b/docker/deprecated/simulation/genesis/docker-compose.yml new file mode 100644 index 0000000000..2f1187a9c1 --- /dev/null +++ b/docker/deprecated/simulation/genesis/docker-compose.yml @@ -0,0 +1,38 @@ +--- +services: + dimos_simulator: + image: dimos_simulator_genesis:latest + build: + context: ../../../ + dockerfile: docker/simulation/genesis/Dockerfile + env_file: + - ../../../.env + runtime: nvidia + environment: + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=all + - PYTHONUNBUFFERED=1 + - ACCEPT_EULA=Y + - PRIVACY_CONSENT=Y + volumes: + - ./../../../assets:/app/assets + networks: + - rtsp_net + depends_on: + - mediamtx + + mediamtx: + image: bluenviron/mediamtx:latest + networks: + - rtsp_net + ports: + - "8554:8554" + - "1935:1935" + - "8888:8888" + environment: + - MTX_PROTOCOLS=tcp + - MTX_LOG_LEVEL=info + +networks: + rtsp_net: + name: rtsp_net diff --git a/docker/deprecated/simulation/genesis/nvidia_icd.json b/docker/deprecated/simulation/genesis/nvidia_icd.json new file mode 100644 index 0000000000..69600b17ae --- /dev/null +++ b/docker/deprecated/simulation/genesis/nvidia_icd.json @@ -0,0 +1,7 @@ +{ + "file_format_version" : "1.0.0", + "ICD": { + "library_path": "libGLX_nvidia.so.0", + "api_version" : "1.2.155" + } +} diff --git a/docker/deprecated/simulation/genesis/nvidia_layers.json b/docker/deprecated/simulation/genesis/nvidia_layers.json new file mode 100644 index 0000000000..a8e098eb9a --- /dev/null +++ b/docker/deprecated/simulation/genesis/nvidia_layers.json @@ -0,0 +1,22 @@ + +{ + "file_format_version" : "1.0.0", + "layer": { + "name": "VK_LAYER_NV_optimus", + "type": "INSTANCE", + "library_path": "libGLX_nvidia.so.0", + "api_version" : "1.2.155", + "implementation_version" : "1", + "description" : "NVIDIA Optimus layer", + "functions": { + "vkGetInstanceProcAddr": "vk_optimusGetInstanceProcAddr", + "vkGetDeviceProcAddr": "vk_optimusGetDeviceProcAddr" + }, + "enable_environment": { + "__NV_PRIME_RENDER_OFFLOAD": "1" + }, + "disable_environment": { + "DISABLE_LAYER_NV_OPTIMUS_1": "" + } + } +} diff --git a/docker/deprecated/simulation/isaac/Dockerfile b/docker/deprecated/simulation/isaac/Dockerfile new file mode 100644 index 0000000000..a908d5c6e0 --- /dev/null +++ b/docker/deprecated/simulation/isaac/Dockerfile @@ -0,0 +1,190 @@ +FROM nvcr.io/nvidia/isaac-sim:4.2.0 + +# Set up locales +ENV LANG=en_US.UTF-8 +ENV LANGUAGE=en_US:en +ENV LC_ALL=en_US.UTF-8 + +RUN apt-get update && apt-get install -y locales && \ + locale-gen en_US en_US.UTF-8 && \ + update-locale LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 && \ + rm -rf /var/lib/apt/lists/* + +# Prevent interactive prompts during installation +ENV DEBIAN_FRONTEND=noninteractive + +# Install basic dependencies +RUN apt-get update && apt-get install -y \ + software-properties-common \ + curl \ + git \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +# Set timezone non-interactively +ENV TZ=America/Los_Angeles +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +# Setup ROS 2 +RUN add-apt-repository universe -y \ + && curl -sSL https://raw.githubusercontent.com/ros/rosdistro/master/ros.key -o /usr/share/keyrings/ros-archive-keyring.gpg \ + && echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/ros-archive-keyring.gpg] http://packages.ros.org/ros2/ubuntu $(. /etc/os-release && echo $UBUNTU_CODENAME) main" | tee /etc/apt/sources.list.d/ros2.list > /dev/null \ + && apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + && apt-get upgrade -y \ + && apt-get install -y \ + ros-humble-desktop \ + ros-humble-ros-base \ + ros-dev-tools \ + python3-rosdep \ + python3-colcon-common-extensions \ + python3-pip \ + python3.10-venv \ + ament-cmake \ + ros-humble-ament-cmake \ + build-essential \ + cmake \ + build-essential \ + cmake \ + python3-colcon-common-extensions \ + python3-flake8 \ + python3-rosdep \ + python3-setuptools \ + python3-vcstool \ + python3-rosinstall \ + python3-rosinstall-generator \ + python3-wstool \ + nano \ + wget \ + curl \ + vim \ + git \ + x11-apps \ + tmux \ + ros-humble-foxglove-bridge \ + ros-humble-moveit \ + ros-humble-moveit-visual-tools \ + ros-humble-moveit-ros-visualization \ + ros-humble-moveit-servo \ + ros-humble-joint-state-publisher-gui \ + ros-humble-rosbridge-suite \ + ros-humble-xacro \ + ros-humble-robot-state-publisher \ + ros-humble-teleop-twist-keyboard \ + ros-humble-teleop-twist-joy \ + ros-humble-joy \ + ros-humble-controller-manager \ + ros-humble-ros2-control \ + ros-humble-ros2-controllers \ + ros-humble-robot-state-publisher \ + ros-humble-joint-state-publisher \ + ros-humble-joint-trajectory-controller \ + ros-humble-joint-state-broadcaster \ + ros-humble-vision-msgs \ + ros-humble-ackermann-msgs \ + ros-humble-navigation2 \ + ros-humble-nav2-bringup \ + ros-humble-nav2-msgs \ + ros-humble-nav2-common \ + ros-humble-nav2-behavior-tree \ + ros-humble-nav2-costmap-2d \ + ros-humble-nav2-core \ + ros-humble-nav2-bt-navigator \ + ros-humble-pointcloud-to-laserscan \ + iputils-ping \ + net-tools \ + htop \ + python3-pip \ + ros-humble-tf* \ + ros-humble-gazebo-ros-pkgs \ + dos2unix \ + python3-genmsg \ + gpg \ + pass \ + ros-humble-depthai-ros \ + zstd \ + && rm -rf /var/lib/apt/lists/* + +RUN apt-get upgrade -y + + +# Initialize rosdep +RUN rosdep init && rosdep update + +# Setup ROS environment +RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc + +# Install Python packages directly +RUN pip install --no-cache-dir \ + rospkg \ + numpy==1.24.4 \ + jsonpickle \ + scipy \ + easydict \ + matplotlib==3.9.1 \ + opencv-python \ + pyyaml \ + pyquaternion \ + pybullet \ + requests \ + pillow \ + open3d \ + av==10.0.0 \ + transforms3d \ + torch \ + torchvision \ + torchaudio \ + transformers + + +ARG USERNAME=ros +ARG USER_UID=1000 +ARG USER_GID=$USER_UID + +# Create ros home directory +RUN mkdir -p /home/$USERNAME + +RUN cd /home/$USERNAME && git clone https://github.com/isaac-sim/IsaacSim-ros_workspaces.git +RUN rosdep update +RUN /bin/bash -c "cd /home/$USERNAME/IsaacSim-ros_workspaces/humble_ws && rosdep install -i --from-path src --rosdistro humble -y" + +RUN mkdir -p /home/$USERNAME/dev_ws/src +RUN cd /home/$USERNAME/dev_ws/src && git clone https://github.com/yashas-salankimatt/thesis_ros_ws.git + +# Install ZED SDK +RUN wget https://stereolabs.sfo2.cdn.digitaloceanspaces.com/zedsdk/4.2/ZED_SDK_Ubuntu22_cuda12.1_v4.2.1.zstd.run && chmod +x ZED_SDK_Ubuntu22_cuda12.1_v4.2.1.zstd.run +RUN /bin/bash -c "./ZED_SDK_Ubuntu22_cuda12.1_v4.2.1.zstd.run -- silent skip_cuda" + +ENV ZED_SDK_ROOT_DIR=/usr/local/zed +ENV CMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}:${ZED_SDK_ROOT_DIR} + + +RUN mkdir -p /home/$USERNAME/deps +RUN cd /home/$USERNAME/deps && git clone https://github.com/facebookresearch/segment-anything-2.git +RUN cd /home/$USERNAME/deps/segment-anything-2 && pip install -e . +RUN cd /home/$USERNAME/dev_ws +RUN chown -R $USER_UID:$USER_GID /home/$USERNAME/ + +RUN /bin/bash -c "source /opt/ros/humble/setup.bash && cd /home/$USERNAME/IsaacSim-ros_workspaces/humble_ws && colcon build" +RUN rm -rf /var/lib/apt/lists/* + +ENV CUDA_HOME=/usr/local/lib/python3.10/dist-packages/nvidia/cuda_runtime +ENV CUDA_TOOLKIT_ROOT_DIR=${CUDA_HOME} +ENV PATH=${CUDA_HOME}/bin:${PATH} +ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} + +# Change working directory to /app for proper relative pathing +WORKDIR /app + +# Copy application code +COPY ./dimos ./dimos +COPY ./tests ./tests +COPY ./assets ./assets +COPY ./dimos/__init__.py ./ +COPY ./docker/simulation/entrypoint.sh / +RUN chmod +x /entrypoint.sh + +ENTRYPOINT ["/entrypoint.sh"] +CMD [ "/isaac-sim/python.sh", "/app/tests/isaacsim/stream_camera.py" ] +# For testing +#CMD ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/docker/deprecated/simulation/isaac/docker-compose.yml b/docker/deprecated/simulation/isaac/docker-compose.yml new file mode 100644 index 0000000000..a65040c4e2 --- /dev/null +++ b/docker/deprecated/simulation/isaac/docker-compose.yml @@ -0,0 +1,47 @@ +--- +services: + dimos_simulator: + image: dimos_simulator_isaac:latest + build: + context: ../../../ + dockerfile: docker/simulation/isaac/Dockerfile + env_file: + - ../../../.env + runtime: nvidia + environment: + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=all + - PYTHONUNBUFFERED=1 + - ACCEPT_EULA=Y + - PRIVACY_CONSENT=Y + volumes: + - ./../../../assets:/app/assets + # Isaac Sim required volumes + - ~/docker/isaac-sim/cache/kit:/isaac-sim/kit/cache:rw + - ~/docker/isaac-sim/cache/ov:/root/.cache/ov:rw + - ~/docker/isaac-sim/cache/pip:/root/.cache/pip:rw + - ~/docker/isaac-sim/cache/glcache:/root/.cache/nvidia/GLCache:rw + - ~/docker/isaac-sim/cache/computecache:/root/.nv/ComputeCache:rw + - ~/docker/isaac-sim/logs:/root/.nvidia-omniverse/logs:rw + - ~/docker/isaac-sim/data:/root/.local/share/ov/data:rw + - ~/docker/isaac-sim/documents:/root/Documents:rw + networks: + - rtsp_net + depends_on: + - mediamtx + + mediamtx: + image: bluenviron/mediamtx:latest + networks: + - rtsp_net + ports: + - "8554:8554" + - "1935:1935" + - "8888:8888" + environment: + - MTX_PROTOCOLS=tcp + - MTX_LOG_LEVEL=info + +networks: + rtsp_net: + name: rtsp_net diff --git a/docker/deprecated/unitree/agents/Dockerfile b/docker/deprecated/unitree/agents/Dockerfile new file mode 100644 index 0000000000..c46fdd66e6 --- /dev/null +++ b/docker/deprecated/unitree/agents/Dockerfile @@ -0,0 +1,146 @@ +FROM ubuntu:22.04 + +# Avoid prompts from apt +ENV DEBIAN_FRONTEND=noninteractive + +# Set locale +RUN apt-get update && apt-get install -y locales && \ + locale-gen en_US en_US.UTF-8 && \ + update-locale LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 +ENV LANG=en_US.UTF-8 + +# Set ROS distro +ENV ROS_DISTRO=humble + +# Install basic requirements +RUN apt-get update && apt-get install -y \ + curl \ + gnupg2 \ + lsb-release \ + python3-pip \ + clang \ + portaudio19-dev \ + git \ + mesa-utils \ + libgl1-mesa-glx \ + libgl1-mesa-dri \ + software-properties-common \ + libxcb1-dev \ + libxcb-keysyms1-dev \ + libxcb-util0-dev \ + libxcb-icccm4-dev \ + libxcb-image0-dev \ + libxcb-randr0-dev \ + libxcb-shape0-dev \ + libxcb-xinerama0-dev \ + libxcb-xkb-dev \ + libxkbcommon-x11-dev \ + qtbase5-dev \ + qtchooser \ + qt5-qmake \ + qtbase5-dev-tools \ + supervisor \ + && rm -rf /var/lib/apt/lists/* + +# Install specific numpy version first +RUN pip install 'numpy<2.0.0' + +# Add ROS2 apt repository +RUN curl -sSL https://raw.githubusercontent.com/ros/rosdistro/master/ros.key -o /usr/share/keyrings/ros-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/ros-archive-keyring.gpg] http://packages.ros.org/ros2/ubuntu $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/ros2.list > /dev/null + +# Install ROS2 packages and dependencies +RUN apt-get update && apt-get install -y \ + ros-${ROS_DISTRO}-desktop \ + ros-${ROS_DISTRO}-ros-base \ + ros-${ROS_DISTRO}-image-tools \ + ros-${ROS_DISTRO}-compressed-image-transport \ + ros-${ROS_DISTRO}-vision-msgs \ + ros-${ROS_DISTRO}-rviz2 \ + ros-${ROS_DISTRO}-rqt \ + ros-${ROS_DISTRO}-rqt-common-plugins \ + ros-${ROS_DISTRO}-twist-mux \ + ros-${ROS_DISTRO}-joy \ + ros-${ROS_DISTRO}-teleop-twist-joy \ + ros-${ROS_DISTRO}-navigation2 \ + ros-${ROS_DISTRO}-nav2-bringup \ + ros-${ROS_DISTRO}-nav2-amcl \ + ros-${ROS_DISTRO}-nav2-map-server \ + ros-${ROS_DISTRO}-nav2-util \ + ros-${ROS_DISTRO}-pointcloud-to-laserscan \ + ros-${ROS_DISTRO}-slam-toolbox \ + ros-${ROS_DISTRO}-foxglove-bridge \ + python3-rosdep \ + python3-rosinstall \ + python3-rosinstall-generator \ + python3-wstool \ + python3-colcon-common-extensions \ + python3-vcstool \ + build-essential \ + screen \ + tmux \ + && rm -rf /var/lib/apt/lists/* + +# Initialize rosdep +RUN rosdep init && rosdep update + +# Create workspace +WORKDIR /ros2_ws + +# Clone the repository with submodules +RUN git clone --recurse-submodules https://github.com/dimensionalOS/go2_ros2_sdk src + +# Install Python requirements +RUN cd src && pip install -r requirements.txt + +# Create dimos directory structure +RUN mkdir -p /app/dimos /app/docker + +COPY requirements.txt /app/ + +WORKDIR /app + +# Install dimos requirements +RUN pip install --no-cache-dir -r requirements.txt + +# Set PYTHONPATH permanently +ENV PYTHONPATH=/app:${PYTHONPATH} + +# Install ROS dependencies +WORKDIR /ros2_ws +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + rosdep install --from-paths src --ignore-src -r -y + +# Build the workspace +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + colcon build + +# Source ROS2 and workspace in bashrc +RUN echo "source /opt/ros/${ROS_DISTRO}/setup.bash" >> /root/.bashrc && \ + echo "source /ros2_ws/install/setup.bash" >> /root/.bashrc + +COPY docker /app/docker/ + +# Setup supervisor configuration +COPY docker/unitree/agents/supervisord.conf /etc/supervisor/conf.d/supervisord.conf + +# Copy entrypoint script +COPY docker/unitree/agents/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +# Copy dimos and tests +COPY dimos /app/dimos/ +COPY tests /app/tests +COPY dimos/__init__.py /app/__init__.py + +# Change working directory to /app for proper relative pathing +WORKDIR /app + +# Create output directories for supervisord and ROS +RUN mkdir -p /app/assets/output/ +RUN mkdir -p /app/assets/output/ros + +# TODO: Cleanup multiple working directories and seprate the dockerfiles for each service. + +ENTRYPOINT ["/entrypoint.sh"] +CMD ["/usr/bin/supervisord", "-n", "-c", "/etc/supervisor/conf.d/supervisord.conf"] diff --git a/docker/deprecated/unitree/agents/docker-compose.yml b/docker/deprecated/unitree/agents/docker-compose.yml new file mode 100644 index 0000000000..6cde23e98e --- /dev/null +++ b/docker/deprecated/unitree/agents/docker-compose.yml @@ -0,0 +1,27 @@ +--- +services: + dimos-unitree-agents: + image: dimos-unitree-agents:latest + build: + context: ../../../ + dockerfile: docker/unitree/agents/Dockerfile + env_file: + - ../../../.env + environment: + PYTHONUNBUFFERED: 1 + ROBOT_IP: ${ROBOT_IP} + CONN_TYPE: ${CONN_TYPE:-webrtc} + WEBRTC_SERVER_HOST: 0.0.0.0 # Listen on all interfaces + WEBRTC_SERVER_PORT: ${WEBRTC_SERVER_PORT:-9991} + DISPLAY: ${DISPLAY:-} # For GUI applications like rviz2 + ROS_OUTPUT_DIR: /app/assets/output/ros # Change output directory + # DIMOS_MAX_WORKERS: ${DIMOS_MAX_WORKERS} + # TODO: ipc: host + volumes: + - ../../../assets:/app/assets + ports: + - "5555:5555" + mem_limit: 8048m + stdin_open: true + tty: true + diff --git a/docker/deprecated/unitree/agents/entrypoint.sh b/docker/deprecated/unitree/agents/entrypoint.sh new file mode 100755 index 0000000000..7a8ddcae6a --- /dev/null +++ b/docker/deprecated/unitree/agents/entrypoint.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e + +# Create supervisor log directory +mkdir -p /app/assets/output + +# Delete old logs +echo "Cleaning up old Supervisor logs..." +rm -f /app/assets/output/*.log + +# Source ROS2 environment +source /opt/ros/${ROS_DISTRO}/setup.bash +source /ros2_ws/install/setup.bash + +# Execute the command passed to docker run +exec "$@" +# python3 -m tests.test_unitree_agent diff --git a/docker/deprecated/unitree/agents/supervisord.conf b/docker/deprecated/unitree/agents/supervisord.conf new file mode 100644 index 0000000000..b66be13e30 --- /dev/null +++ b/docker/deprecated/unitree/agents/supervisord.conf @@ -0,0 +1,35 @@ +[supervisord] +nodaemon=true +logfile=/var/log/supervisor/supervisord.log +pidfile=/var/run/supervisord.pid + +[program:ros2] +command=/bin/bash -c "source /opt/ros/humble/setup.bash && source /ros2_ws/install/setup.bash && ros2 launch go2_robot_sdk robot.launch.py" +autostart=true +autorestart=true + +stderr_logfile=/app/assets/output/ros2.err.log +stdout_logfile=/app/assets/output/ros2.out.log +environment=PYTHONUNBUFFERED=1 + +[program:dimos] +command=/bin/bash -c "sleep 10 && source /opt/ros/humble/setup.bash && source /ros2_ws/install/setup.bash && python3 /app/tests/test_planning_agent_web_interface.py" +autostart=true +autorestart=true +startsecs=11 + +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 +environment=PYTHONUNBUFFERED=1 + +[unix_http_server] +file=/var/run/supervisor.sock +chmod=0700 + +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[supervisorctl] +serverurl=unix:///var/run/supervisor.sock \ No newline at end of file diff --git a/docker/deprecated/unitree/agents_interface/Dockerfile b/docker/deprecated/unitree/agents_interface/Dockerfile new file mode 100644 index 0000000000..3bc00d2a16 --- /dev/null +++ b/docker/deprecated/unitree/agents_interface/Dockerfile @@ -0,0 +1,151 @@ +FROM ubuntu:22.04 + +# Avoid prompts from apt +ENV DEBIAN_FRONTEND=noninteractive + +# Set locale +RUN apt-get update && apt-get install -y locales && \ + locale-gen en_US en_US.UTF-8 && \ + update-locale LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 +ENV LANG=en_US.UTF-8 + +# Set ROS distro +ENV ROS_DISTRO=humble + +# Install basic requirements +RUN apt-get update && apt-get install -y \ + curl \ + gnupg2 \ + lsb-release \ + python3-pip \ + clang \ + portaudio19-dev \ + git \ + mesa-utils \ + libgl1-mesa-glx \ + libgl1-mesa-dri \ + software-properties-common \ + libxcb1-dev \ + libxcb-keysyms1-dev \ + libxcb-util0-dev \ + libxcb-icccm4-dev \ + libxcb-image0-dev \ + libxcb-randr0-dev \ + libxcb-shape0-dev \ + libxcb-xinerama0-dev \ + libxcb-xkb-dev \ + libxkbcommon-x11-dev \ + qtbase5-dev \ + qtchooser \ + qt5-qmake \ + qtbase5-dev-tools \ + supervisor \ + && rm -rf /var/lib/apt/lists/* + +# Install specific numpy version first +RUN pip install 'numpy<2.0.0' + +# Add ROS2 apt repository +RUN curl -sSL https://raw.githubusercontent.com/ros/rosdistro/master/ros.key -o /usr/share/keyrings/ros-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/ros-archive-keyring.gpg] http://packages.ros.org/ros2/ubuntu $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/ros2.list > /dev/null + +# Install ROS2 packages and dependencies +RUN apt-get update && apt-get install -y \ + ros-${ROS_DISTRO}-desktop \ + ros-${ROS_DISTRO}-ros-base \ + ros-${ROS_DISTRO}-image-tools \ + ros-${ROS_DISTRO}-compressed-image-transport \ + ros-${ROS_DISTRO}-vision-msgs \ + ros-${ROS_DISTRO}-rviz2 \ + ros-${ROS_DISTRO}-rqt \ + ros-${ROS_DISTRO}-rqt-common-plugins \ + ros-${ROS_DISTRO}-twist-mux \ + ros-${ROS_DISTRO}-joy \ + ros-${ROS_DISTRO}-teleop-twist-joy \ + ros-${ROS_DISTRO}-navigation2 \ + ros-${ROS_DISTRO}-nav2-bringup \ + ros-${ROS_DISTRO}-nav2-amcl \ + ros-${ROS_DISTRO}-nav2-map-server \ + ros-${ROS_DISTRO}-nav2-util \ + ros-${ROS_DISTRO}-pointcloud-to-laserscan \ + ros-${ROS_DISTRO}-slam-toolbox \ + ros-${ROS_DISTRO}-foxglove-bridge \ + python3-rosdep \ + python3-rosinstall \ + python3-rosinstall-generator \ + python3-wstool \ + python3-colcon-common-extensions \ + python3-vcstool \ + build-essential \ + screen \ + tmux \ + && rm -rf /var/lib/apt/lists/* + +# Initialize rosdep +RUN rosdep init && rosdep update + +# Create workspace +WORKDIR /ros2_ws + +# Clone the repository with submodules +RUN git clone --recurse-submodules https://github.com/dimensionalOS/go2_ros2_sdk src + +# Install Python requirements +RUN cd src && pip install -r requirements.txt + +# Create dimos directory structure +RUN mkdir -p /app/dimos /app/docker + +COPY requirements.txt /app/ + +COPY base-requirements.txt /app/ + +WORKDIR /app + +# Install torch and torchvision first due to builds in requirements.txt +RUN pip install --no-cache-dir -r base-requirements.txt + +# Install dimos requirements +RUN pip install --no-cache-dir -r requirements.txt + +# Set PYTHONPATH permanently +ENV PYTHONPATH=/app:${PYTHONPATH} + +# Install ROS dependencies +WORKDIR /ros2_ws +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + rosdep install --from-paths src --ignore-src -r -y + +# Build the workspace +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + colcon build + +# Source ROS2 and workspace in bashrc +RUN echo "source /opt/ros/${ROS_DISTRO}/setup.bash" >> /root/.bashrc && \ + echo "source /ros2_ws/install/setup.bash" >> /root/.bashrc + +COPY docker /app/docker/ + +# Setup supervisor configuration +COPY docker/unitree/agents_interface/supervisord.conf /etc/supervisor/conf.d/supervisord.conf + +# Copy entrypoint script +COPY docker/unitree/agents_interface/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +# Copy dimos and tests +COPY dimos /app/dimos/ +COPY tests /app/tests +COPY dimos/__init__.py /app/__init__.py + +# Change working directory to /app for proper relative pathing +WORKDIR /app + +# Create output directories for supervisord and ROS +RUN mkdir -p /app/assets/output/ +RUN mkdir -p /app/assets/output/ros + +# TODO: Cleanup multiple working directories and seprate the dockerfiles for each service. + +ENTRYPOINT ["/entrypoint.sh"] +CMD ["/usr/bin/supervisord", "-n", "-c", "/etc/supervisor/conf.d/supervisord.conf"] diff --git a/docker/deprecated/unitree/agents_interface/docker-compose.yml b/docker/deprecated/unitree/agents_interface/docker-compose.yml new file mode 100644 index 0000000000..62b59d24ba --- /dev/null +++ b/docker/deprecated/unitree/agents_interface/docker-compose.yml @@ -0,0 +1,43 @@ +--- +services: + dimos-unitree-agents-interface: + image: dimos-unitree-agents-interface:latest + build: + context: ../../../ + dockerfile: docker/unitree/agents_interface/Dockerfile + env_file: + - ../../../.env + environment: + - PYTHONUNBUFFERED=1 + - ROS_OUTPUT_DIR=/app/assets/output/ros # Change output directory + - NVIDIA_VISIBLE_DEVICES=all + - DISPLAY=$DISPLAY + # DIMOS_MAX_WORKERS: ${DIMOS_MAX_WORKERS} + # TODO: ipc: host + volumes: + - ../../../assets:/app/assets + - /tmp/.X11-unix:/tmp/.X11-unix + - ~/.Xauthority:/root/.Xauthority:ro + # Persist model caches in host filesystem + - ../../../assets/model-cache/torch-hub:/root/.cache/torch/hub + - ../../../assets/model-cache/iopath-cache:/root/.torch/iopath_cache + - ../../../assets/model-cache/ultralytics:/root/.config/Ultralytics + network_mode: "host" + ports: + - "5555:5555" + mem_limit: 8048m + runtime: nvidia + stdin_open: true + tty: true + + dimos-web-interface: + build: + context: ../../../ + dockerfile: docker/interface/Dockerfile + image: dimos-web-interface:latest + container_name: dimos-web-interface + network_mode: "host" + volumes: + - ../../../dimos/web/dimos_interface:/app + depends_on: + - dimos-unitree-agents-interface \ No newline at end of file diff --git a/docker/deprecated/unitree/agents_interface/entrypoint.sh b/docker/deprecated/unitree/agents_interface/entrypoint.sh new file mode 100755 index 0000000000..7a8ddcae6a --- /dev/null +++ b/docker/deprecated/unitree/agents_interface/entrypoint.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e + +# Create supervisor log directory +mkdir -p /app/assets/output + +# Delete old logs +echo "Cleaning up old Supervisor logs..." +rm -f /app/assets/output/*.log + +# Source ROS2 environment +source /opt/ros/${ROS_DISTRO}/setup.bash +source /ros2_ws/install/setup.bash + +# Execute the command passed to docker run +exec "$@" +# python3 -m tests.test_unitree_agent diff --git a/docker/deprecated/unitree/agents_interface/supervisord.conf b/docker/deprecated/unitree/agents_interface/supervisord.conf new file mode 100644 index 0000000000..b03b614fcd --- /dev/null +++ b/docker/deprecated/unitree/agents_interface/supervisord.conf @@ -0,0 +1,35 @@ +[supervisord] +nodaemon=true +logfile=/var/log/supervisor/supervisord.log +pidfile=/var/run/supervisord.pid + +[program:ros2] +command=/bin/bash -c "source /opt/ros/humble/setup.bash && source /ros2_ws/install/setup.bash && ros2 launch go2_robot_sdk robot.launch.py" +autostart=true +autorestart=true + +stderr_logfile=/app/assets/output/ros2.err.log +stdout_logfile=/app/assets/output/ros2.out.log +environment=PYTHONUNBUFFERED=1 + +[program:dimos] +command=/bin/bash -c "sleep 10 && source /opt/ros/humble/setup.bash && source /ros2_ws/install/setup.bash && python3 /app/tests/run.py --new-memory" +autostart=true +autorestart=true +startsecs=11 + +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 +environment=PYTHONUNBUFFERED=1 + +[unix_http_server] +file=/var/run/supervisor.sock +chmod=0700 + +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[supervisorctl] +serverurl=unix:///var/run/supervisor.sock \ No newline at end of file diff --git a/docker/deprecated/unitree/ros/Dockerfile b/docker/deprecated/unitree/ros/Dockerfile new file mode 100644 index 0000000000..6d495a5065 --- /dev/null +++ b/docker/deprecated/unitree/ros/Dockerfile @@ -0,0 +1,116 @@ +FROM ubuntu:22.04 + +# Avoid prompts from apt +ENV DEBIAN_FRONTEND=noninteractive + +# Set locale +RUN apt-get update && apt-get install -y locales && \ + locale-gen en_US en_US.UTF-8 && \ + update-locale LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 +ENV LANG=en_US.UTF-8 + +# Set ROS distro +ENV ROS_DISTRO=humble + +# Install basic requirements +RUN apt-get update && apt-get install -y \ + curl \ + gnupg2 \ + lsb-release \ + python3-pip \ + clang \ + portaudio19-dev \ + git \ + mesa-utils \ + libgl1-mesa-glx \ + libgl1-mesa-dri \ + software-properties-common \ + libxcb1-dev \ + libxcb-keysyms1-dev \ + libxcb-util0-dev \ + libxcb-icccm4-dev \ + libxcb-image0-dev \ + libxcb-randr0-dev \ + libxcb-shape0-dev \ + libxcb-xinerama0-dev \ + libxcb-xkb-dev \ + libxkbcommon-x11-dev \ + qtbase5-dev \ + qtchooser \ + qt5-qmake \ + qtbase5-dev-tools \ + && rm -rf /var/lib/apt/lists/* + +# Install specific numpy version first +RUN pip install 'numpy<2.0.0' + +# Add ROS2 apt repository +RUN curl -sSL https://raw.githubusercontent.com/ros/rosdistro/master/ros.key -o /usr/share/keyrings/ros-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/ros-archive-keyring.gpg] http://packages.ros.org/ros2/ubuntu $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/ros2.list > /dev/null + +# Install ROS2 packages and dependencies +RUN apt-get update && apt-get install -y \ + ros-${ROS_DISTRO}-desktop \ + ros-${ROS_DISTRO}-ros-base \ + ros-${ROS_DISTRO}-image-tools \ + ros-${ROS_DISTRO}-compressed-image-transport \ + ros-${ROS_DISTRO}-vision-msgs \ + ros-${ROS_DISTRO}-rviz2 \ + ros-${ROS_DISTRO}-rqt \ + ros-${ROS_DISTRO}-rqt-common-plugins \ + ros-${ROS_DISTRO}-twist-mux \ + ros-${ROS_DISTRO}-joy \ + ros-${ROS_DISTRO}-teleop-twist-joy \ + ros-${ROS_DISTRO}-navigation2 \ + ros-${ROS_DISTRO}-nav2-bringup \ + ros-${ROS_DISTRO}-nav2-amcl \ + ros-${ROS_DISTRO}-nav2-map-server \ + ros-${ROS_DISTRO}-nav2-util \ + ros-${ROS_DISTRO}-pointcloud-to-laserscan \ + ros-${ROS_DISTRO}-slam-toolbox \ + ros-${ROS_DISTRO}-foxglove-bridge \ + python3-rosdep \ + python3-rosinstall \ + python3-rosinstall-generator \ + python3-wstool \ + python3-colcon-common-extensions \ + python3-vcstool \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Initialize rosdep +RUN rosdep init && rosdep update + +# Create workspace +WORKDIR /ros2_ws + +# Clone the repository with submodules +RUN git clone --recurse-submodules https://github.com/dimensionalOS/go2_ros2_sdk src + +# Install Python requirements (with numpy constraint) +RUN cd src && pip install -r requirements.txt + +# Install ROS dependencies +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + rosdep install --from-paths src --ignore-src -r -y + +# Build the workspace +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + colcon build + +# Source ROS2 and workspace in bashrc +RUN echo "source /opt/ros/${ROS_DISTRO}/setup.bash" >> /root/.bashrc && \ + echo "source /ros2_ws/install/setup.bash" >> /root/.bashrc + +# Set environment variables +ENV ROBOT_IP="" +ENV CONN_TYPE="webrtc" +ENV WEBRTC_SERVER_HOST="0.0.0.0" +ENV WEBRTC_SERVER_PORT="9991" + +# Copy entrypoint script +COPY entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +ENTRYPOINT ["/entrypoint.sh"] +CMD ["ros2", "launch", "go2_robot_sdk", "robot.launch.py"] diff --git a/docker/deprecated/unitree/ros/README.md b/docker/deprecated/unitree/ros/README.md new file mode 100644 index 0000000000..3b6deff3ad --- /dev/null +++ b/docker/deprecated/unitree/ros/README.md @@ -0,0 +1,69 @@ +# Unitree Go2 ROS Docker Setup + +This README explains how to run the Unitree Go2 ROS nodes using Docker. + +## Prerequisites + +- Docker and Docker Compose installed +- A Unitree Go2 robot accessible on your network +- The robot's IP address + +## Configuration + +The connection can be configured through environment variables in two ways: + +1. Setting them before running docker-compose: + ```bash + export ROBOT_IP=192.168.9.140 + export CONN_TYPE=webrtc # or cyclonedds + ``` + +2. Hardcoding them directly in `docker/docker-compose.yaml` + +## Usage + +To run the ROS nodes: + +1. Navigate to the docker directory: + ```bash + cd docker/unitree/ros + ``` + +2. Run with environment variables: + ```bash + xhost +local:root # If running locally and desire RVIZ GUI + ROBOT_IP= CONN_TYPE= docker-compose up --build + ``` + + Where: + - `` is your Go2's IP address + - `` choose either: + - `webrtc`: For WebRTC video streaming connection + - `cyclonedds`: For DDS communication + +The containers will build and start, establishing connection with your Go2 robot and opening RVIZ. + + +## Known Issues + +1. If you encounter the error `unitree_ros-1 | exec /entrypoint.sh: no such file or directory`, this can be caused by: + - Incorrect file permissions + - Windows-style line endings (CRLF) in the entrypoint script + + To fix: + 1. Ensure the entrypoint script has execute permissions: + ```bash + chmod +x entrypoint.sh + ``` + + 2. If using Windows, convert line endings to Unix format (LF): + ```bash + # Using dos2unix + dos2unix entrypoint.sh + + # Or using sed + sed -i 's/\r$//' entrypoint.sh + ``` + + + diff --git a/docker/deprecated/unitree/ros/docker-compose.yml b/docker/deprecated/unitree/ros/docker-compose.yml new file mode 100644 index 0000000000..a16aaff4c9 --- /dev/null +++ b/docker/deprecated/unitree/ros/docker-compose.yml @@ -0,0 +1,22 @@ +--- +services: + unitree_ros: + image: unitree_ros:latest + build: + context: ../../../ + dockerfile: docker/unitree/ros/Dockerfile + environment: + - PYTHONUNBUFFERED=1 + - ROBOT_IP=${ROBOT_IP} + - CONN_TYPE=${CONN_TYPE:-webrtc} + - WEBRTC_SERVER_HOST=0.0.0.0 # Listen on all interfaces + - WEBRTC_SERVER_PORT=${WEBRTC_SERVER_PORT:-9991} + - DISPLAY=${DISPLAY:-} # For GUI applications like rviz2 + volumes: + - /tmp/.X11-unix:/tmp/.X11-unix # X11 forwarding + - ${HOME}/.Xauthority:/root/.Xauthority:rw + network_mode: "host" # Required for ROS2 discovery and robot communication + privileged: true # Required for hardware access + devices: + - /dev/input:/dev/input # For joystick access + restart: unless-stopped diff --git a/docker/deprecated/unitree/ros/entrypoint.sh b/docker/deprecated/unitree/ros/entrypoint.sh new file mode 100755 index 0000000000..dcdc8660c4 --- /dev/null +++ b/docker/deprecated/unitree/ros/entrypoint.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e +# Source ROS2 environment +source /opt/ros/${ROS_DISTRO}/setup.bash +source /ros2_ws/install/setup.bash +# Execute the command passed to docker run +exec "$@" diff --git a/docker/deprecated/unitree/ros_agents/docker-compose.yml b/docker/deprecated/unitree/ros_agents/docker-compose.yml new file mode 100644 index 0000000000..6d93ea89ab --- /dev/null +++ b/docker/deprecated/unitree/ros_agents/docker-compose.yml @@ -0,0 +1,67 @@ +--- +services: + dimos-unitree-ros-agents: + image: dimos-unitree-ros-agents:latest + build: + context: ../../../ + dockerfile: docker/unitree/ros_agents/Dockerfile + env_file: + - ../../../.env + environment: + PYTHONUNBUFFERED: 1 + ROBOT_IP: ${ROBOT_IP} + CONN_TYPE: ${CONN_TYPE:-webrtc} + WEBRTC_SERVER_HOST: 0.0.0.0 # Listen on all interfaces + WEBRTC_SERVER_PORT: ${WEBRTC_SERVER_PORT:-9991} + DISPLAY: ${DISPLAY:-} # For GUI applications like rviz2 + ROS_OUTPUT_DIR: /app/assets/output/ros # Change output directory + # DIMOS_MAX_WORKERS: ${DIMOS_MAX_WORKERS} + # TODO: ipc: host + volumes: + - ../../../assets:/app/assets + network_mode: "host" + ports: + - "5555:5555" + mem_limit: 8048m + stdin_open: true + tty: true + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:5555/unitree/status"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 10s + + dimos-web-interface: + build: + context: ../../../ + dockerfile: docker/interface/Dockerfile + image: dimos-web-interface:latest + container_name: dimos-web-interface + network_mode: "host" + volumes: + - ../../../dimos/web/dimos_interface:/app + depends_on: + dimos-unitree-ros-agents: + condition: service_healthy + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:3000"] + interval: 30s + timeout: 10s + retries: 3 + + +# ---- +# TO RUN: +# docker build -f ./Dockerfile -t dimos ../../ && docker compose up +# GO TO: +# 127.0.0.1:5555 (when flask server fixed) +# ---- + +# video-service: +# build: ./video-service +# image: video-service:latest +# volumes: +# - ./../../assets:/app/dimos-env/assets +# ports: +# - "23001:23001" diff --git a/docker/deprecated/unitree/ros_dimos/Dockerfile b/docker/deprecated/unitree/ros_dimos/Dockerfile new file mode 100644 index 0000000000..3c712a3578 --- /dev/null +++ b/docker/deprecated/unitree/ros_dimos/Dockerfile @@ -0,0 +1,148 @@ +FROM ubuntu:22.04 + +# Avoid prompts from apt +ENV DEBIAN_FRONTEND=noninteractive + +# Set locale +RUN apt-get update && apt-get install -y locales && \ + locale-gen en_US en_US.UTF-8 && \ + update-locale LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 +ENV LANG=en_US.UTF-8 + +# Set ROS distro +ENV ROS_DISTRO=humble + +# Install basic requirements +RUN apt-get update && apt-get install -y \ + curl \ + gnupg2 \ + lsb-release \ + python3-pip \ + clang \ + portaudio19-dev \ + git \ + mesa-utils \ + libgl1-mesa-glx \ + libgl1-mesa-dri \ + software-properties-common \ + libxcb1-dev \ + libxcb-keysyms1-dev \ + libxcb-util0-dev \ + libxcb-icccm4-dev \ + libxcb-image0-dev \ + libxcb-randr0-dev \ + libxcb-shape0-dev \ + libxcb-xinerama0-dev \ + libxcb-xkb-dev \ + libxkbcommon-x11-dev \ + qtbase5-dev \ + qtchooser \ + qt5-qmake \ + qtbase5-dev-tools \ + supervisor \ + && rm -rf /var/lib/apt/lists/* + +# Install specific numpy version first +RUN pip install 'numpy<2.0.0' + +# Add ROS2 apt repository +RUN curl -sSL https://raw.githubusercontent.com/ros/rosdistro/master/ros.key -o /usr/share/keyrings/ros-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/ros-archive-keyring.gpg] http://packages.ros.org/ros2/ubuntu $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/ros2.list > /dev/null + +# Install ROS2 packages and dependencies +RUN apt-get update && apt-get install -y \ + ros-${ROS_DISTRO}-desktop \ + ros-${ROS_DISTRO}-ros-base \ + ros-${ROS_DISTRO}-image-tools \ + ros-${ROS_DISTRO}-compressed-image-transport \ + ros-${ROS_DISTRO}-vision-msgs \ + ros-${ROS_DISTRO}-rviz2 \ + ros-${ROS_DISTRO}-rqt \ + ros-${ROS_DISTRO}-rqt-common-plugins \ + ros-${ROS_DISTRO}-twist-mux \ + ros-${ROS_DISTRO}-joy \ + ros-${ROS_DISTRO}-teleop-twist-joy \ + ros-${ROS_DISTRO}-navigation2 \ + ros-${ROS_DISTRO}-nav2-bringup \ + ros-${ROS_DISTRO}-nav2-amcl \ + ros-${ROS_DISTRO}-nav2-map-server \ + ros-${ROS_DISTRO}-nav2-util \ + ros-${ROS_DISTRO}-pointcloud-to-laserscan \ + ros-${ROS_DISTRO}-slam-toolbox \ + ros-${ROS_DISTRO}-foxglove-bridge \ + python3-rosdep \ + python3-rosinstall \ + python3-rosinstall-generator \ + python3-wstool \ + python3-colcon-common-extensions \ + python3-vcstool \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Initialize rosdep +RUN rosdep init && rosdep update + +# Create workspace +WORKDIR /ros2_ws + +# Clone the repository with submodules +RUN git clone --recurse-submodules https://github.com/dimensionalOS/go2_ros2_sdk src + +# Install Python requirements +RUN cd src && pip install -r requirements.txt + +# Create dimos directory structure +RUN mkdir -p /app/dimos /app/docker + +COPY requirements.txt /app/ + +WORKDIR /app + +# Install dimos requirements +RUN pip install --no-cache-dir -r requirements.txt + +# Set PYTHONPATH permanently +ENV PYTHONPATH=/app:${PYTHONPATH} + +# Install ROS dependencies +WORKDIR /ros2_ws +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + rosdep install --from-paths src --ignore-src -r -y + +# Build the workspace +RUN . /opt/ros/${ROS_DISTRO}/setup.sh && \ + colcon build + +# Source ROS2 and workspace in bashrc +RUN echo "source /opt/ros/${ROS_DISTRO}/setup.bash" >> /root/.bashrc && \ + echo "source /ros2_ws/install/setup.bash" >> /root/.bashrc + +# Set environment variables +# webrtc or cyclonedds +ENV CONN_TYPE="webrtc" +ENV WEBRTC_SERVER_HOST="0.0.0.0" +ENV WEBRTC_SERVER_PORT="9991" + +COPY docker /app/docker/ + +# Setup supervisor configuration +COPY docker/unitree/ros_dimos/supervisord.conf /etc/supervisor/conf.d/supervisord.conf + +# Copy entrypoint script +COPY docker/unitree/ros_dimos/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +COPY dimos /app/dimos/ +COPY tests /app/tests/ + +# Change working directory to /app for proper relative pathing +WORKDIR /app + +# Create output directories for supervisord and ROS +RUN mkdir -p /app/assets/output/ +RUN mkdir -p /app/assets/output/ros + +# TODO: Cleanup multiple working directories and seprate the dockerfiles for each service. + +ENTRYPOINT ["/entrypoint.sh"] +CMD ["/usr/bin/supervisord", "-n", "-c", "/etc/supervisor/conf.d/supervisord.conf"] diff --git a/docker/deprecated/unitree/ros_dimos/README.md b/docker/deprecated/unitree/ros_dimos/README.md new file mode 100644 index 0000000000..4c63aaddb2 --- /dev/null +++ b/docker/deprecated/unitree/ros_dimos/README.md @@ -0,0 +1,165 @@ +# Unitree Go2 ROS + DIMOS Movement Agents Docker Setup + +This README explains how to run the Unitree Go2 ROS nodes with DIMOS integration using Docker. + +## Prerequisites + +- Docker and Docker Compose installed +- A Unitree Go2 robot accessible on your network +- The robot's IP address +- Python requirements installed (see root directory's requirements.txt) + +## Configuration + +1. Set environment variables in .env: + ```bash + ROBOT_IP= + CONN_TYPE=webrtc + WEBRTC_SERVER_HOST=0.0.0.0 + WEBRTC_SERVER_PORT=9991 + DISPLAY=:0 + ROS_OUTPUT_DIR=/app/assets/output/ros + ``` + +2. Or run with environment variables in command line docker-compose: + ```bash + ROBOT_IP=192.168.9.140 CONN_TYPE=webrtc docker compose -f docker/unitree/ros_dimos/docker-compose.yml up --build + ``` + +## Usage + +To run the ROS nodes with DIMOS: + +```bash +xhost +local:root # If running locally and desire RVIZ GUI +ROBOT_IP= CONN_TYPE= docker compose -f docker/unitree/ros_dimos/docker-compose.yml up --build +``` + +Where: +- `` is your Go2's IP address +- `` choose either: + - `webrtc`: For WebRTC video streaming connection + - `cyclonedds`: For DDS communication + +The containers will build and start, establishing connection with your Go2 robot and opening RVIZ. The DIMOS integration will start 10 seconds after ROS to ensure proper initialization. + +Note: You can run this command from any directory since the docker-compose.yml file handles all relative paths internally. + +## Process Management + +The setup uses supervisord to manage both ROS and DIMOS processes. To check process status or view logs when inside the container: + +```bash +# Get a shell in the container +docker compose -f docker/unitree/ros_dimos/docker-compose.yml exec unitree_ros_dimos bash + +# View process status +supervisorctl status + +# View logs +supervisorctl tail ros2 # ROS2 logs +supervisorctl tail dimos # DIMOS logs +supervisorctl tail -f ros2 # Follow ROS2 logs +``` + +## Known Issues + +1. ROS2 doesn't have time to initialize before DIMOS starts, so the DIMOS logs will show successful aioice.ice:Connection followed by aiortc.exceptions.InvalidStateError. + +This is currently solved by hardcoding a delay between ros2 and DIMOS start in supervisord.conf. + +```ini +[lifecycle_manager-18] [INFO] [1740128988.350926960] [lifecycle_manager_navigation]: Managed nodes are active +[lifecycle_manager-18] [INFO] [1740128988.350965828] [lifecycle_manager_navigation]: Creating bond timer... +[go2_driver_node-3] INFO:scripts.webrtc_driver:Connection state is connecting +[go2_driver_node-3] INFO:aioice.ice:Connection(1) Discovered peer reflexive candidate Candidate(3hokvTUH7e 1 udp 2130706431 192.168.9.140 37384 typ prflx) +[go2_driver_node-3] INFO:aioice.ice:Connection(1) Check CandidatePair(('192.168.9.155', 33483) -> ('192.168.9.140', 37384)) State.WAITING -> State.IN_PROGRESS +[go2_driver_node-3] [INFO] [1740128990.171453153] [go2_driver_node]: Move +[go2_driver_node-3] INFO:scripts.webrtc_driver:Receiving video +[go2_driver_node-3] ERROR:asyncio:Task exception was never retrieved +[go2_driver_node-3] future: exception=InvalidStateError()> +[go2_driver_node-3] Traceback (most recent call last): +[go2_driver_node-3] File "/ros2_ws/install/go2_robot_sdk/lib/python3.10/site-packages/go2_robot_sdk/go2_driver_node.py", line 634, in run +[go2_driver_node-3] self.joy_cmd(robot_num) +[go2_driver_node-3] File "/ros2_ws/install/go2_robot_sdk/lib/python3.10/site-packages/go2_robot_sdk/go2_driver_node.py", line 320, in joy_cmd +[go2_driver_node-3] self.conn[robot_num].data_channel.send( +[go2_driver_node-3] File "/usr/local/lib/python3.10/dist-packages/aiortc/rtcdatachannel.py", line 182, in send +[go2_driver_node-3] raise InvalidStateError +[go2_driver_node-3] aiortc.exceptions.InvalidStateError +[go2_driver_node-3] Exception in thread Thread-1 (_spin): +[go2_driver_node-3] Traceback (most recent call last): +[go2_driver_node-3] File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner +[go2_driver_node-3] self.run() +[go2_driver_node-3] File "/usr/lib/python3.10/threading.py", line 953, in run +[go2_driver_node-3] self._target(*self._args, **self._kwargs) +[go2_driver_node-3] File "/ros2_ws/install/go2_robot_sdk/lib/python3.10/site-packages/go2_robot_sdk/go2_driver_node.py", line 646, in _spin +[go2_driver_node-3] rclpy.spin_once(node) +[go2_driver_node-3] File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/__init__.py", line 203, in spin_once +[go2_driver_node-3] executor = get_global_executor() if executor is None else executor +[go2_driver_node-3] File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/__init__.py", line 106, in get_global_executor +[go2_driver_node-3] __executor = SingleThreadedExecutor() +[go2_driver_node-3] File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 721, in __init__ +[go2_driver_node-3] super().__init__(context=context) +[go2_driver_node-3] File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 172, in __init__ +[go2_driver_node-3] self._guard = GuardCondition( +[go2_driver_node-3] File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/guard_condition.py", line 23, in __init__ +[go2_driver_node-3] with self._context.handle: +[go2_driver_node-3] AttributeError: __enter__ +[go2_driver_node-3] Exception ignored in: +[go2_driver_node-3] Traceback (most recent call last): +[go2_driver_node-3] File "/opt/ros/humble/local/lib/python3.10/dist-packages/rclpy/executors.py", line 243, in __del__ +[go2_driver_node-3] if self._sigint_gc is not None: +[go2_driver_node-3] AttributeError: 'SingleThreadedExecutor' object has no attribute '_sigint_gc' +[go2_driver_node-3] ERROR:asyncio:Task was destroyed but it is pending! +[go2_driver_node-3] task: wait_for=._outer_done_callback() at /usr/lib/python3.10/asyncio/tasks.py:864, Task.task_wakeup()]>> +[go2_driver_node-3] ERROR:asyncio:Task was destroyed but it is pending! +[go2_driver_node-3] task: wait_for=> +[go2_driver_node-3] Exception ignored in: +[go2_driver_node-3] Traceback (most recent call last): +[go2_driver_node-3] File "/ros2_ws/install/go2_robot_sdk/lib/python3.10/site-packages/scripts/webrtc_driver.py", line 229, in on_track +[go2_driver_node-3] frame = await track.recv() +[go2_driver_node-3] File "/usr/local/lib/python3.10/dist-packages/aiortc/rtcrtpreceiver.py", line 203, in recv +[go2_driver_node-3] frame = await self._queue.get() +[go2_driver_node-3] File "/usr/lib/python3.10/asyncio/queues.py", line 161, in get +[go2_driver_node-3] getter.cancel() # Just in case getter is not done yet. +[go2_driver_node-3] File "/usr/lib/python3.10/asyncio/base_events.py", line 753, in call_soon +[go2_driver_node-3] self._check_closed() +[go2_driver_node-3] File "/usr/lib/python3.10/asyncio/base_events.py", line 515, in _check_closed +[go2_driver_node-3] raise RuntimeError('Event loop is closed') +[go2_driver_node-3] RuntimeError: Event loop is closed +[go2_driver_node-3] ERROR:asyncio:Task was destroyed but it is pending! +[go2_driver_node-3] task: wait_for= cb=[AsyncIOEventEmitter._emit_run..callback() at /usr/local/lib/python3.10/dist-packages/pyee/asyncio.py:95]> +[go2_driver_node-3] ERROR:asyncio:Task was destroyed but it is pending! +[go2_driver_node-3] task: wait_for=> +[go2_driver_node-3] ERROR:asyncio:Task was destroyed but it is pending! +[go2_driver_node-3] task: wait_for=> +[go2_driver_node-3] ERROR:asyncio:Task was destroyed but it is pending! +[go2_driver_node-3] task: wait_for=> +[INFO] [go2_driver_node-3]: process has finished cleanly [pid 120] +``` + + +2. If you encounter the error `unitree_ros_dimos-1 | exec /entrypoint.sh: no such file or directory`, this can be caused by: + - Incorrect file permissions + - Windows-style line endings (CRLF) in the entrypoint script + + To fix: + 1. Ensure the entrypoint script has execute permissions: + ```bash + chmod +x /path/to/dimos/docker/unitree/ros_dimos/entrypoint.sh + ``` + + 2. If using Windows, convert line endings to Unix format (LF): + ```bash + # Using dos2unix + dos2unix /path/to/dimos/docker/unitree/ros_dimos/entrypoint.sh + + # Or using sed + sed -i 's/\r$//' /path/to/dimos/docker/unitree/ros_dimos/entrypoint.sh + ``` + +2. If DIMOS fails to start, check: + - The ROS nodes are fully initialized (wait a few seconds) + - The environment variables are properly set + - The Python path includes the dimos directory + - The logs using supervisorctl for specific error messages \ No newline at end of file diff --git a/docker/deprecated/unitree/ros_dimos/docker-compose.yml b/docker/deprecated/unitree/ros_dimos/docker-compose.yml new file mode 100644 index 0000000000..2d36b4d479 --- /dev/null +++ b/docker/deprecated/unitree/ros_dimos/docker-compose.yml @@ -0,0 +1,18 @@ +--- +services: + unitree_ros_dimos: + image: unitree_ros_dimos:latest + build: + context: ../../../ + dockerfile: docker/unitree/ros_dimos/Dockerfile + env_file: + - ../../../.env + volumes: + - /tmp/.X11-unix:/tmp/.X11-unix # X11 forwarding + - ${HOME}/.Xauthority:/root/.Xauthority:rw + - ../../../assets/output/:/app/assets/output + network_mode: "host" # Required for ROS2 discovery and robot communication + privileged: true # Required for hardware access + devices: + - /dev/input:/dev/input # For joystick access + restart: unless-stopped diff --git a/docker/deprecated/unitree/ros_dimos/entrypoint.sh b/docker/deprecated/unitree/ros_dimos/entrypoint.sh new file mode 100755 index 0000000000..f7d753f1f7 --- /dev/null +++ b/docker/deprecated/unitree/ros_dimos/entrypoint.sh @@ -0,0 +1,16 @@ +#!/bin/bash +set -e + +# Create supervisor log directory + +mkdir -p /app/assets/output + +# Delete old logs +echo "Cleaning up old Supervisor logs..." +rm -f /app/assets/output/*.log + +# Source ROS2 environment +source /opt/ros/${ROS_DISTRO}/setup.bash +source /ros2_ws/install/setup.bash +# Execute the command passed to docker run +exec "$@" diff --git a/docker/deprecated/unitree/ros_dimos/supervisord.conf b/docker/deprecated/unitree/ros_dimos/supervisord.conf new file mode 100644 index 0000000000..105742b844 --- /dev/null +++ b/docker/deprecated/unitree/ros_dimos/supervisord.conf @@ -0,0 +1,35 @@ +[supervisord] +nodaemon=true +logfile=/var/log/supervisor/supervisord.log +pidfile=/var/run/supervisord.pid + +[program:ros2] +command=/bin/bash -c "source /opt/ros/humble/setup.bash && source /ros2_ws/install/setup.bash && ros2 launch go2_robot_sdk robot.launch.py" +autostart=true +autorestart=true + +stderr_logfile=/app/assets/output/ros2.err.log +stdout_logfile=/app/assets/output/ros2.out.log +environment=PYTHONUNBUFFERED=1 + +[program:dimos] +command=/bin/bash -c "sleep 10 && source /opt/ros/humble/setup.bash && source /ros2_ws/install/setup.bash && python3 /app/tests/run_go2_ros.py" +autostart=true +autorestart=true +startsecs=11 + +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 +environment=PYTHONUNBUFFERED=1 + +[unix_http_server] +file=/var/run/supervisor.sock +chmod=0700 + +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[supervisorctl] +serverurl=unix:///var/run/supervisor.sock diff --git a/docker/deprecated/unitree/webrtc/Dockerfile b/docker/deprecated/unitree/webrtc/Dockerfile new file mode 100644 index 0000000000..c073fbbe08 --- /dev/null +++ b/docker/deprecated/unitree/webrtc/Dockerfile @@ -0,0 +1,30 @@ +FROM python:3 + +RUN apt-get update && apt-get install -y \ + libgl1-mesa-glx \ + build-essential \ + libavformat-dev \ + libavcodec-dev \ + libavdevice-dev \ + libavutil-dev \ + libswscale-dev \ + libpostproc-dev \ + gcc \ + make \ + portaudio19-dev \ + python3-pyaudio \ + python3-all-dev + +WORKDIR /app + +COPY requirements.txt ./ + +RUN pip install --no-cache-dir -r requirements.txt + +COPY ./dimos ./dimos + +COPY ./tests ./tests + +COPY ./dimos/__init__.py ./ + +CMD [ "python", "-m", "dimos.robot.unitree.unitree_go2" ] diff --git a/docker/deprecated/unitree/webrtc/docker-compose.yml b/docker/deprecated/unitree/webrtc/docker-compose.yml new file mode 100644 index 0000000000..c8e9f234f6 --- /dev/null +++ b/docker/deprecated/unitree/webrtc/docker-compose.yml @@ -0,0 +1,44 @@ +--- +services: + dimos-unitree-webrtc: + image: dimos-unitree-webrtc:latest + build: + context: ../../../ + dockerfile: docker/unitree/webrtc/Dockerfile + env_file: + - ../../../.env + mem_limit: 8048m + volumes: + - ../../../assets:/app/assets + - ../../../output:/app/output + ports: + - "5555:5555" + environment: + - PYTHONUNBUFFERED=1 + # Robot configuration - use shell variables with defaults + - ROBOT_IP=${ROBOT_IP} + - CONNECTION_METHOD=${CONNECTION_METHOD:-LocalSTA} + - SERIAL_NUMBER=${SERIAL_NUMBER:-} + - OUTPUT_DIR=${OUTPUT_DIR:-/app/assets} + stdin_open: true + tty: true + command: ["python", "-m", "dimos.robot.unitree.run_go2"] + # command: ["tail", "-f", "/dev/null"] + +# ---- +# TO RUN with default values: +# docker compose up +# +# TO RUN with custom parameters: +# ROBOT_IP=192.168.1.100 CONNECTION_METHOD=LocalAP SERIAL_NUMBER=ABC123 docker compose up +# +# Examples: +# - With IP: +# ROBOT_IP=192.168.1.100 docker compose up +# +# - With LocalAP: +# CONNECTION_METHOD=LocalAP docker compose up +# +# - With Serial Number: +# CONNECTION_METHOD=LocalSTA SERIAL_NUMBER=ABC123 docker compose up +# ---- diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile new file mode 100644 index 0000000000..ef80b70e1d --- /dev/null +++ b/docker/dev/Dockerfile @@ -0,0 +1,54 @@ +ARG FROM_IMAGE=ghcr.io/dimensionalos/ros-python:dev +FROM ${FROM_IMAGE} + +ARG GIT_COMMIT=unknown +ARG GIT_BRANCH=unknown + +RUN apt-get update && apt-get install -y \ + git \ + git-lfs \ + nano \ + vim \ + ccze \ + tmux \ + htop \ + iputils-ping \ + wget \ + net-tools \ + sudo \ + pre-commit + + +# Configure git to trust any directory (resolves dubious ownership issues in containers) +RUN git config --global --add safe.directory '*' + +WORKDIR /app + +# Install UV for fast Python package management +ENV UV_SYSTEM_PYTHON=1 +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:$PATH" + +# Install dependencies with UV +RUN uv pip install .[dev] + +# Copy files and add version to motd +COPY /assets/dimensionalascii.txt /etc/motd +COPY /docker/dev/bash.sh /root/.bash.sh +COPY /docker/dev/tmux.conf /root/.tmux.conf + +# Install nodejs (for random devtooling like copilot etc) +RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.1/install.sh | bash +ENV NVM_DIR=/root/.nvm +RUN bash -c "source $NVM_DIR/nvm.sh && nvm install 24" + +# This doesn't work atm +RUN echo " v_${GIT_BRANCH}:${GIT_COMMIT} | $(date)" >> /etc/motd +RUN echo "echo -e '\033[34m$(cat /etc/motd)\033[0m\n'" >> /root/.bashrc + +RUN echo "source /root/.bash.sh" >> /root/.bashrc + +COPY /docker/dev/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +ENTRYPOINT ["/entrypoint.sh"] diff --git a/docker/dev/bash.sh b/docker/dev/bash.sh new file mode 100755 index 0000000000..c5248841d9 --- /dev/null +++ b/docker/dev/bash.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# history +shopt -s histappend +export HISTCONTROL="ignoredups" +export HISTSIZE=100000 +export HISTFILESIZE=100000 +export HISTIGNORE='ls' + +# basic vars +export EDITOR="nano" +export LESS='-R' + +# basic aliases +alias ta='tmux a' +alias ccze='ccze -o nolookups -A' +alias pd='p d' +alias t='tmux' +alias g='grep' +alias f='find' +alias ..="cd .." +alias ka="killall" +alias la="ls -al" +alias l="ls" +alias sl="ls" +alias ls="ls --color" +alias c="clear" +alias psa="ps aux" +alias grep="grep --color=auto" +alias p="ping -c 1 -w 1" +alias psg="ps aux | grep" +alias unitg="systemctl list-unit-files | grep" +alias ug="unitg" +alias unit="echo 'systemctl list-unit-files'; systemctl list-unit-files" +alias scr="echo 'sudo systemctl daemon-reload'; sudo systemctl daemon-reload" +alias psac="ps aux | ccze -Ao nolookups" +alias psa="ps aux" +alias pdn="p dns" +alias s="sudo -iu root" +alias m="mount" +alias oip="wget -qO- http://www.ipaddr.de/?plain" +alias getlogin="echo genpass 6 : genpass 20" +alias rscp="rsync -vrt --size-only --partial --progress " +alias rscpd="rsync --delete-after -vrt --size-only --partial --progress " +alias v="vim" +alias npm="export PYTHON=python2; npm" +alias ssh="ssh -o ConnectTimeout=1" +alias gp="git push" +alias rh="history -a; history -c; history -r" +alias gs="git status" +alias gd="git diff" +alias ipy="python -c 'import IPython; IPython.terminal.ipapp.launch_new_instance()'" + +function npmg +{ + echo 'global npm install' + tmpUmask u=rwx,g=rx,o=rx npm $@ +} + +function tmpUmask +{ + oldUmask=$(umask) + newUmask=$1 + + shift + umask $newUmask + echo umask $(umask -S) + echo "$@" + eval $@ + umask $oldUmask + echo umask $(umask -S) + +} + +function newloginuser +{ + read user + pass=$(genpass 20) + + echo $user : $pass + echo site? + read site + echo site: $site + + echo $site : $user : $pass >> ~/.p +} + +function newlogin +{ + user=$(genpass 6) + pass=$(genpass 20) + + echo $user : $pass + echo site? + read site + echo site: $site + + echo $site : $user : $pass >> ~/.p + +} + + +function newlogin +{ + pass=$(genpass 30) + echo $pass +} + + +function getpass { + echo $(genpass 20) +} + +function genpass +{ + newpass=$(cat /dev/urandom | base64 | tr -d "0" | tr -d "y" | tr -d "Y" | tr -d "z" | tr -d "Z" | tr -d "I" | tr -d "l" | tr -d "//" | head -c$1) + echo -n $newpass +} + +function sx +{ + if [ -z $1 ] + then + screen -x $(cat /tmp/sx) + else + echo -n $1 > /tmp/sx + screen -x $1 + fi +} + +function loopy +{ + while [ 1 ]; do + eval "$1" + if [ "$2" ]; then sleep $2; else sleep 1; fi + done +} + + +function we +{ + eval "$@" + until [ $? -eq 0 ]; do + sleep 1; eval "$@" + done +} + +alias wf='waitfor' +function waitfor +{ + eval "$1" + until [ $? -eq 0 ]; do + sleep 1; eval "$1" + done + eval "$2" +} + +function waitnot +{ + eval "$1" + until [ $? -ne 0 ]; do + sleep 1; eval "$1" + done + eval "$2" +} + +function wrscp +{ + echo rscp $@ + waitfor "rscp $1 $2" +} + +function waitfornot +{ + eval "$1" + until [ $? -ne 0 ]; do + sleep 1 + eval "$1" + done + eval "$2" +} + + +function watchFile +{ + tail -F $1 2>&1 | sed -e "$(echo -e "s/^\(tail: .\+: file truncated\)$/\1\e[2J \e[0f/")" +} + +PS1='${debian_chroot:+($debian_chroot)}\[\033[32m\]\u@dimos\[\033[00m\]:\[\033[34m\]\w\[\033[00m\] \$ ' + +export PATH="/app/bin:${PATH}" + +# we store history in the container so rebuilding doesn't lose it +export HISTFILE=/app/.bash_history + +# export all .env variables +set -a +source /app/.env +set +a diff --git a/docker/dev/docker-compose-cuda.yaml b/docker/dev/docker-compose-cuda.yaml new file mode 100644 index 0000000000..5def3fb6c3 --- /dev/null +++ b/docker/dev/docker-compose-cuda.yaml @@ -0,0 +1,32 @@ +services: + dev-environment: + image: ghcr.io/dimensionalos/dev:${DEV_IMAGE_TAG:-latest} + container_name: dimos-dev-${DEV_IMAGE_TAG:-latest} + network_mode: "host" + volumes: + - ../../../:/app + + # X11 forwarding + - /tmp/.X11-unix:/tmp/.X11-unix + - ${HOME}/.Xauthority:/root/.Xauthority:rw + + runtime: nvidia + environment: + - PYTHONUNBUFFERED=1 + - PYTHONPATH=/app + - DISPLAY=${DISPLAY:-} + + # NVIDIA + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=all + + # X11 and XDG runtime + - XAUTHORITY=/root/.Xauthority + - XDG_RUNTIME_DIR=/tmp/xdg-runtime + + ports: + - "5555:5555" + - "3000:3000" + stdin_open: true + tty: true + command: /bin/bash diff --git a/docker/dev/docker-compose.yaml b/docker/dev/docker-compose.yaml new file mode 100644 index 0000000000..8175e26c69 --- /dev/null +++ b/docker/dev/docker-compose.yaml @@ -0,0 +1,23 @@ +services: + dev-environment: + image: ghcr.io/dimensionalos/dev:${DEV_IMAGE_TAG:-latest} + container_name: dimos-dev-${DEV_IMAGE_TAG:-latest} + network_mode: "host" + volumes: + - ../../:/app + + # X11 forwarding + - /tmp/.X11-unix:/tmp/.X11-unix + - ${HOME}/.Xauthority:/root/.Xauthority:rw + + environment: + - PYTHONUNBUFFERED=1 + - PYTHONPATH=/app + - DISPLAY=${DISPLAY:-} + + ports: + - "5555:5555" + - "3000:3000" + stdin_open: true + tty: true + command: /bin/bash diff --git a/docker/dev/entrypoint.sh b/docker/dev/entrypoint.sh new file mode 100644 index 0000000000..d48bea16e3 --- /dev/null +++ b/docker/dev/entrypoint.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +if [ -d "/opt/ros/${ROS_DISTRO}" ]; then + source /opt/ros/${ROS_DISTRO}/setup.bash +else + echo "ROS is not available in this env" +fi + +exec "$@" diff --git a/docker/dev/tmux.conf b/docker/dev/tmux.conf new file mode 100644 index 0000000000..ecf6b22ced --- /dev/null +++ b/docker/dev/tmux.conf @@ -0,0 +1,84 @@ +# set-option -g pane-active-border-fg yellow +# set-option -g pane-active-border-bg blue +# set-option -g pane-border-fg blue +# set-option -g pane-border-bg blue +# set-option -g message-fg black +# set-option -g message-bg green +set-option -g status-bg blue +set-option -g status-fg cyan +set-option -g history-limit 5000 + +set-option -g prefix C-q + +bind | split-window -h -c "#{pane_current_path}" +bind "-" split-window -v -c "#{pane_current_path}" +bind k kill-pane +#bind C-Tab select-pane -t :.+ +#bind-key a send-prefix + +bind -n C-down new-window -c "#{pane_current_path}" +bind -n C-up new-window -c "#{pane_current_path}" +bind -n M-n new-window -c "#{pane_current_path}" +bind -n M-c new-window -c "#{pane_current_path}" +bind -n C-left prev +bind -n C-right next +bind -n M-C-n next +bind -n M-C-p prev +# bind -n C-\ new-window -c "#{pane_current_path}" +bind c new-window -c "#{pane_current_path}" + +#bind -n A-s resize-pane +#bind -n A-w resize-pane -U +#bind -n A-a resize-pane -L +#ind -n A-d resize-pane -R +#bind -n C-M-left swap-window -t -1 +#bind -n C-M-right swap-window -t +1 +#set -g default-terminal "screen-256color" +#set -g default-terminal "xterm" + +bind-key u capture-pane \; save-buffer /tmp/tmux-buffer \; run-shell "urxvtc --geometry 51x20 --title 'floatme' -e bash -c \"cat /tmp/tmux-buffer | urlview\" " +bind-key r source-file ~/.tmux.conf + +# set-window-option -g window-status-current-fg green +set -g status-fg white + +set-window-option -g aggressive-resize off +set-window-option -g automatic-rename on + +# bind-key -n C-\` select-window -t 0 +bind-key -n C-0 select-window -t 0 +bind-key -n C-1 select-window -t 1 +bind-key -n C-2 select-window -t 2 +bind-key -n C-3 select-window -t 3 +bind-key -n C-4 select-window -t 4 +bind-key -n C-5 select-window -t 5 +bind-key -n C-6 select-window -t 6 +bind-key -n C-7 select-window -t 7 +bind-key -n C-8 select-window -t 8 +bind-key -n C-9 select-window -t 9 + + +# statusbar settings - adopted from tmuxline.vim and vim-airline - Theme: murmur +set -g status-justify "left" +set -g status "on" +set -g status-left-style "none" +set -g message-command-style "fg=colour144,bg=colour237" +set -g status-right-style "none" +set -g status-style "bg=black" +set -g status-bg "black" +set -g message-style "fg=colour144,bg=colour237" +set -g pane-active-border-style "fg=colour248" +#set -g pane-border-style "fg=colour238" +#set -g pane-active-border-style "fg=colour241" +set -g pane-border-style "fg=colour0" +set -g status-right-length "100" +set -g status-left-length "100" +# setw -g window-status-activity-attr "none" +setw -g window-status-activity-style "fg=colour27,bg=colour234,none" +setw -g window-status-separator "#[bg=colour235]" +setw -g window-status-style "fg=colour253,bg=black,none" +set -g status-left "" +set -g status-right "#[bg=black]#[fg=colour244]#h#[fg=colour244]#[fg=colour3]/#[fg=colour244]#S" + +setw -g window-status-format " #[fg=colour3]#I#[fg=colour244] #W " +setw -g window-status-current-format " #[fg=color3]#I#[fg=colour254] #W " diff --git a/docker/navigation/.env.hardware b/docker/navigation/.env.hardware new file mode 100644 index 0000000000..05e08bd375 --- /dev/null +++ b/docker/navigation/.env.hardware @@ -0,0 +1,64 @@ +# Hardware Configuration Environment Variables +# Copy this file to .env and customize for your hardware setup + +# ============================================ +# NVIDIA GPU Support +# ============================================ +# Set the Docker runtime to nvidia for GPU support (it's runc by default) +#DOCKER_RUNTIME=nvidia + +# ============================================ +# ROS Configuration +# ============================================ +# ROS domain ID for multi-robot setups +ROS_DOMAIN_ID=42 + +# Robot configuration ('mechanum_drive', 'unitree/unitree_g1', 'unitree/unitree_g1', etc) +ROBOT_CONFIG_PATH=mechanum_drive + +# Robot IP address on local network for connection over WebRTC +# For Unitree Go2, Unitree G1, if using WebRTCConnection +# This can be found in the unitree app under Device settings or via network scan +ROBOT_IP= + +# ============================================ +# Mid-360 Lidar Configuration +# ============================================ +# Network interface connected to the lidar (e.g., eth0, enp0s3) +# Find with: ip addr show +LIDAR_INTERFACE=eth0 + +# Processing computer IP address on the lidar subnet +# Must be on the same subnet as the lidar (e.g., 192.168.1.5) +# LIDAR_COMPUTER_IP=192.168.123.5 # FOR UNITREE G1 EDU +LIDAR_COMPUTER_IP=192.168.1.5 + +# Gateway IP address for the lidar subnet +# LIDAR_GATEWAY=192.168.123.1 # FOR UNITREE G1 EDU +LIDAR_GATEWAY=192.168.1.1 + +# Full IP address of your Mid-360 lidar +# This should match the IP configured on your lidar device +# Common patterns: 192.168.1.1XX or 192.168.123.1XX +# LIDAR_IP=192.168.123.120 # FOR UNITREE G1 EDU +LIDAR_IP=192.168.1.116 + +# ============================================ +# Motor Controller Configuration +# ============================================ +# Serial device for motor controller +# Check with: ls /dev/ttyACM* or ls /dev/ttyUSB* +MOTOR_SERIAL_DEVICE=/dev/ttyACM0 + +# ============================================ +# Network Communication (for base station) +# ============================================ +# Enable WiFi buffer optimization for data transmission +# Set to true if using wireless base station +ENABLE_WIFI_BUFFER=false + +# ============================================ +# Display Configuration +# ============================================ +# X11 display (usually auto-detected) +# DISPLAY=:0 diff --git a/docker/navigation/.gitignore b/docker/navigation/.gitignore new file mode 100644 index 0000000000..0eaccbc740 --- /dev/null +++ b/docker/navigation/.gitignore @@ -0,0 +1,20 @@ +# Cloned repository +ros-navigation-autonomy-stack/ + +# Unity models (large binary files) +unity_models/ + +# ROS bag files +bagfiles/ + +# Config files (may contain local settings) +config/ + +# Docker volumes +.docker/ + +# Temporary files +*.tmp +*.log +*.swp +*~ diff --git a/docker/navigation/Dockerfile b/docker/navigation/Dockerfile new file mode 100644 index 0000000000..69378ea7c7 --- /dev/null +++ b/docker/navigation/Dockerfile @@ -0,0 +1,228 @@ +# Base image with ROS Jazzy desktop full +FROM osrf/ros:jazzy-desktop-full + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV ROS_DISTRO=jazzy +ENV WORKSPACE=/ros2_ws +ENV DIMOS_PATH=/workspace/dimos + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + # ROS packages + ros-jazzy-pcl-ros \ + # Development tools + git \ + git-lfs \ + cmake \ + build-essential \ + python3-colcon-common-extensions \ + # PCL and system libraries + libpcl-dev \ + libgoogle-glog-dev \ + libgflags-dev \ + libatlas-base-dev \ + libeigen3-dev \ + libsuitesparse-dev \ + # X11 and GUI support for RVIZ + x11-apps \ + xorg \ + openbox \ + # Networking tools + iputils-ping \ + net-tools \ + iproute2 \ + ethtool \ + # USB and serial tools (for hardware support) + usbutils \ + udev \ + # Time synchronization (for multi-computer setup) + chrony \ + # Editor (optional but useful) + nano \ + vim \ + # Python tools + python3-pip \ + python3-setuptools \ + python3-venv \ + # Additional dependencies for dimos + ffmpeg \ + portaudio19-dev \ + libsndfile1 \ + # For OpenCV + libgl1 \ + libglib2.0-0 \ + # For Open3D + libgomp1 \ + # For TurboJPEG + libturbojpeg0-dev \ + # Clean up + && rm -rf /var/lib/apt/lists/* + +# Create workspace directory +RUN mkdir -p ${WORKSPACE}/src + +# Copy the autonomy stack repository (should be cloned by build.sh) +COPY docker/navigation/ros-navigation-autonomy-stack ${WORKSPACE}/src/ros-navigation-autonomy-stack + +# Set working directory +WORKDIR ${WORKSPACE} + +# Set up ROS environment +RUN echo "source /opt/ros/${ROS_DISTRO}/setup.bash" >> ~/.bashrc + +# Build all hardware dependencies +RUN \ + # Build Livox-SDK2 for Mid-360 lidar + cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/utilities/livox_ros_driver2/Livox-SDK2 && \ + mkdir -p build && cd build && \ + cmake .. && make -j$(nproc) && make install && ldconfig && \ + # Install Sophus + cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/Sophus && \ + mkdir -p build && cd build && \ + cmake .. -DBUILD_TESTS=OFF && make -j$(nproc) && make install && \ + # Install Ceres Solver + cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/ceres-solver && \ + mkdir -p build && cd build && \ + cmake .. && make -j$(nproc) && make install && \ + # Install GTSAM + cd ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/slam/dependency/gtsam && \ + mkdir -p build && cd build && \ + cmake .. -DGTSAM_USE_SYSTEM_EIGEN=ON -DGTSAM_BUILD_WITH_MARCH_NATIVE=OFF && \ + make -j$(nproc) && make install && ldconfig + +# Build the autonomy stack +RUN /bin/bash -c "source /opt/ros/${ROS_DISTRO}/setup.bash && \ + cd ${WORKSPACE} && \ + colcon build --symlink-install --cmake-args -DCMAKE_BUILD_TYPE=Release" + +# Source the workspace setup +RUN echo "source ${WORKSPACE}/install/setup.bash" >> ~/.bashrc + +# Create directory for Unity environment models +RUN mkdir -p ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/mesh/unity + +# Copy the dimos repository +RUN mkdir -p ${DIMOS_PATH} +COPY . ${DIMOS_PATH}/ + +# Create a virtual environment in /opt (not in /workspace/dimos) +# This ensures the venv won't be overwritten when we mount the host dimos directory +# The container will always use its own dependencies, independent of the host +RUN python3 -m venv /opt/dimos-venv + +# Activate Python virtual environment in interactive shells +RUN echo "source /opt/dimos-venv/bin/activate" >> ~/.bashrc + +# Install Python dependencies for dimos +WORKDIR ${DIMOS_PATH} +RUN /bin/bash -c "source /opt/dimos-venv/bin/activate && \ + pip install --upgrade pip setuptools wheel && \ + pip install -e .[cpu,dev] 'mmengine>=0.10.3' 'mmcv>=2.1.0'" + +# Copy helper scripts +COPY docker/navigation/run_both.sh /usr/local/bin/run_both.sh +COPY docker/navigation/ros_launch_wrapper.py /usr/local/bin/ros_launch_wrapper.py +RUN chmod +x /usr/local/bin/run_both.sh /usr/local/bin/ros_launch_wrapper.py + +# Set up udev rules for USB devices (motor controller) +RUN echo 'SUBSYSTEM=="tty", ATTRS{idVendor}=="0483", ATTRS{idProduct}=="5740", MODE="0666", GROUP="dialout"' > /etc/udev/rules.d/99-motor-controller.rules && \ + usermod -a -G dialout root || true + +# Set up entrypoint script +RUN echo '#!/bin/bash\n\ +set -e\n\ +\n\ +git config --global --add safe.directory /workspace/dimos\n\ +\n\ +# Source ROS setup\n\ +source /opt/ros/${ROS_DISTRO}/setup.bash\n\ +source ${WORKSPACE}/install/setup.bash\n\ +\n\ +# Activate Python virtual environment for dimos\n\ +source /opt/dimos-venv/bin/activate\n\ +\n\ +# Export ROBOT_CONFIG_PATH for autonomy stack\n\ +export ROBOT_CONFIG_PATH="${ROBOT_CONFIG_PATH:-mechanum_drive}"\n\ +\n\ +# Hardware-specific configurations\n\ +if [ "${HARDWARE_MODE}" = "true" ]; then\n\ + # Set network buffer sizes for WiFi data transmission (if needed)\n\ + if [ "${ENABLE_WIFI_BUFFER}" = "true" ]; then\n\ + sysctl -w net.core.rmem_max=67108864 net.core.rmem_default=67108864 2>/dev/null || true\n\ + sysctl -w net.core.wmem_max=67108864 net.core.wmem_default=67108864 2>/dev/null || true\n\ + fi\n\ + \n\ + # Configure network interface for Mid-360 lidar if specified\n\ + if [ -n "${LIDAR_INTERFACE}" ] && [ -n "${LIDAR_COMPUTER_IP}" ]; then\n\ + ip addr add ${LIDAR_COMPUTER_IP}/24 dev ${LIDAR_INTERFACE} 2>/dev/null || true\n\ + ip link set ${LIDAR_INTERFACE} up 2>/dev/null || true\n\ + if [ -n "${LIDAR_GATEWAY}" ]; then\n\ + ip route add default via ${LIDAR_GATEWAY} dev ${LIDAR_INTERFACE} 2>/dev/null || true\n\ + fi\n\ + fi\n\ + \n\ + # Generate MID360_config.json if LIDAR_COMPUTER_IP and LIDAR_IP are set\n\ + if [ -n "${LIDAR_COMPUTER_IP}" ] && [ -n "${LIDAR_IP}" ]; then\n\ + cat > ${WORKSPACE}/src/ros-navigation-autonomy-stack/src/utilities/livox_ros_driver2/config/MID360_config.json < /ros_entrypoint.sh && \ + chmod +x /ros_entrypoint.sh + +# Set the entrypoint +ENTRYPOINT ["/ros_entrypoint.sh"] + +# Default command +CMD ["bash"] diff --git a/docker/navigation/README.md b/docker/navigation/README.md new file mode 100644 index 0000000000..50276a6cf6 --- /dev/null +++ b/docker/navigation/README.md @@ -0,0 +1,144 @@ +# ROS Docker Integration for DimOS + +This directory contains Docker configuration files to run DimOS and the ROS autonomy stack in the same container, enabling communication between the two systems. + +## New Ubuntu Installation + +**For fresh Ubuntu systems**, use the automated setup script: + +```bash +wget https://raw.githubusercontent.com/dimensionalOS/dimos/refs/heads/dev/docker/navigation/setup.sh?token=GHSAT0AAAAAADHM56ULLVHMU72XDZSKOZAM2ISY24A +bash setup.sh +``` + +**Installation time:** Approximately 20-30 minutes depending on your internet connection. + +**Options:** +```bash +./setup.sh --help # Show all options +./setup.sh --install-dir /opt/dimos # Custom installation directory +./setup.sh --skip-build # Skip Docker image build +``` + +If the automated script encounters issues, follow the manual setup below. + +## Prerequisites + +1. **Install Docker with `docker compose` support**. Follow the [official Docker installation guide](https://docs.docker.com/engine/install/). +2. **Install NVIDIA GPU drivers**. See [NVIDIA driver installation](https://www.nvidia.com/download/index.aspx). +3. **Install NVIDIA Container Toolkit**. Follow the [installation guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). + +## Automated Quick Start + +This is an optimistic overview. Use the commands below for an in depth version. + +**Build the Docker image:** + +```bash +cd docker/navigation +./build.sh +``` + +This will: +- Clone the ros-navigation-autonomy-stack repository (jazzy branch) +- Build a Docker image with both ROS and DimOS dependencies +- Set up the environment for both systems + +Note that the build will take over 10 minutes and build an image over 30GiB. + +**Run the simulator to test it's working:** + +```bash +./start.sh --simulation +``` + +## Manual build + +Go to the docker dir and clone the ROS navigation stack. + +```bash +cd docker/navigation +git clone -b jazzy git@github.com:dimensionalOS/ros-navigation-autonomy-stack.git +``` + +Download a [Unity environment model for the Mecanum wheel platform](https://drive.google.com/drive/folders/1G1JYkccvoSlxyySuTlPfvmrWoJUO8oSs?usp=sharing) and unzip the files to `unity_models`. + +Alternativelly, extract `office_building_1` from LFS: + +```bash +tar -xf ../../data/.lfs/office_building_1.tar.gz +mv office_building_1 unity_models +``` + +Then, go back to the root and build the docker image: + +```bash +cd ../.. +docker compose -f docker/navigation/docker-compose.yml build +``` + +## On Real Hardware + +### Configure the WiFi + +[Read this](https://github.com/dimensionalOS/ros-navigation-autonomy-stack/tree/jazzy?tab=readme-ov-file#transmitting-data-over-wifi) to see how to configure the WiFi. + +### Configure the Livox Lidar + +The MID360_config.json file is automatically generated on container startup based on your environment variables (LIDAR_COMPUTER_IP and LIDAR_IP). + +### Copy Environment Template +```bash +cp .env.hardware .env +``` + +### Edit `.env` File + +Key configuration parameters: + +```bash +# Lidar Configuration +LIDAR_INTERFACE=eth0 # Your ethernet interface (find with: ip link show) +LIDAR_COMPUTER_IP=192.168.1.5 # Computer IP on the lidar subnet +LIDAR_GATEWAY=192.168.1.1 # Gateway IP address for the lidar subnet +LIDAR_IP=192.168.1.116 # Full IP address of your Mid-360 lidar +ROBOT_IP= # IP addres of robot on local network (if using WebRTC connection) + +# Motor Controller +MOTOR_SERIAL_DEVICE=/dev/ttyACM0 # Serial device (check with: ls /dev/ttyACM*) +``` + +### Start the Container + +Start the container and leave it open. + +```bash +./start.sh --hardware +``` + +It doesn't do anything by default. You have to run commands on it by `exec`-ing: + +```bash +docker exec -it dimos_hardware_container bash +``` + +### In the container + +In the container to run the full navigation stack you must run both the dimensional python runfile with connection module and the navigation stack. + +#### Dimensional Python + Connection Module + +For the Unitree G1 +```bash +dimos run unitree-g1 +ROBOT_IP=XX.X.X.XXX dimos run unitree-g1 # If ROBOT_IP env variable is not set in .env +``` + +#### Navigation Stack + +```bash +cd /ros2_ws/src/ros-navigation-autonomy-stack +./system_real_robot_with_route_planner.sh +``` + +Now you can place goal points/poses in RVIZ by clicking the "Goalpoint" button. The robot will navigate to the point, running both local and global planners for dynamic obstacle avoidance. diff --git a/docker/navigation/build.sh b/docker/navigation/build.sh new file mode 100755 index 0000000000..da0aa2de8c --- /dev/null +++ b/docker/navigation/build.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +set -e + +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +echo -e "${GREEN}================================================${NC}" +echo -e "${GREEN}Building DimOS + ROS Autonomy Stack Docker Image${NC}" +echo -e "${GREEN}================================================${NC}" +echo "" + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +if [ ! -d "ros-navigation-autonomy-stack" ]; then + echo -e "${YELLOW}Cloning ros-navigation-autonomy-stack repository...${NC}" + git clone -b jazzy git@github.com:dimensionalOS/ros-navigation-autonomy-stack.git + echo -e "${GREEN}Repository cloned successfully!${NC}" +fi + +if [ ! -d "unity_models" ]; then + echo -e "${YELLOW}Using office_building_1 as the Unity environment...${NC}" + tar -xf ../../data/.lfs/office_building_1.tar.gz + mv office_building_1 unity_models +fi + +echo "" +echo -e "${YELLOW}Building Docker image with docker compose...${NC}" +echo "This will take a while as it needs to:" +echo " - Download base ROS Jazzy image" +echo " - Install ROS packages and dependencies" +echo " - Build the autonomy stack" +echo " - Build Livox-SDK2 for Mid-360 lidar" +echo " - Build SLAM dependencies (Sophus, Ceres, GTSAM)" +echo " - Install Python dependencies for DimOS" +echo "" + +cd ../.. + +docker compose -f docker/navigation/docker-compose.yml build + +echo "" +echo -e "${GREEN}================================${NC}" +echo -e "${GREEN}Docker image built successfully!${NC}" +echo -e "${GREEN}================================${NC}" +echo "" +echo "To run in SIMULATION mode:" +echo -e "${YELLOW} ./start.sh${NC}" +echo "" +echo "To run in HARDWARE mode:" +echo " 1. Configure your hardware settings in .env file" +echo " (copy from .env.hardware if needed)" +echo " 2. Run the hardware container:" +echo -e "${YELLOW} ./start.sh --hardware${NC}" +echo "" +echo "The script runs in foreground. Press Ctrl+C to stop." +echo "" diff --git a/docker/navigation/docker-compose.yml b/docker/navigation/docker-compose.yml new file mode 100644 index 0000000000..f26b7fbabd --- /dev/null +++ b/docker/navigation/docker-compose.yml @@ -0,0 +1,152 @@ +services: + # Simulation profile + dimos_simulation: + build: + context: ../.. + dockerfile: docker/navigation/Dockerfile + image: dimos_autonomy_stack:jazzy + container_name: dimos_simulation_container + profiles: ["", "simulation"] # Active by default (empty profile) AND with --profile simulation + + # Enable interactive terminal + stdin_open: true + tty: true + + # Network configuration - required for ROS communication + network_mode: host + + # Use nvidia runtime for GPU acceleration (falls back to runc if not available) + runtime: ${DOCKER_RUNTIME:-nvidia} + + # Environment variables for display and ROS + environment: + - DISPLAY=${DISPLAY} + - QT_X11_NO_MITSHM=1 + - NVIDIA_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-all} + - NVIDIA_DRIVER_CAPABILITIES=${NVIDIA_DRIVER_CAPABILITIES:-all} + - ROS_DOMAIN_ID=${ROS_DOMAIN_ID:-42} + - ROBOT_CONFIG_PATH=${ROBOT_CONFIG_PATH:-mechanum_drive} + - ROBOT_IP=${ROBOT_IP:-} + - HARDWARE_MODE=false + + # Volume mounts + volumes: + # X11 socket for GUI + - /tmp/.X11-unix:/tmp/.X11-unix:rw + - ${HOME}/.Xauthority:/root/.Xauthority:rw + + # Mount Unity environment models (if available) + - ./unity_models:/ros2_ws/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/mesh/unity:rw + + # Mount the autonomy stack source for development + - ./ros-navigation-autonomy-stack:/ros2_ws/src/ros-navigation-autonomy-stack:rw + + # Mount entire dimos directory for live development + - ../..:/workspace/dimos:rw + + # Mount bagfiles directory + - ./bagfiles:/ros2_ws/bagfiles:rw + + # Mount config files for easy editing + - ./config:/ros2_ws/config:rw + + # Device access (for joystick controllers) + devices: + - /dev/input:/dev/input + - /dev/dri:/dev/dri + + # Working directory + working_dir: /workspace/dimos + + # Command to run both ROS and DimOS + command: /usr/local/bin/run_both.sh + + # Hardware profile - for real robot + dimos_hardware: + build: + context: ../.. + dockerfile: docker/navigation/Dockerfile + image: dimos_autonomy_stack:jazzy + container_name: dimos_hardware_container + profiles: ["hardware"] + + # Enable interactive terminal + stdin_open: true + tty: true + + # Network configuration - MUST be host for hardware access + network_mode: host + + # Privileged mode REQUIRED for hardware access + privileged: true + + # Override runtime for GPU support + runtime: ${DOCKER_RUNTIME:-runc} + + # Hardware environment variables + environment: + - DISPLAY=${DISPLAY} + - QT_X11_NO_MITSHM=1 + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=all + - ROS_DOMAIN_ID=${ROS_DOMAIN_ID:-42} + - ROBOT_CONFIG_PATH=${ROBOT_CONFIG_PATH:-mechanum_drive} + - ROBOT_IP=${ROBOT_IP:-} + - HARDWARE_MODE=true + # Mid-360 Lidar configuration + - LIDAR_INTERFACE=${LIDAR_INTERFACE:-} + - LIDAR_COMPUTER_IP=${LIDAR_COMPUTER_IP:-192.168.1.5} + - LIDAR_GATEWAY=${LIDAR_GATEWAY:-192.168.1.1} + - LIDAR_IP=${LIDAR_IP:-192.168.1.116} + # Motor controller + - MOTOR_SERIAL_DEVICE=${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0} + # Network optimization + - ENABLE_WIFI_BUFFER=true + + # Volume mounts + volumes: + # X11 socket for GUI + - /tmp/.X11-unix:/tmp/.X11-unix:rw + - ${HOME}/.Xauthority:/root/.Xauthority:rw + # Mount Unity environment models (optional for hardware) + - ./unity_models:/ros2_ws/src/ros-navigation-autonomy-stack/src/base_autonomy/vehicle_simulator/mesh/unity:rw + # Mount the autonomy stack source + - ./ros-navigation-autonomy-stack:/ros2_ws/src/ros-navigation-autonomy-stack:rw + # Mount entire dimos directory + - ../..:/workspace/dimos:rw + # Mount bagfiles directory + - ./bagfiles:/ros2_ws/bagfiles:rw + # Mount config files for easy editing + - ./config:/ros2_ws/config:rw + # Hardware-specific volumes + - ./logs:/ros2_ws/logs:rw + - /etc/localtime:/etc/localtime:ro + - /etc/timezone:/etc/timezone:ro + - /dev/bus/usb:/dev/bus/usb:rw + - /sys:/sys:ro + + # Device access for hardware + devices: + # Joystick controllers + - /dev/input:/dev/input + # GPU access + - /dev/dri:/dev/dri + # Motor controller serial ports + - ${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0}:${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0} + # Additional serial ports (can be enabled via environment) + # - /dev/ttyUSB0:/dev/ttyUSB0 + # - /dev/ttyUSB1:/dev/ttyUSB1 + # Cameras (can be enabled via environment) + # - /dev/video0:/dev/video0 + + # Working directory + working_dir: /workspace/dimos + + # Command - for hardware, we run bash as the user will launch specific scripts + command: bash + + # Capabilities for hardware operations + cap_add: + - NET_ADMIN # Network interface configuration + - SYS_ADMIN # System operations + - SYS_TIME # Time synchronization diff --git a/docker/navigation/ros_launch_wrapper.py b/docker/navigation/ros_launch_wrapper.py new file mode 100755 index 0000000000..dc28eabe72 --- /dev/null +++ b/docker/navigation/ros_launch_wrapper.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wrapper script to properly handle ROS2 launch file shutdown. +This script ensures clean shutdown of all ROS nodes when receiving SIGINT. +""" + +import os +import signal +import subprocess +import sys +import time + + +class ROSLaunchWrapper: + def __init__(self): + self.ros_process = None + self.dimos_process = None + self.shutdown_in_progress = False + + def signal_handler(self, _signum, _frame): + """Handle shutdown signals gracefully""" + if self.shutdown_in_progress: + return + + self.shutdown_in_progress = True + print("\n\nShutdown signal received. Stopping services gracefully...") + + # Stop DimOS first + if self.dimos_process and self.dimos_process.poll() is None: + print("Stopping DimOS...") + self.dimos_process.terminate() + try: + self.dimos_process.wait(timeout=5) + print("DimOS stopped cleanly.") + except subprocess.TimeoutExpired: + print("Force stopping DimOS...") + self.dimos_process.kill() + self.dimos_process.wait() + + # Stop ROS - send SIGINT first for graceful shutdown + if self.ros_process and self.ros_process.poll() is None: + print("Stopping ROS nodes (this may take a moment)...") + + # Send SIGINT to trigger graceful ROS shutdown + self.ros_process.send_signal(signal.SIGINT) + + # Wait for graceful shutdown with timeout + try: + self.ros_process.wait(timeout=15) + print("ROS stopped cleanly.") + except subprocess.TimeoutExpired: + print("ROS is taking too long to stop. Sending SIGTERM...") + self.ros_process.terminate() + try: + self.ros_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print("Force stopping ROS...") + self.ros_process.kill() + self.ros_process.wait() + + # Clean up any remaining processes + print("Cleaning up any remaining processes...") + cleanup_commands = [ + "pkill -f 'ros2' || true", + "pkill -f 'localPlanner' || true", + "pkill -f 'pathFollower' || true", + "pkill -f 'terrainAnalysis' || true", + "pkill -f 'sensorScanGeneration' || true", + "pkill -f 'vehicleSimulator' || true", + "pkill -f 'visualizationTools' || true", + "pkill -f 'far_planner' || true", + "pkill -f 'graph_decoder' || true", + ] + + for cmd in cleanup_commands: + subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + print("All services stopped.") + sys.exit(0) + + def run(self): + # Register signal handlers + signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) + + print("Starting ROS route planner and DimOS...") + + # Change to the ROS workspace directory + os.chdir("/ros2_ws/src/ros-navigation-autonomy-stack") + + # Start ROS route planner + print("Starting ROS route planner...") + self.ros_process = subprocess.Popen( + ["bash", "./system_simulation_with_route_planner.sh"], + preexec_fn=os.setsid, # Create new process group + ) + + print("Waiting for ROS to initialize...") + time.sleep(5) + + print("Starting DimOS navigation bot...") + + nav_bot_path = "/workspace/dimos/dimos/navigation/demo_ros_navigation.py" + venv_python = "/opt/dimos-venv/bin/python" + + if not os.path.exists(nav_bot_path): + print(f"ERROR: demo_ros_navigation.py not found at {nav_bot_path}") + nav_dir = "/workspace/dimos/dimos/navigation/" + if os.path.exists(nav_dir): + print(f"Contents of {nav_dir}:") + for item in os.listdir(nav_dir): + print(f" - {item}") + else: + print(f"Directory not found: {nav_dir}") + return + + if not os.path.exists(venv_python): + print(f"ERROR: venv Python not found at {venv_python}, using system Python") + return + + print(f"Using Python: {venv_python}") + print(f"Starting script: {nav_bot_path}") + + # Use the venv Python explicitly + try: + self.dimos_process = subprocess.Popen( + [venv_python, nav_bot_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + universal_newlines=True, + ) + + # Give it a moment to start and check if it's still running + time.sleep(2) + poll_result = self.dimos_process.poll() + if poll_result is not None: + # Process exited immediately + stdout, stderr = self.dimos_process.communicate(timeout=1) + print(f"ERROR: DimOS failed to start (exit code: {poll_result})") + if stdout: + print(f"STDOUT: {stdout}") + if stderr: + print(f"STDERR: {stderr}") + self.dimos_process = None + else: + print(f"DimOS started successfully (PID: {self.dimos_process.pid})") + + except Exception as e: + print(f"ERROR: Failed to start DimOS: {e}") + self.dimos_process = None + + if self.dimos_process: + print("Both systems are running. Press Ctrl+C to stop.") + else: + print("ROS is running (DimOS failed to start). Press Ctrl+C to stop.") + print("") + + # Wait for processes + try: + # Monitor both processes + while True: + # Check if either process has died + if self.ros_process.poll() is not None: + print("ROS process has stopped unexpectedly.") + self.signal_handler(signal.SIGTERM, None) + break + if self.dimos_process and self.dimos_process.poll() is not None: + print("DimOS process has stopped.") + # DimOS stopping is less critical, but we should still clean up ROS + self.signal_handler(signal.SIGTERM, None) + break + time.sleep(1) + except KeyboardInterrupt: + pass # Signal handler will take care of cleanup + + +if __name__ == "__main__": + wrapper = ROSLaunchWrapper() + wrapper.run() diff --git a/docker/navigation/run_both.sh b/docker/navigation/run_both.sh new file mode 100755 index 0000000000..24c480eaea --- /dev/null +++ b/docker/navigation/run_both.sh @@ -0,0 +1,147 @@ +#!/bin/bash +# Script to run both ROS route planner and DimOS together + +echo "Starting ROS route planner and DimOS..." + +# Variables for process IDs +ROS_PID="" +DIMOS_PID="" +SHUTDOWN_IN_PROGRESS=false + +# Function to handle cleanup +cleanup() { + if [ "$SHUTDOWN_IN_PROGRESS" = true ]; then + return + fi + SHUTDOWN_IN_PROGRESS=true + + echo "" + echo "Shutdown initiated. Stopping services..." + + # First, try to gracefully stop DimOS + if [ -n "$DIMOS_PID" ] && kill -0 $DIMOS_PID 2>/dev/null; then + echo "Stopping DimOS..." + kill -TERM $DIMOS_PID 2>/dev/null || true + + # Wait up to 5 seconds for DimOS to stop + for i in {1..10}; do + if ! kill -0 $DIMOS_PID 2>/dev/null; then + echo "DimOS stopped cleanly." + break + fi + sleep 0.5 + done + + # Force kill if still running + if kill -0 $DIMOS_PID 2>/dev/null; then + echo "Force stopping DimOS..." + kill -9 $DIMOS_PID 2>/dev/null || true + fi + fi + + # Then handle ROS - send SIGINT to the launch process group + if [ -n "$ROS_PID" ] && kill -0 $ROS_PID 2>/dev/null; then + echo "Stopping ROS nodes (this may take a moment)..." + + # Send SIGINT to the process group to properly trigger ROS shutdown + kill -INT -$ROS_PID 2>/dev/null || kill -INT $ROS_PID 2>/dev/null || true + + # Wait up to 15 seconds for graceful shutdown + for i in {1..30}; do + if ! kill -0 $ROS_PID 2>/dev/null; then + echo "ROS stopped cleanly." + break + fi + sleep 0.5 + done + + # If still running, send SIGTERM + if kill -0 $ROS_PID 2>/dev/null; then + echo "Sending SIGTERM to ROS..." + kill -TERM -$ROS_PID 2>/dev/null || kill -TERM $ROS_PID 2>/dev/null || true + sleep 2 + fi + + # Final resort: SIGKILL + if kill -0 $ROS_PID 2>/dev/null; then + echo "Force stopping ROS..." + kill -9 -$ROS_PID 2>/dev/null || kill -9 $ROS_PID 2>/dev/null || true + fi + fi + + # Clean up any remaining ROS2 processes + echo "Cleaning up any remaining processes..." + pkill -f "ros2" 2>/dev/null || true + pkill -f "localPlanner" 2>/dev/null || true + pkill -f "pathFollower" 2>/dev/null || true + pkill -f "terrainAnalysis" 2>/dev/null || true + pkill -f "sensorScanGeneration" 2>/dev/null || true + pkill -f "vehicleSimulator" 2>/dev/null || true + pkill -f "visualizationTools" 2>/dev/null || true + pkill -f "far_planner" 2>/dev/null || true + pkill -f "graph_decoder" 2>/dev/null || true + + echo "All services stopped." +} + +# Set up trap to call cleanup on exit +trap cleanup EXIT INT TERM + +# Start ROS route planner in background (in new process group) +echo "Starting ROS route planner..." +cd /ros2_ws/src/ros-navigation-autonomy-stack +setsid bash -c './system_simulation_with_route_planner.sh' & +ROS_PID=$! + +# Wait a bit for ROS to initialize +echo "Waiting for ROS to initialize..." +sleep 5 + +# Start DimOS +echo "Starting DimOS navigation bot..." + +# Check if the script exists +if [ ! -f "/workspace/dimos/dimos/navigation/demo_ros_navigation.py" ]; then + echo "ERROR: demo_ros_navigation.py not found at /workspace/dimos/dimos/navigation/demo_ros_navigation.py" + echo "Available files in /workspace/dimos/dimos/navigation/:" + ls -la /workspace/dimos/dimos/navigation/ 2>/dev/null || echo "Directory not found" +else + echo "Found demo_ros_navigation.py, activating virtual environment..." + if [ -f "/opt/dimos-venv/bin/activate" ]; then + source /opt/dimos-venv/bin/activate + echo "Python path: $(which python)" + echo "Python version: $(python --version)" + else + echo "WARNING: Virtual environment not found at /opt/dimos-venv, using system Python" + fi + + echo "Starting demo_ros_navigation.py..." + # Capture any startup errors + python /workspace/dimos/dimos/navigation/demo_ros_navigation.py 2>&1 & + DIMOS_PID=$! + + # Give it a moment to start and check if it's still running + sleep 2 + if kill -0 $DIMOS_PID 2>/dev/null; then + echo "DimOS started successfully with PID: $DIMOS_PID" + else + echo "ERROR: DimOS failed to start (process exited immediately)" + echo "Check the logs above for error messages" + DIMOS_PID="" + fi +fi + +echo "" +if [ -n "$DIMOS_PID" ]; then + echo "Both systems are running. Press Ctrl+C to stop." +else + echo "ROS is running (DimOS failed to start). Press Ctrl+C to stop." +fi +echo "" + +# Wait for processes +if [ -n "$DIMOS_PID" ]; then + wait $ROS_PID $DIMOS_PID 2>/dev/null || true +else + wait $ROS_PID 2>/dev/null || true +fi diff --git a/docker/navigation/setup.sh b/docker/navigation/setup.sh new file mode 100755 index 0000000000..5edf9abfd5 --- /dev/null +++ b/docker/navigation/setup.sh @@ -0,0 +1,706 @@ +#!/bin/bash +set -e +set -o pipefail + +################################################################################ +# DimOS Navigation Setup Script +# +# Usage: ./setup.sh [OPTIONS] +# --install-dir DIR Installation directory (default: ~/dimos) +# --skip-docker Skip Docker installation +# --skip-build Skip building Docker images +# --help Show this help message +# +################################################################################ + +# Color codes for output +readonly RED='\033[0;31m' +readonly GREEN='\033[0;32m' +readonly YELLOW='\033[1;33m' +readonly BLUE='\033[0;34m' +readonly CYAN='\033[0;36m' +readonly NC='\033[0m' +readonly BOLD='\033[1m' + +# Configuration +INSTALL_DIR="${HOME}/dimos" +SKIP_DOCKER=false +SKIP_BUILD=false +LOG_FILE="${HOME}/dimos-setup.log" +SCRIPT_START_TIME=$(date +%s) + +# Step tracking +CURRENT_STEP=0 +TOTAL_STEPS=8 + +################################################################################ +# Utility Functions +################################################################################ + +log() { + local level="$1" + shift + local message="$*" + local timestamp=$(date '+%Y-%m-%d %H:%M:%S') + echo "[${timestamp}] [${level}] ${message}" >> "${LOG_FILE}" +} + +print_banner() { + echo -e "${CYAN}${BOLD}" + cat << "EOF" + ____ _ __ ___ ____ _____ + / __ \(_) |/ / / __ \/ ___/ + / / / / / /|_/ / / / / /\__ \ + / /_/ / / / / / / /_/ /___/ / +/_____/_/_/ /_/ \____//____/ + + Navigation Setup Script +EOF + echo -e "${NC}" + echo -e "${BLUE}This script will set up your Ubuntu system for DimOS Navigation${NC}" + echo -e "${BLUE}Installation may take 20-30 minutes depending on your connection${NC}" + echo "" +} + +step() { + CURRENT_STEP=$((CURRENT_STEP + 1)) + echo "" + echo -e "${CYAN}${BOLD}[Step ${CURRENT_STEP}/${TOTAL_STEPS}]${NC} ${BOLD}$1${NC}" + log "INFO" "Step ${CURRENT_STEP}/${TOTAL_STEPS}: $1" +} + +info() { + echo -e "${BLUE}ℹ${NC} $1" + log "INFO" "$1" +} + +success() { + echo -e "${GREEN}✓${NC} $1" + log "SUCCESS" "$1" +} + +warning() { + echo -e "${YELLOW}⚠${NC} $1" + log "WARNING" "$1" +} + +error() { + echo -e "${RED}✗${NC} $1" + log "ERROR" "$1" +} + +fatal() { + error "$1" + echo "" + echo -e "${RED}${BOLD}Installation failed.${NC}" + echo -e "Check the log file for details: ${LOG_FILE}" + echo "" + exit 1 +} + +confirm() { + local prompt="$1" + local default="${2:-n}" + local response + + if [[ "${default}" == "y" ]]; then + prompt="${prompt} [Y/n]: " + else + prompt="${prompt} [y/N]: " + fi + + read -r -p "$(echo -e "${YELLOW}${prompt}${NC}")" response + response=${response:-${default}} + + [[ "${response,,}" =~ ^y(es)?$ ]] +} + +check_command() { + command -v "$1" >/dev/null 2>&1 +} + +################################################################################ +# Pre-flight Checks +################################################################################ + +preflight_checks() { + step "Running pre-flight checks" + + if [[ "$(uname -s)" != "Linux" ]]; then + fatal "This script is designed for Linux systems only" + fi + + if ! check_command apt-get; then + fatal "This script requires Ubuntu or Debian-based system" + fi + + if [[ -f /etc/os-release ]]; then + source /etc/os-release + info "Detected: ${PRETTY_NAME}" + + OS_VERSION_CODENAME="${VERSION_CODENAME:-}" + + VERSION_NUM=$(echo "${VERSION_ID:-0}" | cut -d. -f1) + if ! [[ "${VERSION_NUM}" =~ ^[0-9]+$ ]]; then + warning "Unable to determine Ubuntu version number" + VERSION_NUM=0 + fi + + if [[ "${VERSION_NUM}" -ne 0 ]] && [[ "${VERSION_NUM}" -lt 24 ]]; then + warning "Ubuntu 24.04 is required. You have ${VERSION_ID}" + if ! confirm "Continue anyway?"; then + exit 0 + fi + fi + fi + + if [[ $EUID -eq 0 ]]; then + fatal "This script should NOT be run as root. Run as a regular user with sudo access." + fi + + if ! sudo -n true 2>/dev/null; then + info "This script requires sudo access. You may be prompted for your password." + if ! sudo true; then + fatal "Failed to obtain sudo access" + fi + fi + + local target_dir=$(dirname "${INSTALL_DIR}") + mkdir -p "${target_dir}" 2>/dev/null || target_dir="${HOME}" + local available_space=$(df -BG "${target_dir}" 2>/dev/null | awk 'NR==2 {print $4}' | sed 's/G//' || echo "0") + info "Available disk space at ${target_dir}: ${available_space}GB" + if [[ "${available_space}" -lt 50 ]]; then + warning "Low disk space detected. At least 50GB is recommended." + warning "Docker images and builds will require significant space." + if ! confirm "Continue anyway?"; then + exit 0 + fi + fi + + info "Checking internet connectivity..." + if ! ping -c 1 8.8.8.8 >/dev/null 2>&1; then + fatal "No internet connection detected. Please check your network." + fi + + success "Pre-flight checks passed" +} + +################################################################################ +# System Setup +################################################################################ + +update_system() { + step "Updating system packages" + + info "Running apt-get update..." + if sudo apt-get update -y >> "${LOG_FILE}" 2>&1; then + success "Package lists updated" + else + warning "Package update had some warnings (check log)" + fi +} + +install_base_tools() { + step "Installing base tools" + + local packages=( + "git" + "ssh" + "zip" + "curl" + "wget" + "jq" + "nano" + "vim" + "htop" + "ca-certificates" + "gnupg" + ) + + info "Installing: ${packages[*]}" + + if sudo DEBIAN_FRONTEND=noninteractive apt-get install -y "${packages[@]}" >> "${LOG_FILE}" 2>&1; then + success "Base tools installed" + else + fatal "Failed to install base tools" + fi + + if check_command ufw; then + info "Configuring firewall (UFW)..." + if sudo ufw status | grep -q "Status: active"; then + info "UFW is active, ensuring SSH is allowed..." + sudo ufw allow 22/tcp >> "${LOG_FILE}" 2>&1 || true + else + info "UFW is inactive, skipping firewall configuration" + fi + fi +} + +################################################################################ +# Docker Installation +################################################################################ + +install_docker() { + if [[ "${SKIP_DOCKER}" == true ]]; then + info "Skipping Docker installation (--skip-docker flag)" + return + fi + + step "Installing Docker" + + if check_command docker; then + local docker_version=$(docker --version 2>/dev/null || echo "unknown") + success "Docker is already installed: ${docker_version}" + + if docker compose version >/dev/null 2>&1; then + success "Docker Compose plugin is available" + else + warning "Docker Compose plugin not found, will attempt to install" + fi + + if ! confirm "Reinstall Docker anyway?" "n"; then + return + fi + fi + + info "Adding Docker's official GPG key..." + sudo install -m 0755 -d /etc/apt/keyrings + + if curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg 2>> "${LOG_FILE}"; then + sudo chmod a+r /etc/apt/keyrings/docker.gpg + success "Docker GPG key added" + else + fatal "Failed to add Docker GPG key" + fi + + info "Adding Docker repository..." + local version_codename="${OS_VERSION_CODENAME}" + if [[ -z "${version_codename}" ]] && [[ -f /etc/os-release ]]; then + version_codename=$(. /etc/os-release && echo "$VERSION_CODENAME") + fi + + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + ${version_codename} stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + + success "Docker repository added" + + info "Updating package lists..." + sudo apt-get update -y >> "${LOG_FILE}" 2>&1 + + info "Installing Docker packages (this may take a few minutes)..." + local docker_packages=( + "docker-ce" + "docker-ce-cli" + "containerd.io" + "docker-buildx-plugin" + "docker-compose-plugin" + ) + + if sudo DEBIAN_FRONTEND=noninteractive apt-get install -y "${docker_packages[@]}" >> "${LOG_FILE}" 2>&1; then + success "Docker installed successfully" + else + fatal "Failed to install Docker packages" + fi + + info "Configuring Docker group permissions..." + + if ! getent group docker >/dev/null; then + sudo groupadd docker + fi + + if sudo usermod -aG docker "${USER}"; then + success "User ${USER} added to docker group" + else + warning "Failed to add user to docker group" + fi + + info "Verifying Docker installation..." + if sudo docker run --rm hello-world >> "${LOG_FILE}" 2>&1; then + success "Docker is working correctly" + else + warning "Docker verification failed, but installation may still be successful" + fi + + warning "Docker group changes require logout/login to take effect" + info "For now, we'll use 'sudo docker' commands" +} + +################################################################################ +# Git LFS Setup +################################################################################ + +install_git_lfs() { + step "Installing Git LFS" + + if check_command git-lfs; then + success "Git LFS is already installed" + return + fi + + info "Adding Git LFS repository..." + if curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash >> "${LOG_FILE}" 2>&1; then + success "Git LFS repository added" + else + fatal "Failed to add Git LFS repository" + fi + + info "Installing Git LFS..." + if sudo apt-get install -y git-lfs >> "${LOG_FILE}" 2>&1; then + success "Git LFS installed" + else + fatal "Failed to install Git LFS" + fi + + info "Configuring Git LFS..." + if git lfs install >> "${LOG_FILE}" 2>&1; then + success "Git LFS configured" + else + warning "Git LFS configuration had issues (may already be configured)" + fi +} + +################################################################################ +# SSH Key Configuration +################################################################################ + +setup_ssh_keys() { + step "Configuring GitHub SSH access" + + info "Testing GitHub SSH connection..." + if timeout 10 ssh -o ConnectTimeout=10 -o StrictHostKeyChecking=accept-new -T git@github.com 2>&1 | grep -q "successfully authenticated"; then + success "GitHub SSH access is already configured" + return + fi + + warning "GitHub SSH access is not configured" + echo "" + echo -e "${YELLOW}${BOLD}SSH Key Setup Required${NC}" + echo "" + echo "To clone the private DimOS repository, you need SSH access to GitHub." + echo "" + + if [[ -f "${HOME}/.ssh/id_rsa.pub" ]] || [[ -f "${HOME}/.ssh/id_ed25519.pub" ]]; then + info "Existing SSH key found" + echo "" + + if [[ -f "${HOME}/.ssh/id_ed25519.pub" ]]; then + echo -e "${CYAN}Your public key (id_ed25519.pub):${NC}" + cat "${HOME}/.ssh/id_ed25519.pub" + elif [[ -f "${HOME}/.ssh/id_rsa.pub" ]]; then + echo -e "${CYAN}Your public key (id_rsa.pub):${NC}" + cat "${HOME}/.ssh/id_rsa.pub" + fi + + echo "" + echo -e "${YELLOW}Please add this key to your GitHub account:${NC}" + echo " 1. Go to: https://github.com/settings/keys" + echo " 2. Click 'New SSH key'" + echo " 3. Paste the key above" + echo " 4. Click 'Add SSH key'" + echo "" + else + info "No SSH key found. Let's create one." + echo "" + + if confirm "Generate a new SSH key?" "y"; then + local email + echo -n "Enter your GitHub email address: " + read -r email + + info "Generating SSH key..." + if ssh-keygen -t ed25519 -C "${email}" -f "${HOME}/.ssh/id_ed25519" -N "" >> "${LOG_FILE}" 2>&1; then + success "SSH key generated" + + eval "$(ssh-agent -s)" > /dev/null + if ssh-add "${HOME}/.ssh/id_ed25519" 2>> "${LOG_FILE}"; then + success "SSH key added to agent" + else + warning "Could not add key to ssh-agent (non-critical)" + fi + + echo "" + echo -e "${CYAN}Your new public key:${NC}" + cat "${HOME}/.ssh/id_ed25519.pub" + echo "" + echo -e "${YELLOW}Please add this key to your GitHub account:${NC}" + echo " 1. Go to: https://github.com/settings/keys" + echo " 2. Click 'New SSH key'" + echo " 3. Paste the key above" + echo " 4. Click 'Add SSH key'" + echo "" + else + fatal "Failed to generate SSH key" + fi + else + echo "" + error "SSH key is required to continue" + echo "Please set up SSH access manually and run this script again." + exit 1 + fi + fi + + echo "" + if ! confirm "Have you added the SSH key to GitHub?" "n"; then + echo "" + warning "Setup paused. Please add the SSH key and run this script again." + exit 0 + fi + + info "Testing GitHub SSH connection..." + if timeout 10 ssh -o ConnectTimeout=10 -o StrictHostKeyChecking=accept-new -T git@github.com 2>&1 | grep -q "successfully authenticated"; then + success "GitHub SSH access verified!" + else + error "GitHub SSH connection failed" + echo "" + echo "Please verify:" + echo " 1. The SSH key was added to GitHub correctly" + echo " 2. You're using the correct GitHub account" + echo " 3. Try: ssh -T git@github.com" + echo "" + if ! confirm "Continue anyway?" "n"; then + exit 1 + fi + fi +} + +################################################################################ +# Repository Setup +################################################################################ + +clone_repository() { + step "Cloning DimOS repository" + + if [[ -d "${INSTALL_DIR}" ]]; then + if [[ -d "${INSTALL_DIR}/.git" ]]; then + success "Repository already exists at ${INSTALL_DIR}" + + local remote_url=$(git -C "${INSTALL_DIR}" remote get-url origin 2>/dev/null || echo "") + if [[ "${remote_url}" =~ "dimos" ]]; then + info "Existing repository verified" + return + else + warning "Directory exists but doesn't appear to be the DimOS repo" + if ! confirm "Remove and re-clone?" "n"; then + fatal "Cannot proceed with existing directory" + fi + rm -rf "${INSTALL_DIR}" + fi + else + warning "Directory ${INSTALL_DIR} exists but is not a git repository" + if ! confirm "Remove and re-clone?" "n"; then + fatal "Cannot proceed with existing directory" + fi + rm -rf "${INSTALL_DIR}" + fi + fi + + info "Cloning to ${INSTALL_DIR}..." + if git clone git@github.com:dimensionalOS/dimos.git "${INSTALL_DIR}" >> "${LOG_FILE}" 2>&1; then + success "Repository cloned successfully" + else + fatal "Failed to clone repository. Check your SSH access." + fi + + info "Pulling Git LFS files (this may take several minutes)..." + if git -C "${INSTALL_DIR}" lfs pull >> "${LOG_FILE}" 2>&1; then + success "LFS files downloaded" + else + warning "Some LFS files may not have downloaded correctly" + fi +} + +################################################################################ +# Build and Launch +################################################################################ + +build_docker_images() { + if [[ "${SKIP_BUILD}" == true ]]; then + info "Skipping Docker build (--skip-build flag)" + return + fi + + step "Building Docker images" + + local build_dir="${INSTALL_DIR}/docker/navigation" + if [[ ! -d "${build_dir}" ]]; then + fatal "Directory not found: ${build_dir}" + fi + + if [[ ! -f "${build_dir}/build.sh" ]]; then + fatal "build.sh not found in ${build_dir}" + fi + + echo "" + warning "Building Docker images will take 10-15 minutes and download ~30GB" + info "This step will:" + echo " • Clone the ROS navigation autonomy stack" + echo " • Build a large Docker image with ROS Jazzy" + echo " • Install all dependencies" + echo "" + + if ! confirm "Start the build now?" "y"; then + warning "Build skipped. You can build later with:" + echo " cd ${build_dir}" + echo " ./build.sh" + return + fi + + info "Starting build process..." + echo "" + + pushd "${build_dir}" >> "${LOG_FILE}" 2>&1 || fatal "Failed to change to ${build_dir}" + + ./build.sh 2>&1 | tee -a "${LOG_FILE}" + local build_status=${PIPESTATUS[0]} + + popd >> "${LOG_FILE}" 2>&1 || true + + if [[ ${build_status} -eq 0 ]]; then + success "Docker images built successfully" + else + fatal "Docker build failed. Check the log for details." + fi +} + +################################################################################ +# Completion +################################################################################ + +print_summary() { + local elapsed=$(($(date +%s) - SCRIPT_START_TIME)) + local minutes=$((elapsed / 60)) + local seconds=$((elapsed % 60)) + + echo "" + echo "" + echo -e "${GREEN}${BOLD}╔══════════════════════════════════════════════════════════╗${NC}" + echo -e "${GREEN}${BOLD}║ ║${NC}" + echo -e "${GREEN}${BOLD}║ Setup completed successfully! 🎉 ║${NC}" + echo -e "${GREEN}${BOLD}║ ║${NC}" + echo -e "${GREEN}${BOLD}╚══════════════════════════════════════════════════════════╝${NC}" + echo "" + echo -e "${CYAN}Installation time: ${minutes}m ${seconds}s${NC}" + echo -e "${CYAN}Installation directory: ${INSTALL_DIR}${NC}" + echo -e "${CYAN}Log file: ${LOG_FILE}${NC}" + echo "" + echo -e "${BOLD}Next steps:${NC}" + echo "" + echo " 1. If Docker commands failed, log out and back in for group changes" + echo " Or run: newgrp docker" + echo "" + echo " 2. Navigate to the project:" + echo " cd ${INSTALL_DIR}/docker/navigation" + echo "" + echo " 3. Start the demo:" + echo " ./start.sh --all" + echo "" + echo " 4. Or get an interactive shell:" + echo " ./start.sh" + echo "" + echo -e "${CYAN}For more information, see the README.md in docker/navigation/${NC}" + echo "" +} + +################################################################################ +# Argument Parsing +################################################################################ + +parse_arguments() { + while [[ $# -gt 0 ]]; do + case $1 in + --install-dir) + if [[ -z "$2" ]] || [[ "$2" == --* ]]; then + error "Error: --install-dir requires a directory path" + echo "Run '$0 --help' for usage information" + exit 1 + fi + INSTALL_DIR="$2" + shift 2 + ;; + --skip-docker) + SKIP_DOCKER=true + shift + ;; + --skip-build) + SKIP_BUILD=true + shift + ;; + --help) + print_banner + cat << EOF +Usage: $0 [OPTIONS] + +Options: + --install-dir DIR Installation directory (default: ~/dimos) + --skip-docker Skip Docker installation + --skip-build Skip building Docker images + --help Show this help message + +Examples: + $0 # Full installation + $0 --install-dir /opt/dimos # Install to custom directory + $0 --skip-docker # Skip Docker installation + $0 --skip-docker --skip-build # Only clone repository + +After installation, navigate to the project and start the demo: + cd ~/dimos/docker/navigation + ./start.sh --all + +For more information, visit: + https://github.com/dimensionalOS/dimos + +EOF + exit 0 + ;; + *) + error "Unknown option: $1" + echo "Run '$0 --help' for usage information" + exit 1 + ;; + esac + done +} + +################################################################################ +# Main +################################################################################ + +main() { + log "INFO" "DimOS Navigation Setup Script started" + log "INFO" "User: ${USER}" + log "INFO" "Install directory: ${INSTALL_DIR}" + + print_banner + + echo -e "${YELLOW}This script will:${NC}" + echo " • Update your system" + echo " • Install Docker and dependencies" + echo " • Configure Git LFS" + echo " • Set up GitHub SSH access" + echo " • Clone the DimOS repository" + echo " • Build Docker images (~30GB, 10-15 minutes)" + echo "" + + if ! confirm "Continue with installation?" "y"; then + echo "Installation cancelled." + exit 0 + fi + + preflight_checks + update_system + install_base_tools + install_docker + install_git_lfs + setup_ssh_keys + clone_repository + build_docker_images + + print_summary + + log "INFO" "Setup completed successfully" +} + +parse_arguments "$@" +main diff --git a/docker/navigation/start.sh b/docker/navigation/start.sh new file mode 100755 index 0000000000..4347006957 --- /dev/null +++ b/docker/navigation/start.sh @@ -0,0 +1,234 @@ +#!/bin/bash + +set -e + +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' + +# Parse command line arguments +MODE="simulation" +while [[ $# -gt 0 ]]; do + case $1 in + --hardware) + MODE="hardware" + shift + ;; + --simulation) + MODE="simulation" + shift + ;; + --help|-h) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --simulation Start simulation container (default)" + echo " --hardware Start hardware container for real robot" + echo " --help, -h Show this help message" + echo "" + echo "Examples:" + echo " $0 # Start simulation container" + echo " $0 --hardware # Start hardware container" + echo "" + echo "Press Ctrl+C to stop the container" + exit 0 + ;; + *) + echo -e "${RED}Unknown option: $1${NC}" + echo "Run '$0 --help' for usage information" + exit 1 + ;; + esac +done + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +echo -e "${GREEN}================================================${NC}" +echo -e "${GREEN}Starting DimOS Docker Container${NC}" +echo -e "${GREEN}Mode: ${MODE}${NC}" +echo -e "${GREEN}================================================${NC}" +echo "" + +# Hardware-specific checks +if [ "$MODE" = "hardware" ]; then + # Check if .env file exists + if [ ! -f ".env" ]; then + if [ -f ".env.hardware" ]; then + echo -e "${YELLOW}Creating .env from .env.hardware template...${NC}" + cp .env.hardware .env + echo -e "${RED}Please edit .env file with your hardware configuration:${NC}" + echo " - LIDAR_IP: Full IP address of your Mid-360 lidar" + echo " - LIDAR_COMPUTER_IP: IP address of this computer on the lidar subnet" + echo " - LIDAR_INTERFACE: Network interface connected to lidar" + echo " - MOTOR_SERIAL_DEVICE: Serial device for motor controller" + echo "" + echo "After editing, run this script again." + exit 1 + fi + fi + + # Source the environment file + if [ -f ".env" ]; then + set -a + source .env + set +a + + # Check for required environment variables + if [ -z "$LIDAR_IP" ] || [ "$LIDAR_IP" = "192.168.1.116" ]; then + echo -e "${YELLOW}Warning: LIDAR_IP still using default value in .env${NC}" + echo "Set LIDAR_IP to the actual IP address of your Mid-360 lidar" + fi + + if [ -z "$LIDAR_GATEWAY" ]; then + echo -e "${YELLOW}Warning: LIDAR_GATEWAY not configured in .env${NC}" + echo "Set LIDAR_GATEWAY to the gateway IP address for the lidar subnet" + fi + + # Check for robot IP configuration + if [ -n "$ROBOT_IP" ]; then + echo -e "${GREEN}Robot IP configured: $ROBOT_IP${NC}" + else + echo -e "${YELLOW}Note: ROBOT_IP not configured in .env${NC}" + echo "Set ROBOT_IP if using network connection to robot" + fi + + # Check for serial devices + echo -e "${GREEN}Checking for serial devices...${NC}" + if [ -e "${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0}" ]; then + echo -e " Found device at: ${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0}" + else + echo -e "${YELLOW} Warning: Device not found at ${MOTOR_SERIAL_DEVICE:-/dev/ttyACM0}${NC}" + echo -e "${YELLOW} Available serial devices:${NC}" + ls /dev/ttyACM* /dev/ttyUSB* 2>/dev/null || echo " None found" + fi + + # Check network interface for lidar + echo -e "${GREEN}Checking network interface for lidar...${NC}" + + # Get available ethernet interfaces + AVAILABLE_ETH="" + for i in /sys/class/net/*; do + if [ "$(cat $i/type 2>/dev/null)" = "1" ] && [ "$i" != "/sys/class/net/lo" ]; then + interface=$(basename $i) + if [ -z "$AVAILABLE_ETH" ]; then + AVAILABLE_ETH="$interface" + else + AVAILABLE_ETH="$AVAILABLE_ETH, $interface" + fi + fi + done + + if [ -z "$LIDAR_INTERFACE" ]; then + # No interface configured + echo -e "${RED}================================================================${NC}" + echo -e "${RED} ERROR: ETHERNET INTERFACE NOT CONFIGURED!${NC}" + echo -e "${RED}================================================================${NC}" + echo -e "${YELLOW} LIDAR_INTERFACE not set in .env file${NC}" + echo "" + echo -e "${YELLOW} Your ethernet interfaces: ${GREEN}${AVAILABLE_ETH}${NC}" + echo "" + echo -e "${YELLOW} ACTION REQUIRED:${NC}" + echo -e " 1. Edit the .env file and set:" + echo -e " ${GREEN}LIDAR_INTERFACE=${NC}" + echo -e " 2. Run this script again" + echo -e "${RED}================================================================${NC}" + exit 1 + elif ! ip link show "$LIDAR_INTERFACE" &>/dev/null; then + # Interface configured but doesn't exist + echo -e "${RED}================================================================${NC}" + echo -e "${RED} ERROR: ETHERNET INTERFACE '$LIDAR_INTERFACE' NOT FOUND!${NC}" + echo -e "${RED}================================================================${NC}" + echo -e "${YELLOW} You configured: LIDAR_INTERFACE=$LIDAR_INTERFACE${NC}" + echo -e "${YELLOW} But this interface doesn't exist on your system${NC}" + echo "" + echo -e "${YELLOW} Your ethernet interfaces: ${GREEN}${AVAILABLE_ETH}${NC}" + echo "" + echo -e "${YELLOW} ACTION REQUIRED:${NC}" + echo -e " 1. Edit the .env file and change to one of your interfaces:" + echo -e " ${GREEN}LIDAR_INTERFACE=${NC}" + echo -e " 2. Run this script again" + echo -e "${RED}================================================================${NC}" + exit 1 + else + # Interface exists and is configured correctly + echo -e " ${GREEN}✓${NC} Network interface $LIDAR_INTERFACE found" + echo -e " ${GREEN}✓${NC} Will configure static IP: ${LIDAR_COMPUTER_IP}/24" + echo -e " ${GREEN}✓${NC} Will set gateway: ${LIDAR_GATEWAY}" + echo "" + echo -e "${YELLOW} Network configuration mode: Static IP (Manual)${NC}" + echo -e " This will temporarily replace DHCP with static IP assignment" + echo -e " Configuration reverts when container stops" + fi + fi + +fi + +# Check if unified image exists +if ! docker images | grep -q "dimos_autonomy_stack.*jazzy"; then + echo -e "${YELLOW}Docker image not found. Building...${NC}" + ./build.sh +fi + +# Check for X11 display +if [ -z "$DISPLAY" ]; then + echo -e "${YELLOW}Warning: DISPLAY not set. GUI applications may not work.${NC}" + export DISPLAY=:0 +fi + +# Allow X11 connections from Docker +echo -e "${GREEN}Configuring X11 access...${NC}" +xhost +local:docker 2>/dev/null || true + +cleanup() { + xhost -local:docker 2>/dev/null || true +} + +trap cleanup EXIT + +# Check for NVIDIA runtime +if docker info 2>/dev/null | grep -q nvidia; then + echo -e "${GREEN}NVIDIA Docker runtime detected${NC}" + export DOCKER_RUNTIME=nvidia + if [ "$MODE" = "hardware" ]; then + export NVIDIA_VISIBLE_DEVICES=all + export NVIDIA_DRIVER_CAPABILITIES=all + fi +else + echo -e "${YELLOW}NVIDIA Docker runtime not found. GPU acceleration disabled.${NC}" + export DOCKER_RUNTIME=runc +fi + +# Set container name for reference +if [ "$MODE" = "hardware" ]; then + CONTAINER_NAME="dimos_hardware_container" +else + CONTAINER_NAME="dimos_simulation_container" +fi + +# Print helpful info before starting +echo "" +if [ "$MODE" = "hardware" ]; then + echo "Hardware mode - Interactive shell" + echo "" + echo -e "${GREEN}=================================================${NC}" + echo -e "${GREEN}The container is running. Exec in to run scripts:${NC}" + echo -e " ${YELLOW}docker exec -it ${CONTAINER_NAME} bash${NC}" + echo -e "${GREEN}=================================================${NC}" +else + echo "Simulation mode - Auto-starting ROS simulation and DimOS" + echo "" + echo "The container will automatically run:" + echo " - ROS navigation stack with route planner" + echo " - DimOS navigation demo" + echo "" + echo "To enter the container from another terminal:" + echo " docker exec -it ${CONTAINER_NAME} bash" +fi + +if [ "$MODE" = "hardware" ]; then + docker compose -f docker-compose.yml --profile hardware up +else + docker compose -f docker-compose.yml up +fi diff --git a/docker/python/Dockerfile b/docker/python/Dockerfile new file mode 100644 index 0000000000..6fbd5545e5 --- /dev/null +++ b/docker/python/Dockerfile @@ -0,0 +1,52 @@ +ARG FROM_IMAGE=ghcr.io/dimensionalos/ros:dev +FROM ${FROM_IMAGE} + +# Install basic requirements +RUN apt-get update +RUN apt-get install -y \ + python-is-python3 \ + curl \ + gnupg2 \ + lsb-release \ + python3-pip \ + clang \ + portaudio19-dev \ + git \ + mesa-utils \ + libgl1-mesa-glx \ + libgl1-mesa-dri \ + software-properties-common \ + libxcb1-dev \ + libxcb-keysyms1-dev \ + libxcb-util0-dev \ + libxcb-icccm4-dev \ + libxcb-image0-dev \ + libxcb-randr0-dev \ + libxcb-shape0-dev \ + libxcb-xinerama0-dev \ + libxcb-xkb-dev \ + libxkbcommon-x11-dev \ + qtbase5-dev \ + qtchooser \ + qt5-qmake \ + qtbase5-dev-tools \ + supervisor \ + iproute2 # for LCM networking system config \ + liblcm-dev + +# Fix distutils-installed packages that block pip upgrades +RUN apt-get purge -y python3-blinker python3-sympy python3-oauthlib || true + +# Install UV for fast Python package management +ENV UV_SYSTEM_PYTHON=1 +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:$PATH" + +WORKDIR /app + +# Copy entire project first to ensure proper package installation +COPY . /app/ + +# Install dependencies with UV (10-100x faster than pip) +RUN uv pip install --upgrade 'pip>=24' 'setuptools>=70' 'wheel' 'packaging>=24' && \ + uv pip install '.[cpu]' diff --git a/docker/ros/Dockerfile b/docker/ros/Dockerfile new file mode 100644 index 0000000000..2dc2b5dbb7 --- /dev/null +++ b/docker/ros/Dockerfile @@ -0,0 +1,91 @@ +ARG FROM_IMAGE=ubuntu:22.04 +FROM ${FROM_IMAGE} + +# Avoid prompts from apt +ENV DEBIAN_FRONTEND=noninteractive + +# Set locale +RUN apt-get update && apt-get install -y locales && \ + locale-gen en_US en_US.UTF-8 && \ + update-locale LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 +ENV LANG=en_US.UTF-8 + +# Set ROS distro +ENV ROS_DISTRO=humble + +# Install basic requirements +RUN apt-get update +RUN apt-get install -y \ + curl \ + gnupg2 \ + lsb-release \ + python3-pip \ + clang \ + portaudio19-dev \ + git \ + mesa-utils \ + libgl1-mesa-glx \ + libgl1-mesa-dri \ + software-properties-common \ + libxcb1-dev \ + libxcb-keysyms1-dev \ + libxcb-util0-dev \ + libxcb-icccm4-dev \ + libxcb-image0-dev \ + libxcb-randr0-dev \ + libxcb-shape0-dev \ + libxcb-xinerama0-dev \ + libxcb-xkb-dev \ + libxkbcommon-x11-dev \ + qtbase5-dev \ + qtchooser \ + qt5-qmake \ + qtbase5-dev-tools \ + supervisor + +# Install specific numpy version first +RUN pip install 'numpy<2.0.0' + +# Add ROS2 apt repository +RUN curl -sSL https://raw.githubusercontent.com/ros/rosdistro/master/ros.key -o /usr/share/keyrings/ros-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/ros-archive-keyring.gpg] http://packages.ros.org/ros2/ubuntu $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/ros2.list > /dev/null + +# Install ROS2 packages and dependencies +RUN apt-get update && apt-get install -y \ + ros-${ROS_DISTRO}-desktop \ + ros-${ROS_DISTRO}-ros-base \ + ros-${ROS_DISTRO}-image-tools \ + ros-${ROS_DISTRO}-compressed-image-transport \ + ros-${ROS_DISTRO}-vision-msgs \ + ros-${ROS_DISTRO}-rviz2 \ + ros-${ROS_DISTRO}-rqt \ + ros-${ROS_DISTRO}-rqt-common-plugins \ + ros-${ROS_DISTRO}-twist-mux \ + ros-${ROS_DISTRO}-joy \ + ros-${ROS_DISTRO}-teleop-twist-joy \ + ros-${ROS_DISTRO}-navigation2 \ + ros-${ROS_DISTRO}-nav2-bringup \ + ros-${ROS_DISTRO}-nav2-amcl \ + ros-${ROS_DISTRO}-nav2-map-server \ + ros-${ROS_DISTRO}-nav2-util \ + ros-${ROS_DISTRO}-pointcloud-to-laserscan \ + ros-${ROS_DISTRO}-slam-toolbox \ + ros-${ROS_DISTRO}-foxglove-bridge \ + python3-rosdep \ + python3-rosinstall \ + python3-rosinstall-generator \ + python3-wstool \ + python3-colcon-common-extensions \ + python3-vcstool \ + build-essential \ + screen \ + tmux + +# Initialize rosdep +RUN rosdep init +RUN rosdep update + +# Source ROS2 and workspace in bashrc +RUN echo "source /opt/ros/${ROS_DISTRO}/setup.bash" >> /root/.bashrc + +# Trigger docker workflow rerun 1 diff --git a/docker/ros/install-nix.sh b/docker/ros/install-nix.sh new file mode 100644 index 0000000000..879e2149e1 --- /dev/null +++ b/docker/ros/install-nix.sh @@ -0,0 +1,124 @@ +#!/usr/bin/env bash +set -euo pipefail + +if nix_path="$(type -p nix)" ; then + echo "Aborting: Nix is already installed at ${nix_path}" + exit +fi + +if [[ ($OSTYPE =~ linux) && ($INPUT_ENABLE_KVM == 'true') ]]; then + enable_kvm() { + echo 'KERNEL=="kvm", GROUP="kvm", MODE="0666", OPTIONS+="static_node=kvm"' | sudo tee /etc/udev/rules.d/99-install-nix-action-kvm.rules + sudo udevadm control --reload-rules && sudo udevadm trigger --name-match=kvm + } + + echo '::group::Enabling KVM support' + enable_kvm && echo 'Enabled KVM' || echo 'KVM is not available' + echo '::endgroup::' +fi + +# GitHub command to put the following log messages into a group which is collapsed by default +echo "::group::Installing Nix" + +# Create a temporary workdir +workdir=$(mktemp -d) +trap 'rm -rf "$workdir"' EXIT + +# Configure Nix +add_config() { + echo "$1" >> "$workdir/nix.conf" +} +add_config "show-trace = true" +# Set jobs to number of cores +add_config "max-jobs = auto" +if [[ $OSTYPE =~ darwin ]]; then + add_config "ssl-cert-file = /etc/ssl/cert.pem" +fi +# Allow binary caches specified at user level +if [[ $INPUT_SET_AS_TRUSTED_USER == 'true' ]]; then + add_config "trusted-users = root ${USER:-}" +fi +# Add a GitHub access token. +# Token-less access is subject to lower rate limits. +if [[ -n "${INPUT_GITHUB_ACCESS_TOKEN:-}" ]]; then + echo "::debug::Using the provided github_access_token for github.com" + add_config "access-tokens = github.com=$INPUT_GITHUB_ACCESS_TOKEN" +# Use the default GitHub token if available. +# Skip this step if running an Enterprise instance. The default token there does not work for github.com. +elif [[ -n "${GITHUB_TOKEN:-}" && $GITHUB_SERVER_URL == "https://github.com" ]]; then + echo "::debug::Using the default GITHUB_TOKEN for github.com" + add_config "access-tokens = github.com=$GITHUB_TOKEN" +else + echo "::debug::Continuing without a GitHub access token" +fi +# Append extra nix configuration if provided +if [[ -n "${INPUT_EXTRA_NIX_CONFIG:-}" ]]; then + add_config "$INPUT_EXTRA_NIX_CONFIG" +fi +if [[ ! $INPUT_EXTRA_NIX_CONFIG =~ "experimental-features" ]]; then + add_config "experimental-features = nix-command flakes" +fi +# Always allow substituting from the cache, even if the derivation has `allowSubstitutes = false`. +# This is a CI optimisation to avoid having to download the inputs for already-cached derivations to rebuild trivial text files. +if [[ ! $INPUT_EXTRA_NIX_CONFIG =~ "always-allow-substitutes" ]]; then + add_config "always-allow-substitutes = true" +fi + +# Nix installer flags +installer_options=( + --no-channel-add + --nix-extra-conf-file "$workdir/nix.conf" +) + +# only use the nix-daemon settings if on darwin (which get ignored) or systemd is supported +if [[ (! $INPUT_INSTALL_OPTIONS =~ "--no-daemon") && ($OSTYPE =~ darwin || -e /run/systemd/system) ]]; then + installer_options+=( + --daemon + --daemon-user-count "$(python3 -c 'import multiprocessing as mp; print(mp.cpu_count() * 2)')" + ) +else + # "fix" the following error when running nix* + # error: the group 'nixbld' specified in 'build-users-group' does not exist + add_config "build-users-group =" + sudo mkdir -p /etc/nix + sudo chmod 0755 /etc/nix + sudo cp "$workdir/nix.conf" /etc/nix/nix.conf +fi + +if [[ -n "${INPUT_INSTALL_OPTIONS:-}" ]]; then + IFS=' ' read -r -a extra_installer_options <<< "$INPUT_INSTALL_OPTIONS" + installer_options=("${extra_installer_options[@]}" "${installer_options[@]}") +fi + +echo "installer options: ${installer_options[*]}" + +# There is --retry-on-errors, but only newer curl versions support that +curl_retries=5 +while ! curl -sS -o "$workdir/install" -v --fail -L "${INPUT_INSTALL_URL:-https://releases.nixos.org/nix/nix-2.28.3/install}" +do + sleep 1 + ((curl_retries--)) + if [[ $curl_retries -le 0 ]]; then + echo "curl retries failed" >&2 + exit 1 + fi +done + +sh "$workdir/install" "${installer_options[@]}" + +# Set paths +echo "/nix/var/nix/profiles/default/bin" >> "$GITHUB_PATH" +# new path for nix 2.14 +echo "$HOME/.nix-profile/bin" >> "$GITHUB_PATH" + +if [[ -n "${INPUT_NIX_PATH:-}" ]]; then + echo "NIX_PATH=${INPUT_NIX_PATH}" >> "$GITHUB_ENV" +fi + +# Set temporary directory (if not already set) to fix https://github.com/cachix/install-nix-action/issues/197 +if [[ -z "${TMPDIR:-}" ]]; then + echo "TMPDIR=${RUNNER_TEMP}" >> "$GITHUB_ENV" +fi + +# Close the log message group which was opened above +echo "::endgroup::" diff --git a/docs/api/agents.md b/docs/api/agents.md new file mode 100644 index 0000000000..0456135d0f --- /dev/null +++ b/docs/api/agents.md @@ -0,0 +1,78 @@ +# Agents API + +LLM-based reasoning systems that orchestrate robot behavior by invoking skills in response to natural-language commands. Agents manage long-running operations asynchronously and maintain conversation history across operations. + +--- + +## Quick Start + +The blueprint factory for composing agents with other modules. + +### llm_agent + +::: dimos.agents2.agent.llm_agent + +--- + +## Core Classes + +### Agent + +::: dimos.agents2.agent.Agent + +### LlmAgent + +::: dimos.agents2.agent.LlmAgent + +--- + +## Configuration + +### AgentSpec + +::: dimos.agents2.spec.AgentSpec + +### AgentConfig + +::: dimos.agents2.spec.AgentConfig + +### Model + +::: dimos.agents2.spec.Model + +### Provider + +::: dimos.agents2.spec.Provider + +--- + +## Message Types + +### AnyMessage + +::: dimos.agents2.spec.AnyMessage + +--- + +## Standalone Deployment + +For quick prototyping without blueprint composition. + +### deploy + +::: dimos.agents2.agent.deploy + +--- + +## Related + +**Tutorials:** + +- [Equip an agent with skills](../tutorials/skill_with_agent/tutorial.md) — Hands-on introduction to agents and skills +- [Build a multi-agent system](../tutorials/multi_agent/tutorial.md) — Coordinating multiple agents + +**Concepts & API:** + +- [Agent concept](../concepts/agent.md) — High-level overview and neurosymbolic orchestration patterns +- [Skills API](./skills.md) — Methods that agents discover and invoke +- [Modules concept](../concepts/modules.md) — Module architecture agents build upon diff --git a/docs/api/cli_tools.md b/docs/api/cli_tools.md new file mode 100644 index 0000000000..3a484335a4 --- /dev/null +++ b/docs/api/cli_tools.md @@ -0,0 +1,96 @@ +# CLI Tools + +TUI utilities for debugging and interacting with dimos agents and skills. + +--- + +## Overview + +| Command | Purpose | +|---------|---------| +| `agentspy` | Monitor agent messages (Human/Agent/Tool/System) in real-time | +| `skillspy` | Track skill execution states and durations | +| `lcmspy` | LCM traffic statistics (bandwidth, frequency per topic) | +| `humancli` | Chat with agents from the terminal | + +--- + + + +## agentspy + +Real-time monitor for agent message flow. Shows LangChain messages (HumanMessage, AIMessage, ToolMessage, SystemMessage) as they flow through the agent, color-coded by type: + +- *Human* (green): User inputs +- *Agent* (yellow): LLM responses +- *Tool* (red): Skill execution results +- *System* (red): System prompts + +Useful for debugging agent reasoning, inspecting prompts, and understanding the conversation flow. + +```bash +agentspy +``` + +--- + +## skillspy + +Real-time dashboard for skill execution monitoring. Shows skills as they execute, with state tracking (pending → running → completed/error), durations, and message counts. + +Each row shows: +- *Call ID*: Unique identifier for the skill invocation +- *Skill Name*: Which skill is executing +- *State*: Current execution state (color-coded) +- *Duration*: How long the skill has been running +- *Messages*: Count of messages in the skill's state +- *Details*: Error messages or return values + +```bash +skillspy +``` + +--- + +## lcmspy + +Real-time LCM traffic statistics dashboard. Shows bandwidth and message frequency per topic, useful for profiling communication overhead and detecting message storms. + +Each row shows: +- *Topic*: LCM channel name +- *Freq (Hz)*: Message frequency over the last 5 seconds +- *Bandwidth*: Data rate (auto-scaled to B/s, kB/s, MB/s) +- *Total Traffic*: Cumulative data since startup (auto-scaled to B, kB, MB, GB) + +```bash +lcmspy +``` + +--- + +## humancli + +IRC-style chat interface for interacting with dimos agents. Send messages and see agent responses, tool calls, and system messages in a familiar chat format. + +```bash +humancli +``` + +--- + +## See also + +Tutorials + +- [Equip an agent with skills](../tutorials/skill_with_agent/tutorial.py) +- [Build a multi-agent RoboButler](../tutorials/multi_agent/tutorial.py): Uses notebook equivalent of `agentspy` to monitor multi-agent message flow + +Concepts & API + +- [Agent concept guide](../concepts/agent.md) +- [Agents API](./agents.md): LLM agents that these tools monitor + +- [Skills concept guide](../concepts/skills.md) +- [Skills API](./skills.md) + +- [Transport concept guide](../concepts/transport.md): discusses the LCM pub/sub that `lcmspy` monitors in more detail diff --git a/docs/api/index.md b/docs/api/index.md new file mode 100644 index 0000000000..220356e3bb --- /dev/null +++ b/docs/api/index.md @@ -0,0 +1,33 @@ +# API Reference + +> [!WARNING] +> **Work in Progress** +> API documentation is currently being built out. More modules will be documented over time. + +
+ +- :material-robot: **Agents** + + --- + + The LLM-based Agents system lets you command any robot with natural language. + + [:octicons-arrow-right-24: View agents API](agents.md) + +- :material-function-variant: **Skills** + + --- + + Capabilities via which agents control and monitor the robot. + + [:octicons-arrow-right-24: View Skills API](skills.md) + +- :material-console: **CLI Tools** + + --- + + TUI utilities for debugging and interacting with agents and skills. + + [:octicons-arrow-right-24: View CLI Tools](cli-tools.md) + +
diff --git a/docs/api/skills.md b/docs/api/skills.md new file mode 100644 index 0000000000..bee0fe1723 --- /dev/null +++ b/docs/api/skills.md @@ -0,0 +1,88 @@ +# Skills API + +Skills let agents control and monitor robot capabilities. They are methods on `Module` classes decorated with `@skill()` that become LLM-callable tools. Each skill executes in a background thread and communicates state through a message protocol (pending → running → completed/error), with optional streaming modes (`Stream.call_agent` for progress updates, `Stream.passive` for background data accumulation). + +--- + +## Core Decorator + +The entry point for defining skills. + +### skill + +::: dimos.protocol.skill.skill.skill + +--- + +## Skill Configuration + +Configuration attached to decorated methods, used by SkillCoordinator to control execution. + +### SkillConfig + +::: dimos.protocol.skill.type.SkillConfig + +--- + +## Configuration Enums + +Values passed to the `@skill()` decorator to control behavior. + +### Return + +Controls how skill return values are delivered and whether they wake the agent. + +::: dimos.protocol.skill.type.Return + +### Stream + +Controls how streaming skill outputs (generators/iterators) are handled. + +::: dimos.protocol.skill.type.Stream + +### Output + +Presentation hint for how the agent should interpret skill output. + +::: dimos.protocol.skill.type.Output + +--- + +## Stream Processing + +Reducers aggregate streaming values when `stream=Stream.passive` or `stream=Stream.call_agent`. + +### Reducer + +::: dimos.protocol.skill.type.Reducer + +### make_reducer + +Factory for creating custom reducer functions from simple aggregation logic. + +::: dimos.protocol.skill.type.make_reducer + +--- + +## Infrastructure + +Base classes inherited by Modules. Most users don't interact with these directly. + +### SkillContainer + +::: dimos.protocol.skill.skill.SkillContainer + +--- + +## Related + +**Tutorials:** + +- [Build your first skill](../tutorials/skill_basics/tutorial.md) — Defining and testing skills +- [Equip an agent with skills](../tutorials/skill_with_agent/tutorial.md) — Wiring skills to agents + +**Concepts & API:** + +- [Skills concept](../concepts/skills.md) — High-level overview including execution model and best practices +- [Modules concept](../concepts/modules.md) — Module architecture that provides skills +- [Agents API](./agents.md) — LLM agents that discover and invoke skills diff --git a/docs/ci.md b/docs/ci.md new file mode 100644 index 0000000000..ac9b11115a --- /dev/null +++ b/docs/ci.md @@ -0,0 +1,146 @@ +# Continuous Integration Guide + +> *If you are ******not****** editing CI-related files, you can safely ignore this document.* + +Our GitHub Actions pipeline lives in **`.github/workflows/`** and is split into three top-level workflows: + +| Workflow | File | Purpose | +| ----------- | ------------- | -------------------------------------------------------------------- | +| **cleanup** | `cleanup.yml` | Auto-formats code with *pre-commit* and pushes fixes to your branch. | +| **docker** | `docker.yml` | Builds (and caches) our Docker image hierarchy. | +| **tests** | `tests.yml` | Pulls the *dev* image and runs the test suite. | + +--- + +## `cleanup.yml` + +* Checks out the branch. +* Executes **pre-commit** hooks. +* If hooks modify files, commits and pushes the changes back to the same branch. + +> This guarantees consistent formatting even if the developer has not installed pre-commit locally. + +--- + +## `tests.yml` + +* Pulls the pre-built **dev** container image. +* Executes: + +```bash +pytest +``` + +That’s it—making the job trivial to reproduce locally via: + +```bash +./bin/dev # enter container +pytest # run tests +``` + +--- + +## `docker.yml` + +### Objectives + +1. **Layered images**: each image builds on its parent, enabling parallel builds once dependencies are ready. +2. **Speed**: build children as soon as parents finish; leverage aggressive caching. +3. **Minimal work**: skip images whose context hasn’t changed. + +### Current hierarchy + + +``` + ┌──────┐ + │ubuntu│ + └┬────┬┘ + ┌▽──┐┌▽───────┐ + │ros││python │ + └┬──┘└───────┬┘ + ┌▽─────────┐┌▽──┐ + │ros-python││dev│ + └┬─────────┘└───┘ + ┌▽──────┐ + │ros-dev│ + └───────┘ +``` + +* ghcr.io/dimensionalos/ros:dev +* ghcr.io/dimensionalos/python:dev +* ghcr.io/dimensionalos/ros-python:dev +* ghcr.io/dimensionalos/ros-dev:dev +* ghcr.io/dimensionalos/dev:dev + +> **Note**: The diagram shows only currently active images; the system is extensible—new combinations are possible, builds can be run per branch and as parallel as possible + + +``` + ┌──────┐ + │ubuntu│ + └┬────┬┘ + ┌▽──┐┌▽────────────────────────┐ + │ros││python │ + └┬──┘└───────────────────┬────┬┘ + ┌▽─────────────────────┐┌▽──┐┌▽──────┐ + │ros-python ││dev││unitree│ + └┬────────┬───────────┬┘└───┘└───────┘ + ┌▽──────┐┌▽─────────┐┌▽──────────┐ + │ros-dev││ros-jetson││ros-unitree│ + └───────┘└──────────┘└───────────┘ +``` + +### Branch-aware tagging + +When a branch triggers a build: + +* Only images whose context changed are rebuilt. +* New images receive the tag `:`. +* Unchanged parents are pulled from the registry, e.g. + +given we made python requirements.txt changes, but no ros changes, image dep graph would look like this: + +``` +ghcr.io/dimensionalos/ros:dev → ghcr.io/dimensionalos/ros-python:my_branch → ghcr.io/dimensionalos/dev:my_branch +``` + +### Job matrix & the **check-changes** step + +To decide what to build we run a `check-changes` job that compares the diff against path filters: + +```yaml +filters: | + ros: + - .github/workflows/_docker-build-template.yml + - .github/workflows/docker.yml + - docker/base-ros/** + + python: + - docker/base-python/** + - requirements*.txt + + dev: + - docker/dev/** +``` + +This populates a build matrix (ros, python, dev) with `true/false` flags. + +### The dependency execution issue + +Ideally a child job (e.g. **ros-python**) should depend on both: + +* **check-changes** (to know if it *should* run) +* Its **parent image job** (to wait for the artifact) + +GitHub Actions can’t express “run only if *both* conditions are true *and* the parent job wasn’t skipped”. + +We are using `needs: [check-changes, ros]` to ensure the job runs after the ros build, but if ros build has been skipped we need `if: always()` to ensure that the build runs anyway. +Adding `always` for some reason completely breaks the conditional check, we cannot have OR, AND operators, it just makes the job _always_ run, which means we build python even if we don't need to. + +This is unfortunate as the build takes ~30 min first time (a few minutes afterwards thanks to caching) and I've spent a lot of time on this, lots of viable seeming options didn't pan out and probably we need to completely rewrite and own the actions runner and not depend on github structure at all. Single job called `CI` or something, within our custom docker image. + +--- + +## `run-tests` (job inside `docker.yml`) + +After all requested images are built, this job triggers **tests.yml**, passing the freshly created *dev* image tag so the suite runs against the branch-specific environment. diff --git a/docs/concepts/agent.md b/docs/concepts/agent.md new file mode 100644 index 0000000000..d381647ab5 --- /dev/null +++ b/docs/concepts/agent.md @@ -0,0 +1,158 @@ +# Agent + +## Motivation + +Traditional robot programming requires manually coding every behavior. By contrast, LLM-powered agents (in conjunction with [skills](./skills.md)) allow you to instruct robots at a higher level of abstraction, in natural language. Tell it to "go to the kitchen" and it figures out which skills to call and when. + +> [!TIP] +> New to agents? Start with the [Equip an agent with skills](../tutorials/skill_with_agent/tutorial.md) tutorial for a hands-on introduction. + +```python +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.navigation import navigation_skill +from dimos.core.blueprints import autoconnect +from dimos.robot.unitree_webrtc.unitree_go2_blueprints import basic + +# Create an agentic robot system +blueprint = autoconnect( + basic, # Hardware, navigation, mapping + navigation_skill(), # Exposes navigation as agent-callable skills + llm_agent( # The reasoning agent + system_prompt="You are a helpful robot assistant." + ), + human_input() # CLI for sending commands to the agent +) +``` + + + +## Situating Agents vis-a-vis other DimOS concepts + +Agents are [Modules](./modules.md), so they inherit streams, RPC, lifecycle management, and distributed deployment. + +However, an agent doesn't see information from streams directly. If you want to feed information to an agent, you need to do it via [skills](./skills.md). + +## How to build agentic systems + +Check out these tutorials for a better answer: + +* [Equip an agent with a skill](../tutorials/skill_with_agent/tutorial.md) +* [Build a multi-agent RoboButler system](../tutorials/multi_agent/tutorial.md) + + +But the short answer is, add an agent to the blueprint using `llm_agent()`. + +```python +from dimos.agents2.agent import llm_agent + +agent_bp = llm_agent( + system_prompt="You are a warehouse robot. Focus on navigation and inventory tasks.", + model="gpt-4o-mini", + provider="openai" +) +``` + +DimOS supports multiple LLM providers through Langchain - switching requires only configuration changes. + + + + + + +## Skill discovery (also discussed in tutorials) + +For an agent to discover skills, the skill module must register itself. The simplest way is to subclass `SkillModule`: + +```python +from dimos.core.skill_module import SkillModule +from dimos.protocol.skill.skill import skill + +class NavigationSkills(SkillModule): + """Module providing navigation capabilities.""" + + @skill() + def navigate_to(self, location: str) -> str: + """Navigate to the specified location.""" + # Implementation... + return f"Navigating to {location}" +``` + + + +`SkillModule` is just `Module` with a `set_LlmAgent_register_skills` hook that registers its skills when composed with an agent. The naming convention `set__` tells the system to call this method with the matching module's method when the blueprint is built. + + + +When skills are registered, the agent: + +1. Converts `@skill` methods to LLM tool definitions (using the docstring as the tool description -- this is why it's important to have good docstrings for those methods) +2. Exposes them to the LLM for reasoning + + + + +## The agent loop + +The agent runs an event-driven reasoning loop: + +1. **Invoke LLM** - With conversation history and skill state +2. **Execute tool calls** - Dispatch requested skills to coordinator +3. **Wait for updates** - Suspend until skills produce results +4. **Process results** - Transform skill outputs into messages +5. **Repeat** - Continue until the skills that were initialized with `Return.call_agent` or `Stream.call_agent` -- what we might call *active* skills for short -- have completed. + + + +The loop handles *long-running operations* without blocking. Navigation takes 30 seconds? The agent waits, then resumes reasoning with results. + + + +## How agents receive information + +> [!IMPORTANT] +> If you want to get information to an agent, you need to do that with skills. + +There are two broad ways to get information to an agent. You can either (i) give it skills that it can use to get information -- think of a coding agent and its search tool(s) -- or (ii) stream updates from skills. For more details, see [the Skills concept guide](./skills.md). + +## State management + +Agents follow a one-way lifecycle - once stopped, they stay stopped: + +```ascii +INITIALIZED → STARTED → RUNNING → STOPPED (terminal) +``` + +Stopped agents **cannot restart**. This prevents mixing old and new conversation contexts. To resume operations, create a fresh agent instance. + + + + + +This one-way pattern supports explicit state management - each agent instance represents a single conversation session with its own history and context. + +## Common use cases + +* **Exploration and mapping** - Agent plans exploration pattern, navigates to waypoints, tags rooms in memory, reports findings. + +* **Object search and navigation** - Agent searches memory for target object, explores to locate it if not found, navigates to object's location, confirms arrival. + +* **Guided tours and explanations** - Agent navigates to key locations, describes what's at each, answers questions about equipment and procedures. + +## See also + +### Hands-on tutorials + +* [Equip an agent with skills](../tutorials/skill_with_agent/tutorial.md) + +* [Build a multi-agent RobotButler](../tutorials/multi_agent/tutorial.md) + +### Related concepts + +* [Skills](./skills.md) - Methods that agents can discover and invoke +* [Blueprints](./blueprints.md) - Composing agents, skills, and hardware into systems +* [Modules](./modules.md) - The foundational abstraction that agents build upon + +### API reference + +* [Agents API](../api/agents.md) - API reference for agent classes, message types, and configuration diff --git a/docs/concepts/blueprints.md b/docs/concepts/blueprints.md new file mode 100644 index 0000000000..bb1857cd78 --- /dev/null +++ b/docs/concepts/blueprints.md @@ -0,0 +1,359 @@ +# Blueprints + +## Motivation + +Modules in a robotic system need to be able to exchange data with each other; e.g., +the navigation module might need information from the relevant sensor modules. +But when developing such a system, we don't want to wire modules up to each other +in a manual, fragile way. + +That's where blueprints come in. In DimOS, instead of manually connecting modules, +you can just write a *blueprint*, +a *declarative specification* for each module that describes, e.g., its informational needs; +declare that your system consists of this and that module; +and the blueprint system will handle the plumbing for you. + +```python +# Define your modules + +# Combine their blueprints +blueprint = autoconnect( + ModuleA.blueprint(), + ModuleB.blueprint(), +) +# Then build and run +coordinator = blueprint.build() +``` + +### Background on `Module`s + +Before diving into blueprints, we need to first review Modules. Recall that there are two ways in which a Module can also in some broad sense communicate with or depend on other modules. First, when you define a [Module](./modules.md) -- when you write a *blueprint* for it -- you can declare what sorts of data it consumes and what sorts of data it produces: + +```python +class ModuleA(Module): + image: Out[Image] = None + start_explore: Out[Bool] = None +``` + +In particular, these declarations are done with *streams*: `In[T]` for input and `Out[T]` for output, where `T` is the type variable for the type of data the stream carries. + +Nothing about this required specifying exactly what other Modules this Module will be wired up to. But a Module can also depend on other Modules via the RPC system -- it can declare that it needs to be able to invoke certain methods of certain other Modules. And that does require specifying what those other Modules are: + +```python +class Greeter(Module): + """High-level Greeter skill built on lower-level RobotCapabilities, from the first skill tutorial.""" + + # Declares what this module needs from other modules -- in this case, from + # another RobotCapabilities module that provides lower-level capabilities. + rpc_calls = [ + "RobotCapabilities.speak", + ] + + @skill() + def greet(self, name: str = "friend") -> str: + """Greet someone by name.""" + # ... + # A skill that invokes RobotCapabilities.speak + # See the first skill tutorial for more details. +``` + +### The blueprint system takes your declarative blueprints and does the wiring up for you + +We've just seen how there are various ways in which Modules need to be wired up to each other, based on their blueprints. This sort of wiring is the job of the blueprint system. + +That is, given the blueprints, the blueprint system automatically + +- wires up streams between modules, selecting appropriate transports for the streams + + +- and provides Modules like `Greeter` with any dependencies it needs for RPC calls. + +## Key Benefits + +**Modularity** - A module can just declare what data it consumes and produces, without needing to know the specific module(s) they are communicating with via the streams. + + +**Composability** - You can make a complex system by composing modules. + + +```python +basic = autoconnect( + connection(), + mapper(), + astar_planner(), + holonomic_local_planner(), + behavior_tree_navigator(), +).global_config(n_dask_workers=4) + +standard = autoconnect( + basic, + spatial_memory(), + object_tracking(), +).global_config(n_dask_workers=8) + +agentic = autoconnect( + standard, + llm_agent(), + navigation_skill(), + human_input(), +) +``` + + +**Reusability** - The same blueprint deploys to different environments by changing configuration, not module code. + + +**Type Safety** - Connections are validated at build time. An `Out[Image]` can only connect to `In[Image]`. + + +## How to build and run a blueprint + +Suppose you have + +- defined your modules and got the blueprints with `.blueprint`, + +- combined the blueprints with `autoconnect()` + +- and optionally added your own `.transports()` and `.global_config()` configuration (more on this later): + + +```python +# From the first skill tutorial +combined_blueprint = autoconnect( + RobotCapabilities.blueprint, # Provides speak + Greeter.blueprint, # Requires RobotCapabilities.speak +) +``` + +To build the composed blueprint -- to get the modules wired up, deploy them to workers, and start them -- you just need to call the `build()` method: + +```python +module_coordinator = combined_blueprint.build(global_config=config) +``` + +This returns a `ModuleCoordinator` instance that manages all deployed modules. + + + + + + + + + +### Running and shutting down + +After `build()`, the system is already running. For long-running applications (e.g. an honest-to-goodness robot), +use `loop()` to keep the process alive: + +```python +module_coordinator.loop() +``` + +This sleeps indefinitely until interrupted (Ctrl+C / SIGINT), whereupon it calls `stop()` to shut down gracefully. + +Alternatively, when e.g. writing batch scripts, you can build the blueprint, do whatever you need to do, and just call `stop()` when you're done: + +```python +coordinator = blueprint.build() +# ...do whatever you need to do +module_coordinator.stop() # Clean up when finished +``` + +## How the blueprint system works, in more detail (Advanced) + +Now that we've seen what blueprints are, at a high level, and how to build and run them, +we are in a position to dive into details that are helpful for building more complicated systems, such as + +- how the blueprint system matches compatible streams, and what to do when stream names don't match +- how to override the default configuration + +> [!TIP] +> Feel free to skip this section on first read -- you can always come back when you need the details. + +### How the blueprint system matches compatible streams + +In/Out streams are matched on the basis of *both* the stream name *and* the type of data associated with the stream. + +In particular, when modules declare streams with matching names and types, the blueprint system assigns them the same transport instance. This shared transport is what enables data to flow between publishers and subscribers. + + + + +```python +class ProducerModule(Module): + image: Out[Image] = None + +class ConsumerModule(Module): + image: In[Image] = None + +blueprint = autoconnect( + ProducerModule.blueprint(), + ConsumerModule.blueprint(), +) +# These streams share a transport instance +# Data published by ProducerModule.image flows to ConsumerModule.image +``` + +Matching on not just the name but also the *type* prevents mistakes: an `Out[Temperature]` won't connect to `In[Pressure]` even if both are named `data`. + +### Order and override behavior + +For stream wiring, the order of the Module arguments to `autoconnect()` doesn't matter -- connections match by name and type regardless. + +However, order *does* matter for blueprint composition. This lets you extend a base blueprint while overriding specific modules or configuration. + +In particular, when the same module class appears multiple times, **later occurrences override earlier ones**: + + +```python +# Base configuration +base = autoconnect( + mapper(voxel_size=0.5), + planner(algorithm="basic"), +).global_config(n_dask_workers=4) + +# Override planner params and worker count +advanced = autoconnect( + base, + planner(algorithm="advanced", lookahead=10), # Replaces earlier planner +).global_config(n_dask_workers=8) + +# Result: mapper(0.5) + planner(advanced, lookahead=10), 8 workers +``` + +This makes it easy to build configurations in layers, or create environment-specific variants (e.g. sim vs real), without repeating yourself. + +> [!NOTE] +> The same last-wins rule applies to `.transports()`, `.global_config()`, and `.remappings()`: later specifications override earlier ones. + +### Topic + +A *topic* (or *topic name*) is an identifier that the transport layer uses to route messages between publishers and subscribers. Streams with matching (name, type) will share the same topic. For instance, if `ProducerModule` publishes to an `image` stream and `ConsumerModule` subscribes to an `image` stream (both of type `Image`), both will use the same topic -- `/image`, say -- and the transport ensures messages flow between them. + +By default, the topic is a forward slash followed by the *name* of the stream. That is, the topic for the following `image` stream + +```python +class ProducerModule(Module): + image: Out[Image] = None +``` + +will be `/image`. + +Streams with the same name must have the same type -- this is how the blueprint system knows to wire them together. If two streams share a name but have different types, `build()` raises a `ValueError`. + +### What to do when stream names don't match (remapping) + +Sometimes you need to rename a connection to match what other modules expect. +You can use the `remappings` method to do this: + + + +```python +class ConnectionModule(Module): + color_image: Out[Image] = None # Outputs on 'color_image' + +class ProcessingModule(Module): + rgb_image: In[Image] = None # Expects input on 'rgb_image' + +# Without remapping, these wouldn't connect automatically +# With remapping, color_image is renamed to rgb_image +blueprint = ( + autoconnect( + ConnectionModule.blueprint(), + ProcessingModule.blueprint(), + ) + .remappings([ + (ConnectionModule, 'color_image', 'rgb_image'), + ]) +) +``` + +After remapping: + +- The `color_image` output from `ConnectionModule` is treated as `rgb_image` +- It automatically connects to any module with an `rgb_image` input of type `Image` +- The topic name becomes `/rgb_image` instead of `/color_image` + +If you want to override the topic, you still have to do it manually: + +```python +blueprint +.remappings([ + (ConnectionModule, 'color_image', 'rgb_image'), +]) +.transports({ + ("rgb_image", Image): LCMTransport("/custom/rgb/image", Image), +}) +``` + + + +## Configuration Management + +### Transport + + + +Recall that when modules declare streams with matching names and types, the blueprint system assigns them the same transport instance. + +By default, `LCMTransport` is selected if the data type supports `lcm_encode` (LCM stands for 'Lightweight Communications and Marshalling'). Otherwise `pLCMTransport` is used; this serializes Python objects by pickling them. + +But, as noted earlier, you aren't confined to the defaults -- you can choose whatever transport you like for a given (name, type) stream group with the `transports` method: + +```python +blueprint = autoconnect(...) +expanded_blueprint = autoconnect(blueprint, ...) +final_blueprint = expanded_blueprint.transports({ + ("image", Image): pSHMTransport( + "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + ("start_explore", Bool): pLCMTransport(), +}) +``` + +### Overriding global configuration + +The choice of transport isn't the only config you can override. + +Before we can see why, we need to remind ourselves of another fact about modules. +Each module can optionally take a `global_config` option in `__init__`. E.g.: + +```python +class ModuleA(Module): + + def __init__(self, global_config: GlobalConfig | None = None): + ... +``` + +If a global config is not explicitly supplied when the Module is defined, it will get loaded from an `.env` or environment variables when the blueprint is built. + +But if you want to override the global config for a specific blueprint, you can do that with the `.global_config()` method. + + + + + + +For instance, you might want to change the number of workers for a particular blueprint: + +```python +blueprint = blueprint.global_config(n_dask_workers=8) +``` + +Or perhaps you want to graft different configs onto the same core blueprint for different environments: + +```python +dev = blueprint.global_config(replay=True, n_dask_workers=1) +prod = blueprint.global_config(robot_ip="192.168.1.1", n_dask_workers=8) +``` + +## See also + + +- [Modules](./modules.md) +- [Transport](./transport.md) - How data is transferred between modules diff --git a/docs/concepts/index.md b/docs/concepts/index.md new file mode 100644 index 0000000000..783c77003e --- /dev/null +++ b/docs/concepts/index.md @@ -0,0 +1,23 @@ +# Concepts + +This section explains the key concepts and abstractions that make up DimOS. + +## [Agent](agent.md) + +LLM-based reasoning systems that orchestrate robot behavior by processing natural language commands and intelligently executing skills + +## [Skills](skills.md) + +Skills are how you give agents control over robot capabilities. + +## [Modules](modules.md) + +Every DimOS component -- from hardware drivers to AI agents -- is a Module: a distributed actor communicating through typed streams and RPC + +## [Blueprints](blueprints.md) + +Declarative specifications for wiring modules together + +## [Transport](transport.md) + +Abstraction layer for message passing between modules -- same code works across different backends like LCM for network or shared memory for local IPC diff --git a/docs/concepts/modules.md b/docs/concepts/modules.md new file mode 100644 index 0000000000..1d87601649 --- /dev/null +++ b/docs/concepts/modules.md @@ -0,0 +1,144 @@ +# Modules + +## What is a `Module`? + +A `Module` is a *distributed, communicating unit of functionality* -- the fundamental building block for robot applications in DimOS. Modules are self-contained actors that encapsulate specific behaviors (camera processing, navigation, AI reasoning) and communicate through well-defined interfaces. + + + +```python +from dimos.core import Module, In, Out, rpc +from dimos.msgs.sensor_msgs import Image + +class SpatialMemory(Module): + """Builds semantic memory from camera streams.""" + color_image: In[Image] = None # Typed input stream + + @rpc + def query_by_text(self, text: str, limit: int = 5) -> list[dict]: + """Expose RPC method for other modules to call.""" + return self._search_memory(text, limit) +``` + + + +Every major component is a Module: hardware drivers, perception algorithms, navigation planners, AI agents. This unified abstraction solves three critical challenges: + +**Composability** - Modules connect in flexible topologies without enforced hierarchies. A camera module can feed multiple perception modules; an agent can coordinate several navigation modules. + + + +**Safety through isolation** - Because modules are mapped onto separate processes, every module has its own address space. Even if one module fails catastrophically (e.g. a segfault), it won't bring down the others. + + + +**Distributed execution** - Modules run as Dask actors across a cluster. The system handles network communication, serialization, and RPC automatically. + + + + +## Modules and other DimOS concepts + +### Streams + +When you define a Module, you can declare what sorts of data it consumes and produces: + +```python +class ModuleA(Module): + image: Out[Image] = None + start_explore: Out[Bool] = None +``` + +In particular, these declarations are done with *streams*: `In[T]` for input and `Out[T]` for output, where `T` is the type variable for the type of data the stream carries. + +Streams provide reactive, push-based data flow between modules, built on ReactiveX. The [blueprint system](./blueprints.md) validates that connected streams have compatible types at build time. + + + + + +### RPC system + +We've seen how modules might be wired to other modules on the basis of the sorts of data it consumes and produces. But there's yet another way in which a Module can in some sense depend on other Modules: a Module can declare that it needs to be able to (synchronously) invoke certain methods of certain other Modules via RPC. + +```python +class Greeter(Module): + """High-level Greeter skill built on lower-level RobotCapabilities, from the first skill tutorial.""" + + # Declares what this module needs from other modules -- in this case, from + # another RobotCapabilities module that provides lower-level capabilities. + rpc_calls = [ + "RobotCapabilities.speak", + ] + + @skill() + def greet(self, name: str = "friend") -> str: + """Greet someone by name.""" + # ... + # A skill that invokes RobotCapabilities.speak + # See the first skill tutorial for more details. + + # ... + +class RobotCapabilities(Module): + """Low-level capabilities that our (mock) robot possesses.""" + + @rpc + def speak(self, text: str) -> str: + """Speak text out loud through the robot's speakers.""" + # ... + + # ... +``` + + + + + +### Modules are containers for skills + +Suppose your robot has certain capabilities; e.g. it can move in certain ways. *Skills* are how you'd let AI agents control and monitor such capabilities. As is explained in more detail [in the concept guide](./skills.md), skills are methods on a `Module` that are decorated with `@skill` that get turned into *tools* that AI agents can call. (See also the [Skill tutorials](../tutorials/index.md) for end-to-end examples.) + +And crucially, any `Module` can expose skills that AI agents discover and invoke, in virtue of inheriting from `SkillContainer`. + + + +## How Modules run (Advanced) + +> [!TIP] +> Feel free to skip this section on first read. + +### Distributed actors + +Modules deploy as Dask actors, each with its own event loop for async operations, automatic serialization for cross-worker communication, and transparent RPC handling. Modules communicate exclusively through Dask Actor references rather than direct Python object references, which enables transparent distributed deployment—you work with module references as local objects while Dask routes calls to appropriate workers. + + + + + +### Lifecycle + +Modules follow a defined lifecycle: initialize with configuration, deploy to Dask workers, start processing, handle streams and RPC calls while running, then stop with graceful resource cleanup. The system automatically handles event loop creation, stream initialization, RPC server setup, and resource disposal. + + + + + + +## Common module types + +**Hardware modules** - Interface with sensors and actuators (`ConnectionModule`, `CameraModule`) + +**Perception modules** - Process sensor data (`SpatialMemory`, `ObjectDetector`, `ObjectTracker`) + + + +**Navigation modules** - Path planning and control (`NavigationInterface`, `BehaviorTreeNavigator`, `AstarPlanner`) + +**Agent modules** - AI reasoning and coordination (`BaseAgentModule`, `SkillCoordinator`) + +## See also + +- [Blueprints](./blueprints.md) +- [Skills](./skills.md) +- [Agents](./agent.md) diff --git a/docs/concepts/skills.md b/docs/concepts/skills.md new file mode 100644 index 0000000000..77d2c0bd74 --- /dev/null +++ b/docs/concepts/skills.md @@ -0,0 +1,212 @@ +# Skills + +## Motivation + +Suppose your robot has certain capabilities -- e.g., it can move in certain ways, or play sounds through a speaker. How do you let an LLM agent control these capabilities? + +Skills are how you do that: skills get turned into *tools* that agents can call. + +Skills are often also defined at a *higher level of abstraction* than the robotic capabilities; e.g., a 'follow human' skill that uses computer vision data to control a robot. In this way, skills can be + +* easier for agents to work with and reason about +* and hide or abstract over differences in the underlying hardware. + +```python +from dimos.core import Module +from dimos.protocol.skill.skill import skill + +class NavigationModule(Module): + @skill() + def navigate_to(self, location: str) -> str: + """Navigate to a named location like 'kitchen'.""" + x, y, theta = self._lookup_location(location) + self._set_navigation_goal(x, y, theta) + return f"Navigating to {location}" +``` + + +Finally, if there's information you want to get to an agent, you need to do that with skills -- more on this shortly. + +## What is a skill? + +At a high level, skills are wrappers over lower-level robot capabilities. But at a more prosaic level, a skill is just a method on a Module decorated with `@skill` that: + +1. **Becomes an agent-callable tool** - The decorator generates a tool schema from the method signature and docstring +2. **Executes asynchronously** - Skills run without blocking the agent + + + +> [!TIP] +> The docstring becomes the tool description LLMs see when choosing skills. Write it for an LLM audience: make it clear, concise, action-oriented. + +## Basic usage + +### Defining a simple skill + +For a method on a `Module` to be discoverable by agents, it has to be decorated with `@skill()` and registered on the agent -- [see the 'equip an agent with skills' tutorial for more details](../tutorials/skill_with_agent/tutorial.md). + +```python +from dimos.core import Module +from dimos.protocol.skill.skill import skill + +class RobotSkills(Module): + @skill() + def speak(self, text: str) -> str: + """Make the robot speak the given text aloud.""" + self.audio.play_tts(text) + return f"Said: {text}" + + @rpc + def set_LlmAgent_register_skills(self, register_skills: RpcCall) -> None: + """Called by framework when composing with llm_agent(). + + This method is discovered by convention during blueprint.build(). + """ + register_skills.set_rpc(self.rpc) + register_skills(RPCClient(self, self.__class__)) +``` + +> [!NOTE] +> For most scenarios, you can avoid repeating the `set_LlmAgent_register_skills` boilerplate by subclassing `SkillModule` (which is just `Module` plus the `set_LlmAgent_register_skills` method shown above). + +### How skills reach agents + +When you register a Module with an agent, the agent discovers its `@skill` methods and converts them into *tool schemas* that the LLM understands. Your method signature becomes the tool's parameters; your docstring becomes its description. + +See these tutorials for examples: + +* [Equip an agent with skills](../tutorials/skill_with_agent/tutorial.md). +* [Build a RoboButler multi-agent system](../tutorials/multi_agent/tutorial.md) + +## Updating agents with results from skills + +We've seen how skills can be made available to agents as tools they can call. Often, however, we don't just want agents making tool calls -- we also want to relay updates from the tool calls, from the skills, back to the agent. + +Some of this behavior already comes as a default: if you decorate the method with `@skill()`, the agent will be notified with the *return value* of the method when the skill finishes. + +### Notifying the agent whenever there's updates + +But often we want to update the agent not just when the skill is finished, but also whenever there's progress. Think, e.g., of a 'move to certain coordinates' skill, where we might want to stream progress updates continously to the agent. + +This can be done by making the method a generator and setting the `stream` parameter of `@skill` to `Stream.call_agent`: + +```python +@skill(stream=Stream.call_agent) +def goto(self, x: float, y: float): + """Move the robot in relative coordinates. + x is forward, y is left. + goto(1, 0) will move the robot forward by 1 meter + """ + pose_to = PoseStamped( + # ... + ) + yield "moving, please wait..." # Notifies the agent + self.navigate_to(pose_to) + yield "arrived" # Notifies the agent again +``` + + +The agent is notified with each `yield`, and can take action if something goes wrong. + +### Streaming updates more '*passively*', in the background + +That said, we don't always want to update the agent *every time* there's an update. Sometimes we want to just accumulate updates in the background and *only* pass them on to the agent when the agent happens to be notified by other more 'active' skills. For instance, we may want to periodically pass on information from a camera feed without interrupting the agent on every frame. + +To do this, use `Stream.passive`: + +```python +from dimos.msgs.sensor_msgs import Image + +class CameraFeed(Module): + color_image: Out[Image] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._image_queue = queue.Queue(maxsize=1) + + @rpc + def start(self): + super().start() + self.hardware = # ... + # Subscribe to hardware stream + + @rpc + def stop(self) -> None: + # ...Clean up resources like the hardware stream + super().stop() + + + @skill(stream=Stream.passive, output=Output.image, reducer=Reducer.latest) + def video_stream(self): + """Implicit video stream skill""" + self.hardware.image_stream().subscribe(self._image_queue.put) + yield from iter(self._image_queue.get, None) +``` + +Note that the above skill can *also* be called by the agent. If you don't want that -- if you don't want the skill to be available as a tool -- set the `hide_skill` parameter to `True`. + +> [!CAUTION] +> **Passive skills alone cannot keep the agent loop alive.** If only passive skills are running, the loop exits immediately. Passive skills need to be paired with other active skills; e.g.: +> +> * Position telemetry (passive) + navigation command (active) +> * Video stream (passive) + `HumanInput` (active) + +For more on the `stream` and `ret` parameters, see [the Skills API reference](../api/skills.md). + +### Reducers as backpressure buffers for streamed updates + +When a skill streams updates, the agent might not process them as fast as they arrive. This is when the `reducer` parameter comes in handy: when updates pile up, the designated reducer is used to combine or aggregate updates. E.g., in the camera feed example above, with `reducer=Reducer.latest`, the agent will only see the latest frame from the camera feed. + +> [!NOTE] +> With `Stream.passive`, values accumulate silently until an active skill wakes the agent. With `Stream.call_agent`, whether updates are accumulated depends on whether yields happen faster than the agent processes them. + +## Getting information to agents on demand + +We've seen how updates from skills can be streamed to agents; in particular, how something like a video stream can be streamed in the background. It's worth noting, though, that another way to give an agent access to information is to give it a skill for getting such information *on demand* (think of coding agents and their search tools). + +```python +class GoogleMapsSkillContainer(SkillModule): + _latest_location: LatLon | None = None + _client: GoogleMaps + + # ... + + @skill() + def where_am_i(self, context_radius: int = 200) -> str: + """This skill returns information about what street/locality/city/etc + you are in. It also gives you nearby landmarks. + + Example: + + where_am_i(context_radius=200) + + Args: + context_radius (int): default 200, how many meters to look around + """ +``` + + +## Best practices + +**Return meaningful strings** - `"Navigated to kitchen in 12 seconds"` beats `"ok"` for LLMs. + +**Write clear docstrings** - They become tool descriptions. Be specific about what the skill does and what parameters mean. + + +**Handle errors gracefully** - Return contextual error messages for agent recovery, not raw exceptions. + +**Monitor long-running skills** - Use `skillspy` to watch skill execution in real-time. Skills are tracked in an execution database showing what's currently running and what has completed—invaluable for debugging navigation or other long operations. See the [skill basics tutorial](../tutorials/skill_basics/tutorial.md) for an example of this. + +> [!WARNING] +> **Don't use both `@skill` and `@rpc` decorators on a single method** - The `@skill` wrapper can't be pickled for LCM transport. Use `@skill()` for agent tools, `@rpc` for module-to-module calls. + + +## See also + +- [The Skills API reference](../api/skills.md) + +### Related concepts + +* [Agents](agent.md) - LLM-based reasoning that invokes skills +* [Modules](modules.md) - The distributed actors that provide skills +* [Blueprints](blueprints.md) - Composing modules and skills into systems diff --git a/docs/concepts/transport.md b/docs/concepts/transport.md new file mode 100644 index 0000000000..0d706431be --- /dev/null +++ b/docs/concepts/transport.md @@ -0,0 +1,243 @@ +# Transport + +## Motivation + +Your robot modules might run on a laptop during development, then be split across a GPU server and edge device in production. +Transport abstracts this away: your module code stays the same, only the transport configuration changes. + + + + +```python +from dimos.core import Module, In, Out +from dimos.msgs.sensor_msgs import Image + +class SpatialMemory(Module): + color_image: In[Image] = None + + def start(self): + self.color_image.subscribe(self.process_image) + + def process_image(self, img: Image): + # This code is identical whether using: + # - pSHMTransport (local shared memory) + # - LCMTransport (network) + # - JpegLcmTransport (compressed network) + self.add_to_memory(img) +``` + + + +> [!NOTE] +> The core takeaway: modules can operate on any underlying transport. The same module code deploys across local shared memory, network protocols, or distributed clusters by changing just the transport configuration. + +## How does the Transport abstraction relate to other DimOS concepts? + +We've seen, at a high level, how Transport is the abstraction layer for message passing between modules in DimOS. But let's make this more concrete. Recall that when you define a Module, you can declare what sorts of data it consumes and what sorts of data it produces. In particular, you can declare what *streams* the Module has: `In[T]` for input and `Out[T]` for output, where `T` is the type variable for the data carried by the stream. + +```python +class NavigationSkillContainer(SkillModule): + color_image: In[Image] = None + odom: In[PoseStamped] = None + # ... +``` + +This, however, only specifies the informational needs and outputs of the module; it doesn't yet specify *how* such data is to be transported between modules. To put it another way, what concrete transport backends should be used for these streams? + +It is here that Transport enters the picture. There is, in the library, a fundamental `Transport` abstraction with methods like `publish`, `subscribe`, and `broadcast`, along with concrete Transport classes that realize that abstraction (more on this shortly). And as we'll see shortly, it's easy to change the choice of concrete Transports in the configuration. + + +## The default transports + +When you call `autoconnect()`, it automatically selects sensible transports for each stream. + + +- If the stream's type can serialize to an LCM (Lightweight Communications and Marshalling) type (if it has `lcm_encode` support), then **`LCMTransport`**, a publish-subscribe messaging system that uses UDP multicast, is selected. +- Otherwise, it's **`pLCMTransport`**: this is basically LCMTransport, but with *pickle-based* serialization. + +```python +from dimos.core.blueprints import autoconnect + +blueprint = autoconnect( + connection(), + spatial_memory(), + behavior_tree_navigator() +) +# Transport selection happens automatically +``` + + + +## When to customize + + +When should you override the Transport defaults? + +Here's a quick decision guide: + +```ascii +Is performance acceptable? +├── Yes → Keep defaults, you're done +└── No → Can processes share memory? (same host, not containerized separately) + ├── Yes → What kind of data? + │ ├── Camera frames, point clouds → pSHMTransport (memory-mapped, zero-copy) + │ ├── Images + memory constrained → JpegShmTransport + │ └── Other data → pSHMTransport + └── No (network or isolated containers) → What kind of data? + ├── Images → JpegLcmTransport + ├── Messages that can serialize to an LCM type → LCMTransport + └── Python objects without lcm_encode → pLCMTransport +``` + + + +## How to choose a specific transport + +Override specific transports using the `.transports()` method: + +```python +from dimos.core.transport import pSHMTransport, JpegLcmTransport +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.msgs.sensor_msgs import Image + +# Optimize high-frequency camera stream +blueprint = autoconnect( + connection(), + spatial_memory(), +).transports({ + ("color_image", Image): pSHMTransport( + "/go2/color_image", + default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), +}) +``` + + + +**Key points:** + +- The override key is the tuple `(stream_name, type)` +- The topic name must be specified (e.g., `"/go2/color_image"`) +- Can mix different transport types in one blueprint +- Last override wins if multiple applied + + + + +### Common patterns + +```python +# 1. Shared memory for local performance +standard_with_shm = standard.transports({ + ("color_image", Image): pSHMTransport( + "/go2/color_image", + default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), +}) + +# 2. Compressed images over network +standard_with_jpeglcm = standard.transports({ + ("color_image", Image): JpegLcmTransport("/go2/color_image", Image), +}) + +# 3. LCM for external tool compatibility (Foxglove) +basic = autoconnect(...).transports({ + ("color_image", Image): LCMTransport("/go2/color_image", Image), +}) + +# 4. Compressed shared memory for multiple local consumers +standard_with_jpegshm = standard.transports({ + ("color_image", Image): JpegShmTransport("/go2/color_image", quality=75), +}) +``` + + + + + + +### Class hierarchy + +For reference, the transports form this inheritance structure: + +```ascii +Transport +└── PubSubTransport (topic-based) + ├── LCMTransport → JpegLcmTransport + ├── pLCMTransport, SHMTransport, pSHMTransport + └── JpegShmTransport, ZenohTransport +``` + +**Network-capable:** LCMTransport, pLCMTransport, JpegLcmTransport, ZenohTransport +**Local-only:** SHMTransport, pSHMTransport, JpegShmTransport + +> [!NOTE] +> JpegShmTransport extends PubSubTransport directly (not SHMTransport), so it doesn't share SHMTransport's buffer management. This is intentional—it uses its own JPEG-specific memory handling. + +--- + +## For advanced users + + +### Design principles + +**Minimal Interface** — The base Transport class just has three methods: `broadcast()`, `publish()`, `subscribe()`. + + +**Lazy Initialization** — All transports allocate resources only on first use: + +```python +# From transport.py:64-69 +def broadcast(self, _, msg) -> None: + if not self._started: + self.lcm.start() + self._started = True + self.lcm.publish(self.topic, msg) +``` + + + +**Shared Transport Instances** — Connections with the same `(name, type)` share transport objects, reducing memory and connection overhead. + + + +### Backpressure handling + +Transport provides automatic backpressure through its Observable integration. When subscribers can't keep up with producers (e.g., slow image processing), intermediate values are dropped to prevent unbounded memory growth while ensuring subscribers always see the latest data. + + + +You can also use RxPY operators for fine-grained control: + +```python +# Throttle high-frequency sensor data +stream.observable().pipe( + ops.sample(0.1), # Sample every 0.1s (10Hz) + ops.filter(lambda img: img.width > 640), # Only HD images + ops.buffer_with_time(1.0) # Batch per second +).subscribe(process_batch) +``` + + + +### Performance tuning + +Tuning options: + +- **Buffer size** — `pSHMTransport("/cam", default_capacity=6_220_800)` + Max payload in bytes. Use `DEFAULT_CAPACITY_COLOR_IMAGE` constant for 1080p RGB. + +- **JPEG quality** — `JpegShmTransport("/cam", quality=75)` + Range 1-100. Lower = smaller size, more artifacts. + +--- + +## See also + +- [Modules](./modules.md) — How modules declare and use streams +- [Blueprints](./blueprints.md) — How `autoconnect()` wires streams and selects transports diff --git a/docs/development.md b/docs/development.md new file mode 100644 index 0000000000..f1f31fd77e --- /dev/null +++ b/docs/development.md @@ -0,0 +1,231 @@ +# Development Environment Guide + +## Approach + +We optimise for flexibility—if your favourite editor is **notepad.exe**, you’re good to go. Everything below is tooling for convenience. + +--- + +## Dev Containers + +Dev containers give us a reproducible, container-based workspace identical to CI. + +### Why use them? + +* Consistent toolchain across all OSs. +* Unified formatting, linting and type-checking. +* Zero host-level dependencies (apart from Docker). + +### IDE quick start + +Install the *Dev Containers* plug-in for VS Code, Cursor, or your IDE of choice (you’ll likely be prompted automatically when you open our repo). + +### Shell only quick start + +Terminal within your IDE should use devcontainer transparently given you installed the plugin, but in case you want to run our shell without an IDE, you can use `./bin/dev` +(it depends on npm/node being installed) + +```sh +./bin/dev +devcontainer CLI (https://github.com/devcontainers/cli) not found. Install into repo root? (y/n): y + +added 1 package, and audited 2 packages in 8s +found 0 vulnerabilities + +[1 ms] @devcontainers/cli 0.76.0. Node.js v20.19.0. linux 6.12.27-amd64 x64. +[4838 ms] Start: Run: docker start f0355b6574d9bd277d6eb613e1dc32e3bc18e7493e5b170e335d0e403578bcdb +[5299 ms] f0355b6574d9bd277d6eb613e1dc32e3bc18e7493e5b170e335d0e403578bcdb +{"outcome":"success","containerId":"f0355b6574d9bd277d6eb613e1dc32e3bc18e7493e5b170e335d0e403578bcdb","remoteUser":"root","remoteWorkspaceFolder":"/workspaces/dimos"} + + ██████╗ ██╗███╗ ███╗███████╗███╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ █████╗ ██╗ + ██╔══██╗██║████╗ ████║██╔════╝████╗ ██║██╔════╝██║██╔═══██╗████╗ ██║██╔══██╗██║ + ██║ ██║██║██╔████╔██║█████╗ ██╔██╗ ██║███████╗██║██║ ██║██╔██╗ ██║███████║██║ + ██║ ██║██║██║╚██╔╝██║██╔══╝ ██║╚██╗██║╚════██║██║██║ ██║██║╚██╗██║██╔══██║██║ + ██████╔╝██║██║ ╚═╝ ██║███████╗██║ ╚████║███████║██║╚██████╔╝██║ ╚████║██║ ██║███████╗ + ╚═════╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝ + + v_unknown:unknown | Wed May 28 09:23:33 PM UTC 2025 + +root@dimos:/workspaces/dimos # +``` + +The script will: + +* Offer to npm install `@devcontainers/cli` locally (if not available globally) on first run. +* Pull `ghcr.io/dimensionalos/dev:dev` if not present (external contributors: we plan to mirror to Docker Hub). + +You’ll land in the workspace as **root** with all project tooling available. + +## Pre-Commit Hooks + +We use [pre-commit](https://pre-commit.com) (config in `.pre-commit-config.yaml`) to enforce formatting, licence headers, EOLs, LFS checks, etc. Hooks run in **milliseconds**. +Hooks also run in CI; any auto-fixes are committed back to your PR, so local installation is optional — but gives faster feedback. + +```sh +CRLF end-lines checker...................................................Passed +CRLF end-lines remover...................................................Passed +Insert license in comments...............................................Passed +ruff format..............................................................Passed +check for case conflicts.................................................Passed +check json...............................................................Passed +check toml...............................................................Passed +check yaml...............................................................Passed +format json..............................................................Passed +LFS data.................................................................Passed + +``` + +Given your editor uses ruff via devcontainers (which it should) actual auto-commit hook won't ever reformat your code - IDE will have already done this. + +### Running hooks manually + +Given your editor uses git via devcontainers (which it should) auto-commit hooks will run automatically, this is in case you want to run them manually. + +Inside the dev container (Your IDE will likely run this transparently for each commit if using devcontainer plugin): + +```sh +pre-commit run --all-files +``` + +### Installing pre-commit on your host + +```sh +apt install pre-commit # or brew install pre-commit +pre-commit install # install git hook +pre-commit run --all-files +``` + +--- + +## Testing + +All tests run with **pytest** inside the dev container, ensuring local results match CI. + +### Basic usage + +```sh +./bin/dev # start container +pytest # run all tests beneath the current directory +``` + +Depending on which dir you are in, only tests from that dir will run, which is convinient when developing - you can frequently validate your feature tree. + +Your vibe coding agent will know to use these tests via the devcontainer so it can validate it's work. + +#### Useful options + +| Purpose | Command | +| -------------------------- | ----------------------- | +| Show `print()` output | `pytest -s` | +| Filter by name substring | `pytest -k ""` | +| Run tests with a given tag | `pytest -m ` | + +We use tags for special tests, like `vis` or `tool` for things that aren't meant to be ran in CI and when casually developing, something that requires hardware or visual inspection (pointcloud merging vis etc) + +You can enable a tag by selecting -m - these are configured in `./pyproject.toml` + +```sh +root@dimos:/workspaces/dimos/dimos # pytest -sm vis -k my_visualization +... +``` + +Classic development run within a subtree: + +```sh +./bin/dev + +... container init ... + +root@dimos:/workspaces/dimos # cd dimos/robot/unitree_webrtc/ +root@dimos:/workspaces/dimos/dimos/robot/unitree_webrtc # pytest +collected 27 items / 22 deselected / 5 selected + +type/test_map.py::test_robot_mapping PASSED +type/test_timeseries.py::test_repr PASSED +type/test_timeseries.py::test_equals PASSED +type/test_timeseries.py::test_range PASSED +type/test_timeseries.py::test_duration PASSED + +``` + +Showing prints: + +```sh +root@dimos:/workspaces/dimos/dimos/robot/unitree_webrtc/type # pytest -s test_odometry.py +test_odometry.py::test_odometry_conversion_and_count Odom ts(2025-05-30 13:52:03) pos(→ Vector Vector([0.432199 0.108042 0.316589])), rot(↑ Vector Vector([ 7.7200000e-04 -9.1280000e-03 3.006 +8621e+00])) yaw(172.3°) +Odom ts(2025-05-30 13:52:03) pos(→ Vector Vector([0.433629 0.105965 0.316143])), rot(↑ Vector Vector([ 0.003814 -0.006436 2.99591235])) yaw(171.7°) +Odom ts(2025-05-30 13:52:04) pos(→ Vector Vector([0.434459 0.104739 0.314794])), rot(↗ Vector Vector([ 0.005558 -0.004183 3.00068456])) yaw(171.9°) +Odom ts(2025-05-30 13:52:04) pos(→ Vector Vector([0.435621 0.101699 0.315852])), rot(↑ Vector Vector([ 0.005391 -0.006002 3.00246893])) yaw(172.0°) +Odom ts(2025-05-30 13:52:04) pos(→ Vector Vector([0.436457 0.09857 0.315254])), rot(↑ Vector Vector([ 0.003358 -0.006916 3.00347172])) yaw(172.1°) +Odom ts(2025-05-30 13:52:04) pos(→ Vector Vector([0.435535 0.097022 0.314399])), rot(↑ Vector Vector([ 1.88300000e-03 -8.17800000e-03 3.00573432e+00])) yaw(172.2°) +Odom ts(2025-05-30 13:52:04) pos(→ Vector Vector([0.433739 0.097553 0.313479])), rot(↑ Vector Vector([ 8.10000000e-05 -8.71700000e-03 3.00729616e+00])) yaw(172.3°) +Odom ts(2025-05-30 13:52:04) pos(→ Vector Vector([0.430924 0.09859 0.31322 ])), rot(↑ Vector Vector([ 1.84000000e-04 -9.68700000e-03 3.00945623e+00])) yaw(172.4°) +... etc +``` + +--- + +## Cheatsheet + +| Action | Command | +| --------------------------- | ---------------------------- | +| Enter dev container | `./bin/dev` | +| Run all pre-commit hooks | `pre-commit run --all-files` | +| Install hooks in local repo | `pre-commit install` | +| Run tests in current path | `pytest` | +| Filter tests by name | `pytest -k ""` | +| Enable stdout in tests | `pytest -s` | +| Run tagged tests | `pytest -m ` | + +## Docs + +### Installation + +If you are using the devcontainer, you don't need to install anything else—the dependencies for the docs site are already included. + +### Local Development Server + +Start a local server with hot reload: + +```bash +mkdocs serve +``` + +Then open in your browser. + +### Build Static Site + +Build the static documentation site: + +```bash +mkdocs build +``` + +The output (which includes the various `llm.tx`es) will be in the `site/` directory. + + + +### Embedding Marimo Notebooks + +We embed marimo notebooks in the docs by **pre-rendering to HTML** and using an iframe, rather than using [mkdocs-marimo](https://github.com/marimo-team/mkdocs-marimo)'s native embedding. + +**Why?** mkdocs-marimo uses Pyodide/WASM to run notebooks in the browser. This only works for [packages available in Pyodide](https://pyodide.org/en/stable/usage/packages-in-pyodide.html) or pure Python wheels. Getting custom packages like `dimos` to work appears non-trivial—see [marimo#5488](https://github.com/marimo-team/marimo/issues/5488) and [marimo#5535](https://github.com/marimo-team/marimo/issues/5535). Pre-rendering sidesteps these complications. + +**How it works:** + +1. `docs/hooks.py` registers an mkdocs `on_pre_build` hook +2. The hook runs `marimo export html` on each notebook listed in `MARIMO_NOTEBOOKS` +3. The markdown file embeds the rendered HTML via iframe + +**Adding a new notebook:** + +1. Add an entry to `MARIMO_NOTEBOOKS` in `docs/hooks.py` +2. Create a markdown file that embeds the output: + + ```html + + + ``` + +3. The `height="800px"` attribute is a fallback; the inline style uses `clamp()` with `dvh` for responsive sizing diff --git a/docs/hooks.py b/docs/hooks.py new file mode 100644 index 0000000000..21bcc39e68 --- /dev/null +++ b/docs/hooks.py @@ -0,0 +1,236 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MkDocs hooks for pre-building marimo notebooks. + +See docs/development.md "Embedding Marimo Notebooks" for why we use this approach +(dimos is not available in Pyodide/WASM, so mkdocs-marimo's native embedding won't work). + +Why we kill the marimo export process instead of letting it exit gracefully: + +The tutorial notebooks start a Dask cluster which registers a SIGTERM signal handler. +When dimos.stop() is called, the handler runs close_all() and then sys.exit(0). +However, marimo's runtime catches SystemExit exceptions, preventing the process from +actually exiting. The process hangs indefinitely waiting for... something in marimo. + +This is a marimo-specific issue - the same notebook code exits cleanly when run as +a regular Python script. Since we can't change marimo's exception handling, we poll +for the output file to be written (which happens early, before shutdown) and then +kill the process once the file is ready. +""" + +import concurrent.futures +from contextlib import contextmanager +from pathlib import Path +import subprocess +import time + +import psutil + + +@contextmanager +def _managed_process(*args, name: str = "process", **kwargs): + """Context manager that ensures process tree cleanup on exit. + + Uses psutil to recursively kill all child processes. This is cross-platform + (works on Windows, Linux, macOS) unlike os.killpg which is Unix-only. + + Iteratively kills children since processes like Dask workers may spawn new + children during shutdown. Uses wait_procs() to ensure children are actually + dead before proceeding. + + After killing, captures and prints any stdout/stderr for debugging. + + Inspired by https://gist.github.com/jizhilong/6687481#gistcomment-3057122 + """ + proc = subprocess.Popen(*args, **kwargs) + try: + yield proc + finally: + try: + parent = psutil.Process(proc.pid) + + # Iteratively kill children - Dask may spawn more during shutdown + for _ in range(3): + children = parent.children(recursive=True) + if not children: + break + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + # Wait for children to actually die + psutil.wait_procs(children, timeout=2) + + parent.kill() + parent.wait(timeout=3) + + except psutil.NoSuchProcess: + pass # Already dead + except psutil.TimeoutExpired: + pass # Continue anyway - proc.wait() below will handle it + + proc.wait() + # Now safe to read pipes - process is dead + _print_process_output(name, proc) + + +def _print_process_output(name: str, proc: subprocess.Popen) -> None: + """Print captured stdout/stderr from a process for debugging.""" + stdout = proc.stdout.read().decode() if proc.stdout else "" + stderr = proc.stderr.read().decode() if proc.stderr else "" + if stdout.strip(): + print(f" [{name}] stdout:\n{stdout.rstrip()}") + if stderr.strip(): + print(f" [{name}] stderr:\n{stderr.rstrip()}") + + +# Marimo notebooks to export as HTML with outputs +MARIMO_NOTEBOOKS = [ + { + "source": "docs/tutorials/skill_basics/tutorial.py", + "output": "docs/tutorials/skill_basics/tutorial_rendered.html", + }, + { + "source": "docs/tutorials/skill_with_agent/tutorial.py", + "output": "docs/tutorials/skill_with_agent/tutorial_rendered.html", + }, + { + "source": "docs/tutorials/multi_agent/tutorial.py", + "output": "docs/tutorials/multi_agent/tutorial_rendered.html", + }, +] + + +def _export_notebook(source: Path, output: Path, timeout: int = 180) -> bool: + """Export a marimo notebook, killing the process once the file is ready. + + The notebooks use Dask which hangs on shutdown, but the HTML is generated + within the first few seconds. This function polls for the output file and + kills the process early once it's ready, rather than waiting for the full timeout. + + Returns True if the export succeeded, False otherwise. + """ + name = source.stem # Short name for log messages + cmd = ["marimo", "export", "html", str(source), "-o", str(output), "--force", "--no-sandbox"] + print(f" [{name}] Running: {' '.join(cmd)}") + print(f" [{name}] Exporting", end="", flush=True) + + # Delete old output so we can detect when new file is written + # (mtime comparison is unreliable due to filesystem timestamp granularity) + if output.exists(): + output.unlink() + + start = time.time() + poll_interval = 0.5 # Check every 500ms + min_file_size = 1000 # HTML should be at least 1KB + last_size = 0 + stable_count = 0 + stable_threshold = 2 # Require 2 consecutive identical sizes (write_text isn't atomic) + last_dot_time = start + dot_interval = 5 # Print a dot every 5 seconds to show activity + + with _managed_process( + cmd, + name=name, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) as proc: + while time.time() - start < timeout: + elapsed = time.time() - start + + # Check if process finished naturally + if proc.poll() is not None: + if proc.returncode == 0: + print(f" done ({elapsed:.1f}s)") + # Output printed by context manager's finally block + return True + else: + print(" failed!") + print(f" [{name}] Export failed (exit code {proc.returncode})") + # Output printed by context manager's finally block + return False + + # Print dots to show activity (works well with captured output) + if elapsed - (last_dot_time - start) >= dot_interval: + print(".", end="", flush=True) + last_dot_time = time.time() + + # Check if output file is ready (exists, stable size, and has content) + if output.exists(): + current_size = output.stat().st_size + if current_size > min_file_size and current_size == last_size: + stable_count += 1 + if stable_count >= stable_threshold: + # File write is complete - context manager kills process & prints output + print(f" done ({elapsed:.1f}s, {current_size // 1024}KB)") + return True + else: + stable_count = 0 + last_size = current_size + + time.sleep(poll_interval) + + # Timeout reached - context manager kills process & prints output + elapsed = time.time() - start + if output.exists() and output.stat().st_size > min_file_size: + print( + f" done ({elapsed:.1f}s, {output.stat().st_size // 1024}KB) [timeout but file generated]" + ) + return True + else: + print(f" timeout ({int(elapsed)}s) - file not generated") + return False + + +def on_pre_build(config): + """Export marimo notebooks to HTML before mkdocs build. + + Exports run in parallel using ThreadPoolExecutor. This is safe because + each notebook writes to a separate output file and psutil operations are thread-safe. + """ + to_export: list[tuple[Path, Path]] = [] + + for notebook in MARIMO_NOTEBOOKS: + source = Path(notebook["source"]) + output = Path(notebook["output"]) + + if not source.exists(): + print(f"Warning: Notebook {source} not found, skipping") + continue + + # Skip if output exists and is newer than source + if output.exists() and output.stat().st_mtime > source.stat().st_mtime: + print(f"Skipping {source} (output is up to date)") + continue + + to_export.append((source, output)) + + if not to_export: + return + + print(f"Exporting {len(to_export)} notebook(s) in parallel...") + + with concurrent.futures.ThreadPoolExecutor(max_workers=len(to_export)) as executor: + futures = {executor.submit(_export_notebook, src, out): src for src, out in to_export} + + for future in concurrent.futures.as_completed(futures): + source = futures[future] + try: + success = future.result() + if not success: + print(f"Warning: Export failed for {source}") + except Exception as e: + print(f"Error exporting {source}: {e}") diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000000..9c5d3bccac --- /dev/null +++ b/docs/index.md @@ -0,0 +1,55 @@ +# The Dimensional Framework + +*The universal framework for AI-native generalist robotics* + +## What is Dimensional? + +**TODO:** Stash will be writing the intro to DimOS + +## Key Features + +**TODO: For dimos team to fill in** + +## Ready to jump in? + +
+ +- :material-rocket-launch: **Quickstart** + + --- + + Get up and running with DimOS in minutes. + + [:octicons-arrow-right-24: Get started](quickstart.md) + +- :material-school: **Tutorials** + + --- + + Hands-on tutorials. + + [:octicons-arrow-right-24: View tutorials](tutorials/index.md) + +- :material-school: **Concepts** + + --- + + Guides to DimOS's key concepts. + + [:octicons-arrow-right-24: View tutorials](concepts/index.md) + +- :material-code-braces: **API Reference** + + --- + + Reference for DimOS's public API. + + [:octicons-arrow-right-24: Explore API](api/index.md) + +
+ +## Community + +- [GitHub](https://github.com/dimensionalOS/dimos) - Source code and issues +- [Discord](https://discord.gg/74U8guVj8q) +- [Email](mailto:build@dimensionalOS.com) - Contact the team diff --git a/docs/jetson.MD b/docs/jetson.MD new file mode 100644 index 0000000000..a4d06e3255 --- /dev/null +++ b/docs/jetson.MD @@ -0,0 +1,72 @@ +# DimOS Jetson Setup Instructions +Tested on Jetpack 6.2, CUDA 12.6 + +## Required system dependencies +`sudo apt install portaudio19-dev python3-pyaudio` + +## Installing cuSPARSELt +https://ninjalabo.ai/blogs/jetson_pytorch.html + +```bash +wget https://developer.download.nvidia.com/compute/cusparselt/0.7.0/local_installers/cusparselt-local-tegra-repo-ubuntu2204-0.7.0_1.0-1_arm64.deb +sudo dpkg -i cusparselt-local-tegra-repo-ubuntu2204-0.7.0_1.0-1_arm64.deb +sudo cp /var/cusparselt-local-tegra-repo-ubuntu2204-0.7.0/cusparselt-*-keyring.gpg /usr/share/keyrings/ +sudo apt-get update +sudo apt-get install libcusparselt0 libcusparselt-dev +ldconfig +``` +## Install Torch and Torchvision wheels + +Enter virtualenv +```bash +python3 -m venv venv +source venv/bin/activate +``` + +Wheels for jp6/cu126 +https://pypi.jetson-ai-lab.io/jp6/cu126 + +Check compatibility: +https://docs.nvidia.com/deeplearning/frameworks/install-pytorch-jetson-platform-release-notes/pytorch-jetson-rel.html + +### Working torch wheel tested on Jetpack 6.2, CUDA 12.6 +`pip install --no-cache https://developer.download.nvidia.com/compute/redist/jp/v61/pytorch/torch-2.5.0a0+872d972e41.nv24.08.17622132-cp310-cp310-linux_aarch64.whl` + +### Install torchvision from source: +```bash +# Set version by checking above torchvision<-->torch compatibility + +# We use 0.20.0 +export VERSION=20 + +sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev libopenblas-dev libavcodec-dev libavformat-dev libswscale-dev +git clone --branch release/0.$VERSION https://github.com/pytorch/vision torchvision +cd torchvision +export BUILD_VERSION=0.$VERSION.0 +python3 setup.py install --user # remove --user if installing in virtualenv +``` + +### Verify success: +```bash +$ python3 +import torch +print(torch.__version__) +print('CUDA available: ' + str(torch.cuda.is_available())) # Should be True +print('cuDNN version: ' + str(torch.backends.cudnn.version())) +a = torch.cuda.FloatTensor(2).zero_() +print('Tensor a = ' + str(a)) +b = torch.randn(2).cuda() +print('Tensor b = ' + str(b)) +c = a + b +print('Tensor c = ' + str(c)) + +$ python3 +import torchvision +print(torchvision.__version__) +``` + +## Install Onnxruntime-gpu + +Find pre-build wheels here for your specific JP/CUDA version: https://pypi.jetson-ai-lab.io/jp6 + +`pip install https://pypi.jetson-ai-lab.io/jp6/cu126/+f/4eb/e6a8902dc7708/onnxruntime_gpu-1.23.0-cp310-cp310-linux_aarch64.whl#sha256=4ebe6a8902dc7708434b2e1541b3fe629ebf434e16ab5537d1d6a622b42c622b` diff --git a/docs/modules.md b/docs/modules.md new file mode 100644 index 0000000000..9cdbf586ac --- /dev/null +++ b/docs/modules.md @@ -0,0 +1,165 @@ +# Dimensional Modules + +The DimOS Module system enables distributed, multiprocess robotics applications using Dask for compute distribution and LCM (Lightweight Communications and Marshalling) for high-performance IPC. + +## Core Concepts + +### 1. Module Definition +Modules are Python classes that inherit from `dimos.core.Module` and define inputs, outputs, and RPC methods: + +```python +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import Vector3 + +class MyModule(Module): + # Declare inputs/outputs as class attributes initialized to None + data_in: In[Vector3] = None + data_out: Out[Vector3] = None + + def __init__(): + # Call parent Module init + super().__init__() + + @rpc + def remote_method(self, param): + """Methods decorated with @rpc can be called remotely""" + return param * 2 +``` + +### 2. Module Deployment +Modules are deployed across Dask workers using the `dimos.deploy()` method: + +```python +from dimos import core + +# Start Dask cluster with N workers +dimos = core.start(4) + +# Deploying modules allows for passing initialization parameters. +# In this case param1 and param2 are passed into Module init +module = dimos.deploy(Module, param1="value1", param2=123) +``` + +### 3. Stream Connections +Modules communicate via reactive streams using LCM transport: + +```python +# Configure LCM transport for outputs +module1.data_out.transport = core.LCMTransport("/topic_name", MessageType) + +# Connect module inputs to outputs +module2.data_in.connect(module1.data_out) + +# Access the underlying Observable stream +stream = module1.data_out.observable() +stream.subscribe(lambda msg: print(f"Received: {msg}")) +``` + +### 4. Module Lifecycle +```python +# Start modules to begin processing +module.start() # Calls the @rpc start() method if defined + +# Inspect module I/O configuration +print(module.io().result()) # Shows inputs, outputs, and RPC methods + +# Clean shutdown +dimos.shutdown() +``` + +## Real-World Example: Robot Control System + +```python +# Connection module wraps robot hardware/simulation +connection = dimos.deploy(ConnectionModule, ip=robot_ip) +connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) +connection.video.transport = core.LCMTransport("/video", Image) + +# Perception module processes sensor data +perception = dimos.deploy(PersonTrackingStream, camera_intrinsics=[...]) +perception.video.connect(connection.video) +perception.tracking_data.transport = core.pLCMTransport("/person_tracking") + +# Start processing +connection.start() +perception.start() + +# Enable tracking via RPC +perception.enable_tracking() + +# Get latest tracking data +data = perception.get_tracking_data() +``` + +## LCM Transport Configuration + +```python +# Standard LCM transport for simple types like lidar +connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + +# Pickle-based transport for complex Python objects / dictionaries +connection.tracking_data.transport = core.pLCMTransport("/person_tracking") + +# Auto-configure LCM system buffers (required in containers) +from dimos.protocol import pubsub +pubsub.lcm.autoconf() +``` + +This architecture enables building complex robotic systems as composable, distributed modules that communicate efficiently via streams and RPC, scaling from single machines to clusters. + +# Dimensional Install +## Python Installation (Ubuntu 22.04) + +```bash +sudo apt install python3-venv + +# Clone the repository (dev branch, no submodules) +git clone -b dev https://github.com/dimensionalOS/dimos.git +cd dimos + +# Create and activate virtual environment +python3 -m venv venv +source venv/bin/activate + +sudo apt install portaudio19-dev python3-pyaudio + +# Install torch and torchvision if not already installed +# Example CUDA 11.7, Pytorch 2.0.1 (replace with your required pytorch version if different) +pip install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +### Install dependencies +```bash +# CPU only (reccomended to attempt first) +pip install .[cpu,dev] + +# CUDA install +pip install .[cuda,dev] + +# Copy and configure environment variables +cp default.env .env +``` + +### Test install +```bash +# Run standard tests +pytest -s dimos/ + +# Test modules functionality +pytest -s -m module dimos/ + +# Test LCM communication +pytest -s -m lcm dimos/ +``` + +# Unitree Go2 Quickstart + +To quickly test the modules system, you can run the Unitree Go2 multiprocess example directly: + +```bash +# Make sure you have the required environment variables set +export ROBOT_IP= + +# Run the multiprocess Unitree Go2 example +python dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +``` diff --git a/docs/modules_CN.md b/docs/modules_CN.md new file mode 100644 index 0000000000..89e16c7112 --- /dev/null +++ b/docs/modules_CN.md @@ -0,0 +1,188 @@ +# Dimensional 模块系统 + +DimOS 模块系统使用 Dask 进行计算分布和 LCM(轻量级通信和编组)进行高性能进程间通信,实现分布式、多进程的机器人应用。 + +## 核心概念 + +### 1. 模块定义 +模块是继承自 `dimos.core.Module` 的 Python 类,定义输入、输出和 RPC 方法: + +```python +from dimos.core import Module, In, Out, rpc +from dimos.msgs.geometry_msgs import Vector3 + +class MyModule(Module): # ROS Node + # 将输入/输出声明为初始化为 None 的类属性 + data_in: In[Vector3] = None # ROS Subscriber + data_out: Out[Vector3] = None # ROS Publisher + + def __init__(): + # 调用父类 Module 初始化 + super().__init__() + + @rpc + def remote_method(self, param): + """使用 @rpc 装饰的方法可以远程调用""" + return param * 2 +``` + +### 2. 模块部署 +使用 `dimos.deploy()` 方法在 Dask 工作进程中部署模块: + +```python +from dimos import core + +# 启动具有 N 个工作进程的 Dask 集群 +dimos = core.start(4) + +# 部署模块时可以传递初始化参数 +# 在这种情况下,param1 和 param2 被传递到模块初始化中 +module = dimos.deploy(Module, param1="value1", param2=123) +``` + +### 3. 流连接 +模块通过使用 LCM 传输的响应式流进行通信: + +```python +# 为输出配置 LCM 传输 +module1.data_out.transport = core.LCMTransport("/topic_name", MessageType) + +# 将模块输入连接到输出 +module2.data_in.connect(module1.data_out) + +# 访问底层的 Observable 流 +stream = module1.data_out.observable() +stream.subscribe(lambda msg: print(f"接收到: {msg}")) +``` + +### 4. 模块生命周期 +```python +# 启动模块以开始处理 +module.start() # 如果定义了 @rpc start() 方法,则调用它 + +# 检查模块 I/O 配置 +print(module.io().result()) # 显示输入、输出和 RPC 方法 + +# 优雅关闭 +dimos.shutdown() +``` + +## 实际示例:机器人控制系统 + +```python +# 连接模块封装机器人硬件/仿真 +connection = dimos.deploy(ConnectionModule, ip=robot_ip) +connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) +connection.video.transport = core.LCMTransport("/video", Image) + +# 感知模块处理传感器数据 +perception = dimos.deploy(PersonTrackingStream, camera_intrinsics=[...]) +perception.video.connect(connection.video) +perception.tracking_data.transport = core.pLCMTransport("/person_tracking") + +# 开始处理 +connection.start() +perception.start() + +# 通过 RPC 启用跟踪 +perception.enable_tracking() + +# 获取最新的跟踪数据 +data = perception.get_tracking_data() +``` + +## LCM 传输配置 + +```python +# 用于简单类型(如激光雷达)的标准 LCM 传输 +connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + +# 用于复杂 Python 对象/字典的基于 pickle 的传输 +connection.tracking_data.transport = core.pLCMTransport("/person_tracking") + +# 自动配置 LCM 系统缓冲区(在容器中必需) +from dimos.protocol import pubsub +pubsub.lcm.autoconf() +``` + +这种架构使得能够将复杂的机器人系统构建为可组合的分布式模块,这些模块通过流和 RPC 高效通信,从单机扩展到集群。 + +# Dimensional 安装指南 +## Python 安装(Ubuntu 22.04) + +```bash +sudo apt install python3-venv + +# 克隆仓库(dev 分支,无子模块) +git clone -b dev https://github.com/dimensionalOS/dimos.git +cd dimos + +# 创建并激活虚拟环境 +python3 -m venv venv +source venv/bin/activate + +sudo apt install portaudio19-dev python3-pyaudio + +# 如果尚未安装,请安装 torch 和 torchvision +# 示例 CUDA 11.7,Pytorch 2.0.1(如果需要不同的 pytorch 版本,请替换) +pip install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +### 安装依赖 +```bash +# 仅 CPU(建议首先尝试) +pip install .[cpu,dev] + +# CUDA 安装 +pip install .[cuda,dev] + +# 复制并配置环境变量 +cp default.env .env +``` + +### 测试安装 +```bash +# 运行标准测试 +pytest -s dimos/ + +# 测试模块功能 +pytest -s -m module dimos/ + +# 测试 LCM 通信 +pytest -s -m lcm dimos/ +``` + +# Unitree Go2 快速开始 + +要快速测试模块系统,您可以直接运行 Unitree Go2 多进程示例: + +```bash +# 确保设置了所需的环境变量 +export ROBOT_IP= + +# 运行多进程 Unitree Go2 示例 +python dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +``` + +## 模块系统的高级特性 + +### 分布式计算 +DimOS 模块系统建立在 Dask 之上,提供了强大的分布式计算能力: + +- **自动负载均衡**:模块自动分布在可用的工作进程中 +- **容错性**:如果工作进程失败,模块可以在其他工作进程上重新启动 +- **可扩展性**:从单机到集群的无缝扩展 + +### 响应式编程模型 +使用 RxPY 实现的响应式流提供了: + +- **异步处理**:非阻塞的数据流处理 +- **背压处理**:自动管理快速生产者和慢速消费者 +- **操作符链**:使用 map、filter、merge 等操作符进行流转换 + +### 性能优化 +LCM 传输针对机器人应用进行了优化: + +- **零拷贝**:大型消息的高效内存使用 +- **低延迟**:微秒级的消息传递 +- **多播支持**:一对多的高效通信 diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 0000000000..e94278f6f6 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,138 @@ +# Quickstart + +DimOS is a modular framework for building agentive robots. In this quickstart, you'll learn the basics of DimOS by building an LLM agent that can make greetings. + +## Installation + +### Requirements + +- Python 3.10+ +- OpenAI API key in environment + +### Install DimOS + +```python +# TODO: Ideally, when this is released, this should be as simple as +# pip install dimos +``` + + +## Define a skill + +Suppose you have a robot with a speaker. + +```python +from dimos.core.skill_module import SkillModule +from dimos.core.core import rpc +from dimos.protocol.skill.skill import skill + +# See the Skills concept guide for more on SkillModule +class Robot(SkillModule): + rpc_calls = [] + + # In a real setting, there would also be things like a ConnectionModule + # for the robot platform you are using + @rpc + def speak(self, text: str) -> str: + print(f"[Robot] {text}") + return f"SPEAK: {text}" +``` + +How can we wire up this `speak` capability to an LLM agent -- how can we go from this to an agentic robot that can make greetings by using that on-board speaker? + +Answer: make a *skill* -- a method on a `Module` that's decorated with `@skill`, so that it gets turned into *tools* that an agent can call. + +```python +class Greeter(SkillModule): + rpc_calls = ["Robot.speak"] # Declare dependency + + @skill() + def greet(self, name: str = "friend") -> str: + '''Greet someone by name.''' + self.get_rpc_calls("Robot.speak")(f"Hello, {name}!") + return f"Greeted {name}" +``` + +Notice that `Greeter` doesn't import `Robot` directly. Instead, it declares a dependency in `rpc_calls`, and the framework wires them together at runtime. + +## Wire up and build the modules + +Now we can call `llm_agent` to get an LLM agent module, combine it with our `Robot` and `Greeter` modules to get a [*blueprint*](concepts/blueprints.md) for the whole system, and then build it. + +```python +from dotenv import load_dotenv + +load_dotenv() + +from dimos.agents2.agent import LlmAgent, llm_agent +from dimos.core.blueprints import autoconnect + +dimos = ( + autoconnect( + Robot.blueprint(), + Greeter.blueprint(), + llm_agent(system_prompt="You're a friendly robot. Use greet when asked to say hello."), + ) + .global_config(n_dask_workers=1) + .build() +) + +print("System running!") +``` + +``` {title="Output"} +deployed: Robot-f970968e-... @ worker 0 +deployed: Greeter-fe23b94c-... @ worker 0 +deployed: LlmAgent-dc45564b-... @ worker 0 +System running! +``` + +As part of this process, the blueprint system matches dependencies (e.g., `Greeter`'s need for `Robot.speak`) and converts `Greeter`'s `greet` skill to a tool for the LLM agent. + +The system is now running. For long-running applications, you'd call `dimos.loop()` to keep it alive until Ctrl+C. + +## Say hi to our agent + +Time to say hi to our agent (in a real system, the robot's greetings would then be piped through the speakers): + +```python +agent = dimos.get_instance(LlmAgent) +print(agent.query("Hi there!")) +``` + +``` {title="Output"} +Hello! How are you doing today? +``` + +> [!NOTE] +> Exactly what greeting the LLM will make will, of course, differ across runs. + +```python +print(agent.query("Can you greet Alice as well?")) +``` + +``` {title="Output"} +[Robot] Hello, Alice! +Hello to Alice! +``` + +You now have a robot that you can ask -- in ordinary English -- for greetings! + +## What you learned + +You've seen the core DimOS pattern: define skills in modules, wire them together with the blueprint system, and let an LLM agent handle natural language requests. + +## Next steps + +### Tutorials + +- [Build your first skill](tutorials/skill_basics/tutorial.md): A tutorial that explains how to build a skill -- and how the blueprint system works -- in more detail +- [Equip an agent with skills](tutorials/skill_with_agent/tutorial.md) +- [Build a multi-agent RobotButler](tutorials/multi_agent/tutorial.md): Build a multi-agent RoboButler system, where a planner agent coordinates specialist subagents. + +### Concept guides + +- [Blueprints](concepts/blueprints.md) +- [Agents](concepts/agent.md) +- [Modules](concepts/modules.md) +- [Transport](concepts/transport.md) diff --git a/docs/running_without_devcontainer.md b/docs/running_without_devcontainer.md new file mode 100644 index 0000000000..d06785e359 --- /dev/null +++ b/docs/running_without_devcontainer.md @@ -0,0 +1,21 @@ +install nix, + +https://nixos.wiki/wiki/Nix_Installation_Guide +```sh +sudo install -d -m755 -o $(id -u) -g $(id -g) /nix +curl -L https://nixos.org/nix/install | sh +``` + +install direnv +https://direnv.net/ +```sh +apt-get install direnv +echo 'eval "$(direnv hook bash)"' >> ~/.bashrc +``` + +allow direnv in dimos will take a bit to pull the packages, +from that point on your env is standardized +```sh +cd dimos +direnv allow +``` diff --git a/docs/testing_stream_reply.md b/docs/testing_stream_reply.md new file mode 100644 index 0000000000..e3189bb5e8 --- /dev/null +++ b/docs/testing_stream_reply.md @@ -0,0 +1,174 @@ +# Sensor Replay & Storage Toolkit + +A lightweight framework for **recording, storing, and replaying binary data streams for automated tests**. It keeps your repository small (data lives in Git LFS) while giving you Python‑first ergonomics for working with RxPY streams, point‑clouds, videos, command logs—anything you can pickle. + +--- + +## 1 At a Glance + +| Need | One liner | +| ------------------------------ | ------------------------------------------------------------- | +| **Iterate over every message** | `SensorReplay("raw_odometry_rotate_walk").iterate(print)` | +| **RxPY stream for piping** | `SensorReplay("raw_odometry_rotate_walk").stream().pipe(...)` | +| **Throttle replay rate** | `SensorReplay("raw_odometry_rotate_walk").stream(rate_hz=10)` | +| **Raw path to a blob/dir** | `path = testData("raw_odometry_rotate_walk")` | +| **Store a new stream** | see [`SensorStorage`](#5-storing-new-streams) | + +> If the requested blob is missing locally, it is transparently downloaded from Git LFS, extracted to `tests/data//`, and cached for subsequent runs. + +--- + +## 2 Goals + +* **Zero setup for CI & collaborators** – data is fetched on demand. +* **No repo bloat** – binaries live in Git LFS; the working tree stays trim. +* **Symmetric API** – `SensorReplay` ↔︎ `SensorStorage`; same name, different direction. +* **Format agnostic** – replay *anything* you can pickle (protobuf, numpy, JPEG, …). +* **Data type agnostic** – with testData("raw_odometry_rotate_walk") you get a Path object back, can be a raw video file, whole codebase, ML model etc + + +--- + +## 3 Replaying Data + +### 3.1 Iterating Messages + +```python +from sensor_tools import SensorReplay + +# Print every stored Odometry message +SensorReplay(name="raw_odometry_rotate_walk").iterate(print) +``` + +### 3.2 RxPY Streaming + +```python +from rx import operators as ops +from operator import sub, add +from dimos.utils.testing import SensorReplay, SensorStorage +from dimos.robot.unitree_webrtc.type.odometry import Odometry + +# Compute total yaw rotation (radians) + +total_rad = ( + SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) + .stream() + .pipe( + ops.map(lambda odom: odom.rot.z), + ops.pairwise(), # [1,2,3,4] -> [[1,2], [2,3], [3,4]] + ops.starmap(sub), # [sub(1,2), sub(2,3), sub(3,4)] + ops.reduce(add), + ) + .run() +) + +assert total_rad == pytest.approx(4.05, abs=0.01) +``` + +### 3.3 Lidar Mapping Example (200MB blob) + +```python +from dimos.utils.testing import SensorReplay, SensorStorage +from dimos.robot.unitree_webrtc.type.map import Map + +lidar_stream = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) +map_ = Map(voxel_size=0.5) + +# Blocks until the stream is consumed +map_.consume(lidar_stream.stream()).run() + +assert map_.costmap.grid.shape == (404, 276) +``` + +--- + +## 4 Low Level Access + +If you want complete control, call **`testData(name)`** to get a `Path` to the extracted file or directory — no pickling assumptions: + +```python +absolute_path: Path = testData("some_name") +``` + +Do whatever you like: open a video file, load a model checkpoint, etc. + +--- + +## 5 Storing New Streams + +1. **Write a test marked `@pytest.mark.tool`** so CI skips it by default. +2. Use `SensorStorage` to persist the stream into `tests/data//*.pickle`. + +```python +@pytest.mark.tool +def test_store_odometry_stream(): + load_dotenv() + + robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") + robot.standup() + + storage = SensorStorage("raw_odometry_rotate_walk2") + storage.save_stream(robot.raw_odom_stream()) # ← records until interrupted + + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + robot.liedown() +``` + +### 5.1 Behind the Scenes + +* Any new file/dir under `tests/data/` is treated as a **data blob**. +* `./bin/lfs_push` compresses it into `tests/data/.lfs/.tar.gz` *and* uploads it to Git LFS. +* Only the `.lfs/` archive is committed; raw binaries remain `.gitignored`. + +--- + +## 6 Storing Arbitrary Binary Data + +Just copy to `tests/data/whatever` +* `./bin/lfs_push` compresses it into `tests/data/.lfs/.tar.gz` *and* uploads it to Git LFS. + +--- + +## 7 Developer Workflow Checklist + +1. **Drop new data** into `tests/data/`. +2. Run your new tests that use SensorReplay or testData calls, make sure all works +3. Run `./bin/lfs_push` (or let the pre commit hook nag you). +4. Commit the resulting `tests/data/.lfs/.tar.gz`. +5. Optional - you can delete `tests/data/your_new_stuff` and re-run the test to ensure it gets downloaded from LFS correclty +6. Push/PR + +### 7.1 Pre commit Setup (optional but recommended) + +```sh +sudo apt install pre-commit +pre-commit install # inside repo root +``` + +Now each commit checks formatting, linting, *and* whether you forgot to push new blobs: + +``` +$ echo test > tests/data/foo.txt +$ git add tests/data/foo.txt && git commit -m "demo" +LFS data ......................................................... Failed +✗ New test data detected at /tests/data: + foo.txt +Either delete or run ./bin/lfs_push +``` + +--- + +## 8 Future Work + +- A replay rate that mirrors the **original message timestamps** can be implemented downstream (e.g., an RxPY operator) +- Likely this same system should be used for production binary data delivery as well (Models etc) + +--- + +## 9 Existing Examples + +* `dimos/robot/unitree_webrtc/type/test_odometry.py` +* `dimos/robot/unitree_webrtc/type/test_map.py` diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md new file mode 100644 index 0000000000..bcd353e930 --- /dev/null +++ b/docs/tutorials/index.md @@ -0,0 +1,15 @@ +# Tutorials + +**[Build your first skill](skill_basics/tutorial.md)** + +Build a simple greeter skill and learn DimOS fundamentals: modules, blueprints, and the `@skill` decorator. + +**[Equip an agent with skills](skill_with_agent/tutorial.md)** + +Connect your skill to an LLM agent, enabling natural language invocation. + +**[Build a multi-agent RobotButler](multi_agent/tutorial.md)** + +Build a multi-agent RoboButler system, where a planner agent coordinates specialist subagents. + +--- diff --git a/docs/tutorials/multi_agent/planner_subagents.py b/docs/tutorials/multi_agent/planner_subagents.py new file mode 100644 index 0000000000..7eed26dcdd --- /dev/null +++ b/docs/tutorials/multi_agent/planner_subagents.py @@ -0,0 +1,251 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-agent tutorial: Planner + Subagents pattern. + +This file contains the module and agent classes for the multi-agent tutorial. + +Note: can't use both @skill and @rpc decorators together for a method: + Methods decorated with both @skill() and @rpc cannot be referenced via + RPC from other modules. The @skill() decorator wraps the method in a local + function that cannot be pickled for LCM transport. Workarounds: + + 1. Use only @skill() if the method is called by agents as a tool + 2. Use only @rpc if the method is called via RPC from other modules + 3. If both are needed, create separate methods or have the calling module + implement the functionality directly +""" + +import time + +from langchain_core.messages import HumanMessage + +from dimos.agents2.agent import LlmAgent +from dimos.agents2.spec import AnyMessage +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.rpc_client import RpcCall, RPCClient +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Return +from dimos.utils.logging_config import setup_logger + +# Metadata keys for tracking message flow between agents. +# From = where the message originated; To = which agent is processing it. +FROM_AGENT_KEY = "from_agent" +TO_AGENT_KEY = "to_agent" + +# TEMPORARY WORKAROUND: The base Agent class doesn't track which agent sent a query +# via RPC. To show proper From/To in the tutorial's agentspy display, we encode +# the source agent in the query string itself (e.g., "FROM:PlannerAgent|actual query"). +# The receiving agent parses this out before processing. This should eventually be +# replaced by proper metadata support in the RPC/Agent infrastructure. +FROM_PREFIX = "FROM:" +FROM_DELIMITER = "|" + +logger = setup_logger() + + +def get_from_to(msg) -> tuple[str, str]: + """Extract (from_agent, to_agent) from message additional_kwargs.""" + return ( + msg.additional_kwargs.get(FROM_AGENT_KEY, "?"), + msg.additional_kwargs.get(TO_AGENT_KEY, "?"), + ) + + +# ============================================================================= +# Robot Capabilities - Physical actions the robot can perform +# ============================================================================= + + +class RobotCapabilities(Module): + """Low-level physical capabilities for the robot.""" + + rpc_calls = [] + + @skill() + def speak(self, text: str) -> str: + """Speak text through the robot's speakers. + + Args: + text: The text to speak. + + Returns: + Status message. + + Note: + This method uses only @skill (not @rpc) because combining both + decorators on a method that's referenced via RPC causes pickle + errors - the skill wrapper is a local function that can't be + serialized. The agent calls this directly as a tool. + """ + time.sleep(0.1) + logger.info(f"[Robot] Speaking: {text}") + return f"Spoke: {text}" + + @skill() + def approach_user(self) -> str: + """Move to the user's location. + + Returns: + Status message. + """ + time.sleep(0.2) + logger.info("[Robot] Approaching user") + return "Approached user" + + @rpc + def set_PlannerAgent_register_skills(self, register_skills: RpcCall) -> None: + """Auto-register skills with the PlannerAgent.""" + register_skills.set_rpc(self.rpc) + register_skills(RPCClient(self, self.__class__)) + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + def __getstate__(self): + return {} + + def __setstate__(self, state): + pass + + +# ============================================================================= +# Agents +# ============================================================================= + + +class AgentWithFromToMetadata(LlmAgent): + """Mixin that adds from/to metadata to published messages for observability.""" + + _current_from: str = "Human" # Tracks 'from' for current query + + async def agent_loop(self, first_query: str = "") -> str: + """Override to parse 'from' agent from query string prefix.""" + if first_query.startswith(FROM_PREFIX): + from_part, first_query = first_query.split(FROM_DELIMITER, 1) + self._current_from = from_part.replace(FROM_PREFIX, "") + else: + self._current_from = "Human" + return await super().agent_loop(first_query) + + def publish(self, msg: AnyMessage) -> None: + # For HumanMessage (queries): from=tracked source, to=this agent + # For AIMessage with tool_calls: from=this agent, to=Tools (internal action) + # For AIMessage without tool_calls: from=this agent, to=whoever sent us the query + if isinstance(msg, HumanMessage): + msg.additional_kwargs[FROM_AGENT_KEY] = self._current_from + msg.additional_kwargs[TO_AGENT_KEY] = self.__class__.__name__ + else: + msg.additional_kwargs[FROM_AGENT_KEY] = self.__class__.__name__ + has_tool_calls = getattr(msg, "tool_calls", None) + msg.additional_kwargs[TO_AGENT_KEY] = "Tools" if has_tool_calls else self._current_from + super().publish(msg) + + +class PlannerAgent(AgentWithFromToMetadata): + """Coordinator agent that delegates to specialist subagents. + + Receives user requests, consults subagents for analysis, + and uses action skills to help the user. + """ + + pass + + +class WellbeingAgent(AgentWithFromToMetadata): + """Subagent specializing in mood and environmental context analysis.""" + + pass + + +class ScheduleManagementAgent(AgentWithFromToMetadata): + """Subagent specializing in calendar reasoning and reminders.""" + + pass + + +# Convenience blueprint factories +planner_agent = PlannerAgent.blueprint +wellbeing_agent = WellbeingAgent.blueprint +schedule_management_agent = ScheduleManagementAgent.blueprint + + +# ============================================================================= +# Delegation Skills - Bridge planner to subagents +# ============================================================================= + + +class DelegationSkills(Module): + """Skills that let the planner consult specialist subagents. + + We need `ret=Return.call_agent` for two reasons: + - it notifies the planner when a response arrives + - it keeps the planner's agent loop alive (the loop terminates when no running skills have this setting) + """ + + rpc_calls = [ + "WellbeingAgent.query", + "ScheduleManagementAgent.query", + ] + + @skill(ret=Return.call_agent) + def consult_wellbeing_specialist(self, situation: str) -> str: + """Consult the wellbeing specialist for mood/environmental analysis. + + Args: + situation: Description of the situation to analyze. + + Returns: + The specialist's analysis. + """ + query = self.get_rpc_calls("WellbeingAgent.query") + # Prefix so WellbeingAgent knows this came from PlannerAgent + prefixed = f"{FROM_PREFIX}PlannerAgent{FROM_DELIMITER}{situation}" + return f"[Wellbeing]: {query(prefixed)}" + + @skill(ret=Return.call_agent) + def consult_schedule_specialist(self, question: str) -> str: + """Ask the schedule specialist about events, timing, travel, or preparation needs. + + Args: + question: Question about schedule, timing, or preparation needs. + + Returns: + The specialist's schedule/timing analysis. + """ + query = self.get_rpc_calls("ScheduleManagementAgent.query") + # Prefix so ScheduleManagementAgent knows this came from PlannerAgent + prefixed = f"{FROM_PREFIX}PlannerAgent{FROM_DELIMITER}{question}" + return f"[Schedule Management]: {query(prefixed)}" + + @rpc + def set_PlannerAgent_register_skills(self, register_skills: RpcCall) -> None: + """Auto-register skills with the PlannerAgent.""" + register_skills.set_rpc(self.rpc) + register_skills(RPCClient(self, self.__class__)) + + def __getstate__(self): + return {} + + def __setstate__(self, state): + pass + + +delegation_skills = DelegationSkills.blueprint diff --git a/docs/tutorials/multi_agent/tutorial.md b/docs/tutorials/multi_agent/tutorial.md new file mode 100644 index 0000000000..8738a9a92e --- /dev/null +++ b/docs/tutorials/multi_agent/tutorial.md @@ -0,0 +1,5 @@ +# Multi-agent tutorial: Build yourself a RoboButler + + + + diff --git a/docs/tutorials/multi_agent/tutorial.py b/docs/tutorials/multi_agent/tutorial.py new file mode 100644 index 0000000000..aeae59e4fd --- /dev/null +++ b/docs/tutorials/multi_agent/tutorial.py @@ -0,0 +1,741 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "marimo>=0.17.0", +# "pyzmq", +# "python-dotenv", +# ] +# /// + +import marimo + +__generated_with = "0.18.0" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import marimo as mo + + return (mo,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + # Multi-agent tutorial: Build yourself a RoboButler + + In this tutorial, we'll build a RoboButler multi-agent system (for one robot) + consisting of a Planner agent and specialist sub-agents. + To keep things simple, we'll have just two sub-agents: one for giving advice on socio-emotional matters, + and another for managing schedules. + The Planner coordinates and consults with the specialists, + before taking the appropriate actions. + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.mermaid(""" + flowchart LR + User --> PlannerAgent + PlannerAgent --> WB[WellbeingAgent
mood/context reasoning] + PlannerAgent --> SM[ScheduleManagementAgent
calendar reasoning] + PlannerAgent --> RC[RobotCapabilities
speak, approach user] + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + We'll start with *mock* agents, then swap in real LLMs; most of the tutorial notebook can therefore be followed and run without any API keys. + + /// tip | If you're trying to run this and you're new to Marimo: + This tutorial is a [Marimo notebook](https://docs.marimo.io/). See the [Marimo quickstart](https://docs.marimo.io/getting_started/index.html) to get started. + /// + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Prerequisites + + - Ideally some familiarity with DimOS skills and single-agent systems (see the [skill tutorials](../skill_basics/tutorial.md)); but the tutorial should be broadly understandable even if not + - OpenAI API key for the real LLM section + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Setup + """) + return + + +@app.cell +def _(): + import inspect + + from dotenv import load_dotenv + + load_dotenv() + + from dimos.agents2.agent import LlmAgent + from dimos.agents2.testing import MockModel + from dimos.core.blueprints import autoconnect + from docs.tutorials.multi_agent.planner_subagents import ( + DelegationSkills, + PlannerAgent, + RobotCapabilities, + ScheduleManagementAgent, + WellbeingAgent, + get_from_to, + ) + + return ( + DelegationSkills, + MockModel, + PlannerAgent, + RobotCapabilities, + ScheduleManagementAgent, + WellbeingAgent, + autoconnect, + get_from_to, + inspect, + ) + + +@app.cell +def _(): + from dimos.utils.cli.agentspy.agentspy import ( + AgentMessageMonitor, + format_message_content, + format_timestamp, + get_message_type_and_style, + ) + + # Set up a monitor for agent messages -- more on this later + message_monitor = AgentMessageMonitor() + message_monitor.start() + return format_message_content, format_timestamp, message_monitor + + +@app.cell +def _(format_message_content, format_timestamp, get_from_to, mo): + def truncate(s: str, limit: int = 100) -> str: + return s[:limit] + ("..." if len(s) > limit else "") + + def render_spy_accordion(messages, title="Agentspy"): + """Render agent messages as a collapsible accordion.""" + if not messages: + return mo.accordion({title: mo.md("*No messages captured*")}) + + def entry_to_row(entry): + from_agent, to_agent = get_from_to(entry.message) + return { + "Time": format_timestamp(entry.timestamp), + "From": from_agent, + "To": to_agent, + "Content": truncate(format_message_content(entry.message)), + } + + rows = list(map(entry_to_row, messages)) + table = mo.ui.table(rows, label=f"{len(messages)} messages") + return mo.accordion({title: table}) + + return (render_spy_accordion,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 1: Make the Agent `Module`s + + As before, we need to start by defining the `Modules` of our system. + To keep things simple, let's start with just the Agent modules; agents, recall, are basically `Module`s that can communicate via RPC. + + How exactly to do this, however, isn't immediately obvious. To see why, consider how communicating via RPC implies that we'll need to be able to call the `query` method of each of these agents. But if we're combining multiple `LlmAgent` modules in a blueprint, how is the blueprint system going to know which agent should receive a call to `LlmAgent.query`? + + /// tip + If the details of the RPC mechanism are foggy, it might be worth looking at [the first tutorial](../skill_basics/tutorial.md) again. + /// + + The solution, fortunately, isn't difficult: just define concrete subclasses for the agents. + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + /// note + Strictly speaking, our agents subclass `AgentWithFromToMetadata` instead of subclassing `LlmAgent` directly. + `AgentWithFromToMetadata` just adds some metadata for `agentspy` -- this is a difference you can ignore. + /// + """) + + +@app.cell(hide_code=True) +def _(PlannerAgent, inspect, mo): + mo.ui.code_editor(inspect.getsource(PlannerAgent), language="python", disabled=True) + return + + +@app.cell(hide_code=True) +def _(WellbeingAgent, inspect, mo): + mo.ui.code_editor(inspect.getsource(WellbeingAgent), language="python", disabled=True) + return + + +@app.cell(hide_code=True) +def _(ScheduleManagementAgent, inspect, mo): + mo.ui.code_editor(inspect.getsource(ScheduleManagementAgent), language="python", disabled=True) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + Now RPC calls like `WellbeingAgent.query` are unambiguous. + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.mermaid(""" + flowchart TD + User -->|talks to| PA[PlannerAgent
coordinator] + PA -->|delegates via RPC| WB[WellbeingAgent
mood/context] + PA -->|delegates via RPC| SM[ScheduleManagementAgent
calendar/timing] + + subgraph Subagents + WB + SM + end + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 2: Define the low-level `RobotCapabilities` + + These are the physical actions the robot can perform. The skills will be registered on the `PlannerAgent`. + """) + return + + +@app.cell(hide_code=True) +def _(RobotCapabilities, inspect, mo): + mo.ui.code_editor(inspect.getsource(RobotCapabilities), language="python", disabled=True) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 3: Combine the `Module` blueprints and start with mock agents + + We'll first wire up the system with **mock agents** that return canned responses. + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### Define mock responses + + To do this, we'll instantiate `MockModel` for each agent `Module` with some predefined responses; these responses will be cycled through later, when we call the agent(s). + + /// note + Don't worry about how `MockModel` works -- the details aren't important for our purposes. + /// + + You *don't* need to read the following code closely. Just note that the mock `PlannerAgent` returns a (canned) sequence of tool calls: it calls on the various subagents for advice, before taking certain actions. + """) + return + + +@app.cell +def _(MockModel, PlannerAgent, ScheduleManagementAgent, WellbeingAgent): + from langchain_core.messages import AIMessage + + # Subagent mocks: return brief analysis strings (in Stevens' understated style) + # Note: LlmAgent auto-starts its loop on build(), consuming one response. + # The first response in each list handles this auto-loop invocation. + mock_wellbeing = WellbeingAgent.blueprint( + model_instance=MockModel( + responses=[ + "Awaiting instructions.", # consumed by auto-loop on startup + "One notes a certain weariness. The weather may be a factor. Measured comfort advised.", + "There appears to be some improvement in disposition.", + ] + ) + ) + + mock_schedule_management = ScheduleManagementAgent.blueprint( + model_instance=MockModel( + responses=[ + "Awaiting instructions.", # consumed by auto-loop on startup + "Dental appointment at two o'clock. Departure by 1:35 prudent. Insurance card and umbrella required.", + "No pressing engagements. An opportune moment for repose.", + ] + ) + ) + + # Planner mock: sequence of tool calls showing the delegation flow (Stevens personality) + # Note: The first response is consumed by LlmAgent's auto-loop on startup. + mock_planner = PlannerAgent.blueprint( + model_instance=MockModel( + responses=[ + "Awaiting instructions.", # consumed by auto-loop on startup + # 1. Delegate to wellbeing specialist + AIMessage( + content="", + tool_calls=[ + { + "id": "call_1", + "name": "consult_wellbeing_specialist", + "args": {"situation": "Individual not feeling entirely themselves"}, + } + ], + ), + # 2. Delegate to schedule specialist + AIMessage( + content="", + tool_calls=[ + { + "id": "call_2", + "name": "consult_schedule_specialist", + "args": {"question": "Today's engagements?"}, + } + ], + ), + # 3. Act: offer comfort via speak (in Stevens' restrained manner) + AIMessage( + content="", + tool_calls=[ + { + "id": "call_3", + "name": "speak", + "args": {"text": "If I may: matters are well in hand."}, + } + ], + ), + # 4. Act: send departure reminder via speak + AIMessage( + content="", + tool_calls=[ + { + "id": "call_4", + "name": "speak", + "args": { + "text": "A gentle reminder: departure by 1:35 PM would be prudent for the dental appointment." + }, + } + ], + ), + # 5. Final synthesized response (in Stevens' formal style) + "If I may: departure by 1:35 for the dental appointment. I have noted the insurance card and umbrella. One trusts this is satisfactory.", + ] + ) + ) + return mock_planner, mock_schedule_management, mock_wellbeing + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r""" + Let's recap: We've defined the constituent `Modules` of our multi-agent system, and prepped them with mocks. But we haven't yet done anything to make it possible for `PlannerAgent` to consult the sub-agents. + + It's worth pausing a moment to ask yourself: how can we do this? + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 4: Make it possible for `PlannerAgent` to consult the sub-agents + + + Recall that the LLM in an `Agent` 'acts' via tool calls. So, if we want `PlannerAgent` to be able to call methods via RPC on the sub-agents, we need to give it a way to do so via tool calls. And the way to do that, as we've seen in [the previous tutorials](../skill_basics/tutorial.md), is with *skills*. + + In particular, we'll equip PlannerAgent with `@skill` methods that wrap the RPC calls. + + /// note | Reminder: `@rpc`-decorated methods aren't automatically exposed as tools. + That's what skills are for. + /// + """) + return + + +@app.cell(hide_code=True) +def _(DelegationSkills, inspect, mo): + mo.ui.code_editor(inspect.getsource(DelegationSkills), language="python", disabled=True) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### What the delegation flow looks like + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ``` + User: "Argh i have meetings all day" + │ + ▼ + ┌─────────────────────────────────────────┐ + │ PlannerAgent │ + │ "Let me understand the situation..." │ + └─────────────────────────────────────────┘ + │ + │ LLM decides to call skill + ▼ + ┌─────────────────────────────────────────┐ + │ consult_emotional_specialist │ + │ (internally calls RPC) │ + └─────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────────────┐ + │ WellbeingAgent │ + │ "User appears overwhelmed. │ + │ Gloomy weather..." │ + └─────────────────────────────────────────┘ + │ + │ returns analysis + ▼ + ┌─────────────────────────────────────────┐ + │ PlannerAgent │ + │ "I'll offer some emotional support │ + │ and check your schedule..." │ + └─────────────────────────────────────────┘ + ``` + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + /// note | Why `ret=Return.call_agent`? + Two reasons: + 1. `ret=Return.call_agent` notifies the agent when the skill completes + 2. It keeps the planner's agent loop alive. If no running skills had this setting, the planner's + agent loop would exit after the planner makes the tool calls -- the planner wouldn't see the subagents' responses. + /// + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 5: Combine the blueprints and build + """) + return + + +@app.cell +def _( + DelegationSkills, + RobotCapabilities, + autoconnect, + mock_planner, + mock_schedule_management, + mock_wellbeing, +): + # Build the multi-agent system with mocks + mock_blueprint = autoconnect( + # Physical robot capabilities (speak, move) + RobotCapabilities.blueprint(), + # Subagents: specialists that do the reasoning (mocks for now) + mock_wellbeing, + mock_schedule_management, + # Delegation skills: let planner consult subagents + DelegationSkills.blueprint(), + # Planner: coordinates everything (mock for now) + mock_planner, + ).global_config(n_dask_workers=1) + + print("Mock blueprint created!") + return (mock_blueprint,) + + +@app.cell +def _(mock_blueprint): + # Build and get the `ModuleCoordinator` + mock_dimos = mock_blueprint.build() + print("Mock system built and running!") + return (mock_dimos,) + + +@app.cell +def _(PlannerAgent, mock_dimos): + mock_planner_instance = mock_dimos.get_instance(PlannerAgent) + print(f"Got planner instance: {mock_planner_instance}") + return (mock_planner_instance,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### Interacting with the mock system + + Let's ask the planner something and watch the delegation flow: + """) + return + + +@app.cell +def _(message_monitor, mock_planner_instance, render_spy_accordion): + import time as _time + + # Clear previous messages to show only this query's activity + message_monitor.messages.clear() + + human_query = "have a lot going on today. not feeling great" + print(f"Mock human query: {human_query}") + + # Ask the planner - watch it delegate to subagents + mock_planner_instance.query(human_query) + + # Small delay to allow LCM message processing thread to catch up + # (mock agents complete much faster than the 50ms LCM handle_timeout interval) + _time.sleep(0.1) + + render_spy_accordion( + message_monitor.get_messages(), "View agent activity (agentspy) -- click me!" + ) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + /// tip | Try expanding the 'View agent activity' section above! + Click "View agent activity" to see what tool calls the agent made + and what responses came back from each subagent. + + (This is the notebook equivalent of the `agentspy` TUI helper—use that if you're working in the terminal.) + /// + + The planner consults both specialists, then acts on their input—speaking words of comfort and a departure reminder—before synthesizing a final response. + """) + return + + +@app.cell +def _(mock_dimos): + # Clean up mock system before building real one + mock_dimos.stop() + print("Mock system stopped") + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + --- + + ## Step 6: Swapping in real LLM agents + + Now that you've seen the multi-agent wiring work with mocks, let's swap in real LLMs. + The architecture stays the same - we'll just replace `MockModel` with actual LLM calls. + + Note that while in a real system, ScheduleManagementAgent would have skills that give them access to the user's actual schedule, + we haven't equipped ScheduleManagementAgent with any such skills here. + + /// warning | API key required + This section requires `OPENAI_API_KEY` in your environment. + /// + """) + return + + +@app.cell +def _(PlannerAgent, ScheduleManagementAgent, WellbeingAgent): + # Real subagents with specialized prompts + # Personality: Stevens from "The Remains of the Day" - formal, restrained, devoted to dignity + real_wellbeing = WellbeingAgent.blueprint( + system_prompt="""You are the observational faculty of a distinguished butler in the tradition of the great English houses. + + Your role: Observe and note the emotional state of those you serve, with professional discretion. + Provide a brief assessment of: + 1. What one might surmise about the individual's current disposition + 2. Environmental factors that may bear upon their wellbeing (weather, hour of the day) + 3. What manner of support might be appropriate, rendered with suitable restraint + + Employ understated language. One does not speak of feelings directly, but rather notes + that "there appears to be a certain weight upon the shoulders" or "one detects perhaps + a hint of fatigue." Keep observations brief (2-3 sentences) and dignified.""" + ) + + real_schedule_management = ScheduleManagementAgent.blueprint( + system_prompt="""You are the scheduling faculty of a distinguished butler, responsible for the smooth running of the household. + + Your duty: Maintain meticulous awareness of engagements and ensure all proceeds with proper timing. + Provide analysis of: + 1. Forthcoming engagements and their relative importance + 2. The time required for travel, accounting for contingencies + 3. When gentle reminders ought to be issued + 4. What preparations would be prudent (appropriate attire, necessary documents, umbrella if inclement) + + A great butler anticipates needs before they arise. Be precise and thorough. + Keep responses brief (2-3 sentences) but ensure no detail of importance is overlooked.""" + ) + + # The planner coordinates the subagents + real_planner = PlannerAgent.blueprint( + system_prompt="""You are a butler of the highest calibre, in the tradition of the great English houses. + + You embody dignity above all else. A truly great butler is one who inhabits their role + utterly and completely; who does not allow their own concerns to intrude upon the + smooth running of the household. + + You coordinate specialist faculties and take discreet action to serve: + - Use consult_wellbeing_specialist() to observe the disposition of those in your care + - Use consult_schedule_specialist() to ensure engagements proceed without difficulty + - Use speak() to offer measured words of reassurance, practical guidance, or timely reminders + - Use approach_user() to present yourself when service is required + + Speak with formal restraint. Use phrases such as "If I may" or "One might venture + to suggest" or "It would appear that..." Offer support without presumption. + Never be effusive. A raised eyebrow conveys more than + mawkishness ever could.""" + ) + return real_planner, real_schedule_management, real_wellbeing + + +@app.cell +def _( + DelegationSkills, + RobotCapabilities, + autoconnect, + real_planner, + real_schedule_management, + real_wellbeing, +): + # Build with real agents + real_blueprint = autoconnect( + RobotCapabilities.blueprint(), + real_wellbeing, + real_schedule_management, + DelegationSkills.blueprint(), + real_planner, + ).global_config(n_dask_workers=1) + + real_dimos = real_blueprint.build() + print("Real LLM system built and running!") + return (real_dimos,) + + +@app.cell +def _(PlannerAgent, real_dimos): + real_planner_instance = real_dimos.get_instance(PlannerAgent) + return (real_planner_instance,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### Try the real system + + Ask the real planner - it will actually reason and delegate: + """) + return + + +@app.cell +def _(message_monitor, mo, real_planner_instance, render_spy_accordion): + # Clear previous messages to show only this query's activity + message_monitor.messages.clear() + + query = "I have a dentist appointment this afternoon but I'm feeling really stressed about work" + + # Ask the real planner + real_response = real_planner_instance.query(query) + + mo.vstack( + [ + mo.md(f"**Human query**: {query}"), + mo.md("**Real planner response:**"), + mo.md(real_response), + render_spy_accordion( + message_monitor.get_messages(), "View agent activity (agentspy) -- click me!" + ), + ] + ) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.mermaid(""" + flowchart LR + User --> PA[PlannerAgent] + + PA -->|consult_* skills
wrap subagent RPCs| Subagents + PA --> RC[RobotCapabilities
speak, move] + + subgraph Subagents + WB[WellbeingAgent
mood specialist] + SM[ScheduleManagementAgent
calendar specialist] + end + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## What's next + + - [Agents concept guide](../../concepts/agent.md) - Deeper dive into DimOS Agents + """) + return + + +@app.cell +def _(message_monitor, real_dimos): + # Clean up + real_dimos.stop() + message_monitor.stop() + print("System stopped") + return + + +if __name__ == "__main__": + app.run() diff --git a/docs/tutorials/skill_basics/greeter.py b/docs/tutorials/skill_basics/greeter.py new file mode 100644 index 0000000000..0524ba59f8 --- /dev/null +++ b/docs/tutorials/skill_basics/greeter.py @@ -0,0 +1,129 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DimOS skill tutorial: module definitions. + +This file contains the module classes for the skill basics tutorial. + +Note that these classes cannot be defined in the __main__ of the script +that's used to orchestrate the DimOS run, +because DimOS uses Dask with process-based workers, +and classes defined in __main__ +(notebooks, scripts) cannot be pickled to other processes. + +See: https://distributed.dask.org/en/stable/api.html +""" + +import time + +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.protocol.skill.skill import skill +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +# --8<-- [start:RobotCapabilities] +class RobotCapabilities(Module): + """Low-level capabilities that our (mock) robot possesses. + + In a real setting, there would be a ConnectionModule for the robot platform you are using, + as well as a wrapper module over that that hides platform-specific details. + But to keep things simple here, we won't have anything like a ConnectionModule. + """ + + # In a real setting, you would see dependencies on methods of a 'ConnectionModule' here. + # See, e.g., dimos/robot/unitree_webrtc/unitree_g1_skill_container.py + rpc_calls = [] + + @rpc + def speak(self, text: str) -> str: + """Speak text out loud through the robot's speakers. + + Args: + text: The text to speak. + + Returns: + Status message. + """ + time.sleep(0.1) # Simulate execution time + logger.info(f"[Skill] RobotCapabilities.speak called: {text}") + return f"SPEAK: {text}" + + @rpc + def start(self) -> None: + super().start() + + @rpc + def stop(self) -> None: + super().stop() + + # The following dunder methods are for Dask serialization: + # Module instances are serialized when deployed to worker processes. + # We return {} in __getstate__ since this class has no custom state to preserve. + def __getstate__(self): + return {} + + def __setstate__(self, state): + pass + + +# --8<-- [end:RobotCapabilities] + + +# --8<-- [start:Greeter] +class Greeter(Module): + """High-level Greeter skill built on lower-level RobotCapabilities. + + Does *not* include LLM agent auto-registration -- see the skill_with_agent tutorial. + """ + + # Declares what this module needs from other modules + rpc_calls = [ + "RobotCapabilities.speak", # For speaking greetings + ] + + @skill() + # Note: Can't combine @skill and @rpc on one method (see multi_agent tutorial for details) + def greet(self, name: str = "friend") -> str: + """Greet someone by name. + + Args: + name: Name of person to greet (default: "friend"). + + Returns: + Status message with greeting details. + """ + # Skills need to have descriptive docstrings + # when working with llm agents -- more on this in the skill_with_agent tutorial + + # Get the RPC method reference we need + speak = self.get_rpc_calls("RobotCapabilities.speak") + + # Create and deliver the greeting + greeting_text = f"Hello, {name}! Nice to meet you!" + logger.info(f"[Skill] Greeter.greet executing for: {name}") + speak(greeting_text) + + return f"Successfully greeted {name}" + + def __getstate__(self): + return {} + + def __setstate__(self, state): + pass + + +# --8<-- [end:Greeter] diff --git a/docs/tutorials/skill_basics/tutorial.md b/docs/tutorials/skill_basics/tutorial.md new file mode 100644 index 0000000000..b411783ff3 --- /dev/null +++ b/docs/tutorials/skill_basics/tutorial.md @@ -0,0 +1,6 @@ +# Build your first skill + + + + + diff --git a/docs/tutorials/skill_basics/tutorial.py b/docs/tutorials/skill_basics/tutorial.py new file mode 100644 index 0000000000..6baf0e179c --- /dev/null +++ b/docs/tutorials/skill_basics/tutorial.py @@ -0,0 +1,404 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "marimo>=0.17.0", +# "pyzmq", +# ] +# /// + +import marimo + +__generated_with = "0.18.0" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import marimo as mo + + return (mo,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + # Building your first DimOS skill, part 1 + + In this tutorial, we'll build a simple skill that allows your robot to make greetings. + + We'll assume that you've skimmed the Quickstart and installed DimOS, but we won't require the simulator-related packages. + + **TODO**: Add link to installation instructions, or if they end up making a simple install script, just have a cell with that + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Setup + """) + return + + +@app.cell +def _(): + import inspect + import time + + from dimos.core.blueprints import autoconnect + from docs.tutorials.skill_basics.greeter import Greeter, RobotCapabilities + + return Greeter, RobotCapabilities, autoconnect, inspect, time + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 1: Define the skill (and its containing module) + + Before jumping into how to define skills, let's first establish some background: What even are skills; how do they enter the picture? + + On the DimOS framework, programming a robot involves composing [*modules*](../../concepts/modules.md). These might be modules that endow the robot with some sort of skill or capability -- e.g. perception or navigation capabilities -- or 'agentic' modules that orchestrate the use of certain capabilities. + + At a high level, then, skills are capabilities for your robot. But at a more prosaic level, they are just methods on a `Module` that have been wrapped with a special decorator, the `@skill` decorator. (As we'll see in the next tutorial, this allows them to be invoked by LLM agents as tool calls.) + + + + + So, to define a skill for greeting people, we need to define a module that'll house the greeting skill method, as well as the method itself: + """) + return + + +@app.cell(hide_code=True) +def _(Greeter, inspect, mo): + mo.ui.code_editor(inspect.getsource(Greeter), language="python", disabled=True) + return + + +@app.cell +def _(mo): + mo.md(""" + There are two things to explain here: (i) the declaration of the module's dependencies in `rpc_calls` and (ii) the `@skill` decorator. + + ### Dependency injection + + Notice how `Greeter` declares its dependencies in the `rpc_calls` list: + + - `rpc_calls` declares what methods this module needs from other modules + - while `get_rpc_calls()` retrieves the actual method references at runtime + + This is *dependency injection*: `Greeter` doesn't import `RobotCapabilities` directly. + Instead, the dependencies are supplied at runtime, when the modules are wired up with `autoconnect` (more on this later). + + + + ### `@skill` + + The `@skill()` decorator transforms a method into an agent-callable tool — generating a JSON schema from your signature, tracking execution state, and running in background threads. + + For simple skills, `@skill()` with no arguments works fine (as in `Greeter.greet`). For streaming or background data, you'd use parameters like `stream` and `reducer` — see the [Skills concept guide](../../concepts/skills.md) and the `@skill`-related docstrings for details. + + /// tip + Your docstring becomes the tool description LLMs see — write it for an LLM audience. + /// + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### RobotCapabilities + + We assumed that there was a `RobotCapabilities` module that encapsulated the lower-level robot capabilities that `Greeter` builds upon. This is basically a mock robot that logs when its methods are called. The main thing to note about it is the `@rpc` decorator -- more on this shortly. + """) + return + + +@app.cell(hide_code=True) +def _(RobotCapabilities, inspect, mo): + mo.ui.code_editor(inspect.getsource(RobotCapabilities), language="python", disabled=True) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + #### Why does `RobotCapabilities` use `@rpc` for `speak`? + + Whether to use `@rpc` or `@skill` depends on what you want the method to be used for. + + - `@rpc` methods are for module-to-module communication via `get_rpc_calls()` + - `@skill` methods are for invocation by agents; they come with additional infrastructure that we'll see shortly. + + That is, we're using `@rpc` for `speak` since we aren't trying to expose it to agents -- since it's more of a lower-level capability that's used by the higher-level skills. + + + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 3: Combine Module blueprints with `autoconnect` + + Now that we've defined the constituent modules of our system, now that we have defined blueprints for each of them, we can reduce them down to a combined blueprint with `autoconnect`. + + (We can also optionally override the default global configuration, as is done here with `n_dask_workers`.) + + + + """) + return + + +@app.cell +def _(Greeter, RobotCapabilities, autoconnect): + blueprint_set = autoconnect( + RobotCapabilities.blueprint(), # Provides speak + Greeter.blueprint(), # Requires RobotCapabilities.speak + ).global_config(n_dask_workers=1) + return (blueprint_set,) + + +@app.cell +def _(mo): + mo.md(r""" + ## Step 4: Build and run the blueprint + + And then we can build and run the combined blueprint. + """) + return + + +@app.cell +def _(blueprint_set): + dimos = blueprint_set.build() + return (dimos,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r""" + The `build` method wires up the modules, deploys them to worker(s), and starts them; + it returns a `ModuleCoordinator` instance that manages the deployed modules. + + + + ### Dependency injection, redux + + It's worth pausing to reflect on the wiring up of modules. + + Recall that `Greeter` had declared it needs certain dependencies. + When we build the blueprint, the blueprint system + + - checks what dependencies the various modules require; e.g., that `Greeter` needs `RobotCapabilities.speak` + - and wires up the modules so these dependencies are supplied. + + + + This, in other words, is the runtime dependency injection we had alluded to. + + ### Use the `loop` method for long-running applications + + After `build()`, the system is already running. For long-running applications (e.g. an honest-to-goodness robot), + use the `loop()` method to keep the process alive: + + ```python + dimos.loop() + ``` + + This sleeps indefinitely until interrupted (Ctrl+C / SIGINT), whereupon it calls `stop()` to shut down gracefully. + + We won't need to do that in this tutorial, though -- we'll just call `stop()` at the end. + """) + return + + +@app.cell +def _(mo): + mo.md(""" + ## Step 4: Try calling the skills + + Now that we have our running system, let's invoke our greeting skill! + + /// note + In the following, we'll peer beneath the hood and use lower-level APIs that you'd typically only use when testing or debugging. + /// + + First, we'll get our greeter module instance: + """) + return + + +@app.cell +def _(Greeter, dimos): + greeter = dimos.get_instance(Greeter) + print(f"✅ Got greeter instance: {greeter}") + return (greeter,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + Then we'll setup the `SkillCoordinator`. This is a lower-level API that you don't typically need to use; we're just using it to give you more intuition for what's happening under the hood. + """) + return + + +@app.cell +def _(greeter): + from dimos.protocol.skill.coordinator import SkillCoordinator + + skill_coordinator = SkillCoordinator() + skill_coordinator.start() + + # Register our greeter's skills with the coordinator + skill_coordinator.register_skills(greeter) + print("📋 SkillCoordinator ready with greeter's skills") + return (skill_coordinator,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(r""" + At this point, you might wonder, why not just call `greeter.greet()` directly? + + Answer: because + * (i) we want to invoke skills the way that LLM agents would, as preparation for the next tutorial + * and (ii) LLM agents don't call Python methods; instead, they make *tool calls* that get routed through the SkillCoordinator. + + The SkillCoordinator + * executes skills + * monitor skills; for instance tracking when skills start, stream updates, complete, or error + * and handles communication between agents and skills. + + By using `skill_coordinator.call_skill()` here, we're following the pattern an LLM agent will use in part 2. + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### Let's invoke some skills! + + Run the cell below to see your robot greet different people: + """) + return + + +@app.cell +def _(skill_coordinator): + # Call with no arguments (uses default "friend") + skill_coordinator.call_skill( + call_id="greeting-1", # Unique ID for this specific invocation + skill_name="greet", + args={}, # Empty args → uses default + ) + + # Call with a specific name + skill_coordinator.call_skill( + call_id="greeting-2", + skill_name="greet", + args={"args": {"name": "Alice"}}, # Pass name as keyword argument + ) + + print("Skills invoked!") + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### Monitoring skill execution + + The SkillCoordinator tracks the state of every skill invocation. Let's check what happened: + """) + return + + +@app.cell +def _(skill_coordinator, time): + from dimos.utils.cli.skillspy.skillspy import format_duration + + # Wait a moment for skills to complete + time.sleep(0.3) + + # Generate a snapshot of all skill states + snapshot = skill_coordinator.generate_snapshot(clear=False) + + print("Skill Execution Summary:") + print("-" * 40) + for call_id, state in snapshot.items(): + print(f"• {call_id}: {state.name} → {state.state.name}") + print(f" Duration: {format_duration(state.duration())}s") + print(f" Messages: {state.msg_count}") + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + Notice the output above: + - `[Skill] Greeter.greet executing` indicates the greeting skill started + - `[Skill] RobotCapabilities.speak called` shows the robot "speaking" + - Each invocation has a unique `call_id` for tracking + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + Congratulations! You've just built and deployed your first DimOS skill. + + ### Key takeaways + + Let's recap: + + **Blueprints allow you to declaratively specify and combine modules**, with runtime dependency injection. + + **Skills are methods with superpowers** — The `@skill` decorator transforms regular methods into agent-callable tools with built-in execution tracking. + + **Skill invocations are tracked** — the `call_id` lets you monitor multiple concurrent executions. + + ### What's next? + + In part 2 of this tutorial, you'll see how LLM agents use this exact same pattern to invoke skills as *tool calls*. The agent will decide when to greet, who to greet, and orchestrate complex behaviors by combining multiple skills! + """) + return + + +@app.cell +def _(dimos, skill_coordinator): + # Gracefully shut down / release resources + skill_coordinator.stop() + dimos.stop() + return + + +if __name__ == "__main__": + app.run() diff --git a/docs/tutorials/skill_with_agent/cli.py b/docs/tutorials/skill_with_agent/cli.py new file mode 100644 index 0000000000..21ff7470af --- /dev/null +++ b/docs/tutorials/skill_with_agent/cli.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Part 2 of the greeter tutorial: Standalone script for CLI usage. + +Run this script, then in another terminal use: + python -m dimos.agents2.cli.human.humancli +to interact with the agent. +""" + +from dotenv import load_dotenv + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.core.blueprints import autoconnect +from docs.tutorials.skill_with_agent.greeter import GreeterForAgents, RobotCapabilities + +load_dotenv() + +# Compose the system +blueprint = autoconnect( + RobotCapabilities.blueprint(), + GreeterForAgents.blueprint(), + llm_agent( + system_prompt="You are a friendly robot that can greet people. Use the greet skill when asked to say hello to someone." + ), + human_input(), +).global_config(n_dask_workers=1) + +if __name__ == "__main__": + print("Starting greeter agent...") + print("Use 'python -m dimos.agents2.cli.human.humancli' to interact.") + print("Press Ctrl+C to stop.") + dimos = blueprint.build() + dimos.loop() diff --git a/docs/tutorials/skill_with_agent/greeter.py b/docs/tutorials/skill_with_agent/greeter.py new file mode 100644 index 0000000000..bd1566412f --- /dev/null +++ b/docs/tutorials/skill_with_agent/greeter.py @@ -0,0 +1,50 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DimOS skill tutorial: agent-enabled greeter. + +This file extends the Greeter from skill_basics with LLM agent auto-registration. +""" + +from dimos.core.core import rpc +from dimos.core.rpc_client import RpcCall, RPCClient +from docs.tutorials.skill_basics.greeter import Greeter, RobotCapabilities + +__all__ = ["Greeter", "GreeterForAgents", "RobotCapabilities"] + + +# --8<-- [start:GreeterForAgents] +class GreeterForAgents(Greeter): + """Greeter with automatic LLM agent registration. + + Extends Greeter to enable skill auto-discovery by agents. + When composed with llm_agent() via autoconnect(), the framework calls + set_LlmAgent_register_skills to register this module's skills. + """ + + @rpc + def set_LlmAgent_register_skills(self, register_skills: RpcCall) -> None: + """Called by framework when composing with llm_agent(). + + This method is discovered by convention during blueprint.build(). + It receives a callback to register this module's skills with the agent. + + Args: + register_skills: Callback to register this module's skills with the agent. + """ + register_skills.set_rpc(self.rpc) + register_skills(RPCClient(self, self.__class__)) + + +# --8<-- [end:GreeterForAgents] diff --git a/docs/tutorials/skill_with_agent/tutorial.md b/docs/tutorials/skill_with_agent/tutorial.md new file mode 100644 index 0000000000..40865ebc11 --- /dev/null +++ b/docs/tutorials/skill_with_agent/tutorial.md @@ -0,0 +1,7 @@ +# Wire a skill to an agent + + + + + + diff --git a/docs/tutorials/skill_with_agent/tutorial.py b/docs/tutorials/skill_with_agent/tutorial.py new file mode 100644 index 0000000000..9626e5ad2f --- /dev/null +++ b/docs/tutorials/skill_with_agent/tutorial.py @@ -0,0 +1,441 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "marimo>=0.17.0", +# "pyzmq", +# "python-dotenv", +# ] +# /// + +import marimo + +__generated_with = "0.18.0" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import marimo as mo + + return (mo,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + # Building your first DimOS skill, part 2: Adding an LLM agent + + In [part 1](../skill_basics/tutorial.py), you built a greeter skill and invoked it manually via `SkillCoordinator.call_skill()`. + This gave you a foundation for understanding how skills work under the hood. + + In part 2, you'll wire up your greeter to an **LLM agent**. + With this, you can command the agent to call the greeter skill -- to greet -- by simply asking it to "say hello to Alice". + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Prerequisites + + - Ideally you'd have at least skimmed [the previous tutorial](../skill_basics/tutorial.py); but the main gist should still be understandable even if not + - OpenAI API key set in your environment (`OPENAI_API_KEY`) + - Same Python environment as part 1 + + /// tip | API key setup + Make sure there's a `.env` file in your project root with: + ``` + OPENAI_API_KEY=your-key-here + ``` + /// + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Setup + """) + return + + +@app.cell +def _(): + import inspect + + from dotenv import load_dotenv + + load_dotenv() # Load API keys from .env file + + from dimos.agents2.agent import LlmAgent, llm_agent + from dimos.core.blueprints import autoconnect + from docs.tutorials.skill_with_agent.greeter import ( + GreeterForAgents, + RobotCapabilities, + ) + + return ( + GreeterForAgents, + LlmAgent, + RobotCapabilities, + autoconnect, + inspect, + llm_agent, + ) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 1: Enable agent auto-registration + + When you compose a skill module with `llm_agent()`, the framework needs to discover which skills to expose. + In part 1, we did this manually by calling `skill_coordinator.register_skills(greeter)`. + For agent composition, modules declare themselves as skill providers through a hook method. + + ### The auto-registration hook + + Here's our `GreeterForAgents`, which extends `Greeter` from part 1 with one additional method: + """) + return + + +@app.cell +def _(GreeterForAgents, inspect, mo): + mo.ui.code_editor(inspect.getsource(GreeterForAgents), language="python") + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + The method name `set_LlmAgent_register_skills` follows a naming convention that DimOS uses to discover skill providers. + When you call `.build()` on a blueprint containing both this module and `llm_agent()`, this hook is called to register your skills. + + /// note | The naming convention for `set_`-prefixed methods + Suppose you have a method like `set_Mod_some_method` on module `A`. + The blueprint system will try looking for a Module in the combined blueprint named `Mod` with a method named `some_method`. + If it finds such a Module and method, the blueprint system + will call the original method with the matched method; i.e., + in this case, it will call `.set_Mod_some_method(.some_method)`. + /// + + The method body is boilerplate. It does basically the same thing we did with `skill_coordinator.register_skills(greeter)` in the previous tutorial; i.e., it wires up the registration callback. + + + + ### The SkillModule shortcut + + Since this hook is so common, there's a convenience class that adds it for you: + + ```python + from dimos.core.skill_module import SkillModule + + class MySkills(SkillModule): + @skill() + def do_something(self) -> str: + ... + ``` + + `SkillModule` is just `Module` plus the `set_LlmAgent_register_skills` method shown above. + + + You might want to use the explicit pattern + + * when extending an existing class (like we're doing here) + * when you have more than one module that subclasses `LlmAgent` (e.g. in a multi-agent setup) + * or when you need custom serialization + + but otherwise, you can just subclass `SkillModule`. + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 2: Compose with llm_agent + + Now we'll build a blueprint that wires together: + + - `RobotCapabilities` - the low-level mock robot + - `GreeterForAgents` - our skill module with auto-registration + - `llm_agent()` - creates an agent that can reason and call skills + + + """) + return + + +@app.cell +def _(GreeterForAgents, RobotCapabilities, autoconnect, llm_agent): + # Combine the blueprints + blueprint = autoconnect( + RobotCapabilities.blueprint(), # Low-level capabilities + GreeterForAgents.blueprint(), # Our skill (now agent-enabled) + llm_agent( + system_prompt="You are a friendly robot that can greet people. Use the greet skill when asked to say hello to someone." + ), + ).global_config(n_dask_workers=1) + + # Build the combined blueprint + dimos = blueprint.build() + print("System built and running!") + return (dimos,) + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### What just happened? + + - `autoconnect()` combined the individual blueprints + - `.build()` wired everything together: + - Deployed modules to workers + - Called `set_LlmAgent_register_skills` on `GreeterForAgents`, registering its skills with the agent + - Started the agent's processing loop + + (See [the Blueprints concept](../../concepts/blueprints.md) for more on the blueprint system.) + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Step 3: Interact with the agent + + ### In the notebook (interactive) + + For interactive exploration, we can get the agent instance and call `query()` directly. + + This is the same pattern as part 1, except that now we aren't the ones calling the skill; instead, it is *the agent* that decides which skill to invoke. + + + """) + return + + +@app.cell +def _(LlmAgent, dimos): + agent = dimos.get_instance(LlmAgent) + print(f"Got agent instance: {agent}") + return (agent,) + + +@app.cell +def _(agent): + # Ask the agent to greet someone + response = agent.query("Say hello to Bob") + print(response) + return + + +@app.cell +def _(agent): + # Try another greeting + response2 = agent.query("Can you greet Alice?") + print(response2) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + Notice that you didn't have to specify which skill to call or how to call it. + The agent: + + 1. Received your natural language request + 2. Looked at its available tools (including `greet`) + 3. Decided to call `greet` with the appropriate name + 4. Executed the skill and returned the result + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### An alternative terminal-based workflow + + A notebook is helpful for prototyping and experimenting, + but often you would want to do things in the terminal, instead of a notebook. + + DimOS provides TUI helpers to facilitate this workflow: you can run the agent as a standalone script and interact with it via TUIs like `dimos.agents2.cli.human.humancli`. + + **In a terminal pane: Run the agent system** + ```bash + python docs/tutorials/skill_with_agent/cli.py + ``` + + **In another terminal pane: Send messages to the agent** + ```bash + python -m dimos.agents2.cli.human.humancli + ``` + This opens up a TUI that you can use to send messages. + + **Optionally, in yet another terminal pane: Use `agentspy` to monitor agent activity** + ```bash + python -m dimos.utils.cli.agentspy.agentspy + ``` + + This opens a TUI that shows all messages flowing through the agent in real-time: + + - **Human** (green): Messages you send + - **Agent** (yellow): LLM responses and tool calls + - **Tool** (red): Skill execution results + - **System** (red): System prompts + + Useful for debugging or improving your understanding of what's happening under the hood. + Press `q` to quit, `c` to clear. + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## What's happening, under the hood? + + Let's trace the flow from your message to the skill execution: + + ``` + User input ("Say hello to Bob") + | + v + LlmAgent + |-- discovers tools from registered skills + |-- LLM decides: call greet(name="Bob") + v + SkillCoordinator + | + v + GreeterForAgents.greet("Bob") + | + v + RobotCapabilities.speak("Hello, Bob!") + | + v + Result flows back to agent + | + v + Agent responds to user + ``` + + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### The key stages + + 1. **Skill discovery**: When the system starts, `llm_agent` finds modules with `set_LlmAgent_register_skills` and calls that method. + Your `GreeterForAgents` registers its `greet` skill. + + 2. **Tool schema generation**: The agent converts `@skill` methods into tool schemas that the LLM understands. + + + /// tip | Write descriptive docstrings! + Your skill's docstring becomes the tool description. + + + That's why you want the docstrings to be descriptive (and optimized for LLMs), if you are using an LLM agent. + /// + + + 4. **Agent reasoning**: When you type "say hello to Bob", the agent considers available tools and decides to call `greet` with `name="Bob"`. + + 5. **Skill execution**: The agent's tool call goes through `SkillCoordinator`, which invokes your skill method and returns the result. + + 6. **Response**: The agent receives the skill result and formulates a natural language response. + + + + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ### Hands-on exploration + + Want to dig deeper? Try these: + + - Add `print()` statements in `GreeterForAgents.greet()` to see when it's called + - Inspect the skill registry: `dimos.get_instance(LlmAgent).get_tools()` + - Run with `DIMOS_LOG_LEVEL=DEBUG` to see the full message flow + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## Key takeaways + + 1. **Auto-registration via convention**: The `set_LlmAgent_register_skills` method lets the framework wire your skills to agents automatically. + + 2. **Natural language to tool calls**: Agents convert natural language requests into skill invocations, using your docstrings to understand what each skill does. + + 3. **Composition with the blueprint system**: Combine skill modules, `llm_agent()`, and `human_input()` declaratively; DimOS handles the wiring. + """) + return + + +@app.cell(hide_code=True) +def _(mo): + mo.md(""" + ## What's next + + - **Add more skills**: Try adding a `farewell` skill and see how the agent uses both. + - **Stream progress**: For long-running skills, use `stream=Stream.call_agent` to send updates to the agent as work progresses. + - **Explore real robot skills**: Check out `dimos/agents2/skills/navigation.py` for examples of navigation skills. + - **Multi-agent systems**: See the [multi-agent tutorial](../multi_agent/tutorial.md). + + ## See also + + - [Agents concept guide](../../concepts/agent.md) + - [Blueprints concept guide](../../concepts/blueprints.md) + """) + return + + +@app.cell +def _(dimos): + # Gracefully shut down / release resources + dimos.stop() + return + + +if __name__ == "__main__": + app.run() diff --git a/examples/web/edge_io.py b/examples/web/edge_io.py deleted file mode 100644 index 0a791c2fde..0000000000 --- a/examples/web/edge_io.py +++ /dev/null @@ -1,188 +0,0 @@ -from flask import Flask, jsonify, request, Response, render_template -from ..types.media_provider import VideoProviderExample -from ..agents.agent import OpenAI_Agent - -import cv2 -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler -from reactivex.subject import BehaviorSubject -import numpy as np - -from queue import Queue - -class EdgeIO(): - def __init__(self, dev_name:str="NA", edge_type:str="Base"): - self.dev_name = dev_name - self.edge_type = edge_type - self.disposables = CompositeDisposable() - - def dispose_all(self): - """Disposes of all active subscriptions managed by this agent.""" - self.disposables.dispose() - -# TODO: Frame processing was moved to its own class. Fix this impl. -class FlaskServer(EdgeIO): - def __init__(self, dev_name="Flask Server", edge_type="Bidirectional", port=5555, - frame_obs=None, frame_edge_obs=None, frame_optical_obs=None): - super().__init__(dev_name, edge_type) - self.app = Flask(__name__) - self.port = port - self.frame_obs = frame_obs - self.frame_edge_obs = frame_edge_obs - self.frame_optical_obs = frame_optical_obs - self.setup_routes() - - # TODO: Move these processing blocks to a processor block - def process_frame_flask(self, frame): - """Convert frame to JPEG format for streaming.""" - _, buffer = cv2.imencode('.jpg', frame) - return buffer.tobytes() - - def setup_routes(self): - # TODO: Fix - # @self.app.route('/start', methods=['GET']) - # def start_processing(): - # """Endpoint to start video processing.""" - # self.agent.subscribe_to_image_processing(self.frame_obs) - # return jsonify({"status": "Processing started"}), 200 - - # TODO: Fix - # @self.app.route('/stop', methods=['GET']) - # def stop_processing(): - # """Endpoint to stop video processing.""" - # self.agent.dispose_all() - # return jsonify({"status": "Processing stopped"}), 200 - - @self.app.route('/') - def index(): - status_text = "The video stream is currently active." - return render_template('index.html', status_text=status_text) - - @self.app.route('/video_feed') - def video_feed(): - def generate(): - frame_queue = Queue() - - def on_next(frame): - frame_queue.put(frame) - - def on_error(e): - print(f"Error in streaming: {e}") - frame_queue.put(None) # Use None to signal an error or completion. - - def on_completed(): - print("Stream completed") - frame_queue.put(None) # Signal completion to the generator. - - disposable_flask = self.frame_obs.subscribe( - on_next=lambda frame: self.flask_frame_subject.on_next(frame), - on_error=lambda e: print(f"Error: {e}"), - on_completed=lambda: self.flask_frame_subject.on_next(None), - # scheduler=scheduler - ) - - # Subscribe to the BehaviorSubject - disposable = self.flask_frame_subject.pipe( - ops.map(self.process_frame_flask), - ).subscribe(on_next, on_error, on_completed) - - self.disposables.add(disposable_flask) - self.disposables.add(disposable) - - try: - while True: - frame = frame_queue.get() # Wait for the next frame - if frame is None: # Check if there's a signal to stop. - break - yield (b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') - finally: - disposable_flask.dispose() - disposable.dispose() - - return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') - - @self.app.route('/video_feed_edge') - def video_feed_edge(): - def generate(): - frame_queue = Queue() - - def on_next(frame): - frame_queue.put(frame) - - def on_error(e): - print(f"Error in streaming: {e}") - frame_queue.put(None) # Use None to signal an error or completion. - - def on_completed(): - print("Stream completed") - frame_queue.put(None) # Signal completion to the generator. - - - - disposable_flask = self.frame_edge_obs.subscribe( - on_next=lambda frame: self.flask_frame_subject.on_next(frame), - on_error=lambda e: print(f"Error: {e}"), - on_completed=lambda: self.flask_frame_subject.on_next(None), - # scheduler=scheduler - ) - - # Subscribe to the BehaviorSubject - disposable = self.flask_frame_subject.pipe( - ops.subscribe_on(CurrentThreadScheduler()), - ops.map(self.process_frame_edge_detection), - ops.map(self.process_frame_flask), - ).subscribe(on_next, on_error, on_completed) - - self.disposables.add(disposable_flask) - self.disposables.add(disposable) - - try: - while True: - frame = frame_queue.get() # Wait for the next frame - if frame is None: # Check if there's a signal to stop. - break - yield (b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') - finally: - disposable_flask.dispose() - disposable.dispose() - - return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') - - @self.app.route('/video_feed_optical') - def video_feed_optical(): - def generate(): - frame_queue = Queue() - - def on_next(frame): - frame_queue.put(frame) - - def on_error(e): - print(f"Error in streaming: {e}") - frame_queue.put(None) # Use None to signal an error or completion. - - def on_completed(): - print("Stream completed") - frame_queue.put(None) # Signal completion to the generator. - - # Subscribe to the BehaviorSubject - disposable = self.frame_optical_obs.subscribe(on_next, on_error, on_completed) - - try: - while True: - frame = frame_queue.get() # Wait for the next frame - if frame is None: # Check if there's a signal to stop. - continue - yield (b'--frame\r\n' - b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') - finally: - disposable.dispose() - - return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') - - def run(self, host='0.0.0.0', port=5555): - self.port = port - self.app.run(host=host, port=self.port, debug=True) - diff --git a/examples/web/templates/index.html b/examples/web/templates/index.html deleted file mode 100644 index e112d0f3c5..0000000000 --- a/examples/web/templates/index.html +++ /dev/null @@ -1,27 +0,0 @@ - - - - - - Video Stream Example - - -

Live Video Feed

- Video Feed - -

Live Edge Detection Feed

- Video Feed - -

Live Optical Flow Feed

- Video Feed - -

Current Status: {{ status_text }}

- - - - - diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000..e6d920a293 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1748929857, + "narHash": "sha256-lcZQ8RhsmhsK8u7LIFsJhsLh/pzR9yZ8yqpTzyGdj+Q=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c2a03962b8e24e669fb37b7df10e7c79531ff1a4", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000..75e68f595e --- /dev/null +++ b/flake.nix @@ -0,0 +1,104 @@ +{ + description = "Project dev environment as Nix shell + DockerTools layered image"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils, ... }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { inherit system; }; + + # ------------------------------------------------------------ + # 1. Shared package list (tool-chain + project deps) + # ------------------------------------------------------------ + devPackages = with pkgs; [ + ### Core shell & utils + bashInteractive coreutils gh + stdenv.cc.cc.lib pcre2 + + ### Python + static analysis + python312 python312Packages.pip python312Packages.setuptools + python312Packages.virtualenv pre-commit + + ### Runtime deps + python312Packages.pyaudio portaudio ffmpeg_6 ffmpeg_6.dev + + ### Graphics / X11 stack + libGL libGLU mesa glfw + xorg.libX11 xorg.libXi xorg.libXext xorg.libXrandr xorg.libXinerama + xorg.libXcursor xorg.libXfixes xorg.libXrender xorg.libXdamage + xorg.libXcomposite xorg.libxcb xorg.libXScrnSaver xorg.libXxf86vm + + udev SDL2 SDL2.dev zlib + + ### GTK / OpenCV helpers + glib gtk3 gdk-pixbuf gobject-introspection + + ### GStreamer + gst_all_1.gstreamer gst_all_1.gst-plugins-base gst_all_1.gst-plugins-good + gst_all_1.gst-plugins-bad gst_all_1.gst-plugins-ugly + python312Packages.gst-python + + ### Open3D & build-time + eigen cmake ninja jsoncpp libjpeg libjpeg_turbo libpng + ### LCM (Lightweight Communications and Marshalling) + lcm + ]; + + # ------------------------------------------------------------ + # 2. Host interactive shell → `nix develop` + # ------------------------------------------------------------ + devShell = pkgs.mkShell { + packages = devPackages; + shellHook = '' + export LD_LIBRARY_PATH="${pkgs.lib.makeLibraryPath [ + pkgs.stdenv.cc.cc.lib pkgs.libGL pkgs.libGLU pkgs.mesa pkgs.glfw + pkgs.xorg.libX11 pkgs.xorg.libXi pkgs.xorg.libXext pkgs.xorg.libXrandr + pkgs.xorg.libXinerama pkgs.xorg.libXcursor pkgs.xorg.libXfixes + pkgs.xorg.libXrender pkgs.xorg.libXdamage pkgs.xorg.libXcomposite + pkgs.xorg.libxcb pkgs.xorg.libXScrnSaver pkgs.xorg.libXxf86vm + pkgs.udev pkgs.portaudio pkgs.SDL2.dev pkgs.zlib pkgs.glib pkgs.gtk3 + pkgs.gdk-pixbuf pkgs.gobject-introspection pkgs.lcm pkgs.pcre2 + pkgs.gst_all_1.gstreamer pkgs.gst_all_1.gst-plugins-base pkgs.libjpeg_turbo]}:$LD_LIBRARY_PATH" + + export DISPLAY=:0 + export GI_TYPELIB_PATH="${pkgs.gst_all_1.gstreamer}/lib/girepository-1.0:${pkgs.gst_all_1.gst-plugins-base}/lib/girepository-1.0:$GI_TYPELIB_PATH" + + PROJECT_ROOT=$(git rev-parse --show-toplevel 2>/dev/null || echo "$PWD") + if [ -f "$PROJECT_ROOT/env/bin/activate" ]; then + . "$PROJECT_ROOT/env/bin/activate" + fi + + [ -f "$PROJECT_ROOT/motd" ] && cat "$PROJECT_ROOT/motd" + [ -f "$PROJECT_ROOT/.pre-commit-config.yaml" ] && pre-commit install --install-hooks + ''; + }; + + # ------------------------------------------------------------ + # 3. Closure copied into the OCI image rootfs + # ------------------------------------------------------------ + imageRoot = pkgs.buildEnv { + name = "dimos-image-root"; + paths = devPackages; + pathsToLink = [ "/bin" ]; + }; + + in { + ## Local dev shell + devShells.default = devShell; + + ## Layered docker image with DockerTools + packages.devcontainer = pkgs.dockerTools.buildLayeredImage { + name = "dimensionalos/dimos-dev"; + tag = "latest"; + contents = [ imageRoot ]; + config = { + WorkingDir = "/workspace"; + Cmd = [ "bash" ]; + }; + }; + }); +} diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000000..379fed2831 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,133 @@ +site_name: DimOS Documentation +site_author: Dimensional Team + +hooks: + - docs/hooks.py +site_description: The universal framework for AI-native generalist robotics +repo_name: dimensionalOS/dimos +repo_url: https://github.com/dimensionalOS/dimos +site_url: https://dimensionalos.com/ +edit_uri: edit/main/docs/ + +theme: + name: material + icon: + repo: fontawesome/brands/github + features: + - navigation.tabs # Moves top-level nav to header tabs + - navigation.tabs.sticky # Keeps tabs visible when scrolling + - navigation.indexes # Allows section index pages + - navigation.expand # Expands navigation by default + - navigation.instant # SPA-like navigation + - navigation.top # Back to top button + - navigation.tracking # URL updates with scroll + - search.highlight # Highlight search terms + - search.share # Share search results + - search.suggest # Search suggestions + - toc.follow # TOC follows scroll position + - content.code.copy # Copy button for code blocks + - content.code.select # Select code blocks + palette: + - scheme: default + primary: black + accent: indigo + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - scheme: slate + primary: black + accent: indigo + toggle: + icon: material/brightness-4 + name: Switch to light mode + font: + text: Roboto + code: Roboto Mono + +markdown_extensions: + - abbr + - admonition + - github-callouts + - pymdownx.details + - attr_list + - def_list + - md_in_html + - toc: + permalink: true + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.superfences + - pymdownx.tabbed: + alternate_style: true + - pymdownx.tasklist: + custom_checkbox: true + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + - pymdownx.snippets: + base_path: ['.'] + +nav: + - Home: index.md + - Quickstart: quickstart.md + - Tutorials: + - tutorials/index.md + - Build your first skill: tutorials/skill_basics/tutorial.md + - Equip an agent with skills: tutorials/skill_with_agent/tutorial.md + - Build a multi-agent RoboButler: tutorials/multi_agent/tutorial.md + - Concepts: + - concepts/index.md + - Modules: concepts/modules.md + - Skills: concepts/skills.md + - Agent: concepts/agent.md + - Transport: concepts/transport.md + - Blueprints: concepts/blueprints.md + - API Reference: + - api/index.md + - Agents: api/agents.md + - Skills: api/skills.md + - CLI Tools: api/cli_tools.md + +plugins: + - search + - marimo + - mkdocstrings: + handlers: + python: + extensions: + - griffe_typingdoc + options: + docstring_style: google + members_order: alphabetical + allow_inspection: true + show_bases: true + show_source: true + signature_crossrefs: true + # llms.txt-es are auto-generated + - llmstxt: + markdown_description: > + DimOS is the universal framework for AI-native generalist robotics. + It provides a comprehensive platform for building intelligent robotic systems + with seamless integration of perception, navigation, manipulation, and AI agents. + full_output: llms-full.txt + # TODO: Look into generating the sections from a literate-nav SUMMARY.md in the future + # to avoid having to manually synchronize with nav + sections: + Home: + - index.md: Introduction to DimOS + Quickstart: + - quickstart.md + Tutorials: + - tutorials/*.md + Concepts: + - concepts/*.md + API Reference: + - api/*.md + +extra: + social: + - icon: fontawesome/brands/github + link: https://github.com/dimensionalOS/dimos diff --git a/onnx/metric3d_vit_small.onnx b/onnx/metric3d_vit_small.onnx new file mode 100644 index 0000000000..bfddd41628 --- /dev/null +++ b/onnx/metric3d_vit_small.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14805174265dd721ac3b396bd5ee7190c708cec41150ed298267f6c3126bc060 +size 151333865 diff --git a/pyproject.toml b/pyproject.toml index 46c2cf325d..1e334dde2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,319 @@ +[build-system] +requires = ["setuptools>=70", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] +include = ["dimos*"] + +[tool.setuptools.package-data] +"*" = ["*.html", "*.css", "*.js", "*.json", "*.txt", "*.yaml", "*.yml"] + [project] name = "dimos" authors = [ - {name = "Stash Pomichter", email = "stash@dimensionalOS.com"}, + {name = "Dimensional Team", email = "build@dimensionalOS.com"}, +] +version = "0.0.4" +description = "Powering agentive generalist robotics" +requires-python = ">=3.10" + +dependencies = [ + # Core requirements + "opencv-python", + "python-dotenv", + "openai", + "anthropic>=0.19.0", + "cerebras-cloud-sdk", + "moondream", + "numpy>=1.26.4,<2.0.0", + "colorlog==6.9.0", + "yapf==0.40.2", + "typeguard", + "empy==3.3.4", + "catkin_pkg", + "lark", + "plum-dispatch==2.5.7", + "ffmpeg-python", + "tiktoken>=0.8.0", + "Flask>=2.2", + "python-multipart==0.0.20", + "reactivex", + "rxpy-backpressure @ git+https://github.com/dimensionalOS/rxpy-backpressure.git", + "asyncio==3.4.3", + "go2-webrtc-connect @ git+https://github.com/dimensionalOS/go2_webrtc_connect.git", + "tensorzero==2025.7.5", + "structlog>=25.5.0,<26", + + # Web Extensions + "fastapi>=0.115.6", + "sse-starlette>=2.2.1", + "uvicorn>=0.34.0", + + # Agents + "langchain>=0.3.27,<1", + "langchain-chroma>=0.2.6,<1", + "langchain-core>=0.3.79,<1", + "langchain-openai>=0.3.33,<1", + "langchain-text-splitters>=0.3.11,<1", + "langchain-huggingface>=0.3.1,<1", + "langchain-ollama>=0.3.10,<1", + "bitsandbytes>=0.48.2,<1.0", + "ollama>=0.6.0", + + # Class Extraction + "pydantic", + + # For documenting parameters etc + "annotated-doc", + + # Developer Specific + "ipykernel", + + # Unitree webrtc streaming + "aiortc==1.9.0", + "pycryptodome", + "sounddevice", + "pyaudio", + "requests", + "wasmtime", + + # Image + "PyTurboJPEG==1.8.2", + + # Audio + "openai-whisper", + "soundfile", + + # Hugging Face + "transformers[torch]==4.49.0", + + # Vector Embedding + "sentence_transformers", + + # Perception Dependencies + "ultralytics>=8.3.70", + "filterpy>=1.4.5", + "scipy>=1.15.1", + "scikit-learn", + "Pillow", + "clip @ git+https://github.com/openai/CLIP.git", + "timm>=1.0.15", + "lap>=0.5.12", + "opencv-contrib-python==4.10.0.84", + + # Mapping + "open3d", + "googlemaps>=4.10.0", + + # Inference + "onnx", + + # Multiprocess + "dask[complete]==2025.5.1", + + # LCM / DimOS utilities + "dimos-lcm @ git+https://github.com/dimensionalOS/dimos-lcm.git@3aeb724863144a8ba6cf72c9f42761d1007deda4", + + # CLI + "pydantic-settings>=2.11.0,<3", + "typer>=0.19.2,<1", +] + +[project.scripts] +lcmspy = "dimos.utils.cli.lcmspy.run_lcmspy:main" +foxglove-bridge = "dimos.utils.cli.foxglove_bridge.run_foxglove_bridge:main" +skillspy = "dimos.utils.cli.skillspy.skillspy:main" +agentspy = "dimos.utils.cli.agentspy.agentspy:main" +humancli = "dimos.utils.cli.human.humanclianim:main" +dimos = "dimos.robot.cli.dimos:main" + +[project.optional-dependencies] +manipulation = [ + + # Contact Graspnet Dependencies + "h5py>=3.7.0", + "pyrender>=0.1.45", + "trimesh>=3.22.0", + "python-fcl>=0.7.0.4", + "pyquaternion>=0.9.9", + "matplotlib>=3.7.1", + "rtree", + "pandas>=1.5.2", + "tqdm>=4.65.0", + "pyyaml>=6.0", + "contact-graspnet-pytorch @ git+https://github.com/dimensionalOS/contact_graspnet_pytorch.git", + + # piper arm + "piper-sdk", + + # Visualization (Optional) + "kaleido>=0.2.1", + "plotly>=5.9.0", +] + + +cpu = [ + # CPU inference backends + "onnxruntime", + "ctransformers==0.2.27", +] + +cuda = [ + "cupy-cuda12x==13.6.0", + "nvidia-nvimgcodec-cu12[all]", + "onnxruntime-gpu>=1.17.1", # Only versions supporting both cuda11 and cuda12 + "ctransformers[cuda]==0.2.27", + "mmengine>=0.10.3", + "mmcv>=2.1.0", + "xformers>=0.0.20", + + # Detic GPU stack + "mss", + "dataclasses", + "ftfy", + "regex", + "fasttext", + "lvis", + "nltk", + "clip @ git+https://github.com/openai/CLIP.git", + "detectron2 @ git+https://github.com/facebookresearch/detectron2.git@v0.6", + + # embedding models + "open_clip_torch>=3.0.0", + "torchreid==0.2.5", +] + +dev = [ + "ruff==0.14.3", + "mypy==1.18.2", + "pre_commit==4.2.0", + "pytest==8.3.5", + "pytest-asyncio==0.26.0", + "pytest-mock==3.15.0", + "pytest-env==1.1.5", + "pytest-timeout==2.4.0", + "xdoctest>=1.3.0", + "textual==3.7.1", + "requests-mock==1.12.1", + "terminaltexteffects==0.12.2", + + # Types + "lxml-stubs>=0.5.1,<1", + "pandas-stubs>=2.3.2.250926,<3", + "types-PySocks>=1.7.1.20251001,<2", + "types-PyYAML>=6.0.12.20250915,<7", + "types-colorama>=0.4.15.20250801,<1", + "types-defusedxml>=0.7.0.20250822,<1", + "types-gevent>=25.4.0.20250915,<26", + "types-greenlet>=3.2.0.20250915,<4", + "types-jmespath>=1.0.2.20250809,<2", + "types-jsonschema>=4.25.1.20251009,<5", + "types-networkx>=3.5.0.20251001,<4", + "types-protobuf>=6.32.1.20250918,<7", + "types-psutil>=7.0.0.20251001,<8", + "types-pytz>=2025.2.0.20250809,<2026", + "types-simplejson>=3.20.0.20250822,<4", + "types-tabulate>=0.9.0.20241207,<1", + "types-tensorflow>=2.18.0.20251008,<3", + "types-tqdm>=4.67.0.20250809,<5", +] + +sim = [ + # Simulation + "mujoco>=3.3.4", + "playground>=0.0.5", + "pygame>=2.6.1", +] + +docs = [ + # Documentation + "mkdocs>=1.6.0", + "mkdocs-material>=9.5.0", + "mkdocstrings[python]>=0.26.0", + "mkdocs-llmstxt>=0.1.0", + "griffe-typingdoc>=0.2.0", + "markdown-callouts", + "marimo[mcp]", # for tutorials; mcp extra enables MCP server for coding agents + "mkdocs-marimo", + "psutil", # for docs/hooks.py process management +] + +jetson-jp6-cuda126 = [ + # Jetson Jetpack 6.2 with CUDA 12.6 specific wheels + # Note: Alternative torch wheel from docs: https://developer.download.nvidia.com/compute/redist/jp/v61/pytorch/torch-2.5.0a0+872d972e41.nv24.08.17622132-cp310-cp310-linux_aarch64.whl + "torch @ https://pypi.jetson-ai-lab.io/jp6/cu126/+f/564/4d4458f1ba159/torch-2.8.0-cp310-cp310-linux_aarch64.whl", + "torchvision @ https://pypi.jetson-ai-lab.io/jp6/cu126/+f/1c0/3de08a69e9554/torchvision-0.23.0-cp310-cp310-linux_aarch64.whl", + "onnxruntime-gpu @ https://pypi.jetson-ai-lab.io/jp6/cu126/+f/4eb/e6a8902dc7708/onnxruntime_gpu-1.23.0-cp310-cp310-linux_aarch64.whl", + "xformers @ https://pypi.jetson-ai-lab.io/jp6/cu126/+f/731/15133b0ebb2b3/xformers-0.0.33+ac00641.d20250830-cp39-abi3-linux_aarch64.whl", +] + +[tool.ruff] +line-length = 100 +exclude = [ + ".git", + ".pytest_cache", + ".ruff_cache", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "libs", + "external", + "src" +] + +[tool.ruff.lint] +extend-select = ["E", "W", "F", "B", "UP", "N", "I", "C90", "A", "RUF", "TCH", "D300"] +# TODO: All of these should be fixed, but it's easier commit autofixes first +ignore = ["A001", "A002", "B008", "B017", "B019", "B023", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F401", "F403", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N812", "N813", "N813", "N816", "N817", "N999", "RUF002", "RUF003", "RUF006", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007"] + +[tool.ruff.lint.per-file-ignores] +"dimos/models/Detic/*" = ["ALL"] + +[tool.ruff.lint.isort] +known-first-party = ["dimos"] +combine-as-imports = true +force-sort-within-sections = true + +[tool.mypy] +python_version = "3.12" +incremental = true +strict = true +exclude = "^dimos/models/Detic(/|$)|.*/test_.|.*/conftest.py*" + +[tool.pytest.ini_options] +testpaths = ["dimos"] +markers = [ + "heavy: resource heavy test", + "vis: marks tests that run visuals and require a visual check by dev", + "benchmark: benchmark, executes something multiple times, calculates avg, prints to console", + "exclude: arbitrary exclusion from CI and default test exec", + "tool: dev tooling", + "needsdata: needs test data to be downloaded", + "ros: depend on ros", + "lcm: tests that run actual LCM bus (can't execute in CI)", + "module: tests that need to run directly as modules", + "gpu: tests that require GPU", + "tofix: temporarily disabled test" +] +env = [ + "GOOGLE_MAPS_API_KEY=AIzafake_google_key" ] -version = "0.0.0" -description = "Coming soon" +addopts = "-v -p no:warnings -ra --color=yes -m 'not vis and not benchmark and not exclude and not tool and not needsdata and not lcm and not ros and not heavy and not gpu and not module and not tofix'" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +# xdoctest configuration +# (Not running on CI yet. Run with e.g. `xdoctest --global-exec "import dimos"`; can also use pytest to run it) +# Chose xdoctest because it's better than vanilla doctest at floating-point comparisons, and because it uses a less fragile AST-based approach. It's used by PyTorch. +# Disable built-in doctest to avoid conflicts +xdoctest_style = "google" diff --git a/requirements.txt b/requirements.txt index aef36b8ab3..808a539624 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,96 @@ opencv-python python-dotenv openai -numpy +anthropic>=0.19.0 +cerebras-cloud-sdk +numpy>=1.26.4,<2.0.0 +colorlog==6.9.0 +yapf==0.40.2 +typeguard +empy==3.3.4 +catkin_pkg +lark +plum-dispatch==2.5.7 # pycolmap - -numpy ffmpeg-python pytest python-dotenv openai +tiktoken>=0.8.0 Flask>=2.2 +python-multipart==0.0.20 reactivex +git+https://github.com/dimensionalOS/rxpy-backpressure.git +pytest-asyncio==0.26.0 +asyncio==3.4.3 +-e git+https://github.com/dimensionalOS/go2_webrtc_connect.git#egg=go2_webrtc_connect +# Web Extensions +fastapi>=0.115.6 +sse-starlette>=2.2.1 +uvicorn>=0.34.0 -# Agent Memory -langchain-chroma>=0.1.2 +# Agent Memory +langchain-chroma>=0.1.4 langchain-openai>=0.2.14 + +# Class Extraction +pydantic + +# Developer Specific +ipykernel + +# Audio +openai-whisper +soundfile + +#Hugging Face +transformers[torch]==4.49.0 + +#Vector Embedding +sentence_transformers + +# CTransforms GGUF - GPU required +ctransformers[cuda]==0.2.27 + +# Perception Dependencies +ultralytics>=8.3.70 +filterpy>=1.4.5 +scipy>=1.15.1 +opencv-python==4.10.0.84 +opencv-contrib-python==4.10.0.84 +scikit-learn +Pillow +mmengine>=0.10.3 +mmcv>=2.1.0 +timm>=1.0.15 +lap>=0.5.12 +xformers==0.0.20 + +# Detic +opencv-python +mss +timm +dataclasses +ftfy +regex +fasttext +scikit-learn +lvis +nltk +git+https://github.com/openai/CLIP.git +git+https://github.com/facebookresearch/detectron2.git@v0.6 + +# Mapping +open3d + +# Inference (CPU) +onnxruntime +onnx + +# Terminal colors +rich==14.0.0 + +# multiprocess +dask[complete]==2025.5.1 +git+https://github.com/dimensionalOS/python_lcm_msgs@main#egg=lcm_msgs diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..726740e770 --- /dev/null +++ b/setup.py @@ -0,0 +1,20 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import find_packages, setup + +setup( + packages=find_packages(), + package_dir={"": "."}, +) diff --git a/tests/__init__.py b/tests/__init__.py index 8b13789179..e69de29bb2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +0,0 @@ - diff --git a/tests/agent_manip_flow_fastapi_test.py b/tests/agent_manip_flow_fastapi_test.py new file mode 100644 index 0000000000..d1daea9638 --- /dev/null +++ b/tests/agent_manip_flow_fastapi_test.py @@ -0,0 +1,149 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module initializes and manages the video processing pipeline integrated with a web server. +It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. +""" + +# ----- +# Standard library imports +import multiprocessing +import os + +from dotenv import load_dotenv + +# Third-party imports +from reactivex import operators as ops +from reactivex.disposable import CompositeDisposable +from reactivex.scheduler import ThreadPoolScheduler + +# Local application imports +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.video_operators import VideoOperators as vops +from dimos.stream.video_provider import VideoProvider +from dimos.web.fastapi_server import FastAPIServer + +# Load environment variables +load_dotenv() + + +def main(): + """ + Initializes and runs the video processing pipeline with web server output. + + This function orchestrates a video processing system that handles capture, processing, + and visualization of video streams. It demonstrates parallel processing capabilities + and various video manipulation techniques across multiple stages including capture + and processing at different frame rates, edge detection, and optical flow analysis. + + Raises: + RuntimeError: If video sources are unavailable or processing fails. + """ + CompositeDisposable() + + processor = FrameProcessor( + output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True + ) + + optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores + thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + VIDEO_SOURCES = [ + f"{os.getcwd()}/assets/ldru.mp4", + f"{os.getcwd()}/assets/ldru_480p.mp4", + f"{os.getcwd()}/assets/trimmed_video_480p.mov", + f"{os.getcwd()}/assets/video-f30-480p.mp4", + "rtsp://192.168.50.207:8080/h264.sdp", + "rtsp://10.0.0.106:8080/h264.sdp", + ] + + VIDEO_SOURCE_INDEX = 3 + VIDEO_SOURCE_INDEX_2 = 2 + + my_video_provider = VideoProvider("Video File", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX]) + my_video_provider_2 = VideoProvider( + "Video File 2", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX_2] + ) + + video_stream_obs = my_video_provider.capture_video_as_observable(fps=120).pipe( + ops.subscribe_on(thread_pool_scheduler), + # Move downstream operations to thread pool for parallel processing + # Disabled: Evaluating performance impact + # ops.observe_on(thread_pool_scheduler), + vops.with_jpeg_export(processor, suffix="raw"), + vops.with_fps_sampling(fps=30), + vops.with_jpeg_export(processor, suffix="raw_slowed"), + ) + + video_stream_obs_2 = my_video_provider_2.capture_video_as_observable(fps=120).pipe( + ops.subscribe_on(thread_pool_scheduler), + # Move downstream operations to thread pool for parallel processing + # Disabled: Evaluating performance impact + # ops.observe_on(thread_pool_scheduler), + vops.with_jpeg_export(processor, suffix="raw_2"), + vops.with_fps_sampling(fps=30), + vops.with_jpeg_export(processor, suffix="raw_2_slowed"), + ) + + edge_detection_stream_obs = processor.process_stream_edge_detection(video_stream_obs).pipe( + vops.with_jpeg_export(processor, suffix="edge"), + ) + + optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow_with_relevancy( + video_stream_obs + ) + + optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( + ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), + vops.with_optical_flow_filtering(threshold=2.0), + ops.do_action(lambda _: print("Optical Flow Passed Threshold.")), + vops.with_jpeg_export(processor, suffix="optical"), + ) + + # + # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== + # + + # Agent 1 + # my_agent = OpenAIAgent( + # "Agent 1", + # query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.") + # my_agent.subscribe_to_image_processing(slowed_video_stream_obs) + # disposables.add(my_agent.disposables) + + # # Agent 2 + # my_agent_two = OpenAIAgent( + # "Agent 2", + # query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.") + # my_agent_two.subscribe_to_image_processing(optical_flow_stream_obs) + # disposables.add(my_agent_two.disposables) + + # + # ====== Create and start the FastAPI server ====== + # + + # Will be visible at http://[host]:[port]/video_feed/[key] + streams = { + "video_one": video_stream_obs, + "video_two": video_stream_obs_2, + "edge_detection": edge_detection_stream_obs, + "optical_flow": optical_flow_stream_obs, + } + fast_api_server = FastAPIServer(port=5555, **streams) + fast_api_server.run() + + +if __name__ == "__main__": + main() diff --git a/tests/agent_manip_flow_flask_test.py b/tests/agent_manip_flow_flask_test.py new file mode 100644 index 0000000000..01118f7d0a --- /dev/null +++ b/tests/agent_manip_flow_flask_test.py @@ -0,0 +1,192 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module initializes and manages the video processing pipeline integrated with a web server. +It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. +""" + +# ----- +# Standard library imports +import multiprocessing +import os + +from dotenv import load_dotenv + +# Third-party imports +from flask import Flask +from reactivex import interval, operators as ops, zip as rx_zip +from reactivex.disposable import CompositeDisposable +from reactivex.scheduler import ThreadPoolScheduler + +# Local application imports +from dimos.agents.agent import OpenAIAgent +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.video_operators import VideoOperators as vops +from dimos.stream.video_provider import VideoProvider +from dimos.web.flask_server import FlaskServer + +# Load environment variables +load_dotenv() + +app = Flask(__name__) + + +def main(): + """ + Initializes and runs the video processing pipeline with web server output. + + This function orchestrates a video processing system that handles capture, processing, + and visualization of video streams. It demonstrates parallel processing capabilities + and various video manipulation techniques across multiple stages including capture + and processing at different frame rates, edge detection, and optical flow analysis. + + Raises: + RuntimeError: If video sources are unavailable or processing fails. + """ + disposables = CompositeDisposable() + + processor = FrameProcessor( + output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True + ) + + optimal_thread_count = multiprocessing.cpu_count() # Gets number of CPU cores + thread_pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + VIDEO_SOURCES = [ + f"{os.getcwd()}/assets/ldru.mp4", + f"{os.getcwd()}/assets/ldru_480p.mp4", + f"{os.getcwd()}/assets/trimmed_video_480p.mov", + f"{os.getcwd()}/assets/video-f30-480p.mp4", + f"{os.getcwd()}/assets/video.mov", + "rtsp://192.168.50.207:8080/h264.sdp", + "rtsp://10.0.0.106:8080/h264.sdp", + f"{os.getcwd()}/assets/people_1080p_24fps.mp4", + ] + + VIDEO_SOURCE_INDEX = 4 + + my_video_provider = VideoProvider("Video File", video_source=VIDEO_SOURCES[VIDEO_SOURCE_INDEX]) + + video_stream_obs = my_video_provider.capture_video_as_observable(fps=120).pipe( + ops.subscribe_on(thread_pool_scheduler), + # Move downstream operations to thread pool for parallel processing + # Disabled: Evaluating performance impact + # ops.observe_on(thread_pool_scheduler), + # vops.with_jpeg_export(processor, suffix="raw"), + vops.with_fps_sampling(fps=30), + # vops.with_jpeg_export(processor, suffix="raw_slowed"), + ) + + processor.process_stream_edge_detection(video_stream_obs).pipe( + # vops.with_jpeg_export(processor, suffix="edge"), + ) + + optical_flow_relevancy_stream_obs = processor.process_stream_optical_flow(video_stream_obs) + + optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( + # ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), + # vops.with_optical_flow_filtering(threshold=2.0), + # ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), + # vops.with_jpeg_export(processor, suffix="optical") + ) + + # + # ====== Agent Orchastrator (Qu.s Awareness, Temporality, Routing) ====== + # + + # Observable that emits every 2 seconds + secondly_emission = interval(2, scheduler=thread_pool_scheduler).pipe( + ops.map(lambda x: f"Second {x + 1}"), + # ops.take(30) + ) + + # Agent 1 + my_agent = OpenAIAgent( + "Agent 1", + query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.", + json_mode=False, + ) + + # Create an agent for each subset of questions that it would be theroized to handle. + # Set std. template/blueprints, and devs will add to that likely. + + ai_1_obs = video_stream_obs.pipe( + # vops.with_fps_sampling(fps=30), + # ops.throttle_first(1), + vops.with_jpeg_export(processor, suffix="open_ai_agent_1"), + ops.take(30), + ops.replay(buffer_size=30, scheduler=thread_pool_scheduler), + ) + ai_1_obs.connect() + + ai_1_repeat_obs = ai_1_obs.pipe(ops.repeat()) + + my_agent.subscribe_to_image_processing(ai_1_obs) + disposables.add(my_agent.disposables) + + # Agent 2 + my_agent_two = OpenAIAgent( + "Agent 2", + query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.", + max_input_tokens_per_request=1000, + max_output_tokens_per_request=300, + json_mode=False, + model_name="gpt-4o-2024-08-06", + ) + + ai_2_obs = optical_flow_stream_obs.pipe( + # vops.with_fps_sampling(fps=30), + # ops.throttle_first(1), + vops.with_jpeg_export(processor, suffix="open_ai_agent_2"), + ops.take(30), + ops.replay(buffer_size=30, scheduler=thread_pool_scheduler), + ) + ai_2_obs.connect() + + ai_2_repeat_obs = ai_2_obs.pipe(ops.repeat()) + + # Combine emissions using rx_zip + ai_1_secondly_repeating_obs = rx_zip(secondly_emission, ai_1_repeat_obs).pipe( + # ops.do_action(lambda s: print(f"AI 1 - Emission Count: {s[0]}")), + ops.map(lambda r: r[1]), + ) + + # Combine emissions using rx_zip + ai_2_secondly_repeating_obs = rx_zip(secondly_emission, ai_2_repeat_obs).pipe( + # ops.do_action(lambda s: print(f"AI 2 - Emission Count: {s[0]}")), + ops.map(lambda r: r[1]), + ) + + my_agent_two.subscribe_to_image_processing(ai_2_obs) + disposables.add(my_agent_two.disposables) + + # + # ====== Create and start the Flask server ====== + # + + # Will be visible at http://[host]:[port]/video_feed/[key] + flask_server = FlaskServer( + # video_one=video_stream_obs, + # edge_detection=edge_detection_stream_obs, + # optical_flow=optical_flow_stream_obs, + OpenAIAgent_1=ai_1_secondly_repeating_obs, + OpenAIAgent_2=ai_2_secondly_repeating_obs, + ) + + flask_server.run(threaded=True) + + +if __name__ == "__main__": + main() diff --git a/tests/agent_manip_flow_test.py b/tests/agent_manip_flow_test.py deleted file mode 100644 index 558adabb46..0000000000 --- a/tests/agent_manip_flow_test.py +++ /dev/null @@ -1,124 +0,0 @@ -from datetime import timedelta -import sys -import os - -# Add the parent directory of 'tests' to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# ----- - -from dotenv import load_dotenv -load_dotenv() - -from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler - -from flask import Flask, Response, stream_with_context - -from dimos.agents.agent import OpenAI_Agent -from dimos.types.media_provider import VideoProviderExample -from dimos.web.edge_io import FlaskServer -from dimos.types.videostream import FrameProcessor -from dimos.types.videostream import StreamUtils - -app = Flask(__name__) - -def main(): - disposables = CompositeDisposable() - - # Create a frame processor to manipulate our video inputs - processor = FrameProcessor() - - # Video provider setup - my_video_provider = VideoProviderExample("Video File", video_source="/app/assets/video-f30-480p.mp4") # "/app/assets/trimmed_video.mov") # "rtsp://10.0.0.106:8080/h264.sdp") # - video_stream_obs = my_video_provider.video_capture_to_observable().pipe( - # ops.ref_count(), - ops.subscribe_on(ThreadPoolScheduler()) - ) - - # Articficlally slow the stream (60fps ~ 16667us) - slowed_video_stream_obs = StreamUtils.limit_emission_rate(video_stream_obs, time_delta=timedelta(microseconds=16667)) - - # Process an edge detection stream - edge_detection_stream_obs = processor.process_stream_edge_detection(slowed_video_stream_obs) - - # Process an optical flow stream - optical_flow_stream_obs = processor.process_stream_optical_flow(slowed_video_stream_obs) - - # Dump streams to disk - # Raw Frames - video_stream_dump_obs = processor.process_stream_export_to_jpeg(video_stream_obs, suffix="raw") - video_stream_dump_obs.subscribe( - on_next=lambda result: None, # print(f"Slowed Stream Result: {result}"), - on_error=lambda e: print(f"Error (Stream): {e}"), - on_completed=lambda: print("Processing completed.") - ) - - # Slowed Stream - slowed_video_stream_dump_obs = processor.process_stream_export_to_jpeg(slowed_video_stream_obs, suffix="raw") - slowed_video_stream_dump_obs.subscribe( - on_next=lambda result: None, # print(f"Slowed Stream Result: {result}"), - on_error=lambda e: print(f"Error (Slowed Stream): {e}"), - on_completed=lambda: print("Processing completed.") - ) - - # Edge Detection - edge_detection_stream_dump_obs = processor.process_stream_export_to_jpeg(edge_detection_stream_obs, suffix="edge") - edge_detection_stream_dump_obs.subscribe( - on_next=lambda result: None, # print(f"Edge Detection Result: {result}"), - on_error=lambda e: print(f"Error (Edge Detection): {e}"), - on_completed=lambda: print("Processing completed.") - ) - - # Optical Flow - optical_flow_stream_dump_obs = processor.process_stream_export_to_jpeg(optical_flow_stream_obs, suffix="optical") - optical_flow_stream_dump_obs.subscribe( - on_next=lambda result: None, # print(f"Optical Flow Result: {result}"), - on_error=lambda e: print(f"Error (Optical Flow): {e}"), - on_completed=lambda: print("Processing completed.") - ) - - # Local Optical Flow Threshold - # TODO: Propogate up relevancy score from compute_optical_flow nested in process_stream_optical_flow - - # Agent Orchastrator (Qu.s Awareness, Temporality, Routing) - # TODO: Expand - - # Agent 1 - # my_agent = OpenAI_Agent("Agent 1", query="You are a robot. What do you see? Put a JSON with objects of what you see in the format {object, description}.") - # my_agent.subscribe_to_image_processing(slowed_video_stream_dump_obs) - # disposables.add(my_agent.disposables) - - # Agent 2 - # my_agent_two = OpenAI_Agent("Agent 2", query="This is a visualization of dense optical flow. What movement(s) have occured? Put a JSON with mapped directions you see in the format {direction, probability, english_description}.") - # my_agent_two.subscribe_to_image_processing(optical_flow_stream_dump_obs) - # disposables.add(my_agent.disposables) - - # Create and start the Flask server - # Will be visible at http://[host]:[port]/video_feed/[key] - flask_server = FlaskServer(main=video_stream_obs, - slowed=slowed_video_stream_obs, - edge=edge_detection_stream_obs, - optical=optical_flow_stream_dump_obs, - ) - # flask_server = FlaskServer(main=video_stream_obs, - # slowed=slowed_video_stream_obs, - # edge_detection=edge_detection_stream_obs, - # optical_flow=optical_flow_stream_obs, - # # main5=video_stream_dump_obs, - # # main6=video_stream_dump_obs, - # ) - # flask_server = FlaskServer( - # main1=video_stream_obs, - # main2=video_stream_obs, - # main3=video_stream_obs, - # main4=slowed_video_stream_obs, - # main5=slowed_video_stream_obs, - # main6=slowed_video_stream_obs, - # ) - flask_server.run() - -if __name__ == "__main__": - main() - diff --git a/tests/agent_memory_test.py b/tests/agent_memory_test.py new file mode 100644 index 0000000000..cdad5429ab --- /dev/null +++ b/tests/agent_memory_test.py @@ -0,0 +1,57 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# ----- +from dotenv import load_dotenv + +load_dotenv() + +from dimos.agents.memory.chroma_impl import OpenAISemanticMemory + +agent_memory = OpenAISemanticMemory() +print("Initialization done.") + +agent_memory.add_vector("id0", "Food") +agent_memory.add_vector("id1", "Cat") +agent_memory.add_vector("id2", "Mouse") +agent_memory.add_vector("id3", "Bike") +agent_memory.add_vector("id4", "Dog") +agent_memory.add_vector("id5", "Tricycle") +agent_memory.add_vector("id6", "Car") +agent_memory.add_vector("id7", "Horse") +agent_memory.add_vector("id8", "Vehicle") +agent_memory.add_vector("id6", "Red") +agent_memory.add_vector("id7", "Orange") +agent_memory.add_vector("id8", "Yellow") +print("Adding vectors done.") + +print(agent_memory.get_vector("id1")) +print("Done retrieving sample vector.") + +results = agent_memory.query("Colors") +print(results) +print("Done querying agent memory (basic).") + +results = agent_memory.query("Colors", similarity_threshold=0.2) +print(results) +print("Done querying agent memory (similarity_threshold=0.2).") + +results = agent_memory.query("Colors", n_results=2) +print(results) +print("Done querying agent memory (n_results=2).") + +results = agent_memory.query("Colors", n_results=19, similarity_threshold=0.45) +print(results) +print("Done querying agent memory (n_results=19, similarity_threshold=0.45).") diff --git a/tests/colmap_test.py b/tests/colmap_test.py deleted file mode 100644 index 21067603e9..0000000000 --- a/tests/colmap_test.py +++ /dev/null @@ -1,11 +0,0 @@ -import sys -import os - -# Add the parent directory of 'demos' to the Python path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# Now try to import -from dimos.environment.colmap_environment import COLMAPEnvironment - -env = COLMAPEnvironment() -env.initialize_from_video("data/IMG_1525.MOV", "data/frames") diff --git a/tests/data/database.db-shm b/tests/data/database.db-shm deleted file mode 100644 index 83434a41a6..0000000000 Binary files a/tests/data/database.db-shm and /dev/null differ diff --git a/tests/data/database.db.REMOVED.git-id b/tests/data/database.db.REMOVED.git-id deleted file mode 100644 index 4342f3915b..0000000000 --- a/tests/data/database.db.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -b269371a99c36f7f05b71a7c5593c6b6aaf55751 \ No newline at end of file diff --git a/tests/data/output-0.5fps/frame_0000.jpg b/tests/data/output-0.5fps/frame_0000.jpg deleted file mode 100644 index 1a10eed0c5..0000000000 Binary files a/tests/data/output-0.5fps/frame_0000.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0001.jpg b/tests/data/output-0.5fps/frame_0001.jpg deleted file mode 100644 index 7e0a0e5a05..0000000000 Binary files a/tests/data/output-0.5fps/frame_0001.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0002.jpg b/tests/data/output-0.5fps/frame_0002.jpg deleted file mode 100644 index 0035dda6b2..0000000000 Binary files a/tests/data/output-0.5fps/frame_0002.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0003.jpg b/tests/data/output-0.5fps/frame_0003.jpg deleted file mode 100644 index 4101db2573..0000000000 Binary files a/tests/data/output-0.5fps/frame_0003.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0004.jpg b/tests/data/output-0.5fps/frame_0004.jpg deleted file mode 100644 index ef51ed1558..0000000000 Binary files a/tests/data/output-0.5fps/frame_0004.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0005.jpg b/tests/data/output-0.5fps/frame_0005.jpg deleted file mode 100644 index 2fc669d73c..0000000000 Binary files a/tests/data/output-0.5fps/frame_0005.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0006.jpg b/tests/data/output-0.5fps/frame_0006.jpg deleted file mode 100644 index dabc3c6f57..0000000000 Binary files a/tests/data/output-0.5fps/frame_0006.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0007.jpg b/tests/data/output-0.5fps/frame_0007.jpg deleted file mode 100644 index fce21ebacb..0000000000 Binary files a/tests/data/output-0.5fps/frame_0007.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0008.jpg b/tests/data/output-0.5fps/frame_0008.jpg deleted file mode 100644 index 3bcd51f8a4..0000000000 Binary files a/tests/data/output-0.5fps/frame_0008.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0009.jpg b/tests/data/output-0.5fps/frame_0009.jpg deleted file mode 100644 index 165070366d..0000000000 Binary files a/tests/data/output-0.5fps/frame_0009.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0010.jpg b/tests/data/output-0.5fps/frame_0010.jpg deleted file mode 100644 index 37661ce8a3..0000000000 Binary files a/tests/data/output-0.5fps/frame_0010.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0011.jpg b/tests/data/output-0.5fps/frame_0011.jpg deleted file mode 100644 index 3ff1938304..0000000000 Binary files a/tests/data/output-0.5fps/frame_0011.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0012.jpg b/tests/data/output-0.5fps/frame_0012.jpg deleted file mode 100644 index ca53afa86b..0000000000 Binary files a/tests/data/output-0.5fps/frame_0012.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0013.jpg b/tests/data/output-0.5fps/frame_0013.jpg deleted file mode 100644 index 791dd151e1..0000000000 Binary files a/tests/data/output-0.5fps/frame_0013.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0014.jpg b/tests/data/output-0.5fps/frame_0014.jpg deleted file mode 100644 index 0e432b3dfb..0000000000 Binary files a/tests/data/output-0.5fps/frame_0014.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0015.jpg b/tests/data/output-0.5fps/frame_0015.jpg deleted file mode 100644 index 2b5997771f..0000000000 Binary files a/tests/data/output-0.5fps/frame_0015.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0016.jpg b/tests/data/output-0.5fps/frame_0016.jpg deleted file mode 100644 index d423061327..0000000000 Binary files a/tests/data/output-0.5fps/frame_0016.jpg and /dev/null differ diff --git a/tests/data/output-0.5fps/frame_0017.jpg b/tests/data/output-0.5fps/frame_0017.jpg deleted file mode 100644 index 4f8786e26a..0000000000 Binary files a/tests/data/output-0.5fps/frame_0017.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0000.jpg b/tests/data/output-2fps/frame_0000.jpg deleted file mode 100644 index 1a10eed0c5..0000000000 Binary files a/tests/data/output-2fps/frame_0000.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0001.jpg b/tests/data/output-2fps/frame_0001.jpg deleted file mode 100644 index c6d832a754..0000000000 Binary files a/tests/data/output-2fps/frame_0001.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0002.jpg b/tests/data/output-2fps/frame_0002.jpg deleted file mode 100644 index 43193e4585..0000000000 Binary files a/tests/data/output-2fps/frame_0002.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0003.jpg b/tests/data/output-2fps/frame_0003.jpg deleted file mode 100644 index 4679f686d7..0000000000 Binary files a/tests/data/output-2fps/frame_0003.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0004.jpg b/tests/data/output-2fps/frame_0004.jpg deleted file mode 100644 index 7e0a0e5a05..0000000000 Binary files a/tests/data/output-2fps/frame_0004.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0005.jpg b/tests/data/output-2fps/frame_0005.jpg deleted file mode 100644 index e43968e8c6..0000000000 Binary files a/tests/data/output-2fps/frame_0005.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0006.jpg b/tests/data/output-2fps/frame_0006.jpg deleted file mode 100644 index 62f7926562..0000000000 Binary files a/tests/data/output-2fps/frame_0006.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0007.jpg b/tests/data/output-2fps/frame_0007.jpg deleted file mode 100644 index 53c4ea99bc..0000000000 Binary files a/tests/data/output-2fps/frame_0007.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0008.jpg b/tests/data/output-2fps/frame_0008.jpg deleted file mode 100644 index 0035dda6b2..0000000000 Binary files a/tests/data/output-2fps/frame_0008.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0009.jpg b/tests/data/output-2fps/frame_0009.jpg deleted file mode 100644 index 144e6aa345..0000000000 Binary files a/tests/data/output-2fps/frame_0009.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0010.jpg b/tests/data/output-2fps/frame_0010.jpg deleted file mode 100644 index 8bf6485a7b..0000000000 Binary files a/tests/data/output-2fps/frame_0010.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0011.jpg b/tests/data/output-2fps/frame_0011.jpg deleted file mode 100644 index a2db503086..0000000000 Binary files a/tests/data/output-2fps/frame_0011.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0012.jpg b/tests/data/output-2fps/frame_0012.jpg deleted file mode 100644 index 4101db2573..0000000000 Binary files a/tests/data/output-2fps/frame_0012.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0013.jpg b/tests/data/output-2fps/frame_0013.jpg deleted file mode 100644 index a2d560ba69..0000000000 Binary files a/tests/data/output-2fps/frame_0013.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0014.jpg b/tests/data/output-2fps/frame_0014.jpg deleted file mode 100644 index 0be5d8682c..0000000000 Binary files a/tests/data/output-2fps/frame_0014.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0015.jpg b/tests/data/output-2fps/frame_0015.jpg deleted file mode 100644 index 8a9442f365..0000000000 Binary files a/tests/data/output-2fps/frame_0015.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0016.jpg b/tests/data/output-2fps/frame_0016.jpg deleted file mode 100644 index ef51ed1558..0000000000 Binary files a/tests/data/output-2fps/frame_0016.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0017.jpg b/tests/data/output-2fps/frame_0017.jpg deleted file mode 100644 index d40466b69f..0000000000 Binary files a/tests/data/output-2fps/frame_0017.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0018.jpg b/tests/data/output-2fps/frame_0018.jpg deleted file mode 100644 index 325721b37e..0000000000 Binary files a/tests/data/output-2fps/frame_0018.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0019.jpg b/tests/data/output-2fps/frame_0019.jpg deleted file mode 100644 index a6cadc0b0b..0000000000 Binary files a/tests/data/output-2fps/frame_0019.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0020.jpg b/tests/data/output-2fps/frame_0020.jpg deleted file mode 100644 index 2fc669d73c..0000000000 Binary files a/tests/data/output-2fps/frame_0020.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0021.jpg b/tests/data/output-2fps/frame_0021.jpg deleted file mode 100644 index 91b5c85e2e..0000000000 Binary files a/tests/data/output-2fps/frame_0021.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0022.jpg b/tests/data/output-2fps/frame_0022.jpg deleted file mode 100644 index 707fb59c19..0000000000 Binary files a/tests/data/output-2fps/frame_0022.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0023.jpg b/tests/data/output-2fps/frame_0023.jpg deleted file mode 100644 index 6f9c85a394..0000000000 Binary files a/tests/data/output-2fps/frame_0023.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0024.jpg b/tests/data/output-2fps/frame_0024.jpg deleted file mode 100644 index dabc3c6f57..0000000000 Binary files a/tests/data/output-2fps/frame_0024.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0025.jpg b/tests/data/output-2fps/frame_0025.jpg deleted file mode 100644 index cff338eb8e..0000000000 Binary files a/tests/data/output-2fps/frame_0025.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0026.jpg b/tests/data/output-2fps/frame_0026.jpg deleted file mode 100644 index 32a8401449..0000000000 Binary files a/tests/data/output-2fps/frame_0026.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0027.jpg b/tests/data/output-2fps/frame_0027.jpg deleted file mode 100644 index c523e9a5a1..0000000000 Binary files a/tests/data/output-2fps/frame_0027.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0028.jpg b/tests/data/output-2fps/frame_0028.jpg deleted file mode 100644 index fce21ebacb..0000000000 Binary files a/tests/data/output-2fps/frame_0028.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0029.jpg b/tests/data/output-2fps/frame_0029.jpg deleted file mode 100644 index c37bbddba4..0000000000 Binary files a/tests/data/output-2fps/frame_0029.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0030.jpg b/tests/data/output-2fps/frame_0030.jpg deleted file mode 100644 index 53e366245d..0000000000 Binary files a/tests/data/output-2fps/frame_0030.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0031.jpg b/tests/data/output-2fps/frame_0031.jpg deleted file mode 100644 index aa68f0948d..0000000000 Binary files a/tests/data/output-2fps/frame_0031.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0032.jpg b/tests/data/output-2fps/frame_0032.jpg deleted file mode 100644 index 3bcd51f8a4..0000000000 Binary files a/tests/data/output-2fps/frame_0032.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0033.jpg b/tests/data/output-2fps/frame_0033.jpg deleted file mode 100644 index 9b53531c5f..0000000000 Binary files a/tests/data/output-2fps/frame_0033.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0034.jpg b/tests/data/output-2fps/frame_0034.jpg deleted file mode 100644 index 920e7a1290..0000000000 Binary files a/tests/data/output-2fps/frame_0034.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0035.jpg b/tests/data/output-2fps/frame_0035.jpg deleted file mode 100644 index 672d8ec116..0000000000 Binary files a/tests/data/output-2fps/frame_0035.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0036.jpg b/tests/data/output-2fps/frame_0036.jpg deleted file mode 100644 index 165070366d..0000000000 Binary files a/tests/data/output-2fps/frame_0036.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0037.jpg b/tests/data/output-2fps/frame_0037.jpg deleted file mode 100644 index 390dd8f028..0000000000 Binary files a/tests/data/output-2fps/frame_0037.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0038.jpg b/tests/data/output-2fps/frame_0038.jpg deleted file mode 100644 index 38baee9771..0000000000 Binary files a/tests/data/output-2fps/frame_0038.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0039.jpg b/tests/data/output-2fps/frame_0039.jpg deleted file mode 100644 index 76c6b4518a..0000000000 Binary files a/tests/data/output-2fps/frame_0039.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0040.jpg b/tests/data/output-2fps/frame_0040.jpg deleted file mode 100644 index 37661ce8a3..0000000000 Binary files a/tests/data/output-2fps/frame_0040.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0041.jpg b/tests/data/output-2fps/frame_0041.jpg deleted file mode 100644 index 714681fbe4..0000000000 Binary files a/tests/data/output-2fps/frame_0041.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0042.jpg b/tests/data/output-2fps/frame_0042.jpg deleted file mode 100644 index 4521f8c8ad..0000000000 Binary files a/tests/data/output-2fps/frame_0042.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0043.jpg b/tests/data/output-2fps/frame_0043.jpg deleted file mode 100644 index 9402ab3c0f..0000000000 Binary files a/tests/data/output-2fps/frame_0043.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0044.jpg b/tests/data/output-2fps/frame_0044.jpg deleted file mode 100644 index 3ff1938304..0000000000 Binary files a/tests/data/output-2fps/frame_0044.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0045.jpg b/tests/data/output-2fps/frame_0045.jpg deleted file mode 100644 index 74ae32e7b2..0000000000 Binary files a/tests/data/output-2fps/frame_0045.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0046.jpg b/tests/data/output-2fps/frame_0046.jpg deleted file mode 100644 index c0cee10333..0000000000 Binary files a/tests/data/output-2fps/frame_0046.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0047.jpg b/tests/data/output-2fps/frame_0047.jpg deleted file mode 100644 index 12132c3352..0000000000 Binary files a/tests/data/output-2fps/frame_0047.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0048.jpg b/tests/data/output-2fps/frame_0048.jpg deleted file mode 100644 index ca53afa86b..0000000000 Binary files a/tests/data/output-2fps/frame_0048.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0049.jpg b/tests/data/output-2fps/frame_0049.jpg deleted file mode 100644 index 6dfd2961a1..0000000000 Binary files a/tests/data/output-2fps/frame_0049.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0050.jpg b/tests/data/output-2fps/frame_0050.jpg deleted file mode 100644 index a9ad1e80a5..0000000000 Binary files a/tests/data/output-2fps/frame_0050.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0051.jpg b/tests/data/output-2fps/frame_0051.jpg deleted file mode 100644 index 4b23359f77..0000000000 Binary files a/tests/data/output-2fps/frame_0051.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0052.jpg b/tests/data/output-2fps/frame_0052.jpg deleted file mode 100644 index 791dd151e1..0000000000 Binary files a/tests/data/output-2fps/frame_0052.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0053.jpg b/tests/data/output-2fps/frame_0053.jpg deleted file mode 100644 index ac206e1202..0000000000 Binary files a/tests/data/output-2fps/frame_0053.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0054.jpg b/tests/data/output-2fps/frame_0054.jpg deleted file mode 100644 index 5b63ae4378..0000000000 Binary files a/tests/data/output-2fps/frame_0054.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0055.jpg b/tests/data/output-2fps/frame_0055.jpg deleted file mode 100644 index 3ad9e61043..0000000000 Binary files a/tests/data/output-2fps/frame_0055.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0056.jpg b/tests/data/output-2fps/frame_0056.jpg deleted file mode 100644 index 0e432b3dfb..0000000000 Binary files a/tests/data/output-2fps/frame_0056.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0057.jpg b/tests/data/output-2fps/frame_0057.jpg deleted file mode 100644 index 66c66c5265..0000000000 Binary files a/tests/data/output-2fps/frame_0057.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0058.jpg b/tests/data/output-2fps/frame_0058.jpg deleted file mode 100644 index 3339c76e85..0000000000 Binary files a/tests/data/output-2fps/frame_0058.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0059.jpg b/tests/data/output-2fps/frame_0059.jpg deleted file mode 100644 index 50abfc29ea..0000000000 Binary files a/tests/data/output-2fps/frame_0059.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0060.jpg b/tests/data/output-2fps/frame_0060.jpg deleted file mode 100644 index 2b5997771f..0000000000 Binary files a/tests/data/output-2fps/frame_0060.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0061.jpg b/tests/data/output-2fps/frame_0061.jpg deleted file mode 100644 index 72d47f757e..0000000000 Binary files a/tests/data/output-2fps/frame_0061.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0062.jpg b/tests/data/output-2fps/frame_0062.jpg deleted file mode 100644 index 130ae25869..0000000000 Binary files a/tests/data/output-2fps/frame_0062.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0063.jpg b/tests/data/output-2fps/frame_0063.jpg deleted file mode 100644 index 1dd2b46105..0000000000 Binary files a/tests/data/output-2fps/frame_0063.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0064.jpg b/tests/data/output-2fps/frame_0064.jpg deleted file mode 100644 index d423061327..0000000000 Binary files a/tests/data/output-2fps/frame_0064.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0065.jpg b/tests/data/output-2fps/frame_0065.jpg deleted file mode 100644 index c51d99ef85..0000000000 Binary files a/tests/data/output-2fps/frame_0065.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0066.jpg b/tests/data/output-2fps/frame_0066.jpg deleted file mode 100644 index 3fc0e17015..0000000000 Binary files a/tests/data/output-2fps/frame_0066.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0067.jpg b/tests/data/output-2fps/frame_0067.jpg deleted file mode 100644 index 3dee35ec9f..0000000000 Binary files a/tests/data/output-2fps/frame_0067.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0068.jpg b/tests/data/output-2fps/frame_0068.jpg deleted file mode 100644 index 4f8786e26a..0000000000 Binary files a/tests/data/output-2fps/frame_0068.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0069.jpg b/tests/data/output-2fps/frame_0069.jpg deleted file mode 100644 index 23972dfd6a..0000000000 Binary files a/tests/data/output-2fps/frame_0069.jpg and /dev/null differ diff --git a/tests/data/output-2fps/frame_0070.jpg b/tests/data/output-2fps/frame_0070.jpg deleted file mode 100644 index 59d2a6da44..0000000000 Binary files a/tests/data/output-2fps/frame_0070.jpg and /dev/null differ diff --git a/tests/data/sparse/0/cameras.bin b/tests/data/sparse/0/cameras.bin deleted file mode 100644 index ec10b759a0..0000000000 Binary files a/tests/data/sparse/0/cameras.bin and /dev/null differ diff --git a/tests/data/sparse/0/images.bin.REMOVED.git-id b/tests/data/sparse/0/images.bin.REMOVED.git-id deleted file mode 100644 index 032880910a..0000000000 --- a/tests/data/sparse/0/images.bin.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -cc9db821c6ccb0c01c988ab735f1a69455ad350a \ No newline at end of file diff --git a/tests/data/sparse/project.ini b/tests/data/sparse/project.ini deleted file mode 100644 index 47cbbb6d84..0000000000 --- a/tests/data/sparse/project.ini +++ /dev/null @@ -1,218 +0,0 @@ -log_to_stderr=true -random_seed=0 -log_level=0 -database_path=./database.db -image_path=./output-2fps/ -[ImageReader] -single_camera=false -single_camera_per_folder=false -single_camera_per_image=false -existing_camera_id=-1 -default_focal_length_factor=1.2 -mask_path= -camera_model=SIMPLE_RADIAL -camera_params= -camera_mask_path= -[SiftExtraction] -use_gpu=true -estimate_affine_shape=true -upright=false -domain_size_pooling=false -num_threads=-1 -max_image_size=2400 -max_num_features=8192 -first_octave=-1 -num_octaves=4 -octave_resolution=3 -max_num_orientations=2 -dsp_num_scales=10 -peak_threshold=0.0066666666666666671 -edge_threshold=10 -dsp_min_scale=0.16666666666666666 -dsp_max_scale=3 -gpu_index=-1 -[SiftMatching] -use_gpu=true -cross_check=true -guided_matching=true -num_threads=-1 -max_num_matches=32768 -max_ratio=0.80000000000000004 -max_distance=0.69999999999999996 -gpu_index=-1 -[TwoViewGeometry] -multiple_models=false -compute_relative_pose=false -min_num_inliers=15 -max_num_trials=10000 -max_error=4 -confidence=0.999 -min_inlier_ratio=0.25 -[SequentialMatching] -quadratic_overlap=true -loop_detection=false -overlap=10 -loop_detection_period=10 -loop_detection_num_images=50 -loop_detection_num_nearest_neighbors=1 -loop_detection_num_checks=256 -loop_detection_num_images_after_verification=0 -loop_detection_max_num_features=-1 -vocab_tree_path= -[SpatialMatching] -ignore_z=true -max_num_neighbors=50 -max_distance=100 -[BundleAdjustment] -refine_focal_length=true -refine_principal_point=false -refine_extra_params=true -refine_extrinsics=true -use_gpu=true -max_num_iterations=100 -max_linear_solver_iterations=200 -min_num_images_gpu_solver=50 -min_num_residuals_for_cpu_multi_threading=50000 -max_num_images_direct_dense_cpu_solver=50 -max_num_images_direct_sparse_cpu_solver=1000 -max_num_images_direct_dense_gpu_solver=200 -max_num_images_direct_sparse_gpu_solver=4000 -function_tolerance=0 -gradient_tolerance=0.0001 -parameter_tolerance=0 -gpu_index=-1 -[Mapper] -ignore_watermarks=false -multiple_models=true -extract_colors=true -ba_refine_focal_length=true -ba_refine_principal_point=false -ba_refine_extra_params=true -ba_use_gpu=true -fix_existing_images=false -tri_ignore_two_view_tracks=true -min_num_matches=15 -max_num_models=50 -max_model_overlap=20 -min_model_size=10 -init_image_id1=-1 -init_image_id2=-1 -init_num_trials=200 -num_threads=-1 -ba_local_num_images=6 -ba_local_max_num_iterations=30 -ba_global_images_freq=500 -ba_global_points_freq=250000 -ba_global_max_num_iterations=75 -ba_global_max_refinements=5 -ba_local_max_refinements=3 -ba_min_num_residuals_for_cpu_multi_threading=50000 -snapshot_images_freq=0 -init_min_num_inliers=100 -init_max_reg_trials=2 -abs_pose_min_num_inliers=30 -max_reg_trials=3 -tri_max_transitivity=1 -tri_complete_max_transitivity=5 -tri_re_max_trials=1 -min_focal_length_ratio=0.10000000000000001 -max_focal_length_ratio=10 -max_extra_param=1.7976931348623157e+308 -ba_local_function_tolerance=0 -ba_global_images_ratio=1.1000000000000001 -ba_global_points_ratio=1.1000000000000001 -ba_global_function_tolerance=0 -ba_global_max_refinement_change=0.00050000000000000001 -ba_local_max_refinement_change=0.001 -init_max_error=4 -init_max_forward_motion=0.94999999999999996 -init_min_tri_angle=16 -abs_pose_max_error=12 -abs_pose_min_inlier_ratio=0.25 -filter_max_reproj_error=4 -filter_min_tri_angle=1.5 -local_ba_min_tri_angle=6 -tri_create_max_angle_error=2 -tri_continue_max_angle_error=2 -tri_merge_max_reproj_error=4 -tri_complete_max_reproj_error=4 -tri_re_max_angle_error=5 -tri_re_min_ratio=0.20000000000000001 -tri_min_angle=1.5 -ba_gpu_index=-1 -snapshot_path= -[PatchMatchStereo] -geom_consistency=true -filter=true -allow_missing_files=false -write_consistency_graph=false -max_image_size=2400 -window_radius=5 -window_step=1 -num_samples=15 -num_iterations=5 -filter_min_num_consistent=2 -depth_min=-1 -depth_max=-1 -sigma_spatial=-1 -sigma_color=0.20000000298023224 -ncc_sigma=0.60000002384185791 -min_triangulation_angle=1 -incident_angle_sigma=0.89999997615814209 -geom_consistency_regularizer=0.30000001192092896 -geom_consistency_max_cost=3 -filter_min_ncc=0.10000000149011612 -filter_min_triangulation_angle=3 -filter_geom_consistency_max_cost=1 -cache_size=32 -gpu_index=-1 -[StereoFusion] -use_cache=false -num_threads=-1 -max_image_size=2400 -min_num_pixels=5 -max_num_pixels=10000 -max_traversal_depth=100 -check_num_images=50 -max_reproj_error=2 -max_depth_error=0.0099999997764825821 -max_normal_error=10 -cache_size=32 -mask_path= -[Render] -adapt_refresh_rate=true -image_connections=false -min_track_len=3 -refresh_rate=1 -projection_type=0 -max_error=2 -[ExhaustiveMatching] -block_size=50 -[VocabTreeMatching] -num_images=100 -num_nearest_neighbors=5 -num_checks=256 -num_images_after_verification=0 -max_num_features=-1 -vocab_tree_path= -match_list_path= -[TransitiveMatching] -batch_size=1000 -num_iterations=3 -[ImagePairsMatching] -block_size=1225 -[PoissonMeshing] -depth=13 -num_threads=-1 -point_weight=1 -color=32 -trim=10 -[DelaunayMeshing] -num_threads=-1 -max_proj_dist=20 -max_depth_dist=0.050000000000000003 -visibility_sigma=3 -distance_sigma_factor=1 -quality_regularization=1 -max_side_length_factor=25 -max_side_length_percentile=95 diff --git a/tests/genesissim/stream_camera.py b/tests/genesissim/stream_camera.py new file mode 100644 index 0000000000..7f038a8b8f --- /dev/null +++ b/tests/genesissim/stream_camera.py @@ -0,0 +1,56 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos.simulation.genesis import GenesisSimulator, GenesisStream + + +def main(): + # Add multiple entities at once + entities = [ + {"type": "primitive", "params": {"shape": "plane"}}, + {"type": "mjcf", "path": "xml/franka_emika_panda/panda.xml"}, + ] + # Initialize simulator + sim = GenesisSimulator(headless=True, entities=entities) + + # You can also add entity individually + sim.add_entity("primitive", shape="box", size=[0.5, 0.5, 0.5], pos=[0, 1, 0.5]) + + # Create stream with custom settings + stream = GenesisStream( + simulator=sim, + width=1280, # Genesis default resolution + height=960, + fps=60, + camera_path="/camera", # Genesis uses simpler camera paths + annotator_type="rgb", # Can be 'rgb' or 'normals' + transport="tcp", + rtsp_url="rtsp://mediamtx:8554/stream", + ) + + # Start streaming + try: + stream.stream() + except KeyboardInterrupt: + print("\n[Stream] Received keyboard interrupt, stopping stream...") + finally: + try: + stream.cleanup() + finally: + sim.close() + + +if __name__ == "__main__": + main() diff --git a/tests/isaacsim/run-isaacsim-docker.sh b/tests/isaacsim/run-isaacsim-docker.sh new file mode 100644 index 0000000000..2019695960 --- /dev/null +++ b/tests/isaacsim/run-isaacsim-docker.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Run Isaac Sim container with display and GPU support +sudo docker run --network rtsp_net --name isaac-sim --entrypoint bash -it --runtime=nvidia --gpus all -e "ACCEPT_EULA=Y" --rm \ + -e "PRIVACY_CONSENT=Y" \ + -v ~/docker/isaac-sim/cache/kit:/isaac-sim/kit/cache:rw \ + -v ~/docker/isaac-sim/cache/ov:/root/.cache/ov:rw \ + -v ~/docker/isaac-sim/cache/pip:/root/.cache/pip:rw \ + -v ~/docker/isaac-sim/cache/glcache:/root/.cache/nvidia/GLCache:rw \ + -v ~/docker/isaac-sim/cache/computecache:/root/.nv/ComputeCache:rw \ + -v ~/docker/isaac-sim/logs:/root/.nvidia-omniverse/logs:rw \ + -v ~/docker/isaac-sim/data:/root/.local/share/ov/data:rw \ + -v ~/docker/isaac-sim/documents:/root/Documents:rw \ + -v ~/dimos:/dimos:rw \ + nvcr.io/nvidia/isaac-sim:4.2.0 + +/isaac-sim/python.sh -m pip install -r /dimos/tests/isaacsim/requirements.txt +apt-get update +apt-get install -y ffmpeg +/isaac-sim/python.sh /dimos/tests/isaacsim/stream_camera.py diff --git a/tests/isaacsim/setup_ec2.sh b/tests/isaacsim/setup_ec2.sh new file mode 100644 index 0000000000..f9d33bb3cc --- /dev/null +++ b/tests/isaacsim/setup_ec2.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +sudo apt-get update +sudo apt install build-essential -y +sudo apt-get install -y nvidia-driver-535 +sudo reboot +sudo apt install -y nvidia-cuda-toolkit +nvidia-smi + + +# Docker installation using the convenience script +curl -fsSL https://get.docker.com -o get-docker.sh +sudo sh get-docker.sh + +# Post-install steps for Docker +sudo groupadd docker +sudo usermod -aG docker $USER +newgrp docker + +#Verify Docker + +# Configure the repository +curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ + && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ + sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list \ + && \ + sudo apt-get update + +# Install the NVIDIA Container Toolkit packages +sudo apt-get install -y nvidia-container-toolkit +sudo systemctl restart docker + +# Configure the container runtime +sudo nvidia-ctk runtime configure --runtime=docker +sudo systemctl restart docker + +# Verify NVIDIA Container Toolkit +sudo docker run --rm --runtime=nvidia --gpus all ubuntu nvidia-smi + +# Full isaac sim container +sudo docker pull nvcr.io/nvidia/isaac-sim:4.2.0 diff --git a/tests/isaacsim/setup_isaacsim_python.sh b/tests/isaacsim/setup_isaacsim_python.sh new file mode 100644 index 0000000000..27744482e4 --- /dev/null +++ b/tests/isaacsim/setup_isaacsim_python.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +sudo apt install python3.10-venv +python3.10 -m venv env_isaacsim +source env_isaacsim/bin/activate + +# Install pip packages +pip install isaacsim==4.2.0.2 --extra-index-url https://pypi.nvidia.com +pip install isaacsim-extscache-physics==4.2.0.2 +pip install isaacsim-extscache-kit==4.2.0.2 +pip install isaacsim-extscache-kit-sdk==4.2.0.2 --extra-index-url https://pypi.nvidia.com + +export OMNI_KIT_ACCEPT_EULA=YES diff --git a/tests/isaacsim/setup_ros.sh b/tests/isaacsim/setup_ros.sh new file mode 100644 index 0000000000..976487f299 --- /dev/null +++ b/tests/isaacsim/setup_ros.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Add ROS 2 repository +sudo apt update && sudo apt install -y software-properties-common +sudo add-apt-repository universe -y +sudo apt update && sudo apt install curl -y +sudo curl -sSL https://raw.githubusercontent.com/ros/rosdistro/master/ros.key -o /usr/share/keyrings/ros-archive-keyring.gpg +echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/ros-archive-keyring.gpg] http://packages.ros.org/ros2/ubuntu $(. /etc/os-release && echo $UBUNTU_CODENAME) main" | sudo tee /etc/apt/sources.list.d/ros2.list > /dev/null + +# Update package lists +sudo apt update +sudo apt upgrade -y + +# Install ROS 2 Humble (latest LTS for Ubuntu 22.04) +sudo apt install -y ros-humble-desktop +sudo apt install -y ros-humble-ros-base +sudo apt install -y ros-dev-tools + +# Install additional ROS 2 packages +sudo apt install -y python3-rosdep +sudo apt install -y python3-colcon-common-extensions + +# Initialize rosdep +sudo rosdep init +rosdep update + +# Setup environment variables +echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc +source ~/.bashrc + +# Install additional dependencies that might be useful +sudo apt install -y python3-pip +pip3 install --upgrade pip +pip3 install transforms3d numpy scipy +sudo apt install -y python3.10-venv + +# Create ROS 2 workspace +mkdir -p ~/ros2_ws/src +cd ~/ros2_ws +colcon build + +# Source the workspace +echo "source ~/ros2_ws/install/setup.bash" >> ~/.bashrc +source ~/.bashrc + +# Print success message +echo "ROS 2 Humble installation completed successfully!" diff --git a/tests/isaacsim/stream_camera.py b/tests/isaacsim/stream_camera.py new file mode 100644 index 0000000000..446f42cff3 --- /dev/null +++ b/tests/isaacsim/stream_camera.py @@ -0,0 +1,42 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dimos.simulation.isaac import IsaacSimulator, IsaacStream + + +def main(): + # Initialize simulator + sim = IsaacSimulator(headless=True) + + # Create stream with custom settings + stream = IsaacStream( + simulator=sim, + width=1920, + height=1080, + fps=60, + camera_path="/World/alfred_parent_prim/alfred_base_descr/chest_cam_rgb_camera_frame/chest_cam", + annotator_type="rgb", + transport="tcp", + rtsp_url="rtsp://mediamtx:8554/stream", + usd_path=f"{os.getcwd()}/assets/TestSim3.usda", + ) + + # Start streaming + stream.stream() + + +if __name__ == "__main__": + main() diff --git a/tests/mockdata/costmap.pickle b/tests/mockdata/costmap.pickle new file mode 100644 index 0000000000..a29199e841 Binary files /dev/null and b/tests/mockdata/costmap.pickle differ diff --git a/tests/mockdata/vegas.pickle b/tests/mockdata/vegas.pickle new file mode 100644 index 0000000000..a7da5309c0 Binary files /dev/null and b/tests/mockdata/vegas.pickle differ diff --git a/tests/pygazebo_test.py b/tests/pygazebo_test.py deleted file mode 100644 index 116754f60f..0000000000 --- a/tests/pygazebo_test.py +++ /dev/null @@ -1,26 +0,0 @@ -import asyncio -import pygazebo -from pygazebo.msg.pose_pb2 import Pose -from pygazebo.msg.vector3d_pb2 import Vector3d -from pygazebo.msg.quaternion_pb2 import Quaternion - -async def publish_pose(): - manager = await pygazebo.connect() - publisher = await manager.advertise('/gazebo/default/pose/info', 'gazebo.msgs.Pose') - - pose = Pose() - pose.position.x = 1.0 # delta_x - pose.position.y = 0.0 # delta_y - pose.position.z = 0.0 - - pose.orientation.w = 1.0 - pose.orientation.x = 0.0 - pose.orientation.y = 0.0 - pose.orientation.z = 0.0 - - while True: - await publisher.publish(pose) - await asyncio.sleep(0.1) - -loop = asyncio.get_event_loop() -loop.run_until_complete(publish_pose()) diff --git a/tests/run_go2_ros.py b/tests/run_go2_ros.py new file mode 100644 index 0000000000..37ac4b139b --- /dev/null +++ b/tests/run_go2_ros.py @@ -0,0 +1,176 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl + + +def get_env_var(var_name, default=None, required=False): + """Get environment variable with validation.""" + value = os.getenv(var_name, default) + if value == "": + value = default + if required and not value: + raise ValueError(f"{var_name} environment variable is required") + return value + + +if __name__ == "__main__": + # Get configuration from environment variables + robot_ip = get_env_var("ROBOT_IP") + connection_method = get_env_var("CONNECTION_METHOD", "LocalSTA") + serial_number = get_env_var("SERIAL_NUMBER", None) + output_dir = get_env_var("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + print(f"Ensuring output directory exists: {output_dir}") + + use_ros = True + use_webrtc = False + # Convert connection method string to enum + connection_method = getattr(WebRTCConnectionMethod, connection_method) + + print("Initializing UnitreeGo2...") + print("Configuration:") + print(f" IP: {robot_ip}") + print(f" Connection Method: {connection_method}") + print(f" Serial Number: {serial_number if serial_number else 'Not provided'}") + print(f" Output Directory: {output_dir}") + + if use_ros: + ros_control = UnitreeROSControl(node_name="unitree_go2", use_raw=True) + else: + ros_control = None + + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + serial_number=serial_number, + output_dir=output_dir, + ros_control=ros_control, + use_ros=use_ros, + use_webrtc=use_webrtc, + ) + time.sleep(5) + try: + # Start perception + print("\nStarting perception system...") + + # Get the processed stream + processed_stream = robot.get_ros_video_stream(fps=30) + + # Create frame counter for unique filenames + frame_count = 0 + + # Create a subscriber to handle the frames + def handle_frame(frame): + global frame_count + frame_count += 1 + + try: + # Save frame to output directory if desired for debugging frame streaming + # MAKE SURE TO CHANGE OUTPUT DIR depending on if running in ROS or local + # frame_path = os.path.join(output_dir, f"frame_{frame_count:04d}.jpg") + # success = cv2.imwrite(frame_path, frame) + # print(f"Frame #{frame_count} {'saved successfully' if success else 'failed to save'} to {frame_path}") + pass + + except Exception as e: + print(f"Error in handle_frame: {e}") + import traceback + + print(traceback.format_exc()) + + def handle_error(error): + print(f"Error in stream: {error}") + + def handle_completion(): + print("Stream completed") + + # Subscribe to the stream + print("Creating subscription...") + try: + subscription = processed_stream.subscribe( + on_next=handle_frame, + on_error=lambda e: print(f"Subscription error: {e}"), + on_completed=lambda: print("Subscription completed"), + ) + print("Subscription created successfully") + except Exception as e: + print(f"Error creating subscription: {e}") + + time.sleep(5) + + # First put the robot in a good starting state + print("Running recovery stand...") + robot.webrtc_req(api_id=1006) # RecoveryStand + + # Queue 20 WebRTC requests back-to-back + print("\n🤖 QUEUEING WEBRTC COMMANDS BACK-TO-BACK FOR TESTING UnitreeGo2🤖\n") + + # Dance 1 + robot.webrtc_req(api_id=1033) + print("Queued: WiggleHips (1033)") + + robot.reverse(distance=0.2, speed=0.5) + print("Queued: Reverse 0.5m at 0.5m/s") + + # Wiggle Hips + robot.webrtc_req(api_id=1033) + print("Queued: WiggleHips (1033)") + + robot.move(distance=0.2, speed=0.5) + print("Queued: Move forward 1.0m at 0.5m/s") + + robot.webrtc_req(api_id=1017) + print("Queued: Stretch (1017)") + + robot.move(distance=0.2, speed=0.5) + print("Queued: Move forward 1.0m at 0.5m/s") + + robot.webrtc_req(api_id=1017) + print("Queued: Stretch (1017)") + + robot.reverse(distance=0.2, speed=0.5) + print("Queued: Reverse 0.5m at 0.5m/s") + + robot.webrtc_req(api_id=1017) + print("Queued: Stretch (1017)") + robot.spin(degrees=-90.0, speed=45.0) + print("Queued: Spin right 90 degrees at 45 degrees/s") + + robot.spin(degrees=90.0, speed=45.0) + print("Queued: Spin left 90 degrees at 45 degrees/s") + + # To prevent termination + while True: + time.sleep(0.1) + + except KeyboardInterrupt: + print("\nStopping perception...") + if "subscription" in locals(): + subscription.dispose() + except Exception as e: + print(f"Error in main loop: {e}") + finally: + # Cleanup + print("Cleaning up resources...") + if "subscription" in locals(): + subscription.dispose() + del robot + print("Cleanup complete.") diff --git a/tests/simple_agent_test.py b/tests/simple_agent_test.py new file mode 100644 index 0000000000..fae4d5dcdc --- /dev/null +++ b/tests/simple_agent_test.py @@ -0,0 +1,38 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills + +# Initialize robot +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) + +# Initialize agent +agent = OpenAIAgent( + dev_name="UnitreeExecutionAgent", + input_video_stream=robot.get_ros_video_stream(), + skills=robot.get_skills(), + system_query="Wiggle when you see a person! Jump when you see a person waving!", +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/tests/test_agent.py b/tests/test_agent.py index 73da481a4b..c13d59923b 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,6 +1,23 @@ -from dotenv import load_dotenv +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os +# ----- +from dotenv import load_dotenv + + # Sanity check for dotenv def test_dotenv(): print("test_dotenv:") @@ -8,9 +25,11 @@ def test_dotenv(): openai_api_key = os.getenv("OPENAI_API_KEY") print("\t\tOPENAI_API_KEY: ", openai_api_key) + # Sanity check for openai connection def test_openai_connection(): from openai import OpenAI + client = OpenAI() print("test_openai_connection:") response = client.chat.completions.create( @@ -19,7 +38,7 @@ def test_openai_connection(): { "role": "user", "content": [ - {"type": "text", "text": "What’s in this image?"}, + {"type": "text", "text": "What's in this image?"}, { "type": "image_url", "image_url": { @@ -33,5 +52,6 @@ def test_openai_connection(): ) print("\t\tOpenAI Response: ", response.choices[0]) + test_dotenv() test_openai_connection() diff --git a/tests/test_agent_alibaba.py b/tests/test_agent_alibaba.py new file mode 100644 index 0000000000..b93b302d95 --- /dev/null +++ b/tests/test_agent_alibaba.py @@ -0,0 +1,59 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from openai import OpenAI + +from dimos.agents.agent import OpenAIAgent +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler + +# Initialize video stream +video_stream = VideoProvider( + dev_name="VideoProvider", + # video_source=f"{os.getcwd()}/assets/framecount.mp4", + video_source=f"{os.getcwd()}/assets/trimmed_video_office.mov", + pool_scheduler=get_scheduler(), +).capture_video_as_observable(realtime=False, fps=1) + +# Specify the OpenAI client for Alibaba +qwen_client = OpenAI( + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + api_key=os.getenv("ALIBABA_API_KEY"), +) + +# Initialize Unitree skills +myUnitreeSkills = MyUnitreeSkills() +myUnitreeSkills.initialize_skills() + +# Initialize agent +agent = OpenAIAgent( + dev_name="AlibabaExecutionAgent", + openai_client=qwen_client, + model_name="qwen2.5-vl-72b-instruct", + tokenizer=HuggingFaceTokenizer(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), + max_output_tokens_per_request=8192, + input_video_stream=video_stream, + # system_query="Tell me the number in the video. Find me the center of the number spotted, and print the coordinates to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", + system_query="Tell me about any objects seen. Print the coordinates for center of the objects seen to the console using an appropriate function call. Then provide me a deep history of the number in question and its significance in history. Additionally, tell me what model and version of language model you are.", + skills=myUnitreeSkills, +) + +try: + input("Press ESC to exit...") +except KeyboardInterrupt: + print("\nExiting...") diff --git a/tests/test_audio_agent.py b/tests/test_audio_agent.py new file mode 100644 index 0000000000..e9cd4baffe --- /dev/null +++ b/tests/test_audio_agent.py @@ -0,0 +1,39 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.agents.agent import OpenAIAgent +from dimos.stream.audio.pipelines import stt, tts +from dimos.stream.audio.utils import keepalive +from dimos.utils.threadpool import get_scheduler + + +def main(): + stt_node = stt() + + agent = OpenAIAgent( + dev_name="UnitreeExecutionAgent", + input_query_stream=stt_node.emit_text(), + system_query="You are a helpful robot named daneel that does my bidding", + pool_scheduler=get_scheduler(), + ) + + tts_node = tts() + tts_node.consume_text(agent.get_response_observable()) + + # Keep the main thread alive + keepalive() + + +if __name__ == "__main__": + main() diff --git a/tests/test_audio_robot_agent.py b/tests/test_audio_robot_agent.py new file mode 100644 index 0000000000..bac4e3e808 --- /dev/null +++ b/tests/test_audio_robot_agent.py @@ -0,0 +1,52 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.audio.pipelines import stt, tts +from dimos.stream.audio.utils import keepalive +from dimos.utils.threadpool import get_scheduler + + +def main(): + stt_node = stt() + tts_node = tts() + + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + + # Initialize agent with main thread pool scheduler + agent = OpenAIAgent( + dev_name="UnitreeExecutionAgent", + input_query_stream=stt_node.emit_text(), + system_query="You are a helpful robot named daneel that does my bidding", + pool_scheduler=get_scheduler(), + skills=robot.get_skills(), + ) + + tts_node.consume_text(agent.get_response_observable()) + + # Keep the main thread alive + keepalive() + + +if __name__ == "__main__": + main() diff --git a/tests/test_claude_agent_query.py b/tests/test_claude_agent_query.py new file mode 100644 index 0000000000..597a67ca37 --- /dev/null +++ b/tests/test_claude_agent_query.py @@ -0,0 +1,28 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dotenv import load_dotenv + +from dimos.agents.claude_agent import ClaudeAgent + +# Load API key from environment +load_dotenv() + +# Create a ClaudeAgent instance +agent = ClaudeAgent(dev_name="test_agent", query="What is the capital of France?") + +# Use the stream_query method to get a response +response = agent.run_observable_query("What is the capital of France?").run() + +print(f"Response from Claude Agent: {response}") diff --git a/tests/test_claude_agent_skills_query.py b/tests/test_claude_agent_skills_query.py new file mode 100644 index 0000000000..83d5749e4c --- /dev/null +++ b/tests/test_claude_agent_skills_query.py @@ -0,0 +1,134 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading + +from dotenv import load_dotenv +import reactivex as rx +import reactivex.operators as ops + +from dimos.agents.claude_agent import ClaudeAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import BuildSemanticMap, GetPose, Navigate, NavigateToGoal +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.speak import Speak +from dimos.skills.visual_navigation_skills import FollowHuman, NavigateToObject +from dimos.stream.audio.pipelines import stt, tts +from dimos.types.vector import Vector +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.server import WebsocketVis + +# Load API key from environment +load_dotenv() + +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + mock_connection=False, +) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) + +streams = { + "unitree_video": robot.get_ros_video_stream(), + "local_planner_viz": local_planner_viz_stream, +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +stt_node = stt() + +# Create a ClaudeAgent instance +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=stt_node.emit_text(), + # input_query_stream=web_interface.query_stream, + skills=robot.get_skills(), + system_query="""You are an agent controlling a virtual robot. When given a query, respond by using the appropriate tool calls if needed to execute commands on the robot. + +IMPORTANT INSTRUCTIONS: +1. Each tool call must include the exact function name and appropriate parameters +2. If a function needs parameters like 'distance' or 'angle', be sure to include them +3. If you're unsure which tool to use, choose the most appropriate one based on the user's query +4. Parse the user's instructions carefully to determine correct parameter values + +Example: If the user asks to move forward 1 meter, call the Move function with distance=1""", + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=2000, +) + +tts_node = tts() +# tts_node.consume_text(agent.get_response_observable()) + +robot_skills = robot.get_skills() +robot_skills.add(ObserveStream) +robot_skills.add(KillSkill) +robot_skills.add(Navigate) +robot_skills.add(BuildSemanticMap) +robot_skills.add(NavigateToObject) +robot_skills.add(FollowHuman) +robot_skills.add(GetPose) +robot_skills.add(Speak) +robot_skills.add(NavigateToGoal) +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("Navigate", robot=robot) +robot_skills.create_instance("BuildSemanticMap", robot=robot) +robot_skills.create_instance("NavigateToObject", robot=robot) +robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +robot_skills.create_instance("NavigateToGoal", robot=robot) +robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +print("ObserveStream and Kill skills registered and ready for use") +print("Created memory.txt file") + +websocket_vis = WebsocketVis() +websocket_vis.start() +websocket_vis.connect(robot.global_planner.vis_stream()) + + +def msg_handler(msgtype, data): + if msgtype == "click": + target = Vector(data["position"]) + try: + robot.global_planner.set_goal(target) + except Exception as e: + print(f"Error setting goal: {e}") + return + + +def threaded_msg_handler(msgtype, data): + thread = threading.Thread(target=msg_handler, args=(msgtype, data)) + thread.daemon = True + thread.start() + + +websocket_vis.msg_handler = threaded_msg_handler + +web_interface.run() diff --git a/tests/test_command_pose_unitree.py b/tests/test_command_pose_unitree.py new file mode 100644 index 0000000000..3546035201 --- /dev/null +++ b/tests/test_command_pose_unitree.py @@ -0,0 +1,82 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +# Add the parent directory to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import os +import time + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills + +# Initialize robot +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) + + +# Helper function to send pose commands continuously for a duration +def send_pose_for_duration(roll, pitch, yaw, duration, hz=10): + """Send the same pose command repeatedly at specified frequency for the given duration""" + start_time = time.time() + while time.time() - start_time < duration: + robot.pose_command(roll=roll, pitch=pitch, yaw=yaw) + time.sleep(1.0 / hz) # Sleep to achieve the desired frequency + + +# Test pose commands + +# First, make sure the robot is in a stable position +print("Setting default pose...") +send_pose_for_duration(0.0, 0.0, 0.0, 1) + +# Test roll angle (lean left/right) +print("Testing roll angle - lean right...") +send_pose_for_duration(0.5, 0.0, 0.0, 1.5) # Lean right + +print("Testing roll angle - lean left...") +send_pose_for_duration(-0.5, 0.0, 0.0, 1.5) # Lean left + +# Test pitch angle (lean forward/backward) +print("Testing pitch angle - lean forward...") +send_pose_for_duration(0.0, 0.5, 0.0, 1.5) # Lean forward + +print("Testing pitch angle - lean backward...") +send_pose_for_duration(0.0, -0.5, 0.0, 1.5) # Lean backward + +# Test yaw angle (rotate body without moving feet) +print("Testing yaw angle - rotate clockwise...") +send_pose_for_duration(0.0, 0.0, 0.5, 1.5) # Rotate body clockwise + +print("Testing yaw angle - rotate counterclockwise...") +send_pose_for_duration(0.0, 0.0, -0.5, 1.5) # Rotate body counterclockwise + +# Reset to default pose +print("Resetting to default pose...") +send_pose_for_duration(0.0, 0.0, 0.0, 2) + +print("Pose command test completed") + +# Keep the program running (optional) +print("Press Ctrl+C to exit") +try: + while True: + time.sleep(1) +except KeyboardInterrupt: + print("Test terminated by user") diff --git a/tests/test_header.py b/tests/test_header.py new file mode 100644 index 0000000000..04ecfc9aaf --- /dev/null +++ b/tests/test_header.py @@ -0,0 +1,58 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test utilities for identifying caller information and path setup. + +This module provides functionality to determine which file called the current +script and sets up the Python path to include the parent directory, allowing +tests to import from the main application. +""" + +import inspect +import os +import sys + +# Add the parent directory of 'tests' to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def get_caller_info(): + """Identify the filename of the caller in the stack. + + Examines the call stack to find the first non-internal file that called + this module. Skips the current file and Python internal files. + + Returns: + str: The basename of the caller's filename, or "unknown" if not found. + """ + current_file = os.path.abspath(__file__) + + # Look through the call stack to find the first file that's not this one + for frame in inspect.stack()[1:]: + filename = os.path.abspath(frame.filename) + # Skip this file and Python internals + if filename != current_file and " 0: + best_score = max(grasp.get("score", 0.0) for grasp in grasps) + print(f" Best grasp score: {best_score:.3f}") + last_grasp_count = current_count + last_update_time = current_time + else: + # Show periodic "still waiting" message + if current_time - last_update_time > 10.0: + print(f" Still waiting for grasps... ({time.strftime('%H:%M:%S')})") + last_update_time = current_time + + time.sleep(1.0) # Check every second + + except Exception as e: + print(f" Error in grasp monitor: {e}") + time.sleep(2.0) + + +def main(): + """Test point cloud filtering with grasp generation using ManipulationPipeline.""" + print(" Testing point cloud filtering + grasp generation with ManipulationPipeline...") + + # Configuration + min_confidence = 0.6 + web_port = 5555 + grasp_server_url = "ws://18.224.39.74:8000/ws/grasp" + + try: + # Initialize ZED camera stream + zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) + + # Get camera intrinsics + camera_intrinsics_dict = zed_stream.get_camera_info() + camera_intrinsics = [ + camera_intrinsics_dict["fx"], + camera_intrinsics_dict["fy"], + camera_intrinsics_dict["cx"], + camera_intrinsics_dict["cy"], + ] + + # Create the concurrent manipulation pipeline WITH grasp generation + pipeline = ManipulationPipeline( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + max_objects=10, + grasp_server_url=grasp_server_url, + enable_grasp_generation=True, # Enable grasp generation + ) + + # Create ZED stream + zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) + + # Create concurrent processing streams + streams = pipeline.create_streams(zed_frame_stream) + detection_viz_stream = streams["detection_viz"] + pointcloud_viz_stream = streams["pointcloud_viz"] + grasps_stream = streams.get("grasps") # Get grasp stream if available + grasp_overlay_stream = streams.get("grasp_overlay") # Get grasp overlay stream if available + + except ImportError: + print("Error: ZED SDK not installed. Please install pyzed package.") + sys.exit(1) + except RuntimeError as e: + print(f"Error: Failed to open ZED camera: {e}") + sys.exit(1) + + try: + # Set up web interface with concurrent visualization streams + print("Initializing web interface...") + web_interface = RobotWebInterface( + port=web_port, + object_detection=detection_viz_stream, + pointcloud_stream=pointcloud_viz_stream, + grasp_overlay_stream=grasp_overlay_stream, + ) + + # Start grasp monitoring in background thread + grasp_monitor_thread = threading.Thread( + target=monitor_grasps, args=(pipeline,), daemon=True + ) + grasp_monitor_thread.start() + + print("\n Point Cloud + Grasp Generation Test Running:") + print(f" Web Interface: http://localhost:{web_port}") + print(" Object Detection View: RGB with bounding boxes") + print(" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(f" Confidence threshold: {min_confidence}") + print(f" Grasp server: {grasp_server_url}") + print(f" Available streams: {list(streams.keys())}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Cleaning up resources...") + if "zed_stream" in locals(): + zed_stream.cleanup() + if "pipeline" in locals(): + pipeline.cleanup() + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/tests/test_manipulation_perception_pipeline.py.py b/tests/test_manipulation_perception_pipeline.py.py new file mode 100644 index 0000000000..e5f08df28f --- /dev/null +++ b/tests/test_manipulation_perception_pipeline.py.py @@ -0,0 +1,165 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import threading +import time + +from pyzed import sl +from reactivex import operators as ops + +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline +from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream +from dimos.web.robot_web_interface import RobotWebInterface + + +def monitor_grasps(pipeline): + """Monitor and print grasp updates in a separate thread.""" + print(" Grasp monitor started...") + + last_grasp_count = 0 + last_update_time = time.time() + + while True: + try: + # Get latest grasps using the getter function + grasps = pipeline.get_latest_grasps(timeout=0.5) + current_time = time.time() + + if grasps is not None: + current_count = len(grasps) + if current_count != last_grasp_count: + print(f" Grasps received: {current_count} (at {time.strftime('%H:%M:%S')})") + if current_count > 0: + best_score = max(grasp.get("score", 0.0) for grasp in grasps) + print(f" Best grasp score: {best_score:.3f}") + last_grasp_count = current_count + last_update_time = current_time + else: + # Show periodic "still waiting" message + if current_time - last_update_time > 10.0: + print(f" Still waiting for grasps... ({time.strftime('%H:%M:%S')})") + last_update_time = current_time + + time.sleep(1.0) # Check every second + + except Exception as e: + print(f" Error in grasp monitor: {e}") + time.sleep(2.0) + + +def main(): + """Test point cloud filtering with grasp generation using ManipulationPipeline.""" + print(" Testing point cloud filtering + grasp generation with ManipulationPipeline...") + + # Configuration + min_confidence = 0.6 + web_port = 5555 + grasp_server_url = "ws://18.224.39.74:8000/ws/grasp" + + try: + # Initialize ZED camera stream + zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) + + # Get camera intrinsics + camera_intrinsics_dict = zed_stream.get_camera_info() + camera_intrinsics = [ + camera_intrinsics_dict["fx"], + camera_intrinsics_dict["fy"], + camera_intrinsics_dict["cx"], + camera_intrinsics_dict["cy"], + ] + + # Create the concurrent manipulation pipeline WITH grasp generation + pipeline = ManipulationPipeline( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + max_objects=10, + grasp_server_url=grasp_server_url, + enable_grasp_generation=True, # Enable grasp generation + ) + + # Create ZED stream + zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) + + # Create concurrent processing streams + streams = pipeline.create_streams(zed_frame_stream) + detection_viz_stream = streams["detection_viz"] + pointcloud_viz_stream = streams["pointcloud_viz"] + grasps_stream = streams.get("grasps") # Get grasp stream if available + grasp_overlay_stream = streams.get("grasp_overlay") # Get grasp overlay stream if available + + except ImportError: + print("Error: ZED SDK not installed. Please install pyzed package.") + sys.exit(1) + except RuntimeError as e: + print(f"Error: Failed to open ZED camera: {e}") + sys.exit(1) + + try: + # Set up web interface with concurrent visualization streams + print("Initializing web interface...") + web_interface = RobotWebInterface( + port=web_port, + object_detection=detection_viz_stream, + pointcloud_stream=pointcloud_viz_stream, + grasp_overlay_stream=grasp_overlay_stream, + ) + + # Start grasp monitoring in background thread + grasp_monitor_thread = threading.Thread( + target=monitor_grasps, args=(pipeline,), daemon=True + ) + grasp_monitor_thread.start() + + print("\n Point Cloud + Grasp Generation Test Running:") + print(f" Web Interface: http://localhost:{web_port}") + print(" Object Detection View: RGB with bounding boxes") + print(" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(f" Confidence threshold: {min_confidence}") + print(f" Grasp server: {grasp_server_url}") + print(f" Available streams: {list(streams.keys())}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Cleaning up resources...") + if "zed_stream" in locals(): + zed_stream.cleanup() + if "pipeline" in locals(): + pipeline.cleanup() + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/tests/test_manipulation_pipeline_single_frame.py b/tests/test_manipulation_pipeline_single_frame.py new file mode 100644 index 0000000000..91e10aea33 --- /dev/null +++ b/tests/test_manipulation_pipeline_single_frame.py @@ -0,0 +1,246 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test manipulation processor with direct visualization and grasp data output.""" + +import argparse +import os + +import cv2 +import matplotlib +import numpy as np + +from dimos.utils.data import get_data + +# Try to use TkAgg backend for live display, fallback to Agg if not available +try: + matplotlib.use("TkAgg") +except: + try: + matplotlib.use("Qt5Agg") + except: + matplotlib.use("Agg") # Fallback to non-interactive +import matplotlib.pyplot as plt +import open3d as o3d + +from dimos.manipulation.manip_aio_processer import ManipulationProcessor +from dimos.perception.grasp_generation.utils import create_grasp_overlay, visualize_grasps_3d +from dimos.perception.pointcloud.utils import ( + combine_object_pointclouds, + load_camera_matrix_from_yaml, + visualize_clustered_point_clouds, + visualize_pcd, + visualize_voxel_grid, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def load_first_frame(data_dir: str): + """Load first RGB-D frame and camera intrinsics.""" + # Load images + color_img = cv2.imread(os.path.join(data_dir, "color", "00000.png")) + color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) + + depth_img = cv2.imread(os.path.join(data_dir, "depth", "00000.png"), cv2.IMREAD_ANYDEPTH) + if depth_img.dtype == np.uint16: + depth_img = depth_img.astype(np.float32) / 1000.0 + # Load intrinsics + camera_matrix = load_camera_matrix_from_yaml(os.path.join(data_dir, "color_camera_info.yaml")) + intrinsics = [ + camera_matrix[0, 0], + camera_matrix[1, 1], + camera_matrix[0, 2], + camera_matrix[1, 2], + ] + + return color_img, depth_img, intrinsics + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def run_processor(color_img, depth_img, intrinsics, grasp_server_url=None): + """Run processor and collect results.""" + processor_kwargs = { + "camera_intrinsics": intrinsics, + "enable_grasp_generation": True, + "enable_segmentation": True, + } + + if grasp_server_url: + processor_kwargs["grasp_server_url"] = grasp_server_url + + processor = ManipulationProcessor(**processor_kwargs) + + # Process frame without grasp generation + results = processor.process_frame(color_img, depth_img, generate_grasps=False) + + # Run grasp generation separately + grasps = processor.run_grasp_generation(results["all_objects"], results["full_pointcloud"]) + results["grasps"] = grasps + results["grasp_overlay"] = create_grasp_overlay(color_img, grasps, intrinsics) + + processor.cleanup() + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-dir", default=get_data("rgbd_frames")) + parser.add_argument("--wait-time", type=float, default=5.0) + parser.add_argument( + "--grasp-server-url", + default="ws://18.224.39.74:8000/ws/grasp", + help="WebSocket URL for Dimensional Grasp server", + ) + args = parser.parse_args() + + # Load data + color_img, depth_img, intrinsics = load_first_frame(args.data_dir) + logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") + + # Run processor + results = run_processor(color_img, depth_img, intrinsics, args.grasp_server_url) + + # Print results summary + print(f"Processing time: {results.get('processing_time', 0):.3f}s") + print(f"Detection objects: {len(results.get('detected_objects', []))}") + print(f"All objects processed: {len(results.get('all_objects', []))}") + + # Print grasp summary + grasp_data = results["grasps"] + total_grasps = len(grasp_data) if isinstance(grasp_data, list) else 0 + best_score = max(grasp["score"] for grasp in grasp_data) if grasp_data else 0 + + print(f"Grasps: {total_grasps} total (best score: {best_score:.3f})") + + # Create visualizations + plot_configs = [] + if results["detection_viz"] is not None: + plot_configs.append(("detection_viz", "Object Detection")) + if results["segmentation_viz"] is not None: + plot_configs.append(("segmentation_viz", "Semantic Segmentation")) + if results["pointcloud_viz"] is not None: + plot_configs.append(("pointcloud_viz", "All Objects Point Cloud")) + if results["detected_pointcloud_viz"] is not None: + plot_configs.append(("detected_pointcloud_viz", "Detection Objects Point Cloud")) + if results["misc_pointcloud_viz"] is not None: + plot_configs.append(("misc_pointcloud_viz", "Misc/Background Points")) + if results["grasp_overlay"] is not None: + plot_configs.append(("grasp_overlay", "Grasp Overlay")) + + # Create subplot layout + num_plots = len(plot_configs) + if num_plots <= 3: + fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5)) + else: + rows = 2 + cols = (num_plots + 1) // 2 + _fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) + + if num_plots == 1: + axes = [axes] + elif num_plots > 2: + axes = axes.flatten() + + # Plot each result + for i, (key, title) in enumerate(plot_configs): + axes[i].imshow(results[key]) + axes[i].set_title(title) + axes[i].axis("off") + + # Hide unused subplots + if num_plots > 3: + for i in range(num_plots, len(axes)): + axes[i].axis("off") + + plt.tight_layout() + plt.savefig("manipulation_results.png", dpi=150, bbox_inches="tight") + plt.show(block=True) + plt.close() + + point_clouds = [obj["point_cloud"] for obj in results["all_objects"]] + colors = [obj["color"] for obj in results["all_objects"]] + combined_pcd = combine_object_pointclouds(point_clouds, colors) + + # 3D Grasp visualization + if grasp_data: + # Convert grasp format to visualization format for 3D display + viz_grasps = [] + for grasp in grasp_data: + translation = grasp.get("translation", [0, 0, 0]) + rotation_matrix = np.array(grasp.get("rotation_matrix", np.eye(3).tolist())) + score = grasp.get("score", 0.0) + width = grasp.get("width", 0.08) + + viz_grasp = { + "translation": translation, + "rotation_matrix": rotation_matrix, + "width": width, + "score": score, + } + viz_grasps.append(viz_grasp) + + # Use unified 3D visualization + visualize_grasps_3d(combined_pcd, viz_grasps) + + # Visualize full point cloud + visualize_pcd( + results["full_pointcloud"], + window_name="Full Scene Point Cloud", + point_size=2.0, + show_coordinate_frame=True, + ) + + # Visualize all objects point cloud + visualize_pcd( + combined_pcd, + window_name="All Objects Point Cloud", + point_size=3.0, + show_coordinate_frame=True, + ) + + # Visualize misc clusters + visualize_clustered_point_clouds( + results["misc_clusters"], + window_name="Misc/Background Clusters (DBSCAN)", + point_size=3.0, + show_coordinate_frame=True, + ) + + # Visualize voxel grid + visualize_voxel_grid( + results["misc_voxel_grid"], + window_name="Misc/Background Voxel Grid", + show_coordinate_frame=True, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_manipulation_pipeline_single_frame_lcm.py b/tests/test_manipulation_pipeline_single_frame_lcm.py new file mode 100644 index 0000000000..14ddc5e119 --- /dev/null +++ b/tests/test_manipulation_pipeline_single_frame_lcm.py @@ -0,0 +1,419 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test manipulation processor with LCM topic subscription.""" + +import argparse +import pickle +import threading + +import cv2 +import matplotlib +import numpy as np + +# Try to use TkAgg backend for live display, fallback to Agg if not available +try: + matplotlib.use("TkAgg") +except: + try: + matplotlib.use("Qt5Agg") + except: + matplotlib.use("Agg") # Fallback to non-interactive + +# LCM imports +import lcm +from lcm_msgs.sensor_msgs import CameraInfo as LCMCameraInfo, Image as LCMImage +import open3d as o3d + +from dimos.manipulation.manip_aio_processer import ManipulationProcessor +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class LCMDataCollector: + """Collects one message from each required LCM topic.""" + + def __init__(self, lcm_url: str = "udpm://239.255.76.67:7667?ttl=1"): + self.lcm = lcm.LCM(lcm_url) + + # Data storage + self.rgb_data: np.ndarray | None = None + self.depth_data: np.ndarray | None = None + self.camera_intrinsics: list[float] | None = None + + # Synchronization + self.data_lock = threading.Lock() + self.data_ready_event = threading.Event() + + # Flags to track received messages + self.rgb_received = False + self.depth_received = False + self.camera_info_received = False + + # Subscribe to topics + self.lcm.subscribe("head_cam_rgb#sensor_msgs.Image", self._handle_rgb_message) + self.lcm.subscribe("head_cam_depth#sensor_msgs.Image", self._handle_depth_message) + self.lcm.subscribe("head_cam_info#sensor_msgs.CameraInfo", self._handle_camera_info_message) + + logger.info("LCM Data Collector initialized") + logger.info("Subscribed to topics:") + logger.info(" - head_cam_rgb#sensor_msgs.Image") + logger.info(" - head_cam_depth#sensor_msgs.Image") + logger.info(" - head_cam_info#sensor_msgs.CameraInfo") + + def _handle_rgb_message(self, channel: str, data: bytes): + """Handle RGB image message.""" + if self.rgb_received: + return # Already got one, ignore subsequent messages + + try: + msg = LCMImage.decode(data) + + # Convert message data to numpy array + if msg.encoding == "rgb8": + # RGB8 format: 3 bytes per pixel + rgb_array = np.frombuffer(msg.data[: msg.data_length], dtype=np.uint8) + rgb_image = rgb_array.reshape((msg.height, msg.width, 3)) + + with self.data_lock: + self.rgb_data = rgb_image + self.rgb_received = True + logger.info( + f"RGB message received: {msg.width}x{msg.height}, encoding: {msg.encoding}" + ) + self._check_all_data_received() + + else: + logger.warning(f"Unsupported RGB encoding: {msg.encoding}") + + except Exception as e: + logger.error(f"Error processing RGB message: {e}") + + def _handle_depth_message(self, channel: str, data: bytes): + """Handle depth image message.""" + if self.depth_received: + return # Already got one, ignore subsequent messages + + try: + msg = LCMImage.decode(data) + + # Convert message data to numpy array + if msg.encoding == "32FC1": + # 32FC1 format: 4 bytes (float32) per pixel + depth_array = np.frombuffer(msg.data[: msg.data_length], dtype=np.float32) + depth_image = depth_array.reshape((msg.height, msg.width)) + + with self.data_lock: + self.depth_data = depth_image + self.depth_received = True + logger.info( + f"Depth message received: {msg.width}x{msg.height}, encoding: {msg.encoding}" + ) + logger.info( + f"Depth range: {depth_image.min():.3f} - {depth_image.max():.3f} meters" + ) + self._check_all_data_received() + + else: + logger.warning(f"Unsupported depth encoding: {msg.encoding}") + + except Exception as e: + logger.error(f"Error processing depth message: {e}") + + def _handle_camera_info_message(self, channel: str, data: bytes): + """Handle camera info message.""" + if self.camera_info_received: + return # Already got one, ignore subsequent messages + + try: + msg = LCMCameraInfo.decode(data) + + # Extract intrinsics from K matrix: [fx, 0, cx, 0, fy, cy, 0, 0, 1] + K = msg.K + fx = K[0] # K[0,0] + fy = K[4] # K[1,1] + cx = K[2] # K[0,2] + cy = K[5] # K[1,2] + + intrinsics = [fx, fy, cx, cy] + + with self.data_lock: + self.camera_intrinsics = intrinsics + self.camera_info_received = True + logger.info(f"Camera info received: {msg.width}x{msg.height}") + logger.info(f"Intrinsics: fx={fx:.1f}, fy={fy:.1f}, cx={cx:.1f}, cy={cy:.1f}") + self._check_all_data_received() + + except Exception as e: + logger.error(f"Error processing camera info message: {e}") + + def _check_all_data_received(self): + """Check if all required data has been received.""" + if self.rgb_received and self.depth_received and self.camera_info_received: + logger.info("✅ All required data received!") + self.data_ready_event.set() + + def wait_for_data(self, timeout: float = 30.0) -> bool: + """Wait for all data to be received.""" + logger.info("Waiting for RGB, depth, and camera info messages...") + + # Start LCM handling in a separate thread + lcm_thread = threading.Thread(target=self._lcm_handle_loop, daemon=True) + lcm_thread.start() + + # Wait for data with timeout + return self.data_ready_event.wait(timeout) + + def _lcm_handle_loop(self): + """LCM message handling loop.""" + try: + while not self.data_ready_event.is_set(): + self.lcm.handle_timeout(100) # 100ms timeout + except Exception as e: + logger.error(f"Error in LCM handling loop: {e}") + + def get_data(self): + """Get the collected data.""" + with self.data_lock: + return self.rgb_data, self.depth_data, self.camera_intrinsics + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def run_processor(color_img, depth_img, intrinsics): + """Run processor and collect results.""" + # Create processor + processor = ManipulationProcessor( + camera_intrinsics=intrinsics, + grasp_server_url="ws://18.224.39.74:8000/ws/grasp", + enable_grasp_generation=False, + enable_segmentation=True, + ) + + # Process single frame directly + results = processor.process_frame(color_img, depth_img) + + # Debug: print available results + print(f"Available results: {list(results.keys())}") + + processor.cleanup() + + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lcm-url", default="udpm://239.255.76.67:7667?ttl=1", help="LCM URL for subscription" + ) + parser.add_argument( + "--timeout", type=float, default=30.0, help="Timeout in seconds to wait for messages" + ) + parser.add_argument( + "--save-images", action="store_true", help="Save received RGB and depth images to files" + ) + args = parser.parse_args() + + # Create data collector + collector = LCMDataCollector(args.lcm_url) + + # Wait for data + if not collector.wait_for_data(args.timeout): + logger.error(f"Timeout waiting for data after {args.timeout} seconds") + logger.error("Make sure Unity is running and publishing to the LCM topics") + return + + # Get the collected data + color_img, depth_img, intrinsics = collector.get_data() + + logger.info(f"Loaded images: color {color_img.shape}, depth {depth_img.shape}") + logger.info(f"Intrinsics: {intrinsics}") + + # Save images if requested + if args.save_images: + try: + cv2.imwrite("received_rgb.png", cv2.cvtColor(color_img, cv2.COLOR_RGB2BGR)) + # Save depth as 16-bit for visualization + depth_viz = (np.clip(depth_img * 1000, 0, 65535)).astype(np.uint16) + cv2.imwrite("received_depth.png", depth_viz) + logger.info("Saved received_rgb.png and received_depth.png") + except Exception as e: + logger.warning(f"Failed to save images: {e}") + + # Run processor + results = run_processor(color_img, depth_img, intrinsics) + + # Debug: Print what we received + print("\n✅ Processor Results:") + print(f" Available results: {list(results.keys())}") + print(f" Processing time: {results.get('processing_time', 0):.3f}s") + + # Show timing breakdown if available + if "timing_breakdown" in results: + breakdown = results["timing_breakdown"] + print(" Timing breakdown:") + print(f" - Detection: {breakdown.get('detection', 0):.3f}s") + print(f" - Segmentation: {breakdown.get('segmentation', 0):.3f}s") + print(f" - Point cloud: {breakdown.get('pointcloud', 0):.3f}s") + print(f" - Misc extraction: {breakdown.get('misc_extraction', 0):.3f}s") + + # Print object information + detected_count = len(results.get("detected_objects", [])) + all_count = len(results.get("all_objects", [])) + + print(f" Detection objects: {detected_count}") + print(f" All objects processed: {all_count}") + + # Print misc clusters information + if results.get("misc_clusters"): + cluster_count = len(results["misc_clusters"]) + total_misc_points = sum( + len(np.asarray(cluster.points)) for cluster in results["misc_clusters"] + ) + print(f" Misc clusters: {cluster_count} clusters with {total_misc_points} total points") + else: + print(" Misc clusters: None") + + # Print grasp summary + if results.get("grasps"): + total_grasps = 0 + best_score = 0 + for grasp in results["grasps"]: + score = grasp.get("score", 0) + if score > best_score: + best_score = score + total_grasps += 1 + print(f" Grasps generated: {total_grasps} (best score: {best_score:.3f})") + else: + print(" Grasps: None generated") + + # Save results to pickle file + pickle_path = "manipulation_results.pkl" + print(f"\nSaving results to pickle file: {pickle_path}") + + def serialize_point_cloud(pcd): + """Convert Open3D PointCloud to serializable format.""" + if pcd is None: + return None + data = { + "points": np.asarray(pcd.points).tolist() if hasattr(pcd, "points") else [], + "colors": np.asarray(pcd.colors).tolist() + if hasattr(pcd, "colors") and pcd.colors + else [], + } + return data + + def serialize_voxel_grid(voxel_grid): + """Convert Open3D VoxelGrid to serializable format.""" + if voxel_grid is None: + return None + + # Extract voxel data + voxels = voxel_grid.get_voxels() + data = { + "voxel_size": voxel_grid.voxel_size, + "origin": np.asarray(voxel_grid.origin).tolist(), + "voxels": [ + ( + v.grid_index[0], + v.grid_index[1], + v.grid_index[2], + v.color[0], + v.color[1], + v.color[2], + ) + for v in voxels + ], + } + return data + + # Create a copy of results with non-picklable objects converted + pickle_data = { + "color_img": color_img, + "depth_img": depth_img, + "intrinsics": intrinsics, + "results": {}, + } + + # Convert and store all results, properly handling Open3D objects + for key, value in results.items(): + if key.endswith("_viz") or key in [ + "processing_time", + "timing_breakdown", + "detection2d_objects", + "segmentation2d_objects", + ]: + # These are already serializable + pickle_data["results"][key] = value + elif key == "full_pointcloud": + # Serialize PointCloud object + pickle_data["results"][key] = serialize_point_cloud(value) + print(f"Serialized {key}") + elif key == "misc_voxel_grid": + # Serialize VoxelGrid object + pickle_data["results"][key] = serialize_voxel_grid(value) + print(f"Serialized {key}") + elif key == "misc_clusters": + # List of PointCloud objects + if value: + serialized_clusters = [serialize_point_cloud(cluster) for cluster in value] + pickle_data["results"][key] = serialized_clusters + print(f"Serialized {key} ({len(serialized_clusters)} clusters)") + elif key == "detected_objects" or key == "all_objects": + # Objects with PointCloud attributes + serialized_objects = [] + for obj in value: + obj_dict = {k: v for k, v in obj.items() if k != "point_cloud"} + if "point_cloud" in obj: + obj_dict["point_cloud"] = serialize_point_cloud(obj.get("point_cloud")) + serialized_objects.append(obj_dict) + pickle_data["results"][key] = serialized_objects + print(f"Serialized {key} ({len(serialized_objects)} objects)") + else: + try: + # Try to pickle as is + pickle_data["results"][key] = value + print(f"Preserved {key} as is") + except (TypeError, ValueError): + print(f"Warning: Could not serialize {key}, skipping") + + with open(pickle_path, "wb") as f: + pickle.dump(pickle_data, f) + + print("Results saved successfully with all 3D data serialized!") + print(f"Pickled data keys: {list(pickle_data['results'].keys())}") + + # Visualization code has been moved to visualization_script.py + # The results have been pickled and can be loaded from there + print("\nVisualization code has been moved to visualization_script.py") + print("Run 'python visualization_script.py' to visualize the results") + + +if __name__ == "__main__": + main() diff --git a/tests/test_move_vel_unitree.py b/tests/test_move_vel_unitree.py new file mode 100644 index 0000000000..c0bd416853 --- /dev/null +++ b/tests/test_move_vel_unitree.py @@ -0,0 +1,31 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills + +# Initialize robot +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() +) + +# Move the robot forward +robot.move_vel(x=0.5, y=0, yaw=0, duration=5) + +while True: + time.sleep(1) diff --git a/tests/test_navigate_to_object_robot.py b/tests/test_navigate_to_object_robot.py new file mode 100644 index 0000000000..6ec50714a6 --- /dev/null +++ b/tests/test_navigate_to_object_robot.py @@ -0,0 +1,138 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import threading +import time + +from reactivex import operators as RxOps + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.navigation import Navigate +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + +logger = setup_logger() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Navigate to an object using Qwen vision.") + parser.add_argument( + "--object", + type=str, + default="chair", + help="Name of the object to navigate to (default: chair)", + ) + parser.add_argument( + "--distance", + type=float, + default=1.0, + help="Desired distance to maintain from object in meters (default: 0.8)", + ) + parser.add_argument( + "--timeout", + type=float, + default=60.0, + help="Maximum navigation time in seconds (default: 30.0)", + ) + return parser.parse_args() + + +def main(): + # Get command line arguments + args = parse_args() + object_name = args.object # Object to navigate to + distance = args.distance # Desired distance to object + timeout = args.timeout # Maximum navigation time + + print(f"Initializing Unitree Go2 robot for navigating to a {object_name}...") + + # Initialize the robot with ROS control and skills + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + + # Add and create instance of NavigateToObject skill + robot_skills = robot.get_skills() + robot_skills.add(Navigate) + robot_skills.create_instance("Navigate", robot=robot) + + # Set up tracking and visualization streams + object_tracking_stream = robot.object_tracking_stream + viz_stream = object_tracking_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x["viz_frame"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + + # The local planner visualization stream is created during robot initialization + local_planner_stream = robot.local_planner_viz_stream + + local_planner_stream = local_planner_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + + try: + # Set up web interface + logger.info("Initializing web interface") + streams = { + # "robot_video": video_stream, + "object_tracking": viz_stream, + "local_planner": local_planner_stream, + } + + web_interface = RobotWebInterface(port=5555, **streams) + + # Wait for camera and tracking to initialize + print("Waiting for camera and tracking to initialize...") + time.sleep(3) + + def navigate_to_object(): + try: + result = robot_skills.call( + "Navigate", robot=robot, query=object_name, timeout=timeout + ) + print(f"Navigation result: {result}") + except Exception as e: + print(f"Error during navigation: {e}") + + navigate_thread = threading.Thread(target=navigate_to_object, daemon=True) + navigate_thread.start() + + print( + f"Navigating to {object_name} with desired distance {distance}m and timeout {timeout}s..." + ) + print("Web interface available at http://localhost:5555") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nInterrupted by user") + except Exception as e: + print(f"Error during navigation test: {e}") + finally: + print("Test completed") + robot.cleanup() + + +if __name__ == "__main__": + main() diff --git a/tests/test_navigation_skills.py b/tests/test_navigation_skills.py new file mode 100644 index 0000000000..eedb60c2f8 --- /dev/null +++ b/tests/test_navigation_skills.py @@ -0,0 +1,265 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple test script for semantic / spatial memory skills. + +This script is a simplified version that focuses only on making the workflow work. + +Usage: + # Build and query in one run: + python simple_navigation_test.py --query "kitchen" + + # Skip build and just query: + python simple_navigation_test.py --skip-build --query "kitchen" +""" + +import argparse +import os +import threading +import time + +from reactivex import operators as RxOps + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.navigation import BuildSemanticMap, Navigate +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + +# Setup logging +logger = setup_logger() + + +def parse_args(): + spatial_memory_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../assets/spatial_memory_vegas") + ) + + parser = argparse.ArgumentParser(description="Simple test for semantic map skills.") + parser.add_argument( + "--skip-build", + action="store_true", + help="Skip building the map and run navigation with existing semantic and visual memory", + ) + parser.add_argument( + "--query", type=str, default="kitchen", help="Text query for navigation (default: kitchen)" + ) + parser.add_argument( + "--db-path", + type=str, + default=os.path.join(spatial_memory_dir, "chromadb_data"), + help="Path to ChromaDB database", + ) + parser.add_argument("--justgo", type=str, help="Globally navigate to location") + parser.add_argument( + "--visual-memory-dir", + type=str, + default=spatial_memory_dir, + help="Directory for visual memory", + ) + parser.add_argument( + "--visual-memory-file", + type=str, + default="visual_memory.pkl", + help="Filename for visual memory", + ) + parser.add_argument( + "--port", type=int, default=5555, help="Port for web visualization interface" + ) + return parser.parse_args() + + +def build_map(robot, args): + logger.info("Starting to build spatial memory...") + + # Create the BuildSemanticMap skill + build_skill = BuildSemanticMap( + robot=robot, + db_path=args.db_path, + visual_memory_dir=args.visual_memory_dir, + visual_memory_file=args.visual_memory_file, + ) + + # Start the skill + build_skill() + + # Wait for user to press Ctrl+C + logger.info("Press Ctrl+C to stop mapping and proceed to navigation...") + + try: + while True: + time.sleep(0.5) + except KeyboardInterrupt: + logger.info("Stopping map building...") + + # Stop the skill + build_skill.stop() + logger.info("Map building complete.") + + +def query_map(robot, args): + logger.info(f"Querying spatial memory for: '{args.query}'") + + # Create the Navigate skill + nav_skill = Navigate( + robot=robot, + query=args.query, + db_path=args.db_path, + visual_memory_path=os.path.join(args.visual_memory_dir, args.visual_memory_file), + ) + + # Query the map + result = nav_skill() + + # Display the result + if isinstance(result, dict) and result.get("success", False): + position = result.get("position", (0, 0, 0)) + similarity = result.get("similarity", 0) + logger.info(f"Found '{args.query}' at position: {position}") + logger.info(f"Similarity score: {similarity:.4f}") + return position + + else: + logger.error(f"Navigation query failed: {result}") + return False + + +def setup_visualization(robot, port=5555): + """Set up visualization streams for the web interface""" + logger.info(f"Setting up visualization streams on port {port}") + + # Get video stream from robot + video_stream = robot.video_stream_ros.pipe( + RxOps.share(), + RxOps.map(lambda frame: frame), + RxOps.filter(lambda frame: frame is not None), + ) + + # Get local planner visualization stream + local_planner_stream = robot.local_planner_viz_stream.pipe( + RxOps.share(), + RxOps.map(lambda frame: frame), + RxOps.filter(lambda frame: frame is not None), + ) + + # Create web interface with streams + streams = {"robot_video": video_stream, "local_planner": local_planner_stream} + + web_interface = RobotWebInterface(port=port, **streams) + + return web_interface + + +def run_navigation(robot, target): + """Run navigation in a separate thread""" + logger.info(f"Starting navigation to target: {target}") + return robot.global_planner.set_goal(target) + + +def main(): + args = parse_args() + + # Ensure directories exist + if not args.justgo: + os.makedirs(args.db_path, exist_ok=True) + os.makedirs(args.visual_memory_dir, exist_ok=True) + + # Initialize robot + logger.info("Initializing robot...") + ros_control = UnitreeROSControl(node_name="simple_nav_test", mock_connection=False) + robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP"), skills=MyUnitreeSkills()) + + # Set up visualization + web_interface = None + try: + # Set up visualization first if the robot has video capabilities + if hasattr(robot, "video_stream_ros") and robot.video_stream_ros is not None: + web_interface = setup_visualization(robot, port=args.port) + # Start web interface in a separate thread + viz_thread = threading.Thread(target=web_interface.run, daemon=True) + viz_thread.start() + logger.info(f"Web visualization available at http://localhost:{args.port}") + # Wait a moment for the web interface to initialize + time.sleep(2) + + if args.justgo: + # Just go to the specified location + coords = list(map(float, args.justgo.split(","))) + logger.info(f"Navigating to coordinates: {coords}") + + # Run navigation + navigate_thread = threading.Thread( + target=lambda: run_navigation(robot, coords), daemon=True + ) + navigate_thread.start() + + # Wait for navigation to complete or user to interrupt + try: + while navigate_thread.is_alive(): + time.sleep(0.5) + logger.info("Navigation completed") + except KeyboardInterrupt: + logger.info("Navigation interrupted by user") + else: + # Build map if not skipped + if not args.skip_build: + build_map(robot, args) + + # Query the map + target = query_map(robot, args) + + if not target: + logger.error("No target found for navigation.") + return + + # Run navigation + navigate_thread = threading.Thread( + target=lambda: run_navigation(robot, target), daemon=True + ) + navigate_thread.start() + + # Wait for navigation to complete or user to interrupt + try: + while navigate_thread.is_alive(): + time.sleep(0.5) + logger.info("Navigation completed") + except KeyboardInterrupt: + logger.info("Navigation interrupted by user") + + # If web interface is running, keep the main thread alive + if web_interface: + logger.info( + "Navigation completed. Visualization still available. Press Ctrl+C to exit." + ) + try: + while True: + time.sleep(0.5) + except KeyboardInterrupt: + logger.info("Exiting...") + + finally: + # Clean up + logger.info("Cleaning up resources...") + try: + robot.cleanup() + except Exception as e: + logger.error(f"Error during cleanup: {e}") + + logger.info("Test completed successfully") + + +if __name__ == "__main__": + main() diff --git a/tests/test_object_detection_agent_data_query_stream.py b/tests/test_object_detection_agent_data_query_stream.py new file mode 100644 index 0000000000..cc75fb1961 --- /dev/null +++ b/tests/test_object_detection_agent_data_query_stream.py @@ -0,0 +1,187 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys +import threading + +from dotenv import load_dotenv +from reactivex import operators as ops + +from dimos.agents.claude_agent import ClaudeAgent +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.video_provider import VideoProvider +from dimos.utils.reactive import backpressure +from dimos.web.robot_web_interface import RobotWebInterface + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Test ObjectDetectionStream for object detection and position estimation" + ) + parser.add_argument( + "--mode", + type=str, + default="webcam", + choices=["robot", "webcam"], + help='Mode to run: "robot" or "webcam" (default: webcam)', + ) + return parser.parse_args() + + +load_dotenv() + + +def main(): + # Get command line arguments + args = parse_args() + + # Set default parameters + min_confidence = 0.6 + class_filter = None # No class filtering + web_port = 5555 + + # Initialize detector + detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) + + # Initialize based on mode + if args.mode == "robot": + print("Initializing in robot mode...") + + # Get robot IP from environment + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + print("Error: ROBOT_IP environment variable not set.") + sys.exit(1) + + # Initialize robot + robot = UnitreeGo2( + ip=robot_ip, + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + # Create video stream from robot's camera + video_stream = robot.video_stream_ros + + # Initialize ObjectDetectionStream with robot and transform function + object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + transform_to_map=robot.ros_control.transform_pose, + detector=detector, + video_stream=video_stream, + ) + + else: # webcam mode + print("Initializing in webcam mode...") + + # Define camera intrinsics for the webcam + # These are approximate values for a typical 640x480 webcam + width, height = 640, 480 + focal_length_mm = 3.67 # mm (typical webcam) + sensor_width_mm = 4.8 # mm (1/4" sensor) + + # Calculate focal length in pixels + focal_length_x_px = width * focal_length_mm / sensor_width_mm + focal_length_y_px = height * focal_length_mm / sensor_width_mm + + # Principal point (center of image) + cx, cy = width / 2, height / 2 + + # Camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + + # Initialize video provider and ObjectDetectionStream + video_provider = VideoProvider("test_camera", video_source=0) # Default camera + # Create video stream + video_stream = backpressure( + video_provider.capture_video_as_observable(realtime=True, fps=30) + ) + + object_detector = ObjectDetectionStream( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + detector=detector, + video_stream=video_stream, + ) + + # Set placeholder robot for cleanup + robot = None + + # Create visualization stream for web interface + viz_stream = object_detector.get_stream().pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), + ) + + # Create object data observable for Agent using the formatted stream + object_data_stream = object_detector.get_formatted_stream().pipe( + ops.share(), ops.filter(lambda x: x is not None) + ) + + # Create stop event for clean shutdown + stop_event = threading.Event() + + try: + # Set up web interface + print("Initializing web interface...") + web_interface = RobotWebInterface(port=web_port, object_detection=viz_stream) + + agent = ClaudeAgent( + dev_name="test_agent", + # input_query_stream=stt_node.emit_text(), + input_query_stream=web_interface.query_stream, + input_data_stream=object_data_stream, + system_query="Tell me what you see", + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=0, + ) + + # Print configuration information + print("\nObjectDetectionStream Test Running:") + print(f"Mode: {args.mode}") + print(f"Web Interface: http://localhost:{web_port}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + # Clean up resources + print("Cleaning up resources...") + stop_event.set() + + if args.mode == "robot" and robot: + robot.cleanup() + elif args.mode == "webcam": + if "video_provider" in locals(): + video_provider.dispose_all() + + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/tests/test_object_detection_stream.py b/tests/test_object_detection_stream.py new file mode 100644 index 0000000000..540bb40e06 --- /dev/null +++ b/tests/test_object_detection_stream.py @@ -0,0 +1,239 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys +import threading +import time +from typing import Any + +from dotenv import load_dotenv +from reactivex import operators as ops + +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.video_provider import VideoProvider +from dimos.utils.reactive import backpressure +from dimos.web.robot_web_interface import RobotWebInterface + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Test ObjectDetectionStream for object detection and position estimation" + ) + parser.add_argument( + "--mode", + type=str, + default="webcam", + choices=["robot", "webcam"], + help='Mode to run: "robot" or "webcam" (default: webcam)', + ) + return parser.parse_args() + + +load_dotenv() + + +class ResultPrinter: + def __init__(self, print_interval: float = 1.0): + """ + Initialize a result printer that limits console output frequency. + + Args: + print_interval: Minimum time between console prints in seconds + """ + self.print_interval = print_interval + self.last_print_time = 0 + + def print_results(self, objects: list[dict[str, Any]]): + """Print object detection results to console with rate limiting.""" + current_time = time.time() + + # Only print results at the specified interval + if current_time - self.last_print_time >= self.print_interval: + self.last_print_time = current_time + + if not objects: + print("\n[No objects detected]") + return + + print("\n" + "=" * 50) + print(f"Detected {len(objects)} objects at {time.strftime('%H:%M:%S')}:") + print("=" * 50) + + for i, obj in enumerate(objects): + pos = obj["position"] + rot = obj["rotation"] + size = obj["size"] + + print( + f"{i + 1}. {obj['label']} (ID: {obj['object_id']}, Conf: {obj['confidence']:.2f})" + ) + print(f" Position: x={pos.x:.2f}, y={pos.y:.2f}, z={pos.z:.2f} m") + print(f" Rotation: yaw={rot.z:.2f} rad") + print(f" Size: width={size['width']:.2f}, height={size['height']:.2f} m") + print(f" Depth: {obj['depth']:.2f} m") + print("-" * 30) + + +def main(): + # Get command line arguments + args = parse_args() + + # Set up the result printer for console output + result_printer = ResultPrinter(print_interval=1.0) + + # Set default parameters + min_confidence = 0.6 + class_filter = None # No class filtering + web_port = 5555 + + # Initialize based on mode + if args.mode == "robot": + print("Initializing in robot mode...") + + # Get robot IP from environment + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + print("Error: ROBOT_IP environment variable not set.") + sys.exit(1) + + # Initialize robot + robot = UnitreeGo2( + ip=robot_ip, + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + # Create video stream from robot's camera + video_stream = robot.video_stream_ros + + # Initialize ObjectDetectionStream with robot and transform function + object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + transform_to_map=robot.ros_control.transform_pose, + detector=detector, + video_stream=video_stream, + disable_depth=False, + ) + + else: # webcam mode + print("Initializing in webcam mode...") + + # Define camera intrinsics for the webcam + # These are approximate values for a typical 640x480 webcam + width, height = 640, 480 + focal_length_mm = 3.67 # mm (typical webcam) + sensor_width_mm = 4.8 # mm (1/4" sensor) + + # Calculate focal length in pixels + focal_length_x_px = width * focal_length_mm / sensor_width_mm + focal_length_y_px = height * focal_length_mm / sensor_width_mm + + # Principal point (center of image) + cx, cy = width / 2, height / 2 + + # Camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + + # Initialize video provider and ObjectDetectionStream + video_provider = VideoProvider("test_camera", video_source=0) # Default camera + # Create video stream + video_stream = backpressure( + video_provider.capture_video_as_observable(realtime=True, fps=30) + ) + + object_detector = ObjectDetectionStream( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + video_stream=video_stream, + disable_depth=False, + draw_masks=True, + ) + + # Set placeholder robot for cleanup + robot = None + + # Create visualization stream for web interface + viz_stream = object_detector.get_stream().pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), + ) + + # Create stop event for clean shutdown + stop_event = threading.Event() + + # Define subscription callback to print results + def on_next(result): + if stop_event.is_set(): + return + + # Print detected objects to console + if "objects" in result: + result_printer.print_results(result["objects"]) + + def on_error(error): + print(f"Error in detection stream: {error}") + stop_event.set() + + def on_completed(): + print("Detection stream completed") + stop_event.set() + + try: + # Subscribe to the detection stream + subscription = object_detector.get_stream().subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + # Set up web interface + print("Initializing web interface...") + web_interface = RobotWebInterface(port=web_port, object_detection=viz_stream) + + # Print configuration information + print("\nObjectDetectionStream Test Running:") + print(f"Mode: {args.mode}") + print(f"Web Interface: http://localhost:{web_port}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + # Clean up resources + print("Cleaning up resources...") + stop_event.set() + + if subscription: + subscription.dispose() + + if args.mode == "robot" and robot: + robot.cleanup() + elif args.mode == "webcam": + if "video_provider" in locals(): + video_provider.dispose_all() + + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/tests/test_object_tracking_module.py b/tests/test_object_tracking_module.py new file mode 100755 index 0000000000..56fef2e3d7 --- /dev/null +++ b/tests/test_object_tracking_module.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for Object Tracking module with ZED camera.""" + +import asyncio + +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo + +from dimos import core +from dimos.hardware.zed_camera import ZEDModule +from dimos.msgs.geometry_msgs import PoseStamped + +# Import message types +from dimos.msgs.sensor_msgs import Image +from dimos.perception.object_tracker import ObjectTracking +from dimos.protocol import pubsub +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Suppress verbose Foxglove bridge warnings +import logging + +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) + + +class TrackingVisualization: + """Handles visualization and user interaction for object tracking.""" + + def __init__(self): + self.lcm = LCM() + self.latest_color = None + + # Mouse interaction state + self.selecting_bbox = False + self.bbox_start = None + self.current_bbox = None + self.tracking_active = False + + # Subscribe to color image topic only + self.color_topic = Topic("/zed/color_image", Image) + + def start(self): + """Start the visualization node.""" + self.lcm.start() + + # Subscribe to color image only + self.lcm.subscribe(self.color_topic, self._on_color_image) + + logger.info("Visualization started, subscribed to color image topic") + + def _on_color_image(self, msg: Image, _: str): + """Handle color image messages.""" + try: + # Convert dimos Image to OpenCV format (BGR) for display + self.latest_color = msg.to_opencv() + logger.debug(f"Received color image: {msg.width}x{msg.height}, format: {msg.format}") + except Exception as e: + logger.error(f"Error processing color image: {e}") + + def mouse_callback(self, event, x, y, _, param): + """Handle mouse events for bbox selection.""" + tracker_module = param.get("tracker") + + if event == cv2.EVENT_LBUTTONDOWN: + self.selecting_bbox = True + self.bbox_start = (x, y) + self.current_bbox = None + + elif event == cv2.EVENT_MOUSEMOVE and self.selecting_bbox: + # Update current selection for visualization + x1, y1 = self.bbox_start + self.current_bbox = [min(x1, x), min(y1, y), max(x1, x), max(y1, y)] + + elif event == cv2.EVENT_LBUTTONUP and self.selecting_bbox: + self.selecting_bbox = False + if self.bbox_start: + x1, y1 = self.bbox_start + x2, y2 = x, y + # Ensure valid bbox + bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] + + # Check if bbox is valid (has area) + if bbox[2] > bbox[0] and bbox[3] > bbox[1]: + # Call track RPC on the tracker module + if tracker_module: + result = tracker_module.track(bbox) + logger.info(f"Tracking initialized: {result}") + self.tracking_active = True + self.current_bbox = None + + def draw_interface(self, frame): + """Draw UI elements on the frame.""" + # Draw bbox selection if in progress + if self.selecting_bbox and self.current_bbox: + x1, y1, x2, y2 = self.current_bbox + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 255), 2) + + # Draw instructions + cv2.putText( + frame, + "Click and drag to select object", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + frame, + "Press 's' to stop tracking, 'q' to quit", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + + # Show tracking status + if self.tracking_active: + status = "Tracking Active" + color = (0, 255, 0) + else: + status = "No Target" + color = (0, 0, 255) + cv2.putText(frame, f"Status: {status}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) + + return frame + + +async def test_object_tracking_module(): + """Test object tracking with ZED camera module.""" + logger.info("Starting Object Tracking Module test") + + # Start Dimos + dimos = core.start(2) + + # Enable LCM auto-configuration + pubsub.lcm.autoconf() + + viz = None + tracker = None + zed = None + foxglove_bridge = None + + try: + # Deploy ZED module + logger.info("Deploying ZED module...") + zed = dimos.deploy( + ZEDModule, + camera_id=0, + resolution="HD720", + depth_mode="NEURAL", + fps=30, + enable_tracking=True, + publish_rate=15.0, + frame_id="zed_camera_link", + ) + + # Configure ZED LCM transports + zed.color_image.transport = core.LCMTransport("/zed/color_image", Image) + zed.depth_image.transport = core.LCMTransport("/zed/depth_image", Image) + zed.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + zed.pose.transport = core.LCMTransport("/zed/pose", PoseStamped) + + # Start ZED to begin publishing + zed.start() + await asyncio.sleep(2) # Wait for camera to initialize + + # Deploy Object Tracking module + logger.info("Deploying Object Tracking module...") + tracker = dimos.deploy( + ObjectTracking, + camera_intrinsics=None, # Will get from camera_info topic + reid_threshold=5, + reid_fail_tolerance=10, + frame_id="zed_camera_link", + ) + + # Configure tracking LCM transports + tracker.color_image.transport = core.LCMTransport("/zed/color_image", Image) + tracker.depth.transport = core.LCMTransport("/zed/depth_image", Image) + tracker.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + + # Configure output transports + from dimos_lcm.vision_msgs import Detection2DArray, Detection3DArray + + tracker.detection2darray.transport = core.LCMTransport( + "/detection2darray", Detection2DArray + ) + tracker.detection3darray.transport = core.LCMTransport( + "/detection3darray", Detection3DArray + ) + tracker.tracked_overlay.transport = core.LCMTransport("/tracked_overlay", Image) + + # Connect inputs + tracker.color_image.connect(zed.color_image) + tracker.depth.connect(zed.depth_image) + tracker.camera_info.connect(zed.camera_info) + + # Start tracker + tracker.start() + + # Create visualization + viz = TrackingVisualization() + viz.start() + + # Start Foxglove bridge for visualization + foxglove_bridge = FoxgloveBridge() + foxglove_bridge.acquire() + + # Give modules time to initialize + await asyncio.sleep(1) + + # Create OpenCV window and set mouse callback + cv2.namedWindow("Object Tracking") + cv2.setMouseCallback("Object Tracking", viz.mouse_callback, {"tracker": tracker}) + + logger.info("System ready. Click and drag to select an object to track.") + logger.info("Foxglove visualization available at http://localhost:8765") + + # Main visualization loop + while True: + # Get the color frame to display + if viz.latest_color is not None: + display_frame = viz.latest_color.copy() + else: + # Wait for frames + await asyncio.sleep(0.03) + continue + + # Draw UI elements + display_frame = viz.draw_interface(display_frame) + + # Show frame + cv2.imshow("Object Tracking", display_frame) + + # Handle keyboard input + key = cv2.waitKey(1) & 0xFF + if key == ord("q"): + logger.info("Quit requested") + break + elif key == ord("s"): + # Stop tracking + if tracker: + tracker.stop_track() + viz.tracking_active = False + logger.info("Tracking stopped") + + await asyncio.sleep(0.03) # ~30 FPS + + except Exception as e: + logger.error(f"Error in test: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + cv2.destroyAllWindows() + + if tracker: + tracker.stop() + if zed: + zed.stop() + if foxglove_bridge: + foxglove_bridge.release() + + dimos.close() + logger.info("Test completed") + + +if __name__ == "__main__": + asyncio.run(test_object_tracking_module()) diff --git a/tests/test_object_tracking_webcam.py b/tests/test_object_tracking_webcam.py new file mode 100644 index 0000000000..caf6d75387 --- /dev/null +++ b/tests/test_object_tracking_webcam.py @@ -0,0 +1,219 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import threading + +import cv2 + +from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.stream.video_provider import VideoProvider + +# Global variables for bounding box selection +selecting_bbox = False +bbox_points = [] +current_bbox = None +tracker_initialized = False +object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) + + +def mouse_callback(event, x, y, flags, param): + global selecting_bbox, bbox_points, current_bbox, tracker_initialized, tracker_stream + + if event == cv2.EVENT_LBUTTONDOWN: + # Start bbox selection + selecting_bbox = True + bbox_points = [(x, y)] + current_bbox = None + tracker_initialized = False + + elif event == cv2.EVENT_MOUSEMOVE and selecting_bbox: + # Update current selection for visualization + current_bbox = [bbox_points[0][0], bbox_points[0][1], x, y] + + elif event == cv2.EVENT_LBUTTONUP: + # End bbox selection + selecting_bbox = False + if bbox_points: + bbox_points.append((x, y)) + x1, y1 = bbox_points[0] + x2, y2 = bbox_points[1] + # Ensure x1,y1 is top-left and x2,y2 is bottom-right + current_bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] + # Add the bbox to the tracking queue + if param.get("bbox_queue") and not tracker_initialized: + param["bbox_queue"].put((current_bbox, object_size)) + tracker_initialized = True + + +def main(): + global tracker_initialized + + # Create queues for thread communication + frame_queue = queue.Queue(maxsize=5) + bbox_queue = queue.Queue() + stop_event = threading.Event() + + # Logitech C920e camera parameters at 480p + # Convert physical parameters to pixel-based intrinsics + width, height = 640, 480 + focal_length_mm = 3.67 # mm + sensor_width_mm = 4.8 # mm (1/4" sensor) + sensor_height_mm = 3.6 # mm + + # Calculate focal length in pixels + focal_length_x_px = width * focal_length_mm / sensor_width_mm + focal_length_y_px = height * focal_length_mm / sensor_height_mm + + # Principal point (assuming center of image) + cx = width / 2 + cy = height / 2 + + # Final camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + + # Initialize video provider and object tracking stream + video_provider = VideoProvider("test_camera", video_source=0) + tracker_stream = ObjectTrackingStream( + camera_intrinsics=camera_intrinsics, + camera_pitch=0.0, # Adjust if your camera is tilted + camera_height=0.5, # Height of camera from ground in meters (adjust as needed) + ) + + # Create video stream + video_stream = video_provider.capture_video_as_observable(realtime=True, fps=30) + tracking_stream = tracker_stream.create_stream(video_stream) + + # Define callbacks for the tracking stream + def on_next(result): + if stop_event.is_set(): + return + + # Get the visualization frame + viz_frame = result["viz_frame"] + + # If we're selecting a bbox, draw the current selection + if selecting_bbox and current_bbox is not None: + x1, y1, x2, y2 = current_bbox + cv2.rectangle(viz_frame, (x1, y1), (x2, y2), (0, 255, 255), 2) + + # Add instructions + cv2.putText( + viz_frame, + "Click and drag to select object", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + viz_frame, + f"Object size: {object_size:.2f}m", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + + # Show tracking status + status = "Tracking" if tracker_initialized else "Not tracking" + cv2.putText( + viz_frame, + f"Status: {status}", + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (0, 255, 0) if tracker_initialized else (0, 0, 255), + 2, + ) + + # Put frame in queue for main thread to display + try: + frame_queue.put_nowait(viz_frame) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + subscription = tracking_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Object tracking started. Click and drag to select an object. Press 'q' to exit.") + + # Create window and set mouse callback + cv2.namedWindow("Object Tracker") + cv2.setMouseCallback("Object Tracker", mouse_callback, {"bbox_queue": bbox_queue}) + + # Main thread loop for displaying frames and handling bbox selection + while not stop_event.is_set(): + # Check if there's a new bbox to track + try: + new_bbox, size = bbox_queue.get_nowait() + print(f"New object selected: {new_bbox}, size: {size}m") + # Initialize tracker with the new bbox and size + tracker_stream.track(new_bbox, size=size) + except queue.Empty: + pass + + try: + # Get frame with timeout + viz_frame = frame_queue.get(timeout=1.0) + + # Display the frame + cv2.imshow("Object Tracker", viz_frame) + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + video_provider.dispose_all() + tracker_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/tests/test_object_tracking_with_qwen.py b/tests/test_object_tracking_with_qwen.py new file mode 100644 index 0000000000..89ab5b775c --- /dev/null +++ b/tests/test_object_tracking_with_qwen.py @@ -0,0 +1,209 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import queue +import threading + +import cv2 + +from dimos.models.qwen.video_query import get_bbox_from_qwen +from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.stream.video_provider import VideoProvider + +# Global variables for tracking control +object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) +tracking_object_name = "object" # Will be updated by Qwen +object_name = "hairbrush" # Example object name for Qwen + +global tracker_initialized, detection_in_progress + +# Create queues for thread communication +frame_queue = queue.Queue(maxsize=5) +stop_event = threading.Event() + +# Logitech C920e camera parameters at 480p +width, height = 640, 480 +focal_length_mm = 3.67 # mm +sensor_width_mm = 4.8 # mm (1/4" sensor) +sensor_height_mm = 3.6 # mm + +# Calculate focal length in pixels +focal_length_x_px = width * focal_length_mm / sensor_width_mm +focal_length_y_px = height * focal_length_mm / sensor_height_mm +cx, cy = width / 2, height / 2 + +# Final camera intrinsics in [fx, fy, cx, cy] format +camera_intrinsics = [focal_length_x_px, focal_length_y_px, cx, cy] + +# Initialize video provider and object tracking stream +video_provider = VideoProvider("webcam", video_source=0) +tracker_stream = ObjectTrackingStream( + camera_intrinsics=camera_intrinsics, camera_pitch=0.0, camera_height=0.5 +) + +# Create video streams +video_stream = video_provider.capture_video_as_observable(realtime=True, fps=10) +tracking_stream = tracker_stream.create_stream(video_stream) + +# Check if display is available +if "DISPLAY" not in os.environ: + raise RuntimeError( + "No display available. Please set DISPLAY environment variable or run in headless mode." + ) + + +# Define callbacks for the tracking stream +def on_next(result): + global tracker_initialized, detection_in_progress + if stop_event.is_set(): + return + + # Get the visualization frame + viz_frame = result["viz_frame"] + + # Add information to the visualization + cv2.putText( + viz_frame, + f"Tracking {tracking_object_name}", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + viz_frame, + f"Object size: {object_size:.2f}m", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + + # Show tracking status + status = "Tracking" if tracker_initialized else "Waiting for detection" + color = (0, 255, 0) if tracker_initialized else (0, 0, 255) + cv2.putText(viz_frame, f"Status: {status}", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) + + # If detection is in progress, show a message + if detection_in_progress: + cv2.putText( + viz_frame, "Querying Qwen...", (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2 + ) + + # Put frame in queue for main thread to display + try: + frame_queue.put_nowait(viz_frame) + except queue.Full: + pass + + +def on_error(error): + print(f"Error: {error}") + stop_event.set() + + +def on_completed(): + print("Stream completed") + stop_event.set() + + +# Start the subscription +subscription = None + +try: + # Initialize global flags + tracker_initialized = False + detection_in_progress = False + # Subscribe to start processing in background thread + subscription = tracking_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Object tracking with Qwen started. Press 'q' to exit.") + print("Waiting for initial object detection...") + + # Main thread loop for displaying frames and updating tracking + while not stop_event.is_set(): + # Check if we need to update tracking + + if not detection_in_progress: + detection_in_progress = True + print("Requesting object detection from Qwen...") + + print("detection_in_progress: ", detection_in_progress) + print("tracker_initialized: ", tracker_initialized) + + def detection_task(): + global detection_in_progress, tracker_initialized, tracking_object_name, object_size + try: + result = get_bbox_from_qwen(video_stream, object_name=object_name) + print(f"Got result from Qwen: {result}") + + if result: + bbox, size = result + print(f"Detected object at {bbox} with size {size}") + tracker_stream.track(bbox, size=size) + tracker_initialized = True + return + + print("No object detected by Qwen") + tracker_initialized = False + tracker_stream.stop_track() + + except Exception as e: + print(f"Error in update_tracking: {e}") + tracker_initialized = False + tracker_stream.stop_track() + finally: + detection_in_progress = False + + # Run detection task in a separate thread + threading.Thread(target=detection_task, daemon=True).start() + + try: + # Get frame with timeout + viz_frame = frame_queue.get(timeout=0.1) + + # Display the frame + cv2.imshow("Object Tracking with Qwen", viz_frame) + + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + +except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") +finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + video_provider.dispose_all() + tracker_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") diff --git a/tests/test_person_following_robot.py b/tests/test_person_following_robot.py new file mode 100644 index 0000000000..d2cfcdcb23 --- /dev/null +++ b/tests/test_person_following_robot.py @@ -0,0 +1,114 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +from reactivex import operators as RxOps + +from dimos.models.qwen.video_query import query_single_frame_observable +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + +logger = setup_logger() + + +def main(): + # Hardcoded parameters + timeout = 60.0 # Maximum time to follow a person (seconds) + distance = 0.5 # Desired distance to maintain from target (meters) + + print("Initializing Unitree Go2 robot...") + + # Initialize the robot with ROS control and skills + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + skills=MyUnitreeSkills(), + ) + + tracking_stream = robot.person_tracking_stream + viz_stream = tracking_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x["viz_frame"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + video_stream = robot.get_ros_video_stream() + + try: + # Set up web interface + logger.info("Initializing web interface") + streams = {"unitree_video": video_stream, "person_tracking": viz_stream} + + web_interface = RobotWebInterface(port=5555, **streams) + + # Wait for camera and tracking to initialize + print("Waiting for camera and tracking to initialize...") + time.sleep(5) + # Get initial point from Qwen + + max_retries = 5 + delay = 3 + + for attempt in range(max_retries): + try: + qwen_point = eval( + query_single_frame_observable( + video_stream, + "Look at this frame and point to the person shirt. Return ONLY their center coordinates as a tuple (x,y).", + ) + .pipe(RxOps.take(1)) + .run() + ) # Get first response and convert string tuple to actual tuple + logger.info(f"Found person at coordinates {qwen_point}") + break # If successful, break out of retry loop + except Exception as e: + if attempt < max_retries - 1: + logger.error( + f"Person not found. Attempt {attempt + 1}/{max_retries} failed. Retrying in {delay}s... Error: {e}" + ) + time.sleep(delay) + else: + logger.error(f"Person not found after {max_retries} attempts. Last error: {e}") + return + + # Start following human in a separate thread + import threading + + follow_thread = threading.Thread( + target=lambda: robot.follow_human(timeout=timeout, distance=distance, point=qwen_point), + daemon=True, + ) + follow_thread.start() + + print(f"Following human at point {qwen_point} for {timeout} seconds...") + print("Web interface available at http://localhost:5555") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nInterrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Test completed") + robot.cleanup() + + +if __name__ == "__main__": + main() diff --git a/tests/test_person_following_webcam.py b/tests/test_person_following_webcam.py new file mode 100644 index 0000000000..d66f3d7236 --- /dev/null +++ b/tests/test_person_following_webcam.py @@ -0,0 +1,227 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import threading + +import cv2 +import numpy as np + +from dimos.perception.person_tracker import PersonTrackingStream +from dimos.perception.visual_servoing import VisualServoing +from dimos.stream.video_provider import VideoProvider + + +def main(): + # Create a queue for thread communication (limit to prevent memory issues) + frame_queue = queue.Queue(maxsize=5) + result_queue = queue.Queue(maxsize=5) # For tracking results + stop_event = threading.Event() + + # Logitech C920e camera parameters at 480p + # Convert physical parameters to intrinsics [fx, fy, cx, cy] + resolution = (640, 480) # 480p resolution + focal_length_mm = 3.67 # mm + sensor_size_mm = (4.8, 3.6) # mm (1/4" sensor) + + # Calculate focal length in pixels + fx = (resolution[0] * focal_length_mm) / sensor_size_mm[0] + fy = (resolution[1] * focal_length_mm) / sensor_size_mm[1] + + # Principal point (typically at image center) + cx = resolution[0] / 2 + cy = resolution[1] / 2 + + # Camera intrinsics in [fx, fy, cx, cy] format + camera_intrinsics = [fx, fy, cx, cy] + + # Camera mounted parameters + camera_pitch = np.deg2rad(-5) # negative for downward pitch + camera_height = 1.4 # meters + + # Initialize video provider and person tracking stream + video_provider = VideoProvider("test_camera", video_source=0) + person_tracker = PersonTrackingStream( + camera_intrinsics=camera_intrinsics, camera_pitch=camera_pitch, camera_height=camera_height + ) + + # Create streams + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=20) + person_tracking_stream = person_tracker.create_stream(video_stream) + + # Create visual servoing object + visual_servoing = VisualServoing( + tracking_stream=person_tracking_stream, + max_linear_speed=0.5, + max_angular_speed=0.75, + desired_distance=2.5, + ) + + # Track if we have selected a person to follow + selected_point = None + tracking_active = False + + # Define callbacks for the tracking stream + def on_next(result): + if stop_event.is_set(): + return + + # Get the visualization frame which already includes person detections + # with bounding boxes, tracking IDs, and distance/angle information + viz_frame = result["viz_frame"] + + # Store the result for the main thread to use with visual servoing + try: + result_queue.put_nowait(result) + except queue.Full: + # Skip if queue is full + pass + + # Put frame in queue for main thread to display (non-blocking) + try: + frame_queue.put_nowait(viz_frame) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Mouse callback for selecting a person to track + def mouse_callback(event, x, y, flags, param): + nonlocal selected_point, tracking_active + + if event == cv2.EVENT_LBUTTONDOWN: + # Store the clicked point + selected_point = (x, y) + tracking_active = False # Will be set to True if start_tracking succeeds + print(f"Selected point: {selected_point}") + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + subscription = person_tracking_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Person tracking visualization started.") + print("Click on a person to start visual servoing. Press 'q' to exit.") + + # Set up mouse callback + cv2.namedWindow("Person Tracking") + cv2.setMouseCallback("Person Tracking", mouse_callback) + + # Main thread loop for displaying frames + while not stop_event.is_set(): + try: + # Get frame with timeout (allows checking stop_event periodically) + frame = frame_queue.get(timeout=1.0) + + # Call the visual servoing if we have a selected point + if selected_point is not None: + # If not actively tracking, try to start tracking + if not tracking_active: + tracking_active = visual_servoing.start_tracking(point=selected_point) + if not tracking_active: + print("Failed to start tracking") + selected_point = None + + # If tracking is active, update tracking + if tracking_active: + servoing_result = visual_servoing.updateTracking() + + # Display visual servoing output on the frame + linear_vel = servoing_result.get("linear_vel", 0.0) + angular_vel = servoing_result.get("angular_vel", 0.0) + running = visual_servoing.running + + status_color = ( + (0, 255, 0) if running else (0, 0, 255) + ) # Green if running, red if not + + # Add velocity text to frame + cv2.putText( + frame, + f"Linear: {linear_vel:.2f} m/s", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + cv2.putText( + frame, + f"Angular: {angular_vel:.2f} rad/s", + (10, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + cv2.putText( + frame, + f"Tracking: {'ON' if running else 'OFF'}", + (10, 90), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + status_color, + 2, + ) + + # If tracking is lost, reset selected_point and tracking_active + if not running: + selected_point = None + tracking_active = False + + # Display the frame in main thread + cv2.imshow("Person Tracking", frame) + + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + visual_servoing.cleanup() + video_provider.dispose_all() + person_tracker.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/tests/test_pick_and_place_module.py b/tests/test_pick_and_place_module.py new file mode 100644 index 0000000000..7ff5a689c1 --- /dev/null +++ b/tests/test_pick_and_place_module.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Piper Arm robot with pick and place functionality. +Subscribes to visualization images and handles mouse/keyboard input. +""" + +import asyncio +import sys +import threading +import time + +import cv2 +import numpy as np + +try: + import pyzed.sl as sl +except ImportError: + print("Error: ZED SDK not installed.") + sys.exit(1) + +# Import LCM message types +from dimos_lcm.sensor_msgs import Image + +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.agilex.piper_arm import PiperArmRobot +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Global for mouse events +mouse_click = None +camera_mouse_click = None +current_window = None +pick_location = None # Store pick location +place_location = None # Store place location +place_mode = False # Track if we're in place selection mode + + +def mouse_callback(event, x, y, _flags, param): + global mouse_click, camera_mouse_click + window_name = param + if event == cv2.EVENT_LBUTTONDOWN: + if window_name == "Camera Feed": + camera_mouse_click = (x, y) + else: + mouse_click = (x, y) + + +class VisualizationNode: + """Node that subscribes to visualization images and handles user input.""" + + def __init__(self, robot: PiperArmRobot): + self.lcm = LCM() + self.latest_viz = None + self.latest_camera = None + self._running = False + self.robot = robot + + # Subscribe to visualization topic + self.viz_topic = Topic("/manipulation/viz", Image) + self.camera_topic = Topic("/zed/color_image", Image) + + def start(self): + """Start the visualization node.""" + self._running = True + self.lcm.start() + + # Subscribe to visualization topic + self.lcm.subscribe(self.viz_topic, self._on_viz_image) + # Subscribe to camera topic for point selection + self.lcm.subscribe(self.camera_topic, self._on_camera_image) + + logger.info("Visualization node started") + + def stop(self): + """Stop the visualization node.""" + self._running = False + cv2.destroyAllWindows() + + def _on_viz_image(self, msg: Image, topic: str): + """Handle visualization image messages.""" + try: + # Convert LCM message to numpy array + data = np.frombuffer(msg.data, dtype=np.uint8) + if msg.encoding == "rgb8": + image = data.reshape((msg.height, msg.width, 3)) + # Convert RGB to BGR for OpenCV + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + self.latest_viz = image + except Exception as e: + logger.error(f"Error processing viz image: {e}") + + def _on_camera_image(self, msg: Image, topic: str): + """Handle camera image messages.""" + try: + # Convert LCM message to numpy array + data = np.frombuffer(msg.data, dtype=np.uint8) + if msg.encoding == "rgb8": + image = data.reshape((msg.height, msg.width, 3)) + # Convert RGB to BGR for OpenCV + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + self.latest_camera = image + except Exception as e: + logger.error(f"Error processing camera image: {e}") + + def run_visualization(self): + """Run the visualization loop with user interaction.""" + global mouse_click, camera_mouse_click, pick_location, place_location, place_mode + + # Setup windows + cv2.namedWindow("Pick and Place") + cv2.setMouseCallback("Pick and Place", mouse_callback, "Pick and Place") + + cv2.namedWindow("Camera Feed") + cv2.setMouseCallback("Camera Feed", mouse_callback, "Camera Feed") + + print("=== Piper Arm Robot - Pick and Place ===") + print("Control mode: Module-based with LCM communication") + print("\nPICK AND PLACE WORKFLOW:") + print("1. Click on an object to select PICK location") + print("2. Click again to select PLACE location (auto pick & place)") + print("3. OR press 'p' after first click for pick-only task") + print("\nCONTROLS:") + print(" 'p' - Execute pick-only task (after selecting pick location)") + print(" 'r' - Reset everything") + print(" 'q' - Quit") + print(" 's' - SOFT STOP (emergency stop)") + print(" 'g' - RELEASE GRIPPER (open gripper)") + print(" 'SPACE' - EXECUTE target pose (manual override)") + print("\nNOTE: Click on objects in the Camera Feed window!") + + while self._running: + # Show camera feed with status overlay + if self.latest_camera is not None: + display_image = self.latest_camera.copy() + + # Add status text + status_text = "" + if pick_location is None: + status_text = "Click to select PICK location" + color = (0, 255, 0) + elif place_location is None: + status_text = "Click to select PLACE location (or press 'p' for pick-only)" + color = (0, 255, 255) + else: + status_text = "Executing pick and place..." + color = (255, 0, 255) + + cv2.putText( + display_image, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2 + ) + + # Draw pick location marker if set + if pick_location is not None: + # Simple circle marker + cv2.circle(display_image, pick_location, 10, (0, 255, 0), 2) + cv2.circle(display_image, pick_location, 2, (0, 255, 0), -1) + + # Simple label + cv2.putText( + display_image, + "PICK", + (pick_location[0] + 15, pick_location[1] + 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + # Draw place location marker if set + if place_location is not None: + # Simple circle marker + cv2.circle(display_image, place_location, 10, (0, 255, 255), 2) + cv2.circle(display_image, place_location, 2, (0, 255, 255), -1) + + # Simple label + cv2.putText( + display_image, + "PLACE", + (place_location[0] + 15, place_location[1] + 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 255), + 2, + ) + + # Draw simple arrow between pick and place + if pick_location is not None: + cv2.arrowedLine( + display_image, + pick_location, + place_location, + (255, 255, 0), + 2, + tipLength=0.05, + ) + + cv2.imshow("Camera Feed", display_image) + + # Show visualization if available + if self.latest_viz is not None: + cv2.imshow("Pick and Place", self.latest_viz) + + # Handle keyboard input + key = cv2.waitKey(1) & 0xFF + if key != 255: # Key was pressed + if key == ord("q"): + logger.info("Quit requested") + self._running = False + break + elif key == ord("r"): + # Reset everything + pick_location = None + place_location = None + place_mode = False + logger.info("Reset pick and place selections") + # Also send reset to robot + action = self.robot.handle_keyboard_command("r") + if action: + logger.info(f"Action: {action}") + elif key == ord("p"): + # Execute pick-only task if pick location is set + if pick_location is not None: + logger.info(f"Executing pick-only task at {pick_location}") + result = self.robot.pick_and_place( + pick_location[0], + pick_location[1], + None, # No place location + None, + ) + logger.info(f"Pick task started: {result}") + # Clear selection after sending + pick_location = None + place_location = None + else: + logger.warning("Please select a pick location first!") + else: + # Send keyboard command to robot + if key in [82, 84]: # Arrow keys + action = self.robot.handle_keyboard_command(str(key)) + else: + action = self.robot.handle_keyboard_command(chr(key)) + if action: + logger.info(f"Action: {action}") + + # Handle mouse clicks + if camera_mouse_click: + x, y = camera_mouse_click + + if pick_location is None: + # First click - set pick location + pick_location = (x, y) + logger.info(f"Pick location set at ({x}, {y})") + elif place_location is None: + # Second click - set place location and execute + place_location = (x, y) + logger.info(f"Place location set at ({x}, {y})") + logger.info(f"Executing pick at {pick_location} and place at ({x}, {y})") + + # Start pick and place task with both locations + result = self.robot.pick_and_place(pick_location[0], pick_location[1], x, y) + logger.info(f"Pick and place task started: {result}") + + # Clear all points after sending mission + pick_location = None + place_location = None + + camera_mouse_click = None + + # Handle mouse click from Pick and Place window (if viz is running) + elif mouse_click and self.latest_viz is not None: + # Similar logic for viz window clicks + x, y = mouse_click + + if pick_location is None: + # First click - set pick location + pick_location = (x, y) + logger.info(f"Pick location set at ({x}, {y}) from viz window") + elif place_location is None: + # Second click - set place location and execute + place_location = (x, y) + logger.info(f"Place location set at ({x}, {y}) from viz window") + logger.info(f"Executing pick at {pick_location} and place at ({x}, {y})") + + # Start pick and place task with both locations + result = self.robot.pick_and_place(pick_location[0], pick_location[1], x, y) + logger.info(f"Pick and place task started: {result}") + + # Clear all points after sending mission + pick_location = None + place_location = None + + mouse_click = None + + time.sleep(0.03) # ~30 FPS + + +async def run_piper_arm_with_viz(): + """Run the Piper Arm robot with visualization.""" + logger.info("Starting Piper Arm Robot") + + # Create robot instance + robot = PiperArmRobot() + + try: + # Start the robot + await robot.start() + + # Give modules time to fully initialize + await asyncio.sleep(2) + + # Create and start visualization node + viz_node = VisualizationNode(robot) + viz_node.start() + + # Run visualization in separate thread + viz_thread = threading.Thread(target=viz_node.run_visualization, daemon=True) + viz_thread.start() + + # Keep running until visualization stops + while viz_node._running: + await asyncio.sleep(0.1) + + # Stop visualization + viz_node.stop() + + except Exception as e: + logger.error(f"Error running robot: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + robot.stop() + logger.info("Robot stopped") + + +if __name__ == "__main__": + # Run the robot + asyncio.run(run_piper_arm_with_viz()) diff --git a/tests/test_pick_and_place_skill.py b/tests/test_pick_and_place_skill.py new file mode 100644 index 0000000000..1e6d14e780 --- /dev/null +++ b/tests/test_pick_and_place_skill.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Piper Arm robot with pick and place functionality. +Uses hardcoded points and the PickAndPlace skill. +""" + +import asyncio +import sys + +try: + import pyzed.sl as sl # Required for ZED camera +except ImportError: + print("Error: ZED SDK not installed.") + sys.exit(1) + +from dimos.robot.agilex.piper_arm import PiperArmRobot +from dimos.skills.manipulation.pick_and_place import PickAndPlace +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +async def run_piper_arm(): + """Run the Piper Arm robot with pick and place skill.""" + logger.info("Starting Piper Arm Robot") + + # Create robot instance + robot = PiperArmRobot() + + try: + # Start the robot + await robot.start() + + # Give modules time to fully initialize + await asyncio.sleep(3) + + # Add the PickAndPlace skill to the robot's skill library + robot.skill_library.add(PickAndPlace) + + logger.info("Robot initialized successfully") + print("\n=== Piper Arm Robot - Pick and Place Demo ===") + print("This demo uses hardcoded pick and place points.") + print("\nCommands:") + print(" 1. Run pick and place with hardcoded points") + print(" 2. Run pick-only with hardcoded point") + print(" r. Reset robot to idle") + print(" q. Quit") + print("") + + running = True + while running: + try: + # Get user input + command = input("\nEnter command: ").strip().lower() + + if command == "q": + logger.info("Quit requested") + running = False + break + + elif command == "r" or command == "s": + logger.info("Resetting robot") + robot.handle_keyboard_command(command) + + elif command == "1": + # Hardcoded pick and place points + # These should be adjusted based on your camera view + print("\nExecuting pick and place with hardcoded points...") + + # Create and execute the skill + skill = PickAndPlace( + robot=robot, + object_query="labubu doll", # Will use visual detection + target_query="on the keyboard", # Will use visual detection + ) + + result = skill() + + if result["success"]: + print(f"✓ {result['message']}") + else: + print(f"✗ Failed: {result.get('error', 'Unknown error')}") + + elif command == "2": + # Pick-only with hardcoded point + print("\nExecuting pick-only with hardcoded point...") + + # Create and execute the skill for pick-only + skill = PickAndPlace( + robot=robot, + object_query="labubu doll", # Will use visual detection + target_query=None, # No place target - pick only + ) + + result = skill() + + if result["success"]: + print(f"✓ {result['message']}") + else: + print(f"✗ Failed: {result.get('error', 'Unknown error')}") + + else: + print("Invalid command. Please try again.") + + # Small delay to prevent CPU spinning + await asyncio.sleep(0.1) + + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + running = False + break + except Exception as e: + logger.error(f"Error in command loop: {e}") + print(f"Error: {e}") + + except Exception as e: + logger.error(f"Error running robot: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + logger.info("Shutting down robot...") + await robot.stop() + logger.info("Robot stopped") + + +def main(): + """Main entry point.""" + print("Starting Piper Arm Robot...") + print("Note: The robot will use Qwen VLM to identify objects and locations") + print("based on the queries specified in the code.") + + # Run the robot + asyncio.run(run_piper_arm()) + + +if __name__ == "__main__": + main() diff --git a/tests/test_pointcloud_filtering.py b/tests/test_pointcloud_filtering.py new file mode 100644 index 0000000000..a5300eaa1e --- /dev/null +++ b/tests/test_pointcloud_filtering.py @@ -0,0 +1,101 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from pyzed import sl +from reactivex import operators as ops + +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline +from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream +from dimos.web.robot_web_interface import RobotWebInterface + + +def main(): + """Test point cloud filtering using the concurrent stream-based ManipulationPipeline.""" + print("Testing point cloud filtering with ManipulationPipeline...") + + # Configuration + min_confidence = 0.6 + web_port = 5555 + + try: + # Initialize ZED camera stream + zed_stream = ZEDCameraStream(resolution=sl.RESOLUTION.HD1080, fps=10) + + # Get camera intrinsics + camera_intrinsics_dict = zed_stream.get_camera_info() + camera_intrinsics = [ + camera_intrinsics_dict["fx"], + camera_intrinsics_dict["fy"], + camera_intrinsics_dict["cx"], + camera_intrinsics_dict["cy"], + ] + + # Create the concurrent manipulation pipeline + pipeline = ManipulationPipeline( + camera_intrinsics=camera_intrinsics, + min_confidence=min_confidence, + max_objects=10, + ) + + # Create ZED stream + zed_frame_stream = zed_stream.create_stream().pipe(ops.share()) + + # Create concurrent processing streams + streams = pipeline.create_streams(zed_frame_stream) + detection_viz_stream = streams["detection_viz"] + pointcloud_viz_stream = streams["pointcloud_viz"] + + except ImportError: + print("Error: ZED SDK not installed. Please install pyzed package.") + sys.exit(1) + except RuntimeError as e: + print(f"Error: Failed to open ZED camera: {e}") + sys.exit(1) + + try: + # Set up web interface with concurrent visualization streams + print("Initializing web interface...") + web_interface = RobotWebInterface( + port=web_port, + object_detection=detection_viz_stream, + pointcloud_stream=pointcloud_viz_stream, + ) + + print("\nPoint Cloud Filtering Test Running:") + print(f"Web Interface: http://localhost:{web_port}") + print("Object Detection View: RGB with bounding boxes") + print("Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(f"Confidence threshold: {min_confidence}") + print("\nPress Ctrl+C to stop the test\n") + + # Start web server (blocking call) + web_interface.run() + + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Error during test: {e}") + finally: + print("Cleaning up resources...") + if "zed_stream" in locals(): + zed_stream.cleanup() + if "pipeline" in locals(): + pipeline.cleanup() + print("Test completed") + + +if __name__ == "__main__": + main() diff --git a/tests/test_qwen_image_query.py b/tests/test_qwen_image_query.py new file mode 100644 index 0000000000..1f77bc2b02 --- /dev/null +++ b/tests/test_qwen_image_query.py @@ -0,0 +1,62 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Qwen image query functionality.""" + +import os + +import cv2 +import numpy as np +from PIL import Image + +from dimos.models.qwen.video_query import query_single_frame + + +def test_qwen_image_query(): + """Test querying Qwen with a single image.""" + # Skip if no API key + if not os.getenv("ALIBABA_API_KEY"): + print("ALIBABA_API_KEY not set") + return + + # Load test image + image_path = os.path.join(os.getcwd(), "assets", "test_spatial_memory", "frame_038.jpg") + pil_image = Image.open(image_path) + + # Convert PIL image to numpy array in RGB format + image_array = np.array(pil_image) + if image_array.shape[-1] == 3: + # Ensure it's in RGB format (PIL loads as RGB by default) + image = image_array + else: + # Handle grayscale images + image = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB) + + # Test basic object detection query + response = query_single_frame( + image=image, + query="What objects do you see in this image? Return as a comma-separated list.", + ) + print(response) + + # Test coordinate query + response = query_single_frame( + image=image, + query="Return the center coordinates of any person in the image as a tuple (x,y)", + ) + print(response) + + +if __name__ == "__main__": + test_qwen_image_query() diff --git a/tests/test_rtsp_video_provider.py b/tests/test_rtsp_video_provider.py new file mode 100644 index 0000000000..3de06f5f0b --- /dev/null +++ b/tests/test_rtsp_video_provider.py @@ -0,0 +1,142 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import reactivex as rx +from reactivex import operators as ops + +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.rtsp_video_provider import RtspVideoProvider +from dimos.stream.video_operators import VideoOperators as vops +from dimos.stream.video_provider import get_scheduler +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + +logger = setup_logger() + +import os +import sys + +# Load environment variables from .env file +from dotenv import load_dotenv + +load_dotenv() + +# RTSP URL must be provided as a command-line argument or environment variable +RTSP_URL = os.environ.get("TEST_RTSP_URL", "") +if len(sys.argv) > 1: + RTSP_URL = sys.argv[1] # Allow overriding with command-line argument +elif RTSP_URL == "": + print("Please provide an RTSP URL for testing.") + print( + "You can set the TEST_RTSP_URL environment variable or pass it as a command-line argument." + ) + print("Example: python -m dimos.stream.rtsp_video_provider rtsp://...") + sys.exit(1) + +logger.info("Attempting to connect to provided RTSP URL.") +provider = RtspVideoProvider(dev_name="TestRtspCam", rtsp_url=RTSP_URL) + +logger.info("Creating observable...") +video_stream_observable = provider.capture_video_as_observable() + +logger.info("Subscribing to observable...") +frame_counter = 0 +start_time = time.monotonic() # Re-initialize start_time +last_log_time = start_time # Keep this for interval timing + +# Create a subject for ffmpeg responses +ffmpeg_response_subject = rx.subject.Subject() +ffmpeg_response_stream = ffmpeg_response_subject.pipe(ops.observe_on(get_scheduler()), ops.share()) + + +def process_frame(frame: np.ndarray): + """Callback function executed for each received frame.""" + global frame_counter, last_log_time, start_time # Add start_time to global + frame_counter += 1 + current_time = time.monotonic() + # Log stats periodically (e.g., every 5 seconds) + if current_time - last_log_time >= 5.0: + total_elapsed_time = current_time - start_time # Calculate total elapsed time + avg_fps = frame_counter / total_elapsed_time if total_elapsed_time > 0 else 0 + logger.info(f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}") + ffmpeg_response_subject.on_next( + f"Received frame {frame_counter}. Shape: {frame.shape}. Avg FPS: {avg_fps:.2f}" + ) + last_log_time = current_time # Update log time for the next interval + + +def handle_error(error: Exception): + """Callback function executed if the observable stream errors.""" + logger.error(f"Stream error: {error}", exc_info=True) # Log with traceback + + +def handle_completion(): + """Callback function executed when the observable stream completes.""" + logger.info("Stream completed.") + + +# Subscribe to the observable stream +processor = FrameProcessor() +subscription = video_stream_observable.pipe( + # ops.subscribe_on(get_scheduler()), + ops.observe_on(get_scheduler()), + ops.share(), + vops.with_jpeg_export(processor, suffix="reolink_", save_limit=30, loop=True), +).subscribe(on_next=process_frame, on_error=handle_error, on_completed=handle_completion) + +streams = {"reolink_video": video_stream_observable} +text_streams = { + "ffmpeg_responses": ffmpeg_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +web_interface.run() # This may block the main thread + +# TODO: Redo disposal / keep-alive loop + +# Keep the main thread alive to receive frames (e.g., for 60 seconds) +print("Stream running. Press Ctrl+C to stop...") +try: + # Keep running indefinitely until interrupted + while True: + time.sleep(1) + # Optional: Check if subscription is still active + # if not subscription.is_disposed: + # # logger.debug("Subscription active...") + # pass + # else: + # logger.warning("Subscription was disposed externally.") + # break + +except KeyboardInterrupt: + print("KeyboardInterrupt received. Shutting down...") +finally: + # Ensure resources are cleaned up regardless of how the loop exits + print("Disposing subscription...") + # subscription.dispose() + print("Disposing provider resources...") + provider.dispose_all() + print("Cleanup finished.") + +# Final check (optional, for debugging) +time.sleep(1) # Give background threads a moment +final_process = provider._ffmpeg_process +if final_process and final_process.poll() is None: + print(f"WARNING: ffmpeg process (PID: {final_process.pid}) may still be running after cleanup!") +else: + print("ffmpeg process appears terminated.") diff --git a/tests/test_semantic_seg_robot.py b/tests/test_semantic_seg_robot.py new file mode 100644 index 0000000000..5636cc4ba8 --- /dev/null +++ b/tests/test_semantic_seg_robot.py @@ -0,0 +1,151 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import queue +import sys +import threading + +import cv2 +import numpy as np + +# Add the parent directory to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from reactivex import operators as RxOps + +from dimos.perception.semantic_seg import SemanticSegmentationStream +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.video_operators import Operators as MyOps +from dimos.web.robot_web_interface import RobotWebInterface + + +def main(): + # Create a queue for thread communication (limit to prevent memory issues) + frame_queue = queue.Queue(maxsize=5) + stop_event = threading.Event() + + # Unitree Go2 camera parameters at 1080p + camera_params = { + "resolution": (1920, 1080), # 1080p resolution + "focal_length": 3.2, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) + } + + # Initialize video provider and segmentation stream + # video_provider = VideoProvider("test_camera", video_source=0) + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + ros_control=UnitreeROSControl(), + ) + + seg_stream = SemanticSegmentationStream( + enable_mono_depth=False, camera_params=camera_params, gt_depth_scale=512.0 + ) + + # Create streams + video_stream = robot.get_ros_video_stream(fps=5) + segmentation_stream = seg_stream.create_stream(video_stream) + + # Define callbacks for the segmentation stream + def on_next(segmentation): + if stop_event.is_set(): + return + # Get the frame and visualize + vis_frame = segmentation.metadata["viz_frame"] + depth_viz = segmentation.metadata["depth_viz"] + # Get the image dimensions + height, width = vis_frame.shape[:2] + depth_height, depth_width = depth_viz.shape[:2] + + # Resize depth visualization to match segmentation height + # (maintaining aspect ratio if needed) + depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) + + # Create a combined frame for side-by-side display + combined_viz = np.hstack((vis_frame, depth_resized)) + + # Add labels + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) + cv2.putText( + combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 + ) + + # Put frame in queue for main thread to display (non-blocking) + try: + frame_queue.put_nowait(combined_viz) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + print_emission_args = { + "enabled": True, + "dev_name": "SemanticSegmentation", + "counts": {}, + } + + FrameProcessor(delete_on_init=True) + subscription = segmentation_stream.pipe( + MyOps.print_emission(id="A", **print_emission_args), + RxOps.share(), + MyOps.print_emission(id="B", **print_emission_args), + RxOps.map(lambda x: x.metadata["viz_frame"] if x is not None else None), + MyOps.print_emission(id="C", **print_emission_args), + RxOps.filter(lambda x: x is not None), + MyOps.print_emission(id="D", **print_emission_args), + # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), + MyOps.print_emission(id="E", **print_emission_args), + ) + + print("Semantic segmentation visualization started. Press 'q' to exit.") + + streams = { + "segmentation_stream": subscription, + } + fast_api_server = RobotWebInterface(port=5555, **streams) + fast_api_server.run() + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + seg_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/tests/test_semantic_seg_robot_agent.py b/tests/test_semantic_seg_robot_agent.py new file mode 100644 index 0000000000..a4255d4169 --- /dev/null +++ b/tests/test_semantic_seg_robot_agent.py @@ -0,0 +1,139 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import cv2 +from reactivex import Subject, operators as RxOps + +from dimos.agents.agent import OpenAIAgent +from dimos.perception.semantic_seg import SemanticSegmentationStream +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.video_operators import VideoOperators as MyVideoOps +from dimos.utils.threadpool import get_scheduler +from dimos.web.robot_web_interface import RobotWebInterface + + +def main(): + # Unitree Go2 camera parameters at 1080p + camera_params = { + "resolution": (1920, 1080), # 1080p resolution + "focal_length": 3.2, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) + } + + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), ros_control=UnitreeROSControl(), skills=MyUnitreeSkills() + ) + + seg_stream = SemanticSegmentationStream( + enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 + ) + + # Create streams + video_stream = robot.get_ros_video_stream(fps=5) + segmentation_stream = seg_stream.create_stream( + video_stream.pipe(MyVideoOps.with_fps_sampling(fps=0.5)) + ) + # Throttling to slowdown SegmentationAgent calls + # TODO: add Agent parameter to handle this called api_call_interval + + FrameProcessor(delete_on_init=True) + seg_stream = segmentation_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x.metadata["viz_frame"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + # MyVideoOps.with_jpeg_export(frame_processor=frame_processor, suffix="_frame_"), # debugging + ) + + depth_stream = segmentation_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x.metadata["depth_viz"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + ) + + object_stream = segmentation_stream.pipe( + RxOps.share(), + RxOps.map(lambda x: x.metadata["objects"] if x is not None else None), + RxOps.filter(lambda x: x is not None), + RxOps.map( + lambda objects: "\n".join( + f"Object {obj['object_id']}: {obj['label']} (confidence: {obj['prob']:.2f})" + + (f", depth: {obj['depth']:.2f}m" if "depth" in obj else "") + for obj in objects + ) + if objects + else "No objects detected." + ), + ) + + text_query_stream = Subject() + + # Combine text query with latest object data when a new text query arrives + enriched_query_stream = text_query_stream.pipe( + RxOps.with_latest_from(object_stream), + RxOps.map( + lambda combined: { + "query": combined[0], + "objects": combined[1] if len(combined) > 1 else "No object data available", + } + ), + RxOps.map(lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"), + RxOps.do_action( + lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m") + or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]] + ), + ) + + segmentation_agent = OpenAIAgent( + dev_name="SemanticSegmentationAgent", + model_name="gpt-4o", + system_query="You are a helpful assistant that can control a virtual robot with semantic segmentation / distnace data as a guide. Only output skill calls, no other text", + input_query_stream=enriched_query_stream, + process_all_inputs=False, + pool_scheduler=get_scheduler(), + skills=robot.get_skills(), + ) + agent_response_stream = segmentation_agent.get_response_observable() + + print("Semantic segmentation visualization started. Press 'q' to exit.") + + streams = { + "raw_stream": video_stream, + "depth_stream": depth_stream, + "seg_stream": seg_stream, + } + text_streams = { + "object_stream": object_stream, + "enriched_query_stream": enriched_query_stream, + "agent_response_stream": agent_response_stream, + } + + try: + fast_api_server = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + fast_api_server.query_stream.subscribe(lambda x: text_query_stream.on_next(x)) + fast_api_server.run() + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + seg_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/tests/test_semantic_seg_webcam.py b/tests/test_semantic_seg_webcam.py new file mode 100644 index 0000000000..08d15abe72 --- /dev/null +++ b/tests/test_semantic_seg_webcam.py @@ -0,0 +1,141 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import queue +import sys +import threading + +import cv2 +import numpy as np + +# Add the parent directory to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dimos.perception.semantic_seg import SemanticSegmentationStream +from dimos.stream.video_provider import VideoProvider + + +def main(): + # Create a queue for thread communication (limit to prevent memory issues) + frame_queue = queue.Queue(maxsize=5) + stop_event = threading.Event() + + # Logitech C920e camera parameters at 480p + camera_params = { + "resolution": (640, 480), # 480p resolution + "focal_length": 3.67, # mm + "sensor_size": (4.8, 3.6), # mm (1/4" sensor) + } + + # Initialize video provider and segmentation stream + video_provider = VideoProvider("test_camera", video_source=0) + seg_stream = SemanticSegmentationStream( + enable_mono_depth=True, camera_params=camera_params, gt_depth_scale=512.0 + ) + + # Create streams + video_stream = video_provider.capture_video_as_observable(realtime=False, fps=5) + segmentation_stream = seg_stream.create_stream(video_stream) + + # Define callbacks for the segmentation stream + def on_next(segmentation): + if stop_event.is_set(): + return + + # Get the frame and visualize + vis_frame = segmentation.metadata["viz_frame"] + depth_viz = segmentation.metadata["depth_viz"] + # Get the image dimensions + height, width = vis_frame.shape[:2] + depth_height, depth_width = depth_viz.shape[:2] + + # Resize depth visualization to match segmentation height + # (maintaining aspect ratio if needed) + depth_resized = cv2.resize(depth_viz, (int(depth_width * height / depth_height), height)) + + # Create a combined frame for side-by-side display + combined_viz = np.hstack((vis_frame, depth_resized)) + + # Add labels + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(combined_viz, "Semantic Segmentation", (10, 30), font, 0.8, (255, 255, 255), 2) + cv2.putText( + combined_viz, "Depth Estimation", (width + 10, 30), font, 0.8, (255, 255, 255), 2 + ) + + # Put frame in queue for main thread to display (non-blocking) + try: + frame_queue.put_nowait(combined_viz) + except queue.Full: + # Skip frame if queue is full + pass + + def on_error(error): + print(f"Error: {error}") + stop_event.set() + + def on_completed(): + print("Stream completed") + stop_event.set() + + # Start the subscription + subscription = None + + try: + # Subscribe to start processing in background thread + subscription = segmentation_stream.subscribe( + on_next=on_next, on_error=on_error, on_completed=on_completed + ) + + print("Semantic segmentation visualization started. Press 'q' to exit.") + + # Main thread loop for displaying frames + while not stop_event.is_set(): + try: + # Get frame with timeout (allows checking stop_event periodically) + combined_viz = frame_queue.get(timeout=1.0) + + # Display the frame in main thread + cv2.imshow("Semantic Segmentation", combined_viz) + # Check for exit key + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + + except queue.Empty: + # No frame available, check if we should continue + if cv2.waitKey(1) & 0xFF == ord("q"): + print("Exit key pressed") + break + continue + + except KeyboardInterrupt: + print("\nKeyboard interrupt received. Stopping...") + finally: + # Signal threads to stop + stop_event.set() + + # Clean up resources + if subscription: + subscription.dispose() + + video_provider.dispose_all() + seg_stream.cleanup() + cv2.destroyAllWindows() + print("Cleanup complete") + + +if __name__ == "__main__": + main() diff --git a/tests/test_skills.py b/tests/test_skills.py new file mode 100644 index 0000000000..f10dabed49 --- /dev/null +++ b/tests/test_skills.py @@ -0,0 +1,182 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the skills module in the dimos package.""" + +import unittest +from unittest import mock + +from dimos.agents.agent import OpenAIAgent +from dimos.robot.robot import MockRobot +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.skills import AbstractSkill + + +class TestSkill(AbstractSkill): + """A test skill that tracks its execution for testing purposes.""" + + _called: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._called = False + + def __call__(self): + self._called = True + return "TestSkill executed successfully" + + +class SkillLibraryTest(unittest.TestCase): + """Tests for the SkillLibrary functionality.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.robot = MockRobot() + self.skill_library = MyUnitreeSkills(robot=self.robot) + self.skill_library.initialize_skills() + + def test_skill_iteration(self): + """Test that skills can be properly iterated in the skill library.""" + skills_count = 0 + for skill in self.skill_library: + skills_count += 1 + self.assertTrue(hasattr(skill, "__name__")) + self.assertTrue(issubclass(skill, AbstractSkill)) + + self.assertGreater(skills_count, 0, "Skill library should contain at least one skill") + + def test_skill_registration(self): + """Test that skills can be properly registered in the skill library.""" + # Clear existing skills for isolated test + self.skill_library = MyUnitreeSkills(robot=self.robot) + original_count = len(list(self.skill_library)) + + # Add a custom test skill + test_skill = TestSkill + self.skill_library.add(test_skill) + + # Verify the skill was added + new_count = len(list(self.skill_library)) + self.assertEqual(new_count, original_count + 1) + + # Check if the skill can be found by name + found = False + for skill in self.skill_library: + if skill.__name__ == "TestSkill": + found = True + break + self.assertTrue(found, "Added skill should be found in skill library") + + def test_skill_direct_execution(self): + """Test that a skill can be executed directly.""" + test_skill = TestSkill() + self.assertFalse(test_skill._called) + result = test_skill() + self.assertTrue(test_skill._called) + self.assertEqual(result, "TestSkill executed successfully") + + def test_skill_library_execution(self): + """Test that a skill can be executed through the skill library.""" + # Add our test skill to the library + test_skill = TestSkill + self.skill_library.add(test_skill) + + # Create an instance to confirm it was executed + with mock.patch.object(TestSkill, "__call__", return_value="Success") as mock_call: + result = self.skill_library.call("TestSkill") + mock_call.assert_called_once() + self.assertEqual(result, "Success") + + def test_skill_not_found(self): + """Test that calling a non-existent skill raises an appropriate error.""" + with self.assertRaises(ValueError): + self.skill_library.call("NonExistentSkill") + + +class SkillWithAgentTest(unittest.TestCase): + """Tests for skills used with an agent.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.robot = MockRobot() + self.skill_library = MyUnitreeSkills(robot=self.robot) + self.skill_library.initialize_skills() + + # Add a test skill + self.skill_library.add(TestSkill) + + # Create the agent + self.agent = OpenAIAgent( + dev_name="SkillTestAgent", + system_query="You are a skill testing agent. When prompted to perform an action, use the appropriate skill.", + skills=self.skill_library, + ) + + @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") + def test_agent_skill_identification(self, mock_query): + """Test that the agent can identify skills based on natural language.""" + # Mock the agent response + mock_response = mock.MagicMock() + mock_response.run.return_value = "I found the TestSkill and executed it." + mock_query.return_value = mock_response + + # Run the test + response = self.agent.run_observable_query("Please run the test skill").run() + + # Assertions + mock_query.assert_called_once_with("Please run the test skill") + self.assertEqual(response, "I found the TestSkill and executed it.") + + @mock.patch.object(TestSkill, "__call__") + @mock.patch("dimos.agents.agent.OpenAIAgent.run_observable_query") + def test_agent_skill_execution(self, mock_query, mock_skill_call): + """Test that the agent can execute skills properly.""" + # Mock the agent and skill call + mock_skill_call.return_value = "TestSkill executed successfully" + mock_response = mock.MagicMock() + mock_response.run.return_value = "Executed TestSkill successfully." + mock_query.return_value = mock_response + + # Run the test + response = self.agent.run_observable_query("Execute the TestSkill skill").run() + + # We can't directly verify the skill was called since our mocking setup + # doesn't capture the internal skill execution of the agent, but we can + # verify the agent was properly called + mock_query.assert_called_once_with("Execute the TestSkill skill") + self.assertEqual(response, "Executed TestSkill successfully.") + + def test_agent_multi_skill_registration(self): + """Test that multiple skills can be registered with an agent.""" + + # Create a new skill + class AnotherTestSkill(AbstractSkill): + def __call__(self): + return "Another test skill executed" + + # Register the new skill + initial_count = len(list(self.skill_library)) + self.skill_library.add(AnotherTestSkill) + + # Verify two distinct skills now exist + self.assertEqual(len(list(self.skill_library)), initial_count + 1) + + # Verify both skills are found by name + skill_names = [skill.__name__ for skill in self.skill_library] + self.assertIn("TestSkill", skill_names) + self.assertIn("AnotherTestSkill", skill_names) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_skills_rest.py b/tests/test_skills_rest.py new file mode 100644 index 0000000000..5aa098ca23 --- /dev/null +++ b/tests/test_skills_rest.py @@ -0,0 +1,72 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from textwrap import dedent + +from dotenv import load_dotenv +import reactivex as rx +import reactivex.operators as ops + +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.rest.rest import GenericRestSkill +from dimos.skills.skills import SkillLibrary +from dimos.web.robot_web_interface import RobotWebInterface + +# Load API key from environment +load_dotenv() + +# Create a skill library and add the GenericRestSkill +skills = SkillLibrary() +skills.add(GenericRestSkill) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) + +# Create a text stream for agent responses in the web interface +text_streams = { + "agent_responses": agent_response_stream, +} +web_interface = RobotWebInterface(port=5555, text_streams=text_streams) + +# Create a ClaudeAgent instance +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=web_interface.query_stream, + skills=skills, + system_query=dedent( + """ + You are a virtual agent. When given a query, respond by using + the appropriate tool calls if needed to execute commands on the robot. + + IMPORTANT: + Only return the response directly asked of the user. E.G. if the user asks for the time, + only return the time. If the user asks for the weather, only return the weather. + """ + ), + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=2000, +) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +# Start the web interface +web_interface.run() + +# Run this query in the web interface: +# +# Make a web request to nist to get the current time. +# You should use http://worldclockapi.com/api/json/utc/now +# diff --git a/tests/test_spatial_memory.py b/tests/test_spatial_memory.py new file mode 100644 index 0000000000..f525b016e0 --- /dev/null +++ b/tests/test_spatial_memory.py @@ -0,0 +1,306 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import chromadb +import cv2 +from matplotlib.patches import Circle +import matplotlib.pyplot as plt +import reactivex +from reactivex import operators as ops + +from dimos.agents.memory.visual_memory import VisualMemory +from dimos.msgs.geometry_msgs import Quaternion, Vector3 + +# from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 # Uncomment when properly configured +from dimos.perception.spatial_perception import SpatialMemory + + +def extract_pose_data(transform): + """Extract position and rotation from a transform message""" + if transform is None: + return None, None + + pos = transform.transform.translation + rot = transform.transform.rotation + + # Convert to Vector3 objects expected by SpatialMemory + position = Vector3(x=pos.x, y=pos.y, z=pos.z) + + # Convert quaternion to euler angles for rotation vector + quat = Quaternion(x=rot.x, y=rot.y, z=rot.z, w=rot.w) + euler = quat.to_euler() + rotation = Vector3(x=euler.x, y=euler.y, z=euler.z) + + return position, rotation + + +def setup_persistent_chroma_db(db_path="chromadb_data"): + """ + Set up a persistent ChromaDB database at the specified path. + + Args: + db_path: Path to store the ChromaDB database + + Returns: + The ChromaDB client instance + """ + # Create a persistent ChromaDB client + full_db_path = os.path.join("/home/stash/dimensional/dimos/assets/test_spatial_memory", db_path) + print(f"Setting up persistent ChromaDB at: {full_db_path}") + + # Ensure the directory exists + os.makedirs(full_db_path, exist_ok=True) + + return chromadb.PersistentClient(path=full_db_path) + + +def main(): + print("Starting spatial memory test...") + + # Create counters for tracking + frame_count = 0 + transform_count = 0 + stored_count = 0 + + print("Note: This test requires proper robot connection setup.") + print("Please ensure video_stream and transform_stream are properly configured.") + + # These need to be set up based on your specific robot configuration + video_stream = None # TODO: Set up video stream from robot + transform_stream = None # TODO: Set up transform stream from robot + + if video_stream is None or transform_stream is None: + print("\nWARNING: Video or transform streams not configured.") + print("Exiting test. Please configure streams properly.") + return + + # Setup output directory for visual memory + visual_memory_dir = "/home/stash/dimensional/dimos/assets/test_spatial_memory" + os.makedirs(visual_memory_dir, exist_ok=True) + + # Setup persistent storage path for visual memory + visual_memory_path = os.path.join(visual_memory_dir, "visual_memory.pkl") + + # Try to load existing visual memory if it exists + if os.path.exists(visual_memory_path): + try: + print(f"Loading existing visual memory from {visual_memory_path}...") + visual_memory = VisualMemory.load(visual_memory_path, output_dir=visual_memory_dir) + print(f"Loaded {visual_memory.count()} images from previous runs") + except Exception as e: + print(f"Error loading visual memory: {e}") + visual_memory = VisualMemory(output_dir=visual_memory_dir) + else: + print("No existing visual memory found. Starting with empty visual memory.") + visual_memory = VisualMemory(output_dir=visual_memory_dir) + + # Setup a persistent database for ChromaDB + db_client = setup_persistent_chroma_db() + + # Create spatial perception instance with persistent storage + print("Creating SpatialMemory with persistent vector database...") + spatial_memory = SpatialMemory( + collection_name="test_spatial_memory", + min_distance_threshold=1, # Store frames every 1 meter + min_time_threshold=1, # Store frames at least every 1 second + chroma_client=db_client, # Use the persistent client + visual_memory=visual_memory, # Use the visual memory we loaded or created + ) + + # Combine streams using combine_latest + # This will pair up items properly without buffering + combined_stream = reactivex.combine_latest(video_stream, transform_stream).pipe( + ops.map( + lambda pair: { + "frame": pair[0], # First element is the frame + "position": extract_pose_data(pair[1])[0], # Position as Vector3 + "rotation": extract_pose_data(pair[1])[1], # Rotation as Vector3 + } + ), + ops.filter(lambda data: data["position"] is not None and data["rotation"] is not None), + ) + + # Process with spatial memory + result_stream = spatial_memory.process_stream(combined_stream) + + # Simple callback to track stored frames and save them to the assets directory + def on_stored_frame(result): + nonlocal stored_count + # Only count actually stored frames (not debug frames) + if not not result.get("stored", True): + stored_count += 1 + pos = result["position"] + if isinstance(pos, tuple): + print( + f"\nStored frame #{stored_count} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})" + ) + else: + print(f"\nStored frame #{stored_count} at position {pos}") + + # Save the frame to the assets directory + if "frame" in result: + frame_filename = f"/home/stash/dimensional/dimos/assets/test_spatial_memory/frame_{stored_count:03d}.jpg" + cv2.imwrite(frame_filename, result["frame"]) + print(f"Saved frame to {frame_filename}") + + # Subscribe to results + print("Subscribing to spatial perception results...") + result_subscription = result_stream.subscribe(on_stored_frame) + + print("\nRunning until interrupted...") + try: + while True: + time.sleep(1.0) + print(f"Running: {stored_count} frames stored so far", end="\r") + except KeyboardInterrupt: + print("\nTest interrupted by user") + finally: + # Clean up resources + print("\nCleaning up...") + if "result_subscription" in locals(): + result_subscription.dispose() + + # Visualize spatial memory with multiple object queries + visualize_spatial_memory_with_objects( + spatial_memory, + objects=[ + "kitchen", + "conference room", + "vacuum", + "office", + "bathroom", + "boxes", + "telephone booth", + ], + output_filename="spatial_memory_map.png", + ) + + # Save visual memory to disk for later use + saved_path = spatial_memory.vector_db.visual_memory.save("visual_memory.pkl") + print(f"Saved {spatial_memory.vector_db.visual_memory.count()} images to disk at {saved_path}") + + spatial_memory.stop() + + print("Test completed successfully") + + +def visualize_spatial_memory_with_objects( + spatial_memory, objects, output_filename="spatial_memory_map.png" +): + """ + Visualize a spatial memory map with multiple labeled objects. + + Args: + spatial_memory: SpatialMemory instance + objects: List of object names to query and visualize (e.g. ["kitchen", "office"]) + output_filename: Filename to save the visualization + """ + # Define colors for different objects - will cycle through these + colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] + + # Get all stored locations for background + locations = spatial_memory.vector_db.get_all_locations() + if not locations: + print("No locations stored in spatial memory.") + return + + # Extract coordinates from all stored locations + x_coords = [] + y_coords = [] + for loc in locations: + if isinstance(loc, dict): + x_coords.append(loc.get("pos_x", 0)) + y_coords.append(loc.get("pos_y", 0)) + elif isinstance(loc, tuple | list) and len(loc) >= 2: + x_coords.append(loc[0]) + y_coords.append(loc[1]) + else: + print(f"Unknown location format: {loc}") + + # Create figure + plt.figure(figsize=(12, 10)) + + # Plot all points in blue + plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") + + # Container for all object coordinates + object_coords = {} + + # Query for each object and store the result + for i, obj in enumerate(objects): + color = colors[i % len(colors)] # Cycle through colors + print(f"\nProcessing {obj} query for visualization...") + + # Get best match for this object + results = spatial_memory.query_by_text(obj, limit=1) + if not results: + print(f"No results found for '{obj}'") + continue + + # Get the first (best) result + result = results[0] + metadata = result["metadata"] + + # Extract coordinates from the first metadata item + if isinstance(metadata, list) and metadata: + metadata = metadata[0] + + if isinstance(metadata, dict): + # New metadata format uses pos_x, pos_y + x = metadata.get("pos_x", metadata.get("x", 0)) + y = metadata.get("pos_y", metadata.get("y", 0)) + + # Store coordinates for this object + object_coords[obj] = (x, y) + + # Plot this object's position + plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) + + # Add annotation + obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" + plt.annotate( + f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" + ) + + # Save the image to a file using the object name + if "image" in result and result["image"] is not None: + # Clean the object name to make it suitable for a filename + clean_name = obj.replace(" ", "_").lower() + output_img_filename = f"{clean_name}_result.jpg" + cv2.imwrite(output_img_filename, result["image"]) + print(f"Saved {obj} image to {output_img_filename}") + + # Finalize the plot + plt.title("Spatial Memory Map with Query Results") + plt.xlabel("X Position (m)") + plt.ylabel("Y Position (m)") + plt.grid(True) + plt.axis("equal") + plt.legend() + + # Add origin circle + plt.gca().add_patch(Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) + + # Save the visualization + plt.savefig(output_filename, dpi=300) + print(f"Saved enhanced map visualization to {output_filename}") + + return object_coords + + +if __name__ == "__main__": + main() diff --git a/tests/test_spatial_memory_query.py b/tests/test_spatial_memory_query.py new file mode 100644 index 0000000000..5253ad6b2d --- /dev/null +++ b/tests/test_spatial_memory_query.py @@ -0,0 +1,295 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script for querying an existing spatial memory database + +Usage: + python test_spatial_memory_query.py --query "kitchen table" --limit 5 --threshold 0.7 --save-all + python test_spatial_memory_query.py --query "robot" --limit 3 --save-one +""" + +import argparse +from datetime import datetime +import os + +import chromadb +import cv2 +import matplotlib.pyplot as plt + +from dimos.agents.memory.visual_memory import VisualMemory +from dimos.perception.spatial_perception import SpatialMemory + + +def setup_persistent_chroma_db(db_path): + """Set up a persistent ChromaDB client at the specified path.""" + print(f"Setting up persistent ChromaDB at: {db_path}") + os.makedirs(db_path, exist_ok=True) + return chromadb.PersistentClient(path=db_path) + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Query spatial memory database.") + parser.add_argument( + "--query", type=str, default=None, help="Text query to search for (e.g., 'kitchen table')" + ) + parser.add_argument("--limit", type=int, default=3, help="Maximum number of results to return") + parser.add_argument( + "--threshold", + type=float, + default=None, + help="Similarity threshold (0.0-1.0). Only return results above this threshold.", + ) + parser.add_argument("--save-all", action="store_true", help="Save all result images") + parser.add_argument("--save-one", action="store_true", help="Save only the best matching image") + parser.add_argument( + "--visualize", + action="store_true", + help="Create a visualization of all stored memory locations", + ) + parser.add_argument( + "--db-path", + type=str, + default="/home/stash/dimensional/dimos/assets/test_spatial_memory/chromadb_data", + help="Path to ChromaDB database", + ) + parser.add_argument( + "--visual-memory-path", + type=str, + default="/home/stash/dimensional/dimos/assets/test_spatial_memory/visual_memory.pkl", + help="Path to visual memory file", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + print("Loading existing spatial memory database for querying...") + + # Setup the persistent ChromaDB client + db_client = setup_persistent_chroma_db(args.db_path) + + # Setup output directory for any saved results + output_dir = os.path.dirname(args.visual_memory_path) + + # Load the visual memory + print(f"Loading visual memory from {args.visual_memory_path}...") + if os.path.exists(args.visual_memory_path): + visual_memory = VisualMemory.load(args.visual_memory_path, output_dir=output_dir) + print(f"Loaded {visual_memory.count()} images from visual memory") + else: + visual_memory = VisualMemory(output_dir=output_dir) + print("No existing visual memory found. Query results won't include images.") + + # Create SpatialMemory with the existing database and visual memory + spatial_memory = SpatialMemory( + collection_name="test_spatial_memory", chroma_client=db_client, visual_memory=visual_memory + ) + + # Create a visualization if requested + if args.visualize: + print("\nCreating visualization of spatial memory...") + common_objects = [ + "kitchen", + "conference room", + "vacuum", + "office", + "bathroom", + "boxes", + "telephone booth", + ] + visualize_spatial_memory_with_objects( + spatial_memory, objects=common_objects, output_filename="spatial_memory_map.png" + ) + + # Handle query if provided + if args.query: + query = args.query + limit = args.limit + print(f"\nQuerying for: '{query}' (limit: {limit})...") + + # Run the query + results = spatial_memory.query_by_text(query, limit=limit) + + if not results: + print(f"No results found for query: '{query}'") + return + + # Filter by threshold if specified + if args.threshold is not None: + print(f"Filtering results with similarity threshold: {args.threshold}") + filtered_results = [] + for result in results: + # Distance is inverse of similarity (0 is perfect match) + # Convert to similarity score (1.0 is perfect match) + similarity = 1.0 - ( + result.get("distance", 0) if result.get("distance") is not None else 0 + ) + if similarity >= args.threshold: + filtered_results.append((result, similarity)) + + # Sort by similarity (highest first) + filtered_results.sort(key=lambda x: x[1], reverse=True) + + if not filtered_results: + print(f"No results met the similarity threshold of {args.threshold}") + return + + print(f"Found {len(filtered_results)} results above threshold") + results_with_scores = filtered_results + else: + # Add similarity scores for all results + results_with_scores = [] + for result in results: + similarity = 1.0 - ( + result.get("distance", 0) if result.get("distance") is not None else 0 + ) + results_with_scores.append((result, similarity)) + + # Process and display results + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + for i, (result, similarity) in enumerate(results_with_scores): + metadata = result.get("metadata", {}) + if isinstance(metadata, list) and metadata: + metadata = metadata[0] + + # Display result information + print(f"\nResult {i + 1} for '{query}':") + print(f"Similarity: {similarity:.4f} (distance: {1.0 - similarity:.4f})") + + # Extract and display position information + if isinstance(metadata, dict): + x = metadata.get("x", 0) + y = metadata.get("y", 0) + z = metadata.get("z", 0) + print(f"Position: ({x:.2f}, {y:.2f}, {z:.2f})") + if "timestamp" in metadata: + print(f"Timestamp: {metadata['timestamp']}") + if "frame_id" in metadata: + print(f"Frame ID: {metadata['frame_id']}") + + # Save image if requested and available + if "image" in result and result["image"] is not None: + # Only save first image, or all images based on flags + if args.save_one and i > 0: + continue + if not (args.save_all or args.save_one): + continue + + # Create a descriptive filename + clean_query = query.replace(" ", "_").replace("/", "_").lower() + output_filename = f"{clean_query}_result_{i + 1}_{timestamp}.jpg" + + # Save the image + cv2.imwrite(output_filename, result["image"]) + print(f"Saved image to {output_filename}") + elif "image" in result and result["image"] is None: + print("Image data not available for this result") + else: + print('No query specified. Use --query "text to search for" to run a query.') + print("Use --help to see all available options.") + + print("\nQuery completed successfully!") + + +def visualize_spatial_memory_with_objects( + spatial_memory, objects, output_filename="spatial_memory_map.png" +): + """Visualize spatial memory with labeled objects.""" + # Define colors for different objects + colors = ["red", "green", "orange", "purple", "brown", "cyan", "magenta", "yellow"] + + # Get all stored locations for background + locations = spatial_memory.vector_db.get_all_locations() + if not locations: + print("No locations stored in spatial memory.") + return + + # Extract coordinates + if len(locations[0]) >= 3: + x_coords = [loc[0] for loc in locations] + y_coords = [loc[1] for loc in locations] + else: + x_coords, y_coords = zip(*locations, strict=False) + + # Create figure + plt.figure(figsize=(12, 10)) + plt.scatter(x_coords, y_coords, c="blue", s=50, alpha=0.5, label="All Frames") + + # Container for object coordinates + object_coords = {} + + # Query for each object + for i, obj in enumerate(objects): + color = colors[i % len(colors)] + print(f"Processing {obj} query for visualization...") + + # Get best match + results = spatial_memory.query_by_text(obj, limit=1) + if not results: + print(f"No results found for '{obj}'") + continue + + # Process result + result = results[0] + metadata = result["metadata"] + + if isinstance(metadata, list) and metadata: + metadata = metadata[0] + + if isinstance(metadata, dict) and "x" in metadata and "y" in metadata: + x = metadata.get("x", 0) + y = metadata.get("y", 0) + + # Store coordinates + object_coords[obj] = (x, y) + + # Plot position + plt.scatter([x], [y], c=color, s=100, alpha=0.8, label=obj.title()) + + # Add annotation + obj_abbrev = obj[0].upper() if len(obj) > 0 else "X" + plt.annotate( + f"{obj_abbrev}", (x, y), textcoords="offset points", xytext=(0, 10), ha="center" + ) + + # Save image if available + if "image" in result and result["image"] is not None: + clean_name = obj.replace(" ", "_").lower() + output_img_filename = f"{clean_name}_result.jpg" + cv2.imwrite(output_img_filename, result["image"]) + print(f"Saved {obj} image to {output_img_filename}") + + # Finalize plot + plt.title("Spatial Memory Map with Query Results") + plt.xlabel("X Position (m)") + plt.ylabel("Y Position (m)") + plt.grid(True) + plt.axis("equal") + plt.legend() + + # Add origin marker + plt.gca().add_patch(plt.Circle((0, 0), 1.0, fill=False, color="blue", linestyle="--")) + + # Save visualization + plt.savefig(output_filename, dpi=300) + print(f"Saved visualization to {output_filename}") + + return object_coords + + +if __name__ == "__main__": + main() diff --git a/tests/test_standalone_chromadb.py b/tests/test_standalone_chromadb.py new file mode 100644 index 0000000000..2acbb68b5d --- /dev/null +++ b/tests/test_standalone_chromadb.py @@ -0,0 +1,84 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +# ----- +from langchain_chroma import Chroma +from langchain_openai import OpenAIEmbeddings + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +if not OPENAI_API_KEY: + raise Exception("OpenAI key not specified.") + +collection_name = "my_collection" + +embeddings = OpenAIEmbeddings( + model="text-embedding-3-large", + dimensions=1024, + api_key=OPENAI_API_KEY, +) + +db_connection = Chroma( + collection_name=collection_name, + embedding_function=embeddings, +) + + +def add_vector(vector_id, vector_data): + """Add a vector to the ChromaDB collection.""" + if not db_connection: + raise Exception("Collection not initialized. Call connect() first.") + db_connection.add_texts( + ids=[vector_id], + texts=[vector_data], + metadatas=[{"name": vector_id}], + ) + + +add_vector("id0", "Food") +add_vector("id1", "Cat") +add_vector("id2", "Mouse") +add_vector("id3", "Bike") +add_vector("id4", "Dog") +add_vector("id5", "Tricycle") +add_vector("id6", "Car") +add_vector("id7", "Horse") +add_vector("id8", "Vehicle") +add_vector("id6", "Red") +add_vector("id7", "Orange") +add_vector("id8", "Yellow") + + +def get_vector(vector_id): + """Retrieve a vector from the ChromaDB by its identifier.""" + result = db_connection.get(include=["embeddings"], ids=[vector_id]) + return result + + +print(get_vector("id1")) +# print(get_vector("id3")) +# print(get_vector("id0")) +# print(get_vector("id2")) + + +def query(query_texts, n_results=2): + """Query the collection with a specific text and return up to n results.""" + if not db_connection: + raise Exception("Collection not initialized. Call connect() first.") + return db_connection.similarity_search(query=query_texts, k=n_results) + + +results = query("Colors") +print(results) diff --git a/tests/test_standalone_fastapi.py b/tests/test_standalone_fastapi.py new file mode 100644 index 0000000000..644f68cc61 --- /dev/null +++ b/tests/test_standalone_fastapi.py @@ -0,0 +1,79 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +logging.basicConfig(level=logging.DEBUG) + +import cv2 +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn + +app = FastAPI() + +# Note: Chrome does not allow for loading more than 6 simultaneous +# video streams. Use Safari or another browser for utilizing +# multiple simultaneous streams. Possibly build out functionality +# that will stop live streams. + + +@app.get("/") +async def root(): + pid = os.getpid() # Get the current process ID + return {"message": f"Video Streaming Server, PID: {pid}"} + + +def video_stream_generator(): + pid = os.getpid() + print(f"Stream initiated by worker with PID: {pid}") # Log the PID when the generator is called + + # Use the correct path for your video source + cap = cv2.VideoCapture( + f"{os.getcwd()}/assets/trimmed_video_480p.mov" + ) # Change 0 to a filepath for video files + + if not cap.isOpened(): + yield (b"--frame\r\nContent-Type: text/plain\r\n\r\n" + b"Could not open video source\r\n") + return + + try: + while True: + ret, frame = cap.read() + # If frame is read correctly ret is True + if not ret: + print(f"Reached the end of the video, restarting... PID: {pid}") + cap.set( + cv2.CAP_PROP_POS_FRAMES, 0 + ) # Set the position of the next video frame to 0 (the beginning) + continue + _, buffer = cv2.imencode(".jpg", frame) + yield (b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + buffer.tobytes() + b"\r\n") + finally: + cap.release() + + +@app.get("/video") +async def video_endpoint(): + logging.debug("Attempting to open video stream.") + response = StreamingResponse( + video_stream_generator(), media_type="multipart/x-mixed-replace; boundary=frame" + ) + logging.debug("Streaming response set up.") + return response + + +if __name__ == "__main__": + uvicorn.run("__main__:app", host="0.0.0.0", port=5555, workers=20) diff --git a/tests/test_standalone_hugging_face.py b/tests/test_standalone_hugging_face.py new file mode 100644 index 0000000000..29c0295f90 --- /dev/null +++ b/tests/test_standalone_hugging_face.py @@ -0,0 +1,121 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from transformers import AutoModelForCausalLM, AutoTokenizer +# model_name = "Qwen/QwQ-32B" +# model = AutoModelForCausalLM.from_pretrained( +# model_name, +# torch_dtype="auto", +# device_map="auto" +# ) +# tokenizer = AutoTokenizer.from_pretrained(model_name) +# prompt = "How many r's are in the word \"strawberry\"" +# messages = [ +# {"role": "user", "content": prompt} +# ] +# text = tokenizer.apply_chat_template( +# messages, +# tokenize=False, +# add_generation_prompt=True +# ) +# model_inputs = tokenizer([text], return_tensors="pt").to(model.device) +# generated_ids = model.generate( +# **model_inputs, +# max_new_tokens=32768 +# ) +# generated_ids = [ +# output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) +# ] +# response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] +# print(response) +# ----------------------------------------------------------------------------- +# import requests +# import json +# API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" +# api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') +# HEADERS = {"Authorization": f"Bearer {api_key}"} +# prompt = "How many r's are in the word \"strawberry\"" +# messages = [ +# {"role": "user", "content": prompt} +# ] +# # Format the prompt in the desired chat format +# chat_template = ( +# f"{messages[0]['content']}\n" +# "Assistant:" +# ) +# payload = { +# "inputs": chat_template, +# "parameters": { +# "max_new_tokens": 32768, +# "temperature": 0.7 +# } +# } +# # API request +# response = requests.post(API_URL, headers=HEADERS, json=payload) +# # Handle response +# if response.status_code == 200: +# output = response.json()[0]['generated_text'] +# print(output.strip()) +# else: +# print(f"Error {response.status_code}: {response.text}") +# ----------------------------------------------------------------------------- +# import os +# import requests +# import time +# API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" +# api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') +# HEADERS = {"Authorization": f"Bearer {api_key}"} +# def query_with_retries(payload, max_retries=5, delay=15): +# for attempt in range(max_retries): +# response = requests.post(API_URL, headers=HEADERS, json=payload) +# if response.status_code == 200: +# return response.json()[0]['generated_text'] +# elif response.status_code == 500: # Service unavailable +# print(f"Attempt {attempt + 1}/{max_retries}: Model busy. Retrying in {delay} seconds...") +# time.sleep(delay) +# else: +# print(f"Error {response.status_code}: {response.text}") +# break +# return "Failed after multiple retries." +# prompt = "How many r's are in the word \"strawberry\"" +# messages = [{"role": "user", "content": prompt}] +# chat_template = f"{messages[0]['content']}\nAssistant:" +# payload = { +# "inputs": chat_template, +# "parameters": {"max_new_tokens": 32768, "temperature": 0.7} +# } +# output = query_with_retries(payload) +# print(output.strip()) +# ----------------------------------------------------------------------------- +import os + +from huggingface_hub import InferenceClient + +# Use environment variable for API key +api_key = os.getenv("HUGGINGFACE_ACCESS_TOKEN") + +client = InferenceClient( + provider="hf-inference", + api_key=api_key, +) + +messages = [{"role": "user", "content": 'How many r\'s are in the word "strawberry"'}] + +completion = client.chat.completions.create( + model="Qwen/QwQ-32B", + messages=messages, + max_tokens=150, +) + +print(completion.choices[0].message) diff --git a/tests/test_standalone_openai_json.py b/tests/test_standalone_openai_json.py new file mode 100644 index 0000000000..8c2eed13d3 --- /dev/null +++ b/tests/test_standalone_openai_json.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# ----- +import dotenv + +dotenv.load_dotenv() + +import json +from textwrap import dedent + +from openai import OpenAI +from pydantic import BaseModel + +MODEL = "gpt-4o-2024-08-06" + +math_tutor_prompt = """ + You are a helpful math tutor. You will be provided with a math problem, + and your goal will be to output a step by step solution, along with a final answer. + For each step, just provide the output as an equation use the explanation field to detail the reasoning. +""" + +bad_prompt = """ + Follow the instructions. +""" + +client = OpenAI() + + +class MathReasoning(BaseModel): + class Step(BaseModel): + explanation: str + output: str + + steps: list[Step] + final_answer: str + + +def get_math_solution(question: str): + completion = client.beta.chat.completions.parse( + model=MODEL, + messages=[ + {"role": "system", "content": dedent(bad_prompt)}, + {"role": "user", "content": question}, + ], + response_format=MathReasoning, + ) + return completion.choices[0].message + + +# Web Server +import http.server +import socketserver +import urllib.parse + +PORT = 5555 + + +class CustomHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): + # Parse query parameters from the URL + parsed_path = urllib.parse.urlparse(self.path) + query_params = urllib.parse.parse_qs(parsed_path.query) + + # Check for a specific query parameter, e.g., 'problem' + problem = query_params.get("problem", [""])[ + 0 + ] # Default to an empty string if 'problem' isn't provided + + if problem: + print(f"Problem: {problem}") + solution = get_math_solution(problem) + + if solution.refusal: + print(f"Refusal: {solution.refusal}") + + print(f"Solution: {solution}") + self.send_response(200) + else: + solution = json.dumps( + {"error": "Please provide a math problem using the 'problem' query parameter."} + ) + self.send_response(400) + + self.send_header("Content-type", "application/json; charset=utf-8") + self.end_headers() + + # Write the message content + self.wfile.write(str(solution).encode()) + + +with socketserver.TCPServer(("", PORT), CustomHandler) as httpd: + print(f"Serving at port {PORT}") + httpd.serve_forever() diff --git a/tests/test_standalone_openai_json_struct.py b/tests/test_standalone_openai_json_struct.py new file mode 100644 index 0000000000..304e4cf475 --- /dev/null +++ b/tests/test_standalone_openai_json_struct.py @@ -0,0 +1,89 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# ----- + +import dotenv + +dotenv.load_dotenv() + +from textwrap import dedent + +from openai import OpenAI +from pydantic import BaseModel + +MODEL = "gpt-4o-2024-08-06" + +math_tutor_prompt = """ + You are a helpful math tutor. You will be provided with a math problem, + and your goal will be to output a step by step solution, along with a final answer. + For each step, just provide the output as an equation use the explanation field to detail the reasoning. +""" + +general_prompt = """ + Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. +""" + +client = OpenAI() + + +class MathReasoning(BaseModel): + class Step(BaseModel): + explanation: str + output: str + + steps: list[Step] + final_answer: str + + +def get_math_solution(question: str): + prompt = general_prompt + completion = client.beta.chat.completions.parse( + model=MODEL, + messages=[ + {"role": "system", "content": dedent(prompt)}, + {"role": "user", "content": question}, + ], + response_format=MathReasoning, + ) + return completion.choices[0].message + + +# Define Problem +problem = "What is the derivative of 3x^2" +print(f"Problem: {problem}") + +# Query for result +solution = get_math_solution(problem) + +# If the query was refused +if solution.refusal: + print(f"Refusal: {solution.refusal}") + exit() + +# If we were able to successfully parse the response back +parsed_solution = solution.parsed +if not parsed_solution: + print("Unable to Parse Solution") + exit() + +# Print solution from class definitions +print(f"Parsed: {parsed_solution}") + +steps = parsed_solution.steps +print(f"Steps: {steps}") + +final_answer = parsed_solution.final_answer +print(f"Final Answer: {final_answer}") diff --git a/tests/test_standalone_openai_json_struct_func.py b/tests/test_standalone_openai_json_struct_func.py new file mode 100644 index 0000000000..628463de94 --- /dev/null +++ b/tests/test_standalone_openai_json_struct_func.py @@ -0,0 +1,174 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# ----- + +import dotenv + +dotenv.load_dotenv() + +import json +from textwrap import dedent + +from openai import OpenAI, pydantic_function_tool +from pydantic import BaseModel, Field +import requests + +MODEL = "gpt-4o-2024-08-06" + +math_tutor_prompt = """ + You are a helpful math tutor. You will be provided with a math problem, + and your goal will be to output a step by step solution, along with a final answer. + For each step, just provide the output as an equation use the explanation field to detail the reasoning. +""" + +general_prompt = """ + Follow the instructions. Output a step by step solution, along with a final answer. Use the explanation field to detail the reasoning. +""" + +client = OpenAI() + + +class MathReasoning(BaseModel): + class Step(BaseModel): + explanation: str + output: str + + steps: list[Step] + final_answer: str + + +# region Function Calling +class GetWeather(BaseModel): + latitude: str = Field(..., description="latitude e.g. Bogotá, Colombia") + longitude: str = Field(..., description="longitude e.g. Bogotá, Colombia") + + +def get_weather(latitude, longitude): + response = requests.get( + f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" + ) + data = response.json() + return data["current"]["temperature_2m"] + + +def get_tools(): + return [pydantic_function_tool(GetWeather)] + + +tools = get_tools() + + +def call_function(name, args): + if name == "get_weather": + print(f"Running function: {name}") + print(f"Arguments are: {args}") + return get_weather(**args) + elif name == "GetWeather": + print(f"Running function: {name}") + print(f"Arguments are: {args}") + return get_weather(**args) + else: + return f"Local function not found: {name}" + + +def callback(message, messages, response_message, tool_calls): + if message is None or message.tool_calls is None: + print("No message or tools were called.") + return + + has_called_tools = False + for tool_call in message.tool_calls: + messages.append(response_message) + + has_called_tools = True + name = tool_call.function.name + args = json.loads(tool_call.function.arguments) + + result = call_function(name, args) + print(f"Function Call Results: {result}") + + messages.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": str(result), "name": name} + ) + + # Complete the second call, after the functions have completed. + if has_called_tools: + print("Sending Second Query.") + completion_2 = client.beta.chat.completions.parse( + model=MODEL, + messages=messages, + response_format=MathReasoning, + tools=tools, + ) + print(f"Message: {completion_2.choices[0].message}") + return completion_2.choices[0].message + else: + print("No Need for Second Query.") + return None + + +# endregion Function Calling + + +def get_math_solution(question: str): + prompt = general_prompt + messages = [ + {"role": "system", "content": dedent(prompt)}, + {"role": "user", "content": question}, + ] + response = client.beta.chat.completions.parse( + model=MODEL, messages=messages, response_format=MathReasoning, tools=tools + ) + + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + + new_response = callback(response.choices[0].message, messages, response_message, tool_calls) + + return new_response or response.choices[0].message + + +# Define Problem +problems = ["What is the derivative of 3x^2", "What's the weather like in San Fran today?"] +problem = problems[0] + +for problem in problems: + print("================") + print(f"Problem: {problem}") + + # Query for result + solution = get_math_solution(problem) + + # If the query was refused + if solution.refusal: + print(f"Refusal: {solution.refusal}") + break + + # If we were able to successfully parse the response back + parsed_solution = solution.parsed + if not parsed_solution: + print("Unable to Parse Solution") + print(f"Solution: {solution}") + break + + # Print solution from class definitions + print(f"Parsed: {parsed_solution}") + + steps = parsed_solution.steps + print(f"Steps: {steps}") + + final_answer = parsed_solution.final_answer + print(f"Final Answer: {final_answer}") diff --git a/tests/test_standalone_openai_json_struct_func_playground.py b/tests/test_standalone_openai_json_struct_func_playground.py new file mode 100644 index 0000000000..9fbb5a6aad --- /dev/null +++ b/tests/test_standalone_openai_json_struct_func_playground.py @@ -0,0 +1,197 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----- +# # Milestone 1 +# from typing import List, Dict, Optional +# import requests +# import json +# from pydantic import BaseModel, Field +# from openai import OpenAI, pydantic_function_tool +# # Environment setup +# import dotenv +# dotenv.load_dotenv() +# # Constants and prompts +# MODEL = "gpt-4o-2024-08-06" +# GENERAL_PROMPT = ''' +# Follow the instructions. Output a step by step solution, along with a final answer. +# Use the explanation field to detail the reasoning. +# ''' +# # Initialize OpenAI client +# client = OpenAI() +# # Models and functions +# class Step(BaseModel): +# explanation: str +# output: str +# class MathReasoning(BaseModel): +# steps: List[Step] +# final_answer: str +# class GetWeather(BaseModel): +# latitude: str = Field(..., description="Latitude e.g., Bogotá, Colombia") +# longitude: str = Field(..., description="Longitude e.g., Bogotá, Colombia") +# def fetch_weather(latitude: str, longitude: str) -> Dict: +# url = f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" +# response = requests.get(url) +# return response.json().get('current', {}) +# # Tool management +# def get_tools() -> List[BaseModel]: +# return [pydantic_function_tool(GetWeather)] +# def handle_function_call(tool_call: Dict) -> Optional[str]: +# if tool_call['name'] == "get_weather": +# result = fetch_weather(**tool_call['args']) +# return f"Temperature is {result['temperature_2m']}°F" +# return None +# # Communication and processing with OpenAI +# def process_message_with_openai(question: str) -> MathReasoning: +# messages = [ +# {"role": "system", "content": GENERAL_PROMPT.strip()}, +# {"role": "user", "content": question} +# ] +# response = client.beta.chat.completions.parse( +# model=MODEL, +# messages=messages, +# response_format=MathReasoning, +# tools=get_tools() +# ) +# return response.choices[0].message +# def get_math_solution(question: str) -> MathReasoning: +# solution = process_message_with_openai(question) +# return solution +# # Example usage +# def main(): +# problems = [ +# "What is the derivative of 3x^2", +# "What's the weather like in San Francisco today?" +# ] +# problem = problems[1] +# print(f"Problem: {problem}") +# solution = get_math_solution(problem) +# if not solution: +# print("Failed to get a solution.") +# return +# if not solution.parsed: +# print("Failed to get a parsed solution.") +# print(f"Solution: {solution}") +# return +# print(f"Steps: {solution.parsed.steps}") +# print(f"Final Answer: {solution.parsed.final_answer}") +# if __name__ == "__main__": +# main() +# # Milestone 1 +# Milestone 2 +import json + +from dotenv import load_dotenv +import requests + +load_dotenv() + +from openai import OpenAI + +client = OpenAI() + + +def get_current_weather(latitude, longitude): + """Get the current weather in a given latitude and longitude using the 7Timer API""" + base = "http://www.7timer.info/bin/api.pl" + request_url = f"{base}?lon={longitude}&lat={latitude}&product=civillight&output=json" + response = requests.get(request_url) + + # Parse response to extract the main weather data + weather_data = response.json() + current_data = weather_data.get("dataseries", [{}])[0] + + result = { + "latitude": latitude, + "longitude": longitude, + "temp": current_data.get("temp2m", {"max": "Unknown", "min": "Unknown"}), + "humidity": "Unknown", + } + + # Convert the dictionary to JSON string to match the given structure + return json.dumps(result) + + +def run_conversation(content): + messages = [{"role": "user", "content": content}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given latitude and longitude", + "parameters": { + "type": "object", + "properties": { + "latitude": { + "type": "string", + "description": "The latitude of a place", + }, + "longitude": { + "type": "string", + "description": "The longitude of a place", + }, + }, + "required": ["latitude", "longitude"], + }, + }, + } + ] + response = client.chat.completions.create( + model="gpt-3.5-turbo-0125", + messages=messages, + tools=tools, + tool_choice="auto", + ) + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + + if tool_calls: + messages.append(response_message) + + available_functions = { + "get_current_weather": get_current_weather, + } + for tool_call in tool_calls: + print(f"Function: {tool_call.function.name}") + print(f"Params:{tool_call.function.arguments}") + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + latitude=function_args.get("latitude"), + longitude=function_args.get("longitude"), + ) + print(f"API: {function_response}") + messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + } + ) + + second_response = client.chat.completions.create( + model="gpt-3.5-turbo-0125", messages=messages, stream=True + ) + return second_response + + +if __name__ == "__main__": + question = "What's the weather like in Paris and San Francisco?" + response = run_conversation(question) + for chunk in response: + print(chunk.choices[0].delta.content or "", end="", flush=True) +# Milestone 2 diff --git a/tests/test_standalone_project_out.py b/tests/test_standalone_project_out.py new file mode 100644 index 0000000000..43036df464 --- /dev/null +++ b/tests/test_standalone_project_out.py @@ -0,0 +1,135 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----- +import ast +import inspect +import sys + + +def extract_function_info(filename): + with open(filename) as f: + source = f.read() + tree = ast.parse(source, filename=filename) + + function_info = [] + + # Use a dictionary to track functions + module_globals = {} + + # Add the source to the locals (useful if you use local functions) + exec(source, module_globals) + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef): + docstring = ast.get_docstring(node) or "" + + # Attempt to get the callable object from the globals + try: + if node.name in module_globals: + func_obj = module_globals[node.name] + signature = inspect.signature(func_obj) + function_info.append( + {"name": node.name, "signature": str(signature), "docstring": docstring} + ) + else: + function_info.append( + { + "name": node.name, + "signature": "Could not get signature", + "docstring": docstring, + } + ) + except TypeError as e: + print( + f"Could not get function signature for {node.name} in {filename}: {e}", + file=sys.stderr, + ) + function_info.append( + { + "name": node.name, + "signature": "Could not get signature", + "docstring": docstring, + } + ) + + class_info = [] + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + docstring = ast.get_docstring(node) or "" + methods = [] + for method in node.body: + if isinstance(method, ast.FunctionDef | ast.AsyncFunctionDef): + method_docstring = ast.get_docstring(method) or "" + try: + if node.name in module_globals: + class_obj = module_globals[node.name] + method_obj = getattr(class_obj, method.name) + signature = inspect.signature(method_obj) + methods.append( + { + "name": method.name, + "signature": str(signature), + "docstring": method_docstring, + } + ) + else: + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) + except AttributeError as e: + print( + f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", + file=sys.stderr, + ) + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) + except TypeError as e: + print( + f"Could not get method signature for {node.name}.{method.name} in {filename}: {e}", + file=sys.stderr, + ) + methods.append( + { + "name": method.name, + "signature": "Could not get signature", + "docstring": method_docstring, + } + ) + class_info.append({"name": node.name, "docstring": docstring, "methods": methods}) + + return {"function_info": function_info, "class_info": class_info} + + +# Usage: +file_path = "./dimos/agents/memory/base.py" +extracted_info = extract_function_info(file_path) +print(extracted_info) + +file_path = "./dimos/agents/memory/chroma_impl.py" +extracted_info = extract_function_info(file_path) +print(extracted_info) + +file_path = "./dimos/agents/agent.py" +extracted_info = extract_function_info(file_path) +print(extracted_info) diff --git a/tests/test_standalone_rxpy_01.py b/tests/test_standalone_rxpy_01.py new file mode 100644 index 0000000000..497697d623 --- /dev/null +++ b/tests/test_standalone_rxpy_01.py @@ -0,0 +1,130 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +from threading import Event + +# ----- +import reactivex +from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler + +which_test = 2 +if which_test == 1: + """ + Test 1: Periodic Emission Test + + This test creates a ThreadPoolScheduler that leverages as many threads as there are CPU + cores available, optimizing the execution across multiple threads. The core functionality + revolves around an observable, secondly_emission, which emits a value every second. + Each emission is an incrementing integer, which is then mapped to a message indicating + the number of seconds since the test began. The sequence is limited to 30 emissions, + each logged as it occurs, and accompanied by an additional message via the + emission_process function to indicate the value's emission. The test subscribes to the + observable to print each emitted value, handle any potential errors, and confirm + completion of the emissions after 30 seconds. + + Key Components: + • ThreadPoolScheduler: Manages concurrency with multiple threads. + • Observable Sequence: Emits every second, indicating progression with a specific + message format. + • Subscription: Monitors and logs emissions, errors, and the completion event. + """ + + # Create a scheduler that uses as many threads as there are CPUs available + optimal_thread_count = multiprocessing.cpu_count() + pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + def emission_process(value): + print(f"Emitting: {value}") + + # Create an observable that emits every second + secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( + ops.map(lambda x: f"Value {x} emitted after {x + 1} second(s)"), + ops.do_action(emission_process), + ops.take(30), # Limit the emission to 30 times + ) + + # Subscribe to the observable to start emitting + secondly_emission.subscribe( + on_next=lambda x: print(x), + on_error=lambda e: print(e), + on_completed=lambda: print("Emission completed."), + scheduler=pool_scheduler, + ) + +elif which_test == 2: + """ + Test 2: Combined Emission Test + + In this test, a similar ThreadPoolScheduler setup is used to handle tasks across multiple + CPU cores efficiently. This setup includes two observables. The first, secondly_emission, + emits an incrementing integer every second, indicating the passage of time. The second + observable, immediate_emission, emits a predefined sequence of characters (['a', 'b', + 'c', 'd', 'e']) repeatedly and immediately. These two streams are combined using the zip + operator, which synchronizes their emissions into pairs. Each combined pair is formatted + and logged, indicating both the time elapsed and the immediate value emitted at that + second. + + A synchronization mechanism via an Event (completed_event) ensures that the main program + thread waits until all planned emissions are completed before exiting. This test not only + checks the functionality of zipping different rhythmic emissions but also demonstrates + handling of asynchronous task completion in Python using event-driven programming. + + Key Components: + • Combined Observable Emissions: Synchronizes periodic and immediate emissions into + a single stream. + • Event Synchronization: Uses a threading event to manage program lifecycle and + ensure that all emissions are processed before shutdown. + • Complex Subscription Management: Handles errors and completion, including + setting an event to signal the end of task processing. + """ + + # Create a scheduler with optimal threads + optimal_thread_count = multiprocessing.cpu_count() + pool_scheduler = ThreadPoolScheduler(optimal_thread_count) + + # Define an event to wait for the observable to complete + completed_event = Event() + + def emission_process(value): + print(f"Emitting: {value}") + + # Observable that emits every second + secondly_emission = reactivex.interval(1.0, scheduler=pool_scheduler).pipe( + ops.map(lambda x: f"Second {x + 1}"), ops.take(30) + ) + + # Observable that emits values immediately and repeatedly + immediate_emission = reactivex.from_(["a", "b", "c", "d", "e"]).pipe(ops.repeat()) + + # Combine emissions using zip + combined_emissions = reactivex.zip(secondly_emission, immediate_emission).pipe( + ops.map(lambda combined: f"{combined[0]} - Value: {combined[1]}"), + ops.do_action(lambda s: print(f"Combined emission: {s}")), + ) + + # Subscribe to the combined emissions + combined_emissions.subscribe( + on_next=lambda x: print(x), + on_error=lambda e: print(f"Error: {e}"), + on_completed=lambda: { + print("Combined emission completed."), + completed_event.set(), # Set the event to signal completion + }, + scheduler=pool_scheduler, + ) + + # Wait for the observable to complete + completed_event.wait() diff --git a/tests/test_unitree_agent.py b/tests/test_unitree_agent.py new file mode 100644 index 0000000000..4cf9a9e96b --- /dev/null +++ b/tests/test_unitree_agent.py @@ -0,0 +1,317 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +from dimos.web.fastapi_server import FastAPIServer + +print(f"Current working directory: {os.getcwd()}") + +# ----- + +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.data_provider import QueryDataProvider + +MOCK_CONNECTION = True + + +class UnitreeAgentDemo: + def __init__(self): + self.robot_ip = None + self.connection_method = None + self.serial_number = None + self.output_dir = None + self._fetch_env_vars() + + def _fetch_env_vars(self): + print("Fetching environment variables") + + def get_env_var(var_name, default=None, required=False): + """Get environment variable with validation.""" + value = os.getenv(var_name, default) + if required and not value: + raise ValueError(f"{var_name} environment variable is required") + return value + + self.robot_ip = get_env_var("ROBOT_IP", required=True) + self.connection_method = get_env_var("CONN_TYPE") + self.serial_number = get_env_var("SERIAL_NUMBER") + self.output_dir = get_env_var( + "ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros") + ) + + def _initialize_robot(self, with_video_stream=True): + print( + f"Initializing Unitree Robot {'with' if with_video_stream else 'without'} Video Stream" + ) + self.robot = UnitreeGo2( + ip=self.robot_ip, + connection_method=self.connection_method, + serial_number=self.serial_number, + output_dir=self.output_dir, + disable_video_stream=(not with_video_stream), + mock_connection=MOCK_CONNECTION, + ) + print(f"Robot initialized: {self.robot}") + + # ----- + + def run_with_queries(self): + # Initialize robot + self._initialize_robot(with_video_stream=False) + + # Initialize query stream + query_provider = QueryDataProvider() + + # Create the skills available to the agent. + # By default, this will create all skills in this class and make them available. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent") + self.UnitreePerceptionAgent = OpenAIAgent( + dev_name="UnitreePerceptionAgent", + agent_type="Perception", + input_query_stream=query_provider.data_stream, + output_dir=self.output_dir, + skills=skills_instance, + # frame_processor=frame_processor, + ) + + # Start the query stream. + # Queries will be pushed every 1 second, in a count from 100 to 5000. + # This will cause listening agents to consume the queries and respond + # to them via skill execution and provide 1-shot responses. + query_provider.start_query_stream( + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + frequency=0.01, + start_count=1, + end_count=10000, + step=1, + ) + + def run_with_test_video(self): + # Initialize robot + self._initialize_robot(with_video_stream=False) + + # Initialize test video stream + from dimos.stream.video_provider import VideoProvider + + self.video_stream = VideoProvider( + dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" + ).capture_video_as_observable(realtime=False, fps=1) + + # Get Skills + # By default, this will create all skills in this class and make them available to the agent. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent (Test Video)") + self.UnitreePerceptionAgent = OpenAIAgent( + dev_name="UnitreePerceptionAgent", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + def run_with_ros_video(self): + # Initialize robot + self._initialize_robot() + + # Initialize ROS video stream + print("Starting Unitree Perception Stream") + self.video_stream = self.robot.get_ros_video_stream() + + # Get Skills + # By default, this will create all skills in this class and make them available to the agent. + skills_instance = MyUnitreeSkills(robot=self.robot) + + # Run recovery stand + print("Running recovery stand") + self.robot.webrtc_req(api_id=1006) + + # Wait for 1 second + time.sleep(1) + + # Switch to sport mode + print("Switching to sport mode") + self.robot.webrtc_req(api_id=1011, parameter='{"gait_type": "sport"}') + + # Wait for 1 second + time.sleep(1) + + print("Starting Unitree Perception Agent (ROS Video)") + self.UnitreePerceptionAgent = OpenAIAgent( + dev_name="UnitreePerceptionAgent", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Based on the image, execute the command seen in the image AND ONLY THE COMMAND IN THE IMAGE. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + # WORKING MOVEMENT DEMO VVV + # query="Move() 5 meters foward. Then spin 360 degrees to the right, and then Reverse() 5 meters, and then Move forward 3 meters", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + def run_with_multiple_query_and_test_video_agents(self): + # Initialize robot + self._initialize_robot(with_video_stream=False) + + # Initialize query stream + query_provider = QueryDataProvider() + + # Initialize test video stream + from dimos.stream.video_provider import VideoProvider + + self.video_stream = VideoProvider( + dev_name="UnitreeGo2", video_source=f"{os.getcwd()}/assets/framecount.mp4" + ).capture_video_as_observable(realtime=False, fps=1) + + # Create the skills available to the agent. + # By default, this will create all skills in this class and make them available. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent") + self.UnitreeQueryPerceptionAgent = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgent", + agent_type="Perception", + input_query_stream=query_provider.data_stream, + output_dir=self.output_dir, + skills=skills_instance, + # frame_processor=frame_processor, + ) + + print("Starting Unitree Perception Agent Two") + self.UnitreeQueryPerceptionAgentTwo = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgentTwo", + agent_type="Perception", + input_query_stream=query_provider.data_stream, + output_dir=self.output_dir, + skills=skills_instance, + # frame_processor=frame_processor, + ) + + print("Starting Unitree Perception Agent (Test Video)") + self.UnitreeVideoPerceptionAgent = OpenAIAgent( + dev_name="UnitreeVideoPerceptionAgent", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + print("Starting Unitree Perception Agent Two (Test Video)") + self.UnitreeVideoPerceptionAgentTwo = OpenAIAgent( + dev_name="UnitreeVideoPerceptionAgentTwo", + agent_type="Perception", + input_video_stream=self.video_stream, + output_dir=self.output_dir, + query="Denote the number you see in the image as the 'reference number'. Only provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + image_detail="high", + skills=skills_instance, + # frame_processor=frame_processor, + ) + + # Start the query stream. + # Queries will be pushed every 1 second, in a count from 100 to 5000. + # This will cause listening agents to consume the queries and respond + # to them via skill execution and provide 1-shot responses. + query_provider.start_query_stream( + query_template="{query}; Denote the number at the beginning of this query before the semicolon as the 'reference number'. Provide the reference number, without any other text in your response. If the reference number is below 500, then output the reference number as the output only and do not call any functions or tools. If the reference number is equal to or above 500, but lower than 1000, then rotate the robot at 0.5 rad/s for 1 second. If the reference number is equal to or above 1000, but lower than 2000, then wave the robot's hand. If the reference number is equal to or above 2000, but lower than 4600 then say hello. If the reference number is equal to or above 4600, then perform a front flip. IF YOU DO NOT FOLLOW THESE INSTRUCTIONS EXACTLY, YOU WILL DIE!!!", + frequency=0.01, + start_count=1, + end_count=10000000, + step=1, + ) + + def run_with_queries_and_fast_api(self): + # Initialize robot + self._initialize_robot(with_video_stream=True) + + # Initialize ROS video stream + print("Starting Unitree Perception Stream") + self.video_stream = self.robot.get_ros_video_stream() + + # Initialize test video stream + # from dimos.stream.video_provider import VideoProvider + # self.video_stream = VideoProvider( + # dev_name="UnitreeGo2", + # video_source=f"{os.getcwd()}/assets/framecount.mp4" + # ).capture_video_as_observable(realtime=False, fps=1) + + # Will be visible at http://[host]:[port]/video_feed/[key] + streams = { + "unitree_video": self.video_stream, + } + fast_api_server = FastAPIServer(port=5555, **streams) + + # Create the skills available to the agent. + skills_instance = MyUnitreeSkills(robot=self.robot) + + print("Starting Unitree Perception Agent") + self.UnitreeQueryPerceptionAgent = OpenAIAgent( + dev_name="UnitreeQueryPerceptionAgent", + agent_type="Perception", + input_query_stream=fast_api_server.query_stream, + output_dir=self.output_dir, + skills=skills_instance, + ) + + # Run the FastAPI server (this will block) + fast_api_server.run() + + # ----- + + def stop(self): + print("Stopping Unitree Agent") + self.robot.cleanup() + + +if __name__ == "__main__": + myUnitreeAgentDemo = UnitreeAgentDemo() + + test_to_run = 4 + + if test_to_run == 0: + myUnitreeAgentDemo.run_with_queries() + elif test_to_run == 1: + myUnitreeAgentDemo.run_with_test_video() + elif test_to_run == 2: + myUnitreeAgentDemo.run_with_ros_video() + elif test_to_run == 3: + myUnitreeAgentDemo.run_with_multiple_query_and_test_video_agents() + elif test_to_run == 4: + myUnitreeAgentDemo.run_with_queries_and_fast_api() + elif test_to_run < 0 or test_to_run >= 5: + raise AssertionError(f"Invalid test number: {test_to_run}") + + # Keep the program running to allow the Unitree Agent Demo to operate continuously + try: + print("\nRunning Unitree Agent Demo (Press Ctrl+C to stop)...") + while True: + time.sleep(0.1) + except KeyboardInterrupt: + print("\nStopping Unitree Agent Demo") + myUnitreeAgentDemo.stop() + except Exception as e: + print(f"Error in main loop: {e}") diff --git a/tests/test_unitree_agent_queries_fastapi.py b/tests/test_unitree_agent_queries_fastapi.py new file mode 100644 index 0000000000..b33901ae4a --- /dev/null +++ b/tests/test_unitree_agent_queries_fastapi.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree Go2 robot agent demo with FastAPI server integration. + +Connects a Unitree Go2 robot to an OpenAI agent with a web interface. + +Environment Variables: + OPENAI_API_KEY: Required. OpenAI API key. + ROBOT_IP: Required. IP address of the Unitree robot. + CONN_TYPE: Required. Connection method to the robot. + ROS_OUTPUT_DIR: Optional. Directory for ROS output files. +""" + +import os +import sys + +import reactivex as rx +import reactivex.operators as ops + +# Local application imports +from dimos.agents.agent import OpenAIAgent +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.utils.logging_config import setup_logger +from dimos.web.fastapi_server import FastAPIServer + +logger = setup_logger() + + +def main(): + # Get environment variables + robot_ip = os.getenv("ROBOT_IP") + if not robot_ip: + raise ValueError("ROBOT_IP environment variable is required") + connection_method = os.getenv("CONN_TYPE") or "webrtc" + output_dir = os.getenv("ROS_OUTPUT_DIR", os.path.join(os.getcwd(), "assets/output/ros")) + + try: + # Initialize robot + logger.info("Initializing Unitree Robot") + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + output_dir=output_dir, + skills=MyUnitreeSkills(), + ) + + # Set up video stream + logger.info("Starting video stream") + video_stream = robot.get_ros_video_stream() + + # Create FastAPI server with video stream and text streams + logger.info("Initializing FastAPI server") + streams = {"unitree_video": video_stream} + + # Create a subject for agent responses + agent_response_subject = rx.subject.Subject() + agent_response_stream = agent_response_subject.pipe(ops.share()) + + text_streams = { + "agent_responses": agent_response_stream, + } + + web_interface = FastAPIServer(port=5555, text_streams=text_streams, **streams) + + logger.info("Starting action primitive execution agent") + agent = OpenAIAgent( + dev_name="UnitreeQueryExecutionAgent", + input_query_stream=web_interface.query_stream, + output_dir=output_dir, + skills=robot.get_skills(), + ) + + # Subscribe to agent responses and send them to the subject + agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + + # Start server (blocking call) + logger.info("Starting FastAPI server") + web_interface.run() + + except KeyboardInterrupt: + print("Stopping demo...") + except Exception as e: + logger.error(f"Error: {e}") + return 1 + finally: + if robot: + robot.cleanup() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_unitree_ros_v0.0.4.py b/tests/test_unitree_ros_v0.0.4.py new file mode 100644 index 0000000000..80c546ad74 --- /dev/null +++ b/tests/test_unitree_ros_v0.0.4.py @@ -0,0 +1,193 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dotenv import load_dotenv +import reactivex as rx +import reactivex.operators as ops + +from dimos.agents.claude_agent import ClaudeAgent +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.skills.kill_skill import KillSkill +from dimos.skills.navigation import GetPose, NavigateWithText +from dimos.skills.observe_stream import ObserveStream +from dimos.skills.speak import Speak +from dimos.skills.visual_navigation_skills import FollowHuman +from dimos.stream.audio.pipelines import stt, tts +from dimos.utils.reactive import backpressure +from dimos.web.robot_web_interface import RobotWebInterface + +# Load API key from environment +load_dotenv() + +# Allow command line arguments to control spatial memory parameters +import argparse + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Run the robot with optional spatial memory parameters" + ) + parser.add_argument( + "--voice", + action="store_true", + help="Use voice input from microphone instead of web interface", + ) + return parser.parse_args() + + +args = parse_arguments() + +# Initialize robot with spatial memory parameters +robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + skills=MyUnitreeSkills(), + mock_connection=False, + new_memory=True, +) + +# Create a subject for agent responses +agent_response_subject = rx.subject.Subject() +agent_response_stream = agent_response_subject.pipe(ops.share()) +local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) + +# Initialize object detection stream +min_confidence = 0.6 +class_filter = None # No class filtering +detector = Detic2DDetector(vocabulary=None, threshold=min_confidence) + +# Create video stream from robot's camera +video_stream = backpressure(robot.get_ros_video_stream()) + +# Initialize ObjectDetectionStream with robot +object_detector = ObjectDetectionStream( + camera_intrinsics=robot.camera_intrinsics, + min_confidence=min_confidence, + class_filter=class_filter, + transform_to_map=robot.ros_control.transform_pose, + detector=detector, + video_stream=video_stream, +) + +# Create visualization stream for web interface +viz_stream = backpressure(object_detector.get_stream()).pipe( + ops.share(), + ops.map(lambda x: x["viz_frame"] if x is not None else None), + ops.filter(lambda x: x is not None), +) + +# Get the formatted detection stream +formatted_detection_stream = object_detector.get_formatted_stream().pipe( + ops.filter(lambda x: x is not None) +) + + +# Create a direct mapping that combines detection data with locations +def combine_with_locations(object_detections): + # Get locations from spatial memory + try: + locations = robot.get_spatial_memory().get_robot_locations() + + # Format the locations section + locations_text = "\n\nSaved Robot Locations:\n" + if locations: + for loc in locations: + locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " + locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" + else: + locations_text += "None\n" + + # Simply concatenate the strings + return object_detections + locations_text + except Exception as e: + print(f"Error adding locations: {e}") + return object_detections + + +# Create the combined stream with a simple pipe operation +enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) + +streams = { + "unitree_video": robot.get_ros_video_stream(), + "local_planner_viz": local_planner_viz_stream, + "object_detection": viz_stream, +} +text_streams = { + "agent_responses": agent_response_stream, +} + +web_interface = RobotWebInterface(port=5555, text_streams=text_streams, **streams) + +stt_node = stt() + +# Read system query from prompt.txt file +with open( + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt") +) as f: + system_query = f.read() + +# Create a ClaudeAgent instance with either voice input or web interface input based on flag +input_stream = stt_node.emit_text() if args.voice else web_interface.query_stream +print(f"Using {'voice input' if args.voice else 'web interface input'} for queries") + +agent = ClaudeAgent( + dev_name="test_agent", + input_query_stream=input_stream, + input_data_stream=enhanced_data_stream, # Add the enhanced data stream + skills=robot.get_skills(), + system_query=system_query, + model_name="claude-3-7-sonnet-latest", + thinking_budget_tokens=0, +) + +# Initialize TTS node only if voice flag is set +tts_node = None +if args.voice: + print("Voice mode: Enabling TTS for speech output") + tts_node = tts() + tts_node.consume_text(agent.get_response_observable()) +else: + print("Web interface mode: Disabling TTS to avoid audio issues") + +robot_skills = robot.get_skills() +robot_skills.add(ObserveStream) +robot_skills.add(KillSkill) +robot_skills.add(NavigateWithText) +robot_skills.add(FollowHuman) +robot_skills.add(GetPose) +# Add Speak skill only if voice flag is set +if args.voice: + robot_skills.add(Speak) +# robot_skills.add(NavigateToGoal) +robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) +robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) +robot_skills.create_instance("NavigateWithText", robot=robot) +robot_skills.create_instance("FollowHuman", robot=robot) +robot_skills.create_instance("GetPose", robot=robot) +# robot_skills.create_instance("NavigateToGoal", robot=robot) +# Create Speak skill instance only if voice flag is set +if args.voice: + robot_skills.create_instance("Speak", tts_node=tts_node) + +# Subscribe to agent responses and send them to the subject +agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) + +print("ObserveStream and Kill skills registered and ready for use") +print("Created memory.txt file") + +web_interface.run() diff --git a/tests/test_webrtc_queue.py b/tests/test_webrtc_queue.py new file mode 100644 index 0000000000..223cc4a4c1 --- /dev/null +++ b/tests/test_webrtc_queue.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl + + +def main(): + """Test WebRTC request queue with a sequence of 20 back-to-back commands""" + + print("Initializing UnitreeGo2...") + + # Get configuration from environment variables + + robot_ip = os.getenv("ROBOT_IP") + connection_method = getattr(WebRTCConnectionMethod, os.getenv("CONNECTION_METHOD", "LocalSTA")) + + # Initialize ROS control + ros_control = UnitreeROSControl(node_name="unitree_go2_test", use_raw=True) + + # Initialize robot + robot = UnitreeGo2( + ip=robot_ip, + connection_method=connection_method, + ros_control=ros_control, + use_ros=True, + use_webrtc=False, # Using queue instead of direct WebRTC + ) + + # Wait for initialization + print("Waiting for robot to initialize...") + time.sleep(5) + + # First put the robot in a good starting state + print("Running recovery stand...") + robot.webrtc_req(api_id=1006) # RecoveryStand + + # Queue 20 WebRTC requests back-to-back + print("\n🤖 QUEUEING 20 COMMANDS BACK-TO-BACK 🤖\n") + + # Dance 1 + robot.webrtc_req(api_id=1022) # Dance1 + print("Queued: Dance1 (1022)") + + # Wiggle Hips + robot.webrtc_req(api_id=1033) # WiggleHips + print("Queued: WiggleHips (1033)") + + # Stretch + robot.webrtc_req(api_id=1017) # Stretch + print("Queued: Stretch (1017)") + + # Hello + robot.webrtc_req(api_id=1016) # Hello + print("Queued: Hello (1016)") + + # Dance 2 + robot.webrtc_req(api_id=1023) # Dance2 + print("Queued: Dance2 (1023)") + + # Wallow + robot.webrtc_req(api_id=1021) # Wallow + print("Queued: Wallow (1021)") + + # Scrape + robot.webrtc_req(api_id=1029) # Scrape + print("Queued: Scrape (1029)") + + # Finger Heart + robot.webrtc_req(api_id=1036) # FingerHeart + print("Queued: FingerHeart (1036)") + + # Recovery Stand (base position) + robot.webrtc_req(api_id=1006) # RecoveryStand + print("Queued: RecoveryStand (1006)") + + # Hello again + robot.webrtc_req(api_id=1016) # Hello + print("Queued: Hello (1016)") + + # Wiggle Hips again + robot.webrtc_req(api_id=1033) # WiggleHips + print("Queued: WiggleHips (1033)") + + # Front Pounce + robot.webrtc_req(api_id=1032) # FrontPounce + print("Queued: FrontPounce (1032)") + + # Dance 1 again + robot.webrtc_req(api_id=1022) # Dance1 + print("Queued: Dance1 (1022)") + + # Stretch again + robot.webrtc_req(api_id=1017) # Stretch + print("Queued: Stretch (1017)") + + # Front Jump + robot.webrtc_req(api_id=1031) # FrontJump + print("Queued: FrontJump (1031)") + + # Finger Heart again + robot.webrtc_req(api_id=1036) # FingerHeart + print("Queued: FingerHeart (1036)") + + # Scrape again + robot.webrtc_req(api_id=1029) # Scrape + print("Queued: Scrape (1029)") + + # Hello one more time + robot.webrtc_req(api_id=1016) # Hello + print("Queued: Hello (1016)") + + # Dance 2 again + robot.webrtc_req(api_id=1023) # Dance2 + print("Queued: Dance2 (1023)") + + # Finish with recovery stand + robot.webrtc_req(api_id=1006) # RecoveryStand + print("Queued: RecoveryStand (1006)") + + print("\nAll 20 commands queued successfully! Watch the robot perform them in sequence.") + print("The WebRTC queue manager will process them one by one when the robot is ready.") + print("Press Ctrl+C to stop the program when you've seen enough.\n") + + try: + # Keep the program running so the queue can be processed + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nStopping the test...") + finally: + # Cleanup + print("Cleaning up resources...") + robot.cleanup() + print("Test completed.") + + +if __name__ == "__main__": + main() diff --git a/tests/test_websocketvis.py b/tests/test_websocketvis.py new file mode 100644 index 0000000000..8261c998a0 --- /dev/null +++ b/tests/test_websocketvis.py @@ -0,0 +1,153 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +import pickle +import threading +import time + +from reactivex import operators as ops + +from dimos.robot.global_planner.planner import AstarPlanner +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.helpers import vector_stream +from dimos.web.websocket_vis.server import WebsocketVis + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple test for vis.") + parser.add_argument( + "--live", + action="store_true", + ) + parser.add_argument( + "--port", type=int, default=5555, help="Port for web visualization interface" + ) + return parser.parse_args() + + +def setup_web_interface(robot, port=5555): + """Set up web interface with robot video and local planner visualization""" + print(f"Setting up web interface on port {port}") + + # Get video stream from robot + video_stream = robot.video_stream_ros.pipe( + ops.share(), + ops.map(lambda frame: frame), + ops.filter(lambda frame: frame is not None), + ) + + # Get local planner visualization stream + local_planner_stream = robot.local_planner_viz_stream.pipe( + ops.share(), + ops.map(lambda frame: frame), + ops.filter(lambda frame: frame is not None), + ) + + # Create web interface with streams + web_interface = RobotWebInterface( + port=port, robot_video=video_stream, local_planner=local_planner_stream + ) + + return web_interface + + +def main(): + args = parse_args() + + websocket_vis = WebsocketVis() + websocket_vis.start() + + web_interface = None + + if args.live: + ros_control = UnitreeROSControl(node_name="web_nav_test", mock_connection=False) + robot = UnitreeGo2(ros_control=ros_control, ip=os.getenv("ROBOT_IP")) + planner = robot.global_planner + + websocket_vis.connect( + vector_stream("robot", lambda: robot.ros_control.transform_euler_pos("base_link")) + ) + websocket_vis.connect( + robot.ros_control.topic("map", Costmap).pipe(ops.map(lambda x: ["costmap", x])) + ) + + # Also set up the web interface with both streams + if hasattr(robot, "video_stream_ros") and hasattr(robot, "local_planner_viz_stream"): + web_interface = setup_web_interface(robot, port=args.port) + + # Start web interface in a separate thread + viz_thread = threading.Thread(target=web_interface.run, daemon=True) + viz_thread.start() + print(f"Web interface available at http://localhost:{args.port}") + + else: + pickle_path = f"{__file__.rsplit('/', 1)[0]}/mockdata/vegas.pickle" + print(f"Loading costmap from {pickle_path}") + planner = AstarPlanner( + get_costmap=lambda: pickle.load(open(pickle_path, "rb")), + get_robot_pos=lambda: Vector(5.0, 5.0), + set_local_nav=lambda x: time.sleep(1) and True, + ) + + def msg_handler(msgtype, data): + if msgtype == "click": + target = Vector(data["position"]) + try: + planner.set_goal(target) + except Exception as e: + print(f"Error setting goal: {e}") + return + + def threaded_msg_handler(msgtype, data): + thread = threading.Thread(target=msg_handler, args=(msgtype, data)) + thread.daemon = True + thread.start() + + websocket_vis.connect(planner.vis_stream()) + websocket_vis.msg_handler = threaded_msg_handler + + print(f"WebSocket server started on port {websocket_vis.port}") + print(planner.get_costmap()) + + planner.plan(Vector(-4.8, -1.0)) # plan a path to the origin + + def fakepos(): + # Simulate a fake vector position change (to test realtime rendering) + vec = Vector(math.sin(time.time()) * 2, math.cos(time.time()) * 2, 0) + print(vec) + return vec + + # if not args.live: + # websocket_vis.connect(rx.interval(0.05).pipe(ops.map(lambda _: ["fakepos", fakepos()]))) + + try: + # Keep the server running + while True: + time.sleep(0.1) + pass + except KeyboardInterrupt: + print("Stopping WebSocket server...") + websocket_vis.stop() + print("WebSocket server stopped") + + +if __name__ == "__main__": + main() diff --git a/tests/test_zed_module.py b/tests/test_zed_module.py new file mode 100644 index 0000000000..21b3dec02e --- /dev/null +++ b/tests/test_zed_module.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for ZED Module with LCM visualization.""" + +import asyncio +import threading +import time + +import cv2 +from dimos_lcm.geometry_msgs import PoseStamped + +# Import LCM message types +from dimos_lcm.sensor_msgs import CameraInfo, Image as LCMImage +import numpy as np + +from dimos import core +from dimos.hardware.zed_camera import ZEDModule +from dimos.perception.common.utils import colorize_depth +from dimos.protocol import pubsub +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class ZEDVisualizationNode: + """Node that subscribes to ZED topics and visualizes the data.""" + + def __init__(self): + self.lcm = LCM() + self.latest_color = None + self.latest_depth = None + self.latest_pose = None + self.camera_info = None + self._running = False + + # Subscribe to topics + self.color_topic = Topic("/zed/color_image", LCMImage) + self.depth_topic = Topic("/zed/depth_image", LCMImage) + self.camera_info_topic = Topic("/zed/camera_info", CameraInfo) + self.pose_topic = Topic("/zed/pose", PoseStamped) + + def start(self): + """Start the visualization node.""" + self._running = True + self.lcm.start() + + # Subscribe to topics + self.lcm.subscribe(self.color_topic, self._on_color_image) + self.lcm.subscribe(self.depth_topic, self._on_depth_image) + self.lcm.subscribe(self.camera_info_topic, self._on_camera_info) + self.lcm.subscribe(self.pose_topic, self._on_pose) + + logger.info("Visualization node started, subscribed to ZED topics") + + def stop(self): + """Stop the visualization node.""" + self._running = False + cv2.destroyAllWindows() + + def _on_color_image(self, msg: LCMImage, topic: str): + """Handle color image messages.""" + try: + # Convert LCM message to numpy array + data = np.frombuffer(msg.data, dtype=np.uint8) + + if msg.encoding == "rgb8": + image = data.reshape((msg.height, msg.width, 3)) + # Convert RGB to BGR for OpenCV + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + elif msg.encoding == "mono8": + image = data.reshape((msg.height, msg.width)) + else: + logger.warning(f"Unsupported encoding: {msg.encoding}") + return + + self.latest_color = image + logger.debug(f"Received color image: {msg.width}x{msg.height}") + + except Exception as e: + logger.error(f"Error processing color image: {e}") + + def _on_depth_image(self, msg: LCMImage, topic: str): + """Handle depth image messages.""" + try: + # Convert LCM message to numpy array + if msg.encoding == "32FC1": + data = np.frombuffer(msg.data, dtype=np.float32) + depth = data.reshape((msg.height, msg.width)) + else: + logger.warning(f"Unsupported depth encoding: {msg.encoding}") + return + + self.latest_depth = depth + logger.debug(f"Received depth image: {msg.width}x{msg.height}") + + except Exception as e: + logger.error(f"Error processing depth image: {e}") + + def _on_camera_info(self, msg: CameraInfo, topic: str): + """Handle camera info messages.""" + self.camera_info = msg + logger.info( + f"Received camera info: {msg.width}x{msg.height}, distortion model: {msg.distortion_model}" + ) + + def _on_pose(self, msg: PoseStamped, topic: str): + """Handle pose messages.""" + self.latest_pose = msg + pos = msg.pose.position + ori = msg.pose.orientation + logger.debug( + f"Pose: pos=({pos.x:.2f}, {pos.y:.2f}, {pos.z:.2f}), " + + f"ori=({ori.x:.2f}, {ori.y:.2f}, {ori.z:.2f}, {ori.w:.2f})" + ) + + def visualize(self): + """Run visualization loop.""" + while self._running: + # Create visualization + vis_images = [] + + # Color image + if self.latest_color is not None: + color_vis = self.latest_color.copy() + + # Add pose text if available + if self.latest_pose is not None: + pos = self.latest_pose.pose.position + text = f"Pose: ({pos.x:.2f}, {pos.y:.2f}, {pos.z:.2f})" + cv2.putText( + color_vis, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2 + ) + + vis_images.append(("ZED Color", color_vis)) + + # Depth image + if self.latest_depth is not None: + depth_colorized = colorize_depth(self.latest_depth, max_depth=5.0) + if depth_colorized is not None: + # Convert RGB to BGR for OpenCV + depth_colorized = cv2.cvtColor(depth_colorized, cv2.COLOR_RGB2BGR) + + # Add depth stats + valid_mask = np.isfinite(self.latest_depth) & (self.latest_depth > 0) + if np.any(valid_mask): + min_depth = np.min(self.latest_depth[valid_mask]) + max_depth = np.max(self.latest_depth[valid_mask]) + mean_depth = np.mean(self.latest_depth[valid_mask]) + + text = f"Depth: min={min_depth:.2f}m, max={max_depth:.2f}m, mean={mean_depth:.2f}m" + cv2.putText( + depth_colorized, + text, + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + + vis_images.append(("ZED Depth", depth_colorized)) + + # Show windows + for name, image in vis_images: + cv2.imshow(name, image) + + # Handle key press + key = cv2.waitKey(1) & 0xFF + if key == ord("q"): + logger.info("Quit requested") + self._running = False + break + elif key == ord("s"): + # Save images + if self.latest_color is not None: + cv2.imwrite("zed_color.png", self.latest_color) + logger.info("Saved color image to zed_color.png") + if self.latest_depth is not None: + np.save("zed_depth.npy", self.latest_depth) + logger.info("Saved depth data to zed_depth.npy") + + time.sleep(0.03) # ~30 FPS + + +async def test_zed_module(): + """Test the ZED Module with visualization.""" + logger.info("Starting ZED Module test") + + # Start Dask + dimos = core.start(1) + + # Enable LCM auto-configuration + pubsub.lcm.autoconf() + + try: + # Deploy ZED module + logger.info("Deploying ZED module...") + zed = dimos.deploy( + ZEDModule, + camera_id=0, + resolution="HD720", + depth_mode="NEURAL", + fps=30, + enable_tracking=True, + publish_rate=10.0, # 10 Hz for testing + frame_id="zed_camera", + ) + + # Configure LCM transports + zed.color_image.transport = core.LCMTransport("/zed/color_image", LCMImage) + zed.depth_image.transport = core.LCMTransport("/zed/depth_image", LCMImage) + zed.camera_info.transport = core.LCMTransport("/zed/camera_info", CameraInfo) + zed.pose.transport = core.LCMTransport("/zed/pose", PoseStamped) + + # Print module info + logger.info("ZED Module configured:") + + # Start ZED module + logger.info("Starting ZED module...") + zed.start() + + # Give module time to initialize + await asyncio.sleep(2) + + # Create and start visualization node + viz_node = ZEDVisualizationNode() + viz_node.start() + + # Run visualization in separate thread + viz_thread = threading.Thread(target=viz_node.visualize, daemon=True) + viz_thread.start() + + logger.info("ZED Module running. Press 'q' in image window to quit, 's' to save images.") + + # Keep running until visualization stops + while viz_node._running: + await asyncio.sleep(0.1) + + # Stop ZED module + logger.info("Stopping ZED module...") + zed.stop() + + # Stop visualization + viz_node.stop() + + except Exception as e: + logger.error(f"Error in test: {e}") + import traceback + + traceback.print_exc() + + finally: + # Clean up + dimos.close() + logger.info("Test completed") + + +if __name__ == "__main__": + # Run the test + asyncio.run(test_zed_module()) diff --git a/tests/test_zed_setup.py b/tests/test_zed_setup.py new file mode 100755 index 0000000000..49546dbc1c --- /dev/null +++ b/tests/test_zed_setup.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple test script to verify ZED camera setup and basic functionality. +""" + +from pathlib import Path +import sys + + +def test_imports(): + """Test that all required modules can be imported.""" + print("Testing imports...") + + try: + import numpy as np + + print("✓ NumPy imported successfully") + except ImportError as e: + print(f"✗ NumPy import failed: {e}") + return False + + try: + import cv2 + + print("✓ OpenCV imported successfully") + except ImportError as e: + print(f"✗ OpenCV import failed: {e}") + return False + + try: + from PIL import Image, ImageDraw, ImageFont + + print("✓ PIL imported successfully") + except ImportError as e: + print(f"✗ PIL import failed: {e}") + return False + + try: + import pyzed.sl as sl + + print("✓ ZED SDK (pyzed) imported successfully") + # Note: SDK version method varies between versions + except ImportError as e: + print(f"✗ ZED SDK import failed: {e}") + print(" Please install ZED SDK and pyzed package") + return False + + try: + from dimos.hardware.zed_camera import ZEDCamera + + print("✓ ZEDCamera class imported successfully") + except ImportError as e: + print(f"✗ ZEDCamera import failed: {e}") + return False + + try: + from dimos.perception.zed_visualizer import ZEDVisualizer + + print("✓ ZEDVisualizer class imported successfully") + except ImportError as e: + print(f"✗ ZEDVisualizer import failed: {e}") + return False + + return True + + +def test_camera_detection(): + """Test if ZED cameras are detected.""" + print("\nTesting camera detection...") + + try: + import pyzed.sl as sl + + # List available cameras + cameras = sl.Camera.get_device_list() + print(f"Found {len(cameras)} ZED camera(s):") + + for i, camera_info in enumerate(cameras): + print(f" Camera {i}:") + print(f" Model: {camera_info.camera_model}") + print(f" Serial: {camera_info.serial_number}") + print(f" State: {camera_info.camera_state}") + + return len(cameras) > 0 + + except Exception as e: + print(f"Error detecting cameras: {e}") + return False + + +def test_basic_functionality(): + """Test basic ZED camera functionality without actually opening the camera.""" + print("\nTesting basic functionality...") + + try: + import pyzed.sl as sl + + from dimos.hardware.zed_camera import ZEDCamera + from dimos.perception.zed_visualizer import ZEDVisualizer + + # Test camera initialization (without opening) + ZEDCamera( + camera_id=0, + resolution=sl.RESOLUTION.HD720, + depth_mode=sl.DEPTH_MODE.NEURAL, + ) + print("✓ ZEDCamera instance created successfully") + + # Test visualizer initialization + visualizer = ZEDVisualizer(max_depth=10.0) + print("✓ ZEDVisualizer instance created successfully") + + # Test creating a dummy visualization + dummy_rgb = np.zeros((480, 640, 3), dtype=np.uint8) + dummy_depth = np.ones((480, 640), dtype=np.float32) * 2.0 + + visualizer.create_side_by_side_image(dummy_rgb, dummy_depth) + print("✓ Dummy visualization created successfully") + + return True + + except Exception as e: + print(f"✗ Basic functionality test failed: {e}") + return False + + +def main(): + """Run all tests.""" + print("ZED Camera Setup Test") + print("=" * 50) + + # Test imports + if not test_imports(): + print("\n❌ Import tests failed. Please install missing dependencies.") + return False + + # Test camera detection + cameras_found = test_camera_detection() + if not cameras_found: + print( + "\n⚠️ No ZED cameras detected. Please connect a ZED camera to test capture functionality." + ) + + # Test basic functionality + if not test_basic_functionality(): + print("\n❌ Basic functionality tests failed.") + return False + + print("\n" + "=" * 50) + if cameras_found: + print("✅ All tests passed! You can now run the ZED demo:") + print(" python examples/zed_neural_depth_demo.py --display-time 10") + else: + print("✅ Setup is ready, but no camera detected.") + print(" Connect a ZED camera and run:") + print(" python examples/zed_neural_depth_demo.py --display-time 10") + + return True + + +if __name__ == "__main__": + # Add the project root to Python path + sys.path.append(str(Path(__file__).parent)) + + # Import numpy after path setup + import numpy as np + + success = main() + sys.exit(0 if success else 1) diff --git a/tests/visualization_script.py b/tests/visualization_script.py new file mode 100644 index 0000000000..0f08841453 --- /dev/null +++ b/tests/visualization_script.py @@ -0,0 +1,1006 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Visualize pickled manipulation pipeline results.""" + +import os +import pickle +import sys + +import matplotlib +import numpy as np + +# Try to use TkAgg backend for live display, fallback to Agg if not available +try: + matplotlib.use("TkAgg") +except: + try: + matplotlib.use("Qt5Agg") + except: + matplotlib.use("Agg") # Fallback to non-interactive +import matplotlib.pyplot as plt +import open3d as o3d + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import atexit +from contextlib import contextmanager +from datetime import datetime +import time + +import lcm_msgs +from pydrake.all import ( + AddMultibodyPlantSceneGraph, + DiagramBuilder, + JointIndex, + MeshcatVisualizer, + MeshcatVisualizerParams, + Parser, + RigidTransform, + RollPitchYaw, + RotationMatrix, + StartMeshcat, +) +from pydrake.common import MemoryFile +from pydrake.geometry import ( + Box, + CollisionFilterDeclaration, + InMemoryMesh, + Mesh, + ProximityProperties, +) +from pydrake.math import RigidTransform as DrakeRigidTransform +import tf_lcm_py +import trimesh + +from dimos.perception.pointcloud.utils import ( + visualize_clustered_point_clouds, + visualize_pcd, + visualize_voxel_grid, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def create_point_cloud(color_img, depth_img, intrinsics): + """Create Open3D point cloud from RGB and depth images.""" + fx, fy, cx, cy = intrinsics + height, width = depth_img.shape + + o3d_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) + color_o3d = o3d.geometry.Image(color_img) + depth_o3d = o3d.geometry.Image((depth_img * 1000).astype(np.uint16)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, depth_scale=1000.0, convert_rgb_to_intensity=False + ) + + return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, o3d_intrinsics) + + +def deserialize_point_cloud(data): + """Reconstruct Open3D PointCloud from serialized data.""" + if data is None: + return None + + pcd = o3d.geometry.PointCloud() + if data.get("points"): + pcd.points = o3d.utility.Vector3dVector(np.array(data["points"])) + if data.get("colors"): + pcd.colors = o3d.utility.Vector3dVector(np.array(data["colors"])) + return pcd + + +def deserialize_voxel_grid(data): + """Reconstruct Open3D VoxelGrid from serialized data.""" + if data is None: + return None + + # Create a point cloud to convert to voxel grid + pcd = o3d.geometry.PointCloud() + voxel_size = data["voxel_size"] + origin = np.array(data["origin"]) + + # Create points from voxel indices + points = [] + colors = [] + for voxel in data["voxels"]: + # Each voxel is (i, j, k, r, g, b) + i, j, k, r, g, b = voxel + # Convert voxel grid index to 3D point + point = origin + np.array([i, j, k]) * voxel_size + points.append(point) + colors.append([r, g, b]) + + if points: + pcd.points = o3d.utility.Vector3dVector(np.array(points)) + pcd.colors = o3d.utility.Vector3dVector(np.array(colors)) + + # Convert to voxel grid + voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size) + return voxel_grid + + +def visualize_results(pickle_path="manipulation_results.pkl"): + """Load pickled results and visualize them.""" + print(f"Loading results from {pickle_path}...") + try: + with open(pickle_path, "rb") as f: + data = pickle.load(f) + + results = data["results"] + data["color_img"] + data["depth_img"] + data["intrinsics"] + + print(f"Loaded results with keys: {list(results.keys())}") + + except FileNotFoundError: + print(f"Error: Pickle file {pickle_path} not found.") + print("Make sure to run test_manipulation_pipeline_single_frame_lcm.py first.") + return + except Exception as e: + print(f"Error loading pickle file: {e}") + return + + # Determine number of subplots based on what results we have + num_plots = 0 + plot_configs = [] + + if "detection_viz" in results and results["detection_viz"] is not None: + plot_configs.append(("detection_viz", "Object Detection")) + num_plots += 1 + + if "segmentation_viz" in results and results["segmentation_viz"] is not None: + plot_configs.append(("segmentation_viz", "Semantic Segmentation")) + num_plots += 1 + + if "pointcloud_viz" in results and results["pointcloud_viz"] is not None: + plot_configs.append(("pointcloud_viz", "All Objects Point Cloud")) + num_plots += 1 + + if "detected_pointcloud_viz" in results and results["detected_pointcloud_viz"] is not None: + plot_configs.append(("detected_pointcloud_viz", "Detection Objects Point Cloud")) + num_plots += 1 + + if "misc_pointcloud_viz" in results and results["misc_pointcloud_viz"] is not None: + plot_configs.append(("misc_pointcloud_viz", "Misc/Background Points")) + num_plots += 1 + + if "grasp_overlay" in results and results["grasp_overlay"] is not None: + plot_configs.append(("grasp_overlay", "Grasp Overlay")) + num_plots += 1 + + if num_plots == 0: + print("No visualization results to display") + return + + # Create subplot layout + if num_plots <= 3: + fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5)) + else: + rows = 2 + cols = (num_plots + 1) // 2 + _fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) + + # Ensure axes is always a list for consistent indexing + if num_plots == 1: + axes = [axes] + elif num_plots > 2: + axes = axes.flatten() + + # Plot each result + for i, (key, title) in enumerate(plot_configs): + axes[i].imshow(results[key]) + axes[i].set_title(title) + axes[i].axis("off") + + # Hide unused subplots if any + if num_plots > 3: + for i in range(num_plots, len(axes)): + axes[i].axis("off") + + plt.tight_layout() + + # Save and show the plot + output_path = "visualization_results.png" + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"Results visualization saved to: {output_path}") + + # Show plot live as well + plt.show(block=True) + plt.close() + + # Deserialize and reconstruct 3D objects from the pickle file + print("\nReconstructing 3D visualization objects from serialized data...") + + # Reconstruct full point cloud if available + full_pcd = None + if "full_pointcloud" in results and results["full_pointcloud"] is not None: + full_pcd = deserialize_point_cloud(results["full_pointcloud"]) + print(f"Reconstructed full point cloud with {len(np.asarray(full_pcd.points))} points") + + # Visualize reconstructed full point cloud + try: + visualize_pcd( + full_pcd, + window_name="Reconstructed Full Scene Point Cloud", + point_size=2.0, + show_coordinate_frame=True, + ) + except (KeyboardInterrupt, EOFError): + print("\nSkipping full point cloud visualization") + except Exception as e: + print(f"Error in point cloud visualization: {e}") + else: + print("No full point cloud available for visualization") + + # Reconstruct misc clusters if available + if results.get("misc_clusters"): + misc_clusters = [deserialize_point_cloud(cluster) for cluster in results["misc_clusters"]] + cluster_count = len(misc_clusters) + total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters) + print(f"Reconstructed {cluster_count} misc clusters with {total_misc_points} total points") + + # Visualize reconstructed misc clusters + try: + visualize_clustered_point_clouds( + misc_clusters, + window_name="Reconstructed Misc/Background Clusters (DBSCAN)", + point_size=3.0, + show_coordinate_frame=True, + ) + except (KeyboardInterrupt, EOFError): + print("\nSkipping misc clusters visualization") + except Exception as e: + print(f"Error in misc clusters visualization: {e}") + else: + print("No misc clusters available for visualization") + + # Reconstruct voxel grid if available + if "misc_voxel_grid" in results and results["misc_voxel_grid"] is not None: + misc_voxel_grid = deserialize_voxel_grid(results["misc_voxel_grid"]) + if misc_voxel_grid: + voxel_count = len(misc_voxel_grid.get_voxels()) + print(f"Reconstructed voxel grid with {voxel_count} voxels") + + # Visualize reconstructed voxel grid + try: + visualize_voxel_grid( + misc_voxel_grid, + window_name="Reconstructed Misc/Background Voxel Grid", + show_coordinate_frame=True, + ) + except (KeyboardInterrupt, EOFError): + print("\nSkipping voxel grid visualization") + except Exception as e: + print(f"Error in voxel grid visualization: {e}") + else: + print("Failed to reconstruct voxel grid") + else: + print("No voxel grid available for visualization") + + +class DrakeKinematicsEnv: + def __init__( + self, + urdf_path: str, + kinematic_chain_joints: list[str], + links_to_ignore: list[str] | None = None, + ): + self._resources_to_cleanup = [] + + # Register cleanup at exit + atexit.register(self.cleanup_resources) + + # Initialize tf resources once and reuse them + self.buffer = tf_lcm_py.Buffer(30.0) + self._resources_to_cleanup.append(self.buffer) + with self.safe_lcm_instance() as lcm_instance: + self.tf_lcm_instance = lcm_instance + self._resources_to_cleanup.append(self.tf_lcm_instance) + # Create TransformListener with our LCM instance and buffer + self.listener = tf_lcm_py.TransformListener(self.tf_lcm_instance, self.buffer) + self._resources_to_cleanup.append(self.listener) + + # Check if URDF file exists + if not os.path.exists(urdf_path): + raise FileNotFoundError(f"URDF file not found: {urdf_path}") + + # Drake utils initialization + self.meshcat = StartMeshcat() + print(f"Meshcat started at: {self.meshcat.web_url()}") + + self.urdf_path = urdf_path + self.builder = DiagramBuilder() + + self.plant, self.scene_graph = AddMultibodyPlantSceneGraph(self.builder, time_step=0.01) + self.parser = Parser(self.plant) + + # Load the robot URDF + print(f"Loading URDF from: {self.urdf_path}") + self.model_instances = self.parser.AddModelsFromUrl(f"file://{self.urdf_path}") + self.kinematic_chain_joints = kinematic_chain_joints + self.model_instance = self.model_instances[0] if self.model_instances else None + + if not self.model_instances: + raise RuntimeError("Failed to load any model instances from URDF") + + print(f"Loaded {len(self.model_instances)} model instances") + + # Set up collision filtering + if links_to_ignore: + bodies = [] + for link_name in links_to_ignore: + try: + body = self.plant.GetBodyByName(link_name) + if body is not None: + bodies.extend(self.plant.GetBodiesWeldedTo(body)) + except RuntimeError: + print(f"Warning: Link '{link_name}' not found in URDF") + + if bodies: + arm_geoms = self.plant.CollectRegisteredGeometries(bodies) + decl = CollisionFilterDeclaration().ExcludeWithin(arm_geoms) + manager = self.scene_graph.collision_filter_manager() + manager.Apply(decl) + + # Load and process point cloud data + self._load_and_process_point_clouds() + + # Finalize the plant before adding visualizer + self.plant.Finalize() + + # Print some debug info about the plant + print(f"Plant has {self.plant.num_bodies()} bodies") + print(f"Plant has {self.plant.num_joints()} joints") + for i in range(self.plant.num_joints()): + joint = self.plant.get_joint(JointIndex(i)) + print(f" Joint {i}: {joint.name()} (type: {joint.type_name()})") + + # Add visualizer + self.visualizer = MeshcatVisualizer.AddToBuilder( + self.builder, self.scene_graph, self.meshcat, params=MeshcatVisualizerParams() + ) + + # Build the diagram + self.diagram = self.builder.Build() + self.diagram_context = self.diagram.CreateDefaultContext() + self.plant_context = self.plant.GetMyContextFromRoot(self.diagram_context) + + # Set up joint indices + self.joint_indices = [] + for joint_name in self.kinematic_chain_joints: + try: + joint = self.plant.GetJointByName(joint_name) + if joint.num_positions() > 0: + start_index = joint.position_start() + for i in range(joint.num_positions()): + self.joint_indices.append(start_index + i) + print( + f"Added joint '{joint_name}' at indices {start_index} to {start_index + joint.num_positions() - 1}" + ) + except RuntimeError: + print(f"Warning: Joint '{joint_name}' not found in URDF.") + + # Get important frames/bodies + try: + self.end_effector_link = self.plant.GetBodyByName("link6") + self.end_effector_frame = self.plant.GetFrameByName("link6") + print("Found end effector link6") + except RuntimeError: + print("Warning: link6 not found") + self.end_effector_link = None + self.end_effector_frame = None + + try: + self.camera_link = self.plant.GetBodyByName("camera_center_link") + print("Found camera_center_link") + except RuntimeError: + print("Warning: camera_center_link not found") + self.camera_link = None + + # Set robot to a reasonable initial configuration + self._set_initial_configuration() + + # Force initial visualization update + self._update_visualization() + + print("Drake environment initialization complete!") + print(f"Visit {self.meshcat.web_url()} to see the visualization") + + def _load_and_process_point_clouds(self): + """Load point cloud data from pickle file and add to scene""" + pickle_path = "manipulation_results.pkl" + try: + with open(pickle_path, "rb") as f: + data = pickle.load(f) + + results = data["results"] + print(f"Loaded results with keys: {list(results.keys())}") + + except FileNotFoundError: + print(f"Warning: Pickle file {pickle_path} not found.") + print("Skipping point cloud loading.") + return + except Exception as e: + print(f"Warning: Error loading pickle file: {e}") + return + + full_detected_pcd = o3d.geometry.PointCloud() + for obj in results["detected_objects"]: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(obj["point_cloud_numpy"]) + full_detected_pcd += pcd + + self.process_and_add_object_class("all_objects", results) + self.process_and_add_object_class("misc_clusters", results) + misc_clusters = results["misc_clusters"] + print(type(misc_clusters[0]["points"])) + print(np.asarray(misc_clusters[0]["points"]).shape) + + def process_and_add_object_class(self, object_key: str, results: dict): + # Process detected objects + if object_key in results: + detected_objects = results[object_key] + if detected_objects: + print(f"Processing {len(detected_objects)} {object_key}") + all_decomposed_meshes = [] + + transform = self.get_transform("world", "camera_center_link") + for i in range(len(detected_objects)): + try: + if object_key == "misc_clusters": + points = np.asarray(detected_objects[i]["points"]) + elif "point_cloud_numpy" in detected_objects[i]: + points = detected_objects[i]["point_cloud_numpy"] + elif ( + "point_cloud" in detected_objects[i] + and detected_objects[i]["point_cloud"] + ): + # Handle serialized point cloud + points = np.array(detected_objects[i]["point_cloud"]["points"]) + else: + print(f"Warning: No point cloud data found for object {i}") + continue + + if len(points) < 10: # Need more points for mesh reconstruction + print( + f"Warning: Object {i} has too few points ({len(points)}) for mesh reconstruction" + ) + continue + + # Swap y-z axes since this is a common problem + points = np.column_stack((points[:, 0], points[:, 2], -points[:, 1])) + # Transform points to world frame + points = self.transform_point_cloud_with_open3d(points, transform) + + # Use fast DBSCAN clustering + convex hulls approach + clustered_hulls = self._create_clustered_convex_hulls(points, i) + all_decomposed_meshes.extend(clustered_hulls) + + print( + f"Created {len(clustered_hulls)} clustered convex hulls for object {i}" + ) + + except Exception as e: + print(f"Warning: Failed to process object {i}: {e}") + + if all_decomposed_meshes: + self.register_convex_hulls_as_collision(all_decomposed_meshes, object_key) + print(f"Registered {len(all_decomposed_meshes)} total clustered convex hulls") + else: + print("Warning: No valid clustered convex hulls created from detected objects") + else: + print("No detected objects found") + + def _create_clustered_convex_hulls( + self, points: np.ndarray, object_id: int + ) -> list[o3d.geometry.TriangleMesh]: + """ + Create convex hulls from DBSCAN clusters of point cloud data. + Fast approach: cluster points, then convex hull each cluster. + + Args: + points: Nx3 numpy array of 3D points + object_id: ID for debugging/logging + + Returns: + List of Open3D triangle meshes (convex hulls of clusters) + """ + try: + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + + # Quick outlier removal (optional, can skip for speed) + if len(points) > 50: # Only for larger point clouds + pcd, _ = pcd.remove_statistical_outlier(nb_neighbors=10, std_ratio=2.0) + points = np.asarray(pcd.points) + + if len(points) < 4: + print(f"Warning: Too few points after filtering for object {object_id}") + return [] + + # Try multiple DBSCAN parameter combinations to find clusters + clusters = [] + labels = None + + # Calculate some basic statistics for parameter estimation + if len(points) > 10: + # Compute nearest neighbor distances for better eps estimation + distances = pcd.compute_nearest_neighbor_distance() + avg_nn_distance = np.mean(distances) + np.std(distances) + + print( + f"Object {object_id}: {len(points)} points, avg_nn_dist={avg_nn_distance:.4f}" + ) + + for i in range(20): + try: + eps = avg_nn_distance * (2.0 + (i * 0.1)) + min_samples = 20 + labels = np.array(pcd.cluster_dbscan(eps=eps, min_points=min_samples)) + unique_labels = np.unique(labels) + clusters = unique_labels[unique_labels >= 0] # Remove noise label (-1) + + noise_points = np.sum(labels == -1) + clustered_points = len(points) - noise_points + + print( + f" Try {i + 1}: eps={eps:.4f}, min_samples={min_samples} → {len(clusters)} clusters, {clustered_points}/{len(points)} points clustered" + ) + + # Accept if we found clusters and most points are clustered + if ( + len(clusters) > 0 and clustered_points >= len(points) * 0.95 + ): # At least 30% of points clustered + print(f" ✓ Accepted parameter set {i + 1}") + break + + except Exception as e: + print( + f" Try {i + 1}: Failed with eps={eps:.4f}, min_samples={min_samples}: {e}" + ) + continue + + if len(clusters) == 0 or labels is None: + print( + f"No clusters found for object {object_id} after all attempts, using entire point cloud" + ) + # Fallback: use entire point cloud as single convex hull + hull_mesh, _ = pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + return [hull_mesh] + + print( + f"Found {len(clusters)} clusters for object {object_id} (eps={eps:.3f}, min_samples={min_samples})" + ) + + # Create convex hull for each cluster + convex_hulls = [] + for cluster_id in clusters: + try: + # Get points for this cluster + cluster_mask = labels == cluster_id + cluster_points = points[cluster_mask] + + if len(cluster_points) < 4: + print( + f"Skipping cluster {cluster_id} with only {len(cluster_points)} points" + ) + continue + + # Create point cloud for this cluster + cluster_pcd = o3d.geometry.PointCloud() + cluster_pcd.points = o3d.utility.Vector3dVector(cluster_points) + + # Compute convex hull + hull_mesh, _ = cluster_pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + + # Validate hull + if ( + len(np.asarray(hull_mesh.vertices)) >= 4 + and len(np.asarray(hull_mesh.triangles)) >= 4 + ): + convex_hulls.append(hull_mesh) + print( + f" Cluster {cluster_id}: {len(cluster_points)} points → convex hull with {len(np.asarray(hull_mesh.vertices))} vertices" + ) + else: + print(f" Skipping degenerate hull for cluster {cluster_id}") + + except Exception as e: + print(f"Error processing cluster {cluster_id} for object {object_id}: {e}") + + if not convex_hulls: + print( + f"No valid convex hulls created for object {object_id}, using entire point cloud" + ) + # Fallback: use entire point cloud as single convex hull + hull_mesh, _ = pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + return [hull_mesh] + + return convex_hulls + + except Exception as e: + print(f"Error in DBSCAN clustering for object {object_id}: {e}") + # Final fallback: single convex hull + try: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + hull_mesh, _ = pcd.compute_convex_hull() + hull_mesh.compute_vertex_normals() + return [hull_mesh] + except: + return [] + + def _set_initial_configuration(self): + """Set the robot to a reasonable initial joint configuration""" + # Set all joints to zero initially + if self.joint_indices: + q = np.zeros(len(self.joint_indices)) + + # You can customize these values for a better initial pose + # For example, if you know good default joint angles: + if len(q) >= 6: # Assuming at least 6 DOF arm + q[1] = 0.0 # joint1 + q[2] = 0.0 # joint2 + q[3] = 0.0 # joint3 + q[4] = 0.0 # joint4 + q[5] = 0.0 # joint5 + q[6] = 0.0 # joint6 + + # Set the joint positions in the plant context + positions = self.plant.GetPositions(self.plant_context) + for i, joint_idx in enumerate(self.joint_indices): + if joint_idx < len(positions): + positions[joint_idx] = q[i] + + self.plant.SetPositions(self.plant_context, positions) + print(f"Set initial joint configuration: {q}") + else: + print("Warning: No joint indices found, using default configuration") + + def _update_visualization(self): + """Force update the visualization""" + try: + # Get the visualizer's context from the diagram context + visualizer_context = self.visualizer.GetMyContextFromRoot(self.diagram_context) + self.visualizer.ForcedPublish(visualizer_context) + print("Visualization updated successfully") + except Exception as e: + print(f"Error updating visualization: {e}") + + def set_joint_positions(self, joint_positions): + """Set specific joint positions and update visualization""" + if len(joint_positions) != len(self.joint_indices): + raise ValueError( + f"Expected {len(self.joint_indices)} joint positions, got {len(joint_positions)}" + ) + + positions = self.plant.GetPositions(self.plant_context) + for i, joint_idx in enumerate(self.joint_indices): + if joint_idx < len(positions): + positions[joint_idx] = joint_positions[i] + + self.plant.SetPositions(self.plant_context, positions) + self._update_visualization() + print(f"Updated joint positions: {joint_positions}") + + def register_convex_hulls_as_collision( + self, meshes: list[o3d.geometry.TriangleMesh], hull_type: str + ): + """Register convex hulls as collision and visual geometry""" + if not meshes: + print("No meshes to register") + return + + world = self.plant.world_body() + proximity = ProximityProperties() + + for i, mesh in enumerate(meshes): + try: + # Convert Open3D → numpy arrays → trimesh.Trimesh + vertices = np.asarray(mesh.vertices) + faces = np.asarray(mesh.triangles) + + if len(vertices) == 0 or len(faces) == 0: + print(f"Warning: Mesh {i} is empty, skipping") + continue + + tmesh = trimesh.Trimesh(vertices=vertices, faces=faces) + + # Export to OBJ in memory + tmesh_obj_blob = tmesh.export(file_type="obj") + mem_file = MemoryFile( + contents=tmesh_obj_blob, extension=".obj", filename_hint=f"convex_hull_{i}.obj" + ) + in_memory_mesh = InMemoryMesh() + in_memory_mesh.mesh_file = mem_file + drake_mesh = Mesh(in_memory_mesh, scale=1.0) + + pos = np.array([0.0, 0.0, 0.0]) + rpy = RollPitchYaw(0.0, 0.0, 0.0) + X_WG = DrakeRigidTransform(RotationMatrix(rpy), pos) + + # Register collision and visual geometry + self.plant.RegisterCollisionGeometry( + body=world, + X_BG=X_WG, + shape=drake_mesh, + name=f"convex_hull_collision_{i}_{hull_type}", + properties=proximity, + ) + self.plant.RegisterVisualGeometry( + body=world, + X_BG=X_WG, + shape=drake_mesh, + name=f"convex_hull_visual_{i}_{hull_type}", + diffuse_color=np.array([0.7, 0.5, 0.3, 0.8]), # Orange-ish color + ) + + print( + f"Registered convex hull {i} with {len(vertices)} vertices and {len(faces)} faces" + ) + + except Exception as e: + print(f"Warning: Failed to register mesh {i}: {e}") + + # Add a simple table for reference + try: + table_shape = Box(1.0, 1.0, 0.1) # Thinner table + table_pose = RigidTransform(p=[0.5, 0.0, -0.05]) # In front of robot + self.plant.RegisterCollisionGeometry( + world, table_pose, table_shape, "table_collision", proximity + ) + self.plant.RegisterVisualGeometry( + world, table_pose, table_shape, "table_visual", [0.8, 0.6, 0.4, 1.0] + ) + print("Added reference table") + except Exception as e: + print(f"Warning: Failed to add table: {e}") + + def get_seeded_random_rgba(self, id: int): + np.random.seed(id) + return np.random.rand(4) + + @contextmanager + def safe_lcm_instance(self): + """Context manager for safely managing LCM instance lifecycle""" + lcm_instance = tf_lcm_py.LCM() + try: + yield lcm_instance + finally: + pass + + def cleanup_resources(self): + """Clean up resources before exiting""" + # Only clean up once when exiting + print("Cleaning up resources...") + # Force cleanup of resources in reverse order (last created first) + for resource in reversed(self._resources_to_cleanup): + try: + # For objects like TransformListener that might have a close or shutdown method + if hasattr(resource, "close"): + resource.close() + elif hasattr(resource, "shutdown"): + resource.shutdown() + + # Explicitly delete the resource + del resource + except Exception as e: + print(f"Error during cleanup: {e}") + + # Clear the resources list + self._resources_to_cleanup = [] + + def get_transform(self, target_frame, source_frame): + print("Getting transform from", source_frame, "to", target_frame) + attempts = 0 + max_attempts = 20 # Reduced from 120 to avoid long blocking + + while attempts < max_attempts: + try: + # Process LCM messages with error handling + if not self.tf_lcm_instance.handle_timeout(100): # 100ms timeout + # If handle_timeout returns false, we might need to re-check if LCM is still good + if not self.tf_lcm_instance.good(): + print("WARNING: LCM instance is no longer in a good state") + + # Get the most recent timestamp from the buffer instead of using current time + try: + timestamp = self.buffer.get_most_recent_timestamp() + if attempts % 10 == 0: + print(f"Using timestamp from buffer: {timestamp}") + except Exception: + # Fall back to current time if get_most_recent_timestamp fails + timestamp = datetime.now() + if not hasattr(timestamp, "timestamp"): + timestamp.timestamp = ( + lambda: time.mktime(timestamp.timetuple()) + timestamp.microsecond / 1e6 + ) + if attempts % 10 == 0: + print(f"Falling back to current time: {timestamp}") + + # Check if we can find the transform + if self.buffer.can_transform(target_frame, source_frame, timestamp): + # print(f"Found transform between '{target_frame}' and '{source_frame}'!") + + # Look up the transform with the timestamp from the buffer + transform = self.buffer.lookup_transform( + target_frame, + source_frame, + timestamp, + timeout=10.0, + time_tolerance=0.1, + lcm_module=lcm_msgs, + ) + + return transform + + # Increment counter and report status every 10 attempts + attempts += 1 + if attempts % 10 == 0: + print(f"Still waiting... (attempt {attempts}/{max_attempts})") + frames = self.buffer.get_all_frame_names() + if frames: + print(f"Frames received so far ({len(frames)} total):") + for frame in sorted(frames): + print(f" {frame}") + else: + print("No frames received yet") + + # Brief pause + time.sleep(0.5) + + except Exception as e: + print(f"Error during transform lookup: {e}") + attempts += 1 + time.sleep(1) # Longer pause after an error + + print(f"\nERROR: No transform found after {max_attempts} attempts") + return None + + def transform_point_cloud_with_open3d(self, points_np: np.ndarray, transform) -> np.ndarray: + """ + Transforms a point cloud using Open3D given a transform. + + Args: + points_np (np.ndarray): Nx3 array of 3D points. + transform: Transform from tf_lcm_py. + + Returns: + np.ndarray: Nx3 array of transformed 3D points. + """ + if points_np.shape[1] != 3: + print("Input point cloud must have shape Nx3.") + return points_np + + # Convert transform to 4x4 numpy matrix + tf_matrix = np.eye(4) + + # Extract rotation quaternion components + qw = transform.transform.rotation.w + qx = transform.transform.rotation.x + qy = transform.transform.rotation.y + qz = transform.transform.rotation.z + + # Convert quaternion to rotation matrix + # Formula from: https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix + tf_matrix[0, 0] = 1 - 2 * qy * qy - 2 * qz * qz + tf_matrix[0, 1] = 2 * qx * qy - 2 * qz * qw + tf_matrix[0, 2] = 2 * qx * qz + 2 * qy * qw + + tf_matrix[1, 0] = 2 * qx * qy + 2 * qz * qw + tf_matrix[1, 1] = 1 - 2 * qx * qx - 2 * qz * qz + tf_matrix[1, 2] = 2 * qy * qz - 2 * qx * qw + + tf_matrix[2, 0] = 2 * qx * qz - 2 * qy * qw + tf_matrix[2, 1] = 2 * qy * qz + 2 * qx * qw + tf_matrix[2, 2] = 1 - 2 * qx * qx - 2 * qy * qy + + # Set translation + tf_matrix[0, 3] = transform.transform.translation.x + tf_matrix[1, 3] = transform.transform.translation.y + tf_matrix[2, 3] = transform.transform.translation.z + + # Create Open3D point cloud + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points_np) + + # Apply transformation + pcd.transform(tf_matrix) + + # Return as NumPy array + return np.asarray(pcd.points) + + +# Updated main function +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Visualize manipulation results") + parser.add_argument("--visualize-only", action="store_true", help="Only visualize results") + args = parser.parse_args() + + if args.visualize_only: + visualize_results() + exit(0) + + try: + # Then set up Drake environment + kinematic_chain_joints = [ + "pillar_platform_joint", + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ] + + links_to_ignore = [ + "devkit_base_link", + "pillar_platform", + "piper_angled_mount", + "pan_tilt_base", + "pan_tilt_head", + "pan_tilt_pan", + "base_link", + "link1", + "link2", + "link3", + "link4", + "link5", + "link6", + ] + + urdf_path = "./assets/devkit_base_descr.urdf" + urdf_path = os.path.abspath(urdf_path) + + print(f"Attempting to load URDF from: {urdf_path}") + + env = DrakeKinematicsEnv(urdf_path, kinematic_chain_joints, links_to_ignore) + env.set_joint_positions([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + transform = env.get_transform("world", "camera_center_link") + print( + transform.transform.translation.x, + transform.transform.translation.y, + transform.transform.translation.z, + ) + print( + transform.transform.rotation.w, + transform.transform.rotation.x, + transform.transform.rotation.y, + transform.transform.rotation.z, + ) + + # Keep the visualization alive + print("\nVisualization is running. Press Ctrl+C to exit.") + while True: + time.sleep(1) + + except KeyboardInterrupt: + print("\nExiting...") + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() diff --git a/tests/zed_neural_depth_demo.py b/tests/zed_neural_depth_demo.py new file mode 100755 index 0000000000..8b5fdb5564 --- /dev/null +++ b/tests/zed_neural_depth_demo.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ZED Camera Neural Depth Demo - OpenCV Live Visualization with Data Saving + +This script demonstrates live visualization of ZED camera RGB and depth data using OpenCV. +Press SPACE to save RGB and depth images to rgbd_data2 folder. +Press ESC or 'q' to quit. +""" + +import argparse +from datetime import datetime +import logging +from pathlib import Path +import sys +import time + +import cv2 +import numpy as np +import open3d as o3d +import yaml + +# Add the project root to Python path +sys.path.append(str(Path(__file__).parent.parent)) + +try: + import pyzed.sl as sl +except ImportError: + print("ERROR: ZED SDK not found. Please install the ZED SDK and pyzed Python package.") + print("Download from: https://www.stereolabs.com/developers/release/") + sys.exit(1) + +from dimos.hardware.zed_camera import ZEDCamera +from dimos.perception.pointcloud.utils import visualize_pcd + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class ZEDLiveVisualizer: + """Live OpenCV visualization for ZED camera data with saving functionality.""" + + def __init__(self, camera, max_depth=10.0, output_dir="assets/rgbd_data2"): + self.camera = camera + self.max_depth = max_depth + self.output_dir = Path(output_dir) + self.save_counter = 0 + + # Store captured pointclouds for later visualization + self.captured_pointclouds = [] + + # Display settings for 480p + self.display_width = 640 + self.display_height = 480 + + # Create output directory structure + self.setup_output_directory() + + # Get camera info for saving + self.camera_info = camera.get_camera_info() + + # Save camera info files once + self.save_camera_info() + + # OpenCV window name (single window) + self.window_name = "ZED Camera - RGB + Depth" + + # Create window + cv2.namedWindow(self.window_name, cv2.WINDOW_AUTOSIZE) + + def setup_output_directory(self): + """Create the output directory structure.""" + self.output_dir.mkdir(exist_ok=True) + (self.output_dir / "color").mkdir(exist_ok=True) + (self.output_dir / "depth").mkdir(exist_ok=True) + (self.output_dir / "pointclouds").mkdir(exist_ok=True) + logger.info(f"Created output directory: {self.output_dir}") + + def save_camera_info(self): + """Save camera info YAML files with ZED camera parameters.""" + # Get current timestamp + now = datetime.now() + timestamp_sec = int(now.timestamp()) + timestamp_nanosec = int((now.timestamp() % 1) * 1e9) + + # Get camera resolution + resolution = self.camera_info.get("resolution", {}) + width = int(resolution.get("width", 1280)) + height = int(resolution.get("height", 720)) + + # Extract left camera parameters (for RGB) from already available camera_info + left_cam = self.camera_info.get("left_cam", {}) + # Convert numpy values to Python floats + fx = float(left_cam.get("fx", 749.341552734375)) + fy = float(left_cam.get("fy", 748.5587768554688)) + cx = float(left_cam.get("cx", 639.4312744140625)) + cy = float(left_cam.get("cy", 357.2478942871094)) + + # Build distortion coefficients from ZED format + # ZED provides k1, k2, p1, p2, k3 - convert to rational_polynomial format + k1 = float(left_cam.get("k1", 0.0)) + k2 = float(left_cam.get("k2", 0.0)) + p1 = float(left_cam.get("p1", 0.0)) + p2 = float(left_cam.get("p2", 0.0)) + k3 = float(left_cam.get("k3", 0.0)) + distortion = [k1, k2, p1, p2, k3, 0.0, 0.0, 0.0] + + # Create camera info structure with plain Python types + camera_info = { + "D": distortion, + "K": [fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0], + "P": [fx, 0.0, cx, 0.0, 0.0, fy, cy, 0.0, 0.0, 0.0, 1.0, 0.0], + "R": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + "binning_x": 0, + "binning_y": 0, + "distortion_model": "rational_polynomial", + "header": { + "frame_id": "camera_color_optical_frame", + "stamp": {"nanosec": timestamp_nanosec, "sec": timestamp_sec}, + }, + "height": height, + "roi": {"do_rectify": False, "height": 0, "width": 0, "x_offset": 0, "y_offset": 0}, + "width": width, + } + + # Save color camera info + color_info_path = self.output_dir / "color_camera_info.yaml" + with open(color_info_path, "w") as f: + yaml.dump(camera_info, f, default_flow_style=False) + + # Save depth camera info (same as color for ZED) + depth_info_path = self.output_dir / "depth_camera_info.yaml" + with open(depth_info_path, "w") as f: + yaml.dump(camera_info, f, default_flow_style=False) + + logger.info(f"Saved camera info files to {self.output_dir}") + + def normalize_depth_for_display(self, depth_map): + """Normalize depth map for OpenCV visualization.""" + # Handle invalid values + valid_mask = (depth_map > 0) & np.isfinite(depth_map) + + if not np.any(valid_mask): + return np.zeros_like(depth_map, dtype=np.uint8) + + # Normalize to 0-255 for display + depth_norm = np.zeros_like(depth_map, dtype=np.float32) + depth_clipped = np.clip(depth_map[valid_mask], 0, self.max_depth) + depth_norm[valid_mask] = depth_clipped / self.max_depth + + # Convert to 8-bit and apply colormap + depth_8bit = (depth_norm * 255).astype(np.uint8) + depth_colored = cv2.applyColorMap(depth_8bit, cv2.COLORMAP_JET) + + return depth_colored + + def save_frame(self, rgb_img, depth_map): + """Save RGB, depth images, and pointcloud with proper naming convention.""" + # Generate filename with 5-digit zero-padding + filename = f"{self.save_counter:05d}.png" + pcd_filename = f"{self.save_counter:05d}.ply" + + # Save RGB image + rgb_path = self.output_dir / "color" / filename + cv2.imwrite(str(rgb_path), rgb_img) + + # Save depth image (convert to 16-bit for proper depth storage) + depth_path = self.output_dir / "depth" / filename + # Convert meters to millimeters and save as 16-bit + depth_mm = (depth_map * 1000).astype(np.uint16) + cv2.imwrite(str(depth_path), depth_mm) + + # Capture and save pointcloud + pcd = self.camera.capture_pointcloud() + if pcd is not None and len(np.asarray(pcd.points)) > 0: + pcd_path = self.output_dir / "pointclouds" / pcd_filename + o3d.io.write_point_cloud(str(pcd_path), pcd) + + # Store pointcloud for later visualization + self.captured_pointclouds.append(pcd) + + logger.info( + f"Saved frame {self.save_counter}: {rgb_path}, {depth_path}, and {pcd_path}" + ) + else: + logger.warning(f"Failed to capture pointcloud for frame {self.save_counter}") + logger.info(f"Saved frame {self.save_counter}: {rgb_path} and {depth_path}") + + self.save_counter += 1 + + def visualize_captured_pointclouds(self): + """Visualize all captured pointclouds using Open3D, one by one.""" + if not self.captured_pointclouds: + logger.info("No pointclouds captured to visualize") + return + + logger.info( + f"Visualizing {len(self.captured_pointclouds)} captured pointclouds one by one..." + ) + logger.info("Close each pointcloud window to proceed to the next one") + + for i, pcd in enumerate(self.captured_pointclouds): + if len(np.asarray(pcd.points)) > 0: + logger.info(f"Displaying pointcloud {i + 1}/{len(self.captured_pointclouds)}") + visualize_pcd(pcd, window_name=f"ZED Pointcloud {i + 1:05d}", point_size=2.0) + else: + logger.warning(f"Pointcloud {i + 1} is empty, skipping...") + + logger.info("Finished displaying all pointclouds") + + def update_display(self): + """Update the live display with new frames.""" + # Capture frame + left_img, _right_img, depth_map = self.camera.capture_frame() + + if left_img is None or depth_map is None: + return False, None, None + + # Resize RGB to 480p + rgb_resized = cv2.resize(left_img, (self.display_width, self.display_height)) + + # Create depth visualization + depth_colored = self.normalize_depth_for_display(depth_map) + + # Resize depth to 480p + depth_resized = cv2.resize(depth_colored, (self.display_width, self.display_height)) + + # Add text overlays + text_color = (255, 255, 255) + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 2 + + # Add title and instructions to RGB + cv2.putText( + rgb_resized, "RGB Camera Feed", (10, 25), font, font_scale, text_color, thickness + ) + cv2.putText( + rgb_resized, + "SPACE: Save | ESC/Q: Quit", + (10, 50), + font, + font_scale - 0.1, + text_color, + thickness, + ) + + # Add title and stats to depth + cv2.putText( + depth_resized, + f"Depth Map (0-{self.max_depth}m)", + (10, 25), + font, + font_scale, + text_color, + thickness, + ) + cv2.putText( + depth_resized, + f"Saved: {self.save_counter} frames", + (10, 50), + font, + font_scale - 0.1, + text_color, + thickness, + ) + + # Stack images horizontally + combined_display = np.hstack((rgb_resized, depth_resized)) + + # Display combined image + cv2.imshow(self.window_name, combined_display) + + return True, left_img, depth_map + + def handle_key_events(self, rgb_img, depth_map): + """Handle keyboard input.""" + key = cv2.waitKey(1) & 0xFF + + if key == ord(" "): # Space key - save frame + if rgb_img is not None and depth_map is not None: + self.save_frame(rgb_img, depth_map) + return "save" + elif key == 27 or key == ord("q"): # ESC or 'q' - quit + return "quit" + + return "continue" + + def cleanup(self): + """Clean up OpenCV windows.""" + cv2.destroyAllWindows() + + +def main(): + parser = argparse.ArgumentParser( + description="ZED Camera Neural Depth Demo - OpenCV with Data Saving" + ) + parser.add_argument("--camera-id", type=int, default=0, help="ZED camera ID (default: 0)") + parser.add_argument( + "--resolution", + type=str, + default="HD1080", + choices=["HD2K", "HD1080", "HD720", "VGA"], + help="Camera resolution (default: HD1080)", + ) + parser.add_argument( + "--max-depth", + type=float, + default=10.0, + help="Maximum depth for visualization in meters (default: 10.0)", + ) + parser.add_argument( + "--camera-fps", type=int, default=15, help="Camera capture FPS (default: 30)" + ) + parser.add_argument( + "--depth-mode", + type=str, + default="NEURAL", + choices=["NEURAL", "NEURAL_PLUS"], + help="Depth mode (NEURAL=faster, NEURAL_PLUS=more accurate)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="assets/rgbd_data2", + help="Output directory for saved data (default: rgbd_data2)", + ) + + args = parser.parse_args() + + # Map resolution string to ZED enum + resolution_map = { + "HD2K": sl.RESOLUTION.HD2K, + "HD1080": sl.RESOLUTION.HD1080, + "HD720": sl.RESOLUTION.HD720, + "VGA": sl.RESOLUTION.VGA, + } + + depth_mode_map = {"NEURAL": sl.DEPTH_MODE.NEURAL, "NEURAL_PLUS": sl.DEPTH_MODE.NEURAL_PLUS} + + try: + # Initialize ZED camera with neural depth + logger.info( + f"Initializing ZED camera with {args.depth_mode} depth processing at {args.camera_fps} FPS..." + ) + camera = ZEDCamera( + camera_id=args.camera_id, + resolution=resolution_map[args.resolution], + depth_mode=depth_mode_map[args.depth_mode], + fps=args.camera_fps, + ) + + # Open camera + with camera: + # Get camera information + info = camera.get_camera_info() + logger.info(f"Camera Model: {info.get('model', 'Unknown')}") + logger.info(f"Serial Number: {info.get('serial_number', 'Unknown')}") + logger.info(f"Firmware: {info.get('firmware', 'Unknown')}") + logger.info(f"Resolution: {info.get('resolution', {})}") + logger.info(f"Baseline: {info.get('baseline', 0):.3f}m") + + # Initialize visualizer + visualizer = ZEDLiveVisualizer( + camera, max_depth=args.max_depth, output_dir=args.output_dir + ) + + logger.info("Starting live visualization...") + logger.info("Controls:") + logger.info(" SPACE - Save current RGB and depth frame") + logger.info(" ESC/Q - Quit") + + frame_count = 0 + start_time = time.time() + + try: + while True: + loop_start = time.time() + + # Update display + success, rgb_img, depth_map = visualizer.update_display() + + if success: + frame_count += 1 + + # Handle keyboard events + action = visualizer.handle_key_events(rgb_img, depth_map) + + if action == "quit": + break + elif action == "save": + # Frame was saved, no additional action needed + pass + + # Print performance stats every 60 frames + if frame_count % 60 == 0: + elapsed = time.time() - start_time + fps = frame_count / elapsed + logger.info( + f"Frame {frame_count} | FPS: {fps:.1f} | Saved: {visualizer.save_counter}" + ) + + # Small delay to prevent CPU overload + elapsed = time.time() - loop_start + min_frame_time = 1.0 / 60.0 # Cap at 60 FPS + if elapsed < min_frame_time: + time.sleep(min_frame_time - elapsed) + + except KeyboardInterrupt: + logger.info("Stopped by user") + + # Final stats + total_time = time.time() - start_time + if total_time > 0: + avg_fps = frame_count / total_time + logger.info( + f"Final stats: {frame_count} frames in {total_time:.1f}s (avg {avg_fps:.1f} FPS)" + ) + logger.info(f"Total saved frames: {visualizer.save_counter}") + + # Visualize captured pointclouds + visualizer.visualize_captured_pointclouds() + + except Exception as e: + logger.error(f"Error during execution: {e}") + raise + finally: + if "visualizer" in locals(): + visualizer.cleanup() + logger.info("Demo completed") + + +if __name__ == "__main__": + main()