diff --git a/mpl/_mpl.py b/mpl/_mpl.py
new file mode 100644
index 0000000..209295b
--- /dev/null
+++ b/mpl/_mpl.py
@@ -0,0 +1,44 @@
+import torch
+
+def compute_weights(losses, indices, weights, ratio, p):
+ size = losses.size(0)
+
+ # find first nonzero element
+ pos = 0
+ while losses[pos]< 1e-5:
+ pos += 1
+ n = size - pos
+ m = int(ratio * n)
+ if n <= 0 or m <= 0:
+ raise ValueError
+ q = p / (p - 1.0)
+ c = m - n + 1
+ a = [0.0 , 0.0]
+ i = pos
+ nu = 0.0
+ while i < n and nu < 1e-5:
+ loss_q = (losses[i] / losses[size - 1]) ** q
+ a[0] = a[1]
+ a[1] += loss_q
+ c += 1
+ nu = c * loss_q - a[1]
+
+ # compute alpha
+ if nu < 1e-5:
+ i += 1
+ c += 1
+ a[0] = a[1]
+ alpha = (a[0] / c) ** (1 / q) * losses[size - 1]
+
+ # compute_weights
+ tau = 1.0 / (n ** (1.0 / q)*(m **(1.0 / p)))
+ k = i
+ while k < n:
+ # maybe wrong
+ weights[indices[k]] = tau
+ k += 1
+ if alpha > -1e-5:
+ k = pos
+ while k < i:
+ weights[indices[k]] = tau * (losses[k] / alpha) ** (q - 1)
+ k += 1
diff --git a/mpl/build.py b/mpl/build.py
deleted file mode 100644
index 71a6027..0000000
--- a/mpl/build.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import os
-
-from torch.utils.ffi import create_extension
-
-sources = ['src/lib_mpl.cpp']
-headers = ['src/lib_mpl.h']
-with_cuda = False
-
-this_file = os.path.dirname(os.path.realpath(__file__))
-
-ffi = create_extension(
- '_mpl',
- headers=headers,
- sources=sources,
- relative_to=__file__,
- with_cuda=with_cuda
-)
-
-if __name__ == '__main__':
- ffi.build()
diff --git a/mpl/src/lib_mpl.cpp b/mpl/src/lib_mpl.cpp
deleted file mode 100644
index 1ec95fe..0000000
--- a/mpl/src/lib_mpl.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-#include
-#include
-#include
-
-extern "C" void compute_weights(int size,
- const THFloatTensor *losses,
- const THLongTensor *indices,
- THFloatTensor *weights,
- float ratio, float p)
-{
- // int size = losses->size[0];
- const float* losses_data = THFloatTensor_data(losses);
- const int64_t* indices_data = THLongTensor_data(indices);
- float* weights_data = THFloatTensor_data(weights);
-
- // find first nonzero element
- int pos = 0;
- while( losses_data[pos] < std::numeric_limits::epsilon() )
- {
- ++pos;
- }
-
- // Algorithm #1
- int n = size - pos;
- int m = int(ratio * n);
- if (n <= 0 || m <= 0) return;
- float q = p / (p - 1.0);
- int c = m - n + 1;
- float a[2] = {0.0};
- int i = pos;
- float eta = 0.0;
- for(; i < n && eta < std::numeric_limits::epsilon(); ++i) {
- float loss_q = pow(losses_data[i] / losses_data[size - 1], q);
- a[0] = a[1];
- a[1] += loss_q;
- c += 1;
- eta = c * loss_q - a[1];
- }
-
- // compute alpha
- float alpha;
- if (eta < std::numeric_limits::epsilon())
- {
- c += 1;
- a[0] = a[1];
- }
- alpha = pow(a[0] / c, 1.0 / q) * losses_data[size - 1];
-
- // compute weights
- float tau = 1.0 / (pow(n, 1.0 / q) * pow(m, 1.0 / p));
- for (int k = i; k < n; ++k)
- {
- weights_data[indices_data[k]] = tau;
- }
- if (alpha > -std::numeric_limits::epsilon())
- {
- for(int k = pos; k < i; ++k)
- {
- weights_data[indices_data[k]] = tau * pow(losses_data[k] / alpha, q - 1);
- }
- }
-}
diff --git a/mpl/src/lib_mpl.h b/mpl/src/lib_mpl.h
deleted file mode 100644
index a0fae52..0000000
--- a/mpl/src/lib_mpl.h
+++ /dev/null
@@ -1,5 +0,0 @@
-void compute_weights(int size,
- const THFloatTensor *losses,
- const THLongTensor *indices,
- THFloatTensor *weights,
- float ratio, float p);
|