@@ -22,6 +22,7 @@ class OperatingSystem(Enum):
2222 LINUX : str = "linux"
2323 WINDOWS : str = "windows"
2424 MACOS : str = "macos"
25+ WINDOWS_ARM64 : str = "windows-arm64"
2526
2627
2728PRE_CXX11_ABI = "pre-cxx11"
@@ -130,7 +131,7 @@ def update_versions(versions, release_matrix, release_version):
130131 )
131132 versions ["latest_stable" ] = version
132133
133- # Perform update of the json file from release matrix
134+ # Perform update of the JSON file from the release matrix
134135 for os_key , os_vers in versions ["versions" ][version ].items ():
135136 for pkg_key , pkg_vers in os_vers .items ():
136137 for acc_key , instr in pkg_vers .items ():
@@ -145,7 +146,6 @@ def update_versions(versions, release_matrix, release_version):
145146 if (x ["package_type" ], x ["gpu_arch_type" ], x ["gpu_arch_version" ])
146147 == (package_type , gpu_arch_type , gpu_arch_version )
147148 ]
148-
149149 if pkg_arch_matrix :
150150 if package_type != "libtorch" :
151151 instr ["command" ] = pkg_arch_matrix [0 ]["installation" ]
@@ -200,17 +200,42 @@ def gen_install_matrix(versions) -> Dict[str, str]:
200200 key = f"{ ver } ,{ pkg_key } ,{ os_key } ,{ acc_key } ,{ extra_key } "
201201 note = instr ["note" ]
202202 lines = [note ] if note is not None else []
203- if pkg_key == "libtorch" :
204- ivers = instr ["versions" ]
205- if ivers is not None :
206- lines += [
207- f"{ lab } <br /><a href='{ val } '>{ val } </a>"
208- for (lab , val ) in ivers .items ()
209- ]
203+ if os_key == "windows" :
204+ if pkg_key == "libtorch" :
205+ ivers = instr ["versions" ]
206+ if ivers is not None :
207+ # Flatten x64/arm64 links into separate lines with arch in label
208+ for lab , val in ivers .items ():
209+ if isinstance (val , dict ):
210+ for arch , url in val .items ():
211+ if url :
212+ lines .append (f"{ lab [:- 1 ]} { arch } :<br /><a href='{ url } '>{ url } </a>" )
213+ else :
214+ lines .append (f"{ lab } <br /><a href='{ val } '>{ val } </a>" )
215+ elif pkg_key == "pip" :
216+ command = instr .get ("command" )
217+ if isinstance (command , dict ):
218+ for arch , cmd in command .items ():
219+ lines .append (f"<b>{ arch } </b>: { cmd } " )
220+ elif command is not None :
221+ lines .append (command )
222+ else :
223+ command = instr .get ("command" )
224+ if command is not None :
225+ lines .append (command )
210226 else :
211- command = instr ["command" ]
212- if command is not None :
213- lines .append (command )
227+ if pkg_key == "libtorch" :
228+ ivers = instr ["versions" ]
229+ if ivers is not None :
230+ lines += [
231+ f"{ lab } <br /><a href='{ val } '>{ val } </a>"
232+ for (lab , val ) in ivers .items ()
233+ ]
234+ else :
235+ command = instr .get ("command" )
236+ if command is not None :
237+ lines .append (command )
238+
214239 result [key ] = "<br />" .join (lines )
215240 return result
216241
@@ -235,6 +260,44 @@ def gen_ver_list(chan, gpu_arch_type):
235260 acc_arch_ver_map [chan ][label ] = ("cuda" , cuda_ver )
236261
237262
263+ def merge_windows_arch_entries (entries ):
264+ """
265+ Merge x64 and arm64 entries for Windows
266+ """
267+ from collections import defaultdict
268+
269+ def entry_key (entry ):
270+ # Exclude validation_runner and installation from the key
271+ return tuple (
272+ (k , v )
273+ for k , v in sorted (entry .items ())
274+ if k not in ("validation_runner" , "installation" , "upload_to_base_bucket" )
275+ )
276+
277+ grouped = defaultdict (dict )
278+ for entry in entries :
279+ key = entry_key (entry )
280+ arch = "arm64" if "arm64" in str (entry .get ("validation_runner" , "" )).lower () else "x64"
281+ grouped [key ][arch ] = entry
282+
283+ merged = []
284+ for key , arch_dict in grouped .items ():
285+ if "x64" in arch_dict and "arm64" in arch_dict :
286+ base = {k : v for k , v in arch_dict ["x64" ].items () if k not in ("validation_runner" , "installation" )}
287+ base ["validation_runner" ] = {
288+ "x64" : arch_dict ["x64" ]["validation_runner" ],
289+ "arm64" : arch_dict ["arm64" ]["validation_runner" ],
290+ }
291+ base ["installation" ] = {
292+ "x64" : arch_dict ["x64" ]["installation" ],
293+ "arm64" : arch_dict ["arm64" ]["installation" ],
294+ }
295+ merged .append (base )
296+ else :
297+ merged .extend (arch_dict .values ())
298+ return merged
299+
300+
238301def main ():
239302 parser = argparse .ArgumentParser ()
240303 parser .add_argument ("--autogenerate" , dest = "autogenerate" , action = "store_true" )
@@ -248,7 +311,13 @@ def main():
248311 for val in ("nightly" , "release" ):
249312 release_matrix [val ] = {}
250313 for osys in OperatingSystem :
251- release_matrix [val ][osys .value ] = read_matrix_for_os (osys , val )
314+ if osys == OperatingSystem .WINDOWS_ARM64 :
315+ winarm64_matrix = read_matrix_for_os (osys , val )
316+ windowsx64_matrix = release_matrix [val ][OperatingSystem .WINDOWS .value ]
317+ merged = merge_windows_arch_entries (windowsx64_matrix + winarm64_matrix )
318+ release_matrix [val ][OperatingSystem .WINDOWS .value ] = merged
319+ else :
320+ release_matrix [val ][osys .value ] = read_matrix_for_os (osys , val )
252321
253322 write_releases_file (release_matrix )
254323
0 commit comments