@@ -1458,10 +1458,75 @@ def take(x, indices, axis=None):
14581458
14591459
14601460def take_along_axis (x , indices , axis = None ):
1461- raise NotImplementedError (
1462- "`take_along_axis` is not supported with openvino backend"
1461+ x = get_ov_output (x )
1462+ indices = get_ov_output (indices )
1463+
1464+ if axis is None :
1465+ target_shape = ov_opset .constant ([- 1 ], dtype = Type .i32 ).output (0 )
1466+ x_flat = ov_opset .reshape (x , target_shape , False ).output (0 )
1467+ indices_flat = ov_opset .reshape (indices , target_shape , False ).output (0 )
1468+ result = ov_opset .gather_elements (x_flat , indices_flat , 0 ).output (0 )
1469+ return OpenVINOKerasTensor (result )
1470+
1471+ x_rank = len (x .get_partial_shape ())
1472+ if axis < 0 :
1473+ axis += x_rank
1474+
1475+ x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1476+ indices_shape = ov_opset .shape_of (indices , Type .i32 ).output (0 )
1477+
1478+ zero_const = ov_opset .constant (0 , dtype = Type .i32 ).output (0 )
1479+ axis_index = ov_opset .constant ([axis ], dtype = Type .i32 ).output (0 )
1480+
1481+ # Fix negative indices
1482+ dim_size = ov_opset .squeeze (
1483+ ov_opset .gather (x_shape , axis_index , zero_const ).output (0 ), zero_const
1484+ ).output (0 )
1485+ zero_scalar = ov_opset .constant (0 , indices .get_element_type ()).output (0 )
1486+ is_neg = ov_opset .less (indices , zero_scalar ).output (0 )
1487+ dim_size_cast = ov_opset .convert (
1488+ dim_size , indices .get_element_type ()
1489+ ).output (0 )
1490+ indices = ov_opset .select (
1491+ is_neg , ov_opset .add (indices , dim_size_cast ).output (0 ), indices
1492+ ).output (0 )
1493+ indices = ov_opset .convert (indices , Type .i32 ).output (0 )
1494+
1495+ x_target_parts , indices_target_parts = [], []
1496+
1497+ for i in range (x_rank ):
1498+ dim_idx = ov_opset .constant ([i ], dtype = Type .i32 ).output (0 )
1499+ x_dim = ov_opset .gather (x_shape , dim_idx , zero_const ).output (0 )
1500+ indices_dim = ov_opset .gather (
1501+ indices_shape , dim_idx , zero_const
1502+ ).output (0 )
1503+
1504+ if i == axis :
1505+ # For axis dimension: keep original dimensions
1506+ x_target_parts .append (x_dim )
1507+ indices_target_parts .append (indices_dim )
1508+ else :
1509+ # For other dimensions: use maximum for broadcasting
1510+ max_dim = ov_opset .maximum (x_dim , indices_dim ).output (0 )
1511+ x_target_parts .append (max_dim )
1512+ indices_target_parts .append (max_dim )
1513+
1514+ x_target_shape = ov_opset .concat (x_target_parts , axis = 0 ).output (0 )
1515+ indices_target_shape = ov_opset .concat (indices_target_parts , axis = 0 ).output (
1516+ 0
14631517 )
14641518
1519+ # Broadcast to target shapes and gather elements
1520+ x_broadcasted = ov_opset .broadcast (x , x_target_shape ).output (0 )
1521+ indices_broadcasted = ov_opset .broadcast (
1522+ indices , indices_target_shape
1523+ ).output (0 )
1524+ result = ov_opset .gather_elements (
1525+ x_broadcasted , indices_broadcasted , axis
1526+ ).output (0 )
1527+
1528+ return OpenVINOKerasTensor (result )
1529+
14651530
14661531def tan (x ):
14671532 x = get_ov_output (x )
0 commit comments