forked from hhhaaahhhaa/TTS-systems
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclean.py
More file actions
107 lines (86 loc) · 2.99 KB
/
Copy pathclean.py
File metadata and controls
107 lines (86 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import os
import torch
from tqdm import tqdm
import json
import gc
from Parsers.parser import DataParser
class LengthChecker(object):
def __init__(self, data_parser: DataParser, mi=1, mx=15):
self.data_parser = data_parser
self.data_parser.units["mfa"].segment.read_all()
self.mi = mi
self.mx = mx
def check(self, query) -> bool:
try:
segment = self.data_parser.units["mfa"].segment.read_from_query(query)
l = segment[-1][1] - segment[0][0]
assert self.mi <= l and l <= self.mx
except:
return False
return True
class ExistenceChecker(object):
def __init__(self, data_parser: DataParser):
self.data_parser = data_parser
def check(self, query) -> bool:
try:
filenames = [
self.data_parser.mel.read_filename(query, raw=True),
self.data_parser.units["mfa"].duration_avg_energy.read_filename(query, raw=True),
self.data_parser.units["mfa"].duration_avg_pitch.read_filename(query, raw=True),
]
for f in filenames:
assert os.path.exists(f)
except:
return False
return True
class UnknownTokenChecker(object):
def __init__(self, data_parser: DataParser):
self.data_parser = data_parser
self.data_parser.units["mfa"].phoneme.read_all()
def check(self, query) -> bool:
try:
phoneme = self.data_parser.units["mfa"].phoneme.read_from_query(query)
assert "spn" not in phoneme.split(" ")
except:
return False
return True
def clean(root: str, output_path: str):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
data_parser = DataParser(root)
res = data_parser.get_all_queries()
print("Check existence...")
filtered = []
checker = ExistenceChecker(data_parser)
for query in tqdm(res):
if checker.check(query):
filtered.append(query)
res = filtered
print(f"Remain samples: {len(res)}.")
print("Check length...")
filtered = []
checker = LengthChecker(data_parser)
for query in tqdm(res):
if checker.check(query):
filtered.append(query)
res = filtered
print(f"Remain samples: {len(res)}.")
print("Check unknown tokens (spn)...")
filtered = []
checker = UnknownTokenChecker(data_parser)
for query in tqdm(res):
if checker.check(query):
filtered.append(query)
res = filtered
print(f"Remain samples: {len(res)}.")
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(res, f, indent=4)
def main(args):
""" Write results to args.output_path """
clean(args.preprocessed_data_dir, args.output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("preprocessed_data_dir", type=str)
parser.add_argument("output_path", type=str)
args = parser.parse_args()
main(args)