Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
76a9041
fix rebase with master
fracape Feb 2, 2024
fdb884c
docs: add missing autodocs
YodaEmbedding Feb 1, 2024
07a7f0a
feat: compressai.typing.TTransform
YodaEmbedding Feb 1, 2024
af6396b
refactor: compressai.registry.transforms
YodaEmbedding Feb 1, 2024
5b2b8d0
feat: fast EntropyBottleneck aux_loss minimization via bisection search
YodaEmbedding Apr 27, 2023
246200d
feat: net optimizer
YodaEmbedding Jan 31, 2024
4d4192b
feat: basic layers (Lambda, Reshape, Transpose, Interleave, etc.)
YodaEmbedding Jan 31, 2024
5e69c5b
chore(deps): einops, pandas, torch-geometric, tqdm
YodaEmbedding Feb 1, 2024
6b0d638
chore(deps): pointops, pyntcloud [pointcloud]
YodaEmbedding Feb 1, 2024
09ed45c
chore(deps): pyntcloud use PR with *.off header fix
YodaEmbedding Feb 2, 2024
7f1f81c
feat: compressai.registry.transforms torch_geometric
YodaEmbedding Feb 1, 2024
f9d3708
feat: point cloud datasets (ModelNet, ShapeNet, S3DIS, SemanticKITTI)
YodaEmbedding Jan 31, 2024
b7c1f1e
feat: point cloud transforms
YodaEmbedding Feb 1, 2024
bd8a0d9
feat: point cloud losses
YodaEmbedding Jan 31, 2024
4ece458
feat: point cloud layers (pointnet, pointnet2, hrtzxf2022)
YodaEmbedding Jan 31, 2024
8ac44b9
feat: point cloud compression models
YodaEmbedding Jan 31, 2024
4ca480b
fix: point cloud datasets optional dependencies
YodaEmbedding Feb 1, 2024
16a4752
fix: point cloud layers optional dependencies
YodaEmbedding Feb 1, 2024
214050c
fix: point cloud losses optional dependencies
YodaEmbedding Feb 1, 2024
20838fc
feat: zoo.pointcloud_models [placeholder]
YodaEmbedding Feb 1, 2024
ec7c7fa
feat: examples/train_pointcloud.py
YodaEmbedding Feb 1, 2024
50ab5ea
chore(deps): python_requires>=3.7
YodaEmbedding Feb 2, 2024
115d91f
update github workflow and gitlab-ci
fracape Feb 3, 2024
2a6876e
update github workflow and gitlab-ci
fracape Feb 4, 2024
cb07223
update numpy version for point cloud deps (pandas)
fracape Feb 4, 2024
ce8b9cf
[chores] update github actions
fracape Feb 4, 2024
9c9b37b
refactor: rename transforms.point -> transforms.pointcloud
YodaEmbedding Feb 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@ jobs:
strategy:
matrix:
python-version:
- "3.6"
- "3.7"
- "3.8"
- "3.9"
include:
- os: "ubuntu-latest"
# no Python 3.6 in ubuntu>20.04.
- os: "ubuntu-20.04"
python-version: "3.6"
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ jobs:
sdist:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python 3.8
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Cache pip
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
Expand All @@ -38,15 +38,15 @@ jobs:
runs-on: macos-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
Expand All @@ -72,7 +72,7 @@ jobs:
matrix:
python-version: [cp36-cp36m, cp37-cp37m, cp38-cp38, cp39-cp39]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Install dependencies
run: /opt/python/${{ matrix.python-version }}/bin/python -m pip install build twine
- name: Build wheel
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/static-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@ jobs:
strategy:
matrix:
python-version:
- "3.6"
- "3.7"
- "3.8"
- "3.9"
include:
- os: "ubuntu-latest"
# no Python 3.6 in ubuntu>20.04.
- os: "ubuntu-20.04"
python-version: "3.6"
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
Expand Down
11 changes: 5 additions & 6 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ wheel:
expire_in: 1 day
parallel:
matrix:
- PYTHON_VERSION: ['3.6', '3.7', '3.8', '3.9']
- PYTHON_VERSION: ['3.7', '3.8', '3.9']
tags:
- docker

sdist:
image: python:3.6-buster
image: python:3.7-buster
stage: build
before_script:
- pip install build
Expand All @@ -51,7 +51,7 @@ flake8:

black:
stage: static-analysis
image: python:3.6-buster
image: python:3.7-buster
before_script:
- python --version
- pip install black
Expand All @@ -62,7 +62,7 @@ black:

isort:
stage: static-analysis
image: python:3.6-buster
image: python:3.7-buster
before_script:
- python --version
- pip install .
Expand All @@ -79,7 +79,6 @@ test:
before_script:
- python --version
- pip install -e .
- pip install -r requirements.txt
- pip install pytest pytest-cov plotly
script:
- >
Expand All @@ -93,7 +92,7 @@ test:
- PYTORCH_IMAGE:
- "1.9.0-cuda11.1-cudnn8-devel"
- "1.8.1-cuda11.1-cudnn8-devel"
- "1.7.1-cuda11.0-cudnn8-devel"
# - "1.7.1-cuda11.0-cudnn8-devel"
tags:
- docker

Expand Down
2 changes: 1 addition & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ CompressAI currently provides:

## Installation

CompressAI supports python 3.6+ and PyTorch 1.7+.
CompressAI supports python 3.7+ and PyTorch 1.7+.

**pip**:

Expand Down
3 changes: 3 additions & 0 deletions compressai/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from . import pointcloud
from .image import ImageFolder
from .pointcloud import *
from .pregenerated import PreGeneratedMemmapDataset
from .rawvideo import *
from .video import VideoFolder
from .vimeo90k import Vimeo90kDataset

__all__ = [
*pointcloud.__all__,
"ImageFolder",
"PreGeneratedMemmapDataset",
"VideoFolder",
Expand Down
126 changes: 126 additions & 0 deletions compressai/datasets/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2021-2022, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os
import os.path

from pathlib import Path

import numpy as np

from torch.utils.data import Dataset
from tqdm import tqdm


class CacheDataset(Dataset):
def __init__(
self,
cache_root=None,
pre_transform=None,
transform=None,
):
self.__cache_root = Path(cache_root)
self.pre_transform = pre_transform
self.transform = transform
self._store = {}

def __len__(self):
return len(self._store[next(iter(self._store))])

def __getitem__(self, index):
data = {k: v[index].copy() for k, v in self._store.items()}
if self.transform is not None:
data = self.transform(data)
return data

def _ensure_cache(self):
try:
self._load_cache(mode="r")
except FileNotFoundError:
self._generate_cache()
self._load_cache(mode="r")

def _load_cache(self, mode):
with open(self.__cache_root / "info.json", "r") as f:
info = json.load(f)

self._store = {
k: np.memmap(
self.__cache_root / f"{k}.npy",
mode=mode,
dtype=settings["dtype"],
shape=tuple(settings["shape"]),
)
for k, settings in info.items()
}

def _generate_cache(self, verbose=True):
if verbose:
print(f"Generating cache at {self.__cache_root}...")

items = self._get_items()

if verbose:
items = tqdm(items)

for i, item in enumerate(items):
data = self._load_item(item)

if self.pre_transform is not None:
data = self.pre_transform(data)

if not self._store:
self._write_cache_info(len(items), data)
self._load_cache(mode="w+")

for k, v in data.items():
self._store[k][i] = v

def _write_cache_info(self, num_samples, data):
info = {
k: {
"dtype": _removeprefix(str(v.dtype), "torch."),
"shape": (num_samples, *v.shape),
}
for k, v in data.items()
}
os.makedirs(self.__cache_root, exist_ok=True)
with open(self.__cache_root / "info.json", "w") as f:
json.dump(info, f, indent=2)

def _get_items(self):
raise NotImplementedError

def _load_item(self, item):
raise NotImplementedError


def _removeprefix(s: str, prefix: str) -> str:
return s[len(prefix) :] if s.startswith(prefix) else s
65 changes: 65 additions & 0 deletions compressai/datasets/ndarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2021-2022, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Adapted via https://github.com/pytorch/pytorch/blob/v2.1.0/torch/utils/data/dataset.py
# BSD-style license: https://github.com/pytorch/pytorch/blob/v2.1.0/LICENSE

from typing import Tuple, Union

import numpy as np

from torch.utils.data import Dataset


class NdArrayDataset(Dataset[Union[np.ndarray, Tuple[np.ndarray, ...]]]):
r"""Dataset wrapping arrays.

Each sample will be retrieved by indexing arrays along the first dimension.

Args:
*arrays (np.ndarray): arrays that have the same size of the first dimension.
"""

arrays: Tuple[np.ndarray, ...]

def __init__(self, *arrays: np.ndarray, single: bool = False) -> None:
assert all(
arrays[0].shape[0] == array.shape[0] for array in arrays
), "Size mismatch between arrays"
self.arrays = arrays
self.single = single

def __getitem__(self, index):
if self.single:
[array] = self.arrays
return array[index]
return tuple(array[index] for array in self.arrays)

def __len__(self):
return self.arrays[0].shape[0]
40 changes: 40 additions & 0 deletions compressai/datasets/pointcloud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2021-2022, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from .modelnet import ModelNetDataset
from .s3dis import S3disDataset
from .semantic_kitti import SemanticKittiDataset
from .shapenet import ShapeNetCorePartDataset

__all__ = [
"ModelNetDataset",
"S3disDataset",
"SemanticKittiDataset",
"ShapeNetCorePartDataset",
]
Loading