| 
13 | 13 |    "metadata": {},  | 
14 | 14 |    "outputs": [],  | 
15 | 15 |    "source": [  | 
16 |  | -    "#|default_exp utils"  | 
 | 16 | +    "# |default_exp utils"  | 
17 | 17 |    ]  | 
18 | 18 |   },  | 
19 | 19 |   {  | 
 | 
22 | 22 |    "metadata": {},  | 
23 | 23 |    "outputs": [],  | 
24 | 24 |    "source": [  | 
25 |  | -    "#|export\n",  | 
 | 25 | +    "# |export\n",  | 
26 | 26 |     "import matplotlib.pyplot as plt\n",  | 
27 |  | -    "from   matplotlib.collections import LineCollection\n",  | 
 | 27 | +    "from matplotlib.collections import LineCollection\n",  | 
28 | 28 |     "import numpy as np\n",  | 
29 | 29 |     "import jax\n",  | 
30 | 30 |     "import jax.numpy as jnp\n",  | 
 | 
44 | 44 |    "metadata": {},  | 
45 | 45 |    "outputs": [],  | 
46 | 46 |    "source": [  | 
47 |  | -    "#|export\n",  | 
48 |  | -    "key       = jax.random.PRNGKey(0)\n",  | 
 | 47 | +    "# |export\n",  | 
 | 48 | +    "key = jax.random.PRNGKey(0)\n",  | 
49 | 49 |     "logsumexp = jax.scipy.special.logsumexp"  | 
50 | 50 |    ]  | 
51 | 51 |   },  | 
 | 
55 | 55 |    "metadata": {},  | 
56 | 56 |    "outputs": [],  | 
57 | 57 |    "source": [  | 
58 |  | -    "#|export\n",  | 
 | 58 | +    "# |export\n",  | 
59 | 59 |     "def keysplit(key, *ns):\n",  | 
60 |  | -    "    if len(ns) == 0:  \n",  | 
 | 60 | +    "    if len(ns) == 0:\n",  | 
61 | 61 |     "        return jax.random.split(key, 1)[0]\n",  | 
62 | 62 |     "    elif len(ns) == 1:\n",  | 
63 |  | -    "        n, = ns\n",  | 
64 |  | -    "        if n == 1: return keysplit(key)\n",  | 
65 |  | -    "        else:      return jax.random.split(key, ns[0])\n",  | 
 | 63 | +    "        (n,) = ns\n",  | 
 | 64 | +    "        if n == 1:\n",  | 
 | 65 | +    "            return keysplit(key)\n",  | 
 | 66 | +    "        else:\n",  | 
 | 67 | +    "            return jax.random.split(key, ns[0])\n",  | 
66 | 68 |     "    else:\n",  | 
67 | 69 |     "        keys = []\n",  | 
68 |  | -    "        for n in ns: keys.append(keysplit(key, n))\n",  | 
69 |  | -    "        return keys\n"  | 
 | 70 | +    "        for n in ns:\n",  | 
 | 71 | +    "            keys.append(keysplit(key, n))\n",  | 
 | 72 | +    "        return keys"  | 
70 | 73 |    ]  | 
71 | 74 |   },  | 
72 | 75 |   {  | 
 | 
122 | 125 |    "metadata": {},  | 
123 | 126 |    "outputs": [],  | 
124 | 127 |    "source": [  | 
125 |  | -    "#|export\n",  | 
 | 128 | +    "# |export\n",  | 
126 | 129 |     "def bounding_box(arr, pad=0):\n",  | 
127 | 130 |     "    \"\"\"Takes a euclidean-like arr (`arr.shape[-1] == 2`) and returns its bounding box.\"\"\"\n",  | 
128 |  | -    "    return jnp.array([\n",  | 
129 |  | -    "        [jnp.min(arr[...,0])-pad, jnp.min(arr[...,1])-pad],\n",  | 
130 |  | -    "        [jnp.max(arr[...,0])+pad, jnp.max(arr[...,1])+pad]\n",  | 
131 |  | -    "    ])"  | 
 | 131 | +    "    return jnp.array(\n",  | 
 | 132 | +    "        [\n",  | 
 | 133 | +    "            [jnp.min(arr[..., 0]) - pad, jnp.min(arr[..., 1]) - pad],\n",  | 
 | 134 | +    "            [jnp.max(arr[..., 0]) + pad, jnp.max(arr[..., 1]) + pad],\n",  | 
 | 135 | +    "        ]\n",  | 
 | 136 | +    "    )"  | 
132 | 137 |    ]  | 
133 | 138 |   },  | 
134 | 139 |   {  | 
 | 
137 | 142 |    "metadata": {},  | 
138 | 143 |    "outputs": [],  | 
139 | 144 |    "source": [  | 
140 |  | -    "#|export\n",  | 
 | 145 | +    "# |export\n",  | 
141 | 146 |     "def argmax_axes(a, axes=None):\n",  | 
142 | 147 |     "    \"\"\"Argmax along specified axes\"\"\"\n",  | 
143 |  | -    "    if axes is None: return jnp.argmax(a)\n",  | 
144 |  | -    "    \n",  | 
145 |  | -    "    n = len(axes)        \n",  | 
146 |  | -    "    axes_  = set(range(a.ndim))\n",  | 
 | 148 | +    "    if axes is None:\n",  | 
 | 149 | +    "        return jnp.argmax(a)\n",  | 
 | 150 | +    "\n",  | 
 | 151 | +    "    n = len(axes)\n",  | 
 | 152 | +    "    axes_ = set(range(a.ndim))\n",  | 
147 | 153 |     "    axes_0 = axes\n",  | 
148 |  | -    "    axes_1 = sorted(axes_ - set(axes_0))    \n",  | 
149 |  | -    "    axes_  = axes_0 + axes_1\n",  | 
 | 154 | +    "    axes_1 = sorted(axes_ - set(axes_0))\n",  | 
 | 155 | +    "    axes_ = axes_0 + axes_1\n",  | 
150 | 156 |     "\n",  | 
151 | 157 |     "    b = jnp.transpose(a, axes=axes_)\n",  | 
152 | 158 |     "    c = b.reshape(np.prod(b.shape[:n]), -1)\n",  | 
153 | 159 |     "\n",  | 
154 | 160 |     "    I = jnp.argmax(c, axis=0)\n",  | 
155 |  | -    "    I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(b.shape[n:] + (n,))\n",  | 
 | 161 | +    "    I = jnp.array([jnp.unravel_index(i, b.shape[:n]) for i in I]).reshape(\n",  | 
 | 162 | +    "        b.shape[n:] + (n,)\n",  | 
 | 163 | +    "    )\n",  | 
156 | 164 |     "\n",  | 
157 |  | -    "    return  I"  | 
 | 165 | +    "    return I"  | 
158 | 166 |    ]  | 
159 | 167 |   },  | 
160 | 168 |   {  | 
 | 
177 | 185 |     "test_shape = (3, 99, 5, 9)\n",  | 
178 | 186 |     "a = jnp.arange(np.prod(test_shape)).reshape(test_shape)\n",  | 
179 | 187 |     "\n",  | 
180 |  | -    "I = argmax_axes(a, axes=[0,1])\n",  | 
 | 188 | +    "I = argmax_axes(a, axes=[0, 1])\n",  | 
181 | 189 |     "I.shape"  | 
182 | 190 |    ]  | 
183 | 191 |   },  | 
 | 
194 | 202 |    "metadata": {},  | 
195 | 203 |    "outputs": [],  | 
196 | 204 |    "source": [  | 
197 |  | -    "#|export\n",  | 
198 |  | -    "def cam_to_screen(x): return jnp.array([x[0]/x[2], x[1]/x[2], jnp.linalg.norm(x)])\n",  | 
199 |  | -    "def screen_to_cam(y): return y[2]*jnp.array([y[0], y[1], 1.0])"  | 
 | 205 | +    "# |export\n",  | 
 | 206 | +    "def cam_to_screen(x):\n",  | 
 | 207 | +    "    return jnp.array([x[0] / x[2], x[1] / x[2], jnp.linalg.norm(x)])\n",  | 
 | 208 | +    "\n",  | 
 | 209 | +    "\n",  | 
 | 210 | +    "def screen_to_cam(y):\n",  | 
 | 211 | +    "    return y[2] * jnp.array([y[0], y[1], 1.0])"  | 
200 | 212 |    ]  | 
201 | 213 |   },  | 
202 | 214 |   {  | 
 | 
205 | 217 |    "metadata": {},  | 
206 | 218 |    "outputs": [],  | 
207 | 219 |    "source": [  | 
208 |  | -    "#|export\n",  | 
209 |  | -    "def rot2d(hd): return jnp.array([\n",  | 
210 |  | -    "    [jnp.cos(hd), -jnp.sin(hd)], \n",  | 
211 |  | -    "    [jnp.sin(hd),  jnp.cos(hd)]\n",  | 
212 |  | -    "    ]);\n",  | 
 | 220 | +    "# |export\n",  | 
 | 221 | +    "def rot2d(hd):\n",  | 
 | 222 | +    "    return jnp.array([[jnp.cos(hd), -jnp.sin(hd)], [jnp.sin(hd), jnp.cos(hd)]])\n",  | 
 | 223 | +    "\n",  | 
213 | 224 |     "\n",  | 
214 |  | -    "def pack_2dpose(x,hd): \n",  | 
215 |  | -    "    return jnp.concatenate([x,jnp.array([hd])])\n",  | 
 | 225 | +    "def pack_2dpose(x, hd):\n",  | 
 | 226 | +    "    return jnp.concatenate([x, jnp.array([hd])])\n",  | 
216 | 227 |     "\n",  | 
217 |  | -    "def apply_2dpose(p, ys): \n",  | 
218 |  | -    "    return ys@rot2d(p[2] - jnp.pi/2).T + p[:2]\n",  | 
219 | 228 |     "\n",  | 
220 |  | -    "def unit_vec(hd): \n",  | 
 | 229 | +    "def apply_2dpose(p, ys):\n",  | 
 | 230 | +    "    return ys @ rot2d(p[2] - jnp.pi / 2).T + p[:2]\n",  | 
 | 231 | +    "\n",  | 
 | 232 | +    "\n",  | 
 | 233 | +    "def unit_vec(hd):\n",  | 
221 | 234 |     "    return jnp.array([jnp.cos(hd), jnp.sin(hd)])\n",  | 
222 | 235 |     "\n",  | 
 | 236 | +    "\n",  | 
223 | 237 |     "def adjust_angle(hd):\n",  | 
224 | 238 |     "    \"\"\"Adjusts angle to lie in the interval [-pi,pi).\"\"\"\n",  | 
225 |  | -    "    return (hd + jnp.pi)%(2*jnp.pi) - jnp.pi"  | 
 | 239 | +    "    return (hd + jnp.pi) % (2 * jnp.pi) - jnp.pi"  | 
226 | 240 |    ]  | 
227 | 241 |   },  | 
228 | 242 |   {  | 
 | 
238 | 252 |    "metadata": {},  | 
239 | 253 |    "outputs": [],  | 
240 | 254 |    "source": [  | 
241 |  | -    "#|export\n",  | 
 | 255 | +    "# |export\n",  | 
242 | 256 |     "from genjax.incremental import UnknownChange, NoChange, Diff\n",  | 
243 | 257 |     "\n",  | 
244 | 258 |     "\n",  | 
245 | 259 |     "def argdiffs(args, other=None):\n",  | 
246 |  | -    "    return tuple(map(lambda v: Diff(v, UnknownChange), args))\n"  | 
 | 260 | +    "    return tuple(map(lambda v: Diff(v, UnknownChange), args))"  | 
247 | 261 |    ]  | 
248 | 262 |   },  | 
249 | 263 |   {  | 
 | 
252 | 266 |    "metadata": {},  | 
253 | 267 |    "outputs": [],  | 
254 | 268 |    "source": [  | 
255 |  | -    "#|export\n",  | 
 | 269 | +    "# |export\n",  | 
256 | 270 |     "from builtins import property as _property, tuple as _tuple\n",  | 
257 | 271 |     "from typing import Any\n",  | 
258 | 272 |     "\n",  | 
259 | 273 |     "\n",  | 
260 | 274 |     "class Args(tuple):\n",  | 
261 | 275 |     "    def __new__(cls, *args, **kwargs):\n",  | 
262 | 276 |     "        return _tuple.__new__(cls, list(args) + list(kwargs.values()))\n",  | 
263 |  | -    "    \n",  | 
 | 277 | +    "\n",  | 
264 | 278 |     "    def __init__(self, *args, **kwargs):\n",  | 
265 | 279 |     "        self._d = dict()\n",  | 
266 |  | -    "        for k,v in kwargs.items():\n",  | 
 | 280 | +    "        for k, v in kwargs.items():\n",  | 
267 | 281 |     "            self._d[k] = v\n",  | 
268 | 282 |     "            setattr(self, k, v)\n",  | 
269 | 283 |     "\n",  | 
 | 
297 | 311 |    "metadata": {},  | 
298 | 312 |    "outputs": [],  | 
299 | 313 |    "source": [  | 
300 |  | -    "#|export\n",  | 
301 |  | -    "# \n",  | 
 | 314 | +    "# |export\n",  | 
 | 315 | +    "#\n",  | 
302 | 316 |     "# Monkey patching `sample` for `BuiltinGenerativeFunction`\n",  | 
303 |  | -    "# \n",  | 
 | 317 | +    "#\n",  | 
304 | 318 |     "cls = genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction\n",  | 
305 | 319 |     "\n",  | 
 | 320 | +    "\n",  | 
306 | 321 |     "def genjax_sample(self, key, *args, **kwargs):\n",  | 
307 | 322 |     "    tr = self.simulate(key, args)\n",  | 
308 | 323 |     "    return tr.get_retval()\n",  | 
309 | 324 |     "\n",  | 
 | 325 | +    "\n",  | 
310 | 326 |     "setattr(cls, \"sample\", genjax_sample)\n",  | 
311 | 327 |     "\n",  | 
312 | 328 |     "\n",  | 
313 |  | -    "# \n",  | 
 | 329 | +    "#\n",  | 
314 | 330 |     "# Monkey patching `sample` for `DeferredGenerativeFunctionCall`\n",  | 
315 |  | -    "# \n",  | 
 | 331 | +    "#\n",  | 
316 | 332 |     "cls = genjax._src.generative_functions.supports_callees.SugaredGenerativeFunctionCall\n",  | 
317 | 333 |     "\n",  | 
 | 334 | +    "\n",  | 
318 | 335 |     "def deff_gen_func_call(self, key, **kwargs):\n",  | 
319 | 336 |     "    return self.gen_fn.sample(key, *self.args, **kwargs)\n",  | 
320 | 337 |     "\n",  | 
 | 338 | +    "\n",  | 
321 | 339 |     "def deff_gen_func_logpdf(self, x, **kwargs):\n",  | 
322 | 340 |     "    return self.gen_fn.logpdf(x, *self.args, **kwargs)\n",  | 
323 | 341 |     "\n",  | 
 | 342 | +    "\n",  | 
324 | 343 |     "setattr(cls, \"__call__\", deff_gen_func_call)\n",  | 
325 | 344 |     "setattr(cls, \"sample\", deff_gen_func_call)\n",  | 
326 | 345 |     "setattr(cls, \"logpdf\", deff_gen_func_logpdf)"  | 
 | 
0 commit comments