- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.4k
Torch for Numpy users
        Peter O'Connor edited this page Aug 25, 2018 
        ·
        17 revisions
      
    torch equivalents of numpy functions
| Numpy | Torch | 
|---|---|
| np.ndarray | torch.Tensor | 
| np.float32 | torch.FloatTensor | 
| np.float64 | torch.DoubleTensor | 
| np.int8 | torch.CharTensor | 
| np.uint8 | torch.ByteTensor | 
| np.int16 | torch.ShortTensor | 
| np.int32 | torch.IntTensor | 
| np.int64 | torch.LongTensor | 
| Numpy | Torch | 
|---|---|
| np.empty([2,2]) | torch.Tensor(2,2) | 
| np.empty_like(x) | x.new(x:size()) | 
| np.eye | torch.eye | 
| np.identity | torch.eye | 
| np.ones | torch.ones | 
| np.ones_like | torch.ones(x:size()) | 
| np.zeros | torch.zeros | 
| np.zeros_like | torch.zeros(x:size()) | 
| Numpy | Torch | 
|---|---|
| np.array([ [1,2],[3,4] ]) | torch.Tensor({{1,2},{3,4}}) | 
| np.ascontiguousarray(x) | x:contiguous() | 
| np.copy(x) | x:clone() | 
| np.fromfile(file) | torch.Tensor(torch.Storage(file)) | 
| np.frombuffer | ??? | 
| np.fromfunction | ??? | 
| np.fromiter | ??? | 
| np.fromstring | ??? | 
| np.loadtxt | ??? | 
| np.concatenate | torch.cat | 
| np.multiply | torch.cmul | 
| Numpy | Torch | 
|---|---|
| np.arange(10) | torch.range(0,9) | 
| np.arange(2, 3, 0.1) | torch.linspace(2, 2.9, 10) | 
| np.linspace(1, 4, 6) | torch.linspace(1, 4, 6) | 
| np.logspace | torch.logspace | 
| Numpy | Torch | 
|---|---|
| np.diag | torch.diag | 
| np.tril | torch.tril | 
| np.triu | torch.triu | 
| Numpy | Torch | 
|---|---|
| x.shape | x:size() | 
| x.strides | x:stride() | 
| x.ndim | x:dim() | 
| x.data | x:data() | 
| x.size | x:nElement() | 
| x.size == y.size | x:isSameSizeAs(y) | 
| x.dtype | x:type() | 
| Numpy | Torch | 
|---|
| Numpy | Torch | 
|---|---|
| x.reshape | x:reshape | 
| x.resize | x:resize | 
| ? | x:resizeAs | 
| x.transpose | x:transpose() | 
| x.flatten | x:view(x:nElement()) | 
| x.squeeze | x:squeeze | 
| Numpy | Torch | 
|---|---|
| np.take(a, indices) | a[indices] | 
| x[:,0] | x[{{},1}] | 
| np.put | ???? | 
| x.repeat | x:repeatTensor | 
| x.fill | x:fill | 
| np.choose | ??? | 
| np.sort | sorted, indices = torch.sort(x, [dim]) | 
| np.argsort | sorted, indices = torch.sort(x, [dim]) | 
| np.nonzero | torch.find(x:gt(0), 1) (torchx) | 
| Numpy | Torch | 
|---|---|
| ndarray.min | mins, indices = torch.min(x, [dim]) | 
| ndarray.argmin | mins, indices = torch.min(x, [dim]) | 
| ndarray.max | maxs, indices = torch.max(x, [dim]) | 
| ndarray.argmax | maxs, indices = torch.max(x, [dim]) | 
| ndarray.clip | torch.clamp | 
| ndarray.round | |
| ndarray.trace | torch.trace | 
| ndarray.sum | torch.sum | 
| ndarray.cumsum | torch.cumsum | 
| ndarray.mean | torch.mean | 
| ndarray.std | torch.std | 
| ndarray.prod | torch.prod | 
| ndarray.dot | torch.mm | 
| ndarray.cumprod | torch.cumprod | 
| ndarray.all | ??? | 
| ndarray.any | ??? | 
| Numpy | Torch | 
|---|---|
| ndarray.lt | torch.lt | 
| ndarray.le | torch.le | 
| ndarray.gt | torch.gt | 
| ndarray.ge | torch.ge | 
| ndarray.eq | torch.eq | 
| ndarray.ne | torch.ne |