88published_version.json file
99"""
1010
11- import json
12- import copy
1311import argparse
12+ import copy
13+ import json
14+ from enum import Enum
1415from pathlib import Path
1516from typing import Dict
16- from enum import Enum
1717
18- BASE_DIR = Path (__file__ ).parent .parent
18+ BASE_DIR = Path (__file__ ).parent .parent
19+
1920
2021class OperatingSystem (Enum ):
2122 LINUX : str = "linux"
2223 WINDOWS : str = "windows"
2324 MACOS : str = "macos"
2425
26+
2527PRE_CXX11_ABI = "pre-cxx11"
2628CXX11_ABI = "cxx11-abi"
2729DEBUG = "debug"
@@ -38,29 +40,30 @@ class OperatingSystem(Enum):
3840 "cuda.x" : ("cuda" , "11.8" ),
3941 "cuda.y" : ("cuda" , "12.1" ),
4042 "cuda.z" : ("cuda" , "12.4" ),
41- "rocm5.x" : ("rocm" , "6.0" )
42- },
43+ "rocm5.x" : ("rocm" , "6.0" ),
44+ },
4345 "release" : {
4446 "accnone" : ("cpu" , "" ),
4547 "cuda.x" : ("cuda" , "11.8" ),
4648 "cuda.y" : ("cuda" , "12.1" ),
4749 "cuda.z" : ("cuda" , "12.4" ),
48- "rocm5.x" : ("rocm" , "6.0" )
49- }
50- }
50+ "rocm5.x" : ("rocm" , "6.0" ),
51+ },
52+ }
5153
5254# Initialize arch version to default values
5355# these default values will be overwritten by
5456# extracted values from the release marix
5557acc_arch_ver_map = acc_arch_ver_default
5658
5759LIBTORCH_DWNL_INSTR = {
58- PRE_CXX11_ABI : "Download here (Pre-cxx11 ABI):" ,
59- CXX11_ABI : "Download here (cxx11 ABI):" ,
60- RELEASE : "Download here (Release version):" ,
61- DEBUG : "Download here (Debug version):" ,
62- MACOS : "Download arm64 libtorch here (ROCm and CUDA are not supported):" ,
63- }
60+ PRE_CXX11_ABI : "Download here (Pre-cxx11 ABI):" ,
61+ CXX11_ABI : "Download here (cxx11 ABI):" ,
62+ RELEASE : "Download here (Release version):" ,
63+ DEBUG : "Download here (Debug version):" ,
64+ MACOS : "Download arm64 libtorch here (ROCm and CUDA are not supported):" ,
65+ }
66+
6467
6568def load_json_from_basedir (filename : str ):
6669 try :
@@ -71,32 +74,39 @@ def load_json_from_basedir(filename: str):
7174 except json .JSONDecodeError as exc :
7275 raise ImportError (f"Invalid JSON { filename } " ) from exc
7376
77+
7478def read_published_versions ():
7579 return load_json_from_basedir ("published_versions.json" )
7680
81+
7782def write_published_versions (versions ):
7883 with open (BASE_DIR / "published_versions.json" , "w" ) as outfile :
7984 json .dump (versions , outfile , indent = 2 )
8085
86+
8187def read_matrix_for_os (osys : OperatingSystem , channel : str ):
8288 jsonfile = load_json_from_basedir (f"{ osys .value } _{ channel } _matrix.json" )
8389 return jsonfile ["include" ]
8490
91+
8592def read_quick_start_module_template ():
8693 with open (BASE_DIR / "_includes" / "quick-start-module.js" ) as fptr :
8794 return fptr .read ()
8895
96+
8997def get_package_type (pkg_key : str , os_key : OperatingSystem ) -> str :
9098 if pkg_key != "pip" :
9199 return pkg_key
92100 return "manywheel" if os_key == OperatingSystem .LINUX .value else "wheel"
93101
102+
94103def get_gpu_info (acc_key , instr , acc_arch_map ):
95104 gpu_arch_type , gpu_arch_version = acc_arch_map [acc_key ]
96105 if DEFAULT in instr :
97106 gpu_arch_type , gpu_arch_version = acc_arch_map ["accnone" ]
98107 return (gpu_arch_type , gpu_arch_version )
99108
109+
100110# This method is used for generating new published_versions.json file
101111# It will modify versions json object with installation instructions
102112# Provided by generate install matrix Github Workflow, stored in release_matrix
@@ -109,42 +119,62 @@ def update_versions(versions, release_matrix, release_version):
109119 if release_version != "nightly" :
110120 version = release_matrix [OperatingSystem .LINUX .value ][0 ]["stable_version" ]
111121 if version not in versions ["versions" ]:
112- versions ["versions" ][version ] = copy .deepcopy (versions ["versions" ][template ])
122+ versions ["versions" ][version ] = copy .deepcopy (
123+ versions ["versions" ][template ]
124+ )
113125 versions ["latest_stable" ] = version
114126
115127 # Perform update of the json file from release matrix
116128 for os_key , os_vers in versions ["versions" ][version ].items ():
117129 for pkg_key , pkg_vers in os_vers .items ():
118130 for acc_key , instr in pkg_vers .items ():
119131 package_type = get_package_type (pkg_key , os_key )
120- gpu_arch_type , gpu_arch_version = get_gpu_info (acc_key , instr , acc_arch_map )
132+ gpu_arch_type , gpu_arch_version = get_gpu_info (
133+ acc_key , instr , acc_arch_map
134+ )
121135
122136 pkg_arch_matrix = [
123- x for x in release_matrix [os_key ]
124- if (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ]) ==
125- (package_type , gpu_arch_type , gpu_arch_version )
126- ]
137+ x
138+ for x in release_matrix [os_key ]
139+ if (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ])
140+ == (package_type , gpu_arch_type , gpu_arch_version )
141+ ]
127142
128143 if pkg_arch_matrix :
129144 if package_type != "libtorch" :
130145 instr ["command" ] = pkg_arch_matrix [0 ]["installation" ]
131146 else :
132147 if os_key == OperatingSystem .LINUX .value :
133148 rel_entry_dict = {
134- x ["devtoolset" ]: x ["installation" ] for x in pkg_arch_matrix
149+ x ["devtoolset" ]: x ["installation" ]
150+ for x in pkg_arch_matrix
135151 if x ["libtorch_variant" ] == "shared-with-deps"
136- }
152+ }
137153 if instr ["versions" ] is not None :
138154 for ver in [PRE_CXX11_ABI , CXX11_ABI ]:
139- instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = rel_entry_dict [ver ]
155+ if gpu_arch_type == "rocm" and ver == PRE_CXX11_ABI :
156+ continue
157+ else :
158+ instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = (
159+ rel_entry_dict [ver ]
160+ )
161+
140162 elif os_key == OperatingSystem .WINDOWS .value :
141- rel_entry_dict = {x ["libtorch_config" ]: x ["installation" ] for x in pkg_arch_matrix }
163+ rel_entry_dict = {
164+ x ["libtorch_config" ]: x ["installation" ]
165+ for x in pkg_arch_matrix
166+ }
142167 if instr ["versions" ] is not None :
143168 for ver in [RELEASE , DEBUG ]:
144- instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = rel_entry_dict [ver ]
169+ instr ["versions" ][LIBTORCH_DWNL_INSTR [ver ]] = (
170+ rel_entry_dict [ver ]
171+ )
145172 elif os_key == OperatingSystem .MACOS .value :
146173 if instr ["versions" ] is not None :
147- instr ["versions" ][LIBTORCH_DWNL_INSTR [MACOS ]] = pkg_arch_matrix [0 ]["installation" ]
174+ instr ["versions" ][LIBTORCH_DWNL_INSTR [MACOS ]] = (
175+ pkg_arch_matrix [0 ]["installation" ]
176+ )
177+
148178
149179# This method is used for generating new quick-start-module.js
150180# from the versions json object
@@ -158,21 +188,25 @@ def gen_install_matrix(versions) -> Dict[str, str]:
158188 for os_key , os_vers in versions ["versions" ][ver_key ].items ():
159189 for pkg_key , pkg_vers in os_vers .items ():
160190 for acc_key , instr in pkg_vers .items ():
161- extra_key = ' python' if pkg_key != ' libtorch' else ' cplusplus'
191+ extra_key = " python" if pkg_key != " libtorch" else " cplusplus"
162192 key = f"{ ver } ,{ pkg_key } ,{ os_key } ,{ acc_key } ,{ extra_key } "
163193 note = instr ["note" ]
164194 lines = [note ] if note is not None else []
165195 if pkg_key == "libtorch" :
166196 ivers = instr ["versions" ]
167197 if ivers is not None :
168- lines += [f"{ lab } <br /><a href='{ val } '>{ val } </a>" for (lab , val ) in ivers .items ()]
198+ lines += [
199+ f"{ lab } <br /><a href='{ val } '>{ val } </a>"
200+ for (lab , val ) in ivers .items ()
201+ ]
169202 else :
170203 command = instr ["command" ]
171204 if command is not None :
172205 lines .append (command )
173206 result [key ] = "<br />" .join (lines )
174207 return result
175208
209+
176210# This method is used for extracting two latest verisons of cuda and
177211# last verion of rocm. It will modify the acc_arch_ver_map object used
178212# to update getting started page.
@@ -195,8 +229,8 @@ def gen_ver_list(chan, gpu_arch_type):
195229
196230def main ():
197231 parser = argparse .ArgumentParser ()
198- parser .add_argument (' --autogenerate' , dest = ' autogenerate' , action = ' store_true' )
199- parser .set_defaults (autogenerate = True )
232+ parser .add_argument (" --autogenerate" , dest = " autogenerate" , action = " store_true" )
233+ parser .set_defaults (autogenerate = False )
200234
201235 options = parser .parse_args ()
202236 versions = read_published_versions ()
@@ -217,8 +251,11 @@ def main():
217251 template = read_quick_start_module_template ()
218252 versions_str = json .dumps (gen_install_matrix (versions ))
219253 template = template .replace ("{{ installMatrix }}" , versions_str )
220- template = template .replace ("{{ VERSION }}" , f"\" Stable ({ versions ['latest_stable' ]} )\" " )
254+ template = template .replace (
255+ "{{ VERSION }}" , f"\" Stable ({ versions ['latest_stable' ]} )\" "
256+ )
221257 print (template .replace ("{{ ACC ARCH MAP }}" , json .dumps (acc_arch_ver_map )))
222258
259+
223260if __name__ == "__main__" :
224261 main ()
0 commit comments