Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions download_TrackingNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ def main(trackingnet_dir="TrackingNet", csv_dir=".", overwrite=False, chunks=[],
destination_path = os.path.join(trackingnet_dir, chunk_folder, datum.lower(), Google_drive_file_name)

if (not os.path.exists(destination_path)):

downloader.download(url='https://drive.google.com/uc?id={id}'.format(id=Google_drive_file_id),
output=destination_path,
quiet=True,
)
while True:
try:
downloader.download(url='https://drive.google.com/uc?id={id}'.format(id=Google_drive_file_id),
output=destination_path,
quiet=False,
)
break
except Exception as e:
print("\nException:", e)
print("Retrying...")
# from IPython import embed;embed()
continue



Expand Down
71 changes: 58 additions & 13 deletions extract_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,45 @@
import zipfile
import argparse
import shutil
import logging
import time


def main(trackingnet_dir="TrackingNet", overwrite_frames=False, chunks=[]):
def getLogger(title):
log_dir = "../log"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_file = os.path.join(log_dir, "{}.log".format(title))

logger = logging.getLogger("testzip")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler(log_file)
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
logger.addHandler(fh)
logger.addHandler(ch)
return logger, fh, ch


def releaseLogger(logger):
for handler in logger.handlers[:]:
handler.close()
logger.removeHandler(handler)
del logger


def main(trackingnet_dir="TrackingNet", overwrite_frames=False, test_zips=False, chunks=[]):

for chunk_folder in chunks:
chunk_folder = chunk_folder.upper()
zip_folder = os.path.join(trackingnet_dir, chunk_folder, "zips")

logger = None
if test_zips:
logger, fh, ch = getLogger(chunk_folder)
logger.info("Start at {}".format(time.strftime("%Y-%m-%d %X")))

if( os.path.exists(zip_folder)):

for zip_file in tqdm(os.listdir(zip_folder), desc=chunk_folder):
Expand All @@ -34,23 +65,34 @@ def main(trackingnet_dir="TrackingNet", overwrite_frames=False, chunks=[]):

# if frame folder does not exist, jsut create it
else:
same_number_files = False
os.makedirs(frame_folder)
same_number_files = False
if not test_zips:
os.makedirs(frame_folder)

# extract zip if necessary
if(overwrite_frames or not same_number_files):
zip_ref.extractall(os.path.join(frame_folder))
if test_zips:
test_result = zip_ref.testzip()
if test_result is not None:
logger.info("{} corrupted".format(zip_file))
else:
zip_ref.extractall(os.path.join(frame_folder))

# check that all the files were extracted
same_number_files = len(zip_ref.infolist()) == len(os.listdir(frame_folder))
if (not same_number_files):
print("Warning:", frame_folder, "was not well extracted")
# from IPython import embed;embed()
if not test_zips:
same_number_files = len(zip_ref.infolist()) == len(os.listdir(frame_folder))
if (not same_number_files):
print("Warning:", frame_folder, "was not well extracted")


except zipfile.BadZipFile:
print("Error: the zip file", zip_file, "is corrupted, please delete it and download it again.")


if test_zips:
logger.info("Done at {}".format(time.strftime("%Y-%m-%d %X")))
logger.info("="*50)
releaseLogger(logger)


if __name__ == "__main__":
Expand All @@ -59,6 +101,8 @@ def main(trackingnet_dir="TrackingNet", overwrite_frames=False, chunks=[]):
help='Main TrackingNet folder.')
p.add_argument('--overwrite_frames', action='store_true',
help='Folder where to store the frames.')
p.add_argument('--test_zips', action='store_true',
help='Only check .zip files, donnot extract.')
p.add_argument('--chunk', type=str, default="ALL",
help='List of chunks to elaborate [ALL / Train / Test / 4 / 1,2,5].')

Expand All @@ -77,10 +121,11 @@ def main(trackingnet_dir="TrackingNet", overwrite_frames=False, chunks=[]):
except:
chunk = []


print("extracting the frames for the following chunks:")
if args.test_zips:
operation_desc = "testing zips"
else:
operation_desc = "extracting the frames"
print(operation_desc, "for the following chunks:")
print(chunk)

main(args.trackingnet_dir, args.overwrite_frames, chunk)


main(args.trackingnet_dir, args.overwrite_frames, args.test_zips, chunk)