diff --git a/docs/tutorials/constrained_sinkhorn_ott.ipynb b/docs/tutorials/constrained_sinkhorn_ott.ipynb new file mode 100644 index 000000000..c0f21471b --- /dev/null +++ b/docs/tutorials/constrained_sinkhorn_ott.ipynb @@ -0,0 +1,2060 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b601f657", + "metadata": {}, + "source": [ + "# Constrained Optimal Transport with OTT-JAX\n", + "\n", + "This tutorial implements the **Sinkhorn-type algorithm for constrained optimal transport** of [Tang et al., *A Sinkhorn-type Algorithm for Constrained Optimal Transport*](https://openreview.net/forum?id=V5kCKFav9j) (ICLR 2025) on top of OTT-JAX. The paper extends the classical Sinkhorn algorithm to handle **additional linear constraints** on the transport plan beyond the standard marginal constraints : a capability that until now had only been addressed for special structured cases (capacity, multimarginal, partial OT).\n", + "\n", + "The strategy of the tutorial is to express the constrained algorithm as a **thin wrapper around OTT-JAX's existing Sinkhorn solver**: at each outer iteration we (i) build a `Geometry` whose cost matrix is *modified* by the current constraint duals, (ii) run a few Sinkhorn updates using OTT-JAX's stabilised log-domain kernel (`apply_lse_kernel`), and (iii) take a Newton step on the *new* dual variables that the constraints introduce. This reframes constrained Sinkhorn as a sibling of the multimarginal and unbalanced extensions already documented in the OTT-JAX tutorial collection, and it lets us reuse the library's hardened numerics for free.\n", + "\n", + "In what follows we\n", + "\n", + "1. recall standard entropic OT through OTT-JAX,\n", + "2. derive the dual of the **constrained** entropic OT problem and the resulting three-step algorithm,\n", + "3. implement the algorithm by *modifying* the OTT-JAX `Geometry` at every iteration and reusing `apply_lse_kernel` for the Sinkhorn updates,\n", + "4. reproduce the paper's first numerical experiment (random assignment with mixed constraints) and inspect how the dual variables evolve,\n", + "5. verify Theorem 1 numerically : exponential convergence of the entropic solution to the LP optimum,\n", + "6. trace a Pareto front between two competing transport costs,\n", + "7. visualise the algorithm on a 2D point-cloud problem with a forbidden target zone,\n", + "8. apply the algorithm to a fairness-constrained matching problem,\n", + "9. use the constraint dual as a free gradient via the envelope theorem.\n" + ] + }, + { + "cell_type": "markdown", + "id": "94891fc2", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f66aaf04", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:38:56.533304Z", + "iopub.status.busy": "2026-05-02T06:38:56.533134Z", + "iopub.status.idle": "2026-05-02T06:38:58.247449Z", + "shell.execute_reply": "2026-05-02T06:38:58.246370Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OTT-JAX version: 0.6.0\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "from typing import NamedTuple\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "import matplotlib.patches as mpatches\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import ott\n", + "from ott.geometry import geometry, pointcloud\n", + "from ott.problems.linear import linear_problem\n", + "from ott.solvers.linear import sinkhorn\n", + "from ott.tools import plot\n", + "\n", + "# Use double precision throughout: the algorithm performs Newton steps and\n", + "# benefits from extra precision, especially when the regularisation is weak.\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "plt.rcParams.update(\n", + " {\n", + " \"figure.dpi\": 110,\n", + " \"axes.spines.top\": False,\n", + " \"axes.spines.right\": False,\n", + " \"font.size\": 11,\n", + " }\n", + ")\n", + "print(f\"OTT-JAX version: {ott.__version__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a0b9f3d5", + "metadata": {}, + "source": [ + "## 1. Standard entropic OT, in one line of OTT-JAX\n", + "\n", + "We compare two probability vectors $a \\in \\Sigma_n,\\ b \\in \\Sigma_m$ on point clouds $X = (x_1, \\ldots, x_n)$ and $Y = (y_1, \\ldots, y_m)$. The Kantorovich problem is the linear program\n", + "$$\\min_{P \\in U(a,b)} \\langle P, C \\rangle, \\qquad U(a,b) := \\{P \\ge 0 : P\\mathbf{1} = a,\\ P^\\top\\mathbf{1} = b\\}.$$\n", + "\n", + "The Sinkhorn algorithm replaces this LP with its **entropy-regularised** version\n", + "$$P_\\varepsilon^\\star \\;=\\; \\arg\\min_{P \\in U(a,b)} \\langle P, C \\rangle + \\varepsilon \\sum_{ij} p_{ij} \\log p_{ij},$$\n", + "whose solution can be computed by alternating row and column normalisations of a kernel $K = \\exp(-C/\\varepsilon)$. The resulting $P_\\varepsilon^\\star$ traces a curve inside the polytope $U(a,b)$ from the independent coupling $a b^\\top$ (when $\\varepsilon \\to \\infty$) to the LP optimum $P^\\star$ (when $\\varepsilon \\to 0$).\n", + "\n", + "In OTT-JAX, the workflow has three pieces: (a) wrap the cost in a `Geometry`, (b) wrap geometry + marginals in a `LinearProblem`, (c) call `sinkhorn.Sinkhorn()` on the problem. The output exposes the dual potentials $f, g$, the transport `matrix`, and convergence diagnostics." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cd7d2955", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:38:58.249401Z", + "iopub.status.busy": "2026-05-02T06:38:58.249146Z", + "iopub.status.idle": "2026-05-02T06:39:02.300889Z", + "shell.execute_reply": "2026-05-02T06:39:02.299577Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "converged: True\n", + "regularised OT cost: 4.7576\n", + "row marginal max error: 1.22e-15\n" + ] + } + ], + "source": [ + "# Toy example: solve unconstrained entropic OT between two random point clouds.\n", + "key_x, key_y = jax.random.split(jax.random.key(0))\n", + "n_toy, m_toy, d = 60, 70, 2\n", + "x_toy = jax.random.normal(key_x, (n_toy, d))\n", + "y_toy = jax.random.normal(key_y, (m_toy, d)) + jnp.array([2.0, 0.0])\n", + "a_toy = jnp.ones(n_toy) / n_toy\n", + "b_toy = jnp.ones(m_toy) / m_toy\n", + "\n", + "geom_toy = pointcloud.PointCloud(x_toy, y_toy, epsilon=0.05)\n", + "out = sinkhorn.Sinkhorn()(\n", + " linear_problem.LinearProblem(geom_toy, a=a_toy, b=b_toy)\n", + ")\n", + "print(f\"converged: {out.converged}\")\n", + "print(f\"regularised OT cost: {float(out.reg_ot_cost):.4f}\")\n", + "print(\n", + " f\"row marginal max error: {float(jnp.max(jnp.abs(out.matrix.sum(1) - a_toy))):.2e}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "abdf659c", + "metadata": {}, + "source": [ + "Visually, the standard Sinkhorn solution is a near-deterministic matching when $\\varepsilon$ is small enough. The figure below shows source and target with the dominant transport routes overlaid: each line connects a source point to one of its main targets, with line width proportional to the transported mass." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "58c6f7ae", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:02.303181Z", + "iopub.status.busy": "2026-05-02T06:39:02.302546Z", + "iopub.status.idle": "2026-05-02T06:39:02.696162Z", + "shell.execute_reply": "2026-05-02T06:39:02.695117Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAvMAAAHpCAYAAAAPoLO3AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAABAABJREFUeJzsvXd8nNWZ/n09dXrTqDfLNgZ3G7AJGIIxNtWUUEOWJJAGebNpQLKB7NJZWMIuWfLbTSOFdEKADS2UGBsTAwb3jpssq3fNaPo85bx/jM/xNEkjWbIk+3w/DCM/9Uy/nvvc93ULhBACDofD4XA4HA6HM+kQx3sAHA6Hw+FwOBwOZ2RwMc/hcDgcDofD4UxSuJjncDgcDofD4XAmKVzMczgcDofD4XA4kxQu5jkcDofD4XA4nEkKF/McDofD4XA4HM4khYt5DofD4XA4HA5nksLFPIfD4XA4HA6HM0nhYp7D4XA4HA6Hw5mkcDHP4XA4HA6Hw+FMUriY5xxXLrjgAgiCMK5jeOCBByAIAt55550xO8doPc5nnnkGgiDggQceKGj7W2+9dcwfG+fk5Z133oEgCLj11lvHeyick5y6ujrU1dWN9zAKoqGhYcJ+bgRBwAUXXDDew+AcI1zMT2BM08Qvf/lLXHDBBSgqKoKiKCgpKcHcuXNx66234tlnn83YfrjCj1M4+/fvx+23344ZM2bAZrPB4XCgrq4Oy5cvx3333YfGxsbxHuIJCRePY8NEv+gTBGHSCLXjAf9uHzsm+meBwykEebwHwMmPaZq46qqr8Nprr8Hj8eCKK65AdXU1AoEADh48iOeffx7r16/HTTfdNN5DPeF55513cPnllyMWi+ETn/gELrnkEtjtdhw+fBjbt2/Hww8/jNmzZ6O2tpbt89vf/hbRaHQcR83hcDgnLm+//fZ4D4HDmTBwMT9BefbZZ/Haa69hwYIFWLt2LTweT8b6eDyOf/zjH+M0upOL2267DbFYDM888wxuueWWnPUff/wxZDnzo5Qu7DkcDoczukyfPn28h8DhTBh4ms0E5b333gOQmgLMFvIAYLVacdFFF7F/33rrrfjCF74AAHjwwQchCAK70enDYDCIJ554AhdeeCGqq6uhqipKSkpw1VVXYf369XnHQae7o9Eovvvd76K2thYWiwWnnHIKHn/8cRBC8u73s5/9DPPnz4fVakVZWRluvfVWtLe35902mUzif/7nf3D55ZdjypQpsFgsKCoqwooVK/D666/n3aeurg6CICCRSOCBBx7AjBkzoKoqvv3tb7Nt1q5di2XLlsHpdMLr9eKSSy7Bhg0b8h5vILq6urB//354PJ68Qh4AZs6ciVNOOSVjWb6c+fSUkYaGBtx0000oLi6G1WrFokWL8OqrrxY8rmQyic997nMQBAG33HILNE3L2eb555/HWWedBbvdjqKiItx0001oaWnJe7z6+np84QtfYO+LsrIy3Hjjjdi+fXvOtulT/uvXr8dll10Gn88HQRAQCAQy1m/duhUrV66E1+uF3W7H0qVL8f777xf0GB944AEsW7YMAPCb3/wm4z39zDPPAMh8Tvfu3YsbbrgBJSUlEEURW7duBQCsWbMGt912G2bPng232w2bzYa5c+fiwQcfRDwez3teeo41a9bgggsugMvlgtvtxsqVK7Fnz56cfTo7O/Ev//IvmDlzJhwOB1wuF6ZOnYqbbroJ27ZtY9vR3NkLLrgALS0t+OxnP4uSkhLYbDYsWrQIf/7zn/M+F4QQ/OIXv8DZZ58Nl8sFu92O008/Hf/1X/+V97Uf7PNRV1eH3/zmNwCAZcuWZTyvw2HPnj246qqrUFRUBIfDgU9+8pNYtWrVgNu/8MILWLFiBYqKimCxWDBjxgzcfffd6O/vZ9vQ1xMADh8+nDE2mmpVWVmJ8vLynOPPnj0bgiDgW9/6Vsby7du3s89JOqZp4he/+AXOPfdceDweWK1WzJ07F4899hiSyWTex3DgwAF8+ctfZt9TJSUluOaaa7B58+acbUfyPsrHUN/tb775JgRBwPe+972M/T766CO23ZYtWzLW3XnnnRAEAWvWrMlYvnXrVtx4440oKyuDqqqoqanBl770JRw6dKigsQLH9z2enYo13O+e0fgs/P3vf8dVV12FsrIyWCwWVFdXY+XKlXjllVcK2r+jowPf/OY3MW3aNFgsFvj9flxxxRV49913c7alj49+/2UzUB1BJBLB9773Pfa+nT59Oh544IEB3+ehUAiPPPII5s2bB4/HA4fDgdraWlx99dU8HWmCwiPzE5SioiIAwL59+wra/lOf+hQCgQBeeuklLF26NKOghX649+zZg3/913/F+eefj5UrV8Ln86GxsREvv/wyXn/9dbzyyiu49NJLc46taRouvvhitLa24rLLLoMsy/jrX/+Ku+++G/F4HPfff3/G9nfddReefPJJlJSU4Itf/CKcTidef/11LFmyBC6XK+f4vb29+Na3voUlS5bgoosuQklJCdra2vDKK6/g8ssvx9NPP40vf/nLeR/3ddddh82bN+PSSy/FNddcg6lTpwIA/vrXv+L666+HJEm4/vrrMWXKFGzatAmf/OQnceGFFxb0nAKA2+2GLMsIh8Noa2tDRUVFwfsOxOHDh3HWWWdh2rRp+NznPofe3l78+c9/xtVXX41Vq1YxATsQwWAQ1157LVavXo1/+7d/w8MPP5yzzY9//GO8/PLLuOqqq7B06VJ8+OGH+POf/4xt27Zh69atsFgsbNvNmzdj+fLlCAQCWLlyJebPn4+DBw/ixRdfxCuvvIKXXnoJF198cc453n//fTz66KM4//zz8eUvfxltbW2QJImt37hxI37wgx/gnHPOwZe//GU0NjbihRdewPLly7F161acdtppgz7OCy64AA0NDfjNb36DBQsW4FOf+hRbt3DhwoxtDxw4gE984hOYNWsWPve5z6G/vx92ux0A8Pjjj+Pjjz/GkiVLsHLlSsTjcbz33nt44IEH8M4772DVqlUZ46a8+uqreOmll3DZZZfhq1/9Knbv3o2//e1v2LBhA3bv3o3i4mIAQDQaxZIlS3Dw4EEsX74cV1xxBQCgqakJq1evxooVK7BgwYKMY/f19eHcc8+Fz+fDF7/4RfT19eG5557DTTfdhNbWVtxxxx0Z299yyy343e9+h6qqKnzhC1+Aoih45ZVX8J3vfAdvvfUWXnvttZzZISD/5+Pb3/42nnnmGWzbtg233HLLiHLTDx06hHPOOQcLFizA7bffjpaWFjz33HO45JJL8Nxzz+G6667L2P5rX/safvKTn6C6uhrXXHMNfD4f1q9fj8cffxx/+9vf8N5778HlcqGurg73338/HnzwQXg8noyLc/qaL1++HL///e+xY8cOzJs3DwDQ1tbGxHH2BQVNx1i+fDlbpus6rr32Wrzyyis49dRT8ZnPfAZWqxVr167F97//fbz99tt44403Mp7T1atX4+qrr0Y8HscVV1yBGTNmoKWlBS+++CJef/11vPTSS7jkkktynqtC30cDMdR3e2lpKVRVHfBx0+fk9NNPz1hntVqxZMkStuz111/HNddcA8MwcO2112L69OnYvn07fvWrX+H//u//sHr16pzP3WAcr/d4Pgr97jnWz8L999+Phx56CA6HA1dffTWmTJmC9vZ2rF+/Hr/85S9x5ZVXDrr/4cOHcd5556G5uRlLly7Fpz/9abS1teG5557D66+/jl/+8pfHXC+UTCZx0UUX4YMPPsDs2bPxzW9+E7FYDD/5yU8yAg0UQgguvfRSvP/++zjrrLPwxS9+EaqqoqWlBf/4xz+watUqXjA7ESGcCcnmzZuJoihEEARy8803k7/85S+kvr6emKY54D6//vWvCQBy//33510fCARIV1dXzvKmpiZSUVFBZs6cmbMOAAFALrvsMhKNRtnyjo4O4vF4iMfjIclkki3/4IMPCABSW1tLOjo62HJd18nVV1/NjpdOPB4nTU1Necc7Z84c4vP5Ms5NCCFTpkwhAMi8efNyHlM4HCZ+v5+IokjWrVuXse6JJ55gY1izZk2eZymX6667jgAgU6dOJY8//jh5//33STgcHnSfpUuX5jzONWvWsHM/8MADGeveeOMN9jynk/2aNjY2krlz5xJZlsnTTz+dc95bbrmFACAul4ts3749Y91nPvMZAoD8+c9/ZstM0ySzZ88mAMgzzzyTsf3f//53IggCKSkpIZFIJGdMAMjPfvaznDGkr//1r3+dse6nP/0pAUD+v//v/8vZLx/0ObvlllsGXQ+A3HPPPXm3OXjwYN7Pzb/9278RAOTZZ5/NWH7//fcTAESSJLJq1aqMdXfffTcBQB5//HG27OWXXyYAyLe+9a2cc+i6Tvr6+ti/Dx06xMZ74403EsMw2LoDBw4Qj8dDVFUlDQ0NbPmzzz5LAJD58+eTYDDIlicSCXLhhRcSAOSJJ57IOO9gnw9Cjr5PCv0MUNKf7+985zsZ69avX08kSSJFRUUkFAqx5b/73e8IAHLNNdfkfI4ffvhhAoDceeedGcsBkClTpuQdA31//fCHP8w5x0UXXUQAkNbWVrZu5cqVBABpbm7OOe8///M/E13X2XLDMMhXvvIVAoD86Ec/YssDgQDx+/2kqKiI7Nq1K2M8u3fvJk6nk1RUVJB4PM6WD/d9NBhDfbcvXbqUiKJIenp62LLly5eT6dOnk2nTppFLLrmELe/s7CSCIJAVK1awZeFwmBQXFxNBEMjq1aszjv2LX/yCACBz584d9PeHcjzf49nvkZF894z0s/Dmm2+y92ljY2PO+vTfNPqcZH+PXXrppXl/D7Zv305sNhuxWCwZx6GPL/uxUfI9J4899hgBQFauXJnxXu/q6iJ1dXUEAFm6dGnGuQGQq6++Ouf4pmmS7u7uvOfmjC9czE9g/vznP5Py8nL25QSAeDwecvnll5M//elPGV+ShAz9hT8Y3/jGNwgAcvjw4Yzl9Lz79+/P2efzn/88AUB27NjBltEfwp/+9Kc52+/fv5+Iopgjcgfjv/7rvwgAsnbt2ozlVKz89a9/zdnnD3/4AwFAbrrpppx1mqaRadOmDevLu7e3l1x77bVEEAT2fAiCQGbPnk3uuuuujB8lymBifsqUKRlfqpTa2lri9/szlqW/plu3biVVVVXE6XSSv/3tb3nHSn+Y/vVf/zVn3erVqwkActddd7Fl69atIwDI4sWL8x7v2muvJQDIH//4x5wxLVy4MO8+dP25556bsy6ZTBJZlsmZZ56Zd99sChXzZWVlGUKqELq7uwkA8oUvfCFjORVhN998c84+9fX1BAC57rrr2DIq5u++++4hz0l/1CVJIvX19Tnr77nnHgKAPPLII2zZihUrCADy2muv5Wy/bds2AoCceuqpGcsH+3wQcuxi3uPxkP7+/pz19ILx97//PVt2xhlnEEmSMoQmRdd1UlxcTEpLSzOWDybmGxsbmTih3HrrrcTpdJJ33nmHACC/+93vCCGpz7vL5SKnnXYa29YwDHZOTdNyjh8IBIggCBmfiR/96EcEAHnqqafyjumOO+7IeY2G+z4ajKG+2x966CECgPzlL38hhKQCJFarldx2223kK1/5CrHb7SSRSBBCjgrnxx57jO3/+9//ngAgN9xwQ97jn3nmmQQAef/994cc6/F8jw8k5ofz3TPSz8IVV1xBAJDnnntuyG3zifnm5mYCgFRVVbHXJp0777yTACCPPvooWzYSMT9jxgwiCALZs2dPzva/+tWvBhTz+X4/ORMXnmYzgbnxxhtxzTXXYM2aNVi3bh22bNmCdevW4W9/+xv+9re/4ZlnnsHLL78MVVULPuZ7772Hp556Ch988AE6OztzcuZaWlpyijc9Hk9OTjgA1NTUAEhNp1I2bdoEAHmn4U455RRUV1fntXHctWsXnnjiCbz77rtobW1FIpHIGVc+zjrrrJxlg41BlmWcd955qK+vz3u8fPh8PrzwwgtoaGjAm2++iY0bN2LDhg3Yvn07du/ejZ/85Cd4/vnncdlllxV0vIULF+ZN66ipqcEHH3yQd59169bhhz/8Iex2O9auXYszzjhj0HMsWrQo7/GBzNeL5voOlHq0YsUKvPjii9i8eTM+85nPZKzL99wPNQZFUVBWVpYxhtFgwYIFGalD6UQiETz11FN48cUXsW/fPoTD4Yxaj4HeW4U+h0uXLkVNTQ0ef/xxbNy4EStXrsSSJUtwxhlnDJgWUFtby1LC0lm6dCkee+yxjBxn+hrlS7+aP38+SktL2eNyOp0Z64d6jUbKGWeckTdlbunSpfjTn/6ELVu24Oabb0YsFsOWLVtQVFSEH/3oR3mPpaoqWltb0dPTA7/fP+S5a2pqcOqpp2Lt2rXQdR2yLOPtt9/G+eefj3PPPRculwurVq3CZz/7WXz00UcIhUJYsWIF23/fvn3o7u7G9OnT8cgjj+Q9h81my8hppzVM27dvz2sPuXfvXgCpVMbLL788Y12h76NjYcWKFbjvvvuwatUqXH/99XjvvfcQj8exYsUKmKaJp59+Gu+//z4uuOACln6T/pwM9T2wfPlybNq0CZs3b8Y555xT0JiO13s8H8fju4fWmRX6vZ8NfcznnXde3t/wFStW4Mknn8xbj1EooVAI+/fvR3l5OWbOnJmzPt9v5OzZs3HGGWfg2WefRUNDA66++mqce+65WLx4MaxW64jHwhlbuJif4CiKgosvvpjlLBuGgRdeeAFf/OIX8eabb+InP/lJTsHXQPzf//0frr/+elY8O336dDgcDoiiiHfeeQdr167NEdEA4PV68x6PChXDMNiyYDAIAHkL1ACgoqIiR8yvX78eF154IXRdx/Lly3HVVVfB7XazIsaXXnop77gGOk8hYxgJdXV1uP3223H77bcDAJqbm/G1r30Nr7zyCm699VY0NzdDUZQhjzPY82maZt51W7ZsQX9/P8466yzMnj17ROcY6esFAIFAIGfdQPsMNgY6jvQxjAYDjUXTNFx44YX46KOPMHfuXNx0000oKSlhr9ODDz444Hur0OfQ7XZj/fr1ePDBB/HSSy+x3GWv14svfvGLePjhh1n+PqWsrCzvOely+rrQvz0eD2w2W959Kioq0NnZiWAwmCN0hnqNRkqh4+/t7QUhBD09PXjwwQcHPWY4HC5IzAMpcfmTn/wEH374IUpLS9HU1IQ77rgDsizj/PPPZ4I1X758T08PAODgwYNDjil7n1/+8pdDPoZsCn0fHQuLFy+Gy+XKeNyiKOLCCy+EaZoQBAFvv/02E/NerzcjIHAs3wMDcbze4/k4Ht89gUAAbre7oPHkYyye85GeIx1JkvD222/j3//93/GXv/wF99xzDwDAbrfj05/+NH7wgx8MWefBOf5wMT/JkCQJN954I3bs2IFHHnkEb7/9dsFi/t5774Wqqti4cSNmzZqVse7222/H2rVrj3l81Hmnvb09rwtPW1tbzrJHHnkEsVgMq1evzonMPPbYY3jppZcGPF8+14H0MeQj3xhGQnV1NZ599ln4fD50dnZix44dQ0bMR8rXv/51dHd348c//jFWrlyJl19+GQ6H45iPW+hzle+1HO9OvukMNJaXXnoJH330EW655ZYcB4i2traCxdxQVFZW4mc/+xl++tOf4uOPP8batWvx05/+FE8++SQCgUCOCOzo6Mh7HLo8/fn2eDzo7e1FLBbLK3bG4zUqdPz0ft68eXmdkUYKFfOrVq1CaWkpgKOR5uXLl+O1117D3r17sWrVKoiimBGBpGO68sor8fLLLxd0PrrPpk2bxuwzfizIsoylS5fi1VdfxeHDh7Fq1SosXLiQXRzNmzcPq1atwpe+9CXU19fjmmuugSgeNbM7lu+BgThe7/Hxwuv1oru7u+DZgmxG8pzT10zX9bz7BAKBjAuZkf4Wer1ePPHEE3jiiSdw6NAhvPvuu/jlL3+JX//61zh8+DD3+J+AcGvKSQqd4k5PF6CpGwNFHg4cOIDZs2fnCHnTNLFu3bpRGdeZZ54JAHntqw4cOIDm5ua8y4uKivJOsY7kAmOwMei6PmqPFQAsFgubIk1/LUYbQRDwv//7v7jrrruwevVqXHrppRmWfiOFCpNsizoK/dKmz+nxZqj39FAcOHAAAHLcVYCRvbeGQhAEzJo1C1/96lfxj3/8AxaLBf/3f/+Xs11jYyMaGhoGHFO68wh9jfK9n3fu3InOzk6ceuqpwxIUx/q8bt68GaFQKGd59vidTifmzp2Ljz/+GN3d3QUfXxTFQce2bNkyiKKIt99+G2+//TbKysqYsw2Nwr/yyitYv349zjjjDPh8PrbvzJkz4fV68eGHHw5ozZcNTS0Zr94ehbxe9HG/+OKL2LRpU0YazfLly7Fhwwa88MILGdtShvoeWL16NYDhfQ+M93u8UEb6WTj77LMBYED75KGgj/+9997L+z7M991L38dNTU052+/fvz9jtgNI6YQZM2ago6MDH3/8cc4+hdhMTp06FbfccgtWr16NmpoarF69Ouc8nPGHi/kJyp/+9Cf8/e9/z5t20d7ejqeffhoAcP7557PlNAqTLycdSKWJ7N+/H62trWwZIQQPPvggdu/ePSrjpn7Ijz76KDo7O9lywzDwne98J+/jqaurQ29vb07k7pe//CXefPPNYY/h6quvht/vx3PPPcdyXSn//d//Pax8+UgkgocffnjAKNN///d/IxwOw+fzYe7cucMe63D5z//8T9x7771Yt24dVqxYccz5n0uWLMGsWbPw0Ucf4fe//33GutWrV+PFF19EcXExrr766mM6z0gZ6j09FNRqLvtHq76+PseXe6Ts3Lkzrw93T08PNE3LG2k0DAPf+973Mj4PBw8exI9//GMoioKbb76ZLf/Sl74EAPj+97+fkcahaRruvPNOABjQunUgjvV5DQaDeOihhzKWffjhh3juuefg8/ky3i933XUXNE3DF77wBfT29uYcKxQK4cMPP8wZX1dXF2KxWN7zFxUV4fTTT8f69evx9ttvZ4jTefPmoaSkBE888QSSyWSGqAVSUexvfetb6OzsxD//8z/n7dTc3d3N+hQAqe81n8+Hhx9+OG9dCyEE69atK/jiYLgU8nrRx/kf//EfMAwjR8wbhoEf/OAHGdtSPvWpT8Hv9+P555/P8Td/5plnsHHjRsyZM4cJ2EIY7/d4oYz0s/CNb3wDAPDd7343r7geqBaHUl1djUsuuQTNzc3sdaHs2rULP/nJT2CxWPDZz36WLV+0aBFEUcTvf//7jOcpEong61//et7zfPGLXwQhBN/5zncyLli6u7tzPsNAynZ2586dOctDoRAikQgURSnYIpRz/OCvyATlww8/xFNPPYXy8nKcd955rJDo0KFDeO211xCLxXDOOedkfICXLFkCu92OZ599Fqqqora2FoIg4HOf+xymTJmCO+64A1/96ldx+umn47rrroOiKHjvvfewe/duXHnllQU3uRiMc845B3feeSeefPJJzJ07FzfccAMcDgdef/11RCIRzJ8/P0e0f/vb38abb76J8847DzfeeCM8Hg82btyIdevW4frrr8fzzz8/rDE4HA48/fTTuOGGG3DhhRfihhtuQG1tLTZt2oS1a9fisssuKziaomka7rvvPjz44IM466yzsHDhQvh8PvT29uK9997Djh07IMsyfv7znw9YgDnaPPTQQ7Db7bjnnntw4YUX4q233kJJScmIjiUIAn7zm99gxYoV+PznP4/nnnsO8+bNw8GDB/HCCy9AVVX89re/zcn5Pl6cdtppqKmpwT/+8Q/cfPPNOPXUUyFJEq666irMnz9/yP2vvPJKnHLKKXjyySexY8cOnH766WhsbMSrr76KlStXjljMprNq1SrceeedOPvsszFr1iyUlZWhvb0dL730EkzTZDmn6cyfPx8ffvghFi1ahIsvvhi9vb147rnnEAwG8eSTT2b4Xd9000145ZVX8Mc//hGzZ8/GNddcwzy49+3bh+XLl2f4sRfCRRddhCeeeAL33HMPdu7cySJ+//Zv/1bQ/p/85Cfx9NNP46OPPsK5556LlpYW/PnPfwYhBD//+c8zIqi33norNm/ejP/3//4fpk+fjksuuQR1dXUIBAJoaGjAu+++i4svvhh//etfM8b3xz/+EZdeeinOP/98WCwWLFiwIMO3mxZlBgKBDHEqCAIuvPBC1pwoOwpNH+eOHTvwi1/8Aq+99hqWL1+O6upqdHV14eDBg1i3bh3++Z//Gf/93/8NIHXx8MILL+BTn/oUlixZggsvvBBz5syBoihoamrChx9+iMbGRvT19Q3LkKBQhvpuB4C5c+eirKwMHR0dsFgsOO+889j+559/PmRZRmdnJ6qqqnJ6PDgcDjzzzDO47rrrsGLFClx33XWYNm0atm/fjtdeew1erxe//e1vh5W2Nd7v8UIZ6Wfh4osvxr333ouHH34Ys2bNwqc+9SlMmTIFnZ2dWL9+PaZPn57xns7HT3/6U5x77rm49957sXr1apx99tnMZz4Wi+Hpp59mxdJAKsf985//PJ555hksXLgQK1euRCwWw5tvvom6ujpUVlbmnOPOO+/Eyy+/jNdeew3z58/H5Zdfjlgshr/85S9YsmRJzuzJtm3bcM011+D000/HvHnzUFlZib6+Prz66qvo7e3FXXfdNSopnpxRZvyMdDiD0dTURH784x+Ta665hpx22mnE7XYTWZZJaWkpWb58OfnpT3+a4e9Oeeutt8i5555LnE4ns1FMt9z69a9/TRYsWEDsdjvx+/3kU5/6FNm+fTuzUcu258IgFnED7UNIytN37ty5xGKxkNLSUnLLLbeQtra2vJaNhBDyyiuvkE984hPE6XQSj8dDLrroIrJ27doBrbio9d5grFmzhixdupTY7XbidrvJxRdfTD766KNBx52NYRjkjTfeIHfddRf5xCc+QaqqqoiiKMRut5NZs2aR22+/nezcuTNnv8GsKQeyWcy3z2CWdE899RSzyKS+2oPZrA3kdUxIyjb0lltuIZWVlURRFFJSUkKuv/56smXLlpxth7LJG2p9Pvu0wdi0aRNZsWIF8Xg8zB6Uvh+Gek4JSVkZ/tM//ROprKwkVquVzJ49mzz++ONE07QcWzZCjr6vB7J/y95n9+7d5I477iCLFi0iJSUlRFVVUlNTQ6644gry1ltvZexLX4OlS5eS5uZm8k//9E+kuLiYWCwWcsYZZ5A//elPec9pGAb56U9/ShYvXkzsdjuxWq1k/vz55Ac/+EFeW7tCPh9PPfUUmT17NrFYLOy7YijSn+9du3aRK6+8kni9XmKz2ch5551H/v73vw+47+uvv06uuuoqUlZWxt5jp59+OrnrrrvI5s2bM7bt7Owkn/vc50h5eTmzs81+janPN4Acn++f//znBACxWCw53vYU0zTJH//4R3LRRReRoqIioigKKS8vJ5/4xCfIvffeS/bt25ezz+HDh8k3v/lNcuqppxKr1UqcTieZMWMGufHGG8kf/vCHDLvg4b6PhmKo73ZCjlqDXnjhhTn7L1myhAAgn//85wc8x6ZNm8h1111HSkpKiCzLpLKyktx6663k4MGDBY/zeL7HB7KmHO53z0g+C5TXX3+dXHbZZew9VFVVRVauXEleffXVnOck3/dUW1sb+frXv06mTJlCFEUhPp+PXHbZZQP+PiUSCXL33XeTmpoaoigKqa2tJXfffTeJRqMDPr5QKES++93vkpqaGqKqKpk2bRq5//77SSKRyHkfNjU1ke9///tkyZIlpLy8nKiqSioqKsiFF15YkA0nZ3wQCBnDRF8Oh8PhMBoaGjB16lQsXbqUt0XnnJDw9ziHc/zhOfMcDofD4XA4HM4khYt5DofD4XA4HA5nksLFPIfD4XA4HA6HM0nhOfMcDofD4XA4HM4khUfmORwOh8PhcDicSQoX8xwOh8PhcDgcziSFi3kOh8PhcDgcDmeSUrCY13Udzc3N0HV9LMfD4XA4HA6Hw+FwCqRgMd/e3o6amhq0t7eP5Xg4HA6Hw+FwOBxOgfA0Gw6Hw+FwOBwOZ5LCxTyHw+FwOBwOhzNJ4WKew+FwOBwOh8OZpHAxz+FwOBwOh8PhTFK4mOdwOBwOh8PhcCYp8mgdiBCC7u5uxONxGIYxWofljAGSJMFqtaK4uBiCIIz3cDgcDofD4XA4I2RUxDwhBC0tLQiFQlBVFZIkjcZhOWNEMplEOBxGIpFAVVUVF/QcDofD4XA4k5RREfPd3d0IhUIoLS2F3+8fjUNyxpienh50dnaiu7sbJSUl4z0cDofD4XA4HM4IGJWc+Xg8DlVVuZCfRPj9fqiqing8Pt5D4XA4HA6Hw+GMkFER84Zh8NSaSYgkSby+gcPhcDgcDmcSw91sOBwOh8PhcDicSQoX8xwOh8PhcDgcziSFi3kOh8PhcDgcDmeSMuHEfFI3EYgmkdTN8R4Kh8MZY/jnncPhcDicY2PUmkYdKztbgnh+UzPe3tMBzSRQJAHLZ5bh+jOrMbfKM97DGzWSySRUVR3vYXA448rJ8nnncDgcDmesmRCR+Ve3t+K2327E85uaEYrr0A0ToZiO5zc147bfbsSr21vH7NzPP/885s6dyzqiXnLJJTDNVJQwFovh61//OkpKSmC1WrFs2TLs2LGD7VtXV4f/+Z//yThecXExnnnmGfbvCy64AN/61rfwzW9+E36/H9dccw1M08Sjjz6KadOmwWKxoK6uDk899RTbh66vq6uD3W7HGWecgddee23Qx/HOO+/AZrNluNNs2bIFoigiFAody1PE4Ywq9PP+3IbD6I9px/XzzuFwOBzOica4i/mdLUE8+toehBI6ih0qihwq3FYFRQ4VxQ4VoYSOR1/bg12twVE/d1tbGz7zmc/gS1/6Ej7++GOsXr0al112GVv/L//yL3jppZfwhz/8ARs3bkRpaSkuvfRSRKPRYZ3nV7/6FZxOJz744AP88Ic/xEMPPYQf/vCHePjhh7F792785je/gc/nY9s/9thj+MMf/oCnn34aO3fuxFe/+lVce+212LJly4Dn2LZtG+bMmZNhEbp161ZMmzYNLpdrWOPlcMYK+nnvj2uQtRiQCEPQ43BZBPjH+PPOKRye/sThcDiTh3FPs3l+UzP64ykhL4pCxjpRFOC3q+iOJPH8pmbMqRzd6fe2tjbouo5rr70WU6ZMAQDMnz8fABAOh/Gzn/0Mv//973HxxRcDAH7961+jtrYWf/jDH/CVr3yl4PPMnDkTjz76KIBUg63HH38cP/vZz3DzzTcDAKZPn862TSQSePTRR/HOO+9g8eLFAIDbbrsNq1evxtNPP40f//jHec+xbds2LFy4MGPZ1q1bc5ZxOOMJ/bw7JBMxgYAg9Z5PJBKQJRl2iwX9cW1MPu+coeHpTxwOhzP5GNfIfFI38fbHHZBFIUfIU0RRgCwKWLWnY9SjRAsWLMAFF1yAefPm4dOf/jR+9atfob+/HwBw8OBBaJqGc889l21vt9tx+umnY8+ePcM6z6JFi9jf+/fvRzwex7Jly/Jue+DAAUSjUSxbtgxOp5PdXnzxRRw8eHDAc+QT7lu2bMGCBQuGNVYOZ6xI/7wnE4nUQnJ0vW7oiEUj0BIJvLypAc2tbeMz0JOU8Ux35HA4HM7IGdfIfDSpQzMIZCm/kKfIkgDNIIgmdajy6BWPSpKEt99+G++99x7eeOMNPPHEE7jvvvuwadMmto0gZI6NEMKWiaIIQkjGek3Tcs7jcDhylmUflxIOhwEAb7zxBsrLyzPW2Wy2vPvouo7du3dniHnDMLBt2zZ897vfzbsPh3O8oZ93ASYMM1XbQUBgGgSiKLDPhAATsXgS//uzX2DGlCosXrwYs2bNgiyP+0TiCQtNf+oJRWGFDkEETF2ABAEiBPSGNPzb85vRvHszpvoUqKoKRVGgKEf/HmrZQN95HA6Hwzk2xvXX0a7KUCQB8SQZdDvdILCpAuzq6A9XFEV88pOfxCc/+Uncf//9KC0txZtvvolrr70WiqJg3bp1uPHGGwGkCmK3bt2Km266CQBQUlKC9vZ2dqyGhgYW2R+IGTNmwGazYfXq1fj85z+fs3727NlQVRVNTU0477zzCnoMH3/8MRKJBCorK9myv//97+jv7+eRec6EgX7e+zWDffEIggBJFGAS88iFsggTAmQYkGGisbERjY2NcDgcOP3003HmmWdm1JdwRgea/mQTDJimgbQ6egCAQoBIUsFLO7rxCaVpROfIFvrDuRAYapkkSfxigcPhnLSMq5hXZRHLZ5bh+U3NME2SN9XGNAl0k2DFrDKo8uhmBX344Yd4++23cfHFF6OkpATvvvsuwuEwTjvtNDidTtx+++2466674PV6UVVVhYceegiqquKf/umfAKScan7729/iiiuugNVqxd133z2k7aTVasX3vvc93HXXXZBlGWeffTba2tpw6NAhfPazn4XL5cIdd9yBb33rW9B1HUuWLEEgEMC7776L6upq3HDDDTnH3LZtGwDg//2//4evfe1r+Pjjj/GNb3wDoigiFouN6nPG4YyU9M+7y+FGMpFAMpEEEQhEITXLZRICIgioFvogwgSQ+k6IRCJYt24d3nvvPZxyyilYvHgxTjnlFIjiuNfwT3rS058EmvdE4ytHvpIFARAIQZPhxSK5GZIweAAmH5qmQdO0YRsIFIIgCEOK/mO5eEg3FuBwOJyJxrjPW19/ZjXe2tWOnmgSfntmEaxpEvREk3DbZFx/ZvWon9vtduPdd9/FD3/4Q4TDYUydOhVPP/00PvGJTwAAfvCDH4AQgptvvhmhUAhnn3023njjDdjtdgDAPffcgwMHDuDiiy9GWVkZnnzyyYwUnYG49957IYoi7rnnHrS3t6Oqqgp33HEHW//YY4+htLQUjzzyCA4dOgSfz4fFixfjvvvuy3u8rVu34vLLL8fHH3+MuXPnYv78+Xj88cfx5S9/Gf/7v/+LH/3oR6PwbHE4xw79vPcndPjtTjjsBPFEAol4HLphICkoUIiOOrTDMAwQQiBJEhPthBDs378f+/fvh9frxZlnnokzzjgjbyobpzDS0x31I2mDBAAhJlLXU0JKzIPAhAAdIiQYgx7zeEMIQTKZRDKZHJPji6I4arMI+S4y+EUph8M5FgSSnfQ9AM3NzaipqUFTUxOqqzOFdUNDA4CU7/pIeHV76xG7Oh2yKKR+VIxURN5tk/H9y2fhivmVQx/oJOXiiy/GkiVL8MADDwxrv2N93TickTDQ5z2pG1ChYwHqUZZoAQBYLBYQQnJEfTqSJGHWrFlYvHgxamtrebrFMEnqJpY/+Q5CMR0kHgJJr0omR4L0hCAOGQp0XGZugCKJkGWZp7eMErIsj1kKkizL/DXicE5wxj0yDwBXzK/E1GIHnt/UjFV7OqAdyZFfMStlicYt6gZn27ZtuP3228d7GBxOQQz8ea/C9WdWo8r+SWzatAnr169HX18fLBYLFEWBYRjQNA2iKGakPRiGgZ07d2Lnzp0oLS3FokWLsGDBAlgslnF8lJOH9PQngRBk6D4hlWlDIAAQMUXqhxUKdF1HPB6HYRiQJAmyLHNxfwzoug5d18csLXI0ZxHypSCdCK958JVXEV6zpqBtncuWwXPlFWM8Ig6ncCZEZD6dpG4imtRhV+VRz5E/EWlvb0dFRQX27t2LU089dVj78sg8Z7wZ7PNuGAZ27NiBd999Fy0tLbBarRAEAYZhwDTNHFGfjqqqmDdvHhYvXpzjCsXJZWdLELf9diP6Y0l4bDJEgM2IGMREME5gV4BvLrSg3Kqz/Pd4PI5YLIZoNIpoNIpEIsG6UJumCUEQTgihxxkYURTHtLj5eKUgtdx5F0KrVkFQlEG3I5oG14oVqHryv47LuDicQpgQkfl0VFkcVfvJE53y8vIce0wOZ7Iw2OddkiQsXLgQCxYsQH19PdavX49Dhw6xdaZpQtO0lCNOVnQwmUxi06ZN2LRpE2pqarBo0SLMmTOH21sOwNwqD76/chYefW0PghnpT4BuivA5C0t3JISwJmDxeBzxeBxAqkBVFEWIoshmWJLJJLsooH8Pd5lp8g61441pmuw1HwskSRq1WYR8y9K/NwRFgVxSMuh49K6uMXmcHM6xwH/ZOBzOhEYQBEyfPh01NTXo6OhAfX099uzZg/b2doiiCNM0oet6XlEPAE1NTWhqasKbb77J7C2LiorG6dFMXEYj3VEQBFitVlitVng8HlaYGo/Hj3b6lWU4HA5YrVZYLJYRO8UQQmCa5ogvBIZapmkaD5RMAAzDgGEY7MJwtJFlGaqqYtHevShJJhEOBo8UfAtHZ5aO3CuyPL6dNjmcAeBinsPhTApUVUVVVRWcTidmzZqFaDSKzZs3Y9euXayBG3XAyVf0F41G8d5772XYW86YMYM7iaQxp9KDOZUe3HPZrFFJdxQEARaLhdUvUHGfSCQQiUTQ29sLURSZsLdYLAXPntCLN5vNNmBDvWOBEAJd18fkQiGZTELX9VEfM2f40HqF5JGLt0FfF6sVvBKHMxHhYp7D4UwaRFGEz+dDLBaDaZq49NJLcemll2LLli3YuHEj+vr6MkT9QA44Bw4cwIEDB+DxeJi9pdPpHIdHNDEZq3THbHEPgIn7aDSKvr4+iKLItrFareOWGkW96xVFYXbEowkhZFRnEbLXG9mdvzjHDK//4ExUuJjncDiTDpvNBlVVmfg755xzsGTJEhw8eBAbNmzAvn37YJomm6IfSNQHg0GsXr0a77zzDrO3nDJlCv/RPo6oqgpVVeFyuQCAFdbG43EEg0EAYMKeOhudCAiCwB77WEBrE8ZqZuGkTEHi3wucCQoX8xwOZ1IiSRKKi4sRCoXQ1dUFr9eLU045BaeccgqCwSA2bdqEzZs3IxwOD2hrSTFNE7t27cKuXbtQUlLC7C2tVus4PLKTGxoNTxf3NN8+GAyCEJKRljNWYniyI0kSJEkak/cwnf0arVmEfMsmIvwinzNR4WKew+FMalwuF6xWK3p7e2G1WuF2u+HxeHDhhRdi6dKl+Pjjj7FhwwY0NDQM6oBD6erqwuuvv45Vq1Yxe8uKiopxeGQc4Ki4p2lQuq4zt5xQKATTNDPScrIdSjijjyAIrLfAWNYrjFVx80jrFfjbijNR4WKew+FMehRFQWlpKYLBILq6ulBUVMSaGM2ZMwdz5sxBV1cXNm7ciK1btyIWiw3qgAOkIsKbN2/G5s2bUV1dzewtT5Q0j8kKFZEOhwMAmNNJIpFAT08PdF3PSMtRVZWL+0lGer3CWEAv6tMFfn9HB/SuLiguV6rHAo72WsCRe1nikokzMeHvTA6Hc0IgCAK8Xi/i8Ti6u7vhcrmY4AOAkpISXHbZZVi+fDl27NiBjRs3orW1dUhRD6Sa5jU3NzN7y0WLFnF7ywmCJElwOBwZ4p6m5fT29kLXdaiqmpGaw8X9yU16kTVFt9kRMgwIR+o08hKNgkzQFCDOyQ0X8xwO54TCarWitLQUfX19iMfj8Pl8GcWvqqoyB5uWlhZs3LgRO3bsQCKRGNQBBwBisRjef/99vP/++5g+fToWLVqE0047jdtbTiAkSYLdbmcONLSpUTweRyAQgKZpUFWVRe+PZ5dRzsTFuWzZmGzL4RwPBFJgSXpzczNqamrQ1NSE6urqjHUNDQ0AgLq6utEe35hzwQUXYNGiRfjP//zP8R4KgOM7nsn8unE4hRCJRBAKheD1egctBIxGo9i6dSs2bNiArq6uIUV9Om63m10c0KJNzsQlvWNpIpFAMpmEoigZqTlc3HM4nMnEuEbmg6+8ivCaNQVt61y2DK4rVsI0CURRgDiBpkmTySR3VOBwJiAOhwMWiwW9vb2Ix+PweDx5UyzsdjuWLFmCc845B/X19diwYQN27do1qAMOpb+/H2vWrMHatWsxc+ZMLF68GHV1dTyVY4IiimJGoylCCBP2/f39SCaTkGU5Iy1npF1qORwO53gwrpH5ljvvQmjVKghDFLmYmgZl6TLo338IBKmKcrdVgc+uwKaO/Hrk1ltvxW9+85uMZYcOHcKuXbvw6KOPYteuXZBlGZ/85Cfxox/9CDU1NQBS0fMFCxaAEII//OEPOPvss/Haa68hFArh9ttvx0svvQSv14t7770Xv/jFL3DFFVfggQceSD0W08R//Md/4Oc//zk6Ozsxc+ZMPPzww1i5cuWA48n3vF522WWoqqrCL37xCwDAX//6V9xxxx04dOhQwY+fR+Y5JwuEEPT39yORSMDn8xVUWNff349NmzZhw4YNCAaDQ4r6dIqLi5m95Vi4fXDGDtqllhbVJhIJyLKc4ZjDxT2Hw5lIjHvOvKAokEtKBlxvEAJ0dSGpmyCEQIAAkxD0RZIIxjRUeKzw2kcWFX/qqaewb98+LFy4EPfddx+AVJHchg0b8J3vfAfz5s1Df38/7r77btx0001477332L6/+tWv8I1vfAMffPABW3bnnXfiww8/xGuvvYaioiLcc8892Lt3L6644gq2zWOPPYY//vGPePrppzF9+nSsWrUK1157LdavXz/gePJRWVmJlpYW9u9LLrkEjY2N2LVrF+bMmTOi54PDOVERBAEej4cVRTocjiE7vrrdbixbtgznn38+9u7di/Xr1+PgwYNDFssCQHd3N9544w28/fbbmDt3LhYvXozKysrRflicMSC7Sy0V9/m61NLo/Xh1qeVwOBxgAoj5wTAJYJgEhKSi8aIAQAAEiAAIdJOgLRiHRRZHFKH3eDxQVRV2ux3l5eVs+Q033JCx3c9//nNMnToVzc3NbFZi5syZePTRR9k2oVAIv/nNb/Dcc8/hggsuAAD8+te/zpjFSCQSePTRR/HOO+9g8eLFAIDbbrsNq1evxtNPP40f//jHeceTj6qqKmzYsIH9m0aLOjs7uZjncAbAYrGgpKQEgUAA3d3d8Pl8Q0ZZJUnC7NmzMXv2bHR3d+PDDz/E5s2bkUgkhhT1mqZhy5Yt2LJlC6qqqrBo0SLMnTuX21tOIrLFPYAccS8IQkZaDn99ORzO8WSCi/mjQh4E0HXjiJhPfcEKEGAQgu5QHBUe65A/rIWyf/9+3Hvvvfjwww/R3d3N2lY3NjYycb5o0aKMferr66FpGs466yy2rLS0NCOF5cCBA4hGo1iWVQmfTCZzlg1FdmR+7dq1kCQJZ5555rCOw+GcbIiiiKKiIkSjUXR1dcHj8RScClNcXIyVK1fi4osvxvbt27F+/Xq0trZCluUhv3taWlrQ0tKCN998EwsXLsSiRYtQXFw8Gg+Jc5xRVRWqqmZ0qY3H44jH4wgesTbMbmTF4XA4Y8WEFfMEKTGf/m+A0P+OCGwBRBAQjCUhxvshSamcVkVRWGMR2jhmOCL/yiuvxLRp0/CrX/0KFRUVCIVCOOuss5BMJtk26f7VYOPJbfecXpIQDocBAG+88UZO5H24ebVVVVWsqE+WZfzLv/wLbrvtNrjd7mEdh8M5WbHb7VBVlVlYer3egr8nFEXBmWeeiTPPPBPNzc344IMPsHPnThBChnRCicfjWL9+PdavX49p06Yxe0uehz15oQ2OqLjXdZ3l3Pf394MQktPIisPhcEaLCSvmc8lXp3tU3MeTCQg4aidHb5R0cZ9+U1UVhmGw7Xp6erB37178+te/xjnnnAMgJb6HYvr06VAUBR999BGuvvpqAKm28IcPH2bbzJ49G6qqoqmpCeedd17e42SPZyBo/m1rayv+8z//E7qu4z/+4z+G3I/D4RxFlmWUlJQgFAqhs7MTPp9v2EKruroaN9xwA1auXImNGzdiw4YNCAQCBYnz+vp61NfXw+VyMXtLfkE++ZFlGU6nk9Vl6LrOvO5DoRBM02SRe96llsPhHCuTSMwLEAQRhKp3Ju4FCKlYPYBU9790MUxFvWEY0DQt5wuzvLwc69atw65du+B2u1FaWgq/34+f/exnKC0tRX19Pe6+++4hR+dyuXDLLbfgrrvugtfrhd/vx913353RbdDlcuGOO+7At771Lei6jiVLliAQCODdd99lgqCurg7r16/H4cOH4XA4UFRUlDfSV1VVBQD42te+hsOHD+Pdd9/lrhkczghxuVywWCzo6+uDzWaDy+Uatriy2+04//zz8clPfhL79u3DRx99hP379xfkWR4KhfDOO+/g3XffxWmnnYbFixdj6tSpXOCdINDgUXaX2ng8jkgkAl3XM9JyuLjncDjDYcKKeQGAKKRy4oFU+ookywBSefQEBMQkMAigEA0Dfe2ZpgnTNNm/06P2oijitttuw7e//W0sWrQI8XgcH374If73f/8X9913H+bMmYOZM2fikUcewZVXXplxnHw8+eSTuO2223DZZZfB5/Ph3nvvxaFDhzKa1Tz22GMoLS3FI488gkOHDsHn82Hx4sXMveY73/kObrnlFsyaNQuxWGxAa8rS0lLIsoy2tjasWbNmQNcbDodTGKqqorS0NKM4diQuJYIg4LTTTsNpp52G3t5efPTRR9i2bRui0eiQAs00TezZswd79uyB3+/HokWLsHDhQn6hfoKR3aWWinvqtqTrOlRVzSiq5eKew+EMxLj7zIfXrh3QmtIkgG6aIN3dEM45F+J9/562NuVmIwoC6vx2WGSROQzQrn66rg8pwAVByBH4gyFJUt50nXz7BQIBVFZW4ne/+x2uu+66IZ+P4w33medw8hOLxRAMBuF2u5ngOhY0TcPOnTuxadMmNDU1DUuYybKcYW/JRd2JT3qX2ng8Dk3ToKpqRmoO71LL4XAo4x6ZJ5oGvatrwPUCIYCugRwR9qmkmlSqjSgKqPBYmS2l1WrNiIKbpglN0zIEvmmaMAyDiXxCSN7UnPTc+3TotolEImO5JEnYuXMnGhoacNZZZ6G/vx8PPfQQ3G43Lr300mN+njgczvHDZrPlFMcei3hSFAWnn346FixYgNbWVmzevBm7du3K+R7Jh67r2Lp1K7Zu3YqKigosXrwYc+fO5UWUJzDZXWpN02SNrGiXWkVRMopqubjncE5exlXMOwu0YzRMAvOc8xAXBRCSSr8ppAMsbeyR3vxD13Xous5EPhXn6SJ/qNScfJExwzCQTCbxwx/+EPX19bBYLFi4cCFefPFFxONxGIYxZCSfw+FMHCRJQnFxMcLhMCuOTfcaHwmiKKK6uhpVVVVYtmwZdu3ahW3btqGtra2g/dva2vDyyy/jrbfewoIFC7B48WJub3kSIIpiRrCKEMKCVKFQCN3d3ZBlOSMth7sjcTgnD+OaZjNcTEJgmgSiKEAchalmGpWnAl/X9Rxhny72KTQ1h0bvRzLtLYpi3nSd4/kFzNNsOJzC0DQNfX19sFgscLvdo5bqQruLNjc3Y/v27di9ezc0TRvWMaZOnYpFixZh5syZXMCdpKR3qaWWmLIsZxTV8vcGh3PiMqnE/PEgn7gHUl+WNGKfHc0fbt79YBxPkX8ivW4czlhDCEF/fz8SiQR8Pt+oNwJKJpMIh8PYs2cPtmzZgu7u7mHt73Q6mb2lx+MZ1bFxJheEENbIikbw6Uw1jd6PpLibw+FMTLiYHwLTNDPEva7rGesHEvnAEQeeI9H7Y0UQBMiynLch1kg5kV83DmesiMfjCAQCGT7io4mmaSxav23bNnz88cdDFvKnQ910Fi1ahOnTp/OCWQ4AZBhExONxCIKQkZbDu9RyOJMXLuaHSXrePb0N9BSmC/z0/QVBGLUfWCry80XyhzrHyfS6cTijiWma6OvrAwB4vd4xmTmjtT2RSAQ7d+7E5s2b0d/fP6xjFBUVMXvL0XDl4Zw40LoxGr0HkJGWw8U9hzN54GL+GMmXdz9UFM00TSbq049R4EtREIWI/JP5deNwRoNIJIJQKASv15vhpDWaUJtCQggaGhqwefNmHDhwYFjHkGUZc+bMweLFi1FVVcWj9ZwcdF3PSMuhXWpp9J67J3E4E5dREfNNTU1IJpOYPn36mAxysjFQ3v1gZOfa06h+IRcHw4Gm/iiKgvb2diSTSVRVVcHn83GHHQ5nBOi6jt7eXqiqCo/HM2ZCmdoTEkIQDoexZcsWbNmyBbFYbFjHKS8vx+LFizFv3jwu0DgDout6RlqOaZoZjax4l1oOZ+IwKmK+q6sL3d3dKC0thd/vH5OBTmaoMNc0LW/e/UDQHHlRFEc0AzAY0WgUPT092Lt3L5qbmyGKIoqKilBcXJxxKyoq4i4IHM4QEEIQCoUQi8Xg8/nGVCRT5xL6+d+3bx82bNiA5ubmYR3HYrEwe0veQZozFLS/Co3e67qe0cSKd6nlcMaPURHzhBC0tLQgFApBVVUu/gogvXCWEFJQik26a056ig49RrY//kDnNQwDsVgMTU1N2L9//6DnFgQBPp8vR+T7/X7uhsDhZJFIJBAIBGC32+Fyucb8fLTTtaIo6OnpwYYNG7B9+/Zh21tOmTIFixcvxqxZs/j3N6cgTNPMSMtJ71JLo/dc3HM4x4dREfNASiR2d3ezBkmc4UFtxOiXYyE/xrRJCL3RgiXDMNhx0u+TySRbH4lE0NLSckx5+lTk+/3+DJHPp+45JzOmaSIYDELX9eM2s6VpGjRNgyzLME0T27dvx4YNG9A1SHftfDidTpx++uk488wz4fV6x2awnBMSWttBf280TYOiKBmOOTyVk8MZG0ZNzHNGl3g8ju7ubnR3d6Orqwt9fX1DCm9VVZmoLikpQVFRUcaXZyKRQE9PD7q7u9l9d3c3AoHAqI7d4/FkiHv697F2z+RwJhPRaBT9/f1wu93HzUlG13Ukk0lWF9PU1IQNGzZgz549wwqyCIKAU089FYsWLcIpp5zCI6ycYUNrPGhAKZlMQlGUDMccLu45nNGBi/lJgq7rGQK8u7t7yNx7mgdfUlLCBHW+qLmmaTnHpiJ/NB12XC5X3nQdm802aufgcCYShmGgt7cXkiTB6/UeN/FiGAaSySQEQYDFYkEkEsGWLVuwceNGBIPBYR3L5/Mxe0uHwzFGI+ac6BBCWOSe3miXWhq95yleHM7I4GJ+kkIIQSAQYJH77u7uglwtaNScCvzBfpypS0e2yO/t7R1Vke9wOHJEfnFxMffF5pwwhEIhRKNReL3e4zpDRVMfALDz7t+/Hxs3bsSBAweG9TmWJInZW1ZXV/NoPeeYoIXc6emgkiRlpOXwuiwOpzC4mD+BiEQiTNh3dXUV1GDGZrMxYV9SUlKQtZ5hGOjr68sR+T09PaNqo2mz2fKK/LHousnhjDXJZBJ9fX2w2WxwuVzHVQzTqCghhJkU9PX1YePGjdiyZQui0eiwjldWVsbsLXn6HGe0SE/LSSQSEEUxIy2Hi3sOJz9czJ/AJJPJjLz73t7eIcW2LMsZkXu/31/w1CftiplP5I9mUbTVas0r8o+HewiHcywQQhAMBpFMJlFUVHTcxQmNhhqGAVVVIcsydF3H7t27sWHDBjQ1NQ3reBaLBfPnz8fixYtRWlo6RqPmnKzQyD2N3tO0MRq9511qOZwUXMyfRNCIenr0fijXHEEQcvznh9vpMj0lKFvkD9dCbzDSC4DTb263m6cEcCYU8XgcgUAALpdr3PLQ020tqSjq6Ohg9pbU/apQamtrmb0lj6ByxgJN0zLScgghGWk53EmNc7LCxfxJDCEE/f39GXn3kUhkyP1oISuN3o80Ip59/vQi3OEKicFQFCXDVYfevF4vF/mcccMwDOYkNZ4dmNMdcGhXz0QiwewtOzs7h3U8h8OB008/HYsWLeL2lpwxRdf1jLQc0zQz0nK4uOecLHAxz8kgGo1mpOYUYluZnvZSUlIyKq4dVOSnR/G7urpYMd9oIMtyjsj3+/3jKqw4Jx/hcBjhcBher3fYs16jSbYDjiAIIIQwe8vdu3cP295yxowZzN6Sf6Y4Y42u6xlpOYZhZKTl0ItVDudEg4t5zqBQ20oauS8k/12SpByBPFq5jeFwOCddp1Ann0IRRTFvJN/n83HrNM6YoGka+vr6oKpqQUXoYwl1wBEEAaqqMhGebm853N4UXq8XZ555Js444wxub8k5bhiGkZGWo+s6i9zTW6GfteArryK8Zk1B2zqXLYPnyiuOZegczrDgYp4zLNKLXKnAHypaLggCvF5vRmrOaHvLp88opN8KSRsqlHz1A8XFxeNSyMg58aBpZ4lEAj6fb9yL+2jTn3QHHLr8wIED2LhxI/bv3z9se8vZs2dj0aJFqK2t5VFSznHFNM2MtBxN06CqakZazkAzSC133oXQqlUQhvhcEk2Da8UKVD35X2PxEDicvHAxzzlmQqFQhrgPhUJD7uNwODIsMcfKqi8Wi+VtiFXIGAtFEAT4fL68KTvjLcg4k49EIoG+vj44nc4JYcNKHXBM04SiKBkXroFAAJs2bcLmzZuHfeFcWlqKxYsXY/78+dzekjMu0FkoGr3XNA2KomQU1VJx33LnXQivXQu5pGTQY+pdXXAuXcrFPOe4wsU8Z9SJx+MZefd9fX1DRu+oEw0V+EVFRWOaY5tIJPJG8gvx5h8OdEaCivvBOvFyOBTTNBEIBGCa5oRK78rngAOkcpX37NmDjRs34vDhw8M6pqqqmD9/PhYtWoTy8vLRHjKHUzC0HwON3ieTSSiKAovFgv4HHkRs3TpIxX4YugHDMKAbqXsQAo/HA4CLec74wMU8Z8zRdT0nOq7r+qD7iKKIoqIiJu6PlwBOJpN5I/nDzREeCrfbnZOu4/f7x7UAciKR1E1EkzrsqgxVPnkLJ6PRKPr7++HxeEY9Ne1Y0DQNmqZBluWcz2VnZyeztxxuwXpNTQ0WL16M2bNn89Q1zrgSDofR2dmJlpYWtLa2ouR3v4e/oQFRiwWCAAgQgCOzyQKAIr8fAriY54wPXMxzjjvpvvM0NaeQAlaPx5MRvT+ehXSapqG3tzdH5Bcy6zAcnE5nXq/8iSTkxpKdLUE8v6kZb+/pgGYSKJKA5TPLcP2Z1Zhb5Rnv4Y0Luq6jr68PsixPODtVXdehaRpEUcxxCkkkEtixYwc2bNiAjo6OYR3Xbrcze0ufzzfaw+ZwGKFQCF1dXejq6kJnZyf7Ox6PZ2x39nvvo6K1FTGrFQQACAFBKs2S1oXJksTFPGdc4GKeMyGIRCIZzawKSXex2+0Zlpjj4QKi6/qAIn+obrvDIf2xpt9OJGeQV7e34tHX9qA/rkMWBciSAN0g0E0Ct1XG91fOwhXzK8d7mONGf38/YrEYfD7fhEvTSre1zC4iJISgubkZGzZswK5du4Ztbzl9+nQsXrwYM2bM4PaWnBHT39+fV7QXOnvExPwAgRWXywmLauFinjMucDHPmZAkk8mMyH1vb++Q4pg2h6KRe7/fP265xrTbbrbI7+3tHZaYGQqbzZZX5E+EwsnhsLMliNt+uxGhhA6/XYUoHr0oM02CnmgSLouMp29ZhDmVJ2eEHkh9Lvr6+mCz2eB2u8d7ODmkO+CkFw9SIpEItm7dio0bN6Kvr29Yx/Z4PMzecrK9vznHj2AwyIR6unA/1kaE2WJeACBKEmRJSjVcs1h4ZJ4zbnAxz5kUUHGcHr3XNG3QfdKtJKnAH2/XDFrYmC3ye3p6hqwjGA4WiyWvyJ+IAhAAHnh5F57f1Ixih4qkloRpGpAkCaIoQRJFEAjoiSRxw6Jq3H/lnPEe7rhC09R0XYfP55uQueW0kDDb1jJ9/cGDB7Fhwwbs27dvWKlqoigye8spU6ZMqLQjzvGBEJJXtI9293DgaN+R09e8A/e+fRD8fsiSBFGSkO+dx8U8ZzzgYp4zKaGe3OnR+0Ks8VwuV4a4d7lcx2G0Q5NeR5BdgDvURctwUFU1b0Os8WxUlNRNLH/yHYRiOoocKkKhEGLxGARBhCxJEI5Ed+OmBLsi4KlL/PD7vPB4PPB4PLDb7eMy7vEmFoshGAzC7XZP2OeA2loahgFVVfNeeASDQWZvGQ6Hh3X8kpISLFq0CAsWLODF4ycg9HsxXbTT22h+LwIp0U5/G+ittLSUOatxa0rORIaLec4JQ3rjqK6uroIcaKxWa0bevdfrnVB5uekXLdlCf7hOIYMhy3Jedx2fzzfmIj8QTeLSp/4B3TDhtiro6+tDUtNYtFYUBUiSDA0SJAH4Sk0XbNLRry1FUeDxeOB2u+H1HhX5brd73Gdixho6YyWK4oR772YzmAMOkHos1N6yoaFhWMdWFIXZW1ZUVIzSiDnHC0IIm3nNvo3mjCVwtEN5tmj3+XyDfn540yjORIaLec4Ji6Zp6OnpYZH7np6eIfPV6Rd9+m0ipjEAR5t1Zd+yXRiOBUmS8kbyh/rhGw7pkXmfQ0EkHIGu69B1HSYxkfqGIkgKKiyCic/5D8GqpnzOU6k4A4/DarUycZ9+c7vdE/Z1HQnhcBiRSARer3fCX8Douo5kMpnKM85ywKF0dnZi48aN2LZt27AvWqurq7F48WLMmTPnhHqNTwSoaE8vQKXfz6Mt2mmAIl20l5SUjPi7K/jKqwivWVPQts5ly+C58ophn4PDGSlczHNOGkzTZEWp9AdkKKFALcfSU3Mmuk1kJBLJK/Kj0eionYP2AcgW+UVFRSMqOk7PmafFr4QQ6LqORDIJTdMQMUScInVjvr4PNpsNNpsNhBBIkgRZlll30kJ/qB0OB7xeL9xud4bQd7lcEzrCPRCapqGvrw8WiwVut3vC55KnO+BYLJa8400mk8zesr29fVjHt9lszN6yqKhotIbNKQD6XZvtHNPd3T2qBgBASrRnC3Yq2if6Z4DDGS24mOec1FCPYSp4Q6HQkPs4HA4m7EtKSuByuSbFj0Y0Gs3Jx+/q6iqo1qBQaNFxdjTf7/cPGiUdys2mO5KAXRbwjdNV1DgFxONxJJNJOBwOiKKIYDCInp4e5vsvyzIT+MO9uBBFES6XK29E3263T+jXmhYGJpNJ+Hy+jC6tExXTNNlFdT4HHCD1uFpaWpi95XCjuNTe8tRTT52UF2oTFdM00dvbmyPaC5kFHS6KoqC4uBilpaUZon2i9V7gcMYDLuY5nDTi8XiG0C2kKZSqqhmRe1owNVmIxWJ5u94WcmEzHHw+HxP29Pny+/1McA7qM2+T8f3LZ+HimcVobm5GS0sLEwuSJKGiogLV1dWwWCxMXFAr0GAwiFgshkQiwaL4I/3xl2U5J5JPbxOpADMejyMQCMDlck2aXgRDOeBQotEos7fs7e0d1jncbjezt5woxe+TAcMw2OcqXbj39PSMaj8NICXaaR57umgfzyJ9Dmeiw8U8hzMIuq7nCN2hooI0BYWK++Li4gnX5KcQEolEXpEfDAZH9Ty0s29xcTH6JTfebzPx/uEQdBNQJAErZqU6wKb7y+u6jtbWVjQ3NzMrOkEQUFpaitra2gwBq+s6EokEm5no6+tDPB5HLBZDLBZDNBodlToDmt5CxT0txnW73eMSIaepDgDg9XrHrefCcCnEAYduV19fjw0bNmDv3r3DtrecOXMmFi9ejLq6Oi4Sj2AYBqszShfthfT5GC6qqmYUoNK/J0OKGIcz0eBinsMZBukWkjT6G4vFhtyPClYq8CdLtDQfyWQyI/pNBf9wmwANhkEEWJxuVBQXobw0My8/PQJumiba29vR1NSU8Tr4/X7U1tbC48ltMEXTOuiNNjlKJpOIx+OIRqOIRqMIBoMIBAKj4lttt9vzRvNdLteYi+xIJIJQKASv1zuhZg8KIZlMQtd1KIoy6AVRMBjE5s2bsWnTpmHbWxYXF2PRokVYuHDhpHt+RgoNUmT7tPf29g7roqgQLBZLjnMMFe0cDmd04GKewzlGIpFIRjOr/v7+Ifex2+0ZlpgnwhRyvlmMnp6eURcIDocjQ9zTdJ1IJILGxsaM9CCPx4Oamhr4/f4Bn18q5BOJBMvFp0WZVqsVhBAm7tNv/f39x+zAIQgCnE5nXqHvdDpH7T2h6zp6e3uhquqkfK+l21oqijLg+A3DwMcff4yNGzfi0KFDwzqHoihYouuoae8oqMh9MjiW6LrOvpfSb2Mh2q1Wa17RztOZOMcKdxIaGi7mOZxRJplMZkTuC5miVhQFfr+fRe79fv+kSYsYCppvm52uM9pT9/QCyWKxQNM0CIIAj8cDm80Gh8OBmpoalJaWFlTPQMU9vRFCYLFY2I2mTeUT+cFgEKFQ6JgfmyiKOXaax9IoixCSasgVi6GoqGhSFMdmo+s6NE2DKIoD2lpSuru7sXHjRmzdurXgNKqz3/8A1c3NIJKUsj2VRAh5+nxONC9xTdPyivZCan6Gi9VqzclnLy0thdPpHNXzcDgU7vE/NNyEl8MZZVRVRWVlJSorKwEcbeyTHr3P7l6oaRra29uZ/R51hUlPzZno/uEDIUkS+9FPJ90qNDuaPxInjGg0isbGRvbvRCKBQCAAXdeZEC4uLsasWbMwe/ZseL3eAY+lqipUVWVRRZp3n0gkEImkfPBVVYXFYoHP50N5eXnGRYJpmgiFQujv788Q+YFAoGD3IPr85Etfoo2y8t0Gqs8QBIE10urt7YXdbp90UVPqUmQYBuLxOARBgKqqeS/QiouLcemll+LCCy/Ezp07sXHjRrS2tg55DlMUETuSbiMKAixWC6wWa8bFtd7VNXoPahhQ0Z7t0z6aKW4Um82WI9pLSkq4aOeMC4KiZHTfJQAMXUdS06AcmbEbr8/lRIBH5jmc40x6V1cq8AsReC6XK6OodrIJsUIxTROBQCBv8e1I0lo0TWPR8qNdZVOt26dNm4aKiooMK81CCvDy5d3LspwRvR+ocFPX9RyRT2+jUYhbSKMs+hybpgmfzzdpZ4FM02Q1DwPZWqZD7S137tyZ97109vsfoLK1FbE8aTaKosBmtUJRVRhdXXAuXTpmEUA6u5ct2gvpaj1c7HZ7XtE+met6OCcWLXfehfDatRD8fmhaEloylXZnHvk+t1oscDqd0Mf4czmR4ZF5Duc4Q9M/PB4Ppk+fDiAVVU63xMz3ox0KhRAKhVBfXw8gJdrS88a9Xu+kssQcCOoGVFRUhBkzZrDl1EM9PYJP/x6sSJX6U/t8PpbrbpomE0oulwter5cJXbp9ukd+cXFxhp+1KIqscRUdG03NiUaj6OvrY3n32ak5siyzx5dNIpEYUOhnz+YMRDweRzweR0dHR8667Px8VVXR19eHysrKSSneRFFkdQ3JZBKmabLmYfmoqqpCVVUVLrnkEmZv2dPTU9C5aN6+3W7DaHlTJRIJ9plPF+6j7RgFpGpN8on2kaRscTjHg3g8jkOHDiHc2gprMonoADNQhX43nsjwyDyHMwHRNI25TRSaeiJJUk5H1pOlnT2d6ci+5evwa5omE8z0ORUEAQ6HAx6PZ8B0JlmWc5phUZGf7yJK07SM6L1pmlBVFVarlYn74RSiZufnp4v+Y8nPN00TsVgMbrcbFRUV8Pl8GYJ/Mol8Qgg0TSvIAYduf+jQIWZveda69waMzFN8Ph9IT8+wIoCJRCLH7rHQYvnh4nQ68/q0T/TO1RyOYRhobm5GfX09Dh48iJaWFhBCBp0xo/i8XpDeXh6Z53A4EwdFUVBeXo7y8nIAmfnlVOBnC1XDMNDR0cEisoIgwOv1ZqTmnKg/6G63G263G9OmTctYHg6H84p8WlwaDocRCASgaRrC4TDC4TCzkcx+rnRdz3h+KZIksfqG9JvP54PT6WQ5xrquM/vL7Lx7ehtsZsVut8Nut6OioiJjOSEE4XA4bzQ/HA4PWQApiiIcDgcikQh2794Nu92ecREoy3LeItyJ1igLAMuhV1UVmqYhGo1CluVB6wimTZuGadOmob+/Hwf37Qfa2gY8vqookEQRAyV7xePxHLvHrq6uUW/ABhxNu8sW7RPtNeFwBoIQgu7ubibeGxoaRmwFrGnaSS1oT+bHzuFMGkRRhN/vh9/vx2mnnQYglXZDhX2+jq2EEFZAuW/fPgCpqXYq7qlt3GSzKRwOVEzX1dVlLKdpTVR0HTp0CAcPHkQwGGQ+8xaLBV6vF3a7fdDnyDAMdpx0RFFkXW/Tb0VFRZBlOSPvvr+/f1h59+kIggCXywWXy5Uza2oYBkKhUF6hH41GM7a1Wq2QZRnRaBSKosBqtUIQBGY5mi8dxWKxDJifP95uOTQyr+s6otEoJEkadDbE7XajtLQUIVWF4nIhFo/nTN9ToUwIQSQaxaZNmzJE+3A97gvB7Xbn2D1m91vgcCYLkUiEiff6+vpRm53iYp7D4UxKqICj0eh4PJ6Rd5/Pli4SiSASiaChoQFAyrUl3TGnqKjohMi7Hwq73Y7a2lrU1tZmLG9ra8OOHTtw+PBhJnqpP/twfd9N02QieO/evWy5IAjw+XwsZYc+92VlZSCEFJR3XyiSJMHr9eZ17qGFwf39/QgEAhmOOzSqb7fbBy2OTSQS6OzsRGdnZ866gRplud3u4/oey+eAY7FYBnwtBQAKje4fuRCIx+MghCAWjyMcicCSTKKtvh7rX3111Mbp8Xhy7B4ns4sVhwOkvmcaGxuZeKeObaOBLMtQj1y0y4oCYxQMBCYrXMxzOCcIVqsV1dXVLDqbr4lTtoNHMplEa2srs+yjMwDpkeThCsjJTEVFBSoqKhAOh9HY2Iiurq6M7rBUBNKUp5FElQgh6O3tRW9vL/bv35+xzuv1Zjz3Ho8HoigimUwy7/rs1JyRzqykF/pmE4/H0d7ejpaWFhiGgWQyOexGWXSGoy0rbYXOJOQT+g6HY8xmiiRJgs1mg2mazDWIpjbRxm/9vb0QdR2R/n4YhsFqESRJBDHJqBTa0dS3bNF+Mn3OOCcuhBC0t7ez6HtjY+MxN9ejSKIIRVWgKCoURYF4As8qDxcu5jmcExRZllFWVoaysjIAqS/ZQCCQkZoTi8Uy9jFNMydlhPqz0wjyZCqIHClOpxOzZ89GLBZDc3Mz2traMtxoZsyYgaqqKgDI664zUgvBQCCAQCCAAwcOZCx3u90sD9/tdsPlcsHhcECSJCiKkiHuR8Nm0mq1oq6uDjU1NWyGwOfzQRAERKNRBAKBHNedQhtlUWvW/v5+NDU1ZayTJCknL5/eRqPeIxKJsJSYzs5OtLW1oaurC4lEAqIo4uzWVlRpGuT+/kF/HMUCHqfP58txjikpKRn39CMOZ7QJBoNMvB86dKjgXhpDYbVaMXXqVFQ2HIbS1gaLrgO6DiAGE0D2p5CcxK423M2GwzmJoRFJmppTSKSZdlqlAt/j8ZzQefdAagajpaUFLS0tLMokiiLKy8tRU1OTIzSpG1F2M6zR7shJHXg8Hg9Lu6IOO8PJux8KWhzs9XoHzNWmjbKyHXeG0yhrMFRVzVuEm69RVjgcznGO6erqyrl4BVIXFzQKP625BdWDFMCm01pVica6OhQVFeUI9uLiYi7aOScsiUQCDQ0NLHWmu7t7VI4riiJqamowbdo0TJ8+HZWVlRBFEcFXXkV4zZqCjuFctgyeK68YlfFMJriY53A4DNqshgr83t7eIaOtiqLA7/czEeP3+ydtE6KhMAwDra2taG5uZm5CgiCgpKQEtbW1Q3bH1HUdvb29Oe46vb29oyLyqTC1WCwsgu/3+1FeXs5sJ0eazqFpGvr6+tixh3MBN1CjrJ6+IEKxBFSRQBrB9WA8HkcoFEIymWTpQDQv3mq1QlXVYb0XdV2HaZqQJCljPzozke0cczLZv3JOXkzTREtLCxPvzc3Nx2SHm05JSQkT71OmTOE1IiOEi3kOhzMghmGgr68vI3o/VN6wIAjMqpEKnhPtC9o0TXR0dKCpqSnDFaaoqAi1tbV5C04Hgz7P2SK/p6fnmH80DcOAruswDAOyLMPr9aKsrIwJ/MrKyoK7CdMUmUQiAZ/PB0VRRhQ129kSxPObmvH2ng4kDROyAJxVbcOSCglFQoQJf/pei8ViCIVCCIfDrHlaKBQqKBeX1hhQf/90n//sQlz63i0qKoLP50NZWRkqKyszRDuPEnJOdGhdDxXvhw4dytuzYyQ4HA4m3qdNmwa32z0qxz3Z4WKew+EUDBVz6dH7QlIoqCc2FfhDRbAnC4QQ9PT0oLGxMSNFyeVyoba2FsXFxceUgkT7C+RL2RlpUZlpmkzc03u73Y6ysjKUlpaioqICZWVlKC4uHvCHNpFIoK+vD06nE8H77kdo1SoIQzVo0jS4VqzAllvvwqOv7UF/XIcsCpAlAbpBoJsETouELy8qwhxXyiWnubkZzc3NCIVCSCQSiMfjzM7zWGYyBEGA0+lkj7empoa5GxUVFbG6EBrtT3fAabnzrmE93pOxgQ1n8hGNRlFfX89y30erC7Esy5gyZQoT72VlZSd8WuZ4wMU8h8M5JqhnOxX4hRR/Wq3WjLz7gbqoFkpSNxFN6rCrMlR5ZMc51mMEAgE0Njait7eXLbPb7aipqUFZWdmo2jHSYuZ8kfzhOq7Q1Bxd19lNFEXIsgybzYby8nKUlZVluBx5PB42hsC99yG5fj2UkpJBz6N3dUE/awm+Uns1+uMa3BYBxDBT5zYM6LqBGJGgwsByywH4xdz89vQx05SadJEfj8czms5Q0e50OllNAS0eHuz1SG+UResRrFYrvF4vQg8+iMjadyEX8HhP1m6UnImPrutobGxk4r29vX3U6nkqKiqYeK+treWpaMcB/gxzOJxjItuznRZ/0sh9T08PDMPI2Ccej7OoK5ByMclurlTID0B6uoZmEiiSgOUzy3D9mdWYW+UpaPyjcQwAzM89HA6jqakJnZ2diEaj2Lt3Lw4dOoSamhpUVFSMyg8bzeH2+XyYMWMGW54+c5J9G6izoiAIzIudHoNG76PRKPbt24d9+/axbWhjK/o6VQWDsOg6BF0f8rE19UbRX6pDNZMIh5IgJOXrTv9nhY44ZBzQi+FXmwY8Trr/PpAqnKMXhkVFRbDb7SyNJr0oN7tR1kAM1CjLNE3Mqq9HkaYhGY1ClCRIoghRFCFKEni8kTNRIYSgs7OTpc4cPnx4VKxWgZTjGRXv06ZNg91uH5XjcgqHR+Y5HM6YQlNF0lNzhsq/FASB+XFT0ZjtGPPq9tYB0zXcVhnfXzkLV8yvHPAcSd3E/21pxn+v2o/QCI8xGPF4HE1NTWhra2N577Iso7KyEtXV1cfdVzwUCuUV+fECGq2YpglN01gE3zRNJuyXrP8QFS0tiFmtgCBAkWVWQEq7rgKA1tWFD0tn4odnfx4WaIjGYhmpQik9LyApqFCg4cLYOsiiAFEUIQgCmy0oLi5GaWkpSktLWe5/oe4xyWQybyFuMBjMSd0hhLB/p9/PeeNN+BsakLDbQXD0YkSAAJvdDsuRx8sj85zxJhQKMfFeX18/ah2KLRYL6urqmID3+/08dWac4ZF5DoczptBGVH6/H6eddhqA1I9Mut99KBTK2IcQgr6+PvT19WHfvn0AUt7vVNh3GzY8+toehBI6ih0qRPHoD4lpEvREk3j0tT2YWuzAnMrM6DqNxL++ow1d4QQIAZwWCQ6LAsuR9JqhjlEIVqsVM2bMQF1dHZqbm5mtZWNjI5qbmwe0tRwraIrJ1KlTM5aHw+G8Ij/dxlEURSbK0yP39EaFL+1gK4oiIAgQj1yUsf0IgSQSCEQAISQlAAgBAUBAQEwCCDoMAKFYEn6XjVlRer1eOJ1OSJIEURQRi8XQ2NiIxsZGdnxBENj69BtdTi8K6L3NZoPT6UR1dTU0TUMkEmE2nOm39CLkaUe6wdLZJir/ZUnKeB9yhgcvLD52kskkGhoaWOpMer+QY0EURVRVVTHxXlVVdcI6lk1WuJjncDjHHSosp02bBiAVxU7Pu8/nx06FVUNDA15uVtETVuCxiEgmE4AgpNIqBAGiKMBvV9EdSeL5Tc0ZQjw9mp/UTZhH0jzCSQNRLQ6/Q4XTIg96jHxkR3Gzb5WVlSgtLWW2lpFIBAcPHsTBgwdRVFSEqqoquFyuvPsOdeyh1hW6v81mQ3V1NaqqqkAIQTQaZRdUtJlVIBDI8Wo3jkTXTdNMiXJCoBs6JFGCIMtIHpnKF0wTIAQJzYAkpS4IBAEQJRmSJEISJYiShKguwGmR8M+f/gpEpGYENE2Drut576mozo4MUsEuiiIT8nT5QH/TOgF60SAfmWWgKUexWAzu3bshtLdDkqXUxUdq59TFi8gFzkgJr1lTcGExAC7mkfrMtba2MvHe3Nyck9I4Uvx+PxPvdXV1A/aX4EwMuJjncDjjjtVqRXV1NUvhoznL6ZFimpKhm8CeoAwRgGkYiCbiSCSSEJBKY5EVBaqqQISAN3e04vPzXZBFAXs7o3jo9cOIJA14LCI6kiYEAJIIEAIYpomuUBzxqA5ZBARBhGGIeGlDPaaHdoDWxAqCkLdQLF0U0vt0sZi+bzAYRE9PT4ZXPXVXcTqdefcb6rj5BCoAJmjp39nRabou/d+iKMLpdMLtdqOuro4dUxRFJJNJBINBJvLdDYch9PRAzIrUCUjZQipH8uhNUUSZ2wICEapFhl9Vc/LMTUIQi2q4eG4FFsw7Le9jzr4BKVFDGz8ZhsHSgZLJJLsYyL54AXIvwui+9EaxWCzM4UZ1uyFJUmpGhaTGbJomiGlC05LQkRL2gmkiFA7j0KFDLOVIzkpByl5+sqcqCIpSUGHxyUxvb29Gt9VC0uQKwW63s5z36dOnw+MZ/mwkZ/zgYp7D4Uw4ZFlGWVkZysrKABx1b+nq6kJDayfMPQFIQkqEGYYJw0ileui6DiGZQDQiQBdVaHEBf37hr5g5rRavdzgRiuvwWAWYxGTpESQVToZACAgERDUClSQhCgJMUUHCBA42tcJhs8FlU+CwWfOKTJrmwdJHcDTXOlts0+LVUCiEnp4eRKNRRCIRNDQ0wGazMYefwYR7IeI+exyjFbUDwFxi3F4PJEmC3W6HaZowj4hq0zSZ6wwACKYBr2zAKhH0hJNwWwSINKqOlCjuTxLYFQGnexL4+OOPB3yMgz3+wbalY6KpQoZhZPwbSL33aLOp7NcWAAxZAREEQBABEEhHLnJASErcgwCmCcE0ocVi6G5pKfjCiwp7Ku7T/x7oAiD9NpqOSZyJQSwWw6FDh1jue19f36gcV5Zl1NbWMvFeXl5+0l9MTma4mOdwOBOedAFcN+0UuD56B/1RDTaLCEJS6Rt6mjADACKIIIaG/Xt24uDe3XjfuwJEVGEaEhRZhpCSXakfMEGABMAgBAQq7Kqc8hjXRWgA/i90CkhYgCwKmFcEnF9rwaxyJ2w2G+x2O+x2OxN9+VJahvp3KBRCS0sLAoEAuygJhULM773QYw/3vPn+fSyvUXbnVOBoTrlAAJtM8OnTFDz7cRL9CROiAEgCYJDUzSELuPFUBX4phmBwYGvKsYQK/PRof/qtIhiEwzBgJJOpCDzAqmCZFDqy3CQm4vE425dG/NP/nf23kXUxlDpc5gVA9n36+vSLAJpiRIU/LSLOvhBI334kF0rH+m8ArADZMOkFZ+bs0skiMw3DQFNTE4u+t7a2jpplZHl5ORPvtbW1BRWNcyYHXMxzOJxJhSqLWD6zDM9vaoYkyfC4U9PBumEgmUggkUym0isgoCjaAlPXoMtW6CYBMZMIh42UgCAyiKhANwgEHBEZKVsSWKxWxDQTmq5DBJAkBAIxkdAJ3m8XsKU7jqunRHFW+VHxSoW9w+Fgt0JtKGlL80gkgubmZnR3d4MQgnA4jEQigaqqKpSXlx8Xv+bhXgREnC7opgkxrWlWXkwTdrsd1y2eigXT4nhzbxAfNIahGQSyBJxd48BF012Y7rcUlBIzkrEWcvEjiuKg51ZUBaJpQgiFjlyokKNXLOkYBuLxBHp6enIKcSVJyhFS+aKi6ak/6aI//WIjfT29p45D2d2DCxGF6eMEkFNQnJ5ylV6PQGcF6NjoffZsSPosSPqy2Y2HUaxpSAT76bPKxuN2uyds1PhYC3cJIejq6mLi/fDhwwPayA4Xt9vNxPvUqVNPmGZ9nFy4mOdwOJOO68+sxlu72tETTcJvT7nZyJIE2W6H1WpDdyQJvwJ8eu4U9NX3o7O7ByIINEGGQFJiSDKTMEQZQFrUWxAgEKAvGEIUKgABRXYZNuVoPrNmGOiLanilSUS5w0C1w4QkSTAMgzUtolPhVqsVDocDTqeT3Q9mSenz+VBdXc18+Nva2mAYBnp6ehAMBsfN1nIwgpdfjnCBjjzOZcvgmT4d06cDl5w1/EZd6SJwtO6H2ib7nLHLVyLp8abpd5LvDiAEFWefjcpPfIJZe1L3H/pvmr6TPpuRXeOQPo70C490IZw+vvTHRYV9uq0orSVIH5NhGNA0jS0zTZNFyunjZ+49WWlJ6X/TMdPHlC7+0/+dPmtAl0uSnCoiPvJcUOkuTPDUoZEU7obDYSbe6+vrc9y8RoqqqhmWkcfagZozeeBinsPhTDrmVnnw/ZWz8Ohre9AdSeZ6xNtkfP/ylEc8Ideira0Nyb/uwLtNSYjEAEAgioBANCQFFeRI/jMAEGIiSmRAAGQjiki/icSRSKqiKhBFCV6bjGDcwD6tCOfVuZkYSiQSTGDRKGk8Hs9oPqQoSoa4dzqdzCkiqZuIJDTYVQXTpk1DTU0Ns7VMJBI4dOgQGhoaUFpaiurqatjt9lEVtsO9J4QAZy2GtHhRQfv0E4L+pqaM5ZphIpo0YFNEyKIw6P6jxVApK+nLctYvPR+WpecPeoxC7unjou8dKraz76lQThfC9N/pwj9bZGen6RRynx59p/fp0fT0+/Tz5YvEU/GfXXw9kLhUP/gg1YDLYmE9B+i9oeswWOpSyv60r68vb+3AeIjXoQp3CQi0jk60tbfh5Z/8BB0dHaNzXkFAVVUVi75XV1dzy8iTFC7mORzOpOSK+ZWYWuzA85uasWpPBzSDwKYKWDEr1b2V2kkKgoDKykp8+2oHdv12A/pjMpwyATmSlpAwDCRNAUlIACGwEA06JBBTh2gkoENAMpmAETIgiCIkSYQiKzBkG9Ye6MNFJWHAPCq8qBCyWCwZ0U3DMJBMJnOEKSEEbXEJWwNW7AyIMAmgSCLOmeLCJad6MKPYirKyMnR2dqK9vR3JZBJtbW3Ytm0bioqKUFFRkTF9rhkEMc2EXZWgSIPnWo/Wfb5c6KHu97SH8fKOTqw90AvdMKFIIpbO8ONTC8owu8I9pmMejHwpIfkE7UDLCt1+sAsUKtqpcE+vmUg/Hk25Sd9HURQoisKErqIoGXny6UI9+360SU8TSr/Ri5X0W0JRYR6pXwF9rHlSmATTRCQaQV99fd5zpkf8h1M8PJpFxAQpy9YktVXVNFg1Dd3dPccs5IuKiph4r6urO259KjgTGy7mORzOpGVOpQdzKj2457JZQ6ZrzK3y4F9XzmY+85IgQpRlQDAhmQQ+EbigJIYK0oPftJXBIAIkQWEOLQQExDBgmgZ0TQdRAC1u4skf/QnTqlOFZeXl5fD5fJBlGYlEArquM3HlcDhgsViYUwoVNe8eCuPZ+gSiOiAJBJIIJHQDb+wL4B+H+vGVxX5cNqcMc+bMweLFi9HT04PGxkZEo1EAKas6QgiiFj/+fiCEt/d0QDMJFEnA8pmpC5u5VRPLZi7l9/9xRvfeRMLAS9s7sGZfT97Ou/lEdnr+eCFie7BIciEiO5tsi89sgUzTSYbabqh1hTBYHj29kKRjzkxvkcYssp19UTIYzU4nwqIIhaaQ0VkZkNSfR8S9KQiwqBbYPJ68KUTpt3S3ooFuAHJe+/SUoPTZj3wXQ56uTqi6jkQkktNITTjy/B4LNpsNU6dOZakzPp/vmI7HOTHhYp7D4Ux6VFmEKg+dR54vmu8QgHOmOHFhnR0zy52IJTS88Jd6hBIGbKoE0zShWtSUBaauQz8i6A1RhKkl0NrYgK6Ww9i6dSucTifLe6dFqy6XC6ZpMs9zQgjrqNoak/CX/ToSpgi/Q0rZYRITpklgmCZCSRNPf9QNF4lhiiclCux2O4qKiuByudDb2wtN0/D3j3vw7N42xHRAkUWosoR4kuD5Tc14a1d7XnE8mgwner2nPYyHX9mLcEKHzy5DhJAKvsqAaQKBWBIPvbQDSrwP03zqMYvsfEI5n8gertimj3swoTjYeir4BsuFH8lxB9sm24c/vWA2218//bHnm3kZ7TFPP3QIvngcWmvr4K+tYaCnsxMH33tvWO/NfMXD2U5C6TMJ6WNOh/5b13UEAgEs238AM5JJBHp7M7YTkEqpwzDFvCRJqKmpYeK9oqKCW45yhoSLeQ6Hc1IxUDQ/kUggEAhA1/uxpM6Nt/YHYXNYQajY0fUjQtuEYRDEiIRSrQXu2mr09fUhGo0ikUigt7cXDQ0NsFgscDqdKCoqQnV1NYqLi+HxeOB2u1kaxN+aFIQSElyKAF03ARwVjbIkQVGAQEzHhm4Jp5VZEY/HEY1GWVReFEUc7E3iD7t1xA3AIQMyTIiEwG5RIIgieiIa/v21Pah2q5hZ7hxR+shQ6wpFFEX8ZWM3QnEdXpsEUzcQSyYz7B2tAPpjAv74fj2un56ZG54e5aSCMh0qNkcqdgtZn+0QM96kPweDvRYDidKBtssWudkpRPSiNLtgN/sCoNAxddTWQs/ogzDwft21NYhEIgOuz0f6jEP6mLPX5YNe8PT29qKrqwvd3d2syP1c1vgtd8iFfjJKS0uZeJ8yZcqEKnDnTA64mOdwOCcl2dF8i8XCfN1vNKx4ryGE3nASPrsMq9UCQbCxYsW+mA6bYODyWUUoX3g9otEoOjs70dTUhLa2NoRCISSTSQQCAQSDQTQ1NcFms8Fms8HtdqOoqAj+kjJsT85LNasiBIYB6DRyT1MeABBDwLpDQVw7XYDL4WCRw3g8jnA4jHebdMQNCQ7JgJZMIqbpR+0DJQmKrCAQkfGLt3fin5eU5X0uBotK05uiKAVFsQdbltRNbHqpBaoswaqqiCfiiMdz/eQNqPioVYNt31oIxMwY50Djy2ebONjfhUTf0/8uVAwPtm44Fz6FMpAQHUygZq8baNvs9Jjs7fJZT9IZh3yCf6DXDABCpy9E6PSFBY0fAByDrj16jHyzC+m3fOup0A+FQujo6EB7ezs6OjpSTekEAaqqsiZLtkgEQiIBWVZyvPAHcuJxOp1MvE+bNg0ul6uAR8PhDAwX8xwOh5OGJEk4Z2YN7rtawr+/uht9MQ0idMiSAJMABhFQ5LLhuxfNwDxPEg0NDWhtbYXH48GsWbNgGAYCgQBaW1vR0tKCtrY2RCIRxGIxxGIxBINBdHZ2QmluR/SU0yAIImKaAVESkUgkAaTsCqn3vSGpiBoC3lrzLoocFrhcLng8HjgcDtidLuyLRKBIBBZL6uvcNAmMI4IKuo5kIgFdtuLtjyMoanoXFWUlKC4uRmlpKcrLy1FcXJwzjZ8upPJFWLNFaaGdZYMxDQnNgCSmmgOZhgli5gpcQTBhQkBMM6EQbcDjFSpKh9qvkO3SxWe+nOp8t/RtBxO19N/ZxZn5It0Dic/BBOtIRe2x7Jf+vkmf0ci+AKDPc3YzK1q8S/8e7nhHQiQSQX19PbON7D/SO8Fms6Guri5n+9qGw5gSi0MhBG5dH/TYIoCKinJ87WtfQ0lJyYjHyOHkg4t5DofDyUNGfv3uDiQ0HZIAnFPrxLVnVOKsGeVQVRWnnHIKDMNAa2srmpub0dvbC0mS4Pf7MW/ePJimie7ubibuu7q6EIvFkNC6QXQNhmQBTB2ieTQCTAgBgQkBAkwigBg6ejtakbBZEQgE0N7eDofDgQSREIvXAYII00gJJkkUIaoqQFJ596ZpQCAmdAJs3bkbDfUqHA4Hi7RLkgSPx5OaLfD7UVxcDL/fD6/XC4vlaAOn9Hzp9PzpdPGUvjz7bwDMYSehEQACRCmt4BEAbcBkEBmqYGBabTXrCwAcTZ8ZKJKaT+RljyV928HGXeiyY12fviwf+YpUBypeHWx5IevT1x1PsZmeq55dxEtnwwAM+Biyo/yFomkaGhsbmXhvb28f1v6Vra2wRaMQAMj5xLxwxF6TEECS4POnLqI5nNGGi3kOh8MZgOz8eqssIh4NIxAIoKGhAQ6HA16vFw6HAzU1NazhUygUQiAQQE9PyopOFEWUlpZi4cKFSCaT6O7uRktLC9Yn29GhTIWpmyCplPlURF4UIIoSAAG6IMAZPIS2zmbY7XZ2i0QisDlcEIiJpAkkTA26cSQiT44c60jE04AMyUzC47BBFAii0ShL+VEUBclkEu3t7RliRlVVuFwuFBcXo6SkBGVlZSgqKoLT6YTb7YYsy0xkp98P9feyU4vx123tIIRAVVSoXjUjt9gkJvSIjivnleIbyy/Ie6yhijjTlw90G6gYkl4wpBeopueOA/mtLrNF+UDbZC8bCjomKmiPF+l1CsO5MEgvMB7uxYSiKIMWe2YLfk3TEI/HM14fevx85xEEAZ2dnUy8NzY2Qh8ioj4kgsCcd0Bf33zbjUGKFYdD4WKew+FwhiA9v96qeuH1ehGJRBAIBNDS0gJVVeHxeODxeFhuvMfjQVlZGebNm4dEIoHOzk60tbWhs7MTlZWVqKysREVEwO8bJMQlD2QzDpOKEpNANw0Yih2SnkBZ9BCKioqg6zpCoRD6+vogCAJcLhdKiorQKFeCAEei7SLTFykvfROAiBoxAKfdiv7+fsRiMXYcu93O0nZogSm1Muzp6UFPTw/27t0LIJVDbbVaYbFY4PV6UVJSAr/fD7fbzW4ul2tQO76bPjEVq/f2oC+ms+69FNMkCMQMeOwKbj53OoqKjt1ScyQXHAMtoyIy3fpwIP/0wdxi0htCZf+dfmGRXngKjO5FxGjtN9rQ2aKRzjJkjzUUCqG5uRmNjY1obm5GLBZj2+XbN32mZzCsVis8bjcEWYZ4ZJZrIMx4HHJpKZzLlo3Kc8ThZCOQAitympubUVNTg6amJlRXV4/1uDgcDmdSoGkaK3QFALfbDa/XyxwpNE1DLJYq8rTZbBAEAR0dHWhtbUVbWxveb07g9TYr4oYAEQQCMWEQwCAEkpFEdfcGOHpTYpqmxlit1lTE3TDQByf2ll4AQ1RhQQKSSMWJCAgSoroAq0zwpRk6SpQEc+1pbW1FMBhELBZjgtVms0FV1bxR2YHEiiRJsFgssFgsTOhT68x0ke92u+F0OiGK4hGf+T0ZPvP5uveejOS7oKCCPv0iIvvv9AuDfI2Z6Db5LizybU+3HWhc2Rzvi4iBttE0Dd3d3ejo6EBnZydCoVDOezf9Ii09lz87DSo9hUeSJFRUVKC6uhpTpkxBZWUl7D/5KcQtW4CioqyutUe86oXURYHe1QXn0qWoevK/8j53HM6xwsU8h8PhjAKmabJodzKZhN1uh9frZd1ZDcNALBaDruuw2WywWCwAUo4Z63Y24MWtrdjUloROUs2jTnMmMdMSgFMLIBAIoK+vD+FwmDXKoWKDEIKAexrqXQugiwpATAjEBBFSYt4iGLjA24fzpzrh9/tZ1FzTNASDQfT09KCvr48dlxACm80GRVEQj8ehaRoSiURGB9v0bqLpgociiiIT+FTk04sEKvJ7TBvebzPxUXMMBgFUWcKK2ZndeznjS3aaUSEXAdkXGdmzE4VcTAx1EZEtxvv6+tDd3Y3Ozk709fVlrM/eF8h/4ZCNIAhwOBwszczr9bLPG73NfevvKDl8GAmHA0jrViuKIuwOB5QjbkBczHPGGi7mORwOZ5SJRqMIBAIIh8NQFAVerxdut5uJgVgshmQyyYQuFRSxhIbDre0Idneiq6ONuWkQQlgufiQSQTgcRjgcRjKZZCIorHjRZp2CHkslTIgQYaIo3oLi/gNQIh0QhJSlXmlpKWprazF16lSUl6eKeCORCNra2tDR0cFyyWmE3el0IpFIIBKJsBzlWCyGeDzOGmElEgkkk0n2ONJdSdKdX1RVzYjgWywWmIKIpCnAJovweXKj+W63m7esPwnJ182Wvte7u7tx6NAh1NfXo7GxEfF4fMDeCNl9A7JrI9LvVVVFcXExysrKUFpaCpvNxi4E8l2ozHnzLZQcPoy43Q4m5gVAkmQ4HQ5m7cnFPGes4WKew+Fwxgga/Q4GgzBNk6Xg0Kh8PB5HPB6Hoiiw2Ww5ubpUZLe1taG9vZ0VQWqahnA4zIR9IpFKnzFNE5phArIF0BPQkwnWfZYKbU3ToGkaTNOEoigoLS1FXV0dpk+fDr/fj3A4jK6uLiagAKCoqAjFxcVQVTXltb1uHZSt21JCS9dhmCZACGuqpWs6zCOFqof8Rdh+JKpJRX56kaTVamU3VVVhtVpzcu4VRWH5+B6PJ+OePpecE5doNIpDhw7h4MGDqK+vRyAQGJXjyrKMKVOmMM93v9/P0uLoZ5N+tuhnhubUi6IIy//8byrNxuc72jXqiKZXFAWSyNNsOMcHXgDL4XA4Y4SiKMzqsb+/H4FAAIcPH4bNZmMpOFarFclkkuX22u12FtFzOBw45ZRTcMopp2RYXLa1tUFRFPh8PpimiWg0ysQ9jZIT1QbYbazhE/W613UdyWQSkiRlHPPdd9/NiNx7PB4QQqAoCvr7+xGPx1FaWoqKigpIB+th7twJIssAIansApoOkfa3YBhwu11wnXsuEokEE0h0rJFIBH19fTnFjKqass+02+0siq9pGnp6enKeY4vFkjeaTx13OJMPXdfR1NTExHtbW9uoNdyi7++amhqUlZWBEIJkMolwOIxQKAQAzOPe6XTC5/NBVVUoipLjttPi8yEsSZD5BSVnnOHfdBwOhzPGCILA3G5isRgCgQDa2togyzK8Xi9bp+s6YrEYTNNkxagUam9JLS5jsRiL2re1tbHc/EQiwcRyNBplue7UUtJisUAURfT19aG/vx9WqzXDrrG3txdtbW0wDAOqqsJms8HpdKKkpATt7e2or6/H/GAQLkmC6PdDFISUF9+RiCQxTZiEgJgm0NcHp8OJ8vJy1jQrHo+jvLycReY1TYOu64hGowgGg+jv72fiStM0JvLp7AUV+VToJxIJdHV1oaurK+d5p/ab2behHHc4xxdCCDo7O5l4P3z48IitONNTaQzDgMPhQHV1NaqqqlBeXs5StmRZhq7rUFWVWbTmE+wczmSAi3kOh8M5jlDrypKSEgSDQfT19aGnpwculwterxculwumaSIWiyEajbLIdHahns1mY+3gCSHo6elBW1sbWltb0dPTA7/fD8MwEI1GWa69rusIh8MghMDpdKK0tBQWiwXJZJIV2TocDpafnEwmkUwm0dbWhoMHD0LTNNjtdhQHgpiu60iGw7CoKqS0Dp2iKEISUl75uiTB6/OhfP58RKNRJuhp6gJ9jKZppqz+PB5YrVYAR11GQqEQ+vv7EY1GEY/H0dfXh7a2NmiaBkmSYLVaYbfbmdCnxbuSJLHzdXR05LwO9OImO3XH6XQe14ZJJyuhUIiJ9/r6eoTD4YL3zVeYS+9lWUZlZSVqa2sxbdo0lJaWssh6eoR9tF5jomnQ81xIZm/D4YwlXMxzOBzOOCDLMvx+P4qKiliTqcbGRlitVvh8PhZpj8fjCAaDA+bVA6nIf3FxMYqLi5mvfXt7O0vJcblcAIBYLMby7GkRrq7rLFWltrYWoigiFouhu7sb/f39rAAxmUwikUggGAwiHAlDN3SEQyFEjuS/q6oKVVUgSTJEavFHCAgxWfMpRVFACMkQ8vSeinX6b3o+URRRXl7OIqZUhGXvEw6H0dvbi3g8ztx0bDYbE/qqqmYU5tIi4nzPZT5bTbfbDbvdzoX+CEkmk2hoaGANm/LNpKQzkJMO/ZsiyzKqqqpwyimnYPr06ZgyZQqsVuuoCvaBGI5vPPeY54wlvACWw+FwJgjxeByBQAChUAiiKLIUHFmWkUwmEYvFUrZ3dnvBaSKEEOYr39bWhq6uLtbhNBKJsKh9eodTu92OoqIieL1eGIbB8v0jkQiSySQ+uWEDajo60X8k757+iAiCAEWR4XA4IIoSpGAQxukLgTvvZBcjtGjVYrFAUZSMccbj8QyBT0U6zfenUX3qBJQu8OUjswN0VoGmG9FcaFogTC88qD1ousBPL9DNhyRJAwp97riTiWmabEanvr4eTU1NGSIcyBTs+aLsAI7O9qTVVRQXF+PUU0/FjBkzMHXqVDabw+GcrHAxz+FwOBMMwzBYIyrDMOB0OuH1emGz2Vh+OSEEdrs9QxAXgqZpaG9vZyk59FjRaDTDHQdIiS2aU1xcXAyn0wlN01D2+z/AV1+PmM0KwzgiwAwDBICqKJAVBSAEllgMfVOnouOzN8PlcsHpdDJ7SkVRYLfbWR5/trin0JSjdIGfnXpD/fBpwS4V49R1JDVrkPK5p8XGtKiXzgAYhgFFUVguPxX2sixn/D2Q0KeOO/lu6bUPJzJ9fX1MvB86dIg1JMsn1NPdkvIJ9vT+BQ6Hg6WUTZ8+HR4P70OQj+ArryK8Zk1B2zqXLYPnyivGeESc4wVPs+FwOJwJhiRJLAUnHA4jEAigqakJFosFPp8PLpcrQ4BTa8dCUBQFNTU1qKmpAQAEg0EWte/s7GRWllTYR6NR9PT0oKenh6WglB1Jo7FYLCAERwoODRiGCVVRIEoSiGlCiMVAQFiRLgCoqspy1KnQtVgsUFWVNdqi4p6KZ4fDAYfDkfE4aD0AFfrhcJjZgFJLQRqhj0ajEEWReY5LkgSfzweHw8HOTQhBKBRCNBpl3vnUgSe9GJPOAqQLfEmSWNOtk8lxJxaL4dChQzhw4AD279+Pnp6evFF2IFOw09mRdNGejizLLOd9+vTpKC8v5+lNBRBeswahVasgDHGBT3P4uZg/cZi83yIcDodzgkPFs8vlQiKRQCAQQEdHB7q6uuDxeFhXSpqeQ9NHhiN8qJPOrFmzoOs6Ojs7mbgPhUIwTZPZSIbDYfT39yMSicBrmkgkkpBEEeIRtxlFOdLOHoB4RKS53W7MmDEDkUgE/f390HUd/f396O/vR2NjIwAwYe90Oln6C7Wn9Hq9rAg4XfjSlBdaD0ChPuE0gk9z6altJ238Q1N60ruFynIqRYimAlHBaZomNE1jBcE0Tz8YDGZ4ntOOuNkRfXq+yey4Q2dIqHg/ePAgWlpacgR7eh+BwQR7NuXl5Uy819bWDnvGiZNCUBTIJSWDbjNUwS5n8sHFPIfD4UwCLBYLysrKUFxczERkb28v88L2er2sQFWWZdhstmGLQeoEUllZCSDlOELTcTo7O1mk27p1K4CUDaVumoCuA2mRV0kUQQCIAEyTME/88vJymKaJeDyO/v5+5qyjaRr6+vrQ2dmJZDLJBLXL5YLdbmfCmj7WfOKeQt1K3G43ysrK2HJac0DPS5/DWCyW0d2zv78fvb29rIZAFEWW62+32+F0OlFUVMRmDBRFYZ1xabFyX18fS5NKJz2ST+81TUMkEhnUcSf7NlaOO4ZhsKZidHaC2pE2NDSgpaUFuq5nCHZqdVqIYE/H5XKxZk3Tpk1jBd+c0YcQAghHL7Q5Jx5czHM4HM4kQpIkFBUVwefzIRKJZKTgeL1euN1uGIaBSCQCAMyqcSRQQX3qqafCMAx0dXWhra0N+ptvQSIE9mQyp2lUBkeEummaEEURiUSCWVj6fD4mthOJBCvEtVqtkGWZXZhQe04qYqn9JI3cFxUVZRSzDgTNm8/Ot47H4yyHnkbyg8EgE/e0ULijowO6rrMovsVigdPphMvlynDLqa6uxsyZM+FwOFjhcroFKf07FAplNELKztOXJAnJZBL9/f05IvlYHHeoYE8mkxnCPZlMslmYlpYWtLa2orm5GbFYjEXXnU7niD3YVVVFXV0di74XFxfz1JlRhM4YmaaJWDwO0zBgmGaqloUQOJxOqHy244SFi3kOh8OZhAiCAKfTCafTyVJwOjs7M1JwqM1kul/9SJEkCeXl5SgvL0fw5psRWPV3xGJxxGOxlHg4kjOfEsAaDCOVetFit6Hp/ffh9XpZN1yXywVd1zOa+9AiW0EQkEgk0N/fzwQkbfDT2dmJcDjMIu+0gRQtrHU6nfD7/SyaX8jMBK03KElLTTBNk7ng0Dz6YDDIOujSIuRwOIyenh4m8ul400U+dfEpKSlBXV0duyAxTRO9vb3o6elBb28v+vr60NfXh2AwyNyF0l/r7ILceDyO3t7enAsYSZIyuufabDb2GGnKEN2OXuR1dHSgqakJTU1N6O7uZoKddiQeCYIgoKqqion36urqCZU2NBmh/SBoT4j0v2OxGBYcPowSXUciFmMX2IQQCKII0zQAcDF/osLFPIfD4Uxy0lNwqI0kbQDl8/lgt9tZXr3FYoHVaj2mqKjnyisyiudM00R3dzdLyenr64OmacxWEpEIent70dvby8ZbWlqK4uJilJSUwG63I5lMMg96TdOYJSeNUIfDYaiqCr/fD0mS2AVMc3MzADARTXPQLRYLXC4X/H4/K3QtVEzSXH+3253xGOk4QqEQu0iinv1U4Pf396OrqwvNzc0wDIMJ4vTUIYvFwpa53W5UVFSwsRNC2IwLFfvpwo022cpn4UhfUyr+qTsPjaxbrVZ4vV5W7Nvd3Y2enh7WZIs+9pFSVFTExHtdXR236xwBmqblFezUGnYgdF2HcSQKb5gG0vxiIQDs4ppzYsLFPIfD4ZwgUJeW9BSc5uZmqKrKustqmjZkE6rhIooiSktLUVpaigULFiAejzNh39bWltHYiVpf0kgwAJZPX1JSgpKSElaoSi8ITNNkNpyqqiIWizGRTPP7aadbmg5Dc9xtNhs8Hg9sNhuL3FNBPZxIsSiKLMJdXFwMAKwwlnraUxcdOpZ0wR8KhdDb28v87i0WCxP49EKEWnU6HA44nU7U1tZi+vTp7LlIJBIsgk9vkUgE/eEI+qMJyMSEJBAm6umMAX0v0H10Xc/rIENdhLJvA71HbDYbpk6dynLffT7fsN87JyO6rmeI9HTRXkgnXNongr4v6CzX0Qs7EYKYVowuihB5StMJDfeZ53A4nBMYOjXf398PAHC73fB6vQDAmlDZbLYxs0gkhKC3t5cJ++7ubub1HolEmKMM3VZRFFRWVqK8vBzV1dVQFIU1q9I0jdlM0hx1RVGYsKEpKJqmZeTBx+NxZm9pt9vhcrmYYPb7/fB4PMMW9wORXkxLO91SFxw6JnpREwwGmeiPxWIAUrnlNCWKCnyfz8dSp+jrZ7Vasa8rihc2t+LtPR3QDBOSCJxd48TiYgOx1n3Yt28fGhoa0Nvbyxo0pf/kU7tI2niL/p3eaRdIFRXTjrpTpkzBzJkzMXv2bMyYMYO7zgyAruusyDpdsPf19RUk2NOhjdCocKcXaemzL36/H/P+vgquvXuBoiLmMiWKYk7hq97VBefSpah68r9G5bFyf/vxh4t5DofDOQkwTZOl4FDHGOrpHovFYJomK+QcS6hDSnrTKmp7GQ6Hoet6xph9Ph+qq6tRW1uLkpISaJqG/v5+5iATi8WYiFdVFYqiMBcaKnai0Si7oKE58FarlRXTprvG+P1+9ryMxqxFdmdb6qqT7pGfHsWn91S8CYLA6gHoOJ1OJ/bFnPjD7hiimglZFCAKQFIzoBsmZGhYJDdhqtQHIHWBQZtj0XPGYjHWNGswoe9yuVBWVsZmXuhFD03roalCx8txZyJhGAYT7Nk57KFQ6JiOmy3eBUGA1+tls1clJSWsEN7j8UCSJLTceRfCa9cWZE05mmK+5c67Cva3d61YMWrn5RyFp9lwOBzOSYAoivB6vfB6vYhGo+jr60NLSwsURWEpODRFhEaGx0KMWSwWTJkyBVOmTAGQ6hpKhX1XV1dGOg71cg8Gg9ixYwezzqSuKIqiMO976i8fCARYF1hqz6mqKqqqqlBXVwdJkpijTH9/P3seZFlmRaO0467X62WCyWq1Dijuqb1mtjsM/ZtCI9xU6NIbbY5FnWZoAXAgEGAOOMFgED09PWhtbUV7QsY7xmlImhJkMw7dNCEAEEQRkiAgIVqwIVkNhxxBqZKEJEmw2Wyw2WxsVoZCZxBogyxCCNxuN7PjlGWZNYKiVpvpfvqBQADt7e0ZBbrpFwL5XHeGctyZSJimOahgLzAeOijplqDUpam8vBzFxcUoLy9HaWkpSkpKCpoFIZo2pI88SXtPjhbc33584WKew+FwTjLsdjvsdjsrtqMdXl0uF7xeLxMwNOVjNCLUA0Fz/GfPng1N09DR0cFScqizCxX3pmmiubkZjY2NWLNmDYqKilBXV4dTTz0VFRUViMVi6OvrQzQaZcWqVCDJssxsLx0OB6ZNmwZVVZlI7erqYpF72tzJarWyQlXatdbj8cDhcECSpAwRBqRELvW5p6KM/k1F7kCk50FTK87a2lp2Hvp4AoEAfrSuDXpTAoIWgkHIEXtQAhgAIABCEnHFgQ1dEk4J7WX1A3Qs6aLQZrNh5syZmD59OqZPn46SkhI2TurL39/fj2AwyIqY6ewOvRCghZcUOivS29ub00CLetJnN8iiz++xFs2OJOUjfdYqu+g0GAyOimCnEEJYAXT6xaPf70dpaSlcLhesVuuI0pecy5aNybbHgkkIz9c/DnAxz+FwOCcpiqKgpKQEfr8foVAIgUAAhw8fht1uZ91lQ6EQi+yOtbWgoiiorq5mqZz9/f1M2Le3tzN3HFpEGwqFsH37dmzevBmyLKO2thannnoqTjvtNMiyzCLvNKWkr68PhmFAVVVW3EmLg+fMmcP26e3tRVtbG3p7e9HZ2cki/9QJiDaOKisrQ0lJCcrKyuB2u6Gq6ogjzukXAhSaK03zo71eL1weLz5+vQs2C2CSlM+/ccTik5gmTCruTQO9tio07HoFIlK2mfTxVlRUYNasWViwYAFmzZrFLlCyx06j+enNtwCwolpaAxB94w0YH36Ymo1IajCJCWISmISAEPOoswqAjtpadJx2Krq7u/M20aK2o/luhaSAhdesGTDlg3b7JYQAuo7mlmbs7k5dxKVbgY4GtNaBzkRQi1B6YUdrENJtQ4+VbJepscY0TSQSCSS1JOsQbBzxt0+5LRkQRRFFvqLjNqaTFS7mORwO5yRHFEV4PB54PB6WX97a2soEpMViGZUmVMOFiriZM2cyn3kq7nt6epiwp3aNjY2NOHToEP72t7/B7/djxowZmDNnDmbOnMnScXp7e5m7S1tbGxPL1J6R1hJMmzYNs2fPRjKZZM8JLdoNhUJob29HT08Pi957vV5mt1lcXDwqTkG00DddxPaEYtAMAlkSYEoSdEOHKAoQBQk4IggJCASIgKCiuLwakhFnopHWAhw4cAAHDx5kNQNOpxMlJSWsl4DH42H+/bQoNj2yTguIKyoq0PKznyO0cxcERcnwNyc40n2U+p0bBmRFQffsWTAMgzn7UGj6TkdHR05En1prps+QpKfwyLKcSnfSNRBJgul2s6ZJ9D4di2EgEk69D0aKzWZj3Zfpjdqi0k7HiUQCoihmiPfRqscYC2gtCk29Sk/BisViGbMzdOZsyY4dmJL8/9n70yhJris9EPzes8V393CPfckdewIgUSRBiCwuIFASF5Biq1iqlnQktdQzo5amz2jU0+f0FPvMtKTpoo6OunWkPqOeGal1NFKrSqUSelSqIsgSiyxwBUECCRBAIhNrIpeIzFg8wvfFtvfmh/l98czCPcIjMjKRmbDvpKV7+GJubmbu/t17v/tdF50hfQJSyMH5mGTnbyQSMp8gQYIECRSo/E/2eTS5lOQQjuOg0+moqat7wfUFuq6PrG3CNg9OYEgvT1aU7XY7Yn+pu8SQg89PfvITfP/73wfnHPPz8zh58iROnDiBSqWCdDqtGlODIEC/31da91arpeQuExMTOHLkCD70oQ8hCAI1SGp9fR2bm5sq+7+2toa3335bkbaJiQksLCyoysewzPdBUMikYJscjhfaWwa9AAwMUq1agjMOwWwU0ib+6f/8P2F2egq+76tM+rVr13Dt2jWsra1ha2tLvSfy7AegpCDUNzAzM4PJyclIMEck3+90IA0DKJcRvkUGHl4AROMYEGxUcfToERz7yldUUER++u12WzV+0vEIYgTcNE2srKxEGnqJaJLv/hfOv4Fjrot+uxVaNLLro5HpdDpC2PXrqVRK9Wr0+311PjmOg3Q6jUKhgOnp6euq2FwvKHs+jJwPu50sLrvdrmpOp8terzdccrSLCkkCEIFIBobdYCRkPkGCBAkS7IBpmmpia6vVUtNJybfd9330er2RQ6jOrjTw9Jnl0DZRSFgGwxP3zeJrH1nCg4ul696+TCaDI0eOYG5uDo7j4OrVq1hZWVGNtJ1OJ9J8SgT2ueeew+zsLO6//348+OCDWFpaUtrwRqMRykYG3vBkqUkZVXK7qVQqOHXqVMQCc3NzE2traxF9+TvvvKMy4uVyGfPz85idnVVe9wcheLbJ8cR9s3j6zDJyqcFPOGOwTROWbcGywkz6ZsfD5x+aRyGXRbvdVr78xWIRx48fV69Ndp8bGxtKzrS+vo5Go4Fut4tqtYpqtYo333wTnHPlXkMDv2ZmZmBzDsFYSORlWBlQdqMAQARQSriuBz7oY5iamsLs7Kxq/iUnn06ng2q1ivX1dayvr6uBSUQwdccjAjXddjsdCCnhOq7aN4wx8MH9wyxYU6lUhKjrhD2dTqvH0VAzkmxREGFZljrGB9W77wfUJL0XOacAdRSItOuEvdPpqEboUdCdoqipm2RixqByYxgGDNOAYZi3bBXiTkJC5hMkSJAgwUgwxpSUgdxiSAahN8vqQ6i++epVfOOZ82j2fZicwTQY+q7E02eW8Z3XV/H1L92Ppx5e2PO19QZT3R3GdV0lzyApyszMDBYXF5XvfLVaxdraGlZWVtQwnk6nozLUzz33HH70ox8hk8koSc29996LXC6nmj5pKBN5wlerVVy+fBm5XA7pdFoR+9nZWZw8eVKtmwZEXbt2TclzLl++jAsXLqj9VCqVsLCwgPn5eUxNTaFYLI5N7r/2kSV85/VVNPs+CvlCKIUZPFcIic2ui2LGxK9//Bjy+TyAaIMtkTVy0slmszh16hTuuusu9Rqu66LZbGJ5eRmrq6tYXV3FxsaGqn5cvXpVEb5PX3gXC54Hv9OFYZowBxIZncRJSASMIZVOoTA5iSAIlOyJZBt606nnecouUwihXIlo21zXRb/fV4+hhaw2fS2rT04/qVTY/MsNA4br4uixo/jEX//rEcKuQwiBbrersu6036j6Qs2q15t1punCu5HyePZ8P6DpxETW6Xq/31eP0T3ry+VyhKyT3ImuU58JzT5Ib1QRVKvIJUPD3jckZD5BggQJEowFaoikgTi1Wg1CCBQKBXDO0Wq18OZ6F7/5zHm0HR+VrBkSKcYAxiAlsNV18Y1nzuPEVA6nF0pKMjGMtBNpocZIy7KUpp3+HkWkKpUK7rnnHgghlKUjLURIiZy9+eabeP3112EYBmZmZtRQpIceekiRTiL3NM2VMtlE7ChbTeR+cXERp0+fVlp9suAkWculS5dw4cIFMMZU1n9+fh7z8/OYnJxU5F41bALqcjEr8X/6zBH8o2cvoeEEMD0HBmfwhUQgJAopE//HTyxgyuhjdbWnnquvRye/JG8hgq8yq5rrzN133w0AKou/sbGB9fV1bG1tqeCq1+9pR4DBti1kM1mIQYMu932sX7mCF//Vv0Kr1VI++uoZg+t6IzDdRtuu++HT9hPZ9X0fhbffBt+qKeJPMhvGOTKDBlSC5/lYXl5WfQBkMUrnoO/7MAxDNbBOTEyoAWvUPzAKFDSNI3FxXXeXT934oCoLSWIog0/VAyLjVGHSCfqwQJJ6ZnTiPjExsUMytpLNop041ryvSMh8ggQJEiTYF2jiZKVSUbrnZrOJdDqNf/+LDTR7Hio5S5EJZZ0oAUMCVcfA/+2fP4MnJ+uq0dG2bUWaMplM5DqRDnICIRLnOM4OojuM/AJQRPmBBx7A+vo6VldXsby8rKwsiVRduHABb731Fn7v934PuVwOR48exV133YVjx45hcnJS9RQQwW+321hfX1d6YpLkTExMKFJOQ5Tuvvtu9Pt9RfCr1apq5K1Wq3j77bfVOrLZrNrmqakplEolNZmVMYZPHcth7qkT+Nb5Lfz4QgO+kEibHJ86NYEvn57GPTPbXu70HLo+7DadNJNunTTr8czs3XffHXnOSnUTvR/9aGB36sP3Pfh+ANcFpOwqzXpKSvier6o9Bz33RmnYc7mcGpxUmJ5WQYQQAoEQ4fswDEgAAWOwUzZM21bDuhzHUXIR0zSRz+cVyacmaKoK0ELBKFU9KAigwWW0n+LTden6fkAuTABUwEuBAfV+UKWlUCjsa71xwk4OR+Pi/fK3TxAiIfMJEiRIkOBA0CU4/X4fG5tb+P5bm+BMgoGsFT0EgVBNkAxhlv6tTgr3u2uwTa4mig5bv35J1/Vpr+T2QksqlVLWf3Q/WTLScuzYMZw4cQJSStTrdeWSQ1n7TqejiNqlS5fw3nvvIZVKYXp6GnfffTfuuece3HPPPbAsKyLDIUkNkfV2u61cX+bm5jA9PY2FhQXlpc4YUw23jUZDNaVSoHDt2jWsrKyoQVMLCwtYWlrCwsICyuUyFhc5Pvuhw2syHoUgCFTQQkOsyJO/1WrhgQsXMB0EIVEHYJkWLDN0tjlIvtY0TZRKpaEadpINjYM4YZaS7DLlQP/uwu12wRhTjj7kstPv99FoNCKZdKoKkK++LgPSr4+6bdj2Uc9JOp2OWFXSOULNwI7jqD4Oz/MigQKAsfz5qTcgTtyv19v/VvS3/6AhIfMJEiRIkOC6kU6nUShPQTIO2wgJTOD7YIzDNDmYRuSFYGDMxPT8IjJGSHKGEfdxQBaAuv53HMSDgGKxiFKppIj1xsYGWq2Wyti7rouVlRWsrKzgueeew8TEBE6cOIF7771XTaQVQqiG2Hq9ju4f/iHYmZcQBD5afoAu51g2TaRSNjKZTPjaVuh0Mvn44zj55acUcW40GirIoMrH66+/ribh5nI5Re6XlpZQqVRgGNdH5B3HiQxO0gcoxfcvuceQ+4mqiGgONrsdSeq5iGfYy+XygZ1/5MBznzTzQgQQgYgQa4IZBGjV67h27txY6yb5EYChTbQEaoalOQb6Eteh00AwOueod4COOVWeRgUDtF3xjD9ZZhJhr1QqqnIRtxk9DJedm+1vf704yGCxWx0JmU+QIEGCBIeCrG3CMjn6noBlWaFkwfeV9p1og+tKpC2Ge04eg/A9RZbjza03EvSao0D2lURqafiU4zjwPA9SSvzkJz9RA7UqlQqOHj2K48ePY2FhAYVCAbnz52G8/jrkQNqBgfc6pIQHwGcMXQBcCHS6HRiPf1ZNmy2VSjh69Kga0NRsNrGxsYGVlRVsbW2h0Wjgrbfewrlz58AYU+T+yJEjWFpawtTU1FDS6bpuhKTr13u93o7HjwJp6uk6Bn0RJKmirDzX9PeG42DpyBF89D//z8d289Ez5bvpz8mB6MilS6g4DsTq6vZKGEO8s4IFwVivT9nzYcsw0j5KOkPHkSYuUyWnXq8PdeeJy5AoWNKz/dRrkc/nw/NtUAWyLEsFW3RJjcbD3t+wOQL65W73vV+Wm9eD3QaL6SBZUELmEyRIkOAG4kbLCxLsD7ptohQSlmnCMk0IKcMsqRAQEoAf4MkHZvDLn7h76BAq3/d3aJNHLUSu47fvZq03Lkgfv7CwoBphKWNOWXuaRNtoNHDhwgWYpqmcbn71yjIWGYOXzYYrJJmHEGpKqgSQdRy8fe48/pe/+3dRKpWUheXU1BSy2ayqHkxNTWFhYUF5h5NEiKbUkt6fNNskHckOXr/VaqHb7V73ftHBOYdl2zCEQLbfV9IPanpG4CsfchkEsO3wWNdqNXWcd7NXHEZ0ge1zhI59EITTRrdOngyz6JyD8eH1Ac4ZODeARz6MhYWFkcT8IDaTUkp1PujuPPV6fYdv/rjI5/NDNe372TZd+hMn+sNuowZ0/bZhoIrFuOQ/ftv7FQwwy4I5Pb3rY/bqAbiVkJD5BAkS3Ha40R7mCQ4Osk3c7LqYzNqD6aQM3DARMIF6x0XO5vj0ko1Go4F+vw/btiNDqKjhkkjoQTCM4O930WEYBsrlMsrlMo4fP65sOmu1GjY2NpTUx/d9paPf2trCnO+j1+0q323ODXDThAHKtgpw14VhGuh0Otjc3MT58+eVFWOxWES5XFY2iJT5JZ9/IsFEIGu1GprNpmpiFUIo+85CoaAac2no0SgnEx2MsQiR1DXsxWIRjeMn0PzjP4YQAYIgOnU1ECJy+/LEBLa+/W1FEmn9uoNOHNSUq1dwdLvKQqGAfD4fWoY+9FBIxH/+AsyXXwY3ODgfrFuTowAArq0Cv/vvDiSlIElVnLA3m80Dk3Y6PjpxL5VKu8p6xoW+jw/qgz+M9I8KDCjAGicY2G81QH9M4mEfIiHzCRIkuK1wWB7mCW4MHlws4etfuh/feOY8qh1XHSM/kPCFRDFj4etfvB+fOllURMi27cgAqkwmc90ZO93K8iDQbQr3Wnq9nrKqXF5eRrVaRbvdBgayBj8I4AcBmEeBDYdpmDBME5yH2clCoYB7770X/X4fnU5HDW1aXl7G22+/rQgi2XFKKYeSPHJ7IfJL00nr9Tqq1arK0JJ8JJvNIp/Po1gsKnI+MTGBQqGgSLIeRJD14eXLlxVJk1KCf+bT4x+b2N9E+Mj9SAihGp2JgGYyGUxPT6NQKCgZUjabVY3O8fNl5bf/DVo/+xmkZSEAMIpe7yWl0CsyOnFvNBoHqv5Q03icsJdKpVt+Sur1kue9qgH6fbqtph74xREPBsYJDAIhILE9uFYMnJuElEiPMdX6VkRC5hMkSHDb4OxKA9945jxajo+pXJj1JdCwHN3DPMH7g6ceXsCJqRyePrOM755fgxdIZGyGJ+8Pqyd0bPL5PFzXVRnNTqej5A5kS/l+ERzGmJK37BftdhvXrl1D87/72zBfew3ZdHrQmEmNjAxCSmRTqVAEwjkMI/w51q0rqXKh+5HTRE/DMBQZp2ZavYlYD2Yos01DlshKkTKolM2vVqtqfaS9HhdkMarbWOq3WZaFbDaLXC6HXC6nrufz+YjUirThRObJuYgciei2cUjlfqQUvu+rxmU9095qtQ5E2jnnirTrxL1YLH5gs8mHEQyMKxPSPzN6E3ur1UL58mUUXRfNrU0IQbSewTQMpFL2oH379kJC5hMkSHDb4Okzy2j2QyLPOEO704ExGC1uGAYqWQubHQ9Pn1lOyPz7jNMLJZxeKOE3vnD/rn0NNL11ampKESjK/qZSKUX8DkNqsBsOs/+CPOVX5ubQevNN2KVSSCrcUN8fBCGx9wbNvikh4A6ykJQRX1paArAtMaEmT/I7JylPr9dDd2CvSGSXdPakSeaDz4ht24pMUWMprZOmgzabTUVe0+k0crmcytDrAQMw0MtrE0J1L3q6XLpwATOXr6jhTQBU34AnBTYHmX0Ghq1Tp9D88IeUzj9uKapnWMl+lLaR3vNeFR0JCREIJf+B7+PK8jK++1u/daBjbRhGJMNOl/uZ6JtgPFDGXl/I2z++EKEn+1RymQqCAI/0uigMgmpKCIWDxVgo3+K3doVkGBIynyBBgtsCri/wvTfWYPLwC1gIgU6nDSnCLB4bNLX5sPD7L76Hz1UaKJeKO4YQpW7TMurtCtvksM29s9ucc6VJJ9cPati0bTtC8A4Th9V/IYRQwQg5xGSXl5FzXTitlnqcaZjgLCzzx3XERNz3WqhBc3p6GpzzQYCwrY+3LAv5fB6Tk5OYmppCPp8HY0yRH8rK64OOyMecSH2j0UCv14Pv+9jc3EStVkMul0OlUsHU1BSmp6dRKpX2JKyzV5Yx9e67EJwrJx8FjeBzIZBKpdD88IfUhN69QJIhXe5D/Rd3XbyIkuehM5DD6AsbBAUAYEkJf4xhRuR9H8+0075NsH+Ecyh2kvFRt+1WIfF9X01npgCVhmhRn8DMzEw4DCyXB+cc6XR6u5di0Ftxux7JhMwnSJDgtkDX9eEFEqYRft36QYDBFKLwRzoIM26CAW1f4P/7W7+DSi6lCGKlUkEmk1E+zMOmjWaz2Yg+OHHLeX9AUgxdgtPr9dSUWGoGvV4Std/+CyFERIqh2zoO01B/qNNBFtvSkW17SgkpwyZdCcAaBAKvvfZa5PlUmSA5Cunb45WKcABSmGWnJthOp4P33nsPFy5cQC6Xw/T0NJaWlnD8+HFMTU2pwFZ3ctEz20IItNttrK2t4dq1a7h06RJWV1fRbrfhOI5y0KlUKpiensbk5CQKhQI8z0On00Gr1UKn0wknyXIOJ5sBoDndxGCPQd7jYIypOQNUXaCAJHP1KnKDZuTI8CrGYA0I3jBYljW0CfWg3vcfNFDvwzjZ83EbhUlypXv00/muV6ioITqVCr/3ddBxLZfLMH/0Y3gXLsC6jgb7Ww0JmU+QIMFtgaxtwhqQLSD8cp6anNzOLvo+At+HLwxwKSCcLjb7bWxubgIIfxDS6bT6gSaCP0wXvO6l8PMNjlc2AgQSsAyOT50q41d/aRG/dGL6wG4QCfYHXYJDJJr03ZlMRk2vPIgOd1T/ReicGWCr4+Jv/96r2Lx4HiXRUq+rZ7T3Wo416qgMdOmjoE/wvO+++yKa8lHnmW3bI60UqSHU8zz0+31Uq1VcvXoVtVoNrVYLr7/+OgqFAhYXF3H06FFMTk4OdQ0izXexWMTdd98NIGx+XV1dxdraGpaXl7G6uopWq4U33nhDZUHL5TKWlpZw8uRJHDt2DO5rZ9G9dg25ciVsOoxlyUHXez1MlMuqCZgWcuyhqbjtdhvdbhe9Xg+9Xk/ZkMYDqSDwI7dTDSScacWUfSXnDJOTkzj5K7+CiYmJ63JQuhNB1aJxyLmvzZTYC9TzQQRdX3TiTucyffbJoz8eCKTTaXVd71egZE42m1WftZV0Gv4dFpglZD5BggS3BXQPcyEkON+2WiPpRSAkgo6Dz98zh1+/56/hypUrKptIJIAkHG+99RaEEMrDmSYlvu1O4JtXPPQDwGDh0pXA77+2hj86t4Y/fVzi0TljaGafrqdSqSSLd4jgnKtjRJpxOqYUoOXz+X3p6qn/opRm6PV7YQY58Ad6dgkGibpj41/+8AIe6J+LyFPGhgQMKZHzvFCTG5tyqy4BTE1N4dFHH91zKNF+z60TJ07A9320WmFAQiT8woULOHfuXNhrUqng2LFjOHr0KGZnZ0cGEdlsFkeOHMHU1BROnTqFZrOJtbU1bG5uolqtolqtotvt4uzZszh79ixs28an3nsPk54Hv9/f9hWnKsWgQkHSG7JPJCehbreLVquFdrsN13VhGIaqTOhTXXUPdbqeazTBuz3ldKMv6UwG2UwGAOB3u8gXi1hY+OA4YI2rPafbxgX1MexFzi3LGhmA+76Per2OtbU1ReAdx9n1dXO5XMQ2tVQq7RngS8/b00de7uO9v99IyHyCBAluGwzzMCcIIbHVdVHMWPgrj9+P0wslfOxjH1NSjVqthq2tLayurqLT6SgLwE6nA9d18Xa1j7NrAsswIEP1DmwmYBtA2mKQkqHtA//hIsNMxseSH2YKh4H0mDrJj5P+W92G7lYFDXKibF2z2cS1a9ciPup7OdDo/RfCD9BqNTFUjssDXMUkjnS64Njd0YScY0gSUyqVkFtchPvee7BMC9yg5s1tja7BjVCryziWHn8cpccfv449MxqUMS+Xyzh58iRc11UTZa9cuYKNjQ28/PLLeOGFF5BKpTAzM4PFxUUsLCwgnU6r6ardbhee56lGVF1Hfvz4cfVZ29jYwPr6OlZXV9FoNjERBOg0GmEzrhE69xicwzDNwQAtCTMIsLGxjpeffz4yXZT8/XVf8WGTR7PZbETTbvyjfwzv+eeRn5nZsT+uf5zYrYf9ZM/3I2+xbXsscj7OvII4pJSRihv1yewGy7LUnAO63K/jVH4fn7P9PPb9RELmEyRIcNtgbw9zE1//4v0RJxuSaswMftSllGi32+rHY3NzE997u47nrhhoCa70tRKAIzlcXyITBEgZQIZz9AKOF6smTkywkcNQhBCqCWsUSP89Sr9/EEvEDxIsy8L09HREgrO1tRU2nmazmJycRGaQfY1D778Im94Y5BCKx6SEZBwBM2HAV9KecrmMyclJ1WA6NTWFUqkU0Z6/n/aDUspBpSFQ5G3Y9VQqhRMnTmBmZgb1eh1Xr17FlStX8O6776LdbsPzPKTTaczNzWFxcRFLS0shUdYCURreRDaZk5OTuOeeeyB+8AP4b74Jx/MAKVEYSDCod4AzBmYMGg4tC2AMthVOud0NlIWNN6PGPy+7SSluh5oZVSfGIeee541tn2kYhppsuxdBP+yEQ6/XQ61WU9+9w+QyOkguQ4HoxMTEgedG6Ch9+al9Dwm71ZGQ+QQJEtxWGNfDfBRoQE+hUMDRo0dxdqWB7z3/InzmAhBgADgDpBQQEpBg6EkDMnBgMgnBLLxwVeCj1gbmZqbVDw1lpojIk6Z3VJmafI/r9frQ+2lQzqhG3VQq9YH1q9ZB00lLpZKaykr66kwmE3FzIej9F1mTh+4mUgKMhjqF8q2eMJBPcfydv/F1LMzNjgwObgTIpWUvQk7X9b/H1S3Tc8jKb2ZmBnNzc2CMKU/ura0ttFotrK6uolqtYmpqCseOHcPJkyexuLgIy7JUhpXcfBqNBlLPPIPca2cBKcGEgCEEdlDDXYgcDYfSCXupVNpXv8qtJqXQm0P3Iue+74+1znhz6F7Z85v1nUGVMz3rPq5chr5TP8ie/PtFQuYTJEhw22FcD/Nx8PSZZbT6Pir5FK7WegCAkPexMDMLCQkGV3IgcCAZgwOJH/30BZjCUe4JxWIRk5OTmJ2dxfHjxzE/P6+8r6WUqqFPJ/o0nVMnmnSdGhhrtdqO+4BtKQ/5bOtyHvo7Xvq+UTr+W8X1h97/9PS0IpWXL1+GaZqYmppSWWW9/4JxjlKpGMpfBpIXIJRtOR0XX/rwEk6dOH6g7dmLdO92334GFVHviGma23Z72m10nawAafiUEAKWZanzKD58KtwPAq1WC+vr67h48SLeffdd/OhHP8K3v/1teJ6HVCqlZEWFQkGRr1OOi8wg885cF3IEKWNShll6xjA9M40HvvxlFIvF654rcDOkFLqV6DBiHifo4wZZNJ13L3JOt7/foHOESHutVtvTWtS27YhUplwuJ8YC14H3/yxIkCBBggNiXA/zUdC10+aAbEgpw6E2Ay5FkhthWOA8AGDAEB4M6asfUsdxsLGxgbW1NbzyyivhIKBUCqVSSVn3HTlyBAsLC5icnEQ+n0c+n1dlbLJZowFAdEkLETud4OlSHnLsicOyLGQyGUXU9OvjDtnZDeeutfB7v7iGZ9+qwgskLJPh8bun8NUPz+P0QlE9Lv4aowKMgzxu2PYTeZ+cnESr1UK9XlcuLKVSCdPT06r/YqvrDe2/2Oy6KGZM/CcfnofjOAci5PsBkW4a/DSKkA/7e9g+kFLCdV0VMJJNJJF3cgIaRqCCIECj0Yhk2qniYds2KpUKDMNAtVpVunvSV5fL5TB7H3v/I49ZEMAYfI7yuTwqlcq+9tsoHFRKMcr7fJTkZVwMaw4dRdBv9Wx0t9uN2LLW6/VdAxXOOUqlkiLuhyWXSbANJscM/5eXl3HkyBFcuXJFTaZLkCBBgtsZ9a6Lz//jH8EPBIppC2uNLrqewLZJISLXsujDhY0TWMOHgrdVU2AwsB8kHTI1jdGPMpE7Itf0wzY/P6/011NTUygUCiqbrjf70VAfx3HQ7/fhOE6E+O9mfbgbyHN/lJxnN9vHYR7tqnchbeI3vng/nnp4Ptx/sZ8Z/e9xru/ncXGQXKXT6ajhR67rIpVK4WzDxj9/qYaOGwycixh8KREIIGcx/PkHMnhscTtY3C2IoKmkccK926X++N0CmXFAEy+JvDuOo4JKPfOu66B9399B2MkpaNzKgOd5aLfb2NraQq1Wg+d5ME0TXzh/HkfWN8JmX8ehctdOSAkYBphhoPDkk1j8h//jvt73OBjmfT6KnO+nOXQcYn7Q5tBbBbotJF3u9X2Ty+UiOvdELnPjkWTmEyRI8IFF3Lt+IpdGv9mHkBIG4yDrPCkBBgkPaaS4jwczbUzw0MqSJmsCUXJUq9XQ7/eVlIGWfr+PZrOJy5cv49VXX40QaMoa0w8hNVnm83k15KdQKKgsmJTh9FuS8biuGyH9tIzSqoZTdENHn1GgxkZ9Oy+3JP77b15A2wkiHu3hOsOs9t/71nmcnM7h9MLeU0L3wqiM+DgSljgpTaVS8DwPm5ubmOj18JeOp/BqK4tXN8NAJG0Y+PiRLD5/zwTumcnuyILTEifv+nvcT1CiWyqOes6ofUI+7HTsGWOKvJdKJdVX4XkearUaLl68qDLsjUYD7XZ712Mz6j7TNCM6drIDLBQK6Pf7WFtbQ/O/+9vgG1V4hgFjkP3Xg1Q1nKrfhzkzg+wjj4wtdxnmfb5bs+hBmkP3Iue3grzlsKEPRSPyPo5cJu4uk8hlbj7uvLMxQYIECcZE3Ls+ZXJM5mxsdlwEUoYZehlm5k2DY6qQwn/zJ+/BJ478MqrVqvLX3tzcRL1eR7/fx8zMDObn5yMSmna7HSlJk2c2kQIihToZI/KskzNyTpmZmUGlUkGxWEQmk1GEMO61TYtO/CiTr2f5HcdBEASKbOlTOul+Xbv/v73HUO8w5E2g2/NDzTPng0uGYoqj3vPwOz+7iL/71YcBHJyQH0Q/TkSM9OPDMuacc/R6PTSbTfxypwPJDGRLZRxbmEMmdeuREeqhIPLueZ4in4VCQTXnUjB59epVdb6NclWifg5CfD/r1pM0hIcGKw0LXihwXVpawtrUFLqmCYsGu2lNnaXSdoDnb2wg+8gjmP8f/oHK8o/jfz4uxsmc3+zm0FsFNHODyHuj0dhTLkPnAJH3ZMjWrYGEzCdIkOADjbh3fT5lwjI4Wn0PHceHAGBw4AsPzuGvfeaUcsuZnp6OrIcsLzc3N7G+vq5Ifq1Wg2VZKBaLOHHiBACogTiUJSXJDJFXIpxkD0eEjTL0RExt20Ymk0GlUsHk5CQmJiaUHp+cP4ic0yAVctyJk34KOkiHT9tEpN/zPDDGEEiG1zYz4AAYkwgCgf4gK6wHAYEw8PsvvofK5R8gbVuRQUhk4UiXuvSDvMVJYzxMJ76bhnw/KJVKmJubQ7/fVzalb57fxMTEBObm5iJTJW8m4nr3fr8f0bun02lwztHpdLC2tqYIfK/X2/drkdZ9mN3j9eiaDc5DdyDOYVsWLNNUlpm+74eN5RJAEGBzawvLL7ww1nqpOXQccn47y1sOG1Sd0bPue8ll8vl8JOuuNzgnuLWQkPkECRJ8oDHKu54zhoxtopA28d98/j589ZHFXdejW14eP35c3S6lRLfbRaPRwNramiL5jUYD8/Pz8H0/QpopW0zWlUTq1tfXIaVU5J4IfqFQQLVaVTp9arIjSQ41m+VyuYi1pT6FlAjd1NSUIvy6HIJkDZ1OB9c2G+BvvQGbCxgGQxC4cFwXUuoZPQZh2BCQePXcmyiktid3FgoFtR2UIaf3QtsZt+G80Z775KU+PT2NRqOBarWK8+fPI5vNYm5uDqXS7nan14tRene6nW6jZt69LP5GIZVKRcg6XR/XcnNYc+gouYu5tQkjCOAP2Vah6dLDybhQ7kujiPmN8j6/E0FyGZ287yalA7bPDb1JNZHL3D5IyHyCBAk+8Lhe7/rdwBhTJFUfGS+EQLvdVvrljY0N5elNHt90H/nVk62gEAKu62J1dRXLy8sqm59Op5HP5xVh1kk9edPbth3J9OdyOaRfeAHsxRdhGiYYZwNpETn6yMjIzNInP4m0PYVW30c6bSM1qA5gkOUPhIAIArQ9wIKP2ckyep0WNjc3sba2puQ81EBIkiKy2KQJr3pW3rIsJSUZ1qhLmerrhWEYqFQqKJfLqqnzvffeA+ccs7OzmJ6ePpTXievdycpPD+i63e6+5CQ6aLhVPNM+rNJA2zIOQd/X5FDGw5kNgwx9eMwBBjb4O7weGAby5QoWH374QO81wbZchgYyNZvNseUyRN4TucztjYTMJ0iQIAEO17t+HNB0w2KxGHEI830frVYLjUYDrVZLDe5pNBqK3FPGjZr7SFcupUSv10Or1UIQBCqznkqlkMvlVNaeiHM+n0exWMS9z3wLpfPnIQ0jJFpARDLDWEjA4HmwLAtPfuav4OmXlsMGXM5hcQ7y8pQAhJTodTx88YF5/PU/8RF4nqcqDiQJabfbyne/0+mg3W5jY2NDyX/0YTh0SYEIBSp6vwG9n1wuh0KhoPoJiPDvp2HRCyQCI4X5xSOYm5tDrVbD+vo6rl69qmYJpFKpyHMaf/BNtJ99duj6hAz7FhzHhdPvo373XVi/7z5F3IFQ271fSUg2m92RaS8WizAMYwc539jYGKo9H9f7nCpC4wwoMk0TK7/zb9ESAqzRiKxHxi9v4tCmOwE0aE7Puu9lkUlyGSLuiVzmzkNC5hMkSJBAw/V6118vTNNUP7w6+v2+IveUtSfJDpXUKbvrOI7SuHPOVYDQ6/UihJHK6OW1NeQA9GwLnG1bYlJWFQjlEKbvo1qt4sMTDr5lMmy2HVRyKZjG9uOEkKj3PJSyFv7Sp+/BsYXS0OZc6hvQM9Q0eKbVaintPslPXNdVThtCCEXkSXoxymmEggIKYIrF4g7iTxKfc9daePrMMr53fg2ekLAMhifuC6szDzzwAJrNJtbW1nD27FkUCgXMz8+jUCgAANrPPovWd78LWFZYpZAScvB+9cZSQwig0UDnyBEAGEtClM1m1fbSkk6nlZMSNccuLy/vy/tcl2XtRdD3S/5uxtCmOx1CiIg7Vr1eH0suo+vc9zs1N8HtiYTMJ0iQIMFtANLJ6423NDiq2Wyi0+mgWq1ibW0N9Xpd6atJQtHtdpUXvhACnHNIKcPBUQNrSREI+DKIkFAJIJWyYXADfEAu2lfO40/O5vDNZQtrzS4MAKbBICSDAFBImfibnzmGe2fCBkoVGGjIZDIoFosYBiFERG5CQ4/a7TYajQY2Nzcj7jyO46DT6cB1XSXjoSy+bduwbVs9plarwTAMFVRQVeONbg4/bE7CERwGo/cD/O4Ll/CtV5fxtx4/gacensexY8fQ7/dRrVbx0ksvodvtgnOO8tUVZCDRN00AUpOThJUOCnasdhucsYhbDy1USchms6oCkUqlIgEY2Z82Yhlv2s/UOzFO9vxGNocedGjTBxl0fhJ530suYxhGZBhTuVweu/8hwZ2FZGhUggQJEtxBoGZVIvnXrl3D6uoqqtWqIvd0X7PZxNbWFn7tyjJO93qoE7lTBBSAxIDwc2QdBxfKZXznwdOwLAu91CSupo5iWZYhWUiC7y/5+PiMxLFi2Kio6/NpIQ3/QckkNYaSxpwIP8mTGo2Gku9Qhp+aSD3PU6SXHHWaRhHfah2FKzhypgTZ5kuElYZOwGAhwBP228i5NRVE0LCwIAjwqxcv4VS9Dieb3d5/2F7P4OAg3eth/cgRvPXlp9S+oErBsOZOveqwF0FPmkNvH7iuu8NdZq+qSqFQ2OEuk7j1JACSzHyCBAkS3FEgWYlhGEqOQrpxIrvUaEskl55H2XMJCSmIgjJIhFaDnDHk83nMzc3B931kvR6KvbM45rhwBEPaAFIdE1c3LGxpmWV90afjEoklUkvX43r0OPTJtcNATZ1E4Ok6Efxmsxkh+89vFNAPgJTso98P5UAYVCWoOtEzM3iuKrCw9voOtx+qJAgh0O11wVi4rzgNlzIM5cPPOUehWMTJkycj9pxUedFnDNi2fUdpm3frK4gj//jjd0xmn+QyOnkf5f1PSKfTO9xl7sRBVQkOB8mZkSBBggS3Kfr9PjY3NyPL1tYWNjc3R9oXZjIZLC6GNpukXS/99KdgjgNuGJBCABJKekNgLCT1vudBCIFUKoV8Ph9m7TXttu/78H0f/X4f7XZbyQT0bLJhGIrYx8k+PU4n97rzzjjZZ2qIHeWTrg/NanW6+K3fehNp6SNvp+C6DnrdHnyaykqkXgSoZZdgbFTBIRWZtywrbMi1LfA+Ryq17RhD9+dzOZWtDxwHs7MzuPexx9S+IukTBQXdbjfSCMw5j3jq0z7Uh47RbbcyqK+A7aHhpqbY25HMSymHusvsJoIwDCPiLlMul9+3GQcJbk8kZD5BggQJbmG4routrS1F0vVlr+weIT4RNk4ULdNSw3gGT4CQAkKEz6NBTnygRafJoaRDZ4OMfaFQUGScAgV6nOu6EScVymQTKSXQduiZfMpeE3EeJtuhhtBxQO+hUCjAyrmQ/B2kbIZMOgw2fM8H4xzbq2PwYMCwsvj8U19FJZ+KNIb2+32Un/kWWCvUwyunlsG+9oNAZeXpPRYKBUW+qT9h2DEiEkikn4YukUWp7mZEf1OgMYzwxwOBmw1mWTBjA9fi8Dc2btLWXD8cx4lIZfaSy9A8Cj3rnshlElwvEjKfIEGCBO8zgiBArVaLZNZpaTab+1pXnBBSZlwneHF3EqZfZwyGacIwOAw+mLI6aJYMHAfTMzP40Ic+hI2NDaUXJ/16vV5X5JIy9yQXME1TZe6JpPq+r/TnOqF1HAfdbncHwRlG9PXJskTy41n93Uhr1jZhGQx9NyTNhsGRSqeR5Rzc4ODcgME56j0fhYyJ/+vf/FsQvrtDwiNePAN2dQXZXC7M5pOjjZTwXDck+FLC9H2sXlvF+e9+V1mG6ott2xFLzvg+0En7KOIPQBF/Ivqu66oAKgiCoUFdnPwPCwY+yAiCYIe7zDhymbi7zAd9PyY4fCRnVIIECRLcBNBUxmGSGLJb3A90NxadtMeJ2TDNtWmaqFQqmJycDH3TL12GubGBTLm8q0Zbco6pyUl86Mtfhud5qNfr2NjYwMbGhqoU9Pt9NXTIdV1sbm7i2rVr4Jwjm80qW0Ua9FQqlRSpJJcZPUAgCQoRU92bXQf56Q/T6BeLxaGynWw2C9vkeOK+WTx9ZhlChFntfEyeI4SELySevH8WmZQFpKwdEp6Vchltw4SRTm8HJuQIJAWkkIo4W3aY1afpvqurq+r9kTSHtj+TySiJUTab3eG9v+vx0gh/nPzHQaSfzifK/utuO3rQQGR/VPZ/HPIfDiYLh5NxznEr5aZJLkNSmXq9vqdcxjRN5S5DBD6RyyS4GUjIfIIECRIcEqSUanJoXBJTq9UONNFzL4kMkT9qxiTQlEci7LRUKhWUSqXIY1e+9W20ggBicxO7hRT6gB/LsjA9Pa2sMnXys7m5iY2NDdTrdUXMSV7jOA6azSY8z4NpmhHfdADq/dm2jXw+r4grADWBlIZQOY4TsbGkTHkcRPTjmXzy5L7XziJjSGy0+6hkLSUpCrdHYrPropgx8bWP7O3kpg/bGqZg91otTEyUsXD6tMrq60479D5pv3U6HTU5lyQ0NEDLtm0VIOlDwXTCP66WflimfxTxZ2rfbAeTQRCg3W6j2+2quQE0H4CO0+zyMvKeh26rpYaMhQ6eYaBnvI/Nvo7j7HCX2e3zGpfLlMtl5PP5RC6T4H1BQuYTJEiQYJ/o9Xo7suu0uK574PUeRCJTKpUUSddJ+8TExNgNkYcx4Id08/l8HkcGA5F831fkaGtrC9VqVVk6BkGgyH2v10Oj0UAQBCorTfIYIQRarRZc11VkNpfLYXp6Wmn8iUDTBNxOp4NOp6MkPDppHrbdqVQKv2RM4jlnFmsNHwYDDA5IcAgwFFIG/ovHZrGUC4n2qP0qPW9vvbfvw7YszM7O7riLCHyc5Cspz+B8IImS67rodrsqs696G2JyJFWFiEl4dCkPPZfemxBCzSWgnge6HLZtpOEftn/pvQW+Dwk5eJxUnp3cYLtmvA8bJJfRyfuwao+OTCazw13mVm84TvDBQULmEyRIkGAIqPF0GGkft/F0FHQ5wzDSzjmPSGRyudyO7DpdHsZ0xxs14Mc0TUxNTWFqakrdpksX9AwoyWlIotPpdLC1tRUZhESNrp7nodFoKEJpGIaa8JrL5bCwsIBMJhMZVkXP2draUj70+vTZWazgM6yGd9kUrogyhGDgkDhltXDaaqHx2lv4nTdtlc2vVCoR+Y79yU8gD4wlFRkVEBmGoYZGDQM1DseJNJFpIsR0fpGj0Pr6upLMxLX1dJw45+pcpIw0kfu97DEpIKLX1WU5tC7GGBhnYAitUzkN1NIqGULbfiGFClCux56TqmU6cW+1WnvKZeLEfS+5zAfVdjPBrYGEzCdIkGBfcH2Brusja5uwzdvbA1tvPI1r2ffbeDoM+5HIpNPpoZKYycnJO0p3S4Sbhg/SRFOd4JNWnkghXScynkqlYJomCoUCZmZmYBiGIrWtVgu9Xg+e50UmqubzeZTLZRw9ejQyVdX3fdRqNVSrVWxtbeHBeh3t7mU0ey6E0wOHALrARix+Y4xFtO2pVArZRz6McrmMqakpFIvFiD7/MIIu3baTqhqu66psebvdRqvVQrvdVgv56ccdVuISGepZILkWVYFM01TyHSL3lJGm92/btrIbJY982u9UEbBtG6sv/wLt9y4iM8IyVKfXUiLi2BOH7vOvB8HD3GXGkcvoOveDyGU+CLabCW5dJGQ+QYIEY+HsSgNPn1nG986vwRMSlsHwxH2z+NpHlvDgYun93ryRoIEtwyQx5L5yWK+jEyRar55pJ515XBIzOTl5XRNRb2eYpqn2AaHb7UbIPTUe6sQ+CAJ0u120Wi1YloX0YEhVqVRCOp1WXvdEctfW1tDtdpW9JQUVxWIRMzMzeOCBB5DNZpFOp8PhT4N1x88X8oEnmYvrumi1WjveFxFY0uYXCgV13Ikw6pNwaV06QR92nS73Om8Nw0CpVEKpFH42ySpUtwjVrwshdkyXpcBTrxyZprmDrNN+i0+l3c2VZxjoEYwxGIPBYMNA1QXP87C1taVkXKPkMkT6GWPIZrOoVCpqKZVKhyaXudNsNxPcPkjIfIIECfbEN1+9im88cx7Nvg+TM5gDK7+nzyzjO6+v4utfuh9PPbzwvm0fldKHOcVsbW2pxsnDwF66dpKWTE9P79CyF4vFDyRh3y+IINJwK13jrGfvge0MM5F7kuuQTjyXyylS6fs+Wq2WIvlra2t49913IYRQk1eJ4E9PT6P0yquYOvMi5gwT3DAASAR+AD/w4blhZtxxXbiOg6uLC7iwtIR+v48gCNSk2eOXL2NhdS3025cSy1JihQEAA2NQWXDTNFG/6y40P/xhZDIZpNNp2LZ94H1IunkKKoZdp0vTNOH7vpqI2263VXWDJFD64CoCVVVIDqU7KFFmn17TNE343S6E68JbX49u7GCgFkEO8WmXUqLVakWy7rpcJv/yy5g9dy62Wg7btmCZ1sBByAbnDJlPfQrZ06dV0Eb9FsMy/fR38rlNcCsjIfMJEiTYFWdXGvjGM+fRcnxM5cIfQwI5fXzjmfM4MZXD6YUbm6HXG0/jpP16Gk+HYS+JzMTEBKanpzEzM4OpqSlF3JPGuMOHYRgqk0ro9XoRct9oNHZoxolQ+76vpsIWi0VMTU0pG0xy4Wk0Guj1eiqL//rrr+PDf/zHmH3vIoRhgGG7mZMuLcZgSQn4PnL5HLKf/zyazSbq9ToajQba7TaWVldxfG0dwR5k0JASjUYD36/XI/IWkuvoMpBSqYRMJrODlMcJehykYdeX+u//AdznngPnDIxxmIyhzBkqGr2WAIQIYH3iE+Cf+cwOzX5cwkMNuo7jqEqVZVlI330XUr1emCVHOGXY4EZIphlT28A5g//II3j77bfVMex0OrvKZXLnziN/9nVAa+jVCbgA4CAMFAxuYPpXf3XHOuJWnr7vo/XMM+j+4IfhOqSE1Ow5yVoTEgjOnQMGVRSaKUCXqXT6lrLdTHDnISHzCRIk2BVPn1lGs79N5CUkKI/GOcNk1ka14+LpM8uHQubJm3yYveNejhMHAf0wxwfpAEChUFBZ9pmZGZVtL5fLh6KBTnBwkDZ7YSGsCJGcSif4cfeaIAhQrVaVVEX3dCdJDHmot9vtMAvPgO5AcqKkLRpZ5JwjGwRot9rodDpKn089AXPvvAu2VQMKBQhFFgNIQT70AlICOddFyk6hUqns0Kb3+32srq6iWq0qfX6xWMTk5KRqMCYin8lkIk40enMxVY5M00Q6nYZpmui9+CLcH/5wLK13gRuY/7N/dsd9ugvPMDceZSn62GPwHntse50DwkzNut1uV1UGgtoW+H/8jzAMY4fbjmEYSrZE79/942fRffvtHTKXbWId+tnLahX9fh9Xr16F53nKVYkkXPEl/+//PdK/eAViWICuSZ042YpqMhoJgDMGX5MZiX4f3ZdfRuMPvpno5hMcGhIynyBBgpFwfYHvvbEGkzOVkW82mvD8MLvFiXAI4NuvLOMvPpRHZSLMGlqWFcmM6dep6XCYU8ww/fFhYZhEJpVKKUJApH1mZgaVSuWOajy9U0FacCKNqVRIiPP5fGRI19bWFur1OjzPi+jN9eFbpmkq55xMJgPbtsEZh6XkLhJChNlZIuK+Hz53c2sTzz33HPL5PCqVijqPstksDMOAPdDGxzPGQJjx9dfXcequU8h/9auoVqtKSkLTcQEogklOPxcuXFDnMRF1ev90Ps/OzmJ6ehrFYnGkbOd6td7DXHh0dxchBUSwLYcK5UnhUDDpuFhdmMfyyZORY6o76xiGoQZnkRTKMAw0Gg00Gg289dZbmB942Pfi3x8k4Rn8ZwQBWltbWH/55V3fL50TtuvB5hx+Pg85sNOMBwgAkGk0wKSEoVcpBq8pPW+7uVdK+GtraD/7bELmExwaEjKfIEGCkei6PrxAwjQ0Iq5Zz2Hwu+XBRN0B/vE/+X8jb3P1g0taXGq2o0xdt9tVutphGtXrRVwiQ82A09PTShJDxD2Xy0Ws73Rf7E6ns2N7DvPvw163jtvVdSgIgn01gsYlHsNQKBRQKBRw5MgR5e5Cbi+u60ay4JTFpgABCLOrTD9HoQ2HYgC2ajhy5Cg+9alPodFooNPpoFqtYmNjA9bGBqZ9H/1uN5JZNgfnPq2fc47iRBkPPvGE2m4ppfKRX1tbi0zbbTQa8Dwv8rmRUsJxHFy7dg3Xrl1T66EghQIN6uGYnp6G9H0cRgt4PLO9+R//EM4PfghoVQ29umEDsAAUB8f7/PS0OsZ6Ay3Jh4IgQLPZ3OEyRS48pV4XOSnhB0GEvIfUWxumJgX6joNqtbpDNqU3yurfTyB7TBUYMO2cCKVBotUCyGd/N5ldEACJ/j7BISMh8wkSJBiJrG3CGjS7EorFIoKYbaAbcDDhYm3lCi52WmoKJDWV6aPf9UY6ul2/jD92t4V+hIkgFItFNUSJMpPkHkMaaR3X8/eo++KXh/Fae722jjdW2/gPr63jh+9sKdehT5+q4E8/NIP75vIjn7cXDhpo6M4ptMRtFel+yq7vJ2gZ57F6I+bc3Jy6zgwLHceD8Prod0Jbx06no4LA7CuvKOIlhIAYyEW4YSCTTkMCkAMCZ1sWjh07FmkmbTQasH/8k3A/uC7cQSaXAbBsC6ViVJZGZFwnxVJKlemnzwfp4dvtdoTgk6NLfNowravb7WJdaz4VQuCT589j0XHQrtXCz5XBwbkRCVwkAAQBtmpbWPn5z4fKUkiqRkHSXe9dxJSU6Jm7949ken1YloUTJ06opmX9c01NqjRcjBICtJCEp9cLh2r5lAUfaNZpyNhgheAstIGdmZkBY0ztz1QqNfR7if3RH0FwDjuTCacCDwg99c8QDrdjJ0GC/SEh8wkSJBgJ2+R44r5ZPH3mChzPhRQDbTll50UAISTADCzxGmamKuhm0+pHV/+x14fJ0BRNcsngnKsfbgoA4j+sRN5pUE86nVbWe5Rxp6ZA+oEmD/d2ux25zdCyoncSQtehNyOuQx1H4A/ObuAH79QO5DoUDyKIdOvTQIdlzGmJOwntFojECZIeGOnVlvh5oQeAlGHXJS1EPLvdLprNJt6r+/jZOsP5poFAMBhM4p68iw+X+pixHJW9n2y1UAgC9J2o9t7gBtLa9FnJGLhhwLZt9f4ZYyiXy8gPJCH5fF4bhhS+ByHENvGUEp4fBjWmaSKbzQ61dKTKBVUQSqWS8u2nzxoNPFtfX4/YatbrdbTbbfVcAHi41YQQItKPQkeI5EYAYAqBTruDK5cuqWNCDcP6xF3C4iCYF0LSwQRY6DATzWozFZzQenTozjiWZaFYLKptpwpKv99H+uzrKqNO+xSAOi6UURftNqanp3HXY49FJhD3ej2l2deriPPrGyj5PjoDe1SS1zDGkc/lhmbZh53jd+L3TYJbBwmZT5DgAwxdiiEDb0fT6dbWFtprHUjnCNYdA2l4kd8uKYE+LNgIcF+qjkpm23FEHznveZ76wSSCp3uGe54Xycpns1kUi0Vks1ll00eaWRpOo5N2alj0PG/sxlT9ubof9qjbb/UA4CCuQ9QouZeUha7H9eajEJc5xafd6rdT8Ba3ASTpky6BIFJH66eMMAWBdH03vN5K4Q9XM+gHHJwBBpPwBccvGmm82Unjq8clPnYybIydfO0szPUNFPIFRcSFkOCcwbQsJd7wGUO2kMfMvfdG9O39fh9b2SwCYHv+AB2TIECv3wubyRkDDwLU63W8+7OfKWmangEnm8hhFZtR1SBCqVRCsVjEkSNHlFyFMujm8kpkPbQGyj6HGX4JLsMG45deeklV3vT9TqD933ecsHIhZUjeDQ6SuxiDCbC6jpwCFHrf9N71qbV0nuhVPBXI0bAr04w4xwgp0RoQcSkl0q6LC+fO4Yd//+/vcK8h6JWBzzXqKAgBz3NVlCMBMCbQd/pgjANSwpISjNYxzAqXvjvG+PwkSLBfJGQ+QYIPEKjx9GdvXcUfnF3Hi1cduIEAEwEWWQ13GVVM8uioyzyAj1rAi94R9GGBSQkOCQEGCQYbAT5qXcEkjzrN0IRMItcTExNqG0jnTKQ87n3NGIPv+yqLThNSaTBN5D0Jia4P2MJT9pTDhtbQQgSFyNZ+sB/yr8shbgZ01yEhAvi+ZqEnJTJMYKvt4x/8ux/hqcV+JLure7UTsYnfRn/rVZZht8X9uikIoiwzuZLo9xMBjBNTXdNMf+u3USZev07r1Ks7dP+qY+EPVzNwAoacKaFaQZiEkAwdH/i9iwxTtof5TB8F10FqcL7o0o0gEGg2Gqoh0nZdXD7/Bv7Zb/5mxJsdAD735ps45nnodtrha+kNlIo8S+Q9DxfevYBv/Zt/E9luvZ9D3zc6CR11nfZTvE9DJ6tquJlhqG2ioMV1HXiDAVVFIbC2toaf/OQnO7aHGnD1ikm/34McBItxAmvb9sC3P7Tk7HQ6WFlZiUydpcBd964HoM4XnfDTuayCj1iAoWvdAcAcuOHoU4Bpn+nnkWmayLz5pnI9Ghw2JT3yPV+dA0a4ceBCINACCnoNWicAyF0sNhMkOAgSMp8gwR0GIQTq9fpQe8dGo4EL/gRe9I7AhQGGbWL+DiZxOZjAR60rOGHUIus8YdRQZH28E0zhSjABAQ4TAY4Y9UEAMNwy0rIsTE5OolQqKX/sYrGosu70ww1Albup1N7tbgcVlEEmeYdhGFjuMPx0VeLVqkAgAYMBH5ri+MQCx9GCUI8dtk2U7c/lcsq9hIh43ItbJ2c3IgAYdtsokC45vvRdD995/Ro4JAIRZl7DgIncN8JLn6fw08suOs/9B8jAU+ujdQ9rVNQJuD4YaxhpJiJHz9XJPbA9QZSu03PJF12Xy+gyGv2x+uUovfyw2xljePZdwBEMeRtgEOj1+tt+4JBgEmixFH7n5xdxd+sVfPnqNWR8H51uVxFSXcJBdNEQArVaDW+88cYO8tzvh3IT3xtN4ORgcT0Xa2trkQy8rt/Wg6P4olcyKHDSqxbx8yi+n4QQgAztMtX7k4CUgp6kLCQB7DhvSSq3/f61Y0FDssIbwA0DpmmAIWwkzefzuOuuu9Q6dLtY2jY9WNSDP3q/etAI9erqBFC3ccdBLp9Tr0c2nXoiQa84FF95FXx5BSk7pdalGqA5SYYYZBCAF4sQrRbSyRTYBDcZCZlPkOB9wvW4jUgpd4yaJ+Jeq9Xg+gI+OEwIGGz7h3xTZBWRHyWZedE7giLr7yDok7yHSX4FHzWXI+s2DAPl8pTyX6cm1GKxqPS2esZtFGFVTWoDUOObrsml5sTnr7r47XM9dL0wu2owBk8Az61KvLIp8J89UsAnj2RUgKA3A1JGL+6KQbBtW3mYkysPXSerSpIGxT2p9dtIj+s4TuibHctw6wRF7yugbY1LJ4jA6BlbIk59wdHu5CEk0O26ajt2nDcIIMDgCiA9pPGYggtqENUrHHGyrRPuOMketuwFnYjnXnoJ2bOva4yM6Ve3b2NA/6GH4Dz66I51xIl9IIFzjTZMLmAZHFKyATkU0RVzgQ1rDov9F8LeDimR32MgmSElUukUZmZmdgQd+c1NmM0Wipr8RzWV6usAUCpN4MMf/nCEzFJzMJ0fQBgI0fEhv3gi8PEgh0g99afENeJ0XhHRzkoJe4gUxGQMp00Tf6c0EWrdwbB27Cg27rtP+dunUin1eZn9rd+G/eqrmC6Xtwm9JlMBNQP3+ygWi7jrrrt2vKaUMtI8rVc86DylxtXymZdgXLyIdLcbavN1Oq//LSXmZudw8tOfjgTEepVK33c8m4XgHOaA6Ksltq3+iNsTJLgZSMh8ggQ3GWdXGnj6zDK+d35NuY08cd8svvaRJTy4uO1uQbZ0w4YnbW1tDbXk2xRZvBMs4kpQggAHh8ARo6HkM+8EU0OJPBAmr9LSQx8W3gmmMMmvaPcx5RJDk05LpRIKhQLy+fwOr2udAB5UZ845Rz6fRz6/7cAipcRL723gd7//CvoBw0SGD9KaA8IrJZquwL94uY6K5eOB+XCwjm3bimC5rqsadInU6OS63++jXq8P1XxT427ceYeIG7Bt0adnEUkWEHcA8TxvKLmnv+OkPZ6BVZlawwRDBoEMJ5WmBmQ8YqnIQhlJzjbwd/7yf4t8Nr2DmFOGdTdSTLe5gUDPE8jaBlKmsefj97pNx8offRet118fa5DR7MwsFh9/fM/zqd51wX/yI6S5QCZtQUJuV1m0AIELDoNn8Muf/RwyRxbRuvDeICAJs82M0TZv+5cHEpg8fRpfeuTDkSFNQRCAWzaab72pvVe27YqiEUAPQOHBB/HLH/kllXGOZ9Tp2FCFiM5hkpfFAy7f99Fut9HtdpXTjO/7YIwpX/hMNguz2URxECiYQmxrv2MoeB4K1Wr4h+/jxMkTSP/ar0VkL9QjQ5p7KaI9DKqmMZAbGVLC9Txsbm7uqAxR0JJOp1XTezqdRiqV2vE5EE8+iaCQB2cc3DDU8Qo/gwFEsP15dj/0MJzadvWRXiM+TdeyLKzk82gFAbC1pSoowyDHsEhNkOBGISHzCRLcRIRuI+cjbiM9V+DfvXgZz/ziCn7tHhN32duDbvYj6XgvKO+Qz/gw8E4Qymd+ybyCK0EpVLozQAQCrucNCEVITTjnYIaJq5jEsRM25mam1Ah5ICqbME0Tnueh0+kobTs9hmQxw+QAo0ipft8oMMbwB2c30HEFpvNpSCkQCAHGBiV3IVCwAzQdiW+d2wKrr0Qy3kEQKIKgEx81yGbI4rruUKnCsG3jnKs+AH3RSYgu6dGDnmG3x8mNXv6nvykL/84LW3j2Yg+GxbQhR9sQUiKQEh+Zs+H2u6g5vbE0/3pVxTCMsYPR68X1DjIiEKlmwoPBgK4fwAt7OlEoFHY8vt7zkbUZHv2lD0N+6HTYvD0gqHEJ0g5Z0oDQ0bGzLAvOxx/Fxse3Kwdx+VCkmsEY8oPPl07oicDrAaBhGMjlcsjlcurzRtW6RqOBZrOpnHFoyi31nJBsKpVKgXEDvUuXYA9ILN54A9jaArRzSFWHUjZkOqy2sXodnU4HqxcvDu2tmHIdGCSVYduBj5LfYDuGsiwLR48eVZ8RutR15gAiQTNdV432X/3TcL/4hW2XpcHnlmCa5g6yrixKd/nOyY8RKBK8lRX0Nzf3PC8T4p/gsJGQ+QQJbhLOrjTwm8+cR73TR86QkH4A3w0ghQCXQN218C9/0cXn7LdHatBHYRz5zBn/CAQYDAaYpgXBA3iDLB3pPhkYIAQcKfHSq2eRMaT6AUyn08pdZpjMIr7opOUgiBN8xhgCCXzrlS1g0JzX6/fQ7/VVGZ3RABdm4RfVAMcbb0L4bkSqEHdPoUZdItyZTAbZbBZTU1OK9OsNobqTCv0dz9DHF92JhdZP+1JfiKgcBH9jqoFf/KsX0ez7qGSt8BwY6MGFEGj2fBTSBp46PYVczlDkULcT3As/XXHwv77WHsibwmC0KyT+3YuX8e3XVvB//swxfPGhuR3BwGEi0jwqJRzXwdraWsROkBaqhhDJvDcf4LmWQJ8NBgtFtO8SQgKuz/DhcoCVK5fUa+pkL947oFc1dGmS3rytO9zoUqy4gwtVYwhxIkvBAOcc/X5f+a/HLR1pAiwAlSmnSaqZTAalUklJ4qxHHkGfcziD1839f/4p7GYTolCINP2qbDpl8KWE63pot9sq2KNgIZ1Owy4Ww4bSmNVkuEO3ZS9SSkxPTWHu4YcjnzH9fXPOVaMrWaHq0hv9sRRAFwqFCHE/6PdQ6ctPjT2pVZ96uxf2EyQkSLAXEjKfIMFNwtNnltHqezD9Plw/9Cqmn7TdJC7jIC6fYdDcPQwDOc7RdARMzmFyjlLWgu/7sC0bgRDheHrKPksDJvNRKeZh8u3MJk3NpKyensGlv/XMGa1Pl4Po2cjdyK8a2BLLmLVdCS+Q4CzMNEux/aMvIRHKnwOAc/iQaLg9TBVDYk5+3JTN1gcW0XaSHl1KGRl6RUQ/nU7vIAVCiIiVI2mS6fow7TohLjfRM6jDFr35UW8qZYzBZAz/2SMl/LMXNlFtOzA4YHIGX0gEAiikDPwXf2IODy5N7KiCUHAyTOpDQcs7mw7+1astdD2JggWELy8BDggJNHse/ofvvQvZWsOJCSuynXrWP245GXdXAQC/3YIc7Nc4eafHSUgYvo+t1VW8+53vRAjgqEzrvbaBM7yAlsOQtwCD0xRXBgmJjgMU0xy//ugS7p8r7CDoRAoZY2ofEVHWh1/1+300Gg11julOMzriASW5qeh9BrpcptPpoNlsotVq7WgC1Rs3fd9HJpNRMjjq96D7qBdldXUVnU4HjuOobfhYrYYpIcJAP7Y/GefIUJNot4uZ6Wk88KlPbXvmazr/3mc/C89OqYoDI/kWHUtl6wPYv/zJHVa1o6b8MsZUgiGbzSqyTpn89xP7If4JEhwmEjKfIMFNgOsLfO+NNZicgxkcfuAj8GNexIwBzMAlUcSpVg0mZxHCG3fxCBtPyyiVJ/HNt4tIcYaJbBaGYcDzfUgpwMBUVo0zDAiIQM8RMHgox7C10reQEl4vwOMnJvD5+59UelzKHBJJ1YkeEUB9QBCREX0ku55VBBDRkg+TLdB6dGvBQDJwmYIjACFYmHlLpQbPEZAitNXrBhw2E5grVwDho9VqodlsqswkkZzZ2Vlks1m4rotutxvR0ROBijei6qV60tnq2Vd6z5lMRpEnIvVxz3bHcfaU7+igfalrevXrecbw6wscL9VsnG+YCASDCYmHij4eKbuwVzfwo9X9n78A8M2rGbQdGzlDwPfDwUd61taUQNO38C++fx4PeW/usKzUbS/j0G8TQuCrFy7gLsdBe2sr8jjKvBJ5NQBYlo2lpSV1XOIyCn0fmaaJB85t4O99+41Q6gYWBjyBhC8kJnIm/utfOYVP31OJkPN2ux3J+tO5Hj9fgVA/7wqGjMVhGTySLabzP14Z0veR67poNptKKkODnnTfdfrs0W0k6dED7EajgatXr+4ImvSAkPYPnatBEMDzw8+j57qh88wg2OGMgSP0iIeUEAgD6F6vpz7blJXP5/Mo//qvI/ja13bYZOp2pvR5aLouqu++G9mPdMwKhULkMzfuHIkECT5ISMh8ggQ3AV3XhxdImAYDDAN+EIBxtu05jdAOjkEgANDqubBkmI0qFArI5XIRl5hKpYJyuYxUKoVuwIALV2AxGVqkSYneoPyu26cFMMFYaOFY63JkDQGDGvEACAAdjyFjAg9kmtjcHDSWaRn2fD6PiYkJ9aNMZE0nProbC5F1PQAAsCNTr5ObuKNEHMesCs66ebjS39HEG75u6H9/3G6ByQDOgDAQ6Wq326hWq+pvwzDUvi2Xy1hcXFQZfJ3cE5mj96ZbX5IDDkmRiORHt2un7aOU2xNV4xIRkhPs9hz6m0CE7lHLwmOzFqRhI5e2kLFJg2+M1HwP+1tlWyVwvpEfhIYCvuej7zjQTmAAQGACF4MCcO4cmNyWjND5Q+vUj3fcJlFyA443kDNJqSpNRBgLhYLSXot+HzMzM3j4E5/YkdEmkJ6cgtGPzRn4f/ypJTxzbhPPXWzBCwLYDPjYHMcn5gUKtbfx/PNRgh7v9WCMKV03BXNvV3v43oUuXrzahycETBbg4SmGR6cEpsxtSYwurfH9MNAkK1Za4jIufWGMKVkWybXI5lVfdLvGcfpR1PtKpVWWXtDrivB7KdyfgZLImKaJSqWiggza3zT9OU7aqUpBoOA0m81iYmJibB17ggQJokjIfIIENwFZ24RlMPRdCdswQlKtlYQ55zC4gZ4wkLUY/g9/5i+jXCoqu0Z9KqJe1nddF74IpzP2fQHPCwmIYZpID4h9mK2WEAJIMYFP5Dfxk/Y0Oj6HwQA+GJgjwJAxJb58VGAhG2a+KZOnZ5+pDE6kzLZtlEolJaGgx4yyaySSpFcZdJJEP+Zxb216/OSWi0vPN9D1OPIWwNlA+jwgnR0PyFoST57KYzF7r9rHuiVfp9NBv99XAUe320Wr1cKlS5eUtCafz6NQKKgG4MXFRZimqYIXyq67rhsJTugYUYOiTrjS6fS+SAoFDTrRJxcTum3czD6RQL03QJfw7OY81Oj7+F/+9TlwIZFPGXAdNzIplPQYnjTBLQOPfeqzyBpRyQid75SJjp8H614K57oFvNPPovf8v4HonQdSBdhcwEA4OZUzhh41hctQZnN1ZQU/fvrpiEuQnumm/UjnJWFOSvxpi8EzefgabWDzHYbmQOuuO/sQiaZzX5fTOI6D94IyXmEn4cEMG8ylgATH9zsMP77o437nPOa9a0qSpR9DPcDVz3WSjWSzWVVJKhQKKBQKkQz/YZBenXjTgKxgQNo5Y2Cmqdx3gO34zfPCqdF6peBm6NgTJEgQRULmEyS4CbBNjifum8XTZ5aRTVnI8m1Nu8GNQUZawum4+PJHlvDJP3F6z3XqjXNPrLyN339tHdwww1I4N9Qo9ZDkCriOxKMLBr5yfAEPbrl4fj3AuYYJIRkMHg5c+uSCgSP5nVk8Ihm6bIEygJR9B6Ca/shlQrd31J0udL06ZStpobI9EXwiLHRZSqfxZ++x8Ltv+Wh7UrkCBQglDuW8ib/1ueN48p7KjtfVgwrymic5Q6PRUDp5YDuje/XqVSwvLytSQuQqkysAdhYT+QIgAkXSaN0AUKvVIk3ClmWhUCiowVlUaSHpyH5B5JCClPjS7TtwAiBlhPp5wrCBWtTAOMyNp1LIwjINOP3Q/chOpWCn7EhWXgJo9HzkbI4//7VfBYSvgh4i0XE7SiLYzy338cy5Prq+jHiwu5LBCwykmYQJH4GQcD1XBW/ZIMC11VV897vf3TG9Vr+k81cnwXpzqq73jhNb+ntYtYgxhn5mGleOfwjC4DD8DoieUrXLNTN4zbgbjStvw+6uR5qDKQimc8S2beTzeeVSk8vlIn0Gw7CfoWSUqVfHbFCZ0h1kum++BffaNWQmyuoxSk4lJUQgVLa+2+2ifuUKpNxulLcsC7lcTq2TsuyHGXwcFPttUE307wluFyRkPkGCm4SvfWQJ33l9FS3Hx2Q2M/BBDiGExGbXRTFj4msfWRprffQDnclk8Bc+cRd+8E4dTcfHZNYeuu5yzsRf//yDuGsyjUc9D1/pdlFvtlGtN+H3O/BdB0J4CAIjos0GoLKcjuOg2+1GtPIAIj/WRJJogAyRV91uTtfOxsm+7gNPMhf9ervdxnEu8OePAGe2rFAb7jMYkHiw5OPRaYHU2uv4WS1qP6e/PjWaFotFHD16VG07yR5Iq9xqtdDr9SJkb6XL8R83gHd6PoQMB2fdnXPw6LTAvTMh0aeGQM/zIts+qhmW9pWefaVs+TCCphM1yrLrOLvSwLcG9pGuH8DgDI8dyeEzR1OYT/tK1kEVFkK8SqJfPlBi+HE7gOOE5DgyxEkCgRTwAoEPTZuob23LmHRPdL2iQEHISpfhj/on4UoDKXihlEcGan6AAENPGkhLHxwx6QsA27YwPT0dIYr6a+syMF2+RFnk3SRdAHYQ0LiLU236YUgzBVs4kIxBBNFeGO51IKwcGpX7sRQ0Iu43dMxpoSbneFVKH+ZF5y79Pco9KS6V0ol7/Ll0boefSV/tu+2eiJ3nh+QcuUIBCw89NFTHrgfzUkp1rg2rJMWD9mFyqcNA+9ln0frud8eaXwAgIfMJbhskZD5BgpuEBxdL+PqX7sc3njmPasdVGWVqvitmTHz9i/fj9ML+vbrHXfdHTs5GnndUu06ki2QoNGgmPhxJl85QFlqXMRDR73Q6qNfrkSwwNcXGM8Bx15bJyUllcxcHbafjOPhzjoNmu4taqwPpOwhcZ0eAQM4dRNpG+bzrk2mJUAOhlIBkOD++3MXvrxvoB2zQfyDhSo7X2lm81ZX4XKuBhyc8lZG0LAsTExORfgAKVmgCZ7/fRxAE6PV62NIaPnWiTpe6Dl/PPuuNws+vuPjt8330/LDp2WAMgZT49hsOfvAux198MIdPHAmzvvS6ZG1Imm3aNt2lxQnSkKlfQtU1YUknMulSAvBYCiY8dF//CX77tZYiwMP08QTGGF437obLTKThDvTxA3ceGaDstNRjOQBTm2YMIcCDAHP1Bj575iUEgo57mDWWWnPu2xMTODc5CQDqfIvLWvSAT2+mjZ+jujRMgOHvvCCR9oG8ZavAM/oeASEFWsXjsJqvIJu21bRUfZBZs9lEr9eLkHc9a79bMzwQ7XeI96EQMabPrf45jVtBPnD1GiZdF3Lgla4CmYGeSj3SD4PCCxcu7HCl2u1v/btE76OIu+GMqoQMI/77yfYf1vyCBAluJSRkPkGCm4inHl7Aiakcnj6zjO+eX4MXSGRshifvD4fuHITIH9a6TdNUGcLp6Wk1DZWIJ+mzM5kMKpWKIpmMsR3ylbi+n7KiRBx1bS05xuhTUOmHXJcAENnSBzCl02ksLRRxYkB8iBjFG0n1v3XiQhlDksVQRSHuF25ZFt7dcvBHZ95EYPiYzukyogC+L9DxGZ6tlbCQb2NuQIKllKhWq5EGWZLa5HI5FRzpZJosCHUiHR8WpBNRCkCEEKizAp63PgSPbEoBSBYSYQag6dv4py90cf6FcyijEyFEwHYFZpjkxPabOOn8Au/kH4bLbUCKUBvOOCQ4zMDBfPUFdDsX0QUiZJ4kPPqxS6VSsNNZbDTnYEoG27DBBtvw4uy9oZRm0GAbMA4uBSadayARzmKrjUIQwO73sbi6bdGzTT7DzD0XAouLi3j4P/1PVQ+D3lSeyWTUPtRJYZwgDiOM9a4L89UfIWMK5NIWwFg4xAwYzDwI5x64ksNgJn7p0ceQYkGEvMYvaX93u93I7br1pB6sxptbdQtM/dhS/4sumdMz/RTAdE4/AMu2wiZXte7h77977Dg677476itlT+xmUzvMtjZeOaL3NWw98WCdc46dNYEECe4MJGQ+QYKbjNMLJZxeKOE3vnA/uq6PrG3CNg+npHyY66YGTmrC1d1dut0uaoNx6KlUSrlrEEHdDVTCp/Xo8hPKLJLzCGnQidg2Gg2VUdRJDv2ID/NpT6VSyimDSJtuAxgn+xS8xPEvX+ug0XVQShuRgIOBwbIMFAyJlivxUs3GJ7ObETvBeGBDzizxwUpBEChyogc0RPIoux2fRkrruJi9Cx4MpOCGx0yG/5He2UQPDk/hlXYGS+uvRKwOgZ3TeHUiyDlHgbVwb3cT1fxdqGWXIBkHg49pZxXz/Uso2S2Y2WlFqIgk6n7t+vRdRxoQTQ6DMRjG9vv+wfxp/HDhQfW6HkwYTOLPVy5iIhsGRvPPfAvy7bchS6XwcURoGU0bHZDPWg2lYgnF48fVsRRCoNVqodXazvzvB0QkvUACgQ/HFXC4hGkYKA22B+QQBImgL5C1GD7x6EcQeG5kYJR+qTfX6rML6NzRM9m6jCZO7gl0LHVLWDqH4v7tdN+7tg324IORx4anz7br0HbwIMFeeilSCaBGap1cx8m2/ri9XHbGCarit+uBj77dR5avIO956LfbkfMkdPFJhQ5jevUhQYLbBAmZT5DgfYJtctimvfcDb5F1c86VKwsQ/rjrmft6vQ4pZYTckwZYB9neWZaFYrGo1qUTXnLJIZlNXNZDS9wZhJZOp7Mjyz8s86dLKDKZDCYmJiKZT3p+z/Hw8loDBgPYYFsdx1HWicqnHxZeqwEnWxuwDLaDaAkhIq4xlAX3PE/1P+gZU3JQCYJAEd24DlwFN2BY57OhmwrCIMMPfASBj8F8MgAM0jSxmVrAtBfA5FCuIhQMxSeYxiUfISGXMKw1wEwjlzJgmyVY1kd2SGqGBTL6oC5uWuCYhic4giD0rpcIdfCMbQcVfsCRtjke++gjKOTCybnW889DXjBgpdMYRu+IkAnGYKVSyA9kNvHKTOQ5I+7b7XGfPFHCt9+oqXNMO9HDYywlBCR++WQJRxajMjcAkaqWLn/RB3jp/RcU/NJz9O2JE3zdBjMuraFzaRT09zhMz06Enc5RPVO+w25UCzz1zwIF4/HGaNL8605CdL/+2qO2ZxThn+z2kBt83nRww0A6naLxwoCU8AM/MlV3t/3wfjb17hdJE/CdiYTMJ0iQ4EAgq0NqviQbRSL3zWYTQgglMSFyT9nx+Lr0hlsiIkTKdT/qvX44dQKpD7sa1oDZaDR2NPIOK/H3Ag430Hz5WTgxVARBxKJRcg6fAbVWB3l7Z0OwPoWTFpLOqKZDQL0uSULK5TLy+Tyyg6FgtL+JUEgp0XYFvveCABcSGdMCGGC4Ljxvm3AwFma5zZSFp776q5gth+vMv/wyMmfPhjzGdSEdN+JgopyRING67z7U7703YjlKVRudrNHtce96knjQ445ZTZzvT8D3gkHAIQdJ7e1GUh8WZrx1/OgHV1SQ8dHlZcz6PpxWKxxopBEs27ZhDvaTz0JP+NnZnUT6MPBXPlvETy+fGTSfWzuaz+tdFxM5G3/1sw/gyEx26MRYCpL0fUf7ijLr+qTTg/ixU3C126RVfZiZ7jCle8TrAUZ8INgo1x8gar2pH6thvQDDyHq4P3cGx7o8KV6diAcKtL/iCQYwBhEIVdkh6BKduMyJ9p9+bm+vjg0l/TeiqXe/SJqA70wkZD5BggSHAjYgTel0GuVyaGunk/uNjQ0EQaAmTtIyjNxzztW6gG0bTpK/6G4ccRARpufuBp1Yx51zyPWk3++De31wKeAEgO8HCERINM2B//ZgbQhgwJQ+TARwHE9ltomEUQWAKhw0FZSmxOqa9WE2oJyH/vdkbUkuQb7vo+e4+IdnX0DL8ZFKG5AynI66TaZDktVyJFIWcHRhJtS8S4nM2bNIvfwLYIjXPBH6wYFAOp1B7vOf39E4OQpxwqVXWHzfx6eLApfeA/rCRpYHdGAGpEmiBxOW9DHXu4jNZl0R0XtbbUwJAddxImb3DECxWAw16wiDLkDCC+Shy9oA4MHFiaHN514g4AcShZSB//3HpmF31/HOOzsHQekkNk7U9aFU1wtd2pVKpfb9fD0YiPdTxAMECggoKFCVrBEBAC36lNt4gDCK+OuX9BkZdj8AZQE77HtDt2o1ggDr6xs498d/vH1bzMVIl8gN+3tUFl9/TNwxSr9+I7L9jT/4JrovvwwZBHuSedyA109w45CQ+QQJEtwwkIxlYmICACJOLpubm/B9X8lLKHMfn5wKRG04SZZCUyb1Zrf9kh5d8pPJZFQQMgwveK/j6TPLyGTCr00iyKHntkQgBDxH4tE5C3/y+BMRPXQ8G9tsNtFqtSLWgHqjKGMM/X4fjUZDBRQAIr0AtN2pVAqTk5OYmprCJ0+U8K3zmzC4CcZDDr7tEy7hBwKS+XhsKYvKREY1BDPGIQ0DsrTdJM05B+McBmfgLLwebGygVKng4cceA7BN1HUCthtZi9/n+z7uEgJWpYN/9WobXY+HDjw8dOARHChaDH/u/gIenfuVyETRyVodRqOBVDo1GIwm1IA0z/e3Aw0/wAvvrOEb3/hDBCL02//oQhpPnsrh7qmMIl7xAWU6KRt1HxAGmo+fKmHiK6fw/3v5qpoqm+YMH1m08dljadwzGc4DoGrLsAz7sKD2VoIeDBwUw8h/+9vfhvPcT3d8nkTkepgNdx56CL2PfXTXc4yqevrfVC3I9bpICwFvUPlQmvntNxk2Tcuw6uZ5njru+hCyw0K8+hBfiPjH7WjjfSjxAGEU2s8+C39tDRACotsdvWFSAkMCngS3Lm7tb48ECRLcUSDiUhqQRl0CU6vV1ORUPXMfJ/fjSnKG2VpeD2hOQK038PI3dnr5Vwom/utf/Sjum82rbDtl/ImUkwSp1WqpSbSO4yiSSkOz6AeaJsdSlpO876mRVieXgVkC807gWs1D3gYs04RpGjANEwwcLTfARDaFv/q5B3DXZBqe56HT6aBthU4sUp+mKwSYENDpCw8CNFstyOVl5ZFOZOJ68Eu/BDz58UboxHRuDV4gYBoMj98zha88OIO7pzM7iFtn8psIDANmNhdpNh38A6QM7TkDYKMn0PUkDMbgBALPXuziZ8s9/LkH0nhsYbzeEv14UuM2kT0KCj9qWfj4vSkwK42JfBr5bFoFqHoPiZ41puqQnqW/U6EHQYTui2fg//jHkUwxAxCnktLzMDU5iZlPfGJkVWCvy/QPfwQ+qIBAP1di5w1J6ag3ZbdZBPF+hd0qA3HsJ0CI9x0A0UZoAu1jvfdFVQfr9W1not3I+iEHLQluPBIynyBBgvcN8UZYcrmhrPTGxgY455HMfXxa6jBJDpFieo1Rkpz9YL9zAkg2Qe8tDiKmRDQcx0Gr1UKtVkOtVlOknbKMqVRK6YFJj04BAhFM2byC+1gfZ8270ehbYOiHWUbGIBlHmgs8NtHFlVfb6E1NKXJpSbmdndQ0wYEQkEKE+5YxMCnh9PvYunRJvQ/btpXrEU0vHUfiFMd+nZhWslm0OYc5ImhzfIFWt48igLRpYqaUVXaXQkhsdT08/XaAx04v4a7JtCJWjuOg3W4rtxs9eAK2mzBTqRQKhcIO0kSNy3SMybt/P4hXAEZZNe7nvlsZ43q/xwP5/WL53/5btIQAazYjt8ttVh9CCOTyOZw4cWKHL796TozEA0DqZz+D/eqrEEIqqdgwTb0E0L7vPjQ+9HD0du2xOlGPW4/GX3fYNsXlXEII3LW1hclBsC7jAcrAESrB7YmEzCdIkOCWgWmakYFNNNSo3++j1WqhWq0q8k7knmQp+jooU6zbXAZBoDJW+2kc1HGYcwJ0BxkAyOfzmJycxHHNQlFvKqZMvu7TT1l9+rHudrs4Va/jyOZlvNbJYVmUEUiAS4FC6yKKm+dw/uUNvDboE6D98Zc2qrjfddGu1QdZSZqwysANA9YgGApaLdiWhR5jqlLAOUer1YpYbOq2pkTwx2leBg7PianV9yGkDD32WejuQ85DzAAqWYbNjovf+8Uq/uojRUXaaVCaZVmYn5/HPffco6QxujxGdxrSZUPDpEWj7qPbiDDqvQW7kTX9+riPi9tEDrOMpIpQvBFcf6x+3zjbtZdDUHcwYdn3/V0tIaUQ6HQ7uHTp0q7r3u02cf8DYK32jnUPOyuNRz8Ga+DcRdCP37DjmXr1NaR+8QokNamHL77z9QbHt/bQtgUoXdJ1PQDTrWN3c/6h90qP0xuDw++/0cFxQuNvbyRkPkGCBLcsDMNQg6yA8EeNyH2n01ETU+PkXvfU1jN51GRHE2H1rOq4uJFzAnTozcO0UBa+3W6r4VLNZhONRkMRUSkleBDgISHwgAR6ngT8Ply3B8f20BfpiJY4zMAH22465Hwz2H/GIKCwTBNMSnCDY35+XhE7CpZo4BXp+3WCT4/NZrMRgp/L5a4rayw9b+i0TgnAcnxMADAGzcp+EKDb6UBozcBBYOC759dw2n8bpUJuR8+G7/uo1WrKdnWY5/m4pHq324hEh9sUPTZxQjbK0WXYfcPcZUY5vcTvH/ac+PPjZJ96C+hSd6KhS+oPoXVMtJpICQHfdbf3UfjAgU2prYim6gGJbf/Y7+1LX4SbTsH/6U/1Z+n90+o2/OIVpHJ5pP/kr+xYN0F/HwBQK5fRt20YU1M79qH+3kS1ipnpaRz9xCeGBnV0GQ8chsl99AbzuOxGb6ylAFv31h+27xLcnkjIfIIECW4bcM53DLLSSW98kBWRfPrRjTvdxCU5pLUfh2DuJ4McBNFpunGCPuy23XS6cRSLRRQKhR0e4r7vwxeAJxkMGYDJQGm/afE8D6lmE6zvRN63RFie9wfSExEEsIVAr93BxpUrO7KzpmmiUqmoTGG8XwAAms1mpKGPqixxgj+OjCL/+OMj7+stL6Pz7hVAhiSl0t7E/+77/2IQrAzoGwMCAfxi8TQWnvgM8nY0IKOqTjzTSwQ+Tpri13WiN4xk6h7lcaIYJ8zx2+LLsNsJuse87k+vuwoNcxkieZf+PH0S8bjLXviVK8s45vvotMOMuU4p+cDGVO1b7b79fD4IUkp4P/4xxHM/Bfbq8/B9+L4H97GPR4518P3vQ/7s50OfIs6fBxxnO8BMp4FYdn+wIWGTdmw4HX0u9Pcct3uNE31dPw8g8vkn2RhJx07UG5iWEiwu/dGbgBPclkjIfIIECW5b6Hp6YKfXPfnIx8k9kdaDSnJc10W32x2boLta1vFGQW+GBYB1z8brQQHvelkIycAhcdRs4oFiE8eyQlllplIpLHV7MDtvIp/PRzzlQ2nK9nuXAHrdLs6dO6f2DWWz9Sy8LtGgqgo1E1KzMllx0rGhvgayTowTfDrGhNKXnxrpgX3lb/1XKJ09B48b4AzIuR0s1K7ueBwPAlimgcc/9fVDr66MktLolzox0x8Tt2iMZ2t12Y5+uy7X0aU6e0lehkE/7lS90m+LZ4P1xkxq6oxnjYfdZw6ar/lg+qqOMMtvgDEOn3Nkc1ksHju2Y/v2g5VcHm3bHkujn8/lsXjyZPT5/+R/Ruv554dbO/b7YfNovw+QQ5XmEKXWPbCYXbzrLgDb8h19fxGGVUB06Jaheh+ObrdLlradMy/BefddyEHvTZKNv3OQkPkECRLcMYh73UspI3aY+iArIvapVCoyXZMWkrK02210u11lK0mZyVsBev+Avvxik+MPX++i6wmYJoNtcAQCuCiz2MISfuOz9+PzD0yrTGv1Jz+B++67sAcVD3L4ECK06NMzgIZpYGJiIuKTTYSCbiOSQLp63a2F3IqI5JNkql6vq0rBsPeVSqWU5IoIfjabHer+whgA00Q9Vdi1qa/YbaCSMXD54oV9WWvuZrmpa99vNoa5xRwUQxs3xwgKxtWzE/Jnz4IzBtve6X0f9iUIcBY2lDqOi42NjbEDkWFVD8d1w/N1yGdYP1NIdtZutyPrCIIAMM2hUhq5sRFm2w0DoIbqbhemYcBOpbZdlwYVr263G5HC6AOqCDpBjw/tom2KN0DTd1SxWMTs7Cwcx0Gz2YQ3SCowKYHdvsNCPdOu+zjBrYWEzCdIkOCOgk7eh2XOm81mZCH/b93Schghiv+I6hnow4JuYRgn6MNuHzb85+xKA7/70xfhCIaZQmbHRNLNrou/963zODmdw+mFEjKZDFp2Cq7vQ2xtRX7EySKQ9oaUElNTU8jccw82NzdRq9XQHsgjaB/qWcRMJoNCoaCGg0kp0ev10Ol0VENvr9dT1pwkYdFdenTZETnGUIOqaZpqABdVB0zTxN0XL6IoRWixOXDq0Tm91CwJ00Ebzz333KEdw9sFcVvTcZxxruf+YbczxrDy5ltoX1lGdoTrE0EyhnQqhZmZmT3f27CKhGq4tUy4DCrIkwBEEMAPAmUxur2Ton0RVDXRJViCqhJBAMP3wQdOMWwQDPR7PdU0TXIWxhjMQVAbd7WKV1Xijcr6YCwd5GglhFCfiVarpfqKCoUCgmwWHdME28NtSvT7MGdmdpWyJbi1kJD5BAkS3LIQQozUlI+St+xnsEuxWFQ/0LpXPbnM6JaDetaTMtGUAdN/bPUs3m5EfNh9h5FVffrMMpp9H1M5GxJhBlLxcylRTHHUey7++R+/jr/x2EwYpNxzD2SjHrGyCxdASgERDBoqhUBjagob585FiJJu5djpdJTemvaDTuay2azKsNMUXMuyVHVEz+QDUASG7qfqCGm7KeNPA7RSqRSmtmooSoE089GX5kA6FN1PDIDJgIxx8zOQo4j0YRDqUYOvhhHpOxXDmlUJUgJSSDiug8APSTydHIZpwras8NxHWFlaXl6OzBYoNepIDxp2xYC4E6FPx3TohmEgXyjAMHgYPAoBgYEc0HVRq9Uix4gSCeMeH3Kw6vV6yiLWHay33+8jk8lgamoqbHxlDH3Ox8q4M8NA9pFHRsrYEtx6SMh8ggQJbhpIkz5uI6g+Yv1GId4USySdtrXdbqssc7FYVM2mNMyJ9N5E+sl/nAYq3Uy4vsD33liDyUMNcr3RGJBqzZYRQCAMfO/8Ok61XoVtGuCWCeNTnxq/yhBzS7EsC6VSCcViEYZh7GhMJjtNz/OwtbWFzc1NAFDSAt3GslAoYHqIntnzvEjTLq2fXHTodahaEAQBgn4HJjcQcBuCh8eCMYaUAaQNCcMFLNNSsqEbkX0edvudTKQPglGuRPHH7Hu9UqLRaGB9fR1ra2tYX1/H9MX3UPZ9ON2eIuFSijBF77nb1pBBgE6zha1Ll9RMh1arBaPVhi0E3EEgH06tDV/PMAww3wfjHAgCcMOAYZpKv6MsIxlDyrZRqVT2/Z4AqMSD7/sqOG6329ja2gLnHIVCATMzMzu+f/aTaU+y8rcXEjKfIEGCA0FKqTJW4zaC7idrfiNB2vphGXKdpOsTaEm+Q97q+mPig6z0YIBK9+O65FwPuq4PL5AwaTqtDKU1iDl4SwBeIHHh8lWkmB+5L05Ax7m0bVvJAUiOZJom0um0mvYrhIi4qPQG/uKkL6f+hLW1NTVsa2JiAuVyGeVyGYVCAalUKiLl0ReSKfR6PQT/4B/AfOVVFHI5BCKACASkdMFNE+WJktJG+04fS0tL+PhXvnJDj0uC0ThMgul5HtbX13HlyhWsrKzg2rVrKpCk4O/T9TpKAzKuPrFK/sLVYDb6frtw4QI6nc52Zr7dwsSg8Zix8DmmwcENAyzmTgPfhxgErjpEv4/uyy9j5b/6v0Te226ZcLLl7Xa7sCwLmUwGjuNga2sLrusil8thdnYW2WHuOQPs1jSe4PZGQuYTJLiBcH1xQ73IDxNEsHYj5jpxJz/xWwGmae4pYYlrzQ+aHdVtJtvtNjY3N3cEB+l0WmXF9CZcahKlTP5hZ2iztgnLYOi7Ay9uzmEYPM7lAcFh8gCVQg4GkzvcMuLLOPdR0EK+/vpCOl99sA3nXBF73SqTZDa0uK6L5eVlNSl4YmIClUoFxWIxMsxJr+nb+XoAAJ8ySURBVCqsTE2jbZrIaDrsYJBFTXLitxYOQjCpp6JarWJlZUUR92q1GrHO1JuSgaj+nXMGzg0YxsAiFIMhS74PIQVSQqDZbODdd98F5xzZbBZTU1MolUowqpvI5/MwtcCXMQbfcSCkDB1tpIT0PMgRTlb++jraP/gBgO2qw7D94HmeCkgymQyy2azKwpumqSqFNzpRkODWRkLmEyS4ATi70sDTZ5bxvfNr8ISEZTA8cV84JfTBxfGnhB4UlFUaR2NO133f33vFNwl76czjt5t7eUYfIgzD2OF1T/ux2+3u8Lqn7aVmVbJopGCIGm8P4z3YJscT983i6TPLEEKiWCgAmrwm3F4Jv+PiVz+6hP/2y39GuXYcdBnl7EP3d7vdkdsbt2qk7D3tI7qPNPPVahUbGxsqm0+SpmKxiMnJSZTLZaTTabiD5kSJbYcSg3MAt3ZAnSCEHjBTI3Sn08Hq6ipWVlZU8zXJ8PQeFjqX4lr0UqmESqWCydVVGLU60uk0pAgbWIUgqU0ou2E8lMNMTJTx+OOPI5fLYXJyEpVKBe7FS+hfvgIrZpUKACybVWeY6PfBLAvS88CHNJyybBbG4DskLjGiRvFut6uqgACwtbUF3/eRy+UwPz+/w641wQcXCZlPkOCQ8c1Xr+Ibz5xHs+/D5AzmIFP69JllfOf1VXz9S/fjqYcX9rVO3/fHbgK91bPmexH0dDp9W2mKKWtH5W09kOr3+2p6aJzc0w+x3nxLxCOead4PvvaRJXzn9VVsdl1MZu2INSO52RQzJr72kSUAiMiADoLIgKp9LtQIGl+fTvIdx1G+9DSIipqOAWBtbU09njGGVCqFj1y6hEnfR7/dBucMjHFwxgYkja5zZVFIFZMENwf0GYkven8FBcb6omfcaT2Ueacm6Ewmo8j3zMwMJicnI58147mfggUBjEYzXAegHGYAbfgWgNnZWfzS5z6nyDQArJgmRnXyGLkcoBF0XixCNJt7etoTfN9XDeSUAKA+ENu2USqVUCgUDtVBK8GdgYTMJ0hwiDi70sA3njmPlhO6iQyzBfzGM+exWLBwfMIcuxH0Vsqaj9Kaj5K53Mys+a0A3cUGiA6y6vf7Ea/7uFafMtPdblfpzocN7NkNDy6W8PUv3Y9vPHMe1Y6rAko/kPCFRDFj4utfvB+nFw6nQkSa+XGmtg6DrqXfben3+4rMt1ot1ejqOI5qwiWSEwQ+mO+D1+vhizAGVT/Qpl2yIMDVlRW8/M1vKiJIJIpsLsddEoIVhW4vqge3JNUj+0WyeazX69ja2lILDTsCdk63pWNNg8+mp6cxPz+P+fl5lEolWJalAmmq5rTbbXQ6HRyZnsbC3XeFnyvTgmlZSNk2TMsM/zZN9VnLP/54hMgTrrdhlyYJ6387roNGo4FUKgXOOba2tiCEQD6fx9LS0lAb2gQJCB+sX9kECW4wdFtAz3Phqx8sASkkTCFRdRj+7r/+Dh4v7WyMutkwBl7HezWB6lrzhLTsD/ogK4JO7jc2NpSUJK65P6gk56mHF3BiKoenzyzju+fX4AUSGZvhyftDqddhEfnDAJGyg5AVvQm7VqthfX0djUYDfq2Gmj0Y0sMAyG2vcZLfUKa+efKECppJIkWEUfezp2VUUEXOPAddbqdqFDBcCqNn14ms6xNJqdrS7XZRr9fRbrfRbDbVrAKygKUBYaRFZ4yphupcLofFxUXMz8+rrDsdu2q1ikuXLqlmat/3lQ3q4uJi2HPx5JNKZ04zEPaDcRt2JSTcK8vwGw24nocg8OH7AYKBBChfKEBqw8UY43BdF/V6KAGqVCrI5XLJ922CscDkmGPqlpeXceTIEVy5cgVLS0s3ersSJLjt4PoCT/zD76PV81HJ2aFdXr8HIWSkdNuHhRQL8OuFt2AZLKLtvN4fdF3KMQ5Jv9nWiQmGgxpkiQz5vh+RJxG5p4ZQ3/fHluTcTk3YhwWS51C2l4gjDQgj4s0YUzIhGkBF5JoccvTqWDAYLKTsCwfSjsPIzBuGcV3BwGFjNykMTUTWJVZ6dYWuA2FgRNUmWlzXVdUcCpiOXryI6UuX6dUBMPBBUJXOZJBOpWDZYYXKeehh1B56EOvr62rGAR1bIupklUqknZb97ispZcSGddglTYpttVrKKrXb7WL6f/3XmHjnHfQzGSXnkZDgjKNYLIaTzKREUK2CP/ooJn/zv0exWDxwlSvBBxdJZj5BgkNC3BaQcQYh5KBBMPwqZwDAGVxIXFpZRcYQEQIfb9rSx9bro+zphymXy0UIX5LFuT1BxIZsHGlAEmUcPc9TVRQKxqgpdC9Jjm1y2OYHixzQcKpsNouFhbA/hchprVbD1tYWarUaGo0GXNdVevlOpxPxgScHk2KxqPY9BVRUMdGHXJG7D2NMBefiBz+A+NnPVLMlVemE3GkZ6jz8MDqPPnqg93yQAIDIOJ1vOlmn3gTXdSOWovp1IYQ678i1KJ/PK1eibreLVqsFKSUsK/Tzn5ycVN9xetZ++tJlzL73HqS2/whUUXGlBAsCrF+5gl/0e8hmsygUCjh27NhQ4p7NZiOfBXq/uxHz+G16gzcFJhR8k+zLdV0V5PX7fSXpqQS+OjaMcxicw7RMGNxQQQIfPC+Xy2FqaupAxz5BgoTMJ0hwSNhhC8g4LJOmTw5uAxBIEzYLsDg7FRIt20YqlYoQ91wut2vWnLKFlGHUp5XGJ5fGb7vdyvkfRNCxKhQKAKA0471eD81mE9VqVQ2yInJvmqYiZfo6Pmg9C6Og9zIQwQeAfr+v3FEoi0/BU6/XQ71eV5l3xhjy+TyKxSJKpRLm5uaQyWQUwaOF/gYAvPAi2M9+DmZZMGLbIwE1kVN6HkqlEvJ/+S+P1UMQ76PRb0v9/OdIvfoqHCkhgnByb0hMA62JNLTqvLa4hEvHjgLAjuZSatymPgKyRiT/f2rG7Ha7aDQaqNVqWFtbU0Scc64CVGDbvYiy+rlcDkeOHMHs7CyKr52FWFmBnJhAIAS8IIAIgjDoUXaSHHang1KxhI9+9KOKsNNimqYi5O12G41GI0LMdSHCsJkJZBtLgUa8IkGzEXRYlqUGSvm+r5xz5ubmUKlMwryyDCubBR80XIcTYwdTprkBBgZf6+NIkOAgSL7lEyQ4JMRtAdOpFFK2rRw0GMIm2GrHxdc+uoT/+1P/icpgUTaHhhJROZ8IxHZj3/Z0UvpRJgIxLigrN24AkJDB9x+maarKDDDa657IfTqdhmEYygmGmlT1JtEEIdLpNBYXF7G4uKhu6/V6qhGTSD5Jm9rtNqrVqiJ/pmkqe8xSqYRyuYxMJqMaO6+l0+hZFtjkpCLIQNRFhTEGsbkJ0zBV5noc6PKOTqejhhsZL70M+9VXIQxjO1jQiKzyWhciJOITJfV5J9Kez+cVideHdXW7XWxsbODtt99W9pD0HRUnyPprSSkxOTmJubk5zM/Po1KpIAgCdDodbGxswNnYQMn34XS74XNonZyDqe9BBtYLrRpLCwvq+5GsHPXXtyxLfQ5GbRsApeHvdDpoNpvqujeigZUck6g64bou8vk8KpWKCvaov8jJ5+EOkidBEISvbxrgLPkMJjhcJL/SCRIcIuK2gIax/aUdtwXUGyOLxSJmZmYAbI+td11XXdIPC/3Y6j+uRBr0RS+J6wuwncEb174y3gw4TvY/IYw3FuN63dP5lUqlIIRAp9OJSCOSSs1wZDKZCMGXUkasEvUMPgA0Go3INFr6TBeLxXCAEABTI+hUrZODrDNd7/a6uHDhwg6HIF0CQs3TnU4H7XZ7aEXgVKeDNOdwMpnQvWewUNNvOPDLAK/Xcc899+JD/+V/qQg5DfOiQGF9fR3r6+vY2NjA5uamSiTEgwNq1tb18qRdz+fzsG0bvV4Pb7zxBp5//nmlnadG2a9ubaIwSFAwnXxzDjtlhwSYAUGrjVyxiPnjxyOSqL1AhL9er6ugp9Pp7Po9qFdMU6mUytI3m00wxmDbtppynM/nkUqllFNN+B3sQnoe2NaW2k4xWCLbtovzTYIE4yAh8wkSHCIOwxZwmOe3ECJC8OlHXEqpSsOpVEo1T9m2PfRHbhTJH3U7/dDuN/tPmv/9ZP8TUnlw7OZ1TwRG97onJJKc8UCa5lwupwwghBBKykEZ/GazqawzNzc3YRgG5jc3kQsCeL0eDJqcawzI9EBqI6WAZOFEUtd10W63Ua/XUavVFOF0HCdSkQO2G3BJ/pLJZDAxMYFSqRg6puTzKqutN+ETfM6Rz+WwePQopJTKEWh1dRVra2tKpkLfM9PT0+pvIvwUXNA0Uurp4Zyj0+mg1WphY2MD/X5ffdcwxpSTDC2T9QaMWh2WbQOQkCJ8DZgGTGP7vBSDoGO3c9VxHEXW9YrFKL8P0zR3yBxzuRyCIFAVmtXVVfV4wzCU1CiVSqnhVRRs0WCnyp/6U3Bz+bHOsXFdchIkGIbkmztBgkPGjbAFJH20TsQoI6+T/FarpSQ6ehY/lUop5whySRkHeoPcuAEANXbRkJ9xsVvmX3e9oCUZ8jMau3ndk+6evO7T6bSSb1ET550gybnRLj6cc5V9P3LkCICQzDWbTTSbTZXBJ9234/TVpFEpQnvMTCYNhlAmYvo+tlZX8foPfxghvZQBLhQKKlCPz3cgSUkul0M6nUbnO38ERwigVgO12Sp9/oDQSwCy38fWT3+KV3/jN/DW3JzSsuvQ5X4AFJkvFAqYmJgIp6pOTgKAanilxmLOOfL5PCYnJyPN+vl8PhLAO44D66fPK2kNZwzM4qo5dBRownCctI+ay6EHvUTYc7lcxD2m3+9ja2sLly9fVpaZQPj9VKlUUKlUkMlkVOMrvZ7jOEovn8/nw++nP/NnwiVBghuMhMwnSHADcHqhhNMLJfzGF+6/YYSCfuTjNmaUHSKS32w2VVadiLFO8HezQTvIQCBd1z9OAEA/vJTp6/V6Y7///TT+fpAlJbqkq1wuq2oLkXvyBacAMG7hSMHU7bD/zq408PSZZXzv/Bo8IWEZDE/cFwbSDy4err8+7cdWq6UIPDWCUlPtw/UaMkGAft+JEHmGUHrDGAMYgyElTNPE0aNHI/KU7KB5Etj+bOsSPNL3E5Htdru4+70LmPU8sBHVND0/zet1pM++DuvIEdX8qZN3agamxtdKpYJCoaD6MRqNBt566y3l6kOacgo+yJWLKohUNSSJjGmamJubgzs1ib5pIDOQjsW3V0qh3MHqjQZWfv7zXZMFVK3QSfuo6dLdbldl4Lvdrrrdtm1F4Mmlh96z7uhDAcuwAVMJEtwMJGQ+QYIbiPfDFlC35SMQ6dAbbqmETj++cZJ/0KysrhseB1SyHzfzr2t29yv/IVnCuAHAnSo7oWOeSqUwMTEBABFyTwN34kOTaL/Ytn1LVka++epVfOOZ82j2fSVx67sST59ZxndeX8XXv3Q/nnp4Ye8VDUDVL3JrIUkNyWpIWkON67rtIVWPTgcBwAB7UOngnIf+4pAwjO0Aifd6sCwTruvi2rVruHjxYsQeUZeIxK0bdYRB2PaxEYP1bzfcYnANYIMsvDn4vJIun3oyCoUCKpUKbNuG53lot9vY2NjAxYsX1WeUJHXUbEo6c6ogTExMKNtcalbVq4ZE7ldMC86gSiGECG08B84vFAABYdOuN/guA0LCHSftegA0Cq1WSx1LPSiggU1E4Cn7X61WI248FBjTcKsECd5P3Jm/VAkSJIhAJ286qDxMJF9v6qMfXJ3k3whyq2fYx4UQYl/SH/IBJ7Kyn+bfUSR/VFBwu/6wD/O6J3JPemd9yikRuFtFknN2pYFvPHMeLSecwMy57i8eNp9/45nzODGVi0jdKMhtNBpoNpvY2trC5uamWmjgFPmKUzCpO7gQiaWhbDTkK51OI//eRXBWVZ8d3Y7S932lYWdSwuk7StqhiP8QUABAjbPUlEnbOdtqYQrYXrf2nKHr40xVAyYmJsAYU/aOV69eVd8R9DmizDo1+hLxn5ycRKlUUlUw+pwGQRCR/HHO4XkeOp0ONjc30el04NRrYL4Pd8RnkwFgg32Sz+dx/KGH9rTw1SGlVMd3a2sr4laTzWYVgc9ms0qWRg2/9L0ChE29s7OzyWCnBLcUEjKfIMEHGETMclppm+wu6QecXDxIT61n1PTM2s2Evh3jIk70d8v8682/uhPQuNu2H+nPrSpfoe0rFosAEJmG2u/30e120Ww2FcEnMktB381+T0+fWUazv03kiTRThjdvSdQ6ffz93/0BPpVbU9l1cjfRyTp5retzHGzbRj6fV8PayCVIJ++6lSwt9PnYLdhRshvLVGSY1kl9DTSVFgDq9To2NzfRbDbRarXgeR6KxaIi4hNvvAm2tgYMiPdQDOQ9jDEUi0VYR4+i2Wzi8uXL6viS7SRJZPL5vBr8NDk5iXK5rLaTCDD1Zug2l67rotPpRORA8YpawfNhBwF4vR6Zshu/lEGgMv57QQgRaVDWtfRUdaD3AIQyplartaPSkslkMD09jVwud0t+VhMkSMh8ggQJIuCcK4KigyQt9INNUgwAEWJPJPtWk2Hst/l3P5afN6L5d7fs//uxb8lPnQZZBUEQIffNZhP1el15r1Oj40HPhXGbr3u9HprtDr75UhfCl2i1HUgh4HouRCC2JRtSIjDSeO6ig/de/i0YbDvwomCkXC4jl8spT3/SXWez2YiePE7adVtKev9EgrO5HLgQsDsdALEMueYuI6XE3Owc7v3UpyKTRZvNJq5cuRIh8FQZ0CU4+uVsq4mylGGja8zBhaZQq9cFUK/V8fLzzyuJDX2G5+fnMTs7i3K5rLTylHUn/T4de7LJpWOzubmJbre7aw8MBYC5XA7m5/8URKUcaun3GKG0m/NLEASo1+uKwFNDL2MMpVJJEXg9s05a+F6vp6Q01OBM7zlBglsZCZlPkCDBWCAZBg0uAqC8pYm8kg6VrOziWfzbqQl1VIPxbog3/+5G/uk+X0i0ux7SpgeTj7dvRjX/7pb9P2wZDFkf0vkghFDEvtfroVar4erVq5HtJUI8jkyKsuO5l15C9vVzSj9N00uFEMopJpDAXwtK+PncfXj+yENgAHzPhz9wdmKMweAc4AyGmcaf+NRnUcmnI9NDs9lsJFjSSTstJGuh6/TedNmLfhkEAfoPnlZEO+ItHz5y+3Yh8V4uhyu/+7totVpot9vKlnLYsaPn6UEPBRO+74PR9g2zYwx23rawsID5+XmVcSetvP56juOoakCv19s+hwdVm7gTDsE0zaG69ohs78QJ4C/8heEn2x7wfV/JZxqNxvZQLM7VeymXy5HXE4OBWZ1OB47jqP6EbDaLyclJZLPZ2+a7KkGChMwnSJDgwDAMY6hlpp7Fp8wdWWbGG21JQ3snYD/Nv7rriusLmAbDp06W8IX7KjhZtkaS3YN6/8ebf0dZfurkf799CcP+pu3s9/sR21TSl5umGXkOPZYCg4/95DkUVlYQ7CIXYQAWxTIkYzhz8hGVFQewPYGZMXR9hrQp8bGHPwSTI0LOKTCNkG5sk/V4Vl5/jd3QfuAB4IEHVNaaMr90SV7s7XY7/Jy89tqObSCSTOsgC0n9M0SOMVJKXJmbw3yjiazjIIh9tvT3wxgDDwIcO34cn/61X4s8joYsUYabpqJSNW7Ye2eMRUg7Xd+PHG5cuK6rCHyz2VS3kwd8pVLBxMTEju0kyQ9VFsnFaWJiAoVC4Y5tek9wZyM5axMkSHCo0JttSZIBbDfbEsmPN9vGSf6dXNoe5rriOALffH0TP3y3MXBdOTL0ucOy2nEiTYSYFnoMkRcikvoSv516JIZJS3RyGxlIpEk/aCGSTJ7cpMdut9uK4JNUhYILkqqkUikEIoDgHG42G04vZduvqTzTGYOsNwEwcMMEBzRHFwnIMBkdSOC+ggffdeADO94L6d51a8ZxQeRet5HUF8qYt9tttNttReB1qc7U1JRaB+1L0vFzzhUxJvJOGX+9Qub7ProzM2i22shcuYJgMMDJNE2YhgFO6x9st7+xAcMw0Ol0lL1mvV5XbldxZycix+Rrr5P3TCZzQ7PZ5AG/tbU10gO+WCzuSA5Q1ajVaqnPAzn2FIvFsaV3CRLcqkjIfIIECW4K9mq2JbJXr9chhFDNg/HBV7d76Xtc15WlUgr3TGcUQScpE2mp6W/ad3QbPVZZ/Gnkn4h6/DrJg/SMMbBTbz0M8ceQfzjdRyRXl66k02kUCgVlZ2iapsraZzIZFItF5PN55HI5TK9vgFc3YadSIGK+/YoSoWpFwoAEZxJtVyJrAgZHqL1mHJIBPQ/I2wxfeXAGJyu22h5924Zdjzdjxq/Htf1k6RknyNVqFa7rqveYzWYxNzc3OO5CVSNIZ07+7OVyWRFOytaTIxOdD1TtoCFWU1NTmH/rbWB1FelCHpxtk9vwvAiUZ7v0fVy7dg1nvvvdyMAwavak6/GM+83q2xjHA75QKAz9XiBb0Xq9rmQ0ZMdaKBRuub6eBAkOioTMJ0iQ4H3DqGZbIqREUDudTqTZNp7Fv51K47rripABWs2OJqsAuJDY6Jv4+j/7D/i4dSWitz5MxHXsekac5BpEkHRpiF4J0N1fiCwR2QR2Vlx09xfyqo9nfokcU/Z+0ukjKyU4G2TdB1nleHY+6PXw0FIZlUI2rHiwsOLhBxK+kCjnTXz9i/fjiw/O7agexBtJdV93PXNPlQbaH+SZnsvlEASBknysr69jc3Nz6DRVkslQw2Wn01GDh2ia7OTkJAzDUPuUCGm/3w+nxQ4kUuVyGaVSCTMzM5ienkapVFINm1f+4Jtoex789Q0A2HEOyUFUxIUA50z5qsd17e+HBeO4HvDDQPKgZrOJbrcLz/NgWZbav8lgpwR3Im6fX8AECRJ8YDDMd16XEriui3a7vaPZVieOt2KzresLfO+NNZicgXOGwA+171JIRA08OC55RdzvdmFgm1STfEDPEOve7zoZJ927/jeRQD0gMk0z4sEfzzSTjhzAjuFF2WwWmUxGZdnJipDWTdp+vZqgVxH0dQ1ziQmCAOl0KIHwfR9iIMsiMp9Op2HZNihXP19M4X/6s6fxe79YxbNvb8IPBFI2x6dPlfHUg1O4d8ZGvV7fIR+ifaV7u+vNzLQ/SPpjGAZarRY2Nzexvr6O9fX1iOxDlytRMEqNlq7rqoz7qVOnMDU1peRonudha2sL1WoV/X5fVaioD2NmZkaR96mpKeW40+/31WTSq1evotPpgC0uwP7wh1XTsBByIJ2i5ml7EEQZmP7c5zD1y7/8vn1e9uMBPwrUf1Cr1ZSzVDabVfvpTunLSZBgGBIynyBBgtsCo5ptdQmK67potVo7mm11kv9+/qh3XR9eIGEaJNXgoQTCZKHtN8JMs5QMBrNxz6mHUMqYezat6lIWuo804JRl10k6kUySduhZaAJNEqZ9To4vlLEl2YeujdcvCUSkdJkNyVdIAqRn+vWl2WxuBxKa/aIUAhIShmnC0KoyEhKnF0p4aKmM3xBA3xfIp22krNFyCt2ClLaDst/k9x4EATY3N3HlyhWsra1hY2NjR48B2WRSXwA1qZqmicnJSdx9992Ym5vD5OQk0uk0XNfF+vo66vU61tbWVAaagi4ir5OTk5iZmUG5XIZlWQiCAJ1OB/V6HSsrKzusH1Ugdv/9YA88oPznaRlnOurNwG4e8Pl8XhH43TLpNDCr0Wiozz1VLIrFYjLYKcEHBgmZT5AgwW0L3T4y3myr68qbzaZyfqHsatwy82Yga5uwDIa+GxJU0zAwOTm543FbHReFjIm/8hefgm3uTbyICOtNr5QNpgE4VNnQm14558hkMkr+Qg2WROCJ9OkSDXKZoeCBAgYKLDKZjNIiD5Oy+L6/Q9pC74GCC3oNOr6MM6RTqYgzjcrgkz6fMZjG9hyBFIBcbD/F+wf0Bmza9nw+r6Qy6+vrWFsLB0zppD3es6APmSoUCjh16hTm5uYwMzODfD4PKSUajQZqtRouXryoyL5eNSEP/5mZGUxMTKggjaQ4GxsbQ60faT9S8EGknTzzb6Xq1EE84Eeth6bH0uea9l8+n7+l3nOCBDcDCZlPkCDBHQfKbupleZJ86JNthzXb6peHncG0TY4n7pvF02eWIYSMNL8ShAg13k/eP6uIPGXS9T4Ckq+QLlh3pNEH31BDKRF027aRyWRU86lOrAlxhxrafwB2ZKRJDkKvTwRdl7DEhyoxxiINvd1uV20vvSYN7BGFAjxuwNqn1jmu7/d9P+LPT77yvu9jY2MDa2trWFlZwbVr11QA5Lqu6gmggIjOFwo0jh49iunpaeVnTi4z1Ph6+fJldLtdRbjpeaVSSTViUuAkhECn01H+/HEwxpB/6WVYv/iFcvahQIobA5nSQK/VABA8/jhKX35qX/vtsHEQD/hR6Pf7ymmHJrNOTU2FE2zvYPerBAn2QkLmEyRI8IGAbpmpI55pJXIMhBlbbloImIlSPoNCNnPdzbZf+8gSvvP6Kja7LiazoZuNHAwPCoTAVsdD3jbw6LTA2bNn1fboTai640w6nUYmk0Eul1PNpTRZkyQ2BF0TTpdEsIFoJp387GnfEYElqU9cd66vm0Ba/G63q6wY2+22GjBEGX9dQkULDWbqSAEJhAOg1MFU/4XWlIPX6jt9XL16VRFnqrrkcjmV6W40Grh48aIi7mtra4q0x6ep6vuNgjySf5TLZUxPTyOfz8P3feVY88477yjpCwUuFERRYyn1E+hBWhxkP5nL5VSzsGEYaPz2v4H7wgtglgUwBgFg2KxhOTiH3w8yf1AP+GGgLDz1EXDOUSqVUCwWk8FOCRIMkJD5BAkSfKBBmdq4ZebLF6v4315Yxvff3oLnC5gceOxoDl+4r4zTC6WIVGcvy0y9eXfadPDXPj6Df/KTq1hv9cABGAzwB5NMsybwlWOA2LyMi1UZkQXl8/kIYU+lUpFppDqZ1i0UdW07PW5Yw2m8EfQgREkfytPpdFRWmsA5x9TUlCKq+Xx+V1nFSjoTOsdsbYVtrjRlVbsOANDcZah5lvToKysrWF1dxdraGrrdrnLqGRbc0PGi6gBNRD1y5Ajm5uaQTqfR6/WwsbGB5eVlbG5uotfrKbkMZe5JtkQVEDpWACINnmTXqts/ptPpyFwG/Rxw0mn4tg1zenrX4+BvbIx3wA4Ju3nAUwa+VCqNXe1yHCcix8nlcpifn08GOyVIMATJJyJBggS3PVxfoOv6yNrmWBrzvfCts6s7hjr1fIk/eqeFn17p4W9+xsRnTjDVdKdPEiXEM6/6cKb5IMBfvcfHT1cFzm4xCDBkDYaPLqTwK6cKuGcmq2wBdd92ff06KdKbS+MEfVTW/DBAWVN90YkqEFYOdMvDcQcL0f6zP/kJpClrDoAPdPN8IDMBtisK3oc/jDfeeEOR91qtpjL/pImnxt0gCBRZpmNHuu2FhQUsLi4in8/DMAxsbm5idXUVZ8+eRb1eV5IgyuBTZSGbzaJYLEYqAbofPe0D0nZT1l2ft0D+8VRdmZiYuGWzz9fjAT8MQgi0222sr6+j3++r5uFCobCrk02CBB90JGQ+QYIEty3OrjTw9JllfO/8GjwhYRkMT9w3i699ZAkPLpYOvM74UCeScQRCoNZx8Y+evYhJawlHCkxJSKjxlBo4dVcUXQpTKBSQSqWwuJjBn/xYBgIMTgDkUiZSlqlca0ZNX72erPn1gPy7deIed1KhBkw9y6xLfahCAACOF6gAzGAy0lxKAYFpmjA/+1mUn3xSebQ7jhNO8+x0sLa2pkh7rVZD78J7YFeuImubMHkYSBBRbjQaqllSSol8Po/p6WksLi7i6NGjmJqaQjabhe/7WFlZweXLlyPNr7q0iVxuJicnlWwG2Lbu9DxPyWuoudi2bXVftVrF1atX1XuVUsK27Yj/frxvQQUFrRakEHA9L+JmKqWEaZmRAVE3Aq1WC7VaDVtbW+j3++r2cTzgR8F1XWxubmJrawtCCBQKBSwtLSWDnUag8QffRPvZZ8d6bP4W6JtIcOORkPkECRLclvjmq1d3ZM/7rsTTZ5bxnddX8fUv3Y+nHl4Ya116c+y//snbaPRclNIGPN9TWmp6nCUkah2O/9cfvoxPZq+pZk+S65C2OpfLqeZP0qC32230+33l705yGdu20extO8fo2zXqtsN47G7Pi1stUqMrgXoQSPaTyYT9BLVabdd9/V7dx4+WPbxaFfCFhMGA02WJPzHHcLxkKFkQNdfScen1emi1Wmi1WmpKKu33plHCJT6Ha3wKwmVgbYGJ7hVMt99Fzq0pqcf09DTm5+eV1t00TTSbTbz22muoVqtYW1tDq9WKWGxyzpHNZpVbSiaTUc2WFFyQ5z0Re70KQo20erASBEHEXpQxphqa9cFUtNDzXNfF3WtrqPg++u225jIkwRjHRKkEHDL3PQwP+FHrbbVaWF9fR7fbhWVZmJ6eTgY7jYH2s8+i9d3vhn0Tu+D97JtIcHORkPkECRLcdhiWPScIIbHZdfGNZ87jxFQOpxdKiqzrTjC6OwxlbD0h8cdvVsGkBAayDs914bhuOKiIh/IOJi1ccHP4yrFZFPNZRfD2ypqTdp6I8ubmprI1NAwj4iFPGdqbgSAIIraWvV5PbReBGjh1zf5e1YG4882L6xK/f9lAPwA4C3sF3AD42TrDK1WBJ6YauDfTVv7heqWDtOO6pKnf7+Ndr4T3SqfhcxtMivDYcRubxbvRKd+Nrxz18cTdoVtKEASo1Wq4evUqrl27hmq1qjT9QRAouQxNUiUXIL26Qk5AOqkHohIoxpgi4RSQANv6eNLWU8CgW4dS8EjVAL0fgs4NxhjsmIxHdx66XhyGB/wouK6LarWKra0tBEGAiYkJHD9+PBnstE8wy7rl+iYSvH9IyHyCBAluOzx9ZhnN/jaR7/V6kNgeoprlErWOj//nN1/EX3oozBiSLWAc1PyaSqXgSANgNaRswE6FZIkbBrKDIUUqee0KcAZUZueRNaITWvUBT0S+dBIGRJtTiczpEhPScXPOIxNth1lmxknkbtcp600EudvtotPpDCathiSd7Bv1AVGkHd/tNeITU4kcA2EW9s31Dr79yhW4MkApzcC3VxL6sDsM369N4PhkBgW/oaaeZrNZta26xj2TyUCUFrDMToPBwmSaI6VJlISQqHYc/Me1FObzTYjNy1hdXVWTZznnqFQqOHHiBMrlslriMiEKYAAMzZrrAQtVDHq9HhhjSkNvGIZ6DAUj+gJsuwqR1EZvcqbzxzAMXH31NbQvXkRmn1KWvaB7wNfrdRXMxRuBDzKISc/Ct9ttpNNpzM3NKclZggQJrg8JmU+QIMFtBdcX+N4bazA5U3p2x3HC8v/AspABCGDhx+81cGTrDLJpO0JQyUUln89HJqmmTAuWacBxQk9yBsAeUsruBS7yGRNfePKzMDkifubxoUREijzP20HudeKvk30AkUoCkVjf99Hr9YZOth2VxXddVznLDHOXoZH3tE+oKXMUSAKjv1fdX55Iq+u6kQz/77/WRscNUM6a4AiDmEAIiAHxTwNouwzffaeNj/DLqhE0CAJVqchms1hcXMSxY8ewtLSE33krADu/hfmcDbBwH/f6PXheWAmQQYA6s/HvX76GR3AZ2WwW09PTqFQqmJ2dxfT0dKTaQLp2Iui9Xg/tdlv9HZci6RIY3/fV8Uyn05EGWSCsbND5pp8H+vWbrQ8nS00i8NfjAT8MnudhY2NDVaDK5TJOnTqVDHa6TgQiPBc934MQUgsqfWSzuaHfWQnubCRkPkGCBLcVuq4PL5Awje1hRoyzkAhpRuSGACTjkIaliBh5ntdqNWUlqE/hZIzhhG3hhaaBZuDC4DtlDBKA60t8fKmI2uaGItWkmY6DJBNxkq/7xxN0ckcEL51OKzu++OCrVquFra0t5ahCUhJ6LSKZOqgRdxx3GdLO6+vThycR2XYcRwUIFDRRg2kmkwE3bbzyxz+EbZpIWTZ6/R6cvoNACEgpIMUgS80svNvP4rTpI5fLYWpqCgsLC1haWsL8/DxmZ2fVpF/XF/jpd7+vgjrX89AgQjoI6jjnMBjDlr2IJ375FCbLE4q0x3XtnuftGG5FQQ2R8vjkV87DoU06SY8HbPpttwKBPUwP+FFoNpuq/yCbzWJ+fh6lUummDna6XZtEaXAY2btSIE7Xq9Uq7nr7bUy7Lnr1BoSUEYvWlJ0CEjL/gUNC5hMkSHBbIWubsAbNrsBgKma+sEO64Pd9ZE2Gxz76CDwn6jKjyyLoOjU7HhNpvMLm0fY4skYAzhjAwny/AEPXZ0ibEkf8FfzoR8uRbdP9wPVFz6CnUikUi0UlnyDCHM9293o9dPsuWn0XGZPDMpiyWCSCSOS91WopAqDrri3LUpNGSUKSzWaHapOFEOq1aXgWBQNE4PXsdCqVGjpddti66103EoAZhgEhJcTAz51zDsY5JAxYho0/9YWv4MTCDCqVivKMp6CBjlmt46DveGCQ27Ikg0eDNG6g4wlYBsep+06jkktFyDrnXJ0vdB7QsRhWFSG7yHK5rCbI6sfjZkF63p566P9/e38aHFmWnmeC77mb77vDHVusmbGgMpnFyqxNFLdkppqqqiy2JJbayOkRZWM9Q9nMdI96RLPpYf0Yssd6aqaNJvXPHjPNL8m61dPKHrXISpGsrkWiuBWrIqsqMyIR+4YdDgd89+t3OWd+uJ8T9zocgAMBBJb4nqqbADx8uX7dAX/Pd9/v/WTzI3D4GfCjcF0X6+vr2NjYgBBCRXwe12Cnk9ok2uv1QnMYhr/KiE/OecgOJ/tHhBCYandQDCRDBfGGel2IlwMS8wRBnCosQ8M718t4/8YiOBfQNAZ9SIRwLgAm8JXPzOKdt19Tl0vhHKyuSrEW/OCMLNj4/z0U6Pg6GAQ0CHABcAhEmI+fiW2A1TpYbWhKGAYFdjBKcjchI6u/w+J/oQV891EHf/m0BY8Dps7w1y9l8DdeTaJstVGv19FsNpXIBvrWiEQigXw+r7LOZSyk9OWvrq6q68r9kgJWnkEI+vWlZz8ej4emzO41JCuI67rgThc6BDquD1cfFBKFgGmZz8S1pqPhcMQsHRPZNHq9HlZWVtRxGt7y6QQsQ4fT6yfDgDEU8gVgYI8C+oOluNuDZejIJWOA4MqSJZtyJfIMjTxe0t6TzWbVsK6T0KCZfPvtsa7n+T6cT38aH3300bYM+Fwup/Lbn1do1+t1rK+vo9FoIJFIYHZ2FplM5kQMdnrRTaKcc3Q6nZCtLVhdHzWHAYCyz0nRLvtD9oQBmqZD1+XfIeOFnv0gTg7H/9tGEASxT7721iy+dWsV1Y6DQnx0mk06ZuBrb82GbscYUxNbpV1jFO8C+NXlfob9tz9Zg+P5MDSGv34pjV+6FMdM/BV0Oh3V8Cgz14c/hIPZ8IyxUCOjzBKXlTopFH+8qeObCwa6ft8tooHD58C/+nEH/+YnPn4uvYHX0w6SySTS6bSqEsuUFPl4sglVxjp2u13Yto1ms6nsMXL4kUyqyeVySKfTyGazSKfT2xJGZAU7WKkPbjJeUdpxxJ/8CYwf/RiMMfwXWy7WWi5MnUHJ7YAt6qNzr+OPJn8a78yVcPni+VD1fCfB+e6nyvifPlwCFwJssOCCEKpZmQsO1+P4/OUYOq2mWmzJY2WaphLuUuCP04dwnGS++t6OVeRWq6Uq8CoDvtPZMwN+v0PXZBW+MhDC+Xwec3NzO8ZTnlbLyzDD041HVdWDfRWjCP4+Bs8WjiIY/yqbkEulEs7VGzDqdSSKEyfCukUcPyTmCYI4dbw+k8HXvzKHb3wwj422o3LmPV/A4wLpmIGvf3kOr00fbHAUALw2ncFr0xn89pfm9iV0pC9fivvh7+WHvvwQlyKy1+thxTbwbfsyHHBE4aomXFPrC+CuMPDn7TKuTvaQSPTjD6V4k55ueTpeRjtKcS/TcJLJJGKxGJLJpErJ0XVdVQfb7TYWFhbAOVeLn2BfgLSVBG03spIvF0ryvuv3H6B74waYYeBVAcy6fn+KKxAQ8oDu+3B9jr+4/Dn8+hcvqSmt8ixK0P4SrKp/sQT8G4Oh2nKQjevQmQYwBo0xCAHUOj6ycQv/yTuv4cLgvRCcytvtdpU1Kp1On7qqpkyJqVar2NraUhGYwHgZ8Psdulav17G2toZ6vY5UKoXz588jk8nsuegZtrwIz4MY6uUYPCF0fvSjkPA/THHvcx+e50PXNRh6WP5IW8uoarr8eSfRPQrZUzIs3IM9MsEoVNmELRdepVIJ5XJZLayDMaBL3/4OWscwOI44uZCYJwjiVPLeG9O4VEz0q+fza3B9gZjF8O5cX4w8j5APYhkaLGP8OD5d19WAoVHISnyz2USj0UCj0VDi++ZdH56jI6VzaGzwmIPPayGAOOdo+xr+dMnDzyXXQlGWtm2rMwFSoMZiMeWzl/8mbQBBv20wjUZef1igBb3ysjm1UCgov3UikdgmhjuGgV7A6uD2PGy2nX7THvp9CAIC2W4Dusbwv/9rU0j7DTx8WN3WGCwXETL33TRNnD9vwIvl8f/8w9uo2x4MjQ8t6kz89t+8jss5C7VaDY7jhM5E7McudFI4rAz4cYeuua6LtbU15YUvFos4d+4cYrHYvvY7aHnxKhWIXg8YPvZCwFtfR+vf/bv+j8/pZxdCwHFd+IFkKc45TMuCZZngXIB5HlY+/hgPf/Vr6BuzwkQGW2Hw89arr6L6qU+FriOTj4Jbr9dTC+FYLIZCoaAmAsuv8gxhJpNBJpNBOp0ee1G5374J4mxDYp4giFPLQavnLwrf90PVvXa7vS1dJpvNYmZmBlY0hv/XzZuIRXxkE30h7/kDP7vP++Kb9f379zoR/DTrQgPflmcuRTfnPDSYKOjtl2I9mOYjxXvQNiMFfnBKrOd5qFarWFpaUosJKZBlio20qMw+fYK066Jd76duCCEQFQI9rsFjOhjToDEGQ2P49HQCs3PFbWkw0o60E1/96RlcLiVDi7qoCfzCq3l86XoOVydMOI6jrAp7+d5PoiVkPxnw9T/4Jqp77H+t62Lj0SbemJzDX517HaZuwtAMJC0DAEO14+C/+uYt6O0qCnoXmUwG58+fRzabPby+AcaA4Yq+70OLRkOifxzk7IR6vY7Hjx/j8ePHKHxyC+VeD81qFWIwJwLoLyA934fgHGCAKTiMeg355WVwfffnpvkcnAs8mp1VVXq5KI5Go8hms5icnFQzAoLHyjAMpNNpJdrlYLKDHM9x+yb2e13i9EJiniCIU89+q+dHgRQUwdPzyrc8wDRN5HI5FQsZj8eViK51HHgCochNx3XRarWgMQaN9RNfdI2BaSZeufYp5BLP0mOC3vJgQkswLSco7mWzq7QPBCuXw6k50rsuoxhluots1ms0Gmg2mypVR3ry/0GrjU97Huqe92wGwCDi02IMqVQaus6gO4CBfipPNBpVAim4yNiNqxNx/NbbF/APvlBCo+sgm4giGY8dyPd+UlJQghnw9XpdHRPG2K4Z8OPsv+9yfMZzAXD8+dQ1eN6ggjuYvMagYcsx8fs31/F7v/6FkT77gxJcHI6L53kjPerNZhNra2tYWFjA2toatra21Hvv16qbmBACYmAXY6z/+6FrGrQR7wmua3ADz7N/fQBggz4MwGw2YRoGzp07t2v1PBKJhKrtmUzmUFN9duubIF5OSMwTBEEcAGmXCW5BkaJpGlKpVCjTfbdhTDJys2VzOF4PHWcQA6nFETd1RC0dlsFQ63hIRHR8/s1XoTOoRlcpxOXX4U02pUoh7vu+Gt4jFwLSN9xoNFCr1bCxsaGsQFJA2bYdegwhROg+gv58o9cDG8RPCqAvioSAGPQCOL2+NSjqeVhZWsQP/sW/ULeXiwhZ8ZdV/1gsBtM01cJC13UkEgllbSpPFJ9bNL3oFBSJ4zjY2tpCtVpFs9kMTYeVlqZxMuB3238BoLrZQbpTh+DPzriEX0sBITT8+wc1/N4//m8wM1XG7OwsZmZmMDs7i3w+H0pDGk6HkgO/5BZfW0XE89BttQAIGJ4HLSS0mdo33/fh2nb/31wXT54+wbf++T9X+2/btmq+rVarsG1b7bfs8UgkErA2NsA6HRim0U98GSxGgzMINE2DxjRopgXmuNAiEfRj20Womq8SkhhgmOHEmEQiERLtmUyGptoSLxwS8wRBEHswjl0mOIxJRhnuR1RahobLxST+ZLMvEgfx9hACaPV8dFyOfMwEB8Mv/9Q0Xr18adf9DTaLymqo9MvLiq9snNza2lJxl8MWG0k0GlUDpqSQHpXTHrTpGM0mgB6YxsAEIBgLDbjpD/ySNgOmEnHa7fY2m1Bwv+SZA9lwK4W/FPnDcZ/BCa/BRkO5BVOAHNdR4jY4UVi9KAj17j43LyIDPggfVJnZ4HsAARGvQ75EDBy+ADZqDWxVK/jxj3+sXl85YEqmKeXz+ZCAlWd+fN+Hbds4v7mF0qC6zrmPpOvCAkJnG4QQ0DB47zr9RCgT/bNKCwsLSrzLmE35HpSLO/keiEQiKJVKKDaaMJstmKk0mKZB0wZnt1i/T4P7/QnEGDymJjh8z+unIPl+34ojBKJR+bz6TejxeAKvv/66EvAnIYKTIOhdSBAEEUCK3mDFfT92mYNyc6mOT5b7EzkFAJ0x1R8ohIDPBTbaDopJKxS5KQWwFOVbW1uqmi6tL0GhHmx8DVpnTNNEsVgMXRY8JkHfvBTXwYpusFIv022SW1vQOl1YViQkiDXGwDTtWXZ7r4disYjPfvazKsFGJovIdJ6gjUiKRbn1ej21D8E40GD1PnjZMPIMgGVZeP3pAnKeh267NUgTGlSOB5VfTQp638fm1hYqP/7xtsfcaZPXs20bjUYD9Xo9lDgUiUSQy+VQLBZVSsxeswrGQUCoKbv9GM9BBj/34bnes7Mm8NViy9MjMLiLtaUFMOFve31XVlZC043ltGJd19XlshH0N5stFDiHO0jbGTbZhGw3QsBzPXi+D83zsbK8gh/r/WNsmqaaoRDcpqenceHCBVy6dAlTU1PQdR1L/+i30Hj4EGg0gMHz9gexpYEDAzaIJTU8H61mK7TA0HQdlhWBoetgmga/00Eym8XMpZ0X0gRxHJCYJwjipeaw7TIH5f0bi+i6PgpJC1sdFz4XobAPuUdFo4fv/9H/hD8KiPbgPg/v+05CM2RvCFTWg4zyNu92ma7riMfjSKVSiEajyDx+Ar26iVgs+sy3rJ6LgBAcvi8glw2yaTASiaBQKKgKO4BtE1mDEZVyASAz7mUEoBT6QSEqj0vw+AS/v9hqIeP76Nk9JXJl865hGIhYVt8/zTlqW1u4e+PGtrMYw8dICKH6C+SUXonMvY/H44hEIijfuYPCo8fBWw++hodhQQisnj+HxUuXQmdDPvfxx5ju9dCtVNT1+jdng7NFgCmfM9OgD2JGOefgggMC4ACEzhCt3kNlbWXbPAH5Phm2WwUXioZhKHuUYRpgngfdMPqzE1wXbHA2YOhAKUsYG7we2WwWX/ziF5FKpVQVvlgsYnJyElNTUygW+7aqTqeDer2Ou3fvol6vwysWYH1qTt015zLq1Eav58AZDA8DA3Kuhwn0zxIZuh4cf9DvVzkBw8IIYjdIzBME8dIg7TLBJtWd7DLJZFJNPD2q+EIp8jY2t/BHHy8C3IfOfcREDzYHPKZDNuBp3IUAcG/Vxv/wb/8lNIhtFedh7/phwhhDLBZTAk1OhQ1+P3ysrMgzi0LfixwcP/+sWsw4R7PZxP3799X+S5sGED6DMGphIu0y6XRaCelgU69s9g1GeQ5/lULU7tkqtUcIEZDSg03T1PfmYJoqgFDVWgpb2Rwsm4ylnSWZTKpBX3JwlVyQZO7dx8STJ+CaFhbjQOhnXQhUq1V8uLUVsiBdqdUw6fvoOY5q3ASeiVLGGHTRX1J5ot/L0PeO9+0nvs/haREwz4b25K/6DdiDuNNgytBw2lHwdZc/CyH6DdGOAzE4m2PskE7EGMDEYEGYSPTDbnwf6XQa165dw9TUFCYnJzExMQHHcVTm/b1791Cv19V7RVp7OpOT2LQsVKvVZz0ImoVIOqXsV9Ki9fM/vAG2ZvffX4xBN3QYhglzYN0hiJMOiXmCIM4k49hl5Gj7YNX9MEVwr9dDvV5XFXT5ffCr67roCR2N3uvgALqeD+56YK4LAwgJH65bgGaAmVHovF/dlZVRDF131Pe7/bsU48NiXVaNpYVCXj94n8Gfh2G+D6PRCF0WjAkEAMY5otEoSqXSttsHm3iHm3ulCJfPf3hRE0yzkeISCNuGOOewbRutVgv1eh3mw0eh56QsNoP98TyvLzSFgOv204CCVh55X8OTdnVdx6vLy5hZXuk/7qDJkot+3CEGU2szm5uAEHA0DT0pJJUoF0rbp1x3MASpvxiV+6DL5z/Y79Bx5xxi0IuhCw4NHD2Y0BigaQIcGoQOJDWBtyddvHLpb2Bzc1PlzEtff3BYGOdc/c7IfZCLPPmeyj56DHNzExnOAd+HLucMBBYqTJ1AYANBbUDT9X4V/qd+CvV6HY8ePcJHH32kHluehZG9IBsbG6jVauh0OuCch+YtBPsOTNNUC8BsNovy0wWY1Sqig+bi0zV5gCBIzBMEcUYI2mVarda20eqHbZdxXXdHgS6/9nq9se7LQF9YeQPDCdM0mCP2zWEmTOHiwswk2KDKHayCSoYvi0ajoectK8PBhBiZUT/KriN/Dv7bqAWCFEtSBGuf/xx6ljVIBoGKBZSCUzWUMiDzuc8h+6lPbTvDENyClw2fgZDV9KDgH075YYzBMAzlyZdTeRlj6rpTDx/C3NpCvlAIPLe+iBbSTiIENK2LaKTfFFytVtXrLiv6QN82JI+zpmk4/+GP+lV3XQ9V3PuPMHgv+D6YELA4hx9oKmWDAyWfv845SqUSfv7nfz5klyp2vgWt00E0Fnt2m9DX/v3pjoOLCR9v5hzcbVnwhAZNcJw36rjIVxFZq+Lh4GyGEAKFQgGlUkl54mUvgxzEJW1ChmGoVKRut4tms4kfMKA7iGbUGMP5bhcZzwt55+X3wnWh1esQDBBcoLpZxepPfqJsSvJ1q9VqaLVasG1bnSGQFfdUKqXeG1K0x2IxxGIxlMtlZc8pFotIpVJY+mQerbt3YVAVnjilkJgnCOLU4XmeqsbtZpeRVpn92mU8z1ONo0GBHvxepmocBjoTOKfXcd8vQAhAHzG8pp8wouGi3sT0ZDn0POXkyGBEXvDrUSRuBC0lcnNdF7ZtK++68Wu/BvPv/T1YlhVaLBx0Gx5iNUxQ6MumVinUZTNwo9FQFf1erzeotDMl9oQAuODwe88886GFzOC5W76P9fV1fPe731X+fdd11eMPH3NN0zBbryPPGNqmOUgr6jcCa1Kkaxr0dhvMcaDpGqJDEYfBRlYAMHRjWwa8aRr9Mwd7DECCAHTew6f9e7jMu7B9ARMchuifsTIG7xu52JOvoYxfTCQSKt1oc3MTT548wcOHD/H48WOsrq6qsxGGYeBuuYy7gffN61tbmGu1oWky/72/P/LY9s9Y9L9+tLiIH/23/62qsMt9kf0U0uYUTC2SZwRSqZQS7XJi8U758DRRlTjNkJgnCOJEc9h2GT7wZ+9WVQ9GBL4oXtU38NTPwoaJqHD7Ym8QFwimocs1pE0N/7uf/yl8+sIvKqF+FM244yAr/8HGVBkNmEwmYZrmkfUaAKMXE1KgdzodFb25sbGBVqsVapyVaSsyylJ61l3XxZvVKmZdD9jaCtgthLKBqIWEEGg0GlhZWVH2pFwuF4rElAIzEon0Gzerm9DrdcSi0WfCVe67EMBgsJcGgPv95xJq0NUGe8QC2wg0n8MMvYeDV+yrZo1zaExDPp9HqWSoBCLZcCyEQDweRzabVUOPIpFIKD1odXUVm5ubaDab6HQ6mJycxPnz5wEAjUYDGxsbWF1dxfr6OhqNhhpCNl8s4vbERCg+NRijOmzj4bVaKNnHsizk83nVCDs5OYl0Oo1sNhsS77JCvxc0UZU47ZCYJwjiRCHtMrLqvpddRgpHoC9AWq0WVlZWdhTqwWE8x4lhGNsq6FfqFv75zTY6rgVD12DqGjxfwOMC+YSBr395Du+9MX1s+zws3qW1IRqNIpPJHIl4HzWUaHhrt9uhWM52u63Eoe/7MIy+WPVFv+/AgA9T2+751zQNy9NT/csB1RAqRP+MQF9Q6qrqX/z0G/g//eIvKsEu03iCFhQZz9jtdpVA9X2/L+Z5P+ucof9+CNlOhIDrecraE4lEYGijP7JN03w23+Dnfg4il1NnFKTFyHM9uJ4LwQcLByHQuHZVVfY1TVPZ8cGJpcEF0Pr6urK3NJtNNeBLPr4cOCbtMJZl4fr163jzzTfh+z5WV1extraGSqWCzc3NfprMgGAvhHw9glOIZeOt7IHQNA3dbhdra2totVrI5/O4cOECEolE/8zJ4KzMONBEVeK0w8SYn2qLi4s4d+4cFhYWMDs7u/cNCIIg9kCOaQ9uw3YZmfoRFBejLDDBLPXjRtd1pFKpHW0vmUxGWRSGubVcx/s3FvHt+TW4voCpM7w7V8bX3prFa9OZfe+L43F0HA9xy4Bl7L+5V4p3OTk2OJDpIEkfwQXB8KRQeZlcxHU6Hdi2HYqglCI92CMhp9JK5D5K8VdjSTzVp7EocuDQoDOBV2NdfCbr4HLODKXy6LquHocPGnNNs3+dXC6nFi3ybFGr1QrZveTAq2D2vhTUV7/5AcoLC2hZJrjPVVqOruuIRqMQnMNqt2F4HjzDQC8e71tvVIpPf/CRpmnA1hYif/2vY/af/GMAwObmJur1uhoIVq/X0W63Qw240rsvhFCTS3O5HNLpNCKRiErUkcfScRwl3Gu1mhrY5ThOyL8uFy4yjlIee8dxQmfRZIOsPDPSaDSwtbWFbrerFkDSmmMYRuj3I9h/IB9jL7FeLBZDk2vL5TLFTBJnEhLzBEG8EIbtMtL6IP9NRgdKy4MUbDJlRDYXHjeMsV2FejqdRjKZfO4q9fOK8JtL/UXBd+bX4PL+ouCd6/1FweszOy8KgpVwOQxIived/MajqudSqMvXXIrzbrerxHIwK16K9aDFQjbaymMpbxMUnLI6nM1mkcvlVMa9ZVnIZDL4pBXFf/9JF22nfwwMXYfP+2c70lEDX//KHH7xclrZcur1uhKXcvEiBWjwPTsO8v0rF5+/eONDXKxW0bLMYHw8dE1XCw+r3Qbr9cDicRjFoorCDMI5h1epgL/5GVT//t9XKS6O46iGX+kbl6kumUwG2WwW+Xwe6XRapQHJ2wD997asrsvFQLBZWH6VcZ6+7ysPe7BqDkBN25WNyZ7nwbZtlXYj9y8Wi6n+At/3UavVVMpTt9uF53mHMpQN6J/BmJqawuzsrNrS6fRz3y9BHDck5gmCAPD84nEYKWSkBaJSqYSqmEFRIAXSSch0TiaTuzaTplKpE1/d++ZHy/jGB/No2B4MjcHQmbLrSAEr7TrSZy6r77J5c5QnXk5lDYrzTqejGkmH/c/ewAMeFOdAuFF1VGKNFHzB/QCeCcREIqHSVaanp5FOp1XTrYwdNAwDN5fq+M1/9kM0ex4KcQtM6w8P4pzD9VzUuj4sxvFeehFJr4Zut6ssI4lEQvnHx4Vzjna7rZptpbVG8jc++hjnKxX0EgkYug5jYCMJRkp6lQp4uw0tkYAxMTGY1MrBfR8+9+F7/dQdo9nExsUL+Mk778A0zf5wpkGSjPSP5/N55HI5WJalFkHSHgVAVeObzSY2NzexvLyMVqsVSvqRNhu5uJPHVy6aZPOw7FPxBxNV5esfbErdqZo+3KgatGy1220sLi5iaWlJfd3PgmovUqmUEvYzMzOYnp4+tj4Ugjgo5JkniJecg1ZwJZ7nYXNzEysrK1hfX0elUsHGxoaq7LVaLSWyTNNUVbzgafQXJeLj8fiuQj2dTp+IBcXzcHOpjm98MI9mz0MxYUHTAlNeOcdm28H//V9/DG9rGXnWUdam4MJKLrRGVc13E+fDIl3XdfVaS5uErMpKwRz8PhqNwvd9Ze8IeqqDqT3S5x30zEuRH+T9G4to2B4KCROtVhM+98G5UFGZBoAuTPykEcO7hR6KxeKOZx9GIW0j8oyB4zhqGFEmk1HHjDGGQqGAqcoGzFoNVioFXdegMW2buFX597YNe2VlW5IOG+y3xjkMw0S5XFZnJYrFYqjqLqvsQP8MhjzW1WpVieOtrS1VdZcLuuCALcuykEwmlZ9exprKqr+0EMkG6Gg0ilwut+PvkWEYKBQKKl2mWCzuumhKJBK4du0arl27po7PxsYGFhcXlbhfW1s7cB9Ms9nE/Pw85ufnAfTfy+VyWVlzZmdn1ZRZgjipUGWeIF5i9qrg/l+/dA0/fzEZ8qXX63Wsr6+rITK1Wm2bT12KOBkhd9TJJgDUBNCgMB8W6/sRaqeV3/39W3j/xiKKCQtduwvP63u/ZeMjFwI9mLjM1vE5/YmqfEuPNzBanANQVXu5BafABkW57HGQwm8ni45sWJaxn8G4T8Mw1OsoXzt5psC2+1NaZePpqPeW43G880/+LZpdD/mEhVqtBjdg05ITUbvCQEzn+D9e2oQ+dDdygZBMJlXMaSQSCQnlbrcL3/fVtNdmswnHcRCPx1Eul3Hu3DlMTEwgEonA/73fQ+/P/hxmYDgWFxy+36+8e74P0W7D6HRgJxLYKk1A0/oe8v4QKMA0TFgRC5YVQfqX3kb8b/5NdUy63a56DWW13vd9rK2tYXFxEU+fPsXa2hra7fa2iru0sMk4ymw2i1KphFwup2wuuq4rS5XrusrOtBvJZDJUdc9ms4f+t8BxHCwvL6sFyuLiIprN5qHdfyQSUeJefh1eOBLEcUKVeYJ4SZEV3IbtImUB4Bzc5wDnYD7HRoPht/+Hv8Iv6LeRdGuq8hislgJQUXFB4X7Y1W3TNHdtJpUNfC87jsfxndtrMDQGTWPgvF8x7Zd0GRgDNDAwCCyJPL5grCBqPRPn0g4RbDQMCnRZPZfV54PYjWzbVuK92WyGFhBB8R6LxdRtpBVEetiTyeSe77GO48H1BQz92dkfn/N+prvWzzbXmAaLa2C6jqufegOT+VRoNoH0cjcaDXXG6enTpyqRpt1uo1arqeeRzWZx7do1nD9/HqlUCpZlqeNoWRYWDRO268JZW1M56rIJVg6R0hgDNA2d2VlUf/3X+u/vYhGlUgnZbFZl+Xc6HdR7PdRXVsAYUwObXNdFtVrF4uIiHjx4gNXV1VAPhOu629KhMpkMyuUyyuWysrmYpqnEvrQLyWmv8uzKMLqub6u6R6PRfb9H9otlWbh48SIuXryI/qHsx4YGxf3Kysq2v13j0uv18PDhQzx8+FBdlsvlQuJ+cnLySOY5EMQ40DuPIF5SpAUhG9XRaNT7F4rBYBohYAigxyzcbCdxvfNY3W43u8xBCFZgR9leMpnMvgY+vcwMC1hN0xGxIs8GE8mvHmAaGv5X/+v/DUqZhGpwPQoxIgdwSQHvOI76t2HrTHBx4Pv+s+FThhFqlByHuGXA1Bls59kU3Gg02j8GclgRALftIBEz8MXPfkb1iriui0qlojbp0eaco16vKwEfiURQKpXw+uuvo1wuIxaLqWmo8XgcQghsbW1heXkZKysr6JoG4hcuwB9MP+1XuzXomt6PYrQsRCwLRjSKV999Bz/95S8r25Os/EuC2fhbW1u4f/8+lpeXUalUVLrM8BkzufDOZrMol8uYnp7GK6+8gmQyqYS7HMYWFPw7LZwSicS2qvtJ6CdhjKnhVp/61KcA9N9P6+vrIf/9xsbGgR9DRqF+/PHHAPrHaHJyMuS/z+Vy9HeLeCGQmCeIl5BgBVeOMO/7X8OuOwaBDWsKSWMBUcvct11GZsLvVlWXkZPE8zMsYBPx+MjrddsOYhEDl8/NHEqzc5C9rDMyTWWU7SnYjCsHGB3UzmAZGt65Xsb7NxbBuRhp8+GDVJt3r5fRaTXwZFB9r9VqSsy6rqvsZbZtI5PJ4Ny5c5idnUWhUFCxkf3psf3nfv/+fSwsLGBtbQ2NRgOu6/YHak1NIf7KKzBNE8lkErlcDoWJCUxNTSGdTsN1XSXcN7tdVJ88USJb9i00Gg1sbm5ifX1d9aiMGnKmaZo6i5JMJnH+/Hm8+uqruHjxIkzTVMK9Wq1ifX19z+OpaRoKhQIKhQImJiZQKBRGVudPKrquY2pqClNTU/jc5z4HAOh2u1heXg757w862dn3fSwtLWFpaQnf//73AfR7dILV+5mZmRdypoJ4+SAxTxAvIcEKbj/6bvA/BjV5lIGhb4AwEImlYbHtVb5kMrmrUN9tEitx+AwLWNn8GkQJ2LnyoQn53awzqVRKiff4iMXFsA8+Eokc2gCqr701i2/dWkW146AQt0LHw/M5qu0e4gbDOX8Jf/7nT9T+9Ho9JeCliH3ttdcwNTUVek+7rot2u4319XUsLCygWq0qr3ZwmFYsFkM2m8XExAQmJydRLpdVDGSn00G9XlcVdTnwyrZt1Go1FZO5ubmJSqWifPryGANQk3flY+bzeVy8eBEXLlzA1NQUotEo6vU6qtUq7t69G7rtTsTj8VDVPZfLnbnf5VgshldeeQWvvPIKAKgzKUFxv7q6euDZFZ1OB3fv3sXdu3fVZRMTE6Hm2lKpdOaOK/HioQZYgngJGW4O3Nrags/DH1ga02DDQNxg+N3PayjksiGxnkqlTn3yy1lkOI4xKGA5F6h2HKSiBv7pb3z2QEOogN2tM7IRea8YTylYpQ8+Go0eyfsp2OStMwENgOP78DkQN4C/fRl4I+spAe84DmKxGIrFIiYnJ5XNSwr3ZrOJ9fV1rK2toVarqRQXGb0ohXs+n8fk5KS6j2Dco4xobbVa6mf52DKPXT6WTISSDaqmaSrhLs96pVIpTE5O4sKFC6rBdGtrCxsbG2NVmjVNQz6fD4n301R1P0o8z8PKykrIf1+r1Q7t/k3TxPT0dKiCT9n3xH4hMU8QLynB1BPXdfoeXq0/XVLTNQgObLQd/N3PzuJ3vvrace8usQ92TSmKGfj6l5/lzI+DEALtdluJ93a7rf4t2PMgmz53YtgHL6erPg/1P/gmWt/73o6P1+v1UG12sdxw8N3CHP70/JvQmcCnsj4+ne4iK1rQdV2lrsj0FulFl8lN9XodnU4HnHO1+JALl3w+rzLvS6WSGpAmfe6yYi8HIsnYTZnJLvdVCngZ5yqbyhOJhFpIJ5NJNY22WCzCsiwwxlRW/DhVd7lYkVs+n6fq8D5otVqh3PvDzr5Pp9Oh6v3U1BRl3xO7QmKeIF5SXkQFlzg+bi335wd8e34Nrt+fH/DuXH9+wDivZ6/XC1lnglNC5WCtnawzQYQQahosY0zFGR5Wn8TSP/otNL/9bTDTVKkwMnkl+OHGPA+1uTl8/JVfgaUJZJIJ5VuXHvLNzU2sra2hUqmo5zxcCZcCeHJyEtPT00gkEipdRnrZq9UqqtUqNjc30W63+xNbPa+/UB6IZiEEJu/ew9TSEnzfg+ACmq7BNEyYlgnTtGDoOrTBQKnGtWtwvvgFlSIkq/d7wRjbVnXf6zUj9gfnHBsbG6Hq/fr6+oGz74fRNA2lUinUXEvZ90QQEvME8RJz2BVc4uQx7mRf3/eVeD+odUYyygd/VIlET/7hf47On/wJeCYDn/tQn2hC9JOZBk3dZquF3uuvw/nP/lPouo5arYaNjQ1VeW+329A0TWXkR6NRTExMYGJiAuVyGVNTUygWi+osxcbGBtbW1rC8vKzuQ3rZhRAqojX4nIMRkZxzfPn2HVxYWYEY2IuC15UDogCA+T7qc3N48qt/Z8/jEY1Gt1XdyQ734pHZ90H//WFm30ejUczMzIQq+LRIe3mhBliCeIl5741pXComQhXcmLW/Ci5xsrEMDZax/RT9XtaZXC6nBPy4p/iHffDj5MHvF9/3Ua1WVWxkYn0NscEUUiH64p1DqKFQhqWrDPdms4l/+8d/jFqtpqw+iUQCqVQK586dw/T0NIrFIqampjA5OYloNIput4v19XXcu3cP3/3ud7GysoLNzU10Oh34vg9d11VM6/BxklGacvGTSqWQSqUQjUb7x3h1DahU4CWT4EIAQjwT9IOBXQBgNpsAttfdGGPKbiOz3eV0XOJ42Sn7PhiNuby8rHoh9ott23jw4AEePHigLpPZ93Irl8uUff+SQK8yQbzkvDadwWvTGfz2l+bGquASp5f9WGfGraIfhQ9+mGazqcT71taWqn67rgvL9RDlHN2eC0docIUGIRgAAY33ENEEwD0kfA/1QVRmuVzGzMwMzp8/r7zumUwGrutiZWUFT548wZ/+6Z9iaWkJm5ubasqrnLEgoyXlMWKMIRaLIZ/PIxaLqam60vc+fCxd18Xm5iZKnQ5SgwFUmtYfXrUbkUgkVHWX0ZjEySeYff/aa/0epGD2vdyq1eqBH2NU9v3U1FSoufYoJvASxw+JeYIgAOxcwSVOL77vh1Jngk160WgUhUJBDWzajygc5YM/zHkBcoqpFPDSrtNoNLCxsYGNjQ1sbW3B8zy8u7qKGBdo+jr6xpRnFWxfs9AFEDN0GL6PS5cu4bP/8B+iVCrB932srKzg0aNH+P73v4/l5WVUq1XYtg2gL4SkcE+n06rCqes60uk0SqUSyuUyUqmUypevVquqoVVOZZVnQGRjK9AX5dlsFrF4X/jrg/vue/19cJ9D03VYA6uOpuuYmZ7Bz/ydvW02xOlhp+z7YHPt4uIiut3uge7f9321SJAkEomQNWd6epqy788AJOYJgiDOCDJFJWidkW1Ruq4fyDoTRHq+OeeIRCJIp9OHkoIihXqlUsH6+jpWVlZQq9WwtbWlmkmleDdNE5xzOI6DZrOJN1tdXMBgH8SzJBe1rtA09JgFMBuO08Mf/uEfKuEuhTfwbLJxJpOBaZrQNA2WZaFQKKBcLmNychJTU1PIZrOo1+tYWVnB6uoqVldXQ89FxnbKzfM81Wh76dIltSgQg4o85xw924Yvs/kHO5+KxZQH2ms2yS7xkhCLxfDqq6/i1VdfBdB/n2xuboaaa1dXV8dKLRpFu90OZd8zxlAsFkPNtZR9f/qgvw4EQRCnmN2sMzLScL/WmSCe56Hb7SoffCKROBRrh+M4ePr0Ke7evYuHDx+qyatBW4tswpVpMLZtq31hjEHXdYh4AaivgkEEFLxEAJzDB4Pr+VhaXMKHH34ITdOUVcY0TRWvKavtMqkmn88DACqVClZXVzE/P4/Nzc1n9y4Eut1uSLw7joNsNotMJoNXX31VVT09z1ODn6rVKrrdLn6lsoHMoFE3KJ4YAPeAXmribMEYU5N333jjDQD9M1erq6uh6v1Bs++FEOoM2I9+9CMA/YXtcHNtKpU6rKdEHAEk5gmCIE4Ru1lnIpEI8vn8cw/1kvnsvV5vVx/8uEk53W4Xa2trePjwIR48eIDHjx9jdXUVtm2DD5pX5aKh1+tBCAHP8+D7vqpeA/2zC3JIUjwehxmJQjxp9MX8QMhz/sxmE9T2HAyGaap0GrlNTU1heno6NCRJVt5v3ryJ9fV11aQYrLq3Wi017TaVSiGbzSrbjXyN5FRYmQ4k7UiJRD8W03r6tL+f6De7Mk2DxhgMw4BxyH0HxNnBNE2cO3cO586dU5e1Wq2QuF9aWgolUu0H13Xx+PFjPH78WF2WTqdD1fvp6elD740hDg6JeYIgiBPMXtaZbDarqu+RSOS5HqfX68G2beWD36lZ7sdPt/D//eEC/v29DXi8n2H/zvUy/sM3SijqtrLLLC0t4fHjx1hfX1cDl6S9xHEcNQFVCndZbWeDJJpIJIJ4PK4GJcnnp+s6otEo8pOz0Bb+Chhcv1/T9p/Z5mWD6kAsX371Cr70O7+zzUJg27ZaYKyurqLT6ai+gEajoYS7nKYajUaRy+VQLpcRi8XgeR5qtRru37+Pzc1NOI6jmmLj8bgaJMU5h67r6nnF4wlouoZ4IgFD16EbRv/5H/hVJF5Wkskkrl+/juvXrwN4ln0fjMZ8nuz7RqOBTz75BJ988gmAfkJTuVwONdcWCgVqrj0mSMwTBEGcMBzHCWW+H7Z1ZvixZIV8Lx/8zaU6fu+P7+BP71UgC+CmBjBw/Pd/8QD/8s/v4A3+AIXOE7TbbVUZlFX34HRTIUQo1jGVSinxK4c0aZqmnnOpVMLMzAwuX76sUmigGfgX3/sNGNxHttsAoFIoB97z/vdcAAb3kYzHoGkafN9X1pmVlRXlx5eiXe6nrMgbhoFkMolsNotIJALOOba2tlRzIucchmEgHo+rXHchBAzDgGVZSrzLxJtSqYRSqQRraRnOygqMM5gPvttk3mGSb7+NzFffO+I9ermQg6ZKpRLefPNNAH1L3vLycsh/L5uy9wvnHCsrK1hZWcEPfvADAM+y74MVfMq+fzGQmCcIgjhmfN9Hq9VS4l0mqgCHZ50JIv3nruvu6oP3PE8NVvrg4xX8d/MOWl548eBwgAnAEA5s6LjBL+DK1n2IrTUljGXqjbTs5HI5JJNJZDIZxONx9djxeBzZbBYTExOq8j09PY2pqSkkk8mRiwz/iz+Lv/I5Irq+3TKPvrjv+T5mMlGY16/he9/7HtbX10PCPVh1l2cOIpGI8tQD/cqkbDyUC59YLIZ0Oq0y5mUTrVyMmKaJYrGovPilUikkbpZMEz3XhVep7Pp6iUCj7mmh9b3vqcm8uyGfG4n5oycSieDSpUu4dOkSgP57vV6vh8T9ysrKoWbf5/P5UPV+cnKS4lSPABLzBEEQL5gXZZ0JwjlXcZJSVMsBQzLvWm7SJrO5uQkhBKo8ju86r6ILUz6BQdFbDP7L4MIC6zXhGxF8Yqeh370Ly7IQj8dRLBaRyWSQy+XU4KRUKoVEIqGyt4vFIpLJJKLRqEreMSMxuIIhahk7ni34/P/21/Gb1hU0ex4KcQuaNvDOD3LoNzsuIhrH3yqs4WZ9E83vfhetVguu68L3fbiuC9d1lXUnEonAMAx0u11sbGz0B1EJgXg8jlwup84cyMFPlmWpMwimaWJychKTk5MolUooFou7pgYl33577NdvP9c9KTDThDExset19lrIEEcHYwzZbBbZbDaUfb+2thby3z9P9v3m5iY2Nzfx0UcfAeif5ZqamgpV8DOZDNlznhMS8wRBEC+AoHVGRhZKgtaZRCJxaB9swz54y7LgeR5WVlZCor1are4adXfHycEWmpLuEEJ931+DCEDTIDQTjPvA9E+hsPxnyGXSqsouK9WGYcA0TTUpttVqYWNjA48ePUIikUA0GkXFi+KHVR23ajp8ARgM+Kk88MVJhvOpvj9eCmjGGH7lPMf/eNfDSs0FgwDjPjwhwAWDwXsobPwQP/n4vmq0ldNipf3FMAzYto1WqwUh+skycjJsLBZTIl/6+YG+6IlGo6rqXi6XUSgUYBiG2i/P81Tyzqgt+eUvIfWVL6ufCeK40XUd09PTmJ6eVpcFs++lyD9o9r3neVhYWMDCwoK6LJFIhKw5MzMzh1bEeFkgMU8QBHEEcM5DqTNB64xlWSgWi0rAH/ZpZ9u2sbq6ivX1dTQaDWxtbalhS9J/v9e+dzoddDodNNtdPJm4DsE4hCb3U8j/AwikxphRGNyBbsbx1hd+BhY85cnvdruqwVVWtOVgJcuyYFkWut0u7vfS+PNuDD3OoMGHxgRswfBnKww/XPPxM7E1vBppqFjIer2OjY0NZDsMfuoKWplLEEwDuA9jfR7Gwg2s1haxOjjuMkM+2AhomiZisZj6d8Mw0Gq10Gq1lK0GgLLVpFIpJJNJCCFw//593L9/f8djGRTvwcuC/yaRC5Tg98FFy17fy9vJsxijvh91++DP434ddbvQYzUaAOdwXAeyeYEFnidZLU4PO2XfB8X982bf37lzB3fu3AHQ/72YmJgIVe8nJiYo+34XSMwTBEEcEkHrjKzyAn3xchTWGel5lYOWlpaWsLa2hlqtpirM49Lr9dBut9HpdNDtdlX6is01cMHAhA9oBgANTNOUPGODTlMhADDANKKIaD5emZoF91zV7Kppmhp8JJNqZKVbHqcV28Bf1MtwBUNS52Ba/77lc+34Bv6sU0a38gT2yn00Gg2VO885h+l9ghQHhGGBeQ4so/+YWi7Xn6Q6EJGy2i6Fe/B4SvHOGEM8HlevWTKZhGEYal+l/Sb4c/Br8D4PmiBymnlzaQkl10Wv+azBsn8SRyAajSKRSBzbvhHPRzD7/tOf/jSAfpyl/BskRX69Xj/Q/QshlOVPZt9bloXp6elQBZ+y759BYp4gCOKAyCmkUsAflXVGCIFms6lsMcFNRjtKsSwF8l6P5/u+qr63222V5y4tKLFYDNlsFjPnzmOhlYADHYxpaNoedC18/0IAQpbqNQ0/cyGOt9/6ecTjcaRSKVW9jUajyhMvRbN8XM45/uv/5QF4rYKJlA4GBjGo7DqOC9d1IDwfXT2Km+0k8pubKupRWlsMw1B2Hpn2I6e4JhIJmKapjo0U2HIglaZpKBaLqlG1UCioVBr5GgyL992+Dy4MZCTnTl+D1w/ebvh+5P6Ouv5el426bfDyUdcZtY/jHgcwpk7ZiMFgLLX4I0vRmcM0TZw/fx7nz59XlzWbzVDu/fNk3zuOsy37PpPJhJprp6amXtrsexLzBEEQYzKudSaVSoUqvuMihEC73Q6Jdvm9fCyZthKMTYxGo3uKd+mfl9V3eX8yTlEOM5qamsK1a9dw/vx5TE5OwjRNdH//Ft6/sYikZaDt+PCFgCbE4DH7vnkuAEMDUhENv/aFi5gpJyCEUE2wkUhk1310PI4/fViHoTNoTEOjXoc9GCD1zKMPQAi00pcwFU9AExymaSKRSCCbzSpvu0y+kRV2oC8GHMdR4jQSiSivuxxhb1lWv5I/sIwEN+IZ4wj61SdP0VldhZVKgWkMjGnP8vNJzL8UpFKpbdn3lUol1FxbqVQOfOaqXq+jXq/j1q1bAMLZ93LL5/MvRT8KiXmCIIhd2M06k8lkVPU9Go3u+35HiXYZkTiMFPAy03wvcQz0m82kbUbTNCSTSZRKJcRiMbVFIhGUSiVV4Uqn09vu52tvzeJbt1bRsF2kIxoatt/PmZcCDn0zTCZm4h/+4nl8ajqt7n8nISzjOFutFtrtNlaqDbS6NgQXcHo+fM4hAh5cmRuvQUA3Lbz+6TdRyvbjLXdKjJHC3XVdJJNJFAoFFItFTExMIJlMqmPqeZ46szKqp0CeWdB1HbquK3vO8BYU/2dZQOx05kf2MbTbbXi+Dwa8tJVSYjtSbJfLZbz11lsAnmXfB/33h5l9H4vFVOVeNtcG42HPyjwEEvMEQRABXNcNDWwKWmeGPdTjCDbbfjYRNSjex/nAkhYQ3/eViNytSiwGfmTLstT1U6kUIpGImroq4+ik/7RUKu3YjCgjNHNo4e//dBr/9AcbaDsclg74gsH1B3GaDPjChQz+z++8gs9cmth2VkImxQS34TQMjwsYrJ9bDwZELAuu6ygRrWk6dE2DLQxEdYFrr1yCscPhZ4whl8uhUCigVCphcnIS8XhcHbvhD3B9sEXwzNctBmccYj/3s4j98i8rS5CMs3QcB51OR/0cPAsgkdYnaX+S3wcXA9LLf5oXAK7rotPpoP4H3wT//vehGzr4J59A9Ho7Rk+yeBw6+eZfenbKvg9W758n+77b7W5rUi8UCkrg5//Nv4H7Z38G7ZTPQyAxTxDESw3nPDSwKSgy5dCfcawzjuOMFO2NRmNf+yOEUNViKQZHVZ4zmQxKpRLS6bSqCnueByEEHMeB67ro9XqqOi2znWdmZvZsPux0OqjVaqjX6+h2u7BtG28WTfzOL03iz5c5/v3DGlyfQ2MCP3M5j1/7/AV87vIEfN9Hu90OifZ2u73rB3EsFkMymUQymcTPrSzjj+7U+lV/xhCLxqDpzxYvXADcB64lnZCQl4O15HTVYrGoGmxHsZ+BRrqmYeJXf3XX6wHPfOfBTYp/+Xq6rgvbtkOXCWVXgkqFGVX1H14EHLf9R1bhO50OhBBIJBLQPvwQnX/378BME9y2Ad8Hb7dH34FtQ3Q6p3IgFnF0BLPvX3/9dQDh7Hu5bW5uHvgxqtUqqtUqPvroI3zxwQNMcw7PsmCYBgyj/zumD/1unfR5CCTmCYJ46eh2u6HM9/1YZ1zXVVNRg6K9VqsdeH9288GnUilMTEyo0eyyMVM2xLZaLeUH7/V6ahGQz+dx9epVzM7OolAo7Cn8Op2O8qBKT72u60gmkzh//jxisRiuc44vXO3iP3/7Arouh6UJOHYXrdYa/vIvH+yaPa3rOhKJhBLusVgMvV4P1WoV6+vruHnzJkptD6ZIo24zxHW+Tch3fA1RTeBnZwxcPHcR+XweExMTyGQyME0zlAO/F4c90Oig4nq4CViejXFdVwn+Xq+nzgBIqxXwzO4SrPrvZAPa7/HZCc/z0G63Yds2IpGIOvYAsIVnx9VvtSB2eD9w24ZRKiH+mc8AOJ0DsYgXRzD7/vOf/zyA/t+rYHPt4uJiqIdpv7ieB9fzAPTvI7iojp6CzHsS8wRBnHmC1plmswk3UA3cyTojq0HDCTJbW1uHFjUY9MGnUilMT0+r6aFSwMdiMdUUW6lUcPv2bSXe5SYnlF68eBGzs7Mol8uIxWJ7Pr7Maa/X62i32+h2uyEBn0gk4Ps+KpUKVldX1f7KBJ2dkNNlg5umaepYPnr0aGTm/VQU+A/KbXxrLYGOr0EDoDFAsP7AqkzMwD965xJ+5adnlUA9TGuKTOR5kXaXgywCgmk0svo/fAZAWoDkzzKlSD5m0P8fXAwMnwGQSUG2batJxfF4HKVSadfjpCeTEMmksiRFBhn+QH+RFP/MZzDzT/7xwQ8c8VITj8dx5coVXLlyBUD/d6JarYaiMdfW1g6cfc85V39fd5vifFIgMU8QxJljL+tMoVBQAp4xhs3NTSwsLISaUfeainpQLMtCNptFPp9XyQvT09Mh64vv+6hWq3j48CEqlQra7Xao+i6bYIvForLO5HK5sT38UsA3m03Ytg1N05BKpZRAa7VauHv3rsqrlykvw8imWrnJyrscurS+vo4HDx6o7Ptx+ExBYG42ih9tRXBjtQdfMJi6hl+6PoG/+9Y5/NS53NjHei9834c7EMCe58L3fMTj8bEWQsfJcDPuuAxbgILVf8/ztlmAer0ems0mHMdRC7RYLIZGozHyDIA3WDBIAS8Xmn0ENcMSRwZjDMViEcVicVv2fbC5ds/se5nUhHDj+96j9o4XEvMEQZwJgtYZObkT6P9BllV3maDy9OlTJdrHnYq6XyzLClljpH0nk8moaaNBpG2mUqlgc3MzVHl3XReapiGRSOD8+fOqcXXc4VO9Xg/1eh21Wg3NZlMl3BiGoRpDNzY28OjRIziOA9M0EYlEQqJWRj4Gt1gspjLbt7a2sLq6irW1tV1TeYaJx+Mq2z2fzyOZTAIAfp0xCKbD4UDcMmAZz+cPb7fb6gPdevwI6V4P3RELjIM22p0Ghs8CjFq0CCFCVXiZeAQgJPzlQkDGnS4tLUE8eYyM46CztfXsDgd9AK7jQiQEGE5nky9x+tgp+z5ozdH+8vsAAMF5aBaCDFI1DONUvGNJzBMEcSpxXTeU+S6tM9J/7jiOEvjVahWVSuVIhJppmsoSE/S2p1IpVUmXKTPBOEnXddV+yeq7bFqVFc1gbOTU1BQymczYFhDHcdR02LW1NWxtbaHT6cAwDMRiMcTjcXDOsbm5qarzkUgEqVRqm2iX1XaJ53nY2NjAvXv31KLIHbORMZfLqWNULBZhWZZaeI2yzsR3uqNd8DxPVeTkh3bwzMAXW22oAM5Ahj1jbOCbffkIxphGIhGk0+ltC87hyvrKygoePHiA+fl5tFotvNlzwIDwVF30BbwQAq7jngrLAnF2SaVSmJubw6VLl9DpdLDxl9+Ht7qKRDIB3+fwPQ9eoLhzkHkhx8Hp2EuCIF56OOdot9tKvLfbbbTbbWxtbakBTp1OB81m80gq7dLWEhTsExMT2+wtjuMoK4sUx5qmQQiBRqOhxPvW1pYS+47jwPd9GIaBVCqFK1euYHp6WondcY9PvV7H8vIyVldXsbq6ilqtBs/zkEwmkcvlUC6XYRgGer0eOp2OshzJMwfBanuQbreLpaUlZUOqVqtj9Q3ouo5isYhSqYRyuawaceWiSnq1d0qdGQchBDY3N0PCfW1tbcf3AB/k1/ejJzmYnEo6qL9JG8rLMChKVuGlvz6RSGBiYmLX516pVDA/P4/bt2+PZZ3SmAbTsmAFPPMEcRwEh+YxxpBIJBCLRtHRdURiz8oGMlHM9bxT854lMU8QxIlFNmiurq5iYWEB1WoVW1tbqNVqaA8i7yKRyKH+wdU0baRoz+fzO4oc6Td2XRemaSIej8MwDDiOg5WVFWXn6XQ6SsBL60wkEsH58+cxPT2Ncrk8cmjTMI7jqOjHra0t1ajbbDbVmYBUKoWpqSmkUinledd1HZZloVAoYGJiIjQ8JYis6MuqfrPZHOvYyTMJQduMTGeRk0E1TdvxccdhOMViaWlp1xQdeaZGNoBqmgYwNhDx/deTAdBl8+eg4fMsE6zCW5al3iM7sbW1hfn5eczPz6Nare79AIzBMs2+gLdMstYQx4bjOErAA1DN2/L93h7xN50xBtM0T42QB0jMEwRxBDgeR8fx9u11bjQaePjwIR4/fqxGfW9ubqLX6ykfdyQSgWVZylt9UBhjKpc8KNpl9ONecM7R6/VUhGM0GkUikUCtVsPy8nKo+i43zjksy0Imk8HMzAymp6dRKBR2PJUrhzaNGrgU/FkK+GQyqSLccrkcotGoauKSHvjhybGccxUPKf3uvV5vrGOYTqdDxy+TyYSGKPV6vVDM5n7xfR+rq6uhqvs4+dLBlBeZ3R4cuBWJWIMehPggV1o/84JTCKEqk77fb/TdrQrfaDRw+/ZtzM/PY21tbc/7l83Qhmkins2e+QURcXKRA92C6Uty9sQohOvuGUN70uchkJgnCOLQuLlUx/s3FvGd+TW4XMDUGd65XsbX3prF6zMZdT3btlXl98mTJ1hYWNiWNCCr1rFYDNls9sBWDDmEJCjYpV97v35IOZDJtm3lg49Go9jY2MD9+/dRqVRg27ayzniep4T+5cuXMT09jVKpNHIh4rruNtEuP4wAhAYyyR6BSCSCcrmMixcvqkQbXdfVMB/GGGKxGGKxmBJtcriVFO7jNgAzxtREVbnFYjFVeZfV3mDc4X6P7dbWVqjqvrKyMta+jaq+SwtPNBpViT9ypHv9//Y7aD5dAGv2p/Du9Agn/QN8HOT7Rlbhk8nkjqKm3W7jzp07mJ+fx9LS0lj3f+HCBVy/fh1Xr17F5vIKmnfuwN/Y2PU2Z+G4EicL13VVBZ5zjng8jkKhMHJWSJD9zDg4yfMQSMwTBHEofPOjZXzjg3k0bA+GxmDoDF2H41/+8Cm++eOn+NuXgPPYwPLyMjY2NtDr9VRFGYCqHsvtIKc45VTUYdH+vE13cnKnN/BQOo6jmldrtZoS747jgDEGy7JUbOTk5CTy+bwSt0KIbVNS5eCnYXzfR6fTCdlyYrEYLl++jIsXL2JiYgKGYSjvc7PZhOd5iMViyOVyMAwD7XYbjx8/DuXkj0OwsbdcLqvFT3BCrVwwSGvRfrBte9vQl3ETcIDR1XfLsjA1NaVE++zsLIrF4rYqsX9GPsB3QybS7FWF73a7uHv3Lm7fvo2nT5+O1QsxMzODubk5XL16NbQwPSvCiDgdyAFmwfd5Pp/fU8AHyXz1PWS++t4R7uWLgYkxp58sLi7i3LlzWFhYwOzs7FHvF0EQp4ibS3X85j/7IWqdHhKGgJCDbLgPIQBbGDCEhzfbP0DCfSYmh60z456al5now772caMax8H3fdi2raIhm80mtra2UKlU0O12Q9YZ0zQRi8WUxaVYLCIej2+rtssPnp3y6xlj6oNIxv7JKZ+5XA6Tk5OhMwrydHKv11NnMdrtdmjIlewt2AvpJZV+92Bjb9A6A4xOndnrWK6vr4fynjf2qN4OM6r6nsvlVFSnTPx5mdNS5OKv0+nAsizE4/GRvxOO4+DevXuYn5/H48ePx5qnUC6XMTc3h+vXr4/V10EQR4EsIrTbbXieh3g8jng8fmAr31mBKvMEQTw3799YRMP2EIWHnu2oZsf+Bhjw4bAIFq1Z/LRhKwG/lxUjkUiMFO1HNdRH+rw7nQ7q9ToajYbKZx+2zliWhenpaczMzGBiYgLRaFT52+/fv6+87DthmqaKfYzH48rTXK1W0e124XkeisWiSoFR0zM9D81mUzV9yoSf/UZEZrNZJdyHrT+cc5UjDkA933ESXoQQqNfroar78vLygWJBg9V3y7Jw4cIFXLhwAefOncPs7CxSqdS+7/MsIhNppLgZVYV3XRcPHz7E/Pw8Hjx4MJZ9qVAoKAGfz+ePavcJYlekVazT6cB1XWW9fNkFfBAS8wRBPBeOx/Gd22t9a42mw3UB339W6WOMQWMMOgOqkWmkI5vQWfiEYCwWGynag1NRjwrpg9/Y2MDGxoYS8FK8y+q4ZVmq4TOVSiESicDzPDXpdLdqezwe3zYl1TRNlUazvLysPqgKhQJmZ2eVhQZ4Fsspzww0m03U63VsbW3tOyJSbsEKtpzaKVNn9mOd6fV6WF5eDlXdW63WmEc/jKy+c85RKBRw7tw5XLhwQVmKXoa4yHEJVuFN00QikdhWhfd9H48ePcLt27dx7969sRZ62WwW169fx9zcHCYmJo5q9wliV+T7W06/jsViSKfTI6NzCRLzBEE8Jx3Hg+sLGDqDAb0/PU9jyh4i0YSAYBpK07O4MBUespRMJg/8B/qgyTkyO31tbQ31el1lbfd6PXDOoes6GGNqcqusAnW73R2jEA3D2DZsKZFIKBEqhECz2cTq6irW19eVXadQKODVV19FqVRSAl4IoRqEV1dX0Wg00Ov1xmoslRGR0jYT9OxLhq0zpmnuWeninGN9fT1Uda9UKmMtKHbC933EYjFMTk5iZmZGVd+P6uzLaUcm0riuq1I6gq8t5xxPnz7F/Pw87t69O1YyUTKZVBX4qampo9x9gtiR4AK11+shFoshlUohHo+TgN8DEvMEQTwXccuAqTPYjkAs0hcVuqZDN3TougFD73/f7AmkYyb+wX/yH+5LdO/EuMk5Qba2trC4uIiVlRVV1ZbedNd1lRdbDk+SFXQAoQZVmRIzLNxH+ZOlgN/a2lIedlmBDwp4OY11cXERCwsLWF5eVvYSuQ87CXnZQyBtM5nM9ucfTJ2R97WXdabRaCjhLo/bqEbd/aDrulpoTE1N4fz584fSpHyWkSKn2+3CMAzlEZYIIbC0tIT5+XncuXNnrEbieDyOa9euYW5uDjMzMySWiGOBc64q8L1eD9FoFMlkks7E7RMS8wRBPBeWoeGd62W8f2MRmmYgl8tB0/RQajfnAr5w8O6nyoci5Ecl59iOwPs3FvGtW6v4+lfm8N4b03AcB+vr68oG0mq14LouOp0ObNtWnuxEIqGmoCYSiZCwGa62JxIJJBKJXSvkMrFGDnSSAj6fzysBL4RApVLBxx9/jJWVFSwvL6Pb7apM9OE8eInMxw/63UdVsYOpM3KRIoXgKBzHwfLycqjq3mg0DvDqhCkWi5icnNy2v+N68F9mhqvwwzMQVlZWVBb8ONamSCSCq1evYm5uDufPn6fjTxwLnHN0u120223Ytq1mdJCAPzgk5gmCeG6+9tYsvnVrFZsdF4W4tU3IVzsO0jEDX3vr+ZOwbi7V8Y0P5tHseSgmLGjas0fzuUC13cN/+a8/xvqDm/A3nqDRaKhUGt/3B8OCEigWi8hms0in08raMqravp+YMyngV1dX0Wq14HkecrkcXnnlFaRSKWxubmJhYQEffvghqtWqSsqRw6SSyeS2DzPDMLYNttopHz9onWGM7TiwSS4kgsJ9bW3tuewyQL9hWQ7DkrGgkUgEhmGEzjAQOyMrlZ1OZ2QVvlKpYH5+Hrdv30atVtvz/kzTxJUrV3D9+nVcunTpwPMaCOJ5kAPw5BmmSCSi/g6TgH9+SMwTBPHcvD6Twde/ModvfDCPjbajquWeL+BxgXTMwNe/PIfXpkfbX/aDTM6RQp4LAc9z4bpeP4vd81DrGfjv/mwZn9Ufw7IsWJal4haz2ew2wS4r7gcROp1ORwl4mfOezWZV82C1WsVf/uVfqohI13VVKo70qQfFuYyIlFs+n9/RAiFTZ2QyyU7WmVarFcpzX15eHnvK604YhoHJyUkVC1kqlRCPx9WCSYp3+qAeD5mi5DjOtir85uamqsBXq9U970vXdbzyyiuYm5vD5cuXaRFFHAtCCFWBl0PLEonEyB4e4vkgMU8QxKHw3hvTuFRM4P0bi/j2/BpcXyBmMbw71/exH4aQDybnaBqD3bPRbvVz29XwKY2BQaBiTeLKq1FMT5YxMzODTCZzoGr7KLrdLqrVKtbW1tBsNpUw1jQNvu/j3r17IW+553lKxMuMdtnUJafTShvKqOmwkp2sM0Gvvuu6IeG+uLgYmqx7UAqFQmiKaqlUgu/7cJx+FKlpmrAsa99TdV9mglV4XdeRSCSQy+UA9PsVpIBfW1vb8740TcPFixcxNzeHK1euUA8CcSxIAT8874AE/NFCf3UJgjg0XpvO4LXpDH77S3MHSpjZi2ByDtCvQAohAMag63pfKBsGDKHBMnT87b/7N1FIHU4qim3bqFarWF1dVdNf5QAjzvm2CjTnfNtUWDmhVgr3cQZdSfE+yjojhEC1Wg2J97W1tbGGAO1GLBZTwl2K92g0qhYkcqCWZVlIpVLUPLlPglX4WCymqvDtdhsffvgh5ufnsbS0NNZ9XbhwAdevX8fVq1cpAYg4FuQEalmBNwwDiUQC2WyWFvcvCDrKBEEcOpahwTIOvzIYTM4BoD40DNOAYRhgA7e+3XZgmTpSsWdC+SARlrZtY3NzE/fv38fCwgIqlQo8z4OmachkMuqDKhg9KQUv5xzJZBKXL1/G1NQUSqXStgbGUexmnel0Onj06FHI627b9v4O4hC6rqtYSCncpbVHVt7lcxplCyLGY7gKH4/Hkc1mYds2bt68idu3b+Pp06dj9S3MzMxgbm4OV69e3fVMDkEcJVLAB/s7SMAfD3TECYI4NQSTczgX0DS2rRrJed+n/+5cPzlnvxGW3W4XDx48wJ07d5SANwwDmUwmJOCDyHjLaDSK2dlZNaE0k8nsWbXeyTqj6zpWV1dDVfetra3nO4AAcrlcqOo+OTkZyrZ3XRftdjvkfafq+8FxHCc0+Cafz4Nzjnv37mF+fh6PHz8e60xKuVxWWfDpdPoF7DlBbEdOG+50OipMYHJykvoyjhkmxowvWFxcxLlz57CwsIDZ2edPpCAIgjgIN5fq+M1/9kM0ex4K8XCajUzOSUUN/NPf+CwebbS3RViqptyoga9/ZQ6/PDeBpaUl3Lt3D3fv3sXS0pKKAkyn0ztWmjjnSCQSSKfTmJycxLlz53ZtVg0ybJ3RdR3NZhNLS0tKuK+urqrq/EGJRqOYmZkJVd2Hp+oGq+/kfT8cgtF7sgqv6zoePXqE+fl5PHjwYKzXtlAoKAGfz+dfwJ4TxHZkRGqn0wFjDIlEAvF4nPoyThAk5gmCOHWMypkfTs65WEiMFP1ccLiuh82OCws+3jbuImpvwPd9JeAzmcy2SpNhGCgUCioNJ5/PI5PJIBaL7ZnYMmydCebfSwE/zqCf3dA0DeVyWYn22dlZFAqFkbGUwUQdaeOxLIuq78+J4zih6ZWRSAQLCwu4ffs27t27B9d197yPbDaL69evY25uTiUiEcSLRvZ1tNttMMYQj8eRSCRIwJ9QqPRCEMSpY5zknN/9/Vto2B4KCROe58LzfXieB9/z4HMOXXC0hYU7PIt3J9g2AR+LxVAulzExMYFMJoNoNArP8xCLxRCPx3etXEvrjBTwlUoFa2tryjYzTrzgXmQymZBwn5qa2vFUt6y+y4Qdy7IQi8UoXeIQkFV4WbWMxWLY2trChx9+iLt3744VAZpMJlUFfmpq6gXsNUFsR1rCOp0OhBBqkNNeTfrE8UNiniCIU8luyTnhCEsNLduGO2hKBQCmaTA0HQYYVvUSsgUTxVw2lO8eiUTQ7XZh2zYikYiqtO6EFO9bW1tYXl7G6uqq2jzPe67nalnWNrtMKpXa8frB6rvv+6r6nk6nqfp+SASr8JFIBO12G/fu3cOdO3fGOssSj8dx7do1zM3NYWZmhl4X4liQPTLtdhtCCMTjcTXsjTg9kJgnCOJUMyo5JxhhyQAYhg7X7XvTmab1feqaBkvoYIaGz/+1n8Vkvt9U6DgOKpUKTNNEKpVS+cijoiebzWZogury8jK63e5zPR/GGEqlUqjqPs6URM/zlPcdoOr7URCcYskYQ6PRwOPHj3H79m20Wq09bx+JRHD16lXMzc3h/PnzNFCLOBZc11UWGt/3kUgkUCgUnnv+BnF8kJgnCOLMMSrCMhaLwdB16IYBXdfAwLDZdpCIGpgpF9HrduC6LiKRCOLxOIC+QN7a2oI/sOisra1hZWUFa2trWF9fR6PR6C8QWP8MAGNs3xXWVCoVynOfnp4ey5cqhFDWGc65qr7HYjGq8h4ysnrZ6/XQarXw5MkT3L9/H7Vabc/bmqaJK1eu4Pr167h06RItrohjwfM8ZaHxPA/xeBy5XE7NrCBONyTmCYI4cwxHWEasCCJD+tjzOVyf42cupqAzoFgsKs+5EAKNRgMLCwt48uQJlpaWsLq6Ctd11QefTKKRkZJyeBQAJeqDIl/GPM7MzODcuXMqvnI/MYPB6jtjTE2SJYF4+ARH0ddqNSwsLODhw4fY3Nzc87a6ruOVV17B3NwcLl++TLF9xLHgeZ6qwMt+n2w2SwL+DEJiniCIM8nX3prFt26totpxVJqNEAJcCPjcR63rIxMz8Bs/25+cubS0hKdPn+LJkydYXFxEq9WCpmlKjAMYW5RJUZ/P5zE1NYXJyUmUy2UUCgUIIeD7vhoi1Ov1oGkadF0PbfKxpYCn6vuLQVoQKpUKnj59ikePHo3VsKxpGi5evIi5uTlcuXKFUj+IY8H3fSXgHcdBPB5XqVv0N+PsQtGUBEGcWWSEZd12oTPA0Bh8DnicI2YA7824KHSeYHV1FQBUBf0gXuZkMhlqUJ2ZmdmziUyKet/31dbr9WDbNnq9Hjjn0DQNsVgMpmluE/xS9I/y9BPjI6vwlUoF9+/fx5MnT1CpVMYSPxcuXMD169dx9erVbQPMCOJFIAV8MBY1kUiQgH+JoMo8QRBnDiEEbNvGX5uJ4HffmcK//ngN31/owHF8CN/DLKq47K3DfdjB+mDi6n4+9AzDwNTUVMjrPs6012Gk355zDs65SpPIZrOhswCjRL9MqpGbrNzvVOUPfk/0cV0X1WoVn3zyCR4/foz19XV1fHZ7LWdmZjA3N4erV68imUy+qN0lCIU8syd7OaLRKJLJJCYmJuh3/CWExDxBECcOx+Pb4ibHod1u48GDB3j69Ck2NjZQqVTQbDaR4xzvcsCDDksXMHXZqDqebaZYLIaq7uVy+bl86sPed8uykEgkdvwQlkJ8L5vPsOAn0b8dIQTq9Tpu3ryJO3fuYG1tLdTEvBPlclllwe+nz4EgDgsp4DudDmzbRjQaVVnwZ/X3lRgPEvMEQZwYbi7V8f6NRXxnfg0u74vud673B0G9PpMJXVcIgY2NDTx58gQPHz7E48ePsbm5qYSZdBBKkRZVi4Ldq+fxeFwJ99nZWUxPTz+3fYJzrpJnhBAwDAOWZanUnMNCWoQOS/TvJfhPk+jvdru4desWbt68icXFRbX/uw3/KhQKSsDn8/kXuLcE0UcOJWu322rmRSKRGCuulnh5IM88QRAnAulvb9geDI3B0Bk8X8DjAumogd965zJezziqUfXp06dotVoQQihRKa0RUtDvZXvRdR1TU1OhqnsulzsUn6kc2uR5nqq+W5Z1qj6Ah0X/qEXASRb9nufh9u3b+MlPfoKHDx+CMQbDMHbdl2w2i+vXr2Nubg4TExMvcG8Jok8wSanb7aq4XEquInaCxDxBEMfOzaU6fvOf/RDNnodC3ALT+kKsb0fx0HQFTOHhF/TbSHl1eJ4HbeB1l9XooJjfiXw+H6q6l8vlXSuz+2Gn6vvLEEt4UNG/0yLgeffl4cOH+MlPfoL5+Xn4vg/TNHd9nZPJpKrAT01NPdfjE8RBkAJe2mik9Y4EPDEOZLMhCOLYef/GIhq2h4Tho9Gsw/d8CAzqDAKwhIANE3edPN7S6ojFYnumzsRiMZUqI6vuh21rGVV9TyaTp6r6fhgcxN4jv3ddF7Zthy7br+gXQmBxcRE/+clPcOvWLbTbbViWpRZUo4jH47h27Rrm5uYwMzNDqR/EC0c26ssKvGEYSCQSyGazh1ZkIF4O6N1CEMSx4ngc37m9BkNjABeDIUxQg5gYADAGjQGr+gSsyBp0Fj6hqGkaJicnQ1X3fD5/6AJtp+p7IpE41Mc5qxy26K9UKrh37x5u376NRqMBwzAQiUQQiURGLqgikQiuXr2Kubk5nD9//qVbdBHHjxTwsgJvGIZKsCIBTxwUeucQBHGsdBwPri9g6AyGZgB2QMgzptpVmeDwATQ6NiYyCUxNTWFqagrT09MolUqqAiu98rVaTYn54ETW/V7meR5c14Xv+2po08tYfX+R7Cb6K5UK7ty5g48//hjr6+vKRpNKpVTMp1xsyX6Ky5cv49q1a7h06ZLK6+92u9sq/wRxVMgKfKfTga7rSCQSmJycfClseMTRQ2KeIIhjJW4ZMHUG2xFIWP0/SZrGIFNnGOt74znXkYoa+N3/7LeRy2yPBpTiLfj9OJfJfHd5mRzc1Ov1VPVdCkBZHR7mMBYN4172MrK5uYnbt2/j1q1bWFlZgeu60HUd0Wh0ZDVT13W88sormJubw+XLl6Hr+jYvv3wtg5cNR3Pu1NR7kqn/wTfR+t73xrpu8u23kfnqe0e8Ry8vvV5PCXjGGAl44sggMU8QxLFiGRreuV7G+zcWAWgwTQu63hfwhtEX0YIL9NoOvvTp2ZFCHsCBBa8QQnnffd9Xp70tyxr7/g6yiAguJPZz292e+1EuIl70YqLRaOD27dv45JNPsLS0pKrtpmmOzOTXNA0XL17E3Nwcrly5ss0rP86ZFCnsdxP9nHMwxsbK6j8OWt/7Hprf/jbYHoJRuC4AkJg/ZHq9nhrmxBhDPB4PnTkkiKOAxDxBEMfO196axbdurWKz46CQTA0q8304F6h2HKRjBr721uEkaclcddd1lUDcqco7DscleCX7XUQc1UJi+OfgZR4X6LgciYgBS9dGXq/T6eDu3bu4c+cOFhcX1SJLeuFHvT4XLlzA9evXcfXq1eeeB7BXU7VETuwNpvW4roterxfy+ssZB8OC/6hFPzNNGHvEanqVyqE/7suK4ziqAi+EUIOcIpHIce8a8ZJAYp4giGPn9ZkMvv6VOXzjg3lstJ3tOfMxA1//8hxem87sfWcjGK6+a5oGy7KUz/q0c5w2nL0WAreWG/hXP17G9+5swPX7g8B+4UoBf+vTZVwvJ9HpdHDv3j3cuXNHVeAdxwHnHKZpwjRNcM5D9qapqSlcvXoVr7zyClKpFID+UCjbtnddWBz0DMUwUvTvtfgbJfo9zxtb9A9ftq/Xpf9C9F8HPDt7ZJxwm9BpwXEcVYEXQiAej6NYLJKAJ44FEvMEQZwI3ntjGpeKCbx/YxHfnl+D6wvELIZ35/oTYPcr5A+7+k6MZreFxKhBYL0ex7/+aBXfurWKX57sItt4oJqMXddVNqfh16lcLqss+HS6b7Xa79kIeVmwT2Kc2+70vIPPf5zFQVCkyynFsmG31+upr71eTy0+5eUyAnVU0o/neeqyN+bnUe71YFer2/ZZ0zXksrkxX1liGNd1VQWec454PI5CoYBoNHrcu0a85NCnGkEQJ4bXpjN4bTqD3/7SHDqOh7hlwDLGS40JVt89z1PJM2el+n7c7LexcuHNn8M3PphHs+ehmLDAWF8M9TwHwnVQ7xn4V4+AXzJ95NAZ6YUvFApKwOfz+W2Pc9hnJIKLCplgFPxZbv1hZo5aLA5/Db4Pg5fJ2wYXCMGFw26XSeRCZCeb1JVu3+rhcx+qiRwA0zSMNyKSCOJ5HtrtNtrtNnzfRzweRz6fJwFPnChIzBMEceKwDA2WsXfDmKy+O47Tv51lIRaLnfjEkdPIfhsr3xeX0LA9FBImOt02ej0HUk0KABE4sIWJRyjjXHJd3T6bzeL69euYm5vDxMD3zTlX1eqgKB4ltEcJ73Gu5w72+yiR9q7nZTfhb1mRvm2H9RdFIvBfUvPj4XmestB4nodYLIZcLodoNEqFAeJEQmKeIIhTw3D1XQ5tSqfTY33IOh7fd8Wf6B933/cBw4DI5SA4BxccnIuBL9yHYZiIWBZ4dQOtdht//PESmODwXBfdrg3XdcG5DyEG9WIG+DrDHceC9ckPMVWeQKlUgmma+OEPf4gf/OAHqjLOOT/uQ3CiCL7Xh9/3w83Ywf/u1sj8suP7vrLQOI6DeDyOTCaDWCxGAp448ZCYJwjiRBO0KgAHq77fXKrj/RuL+M78Glzeb8J853rfi//6zMGaak86oyrS8rJer4dms4lms4lWq6W+tlotZSmQEyo7nQ5s28Z/tLCI690uGqurz6q9AW0oF1ZRx8XS4yeo59sQgqHl+vA9D96gd0HBGJjmg2kGEpkcbNvG06dPQ89BNoAahqH85qO+P0vIuQbBbdRlozbDMGA9fgK2vg4rne4f48HgNcYYQKI0hO/7qgLvOA5isRjS6TQJeOLUQWKeIIgThRAilGgive8H/YAd1YRpOwLv31jEt26t4utfmcN7b0wfwTMZDef8wFaQna4XHE4jt7283Put0rqeB2Dn6m7wcgYBHQIO+mc/dEMH8zRoDNA1HZquQ9MYesKECQ/5dALcc0PNnPJYyeO1G1LU7yT6ZdPp8yAn0u4mpA/yb8PXe14RuZTJoKVpMGgw0Ug45+p3pdfrIRaLIZVKIR6Pk4AnTi0k5gmCOHaC1XfGGEzTRDwef+6q682leqgJc1R+/Tc+mMelYgKfmkofug971PXGsYxIe8mwCB++LCjOjxuBoJgHzul13PMKEALQmIZIxIKm6ZCvgBCAYBouGk0Uk9sTVoIJLcHvgz+POpYy5lFuUthHIhFEo1HEYjG1xeNxtSUSCcRisR3F9/MuBl4kwnX3zJEXeyyQzhKc89BZpmg0imQyiYmJiVP1uhLETpCYJwjihXPY1fedeP/GIhq2h0yEodVu9vv/ArnbOhfYsHX8F//vf4UvmAuH9rhBpN98P+JcVqZPE8MF+ytGFU/8LGyYiDMfpqkry4cAQ8fXkDSA9+aKuJSb2nc12zAMaJqmmmOHox2DsY6777dQ1iJN0xCJRGBZVujr8GUnuYKbfPvtI7nuaYNzjm63i3a7Ddu2EYlEkEgkUCwWScATZw4S8wRBvBCOqvq+E47H8Z3bazC0vnzsdrsBkRyo0GsR3O/EkKs8hM7EtnxwmQkuvwJQVWHP89QmBbusHAcvk4OBgO355MHLXhSMsdBQot2+yu+TrgvmujAta9DAylQjK8CgMYZsLgfu+/jUp+bwH/zOb+GDj1dCFic9MAgsn+gPAjsMi9Nuk1+DC8dRQl9eJi093W4X3W5318ezLGtHoS8vOy4vf+ar7yHz1feO5bGPGyGEqsB3u11EIhGVBX/WeisIIgiJeYIgjgTpdZbVd9kgGY/HX8jjdxwPri9g6APRGaocB34QPnzBUGt1AacdEuKjvh6VpSU4WGgvkT1qOqicSBq8TFavgx5yednwYmXUkKMg8eUVaLU6LMuCxhiYpkHTGBgbWFoYg65pEINYRMbYoQ8CO+hxlQJbTosdhVxs7lbll959uThotVo73p+u63tW+U3TPLYq/37nBpzUBYIQQlXgu90uLMtCIpFAPp8/sQL+rBx74uRAYp4giENjuPouP1hf5GltmRHdbLbAuI+e44FpPjzfA/c5BETAbiPATQPMtfHw7jyY2NneIoWvFMPBiv3wz6Mu3+26uq4fygCkcRtGJVJwSj+5ZVmIRqOIRCLKUy795fn529DX1hCNRFQ1XlbmB9/B7vXAfB+1eg21W7f6TaO6jv94LoJfu34BPV8gGTURMQ3oWhfr671Q9X+n71+E4DUMQ02f3QkhxI6V/eBlcjKrrBLvhPwd2avKfxS/P/udG3CSBKUQArZtKwEvz/LlcrlTMeH5NB974mRy8t/1BEGcWOQoesdxIIQ4kuq74zgqPm6cr71eT902657Hhl+AEA4Y02AYWsgeIsDgahbOoYLXf/HnRwpyKbKDUzaDm5zEud9/3+12weme+8EwjG2ifKctGo3uSyhzwQHPA6vVtv1baE99H81mE0u3b+97/3diVFPrqO93sgnttVjY676CP8tjuxvyjNRuth6ZKCR/bjabO96fbODdq8q/7+NqmjAGg7l2Yq9G2heFFPBygWQYBhKJBLLZ7KkQ8MPsdOwFAIi+3e+kHHvi5HP6fgMIgjhWggkqsrKYTCbHqh5K8bIfcf48kzlf1Tfw1M+ixyxETRaK2RYCsGEiAh+vJ1soaIUDP85RISvk0WgU8Xgc0WhUbbFYTDUNByu4siosq/S7/bzb9Yb/rXn9+tj7vZ/rjoNsIj5JjcEHXSzouo5kMgkAoXSe4R4LmYa029me4Z/Had49bc2fsgLf6XSg6zoSiQSmpqZOpYAfxuccvrTwybQmz4MViSA1eI8QxDic/t8GgiB25DAmnu5UfU8kEsqvurm5ObY4f5GCrKB18FlzAT90z8GGCSYENAhwMAgwWPDxWXMBBW33hsfDgDEWikHc6+t+B2MdOX/rb+25OHiexcJ+b3Pcwv5FLS6kqJcxp6O+ygZreaZlWOwH/00K+ytPnyLturDb7X4PhNw0DebASnZcBCvwmqYhHo9jcnLyQGcfjhshBGq1GiqVCtbX11GpVJB/8AC5Xg/dra2Rtznu9zZx+iAxTxBnkINOPJV5zLVaDfV6HY1GA7ZtKzHf7XZDwrzT6Zz4EfGX9C2kmY37fhELfhYcGgz4OKfX8Kq+cWAhr+v6vsR5NBo9dVXRYaQ4PClV0aDoP+zFwri3Cf58FL8LMopzN4LV/Z0E/3Dj9mSziSTn6Nk2gGf2DjCGbDYDDS/2vRocfMYYQyKRQKlUgmVZL3Q/DooQAvV6PSTa19fXsbGxsS0e9YvdLrZPVniG7x//3AjidHEy/iITBHFohCeeArrG0LE5/scfPMU3f7yA//i1GN7IeiFR3mq10Gg00G63IYQIWQLOAgWti4K2gM8ai/CgwQCHzsLCSzbRSQG+lziPRCInOm/8ZUAuLk4K0g407pmFoz7TEVxcCCFCgj8ajQ4q8Qyc93s1GDCwor2Y97XjOOpvEADE4/ETL+CFEGg0GttEe6VS2XOmwfiPAficqvPE+JCYJ4gzxM2lOv7L//kj1LsOovDAwSFnZOoCqLsm/j8/6OCXrHvIipYSAPI0/EkfiLMfZMb0OFXzeDx+ogUEcTpgjJ2YsxYAVDP1sOhvtVqo/8VfgD95gmQi+azJu38j6Ee4QAo2tAshEI/HUSwWEYlEjuwxD4IQAs1mc5tgr1QqoSb7o8L3/Bd8boQ4zZycvzoEQTw3799YRMvliAhnW3GNMSAiXPRg4o6Tw+f0hopZPA3IqMRxxflJElUEcRwEh4PZto1utwvbtlXjtK3rMF7AItZ1XVWB55yrQU57pQK9CIQQaLVaI0W7PbAgHQeccxLzxNjQpx1BnBGCE09VcV3I2MD+fxkYGASWUYBmrEBjx+N3l82gOwnxE98MShCnAJkcNRzlqOs6lvSj/fh3XVdV4H3fRzweRz6fPzYBL0W7FOpB8X6col3XNei6AUMOezMM6LoGBgZvl7hSgghCYp4gzghy4qmpaQi2T/UtsM/K9JoQ4NDgQYOOw/FlymbQcavmsVjszNh5COIkEayCyySYcrk88kyVcN09s8zFPqJhPc9Tj+15nhrkJP35LwIhBNrt9khPe7d79KlVo2CMIZvNYmJiAqVSCZNPF6CvrMDiPhjngBtueJV/lfdz7ImXGxLzBHFGiFsGTJ2hG0zUGPH5ycFgwIeh3PTbGW4G3esrNYMSxPEhpx4HbSwTExO7puAk33577Pvf7brysdvtNjzPQywWQzabfSECfifRvtvU3aMmm82iVCop4T4xMYFisRjqyal3bbSy2bHubz+vE/HywsSYWVqLi4s4d+4cFhYWMDs7e9T7RRDEAfjd37+F928sgPVaoQFJjGnQGAOYhg7X8cWSwN97fbQHnZpBCeLkI2Nk5WA1+bv7IhpJfd9XAt5xHMRiMWWHOwoB3+l0Ql52+X273T70xxqXTCazTbRPTEzQ307iWKDKPEGcIb721iy+dWsVdZ5ALmFA13RoWt8pz7lAteOgEDXwf/mPPovXpnfOmycI4uQhh7S12230ej3EYjGkUqkXUgWXAr7T6YQeOx6PH9pjd7vdbVX29fX1YxftUqgHRftJS98hXm5IzBPEGeL1mQy+/pU5fOODedS6HgxNwNAZPF/A4wLpmIGvf3mOhDxBnBKEEGoiqkyikXGORy3gZfVfLh6i0SiSySQmJiaeKwXLtu2Ror3Vah3i3u+PdDodEuzyK4l24jRAYp4gzhjvvTGNS8UE3r+xiG/Pr8H1BWIWw7tz/QmwJOQJ4uSzWxLNURK079i2jWg0ikQicSABb9v2SE978xhTWlKp1EjRfhJiMgnioJBnniDOMI7H0XE8xC0DlkGpxQRxkhmVRPMiZiZwzpV9x7ZtRCIR1UMzjoDv9XojRXuj0TjS/d6NZDI50tMei8WObZ8I4qigyjxBnGEsQ4NlUEMWQZxUDpJEcxgE/ffdbheWZSGRSKBQKOxY/XccZ6Ror9frR7qvu5FIJEaK9ng8fmz7RBAvGhLzBEEQBPECGZVEk8vljtyfLQW8fGzpv8/n8yEBL0X78ETUWq12pPu3G/F4fJtoL5VKJNoJAiTmCYIgCOLICVbCHcdBNBp9IUk0soFWVuCD/nshBCqVCh49ehSqth+3aB/laU8kEse2TwRx0iExTxAEQRBHwKgkmkQi8UKSaIIV+OBCYnNzMyTax2ybO3RisdiOop0G0BHE/iAxTxAEQRCHSDCJRk5TfhFJNK1WCwsLC1hYWMDW1hZarRYajQYajcaxifZoNBryssvvk8kkiXaCOCRIzBMEQRDEczIqiaZcLo+dRLOf5CnP81CtVrG+vo6lpSUsLi5iZWUFtVoNpmnCNM0jXzgME4lERnraSbQTxNFDYp4gCIIgDsBhJNHcXKrj/RuL+M78GlwuYOoM71zvz4SYm0xiY2NjWzPq+vo6er0eXNcFAJimCcuykEqljuqpKiKRyLYqe6lUQiqVItFOEMcE5cwTBEEQxJj4vq/86DKJJh6PHyiJ5psfLeMbH8yjYXvQmIAGAdfj8LiAxXy8pT/FRa2qHtdxnJCAN03zyDLoLcsaKdrT6TSJdoI4YVBlniAIgiB24SiSaG4u1fGND+ZRbXRgCQdgAhyADkATgA0Tf+XPIIIaEs4WACj//WEKeNM0RzaiZjIZEu0EcUogMU8QBEEQQxx1Es37NxbRsD3ENB+cB06Qi/5/LNFDDxYeooQvxHvPLeBN00SxWNwm2rPZLIl2gjjlkJgnCIIgiAEvIonG8Ti+c3sNhsZgGjp6jgegP0wKAmCMQWMaNAArKILpqxio/D0xDGObPWZiYgK5XI5EO0GcUUjMEwRBEC81z5tEs186jgfXFzB0Bp09WyQwpoEFgmw0IcChwYMGHX7oPgzDQLFY3GaRyWaz0LTd03AIgjhbkJgnCIIgXjoOI4nmoMQtA6bOYDsCEfPZx/Bw4ZyDwWQcM+UiJkth0Z7L5Ui0EwQBgMQ8QRAE8ZIgk2ja7TY8z0M8HkculztQEs3zYBka3rlexvs3FsG0fmWegUHTdRi6Dt3QwTQdfk/gVz87i//0V/7OC90/giBOFyTmCYIgiDOLEEJV4GUSTTqdfq4kmsPga2/N4lu3VlHveshksjANHQz9/eFcoNpxkImb+LufPXds+0gQxOmAxDxBEARxpjjqJJrD4PWZDL7+lTl844N51LoeDI3D0Bk8X8DjAumYga9/eQ6vTWeOe1cJgjjhkJgnCIIgzgSjkmhOsrf8vTemcamYwPs3FvHt+TW4vkDMYnh3rj8BloQ8QRDjQBNgCYIgiFPLqCSawx6s9CJwPI6O4yFuGbCMk7n4IAjiZHK6/toRBEEQLz3DSTSJROKFJdEcFZahwTKs494NgiBOISTmCYIgiBPPSUmiIQiCOGmQmCcIgiBOJCc1iYYgCOIkQWKeIAiCODEEk2i63S4ikciJS6IhCII4SZCYJwiCII6d05ZEQxAEcVIgMU8QBEEcC6OSaMrl8qlLoiEIgjhO6C8mQRAE8cI4i0k0BEEQxwmJeYIgCOJIoSQagiCIo4PEPEEQBHHoUBINQRDEi4HEPEEQBHEoBJNobNuGZVmUREMQBHHEkJgnCIIgngtKoiEIgjg+SMwTBEEQ+4aSaAiCIE4G9FeXIAiCGAtKoiEIgjh5kJgnCIIgdoSSaAiCIE42JOYJgiCIEJREQxAEcXogMU8QBEFQEg1BEMQphcQ8QRDES0yv10O73Ua326UkGoIgiFMIiXmCIIiXDEqiIQiCODvQX26CIIiXAEqiIQiCOJuQmCcIgjijUBINQRDE2YfEPEEQxBmCkmgIgiBeLkjMEwRBnHKGk2gikQgl0RAEQbwkkJgnCII4pVASDUEQBEFiniAI4hThOI6y0VASDXHSqP/BN9H63vfGum7y7beR+ep7R7xHBHH2ob/+BEEQJxxKoiFOC63vfQ/Nb38bbI/3pnBdACAxTxCHAIl5giCII8bxODqOh7hlwDLGs8BQEg1xWmGmCWNiYtfreJXKC9obgjj7kJgnCII4Im4u1fH+jUV8Z34NLhcwdYZ3rpfxtbdm8fpMZtv1KYnm6BFCqK+7bZzzPa+z3+2w7/Mk7uPFx4+RcV20G43gQX/2LYBEIgF6NxPE4UFiniAI4gj4gx8v4b/64BO0eh50jcFgDB0u8C9/+BR/+NES/g8/M4W3X83A9/1QEo1pmohGo4hEImi1WmdC4J20fSSOjkKrhRTncAc2mlEIzknME8QhQmKeIAjikLm5VMc3/vA2NpsdROECDPAG/6YJYNMx8V9/6x7+7I9vIunWoOs6TNOEaZpUgScIgiD2BYl5giCIQ+b9G4to2i6icBHS5qL/H0v00EMEj9kkvpj0KEqSIAiCODD0CUIQBHGIOB7Hd26vwdCYEvJCAJwLcMEBADrToDGBZRQgmH6Me0sQLx4yOhHE4UJiniAI4hDpOB5cX8DQn/15ZQAYY9A0rW+jYYAGAQ4NHv0ZJgiCIJ4DstkQBEEcInHLgKkz2E6g/tjX7yE4GAz4MMBf6P4RB4cxtusmF2uHtR32/b2IfbSePIW+vo5kMhl+zzP1HxiGQdV5gjhESMwTBEEcIpah4Z3rZbx/YxEaNLCgkFe2GwZAw6uxFoqZ7JkXeCfxPg9yf8TeLOVyaHIOFoymHELg2dAogiCeHxLzBEEQh8zX3prFt26toskSyMctaNozIci5QLXjoBg18P/4jV/Aa9Pb8+YJ4rSSfPvtI7kuQRA7w8SYobuLi4s4d+4cFhYWMDs7e9T7RRAEcar55kfL+MYH82jYHgyNwdAZPF/A4wLpmIGvf3kO770xfdy7SRAEQZxyqDJPEARxBLz3xjQuFRN4/8Yivj2/BtcXiFkM7871J8BSRZ4gCII4DKgyTxAEccQ4HkfH8RC3DFgGpdcQBEEQhwdV5gmCII4Yy9BgGdZx7wZBEARxBqESEUEQBEEQBEGcUkjMEwRBEARBEMQphcQ8QRAEQRAEQZxSSMwTBEEQBEEQxCmFxDxBEARBEARBnFJIzBMEQRAEQRDEKYXEPEEQBEEQBEGcUkjMEwRBEARBEMQphcQ8QRAEQRAEQZxSSMwTBEEQBEEQxCmFxDxBEARBEARBnFKMca/oeR4AYGVl5ch2hiAIgiAIgiCIPpOTkzCM3eX62GK+UqkAAD7/+c8/314RBEEQBEEQBLEnCwsLmJ2d3fU6TAghxrkz27bx8ccfY2JiYs8VAkEQBEEQBEEQz8c4lfmxxTxBEARBEARBECcLaoAlCIIgCIIgiFMKiXmCIAiCIAiCOKWQmCcIgiAIgiCIUwqJeYIgCIIgCII4pZCYJwiCIAiCIIhTCol5giAIgiAIgjilkJgnCIIgCIIgiFMKiXmCIAiCIAiCOKX8/wGUNOyDI642RgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(7.0, 4.6))\n", + "xn, yn = np.array(x_toy), np.array(y_toy)\n", + "P_np = np.array(out.matrix)\n", + "\n", + "# Top-mass arrows: keep only the 5% strongest entries to avoid clutter.\n", + "threshold_p = np.percentile(P_np, 95)\n", + "for i in range(n_toy):\n", + " for j in range(m_toy):\n", + " if P_np[i, j] > threshold_p:\n", + " w = 4 * P_np[i, j] / P_np.max()\n", + " ax.plot(\n", + " [xn[i, 0], yn[j, 0]],\n", + " [xn[i, 1], yn[j, 1]],\n", + " \"k-\",\n", + " alpha=min(P_np[i, j] / P_np.max() * 0.7, 0.5),\n", + " linewidth=w,\n", + " zorder=2,\n", + " )\n", + "ax.scatter(\n", + " xn[:, 0],\n", + " xn[:, 1],\n", + " c=\"C0\",\n", + " s=40,\n", + " alpha=0.85,\n", + " label=r\"source $\\mu$\",\n", + " zorder=3,\n", + ")\n", + "ax.scatter(\n", + " yn[:, 0],\n", + " yn[:, 1],\n", + " c=\"C3\",\n", + " s=40,\n", + " alpha=0.85,\n", + " marker=\"s\",\n", + " label=r\"target $\\nu$\",\n", + " zorder=3,\n", + ")\n", + "ax.set_title(\"Standard Sinkhorn transport between two point clouds\")\n", + "ax.legend(loc=\"upper left\", fontsize=9)\n", + "ax.set_aspect(\"equal\")\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "255b7e20", + "metadata": {}, + "source": [ + "## 2. Adding linear constraints\n", + "\n", + "In several applications the transport plan should respect **further linear conditions** beyond having the right marginals. Examples include:\n", + "\n", + "- a **budget cap** on a secondary cost (e.g. minimise $L^1$ transport while keeping $L^2$ transport bounded),\n", + "- **capacity constraints** $p_{ij} \\le c_{ij}$ on individual cells,\n", + "- **moment constraints** in martingale optimal transport,\n", + "- **fairness or coverage** conditions in ranking / assignment / recommendation tasks.\n", + "\n", + "All of these can be written as a finite set of linear conditions on $P$. Following the paper, we encode them in matrices $D_1, \\ldots, D_{K+L} \\in \\mathbb{R}^{n \\times m}$ representing $K$ inequalities $D_k \\cdot P \\ge 0$ and $L$ equalities $D_l \\cdot P = 0$ (where $\\cdot$ is the entry-wise inner product). After translation by the threshold, any constraint of the form $D \\cdot P \\le t$ or $D \\cdot P = t$ can be put in this homogeneous form. The constrained linear program becomes\n", + "$$\\min_{P \\in U(a,b)} \\langle P, C \\rangle \\quad\\text{s.t.}\\quad D_k \\cdot P \\ge 0 \\ \\forall k, \\quad D_l \\cdot P = 0 \\ \\forall l.$$\n", + "\n", + "Geometrically, vanilla Sinkhorn iterates inside the marginal polytope $U(a,b)$; the constrained version iterates inside its intersection with the constraint polyhedron. The cartoon below : a hexagon for $U(a,b)$, a half-plane for an inequality, a hyperplane for an equality : captures the idea (the actual polytope lives in dimension $(n-1)(m-1)$)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "335172ba", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:02.698455Z", + "iopub.status.busy": "2026-05-02T06:39:02.697819Z", + "iopub.status.idle": "2026-05-02T06:39:02.804020Z", + "shell.execute_reply": "2026-05-02T06:39:02.803081Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAG/CAYAAABPOI/AAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAAAlApJREFUeJzs3XecE2X+wPHPTHqynV2W3jtYEEEEFRDE3rFy2PvZsJx6dk/REwtY7lD8KdZTT1H0FPQsYAWVBdRDQDoodWFLNnUyz++PsJGw2V6y2f2+Xy9ebCaTmSfJZPLNM9/n+2hKKYUQQgghhBApRE92A4QQQgghhKgtCWKFEEIIIUTKkSBWCCGEEEKkHAlihRBCCCFEypEgVgghhBBCpBwJYoUQQgghRMqRIFYIIYQQQqQcCWKFEEIIIUTKkSBWCCGEEEKkHAlihRBCCCFEypEgVtTYPffcg6ZpzJ8/P9lNaTbWr1+PpmlccMEFyW6KqKFkvGejR49G07RaPWbWrFlomsY999xT7/1X9pwvuOACNE1j/fr19d6HaBlS9Zz2wAMPYLfbWbt2baPvq/y7cNasWY2+r+aqIY6Thx9+GJvNxi+//FLnbdQqiNU0rVb/WvMbXBflB8Xo0aOT3ZRGM3/+/JQ8QTaFbt261TrQaSk0TaNbt27JbkaTasggVbRcqXxeaKq2//777zz44INcfPHF9OjRo9H31xo0RTxyzTXXkJeXxw033FDnbVhrs/Ldd99dYdmsWbPYsGED559/foUvoQMPPLDODRMiFXTs2JFffvmFzMzMZDdFNGMvvfQSPp8v2c2o4MEHH+TWW2+lY8eOyW6KaCZS8Zx2//334/P5uOWWW5LdlFajIY4Tl8vFddddx6233sqXX37J4YcfXutt1CqITdRjMH/+fDZs2MAFF1zQonsQhUjEZrPRr1+/ZDdDNHNdunRJdhMSat++Pe3bt092M0QzkmrntOLiYl566SVGjx7d6q7mJFNDHSeTJk3i9ttv5+mnn65TENtoObHlOWBr165l+vTpDBo0CKfTySmnnAJED7ypU6dy5JFH0qlTJ+x2O3l5eZx00kksXLgw4TbLLzn6fD5uvvlmunTpgsPhoFevXvz9739HKVXhMe+//z7jxo2jQ4cOOBwO8vPzOeSQQ5gyZUrceuW5YfPnz+eFF17gwAMPxOVykZ+fzyWXXMK2bdsStmnt2rVceOGFseeQn5/PmWeeyY8//lhh3b0vHy5cuJBjjz2W7OxsNE1j2rRpdO/eHYAFCxbEpWXU5HLj3q/3I488Qt++fXE6nXTu3Jkbb7yR0tLShI9bunQpZ555Jvn5+djtdjp37szFF1/MunXrqt1nYWEhLpeLnj17JnztAf70pz+haRofffQR99xzD2PGjAHgxRdfrDT1RCnFc889x/Dhw0lPT8ftdjN48GAeffRRwuFwhX2UX7IKBoP89a9/pVu3brHj4m9/+xuhUChh21avXs0ll1xC165dcTgc5OXlceqpp1JQUFDtcy9XXa7h/Pnzeeuttxg2bBhut5ucnBzOPvtsfvvttwrb2LBhAxCftrPvD8OtW7dy/fXX07t3b5xOJ9nZ2YwfP55PP/20QtuqOt6KiooA+PLLLznppJPo3Lkzdrud3NxcBg8ezOTJkyu8p6Zp8txzzzFy5EgyMzNxOp0MGjSIBx98sNLXeOXKlVxyySV079499hqPGDGCxx57DPgjvQRgw4YNcc9939e0tu/Xjh07uOKKK2jfvj1Op5MBAwbwxBNPVHqsVtZ+TdM466yz4pZv37491s533nkn7r4nnngCTdN44YUXYsv2zYm94IILuPDCCwG499574553opzzpUuXcvzxx5OVlYXb7WbUqFF88803NX4elUmUE7v3ZcSdO3dy2WWX0b59exwOBwMHDox7Xvv6/PPPOemkk8jLy8Nut9O1a1euuuoqtm7dWmHdxYsXc91113HAAQeQk5OD0+mkd+/e3HjjjezevbvC+jU5nqvi9/uZOnUqQ4cOJT09HY/HQ9++fbnqqqvYuHFj3LqlpaXccccd9OvXD6fTSVZWFkceeSRz5sypsN26vF5KKV566SVGjhxJ27ZtcTgctG/fnlGjRvHss8/Gbbe688Le57977rmH3r17Y7fbuf7664Hopfb77ruPkSNH0q5dO+x2Ox06dODcc89NmIvYlOe0H3/8kXPPPZfu3bvjdDrJyclh4MCBXHnllRQXF1f+Zu7l9ddfp6ysrMJntFxNYwCo3TFS7vPPP2f06NGkp6eTkZHB8ccfX2mOZyAQ4JFHHmHIkCGkpaXh8Xg4+OCDmTFjRoXz0t7H1bZt27jooovIz8/H4/EwYsQIvvzySwC8Xi833HBDLB4aOHAg//73vyvsuzYx16xZs6qNR6rKia3N69ihQwcOP/xwZs+enfBzX51a9cTWxbXXXsvXX3/N8ccfz/HHH096ejoAv/zyC7fffjtHHHEExx9/PNnZ2WzcuJH33nuPuXPn8v7773PMMcdU2F44HGb8+PH8/vvvHHvssVitVt59911uvfVWAoFAXMrDs88+y+WXX05+fj4nnHACbdu2ZefOnSxfvpwZM2bw17/+tcL2H3vsMT755BPOOussjj32WL744gv+7//+j/nz57No0SLatGkTW7egoICxY8dSVFTE8ccfz/7778+aNWuYPXs277//PnPmzGH8+PEV9vHNN98wZcoUjjjiCC655BK2bNnCkCFDuO6665g+fTpdu3aNOzBq08N9/fXX89VXX3HmmWeSmZnJ3Llzeeyxx/jqq6/44osvcDgcsXXnzp3LqaeeSiQS4bTTTqNnz578+OOPPP/887zzzjt89tlnVaaEtGnThrPOOosXX3yRjz/+mKOPPjru/sLCQt566y169OjB+PHjcTgcrF+/nhdffJEDDjgg9oMG4lNPzj//fF5++WU6duzIhRdeiM1m4/333+emm27i448/5oMPPsBqrXjonnHGGSxevJjTTz89dlzcddddLF68mHfffTdu3c8++4yTTz6ZQCDACSecQO/evfntt9+YPXs2c+fOZc6cORWeT1384x//4L333uOkk05i1KhRLFq0iDfeeINly5axdOlSHA4HWVlZ3H333UybNo3i4uK4Y3jvnoWffvqJo446iu3btzN+/HhOPvlkCgsLeffddznqqKN47rnnuOiiiyq0IdHxZrFYmDdvXuwzedJJJ9GpUyd2797Nr7/+ylNPPcXUqVNjr7NhGJx22mm8//779OnTh3POOQen08mCBQv461//yqeffsq8efPi3pcPPviAM844g0AgwNFHH83ZZ59NcXExP//8M/fffz833HAD3bp14+677+bee+8lMzMz9sUL8cdEbd+vXbt2MWLECFavXs2wYcM4//zzKSws5M477+Tzzz+v8fvXt29fOnXqxGeffYZSKhaI7v2j4ZNPPuHUU0+N3S6/b+zYsZVu95RTTqGoqIg5c+YwatSoCkHJ3n744QcefvhhDj30UC655BI2btzI22+/zdixY1m6dCl9+/at8fOpjaKiIkaOHIndbmfChAkEAgHeeustLrroInRd5/zzz49b/+9//zu33norOTk5HH/88bRr144ff/yRf/7zn7z33nssXLiQTp06xdafOXMm77zzDqNGjWLcuHFEIhEKCgp47LHHmDt3LosWLYp9X+ytsuO5Krt37+bII49k6dKl9OnThwsvvBCn08natWt59dVXGT9+fKy3vLi4mMMOO4yff/6ZwYMHc91111FUVMS///1vTjnlFO69917uuuuuer1et99+Ow8++CDdunVjwoQJZGVlsXXrVpYtW8bLL7/MZZddVuPzQrnTTz+dgoICjjnmGE499dRYEPLFF1/w0EMPMWbMGE4//XTS0tL49ddfeeutt3jvvff4+uuvOeCAA6p8/fbWUOe0H3/8kUMOOQRN0zjhhBPo2bMnXq+XdevW8eKLL3LjjTfW6FL1f//7XwAOO+ywCvfVJgaozTFS7j//+Q9z5szh2GOP5YorrmD58uV8+OGHfP/99yxfvpzc3NzYuqWlpYwbN47vvvuOwYMHx77jP/roI6688koWLlyYcBxR+XGVk5PDxIkT2bx5M2+99RZHH3003377LZdeeiler5dTTjmFkpIS/vWvf3HWWWfRuXNnhg8fHttObWKuAw88sM7xSF1ex5EjRzJ//nwWLFgQFxfUiKqnUaNGKUB9/vnnCZd36NBBrVu3rsLjioqK1I4dOyos37Rpk2rfvr3q169fhfsABahjjz1W+Xy+2PJt27apzMxMlZmZqUKhUGz5QQcdpOx2u9q6dWuFbe277/PPP18BymazqYKCgrj7rr76agWoyy67LLbMNE01YMAABahZs2bFrf/f//5XaZqm8vLyVFlZWWz5Cy+8EHsOzzzzTIU2rVu3TgFq1KhRFe6rTvnr3aZNG7Vhw4bYcsMw1Mknn6wANWXKlNhyr9ercnNzlaZp6rPPPovb1nPPPacANWjQIGWaZmz53XffXeG9/u677xSgTjnllApteuSRRxSgHnroodiyzz//XAHq/PPPT/g8Xn/9dQWo/fffXxUXF8eWB4NBdeSRRypATZ06Ne4xXbt2VYDq3bu32rVrV2y5z+dTQ4cOVYB67bXXYsuLiopUmzZtVE5Ojvrf//4Xt63ly5ertLQ01b59exUIBBK2cW/l79m+z6f8eEpPT1c//vhj3H3nnHOOAtQbb7yR8HkkYhiG6tOnj3I4HGr+/Plx9/3++++qU6dOyuVyqW3btsWWV3e8nXbaaQpQS5YsqXDfzp07427/7W9/U4D685//rAzDiC2PRCLq0ksvVYB64oknYst37NihMjIylNVqVZ9++mmF7W/cuDHuNqC6du2a8LnX5f26/PLLFaCuuOKKuPVXr16tMjMzqzwG91X+Xu79Ol188cXK7XarkSNHqr59+8aWG4ahMjMzVa9eveK2Uf753Fv5+3P33Xcn3O/e798LL7wQd9+MGTMUoK688soaPYfqjtO9z9Hl6wLq4osvjnu///e//ymLxaL69+8ft50FCxYoTdPU8OHD1e7du+Pue+mllxSgTjvttLjl69evj9t2ufLzz97nDaWqP56rUv6Zu+SSS1QkEom7r6ysTBUWFsZuX3HFFQpQF154Ydz5b9OmTapdu3ZK0zT13XffxZbX5fXKyclRHTp0UF6vt0Jb9/1uquq8sPf9++23X8Lv1G3btqmSkpIKy5cuXao8Ho865phj4pY31TnthhtuUIB65513KtxXUlJSo/OvUkrl5+crt9td4X1VqnYxQG2OkfLvQovFoj755JO4dW+99VYFqL///e9xyy+++OKEx3UgEFDHHXecAtR7770XW773cXXttdfGHYsPPPCAAlRWVpY67bTTVDAYjN336quvJvxOrm3MVV08UtlxUpvXsdy7776rADV58uSE+6pKowex06ZNq/U2r7nmGgXEBWNK/RHE/vrrrxUec9555ylA/fTTT7FlBx10kHK73QlftH2Vf0AvuuiiCvcVFhYqj8ej3G53LEj+6quvFKCGDh2acHvlAcLewVP5SfjAAw9M+JiGCGLvu+++CvetWLFCaZoW98X6yiuvKECdccYZCbc3ZMgQBahvvvkmtixREKuUUgcffLCyWq1q8+bNccv79Omj7Ha72r59e2xZdUHsuHHjFKA++OCDCvctW7ZMAapPnz5xy8tPlC+99FKFx3z00UcKUOPGjYste+KJJxSgpk+fnrANkydPrrQN+6ruhH/77bdXeMxnn32mAHXjjTcmfB6JvPfee1V+yKdPn64A9fTTT8eWVXe8lR+jK1asqOopqkgkonJzc1Xbtm1VOByucH9RUZHSNC3us1D+A+aqq66qctvlqgpia/t+hUIh5fF4lMfjqRCMK6XUXXfdVasgtjwI2/vHU7du3dT48eNjXyabNm1SSim1cOFCBajLL788bhv1CWJHjhxZ4b5QKKSsVqsaMmRIjZ5DXYJYt9sd90Oy3BFHHKGAuMCo/FhatmxZwv0PHjxYWSyWhMHUvkzTVBkZGWrMmDFxy6s7niuzbds2peu6ys/PV6WlpVWuGwqFlNvtVm63O+68Va78WNy7Q6Mur1dOTo7q1q1bjQK1mgax7777brXb2teJJ56oHA5HXOdPU53TyoPYefPm1brd5YLBoAJU9+7dE95f0xigNseIUn98F06cOLHCfWvXrlWAOv3002PLCgsLldVqVYMHD064vfLvtr2/j8vfB4/HU+HHzsaNG2Px0L6dhIZhKJvNprp161bt8yiXKOaqSxBb29exXPl5c+/XrKYaPZ1g2LBhld739ddfM336dL799lu2b99eIa/ut99+q9DtnJmZSa9evSpsq3PnzgBxORWTJk1i8uTJDBgwgLPOOosjjjiCESNGVDmQYdSoURWW5eTksN9++7Fw4UJWrlzJoEGDYnl4Rx55ZMLtjBs3jtmzZ1NQUMA555wTd19Vr0l9JWp/3759yc/PZ/Xq1ZSWlpKenl5t+8eOHcvixYspKCjg0EMPrXKff/7zn7nwwgt57rnnYpeNPv/8c1atWsU555xDXl5ejdtf3q7y3Nm97b///rRt25ZVq1bh9XpJS0uLuz/Rcz/iiCPQNI0lS5bEln399ddA9HJWonzjlStXAtHLL8cdd1yN257IwQcfXGFZomO1OuVt3rhxY8I2//rrrwAJc7EqO94mTZrE7NmzOeSQQzjzzDMZM2YMhx56aIVLlatWrWLnzp307NmT+++/P+G2XC5X3L7Lc6yOPfbYap9bdWr7fq1YsYKysjKGDx8el/5TbvTo0dx333013v+4ceOAaNrATTfdxNq1a1m/fj1XXXUVRxxxBLfffjuffPIJF1xwQSyVoPwxDSHRMWSz2cjPz69TDllN9e7dm4yMjArLy4/foqKi2OX+r7/+GqvVyuzZs5k9e3aFxwSDQSKRCKtWrWLIkCFANDXsmWee4fXXX+d///sfJSUlmKYZe8zeOZZ7q+358/vvv8c0TQ477LAK54x9rVixAp/PxyGHHJLwvFX+vibKw67N6zVp0iSmT59O//79OeOMMzj88MMZMWIEOTk5tXpue6vqdfnggw/45z//yeLFi9m5cyeGYcTdv3PnzhoP8Guoc9o555zDE088wSmnnMLpp5/O2LFjOfTQQ2s1WKiwsBCA7OzshPfXNAaozTGyt5q+Ft999x2GYaDresJzWPlYj0Tn7969e+PxeOKWlbc/KyurwvnaYrHQtm1bNm/eXGFbdYm5aqOur2P5cb9z585a77PRg9h27dolXP7OO+8wYcIEnE4nRx11FD179sTj8aDreiw3IhgMVnhcVlZWwu2V5+JFIpHYsuuvv562bdvyj3/8g6effponnngCgOHDh/Pggw8mzO3Iz89PuP3y5eXJ5uX/V/b8yg+yRAMOKntMQ6iq/Vu3bqWkpIT09PR6tX9fZ599NjfddBPPPfccd9xxBxaLhRkzZgBwxRVX1Kr9xcXFZGZm4nK5Km3X9u3bKS4urvAhSfTcnU4nGRkZcYMEyk98//d//1dlW7xeb63ankii4zXRsVqd8ja//fbbvP3225Wul6jNlb3Hp5xyCvPmzeORRx7hxRdfZObMmQAMGjSIe+65h9NPPz1u32vWrOHee++tUXvLj5uGKN1U2/erpsd2TbVv357+/fvz5ZdfEgqF4gLV/fffn8zMTD799NNYEKtpWsIfYXVV1TmvNsdQQ+4X4o/fwsJCDMOo9vjY+/g866yzeOedd+jRowennHIK7dq1i+XsT5s2LeH5H2p//qzNsVif82JtXq9HH32UXr168fzzzzN16lQefvhhdF1n7NixTJ06tVY5quUqa/MTTzzBddddR3Z2NkcddRRdunTB7XajaRrvvvsuy5Ytq/S1TqShzmkHH3wwX3/9NQ888ADvvPMOr776KhDNmb3llltq9N1R/j0RCAQS3l/TGKCu56uavhbl57DFixezePHiSreX6PydKC+4fB+V5QxbrdYKP1TqGnPVRl1fR7/fD1Dp935VGj2IrazQ8Z133ondbueHH36gf//+cfddfvnlLFiwoEH2f+6553LuuedSUlLCt99+y/vvv8/MmTM59thjWbZsGX369Ilbv7IqBOXLyw+a8v8TjboF2LJlS9x6e2vM4s/btm1LONCjvP3lPQX1af++nE4nF110EVOnTuWDDz5g+PDhvPvuuwwYMIAjjjiiVu3PzMxk165d+P3+hAd0Ve3atm1bhV+RgUCAkpKSuB6O8scuXryYgw46qFbtS5byNr/99tucdtpptXpsVcfb0UcfzdFHH43f7+e7775j7ty5/OMf/+CMM87g888/Z9SoUbF9n3jiibz33ns12mf5yf23335j8ODBtWrvvmr7ftX02K6NsWPH8tRTT/Htt9/yySefkJuby4EHHoimaYwaNYpPP/0Uv9/PN998w4EHHpiwB7gly8zMJBQKUVJSUqP1f/jhB9555x3Gjh3L3LlzsdlssftM0+Thhx+u9LG1PX/ufSxWpyHPi1WxWCxcffXVXH311RQWFvLVV18xe/ZsXn75ZY466ih++eWXWh9DiV4XwzC4++67adeuHQUFBRV+wH377bf1eh71NWzYMObMmUMoFGLJkiV8/PHHPPXUU1x55ZV4PB4mTZpU5eOzsrJwOByxIDGRmsQAtTlG6qL8eLnmmmtigXRTa4qYq66vY/n717Zt21rvM2nTzq5evZoBAwZUeDFN0+Srr75q8P1lZGRw9NFH89RTT3HjjTcSCASYN29ehfUSvZG7d+/mp59+wu12xwLE8i/TykY6l/fWlF86q4nyEbb16V1J1P6VK1eybds2evXqFbucVV37P/vsM6Dm7b/yyivRdZ0ZM2bwwgsvEAqFEv6Sru45lrcrUZmhn3/+me3bt9OnT5+ElyoSPfcvvvgCpVRcIFWeHlFeoqS5qOq1aew2u1wuRo0axUMPPcQjjzyCUipWTqhfv35kZWWxaNGiSktp7at8VOzcuXNrtL6u65UeE7V97v369cPtdvPTTz+xa9euCvfXZdrk8koDn3zyCZ9//jlHHnlkLGgYO3YsW7Zs4dlnnyUQCFRZlWBvDfF5by4OPfRQSktLWbZsWY3WX716NQAnn3xyXAAL0Uuv5T0zDWHYsGHous5XX31V7dWVvY+dRJc263Jer06bNm04+eSTefHFFzn77LPZsWNHLIUG6nec7Ny5k6KiooSX0L1eb63KCdZFTdtut9s55JBDuPPOO3n55ZcBKpSuq8z+++/P9u3bq01lqCoGqM0xUheHHHIIuq4n9TuntjFXXY67ur6OK1asAOo2QVbSgthu3brx66+/8vvvv8eWKaW49957Wb58eYPsY+7cuQnripb/yk7U0/fyyy/H5U8C3HXXXZSVlTFx4sTYCXfEiBH079+f7777jldeeSVu/c8++4zZs2eTm5vLySefXOP2ltc83LRpU40fs6/p06fH1WGLRCLccsstKKVidSkheim5TZs2vPXWW3zxxRdx25g1axY//PADAwcOjCvRUZXu3btzzDHH8NFHHzF9+nTcbjfnnXdehfXKexcqq7l38cUXA/DXv/417kMQDodjU9NdcsklCR/7t7/9Le4yn9/v54477gCIe+4XXngh2dnZ/O1vf0vYE6GU4quvvqpxwNZQqnptTjrpJHr16sWMGTMq7Q1dsmRJlT0S+/rkk08SziK17+fDarVy3XXXsX37dv785z8nfMzOnTtZunRp7Pb5559PRkYGzz77bMIfSvvma7Vp04YdO3YkDF5q+37ZbDYmTZpEWVkZt99+e9y6a9asYfr06RW2UZ0xY8ZgsVh49tln2bFjR1zOa3nQWl53sqb5sNV9FlJJ+WfzsssuS5iLFwgE4r4oy/P49v1BUX6MNaS8vDzOPvtstm3bxg033BCXdwvR80T5j53yY8fn83HbbbfF1e4sn9pU07SEpexqKhgM8sknn1Roh1KK7du3A/HfTfU5Ttq2bYvb7Wbx4sUVzqfXXXddnXIQa6Oqtn/55ZcJ0zKq+n5OZMyYMSilWLRoUYX7ahoD1OYYqYu8vDwmTZrE0qVLueeeeypc6ofoObE8mGsMtY256hKP1PV1LB9DUZc0rEZPJ6jM5MmTueKKKxg8eDCnn346NpuNr7/+muXLl3PiiSfy/vvv13sf55xzDna7ncMPPzxWEPq7777jyy+/pGfPnpx55pkVHnPMMccwcuRIzjrrLNq1a8cXX3zBN998Q48ePeKKI2uaxosvvsi4ceM477zzePPNN9lvv/1Ys2YNb7/9Nna7nZdeegm3213j9qalpXHooYfyzTffcOKJJzJkyBCsVitHHHFEjS/LjxgxggMPPDCuTuxPP/3E0KFDufHGG2PreTweZs2axemnn864ceM4/fTT6dGjBz/++CMffPABWVlZvPTSS7W6dPfnP/+ZDz/8kC1btnDRRRclvOTWt29fOnfuzJdffsnEiRPp06cPFouFk046if3335+zzz6b999/n9dee40BAwZw6qmnxurErlq1irFjx8bVEt1bv379GDhwIBMmTIjViV27di0nn3xy3OC6nJwc3n77bU455RRGjBjBkUceycCBA7HZbGzatIlFixaxceNGdu/ejd1ur/Hzr6+jjjqK77//ntNOO43jjjsOl8tF165dmTRpEjabjXfeeSdWH/aQQw7hoIMOIi0tjU2bNrFkyRJWrlzJkiVLanwZ8qabbmLdunWxmW6cTic//vgjH330EW3atOGyyy6LrXvHHXfw008/8dxzz/HBBx8wduxYOnXqxI4dO1izZg1fffUVf/7zn5k2bRoAubm5vPbaa0yYMIFx48Zx9NFHc8ABB+D1evn5559ZtmxZ3MnsqKOO4rXXXuOYY47hiCOOwOFwcMABB3DiiSfW6f2aMmUKn376KTNmzGDJkiWMGTOGnTt38uabb3LkkUdWqBtcnczMTIYMGcJ3330HxAeqAwcOpF27dmzdujV2vqmJESNG4Ha7ef3117Hb7XTp0gVN05g0aRJdu3atVfuSbcyYMUydOpVbbrmF3r17c9xxx9GjRw/8fj8bN27kiy++oFu3brEfOkOHDmXkyJHMnj2bESNGcNhhh7Ft2zbmzp1L37596dChQ4O276mnnuJ///sfM2fOZMGCBRxzzDE4nU42bNjARx99xAsvvBCrT/nQQw/x5Zdf8txzz7FkyRLGjRsXqxO7a9cu7rrrLg455JA6t8Xv98dyU4cPH07Xrl0Jh8PMnz+fpUuXMnz48Lgv86rOC9XRdZ1rr72Whx56iP3224+TTz6ZUCjE559/zq5duxgzZkyt6ibXVlVtf/TRR/noo48YPXo0PXr0ICMjg1WrVvGf//wHl8tV6Xl+XxMmTODhhx/mo48+qlBbvjYxQG2Okbp48skn+fXXX7n33nt5+eWXOeKII2LnjZUrV7Jw4UIee+yxRpstrbYxV13jkdq+jqZp8sknn9CrV6865YI3eomtRDViy73wwgvqgAMOUG63W7Vp00adcsop6scff6y0lBNVlOFJ9Jh//vOf6tRTT1U9evRQbrdbZWZmqv3220/dfffdFUrvlJcP+fzzz9Xzzz+v9t9/f+V0OlVeXp666KKLEtaZU0qpX3/9VZ1//vmqQ4cOymazqby8PDVhwoSEtTerK6mjlFJr1qxRp5xyimrTpo3Sdb3a9cuVv95r1qxRDz/8cKy8VceOHdXkyZMrLW2zePFidfrpp6u8vDxltVpVhw4d1AUXXKDWrFlTYd3K3pdykUhEdejQQQHq+++/r7StixcvVuPGjVOZmZlK0zTFPnUwI5GImjFjhho6dKhyu93K6XSq/fffXz388MNx9fDKlZdxCQQC6rbbblNdu3ZVdrtdde/eXd17770JH6OUUhs2bFDXXnut6tOnj3I6nSotLU317t1bnXnmmerVV19NWHdwX9WVo0n0WlX2mLKyMnX11Verzp07K6vVmrC8yY4dO9Ttt9+u9ttvP+V2u5XL5VI9evRQJ554opo5c2Zc/eTqjrc33nhDnXPOOap3794qLS1NpaWlqX79+qnJkydXqOOqVLT00WuvvaaOOuoolZOTo2w2m2rXrp065JBD1J133qlWrVpV4THLly9X559/vurYsaOy2WwqNzdXjRgxQj3++ONx623fvl1NmjRJtWvXLnbc7/v61Pb92r59u7rssstUfn6+cjgcqn///mratGmxEjg1LbFV7rbbblOA6tGjR4X7zj33XAWoI444IuFjE5XYUkqpjz/+WI0cOVKlpaXFSuaUHzPVvX9du3at9Hy4r7qU2KqstE6ix5T79ttv1dlnnx17v3NyctSgQYPUlVdeqRYsWBC3bmFhobryyitV165dlcPhUD169FC33XabKisrS/jcanL+rEpZWZmaMmWKOuCAA5TL5VIej0f17dtXXXXVVRWO96KiInXbbbfFzqPp6elq1KhR6u23366w3dq+XqFQSD388MPq2GOPVV26dFFOp1Pl5OSoIUOGqEcffbRCOaXqzgvVleAKh8Pq0UcfVf3791dOp1Pl5+erP/3pT2r9+vVVvv+NfU776KOP1IUXXqgGDBigMjMzlcvlUr169VKXXHKJ+uWXXyp9PokMHTpU5eXlxZUKU6p2MUB5e2tyjJR/F+5bv7lcZcdDKBRS//znP9Vhhx2mMjMzld1uV506dVKHH364evDBB+PKVFZ3XFUVD1V2TNQ25qoqHqnsPVeqdp+1efPmKUA98sgjCZ9LdTSlajEHYwt2wQUX8OKLL8amkEs1o0ePZsGCBaxbty5p80dv3ryZbt26ceCBB/LDDz802X67devGhg0bajWdqBBCiJbhrbfe4owzzuBf//oXZ599drKbI2rh1FNPZcGCBaxZs6bSUmlVSVpOrGh5HnnkESKRCNdcc02ymyKEEKKVmDBhAocffjj33ntvixgo2VqUTwl/zz331CmAhSTmxIqWYePGjbz22musXr2aF154gf79+zNx4sRkN0sIIUQrMmPGDN588002b96ccjnlrdXWrVu5//77ueqqq+q8DQliRb2sXbuW2267DbfbzZgxY/jHP/4RK8QshBBCNIUBAwYknA1LNF/HH388xx9/fL22ITmxQgghhBAi5UhOrBBCCCGESDkSxAohhBBCiJQjQWwKMwyDzZs3J5z9QwghhBCiJZMgNoVt3bqVzp07x6bQE0IIIYRoLSSIFUIIIYQQKUeCWCGEEEIIkXIkiBVCCCGEEClHglghhBBCCJFyJIgVQgghhBApR+YHbcFM02Tbtm34fD4ikUiymyNEHJvNRm5uLunp6cluihBCiBQkPbEtlGmarF+/nq1btxIKhZLdHCEq8Hq9rF+/nnA4nOymCCGESEHSE9tCbdu2jeLiYjp06EB+fn6ymyNEBT6fj1WrVrFt2zY6deqU7OYIIYRIMdIT20L5fD6cTqcEsKLZcrvdOJ1OAoFAspsihBAiBUkQ20JFIhEsFkuymyFElXRdxzTNZDdDCCFECpIgVgghhBBCpBwJYoUQQgghRMqRIFYIIYQQQqQcqU7QyiilkrJfTdOSsl8hhBBCtEwSxLYiSinW7SxLyr6753okkBVCCCFEg5EgthUKhE1CRtPM4GW3WnDaGi9r5ZdffmHcuHH8+uuvuN3uem/v+uuvZ/Pmzbz11lsJ7584cSJDhgzhhhtuqPe+hBBCCFF3EsS2QiEjQrG/aWZJynRR6yA2EomQkZHBtGnTuPTSS+PuO/PMMwkEArz33nsA/PWvf+X6669vkAAWYNmyZRx55JGV3n/nnXcycuRILrnkEjIyMhpkn0IIIYSoPQliW7GO2S4a6wK/An7b7a/TY1esWIHP52PIkCEV7isoKOBPf/oTABs3buTDDz9kxowZ9WlqnB9//JHrr7++0vv79etHnz59eO2117jiiisabL9CCCGEqB2pTtCKaUQHXDXKv3q0a8mSJdjtdgYNGhS3vLi4mLVr13LQQQcB8OabbzJkyJAKs5LdcccdDBw4ELfbTbt27bj55ptrNKBt48aN7Nq1i2AwyIgRI/B4PBx22GH8+uuvceudcMIJvPHGG/V4hkIIIYSoLwliRbNTUFDA/vvvj91ur7BcKRXrof3yyy85+OCD49Ypn6nsueee45dffuHZZ59l5syZlea47m3ZsmUATJs2jUcffZSFCxcSDoe58MIL49Y7+OCDY/cJIYQQIjkknUA0OwUFBQlTCRYvXkzbtm3p2LEjABs2bGD48OFx61gsFu69997Y7a5duzJy5EhWrFhR7X6XLVuGx+Nhzpw55OXlAXD33Xdz4oknEgqFYkF1+/btCQQCbN++PdYWIYQQQjQt6YkVzc7SpUsrDWLLUwkA/H4/Tqczbp3169dz1VVXMXDgQLKzs0lLS2PevHm0b9++2v0uW7aMiRMnxgJYAI/Hg2maceu5XC4AfD5frZ6XEEIIIRqOBLGiWVm/fj3FxcUV8mFN0+Tzzz/n0EMPjS3Lzc1l9+7dsdvbt29n6NChlJSUMH36dL7++ms+//xzlFIccMAB1e576dKlHHjggXHLCgoKGDBgQFxqw65duwDigl0hhBBCNC1JJ2jFFEAjzeBV162GQiEg2su6t7fffpvt27czYcKE2LLBgwezfPny2O3//Oc/aJrGK6+8Elv24IMPous6++23X5X79Xq9rFmzhkjkj/q54XCYp59+miuvvDJu3eXLl9OtWzeysrJq/fyEEEII0TAkiG3F6loCqzH16tWL3r1785e//IWpU6eSnZ3Nl19+yZ133sk111zDgAEDYuuOHz+eK664AtM00XWdnJwcioqK+PDDD+nVqxezZ89m6tSp9OvXr0Lawb5++ukn7HY7zz77LCNGjMDlcnHjjTeSmZnJ1VdfHbfuV199xfjx4xvl+QshhBCiZiSIbYXsVguZrqbbV23ous6HH37ITTfdxIQJE4hEIvTp04fp06czadKkuHWPO+44NE1j/vz5HHnkkZx00kmcf/75nHXWWaSlpXHhhRdy8sknx1URmDVrFhdeeGGFklvLli2jX79+3HLLLZxwwgmUlpZyxhln8Prrr+NwOGLrhUIh3nnnHebOnVuHV0MIIYQQDUVTNSmgKZqlzZs307lzZzZt2kSnTp3i7lu1ahUAffr0iS1TSrFuZ1mTtrFc91wPmtbwUys8+eSTfPzxx7z//vs1Wv+ee+5h/vz5zJ8/v077mzlzJv/+97/5+OOP6/R4ES/RcSqEEELUhPTEtiKaptE915O0fTeGK664gpKSEnw+X42mnv3oo4+YPn16nfdnsVh48skn6/x4IYQQQjQM6YlNYbXtiRWiuZHjVAghRF1JiS0hhBBCCJFyJIgVQgghhBApR4JYIYQQQgiRciSIFUIIIYQQKUeCWCGEEEIIkXIkiBVCCCGEEClHglghhBBCCJFyJIgVQgghhBApR4JYIYQQQgiRciSIFc3SiBEj0DQNTdOw2Wx07NiRs846iyVLljTafux2OwMHDmTOnDkNug+Ap59+mm7duuF0Ohk+fDjff/99g+9DCCGEaE0kiBXNjmma/Pjjj0ybNo0tW7awcuVKZs2aRWlpKYceeijffvttg+7n8ccfZ8uWLSxfvpz+/ftzzjnnUFhY2CD7AHjjjTe44YYbuPvuuykoKGD//ffn6KOPZufOnQ22DyGEEKK1kSBWNDsrVqygrKyMI444gnbt2tGjRw+OOuoo3nvvPfr06cO9997boPs57LDDaNeuHb169eLmm2/G7/ezYsWKuHUffvhhjjvuOJ566inWrFlTq/089thjXHbZZVx44YUMGDCAGTNm4HK5mDVrVoM8DyGEEKI1sia7ASIJCqsJwlzZ4M754/buDWAala9vT4P0/D9ul/wOYX/07zY9a928goKC2KX9vVmtVsaMGcObb75Z621Wth+LxcKgQYNiyzZv3gxA27Zt49Y966yzAHjrrbe44YYb6NatG8cddxzHHnsso0aNwul0JtxHKBRi8eLF3HHHHbFluq4zbty4ButRFkKIFkkpMIJgc4IZASMANjdoWrJbJpoJCWJboycPqvr+I++AI27+4/bLp8CutZWvf8C5cOo//7j9/nXw68fRv+8prnXzCgoKGDhwIHa7vcJ9drs9tnzChAl8+umnHH300bz++usJt/XYY4/x3nvvMX/+/IT76du3L06nE6UUP//8M3fccQejRo2id+/ecet27dqVv/zlL/zlL3+hpKSE//73v3z44YdcdNFFFBUVMWbMGF599VUyMzPjHrdz504ikQj5+flxy/Pz81m9enVtXhYhhGgdlIJ1X8D8hyCrMxw3FfzFoCKQ3g5srmS3UDQTkk4gmp2CggIOOihxoL1q1Sr69esHwNVXX81LL71U5ba2b99eabBYUFDAypUrSUtLw+VyMXz4cEaPHs3s2bOr3KbT6SQzM5OsrCwyMjIIBALs2LED0zQrfYy2T8+BUqrCMiGEaNWUgrUL4IXj4KWTYOM38NO/YdP3EPJCJJTsFopmRnpiW6NrCqq+35Udf3vSu9WnE+ztxOl/pBPUklKKpUuXcsYZZ1S4r7CwkE8++YTHH38cgNGjRyfsYd3bQw89xEMPPVTpfv76179y3nnn4fF4aNeuXaWB5ZYtW5gzZw5z587l008/xeVyMX78eG6//XaOOeYYcnNzEz4uNzcXi8XC1q1b45Zv3769Qu+sEEK0Snv3vG785o/luX1g6GXR/wGMun2viJZLgtjWqLZ5qtlda7d+Rofarb+XNWvWUFxcXKEnNhQKcdFFF9GlSxcuuOCCOm9/3/0cddRR9OrVq9r1Z86cyfvvv89xxx3HbbfdxrBhw9D16i9k2O12hgwZwn//+19OPPFEIFoV4dNPP+X666+v79MQQojU5t0Ob56fOHjtczRYbNFlkXBy2ieaNQliRbNSUFCApmm0bduWrVu3UlxczHfffcdjjz1GSUkJc+fOTZgrW5f96LrO4MGDa7T+VVddxbnnnhu7vXZtxRzh7t27Y7FYKiy/4YYbOO+88xgyZAjDhg1j2rRp+Hy+BgnGhRAipTmzoGx79O9EwasQVZAgVjQrS5YsQSlFr169sFgsZGVlMXDgQCZOnMjll19Oenp6g+2nX79+pKWlVb8y8NRTT1Vb2mvTpk106tSpwvKzzjqLHTt2cNddd7F161YOPPBA5s2bV2kKghBCtEjlaQOFq2HIBdE8V38xDL8azDD0OUaCV1ErmlJKJbsRom42b95M586dEwZPq1atAqBPnz7JaFqTmT9/PjNmzKi0OoFo3lrLcSpEq7ZvzqvVCRfMjVYZ0HSwu0Gvpk8tEo7mxGZ1keoEIkZ6YkXKOv744/nuu+8oKyujU6dOvP/++zVODxBCCNHIKhuwldUlWk+8wwHVB69CVEGOHpGyPvjgg2Q3QQghxL6qqzYgOa+igUgQK4QQQoiGEwnB7MvAu6e0YG4fGHqp5LyKBidBrBBCCCHqTqloekBmxz+mhx1yAfxvtvS8ikYlQawQQggham/vtIEdv8AVX0eD2FAZ9D8JDpwowatoVBLECiGEEKLmKst5LXgJBp4GjjQZsCWahBxlQgghhKieDNgSzYwEsUIIIYSomlLwr7Nh1bw/lknwKpJMglghhBBCVE2ZkNs3GsRK8CqaCQlihRBCCPGH8rSBxbPglH9ES2b5i6HvcZDZWYJX0WxIECuEEEKIxDmv7faDPsdGp4fN7AjZXZPbRiH2IkGsEEII0ZpVNWDLnSfVBkSzJUelaJZGjBjBt99+C4DVaqVt27Ycdthh3HrrrQwePLhR9mOz2ejduzdTpkzh5JNPbrB9ADz99NNMnTqVrVu3cuCBB/Lkk08ydOjQBt1HY1BKoRp5H+GISaE32Mh7Eakqx2NH07RkN6PlWvcFfP6gVBsQKUmCWNHsmKbJjz/+yLRp0zjrrLPw+XysWbOGxx9/nEMPPZTPP/+cQw89tMH28/jjj3P22Wfj9Xq59dZbOeecc9i0aRNt2rRpgGcDb7zxBjfccAMzZszgkEMOYdq0aRx99NGsWrWK3NzcBtlHY1BKoRo7ggWMiKLIF2r8HYmUo2kauqaR7bEnuykt18+z/whgJXgVKUZTqim+pkRj2Lx5M507d2bTpk106tQp7r5Vq1YB0KdPn2Q0rV6WL1/OwIEDKSgoiOt1NQyDgw46iA4dOjBv3rwqtlC7/Xz//fccfPDBACxatIjhw4fz1VdfMXLkyHrvA+CQQw5h2LBhPPnkk0A0eO7cuTOTJ0/mpptuapB9NAZTKVBgEv2/MaxZs5oiX4iAq23j7ECktNw0B26Hlc7ZLqwWPdnNSX1KwaZF0PmQaLWBkBe2rYAPb4CDLmjewWskDIYfsrqAzZXs1ohmQnpiWxEVChH+/fcm36+tQwc0e817UgoKCrDb7QwcODBuudVqZcyYMbz55psN0q6CggIsFguDBg2KLdu8eTMAbdvGB1VTpkxhypQpVW5v+fLldOnSJW5ZKBRi8eLF3HHHHbFluq4zbty4WBpDc6T2BLAKQNFol3M1wKrrZLia6RenSBpfyKDIH8Zp09lVFqJthjPZTUpd++a8njYTOg+PTg/ryoJzXpecV5GS5KhtRcK//86aY45t8v32nDcXe7duNV6/oKCAgQMHYk8Q+Nrtdux2OytXruTiiy+muLgYu93OtGnTOPzwwyus/9hjj/Hee+8xf/78hPvp27cvTqcTpRQ///wzd9xxB6NGjaJ3795x615xxRWceeaZVba7Q4cOFZbt3LmTSCRCfn5+3PL8/HxWr15d5faSqbzjVSkFWjTYbCxWi0a6BLFiH3aLzk5vEG/QQNM0MsIRnDZLspuVWiobsPXzbGh3gAzYEilPjl7R7BQUFHDQQQclvG/VqlX069cPp9PJCy+8QO/evVm1ahUnnngiK1eurLD+9u3bKw0WCwoKWLlyJWlpaRiGgcVi4bzzzuOBBx6osG5OTg45OTl1fk779mQqpZrtYJW4XlhAa9QQVojEXHYLTpuFEr+Bx2GlsCxExyy5jFwjMj2saCUkiG1FbB060HPe3KTst6aUUixdupQzzjijwn2FhYV88sknPP7443Tt+ketwt69e1NUVJQwMHzooYd46KGHKt3PX//6V8477zw8Hg/t2rWrNLCsazpBbm4uFouFrVu3xi3fvn17hd7Z5qC8GoGiPNBu3F5YIaqS5baxtThAkS+MrmmUBsKkOyX4qtb8B2HB3/+4LcGraKEkiG1FNLu9Vpf1k2HNmjUUFxdX6IkNhUJcdNFFdOnShQsuuCDuvvfee4/BgwfXqmezfD9HHXUUvXr1qnb9uqYT2O12hgwZwn//+19OPPFEIDqw69NPP+X666+vcXubVNwgLglhRfLYLDppTiveQLQ3dldZCI/diq7LcVkpMwI9x0aDWAleRQsnQaxoVgoKCtA0jbZt27J161aKi4v57rvveOyxxygpKWHu3LlxubIbNmzg5ptv5sMPP6z1fnRdr3HN2fqkE9xwww2cd955DBkyhGHDhjFt2jR8Pl+FYDzZyktq/dELK4kEIvkynDZ8oQjF/jAOq06RP0yOlNyK2jttYMzt0H6/6PSwjgw4/QXodLAEr6JFkyBWNCtLlixBKUWvXr2wWCxkZWUxcOBAJk6cyOWXX056enps3ZKSEk4++WSeeuqpGvWm7ruffv36kZaW1tBPoYKzzjqLHTt2cNddd8UmO5g3b16zqxEbG8y15y8JYEVzYNE1Mp02dvtC+EIRNC1MutOKrTWX3EqU8/rZfXDC9Oj0sI406Fr/WtpCNHdSJzaFtdQ6sTURiUQ4/vjjOemkk7jqqquS3ZyUl6xe2NWrf8UfjpDetnMT7E2kKqUU20qCKAXtMh2kOW3kt8aSW615wJbUiRUJSE+sSElz587l008/ZevWrTz77LMAzJ8/n6ysrOQ2LEU1ZUktIWpL0zSy3DZ2lAYpDURLbvlDEVz2VlRy67cC+Oj21he8ClEFCWJFSjrhhBMIh8PJbkaLICW1RCpw2iy47BZK9gzy2ukN0inb1WxL1TW4YKlMDyvEPiSIFaIVk8FcIpVkuWxsDUdLbrVJ0ygJGGS2xIkyytMGMjpCTvfo9LDZ3WHg6dDlUAlehdhDglghBPvU1RKiWbJadNIdNkoCYdIMK0W+EGkOK5aWUnJr35zX/ifB+Aei08NqOhx1r8ywJcRe5NMgRCsV3wuL9MKKlJDuslIWMtjtC+GwOtntC5Gb5kh2s+qnsgFbO3+FQDG4siR4FSIB+VQI0UrFldSSwVwiReiaRqbLxq6yEN6gAUC604rDmoKDvKqsNnAp9DlG0gaEqIIEsUK0QnGDufb0wgqRKjwOK96gQYk/jNtuYVdZiPaZKVh2ae18ePmUP27LgC0hakWCWCFaIcXeg7mkF1akniyXje2lQUr8YXRNoywYrVrQrCkVrXdqtUenh+1wILTpDZomwasQddDMP/FCiIZW3gv7BwlhRepx2Cy47Ra8wQgeh2JXWQi33dI8ryrsnTaQPxDG3hmdHjZUBidOh8zOErwKUQcSxArRikhJLdGSZLnt+Iv9FPvC2Cwaxf4wWW57spv1h0Q5r799H606kJYfnR7WlZXUJgqRyiSIFaIViRvMJUSKs+ga6U4bJf4wgbCVIl+YNIcVq0VPbsOqmx42u5v0vArRAJL8SRepIBCOUOgNEghHkt2UVmn06NHcdNNN9d5OosFce/fCDhs2lHvvvafe+ym3bt06VqxY0WDbEyKRDGe0TuxuX4iIabLLF0pug4p/gxeOg5dOip9h69hHYOLb0P8ECWCFaCDSEysS8oUM3lv6Oy99u4HlW0piywe0z+D8EV056YCOrWve8loaPXo0Bx98MI888ki9tzV79mxstvp/6cV6YVXdS2oddthIvv32WwBsNhu9e/fm/vsf4OSTT45bb8GCBdx//9+IRCL89a+3M27cuPo1XohKaJpGlttGoTeENxRB0zQynBGctiSdn1zZULQ++rdUGxCiUUlPrKjg+/W7GPHQZ9w6+ydWbC2Ju2/F1hJuefsnDn3oU75fvytJLWwZQqGa9Rjl5OSQnp5er33F9cICdcmENU2TH3/8kccee4zffvudn3/+H/369WfixHMpLCyMW3fUqFH4fD68Xq8EsKLRue1WHFadUn8YU0UHeTUJpWDtAlj2erTaQKAYynbCoddKz6sQTUCCWBHn+/W7OHfmQkr8YQDMfVIny2+X+MOcO3NhowSy3bp146mnnopblpuby6xZs4BoL+fkyZOZPHkyWVlZdOrUiaeffjq+nabJlClT6NGjBw6Hg27dujF9+nQA/H4/V199NXl5eTidTsaMGcNPP/0U9/ia7OOtt95i0KBBOJ1OcnNzOfroozFNkwsuuIAFCxbw6KOPRi/Zaxrr169n9OjRXHfddVx77bW0adOGU089FYAPPviAkSNHkpWVRW5uLqeeeiqbNm2Ka8ve6QS1ef7dunXD7XYzZMgQ/vPBB7GSWmVeL5MmTSI9PY3OnTvxz3/+o9r3ZcWKFZSVlTFy5GG0a9eOXr16cdNNN+H3+yukDRiGwaxZL/Kvf71OMBisdttC1FeW207EVBT7wwTCkdhECI2iPHgtTxv48GbY/gsU/w5hfzRwleBViEYnQayI8YUMLn3pByKmqhC87stUEDEVl770A75QI35ZVOL5558nPz+f77//nsmTJ3PNNdfwyy+/xO6/7777ePzxx/nb3/7G8uXLefHFF8nOzgbgL3/5C3PmzOHVV1/lhx9+oG3bthxzzDH4fL4a72PLli2cc845XHzxxaxYsYLPPvuMY489FoDp06dz6KGHcuWVV7Jlyxa2bNlC586dY9tMS0vj22+/5fHHHwfA5/Nx00038cMPP/Dxxx9TVlbG2WefXa/n/+CDD/Lqq68yc+ZMfvrpJy677HLOmHA6S5YsATRuuulGvvnma+bMeY8PPviQd9+dw8qVK6vcZ0FBARaLhUGDBsWWbd68GYC2bdvGrWu1Wunduzc9e/bE4UjxKUFFSrBb9egkCAGDUMRklzeEWd2JrLb2DV7Lc17T20HR5mi1AWeGTBErRBORT5qIeW/p7xT5wjVe31RQ5Avz/rLfOWtol0ZsWUVDhgzh1ltvBeDGG2/k0UcfZcGCBfTv359AIMDf//53nnnmGSZOnAhAz549AfB6vTzzzDO88sorjB8/HoAXXniBLl268Oqrr3LppZfWaB9btmzBMAxOO+00unbtCsD+++8PQGZmJna7HbfbTbt27eLa3a9fP6ZMmRK37Iwzzoi7/eyzz9K9e3c2b95Mp06dav38g8EgU6ZMYf78+Rx88MEoBZdedhmff/4Zz//fc/Tp/XdmzZrFv/71OkceeSQAs2bNonv3blW+5kuWFNC3b1+cTidKKX7++WfuuutORo0aRe/evat8rBBNIdNlwx+KUOwLY0/XKfaHyfY0QMmt6qoNSM6rEEkhQayIeenbDehaxRSCqugavPjNhiYPYvfbb7+42+3bt2f79u0A/PrrrwQCAcaMGVPhcWvWrCEcDjNy5MjYMrfbzeDBg+N6MqvbxwEHHMDo0aPZb7/9OPbYYzn66KOZMGECGRkZVbb74IMPrrDs119/5c4772TRokXs3Lkzmr8KbNy4sdIgtqq2rV69Gp/PV+H5h0IhRo8ew9q1awmHwxxyyCFxjy8PxiuzZMkSVq5cSUZGOoZhYLFYmDRpEvff/0CVjxOiqVh0jQxXtNSWPxRB08KkOa3Y6ltyK1QGb54HgaLobQlehWgWJIgVQLSM1t5VCGrKVLB8SwmBcMONBtZ1PRbIlQuH43uI9x2tr2kapmlWWFaZfe8rL/xf031YLBY+/fRTvv76a+bNm8fUqVO56667WLx4Mfn5+ZXu1+PxVFh24okn0qNHD55//nnat29PaWkpw4YNq3LgV1Vt83q9AMydO5f8/HZxExu4XS527NiR8DWoilKKpUuXcttttzFp0nl4PB7atWvXPGdHEq1amsNKWTBCkT+M06azuyxE2wxn7TaiFHi3Q3p+dMCWisDgP8GvH8PQS6HPMRK8CtEMSE6sAKCsnoMg6vv4veXl5bF169bY7fXr11NSUvMAu3fv3rhcLj777LMK9/Xs2RObzcZXX30VW+b3+1m6dCn9+/evVTt1Xefwww/ngQceYNmyZXi9Xj766CMA7HY7kUj1dXULCwtZuXIld955J2PGjKFfv34VRvrX1oABA7Db7WzctIlevXrRs2dPevXuRe9evejYsWPsNVi4cGHsMVu3bmXDhg2VbnPNmjUUFxczbtxR9OrVi/bt20sAK5olTdPIdNswIibeoIE3aNS8xvXeOa/PjooGssW/RQds7X/OnmoDJ0oAK0QzIT2xAgCPo36HQn0fv7fRo0fz0ksvccIJJ+B0Orn11lux22ue1+Z0Ornlllu48cYbsVqtDB8+nC1btrBu3Tr+9Kc/cfnll3PjjTeSlZVFx44due+++7Db7Zx77rk13seiRYv49NNPGT9+PHl5eXzxxRd4vV769u0LRCssLFy4kA0bNuDxeMjJyUm4nezsbNq0acMzzzxD27ZtWbt2bSzXta7S09O5/vrrmXz99Rhhg+GHHkpxcTFfffkFHTt24owzzuC8887jL3+5mezsbHJycrjllluqHIBVUFCArusMHjy4Xm0Toim4bBacNgslfgO33cpOb5BO2e7KH1BZzuuSV6HvcdEBWzJYS4hmRz6VAgCnzcKA9hms2FpS65zYfu0zGrSw+G233cbq1asZP348+fn5PPbYYyxevLhW27jzzjvRdZ3bbruNrVu30rFjRyZPngzAww8/jFKKiRMnUlpayvDhw5k3bx5udxVfcvvIyMjgiy++4PHHH8fr9dK9e3dmzpwZyzO96aabOP/88+nfvz9+v59169Yl3I6u67z++utce+21DBw4kAEDBjB16tTYoLO6UEox5cEHyWvblilTHmDdunVkZ2dz8MFDufPOOwF49NHHuOKKKzjhhOPJysritttuo7BwZ6XbXLp0Cf369SMtLa3O7RKiKWW5bWwtDlDkD2PRNUoCYTKc+/SgyoAtIVKapvZNPhQpY/PmzXTu3JlNmzZVGAC0atUqAPr06VPj7f3ru43cNvun6lfcx99P36/JB3aJyiml2DO3QSwXtrle+F+9+lf84QjpbTsnuymiBSryhSgNGLTNcOK2W+ic7UbX93walIJXToM1e6UdSfDafEXCYPghqwvYXMlujWgmJCdWxJx0QHuyXDb0GkY8uhbt7TjpgI6N2zBRY6kUwArR2NKdNnRdo8gfJmIqdvv2GiypTMjuFv07t4/MsCVECpIgVsR4HDZmnn8wFl2rNpDVtWg5m+fOOxiXPUlzlIsKVOx/BRoSwIpWzaJrZLpshEIGrFuA/T9XEwoGo9PDFv8GA06T4FWIFCY5sSLO0G45vHbpcC596QeKfOEKdWPLb2e4bDx33sEc3C3xgCXR9NSeLlgFoGpXQkuIFkkpcrcvpOvCR8jc/j0AxQsHY+83BjQdcrpBbq/ktlEIUWcSxIoKhnbL4Ztbj+T9Zb/z4jcb4urH9mvr4bxhnThpcGdcrgaYCUc0mFgvrJJeWNHKKYXzt6/J/u4xXFsWxRb7M3sR0JzYdScuRy1rxwohmh0JYkVCbruVs4Z24ayhXQiEI3gDYTyaiZ09EwqYYVTABIcDTZeslGSL64UFJBNWtFbOzV+T/d2jccFrKLsXv/eexNb8UeRlp2H4FR0cSj4nQqQ4CWJFtZx7ai4CqEgEFQpF/49EUD4fms2GZrfL5eskkcFcQvwhffm/YgFsKLsXxYMuxNdtLBEsUBqk1G9g9WiUBoyKJbeEEClFglhRK5rFAk4nmmFghkKgFCocRhkGusMBFosEs0kj1fJEK6MUju1LCbY9EJSJFi6jpP/ZOAqXUzzwPHzdxsYGa1mJDl71BsJ4HBaKfWE8DisWOV8JkbIkiBW1pmka2GzoVmu0VzYcBqUwA4FokCspBk0mvhcW6YUVrcM+Oa9bjnuBUJsB6GEvEU8+W054KeEMW2lOC/6QQZE/TNt0nWJ/mBy35PYLkaokiBV1pmkamsOBKg9mJcWgyUlJLdGqVDJgK33lW+wadjOmverpYXVNI91po8gXoiwYAcDjsOCwSJlAIVKRBLGi3iTFIDmkpJZoNSoJXvfOea1pjVe3w0JZSMcbCOOy6xT5wuSnSxArRCqSa76iRkKhUJX3a5qGZrOhu91otj1fJntSDFQggDLNJmhl8zN69GhuuummRtm24o/BXMMPGcZ9997TYNtet24dK1asaLDtCVEfOd/cT4c5Z8UN2Npx+ANsOeEVfD2PqfUkBVkuGxFT4Q0YBMIRfCGjMZothGhkEsSKai1atIj09HS+++67atfVNA3d4UB3uaI9tEQrGpg+H2YwGO09bMYaOuicPXs2d999d4Ntr1x5L2xtHHbYSCwWHYtFx+l0sN9+g5gzZ06F9RYsWMBll13KVVddySeffNJALRaijswIvs5HAPUPXsvZrDpuh5WyYIRwRLHbF46m5AghUooEsaJaL7/8MqFQiJdffrnGj9EsFjSnM5pOsOcytwqHMX0+lGE0+2C2OtX1TJfLyckhPT29QfedqKRWdUzT5Mcff+Sxxx7jt99+5+ef/0e/fv2ZOPFcCgsL49YdNWoUPp8Pr9fLuHHjGrTtQlRJKZybv6L9O6dj3/4TWrAEi3cLEU8+28Y9Ue/gdW/pTiugKPUbGKZJiV96Y4VINRLEiiqZpslbb70FwFtvvYVZi7SAuqYYmKbJlClT6NGjBw6Hg27dujF9+nQA/H4/V199NXl5eTidTsaMGcNPP/0Ue+zo0aOZPHkykydPJisri06dOvH000/Hbf+tt95i0KBBOJ1OcnNzOfroozFNkwsuuIAFCxbw6KOPRtuuaaxfvz623euuu45rr72WNm3acOqpp/LBBx8wcuRIsrKyyM3N5dRTT2XTpk1x+9q3Z7e69pU/927duuF2uznooIP44IMP4rZZ6vUyadIkMtLT6NqlMzP++Y9q34sVK1ZQVlbGyJGH0a5dO3r16sVNN92E3++vkDZgGAazZr3Iv/71OsFgsNptC1FvewWvHeachev3hWQvegir93c0w49pTyPQaWSDBK/lLLpGmtNGIGwQDEco8YcxWmnakxCpSoJYUaWFCxeybds2ALZu3cqiRYuqeURFtU0xuO+++3j88cf529/+xvLly3nxxRfJzs4G4C9/+Qtz5szh1Vdf5YcffqBt27Ycc8wx+Hy+2OOff/558vPz+f7775k8eTLXXHMNv/zyCwBbtmzhnHPO4eKLL2bFihV89tlnHHvssQBMnz6dQw89lCuvvJItW7awZcsWOnfuHLfdtLQ0vv32Wx5//HF8Ph833XQTP/zwAx9//DFlZWWcffbZ1b4eVbXvwQcf5NVXX2XmzJn8/PPPXHHFFZx22mksWbIk+ropxU033sg333zNu+/O4T8ffMi7785h5cqVVe6zoKAAi8XCoEGDYss2b94MQNu2bePWtVqt9O7dm549e+JwOKp9PkLU2b7B6145r76uR2Ha01CO9CorDtRHmsOC1aJR7DeIKJMif7hR9iOEaBxSnUDE/Prrr3z99ddxy/bOmdQ0jYcffpiTTz45bp2RI0fSu3fvardfkyoGwWCQv//97zzzzDNMnDgRgJ49ewLg9Xp55plneOWVVxg/fjwAL7zwAl26dOHVV1/l0ksvBWDIkCHceuutANx44408+uijLFiwgP79+7NlyxYMw+C0006ja9euAOy///4AZGZmYrfbcbvdtGvXrkL7+/Xrx5QpU2K3+/TpE3f/s88+S/fu3dm8eTOdOnWq9HWorH09evRgypQpzJ8/n6FDhwJw2WWX8dlnnzFz5kz+8Y9/UFJayqxZs3jtX/9izNgj0dGYNWsW3bt3q/K1X7KkgL59++J0OlFK8fPPP3PXXXcyatSoGr13QjQ0x9bF5HzzQL2rDdSHpmlkOG3sKgtRFjTRNIM0hxWnVaoVCJEKJIgVMffccw+vvfZapfcrpXj33Xd5991345ZPnDiRV155pUb7qG6ihFWrVhEIBBgzZkyFx65Zs4ZwOMzIkSNjy9xuN4MHD471ZALst99+cY9r374927dvB+CAAw5g9OjR7Lfffhx77LEcffTRTJgwgYyMjGrbfvDBB8fd/vXXX7nzzjtZtGgRO3fujPUob9y4scogtrL2rV69Gp/PV+G5h0IhxowZg1KKtWvWEg6HGTbskNi87+3bt48F5JVZsmQJK1euJCMjHcMwsFgsTJo0ifvvf6Da5y1EY7CUbU84PWxTBK97c9otOEIWvIEwbofObl+Y9hkSxAqRCiSIFTHTp0+npKSE//znPzV+zIknnsi0adNqva/KJkowA4EaPXZv+w5ustlsFdYvz+W1WCx8+umnfP3118ybN4+pU6dy1113sXjxYvLz86vcr8fjibt94okn0qNHD55//nnat29PaWkpw4YNq3bQV2Xt83q9AMybN69CT7DT6UQB5p5AWddrPjOXUoqlS5dy2223MWnSeXg8Htq1ayd1ZUXT2VPnNZzZnYinHVq4jGDuIMq6jcfXeVRSgte9ZTit7CgNUuI30DUNbzBMmiN57RFC1IzkxIqY3Nxc3nvvPZ588knsdjt6JVPH6rqO3W7nqaeeYs6cOeTm5tZ5n/tWMejdsycul4tP582rUMWgZ8+e2Gw2vvrqq9gyv9/P0qVL6d+/f433qes6hx9+OA888ADLli3D6/Xy0UcfAWC324lEItVuo7CwkJUrV3LnnXcyZswY+vXrV2GUf20NGDAAu93Opk2b6NWrV9y/jh07gvrjNVi08I9LsFu3bmXDhg2VbnfNmjUUFxczbtxR9OrVi/bt20sAK5rGPjmv2d89gsW7BWvpb2iRADuP+FuDVRuoj/KSW76gQThiUuwzYj8YhRDNl/TEijiapnH11VdzxBFHMHTo0IS9ilarle+//z6WS9oQ+yxPMXBZrdw8eTI333YbVouF4cOHs7WwkPUbNvCnP/2Jyy+/nBtvvJGsrCw6duzIfffdh91u59xzz63RvhYtWsSnn37K+PHjycvL44svvsDr9dK3b18AunXrxsKFC9mwYQMej4ecnJyEwXx2djZt2rThmWeeoW3btqxduzaW51pX6enpTJ48meuuuw7DMBgxYgRFRUUsWLCAjh07MeGMM0hLS+O8887jL3+5mezsbHJycrjllluqHIBVUFCArusMHjy4Xu0TosYqmWHLseMntFBpdHrYJAeu+0p3WgmEIpT4DWxpOsUBg2xX82qjECKeBLEiIU3TKr0sHgqFKu2lre8+NYeDu+65B4vFwu333MPWbdvo2KED1197LUopHn74YZRSTJw4kdLSUoYPH868efNwu9012kdGRgZffPEFjz/+OF6vl+7duzNz5kwOOeQQAG666SbOP/98+vfvj9/vZ926dXTr1q3CdnRd5/XXX+faa69l4MCBDBgwgKlTp8YGnNXVgw8+SNu2bbn//vtZt24d2dnZHDx0KHfecWesGPujjz7GFVdcwQknHE9WVha33XYbhYU7K93m0qVL6NevH2lpafVqmxDVasDpYZtatOSWlRJ/GH8ogkaYNIcFWyOc64QQDUNTqV51vhXbvHkznTt3ZtOmTRUGEq1atQqoOIK+pu655x7uvfdeIBqwmaYZ+7/8/saYiaqcUgr2qmIAwJ5SXVgsreZyeKKJDVrSM1+9+lf84QjpbTtXv7Jo9txrP6Ld3Itit1MheN2bUoodpdEf73npdtIcNnLTpMxcsxAJg+GHrC5gcyW7NaKZkJ+YIqF///vfsb+HDx/OF198wbBhwxLe3xjqOlFCS1P+C1MpBRotKoAVLYBS0eACwIzgb38w4bQODTY9bFPTNI0MlxUjYlIWjFAWMvCHq8+RF0IkhwSxooLVq1ezfPlydF3nrrvuYsGCBRx++OF8+eWX3HnnnWiaxv/+9z9Wr17d6G2p7UQJLYna0wVb/gxbVh+sSGl7DdjK/u6R2PSwVt8Oto+dlnLB696cNgsOmwVvwCBiKnb7ajbFtBCi6UlOrKggLS2Nk08+mcmTJzNq1KjYcqvVyn333cfYsWN5/PHHmzTHsiYTJbSkFAOlohmwLTWNQKSoBDmvjh0/4es6logzG9OehunMSm4bG0CGy8rO0iAlAQOLrlEaDJMuJbeEaHYkiBUVtGvXrsKEBnsbNWpUXHDbVKqbKAGHA60lDcJouZ3MItVUM2DLSGufkr2ulbFZdDwOK96AgcduodgXxmO3oregH8pCtAQSxIqUU9lECcrnQ7PZ0Oz2lO6VbemDuURqsZZsJO+T61Ou2kB9pTmt+EIRigMGdqtOkT9Mjtue7GYJIfYiQWwLZbFYqp05KtW11BSD2GAuWv5grkjERNNaUO95CxRxZGHfHa120hqC13K6ppHutFHsC+ELRtDQSHNYsVvkeBWiuZAgtoVyu92UlJSwbdu2aqdTTWUtLcUgbjCXqjjFbkvi8/kIBALYPBnJboootydtwBIooqzHsWjhMvRgCUUHXoFpS2sVwevePA4LZSEdb8DAadcp8oVpmy4lt4RoLiSIbaHy8/Px+/38/vvvFBYWYrW2orfaNP+oLQugaZBCgSxEg9iWG75GBQIB0HWcGTnJborYJ+fVcOUSzOmHpgyUZsHb5xTQW9E5ZC9ZTis7vSG8wQi6puEPW3DZWudrIURzI5/EFkrXdbp168a2bdvw+/0YhpHsJjUdXcfYsZPQ+nXRnllAsztwdO+GpU2bJDeuchFTEYqYhCMmEVNht1rQW3Akq9ucuNOy0C1yGkqaSgZsmc4srKWbCOf0brXBazm7zYLLbqEsYOC2W9hVFqZDlkVK3gnRDLTus1MLp+s67du3T3YzkqNPHyIH7M/OJ59i1yuvQCRasNw5ciT5d9yOo3v3JDcwnmkqNu/2o4UNiooDpDmtZMogEtFYUnh62GRId1oJhIOU+g2sHo1Sv0GGS14fIZJNpp1NYVVNOyv+EFi1im33/Q3fDz9EF9hstLnoInIvvwzd7U5u4/bYVRaiyBdie2mQsGHSLtOJpSV3w4qk0gO76fLiUHTDD0jwWhMlAQOvP0ybNAcum4UO2S4sLThnvdmRaWdFAqmVKChEHTj79KHLyy/RYepULHm5EA5T+MwzrDn+BEo+/jjps36FIybF/jC+UIRgOEKmyyYBrGhYSqH7d0X/NiMozUJpn9NTdnrYZEhzWLDoGsWBMCaKIl842U0SotWTntgUJj2xtRfxeiukGHgOO4z82/+atBSDbSUBvIEwW0uCaEB+hqNFVyUQTWivtAGLbwe/nT4HPexDD3tRpoFyZErgWgu+YIQiX4gMt500h4V2mU4ce6bDFo1MemJFAtITK1oVS1oa+bfdSvd3ZuMeOhSAsq++Yt1JJ7P98WmYPl+TtscfilAWNCgNGBgRk0y3TQJYUX9K4dz8Fe3fOZ0Oc87CtWUR9uK1pK+ajWb4Me1pKHeuBLC15HZYsFl1ygJhTKXYXSa9sUIkkwSxolVy9ulDl5dejKUYqCSkGCilKCwLEjEVJQEDp82Cyya9OqIeEgSvQCxtoKz70ShHequvOFAfWS4bEVPhDRgEjQhloVZU+UWIZkbSCVKYpBM0jIjXy86nnmbXyy83aYpBsT9MoTdIoTeEP2yQn+HEJrMBiboyDdq/dw6u376JLZIBW42jyBfGH4qQm75nkFeWU0puNTZJJxAJyDemaPUsaWnk33pLk6YYRExFkS9E0DDxhQzSHFYJYEU9aRju6Ox8MmCrcaU7rWhAqd/AME1K/NIbK0QySE9sCpOe2IanlKLkPx+w7eG/E9mxEwBr+/bk33Yr6Ucd1WD5qju9QUr8YbaVBjEiJu0zneiSCytqSimcv32DZ+2HFI68F83woQdLsO1eha14o/S8NoHSgEGpP0ybNDtOm5UOWU6sKTYzYEqRnliRgASxKUyC2MbTmCkGIcNk824fZUGDXWUhsj120hySoyhqoLzawPeP4/p9IQDbjpxGqO1+KM2Csrkk37WJKKXYURoENPLS7aQ5beR6HMluVsslQaxIQH42CpFAY6YYFJYFMZWixG9gs+oSwIrq7Ttga08AG8ruBZoWrTYgA7aalKZpZDhtGBGTsqBJWdAgYESS3SwhWhXpiU1h0hPbNBoyxcAXMthaHKDYH6bEH6ZtugOHVCQQVXBu/kqmh23GCr0hwoZJ20wHTpuV9hnOZDepZZKeWJGA9MQKUQ1N08g88QR6zp1LzgUXgMWCsWULv117HZsuvYzgunU12o5SikJvCMNUlAYM3HaLBLCiWplLn61QKksGbDUfGU5r7MpKyIjgDUrtWCGaigSxQtRQfVMMSvxGdIpZXxhQZLntTdBqkVKUwr7j5+jfZgQtWELJwEmEsntL8NpM2aw6HqcVX7D8821gygVOIZqEpBOkMEknSJ7aphhETMWmXT58oQg7SgNkuGxkuiQQEXvsNT2sa8sifjv53xgZXaLTw2oWlNUpgWszFjGjg7ysFp3cNDsZLjvZ8vluWJJOIBKQnlgh6iAuxeD886tNMdhVFsJUiiJ/CIuukeGUATiCSmfYyvjlX39MD+tIlwC2mbPoGukOG6FwBH8oQqk/TNg0k90sIVo86YlNYdIT23wEVq1i231/w/fDDwBoNhs5F11E7uWXEbY7+G23n9KgQVFZiDZpdtx2CWJbtX16XsvJgK3UFS25FQIgL92Ox2EjL01KbjUY6YkVCUgQm8IkiG1eKksxsF97A6ERR7CtJIhV12gro5dbL6VA02jzxR1k/vRCbLEEry1DMByh0Bsi3WUj3WmlbboTlwzebBgSxIoEJJ1AiAZSWYqB77abCd5wLWzaKIO5Wqu9+wrMCP4OhwJSbaClcdiiFUe8AYOIqdjtC6GQfiIhGov0xKYw6Ylt3gIrV7Hh7nsxlxYAoGw2rGf/CcukC9Fc0pPQKpSnDfwwjR2j/k7EnYceLEEPlWLf/iOBTiMkcG1hwhGTnaVBnDYr2R4b2W47GU55j+tNemJFAtITK0Qj8XfqhvWJGfhuvgszOwctHCby8guE/nQGkQWfIb8fW7B9B2z99i05i/6O1ft7dMCWI51A11ESwLZANouOx2HFHzIIGSYl/jAR+awL0SgkiBWiERgRkyJ/mEDYxHvYWMLPv4HlzHPBYoHt2zDuuIXwTddibtyQ7KaKhlRJtYFQdi98HUfK9LCtRJrTiq5rFAcMIkpR7JcJEIRoDJJOkMIknaD52l4SoDQQZmtJEFC0y3CiaRrm2tUYjz2MWrYkuqLNhkVSDFoEx5bvyfn2Qak2IAAoC0Yo9oXIcttxOyy0z3Rht0i/UZ1JOoFIQD5RQjSwQDiCN2jgDRoYEZMslz02+YHeoxe2J5/BetffIKcNSIpBi2ErXi/Tw4oYj8OC1arjDURn8NrtCyW7SUK0OBLECtHAdnqDsbnUnTYLLnt8iR1N07AcdQz2196qmGJwo6QYpIQ9aQO6b2dselh/u6H4Oo+S4FXEZDmtGKaJNxghEI7gDxvJbpIQLYqkE6QwSSdofkoCYXaWBtlVFqIsaNAu04mtmkuIkmKQQvaZpKBov4soPvDyP6aHtbkk31XE2V0WIhA2yctw4LRa6JDlREOr/oEinqQTiASkJ1aIBmKait1lIYKGSVnQIM1prTaAhb1SDO68T1IMmqtKBmy5tnyHFvbJgC1RqQyXDVCU+g0M06TUL72xQjQUOeMK0UCK/GEiZnQksq5rtaoNqWkalvHHoo88nMjzM4m8/UYsxUAbOhzr9Tehd+naiK0XCcn0sKKeLLqGx2nD6w/jtlsoJozHacWiSW+sEPUl6QQpTNIJmo+QYfJbkZ+yoEGhN0i2206as+6/ESukGFitWM6ZJCkGTWXP9LCeVe+S/98/xxZL8CrqQinF9pIgmqbRNsNBmsNGG4/M3lcrkk4gEpB0AiEawG5fCNM0KfaF9xQ7r9986RWqGBiGpBg0BaXAjET/NiP4O47AcLWRagOiXjRNI8Nlw4hEB3l5g2GCkUiymyVEypOe2BQmPbHNgz8UYUuxnxJ/mGJ/mLx0B05b/YLYvakyL5EXZhJ56w3Y88UnKQYNbK+0AV/XIykZdF50etiwF71sO0ZWdwlcRb3tLA0SMRV5GQ5cNivtMpzJblLqkJ5YkYD0xApRD0opdnqjX0wlAQOX3dKgASyA5knDevVkbM+/gnbA4Oh+v19I+PyzMZ55GuX3N+j+WpUEA7aylvwTW9Ga6PSw9jSMNn0kgBUNItNlI2IqSgMGQSNCWUgGeQlRHxLEClEPJQGDcMSkyBcGFFmuxgt2EqYYvDJLUgzqoorpYXcNu5mIK1eqDYgGZ7PquB1WfMEI4YiiqCyMQj63QtSVpBOkMEknSK6Iqdi824cvFGF7SYAMp41Md9P02EmKQd3ZitaQ+9nNUm1AJEXEVOwoCWKz6rRJs5Ppsjfqj98WQ9IJRALSEytEHe32hYiY0ekkLbpGuqvpeu3iUgwOPAiQFIOaitgzcez8GZDpYUXTs+gaaU4rwXCEQChCiT+MYZrJbpYQKUl6YlOY9MQmT9CI8NtuP96gwe6yEDkeOx5Hci49K6UwP/kI46lpsKswurBtPtZrb0A/Ygxaa65HuWfAlmYE8XcZjRYuQw+WkP7L60RcudLzKpJCKcWO0iCgkZduJ81hIzfNkexmNW/SEysSkCA2hUkQmzxbiqM1YbcWB7DoGvnNYJSxpBjsZZ9JCsLpnfj9pDfQIwGZHlY0C4FwhF3eEBkuG2lOK/kZTpzWhh0U2qJIECsSkHQCIWqpLGjg33MZMGI27mCu2qgyxeDZf7SOFINKBmwpqxNb6UaZHlY0G06bBYfNgjdgYCrFrrJQspskRMqRntgUJj2xTU8pxebdfnyhCNtKArhsOm2a4WXAVpdiINPDihQUjpjsKAnicljJdtvI8dhJd8hxmpD0xIoEJIhNYRLENr0iX4hdZSF2loYIGAbtM11Y9OYbELb4FIM908NayrbS5aXhaGYYkOBVpI5if5iygEFuugOXzUKHLBd6S/uR2RAkiBUJSDqBEDVk7KkHGwib+MMG6U5bsw5goQWnGCiFHiiK/m1GMK1uvD2Pl2oDIuWkOazoukZJwCCiFMUBmQBBiJqSntgUJj2xTWt7aYBSf5itJUGUUrTPdKbUZfkWkWKwV9qAHi7j95PfRA+Vooe9EAljOrMkcBUppywQodgfIttjx2230i7Tid0ifUxxpCdWJCCfEiFqIBCO4A0YeEMRjIhJltuWGkHfXjRNw3LUMdhfewvLWeeCxQLbt2HccQvhG6/F3Lgh2U2sXIIBW46dP+NZ85/Y9LCmJ08CWJGS3A4dq1Wn1G9gKpNifzjZTRIiJUgQK0QNFJaFMJWi1B/GYdVx21N3dHtcisEBg4FmPFFCFdPD7jj8Afydj5BqAyLlaZpGptOKYZp4gxF8IQN/OJLsZgnR7Ek6QQqTdIKm4Q0abC8JsNsXwhswyM9wYre2jN9/zTbFQCm0SJB270/E9fvC2GIZsCVasl1lIYJhk7YZDhxWCx2ynGik1hWfRiPpBCKBlvFNLEQjMU3FLm+IUMTEGzDwOKwtJoCFfVIMzmxeKQZKt2Ha0wGZHla0DhlOK6Ao8RsYpkmpDPISokrSE5vCpCe28e0uC7HbF2JHaZCQYdIu09nsKxLUh7l2NcZjD6OWLYkusFqxnP0nLOddhOZqxN6PPQO23BsXsGv4rbHpYe07f8bq3So9r6LVKPEbeAPhWMmt9lkuLCmWf98opCdWJCBBbAqTILZxhSNmdGKDoMFOb5Ast410Z8sPpJRSmP+dh/H09MZPMUgwScHWo2cSzuqO0q0oq1OCV9GqmEqxoySIpmu0TXeQ7rSR47Ynu1nJJ0GsSKDlXBcVooHtLgthmiZF/jA2i06ao3UMHtI0Dcv4Yxs3xaCKAVtEAn9MDysBrGhldE0j3WnDMEzKghFKA2GCERnkJUQi0hObwqQntvEEwhF+L/JTGghT5Pvj0l5rlDDF4JxJWCZdWPsUA5keVoga2VEaxDQVeRkO3HYr+enOZDcpuaQnViQgPbFCJLDTGyRiRgdYOG2WVhvAAug9emF78hmsd94HOW3AMIi8/AKhP51BZMFn1Op3sKaR/d2jFUplyYAtIeJluWxETIU3YBAIR8tuCSHiSRArxD5KAmFCRjSNwFSKLLcEVnVOMVAKW+HK6N9mBC1YQsnA8yV4FaIaNquO22GlLBghHFHs9oVRyIVTIfYm6QQpTNIJGp5pKjbt9uELRdheEiDdaSVLBlVUUG2KwV5pA84t3/HbhP8QceWih70odJTNJYGrENWImIrtJQEcVis5aTayXHYyXa30cyPpBCIBCWJTmASxDa/QG6TYH2ZbaRAjYtIuo2WX1KqPyiZKcE88jjz1X9xbv4utW9L/HIoOvDwavMrsWkLUWEnAwOsP0ybNjstmpX2WE6veCi+iShArEmiFnwQhEgsZJiUBg7KgQSgcIdNlkwC2CvETJZwDug7bt+F7/AV2vr6aYIklljaw++DrZHpYIeog3WHBomsU+w0iKprmJISIkiBWiD12lYWImCYlfgObVcdjb72DuWpD86TRbuDvdB+/FXdeEICyrU7WftSeDYXHUNZxtKQOCFFHmqaR4bJhREzKQiZlQYOAISW3hAAJYoUAwBcy8IUMSgPR6R6zXLaGLerfkpkRAu2H4swyaHd6FmkXHAXZWRAxMd6cTeDSKzG++qZ2VQyEEDEuuwW7VacsEB1sutsnvbFCgOTEpjTJiW0YSik27/bjD0fYWhzAadPJTXMku1nN054BW1mLn2LHkY9i2tPRgyXooRIc25bg73wEWGyoMh/hV17DmPM+mCYA+kGDsV91OXqnjkl+EkKknrBhsqM0iMdpJdNlIzfNgcfeitJzJCdWJCBBbAqTILZhFPvCFJYFKfSG8IcN2mU4sVrkIkWcBJMUlPQ7k+IDL0dplkoHbJnrNxB6egbmTz9HF1itWCeciu3sM9Gcrbx4uxC1VOQL4wsa5GU4cFmjg7z01nLFSIJYkYB8U4tWLWIqdvtCBA0TX8gg3WGTAHZvVUwPG2g7+I/pYSsZsKV364rj4SnYb7kRsrPBMDBe/7ekGAhRB+lOK7qmUeI3MFR0IKoQrVkruhYhREW7ykJ7csxCWHSNdJd8JMo5fv+OnIUP1Xt6WE3TsI4ZjWXYMMKv/gvj3fdQO3YSuv9BSTEQohYsukaa00qJP0wgFEEjjMdhwdYaS24JgfTEilYsaEQoDYTxBg3Chkmmy9Z6Ls3VgGPn/xp0eljN48Z+2cU4//EE+n6DADALlhC44mpCs15CBQIN2n4hWiKPw4LVolMSMDCVSbEM8hKtmASxotUq74Ut8YexW3U8jlbcC7snbUALlsSmhy3rPAp/h+ENPj2spBgIUXfRklvWaMmtYISykJTcEq2XDOxKYTKwq+7KggbbSgIU+UKUBgzyM5zYra3wN90+A7Z2D7mWkgET0cNe0HRMm7tRJyhQZb5YioFUMRCi5gq9IcKGSV6GA6fNQofMFj7YSQZ2iQRa4be2aO2UUuwqCxGOKLzBCB6HtfUFsJUM2HJt/grN8GPa0zAdGY0+w1aVKQYvSIqBEJXJcFlRKEoCBuGISWlQ0gpE6yM9sSlMemLrpsgXYldZiJ2lIQKGQftMV+uZXjZBqSyo24Cthm+aIjJ/AaFnn4fduwHQ8nKxXX4JlpEjZPIJIfZR4g/jDRjkpjtw2Sx0yHK13Lx+6YkVCUgQm8IkiK09I2KyebcfX8hgR2mQTJeNDFfrmRI1fflr5H1+c+x2cwhe9yUpBkLUTMRU7CgNYrHo5KXZSXfayHHbk92sxiFBrEiglV1DFa3dLl+IiGmy2xfGatFId7bwwVxKgYoGgpgRfJ2OwLSlNUi1gcYiKQZC1IxF10h32giHI/iCEbwBg1DETHazhGgy0hObwqQntnYC4Qi/F/kpDRgU+ULkpjlw2S3Jblbj2CttwNv7JLx9TotODxv2YvFuIZzdq1kFrpWpPMXgUiwjD5UUAyGA7aVBMCE3w47HbqNtegucNlt6YkUCLbwbSog/FO5VUstps7TMADZBzquteD2B/INQVld0wFZu/yQ3suZkogQhqpfptFLoDeENRtA1DX/YgssmX++i5ZN0AtEqlAbCBMMRinxhTKXIbGl5sFVMD7t7yLWYzuwqp4dt7iTFQIjKOWwWnHYLZQEDw1TsKgujkIusouWTdIIUJukENWOaKjqYK2ywrTiAx2Elx9NyBj/Ydq0id/6tza7aQGORFAMhKjIiJjtKgzhtVrI9NrJd9pY1aFXSCUQCqdktI0QtFPnDGGZ0ekZd01pcL6xp9eDcVgC07OC1nKQYCFGR1aLjcdjwBsJ4HBaK/WE8TisW+VEnWjDpiU1h0hNbvfCeklplQYNCb5Ast410ZwoHd3tyXkEj0GE4WrgMPVhCxs+zMNI6tejgtTLm+g2Enp6B+dPP0QVWK9YJp2I7+0w0pzO5jROiCZlKsb04iG7RaJvuaFklt6QnViQgQWwKkyC2et6gwfaWML3sPgO2Qtm92XL8S+iGD6VZUDZXyua7NgRJMRAiantJkIipaJvhwOOwkp/eQn7ISRArEkjBb3Mhai7NYcVhs5DhsqFrGsX+FJuasZIBW6CweH/HtKel9ICthlKeYuB6bgbW004BXY+lGARvvxtz82/JbqIQjS4YjmBEzGgaga6R0dLrYItWT3piU5j0xNZMStaHbcbTw6aCaIrBPzF/+l90gaQYiFagRdeLlZ5YkYAEsSlMgtia214aoNQfZmtJEFC0y3A260vM1uINdH5lJNqeMjkSvNaepBiI1qQsGKHYFyLLbcfjsNIu04nd0oIutkoQKxKQaw2iVchx2/EFI2S7bewoDVIaMJpX+Rml0MJelD0dzAgRZza+rmOxlayX4LWO4qoYvPIaxpz3/6hiMOQg7FdeJlUMRIsQMRWlgTA2mwW3w0Ka09qyAlghKiE9sSlMemJrp8gXYldZiJ2lIYJGhHaZTix6knvj9kobQNPYetwLselhNSNAxNVGgtcGUiHFwGbFerqkGIjUV+IP4w0Y5KY7cNksdMhyobe0Kw3SEysSkJ9qotXIdNmwWXQy3TYUJHeQV4IBW67fF+Je/180w49pTyOS1k4C2Aakd+uK4+EHsd9yI2RnQ9jAeP3fBC69EuOrb5Df8yIVhSMmZUEDl8OK3Ro9v7W4AFaISkg6gWg1NE0jx2MnHAmQ5rBQGjBI23PibzLVDNgKtB8mgWsjqjLFQCZKECmoxG+gEa1EYLPopDvk/CFaD0knSGGSTlA3W4qjkx9sLQ5g0TXyM5rmUrIWKqPdfyZJtYFmRFIMRCoLhCPs8obIcNlIc1rJz3DitDbzyit1JekEIgFJJxCtTo7Hjq5pZLhshIzopbimoKxO2HOZL5Tdix2HP8CWE17B1/MYCWCTJJZi8BdJMRCpRSlFid/YM92sBY/d2nIDWCEqIT2xKUx6YusuWqEgzNaSAKapaJfpbNg8sj1pA84t31M05NrY9LCObYux+HdLz2szpMp8sRQDTBNAUgxEs+UNGJT4w+R47LjsVtpnObHpLbhfSnpiRQISxKYwCWLrLmIqNu3y4Q9H2F4SIMNpI9PdAEFlgpzX3094lUhaO5keNkUknCjh9FOxnSMpBqJ5iJiKHSVBbFadNml2Ml12sppTycDGIEGsSKAF/2wTonIWXSPbbcdh1XHbrZQGwxgRs+4brGR62FB2L/RgsUwPm0IqVDEwDIw3ylMMvpYUA5F0pQEDUykyXFasmi7Ty4pWS3piU5j0xNaPUorNu/34wxG2Fgdw2nRy02o5TaNMD9uiqTIf4Vf/hfHue5JiIJqFsGGyozSI22Ely22jjcdOWmuoSCA9sSIBCWJTmASx9ecLRasUFPvDlPjD5KU7cNpqMThCKTq+eTSOndFLzxK8tkzRFIMZmD/9HF1gtWKdIFUMRNPbWRokYiryMhw4bVbaN1F1laSTIFYkIOkEolVz26247VbSnVasuk6RP1z15WKlsBWtif5tRtBCpRQPPH9PtYH7pdpACxVNMZgSn2IgVQxEE/OHIoQME48jOqFBdkPk8QuRwqQnNoVJT2zDCBkmvxX58QbC7CoLke2xk+bYJ8dsr7QB57YCNp05D2VLQw97UejRAVsSuLYKCVMMhhyE/crLJMVANBqlFNtLgmiaRl66nTSnjVxPLdOfUpn0xIoEpCdWtHp2a3RghMdhxW6zUOwPEzH3/LZLMGBLM8Nk/u8VdMMXHbDlzJAAthXRPG7sl12M8x9PoO83EABzcQGBK68mNOslVCCQ5BaKlqg0GCFiKjJdViya3vKrEQhRA9ITm8KkJ7bhmKZi024fvlC05Fa6w0K73d/LgC1RJaUUkfkLCD37POzeDYCWl4vt8kuxjDwUTeawFw0gYiq2lwRwWK3kpNnIctnJbG1BrPTEigSkLocQgK5rZHvsRMzoqN+O868jf/17sfsleBWJaJqGdcxoLMOGxSZKUDt2Err/QUkxEA2mNGAAGumuaO5+hku+uoUASScQIibDacNu1cly6JS2OQCQ6WFFzWgeN/bLL5EUA9HgQuEIvqCBx2HBZokO5tKQHn4hQNIJUpqkEzQApWDdF/DNk3DqDAKmhW3bt+Ev2YH222LoNQ6nlFAStSApBqIh7SgNYu4pqeW2W8lPb6XnI0knEAlIEJvCJIith/Lgdf5DsPGb6LJhl8OQCyn0hSk1bewoi448z0u3S+Ahak2V+WIpBjJRgqgLXzBCkS9EptuOx2GhXaYTh6UWdaxbEgliRQKSTiBaF6Vg7QJ44Th46aQ/AtjcPpDXHxxpZGS3QbfYyXBaMSImZcFIctssUlLCFIOCJQSukBQDUT1TKUoDYaxWHY/DQrrT1noDWCEqIT2xKUx6Ymtp/dfw2f1/BK4QDV6HXgZ9jo7Ldy3yhyn2hyj0hggbJnkZDiy69MaKupEUA1FbJX4DbyBMbroDl81C+ywXltZ8nEhPrEhAemJF67Hx2/ie12MfgYlvQ/8TKgzYynBasWrRUcDRHhEjCQ0WLUV5FQPXczOwnnYK6HqsikHwjnswN/+W7CaKZiR6BSiMy26N1rF22Vp3ACtEJSSIFS1Tec5r2A9mBALF0Pd46DS0yuC1nK5pZHls2Cw6HqcVX9AgbJhN/CRESyMTJYiaKNlTUitjT0mtdKeU1BIiEUknSGGSTpDAvgO2xt0DA0+HUBloOtjdoNf8C2FLSYBA2GBHSRCLrpGb3oqmeRSNSlIMRCLBcIRCb4h0l410p5W26U5cNsmFlXQCkYj0xIqWobIBW6s+jvbGOtLAmVGrABYg221D1zQ8Dhshw8QfkkFeomHEpRicerKkGAiUUhQHDKy6TprDgstmlQBWiCpIT2wKk55YEpfKgkoHbNXFzrIg3kCYHaUhlFK0zXBIL5locOb6DYSe/ifmT/+LLrBZsZ5+Krazz0STWsWtQlkgQrE/RLbHjttupV2mE7tF+poA6YkVCUkQm8IkiAUWPQNz//LH7QYMXssZpsmWogD+cIRCb5A0l40MyVETjUApReTzBYRmSopBaxMxFTtKg1gtOrlpdtKdNnLc9mQ3q/mQIFYkIN/EIrWU/+bStOiArZ5HgtUZPbE1cPBaLjpXuQ0ThdNmpSwQxmO3SMkt0eA0TcN65GgshwyLTZRQnmIgEyW0bN6ggWkqMjxWLJpGlkumuBaiOtITm8JaVU/s3mkDQy6AvseAvzg6YGv3esgf2ODBa9zuUfy+pzd2Z2kQp00n2yO9JKJxVUgxsFqxTpAUg5YmHDHZURLE5bCS7baR47GT7pAgNo70xIoEJIhNYa0iiE2U85rdDc58NRq01rLaQH34QgY7vEFK/GG8AYPcNDt2GXQhGplUMWj59p5UxWmz0CFTgrQKJIgVCUg6gWieqhuw5cxo1J7XRNx2K06bgakU/lCE4oBBngSxopGVVzGwDBtG+NV/Ybz7nqQYtCCBcIRgOBKd0EDXyJErPELUmPTEprAW2xO79Wf48OZGqzZQH8FIhK3FAcqCEYp9IbLcdtwOCWRF04mmGMzA/Onn6AJJMUhZSkUHc4FGXrqdNIeN3DSpRZ2Q9MSKBKR2h2h+LHbYtDD6dzXTwzY1h8VCutOGx2HBatUpCYQx5XegaEJ6t644Hp6C/ZYbITsbDAPj9X8TuPRKjK++QfolUkdZMIIRUWQ4reiaTpZb8mCFqA3piU1hLaIntjxtwO6BDoMh5I0O2Pr2SWjTJ+k9r4lElGJLkT82yCvNaSPDJZk5oumpMl8sxQAzOi2ypBikhoip2FESxGbVaZNmJ9Nll4oEVZGeWJGABLEpLKWD2H1zXjsOgdNfgLCvTtPDNrWSQJjdvhC7y8IEwgZ56Q6sUpRcJEnCKgann4rtHEkxaK6KfGF8QYO8DCcuq4UO2U40ZJBepSSIFQnIt65oWuXB677TwwZLwbu1ztPDNrV0p3VP/VgroFESMJLdJNGKRVMMHoxPMXhDUgyaq7Bh4gsaeJxWbBaNLI9NAlgh6kB6YlNYSvXEKgXrv4z2vG74+o/lzWTAVl34wxG2lwYoDRiU+sO0SbPjkGoFIslUmS82UYKkGDRPO0uDRExFXoYDl81KuwzpLa+W9MSKBCSITWEpFcRuWw7/PPSP2ykcvO5thzdIWTDMjpIQaJCXbpe6naJZkIkSmid/KMLushAZbjtpDgv5GU6cVvnxWy0JYkUCkk4gGodS0dm0IDo9bGZH6HZ4s6s2UF+ZLhu6ppPusmJETHxBM9lNEgKoJMVAqhgklVKKEn8Yq0UnzWEhzWGTAFaIepCe2BTWLHti9x6w5cyE0575Y3rYUBmktU35wHVfu3whSgNhdnpDGBGTvHQHFl16Y0XzIVUMmoeSgIHXH6ZNmgOXzUKHbBcWuXJTM9ITKxKQnljRMJSCtQviB2ytmgtrv4CwPzpgK7NjiwtgAbJcNiyaRobTimkqvEEZ5CWaF83jxn7ZxTj/8QT6fgMBMAuWELjiakIvvIQKBJLcwpYvYirKAmFcdisOm07mnvOGEKLuJIgV9ZMoeIU9aQNTof0BKVFtoD50TSPTbcNu1XE5rJQFDMIRSSsQzU8sxeAvUsWgqZX4w4BGuita2SRdaksLUW+STpDCkp5O4N8N/zq3WU4PmwxbSgIEwgbbi/8oYC5EcyUpBk0nFI6w0xsizWUjw2mlbXq0KoGoBUknEAlIT6yoO3s6hPcM3mphA7bqItttQ9c00pxWguEIgVAk2U0SolLxKQaDAEkxaCxFAQOrHh3M5bRZJIAVooFIT2wKa9Ke2PIBW1t/hOFX/TE97Pqvo3+3wp7XRHZ6g3iDYXaUhgBFXrpDSm6JZk8pRWT+AkLPPg+7dwOg5eViu/wSLCNH1OgYDoVC2O1y9WFfZcEIxb4QmW47HoeF9pku7DK7X+1JT6xIQD5Jomr75rx+cg9sXAjFv0cHbHU/rNX2vCaS5Y6W3MpwWjEiirKg9MaK5k/TNKxjRuN6bgbW004BXUft2Eno/ocI3n435ubfqnz8kiVLGTjoQJYuXdYk7U0VplKUBsLYbBY8DgvpTpsEsEI0IPk0icQqG7CV0wO8O1JmetimFp2K1obTbsFhs+ANGERMudghUkNdUwzeeWcO4XCYd96Z05TNbfa8AQPTVGQ6rVg0jUyX/NgXoiFJOkEKa5R0gr3rvMqArTpRKH7fHcBvRNhZGsRlt5DlltdMpJbKUwwuxTLy0FiKgWmaDDtkBDt3FpKXl8uihV+j69I/Eo6Y7CwN4rRZyfbYyHbbyXDKeaDOJJ1AJCBnGhHPjMD71+1TKqt1D9iqLQ2NLI8Nm0XD7bDgCxqEDSm5JVJL5SkGD8alGCxZspSdOwsB2LFjJ0skpQCAEr8BaGS4rNgsOulOuWolREOTT5WIZ7HCIVfCD89Jz2s9eOxWSq0G6U5FIBSh2B8mN92R7GYJUWvlKQbW8eNY9dBUFi9bBp9/Dgu+wHLwED4r88at/8wzMznqqLFxyw4eMoTu3bs1YauTKxiOEAxHSHfZsOga2W47GjLAU4iGJkGsqGjASdDtcHBnJ7slKS3LbSNYEsHjtFHiC+EPRXDZZZ50kZr0bl35p8POnK1b/1j4XsUBXx9//F8+/vi/cctOOfkkpk17tLGb2CwopSj2/1FSy2234rLJ516IxiDpBKIiTQcpC1VvTquFNIeNNIcFq1WnxB+WGZFESrv77jsYO3ZMrR4zbtyR3HXX7Y3UoubHFzQxIibpLiu6pks+vBCNSIJYIRpRltuGjkam00bEVJRKyS2RwnJycnhu5jPce89d2Gw2dD3xj11d17HZbNx3793MfHYGOTk5TdzS5IiY0ZJadpsFl91CusuGTQa5CdFo5NMlRCMqL6vjsOm47FbKAmEpuSVSmqZpnH/+JN6bMxuLJXFGmkUp3rnnLiZNmtiqJvvwBg1Mpch0WbFqOpkymEuIRiVBrBCNLN1lxarrpLusgEaJP5zsJglRb5qmEQ4nPpbDShGZ+XyNJkpoKcKGSVnAwO2IViPIdFvRW1EAL0QySBArRCPT0Mjx2LDqGh6nFX8oQigsaQUitc2d91Hs7/Le1r17XT8uLa3RRAktRUnAQNeiJbXse/LhhRCNS4JYIZqAy2bFabNEB3npOkUBI9lNEqJePvxgbuzvwYMP5M03XuPAAw+ILfvYYUfLyQHDwHjj3wQuvRLjq69b5ODGQChaUivNaUPXNLJlMJcQTUKCWCGaSLbbjq5ppDmtGIZJmQzyEilq/foN/Lp6NZqmce21V/PmG68xbNhQ/v3mv7j2mj+jaRq//v47W++8bZ+JEh5KmRSDoGFS5DcIVjNRiVKKkkAYq0XH49DxOKw4rVJSS4imIFnnQjQRu0Unfc+0k2VhC6WBMC67LnlzIuW4PW6OOmocF190AcOHHxJbbrVaueGG6xkx4lD+7/lZeHLbxCZKCD09A/Onn2MpBtbTT8V2zploTmcSn0k8f9jk09UlvPNzEasLg7Hlvdo4OG1QFmN7ZeC0xff9eIMRjIiiTZotWlLLJb2wQjQVTbXEazutxObNm+ncuTObNm2iU6dODbfh0m3g2wWurIbbpgAgohRbivz4wxF2lgZJc1rJkC890QoopYjMX0B45guoXbsA0PJysV1+CZaRI5JexeDHLT5un/cbJUETDdj7i7H8doZD54FjOrJ/ezcQLam1oySI3WohJ81GpssuQWxjiYTB8ENWF7C5kt0a0UxIOoEQTciiaWS4bNite0puBQ3CkaovVwrREmiahnXMaJwz/9nsUgx+3OJj8vubKA1FP4v79uyU3y4NmUx+fxM/bvFFbwcMFHsqkGg6mS65uClEU5IgVogmlu6MluDJiJXckkFeovXQPG7sl12M8x9PoO83CCCpVQz8YZPb5/1GREF11yWVgogi2mPrN/AFDdwOCzaLRpbHhoakBgnRlCSIFaKJaWhku+1Y9Oggr2A4QkBKbolWRu/WFcfDU7DfcmNSqxh8urqEkqBZbQBbTikoCZp8uKIIi66R7rTisFrw2KUXVoimJkGsEEngsllw262xklslfqNFlh4SoirNIcXgnZ+Lat1/qgEfrizBU15SyyN5sEIkgwSxQiRJljs6mjndZcWImPiCkhsrWqdkpRgEDZPVhcEKObDVUcCGojA2i0aaw4bDIiW1hEgGCWKFSBKbrpPusuGyW7DboiW3Iqb0xorWq6lTDPzh+v1wDIQVWTKxgRBJI0GsEEmU6fxjVLOpFN6gDPISrVtTphi4bPX7Cmyf4cAidZ6FSBoJYoVIIl3TyHRHqxW4HVbKAgbhamYIEqI1aIoUA4dVp1cbR51yYnvnOshNt9e7DUKIupMgVogkS3PYsFstZLis6JpGSUB6Y4Uot3eKAdnZDZ5icOqgrDrlxJ43JE9KagmRZBLECtEMZLujo5zTnLZoya2QlNwSolx5ioHruRkNnmIwtlcGGQ6dmmYF6BpkOi1M2D+3zvsUQjQMCWKFaAacVgsehxWPQ8dq0SkJhKXklhD7aIwUA5ctOpWsRaPaQFbTokHsjNO71zufVghRf/IpFKKZyHJFS25luqwYEYU3KL2xQiTS0CkG+7d38/iJnUm3R78S941ly2+n23WendCdQ7tm1P9JCCHqTVPS3ZOyNm/eTOfOndm0aROdOnVquA2XbgPfLnBlNdw2RY0U+8MU+UPs8oYJGRHyMhxYdMm7E6IyqsxH+JXXMOa8D2Z0UKR+0GDsV12O3qljrbblD5t8trqE2T8XsbowGFveNcvGhP2yObZvFt3zPFKRIBkiYTD8kNUFbK5kt0Y0ExLEpjAJYlseheL3ogD+cISdpUFcdovUoRSiBsz1Gwg9PQPzp5+jC6xWrKefiu2cM9GczlpvL2iYlAQMfEGDDJeNbLeNHI+ddId8HpNCgliRgKQTCNGMaGhkuW3YLBpuhwVfUEpuCVETDZ1i4LDq6IDDopPhjJbBkwBWiOZFglghmhmP3YrDaiHdacWiaxT5w8lukhApoSGrGATCEYLhCGl7PofZbqkJK0RzI0GsEM1QtidacsvjtBE2THwyyEuIGqtvFQOlFCV+A6tFx+Ow4LZbcdksTdF0IUQtSBArRDPksFhIc9hIc1iwWnVKA2FMSV8Xolb2TjHQcnJqnGJQFoxgREzSnVZ0TZe8dCGaKQlihWimstw2dDQynTYipsIbkN5YIWqrPMXAOfOfNUoxiH7WDOw2Cy67hXSXDZsuX5VCNEfyyRSimbJoGpkuGw6bjstupSwYxojIIC8h6qKmKQalAQNTKTJdVqyaTqbTmsxmCyGqIEGsEM1YusuKVddJd1kBjdKAkewmCZHSqqpiEFzwFb5AGLcjWo0g021Fl5qwQjRbEsQK0YxpaOR4bFh1DY/Tij8UIRSWtAIh6qOyKgaRB/9O2kNTSCvcht0azUsXQjRfEsQK0cy5bFZcNmt0kJeuUyS9sUI0iL1TDBg0EADrj8sIXXkN6qWXMP3+JLdQCFEVCWKFSAFZbhsWTSfNacUwTMqk5JYQDUbr2gXvnXcTuObaWIpB8Uuvs+bMCyn5/ItaT5QghGgaEsQKkQLslmgA63ZYsNksUnJLiAbkDUYwTPCMPxLPc8+Qdc4EsOgY27bz2y33sOnaWwhu2JTsZgoh9iFBrBApItNli1YscFox95QBEkLUT8RUlAUMnDYrDpuFzNxM2k++iu6vzMR90AEAlC36gbXnXMz2fzwnKQZCNCMSxAqRIiyaRqbbht2q43JYKQsahKXklhD1UhowUPxRCSTDFS2p5ezZnS7/fIwO992OpU10ooTCWa9JioEQzYgEsUKkkHSHDZtFJ8NpRUOjxC+9sULUVXRKZwOPw4LNov1/e3cTIkl5gHH8qXrft6uqq2d6dtXMIuuadT1rlGBOgYAeQq6JMW68hHjRfGgkENFL8JiD4iWK+UAjEWIg1yAkXlUMBOLZeFiJ7q4uOz0fXd31lUP1ztqz4+pMz0z32/3/HQWHAmfsZ2qe93l1rO0U6OqkVhAE6n77Xp356ys6fvb+8YrBY09SMQCmjBALeOZYuyUTBurEVoO8VMbkFrAvl/v59s9S7Izard0vNjCdVKuPPzJeMXj7XX1w9mFd+O0fqBgAU0KIBTyTjD5s08jImlC9fsGfNoE92hqUyotKndgpDAKttL94E3ZnxaDOc3368p+pGABTQogFPLTSdgqDUEuxVVEyuQXsRVXXWs9yWRsqjZpLDSJjvtS/O1Yx2LliQMUAOFKEWMBDLgy1lDglLaPIGW1khcqKt0DAl7GRlSqrWiuJU6gv9xZ2J9NJtXplxeCuOyRRMQCOGiEW8FQ3trJBc5q6ebPEIS/gizR/uciVtKxaNtyertuv+MxpnXrxOd38zFPjFYMHqBgAh40QC3gqDAJ121bOhGpHVluDQnnB5BZwPb2skBRsT2otJbsf5tqLpmJwX1Mx+MF3m4rBx1QMgMNGiAU81omcWtZoObEKg0Br/XzajwTMrEFeKhuWSmMrGwY6no5Pak3KdFKtPvETnX71JSoGwBEgxAKeO9ZuTld3YqdhUSkbcsgL2Kmua61lhWwYqhMZJc4qcZO/hd1NfPttVAyAI0CIBTwXW6M0skqjsJncynI+JIEdtoaViqJSJ7YyQbivw1x7MVYxuLJiQMUAOFCEWGAOrCROJgjVTayKstYGk1vAtrJqJrWcM2pHRp3YqmWO5uNve8WAigFw4AixwBxo7nx3ipxR7Kw2spzJLWBkc1Coqmp1YysTBFpJDvct7G4+t2LARQnAvhFigTmxPHbaOmByC5CUl5U2skJJNJrUGnXIp2HXigEXJQD7RogF5kSgQMfaTs4ESiPD5BYgqdcvFAaBluNmjm4pOvq3sDtxUQJwMAixwBxpt6xi13T+TBjoMpNbWGBZXmqQl9s/D8fT1rQfaczVixKevrZi8CYVA+CLEGKBObPymcmtvKi0xSEvLKC6rtXrF7ImVBoZpS2r2JppP9Y1morBvddWDJ78tc79/FdUDIDrIMQCcyYyRkuxUxoZWRtqPctV8UYHC2ZzUKooKy3HVmHQdGFn2a4Vg3f+pf8++GMqBsDnIMQCc6ibOIVqTmGXVa2NjLexWBzN93zRrHW0jJYTJxf68XG3s2KgomDFAPgcfvxUA9gTEwTqJk4tGyppWW0OchUlh7ywGNazQlVdN4sdQajl+HBu5josYxWDs/ePrxhQMQC2EWKBObW0Y3Krx+QWFkDTAy+UjtYIVtLpTWpNynRSrT7+yOdUDH5PxQALjxALzKlAgY6nTjYMlMZW2bA5qQ3Ms7V+LhMGWoqtItsc6PLd7hWD16gYYOERYoE5ljirxFl1IiMbhlrLCj7wMLf6w1LDolIaN29fV2b8MNdeUDEArkWIBebcStvJBKE6sVVRVNoa0I3F/GkmtfJmUqsVKo1mc1JrUtevGLBigMVCiAXmXMs0AbYdGTlntD7IVVa8jcV8WR+UKqta3aT5pW0lmZ+3sLthxQAgxAILYSVxzWJBbFVVtTYHHPLC/CirWptZrthZRS7UcuJkPZnUmgQVAyy6+f8pB6AwCNRtjya3IquNrFDO5BbmRK+fSwq2FzmWE/8Pc+3FWMXg7jslsWKAxUCIBRbEUuTkTDi6wShQr8/bWPhvmJfqD0ulkZEzgY61nQL5Oak1qfjMaZ164VlWDLAwCLHAAjmetmTCQJ3YapCXypjcgufWsmL7ezp2Ru05mNSaxHUrBo89ScUAc4UQCyyQeLSbmUZG1oTq9Zncgr+2BqXyolJnDie1JrVrxeDtd/XB2YdZMcDcIMQCC6bbdgpHV3EWZaXNAW9j4Z+qrtXLclkbKo2MlmKnyMzfpNakdlYM6jxnxQBzgxALLBgXNqe345ZR5Iw2soLJLXhnIytVVfXV5Y05n9SaBBUDzCtCLLCAlmMrGzSnuKu61nrGIS/4o/kLQq6kZdWyzS9lJljMw1x7cf2KASsG8A8hFlhAYRBoJW3WCtLYamtQKC+Y3IIfelkhKdDyaFJrKV7sw1x79dmKgb3xhlHF4DW9/wAVA/iFEAssqLRlFVmjpdjKhIHW+vm0Hwn4QoO8VDYs1Rl93x5PWws7qTWJKxWD215/+WrF4GMuSoBfCLHAAltpN6e609hpWFTqDznkhdlV17XW+oVsGKoTNXNaieMw1yTGKgZ33SGJixLgD0IssMBia5RGVmkrHE1u5fwpETNra1CpKCstJVZhEDKpdYDiM6d16sXnuCgBXiHEAguuOd0dqps4lVWtdSa3MIOa781cLWeUtIyWEicX8hF2kMZWDB783viKARUDzCD+DwAsODua3IpcqNhZbWY5k1uYORuDQlVVazm2zaQWh7kOjemkWv3Fozr96ktUDDDTCLEArp7yTqykQD0OeWGG5GWlzaxQEjWTWt1RlxuHK779tlHF4CkqBphJhFgAChToWNvJmUBpZNQflhrm1AowG3r9QmEQqJtYtazRUkQX9qg0FYP7qBhgJhFiAUiS2i2r2Jnt6aI1LkDADMiGpQZ5M6kVBs0vWzh6VAwwiwixALZdmdzqxE55UWmLQ16Yorqu1ctyWRMqjYzSllVsmdSaJioGmCWEWADbImO0FDulkZG1oXpZrooPJUzJxqBUUTaHuZjUmh1UDDArCLEAxnRH99CvJE5VVWsj420sjl5Z1drMCkXOKG4ZLSdOlkmtmbJdMeCiBEwJGyUAxpgg0HLiVNa1kpbV5iBX0gplQ06D4+isZ4VqqQmvQahuwsfVrLpyUULvjX/q/PMvqvz0kj59+TWt/f0fWn3iUS1965sKWJPAIQhqCize+vDDD3XLLbfo3LlzOnny5MF94fXz0tYlKVk5uK8Jr9Sq9b/LmQZFqQu9AT03TEUaW3UTpxs7kdIWIdYH5camPvndK7r0+t+kspIkpd/4ulZ/+TNFt94ywRfOpaIvrZySXHJATwvfEWI9RojFYernpS6sZ8qGpXIuP8AUpJFR4qxOLMfTfhTsUfb+Bzr/m+e19e//NP/AWt3w0Pd1449+qDDZRwglxGIXhFiPEWJx2C5uDLQ1ZGoL03OiGysyLBL4qK5r9d54U+eff0Hlp5ckSXb1K/urGBBisQtCrMcIsTgKtfhfBKYnEF1K3x1IxYAQi10QYj1GiAUA+GKiigEhFrtgrwQAABy6KysGNz/zNBcl4EAQYgEAwJFoLkq4t7ko4ez9XJSAiRBiAQDAkTKdVKuPP8JFCZgIIRYAAEwFFQNMghALAACmhooB9osQCwAApm6sYnD3nZI+UzF44Y+q+tmUnxCzhhALAABmRnzmtE698Ox4xeBPf9H7D/1Um2+9M+3HwwwhxAIAgJmya8XgwicKEq4gxlV22g8AAACwmysVg+537tPmW2+p/bU7p/1ImCGEWAAAMNPiM19VfOvqtB8DM4Y6AQAAALxDiAUAAIB3CLEAAADwDiEWAAAA3iHEAgAAwDuEWAAAAHiHEAsAAADvEGIBAADgHUIsAAAAvEOIBQAAgHcIsQAAAPAOIRYAAADeIcQCAADAO4RYAAAAeIcQCwAAAO8QYgEAAOAdQiwAAAC8Q4gFAACAdwixAAAA8A4hFgAAAN4hxAIAAMA7hFgAAAB4hxALAAAA7xBiAQAA4B1CLAAAALxDiAUAAIB3CLEAAADwDiEWAAAA3iHEAgAAwDuEWAAAAHiHEAsAAADvEGIBAADgHUIsAAAAvEOIBQAAgHcIsQAAAPAOIRYAAADeIcQCAADAO4RYAAAAeIcQCwAAAO8QYgEAAOAdQiwAAAC8Q4gFAACAdwixAAAA8A4hFgAAAN4hxAIAAMA7hFgAAAB4hxALAAAA7xBiAQAA4B1CLAAAALxjp/0A2L+iKCRJH3300cF+4Y2LUv+yFG8d7NcFAGA/ykIqM2kjlFxyoF/6xIkTspY45CP+q3ns4sWLkqR77rlnyk8CAICfzp07p5MnT077MbAPQV3X9bQfAvuTZZnee+893XTTTfwWCQDAPvAm1l+EWAAAAHiHg10AAADwDiEWAAAA3iHEAgAAwDuEWAAAAHiHEAsAAADvEGIBAADgHUIsAAAAvEOIBQAAgHf+D1B8sMXC5iJHAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Schematic: U(a,b) (a hexagon) intersected with two constraint half-planes.\n", + "fig, ax = plt.subplots(figsize=(5.0, 4.2))\n", + "\n", + "theta = np.linspace(0, 2 * np.pi, 7)\n", + "hex_pts = np.column_stack([np.cos(theta), np.sin(theta)]) * 1.5\n", + "hex_patch = mpatches.Polygon(\n", + " hex_pts,\n", + " closed=True,\n", + " alpha=0.15,\n", + " facecolor=\"C0\",\n", + " edgecolor=\"C0\",\n", + " linewidth=2,\n", + " label=r\"$U(a,b)$\",\n", + ")\n", + "ax.add_patch(hex_patch)\n", + "\n", + "xx = np.linspace(-2, 2, 50)\n", + "ax.fill_between(xx, -2, 0.5 * xx + 0.5, color=\"C1\", alpha=0.10)\n", + "ax.plot(\n", + " xx,\n", + " 0.5 * xx + 0.5,\n", + " \"--\",\n", + " color=\"C1\",\n", + " linewidth=1.5,\n", + " label=r\"$D_1\\cdot P \\geq 0$\",\n", + ")\n", + "ax.plot(\n", + " xx,\n", + " -0.6 * xx - 0.2,\n", + " \"-\",\n", + " color=\"C3\",\n", + " linewidth=1.5,\n", + " label=r\"$D_2\\cdot P = 0$\",\n", + ")\n", + "\n", + "ax.scatter(\n", + " [1.05],\n", + " [-0.65],\n", + " color=\"C0\",\n", + " s=80,\n", + " zorder=5,\n", + " label=r\"unconstrained $P^\\star$\",\n", + ")\n", + "ax.scatter(\n", + " [0.55],\n", + " [-0.53],\n", + " color=\"k\",\n", + " s=80,\n", + " marker=\"*\",\n", + " zorder=5,\n", + " label=r\"constrained $P^\\star$\",\n", + ")\n", + "\n", + "ax.set_xlim(-2.2, 2.2)\n", + "ax.set_ylim(-1.9, 1.9)\n", + "ax.set_aspect(\"equal\")\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.legend(loc=\"upper left\", fontsize=9, framealpha=0.95)\n", + "ax.set_title(\n", + " \"Transport polytope intersected with linear constraints (schematic)\"\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "e1288ac8", + "metadata": {}, + "source": [ + "### The dual potential\n", + "\n", + "Adding entropy regularisation and Lagrange multipliers $f \\in \\mathbb{R}^n,\\ g \\in \\mathbb{R}^m$ for the marginal constraints (as in vanilla Sinkhorn) and $\\alpha = (\\alpha_1, \\ldots, \\alpha_{K+L})$ for the new constraints, eliminating the primal variables yields a **concave** dual function\n", + "$$\n", + "\\boxed{\\;f_{\\mathrm{dual}}(f, g, \\alpha) \\;=\\; -\\varepsilon \\sum_{ij} \\exp\\!\\big((f_i + g_j - C^{\\mathrm{eff}}_{ij})/\\varepsilon\\big) \\;+\\; \\langle f, a\\rangle + \\langle g, b \\rangle \\;-\\; \\varepsilon \\sum_{k=1}^{K} \\exp(-\\alpha_k/\\varepsilon - 1)\\;}\n", + "$$\n", + "where the **effective cost** is shifted by the constraint duals,\n", + "$$C^{\\mathrm{eff}} \\;:=\\; C - \\sum_{m=1}^{K+L} \\alpha_m D_m.$$\n", + "\n", + "The intermediate transport plan associated with $(f, g, \\alpha)$ is\n", + "$$P_{ij} \\;=\\; \\exp\\!\\big((f_i + g_j - C^{\\mathrm{eff}}_{ij})/\\varepsilon\\big).$$\n", + "\n", + "This is the **key observation** that drives our implementation:\n", + "\n", + "> The formula for $P$ above is *exactly* the formula used by vanilla entropic Sinkhorn, with $C$ replaced by $C^{\\mathrm{eff}}$.\n", + "\n", + "So we can reuse OTT-JAX's stabilised log-domain machinery verbatim: build a `Geometry` whose `cost_matrix` is $C^{\\mathrm{eff}}$, and the marginal updates become *one method call* per row/column. We then only need to add a dedicated update for the constraint duals $\\alpha$.\n", + "\n", + "We use the OTT-JAX convention $\\varepsilon = 1/\\gamma$ throughout, where $\\gamma$ is the entropy regularisation strength in the paper's notation. Maximising $f_{\\mathrm{dual}}$ over $(f, g, \\alpha)$ is equivalent to solving the entropic constrained problem (the minimax argument is standard; see App. H of the paper)." + ] + }, + { + "cell_type": "markdown", + "id": "6ea37f79", + "metadata": {}, + "source": [ + "## 3. Algorithm: wrapping OTT-JAX Sinkhorn\n", + "\n", + "The paper's algorithm alternates three updates on $(f, g, \\alpha)$:\n", + "\n", + "1. **Row scaling**, closed form, identical to vanilla Sinkhorn on $C^{\\mathrm{eff}}$:\n", + "$$f \\leftarrow \\varepsilon\\big(\\log a - \\log P\\mathbf{1}\\big) + f.$$\n", + "\n", + "2. **Column scaling**, also identical:\n", + "$$g \\leftarrow \\varepsilon\\big(\\log b - \\log P^\\top\\mathbf{1}\\big) + g.$$\n", + "\n", + "3. **Constraint dual update**, the new step. We jointly optimise over $\\alpha$ and an auxiliary $t \\in \\mathbb{R}$ that re-normalises the total mass:\n", + "$$(\\alpha, t) \\leftarrow \\arg\\max_{\\tilde\\alpha, \\tilde t}\\; f_{\\mathrm{dual}}(f + \\tilde t \\cdot \\mathbf{1},\\, g,\\, \\tilde\\alpha),$$\n", + "then absorb $t$ into $f$. The auxiliary $t$ ensures $\\sum_{ij} P_{ij} = 1$ between iterations, which is what keeps the gradients of $f_{\\mathrm{dual}}$ bounded by $\\|D_m\\|_\\infty$.\n", + "\n", + "Steps 1 & 2 are textbook Sinkhorn : we do not want to reimplement them. OTT-JAX's `Geometry.apply_lse_kernel` does exactly what we need in a numerically stable way." + ] + }, + { + "cell_type": "markdown", + "id": "d3108ce1", + "metadata": {}, + "source": [ + "### Reusing `Geometry.apply_lse_kernel`\n", + "\n", + "Given a current $\\alpha$, we build a *fresh* `Geometry` with cost $C^{\\mathrm{eff}} = C - \\sum_m \\alpha_m D_m$ and call its `apply_lse_kernel(f, g, eps, axis=...)` method. With the OTT-JAX convention $P_{ij} = \\exp((f_i + g_j - C^{\\mathrm{eff}}_{ij})/\\varepsilon)$, this method returns\n", + "$$\\varepsilon \\log \\sum_j \\exp\\!\\big((g_j - C^{\\mathrm{eff}}_{ij})/\\varepsilon\\big)$$\n", + "for `axis=1` (and the symmetric quantity for `axis=0`). The Sinkhorn updates rewrite as\n", + "$$f_i \\leftarrow \\varepsilon \\log a_i - \\varepsilon \\log \\sum_j \\exp\\!\\big((g_j - C^{\\mathrm{eff}}_{ij})/\\varepsilon\\big),$$\n", + "which is one call to `apply_lse_kernel` plus one subtraction. This delegates all the log-sum-exp stabilisation, kernel materialisation, and numerical edge cases to OTT-JAX." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8ebf4d1f", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:02.805716Z", + "iopub.status.busy": "2026-05-02T06:39:02.805559Z", + "iopub.status.idle": "2026-05-02T06:39:02.810658Z", + "shell.execute_reply": "2026-05-02T06:39:02.809814Z" + } + }, + "outputs": [], + "source": [ + "def make_effective_geometry(C, Ds, alphas, eps):\n", + " \"\"\"Build the OTT-JAX Geometry whose cost is C - sum_m alpha_m D_m.\n", + "\n", + " The constrained Sinkhorn iteration on (f, g) with constraint duals alpha\n", + " is just a vanilla Sinkhorn iteration on this geometry.\n", + " \"\"\"\n", + " if Ds.shape[0] > 0:\n", + " C_eff = C - jnp.einsum(\"k,kij->ij\", alphas, Ds)\n", + " else:\n", + " C_eff = C\n", + " return geometry.Geometry(cost_matrix=C_eff, epsilon=eps)\n", + "\n", + "\n", + "def sinkhorn_marginal_updates(C, Ds, alphas, a, b, f, g, eps):\n", + " \"\"\"One row-then-column Sinkhorn step on the effective geometry.\n", + "\n", + " Uses OTT-JAX's stabilised log-domain kernel application. This is exactly\n", + " what `sinkhorn.Sinkhorn` does internally for the marginal updates.\n", + " \"\"\"\n", + " geom_eff = make_effective_geometry(C, Ds, alphas, eps)\n", + " # Row update: f <- eps * (log a - log K @ exp(g/eps))\n", + " log_Kg, _ = geom_eff.apply_lse_kernel(f, g, eps, axis=1)\n", + " f = eps * jnp.log(a) - log_Kg\n", + " # Column update with the freshly updated f\n", + " log_Kf, _ = geom_eff.apply_lse_kernel(f, g, eps, axis=0)\n", + " g = eps * jnp.log(b) - log_Kf\n", + " return f, g" + ] + }, + { + "cell_type": "markdown", + "id": "caff3a61", + "metadata": {}, + "source": [ + "### Sanity check: with no constraints, this *is* vanilla Sinkhorn\n", + "\n", + "Before adding the constraint-dual machinery, let's verify that our wrapper reproduces OTT-JAX's `sinkhorn.Sinkhorn` exactly when the constraint set is empty. This both validates our log-domain conventions and shows that the code path we're building strictly extends, rather than replaces, the existing solver." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "07333ce9", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:02.812608Z", + "iopub.status.busy": "2026-05-02T06:39:02.811953Z", + "iopub.status.idle": "2026-05-02T06:39:10.423750Z", + "shell.execute_reply": "2026-05-02T06:39:10.422674Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "our wrapper: = 4.61226452\n", + "OTT Sinkhorn: = 4.61226452\n", + "matrices agree to: 6.50e-14\n" + ] + } + ], + "source": [ + "# Same toy point-cloud problem as in section 1, now solved through our wrapper\n", + "# with no constraints: should give the identical result as `sinkhorn.Sinkhorn`.\n", + "C_test = pointcloud.PointCloud(x_toy, y_toy).cost_matrix\n", + "eps_test = 0.05\n", + "\n", + "f_w = jnp.zeros(n_toy)\n", + "g_w = jnp.zeros(m_toy)\n", + "alphas_empty = jnp.zeros(0)\n", + "Ds_empty = jnp.zeros((0, n_toy, m_toy))\n", + "\n", + "# 2000 iterations to fully converge (Sinkhorn is sublinear in iterations).\n", + "for _ in range(2000):\n", + " f_w, g_w = sinkhorn_marginal_updates(\n", + " C_test, Ds_empty, alphas_empty, a_toy, b_toy, f_w, g_w, eps_test\n", + " )\n", + "\n", + "# OTT-JAX reference solver, very tight tolerance\n", + "geom_ref = geometry.Geometry(cost_matrix=C_test, epsilon=eps_test)\n", + "out_ref = sinkhorn.Sinkhorn(threshold=1e-12, max_iterations=5000)(\n", + " linear_problem.LinearProblem(geom_ref, a=a_toy, b=b_toy)\n", + ")\n", + "\n", + "P_w = jnp.exp((f_w[:, None] + g_w[None, :] - C_test) / eps_test)\n", + "print(f\"our wrapper: = {float(jnp.sum(P_w * C_test)):.8f}\")\n", + "print(f\"OTT Sinkhorn: = {float(jnp.sum(out_ref.matrix * C_test)):.8f}\")\n", + "print(f\"matrices agree to: {float(jnp.max(jnp.abs(P_w - out_ref.matrix))):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9db0349f", + "metadata": {}, + "source": [ + "Agreement to machine precision: every tool we'll use to *verify* the constrained algorithm (convergence of dual potentials, KL marginal violations, etc.) therefore behaves identically to OTT-JAX's reference implementation. The constraint-dual block we add next is a strict extension." + ] + }, + { + "cell_type": "markdown", + "id": "aa683e3c", + "metadata": {}, + "source": [ + "### The constraint dual update\n", + "\n", + "For the third update we need the gradient and Hessian of $f_{\\mathrm{dual}}$ with respect to $(\\alpha, t)$. With $P$ the current intermediate plan,\n", + "$$\\partial_{\\alpha_k} f_{\\mathrm{dual}} = -P \\cdot D_k + \\exp(-\\alpha_k/\\varepsilon - 1)\\quad (k\\le K),\\qquad \\partial_{\\alpha_l} f_{\\mathrm{dual}} = -P \\cdot D_l\\quad (l > K),$$\n", + "$$\\partial_t f_{\\mathrm{dual}} = 1 - \\sum_{ij} P_{ij}.$$\n", + "The Hessian is dominated by terms of the form $-P \\cdot (D_m \\odot D_{m'}) / \\varepsilon$ which we evaluate from the same $P$ matrix. We work in *rescaled* form by factoring out the maximum of $\\log P$, Newton's direction $-H^{-1}\\nabla$ is invariant under any common positive rescaling of gradient and Hessian, so this only affects numerics." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "33e8716d", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:10.426190Z", + "iopub.status.busy": "2026-05-02T06:39:10.426011Z", + "iopub.status.idle": "2026-05-02T06:39:10.439044Z", + "shell.execute_reply": "2026-05-02T06:39:10.438205Z" + } + }, + "outputs": [], + "source": [ + "@partial(jax.jit, static_argnames=(\"K\",))\n", + "def grad_hess_alpha_t(C, Ds, K, f, g, alphas, t, eps):\n", + " \"\"\"Gradient and Hessian of the dual w.r.t. (alpha, t), rescaled by exp(-m).\n", + "\n", + " Newton's direction is invariant under uniform rescaling of (grad, hess);\n", + " pulling out the max of log P keeps everything in O(1).\n", + " \"\"\"\n", + " if Ds.shape[0] > 0:\n", + " C_eff = C - jnp.einsum(\"k,kij->ij\", alphas, Ds)\n", + " else:\n", + " C_eff = C\n", + " log_P = (f[:, None] + t + g[None, :] - C_eff) / eps # log P with t shift\n", + " m_max = jnp.max(log_P)\n", + " P_tilde = jnp.exp(log_P - m_max) # rescaled plan, max entry = 1\n", + "\n", + " # Gradient (rescaled)\n", + " PD = jnp.einsum(\"ij,kij->k\", P_tilde, Ds)\n", + " g_alpha = -PD\n", + " rescaled_slack = jnp.exp(-alphas / eps - 1.0 - m_max)\n", + " mask = jnp.arange(alphas.shape[0]) < K\n", + " g_alpha = g_alpha + jnp.where(mask, rescaled_slack, 0.0)\n", + " g_t = jnp.exp(-m_max) - jnp.sum(P_tilde)\n", + " grad = jnp.concatenate([g_alpha, jnp.array([g_t])])\n", + "\n", + " # Hessian (rescaled)\n", + " PtD = P_tilde[None, :, :] * Ds\n", + " H_aa = -(1.0 / eps) * jnp.einsum(\"kij,lij->kl\", PtD, Ds)\n", + " H_aa = H_aa + jnp.diag(jnp.where(mask, -(1.0 / eps) * rescaled_slack, 0.0))\n", + " H_at = -(1.0 / eps) * jnp.einsum(\"ij,kij->k\", P_tilde, Ds)\n", + " H_tt = -(1.0 / eps) * jnp.sum(P_tilde)\n", + " top = jnp.concatenate([H_aa, H_at[:, None]], axis=1)\n", + " bot = jnp.concatenate([H_at, jnp.array([H_tt])])\n", + " return grad, jnp.concatenate([top, bot[None, :]], axis=0)\n", + "\n", + "\n", + "def dual_value(C, Ds, K, a, b, f, g, alphas, eps):\n", + " \"\"\"Concave dual value (modulo a constant).\"\"\"\n", + " if Ds.shape[0] > 0:\n", + " C_eff = C - jnp.einsum(\"k,kij->ij\", alphas, Ds)\n", + " else:\n", + " C_eff = C\n", + " log_P = (f[:, None] + g[None, :] - C_eff) / eps\n", + " val = (\n", + " -eps * jnp.exp(jax.scipy.special.logsumexp(log_P))\n", + " + jnp.dot(f, a)\n", + " + jnp.dot(g, b)\n", + " )\n", + " if alphas.shape[0] > 0:\n", + " mask = jnp.arange(alphas.shape[0]) < K\n", + " val = val - eps * jnp.sum(\n", + " jnp.where(mask, jnp.exp(-alphas / eps - 1.0), 0.0)\n", + " )\n", + " return val\n", + "\n", + "\n", + "@partial(jax.jit, static_argnames=(\"K\", \"n_newton\"))\n", + "def newton_alpha_step(C, Ds, K, a, b, f, g, alphas, eps, n_newton=10):\n", + " \"\"\"Newton's method on (alpha, t) with backtracking line search.\n", + "\n", + " Implemented inside `lax.scan` (outer Newton steps) with an inner\n", + " `lax.while_loop` for the line search, so the whole step is JIT-friendly.\n", + " \"\"\"\n", + " KL = alphas.shape[0]\n", + "\n", + " def newton_iter(carry, _):\n", + " alphas, t = carry\n", + " grad, H = grad_hess_alpha_t(C, Ds, K, f, g, alphas, t, eps)\n", + " step = jnp.linalg.solve(H, -grad) # H is negative definite\n", + " step_alpha, step_t = step[:KL], step[KL]\n", + " f_curr = dual_value(C, Ds, K, a, b, f + t, g, alphas, eps)\n", + "\n", + " def cond_fn(state):\n", + " lr, accepted, _ = state\n", + " return jnp.logical_and(jnp.logical_not(accepted), lr > 1e-10)\n", + "\n", + " def body_fn(state):\n", + " lr, _, _ = state\n", + " f_new = dual_value(\n", + " C,\n", + " Ds,\n", + " K,\n", + " a,\n", + " b,\n", + " f + t + lr * step_t,\n", + " g,\n", + " alphas + lr * step_alpha,\n", + " eps,\n", + " )\n", + " ok = jnp.logical_and(\n", + " jnp.isfinite(f_new), f_new >= f_curr + 1e-14 * jnp.abs(f_curr)\n", + " )\n", + " return (jnp.where(ok, lr, lr * 0.5), ok, lr)\n", + "\n", + " lr_final, ok, _ = jax.lax.while_loop(\n", + " cond_fn, body_fn, (1.0, False, 1.0)\n", + " )\n", + " new_alpha = jnp.where(ok, alphas + lr_final * step_alpha, alphas)\n", + " new_t = jnp.where(ok, t + lr_final * step_t, t)\n", + " return (new_alpha, new_t), None\n", + "\n", + " (alphas, t), _ = jax.lax.scan(\n", + " newton_iter, (alphas, jnp.array(0.0)), None, length=n_newton\n", + " )\n", + " return alphas, t" + ] + }, + { + "cell_type": "markdown", + "id": "b1640af3", + "metadata": {}, + "source": [ + "### Putting it together\n", + "\n", + "We combine the marginal updates (built on `apply_lse_kernel`) and the constraint-dual update (Newton) into a single solver. Following OTT-JAX style, we return a typed output object that exposes potentials, the transport matrix, the final `Geometry`, and convergence diagnostics : mirroring the structure of `ott.solvers.linear.sinkhorn.SinkhornOutput`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "755897ae", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:10.441061Z", + "iopub.status.busy": "2026-05-02T06:39:10.440386Z", + "iopub.status.idle": "2026-05-02T06:39:10.449665Z", + "shell.execute_reply": "2026-05-02T06:39:10.448879Z" + } + }, + "outputs": [], + "source": [ + "class ConstrainedSinkhornOutput(NamedTuple):\n", + " \"\"\"Mirrors the structure of `ott.solvers.linear.sinkhorn.SinkhornOutput`.\"\"\"\n", + "\n", + " f: jnp.ndarray # row potential\n", + " g: jnp.ndarray # column potential\n", + " alphas: jnp.ndarray # constraint duals (length K + L)\n", + " matrix: jnp.ndarray # final transport plan\n", + " geom: geometry.Geometry # final effective geometry C - sum alpha_m D_m\n", + " history: dict # per-iteration diagnostics\n", + "\n", + " @property\n", + " def reg_ot_cost(self):\n", + " return jnp.sum(self.matrix * self.geom.cost_matrix)\n", + "\n", + "\n", + "def constrained_sinkhorn(\n", + " C,\n", + " a,\n", + " b,\n", + " Ds_ineq,\n", + " Ds_eq,\n", + " eps,\n", + " n_iters=200,\n", + " n_newton=10,\n", + " track=False,\n", + " track_duals=False,\n", + "):\n", + " \"\"\"Solve constrained entropic OT (Algorithm 1 of Tang, Liu & Ying 2025).\n", + "\n", + " The marginal updates reuse `Geometry.apply_lse_kernel` from OTT-JAX; the\n", + " constraint dual update is a Newton step on the new variables alpha.\n", + "\n", + " Parameters\n", + " ----------\n", + " C : (n, m) cost matrix.\n", + " a, b : marginals.\n", + " Ds_ineq : (K, n, m) inequality constraint matrices, each in homogeneous\n", + " form `D_k . P >= 0`.\n", + " Ds_eq : (L, n, m) equality constraint matrices, each in homogeneous\n", + " form `D_l . P = 0`.\n", + " eps : entropic regularisation (the OTT-JAX `epsilon`; equals 1/gamma in\n", + " the paper's notation).\n", + " track : if True, log per-iteration cost / marginal error / constraint\n", + " violation.\n", + " track_duals : if True, log the dual variables (f, g, alphas) at each\n", + " iteration -- useful for visualisation, slightly more memory.\n", + "\n", + " Returns\n", + " -------\n", + " ConstrainedSinkhornOutput.\n", + " \"\"\"\n", + " n, m = C.shape\n", + " K, L = int(Ds_ineq.shape[0]), int(Ds_eq.shape[0])\n", + " Ds = (\n", + " jnp.concatenate([Ds_ineq, Ds_eq], axis=0)\n", + " if (K + L) > 0\n", + " else jnp.zeros((0, n, m))\n", + " )\n", + " has_c = (K + L) > 0\n", + "\n", + " f = jnp.zeros(n)\n", + " g = jnp.zeros(m)\n", + " alphas = jnp.zeros(K + L) if has_c else jnp.zeros(0)\n", + "\n", + " hist = {\n", + " \"iter\": [],\n", + " \"cost\": [],\n", + " \"row_err\": [],\n", + " \"col_err\": [],\n", + " \"viol\": [],\n", + " \"f\": [],\n", + " \"g\": [],\n", + " \"alphas\": [],\n", + " }\n", + "\n", + " for it in range(n_iters):\n", + " # 1 & 2: row + column updates via OTT-JAX log-domain kernel.\n", + " f, g = sinkhorn_marginal_updates(C, Ds, alphas, a, b, f, g, eps)\n", + " # 3: Newton step on the constraint duals (and shift `t`).\n", + " if has_c:\n", + " alphas, t = newton_alpha_step(\n", + " C, Ds, K, a, b, f, g, alphas, eps, n_newton\n", + " )\n", + " f = f + t\n", + "\n", + " if track or track_duals:\n", + " geom_eff = make_effective_geometry(C, Ds, alphas, eps)\n", + " P = jnp.exp((f[:, None] + g[None, :] - geom_eff.cost_matrix) / eps)\n", + " hist[\"iter\"].append(it)\n", + " if track:\n", + " hist[\"cost\"].append(float(jnp.sum(P * C)))\n", + " hist[\"row_err\"].append(float(jnp.max(jnp.abs(P.sum(1) - a))))\n", + " hist[\"col_err\"].append(float(jnp.max(jnp.abs(P.sum(0) - b))))\n", + " if has_c and K > 0:\n", + " v_ineq = float(\n", + " jnp.sum(\n", + " jnp.abs(\n", + " jnp.minimum(\n", + " jnp.einsum(\"ij,kij->k\", P, Ds[:K]), 0.0\n", + " )\n", + " )\n", + " )\n", + " )\n", + " else:\n", + " v_ineq = 0.0\n", + " if has_c and L > 0:\n", + " v_eq = float(\n", + " jnp.sum(jnp.abs(jnp.einsum(\"ij,kij->k\", P, Ds[K:])))\n", + " )\n", + " else:\n", + " v_eq = 0.0\n", + " hist[\"viol\"].append(v_ineq + v_eq)\n", + " if track_duals:\n", + " hist[\"f\"].append(np.array(f))\n", + " hist[\"g\"].append(np.array(g))\n", + " hist[\"alphas\"].append(np.array(alphas))\n", + "\n", + " geom_eff = make_effective_geometry(C, Ds, alphas, eps)\n", + " P = jnp.exp((f[:, None] + g[None, :] - geom_eff.cost_matrix) / eps)\n", + " return ConstrainedSinkhornOutput(\n", + " f=f, g=g, alphas=alphas, matrix=P, geom=geom_eff, history=hist\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "ba465cce", + "metadata": {}, + "source": [ + "When called with empty `Ds_ineq` and `Ds_eq`, `constrained_sinkhorn` reduces exactly to log-domain Sinkhorn through `apply_lse_kernel` : the strict-extension property we verified above." + ] + }, + { + "cell_type": "markdown", + "id": "1032dfac", + "metadata": {}, + "source": [ + "## 4. Experiment 1 : random assignment under constraints\n", + "\n", + "Following the paper's first numerical test, we consider a random assignment problem of size $n = 100$ with one inequality and one equality constraint. The cost matrix $C$ and the constraint matrices $D_I, D_E$ have i.i.d. $\\mathrm{Uniform}[0,1]$ entries; the source and target are uniform $a = b = \\mathbf{1}_n / n$. We solve\n", + "$$\\min_{P \\in U(a,b)} \\langle P, C \\rangle \\quad\\text{s.t.}\\quad D_I \\cdot P \\le t_I,\\ \\ D_E \\cdot P = t_E.$$\n", + "We pick $t_I = t_E = 0.5$, which is feasible." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7d57a806", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:10.451190Z", + "iopub.status.busy": "2026-05-02T06:39:10.451041Z", + "iopub.status.idle": "2026-05-02T06:39:19.046336Z", + "shell.execute_reply": "2026-05-02T06:39:19.045174Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "final transport cost = 0.0172\n", + "D_I . P = 0.4861 (constraint: <= 0.5)\n", + "D_E . P = 0.5000 (constraint: == 0.5)\n", + "row marginal error: 6.25e-05\n", + "col marginal error: 3.77e-09\n", + "alpha (constraint duals): [0.01970458 0.21334145]\n" + ] + } + ], + "source": [ + "rng = np.random.default_rng(0)\n", + "n_a = 100\n", + "C_cost = rng.uniform(0, 1, (n_a, n_a))\n", + "DI_orig = rng.uniform(0, 1, (n_a, n_a))\n", + "DE_orig = rng.uniform(0, 1, (n_a, n_a))\n", + "a_np = np.ones(n_a) / n_a\n", + "b_np = np.ones(n_a) / n_a\n", + "t_I, t_E = 0.5, 0.5\n", + "\n", + "# Convert to homogeneous D . P (>=, ==) 0 form, scaled by n so entries stay O(1):\n", + "# DI . P <= t_I becomes (t_I 1 1^T - DI) / n . P >= 0\n", + "# DE . P = t_E becomes (DE - t_E 1 1^T) / n . P = 0\n", + "D_ineq = (t_I * np.ones((n_a, n_a)) - DI_orig) / n_a\n", + "D_eq = (DE_orig - t_E * np.ones((n_a, n_a))) / n_a\n", + "\n", + "C_j = jnp.array(C_cost)\n", + "a_j = jnp.array(a_np)\n", + "b_j = jnp.array(b_np)\n", + "Di = jnp.array(D_ineq[None, ...])\n", + "De = jnp.array(D_eq[None, ...])\n", + "\n", + "# Pick eps small enough that the entropic solution is close to the LP optimum.\n", + "# In the paper's notation (gamma = 1/eps), this is gamma = 400.\n", + "eps_run = 1.0 / 400.0\n", + "\n", + "# JIT warm-up (compile cost is paid once)\n", + "_ = constrained_sinkhorn(\n", + " C_j, a_j, b_j, Di, De, eps=eps_run, n_iters=3, n_newton=5\n", + ")\n", + "\n", + "# Run with full tracking, including dual variables for the next plot.\n", + "result = constrained_sinkhorn(\n", + " C_j,\n", + " a_j,\n", + " b_j,\n", + " Di,\n", + " De,\n", + " eps=eps_run,\n", + " n_iters=200,\n", + " n_newton=10,\n", + " track=True,\n", + " track_duals=True,\n", + ")\n", + "\n", + "print(f\"final transport cost = {float(result.reg_ot_cost):.4f}\")\n", + "print(\n", + " f\"D_I . P = {float(jnp.sum(result.matrix * jnp.array(DI_orig))):.4f} (constraint: <= {t_I})\"\n", + ")\n", + "print(\n", + " f\"D_E . P = {float(jnp.sum(result.matrix * jnp.array(DE_orig))):.4f} (constraint: == {t_E})\"\n", + ")\n", + "print(\n", + " f\"row marginal error: {float(jnp.max(jnp.abs(result.matrix.sum(1) - a_j))):.2e}\"\n", + ")\n", + "print(\n", + " f\"col marginal error: {float(jnp.max(jnp.abs(result.matrix.sum(0) - b_j))):.2e}\"\n", + ")\n", + "print(f\"alpha (constraint duals): {np.array(result.alphas)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "457388a0", + "metadata": {}, + "source": [ + "The output confirms that the algorithm reaches a transport plan that respects all constraints to high precision: marginals match $a$ and $b$ within $10^{-5}$, the equality constraint is satisfied to machine precision, and the inequality is active near the threshold. The duals $\\alpha$ are non-zero because both constraints are *binding*, their values measure how strongly each constraint pulls the LP solution away from the unconstrained optimum." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "21c79d11", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:19.048684Z", + "iopub.status.busy": "2026-05-02T06:39:19.048033Z", + "iopub.status.idle": "2026-05-02T06:39:19.778999Z", + "shell.execute_reply": "2026-05-02T06:39:19.778002Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABYYAAAFyCAYAAABBWg5oAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAABAABJREFUeJzs3Xd4k+XXwPFv0jZddC+6y55llSUgUxkC4kBAmSrgq0zRCqIMAUUp4gL1p4gsFw5AVBSEspG992ihhdJdutu0ed4/QkNLW+hOx/lcV64m97NO2jR3cp77ObdKURQFIYQQQgghhBBCCCGEEDWG2tgBCCGEEEIIIYQQQgghhKhYkhgWQgghhBBCCCGEEEKIGkYSw0IIIYQQQgghhBBCCFHDSGJYCCGEEEIIIYQQQgghahhJDAshhBBCCCGEEEIIIUQNI4lhIYQQQgghhBBCCCGEqGEkMSyEEEIIIYQQQgghhBA1jCSGhRBCCCGEEEIIIYQQooaRxLAQQgghhBBCCCGEEELUMJIYFkIIIYQQQgghhBBCiBpGEsNCCCGEEKJMzJ07F5VKxY4dOyr82N27d0elUlX4cWuClStXolKpWLlypbFDKTI/Pz/8/PxKvR+VSkX37t1LvZ/72bFjByqVirlz55brcYQQJVNW7yei+pB+sXup93M/0i9WLEkMC1GOVCpVsW5VqWOpDEJDQyukYypPfn5+ksgQQhj6AbVazZUrVwpd79FHHzWs++WXX1ZghKKyqMp9X1WOvTRyvuCOGTPG2KEIIe7I6UtF1VeV+5aqHHtpSL9YuZgaOwAhqrM5c+bka1u5ciXXrl1j9OjR+c7YtWrVqmICE0IIUemYmpqSlZXF8uXLWbhwYb7lISEhbNu2zbBeZTRx4kSGDRuGj4+PsUMRZejJJ5+kY8eOuLu7GzuUItu2bZuxQyiy9u3bc+7cOZydnY0dihCiAFXp/URUDOkXy5f0ixVLEsNClKOCLn3YsWMH165dY8yYMTXuzKAQQojCOTk5UadOHVauXMn8+fMxNc37MW358uUoisLAgQNZv369kaK8P2dnZ/kQXw3Z2dlhZ2dn7DCKpV69esYOocisrKxo3LixscMQQhSiKr2fiIoh/WL5kn6xYkkpCSEqiZzaiFevXuWTTz6hefPmWFhY8MQTTwBw+/ZtgoKC6NmzJ15eXmg0GlxcXHj88cf577//CtynSqXCz8+P1NRUAgMD8fHxwdzcnPr16/PBBx+gKEq+bTZt2sQjjzyCh4cH5ubmuLm50aFDB9577708640ZM8ZQR/Lbb7+lVatWWFpa4ubmxtixY4mMjCwwpqtXr/L8888bnoObmxtDhgzh5MmT+dbNqd00d+5c/vvvP/r164eDgwMqlYqPP/6YOnXqALBz5848JTmKWosoPj6eWbNm4e/vj7W1Nba2tjRv3pzXX3+d+Pj4POtGRkYyefJk6tati7m5OU5OTgwYMIBdu3bl26+iKKxevZrOnTvj6uqKubk57u7udOvWja+++gq4e9nQtWvXgLxlR+SEgRA117hx47h16xabNm3K056VlcW3335L+/btadGiRYHbHjlyhClTptCyZUscHR2xsLCgQYMGvPbaa/ne0+D+77EJCQkAJCcnExgYiK+vL+bm5tSrV4/Zs2eTmZlZ4PtVYTWGS9IfrVy5kqeffpq6detiaWmJra0tnTt3Zu3atUX/hT7AjRs3mDp1Kg0bNsTS0hIHBwcCAgKYPXs2Wq02z7ol7b+OHz9O//79sbe3x8rKim7durFv37582yQlJbFgwQL8/f2xs7PD2toaHx8fBg0aZPh9rly58oF9X+7LUm/evMkLL7yAu7s7JiYmbNiwASj5a+Xeklc55ZCysrJ47733aNCgAebm5nh7ezN9+nQyMzPz7KOk/XaTJk0wMzMr9LPFl19+iUqlYtasWXliK6iWYmZmJkFBQbRq1QorKytsbGzo2LEjK1asKPB1WJCbN28yb948OnfuTO3atdFoNHh4ePDcc89x7ty5POvOnTuXHj16ALBq1aoCS4jdr5Zieb/uhBAPVtD7SUn/33Q6HcuXL6dz587Y2dlhYWFB8+bNWbhwYZ73zBwbNmxgxIgRNGzYEGtra2rVqkVAQACffvopOp0u3/q5v6OtWbOGdu3aYW1tXeQrU6VflH5R+sWaR0YMC1HJTJ48mb1799K/f3/69++PjY0NAOfOneOtt96ia9eu9O/fHwcHB65fv87vv//O5s2b2bRpE3379s23P61WS+/evbl58yb9+vXD1NSUDRs2MGPGDNLT0/OUu/jqq6946aWXcHNzY8CAAbi6uhITE8PZs2f58ssvmTlzZr79L1myhH///ZehQ4fSr18/du3axTfffMOOHTs4cOAATk5OhnWPHj1Kr169SEhIoH///rRo0YIrV67w22+/sWnTJjZu3Ejv3r3zHWPfvn289957dO3albFjxxIREUFAQABTpkzhk08+wdfXN099oqIkVkNCQujRowfXrl2jdevWvPTSS6hUKi5dusSyZcsYMWIEDg4OAFy7do0uXboQHh5Ot27dGDp0KBEREaxbt47NmzfzzTff5Dn+W2+9xcKFC/Hz82Pw4MHY29tz69YtTpw4wZo1axg/fjz29vbMmTOHjz/+mNu3b+f5O8jkFkLUXEOHDmXq1Kl8/fXXPPnkk4b2P//8k4iICObNm0d4eHiB23799desX7+ebt268cgjj5Cdnc3Ro0dZsmQJmzdv5sCBA4Y+JbeC3mNNTEzIyMigV69eHDx4kObNmzN06FBSU1NZvnw5p06dKvZzK05/BPDyyy/TtGlTunbtiru7O7Gxsfz111+MHDmSCxcuMH/+/GLHkNvhw4fp27cvsbGxdOnShSeeeIL09HTOnz/PwoULmTZtGvb29kDJ+6/Dhw+zaNEiHnroIcaOHcv169f59ddf6dWrF8ePH6dRo0aA/oRi37592bdvH+3bt+eFF15Ao9Fw48YNdu/ezb///kv37t1p1apVkfu+2NhYHnroIezs7HjmmWfQ6XQ4OjoCJX+tFOa5555j9+7d9OvXD1tbW/766y8WLVpEVFQU3377LUCxYr/X6NGjefPNN1m7di2vvfZavuWrVq0yrHc/Wq2Wfv36sX37dho2bMjLL79MZmYmv/32Gy+++CJ79uxhxYoVD3y+u3bt4v3336dHjx48/fTT1KpVi0uXLvHLL7/w+++/s3fvXlq2bGl4bqGhoaxatYqWLVsaTvjn/E7up7xfd0KI0ivO/1tWVhZPPfUUmzZtomHDhjz77LNYWFiwc+dOZs6cybZt2/j777/zXDE0Y8YM1Go1HTp0wNPTk9u3b7N9+3amTJnCoUOHWLNmTYFxLV68mG3btvH444/Tq1evApPOBT0X6RelX5R+sQZShBAVqlu3bgqgBAcHF9ju4eGhhISE5NsuISFBiY6OztceFhamuLu7K40bN863DFAApV+/fkpqaqqhPTIyUrGzs1Ps7OyUzMxMQ3ubNm0UjUaj3Lp1K9++7j326NGjFUAxMzNTjh49mmfZxIkTFUAZP368oU2n0ylNmzZVAGXlypV51t+6dauiUqkUFxcXJSUlxdD+7bffGp7D//73v3wxhYSEKIDSrVu3fMsepFOnTgqgzJs3L9+y+Ph4JSkpyfC4b9++CqDMnTs3z3onT55ULC0tFXNzcyUsLMzQ7ujoqHh4eCjJycn59n3v79HX11eRt2IhBKC4ubkpiqIoL730kqJWq5Vr164Zlvfv31+pVauWkpSUpMyZM0cBlC+++CLPPkJDQ5WsrKx8+16+fLkCKO+//36e9ge9xy5YsEABlCeffDLPfmNjY5X69esX+P6bE9u9fVxx+yNFUZTLly/niykjI0Pp2bOnYmpqqoSHh+dZltOPFkVGRobi5+enAMqqVavyLY+IiFC0Wq2iKKXvv7799ts823z55ZcKoLz88suGtpMnTyqAMmjQoHyx6HQ6JSYmxvD4QX1fznJAGTlypOF55FbS18q9zyWnD2vTpo0SGxtraE9OTlbq1aunqNVqJSIiosixF+bGjRuKiYmJ4u/vn2/ZhQsXFEB5+OGH88Xm6+ubp+39999XAKV3795KRkaGoT0hIUFp1qyZAig///xznm0KijcyMlJJTEzMF8vx48cVa2trpW/fvnnag4ODFUAZPXp0gc8vZ/mcOXMMbRXxuhOiJsv5XymKgt5PSvL/Nn/+fAVQJkyYkOc9ODs7Wxk3bpwCKJ9++mmebQrqC7Ozs5VRo0YpgPLff//lWZbzHc3Kyko5duxYkZ6foki/KP2i9Is1mWQjhKhgD0oMf/zxx8Xe56RJkxQgTxJBUe5+4Ll06VK+bXI+TJw6dcrQ1qZNG8XKyipPJ1aYnA8dL7zwQr5lsbGxirW1tWJlZWX4or9nzx4FUNq1a1fg/p566ikFUL7//ntDW84beatWrQrcpqQd6eHDhxVA8ff3V7Kzs++7bnh4uAIonp6eeTrLHNOmTVMA5b333jO0OTo6Kn5+fkp6evoDY5HEsBBCUfImhnPeo3I+DIeFhSkmJibKuHHjFEVRCk0MF0an0ym2trZKjx498rQ/6D22fv36ikqlKrAPWb16dYkSw0Xtj+7n119/LfCLa3ESw7/88osCKI899tgD1y1N/9W5c+d862dmZiqmpqZKQECAoS3nC/CwYcMeGE9RvwBrNBolMjLygfvL7UGvlcK+AG/dujXfvmbPnq0AyqZNm4oc+/3knKQ9cuRInvY333xTAZRvvvkmX2z3fgHOOaFx+vTpfPv//fffDV+OcytuvAMHDlTMzc3znOgoyRfginjdCVGTlVViuKj/b9nZ2Yqzs7Pi6upaYGIyISFBUalUhf7P3+vIkSMKoLzzzjt52nO+o02dOrVI+8kh/WLBpF+UfrEmkBrDQlQy7du3L3TZ3r17GTJkCN7e3pibmxtq8Xz22WeAvibUvezs7Khfv36+dm9vb4A8NZNGjhxJamoqTZs2ZcqUKfz6669ERETcN95u3brla3N0dMTf35/U1FQuXLgA6C/7AOjZs2eB+3nkkUfyrJfb/X4nJZFTk7l3796o1fd/G8yJp0uXLmg0mnzLC4p75MiRhIaG0qRJE6ZPn84ff/xBXFxcWYUvhKjmAgICaN26NStWrCA7O5tvvvmG7Oxsxo0bd9/ttFotS5cupUuXLjg4OGBiYoJKpUKtVpOYmFhgHwEFv8cmJSVx+fJlateuXWAf8vDDDxf7eRWnPwK4fv06EyZMoFGjRlhZWRn6vKeffhoouM8rqpx+oF+/fg9ctzT9V9u2bfO1mZmZ4ebmluf5Nm3alDZt2vDjjz/y0EMP8f7777N7927S09Mf/GQK4efnh6ura4HLSvpaKUxBz7Owv2tJ5Vxim7ueo06nY+3atVhZWfHMM8/cd/uc17SbmxvNmjXLt7xXr15AwX/Hgvz5558MGDAAd3d3zMzMDK/PTZs2kZGRQUxMTNGeWCEq4nUnhCi9ov6/Xbx4kZiYGGxsbFiwYAFz587Nc/voo4+wtLTMV481NjaWGTNm0KJFC2rVqmV4rwkICAAK7wuL+/1J+kXpF+8l/WLNITWGhahkateuXWD7+vXrGTx4MBYWFjz66KPUq1cPa2tr1Go1O3bsYOfOnWRkZOTbLqcO1L1yaldlZ2cb2qZOnYqrqyuff/45y5Yt49NPPwWgY8eOLFy4sMBaR25ubgXuP6f99u3beX4W9vzc3d0BDBMe5VbYNiWVcwxPT88HrluSuD/88EPq16/PihUrCAoKYtGiRajVanr16kVQUJChvpIQQhRm3LhxvPLKK/z555+sWLGCli1b0q5du/tuM3ToUNavX0/dunV54oknqF27Nubm5gB8/PHHBfYRUPD7W857X2Hv8SV5Xy5Of3T16lXat29PfHw8Dz/8MH369MHOzg4TExNDXbrCnk9RlHc/kON+zzn38zUxMWHbtm28++67/Pzzz7z55puAflbuoUOHsmjRIpydnR8Ya273+xuV9LVSmIKeZ0F/19J44okncHBw4Pvvv2fx4sVoNBq2b99OWFgYo0aNemDtxwf9Ha2srLCzsyvw73ivTz/9lClTpuDg4MCjjz6Kj4+P4eTFhg0bOHHiRKlen0WJtyxed0KI0ivq/1tsbCwAV65c4Z133inSvhMSEmjXrh0hISG0b9+eUaNG4ejoiKmpKQkJCXzyySfF6tsfdCyQflH6xbukX6w5JDEsRCWjUqkKbJ81axYajYbDhw/TpEmTPMteeukldu7cWSbHf+6553juuedITExk//79bNq0ia+//pp+/fpx4sQJGjZsmGf9wmZCzWm3s7PL8/PWrVsFrp8zMjlnvdwK+52UVE7HUJQzvyWJ28TEhIkTJzJx4kRiY2PZs2cPv/32G2vWrOHRRx/l3LlzeSblE0KIew0fPpzXX3+dCRMmEB4ezvTp0++7/uHDh1m/fj29evVi8+bNmJmZGZbpdDoWLVpU6LYFvcfmvKcV9h5f2HtiWVmyZAmxsbGsWLGC559/Ps+yH374wTCpSkmVdz9Q0piCgoIICgoiJCTEMJnrt99+y7Vr19i2bVux9ldY31ma14oxmZubM2zYML744gv++OMPnnrqKcMoqdwT9hTmQX/H1NRUbt++/cD+OSsrizlz5lC7dm2OHj1q+CKaY//+/Q9+MkVQUa87IUTFyPlfHThwIL///nuRtlm+fDkhISHMnj07XzJ5//79fPLJJ4VuW9zvT9IvSr94L+kXaw4pJSFEFXH58mWaNm2aLyms0+nYs2dPmR/P1taWPn36sHTpUl577TXS09P5+++/861XUEI6Pj6eU6dOYWVlZZjhs02bNgAEBwcXeLycjj3nsqiiMDExAYp/1rVjx44AbNmyBZ1Od991W7duDejLeBQ0m++D4nZycmLQoEGsWrWKYcOGER0dzd69e0v9HIQQ1ZutrS1Dhw4lPDwcS0tLhg8fft/1L1++DMCgQYPyfKEBOHjwIGlpacU6vo2NDfXq1ePWrVuGfee2e/fuYu2vuHKOOXjw4HzLyuJEaE4/sHnz5geuWx7914PUqVOH0aNHs337dry9vdm+fbthpExp+42yfq0UR2ljz/miu2rVKpKSkli/fj1+fn4PnL0d9K/p+vXrExkZydmzZ/Mt3759O/Dgv2NMTAwJCQl06tQp35ff5OTkAi9hLcnzNsbrTghRfho3boy9vT0HDhwo8DtFQcq7L8xN+kXpF+8l/WLNIYlhIaoIPz8/Ll26xM2bNw1tiqLwzjvvFPhGXhKbN29Gq9Xma885K2dpaZlv2Zo1azh27FiettmzZ5OSksLw4cMNnWunTp1o0qQJBw8eZO3atXnW3759O7/99hvOzs4MGjSoyPE6ODigUqkICwsr8jag7yw6derEqVOnePfdd/Mtv337NsnJyQB4eXnRp08fwsPD850tPnPmDF988QXm5uaMGDECgIyMDP799998CWdFUYiKigLy/h5zzsBev369WM9BCFH9zZs3j/Xr1/PPP/88cPSDn58fADt27MjTHhUVxYQJE0p0/DFjxqAoCm+88Uae97T4+HjmzZtXon0WVWHP559//mH58uWl3v/AgQPx8/Pjr7/+Ys2aNfmWR0ZGkpWVBZRP/3WvkJAQTp8+na89KSmJlJQUzMzMDJeglrTvy1Eer5WiKm3s7du3p1mzZvz1118sW7aM1NRURo8eXeSRcS+++CIAr732Wp7PO4mJicycOROAsWPH3ncfrq6uWFlZceTIEcNnBdDXp5wyZUqBNRRL0tdXxOtOCFFxTE1NmTJliuG9NjU1Nd86MTExHD9+3PC4sPfrY8eOsXDhwjKNT/pF6RelX6y5pJSEEFXEq6++yv/93//RunVrnn76aczMzNi7dy9nz55l4MCBbNq0qdTHePbZZ9FoNDz88MP4+fmhUqk4ePAgu3fvpl69egwZMiTfNn379qVz584MHTqU2rVrs2vXLvbt20fdunV57733DOupVCpWrVrFI488wqhRo1i3bh3+/v5cuXKFX3/9FY1Gw+rVq7GysipyvLVq1eKhhx5i3759DBw4kICAAExNTenatStdu3a977Zr166le/fuzJ49m/Xr19OjRw9UKhVXrlzhn3/+Yd++fbRq1QqAL7/8ks6dOzNr1iy2b99Ox44diYiIYN26daSlpfH1118bJhNIS0sz1FXq2LEjvr6+aLVaduzYwfHjx+nYsSM9evQwxPHoo49y6NAhnnrqKR577DEsLS3x9fVl5MiRRf49CCGqJy8vL7y8vIq0brt27ejcuTO//fYbnTp1okuXLkRGRrJ582YaNWqEh4dHsY//+uuvs2HDBtavX0/Lli3p168faWlp/Prrr3To0IHLly8/cALPknrllVf49ttveeaZZ3j66afx9PTk9OnT/P333wwZMoSffvqpVPvXaDT8/PPP9OnTh1GjRvH111/z0EMPkZmZyYULF/j333+JiorC3t6+XPqve504cYInn3yS1q1b4+/vj4eHB/Hx8YbJS1977TWsra2B0vV9UD6vlaIqbewAo0eP5o033mD27NmoVCpGjx5d5ONPmzaNv//+m7///ht/f38GDBiAVqvl119/5caNG4waNeqBk/Wo1WomT57M+++/j7+/P4MGDSIzM5Pg4GDi4uLo0aNHvtFMjRo1wtvbm927dzN8+HAaNmyIiYkJjz/+OC1atCjwOBXxuhNC3P+S+8WLFxe7ju39vP3225w6dYrly5fz559/0qtXL7y8vIiOjubKlSvs2bOHCRMm8PHHHwMwatQogoKCmDp1KsHBwTRo0IBLly4ZygaUti/MTfpF6RelX6zBFCFEherWrZsCKMHBwQW2h4SEFLrtt99+q7Rs2VKxsrJSnJyclCeeeEI5efKkMmfOnAL3CSi+vr4F7qugbb744gvlySefVOrWratYWVkpdnZ2ir+/vzJnzhwlJiYmz/ajR482bL9ixQqlRYsWioWFheLi4qK88MILyq1btwo87qVLl5TRo0crHh4eipmZmeLi4qIMHjxYOXbsWIHPF1DmzJlT6O/kypUryhNPPKE4OTkparX6gevnFhMTo8yYMUNp1KiRYm5urtja2irNmzdXAgMDlfj4+DzrRkREKBMnTlR8fX0VMzMzxcHBQenXr1++33lmZqayaNEipV+/foqPj49iYWGhODo6KgEBAcqHH36oJCcn51k/JSVFmThxouLt7a2YmpoqgNKtW7cixS+EqD4Axc3NrUjr5rx/f/HFF3naY2NjlZdfflnx9fVVzM3Nlbp16ypvvvmmkpKSovj6+ubrD4ryHnv79m1l2rRpipeXl6LRaJQ6deoos2bNUsLDwxVAGTRoUIGxlbY/UhRF2bt3r9KjRw/F3t5eqVWrltK5c2dl/fr1SnBwcIFx5/SjxXH9+nVlwoQJSp06dRSNRqM4ODgoAQEBypw5c5TMzMw865Zl/3Xv3yMsLEyZOXOm0qlTJ6V27dqKRqNR3N3dlZ49eyrr1q3Lt/39+r6QkJAH9iUlfa18++23+Z5HYb/zwrYpTb+tKPr+uCj9ZUHPQ1EUJT09XXn//fcVf39/xcLCQrGyslLat2+vfP3114pOp8u3fkHH0Wq1yocffqg0adJEsbCwUNzc3JQRI0YooaGhhs9H936eO3LkiPLII48odnZ2ikqlyvO7Kew1rSjl+7oToiYDHnjL+T8uSR9a2P+bTqdTvv/+e+XRRx9VHB0dFTMzM6V27dpKhw4dlFmzZikXL17Ms/6ZM2eUgQMHKi4uLoqVlZXSpk0b5euvvza8148ePTrP+rm/o5WE9IvSL0q/WPOoFEVRyi7NLISoKcaMGcOqVasIDg4uUg0jIYQQ1ceWLVvo06cPM2bMKPPLWYUQQgghhBAVQ2oMCyGEEEKIAuWua58jJiaG6dOnA/D0009XdEhCCCGEEEKIMiI1hoUQQgghRIGGDh1KcnIybdu2xcnJievXr7N582YSEhKYMGECbdu2NXaIQgghhBBCiBKSxLAQQgghhCjQiBEjWL16NevXr+f27dtYWVnh7+/P2LFj7zthjxBCCCGEEKLykxrDQgghhBBCCCGEEEIIUcNIjWEhhBBCCCGEEEIIIYSoYSQxLIQQQgghhBBCCCGEEDWMJIargKysLMLDw8nKyjJ2KEIIIUSVI/2oEEIIUTrSlwohRPUkieEq4NatW3h7e3Pr1i1jhyKEEEJUOdKPCiGEEKUjfakQQlRPkhgWQgghhBBCCCGEEEKIGkYSw0IIIYSoloKCgnB1daVly5bGDkUIIYQQQgghKh1JDAshhBCiWgoMDCQqKooTJ04YOxQhhBBCCCGEqHQkMSyEEEIIIYQQQgghhBA1jCSGhRBCCCGEEEIIIYQQooYxNXYAQgghhBBCCFHVKYpCTEwM6enpZGdnGzucSsfExAQLCwucnZ1RqVTGDkcIIYQQSGJYCCGEEEIIIUpFURRu3LhBUlISGo0GExMTY4dU6WRmZpKcnExGRgaenp6SHBZCCCEqAUkMCyGEEKJaCgoKIigoSEbuCSHKXUxMDElJSbi6uuLk5GTscCqt2NhYoqKiiImJwcXFxdjhCCGEEDWe1BgWQgghRLUUGBhIVFQUJ06cMHYoQohqLj09HY1GI0nhB3ByckKj0ZCenm7sUIQQQgiBJIaFEEIIUYllZGTwwgsv4OPjg62tLR07dmT//v3GDksIIfLIzs6W8hFFZGJiIldyCCGEEJWElJIQQohKSFEUdEqunygoCuiUXD8BRadfpsu1TLmzTJdrH/r2u9vp11XuHAsUw3H1+zPcv7Mg5/jkWi+nPf8+lDz7o8B1uHv8AvbH/WIq4DgUGnfxYrK3MqNzfecC/iLCWLKysvDz82PPnj14eXmxZs0aBg4cyPXr17GysqqQGBRFYcvZSDKzdAxs6VEhxxRCCCGEMJbw+FQysnTUc6ll7FCEEOVMEsNCiAqlKAoZWToys3VkZunQZuvIylb0P3WK4XGWToc2W9Ev091py9ah1el/5m7P2TYr+842Ov3j7GyFbEVBp9P/zNZhuH+3TUF352e2DsN9naI/du51H7ifO/f1Cdi8CdqcpCxKroTtnYRuTlvudYVxtPSyY+PELsYOQ+RibW3N7NmzDY9Hjx7NtGnTuHTpEi1btqyQGJZsvchn2y/jYmPOI03csNTIqEAhhBBCVE/p2mz6fbKb5IwsAvs04uVu9WSySCGqMUkMC1HDKYpCulZHamYWqZnZpGmzSc3MJjUzi7RM/f20zGxScpZnZpOuzdYnd7N0ZGTlvq9/fPd+rnW0OjLuJIOFqAg5n19VhseqXPfB8Eh1dx2NqVRYKkxycjJBQUEcOnSIQ4cOERMTw8KFC5kxY0a+dbVaLfPnz2flypVERUXRsGFDZsyYwXPPPVfqOM6fP09qaip169Yt9b6K6qk2Xny+4wrRSRl8d+AaYx+uuGMLIYQQQlSk+NRMktKzAFj09wUu3Erig6dbYGEmJ8aFqI4kMSxEFaYoCskZWcSnaIlPzSQxXUtSehZJd34m5rp/92cWielaUjLuJoKrwghVlQrM1GpMTVSYqlWYmeTcV2NmosLURJ2n3UytxkStwkStQq1WYaJCf1+Vu02Vq+2e5Xd+5lmuurud+s6yu20Y9qu+k4BUq1T6BKRKhVqlfw76trzL1SoA/Tp32/TZytz7UqvI26bW/7y7/zs/yTluznaFHycnOZp7EEBOzPq1cydYVXmTrbnaDdsZ9nGfJGyufRR2nAJjKspxZDRDuYmJiWHevHl4eXnRunVrtm7dWui648ePZ/Xq1UyYMAF/f382bNjA8OHDycrKYtSoUSWOITU1lZEjR/L2229jY2NT4v0UVx1na55s7ckvR8L5cucVnuvgg5VGPkIJIYQQovrR3fPdcOPxm4TGpPDVqLa42VoYJyghRLmRbzVCVDLp2myikzKITEwnMjGDqKR04lMyiUvNJD5FS1xKJvGpmYaf2uzyzepamplgpTHBUpPz0xQLUzUWZiZoTNWYm6oxN81130yNuYkaczMTzE3VBa9z577GVJ/UNbsnqVtQstdELQk/IYzJ3d2dGzdu4OHhQWhoKHXq1ClwvWPHjrFy5UrmzZvHrFmzABg7diw9e/YkMDCQYcOGodFoAOjduze7du0qcD/PP/88X3zxheGxVqtl8ODBNG3alJkzZ5bxs3uwyT0bsP7YDWKSM/lk2yXe7NekwmMQQgghhChvulyZ4cdbevD7iZucCL/N40v38NXItrT0tjdecEKIMieJYSEqUFa2jojb6YTFpxIel0Z4fCo3EtKJSko3JIJvp2lLdQyNiRobC9M7NzPDfVsLszyPbSxMsdSYYpUn8WuaJwlsYWqCWhKyQgjA3NwcD48HT7y2bt061Go1EyZMMLSpVComTpzI4MGDCQ4Opk+fPgBs2bKlSMfW6XSMGjUKMzMzvvnmG6OMDPdxsmJYO2++O3Cd/+28io25KWMfriuXVQohqpXMzEzDyTshRM2U+2rSyb3q07uZG6//fILIxAyG/G8/iwa3YFArT+MFKIQoU5IYFqKMKYpCZGIGl6OSuRKdbPh5PS6ViNvpZN97bc59aEzUuNiY41xLg4O1BkcrDfZWGhytzQyPHaw1OFprsLcyw9bCTJIUQgijOnLkCPXq1cPR0TFPe4cOHQA4evSoITFcVC+99BIRERH8/fffmJo++KNLYmIiiYmJhscRERHFOl5hZg1oSmhsCnsvx7J4y0WWBV+hfR1HOtR1pEMdR/w97aVOtRDCIDNLx42EtAo/rqe9ZZHfi7p3707Lli1RFIXvvvuOjh078vrrrxMYGMipU6dwdnZm3LhxzJ49G7VazWeffcaKFSs4duwYAGvXrmXkyJF89913hjryrVq1Yty4cXlOEAohqg5drsywSqViQAsP/JysGbf6MBG305ny43HORSTxeu+GmJrI5x4hqjpJDAtRCunabM5FJHL6xm1O30jkfGQSV6OSScrIeuC2dpZmeDlY4mlvSW07C9xsLXC1McfN9u59eyszqZkqhKhSbt68ibu7e772nNHGN2/eLNb+rl27xvLly7GwsMDZ2dnQ/r///Y/hw4cXuM2SJUt45513inWcorAwM+HrUW2Z+P0xtp+PIk2bzc6L0ey8GA2AvZUZL3auw5jOfthYmJX58YUQVcuNhDR6LN5R4ccNfr07dZyti7z+ihUrmDRpEvv37+fWrVv06dOH8ePHs3btWk6dOsW4ceOwt7dn6tSpdOvWjalTp5KQkIC9vT27du3CycmJnTt38txzzxEfH8+pU6fo1q1bOT5DIUR5yp0YVt/5Ltrc046NEzvz0pojHLuewJc7r3AiLIFPn22Ni425sUIVQpQBSQwLUUSKohAen8aBkDgOhcRxIjyBS1HJ9x0B7GpjTn3XWtR3rYWvkzVeDpZ4O1jh5WiJrSQNhBDVUFpaGubm+b8gqNVqzMzMSEsr3ug5X19flGLOkDlt2jTGjh1reBwREUH79u2LtY/CWGlMWTGmHTcS0th5IZoDIbEcuBrHrcR0ElK1fLj1Ir8eDWft2A54OViVyTGFEKI8NW7cmPfeew+AlStXUrduXT7++GNUKhWNGzcmJCSEoKAgpk6dir+/vyEh/Pjjj7Nz505ee+01Vq1aBcCuXbtwcHCgWbNmxnxKQohSyP31NndVQVcbC34Y15F3Np3hh4Nh7L8ay4DPdvP58DYE+Drm35EQokqQxLAQ9xGZmE7w+Sj2XYnlYIj+i39BHKzMaO5pR1N3W+rdSQTXc6mFnaUkf4UQNYuFhQUZGRn52nU6HVqtFguL8p/N2tbWFltbW4KCgggKCiI7O7vMj+Fpb8lzHXx4roMPiqIQEpPCt3tD+eHgdUJjU3nmy/38OL4jvk5FH7UnhKhePO0tCX69u1GOWxxt27Y13D937hydOnXKc8Va586dmT59OomJidja2vLwww+zc+dOOnToQFhYGBMnTmT+/PlERkayc+dOunbtKle8CVGFKQWMGM5hYWbCwqda0NrHgVkbThOZmMHQ//3HW/2bMKaTn/zvC1EFSWJYiFx0OoWTN26z/Vwk2y9EcfpGYr51zE3VtPaxJ8DXAX9Pe5p72uJpbymdoBBCoC8Zce3atXztOSUkijKBXVWjUqmo61KL+U80p2cTV/5vzREibqcz+Ydj/PpyJ6m/J0QNpTFVF6ukg7FYW+eN8d7PtPdetdGtWze+++47OnTowEMPPYSNjQ0dO3Zk165d7Ny5k1GjRpV7zEKI8pN7xHBhX3GHtPWmmYctL689yvW4VN7ZdJaj1xN4/yl/rM0lzSREVSL/saLGUxSFMzcT2Xj8BptOROQbFWxuqqZjXSfD5ELNPe0wN5UJ3oQQoiBt2rRh+/btxMXF5ZmA7sCBA4blFSUwMJDAwEDCw8Px9vaukGP2aOTK0ufaMG71YU6E32b5nhD+r1u9Cjm2EEKUVpMmTdi4cSOKohgSxPv27cPT0xNbW1tAnxgODAzk999/N9QS7tatGxs3buT48eN88803RotfCFF6BdUYLkgzDzs2TezCaz8f599zUWw6cZNzEYl8MbwNDdxsKiJUIUQZkCEsosZKSM1k+e6rPLJkJwM+28PXu0MMSWEPOwtGdPRhxZi2HJ/dm1UvtOeV7vUJ8HWUpLAQQtzHkCFD0Ol0fP7554Y2RVFYunQpLi4u9OjRo8JiCQoKwtXVlZYtW1bYMQEeberGU609AViy9SJhcakVenwhhCipV155hatXrzJ16lQuXLjAL7/8wnvvvcfrr79uWKdVq1bUqlWLH3/8ke7duwPQvXt3fvzxR2xtbWnRooWRoq/Zli5dSuvWrTE1NWXu3LlFXibEvYqaGAawszLjq5FtCezTCLUKLkclM3DpHn48eL3Yc0QIIYxDRgyLGuf0jdus2BvCHycjyMzSGdo97S0Z2NKDgS3daepuK6UhhBDiHkuXLiUhIYGEhAQAgoODycrKAmDSpEnY2dkREBDAyJEjmTNnDtHR0fj7+7NhwwZ27NjBihUrCpyYrrwYY8RwjtkDm7LzYjSxKZl8su0Si5+p2OS0EEKUhJeXF3/99ReBgYF8+eWXODs7M3XqVCZPnmxYR61W06VLF7Zt20aHDh0A6NixI2ZmZnTp0gW1WsYeGYOnpyfz5s1j9erVxVomxL2UQiafK4xarWJCj/q09rZnyk/HiU7KYMZvp9h9OYaFT/nLpOtCVHKSGBY1gqIoHAyJ4/MdV9h5MdrQbm6qZkALD4a286atrwPqovR8QghRQy1evDhP/eAtW7awZcsWAEaMGIGdnR0Ay5cvx9fXl5UrV/Lll1/SsGFD1qxZw4gRI4wStzHYW2l4uXs9Fvx5jt+OhvNK93rUdall7LCEECKPHTt25Gvr0aMHhw8fvu92f/zxR57H5ubmpKWllWVoopiefPJJADZu3FisZULcK/eI4eIMlupU35nNUx5m2roT7LoYzZ8nIzgZnsCnw1rT2sehPEIVQpQBOZ0rqr1j1+MZ+tV/DP3qP0NS2MfRirf7N+HAzF58OKQl7es4SlJYCCEeIDQ0FEVRCrz5+fkZ1tNoNMyfP5+wsDAyMjI4deqUUZLCxiolkWNER19cbczRKfDJtktGiUEIIUTFSU5OZs6cOTz22GO4uLigUql4//33C1xXq9Uye/ZsfHx8sLCwoEWLFnz//fcVHLEQ+emKOWI4N+da5qwc046ZjzXGVK0iLC6NZ77cz/92XkGnk9ISQlRGkhgW1VZoTAoTvjvKk5/v42BIHACNa9vwybBWbH+tG2Mfrou9lcbIUQohhCgvgYGBREVFceLECaMc38LMhAk96gOw6cRNrsdKrWEhhKjOYmJimDdvHqdOnaJ169b3XXf8+PG8++67PPHEE3z22Wd4e3szfPhwKfcgjK44NYYLolarGN+1Hr+83AkfRyuydAoLN59n9LcHibxnonchhPFJYlhUOxlZ2Xy09SKPfrSTP09FAFDXxZovRwSwecrDDGrliamJvPSFEKK6M/aIYYCh7bxxstagU2D5nqtGi0MIIUT5c3d358aNG4SFhfHVV18Vut6xY8dYuXIlc+fO5dNPP2XcuHH88ccfdO/encDAQDIzMw3r9u7dGwsLiwJvL7/8ckU8LVHDKKVMDOdo5W3PH5O7MKCFOwC7L8XQ5+Nd/H06otQxCiHKTqXNjpXm0hpFUfjkk09o2LAh5ubmNGzYkE8//TTfrJgRERHMmDGDXr16YWdnh0ql4scffyxwnyqVqtBbgwYN8q2/ZcsWOnbsiKWlJbVr12by5MkkJycX/xchiuVgSBz9PtnNJ9suoc1WcK5lzrtPNmfL1K70bV5bJpQTQogaxNgjhkE/anh0Jz8A1h0OIy4l8/4bCCGEqLLMzc3x8PB44Hrr1q1DrVYzYcIEQ5tKpWLixIlERUURHBxsaN+yZQvp6ekF3r744otyeR6iZstd8UFVyoyRrYUZnz3bmkWDW2CtMSEhVcv/rT3K6z+fICldW7qdCyHKRKVNDJfm0pp58+YxdepUOnbsyLJly+jQoQNTpkxh/vz5eda7cOECH3zwAaGhobRq1eq++1yzZk2+29tvvw1Anz598qy7fft2HnvsMUxMTPj444958cUX+frrrxk0aFC+5LQoGxlZ2byz6QxD/refq9EpqFQw6iFfgl/vxvAOvjJCWAghhNGM7OiLpZkJ6Vodo1Yc4Oj1eGOHJIQQwoiOHDlCvXr1cHR0zNPeoUMHAI4ePVrsfWZlZZGenk52dnae+w9aJsS9ctcCLs2I4RwqlYohbb3ZPKUrAb76Seh+ORJOv092cyg0rtT7F0KUjqmxAyhIzqU18+bNY9asWQCMHTuWnj17EhgYyLBhw9BoCq4Ne+vWLRYuXMgLL7zAN998Y9jWxMSE9957j/Hjx1O7dm0AAgICiImJwcnJiR07dtCjR49CYypo0pzp06cXuOzVV1+lYcOGBAcHG+Js0KABzz//PBs3buSJJ54o3i9E3FdITAqTfjjK6RuJADR0q8XCp1oYOh0hhBA1U1BQEEFBQUb/8utgrWFSr/os+vsCp28k8tTn+xjS1ovpfRvjVMvcqLEJIYSoeDdv3sTd3T1fe85o45s3bxZ7nwsWLOCdd94xPH733Xf59ttvGTNmzH2XFSYxMZHExETD44gIufy/pijN5HP34+NkxU/jO/Llzit8/O8lwuPTGPq//bzcvR5TejVEYyqDuYQwhkr5n1ecS2vutXHjRjIyMpg0aVKe9kmTJpGRkcHGjRsNbTY2Njg5OZUoRkVR+OGHH6hfvz4dO3Y0tF+4cIGTJ08yfvz4PMnrESNGYG9vz08//VSi44mC/X7iJgM+3W1ICo97uA5/THpYksJCCCEqRSmJHK90r8+aF9tT18UagHWHw+mxeAdr9oeSLbN0CyFEjZKWloa5ef4Tg2q1GjMzM9LS0oq9z7lz56IoSp5bTuL3fssKs2TJEry9vQ239u3bFzsmUTWVVY3hgpiaqJnYswG/vdKJui7W6BRYFnyFp77Yy+WopDI9lhCiaCplYrg0l9YcOXIEc3NzWrRokae9devWaDSaEl2WU5AdO3YQFhbG8OHD8x0fyNdxmpqaEhAQUKTjJyYmEh4ebrjJ2dn8FEVhyZYLTP7hGCmZ2Thaa/h2TDve6t9UzjQKIWq0W7ducfr0ac6cOUNkZKSxwxG5PNzAhb+ndGVGv8ZYaUxITM9i1sYzDPtqP2FxqcYOTwghRAWxsLAgIyMjX7tOp0Or1WJhYWGEqPKaNm0aYWFhhtvBgwfL/ZiKojB+9WHav/svy4Ivk5Yp5S6MIU+N4XKaoqeFlz1/TnqYkR19ATh9I5H+n+7hmz0heUpZCCHKX6XMoJXm0pqbN2/i5uaGWp33qanVatzc3Ep0WU5BvvvuOyB/GYmc/RcWf1GOL2dn7y9dm82UH4/z6fbLALT2seevyQ/To7GrkSMTQoiKl5KSwooVKxg4cCBOTk54enrSsmVLWrRogYeHB05OTgwYMIBvvvlGJkGtBDSmav6vWz22vdaN/ndm6T4UGk/fj3ex7nCYzEUghKhS/Pz8WLp0qbHDqHI8PDwKHPyT812xKBPYlTdbW1u8vLwMt4K+35a1sxGJbDkbSVRSBkH/XGD0ioPSLxqBrhxHDOdmqTFh/hPN+fb5drjYmJORpWP+H2cZ9tV/XItNKbfjCiHyqpSJ4dJcWlPYtqA/M1uSy3LulZGRwa+//krHjh2pX79+vuMDBcZQ1OMb4+xsVXE7Vcvw5Qf4/YT+Q9PAlh78MK4jte2Mf1ZdCCEqUlxcHIGBgdSuXZvx48dz/fp1nnrqKRYsWMDnn3/OsmXLWLBgAU8++SRhYWG89NJL1K5dm8DAQOLiasZEH0FBQbi6utKyZUtjh5KPu50ly55rw+fD22BvZUZKZjZv/HKSCd8flVm6hRCimmvTpg1XrlzJ1x8fOHDAsLwm+ueM/konzZ2Jww+GxnEsLMGIEdVMFZUYztGjkSv/TO3KwJb6EyIHQ+Po+/FuVu8PldHDQlSASjn5XGkurSlsW4D09PQyuSxn06ZNJCQkFDghXc7+C4qhqMe3tbXF1ta21HFWN/EpmYz45gBnburrCU/uWZ+pjzREXZYV8YUQooqoU6cO3t7evPfeezzzzDOGiVULc+vWLX7++We++uorvv76axISEiomUCMKDAwkMDCQ8PBwvL29jR1OgR7zd6etrwOBv5xk58Vo/jp1i/O3kvhyRAAN3WyMHZ4QQohyMGTIEBYtWsTnn3/O22+/DejLKCxduhQXF5f7Topenf1z+hYAQ9p5sedSDKGxqfxyJJw2PjJ/TEVSymnyuftxtNbw2bOt6dusNrM2niYuJZPZG8+w+dQtFg1ugbejVcUEIkQNVClHDJfm0hoPDw8iIyPR6XR52nU6HZGRkWVyWc7atWsxMzNj6NChBR4fCp619ebNm5XisqCqKC4lk+eW65PCKhW8/5Q/03o3kqSwEKLGWr16NadPn2bSpEkPTAoD1K5dm0mTJnHq1ClWr15dARGKonK1tWDl8+1467EmmKhVXI1OYdDSvaw7HEZWtu7BOxBCiFLQ6XS899571K1bF3Nzc/z8/Pjkk08ACA4Opm3btpibm+Pp6cncuXPzfc8SeS1dupQFCxYYSmwEBwezYMECFixYwO3btwEICAhg5MiRzJkzhylTprB8+XIGDhzIjh07+OCDDwq9ArY6C4lJ4UKkfvKxvs3cGRzgBcCmEzdJ10qt4YqUe8SwqgJGDOfWv4U7W17tSt9m+s+2+6/G0vfjXXx34JqUFRGinFTKEcNt2rRh+/btxMXF5ZmAriiX1rRp04bly5dz8uRJWrVqZWg/duwYmZmZpb4sJz4+ns2bN9O3b1+cnZ0LPD7AwYMH6dSpk6E9KyuLo0eP0rt371IdvyaKTc5g+PIDnL+VhEoFi55uwTNtK+fILyGEqCiDBg0q8baPP/54GUYiyoJKpWJc17r4e9kx8ftjxCRn8MYvJ/lo60Wea+/D0PbeuNpI2SQhqpTsLEiJrvjjWruASdG/5s2bN49ly5bx8ccf07FjR8LDw7l27Rrh4eE89thjjB8/nrVr13Lq1CnGjRuHvb09U6dOLb/4q7jFixdz7do1w+MtW7awZcsWQD8/jZ2dHQDLly/H19eXlStX8uWXX9KwYUPWrFlT4FWpNcE/Z/Sjhe0szehQ15G6LtZ8uPUiSelZbDkbyeMtZYBVRcmp3mCsMVjOtcz5YkQbNp2MYPbG0ySkanlr/Wk2n7rFB4Nb4GlvaZzAhKimKmViuKiX1sTExBATE4OPjw9WVvpLCwYNGsTUqVNZunQpy5cvN+zzs88+Q6PRlOqLNMC6devIzMwstMNu0qQJzZs356uvvuKVV15Bo9EA+lHG8fHxDBkypFTHr2lSM7N4YeUhzt9KQq2CD4e05MnWXsYOSwghqoz9+/cTHx9Pt27dsLa2NnY44gE61nXiz8ldeP3nE+y+FEPE7XQ+3HqRT7Zd4jF/d57v7EdruaRWiKohJRr+mFrxxx3wMdgWbaKw9PR0PvjgA/73v/8xfPhwAOrVqwfAzJkzqVu3Lh9//DEqlYrGjRsTEhJCUFCQJIbvIzQ0tEjraTQa5s+fz/z588s3oCri6LV4ALo2dMHMRI2HvSWd6zmz53IMm09FSGK4AuWMGK6I+sKFUalUPN7Sg451HZn522n+PRfJnssx9PloF7MGNGFIW+8KH80sRHVVKRPDuS+tiY6Oxt/fnw0bNrBjxw5WrFhhuLRm6dKlvPPOOwQHB9O9e3dAX8ph+vTpzJ8/H61WS9euXdm5cydr1qxh9uzZ+WZTXbBgAQAhISEArF+/nsuXLwMYktK5rV27Fltb2/uOtlqyZAl9+/alZ8+ejBo1iuvXr/Phhx/SvXt3nnzyyVL/fmqKbJ3C5B+OcyJcf8nV4mckKSyEEIV599132bNnD5s3bza0DRgwwPDYw8ODPXv24Ovra6wQRRG52Vqw5sUOnL+VyNr/rrH+6A1SMrP5/cRNfj9xk1be9rzUtS59mtWWkkpCiFK5dOkS6enpBda0PXfuHJ06dcqTfOncuTPTp08nMTFR5kQRZSoyST9Hj4/j3dGg3Ru5sOdyDP9djUWnU6TPqyBKJUgM53C1seDrUQFsOH6DORvPkJiexfRfT/HXqVu8/7Q/7nYyeliI0qqUiWEo3aU177zzDg4ODixbtowff/wRb29vlixZUuCZ7VmzZuV5vG7dOtatWwfkTwxfu3aNvXv3MmbMmPtOIvfoo4/y559/Mnv2bKZMmYKtrS0vvvgiCxculLNaRaQoCnN/P8O/5/Qz077RtxFPtZGksBBCFOann37i0UcfNTzetGkTf/31F9OnT6dVq1ZMnjyZefPm8c033xgxyooVFBREUFAQ2dlVszZh49q2LHjCnxn9mvDrkXBW7gslJCaF42EJvPzdURq52TCpV336NXfHRL4sC1H5WLvoR+8a47jFVNh3lHvbpcanKC/RiekAecomdaqnL90Yn6rl/K0kmnrIyYiKkFNKorKkLlQqFU+29qJTPWdm/HqS4AvR7LwYTe8lu3hbRg8LUWqVNjFclEtr5s6dy9y5c/O1q1QqXn31VV599dUHHqc4H258fX2LPNlC37596du3b5H3LfJavjuENf/pa3M918GHl7vVM3JEQghRuV2/fp1GjRoZHm/YsIEGDRqwcOFCAC5cuMCKFSuMFZ5RBAYGEhgYSHh4ON7eVbc2fS1zU0Z38mNkR192Xozmq11X2X81lguRSUz8/hgN3S4xo19jejRylS9GQlQmJqZFLulgLA0aNMDS0pLt27czatSoPMuaNGnCxo0bURTF8N6yb98+PD09ZbSwKFOKohCdrB8x7Gpzd+K9xrVtcLAyIz5Vy74rMZIYriCVoZREQdxsLVgxph2/HAln3qazJGXoRw9vOhHBwqf88Xa0MnaIQlRJamMHIMS99l+JZeHmcwD0aOTCvMebyRddIYR4AEVR8oyM/ffff/OcoPTy8iIyMtIYoZXaM888g5ubG7a2trRo0YI//vjD2CEZhVqtokdjV34Y35F1Lz3Eww30I6kuRibzwsrDPPv1f5wMTzBukEKIKsXCwoLp06fz2muv8f3333P16lX27t3L2rVreeWVV7h69SpTp07lwoUL/PLLL7z33nu8/vrrxg5bVDPxqVq02fpkpKvt3cSwWq2iY10nQP8dUVQMY08+dz8qlYpn2nqzZVpXejZ2BdDXHv54F6v2haLTyVUNQhSXJIZFpRKdlMHkH4+hU6CeizWfPdcGUxN5mQohxIM0atSIDRs2ALB161bCw8PzJIbDwsJwcKiak5bNnTuXsLAwEhMTWb58OcOHDyc2tmZ/QWxfx5E1L3bgl/97iDY+9gD8dzWOx5fu5Y1fThCXkmncAIUQVcasWbOYPHkyb775Jk2aNGHkyJHEx8fj5eXFX3/9xd69e2nRogVTpkxh6tSpTJ482dghi2omKindcD93KQmATvX0ieEDIXFkZRft6l1ROpWpxnBh3O0s+WZ0Wz4e2gp7KzNSM7OZ8/sZhn31HyExKcYOT4gqpdKWkhA1T7ZO4dWfjhOdlIGFmZrPhwdQy1xeokIIURSvv/46zz77LA4ODqSkpNC4ceM8NYe3bdtGq1atjBdgKTRr1sxwX61Wk5mZyY0bN3BycjJiVJVDWz9Hfn25E/+cucUHf18gJCaFdYfD2Xo2khn9GvNMgLdM1iOEuC+1Ws2sWbPyzb0C0KNHDw4fPlzotqGhoeUYmagpohIzDPddcpWSAHjoTp3h5IwsTt24TWufqnmSuyrJKSVRifPCgH708BOtPelc35k5v5/mr1O3OBgaR9+Pd/Fa74a82KWuzMEgRBHIUExRaSwLvsyeyzEAzBvUnEa1bYwckRBCVB1Dhgzhn3/+4fnnn2fmzJls374dU1P9ybW4uDicnZ0ZN25cifefnJzMnDlzeOyxx3BxcUGlUvH+++8XuK5Wq2X27Nn4+PhgYWFBixYt+P7770t8bIDhw4djYWFBu3bt6NmzJ/7+/qXaX3WiUqno29ydf6Z25Y2+jbAwUxOfqmX6r6cY8r/9nL+VaOwQhRBCiEJFJekTw7YWpliYmeRZVs/FGudaGgCOXIuv8NhqopxplarKiWUXG3M+Hx7AF8Pb4FxLQ0aWjvf+Os9TX+zjYmSSscMTotKTxLCoFI5ej+fjfy8C8FQbT54J8DJyREIIUfU88sgjLFmyhLlz5+Lm5mZod3R05LfffuPJJ58s8b5jYmKYN28ep06donXr1vddd/z48bz77rs88cQTfPbZZ3h7ezN8+HBWr15d4uN/9913JCcn888//9C7d2+pPV8AjamaV7rXZ+ur3Xikib7u3uFr8fT/dA+L/j5Pujb7AXsQQgghKl5kor6UhKutRb5lKpXKMEr42PWEigyrxqqsk889SD9/d7a+2o2nWnsCcCIsgf6f7uazbZfQShkSIQoliWFhdBlZ2Uz/5SQ6Beo6W7PgiebyhV8IIUooPj6edevWERQURFBQEOvWrSMuLq7U+3V3d+fGjRuEhYXx1VdfFbresWPHWLlyJXPnzuXTTz9l3Lhx/PHHH3Tv3p3AwEAyM+/Wvu3duzcWFhYF3l5++eV8+zY1NaV3795s3bqVv/76q9TPqbrydrRi+eh2fDUyAE97S7J1Cp/vuEL/T3dz5FrpXwtCCCFEWYq+M2LYzda8wOVt7iSGZcRwxVAq8eRzD+JgrWHJ0FasGNOW2rYWaLMVPtx6kceX7uX0jdvGDk+ISkkKuAqjWxZ8hUtRyQB8MLgFVhp5WQohREnMnz+fhQsXkp6enqfd3NycN998k9mzZ5d43+bm5nh4eDxwvXXr1qFWq5kwYYKhTaVSMXHiRAYPHkxwcDB9+vQBYMuWLSWKJTs7m8uXL5do25qkd7PadK7vzOItF1i5L5Qr0SkM/nI//ZrXZlg7H7rUd64yl4kKIYSovnImn7t34rkcOZOs3kpM52ZCGh72lhUVWo10t8Zw1f2M0LOxG1umObLwr3P8cDCMcxGJDFq2l5e71WNSr/qYm5o8eCdC1BAyYlgY1YVbSXyxQ//lfmRHX9r5ORo5IiGEqJref/995syZQ7du3di0aRMXLlzg/Pnz/P7773Tr1o133nmn0JrAZenIkSPUq1cPR8e87+cdOnQA4OjRo8Xa382bN/n1119JTU1Fq9Wybt06goOD6dq1a6HbJCYmEh4ebrhFREQU/4lUE9bmpswZ2Ixf/u8h6rlYoyjw16lbjFpxkIcXBfPptktE3E4zdphCVHkmJiZkZ0u5lqLIzs7GxESSMuKunMnnXG0KHjHcwsse0zsnMo9el1HD5U1XhUcM52ZrYcbCp1qw9sUOeDnor6BaGnyZ/p/ukdeRELlIYlgYTbZO4Y1fT6LNVnC3s+CNvo2MHZIQQlRZy5Yto2/fvmzevJn+/fvToEEDGjZsyIABA/j777/p3bs3y5YtK/c4bt68ibu7e772nNHGN2/eLPY+P/roI9zd3XFxcWHRokX88MMPtGrVqtD1lyxZgre3t+HWvn37Yh+zugnwdeTPyQ/z7pPN8fe0A+BGQhpLtl6k8/vbGbvqMHsuxaDkXD8qhCgWCwsLMjMziY2NNXYolVpsbCyZmZlYWBQ8MlTUTDmTz7kUkhi21JjQ1MMWkHISFaGq1hguTJcGzvwztStjOvmhUsHlqGSe/mIfC/44S1qmnNATQq7ZF0azZn8oJ8ISAHj3yebYWJgZNyAhhKjC4uLiGDhwYKHLBw4cyK5du8o9jrS0NMzN83+xU6vVmJmZkZZWvNGpHh4e7Nmzp1jbTJs2jbFjxxoeR0RESHIYsDAzYXgHX4Z38OX0jdusOxzG+mM3SErP4t9zkfx7LpJ6LtY8296HPs1q4+1oZeyQhagynJ2dycjIICoqioSEBBkRW4Ds7GwyMzOxsbHB2dnZ2OGISkJRlLulJAqYfC5HGx8HTobf5qhMQFfulGqWGAb9FVRzH29G/xbuTP/lJFdjUli+J4St5yL54OkWdKzrZOwQhTAaGTEsjOJ2mpaPt10CYGBLD3o2djNyREIIUbUFBARw6tSpQpefPn2atm3blnscFhYWZGRk5GvX6XRotdoKGSVma2uLl5cXP/zwA23atKFv377lfsyqprmnHfMGNefgzEcIGtyC5p76kVhXolNY8Oc5Hl4UbJjJ+2JkkowkFuIBVCoVnp6eODs7o9FojB1OpaTRaHB2dsbT07NK1y4VZSspI4t0rQ4ovJQEQBtf/QR0Z2/eJl0rozzLU04pier4b9rOz5G/pjzMS93qolbBtdhUhn31H29vOEVyRpaxwxPCKGTEsDCKz4Mvk5CqRWOqZka/xsYORwghqrylS5fSp08f/Pz8eOWVV6hVqxYAycnJLFu2jA0bNvDPP/+UexweHh5cu3YtX3tOCYmiTGAnKo6lxoRn2nozOMCLo9cTWPvfNbaejSQ5I4szNxM5czORD7depI6zNX2a1aZPMzdaetnLpHVCFEClUuHi4mLsMISoUnLqC8MDEsN3JqDTZiucunFb5qYpR9WtlMS9LMxMeLNfEx5r7k7gLye4GJnM2v+uE3w+moVP+dO1obyPi5pFEsOiwoXFpfLtvlAAXuhcB0+ZVVYIIYqtadOm+dpUKhVvvvkmb731Fm5ubqhUKm7duoVOp8PNzY1hw4Zx5syZco2rTZs2bN++nbi4uDwT0B04cMCwvKIEBgYSGBhIeHg43t7eFXbcqkilUhHg60CArwMZWdnsuxLLljO32HImktiUTEJiUvhy5xW+3HkFN1tzejZ2o1djVzrXd8ZSI5fMCyGEKJmoxHTD/fuVkvC0t8TVxpyopAyOXouXxHA5qi6Tzz1IS297Nk3qwrLgK3wefJkbCWmMWnGQZwK8eLt/U+yspNSlqBkkMSwq3OItF8jM0uFgZcYrPeoZOxwhhKiSXF1d812K6+bmRqNGeSfyrF+/fkWGxZAhQ1i0aBGff/45b7/9NqCvVbd06VJcXFzo0aNHhcUSFBREUFAQ2dlyyWlxmJua0KORKz0aubLgCYUj1+L558wt/jlzi/D4NCITM/jh4HV+OHgdc1M1nes707OxK72auOJuJyd7hRBCFEynU1i9P5T6rjZ0aaCvM52QpgVAY6qmlnnh6YmcE5ibT9/i6HWZgK48Vccaw4UxNzVh2qMN6dusNm/8eoLTNxL5+Ug4Oy9G896T/jzSVEpeiupPEsOiQp0MT2Djcf3lxFN6NcBWJpwTQogS2bFjR4Ufc+nSpSQkJJCQkABAcHAwWVn6emyTJk3Czs6OgIAARo4cyZw5c4iOjsbf358NGzawY8cOVqxYUeDEdOVFRgyXnolaRfs6jrSv48jb/Ztw5mYi285Fse18JCfDb5ORpWP7+Si2n4/i7Q3Q1N2WXk1c6dnYVUpOCCGEyOPwtXjmbjqLjYUpJ+f0RqVSkZGlP3lrYfrg6Y/a+OgTw0euJaAoitSqLic5pSRq0q+3qYctG17pzFe7r/Lx1ktEJWUwdvVhnmjlwZyBzXCwltrxovqSxLCoUAv/Og9AHWdrnuvga+RohBBCFMfixYvz1A/esmULW7ZsAWDEiBHY2dkBsHz5cnx9fVm5ciVffvklDRs2ZM2aNYwYMcIocZfajaNwZr3+fu/5xo3FiFQqFc097WjuaceURxoQlZhO8IUo/j0XxZ5LMaRpszkbkcjZiEQ+234Z51oaejTSjyTu0sDlviPBhBCiovz77798/fXXXL16lbi4uHyTa6pUKq5cuWKk6Kq3uBR9PeGk9Cy02QoaU5Vh4jlzsweXJcqZgC4mOYPw+DS8Ha3KL9ga7G4piRqUGQZMTdS80r0+vZu6EfjLSY5dT2DD8ZvsuRzLgiea07d5bWOHKES5kE/oosIcuBrL/quxAAT2aYSmCGeFhRBCFI9Wq+XChQskJCSg0+nyLe/atWuJ9x0aGlqk9TQaDfPnz2f+fOMmUcuslES2FmIugsoEFKVmDaG5D1dbC4a282FoOx/Stdn8dzVWP5r4XCQ3b6cTk5zJz0fC+flIOBoTNe3rONK9kQvdG7lSz8VaRnoJISrcJ598wrRp03BxcaFjx440b97c2CHVKJnZSq77OjSmajK0+j7avAjfDZt72qIxUZOZrePo9XhJDJeT6j753IPUd7Xhl//rxIo9ISzecoGY5Az+b+0RBrRw553Hm+FUq+KufhOiIkhiWFSYz7ZfBqBxbRv6NpOzbUIIUZYURWH27Nl8+umnJCcnF7peTaq3W2alJKzuTHCjZEN6Alg6lEl81YmFmQndG7nSvZEr8wY14/ytJLaf1yeJj4UlkJmtY8/lGPZcjmHBn+fwdrSke0NXejR24aG6MoGdEKJifPjhh3Tr1o2///4bjUYuDa9o2qy7J6wztNnUMjcl405bURLD5qYmNPO05dj1BI5ei2dQK89yi7UmyxlEX0PzwoC+lNa4rnXp1cSVN345yeFr8fxxMoL9V2KZN6g5/Vu4GztEIcqMDNkUFeLItXj2XI4BYGLP+lJzUAghylhQUBDvvvsuw4YNY/Xq1SiKwvvvv8+XX35J8+bNadWqlaHsQ00RFBSEq6srLVu2LN2OLHPNfJ4aV7p91QAqlYom7rZM6FGf317pzKG3HmHxMy0Z0MIdO0v93AJhcWms+e8aL6w8TMt5Wxi14iAr9oQQEpNi5OiFENVZTEwMQ4cOlaSwkWiz7yaGM+/cz0kMWxShlARAgI/+5OwRmYCu3Oh0NXvEcG51XWrx00sPMWtAUyzM1MSmZDLh+6O88t0RYpIzjB2eEGVCEsOiQnyxQ1+nq75rLfo1l7NrQghR1r755huefvpp/ve//9G3b18AAgICGDduHAcPHiQ7O5udO3caOcqKFRgYSFRUFCdOnCjdjiztgTtfjtLki2hxOdcyZ3CAF0ufa8ORtx/hl/97iIk96tPMwxaAzCwduy5GM++Ps/RYvINuQcHM/f0MwReiSNfWnBHuQojyFxAQQEhIiLHDqLFyJ4YztDmJ4aKXkoC7dYbPRSSRmplVxhEKyFVjWLJFgH708Itd6vD3lK60r6MfLPDXqVs8umQnG4/fyFenXIiqRv7VRbm7Ep3Mv+ciAXipa11MZLSwEEKUuWvXrtGrVy8ATEz0o24yMvQjGczNzRkxYgSrVq0yWnxVmtrkbvmI1FjjxlLFmZqoaevnyOt9GvHn5Ic5OLMXiwa34DH/2thY6CucXYtNZeW+UJ7/9hAt39nC898eZNW+UK7Hpho5eiFEVbdkyRJWrlzJv//+a+xQaiTtPTWG4W6C2Ny0aCOG29wZMZytUzgZfruMIxQgNYYL4+dszY/jOvLO482w0pgQn6plyo/HeWnNEaKS0o0dnhAlJjWGRblbvlt/Vt7VxpzHW3kYORohhKieHBwcSE3VJ85sbW3RaDSEhYUZlltYWBATE2Os8IyizCafA32d4bQ4SQyXMVdbC4a09WZIW2+02TqOXU8g+EIUweejOH8riYwsHcEXogm+EM0czlDX2fpOLWMX2tdxLPKlx0IIATBnzhzs7e3p06cP9evXx8/Pz3AyNYdKpeLPP/80UoTVW55SEll5S0mYmxVtzFptOws87S25kZDGkWvxdKzrVPaB1nA5I2Blktj81GoVozv50bOxvvbw/quxbDkbyYGQOOYMbMqTrT3l9yaqHEkMi3IVk5zBr0fDARjT2a/IZ4KFEEIUT/PmzTl27BgAarWa9u3b8/nnn/PYY4+h0+n43//+R+PGjY0cZcUqs8nnQJ8YjkVqDJcjMxM17es40r6OI9P7Nibidho7L0Sz40I0ey7HkJyRxdWYFK7GhLBibwiWZiZ0qudE98audG/oIrPTCyEe6OzZs6hUKnx8fMjMzOTixYv51pGkTvnJU0riTgmJ4paSAGjtY8+NhDSOSZ3hcmEoJSH/CoXydrTiu7Ed+P7gdRb+dY7baVqmrTvBnycjePdJf2rbWRg7RCGKTBLDolz9dCiMzCwdlmYmDG/va+xwhBCi2ho+fDjLli0jPT0dCwsL3nvvPXr37o2vr/69V6PRsGHDBuMGWZXlTECXJonhiuJuZ8mw9j4Ma+9DZpaOI9fi2XEhih0XorkQmUSaNptt56PYdj4K0M9j0KORC90budLOzxFNMZIMQoiaITQ01Ngh1GiZuUpJZNw7YrgYA4ja+Djwx8kIjl5PQFEUSeaXMSklUTRqtYoRHX3p3siFN387xe5LMWw7H8XBj3Yya0BTngnwktemqBIkMSzKTVa2ju/+uwbAk208sbMyM3JEQghRfY0ZM4YxY8YYHnfp0oUzZ87w+++/Y2pqSu/evWnQoIHxAqzqrO5cqnq/UhLaNDj/J6hNwak+uDUD+UJQJjSmah6q58RD9Zx487Em3EhIMySJ916OITUzm8tRyVyOSubr3SFYa0wI8HMkwMeBAF8HWnjbYWshn0OEEMKY8o4YvrfGcNFP5gXcmYAuLiWT0NhU6jhbl2GUQkYMF4+XgxWrX2jPT4fCePfPcySlZ/HGLyf582QEC5/yx8Pe0tghCnFfkhgW5Wbb+Shu3tYXYR/1kIwWFkKIilanTh2mTJli7DCqh9yJYUXJn/BNjoZdiyDh+t0295bQaRKY21RcnDWEp70lwzv4MryDLxlZ2RwOjSf4fBQ7LkZzOSqZlMxsdl2MZtfFaMM2Xg6WNHG3pYm7LU3v3LwcLFHLN18hapwdO3bwxx9/EBISgkqlws/PjwEDBtC9e3djh1atabMKqjF8p5REEWsMAzRxt8XcVE1Glo6j1+IlMVzGpMZw8alUKoa196FrQxdmrj/FjgvR7LwYTe+PdvFW/yYMa+ctv09RaUliWJSbNfv1o4Xb13GkcW1bI0cjhBA1S2JiIlOnTuWNN96ocbWFc5T55HMA2VrITM6b7M3Ogu3zIDkKUOmTyKkxEHEC/nkLes0Ba5kcp7yYm5rQub4znes78zYQFpfKrkvRHLkWz9Fr8YTG6idlDI9PIzw+ja1nIw3b1jI3pXFtG0PCuIm7DY1q22ClkY/IQlRHWq2W4cOH8+uvv6IoCvb29iiKwu3bt/noo48YPHgw3333Haam8h5QHrJ0d0tJ5CSG07XFLyWhMVXTwsuOQ6HxHL0ez9MBXmUbaA13t5SEkQOpgjzsLfl2TDt+ORLOvD/OkpSexZu/neKvUxG8/3QLPGX0sKiEpMcT5SIsLpU9l2MAGNFRRgsLIURFS0tLY9WqVYwYMaLGJobLdvK5XInd1Li8ieGIE3eSwkDX18EzAC5thaOrIDkSdiyER98BjYxoqgjejlaG0cQA0UkZnLl5m3MRSZyLSORcRCJXY1LI1ikkZ2Rx+Fo8h6/dncBIpQIfRysauNrQ0K0WDd1saOBWi3outbAwk0l0hajKFixYwC+//GI4cVq7dm0AIiMjWbRoER999BFNmzZlzpw5Ro60esosqJTEnRHDFsUYMQzQxteBQ6HxHLkmE9CVtbulJCQzXBIqlYpn2nrrRw//dopt56PYfSmGvh/tYtZAqT0sKh9JDIty8fPhMADsrczo08zNyNEIIUTNlHMpoCgDlg5376fFgUOuk56hu/Q/nRuAV1v9/Ya99aOMdy2G22Gw+0Po8RaoJbFY0VxszOneyJXujVwNbenabC5FJnMuIpGzd5LF5yISSUzPQlHgWmwq12JT+ffc3dHFahX4OlnTwLUWDXISxq421HWxloSxEFXE2rVree6551iyZEmedjc3Nz788EMiIyNZvXq1JIbLScGlJIo/Yhj0E9ABXIxMIjkji1rmktooKzL5XNlws7Vg+ei2rD92gzm/nzHUHv7n9C0WPuWPq62FsUMUApDEsCgH2TqFn4+EA/BEK89id/JCCCFEpWNiBua2kJGYdwK6zBS4cUR/369r3m282kLbF+DwNxB5Bo5/D21GVlzMolAWZib4e9nh72VnaFMUhZu30zl3M5GLUUlcikzmYmQSl6OSycjSoVMgJCaFkJgUtpzNmzD2c7Kmnqt+VHFdF2vquVhT17kWDtYaYzw9IUQhbty4QZcuXQpd3rlzZ3755ZcKjKhmyTv5XPadn8WffA7uJoZ1CpwIS6BzfecyilLkjCuQvHDpqVQqnmrjxUP1nJj+6yl2XYxm2/koen+8i3mDmjOwhbuMHhZGJ4lhUeb2XI4h4s6kc0PblfLSXSGEECWi0Wjo1q0bDg4OD15ZFI2V453EcNzdtrCD+rrDKhPw6Zh/m4a94fZ1fWmJ83+AU33wfajiYhZFplKp8LS3xNPekkea3r3aKVunEB6fysU7ieJLkUlcjEzmcnQymXcSxldjUrgak8JWIvPs08HKjLoutajrrE8c13W2pq5LLXydrDAzKV4SRAhReu7u7hw8eJD/+7//K3D5oUOHDOUlRNnTZuevMWyYfK6YiWEXG3N8HK24HpfKkWvxkhguQzJiuOy521my6vl2/HgojAV/nCUhVcvkH47xz+lbzH+iOY5yIlkYkSSGRZn77ah+tLC/px1N3GXSOSGEMAYHBweCg4ONHYZRlenkc6CvMxwfmjcxnDNa2L0lWBTS57UZA3EhEHsZDnwBdp5g71M2MYlyZ6JW4etkja+TNY/ekzC+HpdqSBZfjU7hSnQyV6NTSMrIAiA+VcuRa/lrYJqoVfg4Wt1JFFvfGWlcCz8nK1xszGX0kBDlZOjQoQQFBeHj48O0adOwtdW/bycmJvLRRx+xatUq3njjDSNHWX3lrjFsSAznTD5XgpI8bXzsuR6XytHrUme4LMnkc+VDpVLxbHsfutR3JvCXE/x3NY4/T0VwICSWd5/0p08zOSkljEMSw6JMpWZmseWMfrTMk609jRyNEELUHDt37qRbt273XWfZsmVMmDChgiIyvjKdfA70I4ZBX2M4R86kc071Ct/OxBQefg3+ngHpt/X1hvu8J5PRVXEmahV1nK2p42yd58ucoihEJ2dwNTolV7I4masxKYTFpaJT9EnlnLIU287n3a+lmQm+Tlb4Olnhdych7edkhY+TFe52lpjIN3UhSmzu3LmcOHGCefPmsWDBAsPo4Fu3bqHT6ejbty9z5841bpDVWFaBk8+VrJQEQICvAxuO3+TItXiydYq8P5YRmXyufHk7WvH92I6s2h/K+5vPE5OcyUtrjvBUa0/mPN4MO0szY4coahhJDIsytfVsJGnabNQqGNDC3djhCCFEjdGzZ0+mTJnCwoULMTc3z7MsLCyM559/nuDg4CqbGA4JCaFp06YMHz6c5cuXGycIyzuJ4dw1hlNj9D+tnO6/rZUjdHkVts2HpFuwfxl0DZQCftWQSqXC1cYCVxsLOtbN+7rIyMrmemwqV6KTuXIncXw1JpkrUckkputHGadpszl/K4nzt5Ly7Vtjosbb0RJfJ2u8HSxxs7Ogtq3+lnPfWiZgEqJQFhYWbN68mT/++IM///yT0NBQAAYOHMiAAQPo37+/cQOs5vKUksi+t5RE8UcMP1RP/x6blJ7FqRu3aeVtX/oghWHyYrl6pfyo1Sqe71yHrg1deP3nExy7nsBvx26w70osHwxuQbeGLsYOUdQglfaTo1arZf78+axcuZKoqCgaNmzIjBkzeO655x64raIofPrppyxbtoxr167h6+vLxIkTmTRpUp43t4iICD755BMOHTrE4cOHSUxM5IcffmDYsGGF7nflypV88cUXnDlzBo1GQ+PGjZk7dy59+vQxrOfn58e1a9fybT98+HDWrl1bgt9G1fH78ZsAdKrnLLNsCiFEBQoMDOTDDz/k77//ZtWqVbRr1w6AFStWMG3aNNRqNWvWrDFylCU3efJkAgICjBtEzojhnFIS2nT95HPw4MQwgGsT/eRzR1bqS1Cc/hX8B5dLqKJyMjc1oYGbDQ3cbPK0K4pCXEom1+JSuRabQmhMKtfjUgmNTeFabCpxKZmAPpFyJTqFK9EphR7DxtzUkCR2s7Wgtp05tW0tcL2TQK5tZ4FzLXMZWSdqtAEDBjBgwABjh1Hj5C4lkaHNRlGUUo0YrudSC1cbc6KSMth7OUYSw2VEd+fPJN1E+avnUotf/q8TX+26ykdbL3IrMZ3RKw7ybHsf3urfhFpysldUgEr7Khs/fjyrV69mwoQJ+Pv7s2HDBoYPH05WVhajRo2677bz5s1j7ty5jBw5kjfeeIOdO3cyZcoUEhISmD17tmG9Cxcu8MEHH1C3bl1atWrFrl277rvfCRMm8NVXX/Hcc88xduxYMjMzOXPmDOHh4fnWbdGiBYGBgXna6tatW4zfQNUTn5LJzovRAAxq5WHkaIQQomZ5//33GTRoEKNGjaJz58689tprnDp1ir/++ou+ffuyfPlyPDyq5nvzxo0b0Wg0PPLIIwX2uRUmJ/mrTdUnhXNGCwNYF3HSm4Z99bWGQ/fAqZ+hlivU6Vr2sYoqRaVS4VTLHKda5rTxyT9hZGK6luuxqVyLzUkWp3AjIY1bt9OJTMwg+U5NY4CkjCySopK5HJVc6PFM1CpcapnfSSCb5xlx7GpjgYuNOa425thbmcmIMSFEmdHmrjGcrUObrXBncCrmZsVPDKtUKrrUd+a3YzfYcymGCT3ql1WoNZpMPlexTNQqXu5ej56NXZm27jhnbibyw8Hr7L4UzeJnWua7+kiIslYpE8PHjh1j5cqVzJs3j1mzZgEwduxYevbsSWBgIMOGDUOjKXjWxlu3brFw4UJeeOEFvvnmG8O2JiYmvPfee4wfP95QSyogIICYmBicnJzYsWMHPXr0KDSm9evX88UXX7Bu3TqeeeaZBz4Hd3d3RowYUdynXqX9ey6SLJ2CmYmKPs2lcLoQQlS0hx56iJMnT9KzZ08WLVoE6E+Wvv3226Xab3JyMkFBQRw6dIhDhw4RExPDwoULmTFjRr51S3PFT0HS0tJ48803+euvv1i5cmWpnkep5R4VnBYHKbkSw1ZFTAyrVNB+vL6cROxl+O9LMLcBj9ZlG6uoVmwtzGjuaUdzT7sClydnZBGZmE7k7XRuJepvd+9nEHk7nejkDLLvFI7M1imG9U7c57hmJvoEsouNOS53EsY5SePcP11szEt0GbgQ5a1Hjx6o1Wr++ecfTE1N6dmz5wO3UalUbNu2rQKiq3m099QYzikjASUrJQHQ+U5i+Mi1eNIys7HUyHtRaRlqDBc/Vy9KoVFtGzZM6MzS7ZdZGnyZ8Pg0hn31Hy90rsMbfRthUYIJGoUoikqZGF63bh1qtTpPHUSVSsXEiRMZPHgwwcHBeUo35LZx40YyMjKYNGlSnvZJkyaxatUqNm7cyEsvvQSAjY1NQbso0OLFi2nbti3PPPMMOp2O1NRUatWqdd9ttFotmZmZWFvXjMll/jlzC9B3zrYWUjBdCCEq2u3bt5k4cSIHDhygXbt2nDt3jk8++YTGjRszeHDJSxbExMQwb948vLy8aN26NVu3bi103dJc8VOQd999lyFDhuDn51fi+MtMTo1h0NcZzqk1bG4LpgWfsC6QqTl0ewO2zILkSP1kdN2mQ23/so1X1Bi1zE2p5VKLei6FfzbN1inEJGcQmZh+Z6TxncTx7QzD/cjb6STlGn2szVa4eTudm7fTgdv3jcHO0qzQpHHOKGSXWubYWZqhluuTRQVRFAWd7m4yUqfTPXAUfE59VVH2tFl3f7cZWTrStXf/NhYlGDEM+u+eoB+BfCg0jq5Sm7XUpMaw8ZiZqHn10YY80sSN134+zsXIZFbsDWH3pWg+Gtqq0BPEQpRGpUwMHzlyhHr16uHo6JinvUOHDgAcPXq00MTwkSNHMDc3p0WLFnnaW7dujUaj4ejRo8WOJzk5mf/++4+XX36ZmTNnsnTpUpKSkvDw8ODNN99k4sSJ+bbZuXMnVlZWZGVl4eXlxaRJk3j99ddRF+G0W2JiIomJiYbHERERxY65oqVkZLHrkn7kVO6ZuYUQQlSMLVu28OKLLxIbG8uHH37I1KlTCQ0N5fnnn2fo0KEMHTqUZcuW4eCQ/zL1B3F3d+fGjRt4eHgQGhpKnTp1ClyvOFf89O7du9ASTs8//zxffPEFly5d4ueff+b48ePFjrlcmFmAmZW+lETuxHBRy0jkZmEHPWfBv3P1JSl2fgCdXwUvI9dRFtWWiVqF2526wy28Cl8vLTOb6KQMopPTiU7KICopQ/8zMYPo5AyikvTtMcmZhhHIALfTtNxO0963hEVOHI7WGpysNTjXMse5luZOGQ0Nztb6n061zHGy1uBiYy4jpESp7Nix476PRcXS5krSZ5bRiOHadhbUd63F5ahkdl+KlsRwGZBSEsbn72XH7xO7sGTrRb7efZVLUck8+fleXn20IS91rSfzBIgyVSkTwzdv3sTd3T1fe05txJs3b953Wzc3t3wJWLVajZub2323Lczly5fR6XT89NNPgH70kpubG99++y2TJk1CUZQ8I5RbtGhB586dady4MbGxsaxdu5bp06dz7do1li1b9sDjLVmyhHfeeafYcRrTzovRZGbpUKng0aZuxg5HCCFqnL59+xIQEMCWLVto0qQJAHXq1GHHjh189NFHvP322+zatatENXrNzc2LVJ+4OFf8bNmy5YH727t3L2FhYYZEdHJyMjqdjsuXLxvvy72VI9xO1U9Al1NKwsrx/tsUppYL9JoF2+bpk8y7gqDtC9DgUX3JCSGMwFJjgo+TFT5OVvddL1unEJ+amTd5nJQ3mRxz537uGsjZOkWfeE7KAJIeGI+1xsSQOHay1ieSnWvdTSA7W99NLDtYaeTLsrivXbt20aRJE1xcCk4exsTEcPbsWbp2ldrv5SFPjeEsnWHiOSjZ5HM5ujZw4XJUMtvPR/FW/6alilHkKiUhb6dGZWFmwszHmtCzsSuvrTvBjYQ0Fv19gR3no/lwSEu8He/fTwtRVJUyMZyWloa5uXm+drVajZmZGWlpacXeFsDCwuK+2xYmOVk/8iEmJoa9e/fSqVMnAAYPHkybNm2YN28er7zyCiYm+rOcv//+e57tn3/+eQYNGsQXX3zBlClTaNiw4X2PN23aNMaOHWt4HBERQfv27Ysdd0X6+7S+jEQ7X0ecaxX8+xdCCFF+Zs+ezaxZswx9UW6vvvoqjz32GKNHjy7XGEpzxU9Bhg4dSt++fQ2PFy9eTEREBJ9++ul9tyvXK2+snOB2uL7GcM7kc0WtL1wQm9rw6HwIfhcSb8DhbyDuqj5BXJzyFEJUMBO16s6IX3Oa5B/PkUdqZpZhlHFMcgaxyZnEJmcQm6J/bGhLySQ+NZPcV/KnZGaTEpfK9bjUB8akVoGjtSbPzcGq4McO1hocrTRSj7SG6dGjB2vWrCm07v22bdt47rnnyM7OLnC5KJ28pSSyychVSqIkk8/leKSJKyv2hnAlOoXQmBT8nGtGKcfyIiOGK5eOdZ3YPPVh5m48w2/HbnAwNI5+n+xmzsCmDA7wkpIfotQqZWLYwsKCjIyMfO06nQ6tVouFhUWxtwVIT0+/77aFyUk016lTx5AUBn2i+tlnn2XGjBmcP3+eZs2aFbi9SqVi2rRpbNq0ie3btz8wMWxra4utrW2x4zSWrGwdOy5EAdC7mYwWFkIIY5g7d+59lzdq1Ih9+/aVawylueKnIJaWllhaWhoe16pVC0tLS5yc7j87c7leeZNTZzj3iOGSlJLIzdoJHn0H9n4Ct07B1WCIPg8PTQDnBqXbtxCVgJXGFF8nU3ydHpysycrWEZ+qJTYlg5ikTP3PnERysv5x9J3HMckZeWqU6hTuJJ8zixybhZkaJ2tzHKzN7iaNcyWPnfI81q9jZiIzMlVVD6ofnJmZWaTSf6Jk8o8YLn0pCYB2dRyxsTAlKT2Lf89FMvbhuqWKs6bL+TeRfGPlYWthxpKhrejVxI2Z609xO01L4C8n2XYuivee8sfRWgYTiJKrlIlhDw8Prl27lq895wvl/S5n9fDwYNu2beh0ujyduk6nIzIyskiXwt7Ly0tfiM3NLX/SM6ctISHhvvvw8fEBIC4urtjHr+xOhCeQmK6/RLBHY1cjRyOEEKIw5f1ltzRX/BTFg5LfOcr1yhurO0nplJi7NYZLM2I4h7kNdJ8Jp3+B079BUoR+cro6XaHlsJKXqxCiijE1URsmraMI01akZmYRm5xJ9D0jkeNSMolPySQu9e7PuORMUjLzjgRN1+q4kZDGjYSivz/ZWJjeHYVspU8g21ua4WCtwc7SDHsrfQI5930rjYmM6jKSxMTEPN/VYmNjuX79er714uPj+eGHH/D09KzA6MrO0qVL+eabbzh16hRvv/12nj7zmWeeYdeuXaSlpeHn58d7773HgAEDKjzGzFyJ4YwyLCVhZqKmW0MX/jgZwfbzUZIYLiUZMVx59W/hToCvA4G/nGD3pRj+PnOLI9fjCRrcgu6NJBcjSqZSJobbtGnD9u3biYuLy3M56oEDBwzL77ft8uXLOXnyJK1atTK0Hzt2jMzMzPtuWxh3d3fc3d0LrMt448YNAFxd7/9PePXq1SKtVxXtuBANgJeDJXXlsh0hhDCas2fP8sknn3DkyBESEhLyzMSeI6c/Kg+lueKnLOVceRMUFERQUFDZXhKck6BNyHUCu7QjhnOo1dBiCLi3hP8+h6RbELITru+DxgOh8WP6BLIQwsBKY4qVo2mRay2ma7NJSNXqE8epmXl+5tz0j7X6hHJKZp5kFkBSehZJ6Vlci31weYscZiYq7Cw1dxLFZnnu21tJQrk8ffTRR8ybNw/QX8k5depUpk6dWuC6iqLw7rvvVmB0ZcfT05N58+axevXqfMvmzp1LgwYN0Gg0HDx4kEcffZSrV68+8AqcslZeNYYBHmnixh8nIzgYEsftNC12lmal2l9NdjcxbORARIFq21mw6vn2rN4fysLN54lOymDMt4cY2dGXmY81kRJJotgqZWJ4yJAhLFq0iM8//5y3334b0HfSS5cuxcXFhR49egD6mr8xMTH4+PhgZaX/MDho0CCmTp3K0qVLWb58uWGfn332GRqNhkGDBpUopmHDhvHRRx/xzz//GOojZmZmsmbNGurUqUP9+vUB/chhOzu7PB/gsrKyWLhwISYmJjzyyCMlOn5ltvOiPjHcvZGLfHAVQggj2bdvH4888gi2tra0a9eOo0eP0rNnT9LT09m/fz/NmjUjICCgXGMozRU/5SEwMJDAwEDCw8Px9vYum50WNHLXqoy/WLs0gscWw8V/4PSvoE2FM7/B+U3g1xUa9gEH37I9phA1hIWZCbXtTKhtV7QTVYqikJqZfTdxnDMCOVcCOS4lg4RULbfTtMSnZhKfqiUzK28yWZutGOopF0dOQlmfQC44oexgpbmzTH+ztTTDxtwUtWR1AHjkkUewsLBAURRmzpzJ0KFD8wwgAn3C2NramrZt2xrq4lc1Tz75JAAbN27Mtyx3yUO1Wk1mZiY3btyo8MRwVvbdUh6Z2ToytPoTtyZqFaalLNHSvZELJmoVWTqFbecieaqNV6n2V5PdnXxO3kMqK7VaxZjOdehc35mpPx3nzM1E1vx3jb1XYvh4aCtaeNkbO0RRhVTKxHBAQAAjR45kzpw5REdH4+/vz4YNG9ixYwcrVqwwXKa6dOlS3nnnHYKDg+nevTug/9I5ffp05s+fj1arpWvXruzcuZM1a9Ywe/bsfLUPFyxYAEBISAgA69ev5/LlywCGpDTAm2++yc8//8zgwYOZOnUqrq6urF27lkuXLvHrr78aEqIbNmzgww8/pH///tSpU4f4+Hh+/PFHTpw4wcyZM/Hz8yvPX12Fi0nO4GT4bQC6Nax+o6GFEKKqePvtt/Hx8eHAgQNotVpcXV2ZOXMmPXv2ZN++ffTv35/FixeXawylueKnPJTLiGHLexLDalOwdCi7/ecwMYMmA6BuN31pictbIVsLV7bpb071wbcT+DwkZSaEKEcqlQprc1OszYs+KhnujkyOT828kzTWJ4wTUrUkpGWSkHLnZ662skwoq1RgY26K3Z2Esa3F3cRxTvLY1jJvm349U2wtzapVHeXOnTvTuXNnADIyMnj66adp3rx5uRwrOTmZoKAgDh06xKFDh4iJiWHhwoXMmDEj37parZb58+ezcuVKoqKiaNiwITNmzCh0YrzSGj58OL/++isZGRk89thj+Pv7l8txCqPTKWTpck0+p707Yri0o4UB7K00dK7vzK6L0fxxMkISw6WQU4tbBn1Vfg3cbFj/Smc+2XaRL3Zc4Wp0Ck99vo8pvRrwcvd6pT7hImqGSpkYBli+fDm+vr6sXLmSL7/8koYNG7JmzRpGjBjxwG3feecdHBwcWLZsGT/++CPe3t4sWbKkwEuGZs2alefxunXrWLduHZA3Mezi4sKePXt44403WLp0KWlpabRu3Zo//viDfv36GdZr0aIFdevW5bvvviMqKgqNRoO/vz9r165l+PDhJfxtVF67L+lHC5uZqOhUr2LPOAshhLjr0KFDzJ49Gzs7O0M9+5yEaKdOnRg3bhyzZs2id+/e5RZDUa/4qdLuHR1ct3v5zs5ibgMBo6H5U3B5G1zaoq9tHHtZfzu6BpzrQ+2W4N5CnzBWyyWEQhhbcUcm50jLzDYkjONTM7mdqiUhLdf9XEnk2/dJKCsKJKZnkZieRRjFr+9urTEpNIGsTzIXnHS2tTTDwqzyvgfNmTOnXPcfExPDvHnz8PLyonXr1mzdurXQdcePH8/q1auZMGGCYSDU8OHDycrKYtSoUWUe23fffceqVavYvn07586dq/Ckn/ae8laZ2WWbGAYY0MKdXRej2XUxmoTUTOytZEKukrg7Yti4cYii0ZiqCezTmB6NXHl13XHC4tL4cOtFdl2K5uNhrfG0t3zwTkSNVmkTwxqNhvnz5zN//vxC15k7d26BE9GoVCpeffVVXn311Qce50Ez0+bm6+vLTz/9dN912rRpU+DlO9XV3sv6iXcCfB2wNq+0LychhKj2VCoVdnZ2AFhb6+u9x8bGGpY3bNiQZcuWlXj/S5cuJSEhwTCBT3BwMFlZ+olHJ02ahJ2dXZGv+Kko5VJKQmMNbs0hPgRaPgf1e5XNfh/E3AaaPQFNBsLN43BtL9w4DFkZEHNJfzv9C5hZgmszfZLYrRnYesq04kJUIZYaEyw1lrjbFf2LvKIopGmzSUzL4naaNs8t8Z77ienafOuka/PXo0/JzCYlM5ubt9OL/Rw2TeyCv5ddsberSPv37y+0Hr9Kpco3eKio3N3duXHjBh4eHoSGhlKnTp0C1zt27BgrV65k3rx5hmONHTuWnj17EhgYyLBhw9Bo9EnN3r17s2vXrgL38/zzz/PFF18UOT5TU1N69+7Np59+SoMGDXjssceK+QxLTpud93t3hjabjCz9CeyyOpnQp1lt3lp/Cm22wt+nbzGsvU+Z7Lemkcnnqqa2fo78Nflh5m06y89HwjkUGk+/j3fx/tMteMzf/cE7EDWWZPJEqRwI0ScdOtaV0cJCCGFMderUMZRCMjc3p06dOmzdupVhw4YBsHv37lLVEly8eHGe+sFbtmxhy5YtAIwYMcKQlC7NFT9VgkoFPd/WD8dTG+HyPLUJeAXob1mZcPMYRByHiJOQGgPaNH3C+MZh/foW9uDWVJ8kdmsOtdwkUSxENaNSqfST8GlMiz1CGSAjK29SOTE9V0I59W7b3WRylj7JnKYlKSMr3/5sLSvvV8zbt28zcOBA9u7di6IoqFSqPJfN57SVNDFsbm5epHr669atQ61WM2HCBEObSqVi4sSJDB48mODgYMO8Njl9bVnKzs42fGaoKNqs/COGc05KlNWIYTtLM7o1dOXfc5FsOnlTEsMlJCOGqy4bCzOCnmlJt0YuvPnbKRLTs3jlu6MMa+fN7IFNsdJU3vdnYTzyqhAldiMhjbA4/aVpHepIYlgIIYypd+/e/PLLL3zwwQeoVCrGjRvHzJkzCQ0NRVEUduzYwfTp00u8/9DQ0CKtV5QrfipKudQYBn1itTIkV0014NNBf1MUSIqAW6f0SeKos/pJ69IT4No+/Q3AylmfJK7dHDxa60ciCyFqNHNTE1xsTHCxKf5VHVnZOpLSs/Ikj91si5+crijTp0/n8OHDrF27loceeoi6devyzz//UKdOHYKCgjh8+DB///13ucdx5MgR6tWrl6ceP2CY+O7o0aOGxHBRZWVlkZWVRXZ2NllZWaSnp2NmZkZkZCT79++nX79+mJmZsX79eoKDg1m4cGGZPZ+i0Gbnr5+dfmfyOXPTsis/8ngrD/49F8m+K7GEx6fi5VD02uBCT2oMV30DWnjQytueKT8e58i1eH48FMbB0Dg+Hdaa5p6V+4oOUfGkErUosQNX9aOFNaZqWvvYGzcYIYSo4WbOnMnPP/9sKO8wffp0FixYQFxcHElJScydO5d58+YZOcqKFRgYSFRUFCdOnDB2KOVPpQJbD2jYB7oFwtPfQJ/3oNVz4N4STO7UWUyNgZCdsH8Z/DYets2Hi/9Aapxx4xdCVEmmJmocrDX4OVvTwsuehxu4VOoaw5s2bWLcuHE8++yz2NjoT4yp1Wrq16/P//73Pzw9PZk2bVq5x3Hz5s18k6IDhtHGN2/eLPY+FyxYgKWlJStXruTdd9/F0tKSNWvWAPDRRx/h7u6Oi4sLixYt4ocffqBVq1b33V9iYiLh4eGGW0RERLFjyi0zO3/JksR0LQDmZmWXlujd1A07SzMUBX4+HF5m+61JpJRE9eDlYMVP4zsypVcD1CoME9Mt330Vna7oJVVF9ScjhkWJHbiq/xLZytu+Un8AFEKImsDBwYGAgADDY5VKxcyZM5k5c6YRozKuchsxXBWo1eBUT39rOgiysyDuCkSehlunIeYi6LL0jyNPw+EV4NwA6nQH34f0tZSFEKKaiY2NpUWLFgCGGr4pKSmG5f369StxGYniSEtLK7DuvlqtxszMjLS04k8YWNj8OwB79uwp9v6WLFnCO++8U+ztCpOVnT8RlZSuP5ldVqUkQF+v+MnWnqzcF8rPh8OY3KsBJlIToVhyym7Lr63qMzVR8+qjDelc35mpPx7j5u10Fvx5jt2XYlj8TMsSXSkiqh8ZMSxKzFBfuI7jA9YUQgghKl6NGjH8ICam4NIImj8Nj8yBp76CTpPAu/3d0cQxl+DQ17D+Jdi3FCLP6ktUCCFENVG7dm1u3boFgI2NDTY2Npw/f96wPC4urkJOJlpYWJCRkZGvXafTodVqsbAwfjmOadOmERYWZrgdPHiwVPu7t5QEQGLanRHDZVhKAmBoO/2Eszdvp7PrYnSZ7rsmMIwYlsxwtdG+jiObp3Sl/51J6HZejKbfJ7vYcSHKyJGJykBGDIsSiUxMJzQ2FYAOMvGcEEJUClFRUXz33XdcuXKF+Ph4Q424HCqViu+++85I0YlKRWMNfl30t6xMuHUSQnbBjSOQrYXQ3fqbTW1o9BjU6QZmxk9UCCFEaXTs2JGdO3fy1ltvAfoRwosXL8bDwwOdTsdHH33EQw89VO5xeHh45JnQNUdOCYmiTGBX3mxtbbG1tS2z/RVUSqI8RgwDNHG3paW3PSfCEli1P5QejV3LdP/VXc7HR6kkUb3YWZmx9LnWdD3szNzfzxKTnMmYbw/xYpc6vNG3UZmfoBFVhySGRYkcu54A6C8vaeVtb9RYhBBCwO+//86wYcNIT0/HxMSEWrVq5Vunpk0iUqNLSRSHqQa82upv6bchdA9c2Q63wyHplr7MxMl10OBRaNAbrORKISFE1TRx4kTWrVtHeno6FhYWLF68mEcffZRRo0YB0KBBAz755JNyj6NNmzZs376duLi4PBPQHThwwLC8utEWWEqi7GsM53i+kx9TfzrOjgvRXIpMooGbTLZaVFJjuPpSqVQMbedDWz9HJn1/jLMRiXyzJ4T/rsay9Lk21HGWUmI1kZSSECVyPCwBgIZuNliby/kFIYQwtmnTpuHl5cWBAwfIzMwkPj4+3y0urmZNMCalJErAwg4a94fHFkPvBeDbCVRqyEyGM+vh90nw3xf6hLEQQlQxXbp04dNPPzWUavDy8uLs2bMcO3aMU6dOcfbsWRo0aFDucQwZMgSdTsfnn39uaFMUhaVLl+Li4kKPHj3KPYaKVmApCcOI4bIfqdi/hTvudvq/8/LdIWW+/+rsbmLYyIGIclPPpRbrJ3RibJc6AJy5mcjAz/aw6UTxJ74UVZ9k9ESJnLiTGJbRwkIIUTlERETwwQcf0K5dO2OHIqoDlUo/GZ3zFGg1HC78BZe3QVY6XN2hLzvh9zA0f0pfbkIIIaoolUpFy5Yty2x/S5cuJSEhgYSEBACCg4PJytInQCdNmoSdnR0BAQGMHDmSOXPmEB0djb+/Pxs2bGDHjh2sWLGiwInpqjptVkGlJHJqDJf9eDUzEzVjOvmxcPN51h+7wWu9G+JqKyWRikJ3Z3C3jBiu3sxNTXh7QFM613dm2rrjxKdqmfTDMf67GsusAU2xMJPSEjWFJIZFsWXrFE7duA1IYlgIISqLNm3aEBUlE0iIcmDtDG1G6Seuu7Idzm3Sl5wI2amvQywJYiFEJXX9+vUSbefj41PiYy5evDhP/eAtW7awZcsWAEaMGIGdnR0Ay5cvx9fXl5UrV/Lll1/SsGFD1qxZw4gRI0p87MpMq8tfSiIxTZ8wL68E1LMdfFi6/TJJGVl8vuMKcx9vVi7HqW5y5qioaSXIaqoejV35a8rDTPr+GIevxfPdgescvZ7AsudaU9clf2k6Uf1IYlgU25XoZJIz9J14S0kMCyFEpfDhhx8yaNAgunfvTs+ePY0djqiONNbQZKC+zvDlf+HsxrwJ4no99cljqUEshKgk/Pz8SpTcKk1t+tDQ0CKtp9FomD9/PvPnzy/xsaqSgkYMp2n1v+fyGDEMYGthxosP1+Hjfy/x/YHrjO9aFw97y3I5VnVyd8SwceMQFcfdzpIfx3dkydaLfL7jCuci9KUl3nvKn0GtPI0dnihnkhgWxZZTX9hKY0JDKeIvhBCVQvv27fnoo4/o3bs3Pj4+eHt7Y2KSdwSOSqVi27ZtRoqw4snkc+XE1Fxfh7j+I3kTxJf/1ZeYaDxAn0DWWBk7UiFEDbdixQoZ9VhJ5NQY1pioybyn3nB5JYYBXuhSh5X7QklI1fLZ9kssfKpFuR2rupDJ52omUxM1b/RtTPs6jkxbd4K4lEym/Hic/67GMWeglJaoziQxLIotJzHc3NMOEzmNKIQQlcLatWsZM2YMiqKQnp4uZSXQTz4XGBhIeHg43t7exg6n+smdIL7wlz5BrE2DM7/B5a360cP1HwETM2NHKoSoocaMGWPsEMQdOclgjakaVJCZawSxeTkmnGwtzBjftS6L/r7AT4fCGPWQH03cbcvteNWBjBiu2bo3cuWvyQ8z+YdjHAyN44eD1zl2PZ5lw9tQT0pLVEvld2pOVFsnwxMAqS8shBCVyezZs2nZsiUhISHcvHmTc+fOFXgTosyZmkOzJ+Hxz6DRY6A2hYwkOLIS/ngVQveAkr+2pBBCGFtSUhLJycnGDqNG0Gbr+wEzExXmJnnTEOU5Yhjghc518LS3RKfAO5vOGGroioJJjWFR286C78d1YEKPegCcv5XEwM/2sPH4DSNHJsqDJIZFsWRl67gYqf/w1MxDzrQKIURlERkZydixY0s1YY4QpWJuAwGjYcDH+gnpUEFKNOz7DP6eAREnjR2hEEIQFhbGmDFjcHFxwd7eHjs7O5ydnXn++ecJCwszdnjVVk4pCTMTtX7UcC7lnRi2MDPh7f5NAPjvahx/nIwo1+NVdVJKQoC+tERgn8aseqE9jtYaUjOzmfLjceZsPJ1nxL+o+iQxLIolNDbF8CYgl+AIIUTl0bFjR0JCQowdhhBQywU6TYR+74N7K31bfCgEvwvbF0CcvE6FEMZx6dIlAgICWLt2LQEBAUyZMoXJkyfTrl071qxZQ9u2bbl8+bKxw6yWsnIlhu9NBJubln/t0r7Na9OpnhMAc38/Q1xKZrkfs6qSUhIit24NXfhr8sO083MAYNX+awz9aj8Rt9OMHJkoK5IYFsVyNiIJ0E8aUMfZ2sjRCCGEyLF06VJ+/vlnvv/+e2OHUmkEBQXh6upKy5YtjR1KzeTgBz3ehJ6zwFF/KSK3TulHD+9bCsnRRg1PCFHzzJgxg+zsbA4fPszff//NkiVL+Oijj9i8eTNHjhwhOzubGTNmGDvMainzTikJjWkBI4bNyj8toVKpeO9JfyzM1MSmZDLn9zPlfsyqSkYMi3vpS0t05MUudQA4dj2BAZ/uYd/lGCNHJsqCJIZFsZyPSASgvmstzEzk5SOEEJXF008/TUZGBiNHjqRWrVo0atSIpk2b5rk1a9bM2GFWqMDAQKKiojhx4oSxQ6nZajeHPu9C5ylQy1XfFrob/pgKR1fr6xELIUQFCA4OZvLkybRq1SrfspYtWzJx4kS2b99e8YHVADmlJEzVqnwjhMu7lEQOP2drAvs0BmDTiZusPxZeIcetanJKMEteWORmZqJm1oCmLHuuDdYaE2JTMhnxzQG+2HFF6nZXcabGDkBULedv6b+8NXa3MXIkQgghcnN1dcXNzY1GjRoZO5Qy1717d/777z9MTfUfWzp16sSWLVuMHJUoFpUKfDuBV3u4vBVO/6pPCJ//E64EQ9NB+onrTDXGjlQIUY1lZGRgZ2dX6HJ7e3syMjIqMKKaQ5t1t5SEyT01ChrVrrgShWM6+bH17C3+uxrHW+tP4+9pT33XWhV2/KpARgyL++nfwp1GtWvxf2uPcjkqmQ/+Ps/R6/F8OKQlthZmxg5PlIAkhkWx5IwYbir1hYUQolLZsWOHsUMoV8uXL2fEiBHGDkOUlokpNOoHdbrBud/h/B+gTYUTP8ClLeD/jH6ZWq5KEkKUPX9/f9asWcNLL72EpaVlnmUZGRmsWbMGf39/I0VXvRkmnzNVY5orMexpb4mfk1WFxWGiVvHJsNY89sluYlMyeXntEX57pRM2ktAyuJsYNnIgotKq72rDhgmdmf7rSf48GcHWs5E8/tkevhwZQOMKPNEjyoZ86hZFdjtVy83b6QDyzy6EEEKIktNYQcthMPBTqNcLUEFqLBz4Eja/ATeO3L2WVQghysjMmTM5fvw4bdu2ZdmyZfz777/8+++/LF26lDZt2nDixAneeustY4dZLRlqDJuo8pSOeLiBM6oKHpnqZmvBx8NaoVbBpahkJv9wjGyd9Dk5dPocPmrJDIv7qGVuytJnWzNrQFNM1SpCY1N5YtleNp24aezQRDFJYlgU2blbiYb7UkpCCCGM6+LFixW6bXJyMnPmzOGxxx7DxcUFlUrF+++/X+C6Wq2W2bNn4+Pjg4WFBS1atCj1pHivvvoqLi4u9OzZk6NHj5ZqX6ISsXKEDuOh/2LwbKtvux0GOxfB1lkQcVISxEKIMvP444+zdu1a4uPjmTRpEn369KFPnz5MnjyZ+Ph41q5dy8CBA40dZrWUlX23lETmnbISAF0aOBslnocbuPB2/6YABF+IZuZvp9BJchjAUC+2ohP2oupRqVS82KUOP4zviKuNOelaHZN+OMaiv8/LyZYqRBLDosgu3Kkv7FxLg3MtcyNHI4QQNVuzZs0YNmwYO3fuLNKED4qiEBwczDPPPEPz5s2LfbyYmBjmzZvHqVOnaN269X3XHT9+PO+++y5PPPEEn332Gd7e3gwfPpzVq1cX+7gAixYtIiQkhOvXrzNgwAD69etHQkJCifYlKik7L+gWCI/MBecG+raYSxD8Lmx7ByLPGjU8IUT18eyzz3L9+nX279/P999/z/fff8/+/fu5fv06w4YNM3Z41ZY2V2L46PV4Q3vnesZJDAM839mP4R18APjpcBhzN52RSbT4//buPCyquv3j+HuGXVlFVkFw33fTtDJ3y0rNfMxyeazUFpfMou1X7mWFWRmZmZlpqy1q2aYWauXjvuaWuyIooiKi7Mzvj5FRAhFwYAb4vK7rXM2cOcvNCfnCfe5zfyEnn6eCYSmsm8Kr8MPoW2lR3RuAWasOMuyTjSSlZtg2MCkU9RiWQjsQnwxAHX9VC4uI2NrGjRt58cUX6dSpE4GBgXTp0oVWrVpRs2ZNfHx8MJlMnDt3jsOHD7Np0yZ+//13Tp06RY8ePdiwYUORzxcUFMSJEycIDg7myJEj1KhRI9/ttm7dyvz585k8eTIvv/wyAMOGDaNz585EREQwYMAAnJ3NE4x1796dNWvW5Huchx56iPfffx+ANm3aWNaPGzeOefPmsXbtWnr27Fnkr0PsnH8D6DYFYrfAjq/h3GGI32NODgc2gSb9wa+uraMUkTLO0dGRtm3b0rZtW1uHUmHktJJwcjDgaDSSfjlR7FPZdpOOGgwGJvduzMW0TJZsi2XB/47i4mjkxZ4NKnS1rCafk+II8HTlyxE38/KSv1m0KYbofafp895ffDikNbX8NMGjPVNiWArtUII5MVzLv7KNIxERkebNm/PTTz+xa9cu5s2bx5IlS/jss8+AK4/+5VS91KhRgwEDBvDwww8Xq1oYwMXFheDg4Otut2jRIoxGIyNHjrSsMxgMjBo1in79+hEdHU2PHj0AWL58ebFiMWpisvLNYIBqrSC4JcRsgp2LIPEYnNxpXoKaQ9P+4FvL1pGKiJ07duwYANWrV8/1/npythfrubpi+M3+zXhx8U4m3tPIxlGZJ6Ob/p9mpGdl89POk3z4x2EcjEaeu6NehU0Oq2JYisvF0YHX72tKo2AvJi/bzaHTF+kT9RczH2hBp/r+tg5PrsGqiWGTyVRhf3hWBIdOXwSgZlXd7RERsReNGjXizTff5M033yQuLo69e/eSkJAAQNWqVWnQoAGBgYGlFs/mzZupVasWVapUybU+pypry5YtlsRwYSQmJrJx40Y6dOiAwWBg9uzZnDx5knbt2l1zn6SkJJKSrvTFj4uLK+JXIXbBYIDQmyCkNRxbBzu/hqQTELfNvAS3gEb3gl89W0cqInYqPDwcg8FASkoKzs7OlvfXk5WVVQrRVSxXJ4bvaRbM3U2D7CZ34Ohg5J0BLUjP3MLKPaeYvfogyWkZTO7VuEJOwKYew3IjDAYD/20fTh1/d0Z+voVzlzJ4+JONPNO9Hk90rKXvKzt0w4nh48ePs2TJEpYuXcratWtp27Yt9957L7169SI8PNwKIYo9uJiWSdz5VABq+qliWETEHgUFBREUFGTTGGJjY/ONIafaODa2aDMVZ2Rk8MILL7B3716cnZ1p1qwZP/30Ez4+PtfcZ8aMGUyaNKlogYv9MhggrB2EtoWjf5kTxMmnIHarefFvYE4QBzY1bysictm8efMwGAw4OTnlei+l7+rEMNhf0tHJwch7A1sw9stt/Pz3ST5dd4yklEze7N/MEnNFoVYSYg3ta1fl+1G3MnzBJvaevEDkr/vYHZdEZL+mVHJW8wJ7Uqz/Gzt27GDJkiUsWbKE7du34+HhwZ133klkZCRr165l4sSJPPXUUzRp0oQ+ffrQu3fv605UI/btcMJFy2v1hxERkWtJSUnBxSXvBKVGoxEnJydSUlKKdDw/Pz82bdpUpH3GjRvHsGHDLO/j4uJy9SmWMspohBq3QVh7OLoWdi+B8zHmHsTxe6BKTWjYB0LbKEEsIgAMHTq0wPdSejIu9xh2drTfn88ujg68+0ALXly8k0WbYvh+eywXUjOYNbAVbs4Otg6v1KiVhFhLaJVKfPdEeyK+3sGPO+P4cUccx85cYu5/WxPg6Wrr8OSyIt/6qlevHi1atODDDz+kXbt2/Pzzz5w+fZovvviCkSNH8tlnnxEfH8+vv/5Khw4dmD9/Pq1atVL1cBl38LS5v7CLo5FgbzcbRyMiIvbK1dWVtLS0POuzs7PJyMjA1bXkfwn09PQkJCSEL774gpYtW3LHHXeU+DmlFBkdzAnintOhQwRUudxr+Owh+HMG/DgODqyEzHTbxikidmfFihVkZ2fbOowK6d8Vw/bK0cHI6/c1Zfht5kl2o/edZsi89ZxPybBxZKVHFcNiTZWcHYl6sAXPdDdPHrzzxHl6R/3FrtjzNo5MchT5p/LgwYPZsGEDx48f57333qN79+6WR3NyODo60rVrV2bOnMmRI0fYvHkzDz30kNWCltKX01+4RtXKOOjWoYiIXENwcHC+PX1zWkgUZgI7a4mIiCA+Pp7t27eX2jmlFBkM5v7DPV6BTv8HAZcnMUqKhQ0fwtKRsONrSNUfHiJi1qNHD4KCghg9ejR//fWXrcOpUMpKYhjMbS5e7NmAiB7mHvYbj5zjgTnrOH0h743v8uhyXlgP34jVGAwGRnWuw6yBLXFxNHIyKZX/zP4fK3efsnVoQjESwy+99BKtWrUq0j4tWrRgwoQJRT2V2JFDl1tJqL+wiIgUpGXLlhw8eJCzZ8/mWr9+/XrL56UlMjISf39/mjVrVmrnFBswGCCoKXQZD92nmnsRY4C0JPj7G1jyBKybbW47ISIV2qJFi7jtttuYN28eHTp0ICwsjGeffZbNmzfbOrRyLyPTnG10dCgb2UaDwcDITrWZ2qcxBgPsjkui/wf/I+bcJVuHVuJUMSwlpWeTIL56tB1V3V24lJ7F8IWbmPvHIcuEh2IbRU4Mp6amMmbMGD799NOSiEfs1MF4cyuJmlXVX1hERK6tf//+ZGdnM2vWLMs6k8lEVFQUfn5+dOrUqdRiUcVwBVS1Dtw2DnrNhHp3gqMLZGfCoWj48Wn4fSoc3wjZWbaOVERsoF+/fnzzzTfEx8ezYMECmjZtysyZM2nTpg1169ZlwoQJ7N6929ZhlkvplyuGnctAxfDVBt0cxjsDWuBoNHA44SL/mf0/DsRfsHVYJUqJYSlJzUO9WTKyPfUDPTCZYOqPe/i/JX9bniqQ0lfkyeeioqJ47733uPvuu0siHrFD2dkmy+RzqhgWEbE/KSkpjBw5kp49e9KvX78SO09UVBSJiYkkJiYCEB0dTWZmJgCjR4/Gy8uLVq1aMXjwYCZMmMDp06dp0qQJS5YsYdWqVcybNy/fielKSmRkJJGRkWRlKQlY4bj7Q6uh0OQ/cPB32PczXDoDJ3eal0q+ULsr1OoMbt62jlZESlnlypUZOHAgAwcO5Pz583zzzTcsWrSIV199lVdeecUyton1lKVWEv/Wq1kwHi6OPPbpZuLOmx+B/+ThNjQN8bZ1aCVCk89JSQvxqcTXj7Vj9BdbWbXvNJ+vP8bxs5eYNbAlHq5O1z+AWFWRfyp/9dVXdOvWje7duxe4XWRkJK1bt9Yd13Lg1IVUUjLMf1TX9FPFsIiIvXFzc2PRokWWhG1JmT59Oi+//DJvvvkmAMuXL+fll1/m5Zdf5ty5c5bt5s6dy4svvsh3333HyJEjOXr0KAsXLiz1+QZUMSw4V4YG98A9M+GWseDfwLz+0hnY8ZW5D/Ff70D83itNFUWkQnFzc8PX1xcfHx+cnZ31SHMJycwyX9eymBgG6FTfn0+HtcXD1ZFzlzJ4YM461h06Y+uwSkTOvwGDKoalBHm4OjF3SGuGtg8H4I/9CfT/YB3xSam2DawCKvJP5T179nDnnXded7vRo0dz/Phxvvjii2IFlpGRwfjx46levTqurq40bdqUzz//vFD7mkwm3nnnHerWrYuLiwt169Zl5syZeQb5uLg4nn/+ebp06YKXlxcGg4Evv/yywON+/PHHtGnThsqVK+Pj40O7du349ddf82y7fPlybr75Ztzc3AgMDGTMmDEkJycX7SLYiWNnrvRRCvetZMNIRETkWtq0aVPiCdAjR45gMpnyXcLDwy3bOTs7M2XKFI4fP05aWho7d+5k0KBBJRpbftRjWCwcHCGsHXSdCD0joU63K20mjq6FlRPg52fhn+WQftHW0YpICcvMzOSnn35iyJAh+Pv7c9999/HHH38wfPhw1q5da+vwyqUrFcNlN9l4U3gVvhxxM1XdnbmYnsWwTzbx94nyN8GpKoaltDg6GJnYqxET7mmIwQB74pK4d9ZaDsSXzdxZWVWs23WVK1+/nYCrqyt9+/Zl+fLlxTkFI0aM4JVXXqFPnz68++67hIaGMnDgQBYsWHDdfSdPnszYsWO5+eabee+992jbti1PPvkkU6ZMybXdvn37eP311zly5AjNmze/7nFHjhzJ8OHDqV+/Pm+99RZTpkyhefPmxMTknszk999/p2fPnjg4OPD222/zyCOP8OGHH9K7d+8yeQf62FlzYtjD1REvN5X1i4jYo5kzZ/Ldd9/x/vvvk5GRYetw7IIqhiVf3tXhpmHQZza0fgS8QszrE4/Bpo9g8aOwNgpO7VIVsUg5s2LFCoYNG0ZAQAD33HMPP/30E/fffz8rV64kJiaGt99+m7Zt29o6zHLJ0mPYsWxWDOdoFOzFokfbUdXdmeS0TIZ+vDFXIVV5oB7DUtoeuqUGUQ+0xNnByInEFPrNXsvmo2evv6NYRZF7DIeEhLB3795Cbdu0aVMWL15c5KC2bt3K/PnzmTx5Mi+//DIAw4YNo3PnzkRERDBgwACcnZ3z3ffkyZNMmzaNhx9+mI8++siyr4ODA6+++iojRowgMDAQgFatWpGQkICvry+rVq0qcEKcxYsX8/7777No0SL+85//FBj/U089Rd26dYmOjrbEWadOHR566CGWLl1Knz59inpJbOr4uRQAQn0q6XESERE7lTPp26hRoxg7dizBwcG4ubnl2sZgMLBr1y4bRVj61GNYCuRcCep2N1cPx++B/cshZiNkZcCRP8yLewDU7GheKlWxdcQicoN69OiBh4cHvXv3ZsCAAXTv3h1HxyL/SSzFUJZ7DP9bTT93Ph7ahvvn/I+E5DSGzFvPN4+3p6p76c2jUFLMT4KZXxvL/v8qKUPuahqEr7szwxdsIvFSBg9+uJ6oB1vSrWGArUMr94r8T7179+4sXLiQixev/5idwWAoVr/DRYsWYTQaGTlyZK5jjRo1ivj4eKKjo6+579KlS0lLS2P06NG51o8ePZq0tDSWLl1qWefh4YGvr2+hYpo+fTqtW7fmP//5D9nZ2ddsC7Fv3z527NjBiBEjciWvBw0ahLe3N1999VWhzmdPjl+uGK5eRW0kRETslb+/P/Xr16dDhw60b9+e8PBwAgICci3+/v62DrNUqWJYCsVggICGcOtY6PM+tPyvuaoYIPmUuRfxkidg1WtwbD1kaVIqkbLq66+/Jj4+ngULFtCzZ08lhUtRxuUew47lpD9BkxAvZg9qhaPRwJEzl3hk/kZSM8r+jeirH5RRUZiUtptr+vLNY+0J8nIlLTObRxdu4rP1R20dVrlX5MTwmDFjuHDhAv369ePSpYIfmdi2bRtBQUFFDmrz5s3UqlWLKlVyV2bkPNazZcuWAvd1cXGhadOmuda3aNECZ2fnAve9luTkZNatW0fbtm158cUX8fb2xsPDg2rVqhEVFZXn/GDu9Xg1R0dHWrVqVazz21pOYji0itt1thQREVtZtWoV0dHR111EpACunlC/J9z5BvR4FWp3BSc3wASxW+HPGbDkcdj8CZw9rFYTImXMfffdh4tL2a/qLIsyMstHK4mrdajrx/T/mOcx2B5znpeW/F0mW0deLfuq+NVKQmyhXqAH3z7enroB7mSb4P8W/81bK/4p8/+27FmRfyrXrl2bWbNmsWLFCpo3b853332X7yOa69evZ/78+dxxxx1FDio2NjbfhHJwcLDl84L2DQgIwPiv5x6MRiMBAQEF7nstBw4cIDs7m6+++ooPP/yQV155ha+++oqmTZsyevRo3n333VznB64Zf2HOn5SURExMjGWJi4srcszWdEwVwyIiIlKRGAzgWwvaDId7P4CbnwD/BubP0pJg30/wy/PwUwTsXgqX1AdPRKQg6eWolcTV+rSoxujOtQH4ZnMMn60/ZuOIbkz2Vbm3clLcLWVQsLcbXz/WnjY1zMWi7/y2n0k/7CY7W8nhklCsZ2eGDh2Kh4cHjz32GP/5z3/w9/enS5cuhIWF4ezszN9//82SJUvw9vbm//7v/4p8/JSUlHzv5BqNRpycnEhJSSnyvmCeEK+gfa8lp21EQkICf/31F+3btwegX79+tGzZksmTJ/PEE0/g4OBgOX5+MRT2/DNmzGDSpElFjrMkpGZkEX8hDYAQJYZFROxadnY2CxYsYNmyZRw+fBiDwUB4eDj33HMPgwcPznPTtLxTj2GxCkcXqHm7eUmKg0OrzP2HL52B88dh2+ew7QsIbAzht0FoW3BytXXUIiJ25ZV7m5CemU3jYC9bh2J1Y7vWZXvMedb8c5pJP+yiYbAnLav72DqsYlHFsNgLLzcnFjzchtFfbGXF7lPMX3uE5LRMXuvbBMdydoPJ1op9Ne+77z727NnD888/j6OjI59//jnTpk1j0qRJfPvtt9x0002sWrWKkJCQIh/b1dWVtLS0POuzs7PJyMjA1fXav2xfa1+A1NTUAve9lpwkb40aNSxJYTAnqh944AESEhIsE/LlHD+/GAp7/nHjxnH8+HHLsmHDhiLHbC0x5660C1HFsIiI/bpw4QK33XYbjzzyCCtXrsRkMpGVlcXKlSt5+OGH6dChwzX745dX6jEsVucZBM0fgN7vQeeXzZPSOboAJji5E9bNgu+Gw9ooiNsO2dm2jlhExC70axXCg22rU923/P1N6WA0MHNAc0J83MjIMvHEp1tISM4/J2HvcvcYtl0cIgCuTg7MGtiSe1tUA8xV+aO/2Epapoo+rOmG0uxVq1bllVde4fjx4xw4cIDo6GhWrlzJkSNHWLt2LY0aNSrWcYODg/Ntn5DThiGnpcS19j116hTZ//pFPDs7m1OnThW477XkJLcDAvLOhpizLmeSvZzjXyv+wpzf09OTkJAQy1KcPs3WktNGwmCAat7qMSwiYq9efvll1q9fz1tvvUV8fDxbtmxh69atnD59mrfffpt169bx8ssv2zpMkfLBYDBXCN/8OPT9ENqPgaDmgAGy0s0VxdGvwtInYMtC9SMWESnnvCs5M3tQK1wcjZxMSuX5b3eUyZ6oqhgWe+PkYOTN/zRj8M1hAPz890mGfbKJS+maDNharFZ/XbNmTW6//XY6d+5M9erVb+hYLVu25ODBg5w9m7tf2/r16y2fF7RvWloaO3bsyLV+69atpKenF7jvtQQFBREUFERMTEyez06cOAFgmek95/j/rvLNzMxky5YtxTq/LR0/a259EeDhiquTg42jERGRa/n222959NFHGTNmDM7Ozpb1Tk5OjB49mhEjRvD111/bMEKRcsrRBcJvgU4vwL3vQ8sh4BNu/izlHOxdZu5H/OPT8Pe3cOGkTcMVqcjWrFnD6dOnr/l5QkICa9asKcWIpDxpXM2L8fc0BGDlnngWbTpu44iKTolhsUdGo4HJvRsxslMtAP7Yn8DQeRtJTlNy2BrssjFH//79yc7OZtasWZZ1JpOJqKgo/Pz86NSpE4ClhcOlS1faHfTu3RtnZ2eioqJyHfPdd9/F2dmZ3r17FyumAQMGEBMTw6+//mpZl56ezsKFC6lRowa1a5sbzjdo0IDGjRszZ84c0tPTLdt++umnnDt3jv79+xfr/LaiiedERMqG06dP07hx42t+3qRJExISEkoxItuLjIzE39+fZs2a2ToUqSjcfKD+XXDn69AzEhr0AjfzxCkknYAdi+CHJ+HX/4N9P5sTxyJSajp16sSKFSuu+flvv/1m+VtTpDgebFOdTvX8AJj0w26Onrlo44iKRpPPib0yGAxE9KjP83fWB2DDkbMMnbdByWErKNbkcyWtVatWDB48mAkTJnD69GmaNGnCkiVLWLVqFfPmzbP0/I2KimLSpElER0fTsWNHwNzK4bnnnmPKlClkZGTQoUMHVq9ezcKFCxk/fnyetgxTp04F4PDhwwAsXryYAwcOAPDSSy9ZtnvhhRf4+uuv6devH2PHjsXf359PP/2U/fv38+2332K46m7ajBkzuOOOO+jcuTNDhgzh2LFjvPnmm3Ts2JF77723xK5bSTh+OTEcUkVtJERE7FlYWBgrVqzg8ccfz/fzFStWEBYWVspR2VZERAQRERHExMQQGhpq63CkovGuDi0GQvMHIX4PHP0Ljv0P0i/CmQPmZfMn5pYUYbdAaBtwrmzrqEXKtes92p+enl7hJmoV6zIYDLzeryk93lrDuUsZjFu0nUWPtsOhjGRZr/43YlDFsNihx26vhZODkSnLdrPp6DmGfLSeTx5ug4erk61DK7PsMjEMMHfuXMLCwpg/fz6zZ8+mbt26LFy4kEGDBl1330mTJuHj48N7773Hl19+SWhoKDNmzGDs2LF5tv13v8VFixaxaNEiIHdi2M/Pjz///JNnn32WqKgoUlJSaNGiBcuWLePOO+/MdYxu3brx448/Mn78eJ588kk8PT155JFHmDZtWpn74RpzztxKItRHFcMiIvbs4Ycf5oUXXmDo0KFERERQt25dAPbt28eMGTNYunQpr732mo2jFKmADAYIaGheWj0EJ7fDkT8hZpO5H/HJneZl41yo1hLCboXgFuDofP1ji8h1JSUlWeaDAThz5gzHjh3Ls925c+f44osvqFatWilGJ+WRv4cr0/o24bFPt7D56Dk+WHOQJzrWtnVYhaKKYSkLHrm1BkaDuSp/y7FEhszbwCcPt8FTyeFiMZjKYkf0Cian0un48eOWifBKS/PJy0m8lMEb9zWl/02qthIRsVcmk4nHH3+cOXPmYDAYLDciTSYTJpOJRx99lPfff9/GUdqGLcdRkWvKSIUTm+DIXxC3HUxXzbDt5AYhbcy9iwMag1HzPIgU16RJk5g8eXKhtjWZTLzyyiu88MILJRxV2aOxtOieXrSdb7fE4OxoZPnYDoRXtf+nQhKS02g9dSUAS0feQrNQb9sGJFKAT9YeYcL3uwBoFurNgofb4OWm5HBRlWjFcOfOnQkODubFF1+kYcOGJXkqKQEX0zJJvJQBQLC3WkmIiNgzg8HA7NmzGT16NMuWLePIkSMAhIeHc9dddxXYf1hEbMDJFcJvNS9pF+DYejj6p7ntREYKHF5tXlw8Iay9eTvf2uYKZBEptK5du+Lq6orJZOLFF1/k/vvvp3nz5rm2MRgMVK5cmdatW9O2bVvbBCrlzvi7G7L6n3gSktN5acnfLHykjd0/QazJ56Qs+W/7cIwGeHnpLrYfT+S/8zbw6bC2uLvYbXMEu1SiV2vVqlUAfPnllzzwwAMsXLiwJE8nVhZ3PsXyOtjb1YaRiIhIQVJSUrjrrrsYPHgwDz30EI0aNbJ1SCJSFC4eUKerebl4xtyP+OhfcO4IpCXBP7+Yl8p+V5LJXqrYEymMW265hVtuuQWAtLQ07rvvPt0slVLhVcmJl+9uyJNfbuPPAwks2XaCe1vY98/uq58nV15YyoLB7cIxGAy8tORvth1PZNgnG5n/UBtcnfS0VWGVaGf97OxsLly4wPfff59n0jexf7GJqZbXqhgWEbFfbm5ubNmyhcxMzcp7tcjISPz9/WnWrJmtQxEpvMq+0LAX3Pk63PUmNL4P3APMn108DbsWw49Pw08RsGsJXEywabgiZcmECROUFJZS1atZMLfVqQrAlGV7OHcx3cYRFUwVw1IWDbo5jJfuagDAukNneeKzLaRnZts4qrKj2D2GMzIyuHTpEl5eXtaOSf7FVv2cvtxwjOe/24lvZWc2v9yt1M4rIiJF16dPH7y8vPjkk09sHYrVzZ8/n1deeYXY2FiqV6/O0qVLLZPrFYb6IkqZZzLBmYPmVhNH/wepibk/96tvriIObQuunjYJUaSsSExM5IsvvuDQoUOcPXuWf/85bDAY+Oijj2wUnf3SWFp8x85cottbq0nLzKZ/6xDe6Ge/N6xPJKZwy2u/A/DL2NuoH6gxRcqOd1bu562V/wBwV9MgZg5ogYNmUbyuIreSOHv2LEOHDuWXX34hKyuLGjVqMGnSJAYOHFgS8YkNxSaaW0kEqY2EiIjde/fdd+nevTvPPfccjz/+ONWrV8doLNEHg0rFjz/+yMSJE/nmm29o1aoVhw8fxtvb29ZhiZQugwGq1jYvLYbAqb/h6Fo4vs7cj/j0XvOy6WMIamZOEldrZe5jLCIWv/32G3379uXChQt4enri4+OTZxt77wErZU9130o82bUOb/yyj0WbYnigTXVaVM/7vWcPsrNVMSxl15gutbmYnsmcNYf4cUcclZwceP2+phiVHC5QkRPDL7zwAsuWLaN169YEBgayYcMGhgwZgtFo5IEHHiiJGMVGTlxuJRHspTYSIiL2rm7duphMJqZPn8706dMxGo04OeWelddgMHDx4kUbRVg8kydPZsKECbRu3RqAmjVr2jgiERszGiGoqXlp/TDEbjVXEp/YAtmZELvFvDg4Q8hNEH4LBDYDB03EIjJu3Dh8fX1Zs2aN2gxJqRp+W02+2RzDodMXmfTDbr57vL1dJquuLqC3w/BECmQwGHjhzvpcTMvks/XH+HpzDB6uTrx8dwPd9CtAkX9D/Pnnn7n//vv54osvAEhOTubuu+/m//7v/5QYLmdyKobVX1hExP7df//9JfYLT3JyMpGRkWzcuJGNGzeSkJDAtGnTeP755/Nsm5GRwZQpU5g/fz7x8fHUrVuX559/ngcffLDI583KymLLli2cPHmSWrVqkZGRwZAhQ5g8eXK5qIYWuWGOzlC9rXlJvwgxG+HIn3Dyb8hKvzKJnbM7hLWDsFvBr55mFJIKa9++fbzxxhtKCkupc3Iw8vJdDXlo/ka2HU9k6Xb7nIju6h7DSqRJWWQwGJjSuzGX0rNYvPUE8/46TKCXCyM61LJ1aHaryInhEydO0KNHD8t7d3d3Jk6cSJcuXTh48CC1aulilxdx582J4WpKDIuI2L358+eX2LETEhKYPHkyISEhtGjRghUrVlxz2xEjRrBgwQJGjhxJkyZNWLJkCQMHDiQzM5MhQ4YU6bynTp0iMzOT77//nnXr1pGWlkaPHj2oUaMGjzzyyI1+WSLli3NlqNnRvFw6C8fWmZPCZw5AejLsX2FeKleFsFvMi0+YraMWKVU1atQgJSXF1mFIBdWpvj+31/Vj9T+nef3nffRoFEglZ/t6mkOTz0l5YDQaeKNfU85dSmfVvtO8+tNeAjxd6d28mq1Ds0tFLrcxmUw4OzvnWle/fn1MJhNxcXFWC0xsKzvbROz5y60klBgWEbFrKSkpODg48Oqrr5bI8YOCgjhx4gTHjx9nzpw519xu69atzJ8/n4kTJzJz5kyGDx/OsmXL6NixIxEREaSnX5mJu3v37ri6uua7PP744wC4uZnHnyeeeAI/Pz9CQkJ49NFH+fnnn0vk6xQpNypVgfo9occrcM870OQ/4BFk/uxiAuxeCj8/Cz8+A7sWQ/Jp28YrUkpefPFFZs+ezZkzZ2wditVFRUXRokULHB0dmThxYq7POnbsiKurK+7u7ri7u9O9e3fbBCm8fHcDHIwGTialMnvVQVuHk0e2WklIOeHkYOS9B1vSNMQLgGe+3s5fBxJsHJV9KtbtqdjYWDIyMiy9C3P+e/UffFK2nbmYTnpmNqDJ50RE7J2bmxv+/v54eXmVyPFdXFwIDg6+7naLFi3CaDQycuRIyzqDwcCoUaPo168f0dHRlqeOli9fft3j+fj45HtePdooUgQegdCkHzS+D84dhiOX20uknIPzx2H7l+alal0Iv83clsK1ZH6WiNhaXFwcVatWpU6dOvTv35+wsDAcHBxybWMwGIiIiLBRhMVXrVo1Jk+ezIIFC/L9fO7cuQwaNKiUo5J/q+3vwZB2YXz81xE+WHOI/jeFEuJTydZhWZhUMSzlSGUXR+YNvYn73l/L0TOXeHThZr59vD31Aj1sHZpdKVZi+LnnnuP//u//aNiwIa1ataJ27doYDAaysrKsHZ/YSE5/YVArCRGRsuD+++9n0aJFPP744zbrv7t582Zq1apFlSpVcq1v27YtAFu2bMnVjqowhg4dyvvvv89dd91Feno6H374IU8//XSB+yQlJZGUlGR5ryeaRDD3Fa5S07w0Hwjxu80J4mPrIOMSJPxjXjZ/DIFNIfxW8+R1TioQkPLj6t7413oCpqwmhu+9914Ali5dauNI5HrGdqnLkq0nOHcpgxnL/2HG/c1tHZLF1RXDygtLeVDV3YUFD7fh3llrOXsxnYfnb2TpqFuo6u5i69DsRpETw9HR0Wzfvt2yfPbZZ6SlpQFw5513EhYWRqNGjWjYsKHlv61atbJ64FKychLDTg4G/PQPRkTE7vXp04fffvuN2267jREjRhAeHm5pxXC1Nm3alFgMsbGxBAUF5VmfU/UbGxtb5GNOmDCB+Ph4atSogYeHB4888gj//e9/C9xnxowZTJo0qcjnEqkwjEYIbGxeWj8McdvhyB9wYjNkZUDcNvPi4GRODofdCkHNwMG+emGKFNXhw4dL9Pi2mqy1MJ566imeeuopmjRpwvTp02nZsmWJnEeuz6uSE6M612HKst0s3naCYbfVpGGwp63DAtRjWMqnMN/KzBncigc/XM+JxBRGLNjE58NvxtXJ4fo7VwBF/u3u9ttv5/bbb7e8z8rKYt++fbmSxZs2bWLZsmUAqiQuo3L6Cwd6uWJUcyEREbvXuXNny+v//e9/edotmEymEh+TU1JScHHJezPRaDTi5ORUrAl/nJ2d+fDDD/nwww8Lvc+4ceMYNmyY5X1cXFyJJsRFyjQHJwhpbV7SL0HMRnMlcdwOc5L46Frz4uwOYe0gvANUraNSMimTwsJKdsJFW03Wej1vvPEGDRs2xMHBgffff58777yTffv24e3tbdXzSOENurk6H/91mJhzKbz+y14+edg+fk9RYljKq9bhVXijX1PGfrWNLccSefabHbwzoLla1FHMVhJXc3BwoGHDhjRs2JAHHnjAsv706dNs27aNHTt23OgpxAZyKoaDvdRGQkSkLPj4449tHQKurq6Wp4iulp2dTUZGBq6upfNIuqenJ56enkRGRhIZGakb1CKF5VwJat5uXlISzW0mjvwBZw5AejLsX2Fe3AOgRgdzT2KPAFtHLWI3ciZrDQ4O5siRI9SoUSPf7XIma508eTIvv/wyAMOGDaNz585EREQwYMAAy4Tv3bt3Z82aNfke56GHHuL999+/blxX3xwdN24c8+bNY+3atfTs2bOoX6JYiYujA890r8fYr7ax+p/TrD2QQPvaVW0dFiZNPiflWJ8W1Th0OpmZvx/g++2x1Av0YGSn2rYOy+aKnBg+d+4cPj4+193Oz8+Pbt260a1btyLtJ/bBkhhWf2ERkTLheu0VSkNwcDBHjx7Nsz6nhURhJrCzpoiICCIiIoiJiSE0NLRUzy1S5rl5Q707zMuFU+YE8ZE/4MJJSD4FO782L1XrQo3bofrN4OJu66hFcunUqRNGo5Fff/0VR0fHXE/XXIvBYOC3334r1vlsNVlrUdlqLgLJrVezYOasOcTuuCRe+2UvS0feYvPqxasrhm0di0hJGNu1LgdPX+THnXG8uXwfTap50aGun63DsqkijwihoaF07NiRt99++7o9mo4ePcrMmTPp3LkzAQGqJihLriSGNeGIiEhZc/z4cTZv3kxycnKpnrdly5YcPHiQs2fP5lq/fv16y+elKTIyEn9/f5o1a1aq5xUpdzwCoEk/uPtt6D4V6nQzt5YA84R1Gz+ExY/CH29CzCbIyrRpuCI5TCYT2dnZlvfZ2dmYTKYCl6u3LymFmay1qDIzM0lNTSUrKyvX68TERFasWEFaWhrp6enMnDmTkydP0q5duwKPl5SURExMjGXRRK7WZzQaeP7O+gDsiDnPjzttf42zVTEs5ZzRaOCNfk2p4+9OtgnGfLmV42cv2TosmypyYviTTz4hNDSUKVOmULt2bZo1a8b48ePZunUrANu2bWPSpEm0aNGCmjVrMmHCBAIDA1m4cKHVg5eScyLR3GNYFcMiImXHkiVLqFu3LuHh4bRp04YNGzYA5p6HTZo0YfHixSV6/v79+5Odnc2sWbMs60wmE1FRUfj5+dGpU6cSPf+/RUREEB8fz/bt20v1vCLllsFg7i980zC49wO47RkIbQNGR8jOhOMbYE2kOUm88SNI2J/7uWSRUrZq1Sqio6NxdHTM9f56S0kriclap06dipubG/Pnz+eVV17Bzc2NhQsXkpGRwQsvvICvry+BgYEsXryYn3766bpP886YMYPQ0FDLol79JeO2OlW5pbYvADOW/0NmVsnfmCiIegxLRVDZxZEPBrfCw8WRxEsZPP7ZZlIzKm7ruSK3krjvvvu47777yMrKYvXq1SxdupSFCxdaBp+UlBRCQkLo1asXkZGRdOzY0TIQS9mQlplFQrK5R6QSwyIiZcOPP/7IfffdR5s2bRg0aBATJ060fFa1alWqV6/O/Pnzuffee4t1/KioKBITE0lMTAQgOjqazExzVeDo0aPx8vKiVatWDB48mAkTJnD69GnLZDqrVq1i3rx5+U5MV5LUY1ikBDk4QuhN5iXtgrkf8eE15gri9GTYv9y8eARBjdvMk9a5V+xHNUVylMRkrRMnTsw19l9t06ZNRT6eJnItHQaDgYge9fnrwF8cSrjIkm2x9GsVYrN4TEoMSwVR08+dGfc3Z/iCTfx9IokJS3fxer+mtg7LJoqdsXVwcKBz58507tyZd955h23btvHXX3/Rvn17WrRoYc0YpZSdPJ9qeV1NiWERkTJh8uTJ3HLLLaxZs4YzZ87k+eOwXbt2fPjhh8U+/vTp03P1D16+fLml7+GgQYPw8vICYO7cuYSFhTF//nxmz55N3bp1WbhwIYMGDSr2uYtLPYZFSomLh7m9RJ1u5h7ER/6Ew6shOR4uxMGORebFr7550rrqN4NzZVtHLRVccnIyiYmJ+baOqF69eome214may1IzkSuUvKah3rTtYE/K/fE885v/9C7eTBODrbpA311KwmDWlFLOdetYQCjO9fm3d8P8NWm47Sv7Uvv5tVsHVaps1opb/PmzWnevLm1Dic2dCLxyh3qIC/b/1IiIiLXt3PnTiIjI6/5eVBQEPHx8cU+/pEjRwq1nbOzM1OmTGHKlCnFPpe1qGJYxAY8As39iBvfZ64ePrwGjq6FjEtweq952fwxhLaFmh0hoLG5RYVIKfnoo4944403OHDgwDW3Kelxw94maxXbe6pbXVbuief42RS+3hTDg21L9ubEtWRnq2JYKpYnu9Rh/eGzbDh8lv9b/DfNQ70J861YN691D0jyiL3cX9jD1REPVycbRyMiIoVxreqjHEePHrVU9VYU6jEsYkMGA/jVgzbDoe8cuHUchLQ29yPOyjBXFf8+Fb4fZa4mvnDK1hFLBTBv3jyGDx9OeHg4U6dOxWQyMXbsWJ5//nkCAgJo3rw5H330UYnHYW+TtYrtNQr2omeTQACift9PWqZtbmpr8jmpaBwdjLwzoDk+lZxITstk9BdbSc+0ba/v0qbEsOQRe7liWG0kRETKjltuuYWvvvoq38+SkpKYN29eqU/+ZmuRkZH4+/vTrFkzW4ciUrE5OEH1ttAhAvrMgpb/Be8w82cXE+Dvb+GHMbByIhxaBRmpBR1NpNjefvttunbtyq+//sqIESMAuOuuu3jllVfYtWsX586d4/z58yUeh71N1ir2YWzXuhgMEHs+lS83HLdJDOoxLBVRkJcb0/9j/nthR8x5In/da+OISpdmhZM84s6bE8OaeE5EpOyYNGkSt956K926dWPgwIEAbNmyhX/++Yfp06dz/vx5xo8fb+MoS5d6DIvYIVcvqN/TvJw9bE4EH/nTPGFd/B7zsmkeVG9vbjXhV0+tJsRq9u/fz6OPPgqYJ3oDSE9PB8DHx4dhw4bx3nvv8eSTTxb7HGVxslaxD3UDPOjVLJil22J5L/oA998UiquTQ6nGkKvHsH70SgXSpUEAD99Sg3l/HebDPw7TuX4A7Wr52jqsUqHEsORx4nIriWBv9RcWESkrWrZsaamAevjhhwF49tlnAahTpw6//PILDRo0sGWIIiK5ValhXloMghNbzEni2K2QmQaHos2Le4A5QVzjdqhcMf5Ak5Lj7u5umWzOw8MDBwcH4uLiLJ/7+voSExNzQ+coi5O1iv14sksdftgeS/yFND5dd5Rht9Us1fNnq2JYKrDn7qzH2oMJ7D15gYhvtvPr2A5Udin/adPy/xVKkeW0kgjyUsWwiEhZctttt7Fnzx527NjBvn37yM7OplatWrRq1QqDfrkXEXuV02qielu4dBaO/GFOEifFQvIp2PGVuQ9xYBOo1RlCbgIH/RkjRVe/fn127twJgKOjI82bN+eTTz5h0KBBZGVlsXDhQmrUqHFD5yiLk7WK/ajp5859LUP4enMMs1Yd5IE21Us1MaXEsFRkLo4OTP9PM/q89xcx51J49ac9vHJvE1uHVeL0G5XkYjKZ1GNYRKSMa9q0KU2bNrV1GDYXGRlJZGRkic8uLyJWVKkKNOwNDXrBmQPmBPHRvyAjBU7uMC8unlCjA9TuAp7Bto5YypA+ffrw1ltvkZqaiqurKy+99BJ9+/bFx8cHg8HAxYsXWbBgga3DlApuTJc6LN56grMX05m/9ggjO9UutXObNPmcVHCNq3kxslNt3vltP5+tP8YdjQO5rY6frcMqUZp8TnI5n5LBpXTzH9BBXmolISIiZVdERATx8fFs377d1qGISFEZDFC1DrQZDvfOgfajIaCR+bO0JNi7DJY9ZZ6w7vAayEy3abhSNjz99NPExMTg6mr+O6d3796sXr2a4cOH89hjjxEdHa1WDmJzoVUq0f8m89wIH/5xiOS0zFI799UVw3raTCqqUZ1r0zDIE4DnvtlBUmqGjSMqWaoYllxizqVYXodUqWTDSEREREREAEdnCL/VvFw4CQd/N1cSp56/asK6j81VxLU6g0+YrSMWO5SRkcH//vc/goKCqFOnjmX9rbfeyq233mrDyETyGtmpNl9vOk7ipQwW/O8IT3QsnarhnMnnVC0sFZmTg5E3+zejV9SfxJ5PZcbyf5jYq5GtwyoxqhiWXGLOXQLA0Wgg0FMVwyIiIiJiRzwCofmD0HsW3PY0BDUHDJBxCf75BX5+Fn79P3PyOCPV1tGKHXFwcKBr166WieBE7Fk1bzf6tbpcNbzmEBdLqWo4p2JY/YWlomsQ5Mljt9cCYMH/jrAr9ryNIyo5SgxLLjkVw8HebjjoNqGIiJRhkZGR+Pv706xZM1uHIiLW5uAIoW2g0wvQ611ofB9U8jV/duYArP8AFo+ADR/C2UO2jVXsgtFoJDw8nOTkZFuHIlIoT3SshaPRwLlLGSxcd7RUzmlSYljE4omOtQnxcSPbBC8v+ZvsbNP1dyqDlBiWXHISwyE+mnhORETKNvUYFqkg3P2gaX/oFQW3PwvVWoPBCJlpcGAl/PKCuYr40Gr1Iq7gxo0bx5w5czh9+rStQxG5rtAqlejXKgQwVw1fSi/5qmFLKwllikRwc3Zg4j3mFhJbjiXyzeYYG0dUMtRjWHI5ftbcSiLUR/2FRURERKQMMRqhWivzcumsuQ/xwd/h4mlzFfGZA7BlgbkPcZ1u4O5v64illCUlJVGpUiVq165N3759CQ8Px80td0GMwWAgIiLCRhGK5DayU22+2RzDmYvpfLruKCM61CrR82Vlq2JY5GpdGwbQtYE/K/fEM+3nPXRrGIBPZWdbh2VVSgxLLqoYFhEpGzp37lzkfQwGA7/99lsJRCMiYmcqVYHGfaFhH4jbBvtXQOxWSE+GPd/Dnh8guIU5QRzcApQEqRCef/55y+tPPvkk322UGBZ7ElqlEn1bVmPRphjmrDnE4JvDcXN2KLHzqcewSF4T7mnEH/sTOHcpgxkr/mFKn8a2DsmqlBgWC5PJZJl8LqSKEsMiIvYsOzsbQxF/ac/pGyciUmEYjVCtpXlJjjcniA/+bk4Qx24xL5X9oE53qNkRXD1tHbGUoMOHD9s6BJEiG9WpDt9uOUFCcjqfrT/KsNtqlti5cn5VVF5Y5IrQKpV4omNt3lr5D19sOMaw22oQ5lvZ1mFZjd0mhjMyMpgyZQrz588nPj6eunXr8vzzz/Pggw9ed1+TycTMmTN57733OHr0KGFhYYwaNYrRo0fn+iM6Li6Od955h40bN7Jp0yaSkpL44osvGDBgQJ5jDh06NN+7ytWqVSMmJnefkfDwcI4ezdscfuDAgXz66aeF+fJt4tylDC6mZwEQolYSIiJ2bdWqVbYOQUSkbHH3hxYDocl/4Pg6+OdXc3uJi6dh22ew4ysIaw/17oQqJZd4EdsJCwuzdQgiRVbdtxL3tqjGN5tjmL36EAPbhpVY1bAqhkXyN+y2Gixcd5SE5DTeXP4PMx9oYeuQrMZuE8MjRoxgwYIFjBw5kiZNmrBkyRIGDhxIZmYmQ4YMKXDfyZMnM3HiRAYPHsyzzz7L6tWrefLJJ0lMTGT8+PGW7fbt28frr79OzZo1ad68OWvWrCnwuE5OTsybNy/XusqV879L0LRp0zyPINWsad+/YOZUC4N6DIuISNkXGRlJZGQkWVlZtg5FROyJozPU6GBezh6Cf5bD0T8hKwMOrzEv/g2gXk/zRHaahanccHBwYOHChdcsNvrqq6948MEHNW6I3RnVqTaLt54gITmNzzcc45Fba5TIeSyTzykvLJJLZRdHxnSpzfilu/h+eyyP3l6TRsFetg7LKuwyMbx161bmz5/P5MmTefnllwEYNmwYnTt3JiIiggEDBuDsnH+z55MnTzJt2jQefvhhPvroI8u+Dg4OvPrqq4wYMYLAwEAAWrVqRUJCAr6+vqxatYpOnToVGJfRaGTQoEGF+hqCgoIKva29yOkv7ORgwN/DxcbRiIhIcSUnJ5OYmEh2dnaez6pXr26DiG6Mu7t7rveXLl0iMjKSp59+usD9IiIiiIiIICYmhtDQ0JIMUUTKqio14ebHoMUgOLzaXEWcfAri95iXyn7mCuKancBZhRNl3fVaKhWnTZNIaQivWpnezYP5bssJZq8+yMC21XF1sn7VsCqGRa5twE3V+fCPQxw/m0Lkr/uY/1AbW4dkFXZ5+3vRokUYjUZGjhxpWWcwGBg1ahTx8fFER0dfc9+lS5eSlpbG6NGjc60fPXo0aWlpLF261LLOw8MDX1/fIsWWnZ1NUlJSofo0ZmRkcPHixSId35ZyKoarebth1C1CEZEy56OPPqJevXp4eXkRFhZGjRo18ixlUXJysmXZv38/RqORvn372josESlPXNyh/l1w99vQIQL8G5rXXzwNWxbAksdh8ydw4ZRNw5QbV1Did/369fj4+JRiNCKFN7pzHYwGOH0hjS82HCuRc+TkOXSDRCQvZ0cjT3erB8CqfadZd+iMjSOyDrtMDG/evJlatWpRpUqVXOvbtm0LwJYtWwrc18XFhaZNm+Za36JFC5ydnQvc93rS09Px9PTEy8sLHx8fRowYQWJiYr7brl69mkqVKuHu7k5oaChvvPFGvpVb9uT4WXPFcGgVVUOIiJQ18+bNY/jw4YSHhzN16lRMJhNjx47l+eefJyAggObNm1uepCnLPvvsM9q1a1dmk9wiYueMRghpDV0nwB2vQY3bwegImamw7yf44UlYEwmndl+ZpUns2jvvvEPNmjUtbf3Gjh1reX/1UqVKFd59913uuusuG0cskr8aVSvTu3k1AGavPkhqhvVbnqiVhEjBejULpn6gBwCRv+4rF5N722ViODY2lqCgoDzrg4ODLZ8XtG9AQADGf/UCMxqNBAQEFLhvQYKCgnjmmWf46KOP+PLLL+nXrx9z586lS5cupKWl5dq2adOmTJw4kW+++YaPPvqIOnXq8Nxzz+WpYr6WpKQkYmJiLEtcXFyxYi6qnIrhEB+3UjmfiIhYz9tvv03Xrl359ddfGTFiBAB33XUXr7zyCrt27eLcuXOcP3++2MdPTk5mwoQJ9OzZEz8/PwwGA6+99lq+22ZkZDB+/HiqV6+Oq6srTZs25fPPPy/2ua+2cOHC6841ICJiFVVqQLsnoHcUNL4PXDwBE8Rsgt8mwfKX4PgGJYjtXNWqValXrx716pmrvIKCgizvc5b69evTpUsXXn31VWbNmmXjiEWubVTn2hgNcCopja82Hrf68dVKQqRgRqOBZ+8wjyebj55j45FzNo7oxtllj+GUlBRcXPL2uDUajTg5OZGSklLkfQFcXV0L3Lcg06ZNy/X+/vvvp0GDBjzzzDMsXLiQYcOGWT77/vvvc2370EMP0bt3b95//32efPJJ6tatW+C5ZsyYwaRJk4oV5404djYnMayKYRGRsmb//v08+uijAJabo+np6QD4+PgwbNgw3nvvPZ588sliHT8hIYHJkycTEhJCixYtWLFixTW3vZEJZAuyc+dO9u3bx3/+859iH0NEpMjcfKBpf2jYB47+Za4cTjwGZw7AH2+CRxA07AXhHcDBLv+8qtAGDhzIwIEDAejUqRMvvfQSXbp0sXFUIsVTy8+de5oFs3RbLLNXH2RAm1BcHK3Xa1gVwyLX16meP/UDPdh78gJz1hykTY0q19/JjtllxbCrq2ueKlww9/fNyMjA1dW1yPsCpKamFrhvUY0ZMwYnJydWrlxZ4HYGg4Fx48ZhMpn4/fffr3vccePGcfz4ccuyYcMGa4V8TdnZJo5fnnwuzFeJYRGRssbd3d3SssjDwwMHB4dcT5z4+voSExNT7OMHBQVx4sQJjh8/zpw5c665Xc4EshMnTmTmzJkMHz6cZcuW0bFjRyIiIizJaoDu3bvj6uqa7/L444/nOfaCBQvo1asX3t7exf46RESKzdEZanWCO9+Aji9c6UN8IQ7WfwDfj4Y9P0BG8QpRpORFR0crKSxl3qhOtTEYIO58Kt9sLv7vdvlRj2GR6zMYDAy/zdyeaOWeeA7EJ9s4ohtjl4nh4ODgfNsn5LSByGkpca19T506laefb3Z2NqdOnSpw36JycnIiMDCQs2fPXnfbnFngC7Otp6cnISEhliW/thrWdjIplfRM8zULq1K5xM8nIiLWVb9+fXbu3AmAo6MjzZs355NPPiEjI4PU1FQWLlx4Q315XVxcCjWGFmUC2eXLl5Oamprv8v777+c6bnZ2Np9//jmDBw8u9tcgImIVBgMENzf3Ie4+FUJuAgyQcha2fgpLnoDtX0Jq8dv3SMn466+/eO+993Kt+/zzz6lXrx7+/v48+eSTdj8vjEidAA96NjHnCGZFHyQjy3rfs9mXS4aNdpkpErEf9zQLJsDT3K3goz8P2TiaG2OX/9xbtmzJwYMH8yRR169fb/m8oH3T0tLYsWNHrvVbt24lPT29wH2LKjU1lbi4OPz9/a+77aFD5m+UwmxrC0fPXLK8rq6KYRGRMqdPnz789NNPpKamAvDSSy/xxx9/4OPjg5+fH2vXruWFF14o8ThuZALZgvz2229kZGRw55133nCMIiJWU7UOdHgG7noTanYyT1SXcQl2LYalo8yJYiWI7caECRNYs2aN5f2+ffsYOnQoRqOR1q1bExUVxcyZM20YoUjhjO5cG4ATiSks3nLCase90kpCFcMiBXF2NPLQLeaim2+3nOD0hfw7F5QFdpkY7t+/P9nZ2bka/5tMJqKiovDz86NTp06Aud/h3r17uXTpSlKzd+/eODs7ExUVleuY7777Ls7OzvTu3bvI8aSlpeU6R45XX32VzMxMevToYVmXmJiYZ1bCzMxMpk2bhoODA127di3y+UvDsbMXAfByc8LLzcnG0YiISFE9/fTTxMTEWFom9e7dm9WrVzN8+HAee+wxoqOjGTRoUInHcSMTyBZk4cKFDBgwAEfH6/fvtNUkriJSgXlVg5sfg3tmQv27wdEFstLNrSWWjoItC5UgtgN///235UYlwBdffEGlSpVYv349P/30E4MHD2bevHk2jFCkcOoHetKjUQAAUdEHyLRS1bAmnxMpvAfaVKeyswPpmdks+N8RW4dTbHY5O0KrVq0YPHgwEyZM4PTp05aJa1atWsW8efMsk8tFRUUxadIkoqOj6dixI2D+w/O5555jypQpZGRk0KFDB1avXs3ChQsZP358nj9Wp06dCsDhw4cBWLx4MQcOHADM1VYAcXFxtGvXjr59+1K3bl0MBgMrV67khx9+4Pbbb+eBBx6wHG/JkiW8+eab3HXXXdSoUYNz587x5Zdfsn37dl588UXCw8NL8tIVW07FsPoLi4iUH7feeiu33nprqZ7zRiaQLciCBQsKva2tJnEVEaGyL7QcDA17w94f4Z+fITMN9i6D/cuhTnfzRHWuXraOtEI6f/48Pj4+lve//PIL3bp1w9PTEzCPm99++62twhMpktGd6/DrrlMcO3uJpdtiua9VyA0fM6fGTXlhkevzcnPigTbVmfvnYRauO8rjHWtRydku06wFstuI586dS1hYGPPnz2f27NnUrVuXhQsXFqraadKkSfj4+PDee+/x5ZdfEhoayowZMxg7dmyebV9++eVc7xctWsSiRYuAK4lhb29vunXrxooVK/jkk0/IzMykZs2aTJo0iWeffTZX9VLTpk2pWbMmn332GfHx8Tg7O9OkSRM+/fRTy2y49ujoWXNiuHoVJYZFRKT4bmQCWWsZN24cw4YNY/bs2cyePZusrCwSExNL/LwiIhauntD8AWhwtzlBvO+nqxLEv0K9u8wJYmfN7VGagoKC2L17NwAnT55k06ZNPPLII5bPk5KSCvVkiog9aFzNi64N/Fm5J56o6AP0aVENB+ONZXRVMSxSNA/dWoOP1x4h8VIG32+LZUCb6rYOqcjsdtRzdnZmypQpTJky5ZrbTJw4kYkTJ+ZZbzAYeOqpp3jqqaeue55/t33Ij7e3d6ErlVq2bMnSpUsLta09OaaKYRGRMm/u3Ll8+OGHHDp0KN/JTg0GA5mZmSUaQ3BwMEePHs2zvjATyFqLp6cnnp6eTJ06lalTpxITE0NoaGiJn1dEJA8XD2g2AOrflTtBvHsJHFgBje6FOj3A0dnWkVYIffv2JSoqirS0NDZu3IiLiwu9evWyfL59+3Zq1qxpwwhFimZ05zqs3BPP4YSLLNsRS+/m1W7oeFd6DFshOJEKoJq3G13q+7N89ym+2nRciWEpu45drhgOq6KqBRGRsujFF1/k9ddfp2nTpgwcODDXo7KlqWXLlvz++++cPXs21wR0hZlA1toiIyOJjIwkKyur1M4pIpKvqxPEu5bAP79A+kXz5HT7foYm/aDG7WB0sHWk5dqkSZM4deoUn376KZ6ensybN4+AAHOf1qSkJL799ltGjRpl4yhFCq9ZqDcd6/mxat9p3v39APc0DcZ4A1ldVQyLFN39N4WyfPcpth5LZP+pC9QJ8LB1SEWixLBw/lIG51MyAAhVKwkRkTJp3rx59O7dm++++86mcfTv35833niDWbNmWVoy5TeBbGmIiIggIiJCFcMiYj9cPMw9iOv1hJ1fw6FVcOkMrP8A9iyDZg9ASGs1+CwhlStXZuHChfl+5u7uzokTJ6hUSX8PSdkyunMdVu07zYH4ZH7++yR3Nc07CXBh5TxRbdDPIJFCu72uH/4eLsRfSOOrjcd56e6Gtg6pSJQYFo6evWh5rVYSIiJlU3JyMnfccUeJniMqKorExERLv97o6GhLa4rRo0fj5eVV6AlkS4MqhkXEblX2hZsfM/cg3v4lxGyEpBPwx3QIaAQt/ws+YbaOskIxGo14eWlSQCl7WoX5cGvtqvx5IIF3f9/PnY0Di101rFYSIkXn6GCkX6sQZq06yOKtJ3j+zvo4OhhtHVahKTEsHL3cX9jZ0UigZ8lPCiQiItZ36623smPHjhI9x/Tp03P1D16+fDnLly8HYNCgQZY/qG9kAllrUsWwiNg9rxDo8Awk7De3lTi9F07tgp+fgzpdoUl/80R2YjWJiYl88cUXln78/55zxmAw8NFHH9koOpHiGdOlDn8eSGDvyQss332KOxoHFus4aiUhUjz3XU4Mn7mYzl8Hz3B7XT9bh1RoSgwLx8+ZE8MhPm431I9IRERsJyoqii5dutCiRQsefvjhEnkE8MiRI4XarjATyJYGVQyLSJlRtQ50nQjH1pkTxJcSYP8KOPIXNO1vnqDOWHaqj+zVb7/9Rt++fblw4QKenp759uPXI/RSFrWpUYWba1Zh3aGzzPxtPz0aBRTre1kVwyLFU8vPnUbBnuyKTeL7bbFlKjGs3y6EmHMpAIT6qI2EiEhZ1atXLzIyMhgxYgTu7u7Uq1ePhg0b5loaNWpk6zBLVUREBPHx8Wzfvt3WoYiIXJ/BAGHt4O63zMlgB2fIuASb58OvL8KZg7aOsMwbN24cvr6+bN26lcTERA4fPpxnOXTokK3DFCmWMV3qALA7Lolfd50s1jHUY1ik+Ho1CwZg+a6TpGaUncIUJYbFkhgO8XGzcSQiIlJc/v7+1KtXjw4dOtCmTRuCg4MJCAjItfj7+9s6zFIVGRmJv78/zZo1s3UoIiKF5+gMje+De96BsPbmdecOw6//B5vmQfrFgveXa9q3bx9jx47VuCDlUvtaVbm5ZhUA3lqxn+xs03X2yOtKKwmrhiZSIdx9OTF8IS2TVfvibRxN4amVhBBjaSWhimERkbJq1apVtg7B7qjHsIiUaZWqwC1PQs1OsHEuJJ+Cf36F4xug1UNQva2tIyxzatSoQUpKiq3DECkxT3evx39m/499py7w48447rmcqCqsK60klBkWKapq3m7cFO7DxiPn+GnnSe5oHGTrkApFFcMVnMlk4oQqhkVERERE7FNQU+g53VxFbHSElHPw5wz46x1Iu2Dr6MqUF198kdmzZ3PmzBlbhyJSIm4Kr8JtdaoC8PbKf8gqYtWwJp8TuTHdG5onflz9z2kys7JtHE3hqGK4gjudnEZapvmbVYlhEZGy49ixYwBUr1491/vrydleRETKEEdnc9/hsFtgwwdweh8cXQvxe6DNCKjW0tYRlglxcXFUrVqVOnXq0L9/f8LCwnBwcMi1jcFgICIiwkYRity4cd3q8sf+BA6evsjSbSfo2zKk0PtezgujvLBI8XSq788rP+3hfEoGW48nclN4FVuHdF1KDFdwOf2FQa0kRETKkvDwcAwGAykpKTg7O1veX09WVtmZCOFGRUZGEhkZWaG+ZhEp57yqQZeJsO8n2P6FuXp49etQqzO0GAzO+n2+IM8//7zl9Zw5c/LdRolhKetaVPehc31/ft8bzzu/7adXs2AcHQr3sHhOX2JVDIsUTy2/ylSvUoljZy/x+954JYbF/uUkhl0cjVR1d7ZxNCIiUljz5s3DYDDg5OSU671coR7DIlIuGY3Q4G4Ibg7/mwVnD8LB3+HkTrj1KfCtZesI7dbhw4dtHYJIqRjXrS6/743n6JlLfLf1BP1bF+73IEuPYTUdFSkWg8FA5/r+zF97hOi98Tx3R31bh3RdSgxXcFcmnnNTQkFEpAwZOnRoge9FRKSc8wqBbpNh9xL4+zu4eBpWjIcWg6DuHXoWPB9hYWG2DkGkVDSu5kX3hgEs332KqN8PcG+LajgVompYPYZFblyny4nhvScvEHPukt0/na/7QBVcjGXiOfv+RhURERERkX9xcIQm/aDbJKhUFbIzYfN88+R06ZdsHZ3dSk5OZtmyZbz77rtERUWxbNkykpOTbR2WiFWN6VIHgGNnL7F464lC7WO6nBhW0ZhI8bWtUQVXJ3O69a8DCTaO5vpUMVzBXUkMa+I5EZGybPLkyQV+bjAYcHV1JSQkhNtvv53g4OBSisx21GNYRCqMqnXgztdg3Ww4sQmOb4CkWOjwLHgE2Do6uzJr1ixeeOEFkpOTLUkwAA8PD1577TUef/xxG0YnYj2Nq3nRrWEAK3af4t3f99O7eTAujg4F7mNpJaG8sEixuTo5cFN4Ff7Yn8BfB85w/032Pfm3EsMV3JVWEqoYFhEpyyZOnGip7rj6D10gz3oHBwcee+wxZs6cWa4rQtRjWEQqFBcP6PAM7F0GWz+D8zHw64vmvsOBjW0dnV34/PPPGTVqFG3atGHs2LE0bNgQk8nEnj17eOeddxg1ahTe3t488MADtg5VxCrGdq3Dyj2nOH42hU/XHeORW2sUuL1aSYhYR/taVfljfwJrD57BZDLZ9d9caiVRgZlMJk6oYlhEpFyIjY2lefPmDBo0iI0bN3L+/HnOnz/Phg0bGDhwIM2bN2f//v1s3ryZAQMGMGvWLN544w1bhy0iItZkMECDe6Djc+DkBunJEP2qeXI6Yfr06bRv354///yTAQMG0LRpU5o1a8aAAQP4448/aNeuHZGRkbYOU8RqGgV7cW/zagC8+/t+zqdkFLi9KoZFrOOW2r4AJCSnsT/evlsVKTFcgZ25mE5aZjYAwd5KDIuIlGVjxoyhZs2afPLJJ7Rq1QoPDw88PDxo3bo1CxYsoEaNGrz44ou0aNGChQsX0qVLFz7++GNbhy0iIiUhuAV0fwU8AsGUBes/gN1LbR2Vze3du5cBAwbg6Jj3wVlHR0cGDBjA3r17bRCZSMl5ukc9nB2NJF7KYNaqAwVuqx7DItbRKNgLT1fzWGPvfYaVGK7ATiWlWl4HernaMBIREblRv/76K126dLnm5126dOHXX3+1vL/77rs5cuRIKUQmIiI24VUNuk819x8G2PY5bP0U/tVuqCKpXLkycXFx1/w8Li6OypUrl2JEIiWvmrcbD90SDsDHfx2xtJPMz5VWEqURmUj55WA00K6WuWp47cEzNo6mYEoMV2DxSWmW137uLjaMREREbpTRaGTHjh3X/HzHjh25qj9MJhPu7u6lEZqIiNiKiwd0egmCmpvf7/kBNsypsMnh7t27M3PmTFatWpXns9WrV/Puu+/So0eP0g9MpIQ90bE2PpWcSM/M5s3l/1xzuyutJJQZFrlRt9SuCsCGw2fzzAFjT5QYrsByKoZ9Kzvj7KhvBRGRsqxPnz7MmTOH6dOnc/HiRcv6ixcvEhkZyYcffkifPn0s69euXUvdunVtEGnpiYyMxN/fn2bNmtk6FBER23FyhQ4RENbe/P7g77B5foVMDr/22mt4e3vTpUsXWrduzaBBgxg0aBCtW7emc+fOeHl58dprr9k6TBGr83JzYnRn89MDi7ee4O8T5/PdTpPPiVhPsxBvAM6nZBBzeX4ve6RsYAUWf8FcMeznoWphEZGy7q233qJdu3Y8++yzeHt7ExoaSmhoKN7e3jz33HO0bduWt956C4DU1FQ8PT0ZM2aMjaMunG3btnHLLbfg5eVFjRo1mDNnTqH2i4iIID4+nu3bt5dwhCIids7BEdqPgfDbzO//+QV2LLJtTDYQGhrKtm3beOqpp7h48SLffvst3377LRcvXmTcuHFs27aNkJAQW4cpUiIG3RxG9SqVAJj28558KxhzVikvLHLj6gV64HC5L8uu2CQbR3NtSgxXYDkVwwGe6i8sIlLWeXt7s2bNGr799lsefvhhGjVqRKNGjXj44Yf57rvv+PPPP/H29gbA1dWVDz/8kAEDBtg26EIaPHgwPXr04Ny5cyxevJhnnnmGnTt32josEZGyxWCAmx+HkJvM73d9VyEnpPP19WX69Ons2bOHlJQUUlJS2LNnD5GRkfj6+to6vGKLioqiRYsWODo6MnHixFyfubu751qMRiNvvvmmbQIVm3F2NPLsHfUA+OvAGVb/czrPNqoYFrEeVycHavmZ+9bvjs2/St8e5J2OVSqMU5d7DAd4qmJYRKQsS01N5Y033uDmm2/m3nvv5d5777V1SFZ15MgRHnjgAYxGI82bN6dx48bs3buXJk2a2Do0EZGyxegAtzwJq1+HkzvNE9K5+UCNDraOrFRcvHiRM2fOUL169Xw/P3bsGFWrVqVSpUqlHNmNq1atGpMnT2bBggV5PktOTra8jouLIzQ0lL59+5ZmeGIn7moSxNzQw2w7nsi0n/ZyWx0/S0UjXN1j2EYBipQzjYK9+OdUsiqGxT6dvqCKYRGR8sDV1ZVp06Zx7NixEjtHcnIyEyZMoGfPnvj5+WEwGK7ZhzEjI4Px48dTvXp1XF1dadq0KZ9//nmxzz169GgWLlxIZmYmmzZt4tChQ7Rv377YxxMRqdAcnOC2Z8C3tvn9+g8gYb9tYyolTz31FL17977m53369OGZZ54pxYis59577+Wee+7By8urwO0+++wz2rVrR40aNUopMrEnBoOBF3s2AGDfqQss2Xoi1+eqGBaxrkbBngDsjlNiWOxQTsWwv3oMi4iUec2bN+fgwYMldvyEhAQmT57Mzp07adGiRYHbjhgxgldeeYU+ffrw7rvvEhoaysCBA/OtYiqMnj178umnn+Lq6krbtm2ZNGkS1apVK9axRESEKxPSVaoK2ZmwZjqkJNo6qhK3YsWKAp+quffee/n111+LfXxb3kQtrIULFzJkyJASP4/YrzY1qtC1gT8AM1b8Q2pGluWznL7DBiWGRayi4eXEcNz5VM5eTLdxNPlTYriCyso2cTr5cmJYFcMiImXea6+9xty5c/n5559L5PhBQUGcOHGC48ePFzj529atW5k/fz4TJ05k5syZDB8+nGXLltGxY0ciIiJIT7/yC1H37t1xdXXNd3n88ccBOHv2LD179mTatGmkpaWxb98+3n77bX766acS+TpFRCoMN2+4PcJcQZyaCGvfhexsW0dVouLi4ggKCrrm54GBgcTGxhb7+La8iVoYO3fuZN++ffznP/8psXNI2RDRoz5GA5xITOHTdUct63N+BKiVhIh1NAzytLzeZad9hpUYrqDOXkwn63IDIbWSEBEp+15//XW8vb25++67CQ8Pp1OnTvTs2TPXctdddxX7+C4uLgQHB193u0WLFmE0Ghk5cqRlncFgYNSoUcTHxxMdHW1Zv3z5clJTU/Nd3n//fQAOHjxI5cqVuf/++3FwcKB27drcddddLF++vNhfi4iIXOYTDq0fMb8+9TfsXmLLaEqcn58fu3btuubnu3btskzUWhy2uolaWAsWLKBXr1439DVK+VAv0IP7WoYAEBV9gPMpGYBaSYhYm3clZ6p5uwHYbZ9hJYYrqFNJqZbXaiUhIlL27d69m8zMTKpXr47BYODIkSPs2bMnz1LSNm/eTK1atahSpUqu9W3btgVgy5YtRTpevXr1SElJ4dtvv8VkMnH06FGWLVumiedERKylVqcrk8/9/S2cO1rw9mVYz549mTNnDmvXrs3z2bp165gzZw49e/Ys9vFtdRO1MLKzs/n8888ZPHhw0b4oKbee6lYXZ0cjiZcy+GC1uR2ZZfI5ZYpErKbB5arhf05dsHEk+XO0dQBiG/EXriSG/ZQYFhEp844cOWLrEACIjY3N9zHdnD+Ui/qIrqenJ19//TXPPfccDz30EB4eHgwcOJCHHnromvskJSWRlHTljnxcXFyRzikiUuG0GgqndsGlM7BuFnR/BRzK35+KkyZN4qeffqJDhw7ceeedNG7cGIPBwM6dO/n5558JDAxkypQpJR5HYW6i9ujRo0jHzMzMJDMzk6ysLDIzM0lNTcXJyQkHBwcAfvvtNzIyMrjzzjsLdTyNpeVfsLcbD7UP54M1h5j312GGtAtXj2GREhDuWwmAY2cu2TiS/JW/0V4KJf7yxHNV3Z1xctDtQBERsY6UlBRcXPLecDQajTg5OZGSklLkY3br1o1u3boVevsZM2YwadKkIp9HRKTCcq4MbR+F6Ffh3BHYvxzqF79y1l4FBgayadMmnnvuOZYsWcKPP/4ImG9CDh48mGnTphEYGFjicVj7JirA1KlTc419r7zyCh9//DFDhw4FzJPODRgwAEfHwqUANJZWDI93rMUXG46RlJrJO7/9c1UrCRsHJlKOhF1ODB89W3BieO4fhzh65hK31/Wja8OA0ggNUGK4wjp1OTHs56H+wiIi5U1ycjKJiYlk5zOJUPXq1Uv03K6urqSlpeVZn52dTUZGBq6uJT/ujBs3jmHDhjF79mxmz55NVlYWiYmJJX5eEZEyLagZhN8GR/6AnV9D+C3g6mXrqKwuICCA+fPnYzKZOH36NCaTCX9//1KtkCyJm6gTJ05k4sSJ1/y8qJPa5YylOeLi4mjTpk2R4xL75l3JmSc61ea1n/fy1cbjBF/uhaoewyLWU923MgCnL6RxKT2TSs75p2J/3XWSjUfO4ehgKNXEsEpFK6hTl1tJBHiqjYSISHnx0UcfUa9ePby8vAgLC6NGjRp5lpIWHByc7+OmOdVPhem9eKM8PT0JCQlh6tSpJCQksHPnzhI/p4hIudD8QXB0gYxLsP1LW0dTogwGA/7+/gQEBJT6Y/P2cBP1enLG0pwlvwpnKR+Gtg8n0NOVbBPEnDPflFBiWMR6wqpUsrw+do2qYZPJxD+nkgGo4+9RKnHlUGK4goq/PPlcgCqGRUTKhXnz5jF8+HDCw8OZOnUqJpOJsWPH8vzzzxMQEEDz5s356KOPSjyOli1bcvDgQc6ePZtr/fr16y2fl5bIyEj8/f1p1qxZqZ1TRKRMq1QFGvU1vz60Ci6ctGk45ZU93EQVyeHq5MCYLnVyrVNeWMR6qvm44XC5P8vRa/QZTkhO53xKBgB1AtxLLTZQYrjCOp2cDmjiORGR8uLtt9+ma9eu/Prrr4wYMQKAu+66i1deeYVdu3Zx7tw5zp8/X+Jx9O/fn+zsbGbNmmVZZzKZiIqKws/Pj06dOpV4DDkiIiKIj49n+/btpXZOEZEyr96d5hYSpmzYtcTW0ZRL9nQTVQSgV/NgXByvpIdUMSxiPU4ORoK9zUWZ15qAbn/8BcvrOv5KDAOQkZHB+PHjqV69Oq6urjRt2pTPP/+8UPuaTCbeeecd6tati4uLC3Xr1mXmzJmWGTZzxMXF8fzzz9OlSxe8vLwwGAx8+WX+j0wNHToUg8GQZwkJCcl3++XLl3PzzTfj5uZGYGAgY8aMITk5uWgXoQQlXLgy+ZyIiJR9+/fvp3fv3oC5RyFAerr5JqCPjw/Dhg3jvffeu6FzREVFMXXqVKKiogCIjo5m6tSpTJ061ZJ0btWqFYMHD2bChAk8+eSTzJ07l3vuuYdVq1bx+uuv59tTsaSoYlhEpBgcXaDBPebXh9dAcrxt4ymH7OkmqgiAu4sjXRr4W95r8jkR6wq/3Gf4yJmL+X6+/3IbCT8PF7wrlW6ezm4nnxsxYgQLFixg5MiRNGnShCVLljBw4EAyMzMZMmRIgftOnjyZiRMnMnjwYJ599llWr17Nk08+SWJiIuPHj7dst2/fPl5//XVq1qxJ8+bNWbNmTYHHdXJyYt68ebnWVa5cOc92v//+Oz179qRt27a8/fbbHDt2jBkzZrBr1y5WrlxZ6j2s/s1kMpGQbE4M+7qrYlhEpDxwd3e3TDbn4eGBg4NDrsdUfX19iYmJuaFzTJ8+naNHj1reL1++nOXLlwMwaNAgvLzMkxTNnTuXsLAw5s+fz+zZs6lbty4LFy5k0KBBN3T+ooqIiCAiIoKYmBhCQ0NL9dwiImVa7W6w+3tIS4K9P0Lrh2wdUZkRFRVFYmKiZdLT6OhoMjMzARg9ejReXl65bqKePn3a8vfuqlWrmDdvXqneRBXJ0atZMD/tNLePiTufauNoRMqX6pf7DF+rx3BOxXBpVwuDnSaGt27dyvz585k8eTIvv/wyAMOGDaNz585EREQwYMAAnJ3zz6CfPHmSadOm8fDDD1t6KQ4bNgwHBwdeffVVRowYQWBgIGCuakpISMDX15dVq1Zd986s0Wgs1B+1Tz31FHXr1iU6OtoSZ506dXjooYdYunQpffr0KeylKBEX07NIyzQnD6oqMSwiUi7Ur1/fMsmao6MjzZs355NPPmHQoEFkZWWxcOHCG5587siRI4XaztnZmSlTpjBlypQbOt+NioyMJDIykqysLJvGISJS5ji5Qp3u8Pc35qrh5gPBUU8aFkZZu4kqkqNjvSsVw6cv5J0cUUSKL8zXnBi+Vo/h/ZaJ50o/MWyXrSQWLVqE0Whk5MiRlnUGg4FRo0YRHx9PdHT0NfddunQpaWlpjB49Otf60aNHk5aWxtKlSy3rPDw88PX1LVJs2dnZJCUl5WlLkWPfvn3s2LGDESNG5EpeDxo0CG9vb7766qsina8kJFz1Q97PQ7/giYiUB3369OGnn34iNdVc4fHSSy/xxx9/4OPjg5+fH2vXruWFF16wcZSlSz2GRURuQK1OgAEyLsHxdbaOpsw4cuQIJpMp3yU8PNyyXc5N1OPHj5OWlsbOnTuVFBabcnVyYFy3ulR2dmB05zrX30FECq16FXO3gROJKWRkZef5/ED85cRwgEepxgV2mhjevHkztWrVokqVKrnWt23bFoAtW7YUuK+LiwtNmzbNtb5FixY4OzsXuO/1pKen4+npiZeXFz4+PowYMcLyiNDV5wdo06ZNrvWOjo60atWqUOdPSkoiJibGsuQ3Y+2NyGkjAeBbWRXDIiLlwdNPP01MTAyuruaJDXr37s3q1asZPnw4jz32GNHR0fqDU0RECq9yVQhuYX59YKVtYxGRUjGmSx12TuzBrXWq2joUkXIlp2I4K9tEbGJKrs/OJKdx5qJ5bhi1krgsNjaWoKCgPOuDg4Mtnxe0b0BAgGXinRxGo5GAgIAC9y1IUFAQzzzzDK1atQJgxYoVzJ07l82bN7N27VpLH6ic418r/pyZZgsyY8YMJk2aVKw4CyMnMexoNODl5lRi5xEREdu69dZbufXWW20dhs2olYSIyA2q3QVit8DpfZAUC57Bto5IREqYUTPPiVhdTo9hMPcZDvO9Ml/Z/svVwmCbimG7TAynpKTk23DfaDTi5ORESkpKPnsVvC+Aq6trgfsWZNq0abne33///TRo0IBnnnmGhQsXMmzYMMv5gXxjKOz5x40bZzkeQFxcXJ4K5BuRkGy+E+Hr7qwf+iIiUm5p8jkRkRsU3AKc3SE9GWI2QcNeto5IRESkzKns4ohPJSfOXcrIUzF88LQ5MVylsjNVKpd+u1e7bCXh6upKWlreZufZ2dlkZGRYHpMtyr4AqampBe5bVGPGjMHJyYmVK688WpVz/PxiKOz5PT09CQkJsSz5VR/fiJyKYU08JyJSvnz//ff07t2bxo0bU61aNYKDg3Mt1apVs3WIIiJSlhgdILi5+XXsdVri/fkW7FoCl86WdFQiIiJlTrC3GwAnzuVODJ86n3r5c+vlK4vCLiuGg4ODc83kmiOnTUNOS4lr7fvbb7+RnZ2dq51EdnY2p06dKnDfonJyciIwMJCzZ6/88pNz/Li4uDyzv8fGxlr1/MWVkxj2VWJYRKTcmDx5MpMmTcLb25tmzZpRp44mDVErCRERKwhuCUf+NLeTSLsALvk85no+Bo6tMy9V60ClKnm3ERERqcCqebuxKzaJmH9VDJ++nKPz91Bi2KJly5b8/vvvnD17NtcEdDn9eVu2bFngvnPnzmXHjh00b97csn7r1q2kp6cXuG9RpaamEhcXR4cOHXKdH2DDhg20b9/esj4zM5MtW7bQvXt3q52/uM5cbiVR1b30S9RFRKRkREVF0aVLF3744YdrtlSqaNRKQkTECoKbg8EIpmyI2w7h+fSuP355HhUXT/BrUKrhiYiIlAXVfMwVw/9uJRGflJMYts3fcHbZSqJ///5kZ2cza9YsyzqTyURUVBR+fn506tQJgISEBPbu3culS5cs2/Xu3RtnZ2eioqJyHfPdd9/F2dmZ3r17FzmetLS0XOfI8eqrr5KZmUmPHj0s6xo0aEDjxo2ZM2cO6enplvWffvop586do3///kU+v7XlVAz7qWJYRKTcyMjI4L777lNSWERErMu5MvjVN78+sTn/bY5vMP83pDUY7fJPTBEREZuqltNK4t+J4Qu2TQzbZcVwq1atGDx4MBMmTOD06dM0adKEJUuWsGrVKubNm2f5ozcqKopJkyYRHR1Nx44dAXMrh+eee44pU6aQkZFBhw4dWL16NQsXLmT8+PF5+vVOnToVgMOHDwOwePFiDhw4AMBLL70EmNtCtGvXjr59+1K3bl0MBgMrV67khx9+4Pbbb+eBBx7IdcwZM2Zwxx130LlzZ4YMGcKxY8d488036dixI/fee2+JXbfCunryORERKR+6d+/Oxo0befTRR20dioiIlDfVWkL8bjj5N5hMYLhqAusLp+DcEfPrUOtNmC0iIlKe5CSGT55PJSvbhIPRPJaevpwY9lNiOLe5c+cSFhbG/PnzmT17NnXr1mXhwoUMGjTouvtOmjQJHx8f3nvvPb788ktCQ0OZMWMGY8eOzbPtyy+/nOv9okWLWLRoEXAlMezt7U23bt1YsWIFn3zyCZmZmdSsWZNJkybx7LPP4uiY+zJ269aNH3/8kfHjx/Pkk0/i6enJI488wrRp0zBc/UuUjSRc0ORzIiLlTVRUFN27d2f8+PEMGzaM0NBQuxhzbEk9hkVErMS3tvm/aUmQci53D+GYy9XCTpUgoEnpxyYiIlIG5Ew+l5Fl4vSFNAK9XMnONl15qt9GPYYNJpPJZJMzS6Hl9EY8fvw4ISEhN3Ss1Iws6r/8CwALHm5Dh7p+1ghRRETswLRp0yw3NfNjMBjIzMwsxYjsgzXHURGRCin9EnzzkPn17c9CtVZXPls5EeL3QPht0H6UTcKTkqexVETkxiQkp9F66koAvn28Ha3CquRa990T7WlZ3afU47LbimEpGWcuXul7rFYSIiLlx7PPPsubb75JzZo1uemmm/Dy8rJ1SCIiUl44VwL3AEi+3DYiJzFsMsHZQ+bXAY1sFp6IiIi9863sjIujkbTMbE4kptIq7MrEc6Aew1JKctpIgCafExEpT+bNm8e9997LN998Y+tQRESkPKpS40piOMeFOMi8/PeFT7gtohIRESkTDAYD1bzdOJRwkRPnzBPQnU6+kqOzVbtXTRlbwZy5aP6mMxigSmVVDIuIlBfZ2dl069bN1mGIiEh55VPD/N+zh6+sy0kSGx3BK7TUQxIRESlLqvmY+wyfSLwEQHxSKgBebk64OjnYJCYlhiuYhAvmVhLebk44Ouh/v4hIedGrVy9WrVpl6zBERKS8yqkIvnga0i+aX+ckhr1CwEEPo4qIiBQk2MucGI5NNCeE4y8/1W+rNhKgxHCFk1Om7m+j2Q5FRKRkvPjii+zbt48RI0awYcMG4uLiiI+Pz7NUJJGRkfj7+9OsWTNbhyIiUvZVqXHldU5COKd62KdGns1FREQkN0vFcE4riZzEsKcSw1JKcsrU/Wx4N0JERKyvfv36bNu2jblz59KuXTtCQkIICgrKs5RFu3btokOHDnh6etKwYcNCV0ZHREQQHx/P9u3bSzZAEZGKwNUL3KqYX587Yp54LidBrP7CIiIi1xXkZS7SjDufOzFsyznA9LxPBXOlYliJYRGR8mT8+PEYDAZbh2F1GRkZ9O7dm3HjxhEdHc0333xDnz59OHjwIL6+vrYOT0SkYvGqBilnITkeUs5BWpJ5fRVVDIuIiFxP1cu5uKTUTNIzs4m/YC7e9Pe03VP9SgxXMPFJl+9GKDEsIlKuTJw40dYhlIh9+/Zx7tw5nnjiCQDuv/9+xo8fz+LFixk2bJiNoxMRqWBcvc3/TU2EczmT0BnAO8xGAYmIiJQdV1cGn7mYph7DUvpyKoaVGBYRkaJITk5mwoQJ9OzZEz8/PwwGA6+99lq+22ZkZDB+/HiqV6+Oq6srTZs25fPPPy/WebOzs/Ndt2vXrmIdT0REboCrl/m/KYmQeNz82t0fnDR/iYiIyPVUvSoxnHAh/UorCSWGpbTYwzediIiUPQkJCUyePJmdO3fSokWLArcdMWIEr7zyCn369OHdd98lNDSUgQMHsmDBgiKft379+ri7uzNz5kwyMjL47LPPOHjwIBcvXizulyIiIsXl5m3+b2qiuaUEmBPDIiIicl1VKjtbXh89e5FL6VmAEsNSSpLTMu3im05ERMqeoKAgTpw4wfHjx5kzZ841t9u6dSvz589n4sSJzJw5k+HDh7Ns2TI6duxIREQE6enplm27d++Oq6trvsvjjz8OgLOzM4sXL+arr74iMDCQ77//nq5duxISElLiX7OIiPxLTiuJlES4dDkxnDMhnYiIiBTI2dGIl5sTADtjzlvWB3m52Sok9RiuSHKqhQH8PfS4l4iIFJ6LiwvBwcHX3W7RokUYjUZGjhxpWWcwGBg1ahT9+vUjOjqaHj16ALB8+fJCnbtly5b89ddfAGRlZVGrVi3GjRtXjK9CRERuSE7FcGYqXIgzv66kxLCIiEhhVXV35nxKBttjEi3rgrxsl6NTxXAFEp+UanmtimERESkJmzdvplatWlSpkjtR0LZtWwC2bNlS5GPu3LmTtLQ0Lly4wPPPP09AQAB33HGHVeIVEZEiyKkYBjh/wvxfVQyLiIgUWk6f4b9PJFneuzo52CweJYYrkJyJ55wdjXi6qlhcRESsLzY2lqCgoDzrc6qNY2Nji3zMTz75hMDAQIKDgzl06BBLly4tcPukpCRiYmIsS1xcXJHPKSIi+ciZfA4Ak/k/lXxtEoqIiEhZVPVyoWZyWiYA1bxt+0S/soMVSE4rCX8PFwwGg42jERGR8iglJQUXl7xPpRiNRpycnEhJSSnyMadPn8706dMLvf2MGTOYNGlSkc8jIiLX4eIBBgcwZV1ZV8nHdvGIiIiUMX7uuf9WCva2XX9hUMVwhRJ/OTGsNhIiIlJSXF1dSUtLy7M+OzubjIwMXF1L/o74uHHjOH78OP/3f/+Hr68v3t7eJX5OEZEKwWD4V9UwqhgWEREpAt/KzrneV1NiWErL1RXDIiIiJSE4ODjf1g05LSQKM4HdjfL09CQkJISpU6eSkJDAzp07S/ycIiIVRs4EdABGR3DxtFkoIiIiZU1VD1UMi42oYlhEREpay5YtOXjwIGfPns21fv369ZbPS0tkZCT+/v40a9as1M4pIlLuXT0BnZu3uYpYRERECqWqWkmIreRUDPu527axtYiIlF/9+/cnOzubWbNmWdaZTCaioqLw8/OjU6dOpRZLREQE8fHxbN++vdTOKSJS7l3dSsKtiu3iEBERKYOquuduJRHiY9vEsCafq0BOX0gFwN9TFcMiIlJ0UVFRJCYmkpiYCEB0dDSZmebZdEePHo2XlxetWrVi8ODBTJgwgdOnT9OkSROWLFnCqlWrmDdvXr4T05WUyMhIIiMjycrKuv7GIiJSOFe3klB/YRERkSKxt4phJYYriEvpmZy5mA7knQFRRESkMKZPn87Ro0ct75cvX87y5csBGDRoEF5e5iqyuXPnEhYWxvz585k9ezZ169Zl4cKFDBo0qFTjjYiIICIigpiYGEJDQ0v13CIi5dbVFcOVVDEsIiJSFFcnhl2djPhUcrJhNEoMVxizog9iMoGj0UCDYE0QISIiRXfkyJFCbefs7MyUKVOYMmVKyQZ0HaoYFhEpAbl6DCsxLCIiUhRuzg5UdnbgYnoWwd5uGGzcq189hiuAo2cuMmfNIQD+2z6cajYuUxcRESkN6jEsIlIC1EpCRETkhlT1MFcN20N+TonhCmDqj3tIz8qmqrszT3atY+twRERERESkrLq6YlitJERERIosp51EsJcSw1LCMrKy8XIz9yt5tkd9PF1t27tERESktERGRuLv70+zZs1sHYqISPnhEQiBTcG3NlSpaetoREREypx+rUII9HSld4tgW4eCwWQymWwdhBQsZ9Kc48ePExISUqxj7Io9T4NAT4xG2/YuERERKW3WGEdFREQqMo2lIiLlkyafqyAaBXtdfyMRERERERERERGpENRKQkRERMoltZIQERERERG5NiWGRUREpFyKiIggPj6e7du32zoUERERERERu6PEsIiIiIiIiIiIiEgFo8SwiIiIiIiIiIiISAWjxLCIiIiUS+oxLCIiIiIicm1KDIuIiEi5pB7DIiIiIiIi16bEsIiIiIiIiIiIiEgFo8SwiIiIiIiIiIiISAXjaOsA5PoyMzMBiIuLs3EkIiJyPYGBgTg6ani1JxpHRUTKFo2l9kdjqYhI2VGUcVSjbRlw+vRpANq0aWPjSERE5HqOHz9OSEiIrcOQq2gcFREpWzSW2h+NpSIiZUdRxlGDyWQylXA8coNSU1PZuXMnfn5+xbpzHhcXR5s2bdiwYQNBQUElEGHFoWtpPbqW1qNraT3WuJaqcrI/NzqOgv6dWZOupfXoWlqPrqX1aCwtn/Q3qf3QtbQeXUvr0bW0ntIeRzXalgGurq7cdNNNN3ycoKAg3Xm3El1L69G1tB5dS+vRtSxfrDWOgr43rEnX0np0La1H19J6dC3LF/1Nan90La1H19J6dC2tp7SupSafExEREREREREREalglBgWERERERERERERqWCUGK4APD09mTBhAp6enrYOpczTtbQeXUvr0bW0Hl1LuRZ9b1iPrqX16Fpaj66l9ehaSn70fWE9upbWo2tpPbqW1lPa11KTz4mIiIiIiIiIiIhUMKoYFhEREREREREREalglBgWERERERERERERqWCUGBYRERERERERERGpYJQYFhEREREREREREalglBgWERERERERERERqWCUGBYRERERERERERGpYJQYLscyMjIYP3481atXx9XVlaZNm/L555/bOiy7tmrVKgwGQ77LypUrc2176tQpBg8ejK+vL+7u7nTu3JnNmzfbKHLbSk5OZsKECfTs2RM/Pz8MBgOvvfZavtsW5bpt2rSJzp074+7ujq+vL0OGDCE+Pr4kvxSbK+y1nDhx4jW/VzMzM/NsX9Gu5caNGxkzZgxNmjTB3d0dLy8vunfvzpo1a/Jsq+9JKYjG0qLROFo8GketR+Oo9WgsFWvQOFp0GkuLR2Op9WgstY6yMo46FntPsXsjRoxgwYIFjBw5kiZNmrBkyRIGDhxIZmYmQ4YMsXV4dm3kyJHcfPPNudY1btzY8jolJYXOnTtz8uRJxo0bh7e3N7NmzaJjx46sX7+ehg0blnbINpWQkMDkyZMJCQmhRYsWrFixIt/tinLddu/eTceOHQkPD+f1118nMTGRN998ky1btrBx40bc3NxK68srVYW9ljmioqLw8vLKtc7BwSHX+4p4LV9//XXWrFlDv379GDNmDImJicyZM4fOnTuzbNky7rjjDkDfk3J9GkuLR+No0WgctR6No9ajsVSsQeNo8WksLRqNpdajsdQ6ysw4apJyacuWLSbANHnyZMu67OxsU8eOHU3+/v6mtLQ0G0Znv6Kjo02A6YsvvihwuxkzZpgA05o1ayzrEhISTL6+vqbevXuXcJT2JzU11XTixAmTyWQyHT582ASYpk2blme7oly3Xr16mfz8/ExnzpyxrMv5//PWW2+VyNdhDwp7LSdMmGACTHFxcdc9ZkW8ln/++acpNTU117qzZ8+aAgMDTS1btrSs0/ekFERjadFpHC0ejaPWo3HUejSWyo3SOFo8GkuLR2Op9WgstY6yMo4qMVxOPf/88yaj0Zjrm8VkMpm++eYbE2D65ZdfbBSZfbt6EL5w4YIpPT093+1uvvlmU7NmzfKsHzVqlMnZ2dmUlJRUwpHar4IGjsJet6SkJJOTk5PpySefzLNt48aNTTfffLO1w7ZLhRmEY2NjTefPnzdlZWXlewxdy9wefPBBk4uLi+W9vielIBpLi07j6I3TOGo9GkdLhsZSKSyNo8WjsfTGaSy1Ho2l1mdv46h6DJdTmzdvplatWlSpUiXX+rZt2wKwZcsWW4RVZgwfPhwPDw9cXV257bbbWLt2reWz7Oxstm3bRps2bfLs17ZtW9LT0/n7779LM9wyoSjXbefOnWRkZFxz2+3bt5OdnV3iMZcFdevWxcvLCw8PD+6//35iYmJyfa5rmVtsbCxVq1YF9D0p16extPg0jlqffmaVDI2jRaexVApL4+iN0VhqffqZVTI0lhaNvY2j6jFcTsXGxhIUFJRnfXBwsOVzycvZ2Zm+fftamqzv3buXN998k44dOxIdHc0tt9zC2bNnSU1N1fUtoqJct5z/XmvblJQUzp07h6+vbwlGbN98fHwYOXIk7dq1w83Njb/++ot3332XdevWsWXLFsu10bW84q+//mL16tU89dRTgL4n5fo0lhadxtGSo59Z1qVxtHg0lkpRaBwtHo2lJUc/s6xLY2nR2eM4qsRwOZWSkoKLi0ue9UajEScnJ1JSUmwQlf1r37497du3t7zv1asX999/P40aNeK5557jzz//tFy7/K6vq6srgK5vPopy3XSNr+/JJ5/M9b5v3760b9+efv368dZbbzF16lRA1zJHfHw8Dz74INWrV2f8+PGAvifl+jSWFp3G0ZKjn1nWpXG06DSWSlFpHC0ejaUlRz+zrEtjadHY6ziqVhLllKurK2lpaXnWZ2dnk5GRYfmmkesLCwujX79+/O9//+PSpUuWa5ff9U1NTQXQ9c1HUa6brnHx3HfffYSHh7Ny5UrLOl1LuHDhAj179iQpKYkffvjBMmOuviflejSWWofGUevQz6ySp3H02jSWSnFoHLUejaXWoZ9ZJU9jaf7seRxVYricCg4OJi4uLs/6nNLznJJ0KZzq1auTnZ1NYmIivr6+uLi46PoWUVGuW85/r7Wtq6trnl5lYhYaGsrZs2ct7yv6tUxJSeGee+5hz549/PjjjzRp0sTymb4n5Xo0llqPxtEbp59ZpUPjaF4aS6W4NI5al8bSG6efWaVDY2lu9j6OKjFcTrVs2ZKDBw/m+scIsH79esvnUniHDh3CwcGBKlWqYDQaadasGRs2bMiz3fr163Fycsr1D13MinLdmjRpgpOT0zW3bdasGUajfnzl5/Dhw/j7+1veV+RrmZGRQb9+/Vi7di3ffvttrkfyQN+Tcn0aS61H4+iN08+s0qFxNDeNpXIjNI5al8bSG6efWaVDY+kVZWEcLZ9XXujfvz/Z2dnMmjXLss5kMhEVFYWfnx+dOnWyYXT269+/tADs2bOHb775httvv91Slt+/f3+2b9/On3/+adnuzJkzfPHFF9xxxx14eHiUWsxlSWGvm6enJz169ODzzz/n3Llzlm1XrVrF33//Tf/+/Us9dnuT3/fqvHnziImJoUePHpZ1FfVaZmdnM3DgQH755RcWLFjAHXfcke92+p6UgmgsLTqNoyVLP7OsR+Po9WkslRulccYbI8wAAA0YSURBVLR4NJaWLP3Msh6NpQUrK+OowWQymYq1p9i9IUOG8NlnnzFq1CiaNGnCkiVL+PHHH5k3bx4PPfSQrcOzS926dcPd3Z3WrVvj7+/Pvn37+OCDDzCZTPzxxx+0aNECgEuXLtGqVSvi4+N5+umn8fLyYtasWRw9epR169bRuHFjG38lpS8qKorExEQSExN588036d69O7fddhsAo0ePxsvLq0jX7e+//6Zt27bUqFGDxx9/nPPnzzN9+nQCAwPZtGkTlSpVstWXWuIKcy19fHzo27cvDRs2pFKlSqxdu5bPPvuM+vXrs27dOjw9PS3Hq4jXcty4cbz11lt069aNIUOG5Pl80KBBQNH+LVfE6ygaS4tK42jxaRy1Ho2j1qGxVKxB42jRaSwtPo2l1qOx9MaVmXHUJOVWWlqa6aWXXjKFhISYnJ2dTY0bNzYtXLjQ1mHZtXfeecfUtm1bU5UqVUyOjo6mgIAA0wMPPGDavXt3nm3j4uJMAwcONPn4+JgqVapk6tixo2nDhg02iNo+hIWFmYB8l8OHD1u2K8p1W79+valjx46mSpUqmXx8fEwDBw40xcXFldJXZDuFuZbDhw83NWrUyOTh4WFycnIy1apVyzRu3DjTuXPn8j1mRbuWt99++zWv4b+HPn1PSkE0lhaNxtHi0zhqPRpHrUNjqViDxtGi01hafBpLrUdj6Y0rK+OoKoZFREREREREREREKhj1GBYRERERERERERGpYJQYFhEREREREREREalglBgWERERERERERERqWCUGBYRERERERERERGpYJQYFhEREREREREREalglBgWERERERERERERqWCUGBYRERERERERERGpYJQYFhEREREREREREalglBgWERERERERERERqWCUGBYRqzty5AgGg4H58+fbOpQimThxIgaDwdZhiIhIBadxVERE5MZoLBUpHCWGRcq5ZcuWMXHiRFuHwZ9//snEiRNJTEy0aRwxMTFMnDiRbdu22TQOEREpGzSO5qZxVEREikpjaW4aS8WeGEwmk8nWQYhIyXnsscf44IMPKM1/6iaTibS0NJycnHBwcADgtdde44UXXuDw4cOEh4eXWiz/tm7dOtq1a8fHH3/M0KFDc32WmZlJZmYmrq6utglORETsjsbR3DSOiohIUWkszU1jqdgTVQyLSLFcunTpmp8ZDAZcXV0tA7Ct4igqR0dHDcAiIlIqNI6KiIjcGI2lIjdOiWERO/P333/Tq1cvvL29qVSpEu3bt+fXX3/Ntc38+fMxGAwcOXIk1/p/91EaOnQoH3zwAWAeGHOWq/f7+uuvufnmm6lcuTJeXl7cc8897N69O9dxhw4diqurK8eOHaNPnz54eXnRs2fPa34N/45j4sSJvPDCCwDUqFHDEseqVass+/z+++907twZT09PKleuTKdOnVi7dm2u4+b0W9q9ezf//e9/8fX1pVGjRgAcPXqUkSNH0qBBAypVqoS3tzf33HMPu3btsuy/atUq2rVrB8BDDz1kiSPnsaZr9XP66KOPaNq0Ka6urvj7+zN06FDi4uLyvUanTp3i/vvvx9PTEx8fH0aMGEFqauo1r5WIiFiXxlGNoyIicmM0lmoslYrD0dYBiMgV//zzD7fccgsuLi489dRTuLu78/HHH9OzZ08WL15Mr169inS8Rx99lOPHj/P777+zcOFCy3o/Pz8Apk+fTkREBH379mXw4MEkJycza9YsbrnlFrZs2UKNGjUs+2RnZ9O9e3duuukm3njjDRwdC//jo2/fvuzdu5evvvqKt956i6pVqwLQoEEDAL766isefPBBOnXqxJQpU8jOzmbevHl07tyZ1atX07Zt21zH69+/PzVq1GDq1KmkpaUBsHHjRlatWsW9995LeHg4cXFxfPDBB3To0IFdu3YRGBhIgwYNmDhxIhMnTmTEiBHcdtttADRt2vSasec8btShQwciIyM5duwYUVFR/PHHH2zZsgUvL69c16hbt240bdqU119/nQ0bNvDhhx9StWpVXn311UJfLxERKR6NoxpHRUTkxmgs1VgqFYxJROzGfffdZ3J0dDTt2bPHsu78+fOm6tWrm8LDw01ZWVkmk8lk+vjjj02A6fDhw7n2P3z4sAkwffzxx5Z1jz76qCm/f+rHjh0zOTo6miZMmJBrfWxsrMnT09P08MMPW9b997//NQGmp59+ulBfR35xTJs2Ld+YL168aKpSpYrpv//9b5714eHhps6dO1vWTZgwwQSY+vXrl+ecFy9ezLPu4MGDJhcXF9PUqVMt6/73v//lie3fx89x+vRpk4uLi6ljx46mjIwMy/olS5aYANPLL79sWZdzjV544YVcx+zdu7epatWqec4lIiLWp3E093qNoyIiUlQaS3Ov11gq5Z1aSYjYiaysLH755Rfuuece6tevb1nv6enJY489xpEjR/j777+tdr7vvvuOzMxMBgwYQEJCgmVxcnLi5ptv5vfff8+zzxNPPGG18+dYsWIFZ8+eZeDAgbniuHTpEl27duWPP/4gMzMz1z6PP/54nuNUqlTJ8vrSpUucOXMGT09P6tWrx+bNm4sV28qVK0lLS+Opp57KdTe6d+/e1KtXj2XLluXZZ+TIkbne33777SQkJHDhwoVixSAiIoWjcVTjqIiI3BiNpRpLpeJRKwkRO3H69GkuXryYawDO0bBhQwAOHz5c4CMmRbFv3z7gyqMz/3b1oAZgNBqpXr26Vc6dXxzdu3e/5jbnzp2zPGoE5HqcKEdqairjx4/n008/zdNrydfXt1ix5fS9yu//SYMGDXL1owJwcnKiWrVqudb5+PgAcPbsWTw8PIoVh4iIXJ/GUY2jIiJyYzSWaiyVikeJYZEywGQyAVia0OfXjB7Md3gLKzs7G4Cff/45395M/5691cnJqUg9nIoax/z58/MMYDmu7pkE4ObmlmebMWPG8NFHHzF69GhuueUWvLy8MBqNjB071nIOazKZTHn+PxiN134II+f/oYiIlD6NoxpHRUTkxmgs1Vgq5ZMSwyJ2ws/Pj8qVK7N37948n+WsCw8PB8Db2xuAxMTEXNv9e0ZYuPaAXatWLQCqV69uuftbkq4Xh5+fH127di328RctWsSQIUN4++23c60/d+6cZWKBguLIT8713rt3L3Xr1s312d69ey2fi4iI7Wkc1TgqIiI3RmOpxlKpeNRjWMROODg4cMcdd7Bs2TL++ecfy/oLFy7wwQcfEB4eTuPGjQGoXbs2QJ7HRt577708x61cuTJgHoyu1q9fPxwdHZkwYUK+dy9Pnz59Q19PYeO444478Pb2zjWba3HicHR0zHMH9IsvviA2NrZQceSna9euuLi48M477+S68/3DDz+wb98+7r777kLFJiIiJU/jqMZRERG5MRpLNZZKxaOKYRE78sorr7BixQpuu+02Ro4cibu7Ox9//DHHjh3ju+++szwW0qhRI2699VZefPFFzpw5Q0BAAN9//32+A0vr1q0BGDVqFHfeeSeOjo7cc8891KhRgzfeeINx48Zx880307dvX6pUqcLRo0f56aefaNu2LbNnz7ba15YTx4svvsgDDzyAs7MznTt3xt/fnzlz5vDAAw/QrFkzBg4cSGBgIDExMURHR1O5cmV+/vnn6x7/nnvuYcGCBXh6etK4cWO2bdvGV199Rc2aNXNtV6dOHTw9PXn//fdxd3fHw8ODxo0bW37BuVrVqlWZOHEiL7zwAl27dqVv374cP36cd999l5o1a/L0009b5+KIiIhVaBzVOCoiIjdGY6nGUqlgTCJiV3bu3Gm6++67TZ6eniY3NzdTu3btTD///HOe7Y4cOWK64447TG5ubiZfX1/TqFGjTLt27TIBpo8//tiyXWZmpmnMmDGmgIAAk8FgMAGmw4cPWz5ftmyZqWPHjiYPDw+Tm5ubqXbt2qahQ4eaNmzYYNnmv//9r8nFxaXQX8Phw4fzxGEymUxTp041hYaGmoxGowkwRUdHWz7766+/THfddZfJx8fH5OLiYgoPDzfdf//9phUrVli2mTBhggkwxcXF5Tnn+fPnTcOHDzf5+/ubKlWqZOrQoYNpw4YNpttvv910++2359r2+++/NzVu3Njk5ORkAkwTJkzIdfx/+/DDD02NGzc2OTs7m3x9fU1DhgwxxcbG5trmWtfo448/znPNRUSk5Ggc1TgqIiI3RmOpxlKpOAwmk7pPi4iIiIiIiIiIiFQk6jEsIiIiIiIiIiIiUsEoMSwiIiIiIiIiIiJSwSgxLCIiIiIiIiIiIlLBKDEsIiIiIiIiIiIiUsEoMSwiIiIiIiIiIiJSwSgxLCIiIiIiIiIiIlLBKDEsIiIiIiIiIiIiUsEoMSwiIiIiIiIiIiJSwSgxLCIiIiIiIiIiIlLBKDEsIiIiIiIiIiIiUsEoMSwiIiIiIiIiIiJSwSgxLCIiIiIiIiIiIlLBKDEsIiIiIiIiIiIiUsH8P4G8pv6LoTLIAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "hist = result.history\n", + "fig, axes = plt.subplots(1, 3, figsize=(13.0, 3.5))\n", + "\n", + "axes[0].plot(hist[\"iter\"], hist[\"cost\"])\n", + "axes[0].set_xlabel(\"outer iteration\")\n", + "axes[0].set_ylabel(r\"$\\langle P, C \\rangle$\")\n", + "axes[0].set_title(\"Transport cost\")\n", + "\n", + "axes[1].semilogy(hist[\"iter\"], np.maximum(hist[\"row_err\"], 1e-16), label=\"row\")\n", + "axes[1].semilogy(\n", + " hist[\"iter\"], np.maximum(hist[\"col_err\"], 1e-16), label=\"col\", alpha=0.7\n", + ")\n", + "axes[1].set_xlabel(\"outer iteration\")\n", + "axes[1].set_ylabel(\"marginal error (max-abs)\")\n", + "axes[1].set_title(\"Marginal constraint violation\")\n", + "axes[1].legend(fontsize=9)\n", + "\n", + "axes[2].semilogy(hist[\"iter\"], np.array(hist[\"viol\"]) + 1e-20)\n", + "axes[2].set_xlabel(\"outer iteration\")\n", + "axes[2].set_ylabel(\"constraint violation\")\n", + "axes[2].set_title(\"Linear constraint violation\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8cf71f7c", + "metadata": {}, + "source": [ + "The middle and right panels show that both the *marginal* constraints (i.e. $P\\mathbf{1} = a$ and $P^\\top\\mathbf{1} = b$) and the *new* linear constraints are driven to near-zero violation after only a few dozen outer iterations. This matches Theorem 2 of the paper, which establishes a sublinear convergence rate on the first-order stationarity of the dual." + ] + }, + { + "cell_type": "markdown", + "id": "00faff74", + "metadata": {}, + "source": [ + "### How the dual variables evolve\n", + "\n", + "Each outer iteration updates three families of dual variables: the marginal duals $f, g$ (the OTT-JAX analogues of the row/column scalings $u, v$ in vanilla Sinkhorn) and the new constraint duals $\\alpha$. The figure below traces five components of $f$, five components of $g$, and the two $\\alpha$ values across iterations. The marginal duals stabilise quickly with typical Sinkhorn-style behaviour; the constraint duals take a few more iterations because they require solving the small Newton subproblem at each step.\n", + "\n", + "We extracted these traces from the same single run as before : no extra computation, since `track_duals=True` simply pickles the duals at each outer iteration." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "47f17716", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:19.781174Z", + "iopub.status.busy": "2026-05-02T06:39:19.780527Z", + "iopub.status.idle": "2026-05-02T06:39:20.145738Z", + "shell.execute_reply": "2026-05-02T06:39:20.144889Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABXoAAAF9CAYAAAC6fNCqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAAA8ERJREFUeJzs3XdgFNXagPFnaza9EwgpNENNaFIUQUQRRUWkKBekSLMgFuwoHbFzFVGRJoif7VoQERGkiRTpIL0GAkkIIWXTs2W+P5YsCekh2U3C+9N1Z86cM/POgNmcd8+cUSmKoiCEEEIIIYQQQgghhBCixlI7OwAhhBBCCCGEEEIIIYQQ10cSvUIIIYQQQgghhBBCCFHDSaJXCCGEEEIIIYQQQgghajhJ9AohhBBCCCGEEEIIIUQNJ4leIYQQQgghhBBCCCGEqOEk0SuEEEIIIYQQQgghhBA1nCR6hRBCCCGEEEIIIYQQooaTRK8QQgghhBBCCCGEEELUcJLoFUIIIYQQQgghhBBCiBpOEr1CCCGEEEIIIYQQQghRw0miVwghhBBCCCGEEEIIIWo4SfQKIYQQQohaKTo6GpVKxYgRI2r8sbt3745KpaqUfZWHM6+hEEIIIYQoH0n0ClENScf0+jniGmZkZPDCCy/QqFEj9Ho9KpWK1157rcqOJ4QQjqRSqVCpVKjVak6dOlVsvZ49e9rrzps3z4ERCiGEEKI2qO5fKtbmPq2ofSTRK2od6ZgKRxk5ciSzZ88mIiKCl19+mSlTpvDwww87OywhhKg0Wq0WRVFYuHBhkdvPnDnDunXr0Gq1Do6sbOrXr8+RI0d46623nB2KEEKIG8zx48d57rnnaN26NT4+Puj1eoKCgrj77ruZO3cuRqPR2SEWaePGjTU2uViTYxeislTP38qFuE5arRaz2czChQuL7Nzl75iazWYnRFiyvI6pt7e3s0MRxYiOjuZ///sf3bp1Y/Xq1c4ORwghqoS/vz8NGzZkyZIlzJgxo1BCd+HChSiKwgMPPMDPP//spCiLp9PpaNasmbPDEEIIcYOZNWsWkyZNwmq10qFDBx599FG8vb25dOkSmzdvZvz48UydOpXExERnh1otSP9XiMojI3pFreTv70/nzp1ZsmRJkYnc/B3T6iivY1qvXj1nhyKKsW7dOhRFoV+/fs4ORQghqtSYMWOIj4/n119/LVBuNpv54osv6NixI1FRUUW2XbJkCf3796dRo0a4urri5eVFly5d+OqrrwrVzbs9sXv37sTGxjJy5Ejq1auHRqNh+fLl9nqKovDZZ58RGRmJwWAgKCiI4cOHExcXV+jWyuJuecx/rMTERMaOHUu9evVwcXGhZcuWfPHFF9d9PhX1+eefExUVZT+3ESNGEB8fX2TdvJFLU6dOLXJ7cbeaVtZ5/Prrr9x1110EBwfj4uJCUFAQnTp1YtasWWXeR25uLu+//779zzMkJIQ33ngDs9lM/fr1qVOnTrliEkIIZ3vnnXd4/fXXqVevHn///Tc7duxg7ty5vPnmm8yfP58jR46wadMm6tev7+xQqw3p/wpReSTRK2ot6ZhW7Hwq6kbpmP7000+oVCpGjx4NwHPPPWefAmTXrl3likMIIWqCRx55BE9PTxYsWFCg/LfffiMuLo4xY8YU2/bJJ58kOjqabt268dxzzzFo0CDOnj3L0KFDmTRpUpFtLl++zC233MKuXbsYOHAgTz75JH5+fvbtzz33HE899RSXLl1i1KhRDB8+nH379tGlSxdSUlLKdW4pKSl06dKFbdu2MWDAAIYNG2b/LF+6dGmlnE95vPDCCzzxxBPEx8czcuRIhg8fzu7du7n11ltJTk6+7v3nqYzzmD9/Pn369OHgwYP07t2bF154gQcffBCdTlfmKbGysrK46667eOmll1Cr1YwfP5577rmHDz/8kBEjRhAbG0u7du2u51SFEMKhzp49y6RJk9DpdKxatYouXboUWa9bt27s2LGjUPlPP/3EHXfcgY+PDwaDgebNmzN58mTS09ML1Ktov7C0ftDUqVO54447AFi6dKm9n6NSqViyZEmhYxfX/y1vv62o/m95z7EssZfmRunT5mc2m/n4449p3749rq6uBAcH88QTT5CRkYHFYqFp06bcdddd5dqncC6ZukHUWo888gjPPfccCxYs4KGHHrKX53VMp0+fzvnz54ts++STT9KiRQu6detGvXr1uHz5MqtWrWLo0KEcO3aMGTNmFGqT1zH19vZm4MCBWK3WQh3TOXPmEBQUxKhRo3B3d+ePP/6gS5cueHl5levc8jqmer2eAQMGkJ2dzQ8//MDIkSNRq9UMHz78us+nPF544QVmz55NYGAgI0eOxMPDg99//51bb70VT0/P69p3fpVxHvPnz+fxxx8nKCiI+++/nzp16pCYmMjhw4eZN28eEydOLLF9cHAwU6ZMYf78+cTFxTFp0iTUatt3Zi1btqyU8xRCiOrE3d2dwYMHs2DBAs6dO0dYWBgACxYswMPDg0GDBvH+++8X2fbgwYM0bty4QFlubi733nsvb7/9Nk888UShEU0HDx5k6NChLF68uNBUEVu2bGHOnDk0atSIHTt24O/vD8Dbb7/NI488wg8//FCuc9u/fz+jRo3i888/R6PRAPD8888TFRXFO++8U+jztCLnU1bbt29n9uzZhIWFsXPnTvtI1rfeeov+/fvzyy+/VGi/RamM8/j888/R6/Xs37+foKCgAtvKeivyU089xebNm5k4cSIzZsywf54+8cQTdOrUCYD27duX9bSEEMLpvvjiC0wmE4888kixg4ryuLi4FFifPHkyM2bMwM/Pj0ceeQQfHx/Wrl3LjBkzWLFiBZs3by7UtypPv7As/aDu3bsTHR3N0qVLad26NX379rW3b9OmTYFjl9T/rcz+Z1nPsTyxF+VG6tPmSU5Opnfv3mzfvp177rmHu+66iz/++IPPP/8cf39/mjZtyvHjx/nyyy+v97SFIylC1DKAEhQUpCiKojz++OOKWq1Wzp49a99+3333KR4eHkpaWpoyZcoUBVA+++yzAvs4efJkof3m5OQoPXr0ULRarXL+/Hl7+ZkzZxRAAZShQ4cqJpOpUNu///5bAZRGjRopiYmJ9nKLxaIMGDDA3v7afQ4fPrzAfvIfa9SoUYrZbLZvO3TokKLRaJTmzZsXOn5FzufaYxdn27ZtCqCEhYUpFy9etJebzWblwQcfLHRuiqIoGzZsUABlypQpRe7z9ttvL9Smss6jXbt2il6vV+Lj4wvt69KlSyWdqp3ValW8vLyURo0alam+EELURPk/T3ft2lXg53ZMTIyi0WiUMWPGKIqiFPt5Wpwff/xRAZSlS5fay/J+buv1+gKfJ/mNHj1aAZSFCxcW2nbmzBlFo9GU6/PUzc1NSU1NLbSvbt26KYBiNBqv+3zK+nk6ZswYBVDmzZtXaNuJEycUtVpdaZ+nlXEe7dq1U9zc3JTLly+Xef/5bd++XQGUXr16Fbk9KipKAZQffvihQPnkyZOVrl27VuiYQghR1Xr06KEAyoIFC8rVbtu2bYpKpVLq16+vXLhwwV5utVqVYcOGKYAybtw4e3lF+oVl7QflfbYU9/lVlv5vefpt+feZ/5gVOcfSYi/OjdinVRRFueeeexRAeffdd+1laWlpipeXl9KxY0elSZMmyoMPPljm/YnqQaZuELXamDFjsFqtLF68GIDz58+zevVq/vOf/+Dh4VFsu2tHuQDo9XrGjRuH2Wxm3bp1RW5///33i3zyeN6tIhMnTrSPPgJQq9W899579lFEZeXm5sbs2bMLtGvRogVdunThyJEjpKWlXff5lFXetZ04cWKBefQ0Gg3vv/++fXROZais89Bqteh0ukLlAQEBZWp/8uRJjEZjsaOMnn/+eXr37l2mfQkhRE3Qvn172rZty+LFi7FYLCxatAiLxVLitA0A586dY9y4cTRt2hQ3Nzf7LZT9+/cH4MKFC4XaNGjQoNh5Wffs2QNA165di2wXEhJSrvO66aabiryrJjQ0FKDQVBAVOZ+y2r17N2AbkXStJk2alPvcSlIZ5zF06FAyMzNp0aIFzz77LD/++CNxcXFljmHu3LkAvP7660Vuz/t96dqpGyZMmMDKlSvLfBwhhHCkvJ+D5f2ZvWjRIhRFYeLEiQQHB9vLVSoV7777Lq6urixZsgSTyVSgXXn7hdfbD8qvpP5vZfY/y3uOFXEj9mk3bNjA6tWr6dy5My+++KK93MPDg+bNm7Njxw5Onz7NzJkzy7Q/UX3I1A2iVsvfMZ00aVK5OqbvvPMOf/75JzExMWRlZRXYXhUd07Nnz5b1tMrUMc1/e0lFzqesytIxPXfuXIX3n19lnMfQoUN5/vnnadGiBY888gjdunXj1ltvLdfE/3nnXNy8gVOmTCl38l4IIaq7MWPG8NRTT/Hbb7+xePFiWrduTYcOHYqtf/r0aTp27EhycjJdu3alV69eeHt7o9Fo7LdW5uTkFGpXt27dYveZmpoKUGiqgPxty/N56uPjU2R5XqfVYrHYyyp6PmWVd27FnX+9evUq5fO0ss7jueeeo06dOnz66ad88sknzJkzB4DOnTvz1ltvFfl7QX5r1qzB09Oz2Pkr4+Li8PX1pWHDhgXK5YnsQojqTFEUgCLnZi1JXn+xR48ehbYFBQURGRnJjh07OH78eIHp4srTL6yMflB+JfV/K7P/Wd6+b0XciH3avOkYnn/++UJ/Xw0GAwCDBw+mVatWZd6nqB4k0StqPemYSsc0v+vtmMLVX8SKG9Fb3J+PEELUZEOGDOHFF19k3LhxnD9/nldeeaXE+rNnz+by5cssXryYxx57rMC2b775psiHnUHJneO8JN/FixeLTPgV98CUylDR8ymrvPOJj48v8tyKGi2bN8LIbDYXuc+iHk5XmecxePBgBg8ejNFoZNu2bfz6668sWLCAe++9l/379xMREVFku+zsbBISEmjWrFmRo6QuXLjAsWPHCiU80tLS8Pb2Ztu2bfY5fIUQojoJDg7m6NGjxMTElKtdWfpUUPjnenn6hZXRD8qvuFgru/9ZnnOsqBuxT7thwwY0Gk2xd6LqdDqmTZtWrvMT1YMkekWtJx1T6Zheq6Id0zwljeiNj4+nXr16HDx4UB7OJoSoVby8vHjkkUf44osvcHV1ZciQISXWP3nyJAADBgwotG3Tpk0ViqFt27bs2bOHzZs3F/pZHR0dXexDVitDVZxPfu3bt2fPnj1s3LiRpk2bFjp2Uefm6+sLUGRCITU1lePHjxcqr4rz8PLyolevXvTq1QsvLy/eeustVq9eXeznqVarRavVkpycXOT29957D0VRCn3O7t+/H5VKRWRkZIXiFEKIqta1a1fWr1/PunXrGD16dJnb5e9TFZXYzOtTXe9dDdfbD8qvuP5vVfc/q8KN1qfNysri3LlzhIeHF5rSMi4ujj179tCyZUsaNWpU5nhE9SFz9IpaL69jev78ead2TAE2b95caFtt6JgCbNy4schj14SO6dy5c3nhhRfIzs5m9erVpbbbu3cv4eHhBeZbzr/NYDDQrFmzCsUkhBDV2fTp0/n555/5448/Su1sNmjQACj8+fDHH3+wcOHCCh1/xIgRAMyaNYvLly/by61WKy+//HKljOopTlWcT355Hb5Zs2aRkJBgL7dYLLz44otYrdZCbZo1a4a3tzfLly8v8KWx2WzmueeeK3Q7aGWex++//15orki4+uW1q6trsW21Wi0RERFcvHiR3377rcC2JUuW8PHHHwOF75zZt28fN910E25ubmWOUwghHOmxxx5Dp9Px448/cvDgwRLr5h+5mffF1oYNGwrVu3TpEgcPHsTd3b3QF4EVVVI/KG8Kuop+plZ1/7MkFY39RuvTZmVloShKkXfVvPjii6SlpVXqvMTCseRPTtwQpGMqHdM819MxBdvtNsnJycXOz7tv3z5atWolc/QKIWqlkJAQ+vbtW+Sc89d66qmn0Ov1DBw4kCFDhvDyyy/Tu3dv7r333iI7OGVx2223MW7cOE6fPk3Lli15+umneeWVV2jXrh27du2idevW5Z4Xsayq4nzyu+WWW5gwYQLnzp2jVatWjBs3jpdffpk2bdpw4MABoqKiCrXR6XQ8//zzGI1G2rZty7hx4xg3bhxRUVFs376d1q1bV9l5/Oc//6F+/fr079+fF154gRdffJFu3brxxRdf0LhxYx5++OES20+cOBGAfv36MXToUF555RW6d+/O888/b58P8Oabby7QZt++fbRp06bMMQohhKOFh4czY8YMTCYTvXv3Zvv27UXW27JlC507d7avjxw5ErD1qfL3jxRF4eWXXyYzM5Phw4cX+fCtsiprPyhvMEtFpyqo6v5nSSoa+43Wp/X19cXLy4szZ86wd+9ee/myZcv4+uuvgaJHJIuaQaZuEDeEkJCQMj/59KmnnuKLL75g4MCB9O/fn/r163Pw4EFWr17Nww8/zHfffVfu4+d1TD/55BNatmzJgAEDcHd3548//sBoNNK6dWsOHDhQ7v2WRVWcT355HdPZs2fTqlUrBg4ciLu7O7///jsZGRlERUUVOre8junUqVNp27Yt/fr1A2zfYCuKQuvWrdm/f3+VnMd//vMf9Ho9Xbt2pUGDBqhUKnbs2MHmzZvL1DEtbX5e6YQKIYRNVFQUGzZs4I033mDVqlWYzWZat27NTz/9hI+PT4U/fz7++GOaN2/OZ599xoIFC/Dx8eGee+7h7bffpmfPnkU+sKUyVNX55PfBBx8QERHB3LlzWbRoEd7e3tx77728/fbbDBo0qMg2kydPxt3dnc8//5wFCxbg7+9P3759efPNN+2fr1VxHm+//TZr1qxh3759rF69Gp1OR1hYGFOmTGH8+PGlfrE+ZMgQUlJS+O9//8t3331HUFAQ9957L0uXLqVbt26EhoYWejL5/v376d+/f5niE0IIZ3nllVcwm81MnjyZW265hY4dO9KxY0e8vb1JTExky5YtHDx4kICAAHubW265hddee4233nrL3qfy9vZm7dq17Nmzh8jISGbNmnVdcZW1H9S0aVNCQ0PZvHkzQ4YMISIiAo1GQ58+fYr80vFaVd3/LElFY7/R+rQqlYqRI0fy4Ycf0rNnTwYOHEhycjL/+9//uPPOO3Fzc+PXX39l2LBhjBo1ittvv71McYlqQhGilgGUoKCgMtWdMmWKAiifffZZgfItW7Yod9xxh+Lj46N4eHgoXbp0UX7++Wdlw4YNCqBMmTLFXvfMmTMKoNx+++0lHstqtSpz585VWrZsqej1eqVOnTrKsGHDlNjYWKVly5aKt7d3oX0OHz68wD5KO9bw4cMVQDlz5sx1n8+1xy7NvHnzlFatWikuLi5KnTp1lOHDhytxcXHK7bffrhT1o8ZqtSrvvfee0qRJE0Wn0yl169ZVnnjiCeXy5cvFtqmM8/jss8+Uhx56SGnUqJHi5uameHt7K5GRkcqUKVOUxMTEUs/z1VdfVQBl1apVRW6PiIhQ5s6dW+p+hBBCVK6UlBTFYDAonTt3dnYo4jqsWLFCAZRXX321QLnZbFYMBoPy+++/OykyIYQon2PHjinPPvusEhkZqXh5eSlarVYJDAxUevToocyZM0cxGo2F2nz//fdKt27dFE9PT0Wv1ytNmzZVXn/99UJ1K9IvLE8/aPfu3cpdd92leHt7KyqVSgGUL774okzHVpTy9dvy7zN/362ifd+SYi/NjdKnVRRFyc7OVl555RUlNDRU0el0ir+/v/LUU08p2dnZyoEDB5TIyEgFUL7++usy7U9UHypFURSHZJSFEEVKTU2lbt26tGnThm3btjk7HHEdMjMz8fT05K+//qJLly7ODkcIIWql+Ph46tSpU2DuOJPJxOOPP84XX3zBe++9x4svvujECEVpLBYLly9fpk6dOgXKt27dyv33349arebo0aMFRrsdOnSIVq1aERcXV+xT0YUQQgghbnQydYMQDlJcx/T5558nOztbbkWsBfbv34+iKGW6pUkIIUTFzJ07l6VLl9K9e3fq16/PpUuX2LRpE6dOnaJ9+/Y8/fTTzg5RlOLo0aO0a9eOu+++m5tuugmz2cy///7Lxo0b8fHx4eeffy6Q5AXb1Eh16tSRJK8QQgghRAkk0SuEg0jHtPbbv38/jRs3xtPT09mhCCFErXXnnXeyb98+1q1bR1JSEmq1msaNGzNlyhReeuklDAaDs0MUpdDr9dx///1s376dP//8E7A9oOb5559nwoQJRT5XYceOHXTq1MnRoQohhBBC1CgydYMQDrJhwwY++OAD9uzZU6Bj2r9/f1566SXc3d2dHaIQQgghRLWSlZXFkSNHuO+++5g8eTJPPvmks0MSQgghhKi2JNErhBBCCCGEqJamTp3KggULuP/++/n444/R6/XODkkIIYQQotqSRK8QQgghhBBCCCGEEELUcOrSqwghhBBCCCGEEEIIIYSoziTRK4QQQgghhBBCCCGEEDWcJHpLYTabOX/+PGaz2dmhCCGEELWCfLYKIYQQlUs+W4UQQoAkeksVHx9PaGgo8fHxzg5FCCGEqBXks1UIIYSoXPLZKoQQAiTRK4QQQgghhBBCCCGEEDWeJHqFEEIIIYQQQgghhBCihpNErxBCCCGEEEIIIYQQQtRwkugVQgghhBBCCCGEEEKIGk7r7ACEEKKyGI1GUlNTsVgszg5FVBNqtRq9Xk9AQABarXzkCSFEUeTzU5SHRqPB29sbLy8vZ4cihBBCiGtIr1cIUSsYjUYuXLiAVqtFp9M5OxxRTZjNZjIyMjAajdSvXx93d3dnhySEENWKfH6K8srOziY9PR1Akr1CCCFENSOJXiFErZCamopWq6VRo0ZoNBpnhyOqkezsbM6dO0dycrIkeoUQ4hry+SnKy2KxcPr0aVJTUyXRK4QQQlQzMkevEKJWsFgs6HQ66aSKQgwGA3q9HrPZ7OxQhBCi2pHPT1FeGo0GnU4nU30IIYQQ1ZBDE70mk4nJkycTFhaGwWAgKiqKr7/+ukxtFUXho48+IiIiAhcXFyIiIpgzZw6KohSoFxcXx6uvvsqdd96Jt7c3KpWKb7/9tipORwghhBBCCCGEEEIIIaoFh07dMHbsWL788kvGjRtHZGQky5cvZ8iQIZjNZoYNG1Zi2+nTpzN16lSGDh3Kyy+/zKZNm3j22WdJSUlh8uTJ9nrHjh3jnXfeoVGjRrRp04a//vqrqk+rzBRFQaVSOTsMIYQQQgghhBBCCCEqhaIoYLXClXfl6oaC76WUXS0qW/38ZdcOBC37MfOXXbNw7T4rSO3mhsZB0x05LNG7d+9elixZwvTp05k0aRIAo0ePpkePHrz00ksMGjQIvV5fZNv4+HjeeustRo4cyaJFi+xtNRoNs2bNYuzYsdStWxeA9u3bk5iYiL+/Pxs3buSOO+5wzAmW4MSui5zee4nQFn606BLs7HCEEEKIWkVRFC5GG4k9nkJWWi5t7grD3cfF2WEJIYQQQogaTlEUFJMJJSfH9p73yr26jNlUcFtRr7z6ZrOtjdWCYrbY3xWrBfLeLVYUiwUsFtt7Wepcec/bhtUKKCjWqwlYRbGCVbmajFXyL1ttuc38yVql6Lr2RO6V5cpKhtZmvkOHUvf1iQ45lsMSvd9//z1qtZpx48bZy1QqFU8//TQDBgxgw4YN9OrVq8i2v/zyCzk5OYwfP75A+fjx41m6dCm//PILjz/+OACenp5VdxIVdOlcGonn07FaFUn0CiFITU1lyJAhbN26lY4dO7J69Wpnh1RtybUSZXFi50V2/hZtXzcmZnP3qJZodPIoAiGEEEKI2kyxWLBmZmLNyLj6XmA537asLJTsbKy5OSjZOSg52Vizc2xlOXnv2bZt2dlYc3NRsrMlkSlqFIclenfv3k3jxo3x8/MrUN6pUycA9uzZU2yid/fu3bi4uBAVFVWgvG3btuj1evbs2VNpcRqNRoxGo309Li7uuvcZ3sqfM/sTSYrNIC0pG08/w3XvUwhRc33wwQd4e3uTmJiIWi2JqJLItRJlEf3vZQBcPXRkpZtIistg9+poOj7QyMmRCSGEEEKIkigWCxajEavRiMVoxJJqxGpMtS9bjKm2balXthtTsaal2xO5Sna2s09BOFL+6VDzlksoKzB5ahnqF9m2EqZgVel1172PsnJYojc2NpZ69eoVKg8ODrZvL6ltUFBQoU6+Wq0mKCioxLblNXv2bKZNm1Zp+wOo29gbvUFDbraFc4cu07Jr/UrdvxCiZlm7di2vvPKKJC7LQK5V6UwmEzNmzGDJkiUkJCQQERHBq6++yuDBg0ttqygKc+bM4ZNPPuHs2bOEh4fz9NNPM378+AJzyq9du5aPPvqI/fv3c+nSJXx9fWnTpg1vvPEGXbp0qcrTK1VOlpnEmDQA2vduwOUL6RzZEseJXQk0iAqgTrhj5sISorYwWaxcNDqm0xzkZUCnKdvPd0VReP/99/n0009JT09n2rRpzJo1i5UrV9KmTZuqDbQs8ZlMmBMSHHY8bZ06qHRl6zRW92snhKh9rDk5mBMSMCcmYklKwnz58pX3JCyXL2NOSsJyOdG2npxcc0bM6nSoyvPSaguu623vaLWoNFpUGjWoNai0moLvGg1o1Kg0Wvt70XXVqLRa23v+Oho1qNWgUqNSq2yJyivrqEClvrId1TXbVbY+wLXbr2xDpbK1vVJHpSq4b5WKq3Xz9d+KTJiWJUkrz7iqEIclerOysnBxKTxfnlqtRqfTkZWVVe62AAaDocS25TVhwgRGjx5tX4+Li6Njx47XtU+NRk1Icz9O773E2YOS6BXiRmU2mwkMDCQlJYVHH30UT0/PSrlroDaSa1V2jnjQ6ZEjR3BxceGpp56iTp06JCcn89VXX9GtWzd+/fVXevfuXdWnWazYEykoiu33wXqNvQlt5kvsiRRSE7L4d8N57hzRwmmxCVETXTRmM2zRDocc68tRHQnxdStT3TfffJNff/2V9evXExQURP/+/UlISKB58+ZVHGXZmBMSODdqdOkVK0nYooXo6petT1Hdr50Qomax5uZiunABU2ws5viLmBMuYoq/iDk+HlNCAub4eFvy1sFUbm6o3dxQu7uhdnO3vbu7Xym78u7qhtrggkrvgsrggtpgQOVisJXlvRsMqFyu2WYwoNLrJfEoagSHJXoNBgM5OTmFyq1WKyaTCYOh+OkMimsLkJ2dXWLb8vLy8sKrCp6EF97Sn9N7L5Ecn4kxMQuvANdKP4YQ4qrqOCJJq9Wya9cu2rRpg9ForDa/KFTHUUjV9VpVN4560OkzzzzDM888U6D9U089RaNGjfjvf//r3ETvcVtHIjDcE73B9mtNVPcQNn9/gvgzRhLOGmVUrxA1XFJSEu+88w7bt2+nYcOGADz00EOcO3cOFxcXJk2axMaNG/Hz82PZsmVV8rt8TVXatdu6dSuTJ0/GYrHw8MMP8+STTwLINRXiBqYoCpaUFEznzpEbcx7T+RhyY2IwnYsh9/x5zPHxlT8CV61G4+mJ2tsbjZcXGi8v1N5eaLyurHt7ofa6su7thdrD05a8db+S0HV1tY2CFUI4LtEbHBzM2bNnC5XnTbuQN4VDcW3XrVuH1WotcPuu1Wrl4sWLJbatLuo28kLvqiU3y8yWH09y5/Dm9g6pEKLyVdcRSfv27SMyMtKeuKwOHanqOgrp2mtVVGe0Olw/Z3LUg06L4ubmRkBAACkpKZVyLhVhtSrEnrAdv36Er708tLkf3nVcZVSvEBUQ5GXgy1HXdzdbeY5VFuvXryckJISWLVvay5KSkoiKiuLgwYMcOXKEzZs3s3TpUubOncvEiY55qnV+2jp1CFu00KHHK4uSrl1OTg5vv/02K1euLDBwprpcUyFE1VIUBfPFi+ScPEXu6VPknDxFzqlT5J46haUSfr9Tu7ujDQxE4++P1s8Pjb8fWv8A27ufP1p/P/s2tZeXbUoAIcR1c1imsV27dqxfv56kpKQCD2T7559/7NtLartw4UIOHDhQYB6pvXv3kpubW2Lb6kKtUdP+nnC2/XyKpNgM1n95lC4DmsiD2YS4wezbt8/+c0w6UiXLf62K6ozK9XP8g05TU1MxmUwkJCTwxRdfcOjQIV577bVS46yKB50CXL6QTm62BSiY6FWpVUR2D+HvK6N6k2Iz8At2r5RjClHb6TTqMn956SiJiYmFBnasWLGCBx54gL///pt7770XgN69e5c6ZU1VUel0ZZ5KwZFKunbbtm3DYDDQr18/AObMmUOTJk2qzTUVQlQexWIh98wZsg8fJvvQIbIPHSb76FGs6ekV2p/a0xNd/fro6tZFWzcIXVAQ2qC66OoGoQ2yvTQeHpV8FkKIsnBYovfhhx/m3Xff5dNPP+WNN94AbN8gzZ07l8DAQO644w7A9stIYmIiYWFhuLnZfsl88MEHee6555g7dy4LF179pvzjjz9Gr9fz4IMPOuo0rkujNoGYcizsWhXN5Qvp/PbJAaJ6hNDslnqo1XJbshCVqTqOSALbF1R9+vQBqDYdqeo6Cin/tSqqM1pdrp8zOfpBp/fddx9btmwBwMXFhccff7zAXL7FqYoHnQIYL9nm6Ne5aPAKKPj/YWhzPzx8XUhPzuHItji69G9S6ccXQjhGixYt2LFjB0eOHCE0NJS3336b7du38/rrr7N//34iIiIA8PX1JSkpycnRVi8lXbu4uDiOHDnCjh07OH78OM888wyrVq0iKSlJrqkQNZwpPp7M3bvJ2rvPltg9ehSlPM82UqvR1auHLjQUfWgIutAw23uIbV3j41NlsQshro/DEr3t27dn6NChTJkyhUuXLtkfGLNx40YWL15sf9ja3LlzmTZtGhs2bKB79+6ArcP6yiuvMGPGDEwmE926dWPTpk0sW7aMyZMnF+rkzpw5E4AzZ84A8PPPP3Py5EkAe5LZWZp2qovBXcfO386Qk2lm75pznD14mc4PNsK3row2EqKyVMcRSWAbpZo3l2p16UhV11FI+a9VUZ3R2267rVpcP2dy9INOP/zwQ5KSkjh37hyLFy8mOzu71Hn2oWoedApgvGybh9vT31BoHme1WkWzznXZ9ftZzv6bSJu7QnH3Lvp8hRDVW7du3Rg1ahSdOnXC39+fiRMnYjAYaN++PefOnSM1NRWAlJQUfH19S9nbjaWka7d3715uu+02XF1dad26NfHx8YDtM1WuqRA1h6IomM6eJWP7djJ37yFr925MJXzZX4BOh0uDBugbN8alcWNcmjRG36gx+oYNUBfznAchRPXm0EliFy5cSHh4OEuWLGHevHlERESwbNkyHn300VLbTps2DV9fXz755BO+/fZbQkNDmT17Ns8991yhunmJgTzff/8933//PeD8RC9AeCt/ghp6sWf1Wc4cSCQpNoPf5/1Li9uCibojBHUZHuokhKh5EhMTiYuLIzIyEpCOVEmuvVY+Pj6FOqNy/Rz/oNObb77Zvvzoo4/Spk0bRo4cyf/+978S46yqB52mXUn0evkXfZ6N2tXhwIbz5GZbOP7PRdreHVbpMQghHOPDDz/kww8/BGyDOFq0aEHdunXp0qULM2fOZNSoUaxatYquXbs6N9BqqLhr16lTJ9555x2sVivnz5+3f47e6NfUZDIxY8YMlixZQkJCAhEREbz66qsMHjy4xHZHjx5l6dKl/PHHH5w6dQqLxUKbNm148cUX6du3b6H6aWlpTJw4kf/973+kpqbSunVrpk+fzt13311FZyZqE0tqKhnbtpOxZQsZW7diunCh1DZqNzdcWjTHtWVLDC1aYGjRAn3Dhqi08uwgIWoTh/4frdfrmTFjBjNmzCi2ztSpU5k6dWqhcpVKxfPPP8/zzz9f6nGUyn4CZBUwuOu4tX8TGrQOYMeK02Sk5nJocyyXz6fT9ZEI9K7yw1aI2iYgIACLxWJfv9E7UiW59loV1RmV6+fcB50aDAb69OnD+++/T1ZWFq6urhU8i4pLu2wbdezpX/SxdXoNTdoHcXhLLCd3X6RV9/ro9PJEZiFqmm3btlG3bl3Cw8PZunUr48eP5/PPPwcgKiqKJk2a0LVrV3x9fVm2bJmTo61eSrp2fn5+DB48mNtvvx2r1crHH38MyDUdO3YsX375JePGjbPfhTpkyBDMZnOJ00QtXLiQBQsW0K9fP8aMGUNOTg7Lli3joYceYsGCBQXubFEUhb59+7J161YmTJhAWFgYS5cupXfv3qxZs4YePXo44lRFDWOKjSXtzz8xrllD1p69YLUWX1mtxtC8Oa7t2+Ea1RpDyxbow8PlgWdC3AAkm+hkwU18uO/p1uz94ywndiUQf8bIHwsPcufwFrh5ya0SQtRmN3pHqjyK6ozK9XP+g06zsrJQFIW0tDSHJ3oVq0Ja0tWpG4oT0SmII1tjyc22cGpPAs06F57TWAhRve3fv58HHniAnJwcGjduzH//+1/uu+8++/ZZs2Y5MbrqrbRrN3bsWMaOHVuo3Y16Tffu3cuSJUuYPn26/S7R0aNH06NHD1566SUGDRqEvpjb2QcNGsSUKVPw9PS0lz311FPcfPPNvPbaa4wcOdL+xery5ctZv349X375JUOHDgVgxIgRREVFMWHCBPbt21e1JypqjNxz5zCu/oO0NWvIPniw+IpaLW5t2uDWqRNu7dthiGqNxkOmhhTiRqRSasLwVyc6f/48oaGhxMTEEBISUqXHOr4znl2/RaMo4Onnwp0jWsh8gkKUUXR0NAANGjRwahyieqqtfz92797NzTffzIwZMwo86LRHjx4cOnSImJgYXFxcinzQaWxsLA0bNmTo0KEFHnQ6YsQIvvnmG6Kjo+1z4CckJFDnmofoJSUlERUVhVqt5ty5c+WKuzI+WzNSc1g+ey8A94xthX/94p/svOXHk0QfSMTdx4U+z7aRB6AKkU9t/fkoqlZt/Xvz2muv8e6773Lp0qUCX6D++OOPDBgwgNWrV9OrV69y7fP1119n1qxZxMXFUbduXQD+85//sGbNGhISEtBort5p8sEHH/Diiy9y9OhRmjZtWq7jOLLfKqqWNTMT45o1pP74E5k7dxZbT9+wIe5duuDe5VbcOnSUxK4QApARvdVKRIe6GNx0/P2/E6Ql5bD+yyPcPboVLjKNgxBCiCI46kGnXbp0oXXr1tx8880EBAQQHR3N4sWLuXjxIt99950zTt0+Py+UPKIXoPkt9Yg+kEhGSg7njyQR1tK/qsMTQghRA+3evZvGjRsXSPKCbQopgD179pQ70RsbG4tWq8Xb27vAcdq3b18gyXvtccqb6BU1X9b+/aT88APGVb9jzcgoso5r69Z43n03nj3vQh8mzx4QQhQmGcRqJqylP101KjZ/exxjYjab/u8oPYY3R6ur2XMK5lhyMOYYyTJnkWPJIdeSS44lh2xLtn3ZYrXNx2nFiqIoKNgGm1sVq33eZY1ag1qlRqPSXH3lL1NfLb+2TKvW2te1Kq3t/UrZte/XPr1dCCGqK0c86HTs2LH8/PPPvP/++6SmpuLn58ctt9zCCy+84LS5kY1X5uc1uGvRG0r+dcYv2J2gBl5cjDZyZFucJHqFEEIUKTY2tsAXnXny5q3PmwO/rE6dOsW3335Lnz59CkxxFBsbyy233HJdxzEajRiNRvt6XFxcuWIT1YNiNpO2di2Xlywhe/+BIuu4tm+P17334tnzLnRBQQ6OUAhR00iitxoKbeZHxwca8c+K01yKSWfrDye57ZGIanuraVpuGvEZ8VzMvMjFjIu298yLpOakYsw1YswxkmMp+snu1ZVapb6aGM5LAJchQZw/mZy3fm2CuaiEc3HJZrVKbX+pUBW5rFapbXW5Zl2lRo264H6uXS9ivyrVldeVfwD7et61yV+mQoXt38L1JWEuRNVzxINOX3rpJV566aXrDbVS5Y3oLe5BbNdqdms9LkYbSYxJ51JMGoGhnqU3EkIIcUPJysqy3w2Tn1qtRqfTkZWVVeZ9ZWZm8vDDD+Pi4sLs2bPLdByDwWDfXprZs2czbdq0MscjqhdrTg4pP/zA5UWLMMcWTtJr69TB+6GH8HmoL/paNkWKEKJqSaK3mmrSvg5Zabkc2HCemKPJ7FoVTYf7Gjg9cZacnczhy4c5nnycaGM0Z1LPkJydXOH9qVVqXDQuuGhcCiQp7f9ckzC0KlYsigWL1XJ1+cq6RbGVWZUSnj5aRlbFSq4l97r3I2zyJ4rz/3kW9WddXP0S66hU9PDoQds6bTmTeqbgsXH8/zOO+P/UGed1vaoiZpVKRYinzEN3IzJeLv1BbPnVv8kHL38DxsvZHN0aR+AjkugVQghRkMFgICen8AAVq9WKyWSyJ2JLYzKZGDhwIIcOHeK3334jPDy8TMfJzs62by/NhAkTGD16tH09Li6Ojh07lik+4Tz2BO/8BZgvXiy4UaXC484e+D78MO5duqDS1Oy7eoUQziGJ3mqs1e31yTTmcnJ3Aid2XsTD14UWXYIdGkOmKZO9CXvZdXEXBxMPEp8RX2xdrVpLoGsgQe5BBLkF4WvwxUvvhbeLN156Lzz1nrjr3O2JXb1Gj1Zd+X8FFUUplAQ2W82YFTNWxYrZaraVKcW8W832tqXVtdcpqm6+/eSV28vyxZS3bFWs9vW8+BUU+9QVFsViOzeuTmVRU+SfisNWUPnHyHHNQUHBbDVX/s5FtZU3wlzcePJG9HqVMdGrUqtodms9dvx6hpgjSRgTs/AKKNtoYCGEEDeG4OBgzp49W6g8byqFvKkVSmK1Whk2bBh//PEH3333HXfeeWeRxylqqoXyHMfLywsvL69S64nqQVEUjL+tImH2B4VG8Krd3PDu3x+/oY/KvLtCiOsmid5qTKVS0eH+hmSl5XLheAp715zD3ceF8CqeWzDLnMXfF/7m7wt/c+DSgUKJM7VKTUPvhvZXA68GBHsE42fwqxZJF5VKZZsagdr9DWheAtiqWLFiLbxutaKQL0GsWO2J47xX/u32ZDKKPZGsXPnH9u/V8vzJ5kL1r5lfuVD5lX1dW56/nb0s/7Hzz9lcRExeGV4YNAbquNWxX6MCyeVrVVKyucRjOGO/5WhWVbE7krPvchDOYbUqpCeVb0QvQMPWAfy78TxZaSYObDzPbQNuqqoQhRBC1EDt2rVj/fr1JCUlFXgg2z///GPfXponnniCb7/9lgULFtC/f/9ij7Nu3TosFkuBB7KV5zii5sjat4+Lb71N1v79BcrVHh74DR+O37ChaPI9rE8IIa6HJHqrObVaRZeBN/Hn4sMkxWWw9ceTuHnqCQyr3FtOFUXhaNJR1pxdw+bzmwvNqdvMrxlt6rShpX9Lmvo1xVUro6CcTa1Sg4pan9Auq+joaAC8XeSXJCFqO6vZSqvb62O8nI1PkFuZ22l1Glp2rc+uVdGc/fcyrbrWL1d7IYQQtdvDDz/Mu+++y6effsobb7wB2PpJc+fOJTAwkDvuuAOAxMREEhMTCQsLw83t6ufIiy++yIIFC3jnnXcKTKtQ1HG+//57vvnmG/vDU3Nycpg/fz6RkZE0a9asCs9SOIolNZWL771H6g8/FihXu7nhN2IEfsOHSYJXCFHpJNFbA+j0GroPacrq+QfJNOay6Ztj3D26JV5lfABNSRRFYdfFXfzv2P84knTEXq5WqWkf1J5bg2/l5qCb8TH4XPexhBBCiMqg1WuIuiO0Qm2btK/DkS2xZKTmcmDjebo9ElHJ0QkhhKip2rdvz9ChQ5kyZQqXLl0iMjKS5cuXs3HjRhYvXmx/gNrcuXOZNm0aGzZsoHv37gDMmTOHDz74gLZt2xIcHMxXX31VYN8PPfQQ7u7uAPTr14/bb7+dMWPGcPToUcLCwli6dCmnTp1i9erVDj1nUTWMa9YQP2MGlkuJVwvVanz69yfw2WfQBgQ4LzghRK0mid4awtVTzx2PNmPNokPkZJrZ+NVReo1phYubrsL73BW/i2WHl3E69bS9rL5nfe4Ov5s7Qu/A1+BbGaELIYQQ1YZGq6bV7SH8s+I0MYeTSIrNwC/Y3dlhCSGEqCYWLlxIeHg4S5YsYd68eURERLBs2TL7yNvi7NmzB4C9e/cydOjQQtvPnDljT/SqVCpWrFjBxIkTWbBgAUajkcjISFauXMldd91V+SclHMaSkkLc1GmkXZOwd+vQgaDXJ2KQ0dpCiCqmUmraU50c7Pz584SGhhITE0NIiPOf7B53KpUNy46gKBAY6kGP4c3R6sp36358RjwLDixgR/wOe1lTv6Y8HPEwHep2kDkvRY2UN3VDgwYNnBqHqJ7k70f14uzPVqvFyq8f7yc9OYf6ET50HyKdLnHjkp+PoiLk70314+zPVgGZu3Zx4aWXMed70J7ay4ugV17Gu18/6WcLIRxCRvTWMPUae9OpTyO2/3KaSzHp/PXtcW4f1BSNrvSHoJmtZr4/9j3/O/4/+wPWInwjGN5yOJEBkfLBI4SDpKamMmTIELZu3UrHjh3lFr0SyLUSVUGtURPZPYRtP5/iwvEUEs+nERBSuXPfC1FjWUyQFld6vcrgWQ80Zbs7TVEU3n//fT799FPS09OZNm0as2bNYuXKlbRp06Zq4ywDi8VKRkpO6RUribuPCxpN2R6CXN2vnRA1nWKxkDhvHomffApWq73c8+67CXrjdXR16pTQWgghKpckemugxu3qkGnM5cCG88SdTOWv74/TbVBEib/sJWQm8N7O9ziadBQALxcvRrQcwZ1hd9oe6iWEcJgPPvgAb29vEhMTUavl/7+SyLUSVaVBVACHNl/AmJjN7tVnuXtkS1Rq+cJTCNLiYNlDjjnW0J/BJ6xMVd98801+/fVX1q9fT1BQEP379ychIYHmzZtXcZBlk5GSw69z9jvseA8807rMz+uo7tdOiJrMkpbGhRdeIOOvzfYylZsbdSdNwrvvgzKYSgjhcNJrrqEiu4fQqlt9AGKPp/D3dyewWqxF1t16YSvPrH/GnuTt1aAXn9/1OT3De0qSVwgnWLt2LQMHDpTEZRnItRJVRa1W0e7ucAASY9I5cyCxlBZCCGdJSkrinXfeYfHixTRs2BA3NzceeughbrrpJk6fPs1tt91Gt27duO+++0hJSbG3mzRpEl27duXBBx/EaDQ67wScqKRr5+LiwoQJE7j11lu59dZb2b17t72dXDshSpcbHU30I4MKJHldWjSn4Q8/4PNQX0nyCiGcQkb01mBRPUKwWhUO/x3L+WPJ/PXdCW4b0ASt3jZnr1WxsvTQUn468RMAbjo3nmn7DF3qd3Fm2ELcsMxmM4GBgaSkpPDoo4/i6elJXJyDbo+tYeRaCUeo39SX4AgfYo+nsHfNOUKa+aI3yK9G4gbnWc820tZRxyqD9evXExISQsuWLe1lSUlJREVFERAQwG+//Ya3tzeff/458+bN49VXX+XgwYMcOXKEzZs3s3TpUubOncvEiROr6kxw93HhgWdaV9n+izpeWZR07c6cOcOBAwfYunUr//77L1OmTOGnn35y+LUToibK2L6d888+hzU11V7m88gjBL0+EbVe78TIhBA3OunN1GAqlYo2d4WiWBWObI3jwrFk/vziMLcPaYrGFWbvns2WC1sA21y8L3d4mSD3ICdHLYSDVMM5BrVaLbt27aJNmzYYjcZq8y1/dZxXsLpeK1H7tL8nnPhTqWRnmDi46QLteoU7OyQhnEujK/N0Co6SmJhIcHBwgbIVK1bwwAMPEBgYaC/T6XRoNLYBD3///Tf33nsvAL1792bYsGFVGqNGoy7zVAqOVNK1CwgIwNXVFYvFQnJyMgEBAYDjr50QNY1x9WouvPQymEy2Ao2GoNcn4jd4sHMDE0IIJNFb46lUKtreHYaLm5Z9f8ZwOTaD3+f/y78t/mBX1jYA7gi7g/Ftx6NTl+1hF0LUCtV0jsF9+/YRGWl7+OGRI0cYM2YMarUaT09P/u///g8fHx8mTZrExo0b8fPzY9myZXh5eVVp+NV1XsH81wpgwoQJbN++HYCPP/6Y9u3bO/xaidrHy9+VZrfU4/DfsRzdFkfjdnXwDqx+yRohbmQtWrRgx44dHDlyhNDQUN5++222b9/O66+/bq+TlJTEJ598wpo1a+zrERERAPj6+pKUlOSU2J2tpGvn4eFBaGgozZo1IzMzkz///BOQaydESZK//574KVNBUQBQe3sT8uF/cb/lFucGJoQQV8ikh7WASqWiZdf6dOnfBJUaYuJi0a1vgN/lUAZGDOT5ds9LkleIamLfvn32J1zn3W76119/0adPH+bNm1fgdsl+/foxd+5c5wbsRPmvVf7bSz///HPefPNNuVai0rS6vT6unjoUBXatOoNypfMmhKgeunXrxqhRo+jUqRORkZGEh4djMBho3749ANnZ2Tz88MN89NFH+Pv7A7YEZeqVW6pTUlLw9fV1WvzOVNK1W7t2LWlpaRw/fpx169bxzDPPAHLthCjO5YULiZ88xZ7k1darR4NvvpEkrxCiWpERvbVIaCtfYs7twPJXAFqzns6n+tIyJAJrMwWNRm57FjeYajjHIMDevXvp06cPQJG3mzrjdsnqOq9g/mtV1O2lcmupqCw6vYZ2d4ez5ceTxJ82cmrPJZq0r+PssIQQ+Xz44Yd8+OGHAPz888+0aNGCunXroigKjz32GCNHjuS2226z1+/SpQszZ85k1KhRrFq1iq5duzopcucr7trt3bsXf39/VCoVfn5+9oeuybUTorDEefO49OFH9nV9w4aELV6Erl7Z+wFCCOEIkuitRRb+u5C/TWsxtPHgrguP4pHuxdFt8Vw6l0aXATfh6WdwdohCOE41nGMQbKNUJ02aVKAs/+2mn3/+ucNvl6yu8wrmv1ZF3V76888/y62lotKER/oT/W8iF46nsHt1NPUae5f5SwkhRNXatm0bdevWJTw8nK1btzJ+/Hg+//xzAFavXs3KlSuJi4tj/vz59OnThwkTJhAVFUWTJk3o2rUrvr6+LFu2zMln4RwlXbu7776bZcuW0a1bN7Kzs5kxYwaAXDshrpG0dGmBJK+hRQtCFy5A6+fnxKiEEKJokuitJdaeXcvK0ysB6Nb8VkYNvId9a2M49k88ly9k8NsnB2h9ZwhNO9dDrZbRvUI4Q2JiInFxcURGRtrLrr3dVG6XtLn2WuW/vfTYsWOMHz+efv36ybUSlUalUtGpTyNWzt1PbraFbctPceew5qjkM1MIp9u/fz8PPPAAOTk5NG7cmP/+97/cd999ANx7772kpaUV2W7WrFmODLNaKunaaTQavv766yLbybUTwib5u++5+Nbb9nVDZCRhixeh8fR0YlRCCFE8maO3FjiRfIJP930KQGRAJE+1fgqtTsPNvRvQbVAELm5aLGYre/44x5oFB0mOz3ByxELcmAICArBYLLi5uQEUebtply5d+OOPPwBu6Nslr71WFoul0O2lcq1EZXP11HPzfQ0BuHjGyPGdF50ckRAC4IknniAxMZG0tDT27dvHwIEDnR1SjSHXToiKS/31V+KnTrWvuzRtStiC+ZLkFUJUazKit4bLMmfx3s73MFvNBLoF8krHV9CoNfbtoc39CAzzZPfqs0QfSORybAa/z/uXJu3rENUjFIO7PKRNCGcp7nZTuV2ysKJuL5VbS0VVaBDpT8yhy8QcTWbv2nPUa+JdLac2EUIIIUTVydi+ndjXJtofvKZv1Mg2ktfHx7mBCSFEKSTRW8Mt/HchcRlxqFVqXu7wMt4u3oXqGNx1dOnfhIZR/uz49QwZqbmc2JVA9L+XadWtPhGdgtDqNEXsXQhRlYq73VRulyysuNtL5VqJyqZSqej4QCMSzu0nJ9PM5u9O0GtMS/mcFEIIIW4QOadOcf6ZZ8FsBkAXEkLYF4vR+vs7OTIhhCidTN1Qg+25uIc10WsAeLjpwzTza1Zi/eCbfLl/fGta9whBo1NjyrGwd+05Vny4j2Pb47CYrI4IWwghhKjWDB46bu3XGICUi5ns+i3auQEJIYQQwiHMly8T8/gTWI1GANTe3oTOn48uKMjJkQkhRNlIoreGyrXkMm//PAAa+zTmkaaPlKmdVqeh1e0hPPhsGxq1DUSlgqx0E7t+P8svH+3l6PY4TLmWqgxdCCGEqPaCb/KlVbf6AJzae4lj/8Q7OSIhhBBCVCVrTg7nxz2N6fx5W4FOR8jHc3Bp1NC5gQkhRDlIoreG+uH4D8RlxAEwrs04tOryzcLh6qnnlr6Nuf/p1jSItN2CkpVmYvfvZ1n+wR72rjlHpjG30uMWQgghaorIO0Ko18Q2JdLu36OJPZHs5IiEEEIIUVUuzpxJ1r599vV6M6bj3rGj8wISQogKkERvDZSYlcgPJ34AoHfD3tzke1OF9+UV4EqXATfZEr5RAahUkJtt4fCWWJbP3sNf3x4n9kQKVqtSWeELIYQQNYJareK2gTfhHeiKosBf353gUkzhebWFEEIIUbOl/PgjKf/7wb7u/+QT+PTt67yAhBCigiTRWwN9e/RbTBYT7jp3hrYYWin79A50pUv/Jjz4fFuad6mHzkWDokDMkSQ2fHWUFR/u5d+N58lIzamU4wkhhBA1gd6gpfuQphjcdVhMVjZ+dZTk+AxnhyWEEEKISpJ16BDx06bb1z1uv53A8eOdGJEQQlScJHprmAvpF1h7di0AAyIG4KH3qNT9u3u70O7ucB56oR0dH2iIf7A7ABmpuRzYcJ5f/ruXP784zPGd8WSnmyr12EIIIUR15OFr4M7hzdEbNORmW/jzi8Mknk93dlhCCCGEuE6WlBQuPPMsSq5t2kJd/foEv/sOKrWkSoQQNVP5JnYVTvfNkW+wKlZ8Db7c3+j+KjuOzkXDTTcHcdPNQSTHZ3BydwJn9idiyrFwMdrIxWgju36Lpk4DL8Jb+hPSzBdXT32VxSOEEEI4k0+QGz2GNWf9l0fIzbawbulhuj4SQXATH2eHJoQQQogKUBSF2NcmYrpwAQCVXk/9OR+h8fZ2cmRCCFFxkuitQRIyE9h8YTMAAyMGYtAaHHJc37rudLivIW17hnH+WDLnDiUReyIZi1nh4hkjF88Y2bHyDL513Qi+yYf6Eb74h3igVqscEp8QQgjhCP71PbjrsRas//II2RlmNn51lPb3NiCiYxAqlXzmCSGEEDVJynffkb5hg3297pTJuLZs6cSIhBDi+kmitwZZfnI5VsWKp96TnuE9HX58rV5Dg8gAGkQGYMqxcOFYMmcPXbY9rM2ikByfSXJ8Joc2x6I3aAhq6E2dBp7UCfPCp66bJH6FEELUeL513bl7dCs2/t9RjInZ7FoVzaWYNDo90Aidi8bZ4QkhhBCiDHJOn+bi2+/Y1736PIBP//5OjEgIISqHJHprCGOukTXRawC4r9F9DhvNWxydi4YGUQE0iArAnGsh/oyR2BMpxJ5IJiMll9xsCzFHkog5kmSvHxjmSWCYJ/7B7vjV98DFVf76iRtTamoqQ4YMYevWrXTs2JHVq1c7OyQhRDl4+hnoNboVW344QezJVM7+e5nEc2l07tuYuo3kdk9Rs5msJi5lXnLIsQLdAtGpdWWqqygK77//Pp9++inp6elMmzaNWbNmsXLlStq0aVO1gQohahUlN5cLL76Ikp0N2OblrTt5spOjEkKIyiGZthpibfRaciw56DS6Kp2btyK0eg0hTX0JaeqLojTAmJhN3MkUEqKNXDybRm6WGVOO5UoiOMXezsPXBf9gD3zrueFdxw2fOq64e7ugkpG/opb74IMP8Pb2JjExEbU86EGIGknvqqX7kGYc3hLL/nUxZKTmsm7pEcJa+NGmZxiefs79QlaIirqUeYnH1z7ukGN93vNzgj2Cy1T3zTff5Ndff2X9+vUEBQXRv39/EhISaN68eRVHKYSobS7NmUPO4SO2FbWa4PfeReNRuQ85F0IIZ5FEbw2gKAprztpG894ecjveLtV3tJBKpcI70BXvQFea3VIPxaqQmphFwlkjCdFpXL6QTnpyDgDpyTmkJ+dw9tBle3uNTo13gK29dx1XvAJc8fB1wd3HBb1B/rqK2mHt2rW88sorkuQVooZTqVW07Fqfek182PbzKVIuZnLucBIxR5OJ6BBEq271MXiUbbSiEKJ4SUlJvPPOO2zfvp2GDRsC8NBDD3Hu3DlcXFwA2Lp1K126dCE5ORkfHx8AJk2axMaNG/Hz82PZsmV4eXk56xSEENVE5u7dXF602L4e8MTjuLVr58SIhBCicknmrAY4mHiQ2PRYAO4Ov9vJ0ZSPSq3Cp44bPnXciOhQF4CcLDOXL6STFJvB5QvppFzMtCd/LSYrSXEZJMVlFNqX3qDBw9dgS/z6uuDu7YKrhw5XT/2Vlw6NVhJnovoym80EBgaSkpLCo48+iqenJ3Fxcc4OSwhxnfzquXPvE5Gc2p3AgQ3nyc4wceyfeE7sukh4K3+adqqLf30ZKSRqhkC3QD7v+bnDjlUW69evJyQkhJb5HpKUlJREVFSUff3DDz/k5ptvtq8fPHiQI0eOsHnzZpYuXcrcuXOZOHFi5QUvhKhxrNnZxL3+BigKAIaoKAKefNLJUQkhROWSRG8N8Ef0HwCEeYXRzK+Zk6O5fi6uWoKb+BDcxMdeZs61kJqYRWrClVdiFqkJmWSk5OR9DpObbSk2CZxHb9Dg6qXHxU2Li5sOF1cteletbT3v3U2HzqBB56JBp9egddHIg+Jqoeo4x6BWq2XXrl20adMGo9GISiV/74SoLdRqFTd1CKJBVABHtsZyeEscFpOVM/sTObM/Eb9gd8Jb+hPS3Bcvf1dnhytEsXRqXZmnU3CUxMREgoMLxrRixQoeeOABAP788086dOhAQkKCffvff//NvffeC0Dv3r0ZNmyY4wIWQlRLiZ98Qm50NAAqvZ7gt99CpZM7b4QQtYskequ59Nx0tsRuAaBXg161NjGk1WvwD/bAP7jgiCerxUpGai4ZKTlXpnrIJj0lh4zkHDKNuWSl5doTwWBLBudmZ5X7+BqtGp2LGq1eg1aflwRWo3XRoNWpUWvVaDRqNFo1Gp3K9q5Vo9ZeXc7/UmtUqDUqVGoVanXBZdWVdfuyWoVKo0Klotb++TpDdZ1jcN++fURGRhb4s772dlO51VSImkvnoiHqjlCada7H6X2XOL4jnrSkHJJiM0iKzWDv2nP41HGlXhMf+0NKDe7SyawuFEVBUWzvWK+sA4pVAdu/+ZYVlCt1ULC3u3Yd5Zr9FrOt6HpXj5v3C4/91568NvYKV4vtbfMV5luDvGMCGZZs3Lz1ZGearqlUlgtWcKVMzct7DKBxg5vYsWMHe3ftJ6R+KO/Pfo/t27fz4vMvk5WWy4f//Yili5fx64qVZKXl4qLJ5WLcJZo0uYnMtFxctO4kXrpMZlpulcR3fSp4QKXE1eui02vQy0OTRS2T9e/BglM2PP00Lo0aOTEiIYSoGvIJXs1tj9uO2WpGrVLTPbS7s8NxOLVGjaefodiH2ihWhZwsM1lpuWSlma6855KdaSY3y0xOlpncTDM5mSZyMs3kZluK3I/FbMVitkKGuSpPp1QqFQWSwqiuJoBVKiBvWZ2/LF+SOF/dq+uFtxVZLy+AvOW8dfK2kW85f6Vrt1/ZVrCKLc58DVSFdnrNflRFbMsfkyr/qgq80vGp40pGqm0akMzMXKxWaxmu+vXLSs8l05rDNWdcpJ3/7KZVi0gyjVc7mx+8N5t2bduTaczlxNG9HDxwiD9+W8dXXy/jvx98xEsvvlLiPqv864EqPoCqhLWybJIkmaiO9K5amt1Sj6ad6hJ3KpXofxO5cCyZ3GwLKQlZpCRkcWSrbeoWTz8XfILc8QlyxbuOGx6+Lnj4uKB31da4LwCtVgWr2YrVomCxWLGaC79f3ZZXz9bGYrGtW80Klmv2YbVY7fWs1ivJUIuCVVFQLLbkqtV6JfFqVVCseesK1ryEbVHl+dYVq1Lgy+MbRd3WGkKa+ZJxZRqt6qhdZCeGPDKMbnd0xc/Xj+eefgGDi4FmTSL58Yef6XTzrWDWYjFbyUrLRa/Oxc3gSeLFy2QZc7mclIiXhzdZxjIkegV4IIleUasoubnEvf46XOkbGFq0wH/kY06OSgghqoZ8gldzmy9sBqBtnbZ46WVU37VUahUGdx0Gdx2+dUuvr1gVcnMs5GaZMeVYMOdaMOVYrixbr67nWjBfKbdYrnQ4TbZksH3dbMVypTNqMVuxmKzX3UFUFLCYrwzREeVSt7UGTz8D2ekmADytPrzb+kOHHNvD5ENWmqlMdffs2cM9PXuTdWVU0abNG4hs2Yb4+Itkp+eyaeNfdO96J1lpudx+aw/GPfs4T499virDr9HyfgYIUV2p1CqCb/Ih+CYfrBYrF88YuXA8mYvRaaRczAQgLSmHtKQcYo4UbKvRqXHz1KF31dmnINK7atEZNFfvItGp0Ghsd5JQXFJYsSVXbS9bovTq+pWX1VpgvcgEqyVfAtZste/PYr667xsxUepohf6Y877ELbBNle+L0/xVVYXKbH93bH9Xiz1GqUHZ/lPmZhX4/uLtWe/x9qz3AFi5agVNmzYjuH49vv7+CJv/3simzes5fPQQTz03lu+++oFbO9/C+x++y4hhj7F+01puvaULWl05n+VQ4e9ZqscXNBX9nkijqR7xC1FZLi9aRM7x47YVrZZ6s95EpZVUiBCidpKfbtVYak4q+xL2AdA1pKtzg6klVGqVba7eKhqlkDeSyZI34ujKaCGLxWpftuaNPLLkH1WkFBiZlNcJB+wjjK7evmnbL/bbSYu6/fNqG+XKoNa85bxbPK9uu1LXmu+W0Px3eypX7/cscLsoV9sW2F7SraXkLeeTF2++RlfbXhPTtev56rp65aBWq9C5aADQoSHUNaToP6hrVH1i4uoBDh3+l1defA2t3tbZXLj0cxbNW8KfG9ag0akxGlNo0rgJWp2agAB/UlKTS+yYVknoDkzUKMWtlTGGmjbaUdzY1Bo19Zr4UO/KHPW5WWYuxaSRHJdBysUsUi5lYryUZf+ZZDFZSUvKAarvSMvrpdaoCkx5ZFtWX1lWodaor75rbO8q+/RH2KdBUqmuTJOUd9eL+popkvKXqQpPnWTfj/2OGuzvKq7eSQNcXS5tWzF33JR1W96dMPkTsFUh+sp8lX713Kv0ONdj27Zt1K1bl/DwcLZu3cqrb7zI559/jk8dN2a+NdVer3v37nz7/dd4+7jRpU4nflvblPv798LX15dly5bh7e3mvJMQQjhF7rlzJH42z74eMHYMhmY1/7k3QghRHEn0OlpGImyebcvO3fIU+IQVW3Vb7DasihWtWkvnep0dGKSoKLVahfrKPL/CsfI6ql4B1fchR4mJicRfjOeWbh1wc3NjxYoV3NWzB/XCA9Dq1HgHulEvtA4msvGu40ZiYiIBdfzxriMdUyFqI72rlvoRvtSP8LWXWa0KWcZcMlJzyEjJISvdVHAqoit3pOR9qWjJN81BSfISqeorydK8ueLty9ds02hUqLV5y1fmpLcnYK+0yVdmm8v+Svui6uYlbLVX95+XoBWiNPv37+eBBx4gJyeHxo0b89///pf77ruvUL2NGzcWWJ81a5aDIhRCVEeKohD/5psoubY76fQNGuD/xBNOjkoIIaqWJHod6fxuWP0KZCXb1k9vhNueh9aPFFk9b9qG9kHtcddV31EWQoiyCQgIwGK5Ok/0v//+y/r161mzZg0HDhxg6NChvPnmm8ycOZNRo0axatUqunaV0fxC3EjUahXuPi64+7hAuLOjEaJ6eOKJJ3hCkjNCiHJK+/NPMjb9ZV+vO3kSar3eiREJIUTVk0SvoygKrJ1sS/JqXUHvBpmX4a/3wDsEGnQpUD3DlMHBxIMA3Fb/NmdELISoYq+//jqvv/46YLvddNmyZfj4+NCkSRO6du1qv9VUCCGEEEIIUXbWzEwuznrLvu7Vuzfut97qxIiEEMIxJNHrKJeOQXq8bbnPHAi4CX4aC4nHYc3rMOhr8Aq2V9+bsBerYkWlUtE+qL2TghZCOEr+203lVlMhhBBCCCEqLvGzzzDHxQGgdnenziuvODkiIYRwjHI+elZUWLRtGgbcAyG4Lbh4wr3vgN4dctJg3fQCT4TaEb8DgGZ+zfDUezojYiGEEEIIUcup1WqsVquzwxA1jNVqRa2WrqSonnJOn+HyF0vs64HPjEcXVMd5AQkhhAPJp7Oj5CV6G9wGeQ8e8QmDO2y3bXN+JxxdCYDFamH3xd0AdKzb0dGRCiGEEEKIG4RerycnJ4fs7GxnhyJqiOzsbHJyctDLXKeimkp45x0wmwFwadoU3yFDnByREEI4jkzd4AgZl+HiIdtyg2serHTT3XD0Nzi7BTbPhvAunMhOwJhjBKBD3Q4ODlYIIYQQQtwoAgICMBqNnDt3ThJ3okxyc3PRaDQEBAQ4OxQhCknf/DfpmzbZ14MmTkSllbSHEOLGISN6HeHs37Z3jR5Crxmhq1JB99dsD2jLMcLmD9gVvwuAQLdAwjzDHBysEEIIIYS4UWi1WurXr4+bm5uzQxE1hJubG/Xr10cryTNRzShmMxffftu+7tmzJ+6d5A5ZIcSNRT6dHSH6SqK3fnvQuRbe7lUPbnkKNn8Ax1ez3xILQLs67VDlTfMghCiRRqMhOzsbi8WCRqNxdjiiGsnOziY3N1eSGEIIUQx3d3fc3d2dHYYQ1YLJZGLGjBksWbKEhIQEIiIiePXVVxk8eHCpbWfNmsWOHTvYsWMHcXFxPP7448ybN69QvSVLlvDYY48VuY8TJ07QpEmT6z6PG1Hyt9+Re+oUACqdjjovv+TkiIQQwvEk0esIUYPAIwiCWpVQ5xE4torMhMMcT9gPXvVoHdjacTEKUcN5e3uTnp7O6dOn0el0zg5HVBNWq5WcnBw0Gg2+vr7ODkcIIYQQ1dzYsWP58ssvGTduHJGRkSxfvpwhQ4ZgNpsZNmxYiW1ff/116tSpQ4cOHfjtt99KPdbUqVNp3LhxgbK6deteV/w3KktKCokff2xf9xsxHH1oqBMjEkII55BEryOEtLe9SqLWQI9JHPphMFZrLmSnEBUY5Zj4hKgFvLy8AEhNTcVisTg5GlFdaLVa3NzcCAgIkFtMhRBCCFGivXv3smTJEqZPn86kSZMAGD16ND169OCll15i0KBBJc5lffr0aRo2bAhQpjsze/XqRefOnSsn+Btc4mefYUlNBUATEID/4487OSIhhHAO6fVWJ4FN2R/aGhJ20iArHe/UWKjj7eyohKgxvLy87AlfIYQQQgghyuP7779HrVYzbtw4e5lKpeLpp59mwIABbNiwgV69ehXbPi/JWx5paWm4ubnJ1GPXIff8eZK+/sa+HvjsM2g8PJwYkRBCOI9DH8ZmMpmYPHkyYWFhGAwGoqKi+Prrr8vUVlEUPvroIyIiInBxcSEiIoI5c+agKEqhumlpaYwfP566devi6upK586dWbNmTWWfTpU44OoKGh1RZhWsnwlWGZkohBBCCCGEEFVt9+7dNG7cGD8/vwLlnTp1AmDPnj2VeryePXvi5eWFm5sbvXv35vDhw2VuazQaOX/+vP0VFxdXqbHVJJc+mgMmEwAuNzXBp18/J0ckhBDO49BE79ixY3nzzTfp27cvH3/8MaGhoQwZMoQvv/yy1LbTp0/nueeeo3PnznzyySd06tSJZ599lhkzZhSopygKffv2ZeHChYwaNYoPP/wQtVpN7969Wb9+fVWdWqVIzUnljPEcuPoRZVHBpaOwr2yJcCGEEEIIIYQQFRcbG0u9evUKlQcHB9u3VwY3NzeGDx/O3Llz+fnnn3n11VfZvHkzt956KydPnizTPmbPnk1oaKj91bFjx0qJrabJPnwY46+/2tcDJ0xAJaOjhRA3MJVS1JDYKrB3717atWtXYL4jRVHo0aMHhw8fJiYmptj5juLj42nQoAFDhgxh0aJF9vIRI0bw7bffEh0dbZ+0/ueff6Zfv358+eWXDB06FICcnByioqJwdXVl37595Yr7/PnzhIaGEhMTQ0hISAXOvOy2XtjKWzveQqVS8Y2hBe4n1oLWFQZ/C95Ve2whhBDCURz52SqEEEKUVePGjWncuHGRd4Pq9XqGDRvGwoULy7QvlUrF448/zrx588pUf+fOnXTu3Jn//Oc/fPXVV6XWNxqNGI1G+3pcXBwdO3a84T5bz40aTcaWLQC43tye8GXLyjQ/shBC1FYOG9Fb0nxHCQkJbNiwodi2v/zyCzk5OYwfP75A+fjx48nJyeGXX34pcBw/Pz8GDx5sL3NxcWHs2LHs37+fY8eOVeJZVa7DSbZbdRp6NcS928tg8AZzFqx+Dcy5To5OCCFEdeSIaZHWr1/P6NGjadasGa6urvj7+/PQQw9x4MCBqjglIYQQwikMBgM5OTmFyq1WKyaTCYPBUGXH7tChA127duXPP/8sU30vLy9CQkLsr6JGItd2GVu32pO8AEEvvihJXiHEDc9hid7rme9o9+7duLi4EBUVVaC8bdu26PX6Am13795N+/btC01mX9Z5lZw519GRy0cAaO7fHNz8oMcbtg0Jh+Hv2Q6LQwghRM3hiGmRXn75ZdauXct9993Hxx9/zHPPPce2bdu49dZb2b9/f1WdmhBCCOFQwcHBRfb/8qZsyJvCoaqEhYWRlJRUpceoLRRFIeGDq31kz549cW3TxnkBCSFENaF11IGuZ76j2NhYgoKCUKsL5qXVajVBQUEF2sbGxnLLLbdU6Dhgm+to2rRpJdapCjmWHE6lnAKghX8LW2HjHtD2Udj7Ffz7P/BvApEDHB6bEEKI6mnv3r0sWbKkwLRIo0ePpkePHrz00ksMGjSoxGmR3nrrLUaOHGmfFmn06NFoNBpmzZrF2LFj7dMiffDBB3Tt2rXA5/CQIUNo1aoV06dP58cff6ziMxVCCCGqXrt27Vi/fj1JSUkFBij9888/9u1V6fTp09SpU6dKj1FbpP35J9mHDtlWNBoCn3/euQEJIUQ14bARvVlZWbi4uBQOQK1Gp9ORlZVV7rZgu70mf9vi6ubdZlPScQAmTJhATEyM/bVjx44S61eWE8knsCgWAJr7Nb+64ZbxUL+9bXnTuxC9pYjWQgghbkSOmhbp9ttvL/Rla6NGjWjbti2H8jpZQgghRA338MMPY7Va+fTTT+1liqIwd+5cAgMDueOOOwBITEzk6NGjZGZmVug4RY3aXbduHVu2bKFXr14VC/4GolitJM752L7u/VBfXBo1dGJEQghRfThsRO/1zHdUXFuA7OzsAm2Lq5udnW3fXhIvLy+8vLxKrFMVDl+2zc8b4BpAoFvg1Q0aLfR+H354DJKj4feX4YE5ENLe4TEKIYSoXsoyLVJxHcbyTItUnLi4uDLdxlrUA2OEEEKI6qZ9+/YMHTqUKVOmcOnSJSIjI1m+fDkbN25k8eLF9gFFc+fOZdq0aWzYsIHu3bvb2y9btoyzZ8/a1/fs2cPMmTMBGDp0KOHh4QDcdttttG3blqioKHx8fNi3bx+LFi0iKCiIqVOnOux8ayrj77+Tc+KEbUWnI/DJJ50bkBBCVCMOS/QGBwcX+NDLU5b5joKDg1m3bh1Wq7XAiCKr1crFixcLtHX2vEoVlZfobe7fvPBGg5ctufvjaMhIgF+fhQc+hJCbHRukEEKIasVR0yIV5euvv+bMmTM8X4ZbJZ01LZIQQghRXgsXLiQ8PJwlS5Ywb948IiIiWLZsGY8++mipbRctWsSmTZvs6zt37mTnzp2ALbmbl+jt168fq1atYtWqVWRkZFC3bl1GjBjBlClTqF+/ftWcWC2hmM0kzv3Evu47cAA6uWZCCGHnsKkb2rVrx6lTpwrdplKW+Y7atWtHTk5Ooad77927l9zc3AJt27Vrx549e7BYLOU+jrNYFStHk44C10zbkJ93fXhoHrgFgDkLfnkajv/hwCiFEEJUN46aFulax44d46mnnuLmm2/miSeeKDVOZ02LJIQQQpSXXq9nxowZxMTEkJOTw7///lsoyTt16lQURSkwmhdg48aNKIpS5Ct/3ZkzZ7Jnzx6Sk5PJzc3l3LlzzJ8/X5K8ZZC6ciW5Z84AoHJxwf/x0n8PEUKIG4nDEr3XM9/Rgw8+iF6vZ+7cuQX2+fHHH6PX63nwwQcLHOfy5ct888039rKcnBzmz59PZGQkzZo1q6pTrLC4jDgyTBlACYleAN9w6DcfvILBaoI/JsLWuWC1FN9GCCFEreWoaZHyi42N5Z577sHHx4fly5ej0+lKjdPLy4uQkBD7q6hRyEIIIYQQJVFMJhI/uZpP8B00CF2QPLxOCCHyc9jUDdcz31FwcDCvvPIKM2bMwGQy0a1bNzZt2sSyZcuYPHlygQ5jv379uP322xkzZgxHjx4lLCyMpUuXcurUKVavXu2o0y2XE8m2+YW0ai3h3uElV/YNhwFLYOWzkHAEdn8BcfvgzingE1rlsQohhKg+HDUtUp7Lly/Ts2dPMjMz2bx5s4w8EkIIIYTDpK5YgSkmBgCVqyv+Y8c4OSIhhKh+HDaiF2zzHU2cOJGffvqJcePGcfbsWZYtW8Zjjz1Wattp06Yxe/ZstmzZwlNPPcXWrVuZPXt2ocnqVSoVK1asYNSoUSxYsIBnn30Wk8nEypUrueuuu6rozK7P8eTjADT0bohOXfrIKNz9of8iaNnPth67F775D+xeCubcKoxUCCFEdeKoaZHA9kC1Xr16cf78eVavXk1EREQlnYUQQgghRMkUs5nE+fPt636PDkHr7+/EiIQQonpSKYqiODuI6uz8+fOEhoYSExNDSEhIlRzjpU0vcTTpKPc1uo8nWpdzjqETa2Hj25CdYlv3qg83PwZNe4O26LkXhRBC1A67d+/m5ptvZsaMGbzxxhuAbVqkHj16cOjQIWJiYnBxcSExMZHExETCwsJwc3MDbKN+GzZsyNChQ1m4cKF9nyNGjOCbb74hOjrafsdMVlYWvXr1YufOnaxevZrbb7/9uuJ2xGerEEIIcSOp7Z+tqb+uJPallwBQGQw0Wb8OrZ+fk6MSQojqx2FTN4iima1mTqWeAuAm35vKv4ObekL99rD1YzjyKxgvwPqZsO1TiHoYWvUHN/kAFKJC8r4Hs38fVsT3YtduK229Qm0pWF6WusW1LVBezuOWpW2p8RV1DtcerojjFFeuUkNABX521hKOmhZpyJAhbN68mcGDBxMTE8NXX31l3+bh4UHfvn0dedpCCCGEuIEoViuX539uX/d95GFJ8gohRDEk0etk54znMFlMAET4VPA2WDc/uGsKRD0CuxbDqfWQlQT/zIMd822J4CZ3QaM7bNM+iKpntYAlFywmUCxgtdoeoGe12NYtJlCsYDVfeVmuvMz5tpfUzmIrz1tWrLYkmGIFlHxl+V9KEWVF1bHkq2vJV2a15dYUS+G6KMW8U3w5ytX9FCovyz5KeM9rk3c97PnBMrYVNYfOFZ7429lRONXChQsJDw9nyZIlzJs3j4iICJYtW1boCeFFmTZtGr6+vnzyySd8++23hIaGMnv2bJ577rkC9fbs2QPA119/zddff11gW3h4uCR6hRBCCFFl0tatI+fESQBUOh1+I0c6OSIhhKi+ZOqGUlT1LTCro1fzyd5PcNW68u3936JWVcK0ySnnYN83thG+5qyC23wbQnBb26tOM/AOA00tzfcrii3ZasoCc7btZcq2XRNT9tWy/OXmHNu6xXQlUZtbhuX8ZWbbu2Jx9tkLcWOQRG+NVNtvLxVCCCEcrbZ+tiqKQvSAgWQfOgSAz8MPU2/6NCdHJYQQ1VctzfDVHCeSTwDQxKdJ5SR5AXzCoPsrcOvTcGYznPwTzm4FSw4kn7G9Dv1kq6vWgW84+ISDRx3byz3vPQD0HraXVl85seWxWq4kYHMKJ1wtOfm2ZdkeMGfKzLeeczV5W6Be/vIr+6vJIzRVGtDobLemqzW2Pyu1BtRa27tKc+VdbVtWqa4sl/QqrY7q6vEoalltW847bl65SnVlWXUldnXBdchXJ1/9AuWqq+2uLS+w76LKi3u/NhaVffclxmJfpvj1An9WpdQtsS0l11Vd2+baa1qBtgXKy9r22mNez3FLOIcij1PS9mLiFEIIIYQQNV7G31vsSV40GvzHjHZuQEIIUc1JotfJTqXY5udt4tuk8neud4em99heuZkQu/fq6+KhK1MCmODySdurJGqdbX861yuJRu2VxJ/WlvxTXUn8Wc222+XtUw+Yr05JYLXYkrjmHNu6s6k0oDOA1mA7L+2VZa2LLcGq0V955VtWa4suL7b+lcSsRnc1MZs/UWtP5Gqvludd2+ISbUIIIYQQQghxA0j8fJ592fv++9CHhjoxGiGEqP4k0etEFquFs8azADTyblS1B9O7QYMuthfYErEp5yDpNFw+BakxkHEJ0hNs7+bsgu2tJshOsb0cQa27moTVuhSRjDVc2e6ar15eHRdbudblapv85XntauuUFUIIIYQQQghRw2UdOEDWrt22FZUK/7FjnRuQEELUAJLpcqLz6ecxXxnZWuWJ3mtpdODf2Pa6qWfBbYoCOWmQmWgbCZybAbnptndT5tUHg9kfHpbvYV0FRvrmf6mvjobNP3JW52oru3ZErVrj2OshhBBCCCGEEKLaSFqyxL7scccduDRu7LxghBCihpBErxOdST0DgE6jo75HfSdHk49KBQYv20sIIYQQQgghhHAg04ULGP9YY1/3GzHcidEIIUTNUUlP/xIVcTr1NAANvBqgkRGsQgghhBBCCCEESV8uA4sFAEPLlrh16ODkiIQQomaQRK8TnU65mugVQgghhBBCCCFudJa0NFJ++MG+7jdiBCp5ULUQQpSJJHqdRFEUzhhtUzc08nHw/LxCCCGEEEIIIUQ1lPK/H7BmZACgrVsXr3t6OTkiIYSoOSTR6yRJ2UkYc4yAEx7EJoQQQgghhBBCVDOKyUTSsmX2db+hQ1HpdE6MSAghahZJ9DpJ3vy8IFM3CCGEEEIIIYQQaX/+iTkuDgC1mxs+Awc4OSIhhKhZJNHrJNHGaACC3INw07k5NxghhBBCCCGEEMLJkr/9zr7s3b8/Gi8vJ0YjhBA1jyR6nSQmLQaAcM9wJ0cihBBCCCGEEEI4V87p02T+84993fc/g5wYjRBC1EyS6HWS82nnAQjxDHFyJEIIIYQQQgghhHOlfHd1NK9bx464NJJn2QghRHlJotcJFEWxj+gN8wxzcjRCCCGEEEIIIYTzWLOySPl5uX1dRvMKIUTFSKLXCRKzEsk2ZwMyolcIIYQQQgghxI3N+PtqrEYjAJqAADzvvNPJEQkhRM2kdXYAN6K80bwAoZ6hToxECCGEEEIIIYRwruRvv7Uv+/Tvj0qvd2I0orZTFIXExESys7OxWCzODkfcoNRqNTqdDi8vL9zd3Sttv5LodYLz6bb5ef0Mfrjp3JwcjRBCCCGEEEII4RxZhw6RfeCAbUWlwvfhgc4NSNRqiqJw4cIF0tLS0Ov1aDQaZ4ckblBms5msrCxSUlLw9PQkODgYtfr6J16QRK8TnDOeAyDMS+bnFUIIIYQQQghx40r57nv7ske3bujq13diNKK2S0xMJC0tjTp16uDv7+/scMQNzmq1kpiYyOXLl0lJScHPz++69ylz9DpB3tQNMm2DEEIIIYQQQogblTUnB+OqVfZ1n0GPODEacSPIzs5Gr9dLkldUC2q1msDAQHQ6Henp6ZWzz0rZiyiX82m2qRtCPORBbEIIIYQQQgghbkzpGzZgvZLc0AQE4NG1q5MjErWdxWKR6RpEtaJSqdBqtVit1krZnyR6HSw1JxVjru1pojKiVwghhBBCCCHEjSp1xa/2Ze/7eqPSyuySQghxPSTR62BxGXH25fqeMveQEEIIIYQQQogbjzk5mfS//rKve/Xp48RohBCidpBEr4MlZCYAoNPo8HXxdXI0QgghhBBCCCGE4xlXrQKzGQB9k8YYWrRwckRCCFHzSaLXweIz4gGo41YHlUrl5GiEEEIIIYQQQgjHM+aftuGBPtI/FkKISiCJXgfLG9Eb5Bbk5EiEEEIIIYQQQgjHy42OJmv/fvu69wP3OzEaIYSoPSTR62AXMy8CkugVQgghhBBCCHFjSv11pX3ZrUMHdMHBToxGCFGcDh06MGfOHPv61KlT8fDwcGJEleva87l2ffny5Xz66acV2vfMmTPp2bPndcdYXpLodbCLGZLoFUIIIYQQQghxY1IUhdRf803b8KA8hE2I6uinn37i7NmzjBkzxl42evRoNmzY4MSoqta153c9id6nn36af/75h/Xr11dWeGUiiV4HsipWLmVdAiDIXRK9QgghhBBCCCFuLNkHDmA6dw4AlV6P5913OzkiIURRPvzwQwYPHoyrq6u9LCQkhA4dOjgxqqpVmefn4+PDQw89xEcffVQp+ysrSfQ6UFJ2Emar7amiddzqODkaIYQQQgghhBDCsYxr1tiXPW6/HY2XlxOjEaLmMxqNPPvss9StWxdvb2+GDx+OxWJh1KhRBUbjlsfp06fZvHkzAwYMKFB+7dQGI0aMoFWrVmzcuJG2bdvi7u5Ox44d2b17d6F9btu2jR49euDu7o63tzeDBw8mISGhQJ0FCxbQoEED3NzcuPPOO9m+fTsqlYolS5bY63Tv3p377y84r/euXbtQqVRs3LjRfqw+ffoQHByMu7s7bdq0YdmyZaWed/7zGzFiBEuXLuXQoUOoVCpUKhUjRoxgxYoVqFQqTpw4UaBtamoqbm5uBaa6GDhwIKtWreLSpUulHruyaB12JGF/EBvI1A1CCCGEEEIIIW4siqKQtmatfV1G84rqQFEUTFaTs8MAQKfWoVKpylw/PT2dbt26kZOTw6effsqlS5d48skn6dChA99++y0HDx6sUBzr1q1Dp9OVaXRrfHw8zzzzDK+++ipeXl68+uqrPPTQQ5w6dQqdTgfYEq/du3end+/efPfdd2RkZPDGG2/Qp08ftm/fDsDKlSsZO3YsI0aMYNCgQezatYtBgwZVKP6zZ8/SpUsXnnjiCQwGA1u2bGHUqFEoisKwYcPKtI9JkyZx6dIljh49yv/93/8BEBgYSIMGDahfvz6LFy/mrbfestf/5ptvsFqtPProo/ayLl26YDab2bhxIwMHDqzQuZSXJHodKG9+XheNC156+dZSCCGEEEIIIcSNI+fYMUwxMQCodDo87uju1HiEADBZTbT/qr2zwwBg96O70Wv0Za7/zjvvcPz4cY4dO0ZoaCgA8+fP54033mDQoEE0bNiwQnHs2rWLiIgIXFxcSq2blJTEpk2baNmyJQAGg4GePXvyzz//cNtttwHw6quvcvPNN/PTTz/ZE9mtWrUiMjKSVatW0bt3b2bOnEnXrl354osvAOjVqxcZGRkFkqlllT9BrCgK3bp14/z588ybN6/Mid7GjRsTGBjI2bNn6dy5c4Ftjz32GIsXL2bmzJloNBoAFi9eTN++ffHz87PX8/X1JSwsjH/++cdhiV6ZusGBLmZeeRCbe1C5vqERQgghhBBCCCFqurR80za433ormny3gAshym/x4sUMGzbMnuQF29ywGRkZTJw40V42atQo6tevj0qlwmw2l7rfuLg4AgMDyxRDcHCwPckL0KJFCwDOnz8PQGZmJlu2bGHgwIFYLBbMZjNms5mmTZtSr149du7cicViYffu3Tz00EMF9n3t1BFllZyczDPPPEN4eDg6nQ6dTsf8+fM5fvx4hfZ3rVGjRhEXF8fq1asBOHjwIDt37mTUqFGF6gYEBBAfH18pxy0LSfQ6kD3RK9M2CCGEEEIIIUS1YzKZmDx5MmFhYRgMBqKiovj666/L1HbWrFn07duX4OBgVCoVTzzxRLF109LSGD9+PHXr1sXV1ZXOnTuzJl8StLZKW5t/2oaeToxEiJrv5MmTxMbGcvc1U6CYTCYGDx5M48aN7WXDhg1jz549Zd53dnZ2mUbzgi2xnJ9er7fvA2xJV4vFwvPPP29Puua9YmNjiYmJ4dKlS5jNZurUKfg8q6CgiuXPRowYwTfffMOLL77ImjVr2LlzJyNHjrTHdL0aNGhAz549WbRoEQCLFi0iPDycO++8s1Bdg8FAVlZWpRy3LGTqBgfKm6NXHsQmhBBCCCGEENXP2LFj+fLLLxk3bhyRkZEsX76cIUOGYDabS73d9/XXX6dOnTp06NCB3377rdh6iqLQt29ftm7dyoQJEwgLC2Pp0qX07t2bNWvW0KNHj8o+rWohNyaGnBMnbStqNR619DxFzaNT69j9aOGHhzmDTq0rc93Tp08DtqRjnpMnT7Jt2zb69etXoO7tt99erjj8/PyIjo4uV5vi+Pj4oFKpmDhxIn379i20PSAggMDAQLRabaGHs128eLFQfYPBQG5uboGypKQk+3J2dja//fYbH3zwAePHj7eXW63W6zyTgsaMGcPgwYO5cOEC//d//8e4ceNQqwuPp01OTi4w4rmqSaLXgfISvTKiVwghhBBCCCGql71797JkyRKmT5/OpEmTABg9ejQ9evTgpZdeYtCgQfaRakU5ffq0fT7MkqbqW758OevXr+fLL79k6NChgG30WVRUFBMmTGDfvn2Vd1LVSPrGTfZl17Zt0fr6OjEaIa5SqVTlmhe3ushLKuZPcr788suYzWYURbmufTdt2pQNGzZc1z7yuLu7c8stt3DkyBFmzpxZbL127drx888/8/zzz9vLfvjhh0L1QkJCWLt2LYqi2H/Wrs13t0BOTg4Wi6XAz+u0tDRWrFhR7tj1en2xo4AffPBBfH19GTx4MJcvX+axxx4rVMdqtXLu3DmaNm1a7mNXlEzd4EBpuWkAeLt4OzkSIYQQQgghhBD5ff/996jVasaNG2cvU6lUPP300yQkJJSa9CjrQ4++//57/Pz8GDx4sL3MxcWFsWPHsn//fo4dO1axE6jm0jddTfR6dC/f6EIhRGHt27fH1dWV1157jd9//53nnnuO7du307lzZ3755RcOHz5c4X136dKFhIQE+zy71+u9997jt99+45FHHuHnn39m48aNfPXVVwwfPpyNGzcCtrsiNm/ezGOPPcYff/zBm2++WeTUOQMGDODcuXOMHz+eP//8k2nTpvHTTz/Zt3t7e9OhQwfefvttfvjhB5YvX07Pnj3x9i5/Lq558+ZER0fzzTffsGvXrgKjnHU6HcOHD+evv/7irrvuIiwsrFD7w4cPk5GRQdeuXct97IqSRK+DWKwWMkwZAHjqPZ0cjRBCCCGEEEKI/Hbv3k3jxo0LPDEdoFOnTgDlmt+ytOO0b9/e/qT2ihzHaDRy/vx5+ysuLq5SYqsq1owMMv/5x77uUc7byIUQhfn6+vLVV1+RnJzMgw8+yJo1a/j999+ZPn06//77L++++26F9929e3cCAgL4/fffKyXWW2+9lb///pv09HQee+wxevfuzfTp03Fzc6NJkyYA9OnTh3nz5rFu3Tr69u3L2rVr+eabbwrt65577uHdd99lxYoV9O3bl8OHD/PZZ58VqPP111/TuHFjhg8fzjPPPMOAAQNKnX6nKKNGjWLgwIGMHz+eDh06MHXq1ALb8x4eV9RD2ABWrVpFeHg4HTp0KPexK0qlXO947lru/PnzhIaGEhMTQ0hISIX3Y8w1MuS3IQC82+1dmvs3r6wQhRBCiBqlsj5bhRBCiMrUqlUr/P392ZRv5CnYbr3VaDQ8/fTTfPzxx2Xal0ql4vHHH2fevHmFtnl4eNC/f3+WLl1aoPz06dM0btyY999/nxdeeKHE/U+dOpVp06YVKq+un61p69ZxftzTAOiCg2m87s8Sp7cQoqrkjcjMP6/tjUSlUmEymdBqS5/J9YUXXmDv3r2sX7/eAZEVLTExkcDAQL744gtGjBjhtDiKM3nyZD799FMuXLhQ5MPr2rVrR9++fZk8eXKJ+6nMv5cyotdB0nPT7cseeg8nRiKEEEIIIYQQ4lpZWVlFdtTVajU6na7Snppe3HEMBoN9e2kmTJhATEyM/bVjx45Kia2q5J+f16P77ZLkFcLBhg4dav8SqEGDBvznP/8ptc1LL73EP//8w969e6s6vBrn2LFjrFixgo8//pgnnniiyJ/pmzZtIjo6mmeeecahscnD2Bwkb35eAE+dTN0ghBBCCCGEENWJwWAgJyenULnVasVkMtkTsVV1nLwH/pTlOF5eXnh5eVVKPFVNURQytmyxr7t36+bEaIS4MS1btqzcberWrcuSJUu4dOlSFURUsz3++ONs376de+65h9dee63IOkajkS+//BIfHx+HxiaJXgfJm58XZESvEEIIIYQQQlQ3wcHBnD17tlB5bGysfXtlHaeoOXUr+zjVhSkmBtOVc0Orxd2Bc1UKIa7PwIEDnXr8gIAAquOMs3kPkCvJAw88UPWBFEGmbnCQvBG9Bq0BrVry60IIIYQQQghRnbRr145Tp06RlJRUoPyfKw8Ra9euXaUdZ8+ePVgslio9TnWRsXWbfdm1TWvU7u5OjEYIIWo3yTg6SJrJluj11Mu0DUIIIYQQ1YWiKJgVMxarBYtie1mtVhQU27JixapYURQFK1asVqvtXSnidWW7gmIvsyiWAm2L26+iKOT9gwJWxWqLL++fK9vzYs5fljfSJW8/BeoUUe/aMvt+r8Rd5DGuxJVXt8A1zBfXtWXXlhe1vah6BdrnLedrUtQxi4ypHHGUdXtRsZepfTHtSmpfXqUdoyzH6VyvM70b9b7uWGqihx9+mHfffZdPP/2UN954A7Bdr7lz5xIYGMgdd9wB2B4OlJiYSFhYGG5ubhU6zvfff88333zDo48+CkBOTg7z588nMjKSZs2aVd5JVQMZ264met1vucWJkQghRO0niV4HyXsYm7tOvr0UQgghxI3HqljJteSSZc4ix5JDjiUHk8VErjUXk9VEriW3wHrecq4l33aryb6ct25RLPYkrdlqvvpeVNmVd6titSd38xKbQgibEM8QZ4fgNO3bt2fo0KFMmTKFS5cuERkZyfLly9m4cSOLFy+2P2xn7ty5TJs2jQ0bNtC9e3d7+2XLlhWY+mHPnj3MnDkTsD0IKTw8HIB+/fpx++23M2bMGI4ePUpYWBhLly7l1KlTrF692nEn7ACKxULm9u32dfdbbnViNEIIUftJotdB0k22RK88iE0IIYQQNYVVsZJpyiTDlGF/pZvSbWXmDNJz08k0Z5Jtzra9LNe851vOsRR+8NCNRKVSoUaNWq22vauuvlSoUKlUqLA9hV6tUtvb5N+mQoXt34LlqECN2v4U+2vb5JXb91vMPovcni+u/Mt56/ZlVAXerz33QmWl7Keo9qXtu6wxFNiuKiKefOdbUtyl7bukuItTprqlVCnLPkqq09K/Zekx1GILFy4kPDycJUuWMG/ePCIiIli2bJl95G1JFi1axKZNm+zrO3fuZOfOnQDcdttt9kSvSqVixYoVTJw4kQULFmA0GomMjGTlypXcddddVXNiTpJ95CiW1FQA1O7uuEa2cnJEQghRu0mi10Hy5uiVB7EJIYQQwhkURSHdlE5qTurVV+7V5ZScFIy5Roy5RntSN9OU6dAYtWotOrUOnUaHXq1Hp9ah1+gLlNnXr5RpVBo0ag1alRatWlt4Xa2xlak09vVrt2nVWrQqLWq1Go1KY0vAXpuMVans2/KSokW98rZpVBpbcvdKuRCiZtDr9cyYMYMZM2YUW2fq1KlMnTq1UHlZHs6Tx8vLi7lz5zJ37twKRFlzZF6ZdxjArWNHVDqdE6MRQojaTxK9DpI3dYOHThK9QgghhKhcJquJpKwkErMTuZx1mcSsxAKvy1mXSclJqZRpCtQqNW5aN9z17njoPHDVuuKqdcWgNWDQGAq+X1PmqnXFReuCQWNAr9HjonGxJ2x1ap0kRIUQopbJ3L3bvuzWoYMTIxFCiBuDJHodIDXLxMnLl8gxW+RhbEIIIYSokExTJnEZcVdf6VeXL2ddLvf+1Co13i7etpfeGx8XH7xcvPDSe+GusyVx3XXu9oSuu84dd507Bo2hyFvahRBCiPwUq5WsPXvs627t2zkxGiGEuDFIotcBZq48zMHki7i4mmXqBiGEEEKUyGQxEZMWQ7QxmrPGs0Qbo4lOjSYpO6lM7fUaPf4GfwJcAwhwDcDf1bbsa/DF28WW0PXWe+Ouc5eErRBCiCqTe/o0lpQUAFQGA4YWLZwbkBBC3AAk0esAHRr68WdqNtlmC+5aSfQKIYQQwsZsNROdGs3x5OMcTT7KieQTxKbHljjFgkaloa57Xeq61yXYI5h67vUIcgsi0C2QANcAPHQeksAVQgjhdJm7r47mdW3dWubnFUIIB5BErwN0CPdF2Z+FYoWUdI2zwxFCCCGEk5gsJo4mHWXfpX0cTDzIyZST5Fpyi6yr1+gJ8wyjgXcDwr3CCfcKJ9g9mADXADRq+X1CCCFE9Za5e5d9WaZtEEIIx5BErwPU9dGgUVuxWCE6weLscIQQQgjhIIqicC7tHLsv7mZfwj4OXT5UZGLXXedOhG8ETf2a0tCrIQ28GxDkFiQJXSGEEDVWVv4Rve3bOzESIURFdejQgaFDh/LMM884O5QSTZ06lffff5/09PQi15cvX05sbCxPPfVUufc9c+ZMNm3axNq1ays15qri0ESvoijMmTOHTz75hLNnzxIeHs7TTz/N+PHjy3yL4f/93//xzjvvcPz4cYKCghgxYgRvvPEGuny3gaSnp/Pee++xc+dOdu7cSWJiIm+99RavvvpqVZ1aiTLMGRh0GjJyzByLNzklBiGEEEI4hqIoHE8+ztbYrWyL3UZcRlyhOnXd6xIVGEULvxY09WtKsEcwapXaCdEKIYQQlc90MQHThQu2FbUa19ZtnBqPEKL8fvrpJ86ePcuYMWOcHUq5jR49mvvuu8++vnz5cnbt2lWhRO/TTz/Nu+++y/r16+nRo0dlhlklHNqjmD59Os899xydO3fmk08+oVOnTjz77LPMmDGjTO2XLl3Ko48+SlhYGB9//DEPPvggM2fO5PHHHy9QLzExkenTp/Pvv//Stm3bqjiVcknPTcegs13q6EsKqZmS7BVCCFE5TCYTkydPJiwsDIPBQFRUFF9//XWZ2iqKwkcffURERAQuLi5EREQwZ84cFEUpUC8uLo5XX32VO++8E29vb1QqFd9++21VnE6NlpCZwHdHv+PxtY/z4qYX+enET/Ykr6fek9vq38bTbZ9mwd0LWHD3Asa3Hc+d4XcS4hkiSV4hhBC1Sva/B+zLLhERaDzcnRiNEKIiPvzwQwYPHoyrq6uzQym3kJAQOnToUCn78vHx4aGHHuKjjz6qlP1VNYf1KuLj43nrrbcYOXIkX375JaNHj2bZsmUMHz6cWbNmER8fX2L7nJwcXn75Ze68805WrlzJmDFjmDNnDpMmTeKLL75g37599rr16tXjwoULxMTEMH/+/Co+s9Klm9Ix6DSgArXFwI7osj01WwghhCjN2LFjefPNN+nbty8ff/wxoaGhDBkyhC+//LLUtmX9AvbYsWO88847REdH06ZNmyo6k5pJURT2Juxl6tapjPpjFF8d+cqe3A10C6RP4z683fVtlt27jFc6vkKvBr2o617XyVELIYQQVSvr34P2ZdfIVk6MRIjaz2g08uyzz1K3bl28vb0ZPnw4FouFUaNGVXg07unTp9m8eTMDBgwotG3btm306NEDd3d3vL29GTx4MAkJCQXqLFiwgAYNGuDm5sadd97J9u3bUalULFmyxF6ne/fu3H///QXa7dq1C5VKxcaNG+3H6tOnD8HBwbi7u9OmTRuWLVtWavxTp07Fw8MDgBEjRrB06VIOHTqESqVCpVIxYsQIVqxYgUql4sSJEwXapqam4ubmxpw5c+xlAwcOZNWqVVy6dKnUYzubw6Zu+OWXX8jJyWH8+PEFysePH8/SpUv55ZdfCo3MzW/jxo0kJCTw9NNPFygfN24c06dP57vvvrN3Pl1cXAgODq70c6iotNw01CoVBq0W0LPxWAI9WwQ5OywhhBA13N69e1myZAnTp09n0qRJgO02pR49evDSSy8xaNAg9Hp9kW3zfwG7aNEie1uNRsOsWbMYO3YsdevaEpLt27cnMTERf39/Nm7cyB133OGYE6zGzFYzG2I2sPzkcs4Zz9nL3XRudAvpRo/QHjTza1bmqamEEEKI2iT733/ty4ZWkU6MRIjSKYqCYqoed16rdLpy/f6Ynp5Ot27dyMnJ4dNPP+XSpUs8+eSTdOjQgW+//ZaDBw+WvpMirFu3Dp1OV2hU7LZt2+jevTu9e/fmu+++IyMjgzfeeIM+ffqwfft2AFauXMnYsWMZMWIEgwYNYteuXQwaNKhCcZw9e5YuXbrwxBNPYDAY2LJlC6NGjUJRFIYNG1amfUyaNIlLly5x9OhR/u///g+AwMBAGjRoQP369Vm8eDFvvfWWvf4333yD1Wrl0UcftZd16dIFs9nMxo0bGThwYIXOxVEclujdvXs3Li4uREVFFShv27Yter2ePXv2FNPyanuAjh07FigPDAykUaNGpbYvK6PRiNFotK/HxRWeV6+80k22yZ/93bzJQsWOM0kYs014GXSltBRCCCGK9/3336NWqxk3bpy9TKVS8fTTTzNgwAA2bNhAr169imxbni9gPT09q+4kahirYmXLhS0sO7yswNy7jX0a82DjB7m1/q24aFycGKEQQgjhXIqikHXokH1dRvSK6k4xmTgW1drZYQDQ9MB+VMUM1ChK3jOsjh07RmhoKADz58/njTfeYNCgQTRs2LBCcezatcs+vVt+r776KjfffDM//fSTPSHdqlUrIiMjWbVqFb1792bmzJl07dqVL774AoBevXqRkZFRIJlaVvkTxIqi0K1bN86fP8+8efPKnOht3LgxgYGBnD17ls6dOxfY9thjj7F48WJmzpyJRmN7CPLixYvp27cvfn5+9nq+vr6EhYXxzz//VPtEr8OmboiNjSUoKAi1uuAh1Wo1QUFBxMbGltoesI8uyi84OLjU9mU1e/ZsQkND7a9rE8sVkZ5rS/TW9fRGq1FhsSr8dbz6D/cWQghRve3evZvGjRsX+CUEoFOnTgAlfgl6vV/A3ogOXz7MhI0TeHfnu/Ykb6d6nXi769v8t/t/uSPsDknyCiGEuOGZzp3DmpoKgMrFBZebbnJyRELUXosXL2bYsGH2JC/Y5pTNyMhg4sSJAMTExHDXXXfRtGlTIiMjGTNmDLm5uSXuNy4ujsDAwAJlmZmZbNmyhYEDB2KxWDCbzZjNZpo2bUq9evXYuXMnFouF3bt389BDDxVoW9QUEGWRnJzMM888Q3h4ODqdDp1Ox/z58zl+/HiF9netUaNGERcXx+rVqwE4ePAgO3fuZNSoUYXqBgQElDrtbHXgsERvVlZWoW8C8hgMBrKyskptr9PpCiWKy9q+rCZMmEBMTIz9tWPHjuveZ1puGgC+Bm86NfQHYN2RhJKaCCGEEKWKjY2lXr16hcrzpi8q6UvQ6/0CtjyMRiPnz5+3vyrjbhlHyjZnM//AfF756xVOpZwCoHVgaz7o/gFvdH6DlgEtZYoGIYQQ4or88/MamjVDpZM7WYWoCidPniQ2Npa77767QLnJZGLw4ME0btwYAK1Wy5tvvsmxY8fYv38/6enpBeafLUp2dnahHF5ycjIWi4Xnn3/ennTNe8XGxhITE8OlS5cwm83UqVOnQNugoIpNXzpixAi++eYbXnzxRdasWcPOnTsZOXIk2dnZFdrftRo0aEDPnj3tU9ktWrSI8PBw7rzzzkJ1KzP3WJUqNHVDeebnW7t2LXfddRcGg4GcnJwi62RnZ2MwGErcj8FgwGQyoShKoc5UWdqXlZeXF15eXpWyrzx5Uzd46Dzo3rwOW04msv98CvGp2dT1rpy4hRBC3HiK+xJVrVaj0+lK/EXker+ALY/Zs2czbdq0StufI/176V8+2vsRFzMuAhDiGcLjUY/Tpk4b5wYmhBBCVFMF5ueNlPl5RfWn0uloemC/s8MAKNcXI6dPnwZsyco8J0+eZNu2bfTr189eVq9ePfvgELVazc0338yZM2dK3Lefnx/R0dEFynx8fFCpVEycOJG+ffsWahMQEEBgYCBarbbQw9kuXvz/9u47vIoq/+P4e25PryQQIFRBaQqIoK6KrG2xwCJrAUF0V7cgurpVf0p11+6uK7p2UNa6xbJ2VMAuCrgICEpPSCC9J7fO74+bXHJJIYEkN4HP63nuMzNnzpn5ziTcQ7733DP7GtR3uVwNRhYXFRWF1mtqanjjjTe47777wqabCwQCzcbeWtdccw3Tpk1jz549PPvss8yePbvRQabFxcUMHTq0Tc/dHg4p0Tt48GAef/zxFtUdMmQIEBxd9P777xMIBMJuWCAQYN++fQd9eFrd/tzc3AZ1c3Jy6N+/f2suoUP1juvN6PTRHJN0DOP6pxAXZae82su/1mRx3QR9jUVERA5NUx+iBgIBvF5vsx+CHu4HsK1x00038bOf/Sy0nZub2yZTI7WngBngpS0v8ey3wQc2WAwLU46ZwuXHXo7D2vJ500RERI421fUe/qT5eaUrMAyjVfPidhZ1ubX6ydHf//73+Hw+TNNstE11dTVLlizh3nvvbfbYgwcPZsWKFWFlMTExnHzyyXz77bfcfvvtTbYdNWoUL7/8MjfeeGOo7F//+leDer169WL58uVhAzqXL18e2u92u/H7/WEPly4vL+e1115rNvbGOByOJkcBT5o0iaSkJKZNm0ZhYSFXXXVVgzqBQIDdu3czePDgVp+7ox1SordHjx5hf7C1xKhRo3jiiSdYv349J5xwQqh83bp1eDweRo0addD2AKtXrw775KCgoIDt27cf8nwfHeHCARdy4YALQ9s/PiGDZz7bxRvf7GXmKX31UDYRETkkGRkZ7Nq1q0F53bQLzX2IergfwLZGe3xbpj25/W7u/fJePs8NPjm4d1xvbhp9EwOTBkY4MhERkc7NNE3cmzeHtl1dYPSbSFc1evRooqKiuPnmm/F6vbzzzjt8/vnnjBs3jldffZVzzz03NPgSwO/3c8UVV/DDH/6Q8847r9ljn3rqqSxcuJDs7Gx69eoVKr/nnnuYMGECl156KZdddhlJSUlkZ2ezfPlyrrrqKsaPH8///d//MWnSJK666iouu+wyvvrqK5577rkG55g6dSpPPvkkc+bMYfLkyXzyySf85z//Ce1PSEhgzJgx3HnnnaGRwnfeeScJCQkNRgwfzHHHHcdTTz3F888/zzHHHENqampoJLTdbufKK6/knnvu4ZxzziEzM7NB+02bNlFZWclpp53WqvNGQofN0Ttp0iQcDgeLFy8OK3/wwQdxOBxMmjQpVFZaWsrmzZsprZ3AHeDMM88kNTWVhx56KKz94sWLMU2TSy65pH0voA1NHtkTh82C2+vn1a/bbg5EERE5uowaNYpt27aFfYoP8MUXX4T2N9fW7Xazfv36sPKWfgB7pCrzlHHrx7eGkryn9zqd+8ffrySviIhIC3j35BCorATAcDhw1PtKuYi0raSkJP7xj39QXFzMpEmTePfdd3nrrbdYuHAh33zzDXfffXeormmaXH311cTExPDXv/71oMceP348qampvPXWW2Hlp5xyCh9//DEVFRVcddVVTJw4kYULFxIdHc3AgcH/L1900UU88sgjvP/++0yePJnly5fz/PPPNzjHeeedx913381rr73G5MmT2bRpE3//+9/D6jz33HMMGDCAK6+8kuuvv56pU6cyc+bMVt+rn/70p/zkJz9hzpw5jBkzhvnz54ftr3t4XGMPYQN488036dOnD2PGjGn1uTuaYTY1nrsdzJ07l0WLFjFz5kxOP/10Vq1axbJly5g7d27Y3H1Lly7lqquuYsmSJcyaNStU/tRTT/HTn/6UCy64gEmTJrF+/XoeeughrrjiCp5++umwcy1evJiSkhJKSkq47777OOecc0KZ9zlz5pCQkNCimLOzs+nduzdZWVlhn2Icrr++9x2vfZ1DnMvGsp+N1aheERFptTVr1nDiiSeyaNEibr31ViD4n7gJEyawceNGsrKycDqdFBQUUFBQQGZmJtHR0UBw1G+/fv2YMWMGTzzxROiYdQ882LlzZ6MPequbp//555/nsssuO6S426tvPVzlnnJu+fgWdpbuBGD6cdO5dPCletCaiIh0ep2lby3/4AOyfzUbANeQIfT7z78jFotIY+rmne17lH0I8ctf/pK8vDxeeuklrFZri9r85je/Yd26dXzwwQeHff6CggK6devWIM/XWcydO5eHH36YPXv2NPock1GjRjF58mTmzp3bLudvy9/LQ5q64VAtWLCApKQkHnroIV544QV69+7N/fffz69//esWtb/66qux2+3cfffdzJ49m7S0NG655RZuu+22BnXvvffesK+zvvvuu7z77rsAXHHFFS1O9LaXaSdl8vaGvZTX+Fj6yU6u/6Hm6hURkdYZPXo0M2bMYN68eeTn5zN8+HBeeeUVVq5cyVNPPRX6T8rixYtZsGABK1asYPz48UBw6oY//OEPLFq0CK/X2+AD2AOTvHXzcNU9uOHll19m69atAKEkc1dW6a3ktk9uCyV5rx91PWf3OTuyQYmIiHQx7i1bQuvOLjCXpcjR4JNPPuGRRx5hyJAhjB49Ggh+a/4vf/lLs+1+97vfMWDAANatW8fIkSM7ItQOt2XLFrZs2cKDDz7I7NmzG03yrlq1ip07d3L99ddHIMLW69BEr2EY3HjjjWETMjdm1qxZTWb4Z8yYwYwZMw56rgOfDtjZpMW7uPykTJ7+dCev/S+HC0b0oH+32EiHJSIiXcwTTzxBnz59WLp0KY888giDBg1i2bJlXHHFFQdt25oPYA/8UPWll17ipZdeArp+otcX8HHn6jvZVrINgBtG3cBZfc6KcFQiIiJdT81334XWnYMHRTASEalz6qmnNvlwtuZ0796dpUuXkp+f3w5RdQ4///nP+fzzzznvvPO4+eabG61TVlbGM888Q2JiYscGd4g6dOqGrqgtvgJTvX49VWvW4ujXl7jakVQANV4/s5asJq/MzTHpsTx4+Sgctg6bNllERCQiOsvXS+s8/PXDvLUjOP/YtSOuDXuAqoiISFfQWfrWbRPPx7N9OwCZTz1JzCmnRCwWkcYcrVM3SOfWlr+Xyip2gLI33qDkxRepWLEyrNxlt3LjWcFPOb/fV8HjH22PQHQiIiJHr5VZK0NJ3gv6X6Akr4iIyCEK1NTgqffNWk3dICLS8ZTo7QCuYcMAqNm4ETMQCNs3tn8KU08MfuL67zXZvPVNbofHJyIicjTKrcjl4a8fBmBY6jB+NvxnEY5IRESk63Jv3Qa1f+9aU1OxpaREOCIRkaOPEr0dwDV0KACBigq8u3c32H/Naf05LiMegHvf3cL73+7r0PhERESONgEzwP1r7qfaV02sPZbfnPgbrJaWPYFYREREGqr/IDbXIM3PKyISCUr0dgBHnz5YYmIAqN6wscF+u9XCHVOG0y81BtOEP73xLU9+vAN/QNMni4iItIflu5azuWgzAHNGziE1KjXCEYmIiHRt7u+/D607legVEYkIJXo7gGG14hpyHBCcvqEx8S479/zkeAakxQLw7Oe7+Pmyr/h0a4ESviIiIm2o1F3Kkg1LADip+0mc0lMPihERETlc7h37nznjHDgggpGINM1qteL3+yMdhkiIaZr4fD4slrZJ0SrR20FcQ2vn6d2wock6yTEOHrx8JBOH9wBge34lt76ygUse/Yx73tnMq1/vYcOeUmq8elMSERE5VM9tfo5KbyVOq5OfH//zSIcjIiJyRPBs3xFad/TrF8FIRJrmcrnweDwUFhZGOhQRAoEA+fn5eL1eYmNj2+SYtjY5ihxU3Ty9vvx8vPvysKenNV7PbuW35w7m7CHpPPnxDjbsKaW40sNb3+zlrW/2BisZ0C3WSc/EKDISo+iZVLusfUU5NMegiIhIY/Kq8nhnxzsAXDzoYtKiG++PRUREpOUCHg/ePXtC247+/SMYjUjTUlNTcbvd5OXlUVJSgtWq/IlERiAQwOv1EggEiIuLIzExsU2Oq0RvB3EOHoRht2F6fdRs+AZ7+g+brX9870T+dvlIdhRUsmJzHhtzytiaV055jQ9MyC93k1/u5uuskgZtk2IcwaRvUhS9kqLokxxDn5RoMhKjsFqMdrpCERGRzu+FzS/gN/3EOeKYNGBSpMMRERE5Inh37YJAAABrQgK2pKQIRyTSOMMw6NmzJwUFBdTU1GgaB4kYm81GVFQU8fHxxNQ+16tNjttmR5JmWRwOnMcdR836b6has5a4Hzaf6K3TLzWGfj8Ifu3FNE3yy91sL6hkT3E1OSXV7Kl95ZbWEKidy7e40kNxpYcNe0rDjmW1GMHEb0oMmcnRZCZH0yclmt7J0bjs+hRLRESObPsq9/H+7vcBuPiYi4m2R0c4IhERkSODu/60DRrNK52cYRh069Yt0mGItAslejtQ9Ikn1iZ6v8IMBDBaOdGyYRikxbtIi3c12OcPmOwrqwlL/u4prmZ3URU5JdWYZrDOrsIqdhVWHXBgSI9z0SclmsyU6NAI4MyUaOJd9sO5ZBERkU7jjR1vEDADxDviOb//+ZEOR0RE5Ijh2aH5eUVEOgMlejtQ9IljKHpqCYHSMtzff49r8OA2O7bVYpBRO2fviQfs8/gC7CmpZldhJbuLqkLJ3uziKjy+AJiwr6yGfWU1rN5RFNY2MdreYARw9wQXaXEuHDY9y09ERLqGGl8N7+58F4Bz+p6Dy9bwQ1MRERE5NPUTvc7+SvSKiESKEr0dyNGvL9bUFPwFhVR9+VWbJnqbPa/NEpwCIjV8zo+6UcC7CqvYXVRZuwy+Kmp8AJRUeSmpKuF/B84FbEBKjIP0eBfp8S66x7tIi3fSIyGKbnFOkmMcxLtsGIbmBBYRkchbmb2SSm8lFsPCxH4TIx2OiIjIEcWtEb0iIp2CEr0dyDAMokefSPk771D11VckXzE9ovHUHwV88oCUULlpmhRXeRuMAN5dVElhhae2EhRWeCis8LApp6zJ4yfFOEiOdpAcW7uMCb7io2zEuezEufYvYxw2PSxORETaxRvb3wBgXI9xdIvWnGwiIiJtxTRNTd0gItJJKNHbwaLHjKH8nXdwb9mCv6QEa2JipENqwDCMUEJ2ZGb401KrPX72ltWwt7SGvPLgcm9ZDXllbvaW1VBc6QnV9QdMCsrdFJS7YV9LTgyxThtxLhuxThvRDhsuu5Uou5Uoh6V2aSPKbsVlr9u24rBasNsswaXVgqNu3WY02KdEsojI0Wdn6U52lu4E4Ef9fhTZYERERI4w/oICAuXlwQ2bDUfv3pENSETkKKZEbweLHnkCht2O6fVS8fEnJFzQtR4GE+WwNjoNRJ0ar5+iSk/oVVwVHPVbXBXcLqzwUFTlobzGh9vrD29sQkWNLzRtRHuwWAxsFgPDCC7rti1GcASy1WJgMYJLq1Fv2xq+bTGgblYKi2HUrgeXBtQu628HKzcsD+6rbY5B8Nh1dQ/GNA//nhzsEC05h3nQo4SfqH5ts5kTmE22afr8Lb0n9c97uMc+sDi8XtPtw6+v8YM3jKcF7ZtpE35ss8l6Td/7Zn5ezZy0qfvYknvQXDxOm4XF00Y12U4EYFX2KgCSXEkMTx0e4WhERESOLJ6dO0Prjl69MOx6oLeISKQo0dvBLDExRI8dS+XHH1OxcmWXS/QejMtuDU0HcTAeX4AKt4+yam9oWe72UV7jo6LGS7U3QLXXT43XT7XHT7W39uXZv13j9QcfKNdCgYCJJxBME7kP+SpFpLNw2a2RDkE6OdM0+TD7QwBO63kaVot+Z0RERNqSZ3dWaN3Rp08EIxERESV6IyD2jDOo/PhjajZswJuXhz0tLdIhRYTDZiHZFpwi4nCYpokvYOLxBfD6A3j8gdp1M7hdW1637QuYBAIm/oCJ3wyu+wImAbO2rPYVMMEfCNRbD9b3B0xM08Q0gyMLTRMCtcMNzWBABMzg6MS6UYh124TaNN6+7nrq6teO9W3WwUb+tmSyio56aF7909S/tgNPX3+zqdAOjNloYqOl52kqzuaOEX7OZuI5zGM39+Np6md3SNfQzDGau49NxdPyYzfevrl6dexWSzNnEYHNRZvJq8oD4IzeZ0Q4GhERkSOPJ3t/oteuaRtERCJKid4IiD5pDEZ0FGZVNZUffkji1KmRDqlLMwwDu9VQwkdERBr4aM9HAHSP6c4xicdEOBoREZEjjzcrO7Tu6N0rgpGIiIgyYxFgcTqJPeUUAMo/WNHsnJciIiJyaEzTZPXe1QD8oOcPOuzbCyIi0nV5vV7mzp1LZmYmLpeLESNG8Nxzz7WorWmaPPDAAwwaNAin08mgQYP429/+1uDvvaVLlwaf09HIa+vWre1xWe3Km6URvSIinYVG9EZI7IQfUv7e+3i2b8f93Xe4Bg+OdEgiIiJHlD0Ve9hXuQ+AE7ufGOFoRESkK7j22mt55plnmD17NsOHD+eVV15h+vTp+Hw+Zs6c2WzbhQsXMn/+fGbMmMHvf/97Vq1axQ033EBJSQlz585tUH/+/PkMGDAgrKx79+5tej0dwZO9f0SvvadG9IqIRJISvRESdcLx2DMy8ObkUPbGm0r0ioiItLGv9n0FQIw9hmOTjo1wNCIi0tmtW7eOpUuXsnDhQm677TYAfvaznzFhwgR+97vfcdlll+FwNP58kb1793LHHXdw9dVX8+STT4baWq1W/vznP3Pttdc2SOKee+65jBs3rn0vqp0FKivxFxaGth29ekYwGhER0dQNEWJYLMSfPxGAilWr8FdURDgiERGRI8tXe4OJ3pFpI7FarBGORkREOruXXnoJi8XC7NmzQ2WGYXDdddeRl5fHihUrmmz76quv4na7mTNnTlj5nDlzcLvdvPrqq422Ky8vx+/3t80FRIAne09o3ZqSgiUmJoLRiIiIEr0RFHfWWRh2O6bbTfny9yIdjoiIyBGj2lfNxsKNgKZtEBGRllmzZg0DBgwgOTk5rHzs2LEArF27ttm2TqeTESNGhJWPHDkSh8PRaNuzzz6b+Ph4oqOjmThxIps2bWpxrGVlZWRnZ4deubm5LW7blrzZ++fndfTStA0iIpGmRG8EWRMSiDn9NABKX3kFswt/kisiItKZrM9fjy/gA2BU2qgIRyMiIl1BTk4OPXr0aFCekZER2t9c2/T0dCyW8D+xLRYL6enpYW2jo6O58sorWbx4MS+//DJ//OMf+eijjzjllFNa/DC2+++/n969e4deJ510UovatTWPHsQmItKpKNEbYYkXTwXAt28fFR9+GOFoREREjgzfFHwDQL+EfiS5kiIcjYiIdAXV1dU4nc4G5RaLBbvdTnV1davbArhcrrC2l1xyCUuXLuXKK69k8uTJLFiwgA8++IDy8nLmz5/folhvuukmsrKyQq/Vq1e3qF1b82bVexBbb43oFRGJND2MLcKc/fsRdeJoqr9aQ8m//k3s+PEYhhHpsERERLq0DQUbABiWOizCkYiISFfhcrlwu90NygOBAF6vF5fL1eq2ADU1Nc22BRgzZgynnXYa773Xsin94uPjiY+Pb1Hd9uQJm7pBI3pFRCJNI3o7gaSf/AQAz7ZtVH/1VYSjERER6doqvZVsL90OKNErIiItl5GR0ehct3XTLtRN4dBU23379hEIBMLKA4EA+/bta7ZtnczMTIqKiloZdWRpRK+ISOeiRG8n4BoxAuexxwJQ9MwyTNOMcEQiIiJd16bCTaG+dGjK0AhHIyIiXcWoUaPYtm1bg2TrF198EdrfXFu328369evDytetW4fH42m2bZ3t27eTlpZ2CJFHhmmaePfsCW3rYWwiIpGnRG8nYBgGyTNnAOD+/nuqPv88whGJiIh0XXXTNmTGZ5LgTIhwNCIi0lVccsklBAIBHn744VCZaZosXryYbt26ceaZZwJQUFDA5s2bqaqqCtWbNGkSDoeDxYsXhx3zwQcfxOFwMGnSpFBZY6N233//fT755BPOPffctr6sduMvKsKsm67CasWWnh7ZgERERHP0dhZRI0fiGjaMmg0bKHr6GaJPOgnDao10WCIiIl3OxsKNAAxL0bQNIiLScqNHj2bGjBnMmzeP/Px8hg8fziuvvMLKlSt56qmnQg9bW7x4MQsWLGDFihWMHz8eCE7d8Ic//IFFixbh9Xo5/fTTWbVqFcuWLWPu3Ln06NEjdJ4f/OAHjBw5khEjRpCYmMjXX3/Nk08+SXp6eosfxtYZeHP2T3NhT0/X368iIp2AEr2dhGEYJF85k5zf/R7Pzp2UL3+P+PO6zqe5IiIinUGNr4bvi78HND+viIi03hNPPEGfPn1YunQpjzzyCIMGDWLZsmVcccUVB227YMECkpKSeOihh3jhhRfo3bs3999/P7/+9a/D6k2ZMoU333yTN998k8rKSrp3786sWbOYN28ePXv2bKcra3ve3JzQuq1eIltERCLHMDUhbLOys7Pp3bs3WVlZ9OqAOYf2LlxI5aefYU1KIvPJJ7BER7f7OUVERDpSe/atGwo2cPNHNwOw9LylpESltOnxRUREOqOO/rsVoOiZZ9j35zsAiL/gAnree0+HnFdERJqmOXo7mZSf/hRsVvzFxRS/8GKkwxEREelSthRtASAlKkVJXhERkXYUNnWDRvSKiHQKSvR2MvaePUmcPBmAkv/8G8/OnRGNR0REpCv5rvg7AAYnDY5wJCIiIkc2b269RG+GEr0iIp2BEr2dUNK0adhSU8HnJ/9vD2IGApEOSUREpEvYUhwc0TsoaVCEIxERETmyeffuT/TaunePYCQiIlJHid5OyBIdTersXwFQs2kTZW++FeGIREREOr/C6kIKqwsBGJysEb0iIiLtyVd/6oaMjAhGIiIidZTo7aRiTj6ZmFNOAaDwiSfw7tkT4YhEREQ6t7rRvIZhMCBxQISjEREROXKZHg++goLQtuboFRHpHJTo7cRS51yHJSEes6aGvPvuw/T7Ix2SiIhIp/VdUXB+3j5xfYiyRUU4GhERkSOXNy8PTBMAS0wMlri4CEckIiKgRG+nZktKotuc6wGo2fQtRU8/E+GIREREOq+6B7ENStb8vCIiIu3Jm5MTWrf16I5hGBGMRkRE6ijR28nF/uBU4if+CICSl16i8rPPIhyRiIhI5xMwA2wt2QrAMUnHRDgaERGRI5svt978vD00P6+ISGehRG8XkPKLX+A8ZiAAeffcq/l6RUREDrCvch/VvmoABiYMjHA0IiIiRzZv7t7QuubnFRHpPJTo7QIsDgfpt96KJS6OQFUVexfdTqCmJtJhiYiIdBrbS7cDYDEs9InvE+FoREREjmze+iN6M5ToFRHpLJTo7SLs6emk/+H3YBh4du4k7149nE1ERKTOtpJtAGTGZWK32iMcjYiIyJHNu3d/otfWvXsEIxERkfpskQ5AWi76xBNJmnEFxc8so/Ljjyl45BFSf/UrTXwvIiJHvboRvQMSB0Q4EpFDYJq1T683wQzUbgcOsn2Q+qH1um0O2D5wvxkeT912XdvQdljgLahvtm/9sNVG6jRW/1CZBzvGQfYftH1bHKMNrjOpH/QYcfjHkSOab19eaN2enh7BSEREpD4leruYpMsvx7cvj/J33qHsv69jS0kh6bLLIh2WiIhIRNWN6O2f2D/CkUiHMU3we8Dnrn1Vg88Dfjf4vRDwBZd+DwS84PfVLuv2ecLrBWrrHlgv4AfTH0yIBnwQCAS368pDy9r9Zr2lGai3v/4+M7yOSGdy/GVK9MpB+fbtC63blOgVEek0lOjtYgzDoNv1c/CXllD1+RcULX0aa1Iy8eeeE+nQREREIqKopogSdwkAAxI0ordTqUvGeirBWwWequAytF4ZXHoqwFtdW15ZL3HrrveqCR7LW70/wdsWoxelEzIg9I21et9cO7Cs/nZj9UObB9Zv4pzNhnQ436A7nGNHoK09pvl2ctQLeDz4i4tD27Y0JXpFRDoLJXq7IMNqJf2PfyT3lluo2fQt+Q88gGGzEvfDH0Y6NBERkQ5XN5oXoF9CvwhGcoTyecBdDu7S4LKmrHa7HNxlta/65WXBZK2nIpjENTvBMwWsDrDYwGoHiz24bGrdYgvWt9rBsOxfWmy1SysY1oZLwwjWabDfsn877BiW/e0MS/BVl7AM26beNgepW5fwrCs36u2zHLCvbpt69aD1iVUaKWtlfU1DJtKl+PLyQ+uWmBissfpwQESks1Cit4uyuFx0X7CAnN/9fv/D2Xx+jewVEZGjzvaS4Py8PWJ6EG2PjnA0nZzfC9UlUF0cfNXUrZc03K6pTez63e0Ti9UJjmiwR4MjBuxR4etWJ9icYHOBzVG7dAWTsPao4NLmqq1Tr67VuT95a3WEJzRFROSw+fLqTduQlhbBSERE5EBK9HZh1rg4Mu66k5xb/g/Ptm3k/+UvmF4vCRecH+nQREREOszOsp1AcDTv1rxyarwB0uKdpMW5IhtYRwj4g0nZqsLwV1NJXE9l25zXEQPO+NpXHLhql6FXfG3Cti6RWz+hW7u0WNsmFhER6VCan1dEpPNSoreLsyYkkHHnHeTe8n+4v/+egsWLCVRWknjJTzA0ekVERI4CO8p2EDBN1my1sfyjNQBYLAaTT8hg1qn9iHV2sf/uBALBpOyByduqIqgqgKri2mVRMIF7OPPU2qIgKhFcCRCVVPtKBFdtmSthf+K2fiJXSVoRkaOWt16i156uEb0iIp1JF/vLRxpjjYujxx13kHvbrbi/3UzRkiX49u0ldfZsDKv+EBMRkSOXx+8htyKXwkoPtqI4HLXlgYDJf9bu4fPtRdzzkxH0SIiKaJwhPg9U5kNlHlTUvfYFyyry9i8PdV5bZxxEJUN0cjBZWz95G5VUW5a4f2nvJPdFRES6jPpz9GrqBhGRzkWJ3iOENTaGjDvuYN+dd1L1+ReUvfkW3rw8ut9yC5ZozVcoIiJHpqzyLEqq3dR4/CT407jq1L5cMCKDl9ft4fnVu8kpqeb659dx70+Op09KBzwsJhAIjrYtzYayPcFlaTaU5QS3qwpbf0xHTDB5G5Nab5l0wHZtctfmbPtrEhERqSds6oY0Td0gItKZKNF7BLG4XHS/7TYKHn2Ustf+S/VXa9hz02/oftut2Hv2jHR4IiIibW5n6U4q3T7AzpkDj+GKcX0wDIOrf9CP4b0SmPvqRgorPPz+X+t5aPooUmPbMBFaXQIF30PhViisW24HX3XL2ludENsNYtMhphvEpkFMWu2y2/4Erv0omGtYRES6jPA5ejWiV0SkM1Gi9whjWK2k/vKX2HtkUPjYY3h27iR7zvWk/e63xJx8cqTDExERaVOr92zB5zex+rsxdXRm2Pz0Y/omc9fFw/ndP9eTX+7mlv98wwOXjSTKcYjTGlWXwO7PIWcd5KyFou3N13fEQHwvSOgJ8T0hoRfEdd+f2HUlgObTFxGRLsablxdat0fwYWymaeIzfXj9XrwBLwEzgN/0B5cB//712qUv4CNgBhotq1/XHwiuBwhgmiYmZoNlwAwEY6hdN00zfLuubl275o5Rb9usnXe//jFD11tvTn7zwPn5G2w207aZ4za2vzV1DzXGg8bUihhFOqPR6aM5q89ZHXKuDk30mqbJ3/72Nx566CF27dpFnz59uO6665gzZ06LHxz27LPPctddd/Hdd9+Rnp7OrFmzuPXWW7Hb7aE6X375JcuWLWPFihXs2LEDq9XK2LFjufXWWzn99NPb6/I6DcMwSPzxZByZvdl3510EysvZu2AhiZdeSvLMGZq3V0TkCOL1elm0aBFLly4lLy+PQYMG8cc//pFp06YdtG1r+uXy8nJuueUW/vnPf1JaWsrxxx/PwoULOeecc9rr0lrkyz1bAEiyZzA0I77B/hG9Ernl/ONY8NpGtuZVsPD1jdw+eThWSwsTrDWlsH0lfP8uZH3Z+Ny5Md0gZQCkDAy+kvoFk7pK5IqIyBHGNE189RK9LZmj1xfwUVxTTGFNIYXVhRTVFFHpraTKV0W1r5oqbxVVvqrQstpbTbW/OpTA9Qa84esBLx6/B2/A256XKiLSZkzMIzPRu3DhQubPn8+MGTP4/e9/z6pVq7jhhhsoKSlh7ty5B23/9NNPM2vWLM4//3zmzJnDN998w+23305WVhZPPfVUqN5dd93Fhx9+yNSpU7n++uspKSnhscceY8KECbz++uucd9557XmZnUb06NH0Wvwg+26/Hff3Wyl58UWq1/+P9N/9DntGRqTDExGRNnDttdfyzDPPMHv2bIYPH84rr7zC9OnT8fl8zJw5s9m2Le2XTdNk8uTJfPrpp9x0001kZmby9NNPM3HiRN59910mTJjQ3pfZKLfPz+6yXQCMzhjc5IfGZwzqxs/P6M+jq7bzxfYiHnj/e24865jmP2Qu3gXrlsHmN8Hv3l9usUH6UMgYBRkjg+tRiW14VSIiIh2vpKYEExOn1YnD6sBmaTxVECgrw6ypCW4YBt9b8inM3hpK4hbWFFJQXUBRdVGorMRd0nB0p4iItAvD7KAx73v37qVv375Mnz6dJ598MlQ+a9YsXnjhBXbu3En37t2bbO92u8nMzGT48OG89957ofL58+ezYMEC1q1bxwknnADAJ598woknnojTuX8evuLiYoYMGUJGRgZr1qxpcdzZ2dn07t2brKwsevXq1Yor7jwCHg+FjzxC2ZtvAWC4XKT+4hfEnXtOi0dSi4hI57Nu3TpGjRrFwoULue2224BgUnbChAls2rSJrKwsHA5Ho21b0y+//PLLTJkyhWeeeYYZM2YAwX55xIgRREVF8fXXX7cq7rbqWz/4bic3ffRTAO6fcCcT+o1psq5pmjzw/ve89nUOAL84YwCXjOndsGJVEXzxCGx8Zf/oXYsNMk+GY86GfqeDM+6QYxYREWkPh9u3/vK9X/Lxno9D28muZJKcSVgsFgyCfzNWeauIyy5mwd9LASiOgZ9f33lmg7QZNiyGBavFGlwaVqzG/nWLpemyBvUtVixYwAALFgzDCN6HA7ZDy9p1i2EBwCC4Xtem/naDtgcs6x/jwL/X634WQLP7DlS/7oH1WnzMAw4f1u7AYzZzvqbqtfY4zR1XDoNua7sYnjqcH/T8QYecq8PelV999VXcbjdz5swJK58zZw5PP/00r776Kj//+c+bbL9y5Ury8vK47rrrwspnz57NwoULefHFF0OJ3lNPPbVB+6SkJCZMmMC///3vw7+YLsbicNDt+uuJPvFE8h54gEBpGfl//SuVn31G6uxfYW/B121ERKTzeemll7BYLMyePTtUZhgG1113HVOnTmXFihWce+65jbZtTb/80ksvkZycHDYdhNPp5Nprr+W3v/0tW7ZsYfDgwe1whc37eOe3ANitBqMzBjVb1zAM5kw4hn1lbr7YXsgjH24jIzGKHxyTur/Slrfhw7uD0zUARKfCCZfD0B8Hp2EQERE5Qnn8nrDtopoiimqKGtQ7vigQWi9uxeeedoudlKgUUl2pJEclE2uPJdoeTbQtmihbVGi9blk3sthusYeWoZc1fOmwOLBaND2hiAh0YKJ3zZo1OJ1ORowYEVY+cuRIHA4Ha9euPWh7gJNOOimsvFu3bvTv3/+g7QFycnJITU1ttk5ZWRllZWWh7dzc3IMet6uIOeUUeh93HPn330/Vl19R9cUXZK1fT/KVM0m48ELN3Ssi0sWsWbOGAQMGkJycHFY+duxYANauXdtkorc1/fKaNWsYPXo01gP6ifrnaS7R21596zf7gg9Di3PEk+A8eCLWajG47YLjuP75r9meX8Gf3vyWBy47gUEpzmCCd+PLwYq2KBhzNRx/Odij2iRWERGRzszAwGbY8Jm+ZuslVexfL42zkhGTQUpUCimuFFKiUkh2JZMalRpWlhKVQpw9Tt8mFRHpAB2W6M3JySE9PR2LxRJWbrFYSE9PJycn56DtgUand8jIyDho+08++YRVq1Zx4403Nlvv/vvvZ8GCBc3W6cpsSUl0X7iQ8neXU/j44wQqKih85FEqPlhB6q9+ievYYyMdooiItFBOTg49evRoUJ5ROw97c31ja/rlnJwcTj755EM6D7RP31pVWUNJ3jZIhMz4zBa3i3bY+POPh/GrZ9dSVOlh0b8+5/HkZbgKNgYr9BwNP5wHCT3bNF4REZHO7IlznwCCD06r9lWTX5UfmlvXNE1MTKLt0biK/ouHpwE4c/RUpk2dH8GoRUTkQJaDV2kb1dXVYXPm1udyuaiurj5oe7vd3uAP0pa0z8vLY9q0aWRmZh70oW833XQTWVlZodfq1aubrd8VGYZB/Lnn0Pvxx4gdPx4A93ffsefXN5J37734CgsjG6CIiLRIU32rxWLBbrc32ze2pl9uqq7L5Qrtb05b962m38/3C/7EL976lNRSH0O79W1V+7R4F3+eMpxUayW/KLqb4m1r8AdMGDUTJv9dSV4RETlq2Sw24hxx9E/sz6j0UYxOH82J3U9kTPcxDE0ZSnRpTaiuPT09gpGKiEhjDmlE78qVKznzzDNbVHf58uWcddZZuFwu3G53o3VqampCfyw2xeVy4fV6MU2zwVc+mmtfXl7OxIkTKSsr48MPPyQhofmvdsbHxxMfH99snSOFLSmJ9D/+gbizfkjBI4/izc6m/L33qfj4ExIvvpiEH/8Ya2xMpMMUEZEmNNW3BgIBvF5vs31ra/rlpurW1D51+2B9eFv3rZ4dO/Cu+YrEqmquf6Ma28joVh9jUCI8EvskleV78AbgCfs0po3+FXGa409ERKRJvrz80LpNz3oREel0DinRO3jwYB5//PEW1R0yZAgQ/Hrn+++/TyAQCBuVGwgE2LdvX+jrn02p25+bm9ugbk5ODv3792/Qprq6mgsvvJBvv/2W5cuXM3z48BbFfLSJPvFEev/9eEr/+zrFzz5LoLKS4mefpfTVV0n8yU9ImHQRloP8ES8iIh0vIyODXbt2NSivm0qhub61Nf1yRkZGo/PqtuQ87cE5cCBvnDeTMf+9g7hqk273/ptyf3/izjqrZQfweeCN35BcswtnjIM/+6bxWdWJbHh5A/dMHYHLrmSviIhIY3z79oXWbWka0Ssi0tkcUqK3R48e/OxnP2tVm1GjRvHEE0+wfv16TjjhhFD5unXr8Hg8jBo16qDtAVavXs3kyZND5QUFBWzfvp2pU6eG1fd6vUydOpVPP/2U1157jVNOOaVV8R5tDLudxCk/Jm7CmRS/8AJlb7xBoKKCoiVLKH35ZRIvvZT48ydicTgiHaqIiNQaNWoUH3zwAUVFRWEPZPviiy9C+5tr29J+edSoUbz//vv4/f6wB7K15DztocbrZ4XD4MuJifzyvTLsPpO8e++j8tPPSPn5tc1/ldQ0YeUdsCf4kNeYs2/hpMCpfPbe92zcU8qtr2xg0aRhRDmU7BURETmQN69eojddI3pFRDqbDpujd9KkSTgcDhYvXhxW/uCDD+JwOJg0aVKorLS0lM2bN1NaWhoqO/PMM0lNTeWhhx4Ka7948WJM0+SSSy4JlQUCAaZPn87bb7/NM888w3nnnddOV3XksSYmkvqLX9D7ySeJn/gjsFrxl5RQ+Oij7J55JUXPPYe/vDzSYYqICHDJJZcQCAR4+OGHQ2WmabJ48WK6desWmmapoKCAzZs3U1VVFarXmn75kksuobCwkOeffz5U5na7eeyxxxg+fDjHdvCDPL3+AKcdZ5Cf4eTZK3sTOzKYaK789FOyrrmWomX/INDEtBRs+Dd8+1pwfdRMGHEJk07oyTWnB78ZtHZXMb/559eUVnk74lJERES6DNPrxV+w/3kudk3dICLS6RzSiN5DkZGRwR/+8AcWLVqE1+vl9NNPZ9WqVSxbtoy5c+eGPTX85Zdf5qqrrmLJkiXMmjULCM7/d9ddd/HTn/6UCy+8kEmTJrF+/XoeeughZs6cGTaa6Le//S3//Oc/Ofvss/H5fPzjH/8Ii+WKK67okGvuyuxpaXS7/noSp06l+LnnKP9gBf6SEoqfWUbJS/8k/txzSPjxj7F37x7pUEVEjlqjR49mxowZzJs3j/z8fIYPH84rr7zCypUreeqpp0IPUFu8eDELFixgxYoVjK99CGdr+uUpU6ZwxhlncM0117B582YyMzN5+umn2bZtG2+//XaHX3ecy87gXl6+rnSRktCXjOl/pvzttylc+jSBsjKKn32W8nffJWna5cSdfTaGrfa/O3nfwof3BNczx8HJ14WOeflJmVgMeHTVdjbnljPnhXXcM3UE6fGaukhERAQIPrTbNAEwnE4sB3n+jYiIdLwOS/QCLFiwgKSkJB566CFeeOEFevfuzf3338+vf/3rFrW/+uqrsdvt3H333cyePZu0tDRuueUWbrvttrB6a9euBYIPglu+fHmD4yjR23L2jAzSfvtbki6/nJL/vEz58uWYNTWUvvoapf99nZhTTiHhgvNxHX98g4fkiYhI+3viiSfo06cPS5cu5ZFHHmHQoEEsW7asRX1dS/tlwzB47bXXuOWWW3j88ccpKytj+PDhvP7665zV0nlx21hWeRYAveN6Y1gsxE+cSMzpp1P0zDOUvf4Gvvx88h/4G8UvvUTStGnEnXYyxjv/BwEfxPWAc/8MBzx47dIxmSRGO7jnnS1kF1Ux57l1LJo8jMHd4yJxiSIiIp1K2Py86en6+09EpBMyTLP2IzlpVHZ2Nr179yYrK4tevXpFOpyI85eUUPr6G5T+9zUCpWWhcnvPnsSffz5xZ5+FNU5/EIuISNPaom+d/f5sdpft5vJjL2facdPC9nl27qRo2T+o/OSTUJk92kNSv0Ji+9gwLn4MejY9r/Bn2wpZ8N+NeHwBbFaDX44fwOQTeuoPWhER6bQ64u/WsnffZc/1NwDBB3r3+ceydjmPiIgcug6bo1eODNbERJKvmE6fp58mdc51OPoH5zT07tlD4WOPsWv6Fey75x6qv/4aMxCIcLQiInIk8gf85FTkANArruEfs46+fel+2630Wvwg0ePGgrcab04ueZ94yPowhfItpZh+f5PHP3lACn+97ATS4p34/CYPvr+VBf/dRIXb127XJCIi0tn59uWF1m2an1dEpFNSolcOicXlIuH88+n10GJ6/uV+4s76IYbdjunxUPH+B+T88WZ2XzmLwiVL8WRlRTpcERE5grjdpfwwuhdDq6voG930XPHOgQPp8Ycb6Hk2RPWwgNWBt8Ig7+67ybr255S//36TCd9ju8fz2MwTOXlACgAffpfP1Uu+ZNV3+ejLUCIicjTy5dVL9KanRzASERFpihK9clgMw8B13HGk/fa39HnuWVKuvRZH374A+PLzKXnxRbKuuZbs62+g5F//wltvXicREZFDER0IcN2Wz7mzsITM/G1NVzRN+OBPuGLKyTg7np733UP0mDFA8JsoeffcS9Y111K2fHmjCd94l53bJw/jF2cMwGIxKKhws+C1jfzx39+QXVzVXpcnIiIR5PV6mTt3LpmZmbhcLkaMGMFzzz3XoramafLAAw8waNAgnE4ngwYN4m9/+1ujHxCWl5czZ84cunfvTlRUFOPGjePdd99t68tpU768+nP0akSviEhnpESvtBlrXByJU35M70f+Tq+HFpPw4x9jTUwEwP3ddxQ+8SS7r5xF9pzrKf7nP/Hm5kY2YBER6ZpiUqH32OD6t683Xe/b12D7iuD6qTfgGnc2PRYtoucDfyX6pJMA8ObkkH/f/WRdcw1l77yL6QufnsEwDC4Z05snZp7I8b0TAfhyZxFXLfmS+97dQk5JdVtfnYiIRNC1117Ln/70JyZPnsyDDz5I7969mT59Os8888xB2y5cuJBf//rXjBs3joceeoixY8dyww03sGjRorB6pmkyefJknnjiCX7605/y17/+FYvFwsSJE/nggw/a69IOm7fe1A12Td0gItIp6WFsB6GHsR0e0++n6qs1VHy4iqrPPidQFT4CynnMQGJOPZXok8bi6NdXD7oRETkKtEnfuuVtePf/AAOuehNiD/iDsyQLXrgcvNWQeTJc9CAc0Me4v/+eomefperzL0Jltu7dSbrssuCURDZbWH3TNHnv2zz+vnIrJVVeACwWgx8el8bFo3pxTFqs+jERkS5s3bp1jBo1ioULF3LbbbcBwff+CRMmsGnTJrKysnA4HI223bt3L3379mX69Ok8+eSTofJZs2bxwgsvsHPnTrp3D0439PLLLzNlyhSeeeYZZsyYAYDb7WbEiBFERUXx9ddftzr2jvi7dduPJuLZsQOAPs/+g+jRo9vlPCIicug0olfalWG1EjP2JNJ/9zv6vvA83RcsIO7ss7DExgLg/n4rRUufJvtXv2L3zCvJf3AxlV+sJuB2RzhyERHp1PqfAfZowIQtb4Xv87nh7T8Gk7yuBDhrfoMkL4DzmGPoMX8+vRY/SMwpJweb7t1L/l//yu6rf0rZm29ier2h+oZhcPaQdP7xs7Fce3p/EqPtBAImyzfu4xfL1nDNM1/x7zXZlFR52vHCRUSkvbz00ktYLBZmz54dKjMMg+uuu468vDxWrFjRZNtXX30Vt9vNnDlzwsrnzJmD2+3m1VdfDTtPcnIy06ZNC5U5nU6uvfZa/ve//7Fly5Y2vKq2ozl6RUQ6P9vBq4i0DcPhIGbsScSMPQnT66X666+p+Ohjqr74An9pKb78fMreeIOyN97AcDqJOv54osecSNTIkdh79tQoKRER2c8eBQPPCk7P8O1rcMJ0sNqC8/J+eC/kbw7WO2tBcKqHZjgHDqT73Lm4t2+n+NnnqPzkE3x5eeT/7UGKX3iRxEsuIf6cszFqR3FFO2xcdlImk0f25I31ufxrTTb7ymrYnl/JQyu28vDKrQzNSOCUASmcPCCFzORo9WEiIl3AmjVrGDBgAMnJyWHlY8cGpwtau3Yt5557bpNtnU4nI0aMCCsfOXIkDoeDtWvXhtUdPXo0Vqu1yfMMHjz4sK+nLfkrKglUVoa2bZq6QUSkU1KiVyLCsNuJHjOG6DFjMAMB3N99R9Xq1VSuXo1n6zZMt5uq1aupWr0aAGtqCtEnnEBU7cuW2vwf7SIichQYMimY5C3eCStuhwm3waq7YeN/gvtHz4J+p7X4cM7+/el+2624d+yg+Lnnqfz4Y3x5eRQsXkzJCy+QeNllYQlfl93KxaN78eORPfk6u4S3vsnlo+8L8PgCbNhTyoY9pTz24XaSYhyM6JXA8b0SGZIRT5+UaJw260GiERGRjpaTk0OPHj0alGdkZIT2N9c2PT0diyX8S7MWi4X09PSwtjk5OZx88smHdJ46ZWVllJWVhbZz2/n5J/UfxGZNTMTSxBQWIiISWUr0SsQZFguuY4/FdeyxJM+cia+ggKovv6Ry9Wqq//c/zKpq/AWFlL/3PuXvvQ+AvXdvok44Ppj4HT4ca3x8hK9CREQ6XMYJMPIKWPcP+Pa/sPlNMP3Bff3PhHG/OqTDOvv1o/v/3YJn506Kn3+eig8/wldQEEz4vvQSSZddStzZZ2PY7UBwnt5RmUmMykyi0u3jy51FfLqtkC+2F1Je46O40sOqLfms2pIfqp+ZHM3AtFgGdoulX7cYeiVFkRbnwmrRyF8RkUiprq7G6XQ2KLdYLNjtdqqrm34AZ1NtAVwuV1jbpuq6XK7Q/oO5//77WbBgQYPyL7/8kl27dh20fWtV/W89ebXPW7GnpZH/ySdtfg4REWncqaee2uK6SvRKp2NLTSX+Rz8i/kc/wvT7cX/3HdVff03VunW4v/0W0+vDm5WFNyuLsv8Gn7Zuz+xN1LBhuIYOxTV0KLb0dH1NVkTkaHDKDVC+F7a+tz/JO3QKjP8jWA5v1Kyjb1/Sb76ZpGnTKH7ueSo+/HD/lA4vvkTS5Zc3eGhbjNPG+MFpjB+chj9gsnlvGd9kl7K+doRvRY2PQMBkZ0ElOwsqeY96I6QsBj0So+iVFEVGgou0OBfd4p10i3WSFuckJdapRLCISDtyuVy4G3lWSCAQwOv1hhKxrWkLUFNTE9a2qbo1NTWh/Qdz00038bOf/Sy0nZuby0knncSYMWPa5WFsJbm55EZHAxA7dCi9W5F0EBGRjqNEr3RqhtWK67jjcB13HEmXX06gpoaaTZuoXreO6q+/xr11G5gm3t1ZeHdnUfZm8IE81pQUXEOHEDV0KK4hQ3D069fg6ekiInIEsFjgnD/BMecEH7gW3xO6te28ho4+fUi/+Y8kXX4ZRc8+R+VHH+Hbt4/8v/6V4hdfCCZ8f/hDjAPmWrRaDIZmJDA0I4HLgEDAZG9ZDVvzKtiaV8G2/Aq+z6ugoDz4x74/YJJdVEV2UVUTl2qQHOOgW6yTxGg7idF2kqIdJEY7SIyykxRjJzHaQVK0g4Qou5LCIiKtlJGR0eho2LqpFOqmVmiq7fvvv08gEAibviEQCLBv376wthkZGY1OtdCS89SJj48nvgO/1ejN3RtatzcyvYWIiHQOynxJl2JxuYgeNYroUaMA8FdU4P72W6o3bqRm4ybcmzdjer34Cwup/PAjKj/8CAg+CM45cADOQYNwDhqMa/AgbBkZGvUrInIksNpg4A/b/TSOvn3p/n+34N6+g+Jnnw0+tC13L/n3/4WSF14kafo0YsePb5DwrWOxGGQkRpGRGMXpg7qFyqs9fvaUVJFdXM2ekmr2FFeTW1pDfrmb/HI3Xn8ACCaKC8rdocRwswyIc9mJd9mIddqIc9mIcdqIc9mJq92OrbcvzmUn1mkj2mElymHFYbWojxSRo86oUaP44IMPKCoqCnsg2xdffBHa31zbJ554gvXr13PCCSeEytetW4fH4wlrO2rUKN5//338fn/YA9lacp5I8e7dn5i29egewUhERKQ5hmmaZqSD6Myys7Pp3bs3WVlZ7fIVGGlbpseDe+vW2sRvMPkbKC9vtK4lNhbn4EG4Bg8OJoAHD8aWlNTBEYuIHH2OlL7VvW0bRf/4B1WffR4qs/fqFUz4nn56kwnf1jBNk5IqL/kVwaRvXnkNhRUeSqq8FFd5KKn2UlLlobjSS43Xf9jnq2OxGEQ7rLUvG1EOK9F2a20i2BbaF1W3tAfruGwWnHYrTpsFVyNLjTIWkc5szZo1nHjiiSxatIhbb70VCL4PT5gwgY0bN5KVlYXT6aSgoICCggIyMzOJrp3OICcnh379+jFjxgyeeOKJ0DFnzZrF888/z86dO0MPevv3v//N1KlTWbZsGVdccQUAbrebESNG4HQ6Wb9+fatjb+++dffVP6Xy008ByLjnHhIuvKDNzyEiIodPI3rliGI4HLiGDME1ZAj85CeYgQDenFzc323BvWULNVu+w7NtG6bXS6Ciguo1a6leszbU3pqcjHNAfxz9B+Ac0B/ngAHYevTAOODpuSIiIs4BA+gxbx7u77+naNk/qFq9Gm92Nnl33U3x88+TfMUVxPzgB4fVhxiGQVKMg6QYB4PS45qtW+P1U1odTAAXVwYTwOVuHxU1PircPsprvLXL4KuuzOdv+Jl/IGAG29X4gBaMIG4hq8XAZbfisFlw2S04bdYGS6fdgtNqwW61YLcFlw6rEdyuLavbdtTuD9axYLcZwWXYPiO036JEs4g0Y/To0cyYMYN58+aRn5/P8OHDeeWVV1i5ciVPPfVU6AFqixcvZsGCBaxYsYLx48cDwekW/vCHP7Bo0SK8Xi+nn346q1atYtmyZcydOzeU5AWYMmUKZ5xxBtdccw2bN28mMzOTp59+mm3btvH2229H4tIPyru3/tQNGtErItJZKdErRzTDYsHRqyeOXj2JmzABANPrxbNzJzVbvgsmf7/bgnd3Fpgm/qIiqoqKqPryq/3HcLlw9u8XTP4OHICjf38cffpgaeKpuiIicnRxHnMMPRYuoGbLFoqWLaP6qzV4d2ex78934OjTh6QrriDm1FPa/UNDl92Ky24lPf7gD/Gpr8brDyWAqzw+arx+qjzBV7Wnbt0XXPfWlflCder21XgDoWkmmuIPmFS6fVS2Xe64VSwWA5vFwFq7tFktoXWrJZgQbs22zWppcLy6bavFwGKAxTCwGLXbFgOrUVtet24htN9qGBhGvba1ZVaLgWFQW26E1a/f3jCoPX7wXIYBBvtjqJuNw1Ibg2EYGASnt9ZUHSJBTzzxBH369GHp0qU88sgjDBo0KGzkbXMWLFhAUlISDz30EC+88AK9e/fm/vvv59e//nVYPcMweO2117jlllt4/PHHKSsrY/jw4bz++uucddZZ7XRlh840zbBEr6275ugVEemsNHXDQRwpXy+V5gUqK3Fv24Z7+3Y827bh3rYdz65d4G/ia7CGgb1HD+x9MnH27Ys9MxNH3744evbEcDg6NngRkS7mSO9bazZtougfz1K9dv83Rhz9+pF46SXEnnZam0zp0Fn5AyYeXwC3z4/bF6DG28jSG2h6X+3S6w/UvoLH84S2A7XbJl5fcNsf0H9l20pdwtdigIER2jYglFA2aivWJbGNA9pBSxLLddv7zxM6P7XZaUKLUBI6dK5GyutWwo5XW3//7rprqldS73hhMdQ714FJ8P3HNw6Ip5G4GjnegecyDrje+kE3ln6vf/6GZQ3rhbdtWHjgPRreM4EfHJPayJmlM2vPvtVfUsJ3404ObhgGx67/H4bd3qbnEBGRtqERvSKAJSaGqBEjiBoxIlRmejx4srJwb92GZ8f2YPJ32zYCVVVgmnhzcvDm5ITNzYjVij0jA0efPsFX3744+mRiz8jAsOmfm4jI0cA1ZAgZf/4T1Rs2UrxsGdX/+x+eHTvIu/Muip5+msSLpxJ3ztlYjsAPBq0Wg6jauXs7SiBg1ksEm6FkcFOJYl/AxO838QVq1wMmvua2/Sa+QOPb++uatW0D+PzBYwRME78ZHAlXtx0IgL/+tmniDwTr+GrLiGDe2qyNNzguWwn0o5U/EFCiV8KEjebt1k1JXhGRTkyZJ5EmGA4HzgEDcA4YECozTRPf3r14du3Cs3NXcLlrF96sLEyvF/x+vFlZeLOyqPz44/0Hs1qxd++OvXcvHL16Ye/VC3vPXth79cSamKivS4qIHIGihg0l6q47qV6/nuIXX6R6zVp8uXspWLyY4mefJWHyZOLPPx9rbEykQ+3SLBYDlyU4bcWRwDRNAiYHJIP3J4mDCePaOrXr9esGTPD5TUzM2sQtmATLTdPErHeOUGK3dgkQMCFQr55Zt12bgw6uh2+Ht99/3v3t6o6zPxZq2wVq29V9x9Csdx/Cy83Qel29umOH1atXqX7evH4c9Y9RV9Dk8Q5oT73yxs7VVLx1sTUXT/3z1Y+h/p6w+MOupWFivsG1NteuXlnPpKgGx5Kjmzc3N7Ru0/y8IiKdmhK9Iq1g1E3Z0KMHMePGhcpNvx9vbm5t8nfn/gRw9p7g9A9+P949e/Du2UMVX4Qd0xIbi71XLxy9eoYSwI7evbBlZByRo71ERI42dd8YcW/dSsk//0nFRx/jLy6maMkSSl58kfjzJxJ/wYXY09MiHap0AoZhYK2dj1dEpDPw1X8Qm+bnFRHp1JToFWkDhtWKo1dwtC4/ODVUbnq9ePfswZOVjXdPNt7sbDzZe/BmZxOoqAAgUFGBe/Nm3Js3H3BQA1tqKrYe3bH3yKhNMHfH1qMH9owMrLGxHXmJIiJymJwDB5J+880kz9xDyb//Q/ny5QSqqij5578o+fd/iBk3joTJk3ANH65veoiISKfhza2f6NWIXhGRzkyJXpF2ZNjtwXl6+/YNKzdNE39JCd7sPfUSwNnB7dzc4Chg08SXn48vP5+a9d80OLYlNhZ7jx7hieCMHti6d8eWmtruT3cXEZFDY+/Zk27XzyFpxhWUvvwKZW+9RaC8nMpPP6Xy009x9OlDwqSLiJ0wAYvLFelwRUTkKOfdu3/qBnuGRvSKiHRmSvSKRIBhGNiSkrAlJRE1fFjYPtPnw7t3H97sLLx79+LLzcWbk4s3Nxffvr2YXh9QOxL4++9xf/99w+PbbdjS0rGlp2NL64YtLQ17ejq2tLRgeWrKEf3UdxGRrsCWlETK1VeRNH0aFStWUvraa3i2b8ezaxf5f3uQwqeWEHf22cSfdy6OPn0iHa6IiBzB/BUVeLZtw711G/E/Og9LdHRony+n3hy9mrpBRKRTU6JXpJMxbDYcvXri6NWzwT7T78dXWIhv795Q8tebm4Mvdy/e3NzQdBCm1xeaE7hRViu2lJTaxG8a9vS00LotPR1bt25YnM72vEwREallcTqJP+9c4s49h5oNGyh99TUqP/uMQEUFpS+/TOnLL+M87ljizzmH2NNPxxKjh7eJiEjb2n7BhaG5eJ2DjiFq+PDQPm9OTmjdroexiYh0akr0inQhhtWKPS0Ne1oaUSNGNNjvr6jAm5MTHAW8bx++fXn48oIvb14eZnV1bUV/qLwp1sTEYOI3NRVbt1SsqanYUlKxpaZgSw1u62FxIiJtxzAMooYPJ2r4cLx5eZS9+Sbl7y7HX1SE+9vN5H+7mYJHHyP2B6cSe+aZRJ1wgr6dISIibcLZv38o0eveui2U6A1UVwenlqulb5iIiHRuSvSKHEGssbFYBw2CQYMa7DNNk0BFBb59+0KJ3/2J4H148/IIlJaF6vtLSvCXlOD+7rsmz2eJjw8mglODCeADk8G21FSNPBMROQT2tDRSZs0iecYMqr5aQ/k771D5xReYNTWUv/c+5e+9jzUhgZjTTiN2/HhcQ47T3OwiInLIHAMHUPnppwB4tm0NlXt27gTTBMCamoo1ISES4YmISAsp0StylDAMA2tcHNa4OJwDBzZaJ1BTs38E8L59+PLy8RcW4CsoxFdQgK+gALOmZn/9sjI8ZWV4tm9v+rzRUcHkb0oK1uRkbCnJWJOSsSYl7i9LSsKIjtZT5kVEDmBYrcSMPYmYsSfhKy6m4oMPKF+xAs/WbfhLSyl7/XXKXn8dW7duRJ88jphxJxM1YjiGTf/FExGRlnMO2P/3gXvrtv3r2/b/P9/Zv3+HxiQiIq2nvwJEJMTicuHIzMSRmdnoftM0CVRW4i8oCM4VnB9M/voKC4Jl+cHyQHn5/jZV1XirsvBmZTV7bsPlCiZ/k5LrJYSTsCbVrifXbickaNSaiByVbElJJF58MYkXX4wnO5uKlauoWLkSb3Y2vvx8yl77L2Wv/RdLTAzRY8YQc/I4okaNwhoXF+nQRUSkk3MOHBBad9cbxFF/QIejf78OjUlERFpPiV4RaTHDMILTQ8TG4ujbt8l6gZoa/IW1o4DrEsKFBfiLivEXFeErLsJfVIzpdofamDU1+HL34svd23wQVmtw/uDkJKyJSVgTE/e/khKxJiQEE8KJiVjj4zWqTUSOSI5evUi+YjpJ06fh2baNyk8+ofLTz/Ds2kWgspKKlSupWLkSDAPnoEFEjxpJ1KhRuI49FsNuj3T4IiLSyTjqjdb1ZmURqKnB4nKFJX2d/Qc01lRERDoRZUBEpM1ZXC4sPXti79mzyTrB0cFV+IuL8BcX4yssDCaCi4vw1U8IFxYRqKjY39Dvx19YiL+wsGWxxMUFE8CJiVgTErEmJoRGBtcliG21S00fISJdjWEYOAcOxDlwIMlXXok3J4fKzz6n8vPPqdm4EQIB3Fu24N6yheLnX8BwuXANHYJr6FCihg3DOWgQFpcr0pchIiIRZktKwpqSEvw/tmni2bED13HHHTCiV1M3iIh0dkr0ikhEBEcHx2CNjYHevZutG/B48BeX4C8qxFdUFEoI1z0wru7lKynBrKoOb1teTqC8HO/u5qeOADDs9trkbwKW+His8QlYE+LD1q3xtdsJCVjj4jQyTkQ6FXtGBokXTyHx4in4Kyqo/t//qF67jup1a/Hm5GLW1FC9Zi3Va9ZSDGCz4jzmGFyDB+M85hicAwdi79ULw2qN9KWIiEgHcw4YQFXtYAr31m04Bw0KPowttF+JXhGRzk6JXhHp9CwOB5b0NOzpaQetG3C78ZeU4i+tTQAXl9Sulwa3S2vLSkrwl5aC3x9qa3q9+PLz8eXntzg2IzoqmASOj2+QFLbExYUniOuSw5pOQkQ6gDU2lthTTyX21FMB8ObmUv3111Rv2EDNxk349u4Fnx/3t5txf7s51M5wuXD2749z0DE4BgzA2bcv9t69NfJXROQI5xw4gKrVqwFwb9uKd88eTI8HACM6Glv37pEMT0REWkDZBhE5oliczhYnhc1AIPhwuQNGBvuLS/CXlxEoK8NfWoa/vJxAaSn+0lJMrzf8GFXV+KqqgwmTlsYYE1ObCI7HEheLtW4ZG4slLh5rXCyW2LjafXHBhHFsrEYPi8hhsffogb1HD+J/9CMAfAUF1GzcSPXGjbi//x7Ptu2YHg9mTQ01mzZRs2lTWHtbejqOzN7YMzNx9M7EkdkbW48ewalvNO2NiEiX5xiwfw5ez7ZtuLdtC207+/XTe72ISBegRK+IHLUMiwVrXFzwifQHmT4CgvMKm243/tIyAmWl+OsSwWWl+5PCZbX7atf9ZaXg84cdJ1BZSaCyEshtXbxRUbXJ4LhgMjguHktszP5kcFwclti6fbXr8XEYLpf+Yy4iDdhSU4k94wxizzgDANPvx7N7N+7vv8e9dSuerVtxb9seenCmb98+fPv2wZdfhR3HcLmwp6dj69Ede/fu2LrXLlNSsKakYE1I0FQQIiJdgHPAwNC6e+s2PNt3hLY1P6+ISNegRK+ISAsZhoHhcgW/vtyCEcNQmxyurg4lhUMJ4pJSApUV+MuCcwj7K8oJlJXvL6usBNMMP1Z1Nb7qamjF1BIAWK3BJHBMDJbYWCyxMcFRxbGxwbKY2GB5TPT+sti6shgMh0OJYpGjgGG14uzXD2e/fnDOOUDwmw++vDw8u3bjzdqNZ3cWnqzdeHdn1X5gBWZNDZ5du/Ds2tX4gS0WrElJ2JKTsaYkY0tJDS4TE7HEJ2CNj9P0NiIinYBzYL0Rvbt3U/bGG/v3aX5eEZEuQf+TFhFpR4ZhYERHY4mOxt6Kec1Mv59AVVVwpHBFJYGK8uBo4YqK4FQSdUnh2nV/RTmB8gr85WUNRhDj94empTika7DbgsngFiSKw8rrEsVOpxLFIl2UYbFgrx2hy9iTQuWmaQYfgpmbi3fvPrx7c/HVW/ry8/d/WBUI4C8sDD7J/fuDnzNsepvYGCxRwfdQS0zMAcva8uhoLFFRwQ+lXC4MhxOLy6mEsYhIK1lTUnD0749n+3bw+8Om8Ik7++wIRiYiIi2l/wGLiHRCRu0oXGtcHK2ZmTc0vURtUjhQXo6/vIJARXlwu7ISf0UFgYrg9BGhssoKApVVmNXVDY/p9R1WohibtWGCJjoaS3RMbaKm3vLAZE5t8tgSFaWvfot0IoZhYEtKwpaUhGvIkAb7Tb8ff3ExvsIi/EWF+5cFhfiKCvEXFOIvDX7Dof5DMeHQp7dpwGrFcDqwOF21SWAnFocz+OGT04HF6QzOfW6zYdjsGDYbht0WTBDbbBjWA7abqmOzgsWCYbGAxQIWK4bFCJ4/VFa732oNfvBltYJhwbA2UWaxBMssFqj7oMwwwDD0wZmItBvDMEi5+ipyb70trDzm1FNx1pu/V0REOi8lekVEjiBh00uktWx6ifpMny84krixpHBlE4niULK4stFEMT4/gdIyAqVlh3dtLlcjI/piSL/lZiU+RDoZw2rFlpqKLTW12XqmaWJWVdXOaV77EMyysuAUNmWlwfejysoDllXBZVVV6GnwjfL7Mauq8Vc18r7U1dUmfYMvat8DGyaGjfr1oHY/GBgHHKPuOAckli3G/mPXldWPoUEZDfdTe77GrqG59SaO2/j7/UGOFXbcgx2rXqWDHeuA4zV+3INcW3jDZsNpsUaOc9BrrSd63FgSzj+/lSeVI0XCRReR/+Di4JzstZJmXBHBiEREpDWU6BURkRDDZsMaH481Pv6Q2pt+f/hI4bqkcFVdYqayNklTb1m3r7KyyVHFEJwH1F9TE/z6d128UVFK8op0YYZhYMQEp4Ox9+jR6vam1xt873B7MN01mG73/nWPh0BNDabbg+lxE3C7MWvc4es+H6bPCz5/7Xrdtg/T6wuV4a+37fVi+v3g82J6fe1wV1py4WbYPO5mU9U6Jho5wth7ZkQ6BIkgw+Eg5eqr2HfHnQDYMzOJPf30CEclIiItpUSviIi0GcNqPaxEMdQmi6urwxPBB4zoM6uq8FdWKskrcpQz7HasiYlEamIX0zQhEIBAALNu6Q+AGQiOKDbN4LJ+WaBuaULAH2pXv63p9wfLgicB0wweyyR4nPrlgfpJ32A8Zl0i2Kwtq3+MgLm/fVj9ukMEGjlneGI5VLl+stlsIq1cV242Unbgjtrygx6rfjOzYfuGJ2y432zsWGHxHuRYTcXb1HXW/xm1RhP3okX3qLGYWnDsxqZjkaNL4mWXUf2/9dRs3kyP2xcFp54REZEuQYleERHpVAyrFWtsLNbY2EiHIiLSrNDculZrq79dLyLSWVmcTnref1+kwxARkUOgj+ZEREREREREREREujglekVERERERERERES6OCV6RURERERERERERLo4JXpFREREREREREREujglekVERERERERERES6OCV6RURERERERERERLo4JXpFREREREREREREujglekVERERERERERES6OCV6RURERERERERERLo4W6QD6Ox8Ph8Aubm5EY5ERETaS/fu3bHZ1CV2FPWtIiJHPvWtHUt9q4jIka2l/ap63oPIz88H4KSTTopwJCIi0l6ysrLo1atXpMM4aqhvFRE58qlv7VjqW0VEjmwt7VcN0zTNDoiny6qpqeGbb76hW7duh/WJdG5uLieddBKrV6+mR48ebRjh0U33tX3ovrYP3df20Rb3VaOOOlZb9K3699Q+dF/bh+5r+9B9bR9tdV/Vt3Ys9a2dl+5r+9B9bR+6r+2jI/9mVc97EC6XizFjxrTZ8Xr06KFPttuB7mv70H1tH7qv7UP3tetoy75VP/f2ofvaPnRf24fua/vQfe1a1Ld2frqv7UP3tX3ovraPjrivehibiIiIiIiIiIiISBenRK+IiIiIiIiIiIhIF6dEbweJj49n3rx5xMfHRzqUI4rua/vQfW0fuq/tQ/f16KSfe/vQfW0fuq/tQ/e1fei+Hr30s28fuq/tQ/e1fei+to+OvK96GJuIiIiIiIiIiIhIF6cRvSIiIiIiIiIiIiJdnBK9IiIiIiIiIiIiIl2cEr0iIiIiIiIiIiIiXZwSvSIiIiIiIiIiIiJdnBK9IiIiIiIiIiIiIl2cEr0iIiIiIiIiIiIiXZwSve3I6/Uyd+5cMjMzcblcjBgxgueeey7SYXUZK1euxDCMRl/vvfdeWN19+/YxY8YMUlJSiI2NZcKECaxZsyZCkXceFRUVzJs3j4kTJ9KtWzcMw+DOO+9stG5r7uFXX33FhAkTiI2NJSUlhZkzZ5KXl9eel9KptPS+zp8/v8nfYZ/P16D+0X5fv/zyS66//nqGDx9ObGwsCQkJnHPOOXz44YcN6ur39eilvvXwqG89fOpb24f61ranflVaSn3roVO/2jbUt7YP9a1tryv0rbZDaiUtcu211/LMM88we/Zshg8fziuvvML06dPx+XzMnDkz0uF1GbNnz2bcuHFhZcOGDQutV1dXM2HCBPbu3ctNN91EYmIiDz/8MOPHj+eLL75gyJAhHR1yp1FQUMDChQvp1asXI0eOZPny5Y3Wa8093LRpE+PHj6dv377cddddlJSUcN9997F27Vq+/PJLoqKiOuryIqal97XO4sWLSUhICCuzWq1h27qvcNddd/Hhhx8ydepUrr/+ekpKSnjssceYMGECr7/+Oueddx6g39ejnfrWtqG+9dCpb20f6lvbnvpVaSn1rYdP/erhUd/aPtS3tr0u0bea0i7Wrl1rAubChQtDZYFAwBw/fryZlpZmut3uCEbXNaxYscIEzOeff77Zevfff78JmB9++GGorKCgwExJSTEnTZrUzlF2bjU1NeaePXtM0zTNHTt2mIB5xx13NKjXmnt40UUXmd26dTMLCwtDZXU/q7/85S/tch2dTUvv67x580zAzM3NPegxdV9N8+OPPzZramrCyoqKiszu3bubo0aNCpXp9/Xopb718KlvPXzqW9uH+ta2p35VWkJ96+FRv9o21Le2D/Wtba8r9K1K9LaTP/7xj6bFYgn7QZmmaf7rX/8yAfPtt9+OUGRdR/1Os7y83PR4PI3WGzdunHn88cc3KL/uuutMh8NhlpWVtXOkXUNzb+wtvYdlZWWm3W43b7jhhgZ1hw0bZo4bN66tw+70WtJh5uTkmKWlpabf72/0GLqvzZs2bZrpdDpD2/p9PXqpbz186lvblvrW9qG+tX2pX5X61LceHvWrbU99a/tQ39q+OlPfqjl628maNWsYMGAAycnJYeVjx44FYO3atZEIq0u65ppriIuLw+Vycdppp/Hpp5+G9gUCAb7++mtOOumkBu3Gjh2Lx+Nhw4YNHRlul9Oae/jNN9/g9XqbrPu///2PQCDQ7jF3NYMGDSIhIYG4uDguvfRSsrOzw/brvjYvJyeH1NRUQL+vRzv1rW1HfWv70ntV+1PfeujUr0p96lvbhvrV9qf3q/anvvXQdaa+VXP0tpOcnBx69OjRoDwjIyO0X5rncDiYMmVKaOLwzZs3c9999zF+/HhWrFjBqaeeSlFRETU1NbrXh6E197Bu2VTd6upqiouLSUlJaceIu46kpCRmz57NySefTFRUFJ988gkPPvggn3/+OWvXrg3dJ93Xpn3yySesWrWKG2+8EdDv69FOfevhU9/aMfRe1X7Utx4e9atyIPWth0f9asfR+1X7Ud96eDpb36pEbzuprq7G6XQ2KLdYLNjtdqqrqyMQVddyyimncMopp4S2L7roIi699FKGDh3KH/7wBz7++OPQfWzsXrtcLgDd64NozT3U/W6dG264IWx7ypQpnHLKKUydOpW//OUv3H777YDua1Py8vKYNm0amZmZzJ07F9Dv69FOfevhU9/aMfRe1X7Utx469avSGPWth0f9asfR+1X7Ud966Dpj36qpG9qJy+XC7XY3KA8EAni93tAPTFqnT58+TJ06lc8++4yqqqrQfWzsXtfU1ADoXh9Ea+6h7vfhu/jii+nbty/vvfdeqEz3taHy8nImTpxIWVkZ//3vf0NPf9Xv69FNfWv7UN/a9vRe1bHUtx6c+lVpivrWtqd+tX3o/apjqW89uM7atyrR204yMjLIzc1tUF43LLtuuLa0XmZmJoFAgJKSElJSUnA6nbrXh6E197Bu2VRdl8vVYH4vaah3794UFRWFtnVfw1VXV3PhhRfy7bff8sYbbzB8+PDQPv2+Ht3Ut7Yf9a1tS+9VHU99a9PUr0pz1Le2D/WrbU/vVx1PfWvTOnPfqkRvOxk1ahTbtm0L+0cB8MUXX4T2y6HZvn07VquV5ORkLBYLxx9/PKtXr25Q74svvsBut4f9g5OGWnMPhw8fjt1ub7Lu8ccfj8Wit5WD2bFjB2lpaaFt3df9vF4vU6dO5dNPP+Xf//532FfhQL+vRzv1re1HfWvb0ntVx1Pf2jj1q3Iw6lvbh/rVtqf3q46nvrVxnb5vNaVdfPXVVyZgLlq0KFQWCATM8ePHm926dTNramoiGF3XUFhY2KBs06ZNptPpNCdMmBAqu/fee03A/Oijj0JlBQUFZkpKinnhhRd2SKxdwY4dO0zAvOOOOxrsa809vOCCC8xu3bqZRUVFobIVK1aYgHnfffe13wV0Us3d18Z+h5988kkTMBcuXBhWrvtqmn6/3/zJT35iWiwW8/nnn2+ynn5fj17qWw+f+ta2pb61fahvbRvqV6Ul1LceHvWrbU99a/tQ39o2ukLfapimabYuNSwtNXPmTJ599lmuu+46hg8fziuvvMIbb7zBU089xVVXXRXp8Dq9s88+m9jYWE488UTS0tLYsmULjz76KKZp8tFHHzFy5EgAqqqqGD16NHl5efzmN78hISGBhx9+mF27dvH5558zbNiwCF9JZC1evJiSkhJKSkq47777OOecczjttNMAmDNnDgkJCa26hxs2bGDs2LH069ePX/7yl5SWlnLvvffSvXt3vvrqK6KjoyN1qR2qJfc1KSmJKVOmMGTIEKKjo/n000959tlnOfbYY/n888+Jj48PHU/3FW666Sb+8pe/cPbZZzNz5swG+6+44gqgdf/mdV+PPOpbD4/61rahvrV9qG9tW+pXpaXUtx469attR31r+1Df2ra6RN/a6tSwtJjb7TZvvfVWs1evXqbD4TCHDRtmLlu2LNJhdRkPPPCAOXbsWDM5Odm02Wxmenq6efnll5ubNm1qUDc3N9ecPn26mZSUZEZHR5vjx483V69eHYGoO58+ffqYQKOvHTt2hOq15h5+8cUX5vjx483o6GgzKSnJnD59upmbm9tBV9Q5tOS+XnPNNebQoUPNuLg40263mwMGDDBvuukms7i4uNFjHu339Ywzzmjynh7YXen39eilvvXwqG9tG+pb24f61ralflVaSn3roVO/2nbUt7YP9a1tqyv0rRrRKyIiIiIiIiIiItLFHfmzJIuIiIiIiIiIiIgc4ZToFREREREREREREenilOgVERERERERERER6eKU6BURERERERERERHp4pToFREREREREREREenilOgVERERERERERER6eKU6BURERERERERERHp4pToFREREREREREREenilOgVERERERERERER6eKU6BU5iuzcuRPDMFi6dGmkQ2mV+fPnYxhGpMMQERFpQH2riIhI21G/KnJ4lOgViZDXX3+d+fPnRzoMPv74Y+bPn09JSUlE48jOzmb+/Pl8/fXXEY1DRES6LvWt4dS3iojI4VC/Gk79qnQFhmmaZqSDEDka/eIXv+DRRx+lI/8JmqaJ2+3GbrdjtVoBuPPOO7n55pvZsWMHffv27bBYDvT5559z8skns2TJEmbNmhW2z+fz4fP5cLlckQlORES6BPWt4dS3iojI4VC/Gk79qnQFGtErcoSpqqpqcp9hGLhcrlCHGak4Wstms6nDFBGRiFHfKiIi0nbUr4q0HyV6RVpow4YNXHTRRSQmJhIdHc0pp5zCO++8E1Zn6dKlGIbBzp07w8oPnGdo1qxZPProo0CwI6t71W/3z3/+k3HjxhETE0NCQgIXXnghmzZtCjvurFmzcLlc7N69m8mTJ5OQkMDEiRObvIYD45g/fz4333wzAP369QvFsXLlylCbDz74gAkTJhAfH09MTAxnnnkmn376adhx6+Yj2rRpE1deeSUpKSkMHToUgF27djF79myOO+44oqOjSUxM5MILL2Tjxo2h9itXruTkk08G4KqrrgrFUfc1oabmO3ryyScZMWIELpeLtLQ0Zs2aRW5ubqP3aN++fVx66aXEx8eTlJTEtddeS01NTZP3SkRE2p/6VvWtIiLSdtSvql8VsUU6AJGu4LvvvuPUU0/F6XRy4403Ehsby5IlS5g4cSIvv/wyF110UauO9/Of/5ysrCw++OADli1bFirv1q0bAPfeey+/+93vmDJlCjNmzKCiooKHH36YU089lbVr19KvX79Qm0AgwDnnnMOYMWO4++67sdla/s96ypQpbN68mRdffJG//OUvpKamAnDccccB8OKLLzJt2jTOPPNMFi1aRCAQ4KmnnmLChAmsWrWKsWPHhh3vkksuoV+/ftx+++243W4AvvzyS1auXMmPf/xj+vbtS25uLo8++iinn346GzdupHv37hx33HHMnz+f+fPnc+2113LaaacBMGLEiCZjr/v6zumnn84999zD7t27Wbx4MR999BFr164lISEh7B6dffbZjBgxgrvuuovVq1fz+OOPk5qayp///OcW3y8REWk76lvVt4qISNtRv6p+VQQAU0QO6uKLLzZtNpv57bffhspKS0vNzMxMs2/fvqbf7zdN0zSXLFliAuaOHTvC2u/YscMEzCVLloTKfv7zn5uN/RPcvXu3abPZzHnz5oWV5+TkmPHx8ebVV18dKrvyyitNwPzNb37ToutoLI477rij0ZgrKyvN5ORk88orr2xQ3rdvX3PChAmhsnnz5pmAOXXq1AbnrKysbFC2bds20+l0mrfffnuo7LPPPmsQ24HHr5Ofn286nU5z/PjxptfrDZW/8sorJmDedtttobK6e3TzzTeHHXPSpElmampqg3OJiEjHUN8aXq6+VUREDof61fBy9atytNLUDSIH4ff7efvtt7nwwgs59thjQ+Xx8fH84he/YOfOnWzYsKHNzvef//wHn8/HZZddRkFBQehlt9sZN24cH3zwQYM2v/rVr9rs/HWWL19OUVER06dPD4ujqqqKs846i48++gifzxfW5pe//GWD40RHR4fWq6qqKCwsJD4+nsGDB7NmzZpDiu29997D7XZz4403hn0aPGnSJAYPHszrr7/eoM3s2bPDts844wwKCgooLy8/pBhEROTQqW9V3yoiIm1H/ar6VZE6mrpB5CDy8/OprKwM6zDrDBkyBIAdO3Y0+5WN1tiyZQuw/6soB6rfCQFYLBYyMzPb5NyNxXHOOec0Wae4uDj01R0g7Os5dWpqapg7dy7/+Mc/GsxFlJKSckix1c0L1djP5LjjjgubrwnAbrfTs2fPsLKkpCQAioqKiIuLO6Q4RETk0KhvVd8qIiJtR/2q+lWROkr0ihwG0zQBQpOuNzb5OgQ/YW2pQCAAwFtvvdXo3EUHPn3Ubre3ao6j1saxdOnSBh1OnfpzCgFERUU1qHP99dfz5JNPMmfOHE499VQSEhKwWCz8+te/Dp2jLZmm2eDnYLE0/eWFup+hiIh0Dupb1beKiEjbUb+qflWOLkr0ihxEt27diImJYfPmzQ321ZX17dsXgMTERABKSkrC6h34RFNouoMdMGAAAJmZmaFPX9vTweLo1q0bZ5111iEf/6WXXmLmzJn89a9/DSsvLi4OTaTfXByNqbvfmzdvZtCgQWH7Nm/eHNovIiKdk/pW9a0iItJ21K+qXxWpozl6RQ7CarVy3nnn8frrr/Pdd9+FysvLy3n00Ufp27cvw4YNA2DgwIEADb6G8dBDDzU4bkxMDBDsPOqbOnUqNpuNefPmNfrpYX5+/mFdT0vjOO+880hMTAx7GumhxGGz2Rp8Avn888+Tk5PTojgac9ZZZ+F0OnnggQfCPnn+73//y5YtW7jgggtaFJuIiESG+lb1rSIi0nbUr6pfFamjEb0iLfCnP/2J5cuXc9pppzF79mxiY2NZsmQJu3fv5j//+U/oaxZDhw7lBz/4AbfccguFhYWkp6fz2muvNdoRnHjiiQBcd911/OhHP8Jms3HhhRfSr18/7r77bm666SbGjRvHlClTSE5OZteuXbz55puMHTuWRx55pM2urS6OW265hcsvvxyHw8GECRNIS0vjscce4/LLL+f4449n+vTpdO/enezsbFasWEFMTAxvvfXWQY9/4YUX8swzzxAfH8+wYcP4+uuvefHFF+nfv39YvWOOOYb4+Hj+/ve/ExsbS1xcHMOGDQv9h6S+1NRU5s+fz80338xZZ53FlClTyMrK4sEHH6R///785je/aZubIyIi7UZ9q/pWERFpO+pX1a+KAGCKSIt888035gUXXGDGx8ebUVFR5sknn2y+9dZbDert3LnTPO+888yoqCgzJSXFvO6668yNGzeagLlkyZJQPZ/PZ15//fVmenq6aRiGCZg7duwI7X/99dfN8ePHm3FxcWZUVJQ5cOBAc9asWebq1atDda688krT6XS2+Bp27NjRIA7TNM3bb7/d7N27t2mxWEzAXLFiRWjfJ598Yp5//vlmUlKS6XQ6zb59+5qXXnqpuXz58lCdefPmmYCZm5vb4JylpaXmNddcY6alpZnR0dHm6aefbq5evdo844wzzDPOOCOs7muvvWYOGzbMtNvtJmDOmzcv7PgHevzxx81hw4aZDofDTElJMWfOnGnm5OSE1WnqHi1ZsqTBPRcRkY6lvlV9q4iItB31q+pXRQzT1KzOIiIiIiIiIiIiIl2Z5ugVERERERERERER6eKU6BURERERERERERHp4pToFREREREREREREenilOgVERERERERERER6eKU6BURERERERERERHp4pToFREREREREREREenilOgVERERERERERER6eKU6BURERERERERERHp4pToFREREREREREREenilOgVERERERERERER6eKU6BURERERERERERHp4pToFREREREREREREenilOgVERERERERERER6eL+H1bT0nJK2NOgAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "f_hist = np.stack(hist[\"f\"]) # (n_iters, n)\n", + "g_hist = np.stack(hist[\"g\"]) # (n_iters, n)\n", + "alpha_hist = np.stack(hist[\"alphas\"]) # (n_iters, K+L)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(13.0, 3.6))\n", + "sample_idx = [0, 20, 40, 60, 80]\n", + "for k in sample_idx:\n", + " axes[0].plot(f_hist[:, k], label=f\"$f_{{{k}}}$\", alpha=0.85)\n", + "axes[0].set_xlabel(\"outer iteration\")\n", + "axes[0].set_title(r\"Marginal duals $f_i$\")\n", + "axes[0].legend(fontsize=8, ncol=2, loc=\"upper right\")\n", + "\n", + "for k in sample_idx:\n", + " axes[1].plot(g_hist[:, k], label=f\"$g_{{{k}}}$\", alpha=0.85)\n", + "axes[1].set_xlabel(\"outer iteration\")\n", + "axes[1].set_title(r\"Marginal duals $g_j$\")\n", + "axes[1].legend(fontsize=8, ncol=2, loc=\"upper right\")\n", + "\n", + "axes[2].plot(\n", + " alpha_hist[:, 0], color=\"C2\", linewidth=2, label=r\"$\\alpha_1$ (inequality)\"\n", + ")\n", + "axes[2].plot(\n", + " alpha_hist[:, 1], color=\"C3\", linewidth=2, label=r\"$\\alpha_2$ (equality)\"\n", + ")\n", + "axes[2].axhline(0, color=\"k\", linewidth=0.5, alpha=0.3)\n", + "axes[2].set_xlabel(\"outer iteration\")\n", + "axes[2].set_title(r\"Constraint duals $\\alpha$\")\n", + "axes[2].legend(fontsize=10)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "0750844e", + "metadata": {}, + "source": [ + "The constraint dual $\\alpha_2$ corresponding to the equality constraint settles at a non-zero value, which is normal : equality multipliers can have either sign. The inequality dual $\\alpha_1$ is positive and bounded, consistent with a binding upper-bound constraint. We will revisit the *interpretation* of these values as Lagrangian shadow prices in §9." + ] + }, + { + "cell_type": "markdown", + "id": "3158d605", + "metadata": {}, + "source": [ + "## 5. Theorem 1 in practice : convergence to the LP optimum as $\\varepsilon \\to 0$\n", + "\n", + "The theoretical centerpiece of the paper is **Theorem 1**: the entropy-regularised solution $P_\\varepsilon^\\star$ is exponentially close to the LP optimum $P^\\star$ as the regularisation $\\varepsilon$ shrinks,\n", + "$$\\|P_\\varepsilon^\\star - P^\\star\\|_1 \\;\\le\\; 8 n^{2(K+1)} (K+1)\\, \\exp\\!\\Big(-\\,\\frac{\\Delta}{\\varepsilon\\,(K+1)}\\Big),$$\n", + "where $\\Delta$ is a problem-dependent vertex gap. Because we have a small enough problem, we can compute the *exact* LP optimum with `scipy.optimize.linprog` and verify the exponential decay numerically." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "925adc10", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:20.147795Z", + "iopub.status.busy": "2026-05-02T06:39:20.147240Z", + "iopub.status.idle": "2026-05-02T06:39:20.161736Z", + "shell.execute_reply": "2026-05-02T06:39:20.160986Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LP optimal cost: 0.057710\n" + ] + } + ], + "source": [ + "from scipy.optimize import linprog\n", + "\n", + "# Smaller instance so the LP solver is cheap.\n", + "rng_th = np.random.default_rng(0)\n", + "n_th = 30\n", + "C_th = rng_th.uniform(0, 1, (n_th, n_th))\n", + "DI_th = rng_th.uniform(0, 1, (n_th, n_th))\n", + "a_th = np.ones(n_th) / n_th\n", + "b_th = np.ones(n_th) / n_th\n", + "t_I_th = 0.5\n", + "assert abs(a_th.sum() - b_th.sum()) < 1e-12, \"marginals must have equal mass\"\n", + "\n", + "# Solve the constrained LP exactly:\n", + "# min c^T x s.t. A_eq x = b_eq, A_ub x <= b_ub, x >= 0\n", + "N_var = n_th * n_th\n", + "A_eq_rows = np.zeros((n_th, N_var))\n", + "for i in range(n_th):\n", + " A_eq_rows[i, i * n_th : (i + 1) * n_th] = 1.0\n", + "A_eq_cols = np.zeros((n_th, N_var))\n", + "for j in range(n_th):\n", + " for i in range(n_th):\n", + " A_eq_cols[j, i * n_th + j] = 1.0\n", + "A_eq_lp = np.vstack(\n", + " [A_eq_rows, A_eq_cols[:-1]]\n", + ") # one redundant constraint dropped\n", + "b_eq_lp = np.concatenate([a_th, b_th[:-1]])\n", + "A_ub_lp = DI_th.flatten()[None, :]\n", + "b_ub_lp = np.array([t_I_th])\n", + "\n", + "lp_res = linprog(\n", + " C_th.flatten(),\n", + " A_ub=A_ub_lp,\n", + " b_ub=b_ub_lp,\n", + " A_eq=A_eq_lp,\n", + " b_eq=b_eq_lp,\n", + " bounds=(0, None),\n", + " method=\"highs\",\n", + ")\n", + "P_star_lp = lp_res.x.reshape(n_th, n_th)\n", + "cost_star_lp = float((P_star_lp * C_th).sum())\n", + "print(f\"LP optimal cost: {cost_star_lp:.6f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8ce64055", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:20.163726Z", + "iopub.status.busy": "2026-05-02T06:39:20.163196Z", + "iopub.status.idle": "2026-05-02T06:39:44.580956Z", + "shell.execute_reply": "2026-05-02T06:39:44.579821Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " gamma = 3.0 (eps = 0.3333): cost = 0.310509, gap = 2.528e-01\n", + " gamma = 5.0 (eps = 0.2): cost = 0.224522, gap = 1.668e-01\n", + " gamma = 10.0 (eps = 0.1): cost = 0.131429, gap = 7.372e-02\n", + " gamma = 20.0 (eps = 0.05): cost = 0.085137, gap = 2.743e-02\n", + " gamma = 30.0 (eps = 0.03333): cost = 0.072030, gap = 1.432e-02\n", + " gamma = 50.0 (eps = 0.02): cost = 0.063680, gap = 5.970e-03\n", + " gamma = 70.0 (eps = 0.01429): cost = 0.061040, gap = 3.330e-03\n", + " gamma = 100.0 (eps = 0.01): cost = 0.059480, gap = 1.770e-03\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAyoAAAHVCAYAAAAekg8iAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAAAgF1JREFUeJzt3XlcVOX+B/DPzMCwM4DssqOCuADirimaWtotLS0rc8/SzNab97apebO6P8q6ZmabmaW2Z6ZpWqLmvuK+IiCr7Pt+5vz+QEYYQGaGWeHzfr18+eKc55zzPXO+HPhyzvM8ElEURRAREREREZkRqakDICIiIiIiUsdChYiIiIiIzA4LFSIiIiIiMjssVIiIiIiIyOywUCEiIiIiIrPDQoWIiIiIiMwOCxUiIiIiIjI7LFSIiIiIiMjssFAhIiIiIiKzw0KFiIiIiIjMDgsVIiIiIiIyOyxUiIiIiIjI7LBQIdKT5ORkSCQSzJgxw9ShNNJcXOYaq6WKjY2FRCIxdRhEKvr6Hm9pP4a+h1jqPWrOnDno1KkTioqKdN5H/bnHxsbqLzDSSmpqKmxtbfHaa6+ZOpQOj4UKtQsJCQl4/PHH0bVrV9jb28PR0RERERFYsGABrl69qpdj7N692yJ/cGqjI5wjUXvA71Xzk5CQgDVr1mDhwoVQKBQAAIlEotW/tWvXmvYk2qkff/wRCxYswB133AFnZ2dIJBI8/PDDLbb39/fH3Llz8d577yElJcWIkZI6K1MHQNQWoiji9ddfx1tvvQWpVIo777wTEyZMgFKpxNGjR7Fy5UqsXr0a//vf//DUU08ZNJbOnTvjwoULqh9Q5sySYiUi7Vn697glxv/KK6/A3t4eTz/9tGrZ4sWLm7Rbu3YtUlJSMH36dAQFBTVaFxUVZeAoO6Y333wTp06dgqOjI/z8/HDx4sVWt1m4cCFWrlyJpUuX4osvvjBClNQcFipk0d566y0sW7YM/v7+2Lx5c5ObfHx8PCZOnIj58+dDoVBgypQpBovF2toa4eHhBtu/PllSrESkPUv/Hre0+K9evYrt27djxowZcHBwUC1fsmRJk7a7d+9GSkoKZsyY0ezrXcnJyYYLtIN6//334efnhy5dumDPnj0YMWJEq9v4+vpi9OjR2LhxI9599124uroaIVJSx1e/yGKlpKTgjTfegJWVFX777bdm/xI1YsQIfP311wCAZ599FqWlpQAavwOcnp6Oxx57DB4eHrCzs0Pfvn3x3XffNdrPkiVLVDe2r776qtlH9a29yx0bG4sbN25g1qxZ8PLygoODAwYPHoy///4bAFBaWooXXngBAQEBsLGxQY8ePfDDDz80Oae1a9di4sSJCAkJgZ2dHZydnTFkyBB88803Gn92zcXa2jmeO3cOEonktjf4oUOHQiKRaPTXKlEU8fHHH6NXr16wtbWFl5cXpk+fjszMzBb7fGhz7tpeY0188skn6N27tyreGTNmICsr67bbnDhxAg8//DB8fX0hl8vh4+ODqVOn3vaVxKNHj+Lhhx9G586dIZfL4e3tjZEjR+Krr77S6fPQ9dr99ttvGDVqFHx9fWFjYwMvLy8MGDAAb7311m3PWdfPYNKkSZBIJFi+fHmTfbz77ruQSCSYMGGCallbrvHPP/+MESNGwMXFBba2tujevTsWLVqkukc0d4zc3Fw88cQT8PHxUX2Pfvnll20+b22P0Zb7EaCfe4g6fd8fbte3TpdroW7z5s2wtrZG3759W/xeXLFiBSQSCf7zn/+0ur8vvvgCoihi8uTJGsegCUPlXD1dvg8yMjIwa9Ys+Pj4QCaTYdOmTTrFoO+fjbczYsQIdO3aVeu+hA8//DAqKiqwYcMGrbYjPRKJLNSiRYtEAOKDDz7Yatu+ffuKAMQvv/xSFEVRTEpKEgGIvXv3FgMDA8WoqChx4cKF4pw5c0SFQiECEJcvX67aPj4+Xpw+fboIQIyMjBQXL16s+nfy5MlG+5w+fXqjY9cvj4yMFENDQ8V+/fqJzz//vPjggw+KEolEtLOzExMSEsR+/fqJ3bt3FxcsWCBOnz5dlMvlokQiEQ8ePNhof7a2tmKfPn3E6dOni//+97/FOXPmiJ07dxYBiK+99lqTc28uruaWaXKOw4cPFwGIFy9ebHKcs2fPigDE2NjYVq+HKIriM888IwIQvby8xKeeekp86aWXxN69e4vBwcFiZGSk2NztSZtz1/Yat+aFF14QAYgeHh7ivHnzxJdeekns2bOnGBwcLPbu3bvZeL/55hvRyspKtLe3Fx9++GHxpZdeEu+//35RJpOJCoVC9bk29Nlnn4kymUy0trYWH3jgAfHll18W58yZI8bExIhRUVE6fx7aXrtPPvlEdX1mz56timPIkCGiv7+/xp+bNp9BYWGhGBwcLFpbW4uHDh1SLT906JBobW0tBgYGivn5+arlul7j119/XQQgurm5iU888YS4cOFCMTo6WpX7xcXFTY4RGRkpduvWTezZs6f49NNPi48//rjo4uIiAhDXrl3bpvPW9hhtuR+Jom7fRy3d1xou1+f94Xb3LW2vRXMOHDgg3nPPPSIAcebMmU3WZ2ZmigqFQuzSpYtYWVnZ6v5iYmJEqVQqFhUVtdq2/nOKj49vdr0xck4Udfs+6NmzpxgQECD26tVLXLBggTh//nxxz549OsWg75+NmoqPjxcBiJMnT2617ZUrV0QA4vjx43U6FrUdCxWyWCNHjhQBiJ9++mmrbV9++WURgDh79mxRFG/dIAGIDz30kCgIgqrt1atXRYVCIcrlcjE5OVm1vP7m1twP/ob7bOkHOgDxmWeeEZVKpWrdsmXLRACii4uL+MADD4hVVVWqdevXrxcBiBMmTGi0v6tXrzY5dlVVlThy5EjRyspKTEtLazWulmJt7Ry///57EYD43HPPNVn39NNPiwDEb7/9ttltG9q3b58IQAwJCRFzc3NVywVBECdNmqT6vNRpc+66XOOWHDx4UAQgBgQEiDdu3FAtr62tFcePH99svFeuXBFtbGzEkJCQJtckPj5elMlkYp8+fRotP3funGhlZSUqFArx9OnTTeK4fv26zp+HtteuT58+olwuF7Oyspq0z8nJabKsObp8BocPHxatra3FoKAgsaCgQCwoKBCDgoJEKysr8cCBA43a6nKNDx48KEokErFz585ienq6arlSqRSnTZsmAhDnz5/f7DFmz54t1tbWqtadO3dOlMlkYvfu3dt03rocQ9f7Uf1no6617yNNChV93R9a2r8un9PtVFdXi87OzmLv3r2brJsyZYoIQNy+fXur+yktLRVlMpkYHh6u0XE1LVQMmXNt+T6YOnWqWFNT0yTutuS9Pn42akqbQkUURdHFxUV0dXXV6VjUdixUyGJ1795dBCBu27at1barVq0SAYhjx44VRfHWDVImk4nXrl1r0r6+sHnzzTdVy9paqDg4OIilpaWN1l2/fl11o05KSmq0rra2VvULmyZ++uknEYD41VdftRqXroVKTU2N6OPjI7q6uooVFRWq5WVlZaKLi4vo5eUlVldXtxrr448/LgIQP//88ybrkpKSRJlM1myh0pLmzl2Xa9ySOXPmiADE1atXN1l35coVUSqVNom3/gnMr7/+2uw+77//fhGAeO7cOdWy+l/m/u///q/VmG6nuc9D22vXp08f0d7eXszLy9M5Dl0+A1EUxffee08EIN5///3ihAkTWvxMdLnG9bn30UcfNWmflZUl2tnZiQ4ODqrPov4Y9vb2zf61fNiwYSKARn991va8dTlGWwqVltzu+0iTQkVf94eW9q/L59SawYMHizY2No2Kgd27d4sAxEmTJmm0j0uXLokAxBEjRmjUXtNCxZA5p+v3gVwub/THmoZ0zXtD/2xUp22hEh4eLgIQS0pKdDoetQ0705PFEkURALR651S9bUBAAIKDg5u0Gz58ON5++22cPHmybUE20LVr10adLAHAx8cHAODi4tJk9BeZTAZPT0+kpaU1Wn79+nX897//xZ9//onU1FRUVFQ0Wp+enq63mNVZWVnhiSeewBtvvIHvvvsO06dPBwB89913KCwsxMsvvwxra+tW93PixAkAwB133NFkXVBQEPz8/JodElKXc9fHNT5+/DgANNvxtUuXLvDz88P169cbLd+/fz8AYO/evarzbejGjRsAgAsXLiAiIgIAcOjQIQDA2LFjW40J0O7z0PbaTZ06Fc8//zwiIiIwefJkDBs2DIMHD1blrCZ0+QwA4IUXXsDu3bvxyy+/AKj7PP75z3+2eBxtrnF9HCNHjmzS3svLC7169cKRI0dw+fJl9OjRQ7Wua9eucHZ2brKNv78/AKCwsBBOTk5tOm9tjtEWhrqH6Ov+0Bp9fk7du3fHgQMHcO3aNXTt2hU1NTV46qmn4OjoiPfff1+jfeTl5QGA3jtbGzLndP0+CAoKgqenZ7PxtiXv9fGz0VDc3NwA1PUXcnR0NMox6RYWKmSxfHx8cPHixSa/IDYnNTVVtU1DXl5ezbavX96WSbvUNTfMppWVVYvr6tfX1taqvr527Rr69++PgoIC3HHHHbjrrrugUCggk8mQnJyMr776ClVVVXqLuTlPPPEEli1bhk8++UT1i8jq1ashlUrxxBNPaLSP+s+1pc/f29u7SaGi67nr4xrXt/H29m52vY+PT5M8rP/l5b333rvtvht2Wi0sLARQNzRra3T5PLS5ds899xw8PT2xatUqfPTRR1ixYgUAYODAgXj77bc1moxOl8+g3sSJE/Hbb7+pYrndHyS0ucaaXEvg1rWo5+Li0mz7+u9hQRBUy3Q9b22OoStD30P0cX9ojT4/p+7duwOo++W5a9eu+OCDD3D+/HnExcXBz89Po33Y2dkBACorKzU+riYMmXO6fh+01F6XGOrp42ejIdUX8vXXmYyLhQpZrKFDhyI+Ph47d+7EnDlzbtv2zz//VG3TUP1feNTVLze3MfyXL1+OvLw8rFmzBjNnzmy0buPGjU1GhTIEX19f3H///fjhhx9w+vRpKJVKHDlyBGPHjm3yl6+W1H+uN27caPYzbm4kLV3PXR/XuL5NVlZWs+0zMzNb3CYvL0/1F7nW1P9ikp6e3upfZ3X5PLS9do8++igeffRRFBcX4+DBg/jtt9/w2WefYezYsTh16hS6det22xh1+QwAICkpCc8++ywUCgXKysowb948nDx5stm/LgPaXeOG17K5XwTrr2Vbvvd1PW9jMPQ9RB/3B2OqL1TOnz+PPn36YOnSpejRoweee+45jfdRXxDX/6JuCtrmnK7fB7f7g4E5531b5OXlQSqVwt3d3dShdEgcnpgs1syZM2FlZYVNmzbhzJkzLbbbtm0bjh49Cjc3N0yaNKnRuuvXrzc7Zv2ePXsAANHR0aplMpkMgH7+qqmr+uEd1c8DuBVzW2h6jvWTZ65evRqffPIJAGDevHkaH6f+c60ffrKh5OTkZh/p63ru2lzjlsTExACom/+gubiai3fQoEEAmj/HlgwcOBBAXc62RtfPQ5dr5+zsjLvuugsrV67Eiy++iMrKSmzfvr3VGHX5DKqrqzF58mQUFRXhq6++wptvvolr167h8ccfb3Ebba5xnz59ANTNsaQuJycHZ8+ehYODA8LCwjSOWZ0u560tXe9Hhr6HAG2/PxhTw0Ll+eefR2lpKVatWqX6i74mfHx84OnpicuXL6teSTY2bXPOEN8Hxsh7YysrK0N6ejp69uyp+p4j42KhQhYrODgYr732GmpqanDvvffi9OnTTdrs2bMHjz32GIC6MfHV3y8VBAH/+te/oFQqVcsSExOxatUqWFtbN5ogslOnTgCg0atmhlL/F0n1X5j/+OMPfP75523ev6bnGBsbix49emD9+vVYv349/P39MW7cOI2PUz83wltvvdXor5BKpRILFy5s9pcvXc9dm2vckvq/PL/11lvIzs5utO9//vOfjfZd7+mnn4ZcLseLL77Y7LwRgiA0OZd58+bBysoKy5Ytw9mzZ5ts07Ag0vXz0PTabdu2DTU1NU2W1z/t0uQ1CF0+g4ULF+Lo0aN45plnMH78eCxcuBB33XUXfvjhB6xevbrZ42hzjWfNmgWg7lo2fHIniiIWLlyI8vJyTJ8+vU19KXQ5b23pej8y9D0EaPv9wZgCAwNhZ2eH3377DT/++COmTp2KYcOGab2f2NhYFBQU4NKlSwaIsnXa5pwhvg+MkffGduTIEQiCoNEEkWQYfPWLLNqiRYtQUVGB//73v+jTpw9GjRqF3r17Q6lU4tixY9izZw+srKywcuXKZn8h7d27Nw4fPoy+fftizJgxyM/Px/fff4+ioiIsX7680asKYWFh8Pf3x99//40pU6agW7dukMlkuO+++9C7d2+jnO9TTz2FL7/8Eg8++CAmTpyIzp074+zZs9i+fTseeughnSYxbEibc3zqqacwf/58AMBLL72k1V+bhg4divnz5+Ojjz5Cjx49MGnSJDg4OOCPP/5AcXExIiMjmxSeup67Nte4JYMGDcILL7yA5cuXo2fPnnjwwQfh4OCAbdu2oaysDL17924Sb1hYGNauXYuZM2eiZ8+euPvuu9GtWzcIgoDU1FTs378fVVVVjd4Bj4iIwKpVqzB37lzExMTg3nvvRbdu3VBQUICTJ0+iqqpK1TG8LbmgybV75JFHIJfLcccddyAoKAgSiQRHjhzB33//jdDQUDz00EOtfm7afgabN2/G//73P/Tp0wdxcXEA6l41+frrrxEZGYnnn38egwcPbpKL2lzjQYMG4eWXX8bbb7+tupYKhQI7d+7EiRMn0KtXL60ntGzreet6DF3uR4a+hzQ8jq73B2OSSqXo1q0bTp06BYVCoco7bU2aNAnff/89/vjjD4SHh+s5ytZpm3OG+D4wRt7ratOmTaqJKesLs8OHDzeaULR+stSG/vjjDwB1febIREw76BiRfhw/flycOXOmGBISItrZ2Yn29vZiWFiYOH/+fPHy5ctN2tcPizh8+HAxLS1NfPTRR0V3d3fRxsZG7NOnj7hx48YWjzNq1ChRoVCIEolEBJpOItnSMJ7Dhw9vdp8AxMDAwGbXBQYGNhn2dv/+/eKIESNEFxcX0dHRURwyZIj4yy+/qIZcXLx4cbPH12R44tbOsaHi4mLR2tpatLKyEjMyMpqN/3aUSqW4cuVKsUePHqJcLhc9PT3FadOmiRkZGWKPHj1EhULRZBttzl3Xa3w7q1evFnv27Cna2NiInp6e4vTp08XMzEzVcKPNOXfunDh79mwxKChIlMvlokKhEMPDw8Xp06eLv/32W7PbHDx4UJw4caLo5eUlWltbi15eXuLIkSPFdevW6fx5NKTJtfv444/F+++/XwwJCRHt7e1FhUIh9urVS1y8eHGjuW80oclnkJKSIrq6uopOTk7ilStXmuxj165dolQqFcPCwlTDhLblGn///ffisGHDRCcnJ1Eul4thYWHiq6++2mRo29a+f+snXlQfQlXT827LMXS5H4mibt9HmgxP3FBb7w+3u2/pci1uZ/LkySIA8cMPP9Q6zno1NTWir69vk0lZm6Pp8MSGzLl6+vo+0CUGff9svJ3Fixerhjtu6Z86QRDEzp07iz179tT4OKR/LFSoQ9LmpkvNq58EceLEiXrdb2FhoWhraysOHDiwTfvhNW6Zoa6dsfEamy9LyrExY8aIALQuwNXFxcWJAHSeMZ3My6ZNm0QA4tq1a00dSofGPipEpJN33nkHALBgwQKdts/KymrSt6OmpgbPP/88Kisr+ajdgNp67YhaY0k5duLECQQEBKj6/ejqmWeeQWhoKF5//XU9RUamIooiFi9ejJiYGEybNs3U4XRo7KNCRBo7c+YMtmzZgpMnT+LXX3/FqFGjMHz4cJ32tXLlSnz11VeIjY1F586dkZOTgz179iAxMRExMTF4+umn9Rx9x6bPa0fUHEvMsZSUFOTm5mLChAlt3pdcLse6deuwc+dOFBUVmd3w9qS5jIwMTJgwARMmTNBqUmnSPxYqRKSx48eP45VXXoGzszPuv/9+fPzxxzrv684770RCQgL++usv5OfnQyqVIjQ0FIsXL8ZLL70EW1tbPUZO+rx2RM2xxBw7fvw4gFvD9bbV4MGDMXjwYL3si0ync+fOWLJkianDIAASUTTRoN9EREREREQtYB8VIiIiIiIyOyxUiIiIiIjI7LBQMZLa2lqkpaWhtrbW1KEQEREREZk9FipGkpWVBX9/f9WMqK0RBAFZWVkQBMHAkZGlYE6QOuYEqWNOkDrmBDVkafnAQoWIiIiIiMwOCxUiIiIiIjI7LFSIiIiIiMjssFAhIiIiIiKzw5npjUwQBI06MAmCAKVSaTGdncjwmBOkjjlB6pgTpI45QQ2ZSz7IZDKN2rFQMbC4uDjExcWpEiIvLw82NjatbqdUKlFUVAQAkEr54IuYE9QUc4LUMSdIHXOCGjKXfPD29taonUQURdHAsRCAtLQ0+Pv7Izk5GX5+fq22FwQBubm5cHd317jqpPaNOUHqmBOkjjlB6pgT1JC55AOfqJgpmUym8cWRSqVataf2jzlB6pgTpI45QeqYE9SQJeUDnwESEREREZHZ4ROVDqCwvBpHkvJRWlULRxsr9A92g4u93NRhERERERG1iIVKO5aYU4qPdydic0IGqgWlarlcJsX4KF/MjQ1FqIejCSMkIiIiImoeC5V26tC1PMxaexQVNQLUh0uoFpT48Xgatp7JxJoZ/TAwpJNpgiQiIiIiagH7qLRDiTmlmLX2KCqbKVLqiQAqawTMWnsUiTmlRo2PiIiIiKg1LFTaodW7E1FRI0DZysDTShGoqBbwyZ5E4wRGRERERKQhFirtTGF5NX5NyGjxSYo6EcCmkxkoLK82aFxERERERNpgodLOHEnKb9RxXhPVghJHkvINFBERERERkfZYqLQzpVW1Rt2OiIiIiMgQWKi0M442ug3kput2RERERESGwEKlnekf7Aa5TLvLKpdJ0T/YzUARERERERFpj4VKO+NiL8f4KF9IJJq1lwCYEO3LmeqJiIiIyKywUGmH5saGws5aBmkrxYpUAtjJZXhyeKhxAiMiIiIi0hALlXYo1MMRa2b0g621DLerVWysZFgzox9CPRyNFhsRERERkSZYqLRTA0M64bcFQzEpxq/FPit39fDCwJBORo6MiIiIiKh1HOqpHQv1cETcg5F49Z7uOJKUj9KqWlzMKsane5MAANvOZuHVkip4ONmYOFIiIiIiosZYqHQALvZyjOnhDQAQlCL+PJ+Na7llqKpV4ot9Sfj32HATR0hERERE1Bhf/epgZFIJ5sbe6jz/zaEUFJXXmDAiIiIiIqKmWKh0QBOiOsNXYQugbkb6dQeTTRsQEREREZEaFiodkNxKiieGhai+XrM/CeXVtSaMiIiIiIioMRYqHdTkfgHo5FA3yWNBeQ02Hkk1cURERERERLewUOmg7OQyzBoarPr6s73XUFUrmDAiIiIiIqJbWKh0YFMHBcLJpm7gt6ziSvxyIt3EERERERER1eHwxEYmCAIEofUnF4IgQKlUatRWVw7WUjw2MAAf77kGAFi1OxH3R/nAqoUJIsm0jJETZFmYE6SOOUHqmBPUkLnkg0wm06gdCxUDi4uLQ1xcnCoh8vLyYGPT+gSLSqUSRUVFAACp1HCFw31hjlizX4KqWhHX88vx3YErGBPuZrDjke6MlRNkOZgTpI45QeqYE9SQueSDt7e3Ru0koiiKBo6FAKSlpcHf3x/Jycnw8/Nrtb0gCMjNzYW7u7vGVaeu/rP1AtYeSAEAdPNyxNanh0AqlRj0mKQ9Y+YEWQbmBKljTpA65gQ1ZC75wCcqZkomk2l8caRSqVbtdfXEsFCsP3wdNYKIyzdKsedKHkZFeBn0mKQbY+UEWQ7mBKljTpA65gQ1ZEn5wGeABF8XOzwQfespz8r4q+CDNiIiIiIyJRYqBACYGxuK+re9ElILcfBanmkDIiIiIqIOja9+EQAg2N0B43r5YMvpTADAB39eRmllLUqrauFoY4X+wW5wsZebOEoiIiIi6ihYqJDKU7FdVIXKkaQCHEk6rlonl0kxPsoXc2NDEerhaKoQiYiIiKiD4KtfpFJcWQOZpPnRvqoFJX48noZ7P9yHQ3wtjIiIiIgMjIUKAQASc0oxa+1RKNFyJ3oRQGWNgFlrjyIxp9R4wRERERFRh8NChQAAq3cnoqJGQGuDfSlFoKJawCd7Eo0TGBERERF1SCxUCIXl1fg1IaPVIqWeCGDTyQwUllcbNC4iIiIi6rhYqBCOJOWjWlBqtU21oMSRpHwDRUREREREHR0LFUJpVa1RtyMiIiIiag0LFYKjjW6jVOu6HRERERFRa1ioEPoHu0Eu0y4V5DIp+ge7GSgiIiIiIuroWKgQXOzlGB/lixamUGlCAmBCtC9nqiciIiIig2GhQgCAubGhsLOWQapBsWInl+HJ4aGGD4qIiIiIOiwWKgQACPVwxJoZ/WBrLUNrtcqzd3ZFqIejUeIiIiIioo6JhQqpDAzphN8WDMWkGL8mfVYavha28ch1VNYIRo6OiIiIiDoSDttEjYR6OCLuwUi8ek93HEnKR2lVLRxtrNDZxQ4TVx9AZY0SyXnl+HTvNTxzZ1dTh0tERERE7RQLFWqWi70cY3p4N1r29IgueHfHZQDAR/FXMSGqMwI62ZsiPCIiIiJq5/jqF2lszrAQhLg7AACqapVYvPksRFE0cVRERERE1B6xUCGN2VjJsHR8T9XX8ZdysOP8DRNGRERERETtFQsV0srQru74R28f1ddLfzuP8upaE0ZERERERO0RCxXS2mv3RMBBLgMApBdW4MNdV00cERERERG1NyxUSGveCls8P7qb6uvP/76Gq9klJoyIiIiIiNobFiqkkxmDgxDu7QQAqBFEvL7pHDvWExEREZHesFAhnVjJpPjPhFsd6w9ey8PmUxkmjIiIiIiI2hMWKqSzfkFumBTjp/r6za0XUFxZY8KIiIiIiKi9YKFCbfLvseFwtq2bNzSnpArv77xs4oiIiIiIqD1goUJt4u5og4V3h6u+/upAMs5lFJkwIiIiIiJqD1ioUJs90j8Avf0UAAClCLy+6SyUSnasJyIiIiLdsVChNpNJJXhzQk9IJHVfn7heiB+Pp5k2KCIiIiKyaCxUSC96+7lgyoAA1ddvb7uAgrJqE0ZERERERJaMhQrpzUtjwtHJQQ4AKCivwf/9ccnEERERERGRpWKhQnqjsLfGy+O6q77+9uh1nLxeYMKIiIiIiMhSsVAhvZrYpzP6B7kBAEQReG3TWQjsWE9EREREWmKhQnolkUiwdEIPyKR1PevPZRTjm0MpJo6KiIiIiCwNCxUNrFy5EtHR0bCyssKSJUtMHY7ZC/d2xszBQaqv391xCTklVaYLiIiIiIgsDgsVDXTu3BlLly7F/fffb+pQLMZzo7vBy9kGAFBSWYu3f79g4oiIiIiIyJKwUNHA/fffj3vvvRcKhcLUoVgMRxsrvP6PCNXXP59Mx6FreSaMiIiIiIgsicUUKqWlpVi8eDHGjRsHDw8PSCQSvPPOO822rampwaJFixAQEABbW1v07t0bGzZsMHLEdE8vH9zR1V319eubzqJGUJowIiIiIiKyFBZTqOTm5mLp0qU4c+YMoqOjb9v2iSeewLJlyzBhwgR8+OGH8Pf3x5QpU7Bu3TojRUtAXcf6N+7rAbmsLs2uZJdizb4kE0dFRERERJbAYgoVHx8fpKenIzU1FZ9++mmL7U6ePIm1a9diyZIlWLFiBebMmYMtW7YgNjYWL730Eqqrb82WPmbMGNja2jb7b968ecY4rXYvxMMRTw4PUX39v7+uIKOwwoQREREREZElsJhCxcbGBr6+vq22+/777yGVSjF//nzVMolEgqeffhrZ2dmIj49XLd+xYwcqKyub/ffxxx8b5Dw6oqdiu8DP1Q4AUF4t4D9bzps4IiIiIiIyd1amDkDfjh8/jtDQULi5uTVaPmDAAADAiRMncNddd2m1z9raWtTW1kIQBNTW1qKyshLW1taQyWQtblNcXIzi4mLV15mZmQAAQRAgCEKrxxQEAUqlUqO25k4uAxb/ozvmfH0CALDtbBZ2XcjC8G4eJo7MsrSnnCD9YE6QOuYEqWNOUEPmkg+3+x26oXZXqGRkZMDHx6fJ8vqnMRkZGVrv880338Qbb7yh+nrZsmX48ssvMWPGjBa3Wb58eaNt6uXl5cHGxqbVYyqVShQVFQEApFKLefDVol6dJLgjRIG/r9Wd06JNZ7F+agRsrCz/3IylveUEtR1zgtQxJ0gdc4IaMpd88Pb21qhduytUKioqmi0EpFIprK2tUVGhff+IJUuWaD3R4wsvvIDHH39c9XVmZib69++PTp06wcOj9ScJ9ZWuu7u7xlWnuXvzAQfc9b99qKxRIq2oCr9cKMGCkV1MHZbFaI85QW3DnCB1zAlSx5yghiwtH9pdoWJra4uqqqazoCuVStTU1MDW1tYocTg7O8PZ2bnJcplMpnFiSKVSrdqbu0B3JywY2RVxf1wCAKzacw339/FDYCcHE0dmOdpbTlDbMSdIHXOC1DEnqCFLyod29wzQ19dX1R+kofpXvjTpkE+G8/gdwQhxrytMqmuVWLL5HERRNHFURERERGRu2l2h0qdPHyQmJiI/P7/R8sOHD6vWk+nYWMmwdHxP1dfxl3Kw4/wNE0ZEREREROao3RUqDz30EJRKJVatWqVaJooiVq5cCQ8PD4wYMcKE0READO3qjn/0vjXgwRubz6G8utaEERERERGRubGoPiorV65EYWEhCgsLAQDx8fGora37BXfBggVQKBSIiYnB1KlTsXjxYuTk5KBXr17YtGkTdu/ejTVr1mg04pYhdcThiZvzytgw7L6UjdIqARlFlVjx52W8dFeYqcMya+09J0h7zAlSx5wgdcwJashc8kHT/jES0YI6CAQFBSElJaXZdUlJSQgKCgIAVFdX4z//+Q/Wrl2L7OxsdOvWDf/617/w2GOPGTHaOnFxcYiLi4MgCMjPz8fx48c16idTP3ycQqFot8MJbjxxA//bmwYAkEmBbx6LQLCbnYmjMl8dISdIO8wJUsecIHXMCWrIXPJB0+GJLapQsWRpaWnw9/dHcnIy/Pz8Wm0vCAJyc3MtZvg4XdQKSoxfdRAXs0oAAAND3PDNrH6QSCQmjsw8dYScIO0wJ0gdc4LUMSeoIXPJhw474aO568jDE6uTyWR4c0JPTFp9EABw6Fo+tp69gfFRnU0cmflq7zlB2mNOkDrmBKljTlBDlpQPfAZIJtU3yA0Pxtx6wvTm1gsorqwxYUREREREZA5YqJDJ/XtsOBR21gCAnJIqLN9x2cQREREREZGpsVAhk+vkaIOFd98a8WvdwWScyygyYUREREREZGosVMgsPNwvAJF+CgCAUgRe33QWSiXHeSAiIiLqqNiZ3sg4j0rLltwbgQdWH4QoAieuF+LL/dfQ2cUOpVW1cLSxQr8gV7jYy00dpsl0xJyg22NOkDrmBKljTlBD5pIPHPXLTDScRwUA8vLyNJp0sn6cawAdZtxzHxvggV4e+Ol0DgDgP1svNlpvLZPgrjA3TO3rjUA3W1OEaFIdMSfo9pgTpI45QeqYE9SQueQD51ExM5xHRTO7LmZjztcnWlwvAWAnl+HzaTEYEOxmvMDMQEfNCWoZc4LUMSdIHXOCGjKXfOATFTPFeVRalphTime/OwUJgJaqZxFAZY2Ax9cdx28LhiLUw9GIEZpeR8sJah1zgtQxJ0gdc4IasqR84DNAMhurdyeiokZosUippxSBimoBn+xJNEpcRERERGR8LFTILBSWV+PXhAxo+iKiCGDTyQwUllcbNC4iIiIiMg0WKmQWjiTlo1pQarVNtaDEkaR8A0VERERERKbEQoXMQmlVrVG3IyIiIiLzxkKFzIKjjW7jOui6HRERERGZN/6WZ2Sc8LF5MQEKyGUSVAuaj5Ytl0kQE6DoMJ9RR8sJah1zgtQxJ0gdc4IaMpd84PDEZoITPmpuTJgbtp7Pa3XUr3oR3g6oKStCTplBwzIbHTEn6PaYE6SOOUHqmBPUkLnkAyd8NDOc8LF113JKMX7VQVTWCFBqmJXvTuqF+6M7GzYwM9ERc4JujzlB6pgTpI45QQ2ZSz7wiYqZ4oSPLevqrcCaGf0wa+1RVFQ3P5+K5Ob/9ete+ukMJBIpJsa0Xvy1Bx0tJ6h1zAlSx5wgdcwJasiS8oHPAMmsDAzphN8WDMWkGD/IZY3TUy6T4sG+ftgwZwACO9kDAEQR+OePp/Dj8TRThEtEREREBsInKmR2Qj0cEfdgJF69pzuOJOWjtKoWjjZW6B/sBhd7OQDg2ycG4uFPDyElrxyiCLz04ymIoogH+/qbOHoiIiIi0gcWKmS2XOzlGNOj+c5WPgo7fPvEQDzy6SEk3yxWFv50GgBYrBARERG1A3z1iyxWXbEyCEENXgNb+NNpfH8s1cSREREREVFbsVAhi+atsG1SrPzrp9P4/iiLFSIiIiJLxkKFLF59sRLs7gDgZrHyM4sVIiIiIkvGPipGxpnpDcPD0RrfzOqHKV8cadRnRVAq8VDf9jF0MXOC1DEnSB1zgtQxJ6ghc8kHg82jsnTpUq2Dac6iRYv0sh9zx5npjUcG4MP7QzH/p8u4XlAFAHj5l7MoKSnBfT3dTRucHjAnSB1zgtQxJ0gdc4IaMpd8MNjM9MHBwToF1OigEgmuXbvW5v1YEs5Mbzw3iisx5YsjSMotVy17a0IPTO5n2aOBMSdIHXOC1DEnSB1zghoyl3ww2BOVpKQkrYOhWzgzveH5ujrg2ycG4ZFPD+FabhkA4JVN5yCRSvFI/wATR9c2zAlSx5wgdcwJUsecoIYsKR/4DJDaJS9nW3z7xECEeDiolr388xlsOHzdhFERERERkab0Xqho+SYZkcF4Otvi2zkDEdqgWHnlFxYrRERERJZAL4VKamoqPvzwQ4waNQoODg4YMWIEVqxYgeTkZH3snkhnns622PhE02Jl/eEUE0ZFRERERK3RuVA5ffo0li5dij59+iAoKAivv/46PDw8EBcXB19fXyxZsgShoaGIiorCkiVLcPLkSX3GTaQxT6emxcqrv5zFN4dYrBARERGZK50KlbCwMERHR+Ozzz7DoEGDsG3bNuTk5GDjxo2YP38+1q9fj+zsbPzxxx8YNmwY1q5di5iYGAQFBek5fCLN1BcrXTwdVcte23QWX7NYISIiIjJLOk34OHXqVIwdOxYxMTEt79jKCqNGjcKoUaOwYsUKnDx5Eps3b9Y5UKK28nSyxcY5A/HIZ4dwNbsUAPD6prMAgKkDA00ZGhERERGp0alQee2117TeJjo6GtHR0bocjkhvPJxssHHOQDz62SFcaVisiCKmDgoybXBEREREpKLTq1+VlZV45pln8M033+g7HiKD83CywYY5A9G1wWtgr/96DusOJpsuKCIiIiJqRKdCZeXKlfjoo4/g6emp73iIjKK5YmURixUiIiIis6HTq1/fffcdRo8ejTFjxty2XVxcHL777jusW7cOEREROgXY3giCAEEQNGqnVCo1aku6cbO3wjez++GxL46qXgNb9Os5CIIS0waZX58V5gSpY06QOuYEqWNOUEPmkg8ymUyjdhJRhxkaHR0dsWzZMjz77LO3bVdZWYnAwEA88cQT+M9//qPtYdqFuLg4xMXFQRAE5Ofn4/jx4/D19W11O6VSiaKiIigUCkilep+XkxrIL6/Bgp8uIzGvUrXsxVh/PBhlXk8MmROkjjlB6pgTpI45QQ2ZSz54e3tr1E7nQuWDDz7A448/3mrbefPm4cSJEzh8+LC2h2lX0tLS4O/vj+TkZPj5+bXaXhAE5Obmwt3dXeOqk3SXW1qFqWuO4vKNUtWyRf/ojulm9GSFOUHqmBOkjjlB6pgT1JC55IOmx9bp1S8/Pz9cvHhRo7a9e/fGL7/8osth2iWZTKbxxZFKpVq1J915KexvjgZ2GJdulAAAlm65AIlEgplDgk0c3S3MCVLHnCB1zAlSx5yghiwpH3R65jNmzBh8/fXXKCsra7WtRCJBYWGhLochMqpOjjbYMGcAwr2dVMve+O081uxLMmFURERERB2TToXKM888g5KSEkyaNAnl5eW3bZuQkAAfHx+dgiMytk6ONlj/eONiZekWFitERERExqZTodKlSxesWrUKO3fuRFRUFH7++edmRw84fPgw1q5di7vvvrvNgRIZS92TlYFNipUvWKwQERERGY1OfVQAYMaMGXBycsLcuXPx4IMPwtPTE3feeScCAwMhl8tx9uxZbNq0CS4uLnj11Vf1GTORwbk5yLHh5gz2F7Pq+qz8Z8t5iKKIx+8IMXF0RERERO2fzoUKAEycOBHDhw/H+++/j3Xr1mHDhg2N1g8cOBCfffaZRqNcEZmb+mJlyueHcSGzGADw5tYLAMBihYiIiMjA2jyAsru7O5YtW4bU1FRcvXoV8fHx+PPPP5GcnIwDBw6gR48e+oiTyCTcHOTY8PgAdPdxVi17c+sFfLb3mgmjIiIiImr/2vRERV1ISAhCQviXZmpfXG8WK1M+P4zzN5+sLPu97snKnGHMdyIiIiJD4BSlRBpwdZBj/eMDENHgycqy3y/g072JJoyKiIiIqP1ioUKkofpipYfvrWLlrd8v4pM9LFaIiIiI9I2FCpEWmitW3t52EatZrBARERHpFQsVIi252NcVKz073ypW3mGxQkRERKRXeu1M35yRI0fC19cXr7zyCiIiIgx9OLMnCEKzk2M2106pVGrUlozPyUaGr2b0xbQvj+FcRl0H+3e2XYRSqcSTBupgz5wgdcwJUsecIHXMCWrIXPJBJpNp1E4iiqJoyECkUqnq/0ceeQRff/21IQ9nduLi4hAXFwdBEJCfn4/jx4/D19e31e2USiWKioqgUChUnyGZn6LKWjzz8xVcyi5XLXtqSGdM6+et92MxJ0gdc4LUMSdIHXOCGjKXfPD21uz3JIMXKgBQVlaGPXv2YPfu3fi///s/Qx/OLKWlpcHf3x/JyckaTYApCAJyc3Ph7u6ucdVJplFUUYNpa47i7M0nKwDwzzFdMW94qF6Pw5wgdcwJUsecIHXMCWrIXPJB02Mb/NUvAHBwcMC4ceMwbtw4YxzOrMlkMo0vjlQq1ao9mYabowzrHx+IqWsO43RaEQDg3R1XIJFIMX9EF70eizlB6pgTpI45QeqYE9SQJeUDnwES6YHC3hpfzx6A3n4K1bK4Py7ho/irJoyKiIiIyHKxUCHSE4Vd88XKyl1XTBgVERERkWVioUKkR/XFSmSDYuXdHZfx4V8sVoiIiIi0wUKFSM8UdtZYN3sAIv1dVMve23kZK1isEBEREWlM50Jlz549KC0t1WcsRO1G3ZOV/o2KleUsVoiIiIg0pnOhcvToUYwbNw4zZ85EbW2tPmMiahecbZsvVv73J4sVIiIiotboVKh89913+Oyzz7B//35ER0fDysoooxwTWZz6YiWqQbHy/p+X8cGfl00XFBEREZEF0KnCmDx5MkpLS5GdnQ0rKytUV1dDLpfrOzaidsHZ1hrrZvfH9DVHcPJ6IQDggz+vQBSB50d3M21wRERERGZK50cho0ePRkBAgD5jIWq3nG2tsW5Wf0xrUKz872Z/FRYrRERERE3p3EeFRQqRdpxuFivRAS6qZf/76wre38nXwIiIiIjUcXhiIiOqL1b6qBUry3dehiiKpguMiIiIyMywUCEyMidba3w1qz9iAl1Vy1bcfLLCYoWIiIioTpsKlZqaGhQVFekrFqIOo9liZddVPlkhIiIiukmnQiU/Px/33XcfHBwc4Obmhi5dumD9+vX6jo2oXXO0scJXs/qjb4Ni5UMWK0REREQAdCxUXn75ZWzZsgVRUVG45557UFpaimnTpmHjxo36jo+oXXO0scLaZoqV93awWCEiIqKOTafhibdt24bJkyerCpPS0lL84x//wKuvvopHHnlErwG2N4IgQBAEjdoplUqN2pJls7OS4IvpMZj91XEcSykAAKyMvwpBqcSLo7tCIpGgsLwah6/lITO3ED7utRgQ0gku9py7qKPjfYLUMSdIHXOCGjKXfJDJZBq106lQSU9Px1133aX62tHREUuWLMGdd96JxMREhIaG6rLbdikuLg5xcXGqhMjLy4ONjU2r2ymVSlX/H6mUYx50BP+9JxAvbKrBqYxSAMDHe64hM68YNYISOy4XoEaof8KSAmuZBHeFuWFqX28EutmaLmgyKd4nSB1zgtQxJ6ghc8kHb29vjdpJRB3eL5FKpfjmm2/w6KOPqpZlZWXB19cXe/fuxdChQ7XdZbuXlpYGf39/JCcnw8/Pr9X2giAgNzcX7u7uGledZPnKqmoxe91xHE0uaLWtBICdXIbPp8VgQLCb4YMjs8P7BKljTpA65gQ1ZC75YNAnKgCQkZGBmpoaWFtbA4Dq/+rqal132SHIZDKNL45UKtWqPVk+Z3sZ1s7sj4c/PYQz6bcfUU8EUFkj4PF1x/HbgqEI9XA0TpBkVnifIHXMCVLHnKCGLCkfdH7m869//QuOjo6Ijo7G448/js8++wwSicTk77wRWToHGyt08dSs6FCKQEW1gE/2JBo4KiIiIiLj0umJSnx8PE6dOqX6t379elRVVQEAxo4di8DAQPTo0QMRERGq/2NiYvQaOFF7VVheja2nMzVuLwLYdDIDr4zrzg72RERE1G7oVKgMHz4cw4cPV30tCAIuXbrUqHg5duwYtmzZAgB80kKkhSNJ+agWlFptUy0ocSQpH2N6aNY5jYiIiMjc6dxHpSGZTIaIiAhEREQ0Gp44JycHCQkJOH36tD4OQ9QhlFbVGnU7IiIiInOkl0KlJR4eHhg9ejRGjx5tyMMQtSuONrp9W+q6HREREZE50qkzfUFB60On6nM7oo6kf7Ab5DLtvjVlEiDSX2GgiIiIiIiMT6dCxd/fH7Gxsfjggw+QlJR027YpKSlYsWIFRo4cCS8vL52CJOpIXOzlGB/lC4lE820EEXjs8yM4eZ1/DCAiIqL2QadC5auvvoK/vz/+85//oEuXLoiMjMSiRYtw8uRJAEBCQgLeeOMNREdHIyQkBIsXL4a3tze+/vprvQZP1F7NjQ2FnbUMUi2KlSvZpZj48QG89fsFVFRz8AoiIiKybDrNTF9PEATs2bMHv/76KzZv3ozr16/Dzs4OFRUV8PPzw3333Yfx48cjNjYWVlYd+/35+pnpU1NTNZ6ZPicnBx4eHhYxIQ/p36FreZi19igqqgU0901aPzP9Pb18sCkhHTXCrVZBnezx34m9MSCkk9HiJePjfYLUMSdIHXOCGrK0fGhToaIuISEB+/fvx+DBgxEdHa2v3bYLLFRIF4k5pVi9OxG/JmQ0GrJYLpNiQrQvnhweilAPR1zKKsHCH0/hVFrj2eynDQrEwrvD2dG+neJ9gtQxJ0gdc4IasrR80GuhQi1joUJtUVhejUOJucjMLYCPuysGhro3mdyxVlBizf4kvLfjMqpqbxU1nV3s8M7EXrijq4exwyYD432C1DEnSB1zghqytHzQqY8KERmXi70coyO8MLZ7J4yO8Gp2BnormRRPDAvFtmfvQP8gN9Xy9MIKTP3iCBb+eApFFTXGDJuIiIhIZyxUiNqZEA9HfPvEQCwd3wP28lt/Lfn+WBrGvL8HO8/fMGF0RERERJphoULUDkmlEkwbFIQ/nhuGoV3cVctvFFdhzrpjePbbk8gvqzZhhERERES3x0KFqB3zd7PH17P7478Te8GpQYf6XxMyMHr5Hmw5nQF2UyMiIiJzpPdCZdeuXcjKytL3bolIRxKJBJP7BWDnC8NxZ7inanleWTWe3nASc785juziShNGSERERNSU3guV0aNH4/fff9f3bomojbwVtvh8el/87+EouNpbq5b/ce4GRr+/Fz8dT+PTFSIiIjIbei9U+IsOkfmSSCQYH9UZO18Yjnt6+aiWF1XU4MUfTmHGl0eRXlhhwgiJiIiI6rCPClEH5O5og4+m9MHqx/rA3dFGtXzP5Rzc9f5erD+cAqWSf3QgIiIi02GhQtSB3d3TB3++MAwP9OmsWlZaVYtXfzmLKZ8fRkpemQmjIyIioo6MhQpRB+diL8fyh6Lw5Yx+8FHYqpYfvJaHuz7Yiy/2JUHg0xUiIiIyMhYqRAQAGBHuiR3PD8OjAwJUyyprlPjPlvN4cPUBXM0uNWF0RERE1NGwUCEiFSdba7x1fy9seHwA/N3sVMtPXC/EuBV/46P4q6gVlCaMkIiIiDoKq9abkD4JggBBEDRqp1QqNWpLHYMxc2JAsCt+XzAE7+28gq8OpkAUgepaJeL+uITfz2Tivw/0RHcfZ4PHQbfH+wSpY06QOuYENWQu+SCTyTRqx0LFwOLi4hAXF6dKiLy8PNjY2LSyFaBUKlFUVAQAkEr54ItMkxNz+7tjsJ8tlu1MRkpBFQDgXEYxxq86gBn9fDCjvzesZcxPU+F9gtQxJ0gdc4IaMpd88Pb21qidRNTzxCdSqRSff/45Zs2apc/dWry0tDT4+/sjOTkZfn5+rbYXBAG5ublwd3fXuOqk9s2UOVFVI2DFrqv4bF9yo4713bwc8d8HeqG3n8Ko8VAd3idIHXOC1DEnqCFzyQeTPVGJj49HWFiYvnfbbshkMo0vjlQq1ao9tX+mygl7mQz/HheBe3p3xks/nsLFrBIAwOUbpZi4+iDm3BGC50d3g601c9XYeJ8gdcwJUsecoIYsKR/0/sxn+PDhGj/OISLL0stPgc1PD8Xzo7rBWiYBAChF4JO91zD2f3/jaHK+iSMkIiKi9oIvKxKRVuRWUjw7qit+WzC00StfSblleOiTg1j861mUVdWaMEIiIiJqD1ioEJFOwr2d8fO8wfj32HDIrepuJaIIfHUwBXd9sBf7ruSaOEIiIiKyZCxUiEhnVjIp5g4PxbZn70DfQFfV8rSCCjz2xWH8+6fTKK6sMWGEREREZKlYqBBRm4V6OOL7Jwdhyb0RsGvQof7bo6kYvXwP/rpww4TRERERkSXSe6Gi59GOichCSKUSzBgSjB3PD8Pg0E6q5TeKqzD7q2N47tuTKCirNmGEREREZEn0Uqikpqbiww8/xKhRo+Dg4IARI0ZgxYoVSE5O1sfuiciC+LvZY/3jA/D2A73gZHNrBPRNCRkY/f4e/H4m04TRERERkaXQuVA5ffo0li5dij59+iAoKAivv/46PDw8EBcXB19fXyxZsgShoaGIiorCkiVLcPLkSX3GTURmTCKR4JH+AdjxwjCMDPdULc8trcZT609g7tfHkV1SacIIiYiIyNzpVKiEhYUhOjoan332GQYNGoRt27YhJycHGzduxPz587F+/XpkZ2fjjz/+wLBhw7B27VrExMQgKChIz+ETkTnzUdjhi+l98cHkKLjYW6uWbz+XhdHL9+LnE2l8XZSIiIiapdPM9FOnTsXYsWMRExPT8o6trDBq1CiMGjUKK1aswMmTJ7F582adAyUiyySRSDAhujOGdHHH4s1n8fuZLABAUUUNXvj+FH47lYFl9/eCr4udiSMlIiIicyIR+edMo0hLS4O/vz9SU1Ph5+fXantBEJCTkwMPDw/IZLJW21P7115y4vczmVj061nklt7qWO9oY4VXxnXHI/39IZFITBidZWkvOUH6w5wgdcwJasjS8kGnV78qKyvxzDPP4JtvvtF3PETUzo3r5YOdzw/H/dGdVctKq2rxyi9nMOXzw7ieV27C6IiIiMhc6FSorFy5Eh999BE8PT1bb0xEpMbVQY73J0dhzYy+8Ha2VS0/kJiHuz7YizX7kiAomz7sLSyvxo5zWfj5RBp2nMtCYTmHOyYiImqvdOqj8t1332H06NEYM2bMbdvFxcXhu+++w7p16xAREaFTgETUfo0M98KOF9zw9u8XsPFIKgCgokbA0i3nsfVMJv47sTe6eDoiMacUH+9OxOaEDFQLStX2cpkU46N8MTc2FKEejqY6DSIiIjIAnZ6oXLhwAWPHjm213YIFC5CamoqNGzfqchgi6gCcba3x9gO9sf7xAfBzvdWh/nhKAcat+Bsv/3wG9364Dz+dSGtUpABAtaDEj8fTcO+H+3DoWp6xQyciIiID0nkeFQcHh1bb2Nra4oEHHsCOHTt0PQwRdRBDurjjj+eGYcbgINT3p6+uVWLjkesorxbQ0rAfIoDKGgGz1h5FYk6p0eIlIiIiw9KpUPHz88PFixc1atu7d2+kpKTochgi6mAcbKyw5L4e+P7JQQhxb/2PIfWUIlBRLeCTPYkGjI6IiIiMSadCZcyYMfj6669RVlbWaluJRILCwkJdDkNEHVS/IDdsmDMAUi1GKhYBbDqZwQ72RERE7YROhcozzzyDkpISTJo0CeXltx9KNCEhAT4+PjoFR0Qd1+m0IjQz8NdtVQtKHEnKN0xAREREZFQ6FSpdunTBqlWrsHPnTkRFReHnn3+GIAhN2h0+fBhr167F3Xff3eZAiahjKa2qNep2REREZF50Gp4YAGbMmAEnJyfMnTsXDz74IDw9PXHnnXciMDAQcrkcZ8+exaZNm+Di4oJXX31VnzETUQfgaKPb7UmmzftiREREZLZ0LlQAYOLEiRg+fDjef/99rFu3Dhs2bGi0fuDAgfjss8/g5+fXpiCJqOPpH+wGuUzaZEji1rzy8xmcTivCjMFB8HezN1B0REREZGhtKlQAwN3dHcuWLcOyZctw7do1pKamQhAEdOnSBQEBAfqIkYg6IBd7OcZH+eLHE2ktDk3cnLJqAV/sS8KX+5MwJsIbs+8IRt9AV0gkfNJCRERkSdpcqDQUEhKCkJAQfe6SiDqwubGh2HomE5U1wm071kslda98udhZI6e0btQvpQhsP5eF7eey0NtPgdlDgzGulw+sZTpPH0VERERGxJ/YRGS2Qj0csWZGP9hay9DS8xAJAFtrGb6ePQAHXr4TKx6JRqS/S6M2p9OK8Oy3Cbjjv/FYtfsqhzAmIiKyACxUiMisDQzphN8WDMWkGD/I1Z6GyGVSPNjXD78tGIqBIZ1gLZPivkhfbHpqMH6aNwjjenk3moslq7gS/7f9Ega9vQuvbTrDmeyJiIjMmF5f/SIiMoRQD0fEPRiJV+/pjiNJ+SitqoWjjRX6B7vBxV7epL1EIkFMoBtiAt2Qml+OdQeT8e2RVJTcHLq4okbAN4eu45tD1zEy3BOzhwZjcGgn9mMhIiIyIwYvVEaOHAlfX1+88soriIiIMPThiKgdc7GXY0wPb6228Xezx6v3RODZUd3ww7FUfLk/Gdfzb01Uu+tiNnZdzEa4txNmDQ3GfZG+sLWW6Tt0IiIi0pLBX/3avXs3NmzYgN69e2Pq1KmGPhwRUbMcbawwc0gw4v8Zi0+mxqB/sFuj9RezSrDwx9MY+t9deH/nZeSUVJkoUiIiIgKMUKgolUqUlJRg8+bN8PHxMfThiIhuSyaV4K4e3vj+yUH47emhuD+6M6wadGTJLa3G//66giHv7MJLP5zCxaxiE0ZLRETUcUlEUZsZCkhXaWlp8Pf3R2pqqkYTYAqCgJycHHh4eEAm42soxJwwpBvFlVh3MBnrD19HYXlNk/VDunTC7KHBiO3mCanUfPqxMCdIHXOC1DEnqCFLywd2pieiDs/L2RYv3RWOp0d0xc8n07BmXxISc8pU6/dfzcP+q3kI8XDAzCHBmNinM+zlvH0SEREZEocnJiK6yU4uw5QBgdj5/HB8ObMf7ujq3mj9tZwyvL7pLAa9vQv/3X4RWUWVJoqUiIio/eOfBImI1EilEowI88SIME9cvlGCNfuS8PPJdFTXKgEARRU1+Hh3Ij7bew3jevlg9tDgJpNMEhERUdvwiQoR0W1083LCOxN74+C/R+KF0d3g7mijWlerFLH5VAbGf7Qfkz4+gG1nMiEo2e2PiIhIH1iotKKqqgqzZs1CQEAAnJ2dMXDgQBw8eNDUYRGRkXVytMEzd3bF/n+PwLsPRqK7j3Oj9cdSCjBv/QkMj4vH539fQ0ll0075REREpDmdCpU9e/agtLRU37GYpdraWgQFBWHfvn0oLCzEvHnzcO+996K8vLz1jYmo3bGxkmFSjB9+f2YoNswZgFHdPdFwQvu0ggq8ufUCBr29C0t/O4/UfN4riIiIdKFToXL06FGMGzcOM2fORG1trb5jMisODg5YtGgRAgICIJVKMX36dIiiiCtXrpg6NCIyIYlEgsGh7vh8ej/sejEW0wYFwq7BjPalVbVYsz8Jw+PiMffr4zianA+OBk9ERKQ5rQuV7777Dp999hn279+P6OhoWFkZpz9+aWkpFi9ejHHjxsHDwwMSiQTvvPNOs21rampUxYWtrS169+6NDRs26CWOixcvory8HCEhIXrZHxFZvmB3Bywd3xOHXr4T/x4bDh+FrWqdUgS2n8vCg6sPYvxH+/FrQjpqBKUJoyUiIrIMWhcqkydPxsKFC/Hmm2/CysoK1dXVt23/j3/8A5cuXWq0rKqqCkqldj+oc3NzsXTpUpw5cwbR0dG3bfvEE09g2bJlmDBhAj788EP4+/tjypQpWLdunVbHVFdeXo6pU6fitddeg5OTU5v2RUTtj8LeGnOHh2LvwhH48JFoRKmNBHY6rQjPfpuAof/dhY/ir6Kw/Pb3TyIioo5Mp1e/Ro8ejZdffhlPPfUU5HL5bdvu27cPXbp0AQBs3rwZAJCamorY2Fitjunj44P09HSkpqbi008/bbHdyZMnsXbtWixZsgQrVqzAnDlzsGXLFsTGxuKll15qVFiNGTMGtra2zf6bN29eo/3W1NRg0qRJiIiIwCuvvKJV7ETUsVjLpLg30heb5g/BT/MG455ePmg4of2N4irE/XEJA9/+C6/+cgaJOR2jzx8REZE2dHpvKyAgQPMDWFmp3sueMmUKSkpKEBoairNnz2p1TBsbG/j6+rba7vvvv4dUKsX8+fNVyyQSCZ5++mlMmjQJ8fHxuOuuuwAAO3bs0OjYSqUS06ZNg7W1Nb744gtIGvacJSK6jZhAV8QEuiKtoBxfHUjGt0dSUVJV17evskaJ9YevY/3h6xgR5oHZQ0MwpEsn3mOIiIhghAkfR44ciWeeeQbdunWDIAhITU2Fm5ub1q9+aer48eMIDQ2Fm5tbo+UDBgwAAJw4cUJVqGjqySefRGZmJrZv3260PjlE1L74udrj1Xsi8OyobvjhWCq+3J+M6w1GBIu/lIP4SzkI93bCrCHBuC/KF7YNOucTERF1NAb/rfvTTz/FSy+9hPPnz+Odd97BvffeC39/fwwcONAgx8vIyICPj0+T5fVPYzIyMrTaX0pKCj7//HPY2trC3d1dtfyTTz7BlClTWtyuuLgYxcXFqq8zMzMBAIIgQBCEVo8rCAKUSqVGbaljYE60D3ZWEkwbGIAp/f3x18VsfLk/GUeSC1TrL2aVYOFPp/Hf7RcxZYA/pgwIaDTJZEPMCVLHnCB1zAlqyFzyQSbT7A9xbSpUampqUF5eDoVC0WIbFxcXfPbZZ6qvAwICcObMGcyePbsth25RRUUFbGya/lCXSqWwtrZGRUWFVvsLDAzUaUjR5cuX44033miyPC8vr9n41CmVShQVFQGoi52IOdH+RHtIET0hBBezy/HdyRvYeakAtTdnts8rq8aKXYn4eM81jAlzw8PRnujqYd9oe+YEqWNOkDrmBDVkLvng7e2tUTudCpX8/HzMmDED27dvhyAICA4OxhtvvHHbJwz1JkyYgAkTJuhyWI3Y2tqiqqqqyXKlUomamhrY2to2s5X+vfDCC3j88cdVX2dmZqJ///7o1KkTPDw8Wt2+vtJ1d3fXuOqk9o050X55eAB39AjEjeJKfHP4OjYeSUVBed3M9jWCiK3n87D1fB4Ghbhh1pAgxHbzgFQqYU5QE8wJUsecoIYsLR90KlRefvllbNmyBX379oW3tzeOHDmCadOmQSqV4pFHHtF3jFrx9fVFSkpKk+X1r3xp0iFfH5ydneHs7NxkuUwm0zgxpFKpVu2p/WNOtG++rg5YeHd3LBjZDb+cTMea/Um4mn1rRLCD1/Jx8Fo+QtwdMHNIECZE+TAnqAnmBKljTlBDlpQPOj3z2bZtGyZPnowjR45g8+bNuHr1Ku644w68+uqr+o5Pa3369EFiYiLy8/MbLT98+LBqPRGRObOTy/DogADseG4Y1s7shzu6ujdafy23DK//eg5D/28PPtqXhsyiShNFSkREZDg6FSrp6emNRs5ydHTEkiVLkJKSgsTERL0Fp4uHHnoISqUSq1atUi0TRRErV66Eh4cHRowYYcLoiIg0J5VKEBvmia9nD8CO54fh4X7+kFvdum0XVdTg62M3EPvuHjyz8SROpRaaLlgiIiI90+nVL1EUm0z0GB4eDlEUkZmZidDQUL0Ep27lypUoLCxEYWEhACA+Ph61tXXzESxYsAAKhQIxMTGYOnUqFi9ejJycHPTq1QubNm3C7t27sWbNGo06shMRmZtuXk54Z2JvvHRXGNYfvo51B1OQW1rXH69WKWLzqQxsPpWBvoGumD00GGN6eEMm5XwsRERkuXQe9SsjIwM1NTWwtrYGANX/DWd+17d33323Uf+THTt2qCZtfOyxx1Sjj33++ecIDAzE2rVrsXr1anTr1g1ff/01HnvsMYPFpikOT0y6Yk4QALjYWWF+bAgeHxqEzQnp+OLva7iSe2s0w2MpBTiWUgA/VztMGxiAh/r6wcnW2oQRkzHxPkHqmBPUkLnkg6b9YySiDmPvSqVSSCQSWFlZISIiAjExMejSpQteffVVbN++HaNHj9Y64PYqLi4OcXFxEAQB+fn5OH78uEYd+uuHj1MoFBxOkAAwJ6gppVKJwsJCJJbI8P2pHOy7VgT1G7q9XIp7I9zxULQnOiv4RLm9432C1DEnqCFzyQdNhyfWqVDZs2cPTp06pfp37tw51ZDAUqkUgYGB6NGjByIiIlT/x8TEaHuYdiUtLQ3+/v5ITk6Gn59fq+0FQUBubq7FDB9HhsecIHXqOZGUW4avDqbgpxPpKK9u/NcyqQQY1d0Ls4YEom+gKyQSvhbWHvE+QeqYE9SQueSDQZ+oqBMEAZcuXWpUvJw6dQpZWVl1B5FITP6IydTqC5XU1FSNC5WcnBx4eHjwxkIAmBPUVEs5UVReg2+PXsdXB5KR0cyIYL06KzB7aDDG9fJp1DmfLB/vE6SOOUENWVo+tGlm+noymQwRERGIiIhoNI9KTk4OEhIScPr0aX0choiINKCwt8aTw0Mxa2gwtp/Nwhf7kpDQYESwM+lFeO67BLy97QKmDQrCo/0D4Oogb3mHREREJqCXJyrUOj5RobZiTpA6bXLieEoB1uxLwrazmVCq3fVtraV4oI8fZg0JRhdPRwNGTIbG+wSpY05QQ5aWDzo9USkoKICrq6vRtiMioraJCXRFTKAr0grK8dWBZHx7JBUlVXXDu1fWKLHh8HVsOHwdsWEemD00GEO7uLMfCxERmZROhYq/vz/69u2LCRMmYPz48QgODm6xbUpKCn799Vds2rQJ+/btM+jwxZaAwxOTrpgTpE6XnPBxtsG/7w7D0yNC8ePxNHx1MAXX828Nb7z7Ug52X8pBNy9HzBwchPGRPrCxNv+/ulEd3idIHXOCGjKXfDBoZ/qffvoJmzZtwu+//47CwkL07NkT48ePx/3334/o6GgkJCSoipPTp0/D2dkZY8eOxfjx4zF58mStT8aScXhi0hfmBKnTR04IShH7k4qw8cQNnEwvbbLe1c4KD/T2wAO9PdDJgfOxmDveJ0gdc4IaMpd8MOjwxPUEQcCePXvw66+/YvPmzbh+/Trs7OxQUVEBPz8/3HfffRg/fjxiY2NhZaWXfvsWi8MTU1sxJ0idvnPibHoRvjyQgq1nMlEjNP7RIJdJcG+kL2YNCUK4t1Obj0WGwfsEqWNOUEPmkg9GHZ64XkJCAvbv34/BgwcjOjpaX7ttF9iZntqKOUHqDJUTN4or8fXBFKw/nIKC8pom6weHdsLsocEYEeYJqZT9WMwJ7xOkjjlBDVlaPuj1MUdUVBSioqL0uUsiIjIyL2db/POuMMwf0QW/nEzHmv1JuJp967WwA4l5OJCYhxB3B8wcEoSJMX6wl3fsp+ZERKR/fFmRiIiaZSeX4dEBAdjx3DCsndkPd3R1b7T+Wm4ZXv/1HAa+9Rfe3nYBGYUVLeyJiIhIe1r/CWzp0qV6OfCiRYv0sh8iIjIsqVSC2DBPxIZ54vKNEqzZl4SfT6ajulYJACiurMUne67h87+TMK6XD2YPDUaUv4tpgyYiIoundR+V2w1FrPFBJRJcu3atzfuxJOyjQm3FnCB1psyJvNIqbDh8HesOpSCnpKrJ+phAV8weGowxEV6wkvHhvbHwPkHqmBPUkKXlg9ZPVJKSkgwRBxERWZBOjjZYcGdXPDE8BFtOZeKLfUk4n1msWn88pQDHUwrQ2cUOM4cE4aF+/nC25fDGRESkOfZ+NDJO+Ei6Yk6QOnPICSsJMCHKB+MjvXE4KR9f7k/BX5eyUf+sPr2wAm9uvYD3d17Gg339MH1QIALc7E0Wb3tnDjlB5oU5QQ2ZSz6YZHhiaooTPpK+MCdInbnmxPWCSnyfkI2t5/NQUaNstE4CYFioCx7p44lIX0dIJBzeWJ/MNSfIdJgT1JC55INRJnwkzXHCR2or5gSpM/ecKKqowXfH0rDuYAoyiyqbrO/p64yZQ4Iwrqc35Fb8BUofzD0nyPiYE9SQueSDpsfmq19GJpPJNL44UqlUq/bU/jEnSJ0554SbowzzYrvg8TtCsP1sFr7Yl4SE1ELV+rMZxXjxh9P4vz8uYdqgIDzaPwCuDnLTBdxOmHNOkGkwJ6ghS8oH/gmLiIgMylomxb2Rvtg0fwh+fmow7untA1mDGe1vFFch7o9LGPTOX3jllzONJpckIqKOi09UiIjIaPoEuKLPo65IKyjHuoMp2HjkOkoqawEAlTVKbDh8HRsOX0dsmAdmDw3G0C7ut+3HUlhejSNJ+SitqoWjjRX6B7vBxZ5PZYiI2gMWKkREZHR+rvZ4ZVx3PHNnV/x4LBVfHkhGSl65av3uSznYfSkHYV5OmDU0COOjOsPW+tZrCok5pfh4dyI2J2SgWrjVYV8uk2J8lC/mxoYi1MPRqOdERET6xVe/iIjIZBxtrDBjSDB2vRiLT6fGYECwW6P1l26U4F8/ncGQd3Zh+Y5LyC6pxKFrebj3w3346URaoyIFAKoFJX48noZ7P9yHQ9fyjHkqRESkZ3yiQkREJieTSjCmhzfG9PDG2fQirNmXhN9OZ6BGqBuYMq+sGit2XcXHuxOhBKAURbQ0ZqUIoLJGwKy1R/HbgqF8skJEZKH4RIWIiMxKz84KLJ8chX3/GomnR3SBq/2tGe1rlCIEZctFSj2lCFRUC/hkT6KBoyUiIkNhoUJERGbJy9kW/7wrDAdfvhNv3d8Lwe4OWm0vAth0MgOF5dWGCZCIiAyKr34ZmSAIEARBo3ZKpVKjttQxMCdIXUfJCWspMLlvZ7jaW2He+pNabVstKHEoMRejI7wMFJ156Sg5QZpjTlBD5pIPnPDRTMTFxSEuLk6VEHl5ebCxsWl1O6VSiaKiIgB1E/MQMSdIXUfLiazcAp22++tsGro6i3CwMf/Jzdqqo+UEtY45QQ2ZSz54e3tr1E4iiq296Uv6kJaWBn9/fyQnJ8PPz6/V9oIgIDc3F+7u7hYxcygZHnOC1HW0nNh5/gbmavlEpZ6VVIJ+Qa4YEeaBkeGeWr9GZik6Wk5Q65gT1JC55AOfqJgpmUym8cWRSqVataf2jzlB6jpSTgwMdYdcJm0yJLEmapUiDl7Lx8Fr+Xhr2yUEdbLHiHBPjAz3RP9gN9hYtZ/PryPlBGmGOUENWVI+sFAhIiKL4GIvx/goX/x4Iq3VUb8AQAKgi6cjqmqVuJ5f3mhdcl45vtyfjC/3J8NBLsPQru4YGe6JEWGe8HS2NcwJEBGRVlioEBGRxZgbG4qtZzJRWSNAeZtiRSoBbK1lWD01BiHuDkjMKUP8xWzsupiNo8n5qG2wcVm1gD/O3cAf524AAHp1VmBEuCfuDPdEr84KSKUSQ58WERE1g4UKERFZjFAPR6yZ0Q+z1h5FRbWA5moVCeqKlDUz+qkme+zi6Yguno6YMywExZU1+PtyLnZdzMbuS9nIK2s8fPGZ9CKcSS/Cir+uwN3RBrFhHrgz3BNDu7rDyda6mSMSEZEhsFAhIiKLMjCkE35bMBSrdyfi14SMRn1W5DIpJkT74snhoS3OSO9sa417evvgnt4+UCpFnEorRPzFbPx1MRvnMoobtc0trcKPx9Pw4/E0WMsk6BfkhpE3+7aEcMZ7IiKD4qhfRlI/6ldqaqrGo37l5OTAw8PDIjo7keExJ0gdcwIoLK/GkaR8lFbVwtHGCv2D3eBiL9d5fzeKK1VFy/6ruSivbnmugaBO9hgZ7qXqkC+3Mv3Qr8wJUsecoIYsLR/4RIWIiCyWi70cY3poNh6/JrycbfFw/wA83D8AVbUCDl/Lx66bfVua65C/Zn8S1uxPgoNchju61g19HBvuAU8ndsgnImorFipERETNsLGSYVg3Dwzr5oHF90aoOuT/dfEGjiUXNOmQv/1cFrafywIA9PZTYESYJ+7s7omevuyQT0SkCxYqRERErZBIJI065BdV1ODvKzk3O+TnIF+tQ/7ptCKcTivC/252yB8R5oE7u3tiaFcPONrwRy8RkSZ4tzQyQRAgCC2/89ywnVKp1KgtdQzMCVLHnDAdR7kUY3t4YWwPLwhKEafTihB/KQe7L+XgXGbTDvk/HE/DDw075Id5IDbMA8HuDnqNizlB6pgT1JC55IOm/WPYmd7A4uLiEBcXB0EQkJ+fj+PHj8PX17fV7ZRKJYqKiqBQKCCVmr6DJpkec4LUMSfMU3ZpNQ4kFeFAUhGOXC9BZa2yxbb+LjYYEqzAkGAFojo7wlrWtuvInCB1zAlqyFzywdtbs76FLFSMpH7Ur+TkZI1H/crNzYW7u7tFjMpAhsecIHXMCfNXVSPgcHI+4i/lIP5iDlILKlps62gjw9Au7hhx82mLu6ON1sdjTpA65gQ1ZC75oOmx+eqXkclkMo0vjlQq1ao9tX/MCVLHnDBv9jIZRoR7Y0S4N8T7RCTmlGLXxWz8dSEbx1IKIDTokF9aJWD7uRvYfu4GACDST4ERN+ds0aZDPnOC1DEnqCFLygcWKkREREZQ1yHfCV08nfDEsNBbHfIvZGP35aYd8k+lFeFUWhE++PMKPJzqOuSPDGeHfCLqOHinIyIiMgGFnTX+0dsX/+jtC0Ep4lRaIXZdqJuz5bxah/yckip8fywN3x+r65A/ILgTRoR74s5wTwTpuUM+EZG5YKFCRERkYjKpBH0CXNEnwBX/vCsMmUUViL9YN/zx/qu5qKi5NUJPjSBi39Vc7Luai/9sOY8QdwdV0RLtrzDhWRAR6RcLFSIiIjPjo7DDowMC8OiAAFTWCDh0LQ/xF7Ox61I2UvMbd8i/lluGa/uS8MW+JDjayNDP3wl3R1ZhZLg3PJy075BPRGQuWKgQERGZMVtrGWLDPBEb5oklooir2XUd8nddbL5DfvzVQsRfLQRwFpF+CowM98LIcE/08HXWuEM+EZE5YKFCRERkISQSCbp6OaGrlxOeHB6KovIa7L1S94rY7kvZKCivadS+vkP++39ehoeTDUaGeWJEuCeGdnVnh3wiMnu8SxEREVkohb017o30xb2RdR3yT6TkYcuJFBxJLcOFrJJGbXNKqvDdsVR8dywVcpkUA0LcMCKsbvhjdsgnInPEQoWIiKgdqO+Q729Xi0UeHrhRUo34S9mIv5iNfVdzUVmjVLWtFpT4+0ou/r6Si6VbziPEwwEjbxYtfYPcILfiDOZEZHosVIiIiNohXxc7TBkQiCkDAlFZI+BgfYf8i9lIK1DrkJ9Thms5Sfh8XxKcbKxwRzd3jLj5mpi7IzvkE5FpsFAhIiJq52ytZXWFR5gn3rhPxJUGHfKPq3XIL6mqxe9nsvD7mSxIJEBvPxeMDPPEnd3rOuRLJOyQT0TGwUKFiIioA5FIJOjm5YRuXk6Ye7ND/p4rOYhvpkO+KAKnUgtxKrUQ7/95GZ5ONnX9Wrp7YmgXdziwQz4RGRDvMEYmCAIEQdConVKp1KgtdQzMCVLHnCB1uuSEo40U9/T0wj09vSAoRSSkFiL+Ug52X8pp0iE/u1GHfAn6B7thRJgHRoR5IrCTvb5Ph/SA9wlqyFzyQSaTadROIoqi2Hoz0lVcXBzi4uIgCALy8/Nx/Phx+Pr6trqdUqlEUVERFAoFpFJ2aiTmBDXFnCB1+s6JGyXVOJBUhH1JRTiWWoyq2pZ/ZQh0tcGQYBcMCVYg0tcRVjK+ImYOeJ+ghswlH7y9vTVqx0LFSNLS0uDv74/k5GT4+fm12l4QBOTm5sLd3V3jqpPaN+YEqWNOkDpD5kRljYBD1/IRfykH8ZeykV5Y2WJbRxsrDOvqjtgwDwzv5s4O+SbE+wQ1ZC75oOmx+eqXkclkMo0vjlQq1ao9tX/MCVLHnCB1hsoJB5kMd0Z4484Ib4hiXYf8vy7UDX98/HrjDvmlVbX4/WwWfj9b1yE/0s8FI8Prhj9mh3zj432CGrKkfGChQkRERFpp2CF/XmwoCsursefyzQ75l3NQqNYhPyG1EAmphVi+8zK8nG1UE00OYYd8IroN3h2IiIioTVzs5Rgf1RnjozpDUIo4eb1ANfzxRbUO+TeKq/Dt0VR8ezQVcpkUA0LccGe4J0aGeyGAHfKJqAEWKkRERKQ3MqkEfYPc0DfIDQvvDkd6YQXiL9a9IrY/MReVNUpV22pBib+v5OLvK7lY8tt5hHo44M7uXhgR5om+Qa6wlmnf2bewvBpHkvJRWlULRxsr9A92g4u9XJ+nSERGwkKFiIiIDKazix0eGxiIxwYGorJGwMHEPNXTlvTCikZtE3PKkJhzDZ/uvQYnWysM6+aBkWGeiA3zQKdWOuQn5pTi492J2JyQgWrhVjEkl0kxPsoXc2NDEerhaJBzJCLDYKFCRERERmFrLcOIcE+MCPfEUlHE5RulN4uWGzieUoAG/fFRUlmLraczsfV0JiQSIMrfBSPD6rZV75B/6FoeZq09iooaAepjmVYLSvx4PA1bz2RizYx+GBjSyUhnS0RtxUKFiIiIjE4ikSDM2wlh3o075O+6mI09zXTIP3m9ECevF+K9nZfh7WyLEeF1E012drXDrLVHUdlMkaLaHnXDK89aexS/LRjKJytEFoKFChEREZlcww75tYISCamF+Otm3xb1DvlZxZXYeCQVG4+kQipBoycxLVGKQEW1gE/2JOL/JkUa6CyISJ9YqBAREZFZsZJJVR3y/3WzQ/6u+g75V3NRVXurD4omRUo9EcCmkxl4ZVx3drAnsgAsVIiIiMisdXaxw9SBgZg6MBAV1QIOXsvFrovZ2Ho6EwUNXhHTRLWgxJGkfIzp4W2gaIlIX1ioEBERkcWwk8swMtwLI8O9EO3vihd/OKX1Pt7edgFHkvIR6e+CKH8X+LnaNeqcT0TmgYUKERERWSQnW91+jUnKLcfn+5JUX7s7yhHp54JI/7p/UX4uUNhb6ytMItIRCxUiIiKySP2D3SCXSRvNm6KL3NJq/HUxG39dzFYtC3Z3QKSfAlE3i5cIX2fYWMnaGjIRaYGFChEREVmkupHCfPHjibQWhyZuSALgrp7eGBPhhVOphUhILcT5zGLUCE03TsotQ1JuGTYlZAAArGUSRPg4q14Xi/R3QXAnB0ilfGWMyFBYqBAREZHFmhsbiq1nMlFZI9x2BDCppG7CyZfuCkOohyMe6OMHAKiqFXAhswQJ1wtwKq0Ip1ILcS23rMn2NYJYtz6tCOsOpgAAnG2t6l4Xu/naWJS/CzycbAxynkQdEQsVIiIislihHo5YM6Nf3cz01QKaq1UkqCtS1szo12SyRxsrGaJuFhn1Csurcfpm0ZJw819eWXWT/RZX1uLvK7n4+0quallnF7ubT1wUiPRzQS8/Bezl/HWLSBf8zjEyQRAgCIJG7ZRKpUZtqWNgTpA65gSp66g50S/QBb8+NQif7E3C5lMZqG7wKpdcJsH4KF88cUcwQjwcNfpsnGxkGBLqhiGhbgAAURSRXliheqJyKrUIZzOKUFnTtG9MemEF0gsrsPVMJoC6JzndvJzQ20+BKD8FIv1d0MXDAVYyqZ7O/vY6ak5Q88wlH2Qyzfp7SURRk7c6SVdxcXGIi4uDIAjIz8/H8ePH4evr2+p2SqUSRUVFUCgUkEqNczMj88acIHXMCVLHnACKKmtxKr0UZdUCHOQyRHZ2hELH0cFup1YpIimvAueyynAuqwzns8pwLa+y2Sc66uyspQj3tEeEtwN6eDsgwssBXk7WBhkimTlBDZlLPnh7azaPEQsVI0lLS4O/vz+Sk5Ph5+fXantBEJCbmwt3d3eNq05q35gTpI45QeqYE6ZVWlWLcxnFSEgtrHt1LK0ImUWVGm3r7ihHlJ9L3ZMXfwV6dVbA2a7tQyQzJ6ghc8kHTY/NV7+MTCaTaXxxpFKpVu2p/WNOkDrmBKljTpiOwl6GwV08MLiLh2pZdnGlqp/LqbRCnE4tQklVbZNtc0ur8efFbPzZYIjkEA8HVf+ZSD8XdPdxhtxK+7+CMyeoIUvKBxYqRERERAbi6WyLMT28MaZH3asuSqWIa7mlSEi91Vn/QmYxapsZsuxaThmu5ZTh5xPpAAC5TIoIX+dbxYu/C4I62RvklTEic8BChYiIiMhIpFIJung6oYunEybF1L0KXlkj4HxmMRKu1z11SUgtREpeeZNtqwWl6ulMPYWddd3QyH4KRAXUPXnp5Mghkql9YKFCREREZEK21jL0CXBFnwBX1bKCsmpV0VL/5KWgvKbJtkUVNdh7OQd7L+eolvm52qmeuvTydYaXvOnoZESWgIUKERERkZlxdZAjNswTsWGeAOqGSE7Nr0BCWqHqycvZ9CJU1TYtQtIKKpBWUIEtp+uGSJZJgDBvJ0QFuCLq5uSUXTwdIZPylTEybyxUiIiIiMycRCJBQCd7BHSyx32RddMc1AhKXMoqafTU5WpOKdTHcxVE4HxmCc5nlmDD4esAAAe5DL1uzusSfbO/i4/CztinRXRbLFSIiIiILJC1TIqenRXo2VmBxwYGAgBKKmtwJr3oVvFyvRA3SqqabFtWLeDQtXwcupavWublbIPIm09cov1d0MtPASfbtg+RTKQrFipERERE7YSTrTUGh7pjcKg7gLp5M84lpSOt3Aqn04txKrUQp9MKUVbddGbyG8VV2HH+BnacvwEAkEiAUA9H1QhjUX4uCPdxgrWME0eScbBQISIiImrHPB3l6BHsgXG9614ZE5QiEnNKb83vklqIi1klENSGSBZF4Gp2Ka5ml+LH42kAABsrKXr4OiPK3xWR/gpE+bsgwI1DJJNhsFAhIiIi6kBkUgm6eTmhm5cTHurrDwCoqBZwLuPmK2NpRUhILUBqfkWTbatqlThxvRAnrheqlrna1w2RHOnnohoi2c1BbqzToXaMhQoRERFRB2cnl6FvkBv6BrmpluWVVuF0WhFO3nzqciqtEIXNDJFcUF6D3ZdysPvSrSGSAzvZq/q7RPm7oIevM2ytzX8mdDIvLFSIiIiIqIlOjjYYEe6JEeG3hkhOySvHqbRCnLw5RPK5jGJUNzNEckpeOVLyyrH5VAYAwEoqQbiPU11/Fz8XRAe4IMTdEVIOkUy3wUKFiIiIiFolkUgQ5O6AIHcHjI/qDACorlXiYlZdJ/36Jy+JOWVNtq1VijibXoyz6cX4BnVDJDvZWKGXn+JWZ31/F3g52xr1nMi8sVAhIiIiIp3IraTo7eeC3n4umDqobllxZQ1OpxbhVFqhqsN+TjNDJJdU1eJAYh4OJOaplvkobBu9MtbLTwFHG/662lHxyhMRERGR3jjbWmNoV3cM7Vo3RLIoisgsqmw0MeWZ9CKUNzNEcmZRJTKLsrD9XBaAuiGSu3k6IdJfoSpewrycYMUhkjsEFipEREREZDASiQS+LnbwdbHDuF4+AOqGSL6SXaIqXBJSi3ApqxhqIyRDFIFLN0pw6UYJvj9WN0SyrbUUPX0bvzLm52rHIZLbIRYqRERERGRUMqkE4d7OCPd2xuR+AQCA8upanL05KWVCWiESrhcivbDpEMmVNUocSynAsZQC1bJODnJV0VI3VLICLvYcItnSsVAhIiIiIpOzl1uhf7Ab+gffGiI5p6RKNTRy/atjxZW1TbbNK6vGrovZ2HUxW7Us2N0BkQ0660f4OsPGikMkWxIWKkRERERkljycbDAqwgujIrwAAEqliOS8srrC5XohEtKKcCGjGNVC0yGSk3LLkJRbhk0JdUMkW8skiPBxbjQ5ZXAnBw6RbMZYqBARERGRRZBKJQjxcESIhyPuj/YDAFTVCriQeau/y6nUQlzLbTpEco0g4lRaEU6lFQFIAQA42VrVFS0N+rt4ONkY85ToNlioEBEREZHFsrGSIepmkTH95rKi8hqcSits0Fm/EHll1U22Lamsxb6rudh3NVe1rLOLHSL9FarJKXt2VsCBQySbBD91IiIiImpXFPbWGNbNA8O6eQCoGyI5vbBC9cTlVGoRzqQXoaKm6RDJ6YUVSC+swO9n6oZIlkqAbl5OjZ66dPV05BDJRsBChYiIiIjaNYlEAj9Xe/i52uMfvX0BALWCEpdvlKr6u5xKK8TlGyVNhkhWisDFrBJczCrBt0dTAQB21jL08rv11CUqwAW+ClsOkaxnLFSIiIiIqMOxkkkR4euMCF9nPNK/bojksqpanEkvatTfJaOossm2FTUCjiTl40hSvmqZu6PNzVfQ6ian7O3nAoWdtdHOpz1ioWJkgiBAEJo+ZmyunVKp1KgtdQzMCVLHnCB1zAlSx5zQjq2VBP0CXdAv0EW1LLu4UtUJ//TN/0urmg6RnFtahT8v3MCfF26oloW4O6C3n6KuePFzQbi3E+RWpntlzFzyQSbTbJhoiSiKYuvNSFdxcXGIi4uDIAjIz8/H8ePH4evr2+p2SqUSRUVFUCgUkEr5DiQxJ6gp5gSpY06QOuaE/ilFEdcLqnAuqwzns8pw/kYZLueUo5kRkpuwlknQzcMeEd72iPByQA9vB/i72BjtlTFzyQdvb2+N2rFQMZK0tDT4+/sjOTkZfn5+rbYXBAG5ublwd3fXuOqk9o05QeqYE6SOOUHqmBPGUVUj4HxmSd1IYzefuqTklWu0rcLOGr39FIi8+a+3nwLujoYZItlc8kHTY/PVLyOTyWQaXxypVKpVe2r/mBOkjjlB6pgTpI45YXj2Mhn6BndC3+BOqmUFZdU3h0guQkJqAU6lFSG/mSGSiypq8PeVXPx95dYQyX6udoj0d0H0zZHGevoqYCdv2/UrLK/GocRcZOYWwCcPGBjqDhd7eZv2aWgsVIiIiIiI9MzVQY7YME/EhnkCqBsiOa2gAidVQyQX4kx6Eapqm74zllZQgbSCCmw9nQkAkEklCPNyalS8dPF0hEza+itjiTml+Hh3IjYnZKBa9X5aMuQyKcZH+WJubChCPRz1dt76xEKFiIiIiMjAJBIJ/N3s4e9mj/si6/or1whKXMoquTW/S1ohrmSXQr1jhqAUcT6zGOczi7HxyHUAgIO8bojkhsWLt3PjIZIPXcvDrLVHUVEjNNlntaDEj8fTsPVMJtbM6IeBIZ1gblioEBERERGZgLVMip6dFejZWYHHBgYCAEoqa3AmvajR5JRZxU2HSC6rFnDoWj4OXbs1RLKnk41qYkpPJxss+vUcqmqbFin1RACVNQJmrT2K3xYMNbsnKyxUiIiIiIjMhJOtNQaHumNwqLtqWVZRZV3hcnNyytNphSirbjrEcHZJFXacv4Ed5280WdcSpQhUVAv4ZE8i/m9SpF7OQV9YqBARERERmTFvhS3uVnjj7p51w/oKShHXckpV/V0SUgtxMasEglK3wXxFAJtOZuCVcd3NqoM9CxUiIiIiIgsik0rQ1csJXb2c8FBffwB1r3CdyyjCyeuF+ONcFo4mF2i1z2pBiSNJ+RjTQ7M5ToyBhQoRERERkYWztZYhJtANMYFucHOQa12oAEBpVa0BItMdpyglIiIiImpHHG10exah63aGwkKFiIiIiKgd6R/sBrlMu1/z5TIp+ge7GSgi3bBQISIiIiJqR1zs5Rgf5QtJ6/NBAgAkACZE+5pVR3qAhQoRERERUbszNzYUdtYytDZ5vVQC2MlleHJ4qHEC0wILFSIiIiKidibUwxFrZvSDrbUMLdUqEtR1wl8zo5/ZTfYIsFAhIiIiImqXBoZ0wm8LhmJSjF+TPitymRQP9vXDbwuGYmBIJxNFeHvm1bWfiIiIiIj0JtTDEXEPRuLVe7rjUGIuMnML4OPuioGh7mbXJ0UdCxUiIiIionbOxV6O0RFeyMmRwsPDAzKZzNQhtYqvfhERERERkdlhoUJERERERGaHhQoREREREZkdFipERERERGR2WKgQEREREZHZYaFCRERERERmh8MTG0ltbS0AIDMzU6P2giAgLy8PVVVVFjF8HBkec4LUMSdIHXOC1DEnqCFzygdvb29YWd2+FGGhYiQ5OTkAgP79+5s4EiIiIiIi00pNTYWfn99t20hEURSNFE+HVllZiTNnzsDDw6PV6hGoe/LSv39/HDlyBD4+PkaIkMwdc4LUMSdIHXOC1DEnqCFzygc+UTEjtra26Nevn9bb+fj4tFptUsfCnCB1zAlSx5wgdcwJashS8oGd6YmIiIiIyOywUCEiIiIiIrPDQsVMOTs7Y/HixXB2djZ1KGQmmBOkjjlB6pgTpI45QQ1ZWj6wMz0REREREZkdPlEhIiIiIiKzw0KFiIiIiIjMDgsVIiIiIiIyOyxUiIiIiIjI7LBQISIiIiIis8NChYiIiIiIzA4LFTNTU1ODRYsWISAgALa2tujduzc2bNhg6rDIwI4ePYpnnnkGvXr1gqOjIxQKBcaMGYO9e/c2aXvjxg1MnToVnTp1gqOjI0aOHInjx4+bIGoytm+++QYSiQS2trZN1jEvOo5z585h0qRJ8PDwgJ2dHbp27YqFCxc2asN86DguXbqEyZMnw9/fH/b29ujWrRtefvllFBYWNmrHnGh/SktLsXjxYowbNw4eHh6QSCR45513mm2rzfU/duwYRo4cCUdHR3Tq1AnTpk1Ddna2IU+lRZxHxczMnDkT69atw/z589GrVy9s2rQJv//+O7766itMmzbN1OGRgUyaNAl79+7FpEmTEB0djcLCQnz66adISkrCli1bcPfddwMAKioq0LdvX2RlZeGFF16Ai4sLVq1ahevXr+Pw4cOIiIgw8ZmQoZSWliIsLAxFRUWora1FZWWlah3zouPYvXs3xo0bh4iICDz88MNwdXVFSkoKEhMTsX79egDMh44kOTkZkZGRUCgUmDt3Ljw8PHDs2DF88cUX6Nu3Lw4dOgSAOdFeJScnIzg4GH5+fujevTt27tyJt99+G//+978btdPm+p8/fx79+/dHUFAQ5s2bh8LCQrz33nvw9fXF0aNHYWdnZ9yTFMlsnDhxQgQgLl26VLVMqVSKsbGxoqenp1hVVWXC6MiQ9u3bJ1ZWVjZalp+fL3p7e4t9+vRRLVu+fLkIQNy7d69qWW5urtipUydx/PjxxgqXTGDhwoVieHi4+Oijj4o2NjaN1jEvOoaSkhKxc+fO4r333ivW1ta22I750HG88cYbIgDx9OnTjZY///zzIgDx/PnzoigyJ9qryspKMT09XRRFUUxKShIBiG+//XaTdtpc//vuu0/08PAQ8/LyVMvi4+NFAOL7779vkPO4Hb76ZUa+//57SKVSzJ8/X7VMIpHg6aefRnZ2NuLj400YHRnSkCFDYGNj02iZq6srRo4ciXPnzqmWff/994iMjMQdd9yhWtapUyc88sgj2LZtG0pKSowWMxnPlStX8MEHH+C9996DtbV1k/XMi45h48aNSE9Px9tvvw2ZTIaysjIIgtCkHfOh4yguLgYA+Pj4NFpe/3X9X7+ZE+2TjY0NfH19W22n6fUvKSnBtm3b8Oijj8LNzU3VNjY2Fj179sR3332n/5NoBQsVM3L8+HGEhoY2Sg4AGDBgAADgxIkTpgiLTCgjIwPu7u4AAKVSiYSEBPTv379JuwEDBqC6uhpnz541dohkBM899xxGjhyJcePGNVnHvOg4du7cCWdnZ2RlZSE8PByOjo5wdHTEo48+iry8PADMh44mNjYWADBjxgycOHECaWlp+OWXXxAXF4cpU6YgKCiIOdHBaXP9z5w5g5qamhbbnjp1Ckql0uAxN8RCxYxkZGQ0+asIAFW1nJGRYeyQyIT279+PPXv2YPLkyQCA/Px8VFZWMkc6mC1btmDHjh1Yvnx5s+uZFx3H5cuXUVtbi/vuuw8jRozATz/9hJdeegk//vgjxo4dC0EQmA8dzD/+8Q8sWbIEu3btQkxMDPz9/fHAAw9g4sSJWLduHQDeIzo6ba5//f8tta2oqEBBQYEBo23KyqhHo9uqqKho8voPAEilUlhbW6OiosIEUZEpZGdn49FHH0VAQAAWLVoEAKrr31yO1I8CxRxpX6qrq/H888/jqaeeQvfu3Zttw7zoOEpLS1FeXo45c+bg448/BgA88MADcHFxwYsvvoitW7ciOjoaAPOhIwkJCcHgwYPxwAMPwNfXF3///TdWrFgBKysrfPjhh7xHdHDaXH9zzBUWKmbE1tYWVVVVTZYrlUrU1NQ0OyQptT8lJSUYN24ciouLsXfvXigUCgC3bhLN5Uj9CFDMkfblvffeQ35+PhYvXtxiG+ZFx1H/y8Njjz3WaPmUKVPw4osv4u+//8agQYMAMB86ig0bNmDOnDm4cOECgoODAQATJkyAu7s7XnnlFUybNg1BQUEAmBMdlTY/I8zx5wlf/TIjvr6+yMzMbLK8/lGcJh2myLJVVFTg3nvvxYULF7B161b06tVLta5Tp06wsbFhjnQQRUVFWLZsGWbPno38/HxcvXoVV69eRUlJCURRxNWrV5GZmcm86ED8/PwAAF5eXo2We3h4QCqVorCwkPnQwaxevRqRkZGqIqXe+PHjAdS9Qsyc6Ni0uf71/7fU1tbWtkk/akNjoWJG+vTpg8TEROTn5zdafvjwYdV6ar9qamowadIkHDhwAD/99BMGDx7caL1UKkVkZCSOHDnSZNvDhw/D2tq6UWFDlq2goABlZWWIi4tD165dVf9+/vlnVFdXo2vXrpg5cybzogOJiYkBAKSlpTVanpmZCaVSCU9PT+ZDB3Pjxg3U1tY2WV6/rLa2ljnRwWlz/Xv16gVra+sW20ZGRkIqNW7pwELFjDz00ENQKpVYtWqVapkoili5ciU8PDwwYsQIE0ZHhqRUKjFlyhRs374d69atU03wqO6hhx7CqVOnsG/fPtWyvLw8bNy4EXfffTecnJyMFTIZmKenJ3744Ycm/2JjY2FtbY0ffvgBr732GgDmRUcxefJkSCQSfPHFF42Wf/755wCAMWPGAGA+dCRhYWE4ffp0k1G7vvnmGwC3ilvmRMem6fV3dnbGXXfdhQ0bNjTqNL97926cPXsWDz30kNFj58z0ZmbatGlYv349nn76adXM9Fu3bsWaNWswc+ZMU4dHBvLCCy/g/fffx+jRozFt2rQm6+vfSS8vL0dMTAyys7Px4osvQqFQYNWqVUhJScGhQ4fQs2dPY4dORjZjxgx8++23jWamZ150HHPnzsUnn3yCBx54AKNGjUJCQgI+++wzTJgwAT///DMA5kNH8vfff2PkyJFQKBSYP38+fHx8sHfvXmzcuBHDhw9HfHw8JBIJc6IdW7lyJQoLC1UzyI8ZM0Y1X8qCBQugUCi0uv5nz57FgAEDEBwcjHnz5qGoqAjvvvsuvL29cezYMdjb2xv3BI0+xSTdVlVVlfjaa6+Jfn5+olwuF3v27Cl+/fXXpg6LDGz48OEigBb/NZSZmSlOmTJFdHV1Fe3t7cXY2FjxyJEjJoqcjG369OlNZqYXReZFR1FTUyO+9dZbYkhIiGhtbS0GBASIr776qlhVVdWoHfOh4zh27Jh47733ip07dxatra3FwMBA8cUXXxRLS0sbtWNOtE+BgYEt/u6QlJSkaqfN9T98+LAYGxsr2tvbi66uruKUKVPEzMxMI51RY3yiQkREREREZod9VIiIiIiIyOywUCEiIiIiIrPDQoWIiIiIiMwOCxUiIiIiIjI7LFSIiIiIiMjssFAhIiIiIiKzw0KFiIiIiIjMDgsVIiIiIiIyOyxUiIiIiIjI7LBQISIiIiIis8NChYiIiIiIzA4LFSIi6pBKS0uxePFijBs3Dh4eHpBIJHjnnXda3S4wMBCLFy82QoRERB0bCxUiIuqQcnNzsXTpUpw5cwbR0dEabXPmzBlcv34d99xzj4GjIyIiK1MHQEREZAo+Pj5IT0+Hr68vkpOTERwc3Oo2W7duhaenJ/r162eECImIOjY+USEiIpOaO3cuOnXqhHXr1jVZd+LECUilUnz66ad6P66NjQ18fX212mbLli0YO3YsJBIJAKCsrAyLFi1CeHg47Ozs4Orqir59+2Lbtm16j5eIqKPhExUiIjKpiRMn4u+//8brr7+OadOmNVq3cOFChIWFYfbs2Y2W19TUoKioSKP9KxQKWFtbtznO/Px8HDp0CM899xwAQBRF3H333Thx4gTmzp2LHj16oLi4GCdOnICDg0Obj0dE1NGxUCEiIpMaPXo0nnrqKTz99NMoKSmBk5MTAOCPP/7AX3/9hV9++QUymazRNvv378eIESM02n98fDxiY2PbHOf27dshlUoxZswYAMDp06exb98+fPPNN5gyZUqb909ERI2xUCEiIpMLCwsDAFy5cgV9+vSBUqnEv/71LwwePBgTJkxo0j4yMhI7d+7UaN+RkZF6iXHLli0YOnQonJ2dAQAuLi6wsrLCjh070L9/f7i6usLZ2RlyuVwvxyMi6uhYqBARkcl16dIFwK1CZf369Th16hT27dvXbHtXV1eMGjXKaPEJgoA//vgDr776qmpZYGAg1qxZgyeeeELVv+bAgQMYNGiQ0eIiImrPWKgQEZHJBQQEQC6X48qVK6iqqsLrr7+OCRMmYMiQIc22r66uRn5+vkb7dnNza/NTjoMHDyI/P7/RsMSrV6/GP//5Tzz77LMYPHgw7O3t0adPnzYdh4iIbmGhQkREJieVShESEoLLly9j5cqVSEtLw/bt21tsf+DAAaP2UdmyZQtCQ0NVr6ilp6fjmWeewf/+9z/MmzevTfsmIqLmsVAhIiKz0KVLFxw/fhxbt27FrFmzEB4e3mJbY/dR2bp1K/7xj3+ovk5ISEBNTQ26du3a5n0TEVHzWKgQEZFZ6NKlC7Zs2QJ7e3ssWbLktm311Udl5cqVKCwsRGFhIYC6py+1tbUAgAULFkChUOD69es4e/Ysli9frtque/fukMvlmDZtGp588kn4+/sjNzcXBw8exLPPPquXUcaIiDo6FipERGQW6p9OPP/881pPxKird999FykpKaqvd+zYgR07dgAAHnvsMSgUCmzZsgWOjo4YPny4ql1ISAh++eUXLFu2DHFxcaitrYWPjw/69++P3r17GyV2IqL2TiKKomjqIIiIiD799FM8+eSTSE1NhZ+fn6nDUbnnnntgY2ODn3/+2dShEBF1KHyiQkREZuHMmTNwcXExqyIFAIYPH97i6GNERGQ4fKJCRERmITY2FkqlEnv37jV1KEREZAakpg6AiIgIqHui0rNnT1OHQUREZoJPVIiIiIiIyOzwiQoREREREZkdFipERERERGR2WKgQEREREZHZYaFCRERERERmh4UKERERERGZHRYqRERERERkdlioEBERERGR2WGhQkREREREZoeFChERERERmR0WKkREREREZHZYqBARERERkdn5fyPG+Dyu++HwAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Solve the entropic problem at decreasing eps and measure the gap on the cost.\n", + "C_th_j = jnp.array(C_th)\n", + "a_th_j = jnp.array(a_th)\n", + "b_th_j = jnp.array(b_th)\n", + "D_th = jnp.array(((t_I_th * np.ones((n_th, n_th)) - DI_th) / n_th)[None, ...])\n", + "\n", + "# Warm up\n", + "_ = constrained_sinkhorn(\n", + " C_th_j,\n", + " a_th_j,\n", + " b_th_j,\n", + " D_th,\n", + " jnp.zeros((0, n_th, n_th)),\n", + " eps=1.0 / 10.0,\n", + " n_iters=10,\n", + " n_newton=5,\n", + ")\n", + "\n", + "# Sweep gammas (= 1/eps).\n", + "gammas_sweep = np.array([3.0, 5.0, 10.0, 20.0, 30.0, 50.0, 70.0, 100.0])\n", + "gaps_sweep = []\n", + "for gamma in gammas_sweep:\n", + " eps_g = 1.0 / float(gamma)\n", + " n_iters_g = max(400, int(5 * gamma))\n", + " res_g = constrained_sinkhorn(\n", + " C_th_j,\n", + " a_th_j,\n", + " b_th_j,\n", + " D_th,\n", + " jnp.zeros((0, n_th, n_th)),\n", + " eps=eps_g,\n", + " n_iters=n_iters_g,\n", + " n_newton=15,\n", + " )\n", + " cost_g = float(jnp.sum(res_g.matrix * C_th_j))\n", + " gaps_sweep.append(abs(cost_g - cost_star_lp))\n", + " print(\n", + " f\" gamma = {gamma:6.1f} (eps = {eps_g:.4g}): cost = {cost_g:.6f}, gap = {gaps_sweep[-1]:.3e}\"\n", + " )\n", + "gaps_sweep = np.array(gaps_sweep)\n", + "\n", + "fig, ax = plt.subplots(figsize=(7.5, 4.4))\n", + "ax.semilogy(\n", + " gammas_sweep, gaps_sweep, \"o-\", linewidth=2, markersize=8, color=\"C0\"\n", + ")\n", + "ax.set_xlabel(r\"$\\gamma = 1/\\varepsilon$\")\n", + "ax.set_ylabel(\n", + " r\"$|\\langle P_\\varepsilon^\\star, C\\rangle - \\langle P^\\star, C\\rangle|$\"\n", + ")\n", + "ax.set_title(r\"Optimality gap decays exponentially in $\\gamma$ (Theorem 1)\")\n", + "ax.grid(True, which=\"both\", alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "aba770c7", + "metadata": {}, + "source": [ + "The gap on the optimal cost shrinks by more than two orders of magnitude as $\\gamma = 1/\\varepsilon$ goes from 3 to 100, in line with the exponential bound. This is the constrained version of the well-known phenomenon for unconstrained Sinkhorn (Weed 2018): entropy regularisation acts as a **smoothing** that biases the solution toward the analytic centre of the polytope, and this bias vanishes at an exponential rate as we sharpen $\\gamma$.\n", + "\n", + "Practically, this means that for most applications a moderate $\\gamma$ (say, of order $1/\\sigma_C$ where $\\sigma_C$ is the typical scale of the cost) already gives a $P_\\varepsilon^\\star$ that is barely distinguishable from the LP optimum." + ] + }, + { + "cell_type": "markdown", + "id": "8bbb4a97", + "metadata": {}, + "source": [ + "## 6. Tracing a Pareto front\n", + "\n", + "A particularly nice consequence of having a generic constrained-OT solver is that we can **profile a Pareto front** between two competing transport costs by sweeping the inequality threshold. We solve\n", + "$$\\min_{P \\in U(a,b)} \\langle P, C_1 \\rangle \\quad\\text{s.t.}\\quad \\langle P, C_2 \\rangle \\le t,$$\n", + "for a sequence of thresholds $t \\in [t_{\\min}, t_{\\max}]$, where:\n", + "\n", + "- $t_{\\min}$ is the smallest achievable value of $\\langle P, C_2 \\rangle$ over $U(a,b)$ (i.e. the value attained by the $C_2$-minimising plan),\n", + "- $t_{\\max}$ is $\\langle P_{C_1}^\\star, C_2 \\rangle$, the $C_2$-cost of the $C_1$-minimiser, beyond which the constraint becomes inactive.\n", + "\n", + "Both endpoints correspond to *unconstrained* OT problems (one minimizing C₁ only, the other minimizing C₂ only). Since the inequality constraint is inactive at these points, we can efficiently compute them using vanilla sinkhorn.Sinkhorn rather than our more general (and slightly more expensive) constrained solver." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "44e8c2dd", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:44.582962Z", + "iopub.status.busy": "2026-05-02T06:39:44.582758Z", + "iopub.status.idle": "2026-05-02T06:39:46.651765Z", + "shell.execute_reply": "2026-05-02T06:39:46.650709Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "endpoints: t_min = 0.0164 (D-only minimum)\n", + " t_max = 0.4963 (D when minimising C)\n" + ] + } + ], + "source": [ + "# Endpoints via vanilla OTT-JAX Sinkhorn.\n", + "\n", + "\n", + "def solve_unconstrained(cost_matrix, eps=1e-3):\n", + " \"\"\"Solve unconstrained entropic OT and return the transport matrix.\"\"\"\n", + " geom_u = geometry.Geometry(cost_matrix=jnp.array(cost_matrix), epsilon=eps)\n", + " out_u = sinkhorn.Sinkhorn(threshold=1e-7)(\n", + " linear_problem.LinearProblem(geom_u, a=a_j, b=b_j)\n", + " )\n", + " return np.array(out_u.matrix)\n", + "\n", + "\n", + "P_C = solve_unconstrained(C_cost) # minimises C\n", + "P_D = solve_unconstrained(DI_orig) # minimises D_I\n", + "\n", + "t_min = float((P_D * DI_orig).sum())\n", + "t_max = float((P_C * DI_orig).sum())\n", + "print(f\"endpoints: t_min = {t_min:.4f} (D-only minimum)\")\n", + "print(f\" t_max = {t_max:.4f} (D when minimising C)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "50b0fe07", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:39:46.654051Z", + "iopub.status.busy": "2026-05-02T06:39:46.653351Z", + "iopub.status.idle": "2026-05-02T06:40:25.055811Z", + "shell.execute_reply": "2026-05-02T06:40:25.054766Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " t = 0.0164 -> = 0.3043, = 0.0164, alpha (phys) = +13.494\n", + " t = 0.0764 -> = 0.0845, = 0.0764, alpha (phys) = +0.874\n", + " t = 0.1363 -> = 0.0486, = 0.1363, alpha (phys) = +0.352\n", + " t = 0.1963 -> = 0.0349, = 0.1963, alpha (phys) = +0.170\n", + " t = 0.2563 -> = 0.0272, = 0.2563, alpha (phys) = +0.105\n", + " t = 0.3163 -> = 0.0219, = 0.3163, alpha (phys) = +0.073\n", + " t = 0.3763 -> = 0.0185, = 0.3763, alpha (phys) = +0.036\n", + " t = 0.4363 -> = 0.0175, = 0.4363, alpha (phys) = +0.011\n", + " t = 0.4963 -> = 0.0172, = 0.4851, alpha (phys) = +0.000\n" + ] + } + ], + "source": [ + "# Sweep t along the Pareto front.\n", + "ts = np.linspace(t_min, t_max, 9)\n", + "sweep = []\n", + "for t in ts:\n", + " D_ineq_t = (t * np.ones((n_a, n_a)) - DI_orig) / n_a\n", + " res_t = constrained_sinkhorn(\n", + " C_j,\n", + " a_j,\n", + " b_j,\n", + " jnp.array(D_ineq_t[None, ...]),\n", + " jnp.zeros((0, n_a, n_a)),\n", + " eps=eps_run,\n", + " n_iters=200,\n", + " n_newton=10,\n", + " )\n", + " P_t_np = np.array(res_t.matrix)\n", + " # alpha_solver lives in the (D_solver = (t1 - D)/n) coordinate system.\n", + " # The Lagrangian dual on the *physical* constraint D.P <= t is alpha/n.\n", + " sweep.append(\n", + " (\n", + " float(t),\n", + " float((P_t_np * C_cost).sum()),\n", + " float((P_t_np * DI_orig).sum()),\n", + " float(res_t.alphas[0]) / n_a,\n", + " )\n", + " )\n", + "\n", + "sweep = np.array(sweep)\n", + "ts_arr, C_costs, D_costs, alphas_phys = sweep.T\n", + "for t, mc, dc, al in sweep:\n", + " print(\n", + " f\" t = {t:.4f} -> = {mc:.4f}, = {dc:.4f}, alpha (phys) = {al:+.3f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3cf9de16", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:40:25.058091Z", + "iopub.status.busy": "2026-05-02T06:40:25.057508Z", + "iopub.status.idle": "2026-05-02T06:40:25.465785Z", + "shell.execute_reply": "2026-05-02T06:40:25.465246Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABOIAAAG/CAYAAAD1rnjIAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAAA9mRJREFUeJzs3XdcVeUfB/DPucBl741s90DEAe699yg3rswyNUeR9ksFtdIiTU0t03Cbq9x7p5mYIrg1EUQFvOy9Lvf8/kBuEqDsy/i8Xy9exZmfgxfuc7/nPM8jiKIogoiIiIiIiIiIiMqVRNUBiIiIiIiIiIiIagIW4oiIiIiIiIiIiCoAC3FEREREREREREQVgIU4IiIiIiIiIiKiCsBCHBERERERERERUQVgIY6IiIiIiIiIiKgCsBBHRERERERERERUAViIIyIiIiIiIiIiqgAsxBEREREREREREVUAFuKIiIiIiIiIiIgqAAtxRFVEYmIiZs2aBWdnZ2hoaEAQBFy4cEHVsYiIiIiI8pkwYQIEQUBoaKiqoxRIlfkcHR3h6OhY4eetCIIgVJlru3DhAgRBwIQJE8r1PIIgoHPnzkXePjQ0tEJykeqwEEeViiAIeb7U1NRgbGyM9u3bY/369cjOzlZ1RKXcP5DF+aNaGp999hlWrVoFBwcHzJ07F97e3pXmTa6838RevHiBL774Au7u7jA1NYWGhgbMzMzQqVMnLFu2DC9fviyX85aF3Ebe669pQ0NDODk5YcCAAfjuu+8QERFR7OM6OjpCEIRySFxzVVRjrDxVpcYvEVFV8+jRI8yaNQuurq4wMjKCVCqFpaUlevbsiTVr1iAxMVHVEQu1efNmCIIAHx8fVUcptqqcvTpiW4Oo9NRVHYCoIN7e3gAAuVyO4OBg7N+/H3/++SfOnDmDvXv3qjidahw5cgR6eno4deoUNDQ0VB2nwmzevBlTp05Feno6GjdujHfffRempqaIi4uDv78/Pv/8c3z55Zd4/PgxrKysVB23UIMGDUKzZs0AAMnJyXjx4gX+/PNPHDlyBAsWLMDixYvh5eWl2pBERERUoK+//hoLFiyAQqFAq1atMHbsWBgaGiIqKgqXLl3CjBkz4OPjg+joaFVHrTSWLl2KefPmoVatWqqOQkRUqbAQR5XSf+943b59Gx4eHti3bx8uXbqEDh06qCaYCoWHh8Pe3r5GFeF27dqFiRMnwtDQELt27cKgQYPybRMYGIjZs2cjPT1dBQmLbvDgwfmetFIoFNi7dy+mTp2Kzz77DKIo4rPPPlNNQCIiIirQN998gy+++AK1atXC7t270a5du3zb/PHHH5gxY4YK0lVe1tbWsLa2VnUMIqJKh11TqUpwcXFRdgH19/cHANy4cQMzZ86Eq6srTExMoKWlhbp16+KTTz5BXFxcvmO8/lj71atX0adPHxgbG0MQBMTHxwPIKYxs3LgR7dq1g6GhIbS0tNCkSRMsXboUmZmZeY7l5OQEALh48WKerof/LSL+/vvv6NKlC4yMjKClpYWGDRti4cKFSE5OLtK1d+7cGYIgQBRFPH36VHme3J/H693pHj58iHfffRfm5uaQSCQIDAwEAGRmZsLX1xfNmjWDjo4O9PX10bp1a/j5+UEUxXznzH3kPDU1FV5eXrC3t4empibq1KmDb775Js8+Pj4+6NKlCwBgy5YteX4WmzdvLtI1FiQ5ORnTp08HAPz6668FFuEAoFmzZjh79uxb77Z++OGHEAQBv/32W4HrHzx4AEEQ8hR5ZTIZPvvsMzRo0AC6urrQ19eHk5MTRo4ciaCgoBJe2b8kEglGjBiBPXv2AMj5Wb6tm2pul+inT58CyNud+/Vu0rldVzMyMuDj44O6detCKpVi1qxZAHIKu4sXL0a7du1gZWUFqVQKGxsbjB49Gvfv3y/0vJ07d0Z0dDSmTJkCa2traGpqonHjxti0aVO+fURRxNatW9GuXTtYWFhAU1MT1tbW6NSpE37++ec82+a+zp88eYLvvvsO9evXh5aWFuzs7PDJJ58gKSmpwJ9HYGAghg8fDktLS0ilUtjZ2eG9995DSEhIvm19fHyUr8tjx46hY8eOMDAwgLGxcZm8jl+8eIFZs2ahXr160NbWhrGxMVq0aIGFCxciKysrz7ZPnjzBxIkTYWtrq+zaNHz4cNy6dSvfcTMzM/HDDz+gRYsWMDU1hZaWFmrVqoVevXrh999/B/Dv3wEAef5O/Ler7aVLlzBw4EDY2dlBKpXCzMwMbm5umD17doF/C4iIarqnT59iwYIF0NDQwLFjxwoswgFAx44dce3atXzLL1y4gH79+sHU1BSamppwdnbGrFmzEBUVlW/b3OEsLly4gH379sHd3R06OjowMTHByJEj8eLFi3z7PHnyBB988AHq1q0LbW1tGBkZoV69epgwYYKyrTBhwgRMnDgRALBo0aI87xG54w0XpZ184MABjB07FvXq1YOuri709PTQokULrF69GgqFotDreX0MtuK2J4qS/W1+++03eHh4QFtbG6amphg2bBj++eefArfNfT8trBtsbnvldZmZmVizZg369u0LBwcHaGpqwsTEBN27d8fx48eLlLEwL168gJqaGlxcXArdZuTIkRAEAWfPnlUuO3z4MLp37w4bGxtoamrC0tISHh4e+Prrr0ucpahtjVxF+RwB5H1NhIeHY9KkSbC2toaamhoOHDig3O78+fMYOHAgzM3NIZVK4eDggI8++giRkZH5zl2U34v/Cg0NxciRI2FmZgYtLS20bNkSR44cKXDb4n62KkxUVBQ+/PBDWFtbQ0tLC40aNcLq1avZJqsB+EQcVRm5f5By3wA2bNiA/fv3o1OnTujevTuys7MREBCAFStW4Pjx4/D394e+vn6+41y5cgVff/01OnbsiMmTJyMiIgJqamqQy+UYOnQoDh8+jHr16mHUqFHQ0tLCxYsX8b///Q9nz57FiRMnoK6ujmbNmmHmzJnKMdtef/N5vRiycOFCLFmyBCYmJhgxYgSMjIxw+vRpLFmyBIcOHcKlS5cKzPi6CRMmoHPnzli0aBEMDQ2VhZT/js3w+PFjeHh4oGHDhvD09ERiYiJ0dHSQlZWFPn364Ny5c6hXrx6mTp2KzMxM/P7773jvvfdw+fJl+Pn55TtvVlYWevbsifDwcPTp0wfq6uo4cOAA5s2bh/T0dGX34c6dOyM0NBRbtmyBq6srBg8erDxGblfM3O0uXryITZs2FWkMrn379iEmJgYeHh7o06fPG7eVSCSQSN58X2HChAlYv349Nm/ejGHDhuVbv2XLFuV2QE7joW3btggODka3bt3Qv39/AMCzZ89w7tw5dO/eHa6urm+9jqLo3r072rdvj8uXL2P//v346KOPCt3WyMgI3t7eWLlyJRISEpT/DkD+1wQADBs2DAEBAejduzeGDBmiLCD/8ccfWLZsGbp06YJhw4ZBT08P//zzD/bt24dDhw7hzz//LPD64uPj0a5dO0ilUrzzzjtIT0/Hvn37MGnSJEgkEowfP1657RdffIGlS5fC0dER77zzDoyMjBAZGYmgoCBs27YNU6ZMyXf8WbNm4fLlyxg+fDgMDQ1x/PhxrFixApcvX8Yff/wBTU1N5bbHjx/HkCFDkJ2djaFDh6J27dq4desW/Pz8sH//fpw7dy7PazDX3r17cfLkSfTr1w9Tp07Fy5cvi/w6Lsz169fRu3dvxMTEoH379hg8eDDS09Px4MEDLF26FHPmzIGRkREAICAgAN26dUN8fDz69euHpk2bIjg4GL///jsOHz6MgwcPomfPnspjT5gwAb/++isaNWqEMWPGQFdXF+Hh4bh27Rp+//13DB06FI6OjvD29s73d+L1/CdOnEC/fv2gr6+PgQMHwtbWFnFxcfjnn3+wZs0a+Pr6Ql2dzQIiotdt2rQJWVlZGDFiBJo2bfrGbV9/jwKAjRs3YsqUKdDW1sa7774La2trXLlyBatWrVIOu2Jra5vvOOvWrcOhQ4cwcOBAdOrUCf7+/ti9ezeCgoIQGBioPE9ERARatWqFxMRE5ft8ZmYmwsLCsH//fowdOxYODg4YPHgw4uPjcfDgQXTq1CnfjbvXFdZOBoB58+ZBIpHAw8MDtWrVQkJCAs6dO4eZM2fi77//xrZt24r8cy1qe6I42Qvyww8/4OOPP4a+vj5Gjx4NCwsLXLhwAR4eHm8sbhVHbGwsZs6cibZt26JHjx4wNzdHREQEDh8+jL59+2LDhg2YPHlyiY5dq1Yt9OjRAydPnsSNGzfQokWLPOsTEhJw8OBB2Nvbo2vXrgCAn3/+GR988AEsLS3Rv39/WFhYIDo6Gvfu3cNPP/2E//3vfyXKUpS2Rq6ifo54XUxMDNq0aQNDQ0O8++67UCgUMDExAZDzVOq8efNgYmKCfv36wcrKCrdu3cKPP/6IQ4cO4erVq8rfpaL+Xrzu6dOncHd3h7OzMzw9PREbG4vdu3dj0KBBOHPmjPJmbe61leSz1X/Fxsaibdu2ePz4Mdzd3TF+/HjExMRgwYIFOH/+fFH/WaiqEokqEQBiQS/LoKAgUUtLSwQgXrp0SRRFUQwNDRXlcnm+bTdu3CgCEJctW5Zn+aZNm5THX79+fb79lixZIgIQp02blue42dnZ4vvvvy8CEFevXq1cHhISIgIQO3XqVOC1/PXXX6IgCGKtWrXEFy9eKJcrFApx3LhxynMVFQDRwcEh3/Lz588rr+vzzz/Pt37ZsmUiALFnz55iRkaGcnl8fLzYuHFjEYC4d+/efOcCIPbp00dMTU1VLn/58qVoaGgoGhoaipmZmfkyjB8/vtD8nTp1EgGImzZtKtL1Tpo0SQQgfvHFF0XavigaNmwoqquri5GRkXmWZ2dni7a2tqKOjo6YmJgoiqIoHjp0SAQgzpw5M99x5HK5GBcXV6Rzjh8/vkjXPX/+fBGAOG7cuCId18HBocDflf+ud3FxEaOiovKtf/nypfJaXxcYGCjq6uqKvXv3zrM89/UOQHzvvffy/I7cvXtXVFNTExs2bJhnHxMTE9HGxkZMTk7Od57/Zsp9fZiamopPnz5VLpfL5eKgQYNEAOLXX3+tXJ6cnCyamZmJgiCI586dy3Os3L8BTZo0ERUKhXK5t7e3CEAUBEE8fvx4vkxFeR0XJCMjQ3R0dBQBiFu2bMm3PiIiQszKyhJFMef3v1GjRiIAcfPmzXm2O336tCgIgmhubi6mpKSIopjzeyoIgtiiRQvlMV73359jYX8nRFEUhw4dKgIQb968mW9ddHR0US6ViKjG6dq1qwhA3LBhQ7H2CwsLE6VSqairqyveuXMnz7rc9/x+/frlWZ7bZtDX1xdv3bqVZ92oUaNEAOLu3buVy1avXi0CEL///vt8509PTxeTkpKU3+e2g729vQvM+7Z2siiK4uPHj/Mty87OVrZrr169WuD1hISEKJeVpD3xtuyFefr0qSiVSkV9fX3x0aNHedZNnz5dmeP1fLltgcLOldteeV16err47NmzfNvmtrWNjY3ztKdFMaedVtj79X/t2rVLBCBOnz4937r169eLAMQFCxYolzVv3lyUSqX52ruimL/dUBJvamvkri/O54jXXxOenp752jsXL14UBUEQW7duna/9vXXrVhGAOHToUOWy4vxevP45ysfHJ8+2J06cUF7H60r62eq/nxk/+OADEYD44Ycf5ln++PFj0dDQsERtUqo62DWVKiUfHx/4+Phg/vz5GD16NNzd3ZGeno533nkH7du3BwA4ODgo79C9btKkSTAwMMDJkycLPHazZs3yPYmjUCiwatUqWFhYYOXKlXmOK5FI4OvrC0EQinWn75dffoEoivjf//4HGxsb5XJBEPDtt99CW1sbmzdvztdlraQsLS0LvLu0ceNGAMCKFSsglUqVyw0NDbF06VIAOU8XFmT16tXQ1tZWfm9hYYFBgwYhISEBDx8+LFa+rVu34v79+xgyZEiRts/tolnQneKSGj9+PORyOXbs2JFn+ZkzZ/D8+XMMGzYs3xOKr19/LjU1NeXTTWUlt2ttQV1VSmPJkiUwMzPLt9zCwqLApzFdXV3RtWtXnD9/vsDXpo6ODlasWJHnd6RRo0Zo164d7t+/n68LqVQqLfApq4IyAcDMmTNhb2+v/F5NTQ3ffPMNBEHIc3fxwIEDiI6OxjvvvJPnLiUAvPfee2jRogXu3LmDq1ev5jvHoEGD0Lt37wLPXxKHDx9GaGgo+vbti3HjxuVbb2VlpfwZXLlyBffu3UOrVq3yPD0I5DwZOWTIEERFReHgwYMAcv7+iKIIqVRa4N+7wn6Ob1LQa9rU1LTYxyEiqglK2h7ZunUrMjMz8dFHH6Fx48Z51s2fPx82NjY4evQowsPD8+378ccf53ta6/333weAAru/FvR3XVNTE3p6esXKDBTcTs5Vu3btfMskEglmzpwJAIW2vQtS3PZESWzfvh2ZmZmYNm0a6tatm2fdkiVLYGBgUOpzADk/64JeH4aGhpg0aRLi4uLw999/l/j4gwcPhpGREX799dc8Q+UAUA6f8d82hbq6eoHjSpek3VBSxf0cIZVK8d133+VrN65atQqiKGL9+vX52t+enp5wc3PDwYMH871mivN74eDggPnz5+dZ1qtXL9jb2+f7nSvNZ6tcWVlZ2L59O3R1dfHll1/mWVe7dm3l7xRVXyzEUaW0aNEiLFq0CF9//TWOHTuGli1b4scff8SuXbuU22RlZWHNmjVo3749jI2NoaamBkEQIJFIkJiYWOA4GgDg7u6eb9mjR48QHR0NfX19fPnll8pCYO7X999/D21t7QLHzipMQEAAACgfE3+dpaUlXFxckJKSgkePHhX5mG/i6uqar0tEUlISHj9+DEtLy3yNQADo1q1bnqyvMzQ0RJ06dfItt7OzA4ACx+F7E3t7ezRo0ACGhoZF2l78T1fksuDp6Qk1NbV8Y379t1sqAHTq1Al2dnb45ptv0KNHD6xcuRLXrl2DXC4vszyvK4/rBQp+vec6evQo+vfvD2tra2hoaCjH+Th8+DAyMjIKnPmtbt26BTZcc18XuePIADk/79DQUDRs2BBz587FkSNHEBsb+8a8nTp1yresfv36sLS0xOPHj5WNrDf9fgFvfm2/6WdSErnFvrd1oX49T2G5u3fvnmc7fX19DBo0CH/99ReaNm0Kb29vnD59ushjTL7O09MTAODh4YEpU6bg119/zTNuDxER5VfS9+c3/b3X1NRU3li+efNmvvUtW7bMt6yg9tegQYOgr6+P6dOnY8iQIfjpp58QFBRU4HhtRfWm98iYmBjMmzcPTZs2hZ6enrLdkNtdsrC2d0GK054oqRs3bgDIO2xMLiMjoyINPVFUd+/exYQJE+Ds7AwtLS3lz+aTTz4BULyfzX9pampi5MiRiImJyTNm2T///IO//voLHTp0yFMk9fT0RGpqKho1aoSZM2fit99+e+sYxGWtJJ8jHB0dYWFhkW/5n3/+CXV1dfz+++/5PqP5+PggIyMD2dnZys9UJfm9aNasWYE3PO3s7PJkLc1nq9c9ePAAKSkpcHFxKfBmaEGvWapeOBgMVUq5jZ43GTFiBPbv3w9nZ2cMHjwYVlZWykLUypUrkZGRUeB+VlZW+ZbFxMQAAIKDg7Fo0aJSJP9XQkJCoecDoJxFqiwaGoWd520ZdHR0YGhoWGCGwp74yr1LlZ2dXbKgRZT7FOGzZ8/K9Jg9e/bE8ePHERAQgObNmyMxMRH79++Hg4NDnierDAwMcPXqVSxatAgHDx7EmTNnAOT8XCZNmoQlS5ZAR0enzLLl3hEvqAFSGoX9269evRozZ86EsbExevToAXt7e+jo6EAQBBw4cABBQUEF/g4V53WxfPly1KlTB35+fvD19cW3334LiUSCbt26wdfXt8Ax6CwtLQs8vqWlJSIjI5GYmAh9ff1S/X4Vtk9J5Z7jbROGACX7u7Br1y5899132LFjBxYvXgwA0NDQwIABA7B8+fIijZED5NxRP3HiBL777jts2bJFebe2SZMm8PHxKXDsRCKims7GxgYPHjwodnukNO9TBb3XFvQ+a29vj7///huLFi3C8ePHlQPbW1paYsaMGZg3b16BxYU3KSxvfHw8WrVqhZCQELi7u2PcuHEwMTGBuro64uPjsWrVqkLb3gWpiHZmUf8NSuvq1avo2rUr5HI5unXrhoEDB8LAwEA5cdrBgweL9bMpyMSJE/HTTz9h8+bNGDp0KIB/byTnTmaRa9asWbCwsMC6deuwdu1arF69GgDQunVrLF26tEKKPCX59y3s3ykmJgZyufytn9Fyb1KW5PfiTXlfL+CV5rPV6yrqtUmVF5+Ioyrp+vXr2L9/P7p164YHDx5g06ZNWLp0KXx8fLBw4cJ8j22/rqA7mrlPaQ0YMACiKL7xq6hyj1nQTD7Av10divqE2Nu86boKy5CamoqEhIQyy1CWcmcvfX0GqLKQ+9Rb7lNxe/bsQVpaGsaPH5/vZ2hjY4P169cjIiIC9+7dw48//ggHBwesWLECM2bMKNNc586dA5DztFJZKuh1IZfL4e3tDSsrK9y9exe7d++Gr68vFi1aBB8fn0KLYcWlpqaG6dOnIyAgAFFRUcrZ1s6cOYMePXooC+Cve/nyZYHHyl2ee/e8NL9fZf3UYW7jrSh3u0uSW0tLC/Pnz8f9+/fx4sUL7Nq1C71798bvv/+O3r17F6t7e69evXD69GnEx8fjwoULmDt3Lp4+fYp3330XFy9eLPJxiIhqipK2RyqqHVi/fn3s3LkT0dHRuHnzJpYvXw5dXV3Mnz+/RDNkFvYeuXHjRoSEhGDhwoXw9/fHunXrlL1IRowYUaprKC9F/Td4Xe7kX4X1gCiowPLll18iLS0NJ0+exPHjx7Fy5UosXrwYPj4+Zdauc3d3R6NGjXD8+HHIZDKIooht27ZBV1cX7777br7tR48ejcuXLyM2NhYnTpzAtGnTEBAQgD59+pRZb5yyVthrz9DQEPr6+m/9jPZ6r4qy/r14PQtQ+s9WJXltUvXCQhxVSY8fPwaQ8+jxf8c/uHbtGtLS0op1vAYNGsDIyAj+/v5vLOK9LvdOSmF37Jo3bw4ABc56ExUVhTt37kBXVxf169cvVtbi0NfXR506dfDy5Uvcu3cv3/rc4s9/Z2Aqrrf9LEpi2LBhMDU1xdWrV9865ohCoShyMWLQoEEwNjbGr7/+iqysLGzZsgWCIOQbW+N1giCgYcOG+PDDD3Hp0iVoampi//79xbqeNzlz5gyuXLkCHR2dIo+hV5qfeXR0NOLj49G2bdt8d9ySk5Pf+jh9SZiammLQoEHYsmULRo4ciaioKPz555/5tiuoGPTw4UO8fPkSderUUY5r96bfL6Bkr+2S/kxbt24NIGcW17d5W+7cD3qF5baxscGIESNw6NAhtGnTBg8fPszTZV4ikRQpv7a2Njp16oRly5bhu+++gyiKynHpiIjoXxMnToSGhgZ+++033Llz543bvv7U05v+3mdkZCjfA3O3Ky01NTU0a9YMc+bMUb4fvd5WKW1bLbft/c477+RbV943ckqaPfe99MKFC/nWxcfHIzAwMN9yY2NjAAX3yEhISCiwiPX48WOYmJjkG7MWKNufzetjHZ87dw5hYWEYNmzYG8cCNDAwQK9evbBmzRp88sknSE9Px4kTJ0qVo6htjbLSpk0bJCUlISgoqNj7vu33orjK6rNVgwYNoKOjg9u3bxc4bEtBr1mqXliIoyoptyvWf/9IyWQyTJs2rdjHU1dXx8yZM5X7p6am5tsmOjo6zxu2sbExBEEotKvCpEmTAABff/11nrsdoijis88+Q2pqKsaPH1/gQKpl6b333gMAfPLJJ3mKVYmJicrpy0s6pXqu3LENwsLCCt0mLCwMDx48UD6K/Tb6+vpYs2YNAGDUqFF5xsR43Z07d9C9e/cij72hqamJUaNGITo6GqtWrcLly5fRsWNHODs75ztuSEhIvv1jYmKQlZVV4ACwxSWKIvbu3Yvhw4cDyBk4uKhPoxXlZ14YCwsL6Ojo4MaNG3nGGsvKysLMmTMLHBuuuDIyMnDmzJl843GIogiZTAag4EF0V61aleeasrOzMXfuXIiimKfrxeDBg2Fqaop9+/bhjz/+yHOMzZs34/r162jcuLGySFYUJf2ZDhgwAI6Ojjh27FiBE7q8fPlSeWe9bdu2aNiwIa5du4bt27fn2e7cuXP4/fffYWZmhkGDBgHIKdoXNOFERkaG8q786z9HU1NTREVFFXgz4syZMwX+bcv9+1QWr2kiourGwcEBS5YsQVZWFvr27Vvg32QgZxyr199zxo4dC6lUinXr1uHBgwd5tl26dClevHiBvn375pnQq7iuXbtW4BM1Bf1dL027ASi87X3z5k3lAPXlpaTZx4wZA6lUirVr1+Kff/7Js27BggVITEzMt0/ueMYHDhzI87OVy+WYNWtWge+vjo6OiI2Nxa1bt/Is/+WXX4o1gcXbvD7WcWHdUoGcG4MF3aAu6HWRkJCABw8eFOtn+6a2RnmYM2cOAGDKlCl4/vx5vvXp6em4fPmy8vvi/F6URFl8ttLQ0ICnpydSUlLwxRdf5FkXHByMVatWlSojVX4cI46qpFatWqFdu3b4/fff0bZtW7Rv3x4vX77E8ePHUb9+/RI1aubPn4/bt29j48aNOHr0KLp16wZbW1tERUUhODgYly9fxrRp07By5UoAgJ6eHtq0aYMrV65gwIABaNGiBdTV1dGxY0d07NgRbdq0weeff46lS5eiSZMmePfdd2FoaIjTp08jICAALi4upXo0uqjmzJmDEydO4MSJE3BxcUH//v2RlZWF3377DS9evMC4ceMKfKS9OOrXrw87OztcunQJY8aMQb169aCmpoaBAweiadOmAIBx48bh4sWL2LRpU55JEd5k5MiRSEtLw0cffYQBAwagSZMmaN++PUxNTREfH49r167h+vXr0NXVLdab6oQJE7Bu3Trlm2VBec6cOYM5c+agdevWaNiwoXKMsoMHD0KhUODzzz8v8vmAnFk+cwfGT01NxYsXL3D58mWEhYVBS0sLy5cvVzY0iqJHjx74+++/MXToUPTt2xfa2tpwcHBQDsj/JhKJBB9//DGWLVsGFxcXDBo0CJmZmTh//jxiY2PRpUuXQp/YKqq0tDTl2HOtW7eGg4MDsrKycOHCBQQGBqJ169YF3jlu27YtmjVrhuHDh8PQ0BDHjx/H7du30apVK+WAxwCgq6uLzZs3Y9iwYejevTuGDRsGZ2dn3Lp1C0ePHoWRkRG2bt1arG6oRXkdF0QqlWLv3r3o1asXxo0bhw0bNqBNmzbIzMzEw4cPcebMGchkMhgZGUEQBGzZsgXdu3fHuHHjsGfPHri4uCA4OBi//fYbpFIptm7dqhx/8MWLF2jTpg3q16+PFi1awM7ODikpKTh58iT++ecfDBs2LM9McD169MDOnTvRu3dvdOzYEZqamnB1dcWAAQPw6aefIiQkBJ07d4ajoyO0tLRw69YtnDx5EqampoXOkkdEVNPNnTsXcrkcCxcuRJs2beDu7g53d3cYGhoiOjoaf/75J+7cuZNnRkoHBwesXr0aU6dORcuWLTF8+HBYWVnhypUruHjxImxtbfHjjz+WKtfOnTuxdu1adOjQAXXr1oWpqSmePn2KgwcPQk1NDZ999ply27Zt20JHRwe7du2CVCqFvb09BEGAp6cnHBwc3nqucePGwdfXF7NmzcL58+dRt25d/PPPPzhy5AiGDh2K3bt3l+pa3qSk2R0cHODr64uZM2eiRYsWGD58OMzNzXHx4kU8ePAAHTt2zHczT0NDA7Nnz4aPjw/c3NyU47GdP38eoijC1dU135NZs2bNwsmTJ9G+fXtl++X69eu4fPky3nnnHezbt69Mfg7W1tbo1asXjh07hvv378PR0bHASa5GjRoFqVSKDh06wNHREYIg4Nq1a7h06RJq166tvAEM5DwdNnHiRHTq1KnIT2G9qa1RHrp06QJfX1/MnTsXdevWRd++feHs7Iy0tDSEhYXhjz/+gKOjo/KBieL8XpREWX22+vrrr3H27Fn89NNPuHnzJrp06YLo6Gjs2bMHXbt2VY5tR9WUSFSJABCL+rKMiYkRp06dKjo4OIiampqis7Oz+Pnnn4spKSmig4OD6ODgkGf7TZs2iQBEb2/vQo+pUCjEnTt3ij169BBNTExEDQ0N0crKSvTw8BAXLFggPnr0KM/2wcHB4uDBg0VTU1NRIpEUePw9e/aIHTt2FPX19UWpVCrWr19f/OKLL8TExMQiXWcuAPmuSRRF8fz58yIAcfz48YXum56eLi5btkx0cXERtbS0RB0dHdHd3V3csGGDqFAoinwuURRFb29vEYB4/vz5PMtv3Lghdu/eXTQ0NBQFQRABiJs2bVKu79SpU75lRfX8+XPxf//7n9iyZUvR2NhYVFNTE42NjcV27dqJX331lSiTyYp9zMaNG4sARF1dXTEpKSnf+nv37omzZ88WW7ZsKZqbm4tSqVS0s7MT+/fvL546darI5xk/frzydQ1AFARB1NfXFx0dHcV+/fqJvr6+YkRERLHzp6SkiNOnTxft7OxEdXV1EYDYqVMn5XoHB4c3/i5lZWWJy5cvFxs2bChqaWmJlpaW4tixY8XQ0FBl5pCQEOX2ISEh+c5R0HXm7pOZmSl+++23Yp8+fUR7e3tRS0tLNDExEVu0aCEuX75cTE5OzrN/7usjODhY/Pbbb8V69eqJUqlUrFWrljh79uxCf19u3LghDhs2TDQ3NxfV1dVFGxsbccKECWJwcHC+bXNfu296Db7tdfwmYWFh4rRp00QnJydRKpWKxsbGYosWLURvb28xMzMzz7b//POPOH78eNHGxkbU0NAQzc3NxXfeeUe8efNmnu3i4uLExYsXi126dBFr1aolSqVS0cLCQmzbtq24YcMGMSsrK8/2MplM9PT0FK2srJR/k3L/NuzevVscNWqUWLduXVFPT0/U09MTGzRoIM6ePVsMCwsr0jUSEdVkDx8+FGfOnCm6uLiIBgYGorq6umhubi527dpVXL16dYHvVWfPnhV79+4tGhsbixoaGqKjo6M4Y8YMMTIyMt+2ue+l/21jieK/78Ovt/euXr0qTp06VXR1dRVNTExETU1N0dHRURwxYoTo7++f7xinTp0S27VrJ+rp6SnbJbnnKko7+e7du+KAAQNEc3NzUUdHR2zevLm4YcOGArO9fj2laU8UJfvb7Nu3T2zVqpWopaUlGhsbi0OHDhUfPXpU6LkUCoXo6+sr1qlTR/lZ4MMPPxRjYmKU7ZX/Onz4sOjh4SHq6emJhoaGYo8ePcSLFy8qf67/bUsU9FmlKPbu3au8/sL+rX788UdxyJAhorOzs6ijoyMaGhqKLi4uore3txgdHZ1n29x8hf17FORNbQ1RLP7niLe9JnL99ddf4siRI8VatWqJGhoaoomJidikSRNx6tSp4sWLF5XbFef34m2fowr79y7JZ6uCrk8mk4lTpkwRLS0tRU1NTbFhw4biypUrxSdPnrz18x1VbYIoFmP0eSIionLQuXNnXLx4ESEhIUWeBZSIiIiIiKiq4RhxREREREREREREFYCFOCIiIiIiIiIiogrAQhwREREREREREVEF4BhxREREREREREREFYBPxBEREREREREREVUAFuKIiIiIiIiIiIgqAAtx5UQul+P58+eQy+WqjkJERERElQTbiERERDUbC3HlJDIyEnZ2doiMjFR1FCIiIiKqJNhGJCIiqtlYiCMiIiIiIiIiIqoALMQRERERERERERFVABbiiIiIiIjKma+vLywsLODq6qrqKERERKRCLMQREREREZUzLy8vyGQyBAUFqToKERERqRALcURERERERERERBWAhTgiIiIiIiIiIqIKwEIcERERERERERFRBWAhjoiIiIiIiIiIqAKwEEdERERERERERFQB1FUdgIpOFEUEPovHibuRSEyTw0BbHb0bW6GZnREEQVB1PCIiIiIiIiIiegMW4qqIh5FJmLMnEHfDEwEAEgFQiMD6i0/Q2MYAK4Y3Q30rfRWnJCIiIqKKpkhNRcLhI0g4fAjZMbFQMzWB4YCBMBzQHxIdHVXHIyIiotewEFcFPIxMwrAfryA1U65cphD/XX8/IhHDfryC36a2ZTGOiIiIqAbJeBKCsEmTII+MBAQBEEUgNBRp128g+scfYe/nB01nJ1XHJCIiolc4RlwlJ4oi5uwJRGqmPE/x7XUKEUjNlGPOnkCIYiEbEREREVG1okhNzSnCyWQ5C3Lbga/+K5fJEDZpEhSpqSpKSERERP/FQlwlF/gsHnfDEwstwuVSiMDd8EQEPU+omGBEREREpFIJh4/kPAmnUBS8gUIBeWQkEo4cqdhgREREVKgqUYjLysrCwoULYW9vDy0tLTRt2hQ7d+58636hoaEYPXo06tatCz09PRgZGcHd3R1bt24t8MmxpKQkzJgxA1ZWVtDW1kbr1q1x6tSp8rikIjtxN7J4298p3vZEREREVDUlHD6U0x31TQQBCYcPV0wgIiIieqsqMUbclClTsHXrVkybNg0uLi44cOAAxowZA7lcjnHjxhW6X0REBF6+fImRI0fCzs4OmZmZOH36NMaPH4979+5h2bJlym1FUcTgwYNx5coVzJkzB/b29tiyZQv69u2LU6dOoWvXrhVxqfkkpsmVEzO8jSBm42V0DJ49e4akpCQkJSUhOTkZDRs2hI2NTfmHJSIiIqIKkx0T+2931MKIIrKjYyomEBEREb2VIFbyQcVu3ryJ5s2bY/HixViwYAGAnKJZ165dce/ePTx79gxSqbRYxxwwYADOnj2LhIQEaGhoAAD279+PoUOHYuvWrfD09AQAZGRkoGnTptDW1kZgYGCxzvH8+XPY2dnh2bNnsLW1Lda+r1t6/D7WX3zy1u1iTq5DdkosmtZ1RI/mdaCmpqZc16tXL7Rp06bEGYiIiIiobJRVGxEAQseORdqNgDcX4wQB2i1bwHHbtlKdi4iIiMpGpe+aumfPHkgkEkybNk25TBAETJ8+HTKZDOfPny/2MR0cHJCWloaMjIw85zExMcHo0aOVyzQ1NTFlyhQEBQXh4cOHpbuQEurd2KpI26mb2EDMSodmRhxSUlLg6OiI4cOHY/78+SzCEREREZWBNWvWwM3NDerq6vDx8VF1HBgOGFikJ+IMBwyomEBERET0VpW+EHfjxg3Url0bJiYmeZZ7eHgAAAICAt56jNTUVERHRyMkJAQbN26En58f2rRpAz09vTznadGiRZ4nyYp7nvLQzM4IjW0MIHnL8B9GrQaiQccBaNqwLiwtLfHs2TPs2bMH69evx59//onk5OSKCUxERERUTdWqVQuLFy/GkCFDVB0FAGA4oD/UrawASSFNekGAupUVDPv3r9hgREREVKhKX4gLDw+HtbV1vuW5Y56Fh4e/9Rhff/01zM3N4ezsjPfffx/t2rXD7t27y/Q8iYmJeP78ufIrIiLirbmKQhAErBjeDDpS9UKLcRIB0NWUYsfXs9GiRQukp6ejX79+6Nu3L9TV1XH69GmsWLECO3fuxP3795GdnV0m2YiIiIhqkiFDhmDAgAEwNDRUdRQAgERHB/Z+flC3sMhZ8J+JGwRNTdj7+UGio6OCdERERFSQSl+IS0tLg6amZr7lEokEGhoaSEtLe+sxJk6ciNOnT2PHjh149913kZWVhdTU1CKdR0tLS7n+TVasWAE7Ozvll7u7+1tzFVV9K338NrUtGlobKJe93sxqaG2A3z9qi0a1jDFq1ChYW1vj6NGjsLCwwAcffICpU6fC3d0dz58/x+7du7F8+XKcOHECkZGcYZWIiIiqp+TkZHh7e6Nv374wNzeHIAh5Jup6XVZWFhYuXAh7e3toaWmhadOm2LlzZwUnLhlNZyfUPnYUVosXQbtlC0idnKBhZwcAENPTkfXihYoTEhER0esq/aypWlpaecZyy6VQKJCVlaUslL1J7dq1Ubt2bQDA6NGjMWHCBPTo0QOPHj1S7l/YedLT05Xr32TOnDmYPHmy8vuIiIgyL8YdmdEeQc8TcOJOJB5EJOLCoygAwJaJrWCmn5NPU1MTY8aMwZYtW/DgwQM4OjrC0tISvXv3Vl5zYGAgrl27hqtXr8LKygpubm5wcXGBDu+WEhERUTURHR2NxYsXw9bWFm5ubjh9+nSh206ZMgVbt27FtGnT4OLiggMHDmDMmDGQy+UYN25cBaYuGYmODoyHD4fx8OEAAFEuR+iIkUi/exeRPj5wPnyIT8URERFVEpW+EGdjY4OnT5/mW57bVTS362hxDB8+HFu2bMHFixfRq1cv5XEK6k5a1PMYGBjAwMDgjduUliAIaGZnhGZ2RkhKz4LrolNQiMD1p3Ho3eTfbrW6urqYOnUq/jshrpqaGho2bIiGDRsiOTkZt27dws2bN3H8+HGcOnUK9evXR7NmzVCnTh1IChtrhIiIiKgKsLa2xosXL2BjY4PQ0FA4OTkVuN3NmzexefNmLF68GAsWLAAATJ48GV27doWXlxdGjhwJqVRakdFLTVBXh/WXSxDyzrvIevECUT+sgeXcz1Qdi4iIiFAFuqY2b94cwcHBiI2NzbPc399fub64cruZJiQk5DlPQEBAvvHTSnOe8qSvpYEmtXLGJ7n6JDbfekEQ3lhM09PTQ9u2bfHRRx/h/fffh5ubG548eYKdO3fi+++/x+nTpxEdHV1u+YmIiIjKk6amZpFu2O7ZswcSiQTTpk1TLhMEAdOnT4dMJsP58+fLM2a50WrYEKYTJwAAYrdsQdqdu6oNRERERACqQCFu+PDhUCgUWLdunXKZKIpYs2YNzM3N0aVLFwA53Q8ePHiQZ+w3mUyW73iiKMLPzw+CIOQprg0fPhwxMTH49ddflcsyMjLw888/w8XFBQ0aNCiPyysVD6ecmWSvheQvxBWVIAioVasW+vfvj08++QTDhg2DhYUFrly5gjVr1mDjxo24ceOGsosuERERUXVy48YN1K5dGyYmJnmWe3h4AAACAgKUy+RyOdLT05GdnZ3n/9+kvCb0KgqzadNyxotTKBCxcAFEubzCzk1EREQFq/RdU1u0aAFPT094e3sjKipKOW7HhQsX4Ofnp5xgYc2aNVi0aBHOnz+Pzp07AwDmzp2Lf/75B926dYOdnR2io6Oxd+9eBAQEYMaMGahTp47yPEOHDkWnTp3w/vvv48GDB7C3t8eWLVsQHByMEydOqOLS38rdyRQbLoXgfmQiElKzYKijUarjaWhowMXFBS4uLkhISEBQUBACAwNx+PBhnDhxAg0bNkSzZs3g5OQEQShkClciIiKiKiQ8PBzW1tb5luc+TZc7TAkAfPnll1i0aJHy+6+++gqbNm3ChAkTCj3+ihUr8uxTkSTa2rDy8caz9yYj4959xG7ZCtP3JqkkCxEREeWo9IU4ANi4cSMcHBywefNm/PTTT6hXrx62bduGsWPHvnG/IUOG4KeffsKGDRsQHR0NbW1tNG3aFJs2bcL48ePzbCsIAg4dOoT//e9/2LBhAxITE+Hi4oIjR46ge/fu5Xl5JebuaAJBAEQR+Ds0Ft0bWZbZsQ0NDdGxY0d06NABYWFhCAwMxN27d3Hr1i0YGRnB1dUVzZo1g7GxcZmdk4iIiKiipaWlKW/svk4ikUBDQ0M5pAkA+Pj4wMfHp1jHL+8Jvd5Gr107GA4aiISDhxD1ww/Q79UTUlvbCjs/ERER5SWI/x3Rn8rE8+fPYWdnh2fPnsG2HBs7vVf+gQeRSZjS0Rn/69uw3M4DAJmZmbh37x5u3rypnEDD0dERzZo1Q6NGjarcQMZERERUM+RO1rB06VLMmzcvz7rGjRvDzMwMFy9ezLNcoVBATU0N06ZNw5o1a8osS0W1EV8nj4vDkz59kR0fD9127WC3cQN7NxAREalIpR8jjt6stbMpAMD/SUy5n0sqlaJZs2aYOHEiPv74Y3Ts2BFxcXE4cOAAvvvuOxw8eBBhYWH5ZmslIiIiqqxsbGwKHLctt0tqUSZ8KApfX19YWFjA1dW1TI5XHOrGxrD8PKcAmfLnn0g8cqTCMxAREVEOFuKqOPdXEzbcCU9EckbFDcBrYmKCrl27YubMmfD09ET9+vVx+/Zt+Pn5Yc2aNbh06RISExMrLA8RERFRSTRv3hzBwcGIjc07+ZW/v79yfVnw8vKCTCZDUFBQmRyvuAwGDoRu27YAgJdfL4U8Lk4lOYiIiGo6FuKquNxCXLZCxI2nFd+gkkgkqF27NoYNG4ZPP/0U/fv3h5aWFs6ePYvvv/8e27dvx507dyDnLF1ERERUCQ0fPhwKhQLr1q1TLhNFEWvWrIG5uTm6dOmiwnRlRxAEWC3ygaClhey4OMiWfaPqSERERDVSlZisgQpnpqeJ2ua6CI5KwbWQGHSqZ66yLFpaWmjZsiVatmwJmUyGwMBA3Lp1C/v27YO2tjaaNGkCNzc3WFtbc1wSIiIiKndr1qxBfHw84uPjAQDnz59X3hycMWMGDA0N0aJFC3h6esLb2xtRUVFwcXHBgQMHcOHCBfj5+RU4kUNVJbWzg/mM6ZD5foeEgwdhOOjfp+SIiIioYnCyhnJSkQPx/m//bez0D0NLB2Psm1q5GlPZ2dl4/PgxAgMD8fDhQygUClhYWMDNzQ1NmzaFrq6uqiMSERFRNeXo6KicYOq/QkJC4OjoCCBnQqolS5Zg8+bNkMlkqFevHubOnYuxY8eWWRZfX1/4+voiOzsbsbGxFTpZw+tEuRwh7w5Hxv370LCzg/Ohg5Boa1d4DiIiopqKhbhyUpGFuIOBLzBzVyA01ATc9ukFLQ21cj1fSaWkpOD27du4efMmXr58CYlEgnr16qFZs2aoW7cu1NQqZ24iIiKisqKKWVP/K+3OXYQOHw4oFDB9fzIsPvlEJTmIiIhqInZNrQY8nHJmTs3KFhEQFoe2tc1UnKhgurq6aN26NTw8PBAZGYmbN2/i9u3bePDgAXR1ddG0aVO4ubnBwsJC1VGJiIiIqi3tJo1hMm4cYjdvRozfJhj07Quthg1VHYuIiKhGYCGuGrAy1IKDqQ6exqTC/0lspS3E5RIEAdbW1rC2tkbPnj3x8OFDBAYG4urVq/jrr79gY2MDNzc3NGnSBNrsKkFERERU5sw/noGkU6eQFR6OiAUL4bh7FwT2TiAiIip3LMRVE+6OJngak4prIbGqjlIs6urqaNy4MRo3bozExETcunULN2/exNGjR3Hy5Ek0aNAAzZo1g7OzMyQSTvJLREREVBYkOjqwWuSDZ+9PQfqdO4jbvh0m48erOhYREVG1x0JcNeHhbIq9N54jICwOGfJsaKpXvTuaBgYGaN++Pdq1a4fnz5/j5s2buHv3Lu7cuQMDAwO4urqiWbNmMDU1VXVUIiIiomJ5fbKGykKvQwcY9O+PxCNHIFu1Gvrdu0OjVi1VxyIiIqrWOFlDOanogXifxaaiw7fnAQD7PmyDlo4m5X7OipCZmYn79+8jMDAQISEhAAB7e3u4ubmhUaNG0NTUVHFCIiIioqKrDJM1vE4eE4MnffshOyEBuh07wG79egiCoOpYRERE1Rb7+lUTtsbasDHUAgD4V7HuqW8ilUrh6uqK8ePHY+bMmejcuTMSExNx8OBBLF++HAcOHEBoaChYTyYiIiIqPnVTU1jMnQsASPnjEhKPHVNxIiIiouqNXVOrCUEQ4OFsiv03X8A/JBbTuqg6UdkzNjZG586d0alTJ4SGhiIwMBB3795FYGAgjI2N0axZMzRr1gyGhoaqjkpERERUZRgOGYyEQ4eQevUqXn69FHrt2kHNyEjVsYiIiKolPhFXjbg75XRHvREaC3m2QsVpyo8gCHBycsKQIUPw6aefYuDAgdDT08P58+fx448/IisrS9URiYiIiKoMQRBgvcgHgqYmsmNi8NLXV9WRiIiIqi0W4qoRj1eFuJTMbNwJT1RxmoqhqamJ5s2b47333sP06dMxbNgwaGhoqDoWERERUR6+vr6wsLCAq6urqqMUSOrgALNp0wAACb/9jpSr/ipOREREVD2xEFeNOJnpwkwvZ/KCayExKk5T8czMzFC3bl1VxyAiIiLKx8vLCzKZDEFBQaqOUijTiROgWb8+ACDS2xuK9HQVJyIiIqp+WIirRnLGict5Ks7/SfWZsIGIiIiIyp+goQHrJYsBQUDm06eI/vEnVUciIiKqdliIq2Zyu6deC41FtoIziRIRERFR0Wk3bQrjsWMBADG//IL0h49UnIiIiKh6YSGumvFwMgUAJKXL8SCyZowTR0RERJVPcnIyUlJSVB2DSsB85kyoW1sDcjkiFi6AmJ2t6khERETVBgtx1UxdCz0Y6eRMVnAthN1TiYiIqGKcO3cOM2bMgJubG7S1tWFoaAgDAwNoa2ujefPmmD59Os6ePavqmFQEanq6sFq4AACQHnQLcb/uUnEiIiKi6oOFuGpGIhHg7shx4oiIiKj8yeVyrF27Fk5OTujevTu2b98OMzMzjBs3Dp999hm8vLwwbtw4mJiYYMeOHejRowccHR2xdu1ayOVyVcevUJV91tT/0u/SBfp9egMAolasQFZEhIoTERERVQ+CKIocSKwcPH/+HHZ2dnj27BlsbW0r9Ny/XA7BkiP3YKIrxY353SEIQoWev6rp3LkzWrZsie+++07VUfKpzNmIiIicnJyQkpICT09PjBgxAu7u7m/c/urVq9i7dy+2bdsGPT09PHnypIKSVh6qbCMWlzwqCsH9+kORmAi9Ll1gu24t25VERESlxCfiqqHcCRtiUzLxWJas4jRF07ZtWwiCAEEQoKGhgVq1amHEiBG4efOmqqNVmM6dO+PTTz/Ns+z333+Ht7e3ihIRERG92axZs/D06VMsX778rUU4AGjdujWWL1+O0NBQzJo1q/wDUqmom5vDwiunbZJ8/jySTp5ScSIiIqKqj4W4aqihtQH0NdUBAFerwDhxCoUCt27dwsqVKxEREYGHDx9i8+bNSEpKQps2bfDXX3+pOqLKmJiYQF9fX9UxiIiICjRz5kxoa2srv8/MzCzSfjo6Ovj444/LKxaVIaN33oFOq1YAgMivvkR2QoKKExEREVVtLMRVQ2oSAS0djQFUjQkbHjx4gJSUFHTs2BFWVlZwdnZGjx49cOjQIdSrVw+LFi0qdN+0tDRMnz4d5ubm0NLSQpcuXXD79m3l+s6dO2P27NmYPXs2jIyMYGtri7Vr1xZ6vB07dsDCwgJZWVl5lnfr1g0zZswoVY6ZM2fio48+gqGhISwsLLB06VLl+gkTJuDixYtYvny58snA0NDQfE/Jde7cGbNmzcL06dNhaGgIa2trbN++HXFxcXj33Xehp6eHxo0bw9/fX7mPo6Mj1qxZkyevmZkZNm/eXKrjEhER/Zeenh7279+v6hhUhgRBgNWiRRCkUmRHRUO2fIWqIxEREVVpLMRVUx7OpgAA/ycxqOzDAAYEBEAqlaJx48Z5lqurq6NLly4ICgoqdN/PPvsMBw8exI4dO3D9+nVYWFigd+/eSE1NVW7j5+cHS0tL/P3335g9ezZmzJiB+/fvF3i8YcOGISsrC0ePHlUuCwsLw4ULFzBx4sRS59DT08O1a9fwzTffYNGiRdi7dy8AYNWqVWjTpg2mTp2KiIgIREREwM7OrsBz+fn5wdbWFtevX8fEiRMxefJkjBkzBkOGDMHNmzfRqFEjjBs3rtj/7uV1XCIiqjnkcjnS0tIKXX/t2rU8N6KoatB0doLZ1A8BAPF79iD1779VnIiIiKjqYiGumnJ/NU6cLCkDT2NS37K1agUEBKBx48aQSqX51kml0gKXA0BycjLWr1+P5cuXo2fPnmjSpAk2bdqEjIwM7NixQ7ldixYtMG/ePNStWxeffPIJrKyscPHixQKPqaWlhdGjR2PTpk3KZVu2bEGTJk3QvHnzUuVwdnbGt99+i/r16ysLXd9//z0AwNDQEFKpFDo6OrCysoKVlRXU1NQKPN/r1+Pt7Q25XI569eph9OjRqFu3LubOnYtHjx7hxYsXBe5fmPI6LhERVW9nz56Fn5+f8sbZmwbz/+effzB//vyKikZlyPS996BZtw4AIGKhNxRF7IZMREREebEQV0251DKEjjSnkOMfEqPiNG8WEBBQaJHr0aNHaNCgAXbs2AE9PT3l16VLlxAcHIysrCy0a9dOub2Ojg7c3NzyPPHm4uKS55jW1taQyWSF5pk0aRKOHTum3Gbr1q1vfBquqDk8PDzy7NemTZtCn8x7k6ZNmyr/X1NTE4aGhnmeJrS0tASAN15jRR6XiIiqt6tXr2Ly5Mlo3rw5BEHA/PnzMWbMGHzzzTc4fvw4wsPDlduGhYXV2LFPfX19YWFhAVdXV1VHKRFBKoXV4sWAICAzJAQxP61XdSQiIqIqiYW4akpDTYIWDjnjxPlX4nHiRFFEYGAg3Nzc8q2LiYnBmTNnMGzYMAwcOBCBgYHKr5YtWyq3+++dd1EU8yzT0NDIs14QBCgUikIztWjRAo0bN8b27dtx6dIlPH36FGPGjHnrtbwtR0FPCLzpqYHCFHQ9ry/LPWbuNUokknzdSf87Bl5JjktERAQAX3zxBZ48eYJff/0VoijCyMgI169fxxdffIF+/frBzs4OZmZmaNasGXx8fNC+fXtVR1YJLy8vyGSyNw65UdnpuLnBeNQoAED0hg3IePxYxYmIiIiqHhbiqjF3x5zuqf5PKm8hLjg4GAkJCfmeiMvMzMSkSZNgb2+PCRMmQF9fH3Xq1FF+aWtro3bt2tDQ0MDly5eV+6WlpSEwMBANGzYsVa733nsPmzZtwubNm9G/f3+Ym5sXum1Rc/x3ooOrV6+iQYMGyu+lUimys7NLlbsg5ubmiIyMVH4fGhqKxMTEMj8PERHVXI6Ojhg+fDhat26NBQsW4OHDh0hKSsLVq1fx008/YdSoUbC3t8eHH36IX375RdVxqRTM58yGuqUlkJWFiAULIfIGHRERUbGoqzoAlZ/cCRtexKfheVwqbI11VJwov4CAAAiCAAsLC0RGRiIhIQHXrl3DihUrkJiYiOPHjxc6Rpyenh4++OADfPLJJzAyMkKtWrWwePFiSKVSjB49ulS5xowZAy8vLzx8+BC//fbbG7ctao7g4GDMmzcPkyZNwl9//YWNGzfmmbnU0dERV69exdOnT6GrqwsTE5NSXUOuzp07Y+vWrejfvz+0tLQwb968Qn+mREREpXHlyhXl/2tra8Pd3R3u7u4qTERlTU1PD1YLF+D5tOlIu3kT8bt3K5+SIyIiordjIa4aa2prCKm6BJlyBa6FxFbKQtzNmzchiiLq1KkDNTU1GBkZoXHjxhgzZgw++OCDt44j8+2330IURYwZMwZJSUlo3bo1Tpw4AR2d0l2riYkJBg0ahIsXL6JPnz5v3b4oOSZNmoTY2Fi0bNkSmpqamD9/PkaMGKFc/+mnn2L8+PFo2LAh0tLSEBISUqpryPX555/j8ePH6NmzJywtLbFixQrcuHGjTI5NRERENY9+t27Q79EDSadPQ7Z8BfS6doXGq7FkiYiI6M0E8b+DR1GZeP78Oezs7PDs2TPY2tqqLMeI9X/BPyQWI1ra4Zt3mr59B1Lq2LEjWrdujW+//bbUx+rcuTNatmyJ7777rgySERERUVVVWdqIpZX1UoYn/fpBkZwM/R7dYfvDD6qOREREVCVwjLhqzsMpp3vjtdDKO05cZRMbG4vt27fjypUrmDp1qqrjEBEREVU6GpYWsPj0EwBA0ukzSDx9WsWJiIiIqgYW4qq53HHiQqJTIEtMV3GaqqF58+aYPn06vv/+ezg5Oak6DhEREVGlZDR8OLRbtAAAvFzyJbKTklSciIiIqPLjGHHVnJu9EdQlAuQKEf4hsRjgaqPqSJVeaGhomR/zwoULZX5MIiKiyiYwMBBqampwcXHJs9zf3x9GRkYwNDTEvXv30LVrVxUlpLIkSCSwXrwITwYPgVwmQ9T338Nq4UJVxyIiIqrU+ERcNacjVUdTW0MAgH9IjIrTEBERUXW2detWDBo0KM+y5ORkdOrUCWfPnsXRo0fRo0cPFaWj8qBZuzbMpkwBAMT9ugupATdVnIiIiKhyYyGuBsjtnur/hOPEERERUfnp3bs3nj59ivv37yuXnT17FllZWejbt68Kk6mer68vLCws4OrqquooZc70gymQOjsDooiIhQugyMxUdSQiIqJKi4W4GsD91YQN/8iSEZOcoeI0REREVF117twZ2traOHbsmHLZ8ePHUa9ePTg6OqouWCXg5eUFmUyGoKAgVUcpcxKpFNZLFgMAMh8HI2bjRhUnIiIiqrxYiKsBWjoYQyLk/P/fnD2ViIiIyolUKkXXrl1x9OhR5bLjx4+jT58+KkxFFUGnRQsYjRgBAIj58SdkPHmi4kRERESVEwtxNYC+lgYa2+SOE8dCHBEREZWfPn364MqVK0hOTsbdu3fx7NmzGt8ttaaw+GQO1M3NIWZlIWLhQogKhaojERERVTpVohCXlZWFhQsXwt7eHlpaWmjatCl27tz51v0ePHiAzz//HM2bN4ehoSH09PTQvn17HDhwIN+2mzdvhiAIBX49fvy4HK6qYnm86p7KceKIiIioPPXt2xeZmZk4deoUjh8/Dl1dXXTq1EnVsagCqBkYwHL+fABA2vUbiN+3T8WJiIiIKh91VQcoiilTpmDr1q2YNm0aXFxccODAAYwZMwZyuRzjxo0rdL+NGzdiw4YNGDp0KN5//31kZGRg27ZtGDJkCDZs2IDJkyfn28fHxwe1a9fOs8zKyqrMr6miuTuZYOPlENyPTERCWhYMtTVUHYmIiIiqIQcHBzRo0ADHjh1DaGgounTpAg0NtjtqCv2ePaDXrRuSz56FzPc76HXuDA0LC1XHIiIiqjQqfSHu5s2b2Lx5MxYvXowFCxYAACZPnoyuXbvCy8sLI0eOhFQqLXDfkSNHwtvbG/r6+splH330EVq2bInPP/8ckyZNgkSS96HAXr16oXXr1uV3QSri7mQCQQBEEbgeGotuDS1VHYmIiIiqqb59+2Lbtm1ISEjAypUrVR2HKpAgCLBaMB9Prl6FIikJL79eCtuV36s6FhERUaVR6bum7tmzBxKJBNOmTVMuEwQB06dPh0wmw/nz5wvdt2XLlnmKcEDOIMIDBgxAdHQ0ZDJZgfslJSUhOzu7bC6gkjDSkaK+Zc7PguPEERERUXnq06cPoqKikJWVxYkaaiANKyuYz5kNAEg6cQJJ5wpvrxMREdU0lb4Qd+PGDdSuXRsmJiZ5lnt4eAAAAgICin3M8PBwqKurw9DQMN+6Hj16wMDAADo6Oujbty/u3btXpGMmJibi+fPnyq+IiIhi5ypvynHiWIgjIiKictSxY0fo6OigQYMGcHBwyLNOFEUVpaKKZDxyJLRdXQEAkYsXIzs5RcWJiIiIKodKX4gLDw+HtbV1vuU2NjbK9cURHByMXbt2YeDAgdDW1lYu19HRwfjx47FmzRrs378f8+bNw6VLl9C2bdsiTdawYsUK2NnZKb/c3d2LlasieDibAgDuvEhAcoZcxWmIiIioutLQ0MAff/yBff8ZrL9fv35v7M1A1YegpgarJYsBdXXIIyMRxS7KREREAABBrOS3JWvXro3atWvj1KlT+dZJpVKMGzcOGzduLNKxUlNT0aFDBwQHByMoKCjfHdr/+vvvv9G6dWuMGjUK27dvf+O2iYmJSExMVH4fEREBd3d3PHv2DLa2tkXKV96ikjLQ6qszAICtk9zRsZ65ihMRERER1SzPnz+HnZ1dpWojlifZypWI+Wk9IAhw3PWr8ik5IiKimqrSPxGnpaWFjIyMfMsVCgWysrKgpaVVpONkZWXh3Xffxd27d/Hbb7+9tQgHAK1atUKHDh1w5syZt25rYGAAW1tb5VdBT/Gpmrm+Jmqb6wIA/ENiVJyGiIiIiKo7s6lTIXV0BEQREQsWQszKUnUkIiIilar0hTgbG5sCx1vL7ZKa20X1TRQKBcaNG4eTJ09ix44d6NatW5HPb29vj9jY6jOmmrtTTvfUaxwnjoiIiIjKmURTE1aLFwEAMh49QozfJhUnIiIiUq1KX4hr3rw5goOD8xXD/P39levf5sMPP8SuXbvw008/YdiwYcU6/5MnT2BhYVGsfSqz1s45EzYEPUtAelb1mhmWiIiIqLLy9fWFhYUFXGtg10xdd3cYvpPTBo9euxaZoaGqDURERKRClb4QN3z4cCgUCqxbt065TBRFrFmzBubm5ujSpQsAIDo6Gg8ePEBqamqe/T/99FNs2LAB33zzDSZPnlzoeQp66u3s2bP4888/0atXrzK6GtVzfzVzama2AgFhcSpOQ0RERNVFVlYWEhISVB2j0vLy8oJMJkNQUJCqo6iEpZcX1MzMIGZmIsLbh7PnEhFRjaWu6gBv06JFC3h6esLb2xtRUVFwcXHBgQMHcOHCBfj5+UFTUxMAsGbNGixatAjnz59H586dAQCrV6/G8uXL4ebmBhsbm3wTLgwZMgS6ujljprVv3x5ubm5o2rQpjIyMEBgYiF9++QWWlpbw8fGpyEsuV9aG2rA30UFYbCquhcSibW0zVUciIiKiKiw2NhYTJkzAiRMnkJ2dDScnJyxatAhjxoxRdTSqRNQMDWH1v8/xYs4nSPX3R8Lv+2E0bKiqYxEREVW4Sl+IA4CNGzfCwcEBmzdvxk8//YR69eph27ZtGDt27Bv3CwgIAADcvHkTnp6e+daHhIQoC3FDhw7FsWPHcOzYMaSkpMDKygoTJkyAt7c3atWqVfYXpUIeTiYIi02F/xOOE0dERESl8/nnn+PIkSNo2bIlrKyscO3aNYwbNw4SiQSjRo1SdTyqRPT79IHewUNIvngRL7/9FnqdOkLdjDeFiYioZhFEPhdeLirz1PR7rz+D175b0FSX4LZPL0jVK30PZSIiIqqk7O3t0a5dO/z6668AgOTkZPTv3x9hYWF48uSJitNVPpW5jVgRssLDEdx/AMTUVBj064day79TdSQiIqIKxQpMDdTaOWfm1Ay5Areex6s2DBEREVVpL168yDOerp6eHnx8fPD06VMEBwerMBlVRho2NrCYNRMAkHj0KJL/+EPFiYiIiCoWC3E1kK2xNqwNtQAA/iHsnkpEREQlJ4oipFJpnmUNGjSAKIqIiIhQUSqqzIzHjIGWiwsAINJnERQpKSpOREREVHFYiKuBBEGAx6vZU1mIIyIiotIKDw9HVlaW8nsNDQ0AQGZmpqoiUSUmqKnBesliQE0NWeHhiFr9g6ojERERVRgW4mood6ec7qk3QmMhz1aoOA0RERFVZXPnzoWenh7c3NwwefJkbNiwAYIgIDs7W9XRqJLSatAAppMmAQBit21D2u3bKk5ERERUMarErKlU9jycc56IS8nMxt3wRLjaGak2EBEREVVJ58+fR1BQkPJrx44dyMjIAAD06dMHDg4OaNy4MRo1aqT8b4sWLVScmioDs2kfIfHkSWSFhSFiwUI47d0D4dXTlERERNUVC3E1lLOZLsz0NBGdnAH/kBgW4oiIiKhEOnXqhE6dOim/z87OxsOHD/MU565fv44jR44AAJ+UIyWJlhasF/kgbOIkZDx4gNgtW2A6ebKqYxEREZUrFuJqqNxx4o7ejsC1kFhM6Vhb1ZGIiIioGlBTU0OjRo3QqFEjjBo1Srk8KioKgYGBuHXrlgrTUWWj26YNDAcPRsKBA4hasxb6PXtCam+v6lhERETlhmPE1WC53VOvhcQiWyGqOA0RERFVZ+bm5ujRowc++eQTVUehSsZi7mdQMzGBmJ6OSB8fiCLbpUREVH2xEFeDub+aOTUxXY6HkUkqTkNERERENZG6sTEsP/8cAJBy5S8kHjqk4kRERETlh4W4GqyehT6MdHIGxPUPiVFxGiIiIiKqqQz694Nu+/YAgJdLl0EeG6viREREROWDhbgaTCIR0Mrx3+6pRERERFS46Oho9O/fH7q6uqhbty5OnDih6kjVhiAIsPLxhqCtjez4eLxctkzVkYiIiMoFC3E1nIfTv4U4jsdBREREZWXx4sVYsmQJ0tPTVR2lzHz00UewsLBAVFQUVqxYgREjRkAmk6k6VrUhtbWF+YwZAIDEQ4eRfPlPFSciIiIqeyzE1XCtnU0BADEpmXgsS1ZxGiIiIqoufHx84OPjg9TUVFVHKRPJyck4cOAAFi1aBB0dHQwYMADNmzfHgQMHVB2tWjEZ5wmtRo0AAJE+PlBUk9cPERFRLhbiariG1gbQ11QHAPizeyoRERGVkUuXLuGPP/6AkZGRSs6fnJwMb29v9O3bF+bm5hAEAcsK6e6YlZWFhQsXwt7eHlpaWmjatCl27tyZZ5t//vkHenp6sLOzUy5zdXXF3bt3y/U6ahpBXR1WSxYDamrIev4cUWvXqjoSERFRmWIhroZTkwho6WgMgIU4IiIiKjvt2rVDu3btIJGoprkZHR2NxYsX4/bt23Bzc3vjtlOmTMFXX32FwYMH44cffoCdnR3GjBmDrVu3KrdJTk6GgYFBnv0MDQ2RnMweBWVNu3FjmIwfDwCI3bwF6ffuqTgRERFR2WEhjuDulNM99VpIDMeJIyIiomrB2toaL168wLNnz/Dzzz8Xut3NmzexefNm+Pj4YPXq1Xj//fdx5MgRdO7cGV5eXsjMzAQA6OnpITExMc++iYmJ0NPTK9frqKnMp0+Dhq0tkJ2NiAULIcrlqo5ERERUJliII3g450zY8DIxA09jOA4HERERVX2ampqwsbF563Z79uyBRCLBtGnTlMsEQcD06dMhk8lw/vx5AEDdunWRnJyM58+fK7cLCgpC48aNyz48QaKjAysfHwBA+t27iN22XbWBiIiIyggLcQSXWobQ1lADkDN7KhEREVFNcePGDdSuXRsmJiZ5lnt4eAAAAgICAOQ8ETdo0CB4e3sjNTUVR48exY0bNzB48OA3Hj8xMRHPnz9XfkVERJTLdVRHeu3bwWDgAACAbNUqRP+0HqFjxyK4T1+Ejh2LuN17OJkDERFVOSzEETTUJGjhkDNO3NWQGBWnISIiIqo44eHhsLa2zrc892m68PBw5bIff/wRkZGRMDMzw8yZM7Fr1y5YWFi88fgrVqyAnZ2d8svd3b1sL6Cas5w3DxJ9fSA9HVErVyLtRgAyQ0KQdiMAkd7eCO7bDxlPQlQdk4iIqMhYiCMAgLtTzl1g/yd8Io6IiIhKTiaTITIyEtnZ2aqOUiRpaWnQ1NTMt1wikUBDQwNpaWnKZWZmZjh69ChSU1Px+PFj9OnT563HnzNnDp49e6b8unbtWpnmr+4kWlrA6xN+5I5n/Oq/cpkMYZMm8ck4IiKqMliIIwCAx6tC3Iv4NDyPY0OGiIiIii45ORkff/wxzMzMYG1tjVq1akFbWxtt27bF8uXLkZSUpOqIhdLS0kJGRka+5QqFAllZWdDS0irV8Q0MDGBra6v8KujpOypcwuEjUCQkFL6BQgF5ZCQSjhypuFBERESlUKaFuOTkZKSkpJTlIamCuNoZQaqe83LgOHFERERUVFlZWejatSvWrFkDc3NzjBw5EqNHj0bXrl3x+PFjeHl5wcnJCQcPHlR11ALZ2NgUOG5bbpfUokz4UBS+vr6wsLCAq6trmRyvpkg4fAgQhDdvJAhIOHy4YgIRERGVUqkKcefOncOMGTPg5uYGbW1tGBoawsDAANra2mjevDmmT5+Os2fPllVWKkdaGmpoZmcEgIU4IiIiKrr169fj+vXr+PHHH3H//n3s2LED27Ztw4kTJyCTyXDixAnUqlULw4YNw969e1UdN5/mzZsjODgYsbF52z/+/v7K9WXBy8sLMpkMQUFBZXK8miI7Jvbf7qiFEUVkR3OcYyIiqhqKXYiTy+VYu3YtnJyc0L17d2zfvh1mZmYYN24cPvvsM3h5eWHcuHEwMTHBjh070KNHDzg6OmLt2rWQy+XlcQ1URlrnjhPHQhwREREV0b59+9CjRw988MEHBa7v2bMn/v77b3Tp0gWTJ09GXFxcBSd8s+HDh0OhUGDdunXKZaIoKp/w69KliwrTkZqpydufiAOgZmxcAWmIiIhKT724O9StWxcpKSnw9PTEiBEj3jrz09WrV7F3714sWrQIy5cvx5MnT0oclsqXu5MpgMcIiU6BLDEdFgalGxOFiIiIqr87d+5g7ty5b9xGKpVi+/btqFu3LjZu3AgvL68KybZmzRrEx8cjPj4eAHD+/HnljeEZM2bA0NAQLVq0gKenJ7y9vREVFQUXFxccOHAAFy5cgJ+fX4ETOVDFMRwwEGnXb7x1u/QHDxCzaTOMR43MmeCBiIiokhJE8W3Peue1atUqTJkyBdra2sU6UWpqKjZu3IiPP/64WPtVVc+fP4ednR2ePXsGW1tbVccpktRMOZr6nIJcIeKHUW4Y4Fo2Y6IQERFR9aWuro5NmzbB09Pzrdt6enri5cuXOHXqVAUkAxwdHfH06dMC14WEhMDR0REAkJmZiSVLlmDz5s2QyWSoV68e5s6di7Fjx5ZZFl9fX/j6+iI7OxuxsbFVqo2oSorUVAT37Qe5TAYoFG/dXt3SEmYffQSjoUMgaGhUQEIiIqLiKXYhjoqmKhbiAGDIuj9xMyweY1vb48vBLqqOQ0RERJWcRCLB9u3bMXr06Lduu3r1anz77bd4/vx5BSSrnKpqG1GVMp6EIGzSJMgjI3O6qYqi8r/qVlawXroUiYcOIeHgQWWxTsPBHuYzPoZB3z4QJGU6Px0REVGp8F2J8vBwMgXACRuIiIio6F68eIH09PS3bmdiYlLpxoijyk/T2Qm1jx2F1eJF0G7ZAlInJ2i3bAGrxYtQ+9hR6LVpDZulX8P58CHo9+oFAMh6GobwTz9FyJChSDp3Hnz2gIiIKosyeyLu2bNnuHDhAmQyGYYPHw47OzvI5XLExsbCxMQE6urFHo6uSquqdzvPP5Rh4qa/AQABC3rARFeq4kRERERUmUkkEgiCAIlEgjp16qBp06bKLxcXF2X3TwDYsWMHxo0bh+zsbNUFVrGq2kasStLu3EXUqlVIuXRJuUzb1RXms2dDt7WHCpMRERGVYLKGgsyZMwc//PADsrOzIQgC3NzcYGdnh9TUVNSpUweLFi3C7Nmzy+JUVM5aOhhDIgAKMeepuN5NrFQdiYiIiCqx8+fPIygoSPl16NAh7N27F8KrmS719fXRpEkTNG3aFElJSSpOSzWBdpPGsN/wM1L//huy71ciLSAAaUFBCJswAbpt28B81ixoN22q6phERFRDlfqJOF9fX8ydOxdeXl7o2bMnevTogTNnzqBr164AgPHjxyMkJAR//PFHmQSuKqry3c4BP1zG7RcJmNjOEd4DGqs6DhEREVUh2dnZePjwYZ7iXFBQECIjIwEAgiDUyCfiOFmDaoiiiJRLlyD7fiUy7t9XLtfr3g3mH38MrXr1VJiOiIhqolI/EbdhwwaMHTsW33zzDWJiYvKtd3FxwcmTJ0t7GqpA7k4muP0igePEERERUbGpqamhUaNGaNSoEUaNGqVcHhUVpSzK1UReXl7w8vJS3qyliiEIAvQ6doRu+/ZIOnUKUatWIzMkBMlnziL57DkYDOgP8+nTIbW3V3VUIiKqIUo9WUNYWBg6dOhQ6HoDAwPEx8eX9jRUgTycTAAA9yISkZCWpeI0REREVB2Ym5uje/fu+OSTT1QdhWogQSKBQe/ecD58CNZffQV1G2tAFJF46DCC+/ZDhI8Psl7KVB2TiIhqgFIX4szMzBAREVHo+tu3b6NWrVqlPQ1VoFaOOYU4UQRuPOVTcURERERUPQjq6jAaNhS1T5yA5RdfQM3UFJDLEb9rN4J79sRLX1/IObMvERGVo1IX4vr164eff/4Z0dHR+dbdvHkTv/zyCwYPHlza01AFMtaVooGVPgDA/wkLcURERERUvUikUph4jkWdUydhPns2JAYGEDMyEPuLH4K790DU2rXITk5RdUwiIqqGSl2IW7JkCTQ0NODi4oJ58+ZBEAT4+flh5MiRaN26NWxtbTF//vyyyEoVyP1V99SrHCeOiIiIqNR8fX1hYWEBV1dXVUeh10h0dWH2wRTUOX0KplOmQNDWhiIlBdE/rEFwjx6I2bQZivR0VcckIqJqpNSzpgJATEwMvvjiC+zduxdxrx7lNjAwwDvvvINly5bBzMys1EGrmqo8ayoAHL0VgWk7A6AmEXDLuyd0NUs9rwcRERFRjVfV24jVnTwqCtHrf0bc7t1AVs5YyeqWljD76CMYDR0CQUNDxQmJiKiqK5NC3OuioqKgUChgbm4OiaTUD9xVWVW9kRWVlIFWX50BAGyd5I6O9cxVnIiIiIio6qvqbcSaIvP5C0SvW4eEAwcAhQIAoGFvD/MZM2DQry+EGvw5h4iISqfUjzmdPn0a3bp1UxbdzM1ZsKkOzPU14WyuiydRKbgWEstCHBERERVZcnIynj59itjYWBR0z7djx44qSEVUdFLbWrD5+iuYTn4PUat/QNKJE8gKC0O4lxdiNmyA+ayZ0OvSBYIgqDoqERFVMaUuxPXq1Qvm5uYYPnw4RowYgfbt25dFLqoEPJxM8SQqBf4hMaqOQkRERFVAXFwcZsyYgT179iA7OzvfelEUIQhCgeuIKiNNZ2fYrvweaXcnI2rVKqT8cQkZjx7h+UfToO3qCvPZs6Hb2kPVMYmIqAopdSFuz5492LVrF/z8/LBu3TrY2tpixIgRGDFiBFq0aFEWGUlFPJxM8Ou1MAQ9S0B6Vja0NNRUHYmIiIgqsQ8++AD79+/H9OnT0alTJxgbG6s6ElGZ0G7cGPY//4zU69ch+34l0m7cQFpQEMImTIBu2zYwnzUL2k2bqjomERFVAaUe3OCdd97Bvn37IJPJsHXrVjRt2hSrV6+Gu7s76tWrB29vb9y7d69U58jKysLChQthb28PLS0tNG3aFDt37nzrfg8ePMDnn3+O5s2bw9DQEHp6emjfvj0OHDhQ4PZJSUmYMWMGrKysoK2tjdatW+PUqVOlyl6V5c6cmpmtwM2weNWGISIiokrv+PHj+Pjjj/H9999j8ODB6NSpU4FfNRFnTa0edFq2hMP2bbD7eT00GzUEAKRc+Quhw0fg2fTpSH/0SMUJiYiosiuzUUZ1dXUxZswYHD58GC9fvsTPP/8MJycnfP3112hayrtDU6ZMwVdffYXBgwfjhx9+gJ2dHcaMGYOtW7e+cb+NGzdi3bp1cHV1xbJly/Dll18iLS0NQ4YMwcaNG/NsK4oiBg8ejI0bN+K9997DypUrIZFI0LdvX5w7d65U+asqGyNt2JloAwC7pxIREdFbaWpqom7duqqOUSl5eXlBJpMhKChI1VGolARBgF7HjnDatw+1Vn4PqZMTACD5zFmEDBqMF599hsywMBWnJCKiyqrMZ00FgMzMTBw7dgy7du3C4cOHkZ6eXuKxQG7evInmzZtj8eLFWLBgAYCcolnXrl1x7949PHv2DFKptMB9r1+/jvr160NfXz9PtpYtWyIiIgIvX75UTjKxf/9+DB06FFu3boWnpycAICMjA02bNoW2tjYCAwOLlbu6zIj16d4g7LvxHG1rm2Ln+61VHYeIiIgqsQ8//BDh4eE4dOiQqqNUWtWljUj/EuVyJBw8hKi1ayAPj8hZqK4Oo3eGwWzqVGhYWqo2IBERVSpl9kScXC7HsWPHMG7cOFhYWGDYsGG4dOkS3n//fVy5cqXEx92zZw8kEgmmTZumXCYIAqZPnw6ZTIbz588Xum/Lli3zFOEAQCqVYsCAAYiOjoZMJstzHhMTE4wePVq5TFNTE1OmTEFQUBAePnxY4muoyjxedU8NCItDplyh4jRERERUmX366aeIiIiAp6cn/vrrL0REREAmk+X7IqpOBHV1GA0bitonTsDyiy+gZmoKyOWI37UbwT174eW3vpDHxak6JhERVRKlnqzh9OnT2L17N/bv34/4+HgYGxtjxIgRGDlyJDp37lzqKb1v3LiB2rVrw8TEJM9yD4+c2YkCAgLQq1evYh0zPDwc6urqMDQ0zHOeFi1aQE0t74QEr5+nfv36hR4zMTERiYmJyu8jIiKKlamy8nAyBQCkZylw+0U8WjiYvGUPIiIiqqnq1asHQRBw48aNN47ny1lTqTqSSKUw8RwLo2FDEbttO2J++QWKxETE+vkhfvdumEycCJMJ46Gmp6fqqEREpEKlLsT16tUL+vr6GDRoEEaOHImePXtCXb3Uh1UKDw+HtbV1vuU2NjbK9cURHByMXbt2YeDAgdDW1s5znjZt2pT4PCtWrMCiRYuKlaUqsDPRhrWhFiIS0nH1SSwLcURERFSohQsXlvomLFFVJ9HRgdkHU2A8cgRi/DYhdutWKFJSEL1mDeK2b4fplCkwHj0KEi0tVUclIiIVKHXFbO/evejfvz80NTXLIk8+aWlpBR5bIpFAQ0MDaWlpRT5Wamoqhg8fDk1NTaxYsaJI59F69Qb5tvPMmTMHkydPVn4fEREBd3f3ImerrARBgLuTCQ4GhuNaSCymdVF1IiIiIqqsfHx8VB2BqNJQMzSExexZMBk7BtHrf0bc7t3Ijo+H7NtvEbt5M8w++ghGw4ZC0NAAAChSU5Fw+AgSDh9Cdkws1ExNYDhgIAwH9IdER0fFV0NERGWl2IW4sFczANnb2wMAWrVqhZcvX751v9zti0tLSwsZGRn5lisUCmRlZSkLZW+TlZWFd999F3fv3sXRo0fh4OBQpPOkp6cr17+JgYEBDAwMipSlqvFwMsXBwHBcD42FPFsBdbUyG1qQiIiIqqk7d+4gJCQEgiDA0dERTZo0UXUklfL19YWvry+75dZA6ubmsJr/BUwnTkDU2nVIOHAAcpkMkT4+iPHzg/mM6dBs2BDPJr8PeWQkIAiAKAKhoUi7fgPRP/4Iez8/aDo7qfpSiIioDBS7EOfo6AhBEJCWlgapVKr8/m1K2uiwsbHB06dP8y3P7Sqa23X0TRQKBcaNG4eTJ09i9+7d6NatW4HnKWhct+Kcp7pyfzVhQ0pmNu6GJ8LVzki1gYiIiKjSOnLkCGbOnInQ0FAAObPdC4IAJycnrFy5Ev3791dtQBXx8vKCl5eXctZUqnk0atWCzddfwXTye4ha/QOSTpxAVlgYwr0+A9TVgdzPS6KY579ymQxhkyah9rGjfDKOiKgaKHYhzs/PD4IgQOPVI9S535eX5s2b49y5c4iNjc0zYYO/v79y/dt8+OGH2LVrFzZs2IBhw4YVep6zZ88iOzs7z4QNxTlPdVXbXBdmelJEJ2fiWkgsC3FERERUoFOnTmHw4MGwtbXFV199hUaNGkEURdy/fx/r16/HkCFDcOzYMfTo0UPVUYlURtPZGbYrv0fa3cmIWrUKKX9cAuTywndQKCCPjETCkSMwHj684oISEVG5EEQx95ZL5XTjxg20bNkSS5Yswfz58wHk3Fnt2rUr7t69i2fPnkFTUxPR0dGIjo6Gvb09dF67U/Tpp59i+fLl+Oabb/DZZ58Vep7ffvsN77zzDrZt24axY8cCADIyMtC0aVNoamri1q1bxcqde7fz2bNnsLW1LcGVVy4f7biBY7cj0b2hBTaOb6XqOERERFQJtWvXDklJSfjzzz+hr6+fZ11SUhLatm0LIyMjXLp0SUUJVa+6tRGp9J4MGoyMhw/fvJEgQLtlCzhu21YxoYiIqNyUerKGP/74Aw0bNoS5uXmB66Ojo3Hv3j107NixRMdv0aIFPD094e3tjaioKLi4uODAgQO4cOEC/Pz8lBMsrFmzBosWLcL58+fRuXNnAMDq1auxfPlyuLm5wcbGBtu3b89z7CFDhkBXVxcAMHToUHTq1Anvv/8+Hjx4AHt7e2zZsgXBwcE4ceJEibJXJx5Opjh2OxLXQmKhUIiQSDgjGhEREeUVGBiIL7/8Ml8RDgD09fUxadIkLFiwQAXJiCovMTOzCBuJyI6OKf8wRERU7kpdiOvSpQu2bduG0aNHF7j+7NmzGD16dKkGpt24cSMcHBywefNm/PTTT6hXr16eJ9cKExAQAAC4efMmPD09860PCQlRFuIEQcChQ4fwv//9Dxs2bEBiYiJcXFxw5MgRdO/evcTZqwsP55xuwYnpcjyITEIjm+o5MQURERGVnFQqRUpKSqHrk5OTlcObEFEONVMTIDT037HhCpGdkIDUgABou7mV69BARERUvkrdNVUikWD79u2FFuK2bduGSZMmISsrqzSnqXKqW7cDhUJE8y9PIz41Cz4DGmFCO87aRERERHkNHDgQ165dw6VLl1C3bt086x4/foz27dvDw8MDBw8eVFFC1atubUQqvbjdexDp7V3k7TUbNIDxqFEwHNCfkzcQEVVBJXoiLjExEfHx8crvY2JiEBYWlm+7uLg4/Prrr6hVq1aJA1LlIJEIaOVogtP3XsI/JJaFOCIiIspn6dKlaNu2LZo0aYKBAweifv36AIAHDx7gyJEj0NbWxtKlS1WckqhyMRzQH9E//gi5TAYoFPk3kEgg0dODmqkpskJCkPHgASK9vSHz9YXhkCEwHjUKms5smxMRVRUleiJu0aJFWLx4cZG2FUURX331FT7//PNih6vKquPdzo2XnuDLo/dhqivF9fnd+Ug8ERER5fPPP//g888/x8mTJ5XdVHV1ddG7d2989dVXqFevnooTqlZ1bCNS6WU8CUHYpEmQR0YCgpDTTfXVf9WtrGDv5wepkyNS/a8hbudOJJ09C7w29I9u2zYwGjUK+l26QFAv9ehDRERUjkr0V7p79+7Q0tKCKIr43//+hxEjRqBZs2Z5thEEAbq6umjZsiU8PDzKIiupmIeTKQAgJiUTwVHJqGORfyBmIiIiqtnq1q2Lffv2QaFQICoqCgBgbm4OiUSi4mRElZemsxNqHzuKhCNHkHD4MLKjY6BmZgrDAQNg2P/fLqi6rT2g29oDWS9fIn73HsTt3YPsqGikXPkLKVf+grqVFYxHDIfRu+9C3cxMxVdFREQFKfUYcYsWLcKwYcPQpEmTsspULVTHu53ybAWaLT6N5Aw5vhzcBGNbO6g6EhEREVGV4OvrC19fX2RnZyM2NrZatRFJdcTMTCSdPYu4HTuRev36vys0NGDQsyeMR4+CdvPm7MlCRFSJlLoQl5KSgpiYGNjb2xe4PiwsDGZmZtCpYQOJVsdCHABM2HQNFx5GYaCrDVaPclN1HCIiIqIqpbq2EUn10h89QtyvvyLx4CEoUlOVyzXr1/93cgddXRUmJCIiACh1H4HZs2dj0KBBha4fPHgwPv3009KehiqJ3O6p/iExKGUNl4iIiIiIyohWvXqw9vZGnT8uwnLBfEhr1wYAZDx8iEgfH/zTqTMiv/wKGU+eqDgpEVHNVupC3OnTpzFkyJBC1w8ZMgQnT54s7WmoknB3MgEAvEzMQFhs6lu2JiIiIiKiiqSmpweTMWPgfOQw7LdsgX6vXoCaGhTJyYjbvh1P+vbD04kTkXjqFES5XNVxiYhqnFJPqRMREQFra+tC11tZWSE8PLy0p6FKwqWWIbQ11JCWlQ3/J7FwMOXj7URERERElY0gCND1cIeuh3vO5A579iJ+zx7Io6KQ+tdVpP519d/JHd55B+rm5qqOTERUI5T6iThzc3PcvXu30PV3796FkZFRaU9DlYRUXYLmDkYAAP+QWNWGISIiIiKit9KwtIT5jOmoc+4san2/AjqtWgEA5JGRiFq1Gv907YYXcz5B6vXrHH6GiKiclfqJuL59++Lnn3/G8OHD0bZt2zzrrl69ip9//hmjRo0q7WmoEvFwMsWfj2PgHxKj6ihERESkQmFhYSXar7BJvoiofAkaGjDo0wcGffrkm9wh8dgxJB47Bs169WA8ejQndyAiKielnjU1MjISrVq1QkREBPr06YMmTZpAEATcvn0bx48fh5WVFa5duwYbG5uyylwlVOcZsa4+icHIn68CAP6c1xW1jLRVnIiIiIhUQSKRQBCEYu+XnZ1dDmmqhurcRqSqKTs5GQkHDyLu11+R+ThYuVyipwfDwYNhPHoUNJ2dVZiQiKh6KfUTcVZWVrh+/Trmzp2LAwcO4OjRowAAAwMDeHp6YunSpbCysip1UKo8mtkZQaomQWa2AtdCYjDEjY1IIiKimsjPz69EhTgiqjxyJ3cwHj0aqdf+RtyvvyLp9Gnl5A5x27dDp3VrGI8eBf2uXSGol/ojJBFRjVbqJ+JeJ4oioqKiIIoiLCwsanTDrLrf7Rz+01+4FhqLka3ssGxYU1XHISIiIqoSqnsbkaqHrJcyxO/Zo5zcIZe6pSWMXk3uoGFhocKERERVV6kna8j17NkzbN++Hdu2bUNmZiYEQUB2djZkMhnknBa72vFwNgEAXOOEDURERERE1YqGpcW/kzus/P7fyR1evkT06h/wuGs3vJgzB6l//83JHYiIiqlMniueM2cOfvjhB2RnZ0MQBLi5ucHOzg4pKSmoU6cOFi1ahNmzZ5fFqaiS8HAyxQ94jCfRKZAlpsPCQEvVkYiIiKgSkMlk+OWXX3Djxg3Ex8dDoVDkWS8IAs6ePauidERUHIKGBgx694ZB797I+OcfxP36KxIOHHw1ucNxJB47/mpyh1EwHDCAkzsQERVBqZ+I8/X1xcqVKzFnzhycPn06zx0RAwMDDBkyBPv37y/taaiSae5gBHVJTtfja6F8Ko6IiIiAe/fuoXHjxli8eDEeP36M8+fPIyoqCo8ePcKFCxfw7NkzPj1DVEVp1q0Lq4ULUeePP2C5cAGkdWoDADIePUKkzyL807ETIpd8iYzg4LcciYioZit1IW7Dhg0YO3YsvvnmGzRr1izfehcXFzx69Ki0p6FKRkeqDhdbQwCA/xMW4oiIiAiYN28eNDU1cf/+fZw5cwaiKGLVqlV4/vw5duzYgbi4OPj6+qo6pkr4+vrCwsICrq6uqo5CVCpqerowGT0azocPw37rFuj37g2oq0ORkoK4HTvwpF9/PB0/AYknT0HMysqzryI1FXG79yB07FgE9+mL0LFjEbd7DxSpqSq6GiKiilfqQlxYWBg6dOhQ6HoDAwPEx8eX9jRUCbk75YwT5x8So+IkREREVBlcunQJH3zwARwdHSGR5DQzc7umjho1CiNGjICXl5cqI6qMl5cXZDIZgoKCVB2FqEwIggBdd3fYrvwedc6ehdn06VA3NwcApPr748XMmXjcrTui1q5FlkyGjCchCO7bD5He3ki7EYDMkBCk3QhApLc3gvv2Q8aTEBVfERFRxSh1Ic7MzAwRERGFrr99+zZq1apV2tNQJdTayRQA8OhlMmJTMlWchoiIiFQtMzMT1tbWAABtbW0AQEJCgnJ9s2bN8Pfff6skGxGVHw1LC5hPn/ZqcoeV0HF3BwDIZTJE/7AGj7t0RcjQoZC/fJmzQ24X9Vf/lctkCJs0iU/GEVGNUOpCXL9+/fDzzz8jOjo637qbN2/il19+weDBg0t7GqqEWjga49UwcZw9lYiIiODg4IDQ0FAAOYU4a2trXLlyRbn+zp070NPTU1E6IipvOZM79ILD1i1wPnIYxqNH50zgkJ0NMT393wLcfykUkEdGIuHIkYoNTESkAqUuxC1ZsgQaGhpwcXHBvHnzIAgC/Pz8MHLkSLRu3Rq2traYP39+WWSlSsZASwONbAwAsBBHREREQJcuXXDw4EHl92PGjMHq1asxefJkTJo0CevWrcOgQYNUmJCIKopmnTqwWrgAdS5ehIa9/dt3EAQkHD5c/sGIiFRMvbQHsLCwwPXr1/HFF19g7969EEURO3fuhIGBATw9PbFs2TIYGxuXRVaqhDycTHHnRSLHiSMiIiLMnTsXXbt2RXp6OrS0tLBkyRLEx8dj7969UFdXh6enJ7777jtVxySiCqSmpwtBTe3tG4oiMh4+QvKlS9Bp1QoSLa3yD0dEpAKCWMZzyEdFRUGhUMDc3Fw5SG9N9Pz5c9jZ2eHZs2ewtbVVdZxyc/JuJD7YdgOCAAR594SBloaqIxERERFVWjWljUj0utCxY5F2I6Dwrqn/IWhqQqdVK+i2bwe9Dh0gdXaGIAjlnJKIqGKUeaXM3NwclpaWNboIV5O4O+bMnCqKwPVQdk8lIiIiIqK8DAcMLFIRTt3CAgAgZmQg5fJlyJZ9gyf9+uNxt26IWLAQiadOITspqbzjEhGVq2J3TQ0LCwMA2L/q55/7/VtPpK4OExMTaPER42rFWFeK+pb6ePgyCf5PYtG1gaWqIxEREZGKFLVdaF+U8aKIqNowHNAf0T/+CLlMBigU+TeQSKBuYYHax44iOzERKX/+ieRLl5Fy5QoUiYmQh0cgfu9exO/dC6ipQbtZM+h1aA/d9h2g1aghBD4EQkRVSLG7pkokEgiCgLS0NEilUuX3RTqZIKBdu3bYvHkznJycShS4qqhJ3Q4WHryDrX89RTM7IxyY1k7VcYiIiEhFitouzM7OroA0lVNNaiMSvS7jSQjCJk2CPDISEIScJ+Re/Vfdygr2fn7QdM77GVGUy5F2+zZSLl1G8uXLSL99O9+TdWomJtBt1y6nMNeuHdRNTSvysoiIiq3YT8T5+flBEARoaGjk+f5tsrOzER4ejvXr1+P999/HmTNnip+WKiV3JxNs/espbr9IQEqGHLqapZ4DhIiIiKqggtqF2dnZCAkJwdatW2FlZYWPPvpIRemISJU0nZ1Q+9hRJBw5goTDh5EdHQM1M1MYDhgAw/79IdHRybePoK4OHTc36Li5wfzjGZDHxSHlypWcwtyfl5EdFY3s2FgkHj6MxFczrmo1agTdDh2g16E9tF1dIWhwDGsiqlzKfLKGt1m7di0+++wzpKSkVORpK1xNutspS0qH+1dnAQDb3nNHh7rmKk5ERERElU1ycjJatWqFGTNm1OhiXE1qIxKVJ1EUkfHwIZIvXULK5T+RGhAAZGXl2UaipwfdNq2h274D9Nq3g0atWipKS0T0rwp/dOndd99FkyZNKvq0VI4s9LXgbKaLJ9Ep8H8Sy0IcERER5aOnp4eJEydi+fLlNboQR0RlQxAEaDVoAK0GDWD2/vvITk5B6jV/pFy+jOQ/LiHr+XMokpORdPoMkk7n9MaSOju/GluuPXRatYKE45cTkQqUWSHu2LFjOHLkCJ4+fQoAcHBwQP/+/dG3b98821lYWMDi1Ww4VH14OJvgSXQKroVw5lQiIiIqmFQqxYsXL1Qdg4iqITU9Xeh37Qr9rl0hiiKynj5F8uU/kXLpElKuXYOYlobMJ08Q++QJYrdshaCpCZ1WraDbvh30OnSA1Nm5yGOfExGVRqm7piYmJmLYsGE4d+4cBEGApaUlRFGETCaDKIro3Lkz9u/fDwMDg7LKXCXUtG4HB26+wKzdgZCqSXDLpye0NNRUHYmIiIgqkaCgIAwePBjGxsYICAhQdRyVqWltRKLKQJGZibQbN3JmYr10CRn//JNvG3Uba+i1aw/dDu2h26YN1PT1VZCUiGqCUs/zPHv2bJw7dw5ffvkl4uLi8OLFC4SHhyMuLg5LlizBhQsXMHv27LLISpWYu5MJACAzW4HAZ/GqDUNEREQq4eTkBGdn53xfJiYmaN68OWJjY7FixQpVxyyxNWvWwM3NDerq6vDx8VF1HCIqIolUCt02bWD5mRecDx9CnQvnYf3Vl9Dv3RuSVw+MyMMjEL93L158PBOPWrdB6NixiP7pJ6TduQtRoVDxFRBRdVLqJ+KMjY0xatQorFu3rsD1U6dOxa5duxAXF1ea01Q5NfFuZ4dvz+FZbBpmd6+Hmd3rqjoOERERVbAJEybk69olCAKMjIxQp04djB49GkZGRqoJVwb2798PdXV1bN26FY0bNy5RMa4mthGJKjNRLkfa7dtIufwnki9fQvqt28B/PiKrmZhAt127nPHl2rWDuqnpG4+pSE1FwuEjSDh8CNkxsVAzNYHhgIEwHFDw7LBEVLOUyRhxLi4ub1y3a9eusjgNVXLujqZ4Fvsc/iExAFiIIyIiqmk2b96s6gjlasiQIQCAgwcPqjgJEZUVQV0dOm5u0HFzg/mM6ZDHxSHlyhVlYS47KhrZsbFIPHwYiYcPAwC0GjWCbocO0OvQHtqurhA0NJTHy3gSgrBJkyCPjAQEIaeoFxqKtOs3EP3jj7D384Oms5OqLpeIKoFSd03t27cvjhw5Uuj6I0eO5JuwgaonD+ec7qkBYXHIlPPxbSIiIip7ycnJ8Pb2Rt++fWFubg5BELBs2bICt83KysLChQthb28PLS0tNG3aFDt37qzgxERUlagbG8OwXz/YLP0adf/4A04HD8Di00+g4+EBvCq4pd+7h5j16/F0rCcetWmL5zNmIG73HmQ8fpxThJPJcg6W+2Tdq//KZTKETZoERWqqKi6NiCqJYj8RJ8v9o/LK/PnzMXLkSPTv3x/Tp09HnTp1IAgCHj16hDVr1iA8PBzLly8vs8BUeXm8GicuPUuB2y/i0cLBRMWJiIiIqDyFhYWVaD97e/sSnzM6OhqLFy+Gra0t3NzccPr06UK3nTJlCrZu3Ypp06bBxcUFBw4cwJgxYyCXyzFu3LgSZyCimkEQBGjVrw+t+vVhOnkyFCkpSPG/hpTLl5B86TKynj2DIjkZSafPIOn0mbcfUKGAPDISCUeOwHj48PK/ACKqlIpdiLOysso39ocoirh9+zaOHz+ebzmQ0z1VLpeXIiZVBfYmOrAy0EJkYjr8Q2JZiCMiIqrmHB0d87ULiyI7O7vE57S2tsaLFy9gY2OD0NBQODkV3MXr5s2b2Lx5MxYvXowFCxYAACZPnoyuXbvCy8sLI0eOhFQqBQD07NkTf/zxR4HHmThxIn788ccS5yWi6kOiqwv9rl2g37ULACDz6dOcmVgvX0aKvz/EtLS3H0QQkHD4MAtxRDVYsQtxCxcuLFGDi6o/QRDg4WyCg4Hh8H8Si486qzoRERERlSc/P7887UJRFLFq1SqEhoZizJgxqF+/PkRRxMOHD/Hrr7/C0dERH3/8canOqampCRsbm7dut2fPHkgkEkybNk25TBAETJ8+He+88w7Onz+PXr16AQBOnTpVqkxEVDNJHRxg4uAAk7FjoMjMRHDPXjljw72JKCL9zl3IVq6EtosLtJq4QMPSomICE1GlUOxCHKdqpzdxd8opxN14Ggd5tgLqaqUehpCIiIgqqQkTJuT5/ptvvkFKSgoeP34M0//MKujj44O2bdsiKiqqQrLduHEDtWvXholJ3if0PTw8AAABAQHKQlxRyeVyyOVyZGdnQy6XIz09HRoaGlBTUyt0n8TERCQmJiq/j4iIKNY5iahqkEil0LCtBfnLl/lmXf0vMS0NMT+tV36vbmEBLReXnMKcSxNoN2kCNUPD8o5MRCpSJrOm5rpz5w5CQkIgCAIcHR3RpEmTsjw8VQEeTjmN7uQMOe5FJKKprZFqAxEREVGFWbduHWbMmJGvCAcA5ubmeP/997F27Vp4eXmVe5bw8HBYW1vnW577NF14eHixj/nll19i0aJFyu+/+uorbNq0KV9B8nUrVqzIsw8RVV+GAwYi7fqNt26n7eYGRXISMh4HA6IIuUyG5LNnkXz2rHIbDQd7aDd5VZhr2hRaDRtCoq1dnvGJqIKUyeNKR44cQe3ateHq6orBgwdj4MCBcHV1RZ06dd44o2pRlWbGq6+//hqDBw+GjY0NBEHAhx9+WOB2mzdvhiAIBX49fvy41NdQU9Q214WZXs54K/5PYlWchoiIiCqSTCZDVlZWoevlcnm+ib/KS1paGjQ1NfMtl0gk0NDQQFpRxnL6Dx8fH4iimOfrTUU4AJgzZw6ePXum/Lp27Vqxz0tEVYPhgP5Qt7ICJIV8zJZIoG5lBftfNsL58GHUv/43HLZthcVnn8Ggbx9o2NkpN816GobEo0chW/YNno4eg4ctW+HJoMEInz8fcbt2I+3uXYhv+HtLRJVXqZ+IO3XqFAYPHgxbW1t89dVXaNSoEURRxP3797F+/XoMGTIEx44dQ48ePUp8jtLMePXFF1/AwsICrVq1wtGjR996Lh8fH9SuXTvPMisrqxJnr2kEQYC7kwmO3Y6Ef0gs3u/orOpIREREVEHc3Nywdu1ajBo1Co6OjnnWhYSEYO3atXBzc6uQLFpaWsjIyMi3XKFQICsrC1paWhWSw8DAAAYGBhVyLiJSLYmODuz9/BA2aVLOWHGCkNNN9dV/1S0sYO/nB4mOTs72urrQadUKOq1aKY8hj4tD+p07SLt9G+m37yDtzm1kR0UD2dnIePgQGQ8fImHfbwAAQSqFVsOGr7q1NoGWiwukjo4QCisEElGlUOpC3KJFi9CoUSP8+eef0NfXVy4fNGgQpk2bhrZt22Lx4sUlLsQVZ8argjx58kQ5m1ZRJpno1asXWrduXaKslMPdMacQ93doLBQKERIJJ/cgIiKqCVasWIEePXqgQYMGGDhwIOrVqwdBEPDgwQMcPnwY6urqWL58eYVksbGxwdOnT/Mtz+2SWpQJH8qSr68vfH19SzVjLBFVfprOTqh97CgSjhxBwuHDyI6OgZqZKQwHDIBh//7KIlxh1I2NodehA/Q6dACQMwmO/OVLpN26pSzMpd+5C0VSEsTMTKQFBSEtKAhxr/aX6OlBq3FjaDfNmQhC26UJ1K2tOeEiUSVS6kJcYGAgvvzyyzxFuFz6+vqYNGmSsoBWEsWZ8aoghU1p/yZJSUnQ0dF548C7VDgP55xxYRLSsvDwZRIaWvMuMBERUU3QunVrXLt2DfPnz8exY8ewb98+AICOjg769euHxYsXo3HjxhWSpXnz5jh37hxiY2PzTNjg7++vXF+RvLy84OXlhefPn8Pute5nRFT9SHR0YDx8OIyHDy/1sQRBgIaVFTSsrGDQsycAQFQokPn06b9Pzt26jfT79yFmZECRnIxUf3+kvvpbBwBqpqbQbpLzxJx2UxdoubhA3di41NmIqGRKXYiTSqVISUkpdH1ycjI0NDRKfPzymPHqTXr06IHk5GRIpVJ069YN3333HRo1alRmx68J6lvqw1BbAwlpWfB/EsNCHNH/27vv8KbKvw3g90n3SheddDJllC0gQ4ay9xAEZIhKUUAFrAgiWwGroIggq2z8Ca9SBAFBBQRRNmUj0BYKbeluOtM0Oe8fpZHSlbZpTpven+vKRXPmfXJKefj2Oc9DRFSDNGrUCD/++CM0Gg3i4+MhiiJcXV0hM/CjUiNGjMDnn3+ONWvWYO7cuQDyepasXr0aLi4u6Natm0HzEBHpiyCTwcLfHxb+/rAfMAAAIKpUUN69+9Qjrdeg/PdfQK2GOjER6SdOIP3ECe0xzGrX/u+R1qYBsGzSBCa2Njpn0GRmInX/AaTu/xnqxCSYODvBfsBA2A8ovdcfUU1X4UJc586dsXr1aowcORL169cvsO7u3bv49ttv8eKLL5b7+JUx41VRrK2tMX78eHTr1g329va4dOkSVqxYgQ4dOuD8+fOoV69eiftzavr/yGQCnvdzwm83H+NsZBImdCx7r0QiIiKq3mQyGdzc3Crl2KtXr0ZKSgpSUlIAAMeOHUNubi4AYNq0abC3t0fr1q0xduxYzJ8/H/Hx8dpxho8fP46QkJAiJ3IgIqquBDOzvPHiGjUCnvTE02RlIfvmLWRfu4qsq9eQffUqciIjAQCqR4+gevQIaYcPPzmAAPO6df6bqTUgABbPPQdZEcNAKcMjCo+DFxmJrPMXkLB2LXxCQmBRh/8HJCqOIIqiWJEDXL9+HR06dEB2djYGDhyIhg0bAgBu3bqFAwcOwMrKCn/99Ve5e5XVrVsXdevWxZEjRwqtMzc3x7hx47Bx40adjiUIAgIDA/Hdd9/ptP25c+fQvn17jBo1Cjt27Chx2wULFhQ5NX1UVBS8vLx0Op8x2XgyHEt+uYlatuY49/HLHJOAiIiI9MbPz6/I8d+AvEkh8ieKyMnJweLFi7FlyxbExcWhQYMGmDVrFl577TUDps3z9BhxSUlJNbaNSETSUisUTx5pvaYt0OXGxha9sZkZLBs0gGWzAG2BzszDA+EDBiI3Lg7QaArvI5PB1NUVdQ/+wp5xRMWocCEOAO7cuYPZs2fj119/1T6mamNjg969e+PTTz9FgwYNyn3sJk2aoFatWjjxVDdaIG/GKxMTE0yZMgWrV6/W6VhlLcQBQNeuXXHr1i3EFvfD6YmiesS1bdu2xjayrjxMwcDVfwEAfpvxIuq5Fh5DkIiIiKimyR8jrqa2EYmo6smNj0fW1WvIupo3IUT21atQp6YWvbG5OZCTU+ox3Rct1MsYeUTGqMKPpgJA/fr18X//93/asUAAwMXFRS9jgUg945WPjw9Onz5d6nacmr6gxh5y2FqYIl2ZizMRSSzEERERERERVUGmLi6w694Ndt3zxs4URRGqhw+RffW/R1qzbtyAmJmpUxEOgoDU/ftZiCMqhl4KcfkqYywQqWe8Cg8Ph6ura6WewxiZmsjQ2tcRJ/6Nx5nwJIxp5yt1JCIiIiIiIiqFIAgw9/aGubc35H37AgBEtRrKe/fwYOIbUCcklHwAUYQ6IdEASYmqpzJ3WcsfFLc8yrPviBEjoNFosGbNGu2yoma8SkhIwK1bt5CZmVmubElJSYWW/f777/jrr7/0OitrTdKuTl7h9GxEEvTwBDQRERERERFJQDAxgWWDBjD3882boKEUOY8eIn7NGqgePzZAOqLqpcw94nx8fPDOO+8gMDAQ/v66zYQSHh6OtWvXYv369Ugt7lnzYug649Xq1auxcOFCHDt2DF27dtXuv3379gKPtl68eBFLliwBAIwdOxa+vnk9tTp16oSWLVuiWbNmcHBwwOXLl7Fp0ya4ublhwYIFZcpMedr55xXiYhXZeJCUCV9n3afDJiIioqrvwYMH5drPx8dHz0mqvqcnayAiqq7sBwxE1vkLpW+Yo0LCqm+QsPpb2HbtCocRr8C2c2cIJiaVH5KoiitzIW7btm2YN28egoOD0aZNG/Ts2ROtW7dGnTp14OjoCFEUkZycjIiICJw/fx5Hjx7F+fPn0bhxY2zbtq1cITdu3AhfX19s2bIF3333HRo0aIDt27frNOPVpk2bCkz0cO7cOZw7dw5AXvEtvxA3dOhQHDx4EAcPHkRGRgbc3d0xYcIEzJ8/H7Vr1y5X7pouoLYDLM1kyFZpcCYiiYU4IiIiI+Pn51eumdFrYjEqKCgIQUFB2skaiIiqI/sB/ZGwdm2Js6aaODrC7qXuUBw8BE16OtL/+APpf/wBUw8POAwfBodhw2Dm7m748ERVRLlmTRVFEQcPHkRISAgOHTqE7OzsQo0wURRhaWmJ3r1744033kDfvn3L1VCrrjgjVp4xG//BX3cTMayVF74c0VzqOERERKRHW7ZsKVf7bvz48ZWQpnpgG5GIqjtleAQeTJyI3NjYvMdURVH7p6m7O3xCQmBRxx+azEwoDh1C8u7dyA678t8BZDLYdumS10vuxRfZS45qnHIV4p6mUqlw/vx53Lp1CwlPBm2sVasWGjVqhNatW8PMzEwvQasbNrLyfP3bHaz87V94OVrh1KzuUschIiIikhTbiERkDDSZmUg9cACp+/dDnZAIk1rOsB8wAPb9+0NmbV1o++xbt5Cyew9Sf/4ZmvR07XJTd3c4DB8Oh2FDYebhYchLIJJMhQtxVDQ2svL8E56IV9f/AwD466PuqO1gJXEiIiIiIumwjUhENZkmMxOKw78i5YcfkBUW9t8KmQy2L74IhxEjYPtiZwimZR5Fi6ja4Hc3VaoW3g4wN5EhR63B2YhEDGnJBicREZExi4uLw6ZNm3DhwgWkpKRA88wYQoIg4Pfff5conXQ4WQMRESCztobD0CFwGDoE2bdv/9dLLi0N6cePI/348bxecsOGwWH4MPaSI6PEHnGVhL/t/M+I7/7G2cgkjGrrjaVDm0kdh4iIiCrJjRs30KVLF6Snp6Nhw4a4evUqGjdujOTkZERHR6Nu3brw9vbGH3/8IXVUybCNSERUkCYrK6+X3O7dyLp06b8VMhlsO3eGw8gReWPJsZccGQmZ1AHI+LX1dwIAnAlPkjgJERERVaaPPvoIFhYWuHnzJn777TeIooivv/4aDx8+xM6dO5GcnIzg4GCpYxIRURUis7KCw5DB8Pt+F/x/3gfHsWMhk8sBjQbpJ07g4TtTcLf7S4hftQqq6Gip4xJVGAtxVOna1ckrxIUnZCAuLVviNERERFRZTp48icDAQPj5+UEmy2tm5j+aOmrUKIwcORJBQUFSRiQioirMskEDuH88B/VPHIfHsqWwatUKAJAbF4eENWtx96WX8SAwEGm//w4xN1fitETlw0IcVbpWPo4wkQkAgLMR7BVHRERkrHJycuDxZDwfK6u8CZpSU1O161u0aIFz585Jko2IiKoPmZUVHAYPht+unaiz/2c4jnvSS04UkXHiTzycMhV3u7+EuK+/hurRI6njEpWJ3gtxHHKOnmVjYYqA2vYAWIgjIiIyZr6+voiMjASQV4jz8PDA6dOnteuvXbsGW1tbidIREVF1ZFG/PtznzEH9P0/Ac/kyWLVuDSCvl1zi2u9w9+UeeDBpEtJ++w2iSiVxWqLS6aUQFxUVhW+++QYvv/wybGxs0K1bN6xatUrbECPKfzyV48QREREZr27dumHfvn3a92PGjMGqVavw5ptvYuLEiVizZg0GDRokYULpBAcHw9XVFc2bN5c6ChFRtSSztIT9oEHw27kDdQ7sh9P4cZDZ2+f1kvvzJB5OnZbXS+6rr5DzkL3kqOoq96ypV65cQWhoKEJDQxEWFgY7Ozv06dMHnTp1wunTp3Ho0CGkpqYiICAAgwcPxqBBg9CyZUt956+yOCNWQX/ceoyJW84DAC5+0gNONuYSJyIiIiJ9e/DgAc6dO4d+/frB0tISSqUS06ZNw549e2BqaooBAwZg1apVNbpXHNuIRET6o1EqkXbkCJJ/+AFZ5y/8t0IQYNOxIxxGjoBd164QzMykC0n0jHIV4ho2bIi7d+/C09MTAwcOxKBBg9CtWzeYPfXNnZubi+PHj+Pnn3/Gzz//jAcPHsDHx6fG9JJjI6sgRbYKzRcegSgC68a2Rq8m7lJHIiIiIjI4thGJiCqH8t49pOzeg9TQUKifGp/UxKUWHIYOg8Mrw2HOn7tUBZSrELdkyRL06dMHrZ88m62LS5cu4eeff8b8+fPLerpqiY2swvqtOonr0QpM7OiPeQMaSx2HiIiI9CwjIwOJiYnw8fEpcv2DBw9Qq1YtWFtbGzhZ1cE2IhFR5crrJXcUKbt3I/PpCYIEATYdOuT1kuvWjb3kSDLlfjSVSsZGVmGL9t9AyF8RaFpbjgPTOksdh4iIiPRs0qRJOHfuHC5dulTk+latWqF9+/ZYs2aNgZNVHWwjEhEZjjI8/L9ecikp2uUmtWrBYejQvF5y3t7SBaQaqVyTNWRnZ+Pdd9/Fjh079J2HjFhb/7wJG25EK6DI5mw2RERExubo0aMYMmRIseuHDBmCX3/91YCJiIioJrOoUwduH81CvRPH4fnFF7Bu2xYAoE5IQOL69bjXoyceTHwDil+PFJpxVZOZieQfdiPytddwr09fRL72GpJ/2A1NZqYUl0JGxLQ8O61evRrffvst+vfvr+88ZMTyC3EaEbgQmYxuz7lKnIiIiIj0KSYmBh4eHsWud3d3R3R0tAETERERATILC9j37wf7/v2gDI9Ayv/9H1L37oU6ORkZp08j4/TpvF5yQ4bA4ZXhEHPVeDBxInJjYwFBAEQRiIxE1vkLSFi7Fj4hIbCo4y/1ZVE1Va4ecT/88AN69OiBnj17lrhdcHAw2rRpgxs3bpQrHBkXJxtzNHSzAwD8E5EocRoiIiLSNxcXF1y/fr3Y9devX4eDg4PhAlUhwcHBcHV1RfPmzaWOQkRUo1nU8Yfbh0F5veS+/ALW7doBeNJLbsMG3OvZCxGDByP38eO8HfJH83ryZ25cHB5MnMiecVRu5SrE3bx5E3369Cl1u2nTpiEqKgrff/99eU5DRii/V9yZ8CSJkxAREZG+9e3bF+vXr8fp06cLrfvnn3+wfv169O3bV4Jk0gsKCkJcXBzCwsKkjkJERABk5uaw79cPvlu3oM6hg3B6YyJMHB0BAGJOzn8FuGdpNMiNjUXqgQMGTEvGpFyFOACwsbEpdRtLS0sMHToUR44cKe9pyMi0q5NXiLv2KBUZylyJ0xAREZE+LVy4EM7OznjxxRcxYMAAzJ49G3PmzMGAAQPQqVMnODk5YfHixVLHJCIiKsDC3x9uQXm95Mzr1Cl9B0FA6v79lR+MjFK5xojz8vLCrVu3dNq2WbNm2Lt3b3lOQ0Yov0dcrkbExQfJ6FzfReJEREREpC/u7u44f/48Zs2ahdDQUPzyyy8AALlcjrFjx2Lp0qVwd3eXOCUREVHRZObmxfeEe5ooQp3A4ZaofMrVI65nz57Yvn07MjIySt1WEASkPDVNMNVsrnaWqFMrrzfl2Qg+nkpERGRs3NzcsGXLFiQnJyM2NhYxMTFITk7G5s2bWYQjIqIqz8TZKW+ChpIIAkxqORsmEBmdchXi3n33XaSlpWH48OHILGWAwsuXL5c4exbVPBwnjoiIyPgJggBXV1e4ublBKO0/NERERFWE/YCBpfeKE0WIOTlQp6cbJhQZlXIV4urVq4c1a9bg6NGjaNGiBX766Seo1epC2505cwZbtmxB7969KxyUjEf+OHGXo1KQrSr8fUNERETVw4MHD/DgwYNC70t7ERERVVX2A/rD1N0dkJVcLskOu4KIwUOQeemSgZKRsRBEUZcHoIv2448/YvLkyUhKSoKrqyteeukl+Pr6wtzcHNeuXUNoaCgcHBxw6dIleHl56TN3lffw4UN4e3sjKiqqxl17aR6lZKHjsj8AAP+b1B7t67BLLxERUXUkk8kgCAKysrJgbm6ufV+aon6BW1OwjUhEVPUpwyPwYOJE5MbG5j2mKoraP03d3WHXpzeSt+8AcnMBExPUmvIOagUGQjAxkTo6VQPlmqwh37Bhw9ClSxesXLkS27Ztw65duwqsb9++PTZs2MBGBhVQ28EKXo5WeJichTPhSSzEERERVVMhISEQBAFmZmYF3hMREVVnFnX8UffgL0g9cACp+/dDnZAIk1rOsB8wAPb9+0NmbQ37vn3x6IMPoLr/AAmrvkHGX6dR+/PlMKtdW+r4VMVVqEfcs8LDwxEVFQW1Wo169erBx8dHX4eudvjbzpLN3B2GHy8+RMd6ztj5Znup4xARERFVquDgYAQHB0OtViMpKYltRCIiI6DJyEDsZ58h9cefAAAyOzt4LFwAed++EiejqqxcY8QVp06dOujSpQu6d+9eo4twVLp2TyZsuHA/GTm5GonTEBERkT789ddf+Pbbbwss27VrFxo2bAhXV1e899570Ghq5r/7QUFBiIuLQ1hYmNRRiIhIT2Q2NvD89FPUXrkCMjs7aNLS8GjGTETPngN1eobU8aiK0mshjkhX+RM2ZKs0uPooVeI0REREpA/z58/Hn3/+qX1/+/ZtTJgwATKZDG3atMHq1auxatUqCRMSERHpn7xPH9TZFwqrNq0BAKl79yJi6FBkXbkicTKqiliII0n4OFnDTW4BADgTkShxGiIiItKHa9euoV27dtr333//PaytrXHmzBkcPHgQY8eORUhIiIQJiYiIKoeZpyd8t26Fy/vvASYmUD14gMjRY5Cwbj3EGjxJERXGQhxJQhAEtPPPm6ThbESSxGmIiIhIH1JTU+Ho6Kh9f/jwYfTo0QNyuRwA0KlTJ0REREgVj4iIqFIJJiaoNXky/HbugJm3N5Cbi/iVK/FgwutQxcRIHY+qCBbiSDJtn4wTdz4yGbnqmjleDBERkTHx8PDAjRs3AACxsbE4f/48evbsqV2vUChgamoqVTwiIiKDsGrRAv57f4L9oIEAgMxz5xA+eAgUvx6ROBlVBZVeiOvevTtee+01baOMKF/7J+PEpStzcSNGIXEaIiIiqqihQ4di9erVePfddzFkyBBYWFhg4MCB2vVhYWGoU6eOhAmJiIgMw8TWFp7Ll8Pziy8gs7WFJjUVj957DzGffAJNZqbU8UhClV6IO378OHbt2oVmzZph7NixlX06qkbqutjC2cYcAB9PJSIiMgYLFy7E8OHDsWPHDsTExCAkJARubm4A8nrD/fjjj+jRo4fEKYmIiAzHvn8/+IeGwqplSwBAyp7/Q8TQYci6dl3iZCSVSi/EaTQapKWl4eeff4aHh0dln46qEUEQtI+n/hPOQhwREVF1Z2Njg+3btyMpKQmRkZEYOXKkdp2trS0ePXqExYsXS5iQiIjI8My9asN3+zbUmjoVkMmQExmJyFGjkLhpE0QNh2mqaSpUiFOpVEhNTS11OxsbG/Tt2xeff/55RU5HRqjdk0LcucgkaDSixGmIiIioIg4ePAhNMf+hkMlksLe3h5mZmYFTERERSU8wNYXL1Cnw3bEdZp6egEqFuOAv8OCNN6B6HCd1PDKgchXikpKSMHDgQNjY2MDJyQn16tXDzp079Z2NaoC2T2ZOTc1S4fbjNInTEBERUUX0798fnp6emD59Oi5cuCB1HCIioirHulUr+O8LhbxfPwBA5t//IGLQIKT9/rvEychQylWImz17Ng4cOIAWLVqgX79+SE9Px7hx4/D999/rOx8Zuefc7SC3zJs9jePEERERVW/79+9Ht27dsGHDBrRt2xaNGzfGsmXLEBUVJXU0IiKiKsPEzg6eXwTDc/kyyKytoU5JwcMpUxGzYAE0WVlSx6NKVq5C3KFDhzBy5EicPXsWP//8M+7evYvOnTvj448/1nc+MnIy2X/jxJ2JSJQ4DREREVVEv3798P333yM2NhabNm1C7dq1MXfuXPj7+6N79+7YsmUL0tLYA56IiEgQBNgPGgT/0L2wbN4MAJDyvx8QMfwVZN+8KXE6qkzlKsQ9evQIvXr10r63tbXFggULcP/+fdy7d09v4ahmaPfk8dSzEUkQRY4TR0REVN3Z2tpiwoQJOHr0KKKiorBs2TIkJSXhjTfegLu7O0aPHo1ff/1V6phERESSM/fxgd+OHXCeHAgIAnLu3UPkiJFI3LKFEzkYqXIV4kRRhLm5eYFlzz33HERRRExMjF6CUc2R3yMuIT0H9+IzJE5DRERE+qRSqZCTkwOlUglRFCGXy3Hy5En06dMHzZo1w5UrV6SOaBDBwcFwdXVF8+bNpY5CRERVjGBmBtf334fvtq0w9fCAqFIhbtlyRL01Cbnx8VLHIz0r96yp0dHRUKlU2vf5M2Dl5ORUPBXVKE085bC1yBsnjo+nEhERVX+pqanYsGEDunTpgjp16mDx4sUICAjA/v378fDhQzx48AA///wz0tPT8cYbb0gd1yCCgoIQFxeHsLAwqaMQEVEVZf3886gTuhd2vXsDADL++gvhgwYj7fhxaYORXpW7EDdr1izY2tqiZcuWePPNN7FhwwYIggC1Wq3PfFQDmJrI0NrXEQAnbCAiIqrO9u3bh+HDh8Pd3R2BgYHIycnB6tWrERMTg927d6Nfv34wMTGBIAjo378/5syZw8IUERHRU0zs7VF75Qp4fPopBGtrqJOS8HDy24hdvASa7Gyp45EelKsQd+zYMaxcuRKvvfYaTExMsHPnTsyZMweiKKJPnz6oW7cuBg4ciI8++gjbt2+v8PT1KpUK8+bNg4+PDywtLdGsWTPs2rVLp30/++wzDB48GJ6enhAEAZMnTy5227S0NEybNg3u7u6wsrJC+/btceTIkQplJ91oJ2wI5zhxRERE1dWQIUNw9uxZzJgxA7du3cLff/+NyZMnw8HBocjtmzVrhjFjxhg2JBERURUnCAIchg1FnZ9+hGXTpgCA5J07EfnKCGTf/lfidFRRpuXZqUuXLujSpYv2vVqtxu3btxEWFqZ9nT9/HgcOHACACveUmzRpErZt24YpU6YgICAAoaGhGDNmDHJzczFu3LgS9/3444/h6uqK559/Hr/88kux24miiMGDB+P06dOYMWMGfHx8sHXrVvTt2xdHjhxB9+7dy52fSte+Tl4hLlaRjaikLPg4W0uciIiIiMrqyJEjeOmllyAIgk7bt23bFm3btq3kVERERNWTuZ8f/HbtRPw3q5G4cSOUd+4g8pVX4BoUBMfXxuj87y1VLYJYid2P4uPjcfnyZVy5cgUzZ84s1zEuXbqEVq1aYdGiRfjkk08A5BXNunfvjhs3biAqKqrQxBFPi4iIgL+/P4C8gmBgYCC+++67Qtvt3bsXQ4cOxbZt2zB27FgAgFKpRLNmzWBlZYXLly+XKffDhw/h7e2NqKgoeHl5lWnfmignV4NmC39FtkqDz4c3w4g23lJHIiIiItI7thGJiKg8Mv45g+hZs5D7+DEAwKbLi/D87DOYOjtLnIzKqlyPpiYnJ+u0nYuLC3r06KEtwum639N2794NmUyGKVOmaJcJgoCpU6ciLi4Ox44dK3H//CKcLudxcnLC6NGjtcssLCwwadIkhIWF4fbt22XOTrozN5WhlQ/HiSMiIjIG6enpuH79Ok6ePIk///yz0IuIiIjKxqZ9O/iH7oVdj5cBABkn/kT4oMFIP3lS4mRUVuV6NNXb2xtt2rTB4MGDMWjQoBKLXffv38e+ffsQGhqKU6dOlXlW1QsXLqBu3bpwcnIqsLxdu3YAgIsXL6JXr15lv4giztO6dWuYmJgUe56GDRsWu79CoYBCodC+j4mJqXCmmqatvxNO30vkzKlERETVVHJyMqZNm4bdu3cXOSyJKIqc3IuIiKicTB0dUXvVKqTs2YPHS5dBnZCAqLcmwWn8OLjMmAGZhYXUEUkH5SrEbd26FaGhoVi8eDFmzpyJpk2bYtCgQRgyZAhatmyJy5cva4tvV65cgVwuR58+fbB9+/Yynys6OhoeHh6Flnt6emrX60N0dDReeOGFcp9nxYoVWLhwoV6y1FTt/J0B3EFUUhaiU7Lg6WAldSQiIiIqg8DAQOzduxdTp05Fly5d4OjoKHUkIiIioyIIAhxHjIB1mzZ49MEHUN64iaSt25DxzxnU/vILWNSrJ3VEKkW5CnHDhg3DsGHDoFarceLECezbtw/bt2/Hp59+CisrK2RlZcHLywsDBw5EcHAwunbtClPTcp0KWVlZsCiiqiuTyWBmZoasrKxyHVfX81haWmrXl2TGjBl48803te9jYmI4+HAZtfRxgLmJDDlqDc5GJGFwy9pSRyIiIqIyOHToEN599118+eWXUkchIiIyahZ16sDvf/9D/MqvkLR5M5S3byNi2HC4fTQLDq++yokcqrDyVceeMDExQffu3dG9e3d8/fXXuHz5Mv766y906NABLVu21EtAS0tLKJXKQss1Gg1UKpW2UFZZ58nOztauL4lcLodcLtdLlprK0swEzb3tcS4yGWciElmIIyIiqmYsLCxQv359qWMQERHVCDJzc7jN+hA2nToi5qPZyI2PR+zCRUg/eQoeny6BKXumV0nlmqyhOC1atMCUKVP0VoQD8h4NLWq8tfxHRfMfHa0u56GS5T2eCpzhhA1ERETVzvDhw3Hw4EGpYxAREdUoth07wn9fKGy7dQMApP/xByIGDkLG6dMSJ6Oi6LUQVxlatWqFe/fuISmpYGHmzJkz2vX6Os/FixcLDR6s7/NQydr6503KER6fgbi0bInTEBERUUni4uIKvD744APExMRg7Nix+PvvvxETE1Nom7i4OKljExERGR1TJyd4rfkW7vPnQbCwQG58PB5MfAOPPw+GWMZJM6lyVflC3IgRI6DRaLBmzRrtMlEUsXr1ari4uKDbk4pvQkICbt26hczMzHKfJzExEd9//712mVKpxPr16xEQEIDnnnuuYhdCOmnt6wgTWd6z7OcikiVOQ0RERCVxd3eHh4eH9tWwYUNcuHABO3fuRKdOneDl5VVgff6LiIiI9E8QBDiOGgX//9sDi4YNAQBJISGIePVVKMPDJU5H+So0RpwhtG7dGmPHjsX8+fMRHx+PgIAAhIaG4vjx4wgJCdFOsLB69WosXLgQx44dQ9euXbX7b9++Hffv39e+v3jxIpYsWQIAGDt2LHx9fQEAQ4cORZcuXfDWW2/h1q1b8PHxwdatW3Hv3j0cPnzYcBdcw9lYmKJpbXuERaXgTEQi+jVjY52IiKiqmjdvXo0ZDFqpVOLtt9/Gb7/9hpSUFDRu3BgrV67ECy+8IHU0IiKiAizq14ff7h8Qv2IFkrZug/LGTUQMHQa3ObPh8MorNebf7qqqyhfiAGDjxo3w9fXFli1b8N1336FBgwbYvn07XnvttVL33bRpE06cOKF9f+7cOZw7dw4A0KlTJ20hThAE/Pzzz5gzZw42bNgAhUKBgIAAHDhwAC+//HLlXBgVqb2/U14hLpzjxBEREVVlCxYskDqCweTm5sLPzw+nTp2Cl5cXtm/fjgEDBuDBgwewtraWOh4REVEBMgsLuM2eDZtOnRD90WyoExMRO28+Mk6ehPuiRZzIQUKCKIqi1CGM0cOHD+Ht7Y2oqCh4eXlJHada+f3mY7yx9TwA4NInPeBoYy5xIiIiIqLCnJ2d8ccff6B58+Y678M2IhERGVpuQgKi58xBxp8nAQCmbm7wXL4cNu3bQZOZidT9B5C6/2eoE5Ng4uwE+wEDYT+gP2T8RVOlqPJjxFHN08bPCfk9Zc9GslccERFRdXPw4EG888476NevH/r164d33nlHb7OppqenY/78+ejbty9cXFwgCAKWLVtW5LYqlQrz5s2Dj48PLC0t0axZM+zatUsvOfLHJq5Tp45ejkdERFRZTGvVgve6dXD7+GMI5ubIffwYD15/HTGfzMO9Pn0RO38+si5cRE5EBLIuXETs/Pm417cflOERUkc3SizEUZVjb2WGxh5yAMDZCBbiiIiIqguFQoEePXpgwIABWL9+PS5fvoxLly5h/fr1GDBgAF566SUoFIoKnSMhIQGLFi3C1atX0bJlyxK3nTRpEj799FMMHjwY33zzDby9vTFmzBhs27atQhkyMzMxduxYzJ07F3Z2dhU6FhERkSEIggCnsa/Bb88eWNSvB4giUvbsQe7jx3kb5D8s+eTP3Lg4PJg4EZpyTohJxWMhjqqk5/3ynlffd/kRZv90FUsP3cSlB8ngk9RERERV1/Tp0/HHH39gyZIlSE5OxqNHjxAdHY3k5GQsXrwYx48fx/Tp0yt0Dg8PDzx69AhRUVFYv359sdtdunQJW7ZswYIFC7Bq1Sq89dZbOHDgALp27YqgoCDk5ORot+3ZsycsLS2LfL399tsFjqtSqTB8+HA0btwYc+bMqdC1EBERGZplwwbw27MH1m3blryhRoPc2FikHjhgmGA1CAtxVOXcjk3D7zfjAAAJ6Tn44dwDrDsRjiFrTqP/N6dwOzZN4oRERERUlJ9++gmBgYGYPXt2gZ5idnZ2mDNnDiZNmoSffvqpQuewsLCAp6dnqdvt3r0bMpkMU6ZM0S4TBAFTp05FXFwcjh07pl1+5MgRZGdnF/lau3atdjuNRoNx48bBzMwMmzZt4qxzRERULcksLSFq1EBp/44JAlL37zdMqBqEhTiqUm7HpmHY2tN4lJKlXaZ5qhPczRgFhq09zWIcERFRFRUQEFCudfp24cIF1K1bF05OTgWWt2vXDgBw8eLFMh8zMDAQMTEx+OGHH2BqaqqXnERERFJQJyb99zhqcUQR6oREwwSqQViIoypDFEXM2H0ZmTm5BYpvT9OIQGZOLmbsvszHVImIiKqYvn374kAJj7AcOHAAffv2NUiW6OhoeHh4FFqe35suOjq6TMe7f/8+Nm7ciDNnzqBWrVqwtbWFra0tdu7cWeJ+CoUCDx8+1L5iYmLKdF4iIqLKYOLspFOPOJNazoYJVIPwV3lUZVyOSsH16NIHcNaIwPVoBcIepqKFt0PlByMiIiKdzJ07F6+++ir69++PqVOnol69ehAEAf/++y9Wr16N6OhofPnll4iLiyuwn6urq96zZGVlwcLCotBymUwGMzMzZGVlFbFX8Xx9fcv1S8AVK1Zg4cKFZd6PiIioMtkPGIis8xdK3kgUYT9ggGEC1SAsxFGVcfh6bNm2vxbLQhwREVEV0qRJEwDA1atXcejQoQLr8otYTZs2LbSfWq3WexZLS0solcpCyzUaDVQqFSwtLfV+zqLMmDEDb775pvZ9TEwM2pY2QDYREVElsx/QHwlr1yI3Lg7QaApvIAgwdXODff/+hg9n5FiIoypDkZULmYBiH0t9mkwAUrNUlR+KiIiIdDZv3rwqM4GBp6cn7t+/X2h5/iOpukz4oA9yuRxyudwg5yIiItKVzNoaPiEheDBxInJjY/MeU32657cownXGdMisraULaaRYiKMqQ25lqlMRDsgr1tlbmVVuICIiIiqTBQsWSB1Bq1WrVvjjjz+QlJRUYMKGM2fOaNcbUnBwMIKDgyul9x8REVF5WNTxR92DvyD1wAGk7t8PdUIiTBwdkXP/PtSJiYj7cgVsOneGqaOj1FGNCidroCqjdxP3Mm3fuX6tSkpCREREukhJSZFkX12MGDECGo0Ga9as0S4TRRGrV6+Gi4sLunXrVqnnf1ZQUBDi4uIQFhZm0PMSERGVRGZtDccRI+C3fTvqHjoIv1074ROyCYK5OXIfP0bMR7M5UaKesUccVRktvB3QxFOOmzEKnXrGzdx9GcuHN0eXBi6VH46IiIgK8fHxwTvvvIPAwED4+/vrtE94eDjWrl2L9evXIzU1tVznXb16NVJSUrTFvGPHjiE3NxcAMG3aNNjb26N169YYO3Ys5s+fj/j4eAQEBCA0NBTHjx9HSEhIkRM5EBEREWDZsCHc5sxG7IKFSD9xAklbtsL59QlSxzIagsjSZqV4+PAhvL29ERUVBS8vL6njVBu3Y9MwbO1pZObkFlmMkwmAqUwGjahB7pPxJEe28cbH/RtBbslHVYmIiAwpNDQU8+bNw/Xr19GmTRv07NkTrVu3Rp06deDo6AhRFJGcnIyIiAicP38eR48exfnz59G4cWMsWbIEgwYNKtd5/fz8ihz/DQAiIiLg5+cHAMjJycHixYuxZcsWxMXFoUGDBpg1axZee+218l5yhbGNSERE1YEoinj0/nSk/forYGYGv107YRUQIHUso8BCXCVhI6v8bsemYcbuy7gerQCAAhM4NPGUY+XIFtCIIj7YE4Zrj/K28bC3xLJhzdg7joiIyMBEUcTBgwcREhKCQ4cOITs7u9CEDaIowtLSEr1798Ybb7yBvn37VplJHQzl6THikpKS2EYkIqIqT61QIGLIUKgePYKZlxf89/4EEzs7qWNVeyzEVRIW4ipGFEWEPUzF4WuxSM1Swd7KDL2buqO5l7224a5Sa7DuxD18/fsdqNR538bsHUdERCQdlUqF8+fP49atW0hISAAA1KpVC40aNULr1q1hZsZ/n9lGJCKi6iTryhVEjh4D5ObCrndv1F65osb9Mk3fWIirJGxkGc6tWAV7xxEREVG1wDYiERFVN4khmxH3+ecAAPcFC+D46kiJE1VvnDWVqr3n3OXY+05HfNCzAcxMBMSkZmN8yFnM+r8rUGSrpI5HREREREREVG05TRgPmy4vAgAeL12K7Nv/SpyoemMhjoyCmYkMU7vXx4FpnRFQ2x4A8MP5KPRa+SdO/BsvcToiIiqPrl274oMPPpA6RpGqcjaqmoKDg+Hq6ormzZtLHYWIiKhMBJkMnsuWwdTVFaJSiUfTp0OTmSl1rGqLhTgyKg3d7fDTOx3YO46ISM/OnDmDV155BR4eHrC0tESDBg0wZcoUPHjwQOpola6oottPP/2E+fPnS5SIqqOgoCDExcUhLCxM6ihERERlZuroCM8vggGZDDnh4YhdvETqSNUWC3FkdNg7johIv7799lt07NgR7u7u2LdvH27fvo1Vq1bh9u3b+PHHH6WOJwknJyfYcdYwIiIiqkFs2rZFrSnvAABS9+5F6s8/S5yoemIhjowWe8cREVXcsWPHMG3aNGzYsAHffPMN2rZtC19fX/Tu3RtHjx7F66+/XuR+WVlZmDp1KlxcXGBpaYlu3brh6tWr2vVdu3bF9OnTMX36dDg4OMDLywvffvttsTl27twJV1dXqFQFf36/9NJLmDZtWrH76ZLjvffewzvvvAN7e3u4urpi6dKl2vUTJkzAiRMn8OWXX0IQBAiCgMjIyEK95Lp27Yr3338fU6dOhb29PTw8PLBjxw4kJyfjlVdega2tLZo0aYIzZ85o9/Hz88Pq1asL5K1Vqxa2bNlSoeMSERERVZZakyfDul07AEDMgoVQRkRInKj6YSGOjFpJveOO346TOB0RUdU3c+ZM9OzZs8iCmyAIcHBwKHK/Dz/8EPv27cPOnTtx/vx5uLq6onfv3sh8ajyRkJAQuLm54dy5c5g+fTqmTZuGmzdvFnm8YcOGQaVS4ZdfftEue/DgAY4fP15sMbAsOWxtbXH27FksX74cCxcuxJ49ewAAX3/9NV544QW8/fbbiImJQUxMDLy9vYs8V0hICLy8vHD+/Hm8/vrrePPNNzFmzBgMGTIEly5dQuPGjTFu3DiUdcL6yjouERERUVkJJibw/PxzmDg5QczMxKMZM6FRKqWOVa2wEEc1QlG94yZsPocP/y+MveOIiIpx7do1XLp0CW+//XaZ9ktPT8e6devw5ZdfomfPnmjatCk2b94MpVKJnTt3ardr3bo1PvroI9SvXx8zZ86Eu7s7Tpw4UeQxLS0tMXr0aGzevFm7bOvWrWjatClatWpVoRx16tTB559/joYNG2oLXStXrgQA2Nvbw9zcHNbW1nB3d4e7uztMTEyKPN/T1zN//nzk5uaiQYMGGD16NOrXr49Zs2bh33//xaNHj8r0eVbWcSvLb7/9hmXLlmHVqlX49ddfERsbK3WkKoGTNRARkbEwc3OF5/JlAADlzZuIW/65xImqFxbiqMYoqnfc7vMP2TuOiKgYly9fBpBXCCrOzp07YWtrq32dPHkS9+7dg0qlQseOHbXbWVtbo2XLlgV6vAUEBBQ4loeHB+Liiv95PHHiRBw8eFC7zbZt20rsDadrjnZPHq/I98ILLxTbM68kzZo1035tYWEBe3t7NGnSRLvMzc0NAEq8RkMetzLMnz8fvXr1wpw5c/D++++jT58+qF27Ntzc3NCzZ08EBQVhx44dUseUBCdrICIiY2LbuTOc33wDAJC8axcUR45InKj6YCGOahz2jiMi0k1WVhYAwNbWtthtBg4ciMuXL2tfbdq00a4TBKHAtqIoFlhmZmZWYL0gCNBoNMWeq3Xr1mjSpAl27NiBkydP4v79+xgzZkyp11FajmfXF7esNEVdz9PL8o+Zf40ymazQ46TPjoFXnuNKad26dejcuTMeP36M5ORknD59GmvWrMGwYcOQkZGBdevWYfz48VLHJCIiIj1wee89WD3p6R0z9xPkPKwavfOrOhbiqEZi7zgiotI1bdoUAHDy5MlC60RRRHZ2Nuzs7FCvXj3ty8rKCnXr1oWZmRlOnTql3T4rKwuXL19Go0aNKpTpjTfewObNm7Flyxb0798fLi4uxW6ra45nJzr4559/8Nxzz2nfm5ubQ61WVyh3UVxcXAo8thkZGQmFQqH38xhSVlYWxowZAxcXF9jb26N9+/YIDAzEmjVr8Ndff0GhUODff/+VOiYRERHpgWBmBs8vv4RMLodGoUD0zJkQi/ilIhXEQhzVaOwdR0RUvBdeeAEvvfQSAgMDsWvXLty7dw+3bt3Cli1b8PzzzxdbNLK1tUVgYCBmzpyJI0eO4Pr165gwYQLMzc0xevToCmUaM2YM7ty5g+3bt5f4WGpZcty7dw8fffQR/v33X2zduhUbN27Ee++9p13v5+eHf/75B/fv30dCQoLeep517doV27Ztw99//41Lly4hMDAQ5ubmejm2VLp27YqIUmZPq1u3roHSEBERUWUz96oNjyWLAQBZYWGI//priRNVfaZSByCSWn7vuB6N3fHBnjBcfZSK3ecf4uSdBCwdGoCuDV2ljkhEJJkDBw7giy++wGeffYbw8HBYW1ujQYMGGDVqFFxdi//5+Pnnn0MURYwZMwZpaWlo3749Dh8+DGtr6wrlcXJywqBBg3DixAn06dOn1O11yTFx4kQkJSWhTZs2sLCwwNy5czFy5Ejt+g8++ADjx49Ho0aNkJWVVWqhSVezZ8/G3bt30bNnT7i5uWHFihW4cOGCXo4tlU8//RS9evVCYGAgfH19pY5DREREBiDv2ROZo0cjedcuJG7cBOt27WDbubPUsaosQeRc95Xi4cOH8Pb2RlRUFLy8vKSOQzpSqTVYd+Ievv79DlTqvL8aI9p4YW7/xpBbmpWyNxERGcKLL76I9u3b4/PPKz5DV9euXdGmTRt88cUXekhGL7zwAtLT0xEdHY25c+di0KBBqFOnjtSxqoTg4GAEBwdDrVYjKSmJbUQiIjIqGqUSka+OgvLmTZg4OcF/716YubFTS1H4aCrRUzh2HBFR1ZWUlIQdO3bg9OnTePvtt6WOQ0VwcHBAcnIykpOTMXPmTNSvXx/Ozs7o3r07pk+fji1btuDSpUtSx5QEZ00lIiJjJrOwQO0VX0KwtoY6KQnRH34IsRLG2DUGfDSVqAj5Y8fl947LHzuOveOIiKTTqlUrpKSkYOXKlfD395c6DhXh0KFDAIDk5GRcuXIFV65cwdWrV3HlyhVs3LgRGRkZEAShUia/ICIiImlZ+PvDY/48RM/6CJlnziDhu+/gMmWK1LGqHD6aWkn4aKrxuB2bph07DgA87C05dhwREVE53Lt3D1evXsXgwYOljiIZthGJiMjYRc+eg9S9ewGZDD5bNsOmbVupI1UpfDSVqBT5veOCejXkzKpEREQVULdu3RpdhCMiIqoJ3D+ZC/M6dQCNBtEfBCE3KUnqSFUKC3FEOjAzkWFKt3pFjh13jGPHEREREREREQEAZNbWqL1yBQRzc+TGxSF69myIGo3UsaoMFuKIyqChux32PtM77vUnveNSs9g7joiIiIiIiMiyYUO4zZkNAMg48SeStmyVOFHVwUIcURmZsnccERGR1sKFC3HhwgWpYxAREVEV4zByJOx69wYAxK1YgawrVyROVDWwEEdUTs/2jotVsHccERHVPJs3b0bbtm3h4eGBiRMnYu/evUhPT5c6VpUTHBwMV1dXNG/eXOooREREBiEIAjwWL4KZlxeQm4tHM2ZCrVBIHUtyLMQRVQB7xxERVQ9paWnYv38/1Gq11FGMTmRkJC5fvoz33nsPd+7cwYgRI1CrVi28/PLL+Prrr3Hnzh2pI1YJQUFBiIuLQ1hYmNRRiIiIDMbEzg61V3wJmJpC9fAhYj6ZB1EUpY4lKRbiiPSAveOIiKq2W7du4cKFC1Dwt7CVIiAgAB999BFOnjyJuLg4hISEwM3NDYsXL8Zzzz2HBg0aYMaMGfjjjz+Qm5srdVwiIiIyIKtmzeA6YwYAIO3XX5Hyww8SJ5IWC3FEesLecUREVVd+Ac7Ozk7iJMbP0dERo0ePxs6dOxEXF4cTJ05g2LBh+O233/Dyyy/D2dkZw4cPx99//y11VCIiIjIQpwnjYdulCwDg8WdLkX37tsSJpFMtCnEqlQrz5s2Dj48PLC0t0axZM+zatUunfUVRxNdff40GDRrAwsICDRo0wKpVqwp1hdyyZQsEQSjydffu3cq4LDJS7B1HRFT1KBQK2NrawtTUVOooNYpMJkOnTp2wdOlSXLlyBffv38fSpUuRnZ2Nv/76S+p4REREZCCCTAaPZUth6uYGMScHj6bPgCYjQ+pYkqgWrdFJkyZh27ZtmDJlCgICAhAaGooxY8YgNzcX48aNK3HfRYsWYcGCBRg7diw+/PBDnDhxAu+99x5SUlIwb968QtsvWLAAdevWLbDM3d1dr9dDxi+/d9zLjdzwwZ4wXH2Uit3nH+LPfxOwdFgAujV01W4riiIuR6Xg8PVYKLJyIbcyRe8m7mjh7QBBECS8CiIi46FQKCCXy6WOUeN5e3vjnXfewTvvvCN1FCIiIjIwU0dH1P4iGPfHT0BOeDhiFy+B57KlUscyOEGs4qPkXbp0Ca1atcKiRYvwySefAMgrXHTv3h03btxAVFQUzM3Ni9w3NjYWfn5+GDNmDDZt2qRdPmHCBPzvf/9DZGSktsi2ZcsWvP766/j777/Rvn37Cud++PAhvL29ERUVBS8vrwofj6qvXLUG6/4Mx1e//QuVOu+v24g2Xvi4X2PEpmZjxu7LuB6d98iUTAA0T/5GNvGUY8WIFmjozseoiIgq6ptvvoGLiwteffVVqaNQDcc2IhER1XTx336LhG9WAwA8ly+D/aBBEicyrCr/aOru3bshk8kwZcoU7TJBEDB16lTExcXh2LFjxe67b98+KJVKTJs2rcDyadOmQalUYt++fUXul5aWxlnVSG+KGzuu+xfHMejbU7gZ89/A4ZqnyuI3YxQYtvY0bsemGToyEZFREUWRPeIM5O2338b9+/eljkFERERVWK3Jk2Hdrh0AIGbhIijDIyROZFhVvhB34cIF1K1bF05OTgWWt3ty0y5evFjivhYWFmjWrFmB5S1btoS5uXmR+/bo0QNyuRzW1tbo27cvbty4oYerICo4dpypDEjMyEG2SlOg+PY0jQhk5uRixu7LNX56ZyKiisjKyoJKpWIhzgDu3LmDBg0aYMKECbh161ah9WfPnsXSpTXvERQiIiL6j2BiAs/gz2Hi5AQxMxOPZsyARqmUOpbBVPlCXHR0NDw8PAot9/T01K4vaV83NzfIZAUvUyaTwc3NrcC+1tbWGD9+PFavXo29e/fio48+wsmTJ9GhQwedJmtQKBR4+PCh9hUTE6PrJVINkt877stXWui0vUYErkcrEPYwtXKDEREZsfwZU1mIq3y//fYbfv/9d1y6dAlNmzbFyJEjcfLkSVy/fh3nzp3D8uXLsXjxYqljSiI4OBiurq5o3ry51FGIiIgkZ+bqCs/lywAAylu3ELd8ucSJDKfKT9aQlZUFCwuLQstlMhnMzMyQlZVV5n0BwNLSssC+I0aMwIgRI7TvBw8ejP79+6N9+/ZYsGABduzYUWLOFStWYOHChaVdDhEA4EasovSNnnL4WixaeDtUThgiIiPHQpzh7N27F6+//rr2M9+zZw/27NlTYPKh6dOnSxVPUkFBQQgKCtKOEUdERFTT2XbuDOc330Dixk1I3vU9rNu1h7xXT6ljVboq3yPO0tISyiK6KGo0GqhUKlhaWpZ5XwDIzs4ucV8AeP7559G5c2f89ttvpeacMWMGoqKitK+zZ8+Wug/VXIqsXMh0nBBVAHDncRqUuRy3kIioPFiIM5wPP/wQfn5+OHbsGCIjIxEZGYmzZ89i0KBBEEURI0aMQHBwsNQxiYiIqIpwee89WD3pLR4zdy5yHj6SOFHlq/KFOE9PzyIf88x/rDT/EdXi9n38+DE0Gk2B5RqNBo8fPy5x33w+Pj5ISkoqdTu5XA4vLy/tq6jHaYnyya1Mix0b7lkigN9vxaHZgiMYveEffPP7HVy4nwSVWlPqvkRExEKcIT169AiTJ09Gly5d4OPjAx8fH7Rp0wY//fQTDhw4gF9++QVBQUFSxyQiIqIqQjAzQ+0VX0Iml0OTloZHM2dAVKmkjlWpqnwhrlWrVrh3716hYtiZM2e060vaV6lU4sqVKwWWX7p0CTk5OSXumy88PByurq7lSE5UvN5N3Mu8jzJXg9P3EvHl0X8xbO3faL7wCMaFnMV3J+4hLCoFuSzMEREVSaFQwMbGBqamVX5EjmqvadOmOH36dJHr+vbti1mzZiEkJMTAqYiIiKgqM6tdGx5L8saQzQ67grivvpI2UCWr8oW4ESNGQKPRYM2aNdploihi9erVcHFxQbdu3QAACQkJuHXrFjIzM7XbDRo0CObm5li9enWBY37zzTcwNzfHoEGDtMuK6vX2+++/46+//kKvXr30fVlUw7XwdkATT3mpj6fKBKBpbTlOftgVwcObYWir2vCwz3ukOjNHjT//jceyQ7cw6Nu/0HLRUby59Rw2ngzH9ehUaHTtckdEZOQUCgV7wxnIrFmzsGPHDsycORNpaWmF1mdlZSE7O1uCZERERFSVyXv2hOPo0QCApE0hSP/zT4kTVZ4q/6vh1q1bY+zYsZg/fz7i4+MREBCA0NBQHD9+HCEhIdrJGFavXo2FCxfi2LFj6Nq1K4C8R1NnzZqFxYsXQ6VS4cUXX8SJEyewfft2zJs3r8Djo506dULLli3RrFkzODg44PLly9i0aRPc3NywYMECCa6cjJkgCFgxogWGrT2NzJzcIh9TlQmAtbkpVoxoAW8nG3g72eCVNt4QRRH3EzPxd3gi/r6XiNP3EpGQrkSaMhe/3YzDbzfjAAAO1mZo7++MF+rmveq72hYYLJuIqKZQKBRwdnaWOkaNMGzYMMybNw+LFi1CSEgIevXqhdatW8PR0RFXr17Fd999h7Zt20odk4iIiKog11kfIvPSJShv3kT0rI/gHxoKMzfje0KxyhfiAGDjxo3w9fXFli1b8N1336FBgwbYvn07XnvttVL3XbhwIRwdHfHtt9/if//7H7y9vbFixQq8//77BbYbOnQoDh48iIMHDyIjIwPu7u6YMGEC5s+fj9q1a1fSlVFN1tDdDj++3QEzdl/G9ei88YtkArRFuUYecqwc2QIN3OwK7CcIAvxq2cCvlg1GtfWBKIq4F5+Ov+8laotzyZkqpGSqcPh6LA5fjwUA1LK1QPs6TnmFuTrO8K9lw8IcERk9URShUCjg7+8vdZQaY8GCBejVqxdWrlyJw4cPY/fu3dp19erVw9q1ayVMR0RERFWVzMICtVd8iYhhw6FOTkZ0UBB8NodAMDGROppeCaIo8vm1SpA/NX1UVBS8vLykjkNVmCiKCHuYisPXYpGapYK9lRl6N3VHcy/7chXKNBoRtx+naXvLnYlIRFp2bqHt3OWW2qLcC3Wd4e1krY/LISKqUrKysrB8+XK89NJL6Ny5s9RxahyNRoPw8HAkJCTAwcEBzz33nNSRJMc2IhERUclSf/4Z0R/OAgDUmjoVLlOnSJxIv6pFjzgiYyYIAlp4O6CFt4NejieTCWjkIUcjDzkmdvKHWiPiRrQCp+8l4O/wRJyLSEJGjhqximzsvfQIey/lTQ/t5WiFF+o4o0M9Z7xQpxbcn4xFR0RUnXHGVGnJZDLUq1cP9erVkzoKERERVRP2Awci4+9/kLp3LxLWrIF12+dhY0RDW7AQR2TkTGQCArzsEeBlj8AudaFSa3D1UWreo6z3EnH+fhKyVRo8TM7CngsPsefCQwCAfy0btK/jjA51ndG+jjNc7CwkvhIiorJjIY6IiIio+nH/ZC6ywsKQEx6O6A+C4B+6F6ZOTlLH0gsW4ohqGDMTGVr5OKKVjyOmdKsHZa4alx+kaMeXu/QgBTlqDSISMhCRkIHvzz4AANR3tUWHJxM/tPN3hqONeZnOK4oiLkel4PD1WCiyciG3MkXvJu5o4e3AseqIqNKwEEdERERU/cisrVF75UpEjhiB3Lg4RH/0Eby/+w6CTCZ1tApjIY6ohrMwNUG7Os5oV8cZ778MZKvUuHA/WTv5Q1hUCnI1Iu7EpeNOXDq2/n0fggA0cpdrx5hrW8cJckuzYs9xOzatyEkp1p0IRxNPOVaMaIGG7nbF7k9EVF4sxBERERFVT5YNG8Bt9mzELliAjD9PImnzFji/MVHqWBXGQhwRFWBpZoKO9WqhY71aAIAMZS7ORSbh7/BE/HMvEVcfpUIjAjdiFLgRo8CmUxGQCUBAbXu0f1KYe97PCTYWeT9ebsemYdja08jM+W/CCM1TU8TcjFFg2NrT+PHtDizGEZHeKRQKWFtbw8ys+F8WEBEREVHV5DByBDL++Qdphw8jbuVKWLdpDavmzaWOVSGcNbWScEYsMlapWSqci8grzJ2+l4ibMYpC25jKBDT3dkB7fyf8cjUGD5IyCxTfniUTgEYechyY1omPqRKRXm3fvh0ZGRmYPHmy1FGohgsODkZwcDDUajWSkpLYRiQiItKROi0NEUOGQvXwIcxq14b/3p9gUo2fdmAhrpKwEEc1RXJGDs5E5I0vd/peIu7EpZf7WKFTOupt9lgiIgD49ttv4ejoiNGjR0sdhQgA24hERETlkXX1KiJHjwFUKtj16oXaX62stp04+GgqEVWIo405ejf1QO+mHgCA+DQl/nnSW+7g1RikZql0Ptae81EIqG0PE1n1/IFKRFWPQqGAr6+v1DGIiIiIqAKsAgLgOmMG4pYvR9qvvyLlf/+D46hRUscqFxbiiEivXOwsMKC5JwY09wQA/HDuQYmPpT5t55kH2HP+IXycreHnbA0/Zxv41bJ58qc1PO2tIGORjoh0lJ2dDaVSyYkaiIiIiIyA04TxyPznH6SfOIHHS5fBqmVLWD73nNSxyoyFOCKqNHIrU52LcPly1BrcjUvH3SIecTU3lcHHKa9A51/LGr7ONvCvlVes85BbskhHRAVwxlQiIiIi4yEIAjyWLUXE4CHIffwYj6bPgP//7YHMxkbqaGXCQhwRVZreTdyx7kS4ztt/+UpzmJvKcD8xAxEJmYhMzMD9xAwkpOcAAHJySy7S+TpZP+lBl/env7MNfFmkI6qxWIgjIiIiMi6mjo6o/UUw7o+fgJyICMQuXgLPZUuljlUmLMQRUaVp4e2AJp5y3IxRlDpramNPOYa2ql3kgJuKbBXuPynMRSZkIDLxv68TM/4r0t2JSy9ysggLUxl8nZ/qQffkUVc/Zxu4V0KRThRFXI5KweHrsVBk5UJuZYreTdzRwtuh2g4oSlQdsRBHREREZHysn38etaZOQcKqb5AaGgrr9u3gMHiw1LF0xkIcEVUaQRCwYkQLDFt7Gpk5uUUW42QCYG1uihUjWhRbpJJbmiHAyx4BXvaF1uUX6SISM3A/IQMRTwp09xMztUU6Za4G/z5Ox7+Piy7S+TnbwNfZWvuYa/7XbnZlL9Ldjk3DjN2XcT1aob0+jQisOxGOJp5yrBjRAg3d7cp0TCIqHxbiiIiIiIxTrcBAZJ49h8x//kHsosWwatYcFnX8pY6lExbiiKhSNXS3w49vdyiyOAUAjTzkWDmyBRq4la84VVKRLjVLhfuJT3rQJeT3pst7n/RUke724zTcfpxWaH9LMxl8nZ70nsvvSfekV52rnUWhIt3t2DRt0THf08XHmzEKDFt7Gj++3YHFOCIDUCgUsLKygrm5udRRiIiIiEiPBBMTeH6+HBGDh0CdlIRH06fDb/cPkFlYSB2tVCzEEVGla+huhwPTOiHsYSoOX4tFapYK9lZm6N3UHc297CvtcU17KzM083JAMy+HQuvyi3QRCRmITMjM+/pJb7rkTBUAIFtVcpEuvzDnW8safk7WWPdneLE9/4C8olxmTi5m7L6MA9M68TFVokqmUCjYG46IiIjISJm5usJz+XJEvfUWlLdvI275crjPmyd1rFKxEEdEBiEIAlp4O6CFt4PUUQCUUqTLVD3pOZdXpIt8UrC7n1iwSHcrNg23YgsX6UqiEYHr0QpcjkpBSx9HfVwKERVDoVDA3r5wb1kiIiIiMg62nTvB+a03kbhhI5J3fQ/rdu0h79VT6lglYiGOiOgZ9tZmaG7tgOZFFA1TM1V549Fpe9PlPep6M0YBZa5G53O88t3fqO1oBVc7C7jaWcJV/uRPOwu4yi3gJs/72t7KrEr3nOPEFFTVXLlyBSqVCq1bt4ZCoYC3tzcAQKVSwczMTOJ0RERERKRvLu++i8xz55F1+TJi5s6FZZPGMPfykjpWsViIIyIqA3trM7SwLtyzb/ZPV/HDuQclzg77tFyNiPuJmbifmFniduamMrjYWsBN/nTBzgKuTwp1rnaWcJNbwNHaXO+zv5aGE1NQVRQZGYmrV6+iQYMGyM7Ohlwux7Vr1xAaGop3332Xj6oSERERGRnBzAy1v/wC4UOGQqNQ4NHMmfDbsQNCFf0lLAtxRER6ILcy1bkIBwAvN3JF+zrOiEtTIk6RjccKJeLSshGXpkRa9n+TPeTkavAoJQuPUrJKPJ6pTICLXeEiXV7vuv962znbWsBEDwU7TkxBVVXTpk1x8eJFnDt3DgBgZmaGgwcPwsXFBba2thKnIyIiIqLKYFa7Njw+XYJH095FdtgVxH31FdyCgqSOVSQW4oiI9KB3E3esOxGu8/ZTu9cvdry8rBy1tigX96RAl1+oi09T4rEib13Kk/HqgLwedjGp2YhJzQaQWux5ZQJQy9bimUdh8wt3eV+7yS1Qy9YCZiayIo8hiiJm7L7MiSmoSvL394e9vT0uXLgAALh06RKys7Mxbtw4yGRFf08T6eqVV17Bn3/+iaysLPj5+eGzzz5D//79pY5FREREAOQ9eiBzzBgk79yJpE0hsGnXDrYvvih1rEJYiCMi0oMW3g5o4inHzRhFiT3jZALQ2FOO5l7FDyBvZW4CX2cb+DrblHhOZa76SWFOifinCnf5hbq4tLzlCek52n00IrTrAEWxxxYEwMnaHC52/41Xl1+8y8zJ1T6OWpL8iSnCHqZWmUk6yotj4VUfgiCgWbNm2LNnD9RqNTIyMtCrVy+4u7tLHY2MwIIFC1C/fn2Ym5vj7Nmz6NGjB8LDw+Hs7Cx1NCIiIgLg+mEQMi9ehPLmTUTP+gj+oXth5uYmdawCWIgjItIDQRCwYkQL7eOaRRXjZAJgbW6KFSNa6KV4Y2FqAi9Ha3g5Wpe4nUqtQUJ60UU67SOxCiUS0pXa3KIIJGbkIDEjp8wzwz7rq9/+xai2PpBbmkFuZQq5pRnsrc1ga25q8HHtyqOmjYVnDEXH5s2bY/v27YiMjMSgQYPQpUsXqSORkWjSpIn2a5lMhpycHDx69IiFOCIioipCZmGB2iu+RMSw4VAnJyM66EP4bA6BYGIidTQtFuKIiPSkobsdfny7Q5FFGwBo5CHHypEt0MDNsEUbMxMZPOyt4GFvVeJ2ao2IxHTlk0Jd9pPCnfKpx2Tz/oxNzUYZhsPD8dvxOH47vtByQQDsLEwhtzKDvZVZwUKdlRnkVmaQWz61/qlt7K3MYGVmUumFoZo2Fp4xFB1FUURUpgn+TcpFdFI2YhwDcC0mvVoVEqlk6enpCA4Oxrlz53Du3DkkJCRg6dKl+Oijjwptq1KpsHjxYmzZsgVxcXFo0KABPvroI4wePbrc5x8zZgx+/PFHKJVK9O3bFwEBARW5HCIiItIzC39/eCxcgOigD5F59iwS1n4Hl6lTpI6lJYiiWJb/T5GOHj58CG9vb0RFRcGrCk+bS0T6J4oiwh6m4vC1WKRmqWBvZYbeTd3R3MveKAoBnx28ifV/6j4enoWpDLkaEeqyzGahA1OZoC3WPVuo++/rotfbW5nBwrTk34qJooj+35zS6XHjRh7yaj8W3tNFx5J6dFblouPThcScxxHQpMfDsm5bAKg2hUQqXWRkJPz9/eHl5YVGjRrh6NGjxRbiXn/9dWzbtg1TpkxBQEAAQkNDcfDgQWzduhXjxo0rd4bc3Fz88ccfuHnzJt57770y7cs2IhERkWFEz/kYqT/9BAgCnMaPR9a1q1AnJsHE2Qn2AwbCfkB/yKxLfrqoMrAQV0nYyCIiY3XpQTKGrDmt8/ahUzqiuZc9MnPUSM1SQZGtgiIrF4osVcH32U/eP7Usf/3TM8nqi4WprPhCnqUZ0pW52Pb3/TJdZ3UdC88Yio7GUEgk3SiVSiQmJsLT01NblCuqEHfp0iW0atUKixYtwieffAIg73u9e/fuuHHjBqKiomBubg4A6NmzJ/78888iz/f6669j7dq1Ra7r378/3nnnHfTt21fn/GwjEhERGYYmMxPhgwZDFRWVt0AQ8sbgefKnqbs7fEJCYFHH36C5+GgqERGVSXkmphAEATYWprCxMIUnSn5EtihqjYh0ZRHFO+3XKiiyc4st5GXmqAsdU5mrQXyaEvFpyjLnKcrQNX/ByswEpiYymJkIMJXJYGoiwMxEBlOZoF2e/97MJG+9qezJ9iYymMmEvGVPvs7b5tnjFdzn6WOZFXPu4s6Vv/xGtKJaT8DBmXxrFgsLC3h6epa63e7duyGTyTBlyn+PogiCgKlTp2L48OE4duwYevXqBQA4cuRIubKo1WrcvXu3XPsSERFR5dNkZf33RhQL/JkbF4cHEyei7sFfDNozjoU4IiIqEykmpjCRCbB/Mlacdzn2V6k1SCuhUJe/LPWp4t69uHQoytATTyMCGTlqAIWLfsZmzMZ/4GCV15NIJgNkggABT/4U8r5HZELeezy1XPZkOZ5aX3C/gtsK2vWFjy1AgEyW92dKVk61LiRS5bhw4QLq1q0LJyenAsvbtWsHALh48aK2EKeL6Oho/P333+jTpw/MzMywd+9eHDt2DEuXLtVrbiIiItKP1P0HoE5IKH4DjQa5sbFIPXAAjiNGGCwXC3FERFRmVXViiuKYmcjgZGMOJxtznfdZeugm1p3QfSy8no3d0K+ZB3LVInI1GqjUInLVGuRqRO3XKrUGKk3+13nb5arFZ77O30dT8FgaDVS5IlRPtst96li56qeW63ksvqJkKNXIUGaVvmEVdfhaLAtxNUB0dDQ8PDwKLc/vTRcdHV3mY65cuRITJ06EIAioV68evv/+e7Ro0aLEfRQKBRSK/wrFMTExZT4vERERlV3q/p//exy1OIKA1P37WYgjIqKqr6G7HQ5M62S0E1P0buJepkLcO93qVYnijijmFeMKFOeeFO1UuZqnioT/rd/+dyT2X9G9ONC5fi30bOIOURSh0YgQkVeEFUURoghoRDHvPZ68126TtxxPrdc82R5P7/fka1EsuJ/45Bwa7Xny3l98kIzIxEydsssEIDVLVZ6PlqqZrKwsWFhYFFouk8lgZmaGrKyyFZM9PT1x6tSpMudYsWIFFi5cWOb9iIiIqGLUiUklF+EAQBShTkg0TKAnWIgjIqJyEwQBLbwdqkQBSt/KMxZeVSAIwpOx6AArlDwzbD4zE6FMhbiZPRtWqXtelt6LGhGwtzKr5ERUFVhaWkKpLDwGpEajgUqlgqWlpUFyzJgxA2+++ab2fUxMDNq2bWuQcxMREdVkJs5OQGRkqT3iTGo5GywTAMgMejYiIqJqIn8sPGtz07xxzYqg77HwpJJfdCzuOvPJBKBp7apTdMzXu4l72bZvWrbtqXry9PQs8jHQ/EdSdZnwQR/kcjm8vLy0r6IelyUiIiL9sx8wUKcecfYDBhgm0BMsxBERERUjfyy8Rh5y7bKni1WNPOT46Z0OVWYsvPKq7kXH6l5IpMrRqlUr3Lt3D0lJSQWWnzlzRrvekIKDg+Hq6ormzZsb9LxEREQ1lf2A/jB1d8+bXawoMhlM3d1h37+/QXOxEEdERFSC/LHwQqd0xOQudTHyeR9M7lIXoVM64sC0TtW+CJevOhcdq3shkSrHiBEjoNFosGbNGu0yURSxevVquLi4oFu3bgbNExQUhLi4OISFhRn0vERERDWVzNoaPiEhMHV1zVuQ3wZ88qepqyt8QkIgs7Y2aC6OEUdERFQKYx4L72nVeQKO6jaTL1XM6tWrkZKSgpSUFADAsWPHkJubCwCYNm0a7O3t0bp1a4wdOxbz589HfHw8AgICEBoaiuPHjyMkJKTIiRyIiIjIuFjU8Ufdg78g9cABpO7fD3VCIkxqOcN+wADY9+9v8CIcAAiiWNoDs1QeDx8+hLe3N6KiouDl5SV1HCIiohpBFMVqWUiksvHz88P9+/eLXBcREQE/Pz8AQE5ODhYvXowtW7YgLi4ODRo0wKxZs/Daa68ZMG1BbCMSERHVbCzEVRI2soiIiIgoX3BwMIKDg6FWq5GUlMQ2IhERUQ3FMeKIiIiIiCoZx4gjIiIigIU4IiIiIiIiIiIig2AhjoiIiIiIiIiIyABYiCMiIiIiqmTBwcFwdXVF8+bNpY5CREREEmIhjoiIiIioknGMOCIiIgJYiCMiIiIiIiIiIjIIU6kDGKvc3FwAQExMjMRJiIiIqDK4u7vD1JRNKSobthGJiIiMly7tQ7YeK0l8fDwAoG3bthInISIiosoQFRUFLy8vqWNQNcM2IhERkfHSpX0oiKIoGihPjZKdnY2rV6/CxcWlUDU0JiYGbdu2xdmzZ+Hh4SFRQuJ9qBp4H6oG3gfp8R5UDWW5D+wRR+XBNmLVxM9eWvz8pcPPXlr8/KVTWZ89e8RJyNLSEs8//3yJ23h4ePA36VUA70PVwPtQNfA+SI/3oGrgfaDKwjZi1cbPXlr8/KXDz15a/PylI8Vnz8kaiIiIiIiIiIiIDICFOCIiIiIiIiIiIgNgIU4Ccrkc8+fPh1wulzpKjcb7UDXwPlQNvA/S4z2oGngfSEr8/pMOP3tp8fOXDj97afHzl46Unz0nayAiIiIiIiIiIjIA9ogjIiIiIiIiIiIyABbiiIiIiIiIiIiIDICFOCIiIiIiIiIiIgNgIY6IiIiIiIiIiMgAWIgjIiIiIiIiIiIyABbiiIiIiIiIiIiIDICFOD1SqVSYN28efHx8YGlpiWbNmmHXrl067SuKIr7++ms0aNAAFhYWaNCgAVatWgVRFCs5tfGpyH347LPPMHjwYHh6ekIQBEyePLmS0xqv8t6HW7duYfbs2WjVqhXs7e1ha2uLTp06ITQ0tPJDG5ny3oPIyEiMHj0a9evXh62tLRwcHNC2bVts27aNP5PKoSI/k5526tQpCIIAQRAQGxtbCUmNW0X+PuR/7s++Nm7caIDkZAzYRpQO24XSYntQOmwHSovtP+lUizafSHozYcIEUSaTidOmTRPXr18v9u3bVwQgbt26tdR9FyxYIAIQx44dK27YsEF87bXXRADiwoULDZDcuFTkPgAQXV1dxX79+okAxMDAQAMkNk7lvQ8zZ84U5XK5OGHCBHHNmjXiypUrxVatWokAxA0bNhgovXEo7z04ffq02L17d3Hu3LniunXrxG+++UYcOHCgCECcNWuWgdIbj4r8TMqnVqvFFi1aiDY2NiIAMSYmphITG6fy3oeIiAgRgDhy5Ehx+/btBV537941UHqq7thGlA7bhdJie1A6bAdKi+0/6VSHNh8LcXpy8eJFEYC4aNEi7TKNRiN27dpVdHV1FZVKZbH7xsTEiBYWFuLEiRMLLB8/frxoYWHBv3BlUJH7IIqiGB4erv2aDa7yq8h9OHfunKhQKAosUyqVYkBAgFirVi1RrVZXWm5jUtG/C0Xp37+/aGVlJebk5OgzqlHT131Ys2aN6OzsLL733ntsiJVDRe5DfqNs6dKlhohKRohtROmwXSgttgelw3agtNj+k051afPx0VQ92b17N2QyGaZMmaJdJggCpk6diri4OBw7dqzYffft2welUolp06YVWD5t2jQolUrs27ev0nIbm4rcBwDw9/ev7Ig1QkXuQ5s2bWBnZ1dgmbm5OQYMGICEhATExcVVWm5jUtG/C0Xx9fVFVlYWlEqlPqMaNX3ch6SkJHzyySdYtGgRHBwcKjGt8dLX34fMzEx+/1OZsY0oHbYLpcX2oHTYDpQW23/SqS5tPhbi9OTChQuoW7cunJycCixv164dAODixYsl7mthYYFmzZoVWN6yZUuYm5uXuC8VVJH7QPpTGfchOjoapqamsLe310tGY6ePe5CZmYmEhARERERg48aNCAkJwQsvvABbW9tKyWyM9HEf5s6dCw8PDwQGBlZKxppAH/fh008/hY2NDaysrNCqVSvs37+/UrKS8WEbUTpsF0qL7UHpsB0oLbb/pFNd2nymej9iDRUdHQ0PD49Cyz09PbXrS9rXzc0NMlnBuqhMJoObm1uJ+1JBFbkPpD/6vg/37t3D//73PwwcOBBWVlZ6yWjs9HEPPvvsM3z66afa9y+//DJCQkL0F7IGqOh9CAsLw/r163Ho0CGYmJhUSsaaoCL3QSaToUePHhgyZAi8vLwQGRmJr776CoMGDcL//vc/jBgxotJyk3FgG1E6bBdKi+1B6bAdKC22/6RTXdp8LMTpSVZWFiwsLAotl8lkMDMzQ1ZWVpn3BQBLS8sS96WCKnIfSH/0eR8yMzMxYsQIWFhYYMWKFfqMadT0cQ9ef/11dO3aFXFxcQgNDUVcXBwyMzMrI67Rquh9mDZtGvr164cePXpUVsQaoSL3wcfHB0eOHCmwbPz48WjcuDFmzpyJV155BYIg6D0zGQ+2EaXDdqG02B6UDtuB0mL7TzrVpc3HR1P1xNLSsshniDUaDVQqFSwtLcu8LwBkZ2eXuC8VVJH7QPqjr/ugUqnwyiuv4Pr16/jxxx/h6+ur76hGSx/3oG7dunj55ZcxevRo7N69G35+fujRoweys7MrI7JRqsh9+P777/HPP//giy++qMyINYK+/22Qy+V488038fDhQ9y+fVtfMclIsY0oHbYLpcX2oHTYDpQW23/SqS5tPhbi9MTT0xMxMTGFlud3fczvClncvo8fP4ZGoymwXKPR4PHjxyXuSwVV5D6Q/ujjPmg0GowbNw6//vordu7ciZdeeknvOY1ZZfxdGDFiBKKionDixIkK56spKnIfgoKCtL95u3v3Lu7evYukpCQAQGRkJB48eFA5oY1QZfx98PHxAQDtPSEqDtuI0mG7UFpsD0qH7UBpsf0nnerS5mMhTk9atWqFe/fuFbo5Z86c0a4vaV+lUokrV64UWH7p0iXk5OSUuC8VVJH7QPqjj/swefJk/O9//8N3332HYcOGVUpOY1YZfxfyu3KnpqZWPGANUZH78OjRI+zatQv169fXvr755hsAwAsvvICePXtWXnAjUxl/H8LDwwEArq6uFQ9IRo1tROmwXSgttgelw3agtNj+k061afOJpBfnz58XAYiLFy/WLtNoNGLXrl1FFxcXMTs7WxRFUYyPjxdv3rwpZmRkaLd79OiRaG5uLr7xxhsFjjl+/HjR3NxcjI6ONsxFGIGK3IdnARADAwMrPbMxquh9mDlzpghAXL58uUFzG5OK3IPHjx8XOp5GoxH79u0rCoIg3rlzp/IvwEhU5D7s2bOn0OuVV14RAYgbNmwQf/31V4NfT3VVkfuQmJhY6HixsbGii4uLWLdu3coPT9Ue24jSYbtQWmwPSoftQGmx/Sed6tLmYyFOj8aOHSvKZDLx3XffFTds2CD269dPBCCGhIRot5k/f74IQDx27FiBfT/55BMRgDhu3Dhx48aN4tixY0UA4rx58wx8FdVfRe7Dtm3bxMWLF4uLFy8WAYjPP/+89n1kZKSBr6R6K+99+Prrr0UAYsuWLcXt27cXeqWnp0twNdVTee/BhAkTxI4dO4rz5s0TN2zYIC5dulRs1aqVCECcNm2aBFdSvVXkZ9Kz8reLiYmp5NTGp7z34fXXXxd79OghLliwQFy/fr348ccfi66urqKpqal48OBBCa6EqiO2EaXDdqG02B6UDtuB0mL7TzrVoc3HQpweKZVKce7cuaKXl5dobm4uNm3aVNy+fXuBbYr7y6bRaMQVK1aIdevWFc3NzcW6deuKK1asEDUajQGvwDhU5D506dJFBFDkq7QfkFRQee/D+PHji70HAMSIiAjDXkg1Vt57sG/fPrFPnz6ih4eHaGZmJsrlcrFTp07i5s2b+TOpHCryM+lZbIiVX3nvw65du8QXX3xRdHFxEU1NTUVnZ2dxwIAB4t9//23gK6DqjG1E6bBdKC22B6XDdqC02P6TTnVo8wmiKIq6PMJKRERERERERERE5cfJGoiIiIiIiIiIiAyAhTgiIiIiIiIiIiIDYCGOiIiIiIiIiIjIAFiIIyIiIiIiIiIiMgAW4oiIiIiIiIiIiAyAhTgiIiIiIiIiIiIDYCGOiIiIiIiIiIjIAFiIIyIiIiIiIiIiMgAW4oiIiIiIiIiIiAyAhTgiIiIiIiIiIiIDYCGOiIiIiIiIiIjIAFiII6JqITQ0FHZ2dkhKSpI6iuS2bNkCQRAQGRkpdRSD+uqrr+Dn54ecnBypoxAREVV7NaE98ew1VpVrXrBgAQRBQGxsrKQ58lVGnrJ81lXlvhAZCgtxREbuxx9/hI2NDRITE7XLjh8/DkEQtC+ZTAZvb29MmjSpwHa6ePZYJiYmcHZ2RpcuXfDNN98UKposXrwYDRo0KNM5NBoNPvnkE0yePBlOTk5l2lcKp06dwoIFC5CSkiJ1lEKqcjag5HyTJk1CRkYG1q1bZ/hgREREZJQqq21U1dtc1dmSJUtw9OhRqWMQlRsLcURGbu/evWjdujWcnZ21y8LCwgAAy5Ytw/bt27Fx40Z06NABGzZswPjx48t0/PxjLV26FNu3b0dISAg+/PBDmJub491330WHDh2gUCi02/fp0wd37tzB5cuXdT7HoUOHcO3aNUyePLlM2aRy6tQpLFy4sNIaXmPHjkVWVhZ8fX3LvG9lZ6uokvJZW1tj7Nix+OKLL6DRaAwfjoiIyIhUpD1RXRV1zZXVNqrqba7qKiIiAp988gnu378vdRSicjOVOgARVZ7c3FwcPHgQH3/8cYHlYWFhMDU1xfvvvw8LCwsAwMSJExEeHo7Dhw9DqVRql5cmLCwMMpkM7777LqytrbXLZ82ahbVr1+Kdd97BnDlzsHr1agBA69atUbt2bYSGhqJFixY6nSMkJARt2rRB3bp1ddq+usnMzCzw2ZXGxMQEJiYmlZio6ho5ciRWrlyJP/74Ay+//LLUcYiIiKqtmtieMIZrLmu70dicO3cOQN7/KYiqK/aIIzJif/75J5KTkzF48OACy8PCwtC4ceNCxTZ3d3eo1WpkZmbqfI6wsDDUq1evyAbB22+/jYYNG2LXrl3aHkyCIGDgwIEIDQ3V6fg5OTk4ePAgevbsWWhdTEwMAgMD4eXlBQsLC/j5+eGtt95CWlqadptr165h4MCBcHBwgLW1NTp06IBff/210LHyx8a4c+cOJk+eDGdnZ9ja2uKVV14p8Lhueno6PvjgA/j7+8PS0hJubm546aWXcPLkSe1xZs+eDQDw9/fXPrJ7/PjxAue5ceMGxo8fD2dnZzRp0gQAcP/+fUyZMgWNGjWCtbU1HBwcMGDAAFy/fr1A1qLG0dAlf2nZiqPPz7mkz0+XfG3btoW9vT327t1bYmYiIiIq2bPtCV3bQvkeP36MwMBA1K5dGxYWFqhfvz4+//xziKJYYLt//vkHHTp0gKWlJXx8fLBs2TJs3ry5UFtmwoQJ8PPzKzWnru0lXa+5uLbH0aNHIQhCkW2OX375BYIg4MCBA0WeR9c2V1paWomfdUntRkC3e1Ba27UsefLp2u4rSlHfD89+zxTnhRdewMiRIwEArVq1giAIMDU1RXZ2tk77E1UV7BFHZMRCQ0PRtGnTAj3JcnNzcf36dYwaNarAtmlpaTh79iz8/Pzg6Oio0/HzjzVgwIBit2nVqhW+//57xMXFwd3dHQAwaNAgrF27Fvfv3y/1cYjz588jOzu70G+9YmNj0bZtW8THx+Ott95C06ZNERMTg7179yIxMRF2dnb4999/0bFjR1hYWGD69OmwtbXF5s2b0bdvX+zduxcDBw4sdL5Ro0bB3d0dixcvxp07d/DNN9/AzMwMu3btApBXXNy9ezemTJmCJk2aIDk5GWfOnMHly5fRuXNnDB06FLdu3cIPP/yAlStXolatWgCARo0aFTjPiBEj4O/vjyVLlkCpVALI+w3f8ePHMWTIEPj5+SEmJgbr1q3Diy++iOvXr2s/v5KUlF/XbJX5OZf0+emSTxAEtGnTBqdOnSr1syAiIqKyK60tBAAJCQlo3749lEolJk2aBA8PD5w8eRKzZs1CdHQ0vvrqKwDAjRs38PLLL0Mul2Pu3LkwNzfH+vXrYWtrW+58+mgv5Sup7eHi4gIvLy9s374dQ4YMKbDfjh074OLigt69e5f5uE/T5bMGim436noPSmu7ljVPedrX+Sr6/RAUFIQlS5YgOTkZixcvBgBYWVnB0tJSp/2JqgyRiIyWr6+v+PHHHxdYdu3aNRGA+Omnn4rx8fFidHS0eOzYMbFLly4iAHHr1q06Hz//WIsWLSp2mzFjxogAxEePHmmXKZVKUS6Xi1999VWp59i0aZMIQLx06VKB5ePHjxcFQRD//PPPQvtoNBpRFEVx2LBhoqmpqXjz5k3tutTUVNHHx0f08/MT1Wq1dvn8+fNFAOKoUaMKHOu9994TTUxMxNTUVFEURdHBwUGcMmVKiZmXLl0qAhAjIiIKrcs/z/Dhwwuty8jIKLTs3r17ooWFhbhkyRLtss2bNxc6vq75S8pWFH1/zqV9frrke+utt0Rzc3Od8hMREVHRnm1P6NqWEEVRDAwMFF1dXcXY2NgC2wYFBYkymUyMjIwURVEUhw4dKpqZmYl3797VbhMXFyfa29sX+vd+/Pjxoq+vb6k5dW0vFbVvUW2oktoes2fPFs3NzcWkpCTtMoVCIVpbW4vTpk0rtP3TdGkPlvZZl9Ru1PUe6NJ2Lcu917XdV9RnXZbvh+LUr19ffPXVV0vdjqgq46OpREbq8uXLuH//fpGPpQLAxx9/DBcXF3h6eqJbt26IjIzE9u3bMW7cOJ3PkX+sZs2aFbtNfHw8TExM4OLiol1mbm6OPn36YN++faWeIyEhAQAK9NLTaDTYu3cvevfuXeg3eUBerym1Wo3Dhw9jwIABeO6557Tr5HI5Jk+ejMjISFy7dq3QvlOmTCnwvkuXLlCr1doBYeVyOc6dO4eYmJhSs5fk7bffLrTs6cd7MzMzkZiYCLlcjoYNG+LChQs6Hbe0/GVRGZ+zPj4/Jycn5OTkFJgEhIiIiPSjtLaEKIrYvXs3+vfvDxMTEyQkJGhfvXr1gkajwfHjxwu0EZ5+OsPFxQVjxowpdz59tJd0NX78eOTk5GD37t3aZT/++CMyMzMxduzYCh9f13bbs+1GXe8BULa2V2l5ytu+fnbf8n4/ZGZm4t69eyX+34OoOmAhjshInTp1ClZWVoUe6QwLC4MgCDh06BCOHj2KEydO4N9//0VERARee+21Mp0jvxAXEBBQ5HpRFHHx4kW0adMGZmZmBdZ16tQJp0+fhlqt1ulc4lNjR8THx0OhUBR73vxtMjIyCjQS8jVu3BhA3qxLz3p2bJL8AmBSUhIAYPny5bhy5Qq8vLzQrl07zJ8/H7du3dLpGp7m7+9faFl2djY+/PBDeHp6wsbGBrVq1YKLiwuuXLmi84xbpeUvi8r4nPXx+eV/LwiCUKb9iIiIqHSltSXi4+ORnJyMkJAQuLi4FHjlT6QUFxeH+Ph4ZGZmomHDhoXOUdQyXemjvaSrhg0bol27dtixY4d22Y4dO9CwYUM8//zzFT6+ru22Z9uNut4DoGxtL13ufXna1/n7VvT74erVq9BoNCzEUbXHMeKIjFSHDh2QlZWFS5cuoVWrVtrlYWFh8PX1LXZMi7IICwuDnZ1dkUUlADh69CgSEhIQFBRUaN3p06fRrl27Umeuyh9TIzk5Wds4qGghpqT9i8uTv8+rr76KLl264Oeff8aRI0ewcuVKfPbZZ9i8eXOZCplWVlaFlr377rvYtGkTpk2bho4dO8Le3h4ymQzvv/++drKL0pSWvywq43PWx+eXnJwMc3Nz2NnZlSsXERERFa+0tkR+m2TUqFGYOHFikdvWr1+/xHZEUe2S4tobz/7SVh/tpbIYP348pkyZgsjISJibm+PYsWNYtGiRXo6ta7vt2XajrvcAKFvbqyLtyNLajWX9fihKfieA5s2b67Q9UVXFQhyRkWrVqhW8vb2xb9++QoU4fU33HRYWhqZNmxb5D6pSqURQUBAcHR0RGBhYYJ1KpcKhQ4cwb968Us+RP6htREQEWrZsCQBwdXWFXC7H1atXi93PxcUFNjY2Rf7GL39ZUTNz6cLDwwOBgYEIDAxESkoK2rdvj4ULF2obM+UtXO3evRvjxo3TDq6bLzk5WVuQrKiyZKusz7mkz0+XfPfu3StxggkiIiKqPC4uLpDL5cjNzdX2viqKWq2GtbV1kW2Ef//9t9AyBweHInu0PT2zKqD/9lJpbY9XX30V06dPx44dO2BhYQFRFHX65WFl9tzX9R7kK63tWpbzlrd97erqWqbvh6JcuXIFjo6O8PLy0j00URXER1OJjNjAgQMRGhqqfR8XF4fY2Fht1/GKyD9WUV3DHz9+jP79++Pq1avYuHEj7O3tC6w/ceIEUlJSMGjQoFLP07p1a1haWuL8+fPaZTKZDEOGDMGhQ4dw+vTpQvuIoggTExP07t0bBw4cKPCPe1paGtatWwc/Pz80bdq0LJcMtVqN1NTUAsscHBzg7++P5ORk7TIbGxsAKLBMF6ampoV+I/j9998jOjq6TMcpSVmy6ftz1uXzKy2fKIq4cOECOnTooMPVEhERkb6ZmJhg+PDh2Lt3Ly5evFhofWpqKlQqFUxMTNCrVy/s378f9+7d066Pj48vNCsoANSrVw+pqam4dOmSdll6ejq2bt1aYDt9t5dKa3s4OjpiwIAB2LFjB3bs2IEXX3wRvr6+FT5uReh6D3Rtu5blvOVtX5f1+6Eo9+/fZxGOjAJ7xBEZsUGDBuHbb79FREQE/P39td25mzRpotP+giCgS5cu2sFen5Z/rLS0NOzYsQOiKCIpKQnnzp3TTsKwa9cuDB06tNC++/btQ0BAAOrUqVNqBnNzc/Tu3RtHjx7FZ599pl2+dOlSHD16FC+99BImTZqEJk2a4PHjx/jpp5+wd+9e+Pn54dNPP8XRo0fRuXNnTJkyRTu9+oMHD/DTTz9BJivb7yLS0tJQu3ZtDBs2DM2bN4dcLsdff/2Fw4cPFxjctk2bNgCAOXPmYNSoUTA3N0f37t3h6upa4vEHDBiAbdu2QS6Xo2nTprh8+TJ++OEHnT4nXZU1mz4/Z10+v9LynTlzBqmpqYUmISEiIiLDWbZsGU6cOIGOHTvijTfeQEBAABQKBa5du4Yff/wRd+/ehbu7OxYtWoRff/1V20YwMzPD+vXr4evrW6j32+jRozF79mwMGTIE7733HlQqFUJCQuDm5oaoqCjtdvpuL+nSNho/fjwGDBgAANi4caPejlsRutwDS0tLndquZVGR9nVZvh+K4u/vj8OHD+Ozzz6Dj48P6tevj3bt2pXrOogkZajpWYnI8HJyckR7e3tx5cqVoiiKYnBwsAhAPHv2bKn7pqWliQCKnR48/1j5L0tLS9HDw0N86aWXxKVLl4oJCQnFHtvHx0f85JNPdL6O/fv3iwDEO3fuFFgeFRUlTpgwQXR1dRXNzc1FPz8/cdKkSWJaWpp2m6tXr4r9+/cX5XK5aGVlJb7wwgvioUOHCp0jf9r2mJiYAsuPHTsmAhCPHTsmKpVKMSgoSGzRooUol8tFa2trsUmTJuIXX3whqlSqAvstWbJE9Pb2FmUymXb/ks4jinlTv7/11luiq6uraG1tLb744ovi2bNnxS5duohdunTRblfUdPC65C8tW3H09Tnr+vmVlG/GjBmit7e3qFarS8xMREREJXu2PVGWtoQoimJCQoL4/vvvi35+fqKZmZno4uIiduzYUQwODhaVSqV2u7/++kts3769aGFhIXp7e4tLly4VQ0JCCrVlRFEUf//9d7F58+aimZmZ6OfnJ65atapQTl3bS0VdY1FtKFEsvW2kUqlENzc30dLSUkxJSdH1Iy5ze/DZz7qkdqMoln4PdG17lfXe69LuK+6zLsv3w7NiYmLEPn36iHZ2diIAccmSJSVuT1RVCaJYjhG8iajaGDVqFGJiYors1VaSgwcPon///ggLCytx1syyyp884vz58zqPVZc/O1Lv3r3xxRdf6C0LVS+ZmZnw9fXFJ598gnfffVfqOERERFROW7Zsweuvv46IiIhyj9lrSBqNBj4+PujYsSN++OEHqeMQUTXHMeKIjNygQYNw6tQpJCYmlmm/Y8eO4dVXX9VrEQ7IeyzVy8urTBNGyGQyLFmyBOvWrSs0nTvVHBs2bICNjQ0mT54sdRQiIiKqQQ4ePIhHjx5h/PjxUkchIiPAHnFERERERERkMNWlR9yZM2dw9epVfPrpp7C2tsbVq1fLPMYwEdGz+FOEiIiIiIiI6Blr167F5MmT4ejoiJ07d7IIR0R6wR5xREREREREREREBsCSPhERERERERERkQGwEEdERERERERERGQALMQREREREREREREZAAtxREREREREREREBsBCHBERERERERERkQGwEEdERERERERERGQALMQREREREREREREZAAtxREREREREREREBsBCHBERERERERERkQGwEEdERERERERERGQALMQREREREREREREZwP8DGmX/oCNxjZwAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(11.5, 4.2))\n", + "\n", + "axes[0].plot(D_costs, C_costs, \"o-\", color=\"C0\", markersize=7)\n", + "axes[0].set_xlabel(r\"$\\langle P, D \\rangle$ (constraint cost)\")\n", + "axes[0].set_ylabel(r\"$\\langle P, C \\rangle$ (objective)\")\n", + "axes[0].set_title(\"Pareto front: C vs D transport costs\")\n", + "axes[0].annotate(\n", + " r\"$C$-only optimum\",\n", + " xy=(D_costs[-1], C_costs[-1]),\n", + " xytext=(D_costs[-1] - 0.05, C_costs[-1] + 0.05),\n", + " arrowprops=dict(arrowstyle=\"->\", alpha=0.5),\n", + " fontsize=9,\n", + ")\n", + "axes[0].annotate(\n", + " r\"$D$-only optimum\",\n", + " xy=(D_costs[0], C_costs[0]),\n", + " xytext=(D_costs[0] + 0.05, C_costs[0] - 0.04),\n", + " arrowprops=dict(arrowstyle=\"->\", alpha=0.5),\n", + " fontsize=9,\n", + ")\n", + "\n", + "axes[1].semilogy(ts_arr, np.maximum(alphas_phys, 1e-3), \"o-\", color=\"C3\")\n", + "axes[1].set_xlabel(r\"inequality threshold $t$\")\n", + "axes[1].set_ylabel(r\"$\\alpha\\,/\\,n$ (physical dual on $D\\cdot P \\leq t$)\")\n", + "axes[1].set_title(r\"Constraint dual vs. threshold\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "2cc028a6", + "metadata": {}, + "source": [ + "**Reading the figures.** The left panel traces the Pareto-optimal trade-off: as we relax the cap $t$ on the secondary cost, the primary cost $\\langle P, C \\rangle$ decreases monotonically, with strong gains at first and diminishing returns near the unconstrained optimum. The right panel plots the constraint dual $\\alpha/n$ on a log scale: when the constraint is barely active (large $t$), $\\alpha$ is essentially zero; as $t$ approaches $t_{\\min}$, $\\alpha$ explodes : this is the Lagrangian shadow price, telling us how expensive each unit of additional tightening would be for the objective.\n", + "\n", + "The smooth, monotone behaviour of the Pareto curve is direct evidence that the constrained Sinkhorn solution does *not* depend on initialisation noise or solver hyperparameters in any meaningful way at this level of regularisation. We make the *price* interpretation of $\\alpha$ precise in the next section." + ] + }, + { + "cell_type": "markdown", + "id": "666a549b", + "metadata": {}, + "source": [ + "## 7. A geometric example : capping mass to a forbidden zone\n", + "\n", + "The previous experiments use random costs, which makes for clean numerical tests but doesn't visualise as a \"movement of mass\". Here is a small geometric example: source and target are 2D point clouds with several clusters, and the constraint forbids transporting more than a small fraction of the total mass into one designated *target* cluster (say, an out-of-budget destination).\n", + "\n", + "We let OTT-JAX build the cost matrix from the two point clouds via `pointcloud.PointCloud`, and we solve the unconstrained baseline with `sinkhorn.Sinkhorn` directly : only the constrained version goes through our wrapper. The constraint matrix $D$ has uniform rows: $D_{ij} = 1$ if target $j$ is in the forbidden cluster, $0$ otherwise. So $\\langle P, D \\rangle$ is exactly the total mass routed to the forbidden cluster, and the cap reads $\\langle P, D\\rangle \\le t_\\mathrm{cap}$." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "dd78cd2b", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:40:25.467896Z", + "iopub.status.busy": "2026-05-02T06:40:25.467740Z", + "iopub.status.idle": "2026-05-02T06:40:29.123206Z", + "shell.execute_reply": "2026-05-02T06:40:29.122189Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "unconstrained mass to forbidden cluster: 0.250\n", + "constrained mass to forbidden cluster: 0.100 (cap = 0.1)\n", + "unconstrained transport cost: 0.2843\n", + "constrained transport cost: 0.2847\n", + "alpha (constraint dual): 35.156\n" + ] + } + ], + "source": [ + "# Build a 2D point-cloud problem with a 'forbidden cluster'.\n", + "np.random.seed(1)\n", + "\n", + "\n", + "def cluster(center, k, sigma=0.04):\n", + " return np.array(center) + sigma * np.random.randn(k, 2)\n", + "\n", + "\n", + "src_pc = np.vstack(\n", + " [\n", + " cluster([0.2, 0.85], 8),\n", + " cluster([0.2, 0.55], 8),\n", + " cluster([0.2, 0.25], 8),\n", + " cluster([0.5, 0.55], 8),\n", + " ]\n", + ")\n", + "tgt_pc = np.vstack(\n", + " [\n", + " cluster([0.85, 0.55], 8),\n", + " cluster([0.85, 0.85], 8),\n", + " cluster([0.85, 0.25], 8),\n", + " cluster([0.55, 0.25], 8),\n", + " ]\n", + ")\n", + "n_pc = src_pc.shape[0]\n", + "\n", + "# Cost matrix: squared Euclidean, computed by OTT-JAX's PointCloud.\n", + "geom_pc = pointcloud.PointCloud(src_pc, tgt_pc)\n", + "C_pc = np.array(geom_pc.cost_matrix)\n", + "\n", + "# The forbidden target cluster: targets with y > 0.7 (the upper-right group).\n", + "forbidden = (tgt_pc[:, 1] > 0.7).astype(float)\n", + "D_forbid = forbidden[None, :] * np.ones(\n", + " (n_pc, n_pc)\n", + ") # rows uniform, ones in forbidden cols\n", + "t_cap = 0.10\n", + "# Homogeneous form: D_solver . P >= 0 with D_solver = (t_cap 1 1^T - D_forbid) / n.\n", + "D_pc = (t_cap * np.ones((n_pc, n_pc)) - D_forbid) / n_pc\n", + "a_pc = jnp.ones(n_pc) / n_pc\n", + "b_pc = jnp.ones(n_pc) / n_pc\n", + "eps_pc = 1.0 / 200.0\n", + "\n", + "# Unconstrained baseline: vanilla OTT-JAX Sinkhorn on the same Geometry.\n", + "out_unc = sinkhorn.Sinkhorn(threshold=1e-7)(\n", + " linear_problem.LinearProblem(\n", + " geometry.Geometry(cost_matrix=jnp.array(C_pc), epsilon=eps_pc),\n", + " a=a_pc,\n", + " b=b_pc,\n", + " )\n", + ")\n", + "P_uncon_pc = np.array(out_unc.matrix)\n", + "\n", + "# Constrained: our wrapper with the single inequality.\n", + "res_con_pc = constrained_sinkhorn(\n", + " jnp.array(C_pc),\n", + " a_pc,\n", + " b_pc,\n", + " jnp.array(D_pc[None, ...]),\n", + " jnp.zeros((0, n_pc, n_pc)),\n", + " eps=eps_pc,\n", + " n_iters=200,\n", + " n_newton=10,\n", + ")\n", + "P_con_pc = np.array(res_con_pc.matrix)\n", + "\n", + "print(\n", + " f\"unconstrained mass to forbidden cluster: {(P_uncon_pc * D_forbid).sum():.3f}\"\n", + ")\n", + "print(\n", + " f\"constrained mass to forbidden cluster: {(P_con_pc * D_forbid).sum():.3f} (cap = {t_cap})\"\n", + ")\n", + "print(f\"unconstrained transport cost: {(P_uncon_pc * C_pc).sum():.4f}\")\n", + "print(f\"constrained transport cost: {(P_con_pc * C_pc).sum():.4f}\")\n", + "print(f\"alpha (constraint dual): {float(res_con_pc.alphas[0]):.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "1d2376b8", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:40:29.125494Z", + "iopub.status.busy": "2026-05-02T06:40:29.124853Z", + "iopub.status.idle": "2026-05-02T06:40:29.553018Z", + "shell.execute_reply": "2026-05-02T06:40:29.552003Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABNgAAAJCCAYAAAASm3pjAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAABAABJREFUeJzsnXecHHX5xz+zvd3uleQu5dJ7r4RQQ4dIE+kdQZCi/AC7dFFApagoXYqIiIo0hYB0CIEACSUJ6fVyud623Nb5/v6Ye+a+Mzu7t9dSn/crk50+35mdvXnm832KIoQQYBiGYRiGYRiGYRiGYRimR9h2dQMYhmEYhmEYhmEYhmEYZk+GBTaGYRiGYRiGYRiGYRiG6QUssDEMwzAMwzAMwzAMwzBML2CBjWEYhmEYhmEYhmEYhmF6AQtsDMMwDMMwDMMwDMMwDNMLWGBjGIZhGIZhGIZhGIZhmF7AAhvDMAzDMAzDMAzDMAzD9AIW2BiGYRiGYRiGYRiGYRimF7DAxjAMwzAMwzAMwzAMwzC9gAU2hmEYhmEYhmEYhmEYhukFLLAxDLPLURQFhx122K5uRhaHHXYYFEXZ1c1gGIZhGKYANm/eDEVRcNFFF+3qpnSb++67D1OmTIHP54OiKLjlllv6/Bh0fbpjc/X0mvbEhtqT7a49+d5jGKbvYIGN2Sd45513CjIoFEXZYx/sfQUZN5s3b97VTdnjUBQFI0eO3NXN2Kt44okn+u1Fg2EYZl+D7BybzYYNGzbkXO/oo4/W133wwQd3Ygt3D8hu3Fliyd///ndcffXVSKVS+P73v4+bb755t+x4ZHYvdvZ9urtQVVWFiy++GEOGDIHb7cbIkSNxzTXXoLm5uVv7+de//oXvf//7OOSQQxAMBqEoCs4666x+ajWzr+DY1Q1gGIb5+uuv4fP5dnUzGIZhGGavx+FwIJ1O49FHH8Udd9yRtXzTpk1488039fX2JIYOHYqvv/4aoVBoVzelW/znP/8BAPzlL3/B/Pnzd3FrjOyp15TZO9mwYQMOPPBA1NXV4eSTT8bEiROxdOlS/P73v8eiRYuwePFilJWVFbSvX/7yl/jiiy8QCARQWVmJ1atX93PrmX0B9mBjGGaXM3HiRAwfPnxXN4NhGIZh9nrKysowf/58PPHEE5YC2qOPPgohBE488cRd0Lre4XQ6MXHiRAwePHhXN6VbVFdXAwAGDRq0i1uSzZ56TZk9g3Q6jffeew8//elPceWVV3a5/pVXXom6ujr84Q9/wAsvvIA777wTb731Fq699lqsWbMG119/fcHHvvfee7F27Vq0tbXhgQce6M1pMIwOC2wM0wVyiNrnn3+O448/HsXFxfD5fFiwYAE+/PBDy+0ymQwefvhhHHLIISguLobX68Xo0aNx4YUXYsWKFYZ1k8kkfvvb32LmzJnw+XwoKirC/Pnz8dhjj0EIkbVvCkWMxWL40Y9+hOHDh8PtdmPs2LH49a9/bbnNyy+/jKOOOkp3p66oqMD++++P22+/3bDfd999FwAwatQoPTxEDnukENKNGzfi97//PaZOnQqPx4NvfvObAIDW1lb89re/xRFHHIHKykq4XC4MHDgQJ510Ej766CPLa2UVvnvLLbdAURQ88cQTePvtt3HYYYehqKgIwWAQxx9/PL7++mvLfcXjcdx1112YM2cOAoEA/H4/5s6diwcffNDyugDAQw89hOnTp8Pj8aCiogIXXXQRampqLNe1glz0AWDLli36dTO77dO1bG1txTXXXIMRI0bA4XDgd7/7HQBg7dq1+OlPf4q5c+di4MCBcLvdGDFiBC677DJUVVXlPO5FF12EzZs346yzzsKAAQPg8Xgwd+5cvUdcJplM4r777sOcOXNQVlYGj8eDoUOH4thjj8W///1vw7ojR46EoihIJBL4+c9/jpEjR+r32W233YZkMpnzehx//PEoKyuD2+3G6NGjcc0116C+vj5r3YsuugiKouCdd97BU089hf322w9+vx8zZ87ERRddhG9/+9sAgFtvvdVwXd95551CvhqGYRjGgksvvRQ1NTV4+eWXDfPT6TQef/xxzJs3D9OnT7fc9oknnsCpp56K0aNHw+v1IhgM4qCDDsJf//pXy/ULsT96sq4VVnmw5LxjDQ0NuOyyyzB48GC43W5MmTIFjz/+uGEft9xyCw4//HAAwJNPPml49jzxxBOGdf/973/j8MMPR3FxMTweDyZNmoSbbroJkUikoPaSrfP2228DMNpePTmOfK7V1dW4+OKLMXjwYNjtdrzwwguGdbdv347zzjsPAwcOhNfrxdy5c/Hss88WdE1lemJDdXebZcuW4ayzzsKQIUPgcrkwePBgnH/++Vi/fn3O9hb6fRfCJ598grPOOgtDhw6Fy+XCoEGDcMQRR+DJJ5/scluy1XKlurDKO1fI76A792mh1687909vqK6uxmOPPYbTTjsNZWVlWLBgAX79619jy5YtebfbuHEjXn/9dYwcORJXXXWVYdmtt94Kv9+Pp556CtFotKB2HH744Rg3btw+nx6I6Vs4RJRhCuTTTz/Fb37zGxxwwAH4zne+g61bt+K5557DkUceic8//xwTJkzQ100mkzjhhBPwv//9D5WVlTjrrLNQXFyMrVu3YtGiRZg1axamTp0KAEilUli4cCHeeustjB8/HldccQWSyST+/e9/45JLLsEHH3yAxx57LKs9qVQKxxxzDKqrq7Fw4UI4HA688MIL+OlPf4p4PI6bb75ZX/fhhx/Gd7/7XVRUVOCEE05AeXk5GhoasGrVKjz44IP4+c9/DgC4+eab8cQTT2DLli34v//7PxQXFwOA/ilz9dVXY/HixTj++ONx/PHHo6ioCIAW7nn99dfj0EMPxfHHH4+SkhJs3boVL730El599VW8/PLLOO644wq+7v/5z3/w4osvYuHChbj88suxatUqvPLKK/jkk0+watUqDBgwQF83HA7jqKOOwtKlSzFr1izdGHzttddwxRVX4KOPPsoyOn7wgx/gnnvuwcCBA3HxxRcjEAjg1VdfxYEHHqifU1eMHDkSN998M2699VaEQiFcc801+rKZM2ca1k0kEjjiiCPQ0tKChQsXwufzobKyEoBmQD/44IM4/PDDceCBB8LlcmHlypV49NFH8fLLL+PTTz/F0KFDs46/ZcsWzJs3D6NHj8b555+PpqYmPPvsszj55JPxxhtv6AYYoAlazzzzDCZPnoxzzz0Xfr8f1dXVWLp0Kf7973/jW9/6Vtb+Tz/9dHz22Wc49dRT9fvspptuwmeffZZlcD366KO47LLL4PV6cfrpp2Pw4MH48MMP8fvf/x7PP/88Fi9erJ+vzF133YU333wTJ510Eo488kgkk0kceuihaGlpwYsvvogFCxYYRFjOdccwDNNzzjzzTFxzzTV45JFHcMopp+jz//vf/2LHjh34xS9+YdmxAwBXXHEFJk+ejEMPPRSDBw9GY2MjXnnlFZx//vlYs2YNbrvtNn3dQu2P7q7bE1paWnDQQQfB5XLhtNNOQzwex7/+9S9cfPHFsNlsuPDCCwFogsfmzZvx5JNPYsaMGXoHImB8pt9000247bbbUFpaijPPPBPFxcX43//+h9tuuw0vvfQS3n///S7tCHquWdlevTlOY2MjDjjgAIRCIZx++ulQVRWlpaX68ubmZhx00EEoKSnBxRdfjObmZvzjH//AWWedherqalx77bUFXdOe2FDd3ebpp5/GRRddBJfLhZNOOgnDhg3D+vXr8cwzz+Dll1/GO++8k2VrAYV/313x6KOP4vLLL4fNZsOJJ56ICRMmoKGhAcuWLcPvfve7gvdTKIX+Dgq9T3ty/bq6f7pLOp3GkiVL8Morr+DVV1/FF198AQCw2WzYb7/98I1vfAMLFy7E3Llz8+7nrbfeAgAcc8wxsNmMfkJFRUU46KCD8Prrr+Ojjz7CkUce2eP2MkyvEAyzD/D2228LAGLBggV51wMgzD+Lxx9/XJ//+OOPG5Y9+OCDAoC44oorDPN/9rOfCQDiG9/4hmhvbzcsSyaToqamRp++8847BQBxzDHHiEQioc9vaWkRU6ZMEQDEP//5T8t2Lly4UMRiMX1+bW2tCIVCIhQKiWQyqc+fPXu2cLlchuMS9fX1hukFCxYIAGLTpk0WV6hz+ZAhQyzXaWlpydqnEEJs27ZNDB48WEycODFrmdV3c/PNNwsAwm63izfeeMOw7Kc//akAIH79618b5l9yySUCgLjzzjsN8+PxuPjGN74hAIiXXnpJn79kyRIBQAwfPlzU1tbq89PptDj55JMt74d8ABAjRozIuxyAOPLII0U0Gs1aXlVVJeLxeNb81157TdhsNnH55Zcb5tN9DUDccssthmWLFi3S7xGipaVFKIoi5syZI1KpVNZxzN/biBEjBAAxbtw40dTUpM+PxWJiv/32EwDE3/72N33+1q1bhcvlEn6/X6xYscKwrxtuuEEAEMcff7xh/oUXXigACJ/PJ5YvX57VJvr93XzzzVnLGIZhmO4BQFRUVAghhPjud78rbDab2LJli778+OOPF4FAQITDYf05/MADDxj2sX79+qz9JhIJccQRRwiHwyGqqqr0+d2xP7qzbi42bdokAIgLL7wwax4Acckll4h0Oq0vW7lypbDb7WLSpEmG/dDzVd6PzJIlS4SiKGLo0KFi+/bt+nxVVcUFF1wgAIirrrqqoDYLkdv26u5x5HM9//zzs5718vIzzjhDZDIZfdn69etFKBQSLpdLbN68OWsb87XoiQ3V3W3WrVsn3G63GD16tOG+EkL7jux2u5g9e3bOcyz0+87FypUrhcPhEKFQSHz55ZdZy7du3Zp1XPN1onsplx1D3z3Rnd9BV/dpd69fV/dPd3nmmWfEaaedJkKhkL7fsrIycc4554innnqq4N818cMf/lAAEHfddZfl8quuukoAEPfff3+320rX8swzz+z2tgwjwyGiDFMgBx10UJZ7/MUXXwyHw4GlS5fq8zKZDO6//354PB48+OCD8Hg8hm2cTicqKir06UcffRQAcM8998DlcunzQ6GQnnz4kUcesWzTH/7wB3i9Xn26vLwcJ598MlpbW7FmzRrDug6HA06nM2sfsgdYd/jxj39s6UUUCoUs91lZWYnTTjsNq1evxtatWws+zllnnZXVC3XZZZcBgOG6NzU14cknn8SsWbPwk5/8xLC+2+3Wr+VTTz2lzyfPwJ///OcoLy/X59vtdtx1111ZvWN9xd13321Z1GHo0KFwu91Z84855hhMmTIFr732muX+RowYgRtuuMEw79hjj8Xw4cMN18hms0EIAZfLBbvdnrWfXPfCjTfeiJKSEn3a6/Xil7/8JQAYvCv/8pe/IJlM4sorr8SUKVMM+7jhhhswZMgQ/Pe//9Vzzchcdtlllj3QDMMwTP9w6aWXQlVV/e94VVUVFi1ahLPPPhuBQCDndmPGjMma53K5cNVVVyGdTuPNN980LOuO/dHXtoqMz+fDPffcY3j+TZ48GQcddBC+/vprhMPhgvf15z//GUII/PznP8eQIUP0+Yqi4De/+Q28Xi+eeOIJpFKpXrW5p8dxuVy466674HBYByvZ7XbceeedBjtnzJgxuPLKK5FMJnOG+8r0xIbq7jYPPPAAEokE7r333iwP/sMOOwwnnXQSli1bhlWrVmUdqy++7wceeADpdBrXX389pk2blrV82LBhXe6jJ/TV76Cn16+r+6dQbrjhBvzrX/9Ca2srRo8ejeeeew51dXV4+umncd5553X7fFpbWwEgZ8ENmt/S0tKrdjNMb+AQUYYpECu3ZRLL5LLQq1evRmtrK+bMmdPlgzccDmP9+vWoqKjIEiQA6MLSsmXLspaFQiGMHTs2az4dU27T+eefj2uvvRaTJ0/GmWeeiUMPPRQHHnhgrxLWzps3L+eyxYsX4/e//z2WLFmCurq6rFxd27dvL7iogdV1tzrHpUuXIp1Ow2azWea5IONTzt322WefAYAh9JAYO3YsKisruyUGFoLH48mZ10YIgaeffhpPPPEEvvjiCzQ3NyOTyejLZQFWZubMmZaC2bBhw7BkyRJ9uqioCCeffDJefPFFTJ8+Hd/61rdw8MEH44ADDsj7MrVgwYKseYceeigURcHy5cv1eXSfHnHEEVnru91uHHzwwfjHP/6B5cuXG14UgPz3E8MwDNP3zJkzB7NmzcJjjz2GG2+8EX/+85+RyWRw6aWX5t1u69at+PWvf4033ngD27ZtQ3t7u2H59u3b9fHu2B/9YavIjBs3DsFgMGs+2RQtLS0Fp4bI97yrqKjAtGnTsHTpUqxdu9bSviuUnh5n5MiRBgHLzPDhwzFq1Kis+QsWLMAdd9xheLbnoic2VHe3Wbx4MQDgvffes7SFa2trAWi23eTJkw3L+uL7ptzBCxcuzLteX9KXv4OeXr+u7p9Cueqqq/Dss8/ik08+wcaNG3HqqadixowZWLhwIRYuXIgDDzyw1yKejOjItcw51ZhdCQtszD4B9YipqppzHVqW64+yVR4yQOtlkkUQ6jWxypVlhnpiclWN8vl8CIVClj0x+doDwNCma665BuXl5bj//vvxpz/9CX/4wx8AAPPnz8cdd9xhaeh0Ra42P//88zjttNPg8Xhw9NFHY8yYMfD7/bDZbHjnnXfw7rvvIpFIFHwcq/O0OsfGxkYAmvFGBpwVckLgrq7/4MGD+1xgKy8vz3mPXXfddfjd736HwYMH49hjj8XQoUN1D0XKz2JFvnvBfM///e9/x1133YWnn34av/jFLwBoQvGJJ56Iu+++29IrUfa4JDweD4LBoH4NgcKuJ2Dds7g7Vk5jGIbZ27n00ktx5ZVX4r///S8ee+wxzJgxA/vtt1/O9Tdu3Ih58+ahubkZhxxyCI499liEQiHY7XY9H5T8jO+O/dEftopMd+ymrujN86479PQ4XT1TrZ7r8nz52d6btpltqO5uQ7bd3XffnbctVkUl+uL77o5N31f05e+gp9evr2yya6+9Ftdeey0aGxvx2muv4ZVXXsFrr72GO++8E3feeSdCoRCOOuooLFy4EMcdd1yX15k81HLdn21tbYb1GGZXwAIbs09Af2jpQWNFQ0MDgNwP5EKh7eUe3K7alatyUiwWQ2trK8rKynrVJgA455xzcM4556CtrQ1LlizByy+/jEceeQQLFy7EF198gfHjx3drf7lEohtvvBEulwuffvopJk2aZFj23e9+V69S2tfQtfz+97+vGyOFblNTU2P5MN6xY0ffNbCDXNeNSo5PmTIFS5YsyepVfeaZZ/rk+B6PBzfccANuuOEGVFdX4/3338fTTz+Nf//731i5ciW++uqrrLCE2traLI/DeDyOtrY2Q9Lbru5nup5W15p7GxmGYXY+5557Ln74wx/iqquuQlVVVVaKBTP33HMPGhsb8dhjj+lVnolnnnnGsqpid+yPvrZV+gv5eWdlN+Z73u2M43T1TCXPpVzzC2l3T2yo7m4j2++9SbLfU2SbXk6V0R2okz+dTlsutxJh++p30NPr19c2WVlZmX5Oqqri008/xauvvopXXnkFzz//PJ577jkA2t+jfOHJVFBu7dq1lsvXrVsHALvN3wlm34RzsDH7BBMmTIDH48HatWt1Ic3Mhx9+CACYMWNGr441ceJEFBcXY+XKldi2bVvedYuKijB27FjU1tZa5o+gajlz5szpVZtkgsEgjj32WPzxj3/ED37wA8TjcSxatEhfTqGG3enJlVm/fj0mT56cJa6pqooPPvig5w3vgv333x82mw3vv/9+wdvQdX3nnXeylq1fvz5nBbVc2Gy2Hl+3jRs3QlVVHHvssVniWlVVFTZu3Nij/eZjyJAhOPPMM/HSSy/hgAMOwJo1awwhtISVKPree+9BCIFZs2bp82bPng0AePvtt7PWTyQSeqgCrVcIvb0fGYZhmNwEg0GceeaZqKqqgtfrxbnnnpt3/fXr1wMATjvttKxlXXWgdWV/9HTd/qCrZ0++5119fT1WrFgBv99vqDDfE/rrOFu3bsXmzZuz5tN3KD/bc9ETG6q72xxwwAEA0C3bri+ZP38+AODVV1/t8T5ImLN6J2htbc0pFgG9t9l39fWzwmazYd68ebj55pvx8ccfo7a2Fn/9619xzjnndBnhcvjhhwMAXn/99awIjXA4jMWLF8Pr9erfG8PsClhgY/YJPB4PzjrrLKTTafzgBz/I+qPc3NyMm266CYBWuKA32O12XHnllYjH47j88suzHhbpdNrQc3jJJZcA0MqWy0lq29ra9FLc3/nOd3rVpldffdUy0S55GsmFEshbrqehkSNHjsS6desMieyFELj11lstRcS+YuDAgTj//PPx+eef45ZbbrHsKayqqsLq1av1aep9v/3221FXV6fPz2Qy+OEPf5g3pNiKsrIy1NfXZ+WjKQQKzfzggw8MhlIkEsGll16as+ezO9TX1+v5RGQSiYTegyrfC8Rtt91m6GFtb2/XiyrIHgznnXceXC4X7r//fsN1BoA77rgD27dvxze+8Y2s/Gv56O39yDAMw+TnF7/4BZ5//nm89tprXXou0bPKLJC89tpretEmme7YH91Zt7/p6tlDtuLtt99u8NoWQuDHP/4xYrEYLrzwQstE9d2hv46TyWTwk5/8xGDnbNiwAffffz+cTmeXQivQMxuqu9t873vfg8vlwg9+8IMsu4K2sxLr+oorrrgCDocDv/rVr7BixYqs5YV0xE6cOBGhUAgvvPCC4TtMp9O45pprsmzGvrTZd/X1K4QBAwbg3HPPxdNPP41//vOfedcdM2YMjjnmGGzevBl/+tOfDMtuvvlmRKNRXHDBBfD7/fr8DRs2YPXq1YjFYv3SfoYxwyGizD7D3XffjU8//RR/+ctfsGTJEj1vSHV1NV566SU0NjbivPPOw3nnndfrY918881YunQpXnnlFYwdOxYnnHACiouLUVVVhTfeeAM/+clPcM011wDQ8m4tWrQIixYtwrRp03DCCScglUrhueeew/bt23HBBRfg9NNP71V7zj77bLhcLhxyyCEYOXIkFEXB0qVL8f7772PMmDE444wz9HWPPvpo/POf/8Sll16K0047DYFAAMXFxfje975X0LGuvfZaXH755Zg1axZOPfVUOJ1OLF68GKtWrcKJJ56Il19+uVfnko/77rsP69atw6233oqnnnoKhx56KAYNGoSamhqsWbMGH330Ee655x5MnDgRgNazd9111+Gee+7B1KlTcfrpp8Pv9+PVV19FNBrF9OnT8eWXXxZ8/KOPPhp/+9vfcNxxx+HQQw+F2+3GjBkzcOKJJ3a57aBBg3DWWWfh73//O2bOnIljjjkGra2t+N///gePx4OZM2fi888/7+mlAaCFOBxwwAGYMGGCXoQjGo3itddew7p163Dqqadi3LhxWdtNnDgRU6ZMwWmnnQaHw4EXXngBGzduxMknn4yzzz5bX2/EiBH4wx/+gCuuuAJz587FGWecgUGDBuHDDz/Eu+++i8rKSjzwwAPdavOBBx4In8+Hv//973C5XBg+fDgURcH555+PESNG9Op6MAzDMFqV78rKyoLWvfLKK/H444/j9NNPx6mnnoqhQ4dixYoVWLRoEc444ww8++yzhvW7Y390Z93+ZsKECRg2bBjef/99nHvuuRg/fjzsdjtOOukkTJ8+HQcccAB+9rOf4Y477tDth1AohP/9739YtmwZpk2bhttvv73X7eiv40yfPh0ff/wx5s6di2OOOQZNTU34xz/+gdbWVtxzzz2W+Vit2tZdG6q720yYMAFPPPEEvv3tb2Pq1Kk47rjjMH78eGQyGWzbtg2LFy82dBL2NZMnT8b999+Pyy+/HHPmzMGJJ56I8ePHo7m5GcuXL0cikeiyIITT6cS1116LW265BbNmzcK3vvUtAJpXohACM2bMwBdffKGv353fQVf36a6+fscee2zO/MFWzJs3D3/5y1/yrnP//ffjwAMPxNVXX40333wTkyZNwscff4y3334b48ePx69+9SvD+kceeSS2bNmCt99+Oyt/3QsvvIAXXngBQKeA+fHHH+Oiiy7S13niiScKbj/DAAAEw+xDRKNR8etf/1rMmzdPBINB4XA4xIABA8QxxxwjnnnmGcttHn/8cQFA3HzzzZbLR4wYIUaMGJE1P5VKiT/96U9i//33F4FAQHi9XjF69Ghx0UUXiRUrVhjWjcfj4s477xTTpk0THo9H+Hw+MW/ePPHII48IVVWz9g3A8phCCHHzzTcLAOLtt9/W5z3wwAPilFNOEaNHjxY+n0+EQiExbdo0cfPNN4uGhgbD9plMRtx4441izJgxwul0Zh1rwYIFAoDYtGmT5fGF0K7ZjBkzhM/nE2VlZeKb3/ym+PLLLy3bRuezYMECy/N4/PHHLY9htY0QQiSTSfHAAw+Igw8+WIRCIeFyuURlZaU45JBDxB133CGqqqqytnnwwQfF1KlThdvtFuXl5eLCCy8UO3bs0M+1UOrq6sT5558vBg0aJGw2mwAgLrzwQkObc31vQmj3589//nMxZswY4Xa7RWVlpbjyyitFQ0ODZVvefvvtrGPImLdpbm4Wv/jFL8Thhx8uhg4dKlwulygvLxcHHnigeOSRR0QqlTJsP2LECAFAxONx8bOf/UyMGDFCuFwuMWrUKHHrrbeKRCJhedw333xTHHfccaKkpEQ4nU4xcuRI8f3vf1/U1NRkrXvhhRda3hMyr7/+ujjooINEIBAQALpcn2EYhrEGgKioqChoXXoOP/DAA4b5ixcvFocffrgoLi4WgUBAHHTQQeL555/Xn0myvdQd+6M76+Zi06ZNWc9FmmdlMwjR+Rwy2zWfffaZOOqoo0QoFBKKoljaJP/4xz/EoYceKoqKioTL5RITJkwQ119/vWhrayuovURXtlWhx+nqXOXlVVVV4pxzzhEDBgwQbrdbzJ4929IWtrqmMj2xobq7zcqVK8Ull1wiRo4cKVwulwiFQmLixIniwgsvFC+//HK3rkGu7zsfS5YsEaeeeqqoqKgQTqdTVFRUiCOOOEL85S9/yTqu1XVSVVX89re/FWPHjhVOp1MMGjRIXH755aKxsTHrnLv7OyjkPi30+nV17brLmDFjdLutkKHQ427dulVcdNFFYtCgQcLpdIrhw4eLq6++WjQ2NmatS7asld1If+PyDQzTXRQhOurZMgzDMIzEyJEjsWXLFvBjgmEYhmEYhmEYJj+cg41hGIZhGIZhGIZhGIZhegELbAzDMAzDMAzDMAzDMAzTC1hgYxiGYRiGYRiGYRiGYZhewDnYGIZhGIZhGIZhGIZhGKYXsAcbwzAMwzAMwzAMwzAMw/QCFtgYhmEYhmEYhmEYhmEYphewwMYwDMMwDMMwDMMwDMMwvaBggS2dTqOqqgrpdLo/28MwDMMwDMP0AWy7MQzDMAzD7DwKFthqamowbNgw1NTU9Gd7GIZhGIZhmD6AbTeGYRiGYZidB4eIMgzDMAzDMAzDMAzDMEwvYIGNYRiGYRiGYRiGYRiGYXoBC2wMwzAMwzAMwzAMwzAM0wtYYGMYhmEYhmEYhmEYhmGYXsACG8MwDMMwDMMwDMMwDMP0AhbYGIZhGIZhGIZhGIZhGKYXsMDGMAzDMAzDMAzDMAzDML2ABTaGYRiGYRiGYRiGYRiG6QUssDEMwzAMwzAMwzAMwzBML2CBjWEYhmEYhmEYhmEYhmF6AQtsDMMwDMMwDMMwDMMwDNMLWGBjGIZhGIZhGIZhGIZhmF7AAhvDMAzDMAzDMAzDMAzD9AIW2BiGYRiGYRiGYRiGYRimF7DAxjAMwzAMwzAMwzAMwzC9gAU2hmEYhmEYhmEYhmEYhukFLLAxDMMwDMMwDMMwDMMwTC9ggY1hGIZhGIZhGIZhGIZhegELbAzDMAzDMAzDMAzDMAzTC1hgYxiGYRiGYRiGYRiGYZhe4NjVDWAYhmEYZifS3Az861/Axx8Dy5YBNTWAqgLFxcD06cDcucCppwJjxuzqljIMwzAMwzBCAO+8A7z+OvDZZ8CaNUB7O+B2A2PHAnPmAIcfDhx3HGC37+rW7tMoQghRyIpVVVUYNmwYtm3bhsrKyv5uF8MwDMMwfUlVFXDrrcDTT2tGWVccdxxw003AAQf0f9uYfoFtN4ZhGIbZg1FV4NFHgXvu0US1rhg+HPje94D/+z/A5er/9jFZcIgowzAMw+zNCAH8+c/AlCmakVaIuAYAixYBBx0EXHcdEIv1bxsZhmEYhmGYTtavBxYsAL773cLENQDYuhX48Y+B/fYDli/v3/YxlnCIKMMwDMPsraiq1pP5wAPG+aNGAWefDcybB4wbB9hsQHW1Fnbw4ovA4sXaekIA994LfPgh8MorQGnpzj8HhmEYhmGYfYnFi4HjjwdaWzvnuVxaCo/DDwdmzgRCISAaBb78Enj/feDvf9emAW3eAQcAzz4LnHzyLjmFfRUOEWUYhmGYvZWrrwbuu69zesgQbfqb39REtVwsXw5ceSXw0Ued8+bNA956C/D7+625TN/CthvDMAzD7GF89hlw2GFAJNI573vf09J2DByYe7u2NuCuu4A77gDSaW2e0wm8/DJw7LH92mSmEw4RZRiGYZi9keeeM4prCxcCK1YA3/pWfnENAGbNAj74ALjxxs55S5cCP/lJ/7SVYRiGYRhmXycaBc48s1NcCwaBN97Q7Ll84hqt+4tfaEWsBg/W5qVSwHnnAXV1/dtuRocFNoZhGIbZ22hsBK64onP6qKOAF14ASkoK34fdrhlqt9zSOe9Pf9KqWDEMwzAMwzB9yw03ABs2aOMeD/Daa8CRR3ZvH7Nna7YapfVoaACuuqpPm8nkhnOwMQzDMMzexh//CNTXa+OlpcBTTwEuF9oWLULk/fcL2kXgkEMQPO44zYvtzTe1/B4AcPPNwLvv9lPDGYZhGIZh9kFqajT7jfjVr4D58wGgZ/bb/fcDZ52lzfzXv7S8bNOn93WrGRMssDEMwzDM3kQqBTz0UOf0bbcBgwYBACLvv4+2VxfB1kXpdjWZBADNQLPZtP1NnqwtfO89LdR06tR+aT7DMAzDMMw+xyOPdOZOmzIF+L//0xf1yH474wytevwbb2gL778fePDBfmk60wkLbAzDMAyzN/HBB8COHdp4MAhccIFhsc3lgnPo0Ly7SG3fbpwxaRJw9NHA//6nTT/7LAtsDMMwDMMwfcWzz3aOf+97WqoOiW7bb4qiFbsige3ZZzWRras8vEyv4KvLMAzDMHsTn3zSOX7ccUAg0Df7PfXUzvFPP+2bfTIMwzAMw+zrRCLAqlWd09/6Vt/sd+FCwOvVxltagI0b+2a/TE5YYGMYhmGYvYnPP+8cnzNHHxVCIJPOQBUCiWQSsfYY2sJtaGxqRE1NDSJyOXgrpH1h+fK+bTPDMAzDMMy+yldfAUJo48OGAeXlfbNfhwOYMaNzWrYRmX6BQ0QZhmEYZg8llUqhvb0d7e3taGtrQ2trK0atW4eyjuX/+fJLfPjzn6OxsRFNTU34xqZNmBKNoT4chhCqZssJAQFAURQE8nm7DR/eOd7c3I9nxTAMwzAMsw/R1NQ5PmwYMpkMUqkUkskk2tvbEYlEkMlkkG5vhypUqKqAUFX4fD44HF1IOsOHAx99lH0cpl9ggY1hGIZhdhMymQzi8Tja29sRj8fR2tqKtrY2tLS0oK2tDeFwGOFwGJFIBK2trYhGo7rhlUgkEI/HcduGDbrA9u477+DlQACqqkIIgUMUG1RFQcZ0XEXRxLq8qKpxA4ZhGIZhGKZLVFVFOp1GOp3W7bX29nZ93Lt2LSZ2rNvU2Ih3XnwRmUwGosOrbWh9PYKZDNKJBBRFgWKzwWazAYWYY2y/7VRYYGMYhmGYfkIIYTCkZE+z1tZWhMNhtLW1IRKJIBKJIBaL6T2W8XgcqVQK6XQayWQSiURC/6RxMtZUVdVFtGqqQAVgcHs7Ml4vFEWB2+2GW7HBnknD7/bA4XDA6XTC6XLB5XTC7XbnPxk5b8eAAf10xRiGYRiGYXZvhBC6l5ksmskD2XJks2UyGYNoBgA2mw0ulwsDpYIGRQ0NmDRxIjxeL9xuN7xeLyKfLUNs4yYEiou731jZfhs4sBdnzRQCC2wMwzAM0w1SqZRBMCPRjIQz8jILh8OIRqNIJpNIpVL6IISAqqqaq3+HQCYLaDSk02mkUindICMBTe3oiVQUBTabDQ6HA263Gx6PB6FQCJGiIj1H2umjRuGA++7D+PHjUVJSgpobb0TkzbcwoIsqVJbIhQ3kfGwMwzAMwzB7OGSX5RPNqJMzkUjo9plZNHM4HPB4PHC5XHC73SgqKtLtNJfLBa/XC4/HA6/XC6fTCUVRgPZ24Ac/ANJpOBsbMSkYBCor9X22O+wFOatlEY9r+d2I2bN7foGYgmCBjWEYhtmnUVXVEJbZ3t6OWCyGcDish2jKoZnkWWYWzQgSwUg8I2NNFtDI64wENJoWQsButxuEM2eHd5nP50MwGMSAAQMwePBgDBs2DIMGDcLgwYMxePBghEIhLVzgww+Bgw4CAAxbuRLDJkwASkt7f6GeeaZzfN683u+PYRiGYRimnxBC6GKZWTQzi2fUGWolmimKApfLBZfLBY/HA5/Ph9LSUt27jMQzEs3skjdawXi9wPTpwLJl2vQzzwA/+lHvL8K//w1QCpCKCq2AAtOvsMDGMAzD7FUIIfSksLk8zWiIRCKIRqMGA8xKNDPvXxbP5JwasoFGXmpCCNg6cmXY7XbY7XZ4PB593OVywefzwe/3o7i4GH6/H6FQCGVlZaioqEBFRQXKy8sRCoXg8/m6NtzmzwfGjAE2bNB6Lh96CPjZz/TFajKJ1PbteXehJpPGGR9/3JkgV1GAc87p8ntgGIZhGIbpS2SxjD6tQjJpWhbMZNGM7C85AsDj8RhEM1k8U3ZG7rLzz+8U2O6/H7j6akBK39Ft+01Vgd//3rh/zsHW77DAxjAMw+z2pNNpXSSTvc1isViWYEYJ/2UDjLzGcolmMnLoZjqdNlRyImONPMzI28xut+teZrRMFtScTiecTie8Xi8CgYAuog0cOBDl5eUoLS1FSUkJfD6fHjJgiRBAWxuQTgM+H+DxZK9jswFXXAH88Ifa9G23AaefDowdi8Ahh+irORoa4Fm3Ds4dO+Cor4eSSgE2GzKBAFKDB8MGaAKdzQZcemnn/r/xDWDUqMK/PIZhGIZhGAvIm98snCUSCb0IgDyQXWYWzWRbjIZgMKiLZrKHmcej5aHdKaIZEY8DsRjgcABFRdZC14UXAj//uRYuunkzcOutwO23A4BuvynJJDxr18K1fTucNTWwRSJaNXi3G6mBAzX7bdw4bX8PPAAsXaqNKwpw+eU74UQZRRTytgGgqqoKw4YNw7Zt21ApxQMzDMMwTHehsEzZw8wqPDMSiSAcDqO9vT0rLJMMsQIfYwCyvc9ITCMURYHD4dDFM3lwOp26cEYhoFR5k5LUut1u+P1+FBUVIRAIIBQK6eJZSUkJAoEAvF4vvF4vXC5XYY3euBF4/HFg8WLgs880gY0YN07Lh3bSScC3vtXZ0xmNAtOmAZs2adOzZgFvvw2EQsAbbwB33AG89VbXxy4r08IJPv9cm3Y4gE8+AWbOLKztzC6FbTeGYRhmZ0J2Vi7RzJzLzByWKYtmZI/JnmYknpFgRqKZy+WC3W7fuaJZPuJx4F//Av7zHy2H7YYNncuKi7VcaAcfDFx8MTBiROeyX/4SuPFGbdxmA154ATjxRKC+XhPbHn8caG3Nf2xFAfbfX7Pd4nFt3ne/Czz4YB+eIJMLFtgYhmGYPsFcLdPK04yqZlJYZl+IZjIkmAkhoCiKQRijaTLYSEQjY0wIYRDgKEyUllMODvJACwQCKCoqgs/nQygUQnFxMYLBoC6gUS9pj1m9Wsu/8d//ap5rXVFeDlx3nTY4ncA77wCHH965fOZMTZD75z973qabbwZuuaXn2zM7FbbdGIZhmN4i20VdiWbmsEzKMWv2/pc9zWQvMwrNJC+z3Uo0K4RkEvjtb4Hf/Q5oaOh6fUXROknvugsYO1bLlzZ/fmeoqNOpRRE8+yzQ2NizNg0bBqxcqXnOMf0Oh4gyDMMwllBeC3MuM7OnGYVmUk9kX4tmAHSjjEQySv5P4pksnFHYJhl0cm40+dyoXYqiQFEUXUzz+Xy6eOb3+3WjLxgM6nnQzCJanxp/QgD33quFCSQS1usoSrboVlcH/PSnwD/+ATz1FHDYYcCdd2rzAK0nkzzRiMMPBxYu1HpSy8s1w3DVKuDvfwcWLdLyd8j4fH1wggzDMAzD7CrksMxcOc3kYgDmsEyapsJM1JlJnmbm/GU0Ldtse5RoVihffaXlOfviC+vlVrabEMCLLwKvvw78+tfA976n2XEHHwzU1GiC2/33G7cZPBg46yzNS238eC26YMcOLULhiSc0bzeZQECzJ1lg2ymwwMYwDLOPoKqqntOiEE+zWCxmmfxfHnoDGWNUlYmMMJfLZRDPSFhTFAU2m03Ph0bCGX22t7dniXi0HQlz9FlaWoqioiL9uNQG8k6TBTSfzwePx6NV6OxvVFVz43/0UeP8Aw4Avv1trVdz4kTNmGpu1no4Fy3SQgaamrR1ly0DDjwQeOUV4Cc/ASIRLeRAxm7XcqkddZTm1ZbJaOGnn30GvPwysGWLdft+8hMt79vVV/f9uTMMwzAM023M3vddiWY0jzoiZdGMUmVQxyUVYjKHZZLtZBURsE/y7rvACSdoNhdRVgZccglwzDFaR2ZJiSaYff01sGSJZrt9/LG2bnu7ZlutXg388Y9aWo/99zemBQE0Qe3UU7X0IMGgZq99+SXwwQeawGbVmf3118Cxx2qRDSyy9TscIsowDLMHY66Wmc/TLBwOGzzKaJwKAPRWNKOKTCSQUWVM8gSjnk0yyAAYcpnRuJV4Rp+q5FFFIpwsogGd+dDIu4zEMzkcweyFRkOPSqv3JddcY6z4NHYs8Oc/A4cemn+7WEwrZvCb33R6nYVCwPvva7k8XnyxZ+3xerUw1Zde6vR+s9u1iqJz5xa+H1UF1q0Dams14y8UAiZNMlTHYvoett0YhmH2TKjAUiGiWSKRMIhkJJxRjlnqXCRPM6vwTPI4y5VGg8nD558DhxzSKa7Z7Vr19uuvty5GJfPWW8B3vtOZMxfQOjOPOgo4+uiet+mYY4CpU4F77umcd8kl2R24XdHUBKxdqwmAbrdml5aX97xd+wAssDEMw+xGZDIZS5Esl6dZV2GZvRHNbDabLmKRcOb1elFUVKQPwWBQzztmt9shhICqqrq3XCwWQ3t7O5IdZcMzmYxejdMsnpHQR8cmQc7pdMJms+lVogDoyWztdnuWeOZ0OmG32y0FtLwVOnc1r7wCHH985/RJJwHPPNO9sMz33tOS4VKP54gRRm+0k0/WvNSqqvLvx+8HzjtPq0Q6dqzmLXfIIVoODwCYMkXzlMtXqCEe18IcnnxSK4wQDhuXO53A9OnAGWdoSX4HDCj8PJmCYNuNYRhm90DuUCxENJMLMcmiGXmZAdDtHYfDkSWayV5m5vyzTB8Sj2veZKtWadPFxVru3AMPLHwfkYhW7X3Ros555eVa6g8AmDABGDq0sOJUBx6oecKdcYYWknrHHVrKEWLRIs2bLR/Ll2sFEV5/XatmaqayEjjiCK0q6fz51hVR92FYYGMYhulHhBA5q2Wa50Wj0ZzVMvtKNCOxzCyaBQIBXSyjXGNFRUW6IEXCGXnMxWIxXTxrb2/XPcuEEHk90Kjd1INqFtGAzkqeNG2z2QxGI3mt2Ww2eDweS2+0git07i6Ew5pH1/bt2vThh2tGUE/O4/33te0zGeP8M8/Ucqul04j++tfIvPYaXFVVcLS2QkkmAUVBxutFprgY8fHjkRwxAplQCIFDDkHwuOO0nCJz5wJUdfXJJ4ELLsg+vhDAww9rPbeFJuR1u7W8I7/4Bed560PYdmMYhukfCgnLpM/29nZDHjN5nOwnEs3Ic0z2NCOhzOxlJgtnzC7ixhs703A4nZoNtv/+3d9PPK55rS1ebJzv92u53UaNQvixx5D55z/hrK6Gq6YGSjwOJZOB6nZDDQSQqKxEYtw4pAYNQuCwwzTbTQjNm+2NN7T97bcfsHSpdRtWr9ZEs3ffLbzd++0HPPSQVq2eAcACG8MwTLdJpVJ5PczM4Zn5cpj1VjSjsEyzcEZVLmXRLBgM6vnEKHzS7XYjnU4bxDKzeJYwJdqXc6BZeaFRtU0yDOXQULkipyyoATD0vsrLqL3moVcVOnc3/vQnTWACtPwYK1YAw4ejbdEiRN5/v6Bd6EIYYDT4AE3A2rpV6xHNZNB0xhlwv/46/HKuEAtSDicyw4fBc8opwOjRwIcfAk8/rS2cP1/LISJTW6t5vpEhJ+Pzafuw2TQh0Up8GzdO89qbM6egc2byw7YbwzBM4ViFZeYSzuLxeFY4Jolmssc9AIMgJodkUgoLOeesLJwxuznhsOZZRh76t90G3HADAPTMftu4EZg2TUv7Qdx+uxZuCqD2iivgfO45hJqaYDd3okpkFAWZQYPgOv54LZKhvFwTzkj2WbpUE8YIIYD77gN+/GPr4lqjRmmpPSIRYMOG7DxvDgdw003aubM3Gxc5YBiGkcMy8+Uyo3ErYSyfeNYdFEUxeJfJ4z6fTw/NDIVCCAaDCAQCehgAiU5yQv5MJqO3ncSz1tZWw7RqqhQpV960EtDk8E23241AIIDS0lJ9mvYhi2gErWPOdZZPRNvr838IYawQ9ZOfAMOHAwAi77+PtlcXwdaFJ5vaEYKrC2w33KCVfI/Hten99tMMrNWrgYsuQikl1e0CZzoF58aNwN13Zy/86CMtdGDkSG16+3bNc27dus51BgzQcouce67moUffuxDAtm3ACy9o4uLatdr8deu06qevvKKFpDIMwzBMD6GwzEK8zMhukz3L5PBMAIbOP9l7TC4CIOeaNQ+7PM8r07c8/XSnuDZ0qGa/ddAj+230aODss7Xcu4AmVl12mVbh/Ze/RPlDD0EpwDfKLgTsO3bkzrf29NOdApsQmoD36193LlcUrRDW5ZdrOYCDwc5lkYjW2frII8Dzz2vREum0JrBt3ap5s+2MomC7MSywMQyz1yGE0CsldZXLjPKDySXL+1o0oypMuTzN5FxmFJYpi2Vm0UyGqoK2t7ejubk5yxPN7H0GQA/1tBLRUqkU7Ha7wfOsqKgIAwYM0KftdjvS6TRsNhtsNpvBK01VVf38zFB4g1ydk85zp1To3F1Zt64zd4fTqQlSEjaXC86hQ/PuIkWhpYTbrYUVkMCWSmkG1SWXZPVOJkIhqC4X3E1NsOXpEbXkiiuAf/1LGz/2WKO49p3vaCJfKJS9naJoIuLVVwNXXaUl4b3xRq1tkYhWiWvpUi3vCMMwDMNAs+8KSf5Pn1bhmHK1TACGIkkkhDmdToP9RbYPi2aMgeef7xy/7DLNhpPokf02ZkznuBCaV9t3vwssXw65uznjciFRXAxHNApXNNq9dj/yiBZtMHcu8LvfGcW1sWO16qYHH2y9bSCghZwec4yWq+2ii7QqpoAm6JWXA7/6Vffas5fBAhvDMHsE5I7fVR4z+qScYWRkUf6vXCJad0Qz8uCyEsxk0ayoqAjFxcV6WKbVkE9YUlVVF8oaGhqyxLP29nZkLAQRyoNmJaJlMhlDTjOXy4VQKGQoECAbmmYRTQgBIQR8Pp+lZ5nT6cxZXICN0Bx89lnn+H77ARUVvd9ne7sxBHP1as2Y6kAAaC0vR2TqNAQ3rEdQLoQAQCgKksEgRDIFd3sMOX0IFy0CZs4EDjigswCComhG1sUXF9ZWu12rVLpggSbStbRoRRouukgrO8/3DcMwzF4LiWCFimYksplFMyFEXtFMDsmUwzKthDO2V5guEcJov51wQt/s11xU4OSTgR079MmEz4eWqdOgOhwo/2Qp7Kb0MimfD2nFBlcmDTt1spqJxbQ0H9ddZ6xcv2AB8PLLWqqSQpg1C/j4Yy3H70svafPuuEMr2NWdIg97GSywMQyzS1BVtaBwTPokAaw/RDPKGZZLMHM6nYaqmXJYpjxQ8tlCvbFyFQyg8XiuByOs86BR/jMAhhBOCl0wh2bSOVIZdlkwk0MizFhV6CRvNM4Zkh/qeafvL5VKwbd4Mcj5vnboUKx48020tLQgHA5j6MpVGBiLIVxVhYyagZpR9SqtdrsdQyuHQqgCIpNBY0MjNrz/vpZMuaoK35AP3Nqqj9b5/fhs+HAMEALDPv0EQao4CiBltyM8cRLCI0dCdbuR2r4dxdOmonzJkk4Bzcz69dpA3HFH4eKazLx5Wsjo4YdrhutHH2nho1df3f19MQzDMLsEIUSXIpk8rqpqlmhGNogsmNEnFUKS85lREQCrPGYsmjH9Qk1NZ0em0wlMnZp3dQGh9W4CBntbABCq9ptRVRWOujoY7lRJXGvef3/UpjMIKAqGfLQENslOjw4ZitZx45AoLUWquhqBQw/BkExGixCwIpMBfvvbzumJE7snrhEej1Yx/tBDtcgDITQbcMUKLTfbPsi+edYMw/Q55DXVVTgmjcuhiySa5Uv+313RjHJiWOUzo08qAkBhmWaxTA4N6G4IIwmIJJZZfVp5nxGZTCbLA42MUDIuSTTz+/1ZRQSATiGMhDVZRJO93KzaYbPZ4Pf7LT3R9rgKnT2E7ku6/vK9TPcwVX6NxWJZ9zqJn7FYDJFIBNFoVN9ODsn94ebNOKvjmE+/+Sb++Omn+nd9jcOJA+x2NLbHdMNMAIAQUGwKqrZVQRUqSlIpfLH0YzyxcgUymQyK43GjwNbBNpcLl4wahVNUgem1NRgkiWt1fj9WDhuGkaawzHR5uVYV68gjtXAAGZsNkIXY/fcHfvhDAD1M8LtgAXDttZ0G4b33aiGk/GLEMAyzyyhULKMcZgAsRTNK/m/u1CO7RrbRPB5PzpBMFs2Y/oTuZbKT5agQsuHs69djdsf6Cb8fr77yit5xmk6nMWLNWpTE42iuq9OiPzoENrfHDa/H2ym4ZTJoaWlB0+rVAIDh0SgsEmtgxy9+gZaaWjjfew9DPl+ui2tppxO1M2chOWyYcQOnE/jNb7TccD/4Qf4TVhTgiSd0ca1H9tuTT2oRDYkEsGaNlkv3pJMK2sfeBgtsDMPkhEqLFxKSGY/HDR5P+UQzs+dZIaIZhWXKopmVeBYIBFBUVAS/32/pYdYb0UwmlUrlFc/yeZ/R9ZEf2KlUShfPhBCGvG1FRUV6z6wMVdikEFQyNMlopQIHudqiKIqlgEYeb3sSVNnVfJ+Sd6DsFUjfD4ldlIePRLR4PG4QHylfC4UdW1UMsxpoPdqWIPGMhlYpd0YqGkWj5K2ZLC6B8HiQycjehNq+lIyCZCoJBQqE0F5m7HY7nE4nhNuNDGDoBU0rCp496ywcMnYsDnjxJQxp6gwhbS0tQ9OYMRhisxnygaixGFJVVUBJCfDii8CUKZ0JfQEtJILCAgBg2DBdDOtxgYabbgIefFALYdi8WQtDPf74vPtgGIZhCsecd7YrAY2eYWbRjOaTYGblaeZ0OvWcZlaCGYtmTHcw22K5CnLJApk8ni/XMkUZ0DFk281MSWOjLrCpiQTWrFmjp1VxOp3IqCoUADapg9umKHC6nJ32vKIgY7PBWxTAgNGjtXedUaOyD/a972HwjTciecEFGLhxg54vN+1woGrSZCSpOju1h2w3QAsDXb4c+OtfO/cXCGi5eskDz+Uy5H7rkf02cSLw7W9r9hsAPPAAC2wMw+z9kFdVocn/zcJXLtHMSkQrRDSTy5PnEsxINMtVLdM89EXFSTl8NZeI1tX5UYiEnPtMNjyFELoHWlFREdxut2XbqZQ7hWCSJ5q5na2trTkNgV1doZMMchKkKJ9eNBpFLBbTvcDIw8tKvCXxSzai5EE29snwko0jWdCiga4hLZe3IYFMFsloPNe+zfMI+r5l5HthhyT0TgD03HyUK0+x2eC226HYlI75dtg7lvt8Pi1MJhzGjGlTMfmyy1Di96OsqgrKqlVAfb2+b8exx+KHV10FzJyJ+N1363nV2j0eVFdUANGo1k65rUKgfeVKVF9/PQDAd9BBKF60qHN5R7VTneee0wo2TJ4MoIcJfkMhYxWt115jgY1hGCYPQohu5TGTO0TNoplZLJMrZwJahIBsu+XLY8ai2b6H2Xaid4J8XmHm9wjZE8wsipmLVuQTwQAtmoNsKrvdrk/Tp8PhQFFRkS4EU24++X2EokTk9xN3Og1x331QhIC3vR3Xnnsu7IMH6/d6dW0tItXVKCory9s+oShwOp3w+/1aFU5z57/Ho3n319bC/fnn8MVi+qLtQyvRnk5reWvz2G7KwIGocLlg6xDEMGQI0NDQuX4iAdxyC/DHP+qzemS/XX55p8D21lta9dN9JOpFhgU2htnDyVUt02pevoqS+cIyCxXN6IHk9/u7zGkWCATyJv/va9FMhrydcolnVCShkP2QcWA2SKmaptfrRXFxcc7cZBTKmiuPGYXdtre3o66uLmdeNLlCJ+2n0AqdZo8s+R4gjy8SvyKRSF4RTB5kA0o2kuScK7LQ1R0xrJDvh9Y1i2Q0LYtiuYw0c/4XueiDoiiGaTkM16onnww6Cu+l78gdDgNvvgkAOMzjwY9/9COUlpVhwIABqHzhRbi/+AKhigrY7DbYlI5j22ywSb+LVCKB8kAAFS++CDz2mLHAAbFokTYEg/B0hIYKADXDR0DpMIBEKmUM+QQgYu2IvPkWACARjaJYXvj11+YLrvVa3nef9fcBQAgVEMh/Xx55ZKfAJicRZhiG2UewCr/M9WlO9SCLZkBuDzMK2SShjMSFXHnM5Hksmu2ZmD3xzYM8Px6P653FskBmFsHMXmBWnv+57CwSvciGcjgc+qfZbpI9ImmwEsJonFKkyKKaPMhFvHrExIm6HeRavhyorOzJFwLvihVasaiPPspeHo8Dp58OABgovR80l5WhvbgYCrq23QCgTFGgS12trUBTk/E4Tz4J3H47EAyix0yfDpSWavtOJrU8bLNnd73dXgYLbAyzm0EPtEKT/1sJLn0lmpHRRT0ruTzM6CFnFZaZKzyzvzynhBCWuc/kcMFC87iR63k6nc4yJCkcz+fzobi4OKdYYLPZDKIXfVIeMwrDpbY1NjZmGcpkJFP4HxUmIGPEbrfrAlZLS0uWWCjfL1YimFUeFdnIovO1Er2s5lstLxRzr7ncy0iGkCyIyeNkzMkvFvT7oBcD2k+uHDCyh6D8Xcvfp5zHjubJ1VN9Ph98Ph/8fr/ufen3++Hz+VBaWopQKKTnzXO73fAkEhAzZkBJpVDc2orr9tsPyoIFAIDqjz5CZKUDTo8n90VTVZRs344Bn36iJa3tCinvWtLlQtxuB6gKlcXfE8XhgHPoUAghUPK5MQdbfNUqmFuWfPRR/HHIEExYthyDo1G0bt6sfR9qBqoqIFQVRUVFqMxniM6a1TluFvEYhmH2QOSwzEKrZcrIzzY5JYRZOANgEM3ISyefaEbzWTTbtciCVD7hSx6n+4pEMHPnpjlFiznEVxbAzGKYledXLk8wKqpF88hWJWFM9gKT3x/MgpdZBOszMaw/OPDAThvlL38BTjyxW5u7mpsxZMVXcH+ytKD17WSTA2gNBDRhDchruwGAs60NLtnJQvZeczo1GzASAZ59Frj0UsN+KFec6Mj9q+SuK6/lcps1S+80xtdfs8DGMEzfo6qqnuepEE+zVCpluR+5+mA+0cyqJ1OGBBqv15u3AIDcA5RPLNsZopmMWZAyi2eFep8BncYukC2ckGFBwmI+rHrOSKhpb29HJBJBXV0dIpGIIdl9Mpk05PsisUsWjuQeP9ngkg0k2djqygust2KYbMzLYpUs9pmFMdkN3yxy0WAWy2TPM/n8SDSTBU0S3SiHHYlq5hcPOkcZc0gn7YeMQxLG5IFENL/fD7/fj6KiIgQCAf37J+HMPFCvrCWnngr8/e9au3/7W60aE3lDJpPZbvgd2FIpDF27Bt5IxLggGIRoazOYQZmyMiCdhl2qKOpOJjFiy2ZsGj0aabsD9mQSNiEg6GVNCCQyaWzftAlF4TBGbdpkOExaCkGN2mzwqypc8Tg+fvRROJ0uDIBAJJXManeXIndxced4e3v+dRmGYXYBFJZZqJeZVYcoiWYAsjp+ZLGMbBLKt0oePoUIZ73JN8sYyeX9lUv8srLTZI9+2RvMvL1sA+YSwcw2mdnuMgtXPp/PIKbK65iFL7JbrISufCKY3AG813PxxZ3e9s8/D6xdC4wfry/OZ78V79iBgVu3ZMtVPp+Wg5aw2YDRow3V2hUAw7dtQ82YMYiUlkGNxTSxjX7rqgo4HB33j4pBy0yRANJ7YqKsDO6aGgBAzXPPYeXo0QhWbYc3mUS0pUUvrAUARUVFOSNyOk+suHN8H7XfWGBjmB5grpaZz9MsX7J7K9Esl+dZLtFMfigGAoG8uczoQbm7iWaEEEKvumglnrW3t+cUIHPtTxadzL2CciJTKzd583Wn/VCvH4k9smcYjct5Tsz5KMweUbJRLefrModEWgliMlbeX7nErnzLzMY53TuyAU/HIq86eTuab77WdF3ovpe96OSqY+ZcZ7InmuxlJrfZ/B3J1422ox5Wur9lsYwEM/o0i2W5hDP6jfWaq67SBTb897+IPPoowiecgPaJEyFaW5HKaNcylU4jndbuLSXWjunvvG0Q12pLSvDf6dNx+FdfwZwm197YiEcqK3FJayvk1y1fLIZha9bi/SFDUOpyochmA8lfDgBJVQDRCKZs2ZxlCDqk38c2hwMTO/J7TIxGgWL5uiid/1vkpMtCNi73wfwdDMPsGuj5VGjyfyvkDh/5+W4WHczzWTTrHXLnnJXYZQ5fzCWM0fY0LYtgVvlYc4lgspe7OTQ3lxAmRyjkErLkaquFCF2FCGP7hBjWHxxwgOaxtXy5JlpdfDHw7ruA3Y7AIYfk3CywZAmCW7fo08LhQPrKK5EZMACem24yrqyqSNjtSC5YgKJ339Vn21QVg9etw5apboSdLthTKQjKlygE0gDaW1sxeMN6eDvCQQWQZcelbDZQWbOitWu1DmOPGza7HW6PG4pig6J02NH2Av7usP3GAhvDAJpBVWji//b29px5sGhfXRUAyCeaUbgZCWO5QjPlUEFFUXThIF8BgF0hmsnI3md0LcPhMMLhMKLRKKLRqB6WKRu5ZOiYjSMSyMjQsRJzrCo+mo0hwNrbCTCKcrJIJg9AtgeUvE8yhM2Gjpy7K583mNVy2dCinCnyPFmEosHj8eiFJWTR1ao30woKPY7H47rRSaJiLBZDW1uboYqq7HFmVSDA3D55cLlcWcKZ3Fsrz6fqZCSamQUz+k2QEEbXjM7f/LuieeQ1YM7TpqoqotEo2tra9PuMDHDztFzogoo5UD5Eum7k4fq9ykoc2FH1SbniCvzioYew0uezFmvTaTywZQuKpAqkDxUV4bceD27+5BOMkg0ciUupqhQAVVFg67hXg6kk5tbWYMOQoRCKArUjT5pQFDjbY5i3Ywf8HeK2XJ1UKIres6m43VrODQBzysuRHjgQ7voGlIdCnfeZwwFHIWFIX33VOW5VUYthGKYA6PldqGiWT/w3ewu5crw80nOMtumqaua+Iprl8ky38gTL92kVHmneb66QR7P3F2Eu5GDOpyrbXbLtbWV7yTYbgCwxq6cimFWnImNNvvQlvR3M95nzuusw/PzztQMvXozas8/G5uuuQyYYROa44wzpZlKpFAYsWYKD3n5bb2vT0KF46+KL4YjHcfwvf2k8D2iCmHvNGq2qegcpux3OTEbzZFu5EhvGjUdKEkoVAE5FwdDNm1C6dau+XdrlgjNpjCoINDfr4/6GBuy///6ofuklROx2OD3e7l/8FSs6x/dR+40FNmavhPJwFZr8vyuvKFm4ySeaJZNJS/FNfnn3er0IhUJZQpn8wk8P9K6S/5Nw4HK5+kU0Mz9I6DqYE9jTEIlEEA6HEYlEEIvF9HBIEtQoP4Q5vNFcbcicCNX8YDPn2LLqUaTtCHNPspXXWK5taXt5O7rmJMhYJV21Mo7IADMLSrLYRAZcrlBD2QuLvMesjENZ9CMj1HzfCiF0oSwajepij/y9kWBGoplZdKSQTfO5yOdELuVyHjVZtCLoe6TrTYKgfM5y3j/aL4V1Wr3IyL8pAPo9bJVvThZx5UGuZEr3slkItqqWRfszC7tWn99XFCyy2TBQVeHPZPCbZctwXTCIF12urHUvSCRwiOQd+yOnE08nk/h9bS1Ot/gbtBnASNO8jKLgi1AxZrVoxlVFPI5IOIxYSQkgAHsmg4GtrShvatSFOACorazEkA6hzhYKAR3G2ZhAAAiHAQAnHXccqp1ORN58C6Hy8qz2dMnixZ3jc+Z0f3uGYfZK+iIsk5A9x+TnOpFrnJ7ThQhne5IYIj/bcyXAz+UdlitthTzfKu1FLhGMOuEAY6isVe5UeVoWuvKJVbm8tvpKBJOn9wa6I0DJ68tpP7q7n94MMmRXmt8h5M5fc1vNgqx872cdz2bD7GOOwcTXXwcAVPzzn4hUVeHj885Dyu/X71uHwwFPLIbZDz2kt61t3DisuvtuTP/sM4z54x9hN4lfmeHD4egQyNxSDrW2yZMRWr8BjvYYbEJgWNU2bBoxEorNBggBbzSKQZs3wS2FaMbLBkCJdXbKwuPRCijIYZzdiBKyZNs2bQA687Htg7DAxuwxUOXHQpP/d0Uhohl5oZgNNBJASEiRPYKskv/L4kdXBQBkEaEr0Yz+2MveQuay17JHjZznK1f+B7mXxSymkPhC+5f3Jws2XeWNMIf8EVYPSbMAk6t3UTZozWKXbGzJx6I2mD3ZzCKR7PFEg9vt1t3z5XBB+u6oSqRVOKEswsltpPxh9D3mEjJTqRRisRhaWlp0UVc+HyvvMlkoo08K0ZS/DzJqcw1yqIMsCJq/Q6uXGtnbTvaoNF8Lp9Opi3AktpkFPBrkCmnkJWZ+QZDPVTb46bdt1XtuDgMynxcd18ogk401GqdQIvqU16HrViMEzvR48FIshgCAIiHwSGsrTnI4cK/LhU9tNkBR4FNV3Cr9jfub3Y6EzYZlySQGSb+VVrsdoY52NgaK4E4mMFgy3pyqitHRKFpdLoQ65o9sqEcdBNzxBEKRMOyy0A2gdsYMuNIZoENgc8+bB7z2mvb97tjR+WX7/bo3W7dJJrVqVsRhh/VsPwzD7BFYhWXmE85yYe7Mke0o2Q6TX/bN21oJZ7uDaCZ3MlqJX/meZ10JX3IO11yCpFXoo/naWoVCyssB6PaCEMJQ4dFKnOqKvhLBdoUYlkuQ2l2HnpDPlsw3WIlzViIsgKx5VuIY2Vu53i+s7FazaEtRDXTvyNEOdrsdtrvvRuzKK+F7/30AwJglSzBqyxYkf/YzqGeeCXswqP2N+fGPobS0aDsOBhH87W9x8MMPA//5j3481emEjfJD19Sg9fDDEZI83gCgZOVKtJeXw94egwLAE4thcPV2ZFxuBNpaDWIcALQPGIi6/fbD8Fdf6Zx54IHAW28Z1kMg0JOvupNHH+0cnzu39/vbQ2GBbQ8mnsrg820tiMTTCHgcmDmsGB7nnlP9Rw7LLCSXWb7E/URXopk8mP+QmgUUcz4zebn84DeLZrJYJns3kVFmJZa0tLQYwsrkT7PwJQtgVjkg5Ie17Blm5eklf+YKazOLZ4U+cM3JVWUhjAQUaqfZ28nKq8ssnplzhMleYXLvlNmgpBBcekDKgo+cg0sOM5RFOnM4gDmM0CxakjDc1tZmuJ5Wgm4+yANN9qSSP81Vo2SsPMrMPfZWYq4cpmrVk08vMLJgmKsqmc1myxKggE5jW/7uSAiTPQutXrKshDBZHJcFcrPBZWWE5TPQzL8zEv7MHo/y70K+t+XvQQih/wbM38tGRcF58Tj+XF+Pso59n5hO48R0GuucTnzh9aJIUVDScYwEgJNUFeeYjKnGoiK0+nwI1dYCACalklg1YiTQUI9Bzc16Do6QqQCBXQgMlgoXEMlAAA1z5iJRVoZBH7zfuWDBAiCd7qwYRUydCixbljfBL6GahbiHHwbq6rTxkhLgtNPybs8wexJ7uu1WCLLQU4iXWS57Qg7BUxRF76yyEnHouLJtJ9sOXQlnvRVXZJvLyvPLShgzC1y5OoBke60QsSNXRxldM3NHFdlD9MySr7+8TVfeX/noSxEslwDXXdEon/iypwhWuebn6yDt6wGApT1l9a5hHrdaJl8fKxtMHqyuEd3/dD+b06/QO4P8KTtEmMNxzZ857/3XXgPOPVcrdgDAVl0Nz/e/D9xwA3DwwZpd9OCDnev7fMA3v2nYRdrhxPZx4zFs9dewqSpsySTaamrRNnEShqxdA3vHOdtUFf6OwgREsVTAihCKgpaJk9AyYQJ8NTWd+dcGDwa+//1sgW3KlM5r3137rb4euP/+zunvfjfvtnszLLDtgcRTGTz98Va8/EU16triyKgCdpuC8qAHJ84YgvPmD4fbsfONNXoxLjSXWbJA7wbZSKMX5lwimvzHloQX+sPp8/myvMtkzyZZaEqn04Zqh7IxQe2RBaJwOJwlfskv9uZeP6trl0v4ssphZf60MtZkEcPKs0b2xKH1cj2MPR6PboQBnYaSLKjQfDofEptkYU4+Bo3TdiRymr2cvF6v7gkmezjJx5LPmcJ95YehbEiT15TZs0yuikP7JE8xK0M513R3DCchRFY1UaokaxaNUqmU4Xuh85K/I7/fn9foIiOaREabzWYwLmmcBEhzSK15AJAlzCaTSYOIJ98P5mPJv2uzmGslIucSw+T7Tb7f5e8inxhmJUyb71G6p2ggYVI+T/mayfceXW+zlyuJbWajW7XZcG0igcu//BIHSgbUuFQK40zu+24Abuk8k3Y7tlcMwoYBA+BKpzGythY2AL5EAh6PB9UzZiL02afwdYRydkXS6UTroMGIzp4NYbfDEQ7DS+IXoJWknzIlW2CbMwcBv7+gYwDoTAa8YQPw0592Lrj0UsDbgxwgDLObsbvaboUgPxMLEc3ydYjK4hZ5fQPI+ptLL8gkMFE7eiqayTYVPWNkr+ZcIY9Wy628v3K98FtBzwj5GW4WLGQ7izB32Fh9mtfrCWYxwer5n6szNNc0fX99IXrl2ldP6I2w1F+CVT4hrS+Qr7PVe4VZ1M0lkMnfVa7jmEMzze8Cshgmn59ZIDMX06J3ga7EMPNnv3oper3Ac89pXlw/+IGeMgOtrcB//6sNMiaBrK20DM0HHADV7UasrRWBjkiBknAb6uftj0a/H+WffVpQU1RFQaS0DG0zZyIVCgEAijZu7Fzh5JOBE07QcqTJ1eFnzwaAvAUazAQOOUTLxfu97wENDdrMsjLgrLMK3sfehiIK/ItUVVWFYcOGYdu2baisrOzvdjE5iKcyuOGFFXhvbT1UIRD0OOG025DKqGiLp2BTFBw6fiB+dcrUPjHUyGOmkDxm8Xi84AeclUBmnkdeObLRQmF0suiVyzgxC0/00mwWHMy9G/QHOxfyS34+jzCzMGZlkMkPMfMDRj4vINsIIGg72Tix6kWyejCbhQZ53OypZT6+LGwBMKwrX0fZ8KWKjVYVHD0eT5ahLERnfjDzYOWlRe0gYU72UKOHs9wjnM9QLsSwMn8vssgci8WyhCd5oNBd+T7usndMuv/M5y1/ygIrXSfZ0042imSxjtohG1/m+1QWjGWDieZZCY5deYjR79VsfNG41bU2XxN5yHevy0ab+drT3wT5+5VfFuTwYBJ9ZQGY7mMKe6Hzk/+e0fWRPe2sqtACgN1mw9y2NpxcVYX5dXXI9xe92efDp7Nnw+nzoayuHmLgQLjcLlQuW45AtdYDmSgpQfWCwxBcvx5lK7QiAhmfD/Hx4+GsrYWrI8RTAGgcWI5YwI+YPwC4nLA5XYAQGLpmNfwdvaSJYcPg3rpVM6wmTQLWrEHHRQPef18LP+gOO3YAhx/euZ8RI7RiB0VF3dsPo8O22+7BzrbdCkHu5CjkMxf08msWi+S/2/Q3Vz6u+Tlm9oqibcx2itWzO1cnozyvUO8vwJiHy+yJbWV30jmYz1s+P9k2pWez3C4r4cg8bTWvKwGMviOz+JXvU77e3WVnCU69Ealy2RK7K4V0wndnWaFCMNl45uuXSwyzWibbVlaeiN0Vx/aU7ww7dmjeag8/nCWkGbDbgVNOQYPbjab1G+AcOhQA4G5owJD3tIqhAsCOQxcgFQxi2H//o+fEjU2eDFsiAVdVFWwddmPM50NrSSkiwSAyXo9muwHwNzdj6No1ncf94gtg+nTg2WeNQtiJJwIvvKDZcYUiBHDjjcCvftU57+mngXPOKXwfexnswbaH8fTHW/He2np4nXYEvU59vsthg9/tQFt7Cu+trcdfP9qKSw7OrtyhqmpB4Zj0mc+okl+wM5mM/nJoFbomf8ov7LnEMcBocOXqcVAUJavnkgYSVNxut76+bNhR76X8EJINHnnc7EVDx8714KZxuZ2EbMhQqEO+B77ZQ8vKiMzXHlk8oGnZk8tKjJITx9N5Wwlm8jybzWbwOCPxTJ62MvAAGO4RWWxIJBL6feB2u/V2U3userEURbH0aDRfq1whBub7m+4X8jCTBT/5XifPM7qvqS1mw8g8nUwmswxZ2s4sTlFbzAYP3VdmZG+xXOEmdE1kMdbc00j7kvcpv+zI8/MZsvIyEvTk3wcJoFaGnPnlQv67IBtfZkNM/g2YhXQ5xFguomC+h30+n8Hrle5F2YiNx+N6JVwqnJJOpw3572KxmF48QhZY5cHr9cLfkRQXADaUluKekSNRmk5jWiyG+bW1mNdRYTPjdqP5+usROOgglPh8OPqLLxB59FFktldBaW1B2udHfOAA+Ku3a1WomptRvPprtI0Zi5KVK2ATAvZYDP5LLwWuvBKorAS2a+uKhcfBUVmJoHQ/+T77TBfXACB92WVaafft2zVjsvPHAxx9NHD33VqIQCEG8VtvAd/+NkDVrhQF+POfWVxj9gp6a7sVghCiS5FMHs/1ki3bVU6nU7eh6HkkP0Po2RKJRAwRBlZilmw70d/5nnghyXaMuYOU/s7LIfjmba2um7xcfg7KHSKUqqRQMcw8Tc8tK/HAaj6dn1nos0pT0deik1WnVnf2x6BgwavQjvpCoXuJxuXPfCKVLIzJNpbZzuypONaTsOI9msGDgVtv1cJDly0DPvsMuOUWLYwSAM44Azj/fGD4cGDDBrjvvhvlGzfCXluDjNOFZCiIRHEx3C0tUAAM/OxTVB92OMIDBiDUsQ+f3Q6sWgXce2+n5//IkRDf/Cbk+AFbJIKBjz2mT6f22w/O6dO1iY4cujovv6yFrT70kHYOXdHUpIWa/u1vnfO+9S3g7LO7c7X2Olhg24V0Nw9HPJXBy19Ua72fuoEmoAoBoapQVQG3TUVrKoW/f7gOw5NbkE6065Uc6cWOwq6s3N3N47kSrMthALSu2XiwGuiPPq0DZBsgZHjJf9hz9ZTQ9uaHkLw/sydKIcYDLTPnSertw0Hel/ywkYUUK88zEpjMifvlvAGyaBAIBOD3+1FUVIRgMIhAIGAIh8xkMnolRNljpitDiUJtzQIafeZzvY7FYmhra0NbWxvC4bBebTQajVqGYKZSKf2eoXbTp5y4mLzcCrn28jgZzGZBmIQQCmMm4czsvSUb4WZPRdkrK991pRcNejGSl5mPZX5hMQvc8kDb5Tqm1e/Kqp1WvxXzerJgKnvC0WD1smHeXghhCD+22q9ZNKP90e9UFsvM+RDlyqvm3wstM+f3o3XoXiQvtHA4jIaGBrS2tupFP1pbWxEOhw0CGt1bJGCaQ87l0B/ZU8/lcqG4uBhlZWUYOHAgysrKMGDAABQXF8PpdKLo9dc1ry4ALaNGoXbFCvjuvVev5Kmnk5XyqMmvrCWrV0M4HGgdWI6SOi0/G378Y+DwwzUvtI58GwO/8Q3gzDM7N3z8ceDOOzunL7gA/htu0PKvXXIJ0NZGN5HWmxmLAVdcATz1FHDttVo4grNTWOg4ca1a6B//qPWiyjz8MHDkkWCY3Y2+sd2MBL1ORJNpvPxFNc7df7i+v0LFMrkqsjnUi/4GmT3KZLHM7GlPth51isrPFPPfZvPfadmDPZd4RLaVPN8swMnea4BRCLPqdJRtQfk5bKYrgYmgdskdiPSMM4ehmufJApi5OAILVrsnvfEIs1pWKPT7kX8L5g5g2Z6iT8Day5B+B1Ydt1ZRO4WIveZ3NivyieNyZ253t+3N8nzbdbWNbHNbfZptW/O2hvmVlUBlJQbdcoseidA0aRK8d94Jb0fFdL0r0ZQCVwBQADijUQxa/AFqBw9BsLERiqpqtuCNNwIHHaSv7yspgU/2JKuqAo45BohEtOlgEM5//1sbf+EFzbYz8/LLwAcfaMLZpZdqna9m6uq0TtDf/x7oyPULQGvLX/5SWMfqXgyHiO4CepqH46ONjbju2c/hcdrhdzvQ3NyM5uamjh85ICAAAWQUB4TNgYEbX4G7dVvWHwKzUJAv9AswupbnetGVhRXZ3V2+vXLtQxbAcrng5xPE5PlWopyVoGBut1mAoGVWDxr5IURigtzjKFdHlKvOyEat3PMrG3FWBprZy8rpdFqKWy6XS7/m5I1IQlF7e3vehxtBHjSy5458DLvdbgjlNVcijUajiEQiBlGXhlzHp3xUsgBIg1XvM82XDQ4SwGKxmH5sGperZ5JnpuxJaSVMWd2/ZqOG7hOzIUUvB/TyQuvTp1lMlZdbiWpWWP3+5IG+S6v73PybMwvQuX6n5vOWt7US/qxEOGqT1XkT8m9J9qKgT8qXJxcfkUVoWm71ciMLcfQSCMDgkUjhvW1tbYhGo7rHYiQS0UU0Wpc6Heh85U4E+VrK9xK9XFI4KYnhskhNvyfz3+VkMomZa9bg/5YsAQCk0fNeskgoBK+qwk45QoYOBcaMAd57T5t+6ingvPO0fBr/93/G3snx44GPPtIqhV5wgVEcu/de4J57Osu0Ez6fVq593Dgt9GD7duDTT4HGxuz1Hn5YSxTM9Bq23fqOvrLdACCdSSOZ6EiBIVQIVSCeVpHMAKcMacMQR8yQH0y2G+S8pnKHRa6XbcAoHtC68t94899Z6qwwf5ptErljRP6kY8qdUVbPP3lZLruOxnPNo3Gzh5dZ/JKXyX+vzSFrXUVO9MfLflfL+2vbXXHMfMu7s51VJ2OudxqzmGx+17ESpPJh5TBgtpnM9qJsP5nPx9wuK9vPLG6bRW75vrUSwM3HNJ9vPsHIat1CPq3WzyWa92afha5vPl4+obKn5BPFj//2t+GmzsgeknY4kRo7Bt7VqztnXnQR8MQT2vi8ecDHH2sdmH/9K3D11QBVLQWAZ57RQkJfeQU49VSAKtIffTRQUaFtY2bcOGDmTCAYBKJRLbx09WrtGDInn6yFhnYj/+7eCnuw7WTy5eGoao7hoXc34OsdbZZ5OCLxNDKqgNPe+VKbIM8dof0nAAioEC4HWiIJ2Bsash42QLYHivxH2FxFhY4lfxLmF1izkWY2aMwGodwWGjd/djVOmB8wVg8ds2gmizrk4SJ7u5iriMrVEqnikux5RR4tsqjTHRRFycpPRp8UTppMJg3iWX19ve6ZWMj+SYyThQq52AMJZ+StY1XJ1BzOSevnCgshbx35GpI3EQkuJNzJXpY0TedLAgddWzmvlRy2AhgrGuXqbZLvNbm9Vr1RND+TyRh+F1Y9WjTP6kWAps1Gl9kwMoe/yPcwYJ2XTL63c9378u/R/DuS22Ilgpk9IKyMFbPIbfUSZzWQhyb9zaDfoblirPlvizydy7AhEZaEtLa2Nj2kMxwO67/btrY2tLS0IBKJ6GKbLKLJ359sQNNxqRqezWbM1xYIBODz+RAIBPRcg1Z/w8irVP6k3zuJ2uWSICU/wKN2O1YVF8PtdKJEFUjabQjEEwjFovCYKowCQKC1FcJm08QuVc0O81y/Xuu5fPppoL29c/7YscD//gds3qwZdV9+2bns8suBa64BLr4Y+OEPgUce6VwWi2neah29tZYcdpjWIzp6dO51GGYX0Je2GwBEo1E0mcRlFXak7W4sX7EaNWqDwQtMFo3kv+1A599jWSyiv6Hmv59mAUnuADA/M+VPVdXSJJi9o+n4+bBaLnekUM5Kqw5Nq09zZ6dsq+Y7Zq7lZMf1ZNvuLM+3XX8ds6v29FWbChFBrMIe5ZQsZuGLhGTzJ4lhZhvFjHxsc+e4bJNYCVNW4pi8TyunBMoZnct7zOycYBbprH7fVgJZd78X+VoVuq2VzZprHbMNal4vl+ODvF6u5WYnCKtjmvfR1TGttu3OcrPtnWsfVp8OKWWRjJgxAzFFQayxEfZQCPb2drhbWuBuaoJifu9Op+BYvVrrjIzFtJkkrhG/+50W4imLcIqiVfg86STguuu0dWjfFRWa7TVsGHDKKVrKENkrbd06bchFcbG2vwsu2Oc91wgW2HYyhebheOKDjTh1aonuDRSJRLBmWxuSiSQaE1E4oCIWi0HN0B9NAFC0f3Y7AAGvU4EnGDQ8SGQXdW277D92VsaLVe4Hs2dVd7ASDbqaR4KX2duJRDErgU82Ls1hZLJhadW7QmGUJO60tbXp0zTP/CDNJTrQPLvdrreXhC7Z2w2A4WVfzomXkF6SZbFHDvWVQzisetXovGSxIJegJIcCm8Us83nTg0YWOuTl5mqYJFpYVcy08irLdW1l8hkOVkaf3EMoC79Wgpf5OLKwIhtMVgaG2aCSB/NLhdV9L+/DSqCSxVHzb5uEMmqzfL/Qy5N8zrROrp5U8wuc7G1oJVbL4xSuKVfTlKtq0sug7JVmZUDJ15fOh86JBDPyYgyHw7poRp6Uzc3NaGpqQktLiyHvpDlvUL5j03dG7S8qKoLf70cwGEQwGITf79fDicz3HQnH5KVmLnoBQP+eyWuUrrdHUXBZR/l3IlZcjPXnnIPYKaegZNgwOB9+GLHFi+EaWokYgJgQ8DTUo3jNGmP1T0ALMZCRp2+9FVkcdZTW83nFFcCrrxp7L885Rwv1BLRezocf1oy4Bx7QwgXkHlQZl0vr9bzqKuDQQ9k4Y3ZLepNDLeBxwG5TkMqocDlsEBBaVXPKQ6kosCk2tKdUxNMqzl74TYwvsenCvpVXcK7njfx33/ys76lII9t+Zq+vXPPyLcvlGdZf5PNwyeVBU6h3TG8+u/Km6c9jFrquLG51FSbZlXdYLiHMCllkMttH5ElptUz2IKfBSkAmu0YW68zVb2VBxcrbkcQwOVee3DlqFV7ZXUGmO5/9uW/5k+kh//iHIXUHAOA73wGuuQbKlClovf56RN58Sy9yAAD29nYUbdqI0Nq1sJntNRLXzCxdqg0y5eVap+fmzVrxKKr0Scv+9z9NXAO0/GnHHqt5uj3wgJY/LheTJ2sdqxdcAHRUKmU0OER0JxJPZXDWwx+hqjmGwSEvmpub9R7BjJqBUAVUNYO4zQtXKozRm1+Gona+9GVgw5YxpyDlLIIjFenwVwNg+gbTrgAciVZUrHgGisgYHkJdGT1mjzWg8w+3lZAgPzTkl3uzp5f8Mm7VmypPy279slFnbldvIU8sOdxLni7EI8yMLDRYDUKIrGqYZsGuq3wPQOcD3+xdR9dO3t5qnzTIoZIknslJiuVjyyIUYPRikkNe5dA2c2+lbHTleljne4ibH/RmAVNej+57wCgam+9duRffav/y/swCi5UQLd/X5hLiVoKUHIpL36XZUwzorJwmCzLk4WQ2fnN937JYZm43iVwkgslFMORPOQefLJbJ96L8KQvxZKjnyidUyHwK1QyHw7o4Fo/H9XBgGm9ra9O9HWUvNPPfNvPfNbOnq+yFFgwG9ZxoJSUlKCoq0vMb0vWh70wubCAPhfxdka9rUVERSkpKUFpaiumPPIKB5lxl//wncNpp+mS1hZHWcfER2LwZZV98nm2odYXdDrjd1gadwwFcfz1w0025q06pqtb7+dlnWq+oqmo9njNnAlOnavtm+gW23XqP2XYjKC2HgPYMqmmNY3CRC/ecOALIdHYmRdoTuPX9FtRHVRS7c3v7tCYVlLoFvj9VhceZLVLpx83jJZTLE8yqc8fs0SNvK+/DfIyefu4KQWl3wUoIlef3xadsj5nHZdvLLDRZLbMSmqw6G61EqHxirFmEMnfaywJxV4PcISiPm+nKI7KrTxaWmF7z1VfA3LmAnC/6gAOADz/UJ3PabgAckTDKP/kE7o5cu93C4wFSKcAqTc/8+VpI6JgxubdvbNREtjVrtHBSt1uLZJgzRxPnGEvYg20n8vm2FtS1xRH0aL2f7e3taGltyVpP2FXEbW5sjShwtlQbltncy6GOPBgpxQFbOtHR2a9on4oC1e6Boigoi2xGWUlIf+G3Eslo3OzxZZVMX85ZJI/LooK8Xl8/kKwMpVw9NzROSezlEEYa5JAvq55gemGWj0XeZ3KFQRIc6Jolk0lDQYlIJILGxkbDy7VZ+JB718hQkEUP2fAgI5sEB/ICi0QiWcIKCW0kvpD4JXvN0L7ktshePFYu+GbhTT4Hs+BlJUo5HMY/O1ZeQvl66M0vHWZxV75nCuk1NR/LKo8LhbGQF5a5Qqrf7zcskwUoekFRVVX3UpLvQ7MglEsws8oVZ3Vd6Lcpe5DJg1wAQBb3zB5ksnBL15g8xqzEL/ICs0rAbRZuc4nHNE7XxWqg6yaLViSQmyvHmkVUq1BTj8ejFwQpLS1FRUUFhgwZgsrKSgwePBilpaUoLS2Fx+OB3W7XK4OGw2E0NjaiqakJTU1NqKmpQUtLi56jLR9yh4Pb7UYoFNILG5SVlenCnd/v7/y9vP12diEAQMuVJglsAKAmk0h1FCuQaXa5EB8/AUPXrNbLvGPiRE0kW7Eid4MzGWtx7aCDgPvu03Kr5cNmAyZM0AaG2cMw224A0NrWikg4ootsAJCGDZuidtz5yBIMsoUN4tYAdRCq04PQklHhVjLQ7LaO55sCJFQbhLBhgjeKSGsCsRyeMFadpPJ6VuINAIONoSiKoaiOOQytP7xtzPP68xi5jtmfn10d2wpZ+Mr1LMy1zKojrVBBkb5r6gCzusdyiU25lpHAl6vNVp9kC5k7cs1ttTqmWcSz+jQL0wyzS0ilgAsvNIprgJbHdvNmYORIfVYu2y0FYOuYsRi6ZjV8lDfX6wVmzwY+/1zLi5YLq1RFZWXAz36mpfTo6jdSVqblZzv66PzrMQZYYNuJmPNw2OwdYhcAQHoYK4DqcKCkfAjKQjAIXS6PitWONtS5B0JBEC5bBg4FUGFDEnYoAMYVpXHa3BnwuObkzR9Gg1VYWS7Ryrw8nlbxdU0U0ZQKj9uBKYMC8LqyjbyeTudqQyaTMVR7lHOS0Qs3PahtNpsuKuSDPFDMidRlUYiSnVPY2fbt29Ha2qp7z5hDOuRcEnIom+wRBnQKWGTAmIUhqzBJ+iTPMXMlTjmUUz6eLM6YjRk5f1munkD6JCPNHM4nG0RmzD3YhGy4mY1+87HNoYtWgpzZK4nEURr8fj8CgQACgQCKiop0QaOoqMiQ904eSFyRvcdISJXvxaamJl1kJUHNbBCbhSbz92A2JOXwaLNHmZzjiwaqskrGpeydStNmgYyquZpFMbkn25w7pbsvBvI9RUKaLPqaj20uUEHio3x/yIKiLIrKoac+n08P3ywuLkZFRQWGDh2KQYMGoaysDH6/H36/H06nE3JoeEtLC6qrq7FixQo9rDQajeo50XJh5a1IXmilpaUIhUK6gOb3+7sOmRJCKzRAzJihJZkFgJdeQsM55yA5SgtLS1Vth6N8oOVunBUVsFdWou3TT1H8+uvazNWrgUcf1UIViCuu0OYvWwa0tsonpnmcHXywtv7MmfnbzTB7AVY51KjTRVHoGWRDWgDtGQUTx83ErAqH4W/UIULBEysT+Ko+jSQAv9MGuwJkBBBNqVAcCuaWO3HBzHI47UYPaiu6Ennk9azWTaQF1jelEEup8DltGFfqgtvRWf26EPvMyqO9K3GtL6bNy3q7vdV0Lgp95nUlkOWysXJhJTSZPeQLEcPMdpZMV6KYnGMsV4iomVzto2dzLlGMvceYvYbHHweWL9fG7Xato3HVKkAItB97LJpPPRVQlC5tN2dlJdoOOhDuRx6BPRbTcuNOnqyJa59/rq140EGaYPfxx1oeXZnBgzUvutNP14Yu3omZ3sEC207EnIejKFCkeUXY7LDZbbAptg7RSiCZEfjZuddi/1GlWXmJEukM/vpRdiWrEV1UsupLelpNqxCEEHolPyvxrKuXWzlkUfbOkkMBaT3ZjZ68uuTQLhJQzLmaZOHK7G4vhDAIULlc5WUhiF6wzeGU8n5l0Uxug7nX0Owqb2VEykaW3JOcT6Qyn2MurzWg00C3Mp6sjiV/H7RcPr65OAIJTV6v1yBUBINBXTALhUIoKSlBcXGxLp5QT78QQheV5IFEs8bGRmzduhXhcFivjGoloFr1Jptfisz3gOz1aC6oQV5yZHyS0CtfD/N3KoddRiIRtLS0GBI2W/0uuiOSyeO5oO9a7lkm8d7sHSnfQ0CnpynlTaNrnclkDJ4ZRUVFuncn3Ud0LLfbrXuh0T0waNAgDBkyBIMGDUJJSYleqdPlcunecPRdr127Fs3NzWhpadFzH1L4tBUk2ssims/n00W04uJiXcAl78ZevSi8/74WYgAATqcWFnrFFcCbbwIAQv96DlunTUPG6cy5CzWZhG/OHAz51a+0EM0FC7Qy7ABw7bWdKy5YAPzpT4CiaOtVV2sGnNOpGWher/UBGGYvxWy7AYDf54ffZ6yUFk2kYUtlcNhBM7H/6LKs/Rx8iNF2S6oCdruCUSWF2U1Wtkt3pxNpFc9+Vo1Fq+pQH0lptpsCDAi4cOzEMpw2swJOu1LQ/go9tvz3vjf76glmAchqoOeT+dM8buWZL0+bn89mu09+nuULnTTPp+3yHQsw5hOjjlerUFGradleJKwEL/rMJ+7J4wyzTyOEZk8R116riWCnnAIA8K5di9a/P4u2PKGWBtsN0MI6zz5bG3/iCc1DDtDEOzmioLVVy/mmdqTk4HDOnQoLbDuRmcOKUR70oKo5Br/bAX+OMrbNre0YVuLHYdNGwuPMNrbcDjsuOXgUzt1/OL7Y1oJwPI0ijwMzhhVbrt/X9KaaFqCFJFqJZrKQRd4tcnin/Cl7vMiiF4lPZm8t2ZCS843JFSjNoZKy+GTluSUXTDAXUgCQdfx8BqeqqgZjhLahkDezqGc+VytjTnaNl41DK++5QgxYKxHFbDCScUWCHLVNzvElhymSUEbiEk2Tx1EoFEIoFDJUUpWT6dO1Jo9GOTw3Go1i27ZtWLt2rS7c0CCHYuYLubBCPlc5/5h8D9Ag51WT2yv39ssGrfydkWcXkasH3KrtVoJZrtALK2GMwiHl79lcxEAWBun3av6tkojZ1NSEhoYGNDY2orm5Wff6TKfTloY4FSGha0LHDwQCCIVCuoBaVlaGIUOGoKKiQhdSqTCA7GW4Y8cONDU16ccmjzjKgWl1r7tcLkPFXTkfWklJiS6ekZDm7s88Yo891jl+6qlayfSHHwamTweiUThTSVRuWI/agw5GJkevpCHswGYDfvSjToFNDjl49FFNXKP1OG8Xs49jtt1y0RZPYViJDzOGFVsu763tZvYO6y7xVAa35rDdasJJPPVpDba0ZXLabjuDfB5SVkWQzBXEzUNXYh3ZJ0II/bkj2zFmL3B5Gdldcueh2QayOrbVtOy1neuZbu5AtbIv5fMyt7nQwWzL7UxvRPP0rjg2w/QZy5Z1Vlp3OLTCT4MHA2ecoRU9AFCxeRNsAwYgSkUGTGSFjJ5+uhbeuXlzp7gGAD/+sTFdRyjEhQd2ISyw7UQ8TjtOnDEED727AW3tKUMlKqKtPQWbouCEGUO6NLg8TrtlL2l/Y11NS8Bus0FV7WiLp/DaV9XwJltwxBAVra2taGlpQWtrqyGcUs6rRC/nNJCwJIs0QPaD0UqksBK2zAaN2TuLXqhlDy9zDrRcvZBmg4raInuJUdgnCQ/koSdXCKXztjIK5d5L2ZNJ7tkkZHGuK0hQyVVt1eFwGPLNyeuSh5IsnJkFNNqGkNfNVRDC6XTq3liyd1l9fX2WUEbTciimlaFtFlFkg1S+lrKYI4dgyJ5z8jz6buVQYvlTvidzeYZZiWS5vMdyiWLmcGbzPSuHlpoFSqsciuZ5iqLoIrgsYDY3NxvykdXX16OhocEgoqVSKctcKOSZKIvDHo9HF7FIYC0pKcHgwYMxcOBAFBcXw+v16vc9ifWRSARbtmzRvdDMf1+sXj6cTqfu0SYPdEyqDCp7ojnzeIj1K4sXd45feKH2OXq01jN60UUAAHdrK4a+8T80zpyF6NChnSKZFUJo4QWKYqwEet99WvJahmF09m7bDXDYFQgAkXgKb6yqxdjyAC5fkCfptYT8vOpOeGSuz+54q5mjA0gg6yo8MtdnT7G6BoV8yt5z9KwlrEIlzZEP8rTc+UufvfF07M50vmVke/X2WDuTngh3u1KE7OtjMX2IbLsde6wmrgHAn/6E9Ouvw9HSAkUIlH+yFOG6WjRNmw7V5cq/zx07NLFOZv/9gZtv7tu2M72CBbadzHnzh+PrHW14b209osm0oQexLa4ZaIeOH4jz5g/f1U21JJ7K4OUvqrXeT68T27dvR6y9HSmbB6rDDaHYoOWTU/DU8kY8++wbcGz6ACKdMjw4ZchQyNejJhsOcs+g2RMMQJabvdwrJ4sKcjieOS+d3W43eI7R/qnHlKqMktAge+9QdVJK4Cr3sJqFHvPDzar3MFe+DPmBaBbdSCygBPxy+BxN+/1+PcltvmtVCLIHmlkwozBBSlxPgllTUxPC4bCe004ODzTnLaNwB6u8crKQKl9Ps9AknxeFl5LoRO02i0s0bfYcNI/LiXrNxjN5J5pFMXmcjm0l5sreY263O68QlmteIYaTEJ2h2ZSsn+5vEjSbmppQX1+ve6JRTrJEImF4EZC/B7PA6vF4EAwG9fBdEtIGDRqEAQMGIBQKGXJDUrva2tqwfv16XdSTw7apYIcM/a6Li4st86GRcCZ7ovn9/iwhcJfT2mrMpbH//p3jF16I1r/+FaE33gAA2JNJlC/9GPFAAA1DhkIdOgSZULHmiSYEHHV1wP33Aw891NmrSpx6KnDJJf1/PgyzB7K32W4AkFYzaIom0dqeRjrTYVMJgd+/sQZtra341rRSOBRr7+meCGJWgleuZPX5xLC+DD+kZ3sh+cUKFQVlm9b8Sc+2fELfbvcM2g3obyGwP6ZlO7W37d6Z7I7C385qS5/z2Wed47LtNmAAGs8+G6V//jOcHfZr0ZYt8G/fjsjw4YgOGYpkcbEuttkiEeC//wWeeUZLESLbvB6Ptoyrse9WsMC2k3E77PjVKVMNeTiiiTTsNgWVJb6dlkOtp5iracWTCcRtXsDpgVadQQWgAooNsDmRHHMYMv5yeFY+DzuEQQAwC1vkCWQWRmRvIDJM6EWZlpm91mSPH/rMZDK6pxiJYK2trVnFB+TwUxrMIpn5wQlkl4Y3/7G2csMnIchKVJTDDX0+n0EkkwfytgkGg3qoGrXBqnpqoSiKYhDMSKAgkZFEs0wmg2g0irq6OoTDYb0aZjgcRjgc1sVIc1J9up4kXJrFMrkdspcW3T9yNUwSy2TBTPa+MouzVi8MFDJM0/KD10oUI+HLShSTp+X7tBBBzDy/r/OYkIgme6PFYjG0tbXpxTvk6pgtLS0Ih8P6vUNtkoVup9OpG4N2ux1er1cvKEA5yWh84MCB8Hq9BhEZ0O5VKioQiUR07zMS0sz50EhI9vl8WSKa1+vVC1eYQzl9Pl+vDSnzC1dXL2K9Wce9cSPmdxy3vagI9z/2GFpbW/Xf1xHRGPYfWolhNTvg6vi754lEULl2DbB2DYSiQNjtUDIZKEs/zn1Skyb16powzN7M3ma7qUJge1MM0aTW6WdTtK5RASCeEvjLpzVYXd2CK+eVwuPqFMGo06wrMWxniUbdFcOs1jXTlThm7oy08jZj+pZ93cNqd/BG7Mm+cqWk6c6xdyZ9LfxVbNgAylrbMGAA2rdt05e1u90IT5qMwdXb4a+vBwDY0mkEN25EcONGAFrBQ6gqbB9/BPzhD9aNdru1Sp/MbgULbLuAXZ1DrTeYq2kJpx8QTkAIKIoKxaZ5r0FRIBTA63FBGT0bC+ZPxbEjnXqFRvKMIq8m8kSJRqP6yzx5gFkVOiDBKJPJZPU+kCeTOVeZbFjJApl5nPYl/7HM9TAwh5rSNL38y9UeKYeYx+OB0+mE3+/Xc4wVFxfroWk0r6hIK4Jhs9n0a2EWzGhIpVJ6CG5XKIqii2ay2ETil+xlRqJLJBJBOBxGLBbLEsnMlVHlTxnZIJUr2pJwZiWWkWAi5zeRvxdzmAWtQ0IphSeaBTKz555VWCXtqxAPsXwi2a4gmUwaxDMab21tRVtbG5qbm9HU1JTliZZMJiGEyHpRILFV9shzu936vTtgwADdI43uZTkJMu0jmUyiubkZ27Zt03/zcjinnPdOFtyLiooMgjN5Ivp8Psv7xul0GkQruqd37NjRpShWiBDWVwghkEgkEIvF9N8YeQzSd1dRV6cLbOFEAi+99JJhH/NjMTQVFWFLIIAJNTUY3NoCWZJVhIBiVazB4dByuX39tTadyV3EgmGYvct2a4omEUsJ2BQFdlLXACgCyKgCXrcTKxpVfBEL4ZLZo/qlTebOrp6IZGbyCV+cmJ/ZU9mXBcadLSL2lQipqqohR1pCCMRisc71kynA6cSWufuhuGobBq5bB2c8bjh3W45CWxgzBtiwQRtn2223hAW2XciuysPRG+RqWg67AuHywZ5R4bR3VueE0HpHVQH4XVoJ+k9qVTg2v4doW4vu4UTJXEmsMYfXWXkxkYhFRpDZY838R5KWmwUwc0ikjDlc05ybjD5JLCPR0Ov1GqoI+ny+nIPT6dRFMyvhrKamBlu2bMlZxVCGjFQSu+R8blb53ijxu1wp0SxCdoVskJKXlrnSpznnl5wnTBbKZE8zsxEhe6915Slm1UNeaI6xXPN2Z2NbDlGWhbSWlhZdNKNx8kQjcZpEaHNvO4mTdB9QTjSr5P7Um0/Qb622thYbNmwweKDRfUihtHSPWX2HsrhHOfvofqL5JMxSaHZbW9tOv/5dITqMKfI0o+9JFs/Iu9PqN0e/BbdkPAXTaXhdLghJtHREY1CSSWTcbqweNQob02kMbWxEaSSMovZ22KW/IarbDdsBBwBHHKGFg15/fafAVlTUvxeEYfYS9gbbrSWmvfg57KYcrhBQFKDI7URbPIWXv6jGufsPtxQQu5trzPxp9k6R7S4rcYw6HPOFV+6rIgTD7K30awhnfyNV7hzqdAITJujT1cUhROx2OL1eJMaNR9WYsfDV7IC/qgru5mY4o1F9XaEoUCZNAubNAy64AAgGgblztYVsu+2WsMDGdAu5mpYAOkq8awZaKpmCKlQtxsBmB9Q0GprrIWxOxBxOvPLJStgaNxr2Z86zRmFnslhCYhuJTbKnkuyRZBVaSSGoHo/HEGoqizm0P/kln/6QO51OBAIBQ0VL8tQpKioyiGYej0cPQ7USzurq6gzil5zjzTyQZx8JIhS2Sh59ZkFODmE1h1wWAl1LOceXLCSac6yR4EGih7nwAdETUUxuU09CKuV5uxuy6FtI+GAikdC9CUmYIRGNctjJ4YJUIVMWbWSjRPYwTKfTuohF93AwGNQFNfm7ArTvg45DbZSLkshh1XRc83diziEHQM+JR4KsPL4zvkPzC1uuFzh5PJPJ6LkDKY+g7H1GYhqJ+3Qcui42m00P684VZkACqNvpRPyzz+BJpeBRVZw0fjzSEyboIbeVL7wI1xdfIDBoEOw2G2x2O5Tx49EMoFkI2BIJ2NQMkrV18B53LIbcfnvnycv5QaZN6/drzTDMrsFsu6VVFTbp2ZBRM4AAMgJwOmxw2gR8TgU7mqP437K1mD7IlyWSmcmXMy2f91iuzk6GYZg9mqlTgf/8Rxtftiz/ujYbYkOGIjZkKABASaVgS6WQrqmB9+ijMPg3v+lc95FHOsfZdtstYYGN6RZyNa1IPAUhtPzZCsUXCHRUrxNAMqZVqBMZQPHA7g3AZcpZJYtc+gullO9LFsjkeeZ8XCSwyOGB5jxq5OVGx6CXea/Xq4e1FRcX6546ZWVl8Pv9UBQFqVRKF7zkcLvq6mpDCKtVcv5MJmPwHqOQMMr3JW9jbitgFEUKQRYD6JqQuEFhdF6vV/e88/v9hrxlcj48q5x4hYgeJK50VxCTp/s6T1ZfhAb2dh2r75CENFmUpZBG8vYkj0+69yg8Wt6f+eVE9hSjvGQ+n8+Q3J8KEJiFHdofCcb0e6LfFIlz8ndGIZvy96l7YUnenlbenFbCVq4XsELW6Y5YJpNMJhGLxdDc3IyGhga9qAN5A1KIbXt7e5e/R5/PZzmffhvmirwOh0PPJ1daWqpXUg0EApph9rGWP+3KkSOBq6/W91e9ZInWC2pVeUpRoHo8UAFkXC5jZdEtW4AVKzqn58zJez4Mw+y55LLdCO15ISAUBU41jUQcsNsdyAggkbHB7/fnFcd2xw4thmGYXQp5mQHAokVaOGeBfyuF04mM04m02w1hrmD/yiud42y77ZawwMZ0m7P3q8RXWxvx9toGZISAyAjYFM2FFTbtD4ddTcJly8Dm80G1OyEUB2ZNnYTBjuEG0cZKZOlKWCHPL3rZp9xNtD8SkoqLiw1inZzknNoghMgSNGpraw1hlfF43FBNVBbwyJNM9iiTk/nLYa8yheRTINGRkuvKgyyWyR50fr8fXq9XF1To/D0eT7cEK/N3k08Qkz0QZU/EroQo8sYrVAjrrli2O6GqquFeoXBJ8jqTRTQSa0mAlXM6yIKQnIeOoO+fvND8fr/+3cvh1ZTDjPZHYdo0UEU18gKl4hl0L8lVVylk0+/36x5wgUAARUVFCIVCemXOXeXeL//Gw+GwXsCBwmibm5t14UwOoe0usveslYDm8/kM3q800LUKBAJw5SrPfuaZusCGBx8ErrvOYKSpySRS27fnbZ9qLnDy0ENaBwigGYFUPp5hmL0SqoT6xqpaqEJoHaLo+FDsUGwCHpsKt0giEklCtTkBuxMuJaPbGQzDMEyBHHUU4PMBsRiwebMmsh1/vL64R7bbtm2AnIv3pJP6sMFMX8EC204insrg820tiMTTCHgcmLmbJcXNZDJZCdHpRd88JBIJDBfAdGcIH8VLkYICITryZdkAj5KBz65AQQAAEM44ELKnsV9lORxK1y+ucuJb2cNNDnm02+3w+XwGTzhVVQ1eKXK4JgDU1tYaqoKmUik9VE4Wy+R8UbIXnCyYAdleP+TRI4eCkbeK7BUke4fJyf0pNE6uGEreZXK1StqXufImebmREENil6qqaG9v148vFwwwnwdNUy63Qry89mXMnlKySJVKpdDe3o5IJIKWlhZEIhHLYh30vckCFHl8AZ0VcOleIuFKHsi7QFVV/XsnIRYwikxyYQEhhC6OyR6i5uISAAwitVyVk0S8nQ0VCKBzikajer45OQdda2urfr49Ec9I4M4nnvn9fl04o+tSsHhWCBddpOVLa28HNm4Efv97TWQDEDjkkIJ3o6+7fj3wu991Lrjqqp63jWH2YnZ32607UCXUseUB/OHNdYinMlSTCk67DcU+J0r9LkCoiMXaURdJwp+K4uvFixDeWIHRo0dj6NChGDBgAIttDMMwXREKAeecAzz6qDb9wx8CRx4JeDw9s90AzfajjvXZs4H99+/DBjN9hSIKfNuoqqrCsGHDsG3bNlRWVvZ3u/Ya4qkMnv64s6x7RhWw2xSUBz39XtadkqDL3jFWghl5y6SEgpqkG0lhg0tRMciVgLMLQeyTSAifhEtgU1S4FQGHonZ4RQgIASSEDWlhwxT7DkzAdt1zhl5MAWNeKtlTzCofhxyaJie0pU8SD6zylcmVOGXRjLYnAcPKI8sq/E4W1SiskoQws7ePOcRV3qd8/rJAZpWbyZwI3lxsYG/NY9IfoYGFrmNeX658S0ns6+rqUFNTg9bWVrS0tOi5uaLRqO4BKQtpZkFNDgGmUEHyBqNPr9erH5/Wo4qZBFWBpSGZTFoKZzSYw3ooZNhKSHOaXdT7Cdm7kc4jFotlCWfkcSb/5rsjnpEHZi7RTA7btBLMZDGtV+JZoVx3HXDvvdq4xwMsWQLMnNn9/cTjmoH34Yfa9KhRwMqVgNebfztmj4Rtt56xK223ntBdIfChdzfggXc2wOlQUOR2wuuyG3KytbWn0J7K4OxZAzEB27Fu3TrEYjGUlZVh8ODBGDJkCAYOHMhiG8MwTD5WrwZmzADIE+173wP+8Adjyo5CefJJrcOV+Ne/gFNP7ZNmMn0LC2z9SDyVwQ0vrMB7a+uhCoGgxwmn3YZURkVbPAWbouDQ8QPxq1OmFmyoUY4gs1hGVSHleZRsPB+ZTAbJjMCXsSDWJUKIqg4tjRoEvEhiBBowRlRDEZksz6VMJoO0CqzyTUe9YyCgKHApGTgUQCh2JIUdigKM9iWwcGALnDYlq8qlvp+OkEpzWJwcVmeVTJ32IXuXkWAle2bJgpU56bhcWhnoTMxP4XHmgbyLZJHCLMJYJfktdP6urJTTVZ6rvsqR1V2xbGdCnlHyb4tCh2tqavSQQrk6pOzhaN4XfZ+yB6Yc5ksiGoVh2u12Xfg1V8+k318ymTSISrkENIfD6KSsKIohD5tZTOvPPDqqqhq853KJZ5TjzFyFtFDMRTdyeZ/ZbDa43W5LwUyet7OExS6JRjUjjUqzDxyohRvMnl34PiIR4IwzgFdf7Zz39tvAYYf1aVOZ3Qe23bpPf9hu/dnWngiBiXQG1z9f+Dkmk0msW7cOX331FXbs2AEhBAYNGoTBgwcjFAphwIABLLYxDMNYcfvtWhQC8bOfAb/8pTERZlf87W/AhRcCVBn+1FOBf/6zZ0Id0++wwNaP/PmDTXjo3Q3wOu0IerNf0qiH8JKDRuCMGQOyvMqsPM7oJT5fInWrvFWyiCVPZ4SC9SX7odk9CIACu5qEIlQImw0ZmwsKFJRn6jEzsxpuhzEPF3ltqbDhy/YQNiSLERNOqAIABFyZGMqiWzCgeRUyqU4PMjlPGbWVbkMSyGSxzIxZKAOM3m6y1xl5qVCxBLNQJotlJHqQMNETkaxQMchcjn5XeWSZt90jy2D3kGQyqYdDR6NRtLa2oq6uDtXV1WhqakJbW5vuidbe3q7fryR+ATCEL5vH6ffhcrl0wcbn8+kiGnkqyiIa3e+U5y8ej+tisZWAZiX+2Gy2LPGMpn0+X58LlnIRD1k4o6IMsseZeb3uiGfy3518IZt0fiSe5RPQdhvxrFA++gg4/HDNCw0AXC7gllu0sIOuzuWdd4BLLtFCTImf/AS4887+ai2zG8C2W/cp1Hb77oIxuOTgUbughRq9FQIT6Qz++lH3xbn6+nqsWbMGa9euRTgchtfrxZAhQ/R0BQMGDMDAgQPhZa9YhmEYTRQ79ljgrbc65y1YoIWOjh2bf9umJi2C4cknO+eNHKnl5S0v75fmMr2HBbZ+Ip7K4KyHP0JVcwyDgh5Eo1HtBV3NQM10CmIR1QGv2o79o0uATEqfLwthsmgmJ8ynl/lcA2BMgG4WxxwOB+pKpqA6OAV2kYET6SzhKgUHMoodk5QqTFS0XktzRU0Kg4unMoh5BiKlOJBpj0Bt2AQ1lTC0V943CRIycuikLJTZbDaDUEYDCQ0kUni9XsNA61i9lJOQJi/va6+tXOP7kpC1K0mlUrqIRtUha2trdRFN9kRrb2/PWyRBFoJJHKPvUw4l9Pv9cDqdushGhTfknGb0G5V/1yTGmQU0q3uF8n5ZeaJ1t6BFLtLptKVoJuc8o+IAVssLEc/kEOt8opksnAGaZ19XYZt7nHhWKK+9BpxyipaPjRg8GPjOd4BjjtHCRgMBrVrV6tWaKPfnP2shpTJXXgn88Y/c+7mXw7Zb95Btt8EhTSBqamqE2+2Bz+/TK6bvaG1HZYkPf79s/i7LydZXQmA8lcEX21oQjqdR5HFgRoF55pLJJDZt2oR169ahpqYGqqoiGAyioqJCL+xCYaQstjEMs08TDgPf+AbwwQed8+x2rUjBuedqxaaGD9dssvp64LPPgOefB/76V61IAjFiBPDmm8CYMTv/HJiCYYGtn/hoYyOue/ZzeJx2+N12bNq0Wc+dJKPanBB2JwZueAXecFW3jyO/kMovpXIoI3lqUVgUeV0JmwPPNQ9DW8aBoCMDQEEymUAqnTaIgCmnH0qsGY73/oR0IqbnkzKLELIXlNU4hblRW0kUc7vduneNz+fTX5JDoRCCwSBKSkoQDAbh8/myPNHkc7XyMmMha+8nnU4binM0NzejuroaO3bsQENDg+6FRkUG6L4lD0nzPSIXsQCg30ckgMmVWs0CriyikSBHYZfpdNoy7DjXPep2u3OGclLxgp4gFzuwEsZIPA+Hw5bryAUDckG/80KEM/P5ezyevMJZUVFRVujrPsennwLnn68JaFa43UAq1ZkIV8br1bzWvv99Ftf2Adh26x5G282BZDKBrdu2QagCTpcTwaIgAkUBpFQF8ZSKe8+cif1Hl+30dpqFQCFUNDQ0IhgsgtvdWXRmZwmBdXV12LhxI7Zu3YpoNAqPx4NQKISioiIoiqJ7trHYxjDMPksspkUcPPCA9XKHQwsbtdALAGhecI89BgwZ0n9tZPqEffwtpf+IxNPIqAJOuw1AdpJ8QhEqhKJA2F26MESCEXlmyaGM5KXl9/t1Txn55V5+ye/Kk2plfRKZxQ0o9drgddmgQMGmzZuQSppyt6mAcAWgFlfC1bxZD2cjzxOqfEnVLykpe0lJCYqLixEMBhEKhVBSUoLS0lIUFxfrFUDpJbs/cz8xez5yldtIJIKmpiZUV1ejuroajY2NuogWDof1ypyAMb+eORealTeaoij6b0wWqM0imgz9DgBNVDOLaPlCMuk3YyWkdUdEMlcJzRWyScU+zJ5m5vFc4hl55dHvPleus1zidj7xjD73efGsEObOBZYv13J4/PGPQGurcbnV96cowAknAHffDYwbt3PayTB7GEbbDQiHIxCq9sxIxOOojyfQ1NQEl9cLxV2EqtpGzBtVutM78z7f1oK6tjiCHqfezubmZjS3NMPldCFQFEBxqBhBjxN1bXF8sa2lX4XA8vJylJeXY9asWdi0aRM2bdqE1tZWRCIRBAIBpNNpRKNRbNmyhcU2hmH2TXw+4P77gdNOA370I2DZMuNyKQWNgREjgJtuAr79be4Y3UPgN5l+IuBxwG5TkMqocDls8Hq9HUKSDTabHTa7DXabHSloVTZ/cPVVmD9mgC485QtJ7CvqVtXCZm+B3+uCy9GRt8jlhlCF3k673Q7F4UTG7sax3zod+w/zY+DAgSgtLUVZWRlKS0vh9/vZU4zpNaqqGkS0hoYGVFdXo6amBvX19XpOtEgkgkQikTM/HyGEQDqd1r3UzCGdJKCZPSPlXF4kvsl5+QDkrcRpbk8uAY1+6/mgggu5RDNZOKNQ03xeZ1QYwYzNZtPPmwovWHmf5RO/aDtznjMWz/oJj0cT2H72M+CZZ4D//lcLKdi2rXOdQEArgnDIIcDFFwOjR++69jLMHoDZdisuLobD4ejovIkj0+HdH43FIRIqnnz0QWz+cCgOPPBAjBs3DgMGDNgp7TQLgaqagcNhRzqTQTKZRFNjE5qbmuB0e6F4AqhrDgPof087t9uNiRMnYuLEiaitrcWGDRuwY8cOxGIx/XkQiURYbGMYZt/liCO0SISlS4G//EXLp/bll1rkAaCJaBMnap2pZ5wBLFyohZMyewwcItpPWOXxsGJHazuGlfjwzC7I42EOhchFNJFGPJXZZaEQzN6DECJLRKuqqkJNTQ3q6ur0nGiRSASpVCpLRDMLzJQTMJPJ6GGc5D1F036/31DIgj7l0E9ZTKJ95KrEacZut+ctKmAlPquqWrBwJhcAyed1ZiWekZdoV1U1uxL6ZPHM6tPv97N4trsQDmsVR51OoKSke1WqmL0Ott26R27bTSAeT+j5OqPCCUeiFeVf/Q1Q0wiFQhg5ciQmT56M2bNnY9y4cSgtLe23dlrZb0IIveq1lg4hg4zihLA5MGz7W5g3qhRHHXUU9t9/fwSDwX5rm5l4PI5NmzZh8+bNiMVi8Pv9KC4uhtvtRjQa1ddjsY1hmH2WZFKLRlBVIBjU0nkweyz8RtRPeJx2nDhjCB56dwPa2lM5E9DaFAUnzBiyS5LkzhxWjPKgB1XNsbwCW1s8hWElPswYVrzzGsfssQgh9DxebW1tqK+vx/bt21FdXY36+nqDiEYeZjKy8AVAL/iRSqUMoZyU449yklmFctJ+3G637rUme7DlqsRphnKv5SoqQKiqqgtfLS0tqKmpsRTFzCGY5kqc5s94PG4Qz0gUJIGM2mcWzwoJvabiDFYeZyye7YEUFWkDwzDdJrftpuhpOpoi7RCxJCb7mlA2ZRJqamoQiUTw5ZdfYs2aNXj//fcxbNgwjBo1CrNnz8b48eNRXFzcp+20st8URdH/hmcyGYQjYdSFk7DHW+CJ1WLZss1YtmwZQqEQpk+fjiOPPBKTJk1CcXFxn1eXlvF4PJg0aZLu1bZx40Zs374dDocDFRUVKC4uRiqVQmtrK3u2MQyzb+JyAQMH7upWMH0Ee7D1I4l0Btc/3/MS6juDPaUcPbN7QaGLsVgMra2tqKurQ1VVFXbs2KF7olEYSCqVytqevMQA6FVySUiTC2OQGEYeZeaiAiSOUc408jijbSj8s5AQZqpMaVVUwG635/U0sxLBiEwmk9PjjD7pGpmLleTyPiv0Zcgsnpk/WTxjmL0btt26T6G220+PHIF1a77Gl19+ibq6Ov35Fw6HkclkEAqF9FQastgWCoX6pJ2F2m9nzijFgMavsHTpUtTU1CCRSOjPqkGDBmHGjBlYsGABhg0bhpKSkqw8o/1Be3u7XhQhFouhuLgYw4YNg9frRWtrK1paWvR1WWxjGIZh9iRYYOtnEukM/vrRVrz8RTXq2uLIqAJ2m4LyoAcnzhiC8+YP32XiGrVvdxcBmV2HLKLt2LFDr85ZW1ure6LFYlplWTMUnki50EhIAzRPL/JAk8MUSRhzOByGHGfmqrNy6GOhIpGiKFlFBcizzWazIZ1O5xTOrERCwCieWXmdtbe3I51OZ3mcWYln3RHOABgq7prFM/I84+IhDLNvw7Zbz+iO7RaPx7FmzRqsWrUKTU1NaGhoQG1tLerq6tDe3g6Xy4VQKAS3241QKISxY8di9uzZmDBhAop64W3aXftNVVV89dVXePvtt/Hll1/qRYGSySRsNhtGjhyJWbNmYfbs2SgvL0dJSQkCgUBfXVJLhBDYsWMHNm3ahJqaGjidTlRWVqKyshKZTAaNjY0stjEMwzB7FCyw7STiqQy+2NaCcDyNIo8DM4YV75KwUCt2dxGQ6V9SqRSi0ShaWlqwY8cOQ0605uZmPecMiWMy5IlGCfYzmYyhOidV1STPMzk/GiXWJyGNxn0+ny58keea1+stKJTTZrPpOddkTzYSmsyVNq2EQUIW3HKJZwAshTMSCGm8u0VASAQ0e5zJYZssnjEM0xVsu/WO7thumUwGGzZswIoVK9Da2opYLIb6+nrU1taisbHR4Iltt9sRCoUwbtw4zJkzBxMmTIDf7+92+3pqv7W3t2PJkiV47733sHHjRt0DO5VKwe12Y9y4cZg1axZGjx6te+GFQqF+DSWNxWK6V1s8HkdJSQlGjhyJiooKtLW1oaGhgcU2hmEYZreHBTZGZ3cWAZnekU6nEYvF0NjYiJqaGmzbtg01NTWora1FS0sLwuEw4vG4pYhGghiQXZmT5udLmO90OvUcaSRCeb1evbIuVemkQgT5xKhUKgVVVXXxSq64qyiKHrpqdR7m62EVtkn50Uh4y1cggMTC7pJLPJPDNlk8YximL2DbbecjhEBVVRVWrFiBmpoaANCL+tTW1iKdTsPpdCKTyUBVVdjtdgSDQUyYMEEX23w+X7eO2Rv7ra6uDu+88w4+/vhj7NixA6qqIpVKIZVKoaioCJMnT8bMmTNRWlqKkpISfeivUFJVVXWvtrq6Ot2rbdSoUfD5fLqXIIttDMMwzO4IC2wMs5eQyWQQjUbR0NCA6upqPZyzrq4OTU1NiEQiiMfjUFXVsB1NU880eaPRPBLByOPM6XTCZrMZ9iPnO5Pzo5E3mlxYgEI8CRLtksmk3oNOHm50bJoupChBLvEsmUwinU53KZ71JieZ3+/P6XVG4/3pAcAwDCPDttuupbGxEStWrMCmTZt07+5wOIyGhgZEo1E4HA6oqorW1lZdbCsqKsLEiRMxd+5cTJw40VBIpz9RVRXr1q3TQ0ibm5sBQE+RMHDgQEyePBkTJ07UO4pKSkpQWlraI++7QohGo9i0aRO2bNmCRCKB0tJSjBw5EkOGDIGiKGhsbNSLJxEstjEMwzC7EhbYGGYPQlVVXUSrqqoy5ESjcM5kMmkQv4QQeuim7HVGIprdbofX6zVU7yTvMBK/gGxPNBKmqMCAEAIulwtut1sP0STBLJVKGQQ0WUSjKp9yjrVcIpeVeJZKpfQiCeRdl8vjrDdYiWeU74w8z1g8Yxhmd4Jtt92DSCSCVatWYe3atbpgJYTQK2rT87alpQXNzc0GzzZZbHO73TulvfF4HMuXL8c777yDdevW6SkRMpkM3G43hg4dikmTJmHEiBF6mgcS2/ojlFRVVVRXV2PTpk2or6/X2zBixAgUFxcjnU6joaEBDQ0NLLYxDMMwuxQW2BhmN0MIgWg0irq6Omzfvl3PiVZbW6t7oiWTyaxcZ3LoJgA9ZFJRFLjdbvh8PkMONBKdMpkMUqmU7r0li1PkrUaiGiXrJ482QBO9SECTxbR0Oq0Lcl6v1yCgeTyeLAPcLJ6RcEbiGVU9MxcG6K1wpihKQWGbLJ4xDLOnwbbb7kUymdQLIsRiMX2+EAKxWAxerxdCCDQ3N+thkJlMRvdsmzx5si627Yxqn4Dmhffhhx9iyZIl2LZtmyF3aSgUwpgxYzB27FgMGDBAT9lQXFzcb6GkkUgEmzZtwtatW5FMJnWvtsGDB8PlciGVSuUU2wYOHIgBAwbsNK9AhmEYZt+DBTaG2QWQMV1bW4tt27ahqqpKrzrW1NSEcDiMVCqli2iUqyWdTkNVVYMnGqCFO3q9Xvh8PoPoRIKanOtFFshIRKNCBbRfWYgDtN5jEtBk45rEN4/Ho3uyyUIabZ9KpXThjEQzEs5IlBNCGMSzvhC0SDzLJZyxeMYwzN4M2267J6qqYtOmTVixYgWampqylgeDQfh8PjQ0NKCmpsZSbJsyZQrmzJmDSZMmFVQEqLdkMhls3boV7777LpYvX476+nrdRnE6naioqMDEiRMxcuRIeDwe/fkfCAT0/G19GUqayWR0r7aGhgZ4PB7dqy0UCkFRlJxiWyAQ0D3bWGxjGIZh+hIW2BimnyBPNBLRtm/fjh07dqC+vh6NjY2IRCJ6qIhchZPEMBkKo/R6vfD7/TnDKMljTV5GOcxkIYuENjNWSfvtdrsumpkH6i2mwgLyech51WTxrK/ELEVRdJEsl/eZz+dj8YxhmH0Wtt12f6qrq7FixQps3749a1l5eTkGDhyIRCKhJ/0nsS2dTuti29SpUzFnzhxMnDhxp4ht7e3t+Prrr/Huu+9izZo1WeLVyJEjMW3aNFRUVEBVVV2Ic7lcutjWl6Gk4XBY92pLp9MoKSnBqFGjMGjQIN2DLplMorGxkcU2hmEYpl9hgY1hekkkEkFNTQ22bt2K6upq1NTUoK6uThfRyOuMPNBkIYqghP4koAUCAT2PGUHimRDC4NlGoZRy5a9MJqN7ppGYRoOViOZ0Og3CGXmiyd5tcvvJmy2RSADQPOh6UlEzF7J4Juc5kwU0Fs8YhmHyw7bbnkNzczNWrlyJDRs2ZHWyFRcX6wUPtm3bhs2bN6O+vh4NDQ1obm7WxbZgMGgQ23pTtKdQmpqasGzZMixevBibN29GPB4HoNk1ZWVlmDhxIqZNm4aSkhLEYjHdC14OJS0tLe0TYTCTyWD79u3YtGkTGhsb4fF4UFlZieHDh+tebQCLbQzDMEz/wQIbwxRAJBJBdXU1qqqqDOGcDQ0NiMViemEB2QuNBgC6VxmJaMFgUBfRqMqmHE5JecxIQKMQSxK4yBONvNrMhQfMeckURYHL5dIFNAonpe0URTGEayaTSSQSCUNOt77CZrPlDdskz7O+Pi7DMMy+Bttuex6xWAxff/01Vq9ereceJTweDyZPnoyxY8eivr4emzdvxpYtW/QwSLPYNm3aNF1s622+0q6gkM3Fixdj+fLlqK6uNhRJGjp0KGbOnInJkyfD5/Ohra1NL54A9H0oaVtbGzZu3Iiqqiqk02mUlpZi1KhRqKioMOSFY7GNYRiG6UtYYGOYDqLRKKqqqvRwThLRyBONBC8S0sgzjSp0Uu4zp9MJn8+HUCiki0Vy5a9UKoVYLKbnM0skEnooZSKR0MNGCTmvmcvl0oU0s9eYoijw+Xx67ha3262vT15odCwqItDX2Gy2LE8zFs8YhmF2DWy77bmkUimsW7cOK1euRCQSMSyz2+0YN24cpkyZAq/Xq3u1bd26FU1NTbrYlkqldLFt+vTpmDNnDiZMmNDvYlt7ezvWr1+PDz74AKtXrzbka/P7/RgxYgTmz5+P8ePHw+FwoKWlBW1tbfo6brdbL5LQ21DSdDqNqqoqbN68GU1NTfB6vaisrMSwYcMMXm0Ai20MwzBM72GBjdmniEQiuohWXV1tyIkWDoeRSCQs86FlMhndC81ms8HlculeWMFg0CB8AUAikUAsFtOFrPb2dl2UI08xCvWkYgNyoQAK56TjkXDn9Xr1dVwul35cMj4TiQTC4bChEEFfYSWemT9ZPGMYhtl9YNttz0dVVWzZsgUrVqxAQ0ND1vIRI0Zg6tSpKC8vRzKZxLZt27Bp0yZUVVVZim2hUAjTpk3D3LlzMX78+H4X2xobG7Fq1Sp8+OGH2LhxI9ra2gBoNkVJSQnGjRuH+fPnY+zYsRBCoKmpSc8xB0BvM3m39TSUVAhh8GpTVRVlZWUYOXIkysvLs6qdstjGMAzD9AQW2Ji9CiEEwuEwtm/fjq1bt2LHjh16TrT6+nqEw2FDOKfZC42qa5KHmNvt1itNkheZ0+mEEAKJRELPsUaDXGmTBqAzx5q5QAAJZj6fTzcaqaCBy+XSveKoSEEikehX8cyqyiaLZwzDMHsmbLvtXdTW1mLFihXYunVr1rLy8nJMnToVw4cPh6IoSCQS2Lp1KzZt2oTq6mo0NTWhsbERTU1NuthWXFysi21jx47t15xtlB9t+fLl+PTTT1FVVaXna6MqpFOnTsXcuXMxatQopNNpNDU1obm52RBKWlRUpHu39TSUNJVK6V5tLS0tBq+2YDCYZeuw2MYwDMMUCgtszB4H9UJSPrTt27ejpqZGD+mk6pzmXGh0q6uqquc8IzHL5/PB5/PB7/fD5/Pp4paqqojH43r4Zjwe13OhyYOqqrDb7XA4HAbvNlmcImEuk8noudKoaAHth4oS9CWyeJZLQPN6vSyeMQzD7GWw7bZ30traipUrV2L9+vWGgkn/z96fxkaWpvm92P+cEyfixL5HkBHB4B7MTDKrqrOru6d7pqelaei670BjSRe2JcEDCLjCtQADxhVkQP4w+mRgAH8T/EWGrqFrw7iCfAHfqzEGwh0DHmNGM6qla8usyo3B5B7BYOz7iRNn9Qf2+1YEtyRzqcpkPj8gipk85ImFzOLD3/k/zwOcCKj19XWsrq5yYaZpGvb397G7u4tKpYJut4tGo4Fmszkl29577z0u215nsk1VVezv7+Pjjz/GkydPcHx8zC8e+v1+Pq/tBz/4AbLZLCzL4rLtvFbSWCyGUCh07VZSx3HQ6XSwu7uLcrkM27aRSCSwsLCAZDJ5JtUGkGwjCIIgLocEG3EtNMPC/cMOBpqJgOLCB3MRKPKrL8Js20a/38fh4eFUO+fR0REajQZPjk22XbJ2TrYSngkjJr3Yhk5WiPn9fsiyzGeTaZqG4XAIVVWnZqRNnn+ylZMJq0gkwts6JUniSwjYFjCWkJtMpr2q7ZeSJD23bZPkGUEQxLsJ1W43G03T+EIElgZjeDwe3Lp1C7dv34bgcvPaTYaJsNVBpXSA4+NjLttarRZ0XYckSYjFYnxBwuuUbawltFgs4tNPP8XOzg4ajQbvKIhEIlhYWMCHH36I27dvI5VKwbZtdDqdV95Kqus6SqUS9vf30el0uOjL5XLnptrY55BsIwiCICYhwUZcCc2w8G8/PcCfPjhCrafBsh1IooBUSMEfvJ/BH/5WHh7X9Qowy7LQ7XankmgHBweoVqt8sQDbaslk16RAYwKLzUZTFAWBQACRSIQXWMFgEG63G7ZtYzQaQdM0dDod9Pt9jEYj3tIJgJ+Pbfr0eDwIhUK8hVNRFH6/bNMmezysnXRyNtrLSi0mzy4SZywZRxAEQRDnQbXbu4FlWXj27BkePnzIZ5wBgGkDX/R82BqHMbQlOI4wVbv9F3cTOCodYG9vD7Va7VLZ9uGHH2Jpaem1tZGapolKpYKHDx/iyy+/xMHBAX8usiwjkUigUCjg3r17WF1dRTQa5WNBLmsljcVi16qVHMdBu93G3t4eyuUyHMdBIpFAPp8/d1Ybg8m2er0+9TUg2UYQBPFuQYKNeC6aYeFf/MlD/MdiHbbjIKTIkCURhmWjpxkQBQG/W0jij//exhnJZpoml2gHBwfY39/naTTWzsnmijGBxiTWpEBj7ZdsS2Y0GkU8HkcsFkM4HIaiKLBtG6qqot/v841U7NzsvJPLCjwez1SqbLI907ZtvoxAlmW43W7+cUyknd7ieR0kSbpwy+Zk2yZBEARBvChUu71bOI6Dw8NDPHz4EOVKFf+hGsSzoRsOAEV0oLhdkNweqIZzpnbr9/vY29vD3t4eGo0Ger0el23j8RiSJCEej3PZtrCw8MILB57HcDjE4eEhvvjiCzx9+nRqXpvP50Mmk8Ht27fx/vvvY2Fhgc9iG41GaLfbaLVa6Pf7L91KOh6PUS6Xsbe3h16vB5/Ph1wuh2w2e2GqDTiRbY1Gg7+ODJJtBEEQNx8SbMRz+Td/vYt//Zfb8MoSQt6zxVR3pEMdm/iDFQVrwhHf0HR8fIx2u43hcAhd16dmoZ3+tmPLBdxuNwKBAMLhMKLRKBKJBBdpPp8PlmVhOBxC0zT0+330ej1+/smFBYIgQBAEfn9MqkmSxOevsVlssixPpc8m//wiV2pdLteFiwJInhEEQRDfFVS7vbv8n//sG/y3Hx1AdEwo4nTN5XK5YIlumBDxT36xjH/8O4tTx3u9HnZ3d7G3t8dlFZNtmqZBkiQkEgku2+bn51+LbHMcB81mEzs7O/jyyy+xvb3N57WJoohQKIR8Po+NjQ1sbGwgm83yhJlpmhe2kkYiEb4o4SqP27ZttNttPsNOEATEYjHMz89fOKuNQbKNIAji3YIEG3EpmmHhH/w3n6DUVjEb9qLX7520Vo6/3ZZpWRZMdwCC2ob7r/8VBPv8LZcsiaYoCp9dFovFEI/HkUwmkUqlkEgkIIoi39DZ7XbRbrfR7Xahqipv6dQ0bap1lEk7JuomxRlrH52UZ5OtnNeZLXJanp33lgolgiAI4k2Aard3k8naLRVwQ1VVHB8fw+12Q1EUXvf0LQnZiIL/1//25wj6zq9dOp0O9vb2sLu7y0dsnJZtyWSSbwDN5/OXCqcXxTRNHB0doVgs4sGDB1z+2bYNt9uNWCyGpaUlbGxsYG1tDTMzMzypxpZjsXTb5Ly667aSaprGU239fh9+vx+5XA6ZTObSVBtAso0gCOJdgAQbcSmf7DTxz/77+1BkCX6PC8+ePTszSBcAHMkNuNxwf/nvoPQO+RKASCSCVCqFbDaLfD6Pubk5xGIxuFwuPg+NrY1nG606nQ5fNMDaNFkLKft2nWzZZPLM6/VCUZSpJBqTaFdpB3C5XBdu2SR5RhAEQbxtUO32bnK6duv3+zgsHZ6M3nDwmzEZbkiKH44o4385p+JX95Zx584d3m55Hq1Wi8u2Xq+HwWCAer3O559JkoRUKsVlWy6Xg8fjeeXPj7WQPnr0CE+ePMHh4SFvCfX7/UilUlhZWcHdu3extLSEWCw2Jb4uayVlM3yf10pq2zZarRZ2d3dxfHwMURR5qi2RSDxXMpJsIwiCuJmQYHsHuc4m0P/v4yr+6N9/g4jPDbdLxN7+HoaDAQThJCXmcklwuWRIsgeWy4P/8j0f/vYP5hGLxeDxeGAYBiqVCkqlEiqVCg4PD1Eul9FqtTAcDvmMtEmJxto5mRxjiTO2BdTr9Z6ZiSbL8qVXDc+TZ6ffUiFDEARB3CSodrs5vEztVq/X0Ww2+TxaPqpDlCAoQczXP8LP5oNYXV3F8vIyNjY2EIvFLn08jUaDz2zr9/sYDAb8gqmqqpAkCTMzM1hfX8e9e/eQy+VeeZ3lOA4ajQb29/fx9ddfY2dnB0dHRxiNRhBFEeFwGJlMBqurq7h79y7y+TwCgcDUOUzTRLvd5jfLsgBcr5V0NBrxWcODwYCn2mZnZ5+bagMul23JZBLxeJxqVIIgiLcEEmzvEC+yCfT0VdDxeAzgZKPTt9F7G92hhoFm4FfhY4jNHVQqFTSbTfT7fT4fbTKBBpwUL6clmt/vh9/v51s7J49fVNwweXZR6iwYDL6WK6gEQRAE8SZDtdvbz6uo3QDAMAz0+j0M+gOMRiPYjg1bkOGILgif/j8gNLahKAoymQwKhQJ+8Ytf4Hd+53eQzWYvfXyO46Ber3PZNhwOp2QbS7ZNyrZMJvPKZ9EahoGjoyNsb2/jyZMn2NvbQ7VahWEY8Hg8iEQiyOfzWFtbw61bt85N171sK6llWTzVVq1W+WKIubm5K6XagItlWzAY5Mk2qmkJgiDeXEiwvSO86CZQNsfjsDlExHMye4LNQWOpM9u2YLqDEEcdhL/4v03NYDtJubmm0mZ+vx+hUAh+v39qNtp5SwVkWb4wccb+TIUGQRAEQZyFare3m5et3dj83NNYloVer4faQAeGTQR+/W+gDU8uiLIN7rIs862hf/fv/l389m//NoLB4KVpLMdxUKvVsLu7i/39faiqiuFwiGaziWazyWXb7Ows7ty5g3v37mF2dvZKs8+uw2AwwOHhIYrFIra2trC/v49Wq8VbSBOJBPL5PG7fvo3V1VXMzMycO493NBqh1Wqh3W5Pya7JVtJwOHzua6KqKk+1qaoKv9+PbDZ75VQbcLLFtNlskmwjCIJ4iyDB9o5w8SZQB5Zto6vqGOkW/tYcsGId4vDwEJVK5eSqpGcJ3cyPIVg6BPOc+WsuBZDcCJU/Rbz5NTweD0+OsY2Z5y0VuEyesbdUOBAEQRDEi0G129vN87a490YGRoZ17ibQq37uf/mzPJKtb/AXf/EXePToEdrtNjRNg2VZsG0btm1DkiTEYjHcu3cPf+fv/B3cvXsX4XAYfr//QlHkOA6Oj4+xu7uLg4MDjEYjqKrKFySoqgqXy4VsNotbt27hBz/4AWZmZi6dAXddWLru4OAAT58+xd7eHm/jZFtIZ2dnMT8/j1u3bmFpaQnxePzc5/SiraSWZfF22lqtxjew5nK5K6faAJJtBEEQbwsk2N4BTm+T6na70HWdt26yIsqQ/eem0BxBwvD2H0BPrEIQBIjmGKJgQxBl2C4PRFFERujgd71l+JRvWzovE2ckzwiCIAji9UK129vL6RSaAwelwxL8fh/C4Qi/WFnpjpCL+vD//N/81tRMtrFp4Y/+/dXTb2xo/1/+5V/ir//6r/HgwQPUajU+r82yLFiWBUEQMDMzg3v37uGnP/0pNjY2+FKAQCBwrpyybRuVSoXLtvF4DFVVebJNVVXIsoxsNou1tTX84Ac/QDqdPjMv7WUwDINv/2SpNjavjbWQ5nI55PN5rK+vY25uDqFQ6NxzsVZSlm473UrK0m2nk3nD4RClUgmHh4dQVRWBQADZbBYzMzNXTrUBJNsIgiDeZEiwvQNMzuLwygJ2d/cw1sdnPs6R3IAoI/zkTxAYHUNRlG/nmEWiOA6s4gApaHADoghJFBFVRPw878V/8V4S0fC3c9Bex4p2giAIgiCuDtVuby+n56gNh0PsH+wDAARBgOJREAwGIHsDMB0R//Lvf4CfLMWnzjE2Lfx3n1xvfhsA9Ho91Ot1PH78GH/+53+OBw8eoNFoTC1IsG0bsiwjmUxifX0dd+/exdraGlKpFEKhEK8HT2/itCwLR0dH2N3dxeHhIXRdh6qqaLVaXLZNJts++OADpFIpBIPBV/ba9vt9HB4eYmdnB7u7u9jb20O9XodpmvD7/YjFYpibm8Pi4iJu3bqFbDZ76ZKBi1pJFUXhc9smBZppmnw5Q71en0q1xePxa9XQJNsIgiDeLEiwvQNMb5MSsLe3h9FIO0mjiQJEUYIkiYAkw5IU/H6yi9/K+xGNRs+kz1weL55UVfQ1E0HFhfcv2WJFEARBEMT3B9Vuby+nN4HWajU0m004mC7bBdEFQQniP4s28Q9+dx0/+tGPzsggzbDw4LBz7dptNBqhXq+j0+ng2bNn+Iu/+As8fvwY9Xqdt0g6jgNZluH1ehEOh5HL5bCxsYFbt24hFoshEAggFArxhNukcDNNE+VyGbu7uyiVSjAM44xsm0y2ffDBB0gmkxcmy67LZAvpzs4O9vf3cXBwgHa7DQAIhUJIpVLIZrN8w+rs7OyZecGTPK+VNBaLIRKJQJZlOI6DwWCAcrmMUqmE0WjEU23pdPpaqTaAZBtBEMSbAAm2d4DTV0Hb7TYgALJLhiRJkCQJoihibDnQDPvcq6AEQRAEQbxdUO329nLeJtDxeIxOt4PhYAjdOFlIwLoPgo/+R0SMBmKxGNbW1vDee+/hpz/9KRYXF59zT8/HMAw0Gg00m020223s7u7i17/+NXZ3d1GpVKDrOoCTre5sBIjH40E6neZbO1OpFCRJOiPcWKuraZo4PDzE3t4eyuXylGybnNk2NzeHtbU13L17l8u260ioi9B1HUdHR9jf3+ez2kqlEvr9PtxuN8LhMG/nvHXrFhYWFpBMJp+79OGqraTsNd7f30ej0YAsy4jFYi+UagO+lW31eh39fn/qfkm2EQRBvD5IsL0DPG+bFKPSHWEu6sO/OzXHgyAIgiCItw+q3d5eLqvdHMeBNtbQ7/fR1hyIozaCn/0b2MaJ6GJzcH0+H+LxOO7cuYMPP/wQP/rRj15qY6dt22i326jX6+h2u6hUKlxKPXv2DKVSCaqqQhAEBAIB+P1+vk0+kUjg1q1bWFlZwczMDGRZhiAIfLM8ayuVJAm6rk/JNsuyeBtmq9XCcDiELMuYm5tDoVDA3bt3kUgkLtzoeV36/T4ODg6wt7eHo6Mj7O3toVKpQNM0+P0nHR7ZbJYn6+bm5hCJRJ57XlVV0W630Wq1pqTXZCtpMBjEYDDA0dERSqUSNE1DIBBAJpN5oVQbcCLbGo0GGo0GyTaCIIjXDAm2d4SX2URFEARBEMTbB9VubzdXqd2GYwO/mpeQ17axvb2N3d1dHB8fYzgcQpIk+P1++P1+uFwuyLKMxcVFvPfee/jZz36GxcXFFxZSbE5bp9PB8fExjo+PYds2jo6OsLm5if39ffT7fXi9XgQCAb5dkyWzVldXebJtMp11WrhZlsVlV6VSOVe2ud1uzM3NYXV1FRsbG0gkEohEIi8t22zb5i2k+/v7qFQq2NvbQ7PZhGVZXFBlMhku+3K5HLzeiy9mMwzDQKfTQavVQqfTObeV1O/3o9PpYH9/H61Wi7922Wz2hVJtAMk2giCI1w0JtneE626TIgiCIAji7YZqt7eb69Rujmlge3sbxWIRnU4HtVoNOzs7ODg4QKvVgsvlQjAY5PPZBEFANBrFrVu38OGHH+LHP/7xC23tZHPa2u02jo+PUa1WMR6PIYoi6vU6Njc3fzP7dwSv1zslb2RZRiQSwcrKCu7cuYNcLgeXy8VlEwD4fD4u3NxuN98EyoTeaDRCu92emtk2Pz+PlZUVrK+vIx6PIxKJnFm2cF10XUe5XObbR0ulEg4ODtDtdiFJEsLhMGZmZpBKpbCysoKFhQVkMhkuFi/jslbSUCiESCQCl8uFVquFo6MjjMdjnmpjSyVeRCaSbCMIgnj1kGB7h3jRbVIEQRAEQbx9UO329nPd2s1xHFQqFRSLRZRKJT5If39/H7u7u2g2mwAAr9fL558BJ/PTFhYW8P777+MnP/kJVlZWriWlJue01et1VCoV9Ho9eDweOI6DbreLYrGIo6MjqKoKURSnpJAkSQiFQlheXsZ7772HpaUl+P1+DIdDmKbJP87n8yEYDEKWZbRaLZTLZdRqNdi2DU3TziTb5ufneVouHo8jGo2+tGzr9Xo4PDzE/v4+qtUqDg8PUS6XMRgM4PP5EIlEkMlkkEgksLa2hnw+j1QqdeX7vayV1OfzwXEctNttdDod3n47Ozv7wqk24HLZlkwmEY/HSbYRBEFcARJs7yAvuk2KIAiCIIi3B6rdbg4vUrupqoqtrS1sbW1hNBoB+FaE7e7uot1uQ9dPliWchqXbfvjDH+JHP/oRwuHwlR6nbdtotVqo1+s8cdVsNuF2u7lQYzPGOp0O6vU6BoPB1DlEUUQgEMDc3Bx++MMf4s6dO4jFYhiPx+j1ejAMg3+s1+uFy+Xi5+p2u1y2sWTbcDiEx+Phyba1tTUu2yYl43WxbRu1Wo0vRKjX61y66bqOYDCIeDyOmZkZzMzMYHV1FXNzc4hGo1e+j4taSQVBgCzLGI1GUFUVlmUhEAhgdnb2pVJtwMWyLRQKIZFIkGwjCIK4BBJsBEEQBEEQNxCq3QjgRASVSiUUi0VUKhX+fsdxThYltNvodrtnJA6DzTh77733eLrN5XJdep+s7ZG1jzIBxTbXMyzLgiAIKJfLKBaLaLVaU8JPEAT4fD7MzMzg/fffx7179zA3NwdBENDr9dDr9fgWU/Zc2XMZjUaQJAmapqHT6XDZpigK8vk8VlZWsLq6ikQi8dKyTdd13jbKWmUPDg7QbDYhCALC4TASiQRSqRRyuRyWl5eRy+Xg9/uvfB8sCcjSbePxGMCJELMsC8PhEI7jwO/3I5FIYGZm5qVSbezcJNsIgiCuDgk2giAIgiCIGwjVbsRper0etra28OzZsykxBZzILo/Hg16vh+3tbRwdHXFpM0k0GkWhUMC9e/fwox/9CNFo9NK0lKqqvH20XC6jWq3Ctu0pSSeKIrLZLGRZxtdff4379++jVqtNpdWAk8RaMpnErVu3cO/ePdy5cwfBYBD9fh/9fh+9Xo+Lp+FwiEajgXa7DcMw4PF4YFnWlGzzer3I5/NYXl7GysoK4vE4YrHYcwXiZXS7XRweHuLw8JA/58PDQ/R6PSiKgnA4jFQqhUQigYWFBSwuLiKTyVxbhKmqyue29ft92LaNwWDApaPH40E6nUYmk0EymXypVBvwfNmWSCReSuYRBEHcBEiwEQRBEARB3ECodiMuwrIs7O/vY3NzE41G48zxaDSKZDKJwWCAr776Cs+ePUOz2TwjvNxuN7LZLN577z38+Mc/xsrKyoWJJtae2mg0UKlUUKlUMBqNznx8NpvF2toaNE3DRx99hM8//xzlcnlq+D9wMpMsGo1iaWkJ9+7dw8bGBjKZDGzbRq/X48JN0zT0+31UKhXU63UuEgVBgKqqU7Jtfn4eS0tLWFpa4rLtKosKzsO2bT6j7ejoCK1WC4eHh6hUKlBVFcFgEJFIBOl0GvF4HCsrK5ibm8PMzMy158QZhoF2u81ns7E5bo1GA4IgIBaLIZ/PI5fLvXSqDSDZRhAEcREk2AiCIAiCIG4gVLsRV6HVamFraws7OztTCwWAk02frJ2x2+3iyy+/xIMHD1CpVNDv96fSbYIgIBKJYHV1FT/84Q9x7949xOPxM7Jock5brVZDqVRCt9uFoihTCatoNIq7d+8in8+jVCrhs88+w2effYb9/f0z9+12uxEOhzE3N4f3338f7733Hubn5+Hz+aDrOm8n7fV6qFarOD4+RqVSwXg85rLNMAz0er0p2ba8vIz5+XkkEomXkm3j8Zi3kLI03/7+PhqNBizLQigUQiwWQzKZRDqdxtLSEubm5hCLxa6dOmOCsd1uo16vo9ls8iSfLMuYnZ3FysoKcrkcwuHwS6Xa2HMj2UYQBHECCTaCIAiCIIgbCNVuxHUwDAM7OzsoFovodDpnjqfTaRQKBWSzWTSbTTx9+hRffPEFdnZ20Gg0eGsmw+PxIJPJYGNjAz/+8Y/5ZlDG5Jy2RqOB/f19tFotuN3uKSnn8/lw584drK2tweVyoVqt4quvvsInn3yC7e3tM3PjZFlGIBBAJpPBnTt38P7772N5eZkvF2Airdfr4eDgADs7OzxNx7abmqaJ4XAIXdfh8/l4K2c+n0c8Hn+pFFin0+GbR9vtNl+U0G634XK5+Ly2eDyO2dlZLC0tIZfLIRgMvtD9sVbScrnMt64Oh0P4/X7Mzc1hZWUFCwsLr2SWmqZpXOiRbCMI4l2EBBtBEARBEMQNhGo34kWp1WooFovY398/s2VUURS+ICAQCKDf7+Pg4AAPHjzAw4cPUalU+DZPhiAIiEajWFxcxIcffoj3338fqVSKzzpTVXVKtNVqNUiSNDULTZZlFAoF3LlzB4FAAI7joNls4uHDh/j00095u+vkbDlJkuD3+zEzM4OVlRW8//77KBQKSKfTXOIx4ba3t4enT59ib28PmqbBcRyYpgnLsqCqKkzTRDgc5rItl8vxZNuLyCnbtnF8fIzDw0McHx+j2+3i6OgIpVIJqqrC6/UiHA4jmUzyFs/5+Xlks9kXlmGGYaBer+Pg4IC3/dq2jUAggIWFBSwtLSGfz79wUm8Skm0EQbyLkGAjCIIgCIK4gVDtRrwsmqZhe3sbxWIRg8HgzHE2Ly2TyfDkV7VaxdbWFr766ivs7u6iXq9zYcVQFAXpdBrr6+v40Y9+hMXFRYTD4ak5bWx2GZuZxhAEAQsLC7h79y7i8Th/f7fbxebmJn7961/j0aNHOD4+xmg04sdFUeQbNtnn37lz58yCAZbke/z4MZ49e4ZerwfHcaDrOkajEUajESzL4u2wS0tLyGazfGaboigv9DqXSiUcHh7yxQWlUgnHx8cwTROBQADRaJSn55jgm52dfeHtp7Zt4+joCDs7O9ja2kK324UsywiFQsjn81hcXEQ6nYbX632h859+fs1mE/V6fer7iGQbQRA3DRJsBEEQBEEQNxCq3YhXheM4qFQqKBaLKJVKZzaL+v1+FAoFrKysTAkmtlHz4cOHU+m2yVlvoigiGo1ifn4eP/jBD/Dee+8hmUzyLaBsCydLdU3ODJuZmcHGxgZyudzU+weDAba3t/Hll1/i/v37Z2bGCYIAn8+HSCSCfD6P27dvY2NjA/l8HoFAYOp5l8tlPH78GMViEd1uF6PRCLquYzgc8u2d0WgUKysrWFtbw+zsLOLxOBKJxAvJtk6ng4ODA5TLZfR6PTSbTd5CCpxIKSbbksnkVJruReepGYaBo6MjbG5uYn9/H4PBAD6fD7FYDOl0ms+De9lNpMCJbGMSlWQbQRA3DRJsBEEQBEEQNxCq3YjXgaqq2NrawtbW1lRCDDiRZfl8nrdhTqLrOo6Pj7Gzs4P79+/zVlBVVaeEndfrRTqdxq1bt3Dv3j0kEgmejGOiyev1Ts1pC4fD2NjYwPLy8plEl6Zp2Nvbw4MHD/DVV1/h8PDwjORj7ZjZbBbLy8u4e/cuFhcXEY/HuVCybRuVSoUn+nq9HlRV5YsR+v0+xuMxIpEIlpaWUCgUMD8/j1QqhXg8fu0kGGshPTg4QK1WQ7/fR61Ww+HhIQaDAWRZRjgcRiwWQywWw8zMDObn5zE3N4dQKHSt+2I4joPBYICjoyMUi0VUq1U+hy4Wi3GpF4vFEIlEplp4XwSSbQRB3DRIsBEEQRAEQdxAqHYjXie2baNUKqFYLKJSqZw5Hg6HUSgUsLS0dEaSOI6DdruNw8NDPH78GI8fP0a1WkWn05maoSaKImKxGObm5nDr1i3Mzs7CsizUajXU63V4PJ4pyaMoCm7fvo1bt26dmx7TdR3lchkPHz7El19+ib29PbRarakFDYqiIBAIYHZ2FvPz87h79y5WVlaQTqf5fVmWhaOjI+zt7WFvbw+DwQCqqqLb7aLVaqHX60HXdXi9XiwsLGBlZYXLr2QyCZ/Pd63XmrWQHhwcoNvtotPpoFqtolwuwzAMKIoyJdtyuRzm5uaQzWZfuMXTMAzeqsqkpGVZCAQCCIVCCAaDCIfDiEajL9wae/o5kmwjCOJthwQbQRAEQRDEDYRqN+K7otfrYWtrC8+ePZsSZMDJooHFxUUUCoWpmWmTjMdjng57+PAhDg4O+LyuyWUJXq8XyWQS6XQa2WwWtm1DVVXIsjwlXyRJwurqKtbX1y9Mc1mWhePjYzx+/BhfffUVtre30Wg0MBwO+ce43W4EAgF+f2ybaTab5ULJsiyUSiXs7e3h8PAQhmFA0zR0Oh0+5F/TNPh8PszMzGBhYQG5XI4vLbjudlAmJsvlMobDIZdg9XodALgAY7Itn89jbm4Os7OzL5Q4cxwH/X4f9XodlUqFC0lBEOB2uxEKheB2u+H1erlsCwaDL9VKeplsSyaTL7XFlSAI4nVCgo0gCIIgCOIGQrUb8V1jWRb29/f5Rs/TxONxFAoFLCwsXCh7HMfhs9eePn2Kzc1NHB8fo91uYzwe83ZSURTh8/m4THK73byVcVLuzM/PY2NjA6lU6sLHbds26vU6isUi7t+/j2fPnqFSqUwJPlmW+ZKETCaDlZUV3LlzB3NzcwiHwwAA0zRRKpWwu7uLcrkMy7JgWRafpVatVvl8s3Q6jXQ6jdnZWWSzWeTz+WstLWCCkMlIVVXRaDRQKpXQ6/UgSRJPmcViMSSTSczNzSGXyyGZTE612F4VXdfRarVQqVSmvh4ul4vPtRMEAS6XC5FI5JW0kl4k28LhMBKJBMk2giDeKEiwEQRBEARB3ECodiO+T1qtFra2trCzszM17ww4kVXLy8soFApcTl3EaDTi2y6fPHmCUqmEWq2GwWAA0zThOA5M04QoilAUBW63G8lkEtlsFuFwGLIsAwBSqRRfZHBZuspxHLRaLWxvb+Obb77B5uYmXzhgGAaAk4Sc3+9HJBLh7Zjr6+tYWFhAMpmEIAgwDAOHh4fY3d3F0dERbNuG4zgYjUZotVo4Pj7ms9RSqRRSqRTC4TBmZmaQy+WQyWQQDAavJKdGoxFvIe33+zxxViqVYBgGXC4XQqHQlGybn59HLpdDJBJ57vnPe4263S4ajQbf1irLMlwuF1wuFxzHmVoowRYzvGwrKck2giDedEiwEQRBEARB3ECodiPeBAzDwM7ODorFIjqdzpnj6XQahUIB+Xz+uakqljQrlUrY3NzEzs4Ojo+P0Wq1MBqNYBgGdF2HrusQBAEejwf5fJ6LJL/fj3A4jPX1dayurl5JXnU6Hezv7+Phw4d4+vQpn0fGFjyIogi/349QKIRMJoPZ2Vmsr69jcXERs7OzkGUZuq5jf38fe3t7qFQqXEBpmsaXFzDZxlo7WfpsdnYWs7OzCIVCCIVCz33MrVaLt5BqmoZ2u41arYZqtQrbtvmMOSa8ZmZmeLLturPhgJP2XpZq63Q6cBwHgUAAPp8PoijCNM2ptttX1UpKso0giDcREmwEQRAEQRA3EKrdiDeNWq2GYrGI/f39qdlqwMlygZWVFayuriIQCFzpfIPBAJVKBbu7uzxpVqvV0Ov10O/3oWkaVFUFAEQiEaRSKWSzWT40/4MPPsDt27evvAhgMBigVCrh0aNHePLkCQ4PD9FsNjEcDuE4Dm+TDAaDmJmZQSqVwq1bt7CyssIFlqZpODg4wO7uLqrV6pRsGw6HaLVaUFUVbrebSyi2xCAcDiOZTCIcDvNFAyyhdxrLslCpVHBwcIBGo8FFGEvjscc6KdsymQzm5uaQyWQuPO9F2LbNU221Wg2j0QgejwfhcJin5FRVRafT4V97l8uFaDTKb1dtjz0NyTaCIN4USLARBEEQBEHcQKh2I95UNE3D9vY2isXilBBhZLNZrK2tIZPJXDnhZFkWqtUq32zKUlxsfttgMMBoNOLtkpFIBOl0GplMBj/4wQ/ws5/9DIlE4srPgbWuPn36FE+ePMHBwQGfsWZZFgDA5/PB7/fzFtClpSWsra0hl8shFothNBphf38fu7u7qNVqADAl29rtNizLgsfjQSKR4AsRWBKPJcVYui0UCp0rxkajEQ4PD3F4eIjBYIDhcIhGo4GjoyOe9gsEAggEAojFYojH48hms5ibm0Mqlbr2vLbRaIR2u43j42N0u12eaovH44hGowBOkoGtVosvxWCtpLFYDNFo9IVbSUm2EQTxfUKCjSAIgiAI4gZCtRvxpuM4DiqVCorFIkqlEk7/WuL3+1EoFLCysnJt4dLr9XB0dIS9vT2emjs4OECtVuOtjJIkwev1crHz3nvv4ec//zk2NjaeOxtuEl3XUalUsLW1xWXb0dERBoMBxuMxAPDWzFgsNrWVdG5uDul0GpqmYW9vD7u7u3xBxKRs63Q6sCwLoVAIqVSKb0d1uVw8icbm0E0Kt9Myqdls4vDwEEdHRxiPx+j1eqjX6zxNJ0kSfD4fn9cWj8d5C2ksFrvW18CyLJ5qq9fr0DQNbrcbkUgE8Xgc8Xgcpmmi3W6j1WpNCTGv18tl24u2kpJsIwjiu4YEG0EQBEEQxA2EajfibUJVVWxtbWFra4vPN2OIooh8Po9CoYB0On3tcxuGwdNtDx8+xNbWFra3t3F0dIThcAhRFPn2S5/Ph0wmgw8//BA/+9nP+NZNj8dzpfsyTRPHx8fY29vD48ePcXBwwDd7snZVj8fDU2islfT27duYn59HJpOBYRjY29vD3t4ems0mgG9lm6qqaLfbAIBgMIhsNsvFlyzLkCQJHo+Hp848Hs+UcGPPw7IsHB0d4fDwEI1GA4ZhoN1uo1qt8ll5Ho8HiqLwFtLJTaR+v/9aXwP2uI+Pj9Hr9QBgKjEXDof5Y2i326+8lZRkG0EQ3wUk2AiCIAiCIG4gVLsRbyO2bfM2z0qlcuZ4OBxGoVDA0tLSCwuRdruNvb09fPPNN/j666+xvb2N4+NjWJbFN54KggC/34/l5WVsbGzg9u3bWFxcRCqVQjQavVLbpG3bqNVqODg4wKNHj/imz06ng8FgAMdxIMsyX5KQTqcRj8exurqKpaUl5HI5OI7DZRsTa5OyjcmwUCiEbDaLVCoFt9vNt3oCmNri6vF4EAwGuXBTFAWqqvIWUlVVoaoq33SqaRoEQYCiKFOpMrbtNJvNXuvrYFkWOp0O6vU6ms0mxuMx3G43wuEwT7V5PB4+063VaqHdbr/SVtLnybZEInHtGXQEQRAACTaCIAiCIIgbCdVuxNtOr9fD1tYWnj17xgULQ5IkLC4uolAoIB6Pv/B9qKqKJ0+e4OHDh/j1r3+NUqmEwWAATdOg6zosy4IgCIjFYpifn8f8/Dzy+TxWV1eRyWSQSqWutH3TcRw0m02USiU8efIE+/v7KJVKXPKYpglJkvgstFQqhVgshsXFRSwvLyOXy0GWZS7but0uPy+Tbd1uF4Ig8GRbMpmE3++H3++H2+2GKIoYDofQNI0/LrfbzRcmBINBjEYj3uJqmib6/T6azSZvIWXbWVn6LBqNIpPJIJfLIZ1OXytdxubMVatV9Pt9PqttMtXGWkPZAgg2T4/xsq2kJNsIgniVkGAjCIIgCIK4gVDtRtwULMvC/v4+Njc3+XyySeLxOAqFAhYWFnhq67o4joN2u43NzU18/vnnePToERqNBnRdR6/Xg6ZpMAwDPp8PkUgEsVgMqVQKc3NzmJ+fx+LiItLpNBKJxJUkU6fTQblcxubmJvb29lAul/mSBE3TIIoil2NM9MzOzvIlCR6PB4eHh9jb20O/3+fPYVK2SZKEUCiE2dlZJBIJ+P1+RCIRBINBuFwuqKqKXq831ZIryzJCoRB8Ph+GwyFPmrHkWaPR4Ek6SZIgyzJ/PaLRKLLZLHK5HOLx+JVll2ma/NzNZhO6rvOk3WSqjaHr+qWtpLFYDJFI5NqtpCTbCIJ4WUiwEQRBEARB3ECodiNuIq1WC1tbW9jZ2ZlqfQRO5NDy8jIKhcK1lhScZjAYoF6v49GjR3jw4AGOj4/59s1+v49+v88XJLDWyUgkgmw2i2w2i6WlJZ5uu8rj6Pf7KJfL2N7exvb2NqrVKsrlMvr9PobDIW9X9fv9iEajXPasra1hbm4OiqKgWq1ib2+Pi6HTsk2WZQQCAWQyGcTjcS4K4/E4T671er2pWXHAibRyuVwYDAZot9twHAfj8Rjtdhu1Wo23kIqiCI/HwwVXLBZDLpdDLpfj20+v+tq3Wi3UajUMh0OeaotGo2dSbQBeWyspyTaCIF4EEmwEQRAEQRA3EKrdiJuMYRjY2dlBsVjkc8gmSafTKBQKyOfzV5qXdh7j8Rj1eh17e3vY3NxEuVzmibZWq8XbLdlSAdu24XK5EAgE+KbQTCbD021sPtpljEYjlMtl7O3tYWtrC7VaDaVSCd1uF8PhEJZlwefz8SUJyWQS0WgUKysryOfzUBSFz5hjouy0bGNtoWzmG9saGo/HEYvFIAgCl21M8rHzDIdD9Ho9DIdDeDweGIaBTqeDWq0Gx3HgOA5cLtdU62YikeCy7arLIth5m80mWq3WVKqNtZCed66LWkl9Ph+Xf4FA4FqtpEy21et1/loA4K9/PB4n2UYQBAASbARBEARBEDcSqt2Id4VarYZisYj9/X3eLshQFAUrKytYXV1FIBB4ofObpolWq4WDgwMUi0UcHx9DFEVomoZ2u416vQ5d1+F2u3nCi7VdTqbbZmdnkc/n+TKC5y1L0HWdb/rc3NxErVZDuVxGu91Gv9+HYRhQFAWBQADBYBCJRIInxxYXF+HxeDAYDHBwcMDnrk3Ktn6/z6UVm/nm9XqnZJvb7eaz2Jh0Gw6HME2TSyf2OAzDwHA4RLfbhSiKcBwHoigiFAohGo0iEokgnU4jl8thdnb2Si2cjuPwVFu9XufSkLW7npdqm3z92u02Wq0Wut0u/96YbGu9bivpaDTiybZJ2RaJRPg2UpJtBPHuQoKNIAiCIAjiBkK1G/GuoWkatre3USwWp9JLjGw2i7W1NWQymWsPwwdOZE+n08HR0RGePn2Kw8ND3io6GAzQ6XTQbrchyzJ8Ph86nQ663S5GoxF0XYfL5eIyK5vNIp1OY2FhgafbLluWYJomjo+P+YbVarWKo6Mj3sI4Go3g8Xim5rbFYjGk02ksLS3B4/FA0zSUy2WMx2P+fJhsGwwGUBQFwWAQyWQSkUhkSrbF43GevrMsa0q4NRoNHB8fo1arQdd1yLIMwzCgqipfjMDSfUxshcNhzM7OYm5uDolE4kpfj0lh1m63YRgG3G73c1NtwEkrKfv6nG4lDYfDiEaj124lJdlGEMRpSLARBEEQBEHcQKh2I95VHMdBpVJBsVhEqVTC6V93/H4/CoUCVlZWrj2bizEYDFCtVrG5ucnnwQUCAUiShNFoxFsxXS4XLMtCs9nkLYaj0QiO4/A5aEy2sW2cqVTq0mUJtm3zNNvW1hYqlQpqtRoqlQqGwyEGgwFcLhf8fj+fX8YEFEu26brOk3fsNWOybTgcQlEUhEIhxONxLtvY30+LLNu20e/30e12sb+/j52dHdTrdTiOA13XMR6PeSutLMtwHAdut5vPaguHw3w5wlVm1jmOg16vh3a7jUajwdOCk3PlLkq1TX79mKyblGMv2kpKso0gCIAEG0EQBEEQxI2EajeCAFRVxdbWFra2tqa2ZQKAKIrI5/MoFApIp9MvdH42p61YLOLZs2fo9/sIBoNcpliWBbfbjWg0ivF4jFqtxmeWsYUCmqbB7Xbz2W1sEcH8/DxPt4VCoXPv33EcNBoNlMtl7OzsoFwuo16vo1Kp8IUMABAIBOD3+7kkY62kPp8P4/EYnU4HlmXxc07KNr/fzzd6hkIheL3eqQ2fpyWlbdtot9v8da9Wq7BtmyflWEup2+2GJEkIBAJ8XlskEsHc3Byy2Sy8Xu+VXv/JNlDTNPlCh+el2hiXtZKyZNt1WklJthHEuwsJNoIgCIIgiBsI1W4E8S22bfP2ykqlcuZ4OBxGoVDA0tLScxcRnIdpmmg2m9jZ2cHm5ibq9Tr8fv+UJAqFQsjn8/B6vahWq9jf3+cbMCfTbbZt88UDbF4bm12WSqWQTCYvfIztdhvlchn7+/s4ODhAs9lEpVJBp9Phcou1kbLWStZK6vP5eGsnk0yTsk1VVS7potEol22BQIDLtvOkWK/XQ7FYRLFYRLvd5osSer0eTNOE2+2GLMuQZRnJZBKJRAKhUAjJZBK5XA6ZTAYul+vS19+2bZ5qazabGI1GEAQBPp8PoVAIiUTiuak2dp7ntZLGYrErL2sg2UYQ7xYk2AiCIAiCIG4gVLsRxPn0ej1sbW3h2bNnXKAwJEnC4uIiCoUC4vH4tc/tOA7a7TYODw/5nDY2G43JHUVRcOvWLSwvL6PT6aBSqeDg4AD1eh29Xo+n29iW0sl02+zsLKLRKHK5HGZmZviyhPPEUb/fR7lcxuHhIfb29tBsNnF8fMzbIkejEd9IGggE+Hw0Jsosy8J4POaLGCZl23g85vIqEokgGAzC6/XC7/dz2XZ6ppzjOKjX6zg4OMDOzg7falqtVtHr9fi8NuBkOQRL84VCIczOziKXyyGZTD53KyxbPtFqtdDr9WBZFmRZht/vv3KqjXFZKylL3V21lfQ82cbEHck2grgZkGAjCIIgCIK4gVDtRhCXY1kW9vf3sbm5iUajceZ4PB5HoVDAwsLCcxNU59Hv91GpVPD48WPs7u5CEAQ+pw04kXkrKytYX19HKBRCu93G0dERjo6OUCqVeCKLpdtUVYVlWXyDZiaT4YmvTCaDVCp14bIEVVVxdHTEW0mbzSZqtRoajQb6/T5UVYWiKDzdFg6H+Xw0RVFgWRZM0+RialK26brOk22RSIQn93w+H5dtfr9/6vEYhsHlX6vVgqZpaLVaOD4+RrvdhuM4sG0buq5PycVEIoF8Po9cLodoNHrp62/bNk8IttttjMdjPvvuOqk2hq7r/Fwv20pKso0gbiYk2AiCIAiCIG4gVLsRxNVptVrY2triCwsmkWUZy8vLKBQKVxrCfxpN01Cv1/Ho0SNsb29D07SpOW0AkM/nsb6+jpmZGQAns8UqlQqXbc1mE71eD/V6Hd1uF6qqYjQa8XRbKpXC7OwsIpEIZmZmeLrtvGUJ7NylUgl7e3tc9FSrVQwGg6klCZOtpKFQCC6XC47jwLIsntyalG2GYSAQCCAUCiEUCnHZ5vV6uWwLBAJTj6ff7+Pw8BClUgmapvGlEGxpw3g8xng8hmEY8Pl8SCQSvG2WCdDLNrACJ0KLtX32+32eamNJtOuk2oATOdvtdnm6zTAMAC/WSkqyjSBuDiTYCIIgCIIgbiBUuxHE9TEMAzs7OygWi+h0OmeOM6mTz+ef26p4GjanbXNzE8ViEa1WC8FgcGpJQCKRwMbGBubn56daMxuNBhdulUoF3W4X3W6XS5lJuRUOh5HJZJBIJHjSjaXbTi9LME0TlUqFz22r1+s83dbtdjEYDOA4DpdtrM1y8jyO4yAcDkOSpCnZZpomQqEQgsEggsEgfD4fvF4vFEXhsi0YDE6dh7WQHh8fw7ZtmKbJt4VqmobxeIxutwvDMBAMBhEKhfjzXV1dRaFQOJOWm4SJMbbQYDLVFgwGr51qYwwGA55uu6iVdPK5XgTJNoJ4uyHBRhAEQRAEcQOh2o0gXo5arYZisYj9/X3eDshQFAUrKytYXV09k8h6HmxO2+7uLh4/foxKpQKfzwefz8fFTiAQwPr6OlZXV88IldFoxFtJy+UyOp0Ol0adToen22RZRjAYRDKZxMzMDB+sz9Jtp5clWJaFWq3GWzer1eqUbOv3+zBNky828Pl8fAbZ5OsTi8XgdrunZJtt21yIBQIBnmrzeDxTso09f13XcXR0hIODAy46x+MxWq0WRqMRLMuCqqpoNpt8cUM4HIbH48Hs7CyWl5extLSESCRy4UIIVVV5qm0wGMC2bciyDK/X+0KpNgbbbHpRKylrvX1eKynJNoJ4+yDBRhAEQRAEcQOh2o0gXg2apmF7exvFYhGDweDM8Ww2i7W1NWQymWsnn/r9PkqlEr755hscHBxAlmUEAgGeXnO73bh16xZu3759bhukbduo1+tcuDEZxoTbcDjEcDiEaZpcQs3OziIejyMcDnPZdnpZAkvNlctllEolHB8fo9lsol6v8zZLXdfh8XgQCASgKAoikQgCgQAsy4Jt2xBFEbFYDH6/f0q2AeCibbKF1O12c9kWCoX4Y+n3+zg4OECpVOKJs+FwiHa7DcuyIIoiVFVFo9GAKIpwu93w+XzweDxIJpOYm5vD3NwcF3ynpZllWeh0OnwpAlt8oSjKS6Xa2Lkn58CdbiVl6bbniTySbQTxdkCCjSAIgiAI4gZCtRtBvFocx0GlUkGxWESpVMLpX6P8fj8KhQJWVlam2j6vgqZpOD4+xtdff42dnR3Yto1AIMCliSiKWFpawsbGxqXD/QeDwVQraafT4csSTqfbAoEAT7Qx2ZNOp7lw83q9/LztdvuMbGs2m2g0GhgMBvycTJgFg0H4/X7Ytg3DMHh6i4kzJtuYJGJbSBVF4bKNpciY3GIy8eDgANVqFbZtw7Is9Ho99Pt9CIIAURQxGAzQ7/fhdrvh8XggSRJcLteUSGStq6FQaOprNRwOefvocDiEbdtwuVwvnWqb/Pq8bCspk231ep0Ly0nZlkgkXmgpB0EQLw8JNoIgCIIgiBsI1W4E8fpQVRVbW1vY2trCaDSaOiaKIvL5PAqFAtLp9LXOa5om6vU6Hj9+jM3NTQwGgzNz2rLZLDY2NpDJZC49l2VZqFarPN3GFiX0ej20Wi0MBgM+K41tAZ2dneXbMNkygVQqhXg8zlsae70eyuUyyuUyKpUKWq0WWq0W6vU6P6cgCLzt1efzcdk2Ho/hdrv51k1JkrhskySJL0fweDxctsmyzOVWJBKBIAjQdR3lchkHBwfodrsATubntVotjMdj/lg7nQ4sy+ICDwBcLhcikQiSySRkWYbb7eb3yz6OzX5jaT3TNOE4DhRF4WLyRVNtDNZKyoQe+7X8Oq2kLLnXaDRIthHEGwAJNoIgCIIgiBsI1W4E8fqxbRulUgnFYhGVSuXM8XA4jEKhgKWlpQtngV103na7ja2tLTx69Aj1ep3PPWNSJxqN4u7du1hcXLzSwoVer8dl2/HxMfr9Pnq9HjqdzlS6zeVyIRgMIh6PI5VKIRwOIxwOT6XbWMpKVdULZVu/38dwOITjOPB6vfD7/fB4PPB6vVy2eTweRCIRvnGTyTaXy4VwOIxQKMSTaF6vFy6Xa0q2iaKIXq+Hw8NDHB4e8vZONl/NcRy4XC4YhoF2u81lHZsbx+bJ+f1+LrJkWZ4SbrZtcwmmqio/p6IoryTVBlzcSiqKIkKh0JVaSUm2EcT3Dwk2giAIgiCIGwjVbgTx3dLr9bC1tYVnz55x0cOQJAmLi4soFAqIx+PXOm+/38f+/j7u37+PcrnMU1RMqvl8Pty5cwdra2tXlniGYUyl21grabfbRafT4Uk0y7KgKApCoRBmZmZ4qioajSKZTCKdTvMk2Hg85osXjo6O0G63eSsp20hq2zbcbjf8fj/cbjcUReGyjcmvWCwGn8+H8XgMVVXh8Xh4GylLnDHZFo1GEY/HedtsrVbjLaSO48BxHPT7ffT7fYiiCJfLheFwiE6nwzeiRiIRnp4LBoM8VcdgyyK8Xi8sy8J4PMZwOIRlWfz1CQQCUwm7l6Xf7/N0G5NlwEkbMku3XbZc4zLZlkwmEY/HSbYRxGuABBtBEARBEMQNhGo3gvh+sCwL+/v72NzcRKPROHM8Ho+jUChgYWHhWpJD0zSUy2Xcv38fOzs7kCRpak6bLMsoFAq4c+fOtTebttttPretVqvxdBtbmDAajTAajfh9xuNxJBIJRCIRhEIhnnZjM85M08Tx8fG5sq3T6fDFC6IoIhAIwO12w+VywXEcGIYBSZK4SAqFQtB1HaqqQlEUnmxzuVx86yf7eCbbLMtCqVTCwcEBer0e/7p0Oh3oug5JkiAIAn8sLEUXDAbh8/mQTqcRDofhOA56vd6U5JIkCaIowrbtqe2pLpcLHo/nlaXaGJe1krJk22WtpBfJNtYGTLKNIF4dJNgIgiAIgiBuIFS7EcT3T6vVwtbWFnZ2dmCa5tQxWZaxvLyMQqGAcDh85XMahoF6vY4HDx5gc3MThmEgGAxyoSMIAhYWFrCxsYFEInHtx6zr+tSihH6/j263y4UbS7fZtg2Px4NgMIh0Os1FTygUQjKZ5MLN7XajVquhVCrxtNykbBsMBrwlkm3/FEURjuNwGcYEWCQSgWVZUFUVXq8X0WgUgUAALpeLLyNgHx+PxxGLxTAYDHB4eIhSqcSThePxmMsqWZZhGAYajQYcx+Fiz+v1IhwOI5fLIZ1OwzAMPsOOLSgwDIMvQ5AkCYqi8PlxrzrVBlzeSsqShZe1kpJsI4jXCwk2giAIgiCIGwjVbgTx5mAYBnZ2dlAsFtHpdM4cT6fTKBQKyOfzV5qnBnw7p+3Ro0d49OgRut0ub2VkQmdmZgYbGxvI5XIvJHkcx0Gr1eKtpGxrKNvcyeaSjcdjCIKAQCCAWCzGlwCEQiFEIhEu22KxGFqtFp/b1ul00Gq10Gw20Wq1MBwOMR6P+dw2j8cD27bhOA5M04QkSVwksRlso9EIPp+PyzZJkiBJErxeL0RR5LItEomg2Wzi8PCQt5ACJ5s9B4MBJEmCLMt8k+hkQsztdiORSGBubg4zMzMQBIGn/Hq9Hj9Hv9+HpmmQZZm31rJk3atMtTFeppWUZBtBvHpIsBEEQRAEQdxAqHYjiDeTWq2GYrGI/f39qRZDAFAUBSsrK1hdXb1Wm2ev18OzZ8/w9ddf4+joCH6/f2pOWzgcxsbGBpaXly/dSvk8NE3jybZKpYLhcMglU7fbxXA45IsA2LKAZDKJSCSCSCQCv98/tZ3UMAwu29jWTibbBoMBxuMxn9vm8Xh4YsswDD7LjaW2ZFmGpmlc8vl8Pt7OyWRbOBxGPB6H3+9HtVrFwcEB+v0+gBNh2ev1oOs6ZFmGKIpot9tcXLKEnizLmJ2dRS6XQzKZhCAIsCyLi8dGo4FKpYJer4fxeMy/rmxuXT6fRywWe2WpNsZ4PObJtuu2kpJsI4hXAwk2giAIgiCIGwjVbgTxZqNpGra3t1EsFjEYDM4cz2azWFtbQyaTubKMGY1GODg4wFdffYW9vT243W4Eg0EuRxRFwe3bt3Hr1i0oivJSj9+2bTQaDZ5ua7fbU8Kt3+9DVVXous4lF5tPxtJtgUCAyzaPx4NarYZyucwXI0zKttFoBMdxIAgCFEWBruuwbRumafK2TCbbFEWBYRhctgUCAQiCAEEQuGxjs+MkSeLz4lgLqWEY6Ha7EAQBbrcblmVx+cTaVQOBADweD3K5HHK53FSbL5N15XIZpVIJzWaTP162ITWbzWJ+fh7xePzKqcWrwubNMWl5XitpLBY7dykGyTaCeHFIsBEEQRAEQdxAqHYjiLcDx3FQqVRQLBZRKpVw+tczv9+PQqGAlZWVK0sxwzBQqVRw//59FItFOI4zNadNkiSsrq5ifX0doVDolTwPVVW5bDs+PsZoNJqa36aqKkajEb//YDCIZDKJcDiMcDjM02ipVAp+vx+qqqJcLqNer6Pf76PVaqHVaqHX63HZ5jgOPB4PxuMxTNPkc+Em2zN9Ph9//rFYDH6/HwC4bBMEgX8s265aq9X412E0GmE4HEKSJHg8HgyHQzQaDdi2zbeesll0TLZ5vd4zr02pVEK5XObCkM1t8/l8yGQyXNIFg8FXKtwcx8FgMLiwlZSl285LTJJsI4jrQYKNIAiCIAjiBkK1G0G8faiqiq2tLWxtbXEZxRBFEfl8HoVCAel0+krns20bzWYTX3/9NR4/fozhcHhmTtv8/Dw2NjaQSqVe2fOwLAv1ep0LN7YYgMk2lkgzDAOiKMLj8fBUFUu3+Xw+JJNJhEIhniCr1WoYDAZctrEZcEy2uVwuaJoGwzBg2zYUReELByKRCHw+H2RZRjAYRDweh8/n44+ZvSbBYBB+vx+apqFarfIWUsdx0O/3YZomZFmGLMtot9toNBp8e2g0GoUkSUgkEsjlcpidneVbXtnXgyXzGo0GOp0On9vGUnWRSATRaBShUAihUAjBYPCl2npPc1Erqdvt5glANt9uEpJtBPF8SLARBEEQBEHcQKh2I4i3F9u2USqVUCwWUalUzhwPh8MoFApYWlo6t83vPHq9Hp48eYIHDx6gXq8jEAhMzWlLpVJYX1/H/Pz8K58P1u/3uWyr1WoYj8e8lZQl0lRVhSiKEEWRb+CcTLexBJrjOFBVFbVaDaqq8iUJnU4Ho9GIL0UATjaisvf5fD54vV4oisLnwbE/JxIJeL1eLuqYbGMbSllqjbVaWpaFXq/HW0gdx+GPgaXhQqEQJEnCzMwMcrkcUqnUlLQajUY8kadpGobD4VQrbCAQQDAYhCAI8Pv9U8LtVYms57WSTi55mITJtnq9zkUwyTaCIMFGEARBEARxI6HajSBuBr1eD1tbW3j27BmfEcaQJAmLi4soFAqIx+NXOt9oNMLOzg6+/PJLHBwcwOv1TkmbYDCI9fV1rK6uvhZJYpomqtUqF27D4XBqdttgMICmaTBNEy6Xi88sY8mqUCgEj8fDU1aGYaDT6UBVVXQ6HTSbTbTbbYxGI5imCeBEJOm6juFwCMdx4PP54PP54Ha7EQ6H4fP5uNRjyTYm6hRF4e2klmXxxzop8Zgc9Hq9GI1GqNfrUFWVp/LYfWWzWeRyOUSjUf56TEou9pht28Z4PMZ4PIYkSXC5XFNyjglHJt1exdeJtZKydNtVW0lJthHEt5BgIwiCIAiCuIFQ7UYQNwvLsrC/v4/NzU00Go0zx+PxOAqFAhYWFq4kNAzDQKlUwhdffIFnz57xuWhsTpvH48GtW7dw+/btMzPFXiXdbpfLtnq9DsMw+FZSlu4aj8cQRRGCIMDn8/FWUibHvF4vZFmGaZpQVZUn5JrNJprNJjRNg67rEAQBpmnyuWos2eb3+3nrKEuLpdNpJBIJeDyeM7JNkiTous5bRhnD4RCWZUGSJCiKgk6ng3q9Dsdx+Lw2WZbh9/uRy+UwNzc31abKEnm9Xg+GYcBxHH4+SZIgyzIEQThzvz6fb0q4TbalviiapvG5bZNC8bJWUlVVUa/X0Wg0SLYR7yQk2AiCIAiCIG4gVLsRxM2l1Wpha2sLOzs7U6IFAGRZxvLyMgqFwtRmy4uwbRv1eh1ffPEFnjx5AsMwpua0iaKIlZUVrK+vIxKJvKZndIKu6zg+PkalUsHR0RFvHWXCbTgcYjwew7ZtiKIIl8vF55ax2W2yLMPlcsGyLGiaBuBEfDHZNhwO+ew3lmpjSweYbHO73fB6vfD7/QiHw8jlclOyjc13EwQBhmFA0zSoqspnpdm2jcFgwNNnkiSh0Wig2WzyeW3hcBiSJCEWiyGXyyGTyfBWTNM0eapN0zTYtg3LsiAIAmRZRjQahd/vn2q1Ze2dwMk8OdZOGgqFrtxGfBEsZcfSbex77rJWUtZWS7KNeJcgwUYQBEEQBHEDodqNIG4+hmFgZ2cHxWIRnU7nzPF0Oo1CoYB8Pv/czZSO46Db7eKbb77B119/jU6nc2ZOWy6Xw8bGBmZnZ1/H0znzeNrtNk+3NZtNmKY5NbuNtVG6XC4IgsBbRydnt7GWy8n5aWxuW6/Xg2maXJT1+32eDvP7/fD7/fB4PHC73fD5fIhEIpifn8fMzAzcbjcsy4JlWVAUBbZtYzgc8uUNbOMrS8yxVJuu66jX63xeWywWQyAQgCRJSKVSmJubQzqd5q852wDKHhdLtU3OqotEItA0Df1+n782k+3EbKsqk24spfiiX5eLWkkDgQBPt022kl4m25LJJGKxGMk24kZAgo0gCIIgCOIGQrUbQbxb1Go1FItF7O/vw7btqWOKomBlZQWrq6tnZmidh6qqePr0Kb766itUKhXegsgkSDwex8bGBhYWFp4r7l4V4/GYJ9sqlQrG4zFGoxFvJR0Oh9B1nW8TZUsKmGwLh8OwLAvj8RiWZcG2bbjdbr7Vs9PpcNmm6zq63S76/T50XYff70cgEIDX64UkSfB6vYhGo1hYWEA2m52SbR6Ph3/+cDiEy+XiLbbsvtk5ut0un13G5rUpigJZlpHJZJDL5RCLxXhrK1tGwJ4ne7wulwvRaBTxeJyLPU3TuGzr9/sYj8f8tfR4PFy4sZl2L8rzWklZWo99n1wk26LRKBKJBMk24q2GBBtBEARBEMQNhGo3gng30TQN29vbKBaLGAwGZ45ns1msra0hk8k8d1uoruvY39/HZ599ht3dXbjd7qkElN/vx/r6OgqFwiuZ+3VVHMdBo9Hgwo21Lfb7fS7GxuMxDMPgCwJOt5PKsswTbKw9VNM0NJtNtFot3kZqGAba7Ta63S50XedtpD6fD5IkwePxIB6PY2Fhgc+/sywLpmnC4/Gg3+/zBQZ+v59vK9U0jT8ul8uFZrOJRqPB57VFo1G4XC74fD7kcjnkcjkEAgGeIGOpNjYfjrWQTm5gnZSfpze3Tgo3t9s9JdyYpLsul7WSRiIRnm5jraQk24ibBgk2giAIgiCIGwjVbgTxbuM4DiqVCorFIkqlEk7/2uf3+1EoFLCysvJcoWLbNo6Pj/HZZ59hc3MTtm1PzWmTZRlra2u4c+cO/H7/63xa5zIajXgr6fHxMV92MLmZlAkzSZIgiiI8Hg9PtgUCAQwGA76kwO/385bZVqsFTdP4YoN2u41OpwNN07hs8/v9XOQlEgksLCxgeXkZsizDtm2YpglJknhr6uTrZ9s2dF2Hy+XiSbh6vY5mswlFURCLxRAKhbikyuVyyGaz8Hg8XP612+2ppQgAzk21TaLr+pRwY/PqgJM5fpPC7UWWXDiOw+Viq9XiAg34tpU0Fovx7xeSbcRNgAQbQRAEQRDEDYRqN4IgGKqqYmtrC1tbW1OiAzhJF+XzeRQKBaTT6UvP4zgOOp0OvvzyS3z99ddQVZVv3hRFEaIoYnFxERsbG4jFYq/zKV0IW9rAhBtLqZ2eT2bbNpdigiAgGAwiHA4jGAzCNE2eYvP5fHC5XBiNRmi32xgMBnC5XFy2tVotjMdjeDwe3krqdrth2zYSiQTm5+dx69atqTZS27ZRq9XQbDZ5ss7n8/EZa6IoQlEULp0m57X5/X4IgoBUKoVcLoeZmRmIooh+v49Wq4XhcAjHcaZk20WptknY9lZ2m/w+YRtWmXCb3Hx6Va7TSkqyjXhbIcFGEARBEARxA6HajSCI09i2jVKphGKxiEqlcuZ4OBxGoVDA0tLSczdPqqqKb775Bl9++SUajQZfiMDEx+zsLO7evYtsNvtanstVGQwGvJW0Wq3CsiyoqsrbSYfDIUzThCzLXD6xlslgMAhBELicUxQFiqLAMAwMBgN0u124XC6eJGs0GhiPx5BlmSfbFEWBZVmIx+Nctvl8Pt5GqqoqarUa2u02n3Xn9/t5qk0URbjdbn7+0Wg0tbXT5XJhdnYWc3NziMfjU6k21qLJNp8+L9U2CVv6wITb5DIDJgWZdPP5fM9tN57Esiz+GJ/XSkqyjXibIMFGEARBEARxA6HajSCIy+j1etja2sKzZ8+mNk4CgCRJWFxcRKFQQDwev/Q8uq5ja2sLv/71r1Eul6EoCoLBIBd0kUgEGxsbWFpagiRJr+35XAXLslCtVnm6jbWEMpHE5qwJggC32w1BEOA4Dnw+H0KhECRJwng8RrvdhiAI8Pv9sG0b4/EYnU4HoijCNE2+vIC1ljLZ5vV6YVkWIpEIFhYWsLa2NrV8odPpoF6vo9/vIxgMTrWRMgEoCAKazSbq9ToAIBaLIRKJ8C2lbF5bMBhEr9dDq9XicoyJNsdxrpRqm4RtcGWv1XA45MdcLheXbUwQXlW4Pa+VlMlEv99Pso144yHBRhAEQRAEcQOh2o0giKtgWRb29/exubmJRqNx5ng8HkehUOAD/C87z+HhIT777DMUi0VIkjQ1p83r9eLOnTtYW1t7qa2Vr5Jer8dlW61Wg+M4GI1GU7Pb2LbRyeceCoV4yydrp2RSyTRNDAYD3qLZ7XbRaDSgqioEQZia28ZmseXzeaytrSEej/MlBs1mE7VaDbquc9mmKAokSYLL5eIz4RqNxtS8Npa6C4VCmJubQzabhSAIaLVa6Ha7vG2UbSGVJAmxWOxKqbZJJttu+/0+b00FwL/2rKX0OsLtslZSJtvC4TBGoxHJNuKNgwQbQRAEQRDEDYRqN4Igrkur1cLW1hZ2dnZ42x5DlmUsLy+jUCggHA5feA624fOzzz7Dw4cPYZomAoEAn9PmcrmwurqK9fV1BIPB1/2UroxhGFPpttFoBMuyeCso27wpyzJPk1mWBY/Hw5NpmqZhMBjA7XZDkiSeTGOvZbfbRbPZ5DLK6/Vy2QYAPp8Pc3NzKBQKSCaTEEURrVaLLz1gyTO2yVVRFIiiCFmW0e/3Ua/X0ev1EAwGEYvF+HKCZDKJXC6HdDoNVVV5Uowl9Njst+um2iaZTAL2+30uGYGT1s/Twu2qqblOp3NpK2ksFuOikaUGAZJtxPcDCTaCIAiCIIgbCNVuBEG8KIZhYGdnB8ViEZ1O58zxdDqNQqGAfD5/qSjp9/u4f/8+vvzyS/R6vak5bYIgYH5+HhsbG0gmk6/x2bwY7XYbR0dHqFQqaDQacBwHmqZx2TYYDAAAiqJweWNZFhRFgW3bMAyDJ6tcLhcsy4JhGDBNE4IgoNfrod1u8zZVtiSBySev14tMJoPV1VUkk0nIsozj42Mu0URR5LLN6/XypKAgCFPz4KLRKCKRCGRZhiRJmJ2d5S2k7XYb3W6XizDHcWAYxgun2iaxbXtqhttp4RYIBLhwCwQCzxVurJW01Wqh3W5f2EoK4FLZFo/Hv/dWZeLmQoKNIAiCIAjiBkK1G0EQr4JarYZisYj9/X3Ytj11TFEUrKysYHV1FYFA4MJz6LqOx48f47PPPkOlUuHD/NmctnQ6jY2NDczNzV1rWP53ha7rfFFCpVLBeDw+I5A0TYOiKHC73XwWGwDeNqrrOt/uCYDLNkmSMBgM0Ol0MBwOYRgGXC4XAoEAfD4fZFmGz+dDIpHAysoKUqkUl23VahXj8RiiKHI5Fw6HuRSzLAvNZhPNZhOCICAWiyEUCkEURXg8HmSzWWQyGQAnQlHTNIiiyFNttm2/VKptEtu2MRgMzrTfstdoUrgFg8Hn3pemaVy2TbaSejweviTB5XKh2Wyi0WiQbCO+E0iwEQRBEARB3ECodiMI4lWiaRq2t7dRLBZ5emuSbDaLtbU1ZDKZCyWZZVnY3d3FJ598gr29PciyzBNYwMlss/X1dayurr6x4sNxHLRaLd5K2mq1AJy8Pkwe9ft9iKIIn88HSZIgCAJGoxGfy8ZSbJZlwbIs2LYNXdf5EoVOp4PRaMSFF5NnLOUWDoextLSEZDIJSZJQrVbRaDRg2zZfvsDaRGVZhiAI0DQNjUYD7XYbHo8HsViMS9FgMIhcLod4PI7RaIRutwsAvA32VaXaJrFtG8PhcOo1Oy3cJhcnXPb9wFpJW60WOp0Ol5uSJCEcDvPXgc3DY7KNtZqSbCNeFSTYCIIgCIIgbiBUuxEE8TpwHAeVSgXFYhGlUgmnf530+/0oFApYWVm5UMQ4joPj42P8+te/xuPHj+E4DoLBIHw+H0RRhKIouHXrFm7fvv1KZM7rRNM0nmyrVCowDGMqrdXtdqFpGnw+H5+Zpus6n8Om6zpvmR2NRly2sflo7PNZSyRbksC2tQYCAczPzyORSPD5d0yQsaUKk5IJOGndbTQafFtpLBbjiyfi8ThmZ2fh8/kwGAwwHo8hSRJvH7VtG36/H4lE4qVTbZM4jnNGuLGFDJPSkAm3i2aqXdZKGgwGEY1GIcsyX5JwWrYlk0nEYjGSbcQLQYKNIAiCIAjiBkK1G0EQrxtVVbG1tYWtra0pkQGcCIt8Po9CoYB0On3hObrdLj7//HPcv38fqqoiGAzC7/fzTZkrKytYX1+/dLHCm4Jt22g0GjzdxkTXeDzmsq3f7/Mtm0xctdtt6LoO0zThcrngcrkwGo2g6zoMw4BhGHC73RgMBjAMg89tUxQFgUAAiqIgEonA5/Mhm80iGAzCcRy+4ZTh8/kQj8cRjUYhSRJs2+aprvF4zBcHSJIEURQxMzPDFwQMBgMIgsBTbSxx9ypTbZM4jgNVVaeE2+TiDb/fz2VbKBS6ULiNRiO+JOG8VlK32w1d19FqtaZk2+SCBJJtxFUhwUYQBEEQBHEDodqNIIjvCtu2USqVUCwWUalUzhwPh8MoFApYWlric9dOMx6Pcf/+fXz++edoNpvw+/0IBAL84/P5PNbX1zEzM/Nan8urRFVVLtuOj495O+hkum08HiMYDPJ0G1t+YFkWBEGA1+uFYRgYDAbQdR2apsHtdsMwDFiWBVVVoes63G43/H4/fD4fYrEYl2lerxe2bUPTNJ4IA05kWzKZRCQS4dJscl5bNBpFKBSCIAhwu91IpVIIhUI8dSfLMk/bva5U2ySTwo3NvpsUbj6fb2qGG0vrTfK8VlKPxwPTNHlqECDZRlwPEmwEQRAEQRA3EKrdCIL4Puj1etja2sKzZ8+g6/rUMUmSsLi4iEKhgHg8fu7nW5aFp0+f4te//jUODg54KySb05ZIJLCxsYH5+fnXInJeF5ZloV6vc+HW7/cBnCxQYJtJ+/0+XC4XQqEQZFnGeDxGpVLhssfn80EQBAwGAwwGA4xGI95eKggCVFWFpmmQJInPbYvH4/zPLpfrTEsvcDL7ji0yEAQB4/EYjUYDnU6Hz2vz+XwATpJj0WgUfr8flmVBFEUu6Fg76etKtZ3mdMLNMAx+zOv1cuHGXs9JWMKv3W5PpdcA8AUcbE4cyTbiqpBgIwiCIAiCuIFQ7UYQxPeJZVnY39/H5uYmGo3GmePxeByFQgELCwvntvc5joPDw0N88skn2NragiAIU3PaAoEAX4hwXlrpTaff73PZVqvVYNv2mU2bmqYhHA7D7/fDNE2+EdOyLPh8Prjdbr6UoNvtQpIkuFwuuN1uaJoGVVUBnIi5QCCAZDKJQCDAxZDjOFAUhf9dEAREIhHEYjH4/X5IkoThcIh6vY7hcIhAIDA1y421Z3o8HoiiOJVqsyzrtafaTjMajfhr1+v1poSboihTwu10kpK1krZaLfT7/alWUrfbDcdxoGkal8Yk24jzIMFGEARBEARxA6HajSCIN4VWq4WtrS3s7OxMtfUBgCzLWF5eRqFQuHDOWqvVwieffIJvvvkGhmEgEAhwUeR2u/lCBJayetswTRPVapULNybGdF2fEkayLPOWzl6vh0qlAlVVoSgKfD4fDMNAp9PhQtPj8cDr9cKyLAyHQ5imCa/Xi2AwiGQyiWAwyIUYk5Zs2QETSJFIBIqiTG3hNAwD4XAY4XCYb0kNBALw+/3wer1wu90QRRGmaX7nqbZJJje79nq9qUSlx+OZEm7sebOvB2slZe26wEkCk0lE0zT59zLJNoJBgo0gCIIgCOIGQrUbQRBvGoZhYGdnB8ViEZ1O58zxdDqNQqGAfD5/buJJVVV88cUX+OKLL9Dr9fhmSVmWIYoilpaWsLGxgWg0+h08m9dHt9vlsq1er8NxHN6uyGTReDxGOByG1+vlKbNGowG3281FY7fbRa1Wg67rXMKJoghVVTEej/mShHQ6jWAwyLebMvnE5sJ5PB4Eg0GEw2G43W7IsoxWq4VWq8W3bwYCAd6qGggEuMjzeDw8/fV9pNomYcsmJl9DBnuOTLgxEXhZKylbFDE5245k27sNCTaCIAiCIIgbCNVuBEG8ydRqNRSLRezv78O27aljiqJgZWUFq6urCAQCZz7XNE08fPgQn376KY6Pj8/Mactms9jY2EAmk/lOnsvrRNd1HB8fo1Kp4OjoiAueyXRbv9+H2+1GOByGYRhot9s4Pj6GaZrw+XxwuVxQVRXVahW9Xo+/Xi6XC7quYzQaQZZlBINBpFIp+P1+jMdj9Pt9LjHZDDefz8eXKXi9XoiiiEajgV6vxzdzKooCx3HgOA4CgQBCoRBP3pmmyefEfR+pttOv7aRwm5Rnbrd7aksp+96abCXt9Xr845lkM02Tb2El2fbuQYKNIAiCIAjiBkK1G0EQbwOapmF7exvFYhGDweDM8Ww2i7W1NWQyGQiCMHXMcRzs7Ozg448/xs7ODlwuF5dBbBPm3bt3sbi4+FYtRLgIx3HQbrd5uq3ZbPL3D4dDvixhPB4jEolAkiT0+31Uq1UMBgO43W4+n61er6Pb7UIURZ5MsyyLS6ZQKIR0Og2v1wtVVTEYDOD3+xEKhbhgm2wJ9fv9cBwHjUYDqqryZQgulwuGYfAW1FQqhUAgMLXZlC1jYBLu+4IJN7aldDQa8WOyLE+1lHq9XpimiXa7zW9Mso3HY55sc7vdcLlcJNveEUiwEQRBEARB3ECodiMI4m3CcRxUKhUUi0WUSqUz2y79fj8KhQJWVlbOTTxVq1V8/PHHePToEU9OsTltPp8Pd+7cwdra2pnh9m8zbMvo0dERKpUKnzFmGAZ6vR663S76/T4URYGiKBiPx7y107ZtKIoCXdf5+4CTxQUsgWYYBizLQjgcRiwWg8fjgaZp0DQNXq+Xv8YshcZkUiAQ4EsZ2Ly2UCjEN516PB6Ew2Fks1m43W4YhoHRaPRGpNomYa8jk25sNh4Avu11Urid10qqqipM04RlWfB4PLztlmTbzYQEG0EQBEEQxA2EajeCIN5WVFXF1tYWtra2plJEwMmMq3w+j0KhgHQ6feZz+/0+Pv30U3z11VcYjUZTc9pkWUahUMCdO3fObT19m2HpMSbc2u02f//k7DZd13m7Y6/XQ6vV4osI2HD/drsNx3Hg9Xr5PDfLsqaWG7jdbui6DtM0eXtpIBDAzMwMFEWBIAg82TYej9HpdPiWUr/fD13XuaibmZnB7OwsHMeBqqpvVKptEtM0p4TbcDjkx1h6kgk3QRB4so21kqqqitFoBNu24Xa7uQAm2XZzIMFGEARBEARxA6HajSCItx3btlEqlVAsFlGpVM4cD4fDKBQKWFpaOpNM03Ud9+/fx6effopWqzU1p00QBCwsLGBjYwOJROK7ejrfKaPRiLeSsnlswHQqi20mFUUR4/EY3YnoQrAAAQAASURBVG4XqqrCtm0+g63T6cC2bXg8Hr4kwbIsmKbJE2ysvRQAX3IQiUQwOzsLr9cLwzDg9Xrh9XoxGo0wHA4hyzJvIVVVFYZhIBKJIJ/PIxKJ8Nlwb1qqbRLTNHk7aa/XmxJukiRx4ebz+aDrOpeXlmVBVVX0+32ebAsGg3C73STb3nJIsBEEQRAEQdxAqHYjCOIm0ev1sLW1hWfPnvFWSIYkSVhcXEShUEA8Hp865jgOnj59io8++gilUomnrdictpmZGWxsbCCXy52Z8XZTsG0b9XqdCzeWqGKJMdZOqmkaRFHkqTdN0zAejzEcDjEYDNDv92Ga5lQ7qG3bfJkCW6jABvwDgNfrRSKR4LPXNE2D3++HLMu83dTn8yESicA0TQyHQ/51WVhYgCzLGA6Hb2yqbRLLss4IN6ZbRFFEMBhEMBgEAC7cNE2Dqqrodrt822swGITP50M0GkUymUQ0GiXZ9pZAgo0gCIIgCOIGQrUbQRA3EcuysL+/j83NTTQajTPH4/E4CoUCFhYW4HK5po4dHh7i448/xubmJk9asTa9cDiMjY0NLC8v33iZMRgMeCtptVqd2oDJ5FCn04FlWRAEAYZhwDAMDIdD9Pt9dLtdDAaDqbligiDAcRxYlgVZlvnWUbfbzV9Pt9uN2dlZxGIx+Hw+jMdjBINBiKLIFx6wJQqTixUWFhYwMzMD27ahquobnWqbxLIsDAYD/poOBoMp4RYIBCDLMizLgq7rGAwGUFUVnU4Ho9EIHo+HbzKNx+NIJBIk295wSLARBEEQBEHcQKh2IwjiptNqtbC1tYWdnR3eAsmQZRnLy8soFAoIh8NnPu+TTz7BgwcPYBjG1Jw2RVFw+/Zt3Lp1642WN68Ky7JQrVZ5uo21OTqOg9FohG63y2+CIHDpNhqN0G630Ww2uTjyeDzweDwQBIFv0QTARRJLuAmCAFmWkU6nEY1G4fF44DgO30TKto6ypNpgMICmaUilUpifn0csFntrUm2T2LZ9RrjZtg3gRLix18E0TS40O50O+v0+byMNh8NIpVIk295QSLARBEEQBEHcQKh2IwjiXcEwDOzs7KBYLKLT6Zw5nk6nUSgUkM/npyTMaDTC559/js8//xy9Xg9er5dv0ZQkCaurq1hfX0coFPoOn833S6/X47KtXq9zAcTmjXU6HdRqNZimCdu2eQKrXq+j0Wig2+1y2aYoClwuF3Rdh2VZfHOp3+/nbaKiKMLtdiMWiyESiUCSJP65TObJsoxQKITxeIzBYMAXXWQyGSiK8lal2iZhwo21lfb7ff56M9i8O03TeHqQbTANh8PIZDIk294gSLARBEEQBEHcQKh2IwjiXaRWq6FYLGJ/f/+MrFAUBSsrKygUCvD7/fz9tm3j66+/xieffIJqtXpmTtv8/Dw2NjaQSqW+66fzvWIYxlS6jW10nUy3VatVLtVcLhdkWUaz2USlUuFtpm63Gz6fj28eZe2gbNGB3++Hx+OBJEnwer0IBALwer18C6nH44FhGLz11O12Q1VVqKqKYDCI+fl5pNNpmKb51qXaJnEchyfc+v0+X4IAAOPxmD+/4XAIwzDQ7/cBAMFgEJFIBHNzczSz7XuGBBtBEARBEMQNhGo3giDeZTRNw/b2NorFIgaDwZnj2WwWa2tryGQyU8sNtre38dFHH2FnZweSJE3NaUsmk9jY2MD8/PyNXYhwGe12G0dHR6hUKmg0GnyeGEu31et1VCoVGIbBZVmv18Px8TGazSbG4zFkWUYgEODSbDgcwjRNiKKIaDSKcDg8JdfYHDc2j4wtVmAtkwC4iJqZmUEmk0E4HOaP4W1LtU3Clk2wllL2PNnrzV6/8XgMXdfhOA5vI52fn8fMzAzJtu8YEmwEQRAEQbwyen/2Zxj81V9d6WMDP/85Qr/61Wt+RO8uVLsRBEGcSIpKpYJisYhSqYTTv/76/X4UCgWsrKxMSZharYaPPvoIDx8+hOM48Pl8fE5bMBjE+vo6VldXzyxSeFfQdZ0vSqhUKhiPx/yYqqpoNBool8vo9XoQBAE+nw+6ruP4+Bi1Wg2apkGSJL59lG3gHI/HEEUR4XAYkUgEgUAAoVAIiqLwRKLf70c4HIYkSRBFEYqiIBAIQNd1jEYjyLKMbDaLdDoNt9vN57u9jam2SSa3vjLhxiQbe5+qqlwusjbS+fl55HK5S2Ub1W+vBhJsBEEQBEG8Mo7+6I/Q+5/+DKLbfenH2bqO0H/+K2T++I+/o0f27kG1G0EQxDSqqmJrawtbW1u83ZHB5noVCgWk02n+/sFggE8//RRffvklVFWdmtPmdrtx+/Zt3L59G16v97t+Om8MjuOg1WrxVtJWq8WPWZaFVquFcrmMarUK0zTh9Xr55xwfH2M0GsFxHN4eapomBoMBl3bBYBChUIin0WRZhq7rME0ToVCIJ9smW0hZG2owGMTs7Cyi0ShkWX7rU22nmRRubHFCr9fjrbumafKZbbFYDPPz81hYWEAsFpuSbVS/vRreTd1OEARBEMRrQ3S7IWezl36MUS5/R4+GIAiCIE7w+Xx4//33cffuXZRKJRSLRVQqFQAnc9j29vawt7eHcDiMQqGApaUlBAIB/PKXv8QvfvELfPHFF/j0009Rq9XgdrsRDAZx//59fPPNN1heXsbGxgYikcj3+yS/BwRBQDweRzwex927d6FpGk+2VSoVJJNJJJNJACfCslQqoVwuw+PxoFAo8GRWtVpFq9WCbdtwuVyIx+O8TfLo6AjlchmKoiAUCiGZTCKTycC2bZTLZYzHY4TDYfj9fiiKwpcp6LqOYrEIl8uFaDSKVCoFVVVRr9dvRKqNJQBnZmYAnCzuYLKt1Wqh2WyiVqvh+PgY5XIZT548gc/nQzqdxtLSEpaWlvjXhuq3l4cEG0EQBEEQBEEQBPHOwNJq+XwevV4PW1tbePbsGXRdBwB0u1189tln+PLLL7G4uIhCoYB4PI6f/OQn+PGPf4ynT5/i448/xuHhITqdDoLBIDY3N7G1tYVcLoeNjQ3Mzs5+z8/y+0NRFC5vbNtGo9Hg6TYAuHXrFm7dugXDMLg403UduVwOjuPAMAx0Oh30ej0YhgFBEBAMBiEIAjRNQ7PZRL1ex9OnTxEMBjEzM4N8Pg9JklCv16FpGsLhMBRFgc/ng9frhd/vR6vVQr1eh9frRSQSQTQaRafTgaIoNybV5vV64fV6eQpT0zT0ej10Oh1UKhUcHh6iUqng8ePHePLkCX/uPykfwW/bcAF496YLvjqoRZQgCIIgiAtxHAe2bcOyLNi2DdM0oaoqRqMRBoPB1JXSbreL7J/8CRI7uxh4vbBsG/ZvPs+yLTi2g3w+D0VRYJTLCPzy96jF4DVCtRtBEMTVsSwL+/v72NzcRKPROHM8Ho+jUChgYWGBz10rlUr46KOPsLm5CeBkNlggEIAsy4jH49jY2MDCwsJbm456HaiqymXb8fEx35LpOA7q9TpKpRKf6eY4DkzT5DWHpmkAwGevmaYJ4ORrx7aHZrNZ5HI5uFwudLtdaJqGYDDIE4eKoiAYDEIURciyDI/Hg1gsBr/fj2g0imAw+Nan2i6DCbfJbbv1eh0/+fxz5I+r0IJBeDweeL1eeBQPhFO6jeq3yyHBRhAEQRDfI6cF1nlvX/UxXdehaRq/jcdjDAYDqKrK345GI6iqivF4zLdTjcdjGIbBz22aJl8Zb9s2bNvGP3XJ+C1JQvU3g4gd9l/nJDGwvLwMv99PBdp3ANVuBEEQL0ar1cLW1hZ2dna4xGHIsozl5WUUCgWEw2EAJ9s1P/74Yzx48AC6rvOFCB6PB36/H3fu3MHa2hpkWf4+ns4bi2VZqNfrXLj1+31+rN/v4/DwEOVyGd1ul9cfpmliPB5jOBxyOQec1BiiKMKyLBiGAY/Hg0wmw2Ubm+mmKArfZMoWKLjdbni9XrhcLoTDYYTDYZ5o+75TbY7j8FrxsrdX+ZjJt7ZtwzAMqKqKVqsF8V/9XxDZKqIty3Ds3ygiAfB6fZj9TfspQILteVCLKEEQBHHjeV2i6lUcYxuxrgMrjnRdh2EY/KZpGkaj0Rl5NinJ2Oew+7Ysixdm7Db52E7fThdzpzEDAUCSzj0GYKoYJgiCIIg3kVgshp/85Ce4d+8ednd3sbm5iU6nAwAwDANPnz7F06dPkU6nUSgUkM/n8fu///v4vd/7PXz22Wf47LPPUK1WeWpqMBjg/v37WFtbw507d+D3+7/fJ/iGIEkSZmZmMDMzg3v37qHf73PZJooigsEg7ty5A1VV+QyxarUKSZLg8Xh4TTIcDjEej2GaJmzb5ssMjo6OsLe3B1mWkUqlkMlkIIoiv7DIpNxkeq3f76NWq8HlcsHn8yEWiyGRSPDjgiC8tNy6zudexmQt97w/n/f3SRZcEiTJhWAgCMM0YRgGzN/Ui8TVIcFGEMQLoxkW7h92MNBMBBQXPpiLQJHPX/1M3GzeNGl1+u2bCGt7mJRkp5NlTJCxt5MpMlZEPk/UTRZvp4us0++/SJoxBEGY+jP7O/uzKIpwSS4IOCma2fv5cenmtVoQBEG8TVDtdj1kWUahUEChUJhqqWM/c6vVKqrVKhRFwcrKCgqFAn7+85/jt3/7t/HgwQN88sknqNVqfE7b119/jcePH2NxcRHr6+uIx+Pf8zP8briOdEomk0gkEtB1HbVaDUdHR9A0DZFIBOFwGMvLy3z7aLVaRb/fh+M48Hg8AIDxeAxVVWGaJhzH4e28h4eH2N7ehiRJCAQCCIfDcLlcXJhZlgVRFOHxeBCNRhEOhxEMBgEALpcLfr8fiUQC8XgcoVCI39+LMtl+OnmRk31vTb69SMSx80zWWm63m7e/SpIEWZbhcrnOvbFjsiyj8+gx1P0DBBOJ6a+d82bW0W8qJNgIgrg2mmHh3356gD99cIRaT4NlO5BEAamQgj94P4M//K08PC4q1l4lp688vWkS612FxeuZHLtMkjFRdrrVcjIx5jjOlLRib0/LLOD8loHTsmwynTb5uZPnZYXZ5PsZoijywsvtdsPtdsPn8/GWCtZCEYvFEIlEEIlEkP/TP4X3m4cIp9P8yrAgihAnngNBEATx3UK128uTSqWQSqXw4YcfYnt7G8ViEYPBAMDJXKuHDx/i4cOHyGazWFtbwwcffIAf/OAHePbsGT755BPs7Oyg1+vxzZbb29uYnZ3FxsbGS7fxv2g66mU/56of8zKwls3BYIBGo4FGowGPx4N8Po9MJoN+v8+XHnQ6Hdi2zVtzTdPkM8cEQeBiSdd1npLz+Xzw+/3weDyQpJN/A7VaDfV6HYqiYGZmBrOzs5BlGe12G6qqIpFIYG5uDqlUCuFw+NxU23lvJ//MauqLkCTphW8vUm/1XdK5yw1EgS6OXgcSbNeArvgQxMm/g3/xJw/xH4t12I6DkCJDlkQYlo1SW8W//sttPKn08Md/b+OtKtQmJcebIq2u8sOXuB5sKC57y64esjliTHoxCXaeKBuNRlyisVbL05H7y+SYIAhQFAVer/eMNGOfPznb7PSss0kZ97y0GbtqyyQaE2Yej4cXlH6/H+FwGJFIhM8bYb9EpFIphEIheL1eKIoCj8fz3IG/R198gcGTp5Df8i1cxM2B6jfiXeem1m7fJZOiRBRFrKysYHFxEZVKBcViEaVSiR/f3NzE06dP4fP5sLi4iIWFBfzyl7/E3bt38dlnn2FzcxOHh4dQFAWlUgkPHjxAIBDA6uoqcrkcBEG4ttz6Ppi8SDf5Z1mWp/5+1bfXOWYYBur1Ok+xjcdjWJaFXq+HVquF/f191Go19Ho9BAIBuFwuWJYFTdN4uo3VNZIkYTQaYTgcwuPxwO12w+PxQJZlOI6DnZ0dPHv2DIqiIJPJYHZ2FsPhEJubm1AUBclkErlcDtFodCrVJgjCGfHFUmVXuRFvJyTYrgBd8SGIb/m3nx7gPxbr8MoSQt5vFzm7XSL8Hhd6IwP/sVjHf/fJAf7x7yzyzzs91+m0VLrs2Hchsb6v4uSmweTVpMCa/Pt5b1/2mCAIvGWSJcMMw8BoNEK/38dwOJwa4D/59+FwyD+XiazrSs1JkXWayZlmk+efvB/29+dJs8nCkhV+7Ma2PbFkGZNliUQCyWQS8Xict1y8iraG52HrOoxy+bkfQxCvE6rfCOKEydot4JFgmiYkQYDbI11au32XvMrh7a/jcy4jHA7D4/Hg8PAQpVIJ4/GYH9vZ2YEgCEin08jn87hz5w4WFhbw6NEjFItF1Go1ng4vl8tQFAULCwvI5/PweDznyqbTAutFZNbLyC329vsmFothbW0NjuOg1WqhVCqhVCohFAphbm4O7XYb7XYbBwcHqFar6HQ6sKyTTaOGYWA8HqPb7cK2bb74wDAMnkabTPF7vV6MRiM8ffoUT548QTAYxNLSEubm5jAcDvH48WMkk0ksLy9jcXER8Xj8wtrwTYbqt5eHtog+h9NXfHwuQNc0+AIB9DQToiDgdwtJuuJDvBNohoV/8N98glJbxWzYi729Xaiq+psfvt/+EDbdAbiNAVYP/wPckjD1A8rlcsHtdk/JAbfbzd/vcrkulCjEtxuSXre0etFjLwqL8LMB/eztYDBAt9vFYDBAr9fDYDDgt+FwCFVVoarqVMprcqvl6+KypNmkzGXzPybnZEwWyKf/HUzeFEWB3+9HKBTiLZjRaHRqLkgwGEQoFILP54Pb7X4j/p30/uzPMPirv7rSxwZ+/nOEfvWr1/yI3l3e1doNOFu/QR8hFPQDgoSeZlD9RrwznK7dNG2EWq0Ox3GmEjUDy4WkX8T/8W8mkIpFEAwGeYLnu5Bb3wcvKpYu+xjHcXB8fIzd3V3U6/Wp+xEEAeFwGIVCAYuLi5AkCffv38fnn3+OTqcDl8uFUCiEQCAARVGwurqK9fV1PgfspjN5Qfy6t8maT9d11Ot1frMsC6PRCL1eD9VqlbeSmqY5tXWUbYtlXQbsdxPW3SBJEq/PJj8+FothcXER2WwWbrcbkiQhk8mgUChgbm7ue91Aeh2ofns1kGB7Dv/mr3fxr/9yG15ZglcGjivHsG0bPr8PiXgCfc3EyLDwT36x/L1d8SGI74pPdpr4Z//9fSjyyRXPZ9vPMFJHZz/Q5QYkN1yf/1u42ntcKky25F0EK1BO35h4Y4Xg6cGc7Db5d/ZDblICXUUcvWnSavJxvQkC5Twcx+Ftk6clWb/fR6/X429ZgmxSkrHNT6fbIV+nJDvvOZxOlZ2XNpucZ8akMPt+Oy3KPB7P1I0VZmw1vN/v52+ZMGODd30+35VaMgniIt7V2g2YrN9EGGoPg8EQkiQilUrD7XajNzKofiPeCU7XboPBAMfVY9g8qS1AkkQ4kgeO6MLG8EvEzCYAwOPxcOHDfj6FQiEEg0H4fL4r1SSvI1H1KpJbk+MbXhe9Xg9bW1t49uwZ9FOpH5fLhYWFBRQKBUSjUWxubuKjjz5CuVyGIAi8LpBlGfPz89jY2EAymXytj/dV8KKC7LJukvNaLa96E0URjUYDlUoFR0dHaLfbGAwGaDabfFtps9mEpmlwHId3QbBxH6yNVFEU3jGh6zpEUYSiKJAkiX+uJElIpVKYm5tDOp3m/34WFxdx+/ZtJBIJquneAd6+3OJ3iGZY+NMHR7AdB363gOPjKlRVBYTfDJK0bSSSSQx1E3/64Aj/65/kaaYHcaMZaCYs24H8m02Ajn2Bn7ctwCXCFGVYEzF5xnnSa3IoJ/vhw4oftm2RXWE6b6bVeX9nrXtMyJ2WdOfJuskfpG63m8+eOk/qnff2Kh/zpv5wtW2bzxq7TJKx26QgGw6HUwmySTn1fbbfTg6SnRRl7Ar7aaHLrk76/f4pScZaBxRF4V/nSeHJ2jTZTDMmztgvJewXkkAgMDVElyCIV89k/WZpQ/T7AxiGAVmWcVw9Ppkt6FWofiPeCU7XboZxkuB02OByAbAdB7alw4GEkenwlkVRFGHbNv/5P1lvuVwuhMNhRKNRnrKOxWL870xIvMuEQiH88Ic/xAcffID9/X1sbm6i0WgAOEnuP3v2DM+ePUM8HkehUMA/+kf/CJVKBR999BGKxSL6/T58Ph80TcPe3h7S6TQ2NjYwNzf32l7b0xcbr3u7iPPqflYPXUWSvQzJZBLJZBLvvfceRqMRF2vHx8d8ZlupVMLR0RFarRafsavrOobDIV+S4PV64Xa7EYlEYFkWxuMxBoMBl20AcHBwgIODA7jdbqTTaWQyGezv7+Orr77C7OwsCoUC1tbW4PP5Xuo5EW8uJNgu4f5hB7WehqBHQrVaw3is8asP+pglNTT4wzGUm338fz57ip+uJOHz+eD1eukXKOLGEVBckEQBhmXD7RKRSCZg6MZv2uUsWNZJ25wBEZYjIjeXgcs/PpMAuqwdgP0AnhRj7O3LJLgmr0ppmvZC57joChp7bOyxPu9jmKhhiaeLBN1VZN3pt+zqGpstwQQZKw5OSzI2l+x0q+XpJNmbNKNuMkHGrhiy1/d0SzGTrJNpMva6P69oY593WpoxUcbEmd/v560EBEF8/7D6TbJ09Po9mKbBf/7Isgu1ag3JZBIhRUatp+HBYQc/WYp/3w+bIF4Lp2s3n8/Htyjq+hhjXT9Jswku2HBgj4eod+tcoLFE9Xn113A4xHA4RKlUOnPM5/Px0Qan34ZCoXdKvkmShKWlJSwtLaHVamFraws7Ozu8xbDZbOLjjz/G559/juXlZfzqV7/C3/pbfwsff/wxvvnmG1SrVXg8HqiqiuPjY4TDYayvr2NlZeXc2uNVtVpOwi5Inq5vWbfIZbUvu4D+JuD1erG8vIzl5WXYto16vc6FW61W4zPbjo6O0Ol0oKoqX27V7XbhcrnQ7Xbh9/uhKArS6TRM08RgMICmaTz1ZlkWtre3sbe3B0VREI/HkU6nuWhbXl7G+vo65ufn39gL78SLQb8NXAK74tPvdaEbOsa6DuDbXzJN00Cv14OmG3AHo/j0y6/R2fo2/su2tE3evF7v1N/p6g7xNvHBXASpkIJSW4Xf40I0Ej334yrdEeaiPvy7/9O/glsSeMJpOByi1+uh2+2i2+2i3W5jOBz+psg7e2NXj5gwYlfGJmdaTcb9J9NKtm3DsIGhJwFLckOydPi0OkTnxTdysiQdK4helouEHXsu7D5ZsWPbNn8d2OwHXddhmib/OxuWf3pg/uQMsMmE3+m/v8j7XhRWdLGEGJOOk6Ly9P2cnmV23aKNFYOTGzRZ0oy1vrC3Xq8Xsiy/8PMjCOL7YaCZ0HQDer8H2DZM89v/7xuGCdO0cHxcQTSeggUX+tqr+X86QbyJnK7d2MWmSUzTQLmtIiDo+ElkBmM1wOu2vb2TUR/Pk22nYRfujo6OzhwTRRGRSORcAecNhPCkNrqxW39jsRh+8pOf4N69e9jd3cXm5iY6nQ6Ak3Th06dP8fTpUyQSCayuruLHP/4xvvjiC3zxxRc4ODiAJEnwer3Y3d2Fx+PB4uIiFhcX+ZbM57VaTkqvyYu+py8Wn/7zTQyOiKKIdDqNdDqNH/zgBxgMBryVtFQqoVar4eDgAOVyGb1ej1+U1jQNzWYTsiyj2WzybeyJRAKmaaLdbkPTNH7x2zRN7O/v4+DggNediUQC8Xgcy8vLeO+997CxsYFYLPZ9vyTEK4AE2yWwKz5KIARL1yGJEkywIszhb3XTgjEYYNhpAcEQ//zxeIzxeIx2u33hfQiCcEa6nSfl3G7363uiBHFFFFnCH7yfwb/+y230RgZC3rPyoTc6aT342+9neEHE5MV5sPj1eTdVVU+SZzZwMBCg2QIU0UE+4ED+zcWeyaQVkzWQZHzSkPFFHejqgO0AIhyYsoO7ER33gkMY42+TXezfqqZp194+yt5OMhmvP+9cp1spTw/Kn5Rjp2/P47QEu6h19rJjk++7yv0xacU2WbLhsOzqHkuNTYozdsWVLQLQdZ2/jqdbNp8n8lgR6Pf7uTgLBoN8QQB7y6TZm7IM4PuChtgSN52A4oLiliH6/BgPe5Akaer/047jYDzWUa3X4Q/HEFSoHCZuLlep3VQDUDwe/ONf3MH/4u7v4fDwEIeHh+j1enAcB6PRaGrzNhubYJrmyficczAdAXXbD90R4RZsJMUhXMJJHWPbNlqtFlqt1tTHPzUS2DajGMENCCJckoCYV8LfWAzi73+YQToRRygUeisSP1dptfT5fHjvvfdQrVaxvb2NUqnEL5QeHBzgyy+/hMfjwdzcHP7G3/gbKJVKePToEWq1GkRRRCBwIkK3trawvLyMO3fuIBqNXpgoe5drn6sQCASwurqK1dVVWJaFarWKo6MjHBwc4PDwEAcHByiVSlygDQYDjEYjtFotdDodyLKMYDCIeDyO2dlZ6LrOZ76x3+XZ51WrVfh8Pnz99df48z//c+RyOWxsbOAnP/kJbt++fe62d6rf3g5oycElTG7dSQXkqTZRy7Z/49gcWJ4QxFEbwc/+W/y9//nfxo9//GPoug5VVWEYxit5LC6X60L5Nvn3t+EHDvF2MzYt/NG//3YzW0iRIUsiDMt+5ZvZRmMD//f/tI3/8PUxaoMxTMuBAAdht4P3oxZ+K67Ddepb3rCB/2HPhac9CY4DeCUHkgBYAMaWCFEUcC/jxf/+FxlEQ0GeYmLCxzRNjMdjqKo61UrZ7/f5cP7hcIhutzu11ZIN72cD+1my7LRkuyhZxlJqp99/Hbn2MpwX+59cFnF6WP9kSyp7fCxJx6TZ5DyO6w4TZsk1NtuMyTsm0SZbNQOBAE+bvcyMPPZ83jauW3AN/uqv0Puf/gzicy7c2LqO0H/+K2T++I9fxcMkvgfexdoNmK7fvM4Y3W4Xtm1BN4yp2aGWJwhJ6+K/Wh7in/3X/ztq8yZuLC9au3W73SnZdhpFUTAzM8OlV7fbRa3Rwp89G+KLhoC+IcAGIADwCQaWXW3clhuQhOmaxnQE/KfxHEpWCHAAt2BBhAMbAnRHAgQgJ/XwO55DyJLAZ7+dl4B7lRsbX1er5UUCzDAMHB4eYnd3F6PR6MzCr1wuh7W1NYxGI3zyySfY29s7Wb7n8yEYDHIZt7GxgZmZmVf2OhAnCyuOjo5weHiIra0tLkTZzLbhcIjRaMS/7mxWWyqVgt/vR6/XQ6fTwXA45DX9jwwD71kWBHzbreGSXQgEAkilUphJz5xskBWofnubIMH2HCa3iAY8Imq1+m+23RnQdQO2ywNHcsO799fwHnwMWZbxy1/+En/4h3+IW7du8Ss7rH+b/fn07VX98syGbJ8n39jN4/G8lb9EEm8OY9PCf/fJAf70wRFqPQ2W7UASBaRCCv7g/Qz+8LfyLy3XNMPCv/iTy4vB31mJ4//we3kY4xFPvf0PD1v402djyLAh42TAviD+RgLZDjRbgOEIWBeOsKDv8hllrPWSpclOz/k6Txo9j0k5NrmO/rqr6y8Sc5N/nizYTi92OG+e3ek2TDajjrXnTrahXrZJ8zqwx3ha3E2m3yZTcB6PZ2qhwOkZd6+DVzEH73lS71U/9qM/+qNrFVwAMPjz/x/kbPbSjzfKZQR++XtUoL3FvKu1GzBdv8Ecod066SbgFwIkDxxJRqD0MQLlz/DBBx/gn//zf46FhYXv94ETxGviZWu3q8i2dCaHf/vUxCcHfTiOg5DigghAN030tJPaai1k4T+LNjHodTAYDAAAD/UkHhhpuGDD1gYwTfPbWsUlwYAMEwLel2vYcNcvfZ6KonDZFolEeIKdXZS7zhD/580LftGtls/DcRxUKhUUi0WUSqUzj8Pv96NQKCAQCOD+/ft4/PgxDMOAx+NBMBiE1+tFMpnExsYGzfh6DRiGgWq1inK5jEePHuHp06fY399Ho9GApmlQVZXX0o7jQJZlnmpTFAWtVgv9fh9/v9HEjw0DLI5z+rdzQRAgShI8ogDl934Jn99H9dtbAAm253D6ik9QcaHXaWOk6TAEGbZtwVXbhP/Jn0L4zWwnSZLw4Ycf4qc//Sl+//d/H/H45YNzT1oVxueKt0kp96KD2U8jiuKF8m3y/TR7iHgemmHhwWEHfc1EUHHh/Vc4J2Pyl6OLWlFVw8I//ukc/lcfJDEej9EdqPiv/9+7OO4bCMs2jipHMI2Ttm6eoBIAxxOGMGrB98n/FYJtPlc4T7Yrvo7tRgy2xXQyqTU5J4wN02f/VtnHsi1Mtm1Pza9jV9SGwyE6nZNCdnL7Jxvayj5+cvnE6TbV89pW2fsuGnw7OcODvYbAt20TL8tVlkw8bwHF5CyS7wo2A+VFBd3p9+n/8l/C+E8fwTU7O/F9Lpwp1FjBBZwINjEQgP2bX27Ow1ZVuDKz8P3wh9Rq8JbyrtZuwNn6zQ0LvU4bNgQYgguObUNubiGy9We8fpufn8c//If/EH/zb/7Nk9QAQdxAXkXtdpFs+6uajL+syvBIAsI+N9xu+TfJ0JOfSL2RgZFh4Z/8Yhn/+HcWYRgGjust/Ff/7iEqvTHCbgfNRhPj8fikHjkp2yCKEmxPCD5nhL/hPEDQp1z5AuUkgiDwOaunU3DRaBR+v/9Kkuy7DCqoqoqtrS1sbW1hNBpNHRNFEfl8HrOzs9jb28P9+/cxHA7hcrmmRmWsr69jdXWVfrd7TbTbbZTLZTx58gSff/459vb2UKvVMBqNfhPMObmw4zgOXC4XUqkUMpkM/uDgACvNFhq/2dbrOA4c/OZ3FuDkdxYHSNg2vlYUBIMBrHd7kBXl0t89WP2W+qf/lGq37wESbFfg9BUf03agqUMI4x5m9CPEW49QOtjDcDjknyOKIj744APMzc3ht3/7t/G7v/u7L/0/NcuyuHC7LA33qgawy7J8YSvq5Pvpqgjxqpls75kNe/n2yxORw4bdA0NHRkDQ8T+Tn0AWgWM7iL8yFiHBhsux0Gw2YVpn/z0ILgWi7IH2F/8aQmN7qli6qB1zspg6XVhNCiYmyRRF4YtMmAxj75u8Tcozltg6vUSAySl2v0yksf8XsB/gTMRPbjyaTOOdfuzsuU62r7IbewxM4kxuiXK73bx9k/37n/y88163i4rR8+57sjBmV/9Ot9ie9/6LPv68+z7vcV5V2F32vsmBwM97/q+qQN/4y79Eem8fo9NCQGBF2snX0tProbe2BlEUEXz6FP9/9v47OpL7vhdEP5U65wQ0gEbGRAxnhjnMUKJIWQyisle0ryzfa60ln2uvpb0O5x3rrnf37PF7d/fd+0xR0sqiJa5tWaIoUhJpSSYvLTHPcBgnB2ACMMjonFOl90fj95vq6mqEGWAS63NODzDoRnehQ9W3Pr9PAACuUgGd4vTbpShgBAHgedNqcI3igzy7Aa3zW7VeR7lUglWp0PltYW6maQGzq6sLe/fuxZ133okbb7zxugz2NmFiPUHItrOTU/iv70vI1Bl4hcbckc/nwbAMnA4nHE4H7DY7FvJV9Pgd+PGXb4dN4HDgXAr/6alDsAkc7BYW01PTqFarjfMZBmCWWDaVtUDlBNgOPQVL9jycTicCgQACgQD8fj+dn7Szk36G0hc26WG1Wg2tp36/H16v94qe8yiKgpmZGYyPj2N+fr7leq/Xi/7+fuTzeRw8eBCZTAYMw9BcWofDgc2bN2Pbtm1wOBxX4C/4YKBer2N+fh4nTpzAG2+8gTNnzmBubo66ZYhjRJZl/CnL4VYABbsdAC7Yi1WyqA2wDIOQquB9noeqqrhTksEDlIhjlxZVm6AoAMvC+6lPmrPbFYBJsK0B2hUfl5XDyX0v4MC+Ru6NqqpIJpOYnZ2lBBfDMFSa6/P58MADD2Dz5s0bvuohimJb8k2rjFsvW6oRaaAl5QhxYMLEaqEdtpxWHgsLC4jH4wDIMWRpMOKtAG+B5/jPYCvMQAxvQXbkAXBSBVAahLSiGmRgsDw4uxu533wXtXPvLt3vhSFMq7oiZIm2sXI5NZt+aCNqJXJ/xCKotWkCFwgqYlHVW1aJZVNvFdU+jv57fWmBdrtJjhrJONNmnZH/64midiUI5Hu9pdWI7Frue/J/8veuB9qp7rTP43JqvdXezuj1aDfYa99n2tvpn2e9HdiIgFQUBV+q1nCTLCPBAFBJBQ/9BizHggGDoKLgfZ4DVODGpefXAUBLQfNLFl4AUEUR3BJpZ1oNrk2Ys1sD2vktl1zAyz/9B4jVRjC7qqpIpVJYXFykapdwOIy77roLHR0duOeee9Db23slN9+EiWsCB86l8LUfHwSnyuDUxuwST8ShyI3PFcsyYFkOvM0BRrDhi5tU3L0litNlGx7bH4fXLgCyhEQigVq9BlmSoUJdWv5hAJaHarGjc/oVeIrnm7JfVVWF1+ulVlASM0HmnPUgyhmmkf1mlPvm9/thXyJJLgcKhQLGx8dx9uxZ1Gq1put4nkcsFgPDMDh16hTm5uagqirsdju1jw4ODmJ0dBR+v/+ybfMHEaqqIp1O48SJE3jttddw6NAhSrbJsowviyJukhXEVZWee/BCIwtUli7MyEFFwXscB4YB7pRkcGie3ViWbSpEVEURDMfB8/DHzdntCsBMc10DbAKH2wYv2D1vG/x38LmdeOGFF8AwDMLhMPx+P2ZnZ5FIJKCqKo4ePQpRFDE0NIQnn3wSw8PDq7KNXgoEQaBV2u2gqipVuiyniNPvtI1Amhi1TUB6sCy7bDkDuZgBwyYAoFiVICsqBO6COopAXSoXUVUVkESAsyBXrqGYTILh4lBlCZICMIoMhmXAYWmoUjXEk2ADGCDgsoPt7qbqJT0ZQtRnRIFGSA+taquxTa3EDbleS5LVajUUCoWWrDdt3puejCH3r4U+ew1Ak/JKayPUZp5pbZp6KIpC1XBkWwhWahbVq7RWo2LTfq8nJ7XbaERqEeJOT+KR5i19gyt5nvVFEtrXTvt/PXlGHmM5os2I+FqO1NO+7u0KLvTft0M2FIbsdKJqoF5mgAtlFAwDsS6SF2DF+zVh4npB8/zWgVv6/2c89thjKBQKYBgGoVAIXq8XCwsLyGazSCQSeOWVV3D33XfjZz/7GUZGRnD33XebtlETJpZBsSpBVQGXww4Lz6JcKTctGCmKCkWRIJZLYCzA08++iP1Ig+3cjLL/NkgVpkHMyRIYMPScQFWVxu8yLKAoqJdyqNfrNMuVEAuyLCOTySCTaeQtOp1O+P1+mr1ms9loY3y9Xm+JW1hJAKGqKrLZLLLZLCYnJ1uut1qtbck3r9e7rmpYt9uNm266Cbt27cL58+cxNjaGZDIJoJEzOTExAQDo7OzE8PAwZmdnMTExgcXFRVitVhSLRZw+fZo2V3Z1da3btpm4AIZhEAwGsXfvXuzduxfVahUnT57EK6+8gv3794MfGwfkOp17RUlCXRTBsgzsdgdYhoEsy2BrVQgCD1lW0KhtM3E1w1SwrQNefPFF/PSnP236WalUwtTUFLWNDg0NYcuWLfREfs+ePdi7d+9V74XX2lKXu6yX2sRisbQl37TknFnScG2jKso4NJ1FsSrBZeOxS5f/8dp4Av/zU4fAcwycFh5yvYJKubxEXMhQFBWyIkNSOSgMh5H0fvjFFFRWwJHgh1HhnHCgBpZlwDAXSCVZliHJEiqwgq/l0TP+DFhVpqROO4KDkGqksdLlcsFqtUJRFFSrVdRqNdocSr4Se6YR8aYnzox2w2Qli5B8RA1KmqKI3J/kr2nJIC3pBKCJADIigvTWS0JI6VtQjb4nijPt72vVae0sqEZklJ44M7KKGmWraAkqo9dO+732dSA/NyLYtFgt4aV/fdeKdr/XjtQkBOv/4g/gbrsd80Z2aDCwWK1gWRYhRcGhpWPOrqWGa4eqNlZBl5ShLMOA5VhwLAdIkqlgu8Zhzm7tsbi4iEcffbRlcbBUKlGFgdvtxt13302bpm+77TbTNmriA4u1zm5WnkGtVkW1UkG5XKGqHYXlobACHId/Ai49AcHmQOGWP4Bi88OmVprmNq3Kvc47wFWz8Lz7/wCy2DQn2Wy2tnEf5KvNZkM0GkU0GkUkEoHL5UK1WqUt8WSWI5ZxQr4RAu5S7aHLqd/Ww7KZTqdx+vRpnDt3riUuSBAEBAIBZLNZTExMoFqtNuW0BYNBjI6OYnBw0Iz+uUxQFAXjX/0qaq+/gSTDXMgdXEKstxc2q7Wx8J1IYNznRa1aww3ZLFi0KtgEQbjw/jcVbFcUJsG2Tti3bx9+8IMftJyIaW2jsVgMN9xwA33zr8Y2utLB7GoByYPSKuJIiPpG2FL1JJwRKWfaUq8+VEUZP3yrfYPVb9/Ug6ffm8FzB2dxciEPSVbBcwx4loXPISDgaBw81CX723yugoiTw//rVjuyqQSSySTeiAt4u+AFBwUWiFCUC2ohhmEgsQJkhsdN9hSGlWm62plKpVCpVKiyTNueSS56ckhrsSSrqNqhUK9a0tr/9EUAWhsqOUi2I5u0pFA7CyuxbOoJLCN7pt7Sqd1+PQlmpNprp85a6aK/H4K1ElqrxcXel9F1ehVhu9/Xkm7L/V3a+yFk72qJuiaCTb8ZzAUFWwQM3lJVACpuYxrvUyeahzQCi8UCVlFMgu0ahzm7LY9MJoNHH30UCwsLTT9X1Qu2UavVirvvvpu6Avx+/4q20WtldjNhYjW45NnNaQHLMFDVxqLkfK4KoZ5H7PTPUMpnoSgKsh03It99G1hFBK+K4DmezjKqqkJkeCgMj97yGEKpo0ilUsjn8xBFkW4nyYvVzmLauUt/TGVZFpFIBNFoFF1dXejq6qJEV7VaRbFYbLoQxw5Rz2kVcJdqQ7VYLIa5b6QJdS33LYoiJiYmMDY2hmw223K91+tFtVrFzMwMCoVC0wKyx+PBtm3bsHnz5ibLoYmNwdzXv05bQUVRRDaXQyGfhyzLGB4ZBqmrEmdnYfvQh5DNZMDt2wfIMiT97Mkw4DgWAi+YBNsVhkmwrSMOHjyI733vey2rBpIkUdtoZ2cnbrzxRrrjl1QGtu4tGN19M6IhPx3CVjqYrVSlfTVCVZttqe0u9Xp9XR6P47hVqeGuBVuqnuAwIj2MfrbS9xfzO+0ed6Xb1SQFz0xwOJVjoQKwsSo4RoWsAhWZBQMVAquiLjO0NaeiNH5OAtgtjAw3J4GBiqrCQlJZ7LImMGpJUIKoWKniLWkAC0wAiqqClauAIkMBC4WzQlVV8PExcIeehlyvtRBQ2uyztRA/ZHgjAxcZuoyy24zytcj3ZHvI6679uRHpYnQ/5P9Gg6X2/3rlltFXPZZTdy33e9q/p939ap8XI3us/u9fjoDSrsDqf0er/tL+zOh2q/m+HVmmfWx9Fh75Xkuo6f/edpZW/Xvmf2JZ3MawiJOfaf4FGc8YaAg24Lal7WlLsFmtYGXZJNiucZiz28ooFot47LHHcP78+ZbrRFHEwsICyuUy7r77bgQCAXpd3+AIvIM3QGEtlEQDcN3NbiY+2KiKMv7zsxfaeD02AQLHQpQV5KsiGDBwWDmUaxJUNGa3bEVsarF223h0+exgGYa2iH757gF8fMSJw4cP46233sLZifM4Yt2OvKMbYBiwUg0so4DlBCi8HRzLoM9Swl77LGSxBlEUUavVUCwWkUgkkEgkmtpMtfbR1Vg/yfHc5/Ohs7MT0WgUnZ2d8Hq9YBiGFj0BDTdEuVxuamUnX4n6rV0r+MUq3D0ej6Hyzefzwel0tv29eDyO8fFxnD9/vmUGI+c/8Xgc+XweqqpSl4TT6cSmTZuwbds2uFyui9pmEytDS7BpoagKWObCHKttgc//4pdQZRngeciKAkXj8CBEr0mwXVmYBNs64+TJk/jOd75jmF1GbKN2ux27b7oFZ5kuTChBVFQBKsPAZrUiFvLgwR1RnFoo4I3TCciKCp9DgMBx9GDGMgzu3hTG33x69KIHtat5dZUcuMiFHLRKpRLK5TL9Xp9fdTEKGlVVqbzcYrE0hbyT/5MD9HoQUhf7O5eK5UgwvaWvna1vpeuNrH9a5dW8ewvm/TvAyiI4pY5GjtrS9kGFKLih8DawUhV8LQ8VgGT1QOWIErFBtDFiBYQ8EBLjsBz9GRSxDkW5YFWUVQbq0B4w/TcDNi/AslAVBWo5A/HsW6gf/zWgGLftkuFHrwIjtkj9bVeye+qD7slttIRSO7JOPwwZPYb2++Xy0YygJXq0gfv6v02fPWe0Cky+ktvp89TaNXjpc+f096e9jZ6oJF/JbfXlAfrfN3o8o23R/1ybXbfc/WiVhEb3rX9vGSkP9cUaenUjUT8SMpdlWUR/9nO4TpyAFPCDEmra98tSZqElk0V20yYoqoLgyVOAooCVZaht3jcMAM7jAWASbNcqzNltdahWq/jOd76DU0vtunqUSiXE43HcdNNNCEY6caTsxXjVhZLCgxessFstCLsbx6q5bAUq0EJEXO+zm4nrE99/YwLfffUs7AIHj7011mYmU0auIsJrF9Djd0BRVcxmKyhWL8xYqqrC57CAYWD4OVAUBYlEAkePn8Q/v3UehzMCygoPBQygyGDrRfgL5zAozyDWHUUwGITL5WqZRQqFAhYXF5FMJpHNZpHL5aAoSlNp01ojbaxWKyKRCEKhEDo6OhCJRJqseOTvIyo2i8VC1W+lUonaT4kKrlarNdlO9STcxdgzBUFYtvmU53lUq1WcPXsW4+PjKBaLTb+vKI229UwmQ3OCbTYbjSLp7+/H6OgoQqHQmrfNxPJoR7DpoSXYcs8+R1tCL6BxTkXfl6oKhudNgu0KwSTYNgATExP45je/SfPXtFBVFfFUBhOBW1ELDgFgYIEMFioUMJAYAeAFyGqj4rpayCw16CydfHEsREaAAg53+ku4p/tCiycJY9eGspPHJCRBTZTx3Ik0Xj5bQKrcCJLnGCBg57C334H7h5zgGGMiaCNUTav5/XYQRRGiKNK8q3q9bvj/9YCqqnT1Shser/2enAS3s90ZZUrpCSvtddrnpd1z1c46qCfojNRNlxMqyyO18wuQbT5wtfzSQUBzPRhIjiDA8oAsAflFSJLYqGe3e8BaXQDLAQwHqDLk5HmIZ95E9di/GRJl9O/lBHCRIbBWB5RaGdLiGTCKZKiOWk5xpCdI9KSk/rGNvidoRwK1uw8jVZUReaMn7PT3aURM6cki7f+1tgftxWKx0OFSb0k1IpvaXbTFBkbEknY7tY9lZKnVX6/92/RkVLv71j+2/v9k5Vn//GnbZo3+/uWIu/XC3Ne/jvzzL4Bdwc6h1OvwPHA/XHv3ovj66yi/9x6kuXmwy2S/sG4XlELRJNiuUZiz2+ohiiK+973v4dChQ4bXq6qKdC6P1MBvYZENAipgZWWwABiOQ0m1oCYpsAsc+kNOsLrPOFHufOVDQ/jSnoE1bdv16GowcfWjKsp45PEDmMmUEfU2GjIVRbkQhaGqOJcooSrKsAkcBsNOpFMpKIqKOmtFRQZkFfT9ui3qwSd2dS/7flUUBdPzi3jh7ZM4dWYChUwS9vIi6pUSjZsBGsRXT0+PIdmmqioqlQpyuRzi8TjNxy2VSmBZFk6nk5Jt+jgQbQ6tEViWRSgUQmdnJzo7O9HR0QGHwwFJkmj2rnZRTrtAaV3KQ5VluUn5pr0oimKofCNftTlzq4Xb7W6ym4qiiHQ6jVwu1xKnQ1SBhUKBkpMulwtOpxNdXV0YHR1FT0/Pus4wH2RczPwWf/TRFWc3AFDEOjz332/OblcAV7837hrEwMAA/vzP/xzf+MY3WrzvDMOg0n0zKtYBMGINNnbJDsUy4BkWPBRkRQYqWNQYFaIoQjIIrpYtbrw6WcSJXz4Jto0SB7ggkbZarRBsdkxF7kLWFgWYJWKPUSExLKbKAn6ULuHNE1P4WCgDgb24Hafe0mRkZ9KSQO1UUO1aAldSSemJKXKgJGQcuWhztcgBdSWl22qhP/k3+v9GH5j0J/FGJAjZVr2qxuiivV27+2pH/DAMg7w1gozTD0GVwdnsKJVKEEWxcZ8MwAh2MGi0Q4FhIbEcFKUhs0cxA7mYBSPYGu2fHI/c6/8MceZoWxVUE8GSn77wGnjdbZ97I3JS/3wSkGIBrd1Tq6DTk17ke3Jf+vu2Wq00pJcE9Rr9PQTkMcj9a2+rJcNsNhv9nmGYliFSu21aAo1c9NtBvhcEAU6ns6l0geR3OJ1O+njtFF5GpNNKX020h2vv3jXd1nP//fDcf39j5TS78sqpUigue70JE9cDBEHAV77yFfzgBz/A/v37W65nGAaFjt2YlT2wswqsnAxZVhp2HEWCKLGN2U1SkC7V4bYA1UoVNntjn+60sijWRPzi0Cz+3W29q1aeLWfRm8mU8d1Xz+LkfP6SlHEmTBjh0HQW8XwVHltDuaYoChYXFylZpLACRFkBxzKQFAXlmoRCoXCBZGIAhrOC5y1gVA57PBlsBnBmrNw4L9Et4BEiKRbtwFc+HYWiKEgmkzh58iSOHTuGubk5SrBJkoR0Oo1MJgOLxQKv1wuv10vJNhIF09HRgVKphHw+j3K5TAkk4MLc6nK5WjLOiGvBKJM3m80imUziyJEjABr2TS3h5vf7W+YWRVFoCRaZf8nMFw6HAYBmnhGxAHHtFItFpNNplEolumC4HAlnpOwrFAqYmppq+rkoisjlcqhUKuB5vmmuczqdlAgkt8vn85iZmaGFCENDQ2bhyyXiYua34uuvr2p2E2dnL3XzTFwkTIJtg9DV1YW//Mu/xKOPPop4PE5/LoPFgqULYBhYGQXVaiNHQFXVxpDGW6DYrACjIleRgEoVgHqBiAAAMFDVMkTeiTzvgyU71URyABcILpIFoKoqKrHbUebDYKolqPUyxKX7IlAFG46XLTh14H1YJvdR1YaR6kV/IY+rt0K2+9lGQk8saX9G/k+CULW3Wc7ySC4r3a+WrFrp8cmFKIKMlEJ6gkFPSOh/ZvSY+u+vBKZkL9i6AAsYcAyHYrHYRFwyvAoWANQGwcY0NAFN96FKVaiKBM7uBm+1g1kigEgpAPmqXeUj718yeBAFkvZ5J0MeIbnIV0EQWkhcAPQxtZ8NjuPgdrvh9XppDTyxFCWTySaizeg1Iblv5HHcbjd8Ph8lrozeR0YXcl96Jad2xZnn+SbrM/n7AdBGVBLkq/27tb9DLuT3yPBGQIYzo4tZPrL+IITZxUCp11ccwpR1ysU0YeJqB8uy+OIXvwiHw4Ff//rXTddp5ze1XoYsCKhWq42FFosdCgAWCmSFQbpYBSNIEMU6KpVK49hkESCowFQyj3/61asYjdgMyQX990++t4BXx+JwWDh47BdUDhaehdPKI18R8dp4Av98YGrNyjgTJpZDsdpwughcY34QxTqdDcrlEmTWCplzgGFUACyySxloqtpoIWdYBoxcgSLWoPA2PP9v+3Gcy8Hj8SAUClFSjDSja+cDMmcR++M999yDQqGA8+fP4/Tp00gkEk3Wz3K5DFmWUavVEA6HYbFYUK/XwbIsvX9ZllEoFJDP52mUjyRJqFardI7O5/NgWZYudi4X9E8WVmVZRiaTQTKZxKFDh8DzPMLhMEKhEAKBACKRSFPbqRFUVaWzG5m9SGRNKBSiKjhFUeicV6vVaKlcOp2mrh39PNzOhioIAr3vXC6HRCKBc+fONW0Xz/N0Xieznd/vx/T0NCKRCLZu3YotW7a0/btMLI+Lnd/M2e3qhkmwbSCCwSD+4i/+Ao899himp6cBADnejxpjA680doKyLEOSJXAst7SisZRvoMhQGRYyw0EVK0tWAw3hIKtgWAGJTBHydGNFQq8IaVKfCDYo4a2NnWStBHGpiEFjPgNEEawrAKlzBzLv/gKsKhuexOuJHQK9bdJoZUVrMdPWXhupqox+pldIGSmmyHNBvhr9bLnrliOijNRwWlUckYevFaqqUjJECyPijVyIWknbRqlvplzL/9frNka/8950Hif+ZQw2CwenpXGwLpVKDfIVKlSWh8wAAAswgNUigLd4ll4LFYqypC5jeagsh77uCDjbSNNzQAgxoDFcGOXptTuZ0Q4h5OfkPrXWQDIAERKqXq833Y4MNzzPo7OzEzt27IDH40G9XkexWMTCwgJmZmZQLpdX9b4glgGv14uOjg46rNntdlSrVVQqFdrUSxpQOY5bdtCRJImuouZyOfq+kySJfqbI8+V0OunnlZCNtVoN2WwW1WoVoihSO4M2s7BQKNDf0+4jyHPTjnwjtg0TlwdrXTk1YeKDAIZh8LnPfQ4ulwvPPvss/bl2flNUBblcbul4ZYVUq0HlBTAMwEKFKKsoqTJsHAeGbeRp1ms1sLwFKhhUJBge8/Woyyp+8nYVtboCJ8uiWKwhlUrDYrXA5XSB4zlYGBZ5ScLP3pnEfQM2uO3WluOaCRMXA5eNB8cyEGUFFr4x83u9XlQqZUiSDHXpcN0YvRTUKhUISzOR1oHCsAJ1zhQKBeRyOczMzNB4G4fDQedFp9NJY24cDgdcLhc9v+A4Dg6HA7t370apVMLs7CwmJiaQSqXoZ4nMoR6PB+FwGD09PXC73VBVFdlsFoIgIBKJQFEUmums/YwEAgE669RqNeTzedTrddjt9hayjcxLRhBFEfPz85ibm4OiKIhGo+ju7kZnZycikQg4jqMzXKVSgaIodPGyHch5BlHBWSwW2O12mo2mVdqR29RqNaqE0263vgGV53l0d3dDlmVks1mk02l6X0CDTJyfn6fCDTJn+3w+BAIBDA8PY9euXejp6aFWVLIIa2J9Yc5uVz/MDLbLgHK5jG9/+9s4c+YM4kIHTjhugKCIYKGgWCpCEhs7L0EQoLACFEcAzFLymlJMAvWGHLpxkt9Q9qi8BQxngeX9J8GlJ+hj6ckhqs4KDKC66/OALEIVqxCl5mwyoowDbwUjWFF5+TtQFsbbKmWMbF56ddVK35PtM7KmkZPzdt9vhCSZkCiXSjZprXjkAKc/2BHF0krPU7ufaf8vCAIdUshgYtSWeiVl3EY5HkAjfF2WJNTqImZyddQVgIcKDypQoUCRGwQcyzT+1jIssCtl7Ei8hEoxTy2ZhBwiq4Okop2sApI8DS2xRl4zMgAuBy2Jph1KyGtNXldCVGnfE9rHdLvdCAQCsFqtkCQJlUoFqVQKhUKBktOreZ1I9khPTw9isRi6urogCAJEUaRkm5Z4035vVMCihf7v0Wcaks+sthSErKoSmzb5/VqtRldhtYo5vQ1VC/I+1l9cLpepfjNhYo0wZ7dLw2uvvYYf/ehHjfxczfymSHWqNGFZFhaHG1Xe2QiWZhsLom5WgpVtqId5gQfLsKjJgMxw+Nv/YSfu2tTRtGhn9P37MwX8t31JWFjAygHVagWJRAKqCnAcC5vN3jieMBxklcUnO3MYdMlNxyCinCH7a3KctNvtdOHJSD1nLnaYaDu7qSoq1QoqlSoSVQaiykBgVPg5EYACRVWhKheygausDR5Owm/xJ1AtFZrmYqKyt1qtNGbCbreDxFeoqtoUoUFmPOACoZTP5zE/P4+pqSlkMpmmHDSGYeB0OhEKhTA0NIRIJAIAyOfzyGazdMGQzEdEKcbzPHUmsCyLbDaLYrGIarWKcrlhcb2U2ToUCmF4eBhDQ0MYHh5GIBBoWjg1uqyUK60oimEmNVkMrdVqTdE55PpqtdrkdCDnGIRgFEWx6ZyHPAcko404I4jaMBKJ0NgQo9ZTv99vWFJhwsT1AnNZ6zLA4XDgq1/9Kr773e/itVPzYNRGoQELQOAbQ5csS5BlCQLHQYICleHAqCosPA+WtUGFulSc2Gi9lCweWKUihiN2yL7Bpp0pGdC0UG1OgOXASLWlLkYjqI3AeMYOhbOuqiDAyPK2XO6SVl2nt9hp/69drdJ+TzKvLBYLJY/0J+Ha74ksnOzotXJyI2LsckKSpKa2VP2FkCIr2WolSUI+n2+qJzeC1WptId60BJzD4YDNZtuQ58EmcHh4Zxe+++pZ5CsibaJiwIDnBfC8AGcNkCoinBYOPruXWqdVVYGiqKjIAKcw2O6s4s7BW+ByuaCqKvL5PFKpFG2W1Q5sLpcLDocDgiA05WeQFTlVVWkpiJbAZRiGDiDar+Q1I8SeEWRZRqlUQq1Wo79L3vt6gpaouex2O/3skvYpbc6G9nNAvqbTaRw/fpyeDHV2diIWi1HirbOzs+32aQc27XtN+/PV2BgIiaYn4kg+Cxm4yKBMBjlyH9VqlX6e25Fw2hVQsrpt9HknK+AmTJgwsV64++674XA48P3vfx+8KtH5jcwrjXmrjlq5AMZth8pwUBVlaci6EJNRr9XB8xzKMo+gTUZ+4gimbJvR19e3rGJlWl4Ex2Xhdlhg4VlYbVbIsoJisQBZbpwwS7IEi9UOlbfD6vLCbq80Wc0ICWEE/cKi9tikXUQhxAYhFdoRc0b5TyauXbSd3RgGDrsDVosVBbkMqd4o93DYGkp+cs6iqirKEsDIKnqUBbgdNoQDPjpjlctlSqJpM88KhQI8Hg+1j2rjOorFImRZbsqCDQQC2LJlCywWC3K5HCYmJqiNlJBB58+fx+TkJKxWK7xeL7q6uqiNs1qtolAooFgsIplMIp1Oo1arIR6P01mMPFYwGITP50OxWEQul6ORGhzH0XnSKE5H75BJJpNIJpM4cOAAgMa54uDgICXcRkZGWvYNZGF2uQuxtxpBO7/po0RKpRIl8ci8S2ZX4qYg+XfAhSicbDYLVVXp3FYqlZBMJuH1ehGJRFAoFDAzM9OyLRzHNbWd6gk4U/1m4lqGqWC7jJBlGX//xD/gO2MWVFgHbCTEnV4vQRQlcE4fKowVjKpAKCfBLu3cgaWdIyNAZXnssiZwe6BCVzy0J/wk/4DsRBcVN465bgKjSIBUo4qeliB/3gpwAqQ3vg9lcbw5J0tzYFjN9yvBSKm10kVv4dOTDu1+Rg5s7Qg5QsRo/09uQ7LQrgRUVaXybj3xpr0QyfalgjxHRuSb9nIxB76aJOPrPzcOas5XRTBg4LByKNckqAA8Nh4MVNTqIvLVxvt1i1fBb/dLYFSZ5oUJggC73U6HhVwu15StUa/X6QomsR04nU46aOgJNEK+kZMIrWqSkH4Amgg7/f0QhSL5P5Hnk4Dder2+bEMVUeCRgbJSqUCWZcOhjRB3+pMcr9dLyba+vj50dnY2qRisVmvb97WqqoYWVL0irt1Jm5boNCLgiMpPq2DTFqFo8+i0llWt8k372pBBr536zel0mvkgJj6QMGe39cHx48fxre98F/ttt7TMb5IsoVatQbE4oVrdAMOAgwJLLQuLIMBqtQBgUFNYSGDx4Q4RD4002t88Hg9GR0fh8/kMH/fAuRT+01OHYBM4OK0X1sRJ0HqhUICsyFBZCxjBii8MSfjiA3eiq6ur5RhVqVyYF7V5m9qv5Ji1UpMiIR2MiDl9nqmWmCP7bD0xZy6OXL1YaXYDAAujoFyXwfE8vHYBgIJaTUS+KgOqgmFnHXdap8GqMnw+H11wI4uf5H2pjVjRFk6RRXKv1wvHcq3XLEvnd4fDgVqthtnZ2aaCBO392mw2BAIBDA4OIhqNwuFwUKVXIpHAzMwMZmZmKFFH8nTJe1rb6kmy0IrFIlXCkc+ANlN5pQu5Hc/z6O/vx8jICCXdPB7Piq+XoigrKuH054xaaM8dtQo3Yu1dWFhAJpOhWXxk/0JiQ7RzGlHKRiIRhEKhJmKePG/tQLLejBRwbnf7wjITJq4GmATbZYaqqviTb/4M/zYDcKoEQZVarq8pLGSLCxzLgGcUQKxAFuvgeAsU3goGQBebwx38JLZsGsaDDz4Iv9+Per3e1haWK5bx396XkKoBbq61jZOgpFrgUCu4vXwAilijO039IKY9cSbSY32750biYog5i8WyrCJI/zNyoNOuXBkRcEbKucudJaUoSlvyjfycKLzWA1rl1XKKOP1zUJNk/POBKfzi8Bzi+Sqtbo94bHh4Zxf+h5t78JN3Zwyvf2B7BPf1W5BJJrC4uGhodSQFASzbCLxNpVK0ZEB7G6KEcrvdEASBZqktpyQg2SOEjCWKP0KCaj8TRl+1dlXSYpXL5ZpWBrUruNpWUkJoEtKOfBaNQIYWQryRz4DFYoHf70cwGEQwGEQoFGp6P+vJN212nfaifU2X2+eQ740yhrQ2hnZWVCNiX7sYQKxZ5ISuHQFHSDy9+k1/MU/wTFyPMGe39cPZs2fxte88hzG2r2V+I4RBzeIFeAs4VYaLkyCLIiRFASwOcCyLPksJH3YtwONyIBKJUOK/t7cXmzdvblm8amfRI5BlCZlMFsmyBKFewI2ZV9Ed7cDWrVtx5513rrnlT6/01pbe6Ik5MgtWq9WWpsV2MHIpkAUifZu2lphrVwRhnmRfHqw0u312dxce//VR/Ho8jVy9YZEm19+/LYwP97CIz8/hxIkTKBQK8Pv9sNtb388Mw9DMWtI2r53fZFmG3W6Hz+eD1+uli6vLvccZhoHb7QYAxONxaiPVziaqqsLpdCISiWDbtm3o7e2Fx+Oh7+tSqYTp6WmcOXMG8/PzyGazjQxGgL5PATTlzlUqFeRyOeRyOepqIOci5HxkuTxpreOHzHKBQAADAwPo7+/H8PAwuru7W3KtV2PvJvurlUg4IycTmd9SqRRmZmboPK4tW8jlcvT8SRsrQkgy8veTr/oyhpU+30T91k4Bt9wsb8LE5YBJsF0BVEUJv//NF3BwoXGCzCsiWDRsBxLbGK581Xn0uFgs8J0oyhwkRYUk1sGLJQxyaexyFcAxFwLd9+zZg7179y6rLPr+GxP47qtnYRc4KvMGyEmrglxZREWU8blRHx4YtBraxvT/JwcSbXaI1upGLsRKpj2ZJuQBIeQuBzHXLt/MiJTQHhTIDn85+6r2e6BZTWNEwBmRdsupitYDJKNrJUXcKncLK4KE1LbkwVlsOJdTIIJDwO3ArpgPNuHCgFQVZRyezqJQleC28dipu540Hi0uLmJxcRGpVMpwm91uN/x+P7UczM/Po1Ao0BMCoEEWkmGPhPmTFTxFUWiTVDs4nU56oCcXp9PZcjtZlg3JN/I9kdWT2nmyQkgIbHLyQog6hmHoiQ553bRk92pAVJsulwsejwdOp9Ow7EE7AGmDh8kq8XKkHM/zyxJxRIGpff20NoZ2JBx5TsgASFaWtfsS8rnV2praqeCIsrFd9pt5ImfiWoQ5u60vzp2fwh98+0XMwQ+gdX5TFQVMJQuO56BY3OB4HlBVsLUCOmozuDsKeN2N44MsN9Q8JPDcYrFgy5YtLa9Tu9mNIF8RUaqJuNGeQuXwv6JUKsHn86GjowMDAwO44447sHnz5g096dQqtvXznzbfSj8Paom55Y5bepUc+Z6oZMjxhqjUycWImDMXUy4ey81mqqpicmYOb56ahcrbMBiLYndfoGV2Gxsbw6lTp+i5QjvYbDawLEvVU/oFS6AxvwUCAdpGSmzb2iB+PVRVRbFYRCKRwMLCAnK5HFWzEVLH7Xajq6sLO3fuxPDwMILBIJ0BSEP8wsICXcQlllJJkmCxWJra58n39XqdNo2m02mUSiUakUFmLrKASOYY7fdG+ddWqxWdnZ3o6uqiF6KYI58L7TmMUcHXcipSvSWVzGtkfiOvzcLCAubm5ujzLooi0uk0FhYWIIoiFS2Q54LMX8CFRW+j8ynt9huV57X7LDscjrbqN4/HY85zJjYcJsF2hVCTZPyv//RrPH8ijhpjg8owYFQVVrWKzvocemuTYKFg07ZRlG1h5Mo1WCBDjp/F+KkTsNvtiMViTbYnn8+HBx54AJs3bzbceawk82YZBndvCuNvPj0KK7/yAEJOgtuRb8spW0igpj4LQFs7rR3CtCfYRmq59SKEloOeiDOy6ZHbaIsbjCytRt+T+1mOgGunmlvPljBiD1xOEUesjusBlmXbWlG1l+X+RlEU6bC0uLjYlBNBwHEcgsEgDc/N5/N05U1LBLtcLnog5nkepVKJEsKERCakTbuDtCAILaSbx+NZ82BfKpWQSqUQj8eRSCRoU5bW0qq19NTrdZTLZRreS0oGjIY97edIO/QQwk0bNLzS66dvXdU3CZP8P/17V0ty6XNYtKur5L2otaSuZEPVt/xqf5fYbck2kkxHcnKmJ+FsNltbknyl96YJE1cS5uy2/pidX8SfPvY0zohew/ktkDqGcqUC1d8LmbMg4nNjc9iG+PwspqamKPFFTqgBIBKJwO9vkHaBQACjo6NwuVwA1ja7VUtFPP3003jttddQKpUQDAYRiUTQ3d2N2267DZs3bzZcALoSIPt7beA6mf3InKg9NmvdE9psueVsrO3y5bQWVvKV7P+1jeNaMsI8KV8dcrkcpqamIAgCBgYGDMuJ8vk8ZmdnqSp+fn4e8/PzhrMbAEpEMQzTVJCgX7AkzaHBYBBWq5XmEJL3lyRJLa8lyQg7f/48MpkMADQtFgaDQfT09OCGG27AyMhIE9mWzWbpfEbej6VSiZ6blEqltvnIsiwjk8lgcXGRzneiKDY9NsdxTTOaVpigFyeQn/n9fgQCgSanQrvPgfb/hJTTFkmsRMyR7SPnDJOTkzh58iSmp6fpa0KUbqVSCXa7HR6Ph5KZpIGUvG8KhQIKhQJKpRL97ANo6zjSKl/bWc/1n1uWZdsWL/h8PrNMy8S6wCTYrjBe2/cmvv3kL1EHB14V4ZWy4NC8infLLbdgdHQUx44doyeeZAcWiUQQjUabTtxHRkbw4IMP0p2WFivJvL9we++qyLVLhT4jYCWlnJboITkh2pNoQsJVKhVqWdCedBNlj9EB6XJgNTZWcoDTk3PLWVmJ2s5qta7Zxroa8mQ5kGywdnZUct162VJJW6qRFZWQQXa7HSzLolAoUHVbIpEwXBknyjPSNprL5ejQRg7sJGvD4XBQ6ymx4xL7AvlMktehnYqUYRh4PJ4W4m0tB3NFUZDL5ajCLZPJNFWva0EI8FKpRH+nUCi0qOGMLEFalRwpCSHPuVY5tlxGz0rQljgYEXQWi4USWOQ9S+yeejJOf5JGnist4aYn4cjJmp6II3+P3rZA/n6iaNOScCSMuV32m3liZuJKwZzdNgbZbBb/7dHHcCpRgcQILfMbKSgaHh7G9PQ07HY7ent7oaoqzp07h0wmg4GBAUSjURr4brfb0dnZSY/Ng4ODGB4eBsdxa57dUqkUfvazn2Hfvn2oVCo0GqCjowO7d+/Gtm3bKKF3rYGozMmFnNxrF8v0VlZ9vly7/FCgPTGnJyD0llYjYu6DuvhSrVYxOTkJSZKo3VKPSqWC6elpCIKAWCwGnueRyWQwPz+Pubk5xOPxZbNqWZalcwCZC7QFTCzLwuPxwO12IxgMguMarfFkLiUzA7Fu8jyPYrGImZkZTE5O0gVNAPRY7/f70d/fj1tvvRXbt2+nZJuiKEin0y2OClKoQBZ14/E44vF42wXqSqXREkwu+Xyekr/aGddIaKD/npByLpcL4XAY4XAYkUgEHo+naQ7Ufpa0nwt9CV07kk6rGiWfCUmSMDMzg6mpKRrBks/ncf78eczNzYHjOFpCJwgCXC4XOjs7EQgE6N9HXpt8Po9cLkfL3LSxKkQBqf3M6s+VtAS6kRJOP5/Z7fa2xQum+s3EamESbFcBDh8+jMcff3xZIuKGG27AI488gnfffReTk5MAgEwmg6NHj6JcLqOnp6eJUON5HnfddVdb2+hKFryrGSup5oyUc+VymUrM9aq55WysWqWPfpXocmA1oaj6VSWtnHo5co5l2RVz5IxIu7UWHBBV1XJquEqlsm7PKSHdyFer1Uozz/L5PF3B1IKsaJFMNUL+EpAhRJu9RlZDCchjaFuYZFmmB/d2mRgkT0R7WUt9ObEdZLNZSroZ5Z4BaFKIkufDSFWnXS3VfiWWpkAgAL/fD6/XCwAoFotN5He7BlY9iXWxMFLJke+1K5laZah2GNcO5NpBTm/ZJX+P/kLs49rt0NqhCYFNTsjIgKb/XK23+tSECT3M2W3jUCqV8K1vfQvnzp0zvJ6Etn/0ox/F5OQkZmZm0NPTA6/Xi2w2izNnzoDjOAwPD9N9ab1eh8/nowundrsd27ZtQ0dHB4C1z26zs7P4xS9+gbfffhvVahWhUIjaUrdt20bv+4Nw0khIBW3pQzulnL74YaXSB+DCSb5+3tKGvRvlzF1Mhta1AFmWcf78eRQKBXR2dtL3sBaSJGFqagqyLLe4ckRRxOLiIubm5jA3N9d2MVGSJGrtJAowLcjrrm0lJb8HgOalaeeXer2OQqGARCKB8+fPN5FtBE6nE93d3di1axduuOEG9Pf3w+12Q5ZlJBKNrGCS0wYALpcLHR0dCIfDqFQqWFxcRDwep6Sc0UKwJElIpVKUcEsmk2AYpolwW+0iHpl3rFYrent70dfXh1gshs7OTnAcRz8b2s+EkZXb6Hs9tOcZhFgk8yYpaEmlUgAAv9+Pzs5OBINB+P1+DA4O0sUIcv/kXI28TloHFcnDIwQcUcAZFTksZzPXkoTtbKjkXKGdAs5Uv5kgMAm2qwTj4+P49re/vWwb5MjICP74j/8YyWQS+/btQy6Xg6IomJqawqlTp2C1WhGLxZoadnw+Hx588EFs2rTpAzFALQdtI+dqlHPlcpkG0ZN8K62NVXsA0itk2q0qXU6shZgjAx45sLRTy5GD0XIEnFEbq1HhgRZaW6qRGo58vx62VEKEkdeOrHBps8M8Hg+CwSAEQaAyf/2ukoTskoB8EsqrzbCTZZke7LU2U/Jct2uo5flGC6iWdPN6vasmY0qlUpPKjewr9FAUBYVCgVa05/N5KIqyJtKHtJXGYjFEo1H4fL4mwsroKyH4iBJQS0q2a2JdTnGwWpD3M/n7iEqPfB60Vm/tiihZoSYXURTpe8iIgFNVlVpntYo8QsBp7aVE1RgIBOB2u5s+U6b6zcSlwpzdNha1Wg1/93d/hxMnThheT04KP/e5z4FlWbz++uvgeZ4q12ZnZzE5OYlIJILBwUFql5dlmSo6AKCjowPbtm0zDIZfDSYmJvD888/j0KFDqNfr8Pv9cLvdiEQiGB4extatWxGLxa4LYme9Qfb5ROGtJSC0xQ9aUk4b5bDSsUurFDKyvmm/kgU+fbYWOc5cbccLVVWxsLCAeDwOr9eLWCzWsqinKArm5uZQLBbR3d1NCTA98vk85ubmMD8/j8XFRcPnlSwOEhWU9nivhdfrRTgchs/ng91ub5r/yfxG5iJRFJHL5TA5OYlz584hm822PDZZSBsaGsLmzZsRi8WoQ6FWqyGVSqFcLtPb+/1+dHR0IBQKUWVdMpmkzgtCSBk9n7lcjpJt8Xgc5XK5ReW2lrxFlmURi8UwPDyMoaEhDA0NNbUa6xWjRhEl+s8F+b92gbZQKGBubg4LCwtUwJDL5ZBMJlGr1WCz2WiWHsnA6+3thcPhAMdxTe97sp/SztraRVziWioUCi0EXKlUolnM7Z4P7bkPWUDVlq5o1apaGypZMG+nfjP3rx8cmATbVYSpqSk89thjKBQKbW8Ti8Xw1a9+FU6nE0ePHsV7771HiZ+TJ09iamoKoVAI3d3dTQeV5WyjJlaGPuhzuYw5svMmO3Stas4o5FffMnm5ix/0WEk1p23+IbJ5I0ur/gBlVPqwko1VT0ARef9KirjVkjGKolAlVy6Xa8n/IAdWQqQRwoOQctpgZSK/DwaDVAFFSDeihgIurJYWCgVqVajX62AYhg7S7Qgut9vdonZbzQnXWqylZJvJNlar1Zb20OVgtVrR3d1NSbeurq4Vhz3awKdTkWm/EjKOkJjkKyEujQa/SyXltJknZDvJhZBwwIWBjAQQa1d2ta1i+m3SKuGI+k5LwrndboTDYQQCAQQCAVqgsdHqt/wLL6D4+uuruq1r71547r9/Q7bDxKXDnN02HpIk4YknnsB7771neD1ZPPrMZz6Dbdu24eWXX8bExAQ6OjqoGvrcuXNIpVLo7+9Hd3c3DWu3WCzo6elplAMtqd0GBgYu6kRNVVWcOnUKv/nNb3Dy5EnIsgy32w2bzYZIJIL+/n5s3rwZ/f39ZgvfOoAcB7TqG7111chBQY4dK52eGdn1tHOJ3rZHFP2rCbdfb6wml42otCKRCEKh0LL3J8sy4vE4Vbe1yzkjUR8ERtEoFosF0WgUXV1d6OzsbJozyYWo18jjnj59GmfPnkU8HkelUmmakQnZ1t/fj66uLhqzIQgCPRcgx3uWZREKhRCJRBAIBJq2rVKpUIXbctZSYislhFsmkwHLsk2EG9l/rBbBYLCJcOvq6rqofY4RMVer1TA5OYlTp05hcXER9XodiUQC09PTyGazAC7MujabDaFQCB0dHbBarW0/E/rPAiFYyfOpb6NXFAW1Wg2FQoHOvNqvxWJx2c+g3pKqJcG1ZSvkvIjMhsup37TqzUuBOb9dHTAJtqsMi4uL+Nu//VsatGmEjo4OfO1rX0MgEECpVMLbb7+N8fFxABdso2QlKBQK0R0Lz/PYs2cP9uzZs2abn4mLAyEOVpMxR0g5QsyRlTNtppw+6JecsBup5S6XjVWL5fLltEoeQshpv29nrVirjdXpdBoOSHplnJEttV6vN610GRE0RMXHMAxVGmjVbzabDR0dHejq6kIsFqPFChzHtZBupVKJPgaxJBSLxaYhm9yndtVOC6vValiosNJKdr1eb7KVtrOW1mo1ZDIZSh4Xi0W6ur4aMAyDjo4OxGIx9PT0oKenxzCP5WJAVlbbqeXINusvWmLOyMq6FmhtR+0yTch9ktdPq4bTPk+qqrbYuAmhqy9jIBmC4XAYnZ2dCIfD8Pv9ayqoaIe5r38d+edfALvCSbZSr8PzwP3o+pu/uajHMbHxMGe3ywNFUfCjH/0Ir7c5sanX68jn83j44YfxkY98BIcPH8bBgwchSRLNaCK2UaCxKEoy0iqVCgKBALq6usBxHFwuF0ZHRy96wVSWZRw9ehSvvvoqJiYmAFwIdQ+Hw01qFq0jwsTGQ1v8oJ0djTLl2uXLrcbGulzpg56Y0yt2CDG31uNLtVrFxMQEZFlum8tGyg88Hg+6urpW/RilUomq2+bn5w2P4ySvjwTzk0wu/fMTCoVoIyexbkuSZOiqyOVymJiYwMmTJzEzM0NVW2QhzuFwIBQKYWBggJJtZHuLxSJUVW06nvf396Ozs9PwuSHqtZWspcRWSgi3ZDJJyfqLsZYCDYUeIduGhobakqRrRTqdxunTp3Hu3DlIkoR0Oo3JyUnMz8832eWDwSA6OzsxNDQEr9drSFSTcyR9tvByIgV9Xp32PEpLwBk9FvmcLfccap0/euJb+9kiCj1CtukJONKQuxqY89vVAZNguwqRyWTw6KOPYmFhoe1t/H4/vva1r6GzsxMAsLCwgH379iGZTDbZRnmeR29vL22jAkzb6LUCUiKgbfQ0Us1ls9mmEFCiOiJV5doDDyEfyEHnShY/aLGcak7fWqQNlyerQ3pyThuM387GSqydhMQg5QXkeSVV7MlkkhKeRtASX0aEnM1mg9vtht/vRzgcpuQHUSqRYY88/+T1JvenJeWI7UQ7EBspDViWbbGY+ny+FYn11VhLFUWh1xFimPydq1VUeTweSrjFYjFEIpHLvi9q10JWrVYNiVBCzlar1RZCbiW1ASEC9eSbtglVayki9gZt2QSAJiurvpmYDM9am2kgEKCr49FoFNFolF6/3Hth7utfR/E3L0Ho7l72ORRnZ+G69yPmgHYVw5zdLh9UVcVzzz2H559/3vB6QrLdd999+PznP4/x8XGcOnUK8/Pz9LOsqiq1jQYCAQwPD8Nms1GbYjQapeqe7u5ubN269aLVZtVqFYcOHcKbb76Jubm5pszKcDiMrq4u9Pf3Y2ho6JotRPiggBxjtDmrRuUP+txhcsxZySlhRMwZqeX05IG+iRLAirlsRuUHa4GiKEgmk1TdZiRYUFWVElwAaOi+fg5xOp2UbItEIi0qMBI7Uy6Xsbi4iOPHj+PYsWO0zV5rVxQEAW63GwMDA+jt7UUgEKA2xlwuh2KxCIZh4HA4EAgE0N/fj4GBAXR2drb9jK/GWkqIuWQySVWC5LEu1lpKbKWEcBseHm6yla4VoihiYmICY2NjyGazKBaLmJ6extTUFKrVKhwOB1207urqwujoKAYHBw3txnrVnFHeIvmMaEUL2kVR7SxP9r3aC1GkahVweuEDATnPIBfiftDPddrPkpbcJmTocs2nWieLOb9dHTAJtqsUhUIB3/zmN3H+/Pm2t3G5XPjTP/1T9PX1AWjsRE+ePIl33nmH7jCIbTQQCKC7u7tp57lp0yY88MADpm30OoM2/FNPxpGBi2TLZbNZZLNZWo1dLBabFHNatRzJTGinlrsSxByApoMWOZBpc0yMcua0RRDa2xoRcna7HTzPQ1EUqmzLZrM0a4uoDAkRQlZ2jZ4PMsSRZiujwZEErpK/iYAE8xOZPHm9yBCnqmrTyrORHYCskmpl6sspFFZrLS0Wi0ilUtT2Wq1W6basZtXNYrGgu7ubkm76fdXVBNJOqiXmSK6f3magV0zq1XL694gRAae3l2uVd9qTI3KCpCWr2zV+kfdBKBRCNBpFLBZDX18furq60NHRgfz/+/+D4kvmgHY9wJzdLj9efPFF/PSnPzW8juQ53XHHHfjDP/xDLCwsYGxsDMlkErlcrin/6dy5c0gmk+jt7aXZVbVaDYIgoK+vjypwSObTxS5SFAoFvPvuuzh48CBSqRSsVitt5wuHw9Q+NzQ09IEpRPiggBASehtrOwsr+X61NlZ98QNxJpCZ1O/3UxGAlphjGAaLi4sAgIGBgUtSUlYqFdpMSpRRetTrdRSLRbrg6fF4WmYQjuPQ2dmJrq4uRKPRttukKAoWFxdx6NAhHDp0CJOTkygWi3RWliQJLMvC5XJRpxGZB6vVKvL5fFM+mN1uR0dHB/r6+hAKhWhOmcPhMPwsrsZaWqlUmgi3dDpNFfQXay0NBoNNhNvF2kqJBXdychKVSoUuOBBSMBKJUEXg9u3bsXnz5ktW0xnlzJHPApnfjFRs2iZ6bSlWsVik51Xtmk+1Crh2zh8tsW2Uyagl34jNuP9X/wr70aPgOjvB8Us5wmh9n5jz28bCJNiuYlSrVXz729+m9k8jWK1W/PEf/zE2b95Mf1ar1fDOO+/gxIkTUFWV2kYLhQKi0WjTgGTaRk1oQRRURmo5EoJPAkPJhRxEtAcgrWKOrN5dDcUPWpAVJO2BTNscpM+Z06vmCNFC/gatokifWUcISG3wPcuyNOTe7XavKktNa4ckfwN5bEEQmoYDkhOmlaZbrdaWgUwQhBal23Jy9NVYS+v1OlKpFNLpNFWACYJAV0hXOkEjtlKicFtPW+mVAiFjtcQcyWvUEnJaYo58/rSkmhEJRz5v2pMj7XthNcpUbdnDn7IcblZVZAUBHM9DIGT1EvnrXDqxMAe0qx/m7HZlsG/fPvzgBz8w/NyRFr1du3bhT//0T1Gv13HkyBHkcjmIooi5uTn6e8Q2qigKhoaGEA6HATTUxn6/nxJvPp8Po6Ojl7SfTCQSeO+993DixAkUi0U4HA6USiVIkkSJNnISbRYifLBBiDktCaFXB2kVc9p8OaLuqVQqSKfT4DiOlkoRkAXGQqEASZIQiURoS6ORYo4QDXobqx6qqiKVSlHCLZlMGv5tpHGeLFwaBdT7/X5Eo1F0d3e35Kdpkc/ncfLkSbz33ns4e/YsisUiJWJIxIjP50NXVxfN4iqXy0in05RoI8+Ny+WC1+ulBCAh28jF4/G0EGKrsZbKstzUVppIJCgRdLHWUpvNhsHBQUq4rdVWWqvVcObMGYyPjyObzWJhYQETExN0P+nz+dDX14fe3l5s3boV27dvb1uQsZHQlnEREYMRSV2tVqm4gZCaxIFE5kASX6I/VwKaz1m0pJx2drPZbPiDcgU3VKvILWUnsywLfukzYbNa6XNkzm8bC5Ngu8ohiiL+/u//HocPH257G57n8Yd/+IfYtWtX089TqRT27duH+fn5JtsokfaSbAHggm1US9SZMLEWaGXYemIuk8lQtZyWmCOrPEbFD9o8hCtd/KCHfsWJEHRkd0oOdmTVVk++kaFN+zNBEKh6jgRPk+w1bXW4HtqCAO1FFMWmzC+t3dC6dJAlajr9Si3DMLTd0u/3w+v1wu/3t1WVrWQtVRQF2WwWyWQS2WyWquDIsLYaC4jH42ki3D5ISgotMacN4dWTctoCCK2aUGsX0q++ajPiCL7K87id5RA3GA+sNis2jWxqbJc5oF31MGe3K4eDBw/ie9/7Xts8qFwuh02bNuHP//zPYbfbcfz4cczOzsLlciGZTGJ2dhYAmmyjHo8HIyMjNGu0Vquhu7sb4XAYDMOgr68PmzZtuugCFFVVMTMzg4MHD+Ls2bOo1+twOp3I5XKoVqs0n8rj8WBgYMAsRDCxJhAygsyHuVwOp0+fpjmDRKWpXawlc4PL5WqKuzGCXrFN5qd2GXMsyyKfz9OSgHq93kKkVSoV5PN5SnCQ8H0trFZrU1FCu89EoVDAiRMncOTIEUxOTtJSLkLQELK8v78fPT094HkeyWQS8/PzyGazqNVqVM0UDAYRCATgdDqbwvwJCUcWS71ebwuxtZK1VFVV+ryQ54bEgVystZRlWfT09FDCbbXWc1VVMT8/j/HxcUxNTSGZTGJychLJZBLVapUWwQwODmLr1q0YHR2lCxFXM4j1VNtITJSFs7OztH2XKA2J20ibM2d0PvQ/sRxuYxjEQea3hn6NZdmmZmpzfttYmATbNQBFUfCDH/wA+/fvb3sbhmHwxS9+EXfeeWfLdWfOnMGBAwdok+GpU6cwNTVFK7O1O17TNmrickNVVUPVXD6fp4SNtma7SXZdqaNgCaCmMJCrJcjxs1Cl5tWfK5Evpw9N1YIMb9o8LULGaS/aNiKyekmytogVVq9O067uCoLQZDfSEirkoE6GWFVVwfN8E8FHWlPJfQqCAIfD0dKCpB3utH//StbSUqmEZDJJBzdS005It5XUEcRWSki3q9lWeqVATsDJ50lvYyWkdzqdbhriPrewiJ31OuIAVFUBNB8bp9OJwcFBAOaAdi3AnN2uLE6ePInvfOc7hu1/siwjl8shFovhL/7iLxCJRDAxMYHx8XFYLBZ4PB4cO3aMNuvV63WcPXsWyWQS3d3d6O/vB8/zqFQq4Hkeg4ODNDR727ZtiEajF73diqLgzJkzOHLkCCX6nE4n0uk08vk8bat3OBzo7e01CxFMrAlVUcah6SyKVQkOCwu/kkO9Umqby5ZOpzExMUHVXnqrnv57bb7cStAqgLQZrOVyuanlWxAEsCxL2+Z5nqf5ulplPsMwNMOQkNFGKBaLOHHiBA4fPkztkARkJgsEAhgcHER3dzcEQUAmk8Hc3BwWFxeb2k1J47jX66XbqYXNZmtRu7lcrjW1lmptpclkEqlUqinD7mKspWu1lZbLZZw+fRqnT5/G/Pw8pqamMDc3R0sHSBvy6OgoduzYcUnW+asRRP1M/nZCxM3NzdFsvd9ejOOGWhUJMNQ9owLgOBY93RfcIOb8trEwCbZrBKqq4plnnsGvf/3rZW/327/927jvvvtafi6KIt5//30cOXKEhpQfPXoU+XweHR0diEajdKdm2kZNaIcfl43HrpgPNmHjq9xXuz1bO9145v1Z/OLQLBbyVUiyApYB/DYWt3QwuMlbQTHXTMxpyTmtao7YWNtly10qMae933YNX1oJOLGSakFyGMiwRwY+beacPkgVuEDmORyOJhWc/nZkwAQuqN2IfUH7eE6nk94XIfccDgfC4TAikQhts/R6vS3D1UrWUmIrTSaT1FpKVp9XYy0l2Rza8oRr3VZ6pTD3V19H4Te/AdPZCVmTA1cX67BZbXTV2RzQrn6Ys9uVx+TkJB577DHD/EpCsnV0dOAv//Iv0dvbi2QyiePHjyOfz2NoaAjZbBbvvvsuJQuIbVQURQwODtKyq0KhgGAwiN7eXnAch1AohO3bt8PpdF70ttfrdZw6dQqnTp3C4uIiDWqPx+NIp9MIBoPo6uqC3W5HNBq95LBzE5eOq3l+s/AsTs7n8fyxBcTzVciKCo5lEPHY8KEBF/Z2qggFfNT6rMVayw9ILMNy2XJ69wQpCCNzECkQy+Vyhtlt5P6Jqsvj8cDlcjVZVX0+H51JYrEYzfUlMxzQINtOnjxJM9u0ZBvQIMi6urqwbds29PX1QRAEzM7OYnZ2lto5K5UKFEWhs5l+EVZ/PsdxHDweDyXcSAM9eV6JtZSQbnprqdZWSog37XN0MdbS1dpKFUXBzMwMxsfHcfbsWUxPT2N6epoqAp1OJ/r7+7Ft2zbs3r0bw8PDF63qvZZQKpUw//X/jPq+fZADAYiSCEm8EC8SCAbgdDSOB+b8trEwCbYNwkYc4FRVxfPPP4/nnntu2ds9+OCD+MQnPmG4I8vlcti/fz+mpqaobfTkyZNgWZbmCBD4/X488MADpm30A4SqKOOHb03hF4fnWoafh3d24Qu398LKX75BzWh7WAaoSgpEWYWVZ2CBDI/LAVFWka+KYBkGd28K428+Pdp2W0kFN1HNFQoFakPQKua0GXPke5I1V6vVWnLlVlP8oL2tEXmn/7+2bYj8X99ItBoYBadq8xwIsUcGPm1TpZbwU1WVlkXoc1AIGUdWS0OhEMLhMMLhMDo6OmipgsPhAMuyy1pLtbbSVCpFm8C0w9pKAxOxlWrbSs3coJVhtlBdPzBnt7Vho8iJ+fl5PProo1SNpgVR/Pp8PvzlX/4lhoeHUalUcOrUKczMzCAcDqO3txeHDh3CqVOnAFywcp4/fx4ulwsjIyNwu930hL+3txfhcBgsy1KFyKXs+8rlMo4ePYqJiQlks1lYLBa4XC7E43EkEgn4/X50dXXBZrNRVcoHycZ/NeBqn98kWUGhJi3NbiwiHissHAdRVujsdke/F7+/3QqHzWJIrkiShOnpaUiShFgs1mLVXA+QfF1tuL3WukfUQtVqFbIs04xUEuFAssvIrKL93BFSKxAIUHunPrxelmXMz89jYmICCwsLdKGTzHCk2fSGG27Atm3b4HA4MDMzg+npaaRSKUoakjlNEATqDiE2VEK6EUeEFqQESat2IxnBWmspId6ItVRrKyWEG7GVAhdnLV2NrbRQKGB8fBwnT57ExMQELUQQRRE8z6OrqwubNm3CHXfcga1bt27Ie+ZqQrv5TYUKqBeym835bWNhEmzrjMtxgHv11Vfx5JNPLquq+dCHPoTf+Z3faTvcTE1NYf/+/cjlck22UZfLhd7e3qbAddM2+sFAVZTxn589htfGE1BUFR4bD0E3/KxEXG3k9jh4QOA5ZCsy0uU6GAA2VgFXy8PpcCAYDDbyNCoiKqKMr3xoCF/aM7Du2yXLMh1WiL2OXLTlD0QxR+rXi8UiXR3VknLaVkltGyTQTLaR0gI94aZvHrpYUCm57gJcOCBr1XVGxJ+26YgMjKQUQhAE2Gw2mvtGQrOj0Sj8fj+1pJJBlTRzlstl+hhaW2kqlUK5XKZk22qspRaLhVa9m7bS9jAJtusH5uy2OlyO2S2VSuHRRx9FPB5vuY4EqzscDvzZn/0ZRkdHIcsyzp8/jzNnzoBhGGzbtg2KouCVV15BIpEAcME2mkgk0NnZicHBQVgsFhSLRQiCgOHhYdjtdjgcDmzfvv2Ss4kymQy1jVYqFapiTiQSWFxchNvtpoo2l8tlFiJcJrTObwIEjr1q5jePTUChKiFdqkOFCpZhYONURFwCLc0hs9uX7uzF3Z2NjKne3t4WJbqiKJibm0OxWER3d/cVC7aPx+OYm5vD3Nwc8vk8zdQixQ6ZTAbpdBr1ep1aKFmWbSogslgscLvdbVtBa7UaFhYWMDMzg3Q6TRtItaH2HR0d2Lx5MzZv3gyPx4NyuUy3h1xcLhdVp5GwfWKFJfZKLemmdwwYFSq43W5qmW1nLa1Wq02Em9ZWClyctTQQCDQRbt3d3WBZlu4vT5w4gePHj+P8+fPIZDKo1WpgGAZ+vx9DQ0O46667sHPnzqYc8usJ5vx2dcAk2NYR2gOKrCiwMgrcTseGHODeeecdPPHEE8uGvd9yyy349//+37dVeciyjCNHjuD999+HJEnUNprL5Wh+APld0zZ6/eP7b0zgu6+ehV3g4LbzSMTjsNps8Lg9YBhmw4mr5bbHZeUwPz8PWVFQFTyQVQaAClmWwYllCHIFgmBBMBiE1WpFoiSix+/AU1+544paI7QgdoVCoUAbNtPpdJNiLplMUkUAyS0jTVxamym5P6I8I5eLVbet5W8A0FJAoVXske0jIKo4rZVVa2clQcSEeCP5b8SK6nQ6KUlHVHd2ux1utxuCIKBYLFKlWzqdpm2lq7GWmrZSY8x9/evIP/8C2BXIR6Veh+eB+80B7SqGObutjMtJTuTzeTz22GOYnp5uuY5YsgRBwFe/+lXccsstAIBkMolTp04hl8uhv78fg4ODGBsbw/79+2nLILGN1ut19Pf3o6urCwzDIJvNIhQKYWBgACzLIhqNrouKY35+HkePHkUqlYIkSfTEPJVKYWFhAQ6HgxJtVqvVLETYYGjnJY+9dUa/kvObxy6gWqthJidClBUIHAtJliHJCixKFRGPHV6vFwzDYD5XQY/fgR9+6RYszs2gUCi0zWUjAfyRSAShUGjD/6blUCqVMDc3h/n5eczPz7dkv5HM01KpBJ7naZYtaX8XRZGSQD6fDy6XizaEEytrPp/HuXPncO7cOSwsLLRYVsn9dnV1IRqNgmEYFItFVCoVmulmtVppgZXD4aBzm5b0IxdSxqXN4tWeT7IsC4/HQ++PEG+CICCbzRpaS4mtlBBuelspsHZrqc1mw8DAACXcBgYGUC6XMT4+jnfeeQfnzp2jpJ+iKLDZbOjt7cUdd9yB2267jVrsrxeY89vVAZNgW0eQA4pNYFHOpVCr1uD2uBHwBzaEoDh27Bj+7u/+jsqRjbB9+3b80R/90bJDTalUwoEDB2gVPLGNAkB3dzeCwSDduZm20esTVVHGI48fwEymjKjX3iCB0ilABWw2K/yBAKwWKx1+fvzl2zeUuGreHhsSySTKpTIkcKhwDrAMA6hKg2hTZHDlJIAGweR0OMFa7ZDB4d8Nirgh6mhZIdMfvK/G1XVVVZFOp7G4uIi5uTnMzs7S5lXSvppOp2kzZL1eb7I1EPsCcKHyfiMbWbVEn558W8k6qw0HJiShtgmMKOAI0UZuT25jsVjgdDrpCrE2R45YMki+CCHm2mE9bKX5F15A8fXXV3Vb19698Nx//5ruf6NxrW+/iQswZ7eVoScDRLG+tJ9pfO7Xe3Yrl8v49re/jTNnzrRcR2xWAPAf/+N/xN13301/5/Tp05idnYXP56Mk2b59+3DkyBH6u8Q2arfbMTw8DL/fj3q9jnK5jIGBAYTDYXAch02bNqG/v/+SFmFUVcXExASOHz+OUqlE99ccx9EgbkEQ0N3dTZUpZiHC+kM/v6mqimwuC6/H23TsujLzmx3VWhXz8TQqrAMcx4JlGIiSCFkBGFWBQy7A7XLC7w+gJjd+/28/vwu3DgRok6LX60Vvb2/LsTifz2N2dhYej4eSylcaiqIgmUxSdRuJtyAQRRG5XA7pdJqqy/x+P9xuN91+lmURDofR3d2Nrq6ulvbUcrmMw4cP491338Xp06dRKpUoSSbLMliWhdfrpQ3DJOM2l8vR+YuUNBBlG81aXbqQ/xOnBVlYtVgsNOpDn8srCAI8Hg8lCkm2m9VqNbSWrmQrBdZuLWUYBj09PRgeHkZfXx94nsfY2BiOHTuGubk5VCoVarsNh8PYvXs37rnnHgwMDKD44ovX/Oxjzm9XB0yCbZ2gPaDY1Rri8ThEsQ6LxQqn04lIJAKe59f9AHf27Fl885vfbAnE1GJoaAh/8id/suJAMz8/j3379iGVSjXZRh0OB2KxWNMOftOmTXjwwQdXVbFs4urHgXMp/KenDsEmcLByjfdCrVZrSPlZFgLfCDYW7C7UJAV/+/lduG0weFm2RxUrSKXSAACR4VFlHYCqAKoKMEyjHaeSBeQLbUe8xQbW5sLt/CSG7ZUWUs2oYWk5Ek7b5HmlUKvVmlYE9a10xLZKlGFkJZE0SJKhC2jYioiNlXzNZrOUvKvVak2W1XYFEKuBnnjTf12J6CPKPAKtTdbIpqol64jiTZvrxjBM04BIrBPk/9pBkdif+vv7MTAwQG2lRqG7WpgriCauFpiz2/LQkwGyLGF+fgGCICAYDIDnG/v99Z7d6vU6Hn/8cRw9erTlOlVVUSgUIMsy/uAP/gAf+9jHwDAMtUBNTk5ClmWMjIwgFoshlUrhpZdewtzcHIDGseLcuXNIJBIIh8MYGhqCzWZDPp+HIAjYtGkT7HY7PB4Ptm/ffslznCzLGBsbw9jYGERRpIoTVVUbodvz8wBAiTaGYcxChHWEdl5yWnm6MGe1WdHZ0QmHww6AQakmUfLqcs1vNp7BwsI8ijUFosUFlmHA89zSIpgCVQWEeh4Cq8Bus8Pl8aEiM/ibT+/AfdsaqrVcLoepqSlYLBb09/e3HH+r1SqmpqZWXX5wuVGpVDA/P08VblrFFvmsZzIZFAoFWK1W+P1+2kxKQKzX3d3dCIVCTTNsuVzGyZMncfjwYZw5c4aqWgmsVivNbNu0aRNKpRJmZmaoeqxer9MWeYfDAVmWaY5buVymUR3ktuVyuakYglhMCcFOLKaECCNWVr/fT7PnQqEQnE4nzUDOZrNIpVI0Q1JLuBEiUou1Wkv9fj8ikQjNt8tkMiiXyxBFESzLwuVyYcuWLXhwYgKWd941ZzcTlwyTYFsnkAMKo0jIpeIolUqQZRnM0sHEYm0c6BjBtu4HuJmZGXzjG9+gq55G6OnpwVe/+tUV7U+KouDkyZN45513UKvVkE6ncezYMeRyOQQCAfT09FCiwbSNXj/49YlFfP3nR+FzWFApFZDNZSGJIhSSwQWA43jwFitYuwf/12/vpsPPRm6P1y4gmViAJDbk9kTBBqhgQCyLLGxqBbwiQlEVqIoKCSzAWfCAP44Bl0yHAEmSmggWm81Gv66kUuJ5vkn11o6IuxwrqMRKpJXfa3flkiShUCjQ4cHpdMLj8dABxGKxIBKJUGuFfiAVRRH5fL4pX45YWbXNrNoiiFKpRMN+tRZSI2LN6GJ0m9Wo67Q5cXrSzShPTnu/2vvXK+e0agwyNNrtdjidTrqy3Nvbi97eXkQiEbhcLjidTjidTjDf/r8h7t8PS08Plns3mBkYJjYa5uy2PLRkgMPCIZFMIJVMAQBsdhs6Ih1wOBwbQk7Isox/+Id/wNtvv91ynaqqdLHjC1/4Aj71qU/R/VoikcDZs2eRyWTQ3d2NoaEhOJ1OnDhxAq+//jptKyW20Wq1ir6+Pvr6Z7NZhMNhDA4OgmVZxGIxbNmy5ZLnuGq1iuPHj2NiYgIMw8Dj8VDbWa1Ww+LiImRZRkdHB13wDQaDGB4eRiQSuSrUR9citPObhWeQSqeRTCQbERocB7/fj2AwAFllkS3Xm8irjd4eRaohkUiiLCqoC24wUMGyjcUvVQUkWYFQL4CVa43oCKsdgs2B//v3bsPtQxdsn9VqFRMTE21z2S5H+cF6QFVVpFIpSrYlk8mm62u1Gi19kiQJXq8Xfr8fHo+HzqiCIKCzs5NaQbV/a6VSoW2kRmSbxWJBd3c3duzYgZGREUiShHg83pR16/f70dHRgVAoRGdG8hkul8sol8tNecO5XI7GdGhdFpVKhcaAaC/aWZtYUIkajuf5ppxjIh4pl8soFAp0MZiUNWhnvtVaS2VZRqlUok4PUujF8zw+u7CArfkCmI4OmjFnBHN2M7ESTIJtnfDrE4v4q58fRb2QRqXUCDPXPrUsy4LjObg8PghOH/7LZ3eu6wEuHo/j0UcfRSqVanubcDiMr33ta6vKKqhWq3j77bdx8uTJJtuoqqqIRqNNw5BpG7320bwCyqFcriCTyaBer0OWJZB3ssoKYHgrHumv4T/93ic2LK9Kuz12gUEymWrkSAAocS4oYMFCAcBAVQG+loONv0CsVFkb7EoZe+rv4fOf+wyGhoboKmE6naarcyRMv16vN7VskqD+i7GOXgk1nCiKSCQSWFhYwOLiYpOilagIisUigMZqps/no0UmDMMgEAhQws3pdF7UNqiqikqlgmKxiHw+j1QqRds/ta2sZBgjJRCk2YrYEWRZphcSGkyy3fR5c2T7tYQaUU6Q67RfCaG2HMlHrjf6fSO1HBkgSUaJ0+nEn3IcbpRkZJYIOnrhefAcR4k4c0gzsdEwZ7floSUDVFnE4uLi0r6ysQ9gWQ6BgB8eXxD5qrTu5ISqqnjqqafw8ssvG15PynE+85nP4Hd+53foMalcLuPs2bOYm5ujLaLhcBiiKOLAgQN4//336WIFsY1aLBYMDw8jFAqhWq2iVCpheHgY4XAYFosFW7ZsWZf3SD6fx9GjRzE3NwebzQafz9e0AEOUM+FwmBJtbrcbQ0ND6OnpuSojG65m6BVsABq2zPl5VCoVMGAaBIY3AIa3XFYFm9PKQxTrSCSSSMsWqGDBQAHLsGBYFmAY2OUypGqpkQVmccHLS/jjrRJ++zOfaiowIArOdrlsqqpidnYWxWIRXV1d10SeKikxIHZSLSEmy3LTvGS326m6TUuqBYNBdHV1oaurq0mNuhqyraurCzt37kR/fz+dI4nCjmVZhEIhRCIRBAKBFQlwWZYpAUeUbwsLC0gmk00LtCRvDrgwRxEbKlkEJ3O01WqFoii0mKFQKFBijBB5RFFHolC0szwh7wiBR5RuhDgkCttEIgGWZfEfFRU3ShKyVisEnqcOB/3Mbs5uJlaCSbCtE8gBReCAYiaFYqmIeq3eFPjNMAxYiw281YEvblLxP37yw4hGo+u2DdlsFt/4xjeoTcAIXq8XX/va1xAId6yqij6ZTOKNN96gljRiG7XZbC2B4KZt9NqF3iYDAIoiI5vNolAsQhIlKKoCiXdBEPPYmXwZse4oPvOZz+D2229f95Xn1u1Rkc8XkMlmUIMVddaKxgkQAxYKbGIeYr0Oi0WAzFkhMzwGqmfQXzsHALj33nvxmc98BjzPQ1EUeqDO5/PI5/MolUpNCicS8spxHARBoAdkhmFoZbs+mHUt0KrhlsuGu9jnNZ/PU3VbMplsUmoReyhptvJ4PPB6vXTgcDgc6OjoQCQSgd/v3/CTHUmSUCqVKAlI1HKEDM3lcpSsS6VSdNgk8n6SFUIIOD2Ws5ECFxRt7Wyw+sy45Ug3gr8OBvEhuwMLS/t/epul4S/W0wO/328OaSY2HObstjz0ZEC9XkcylURx6bhHjjO8zQmXL4BvPLIbezavbyi2qqr45S9/iV/+8peG15dKJZTLZXzsYx/DF37/P+DoXAHFqgSHhUVAzSM+PwtRFNHX14e+vj4IgoB0Oo2XX34Z58+fB9BsGw0EArRdNJvNUnLNbrcjEAhgdHS0JfPpYpBMJnH48GGkUika6p7NZlGtVqEoCo0jIS3SQGMBaHBwEH19fWYhwiphNL8BgKoqS43baSiqAsbuR9jJ4Uf/460Y6N24fYHxPKlgOp5FQQJURQXDqADDgWeBPr8N+XwO2XIdCsOjM3sUdwZr6Orqwic/+UkMDFzIPVRVdcVctqup/GAtUFUV2WyWkm2JRKJp9iiVSnQ2kmUZPp8Pfr+/aX6z2+2UbOvo6KDuhGq1ipMnT+LgwYPLkm07duxALBaDKIpIpVL0HFYQBLoIu1biUqt8K5fLyGQySKVSTWq3arVKYzkEQWjKDSbzHcnhVVWVKtrS6TS9r3w+T+9PO9NrQdRqpJTFbrfT5yifz+ML2RxuBbCwtFjLcSx4jofD6UAgEKSNt+bsZmIlmATbOkF7QOn02JDJZJDP51Cvi7S5BAAYhx9MOY3OEz/GXXfchptvvhl33XUXent712U7SqUSvvWtb+HcuXOG18tgsegegdS1E7kaVl1FPz4+jrfeegvlcrnJNurz+dDT00MzEXiex969e3HXXXeZttFrDO1aqBp142mU6ipElUEoeQih1BHwPI9gMIhdu3bh85//PCKRyIZvT71eQzyRRFERIDICAAaCUodVrUEBUFM4MAwQkZPYXjm6pHJrYHBwEH/4h3+IQCDQ8liEdCOEG7E8GoXxazO7yCBASDf9hZxIXAwYhoHVar1kNZwkSTRc9oI648LfXSgUUCgUoKoqHA4HfD4fPdnheZ6uYIbD4RWzxy4XiFouHo8jmUzSrI6pqSksLCzQMF9CntZqNYii2JQrp29l1Q7phFwzur0+g86IaPtfQyF8yO7AvK5JDEsDW29vH/w+nzmkmdhwmLPb8jAiA2RFRjaTXTrxq0CWFShWD/h6Dl8aKOLhB+/H4OAgVQGvF37zm9/gJz/5ieF1xXIV54VeiF07wTr9UDSz2z1DHtziq6CQy6CzsxP9/f3wer0AGjm9r7zyCnK5HIALttFKpYKenh709fUBANLpNDo6OjA0NASO4zAwMICRkZFlc41WA1VVMT09jWPHjqFYLKKjowN+vx8LCwtNDaj1ep1mYgKNjKW+vj4MDg6ahQirwHItopVqBVPzSYgK4J1/B/d0qXjggQfw4Q9/eMNITKPtUVQV5xN5lOpyI+RDVWFRRdgFQOVtUBUFltQZhM7/GnaLgKGhIYRCIXpOoT3GZrNZTE9Pt81luxrLD9YKURSxsLBA89uI9RtoLJQSRVg+n4fT6aTqNvJ5YVkWHR0dlHAjcx0h24iyTZ/hTci27du3o6urC6IoIpvN0pnHbrfTRdiL3QcqikJtoIQsSyQSNAuYzOAAqDvAYrHQr2S/5HK5aANtpVKhVlVCUMbjcbo4S87F9YuqZCFdEAR8RZKxWxSxqJn3CHiBRygYgs1ugzWXh/vejyD2X/7LRf39Jq5/mATbOkJ/QCmVSrSauFqroq5wAG+BfOwFKKd+A6/XizvvvBP9/f3o7e3Fnj17MDQ0dMnbUavV8J3vfIc2gRLIYHHCsQNJoQMAg4jPCafdvuoq+nq9jvfffx9Hjx6FJEk4f/48Tp06BVmW0dnZic7OTnqS6vf78eCDD2LTpk2X/PeYuDyoSTK+/vNjeG08AUVV4bEJEDgWoqwgVxEh1mvwVeYQnX0VpUKO/p7D4UA0GsVDDz2Ee++995IH8pW2py7JWMwUUJMUsFDAqTJUhgGjqrCqVTgSJ2GdfhvDg/0t2+J0OvGlL30J27dvX/HxZVluUbqVy+UW0k2bcUYamZxOJ1iWhaqqTc2e7S6XUw1XKpUo2ZZIJJrq5CuVCnK5HCqVCgRBoC1QhMTzer1Nq5hX49AqimKLIo40VpFcuVwuR1dFAVBrArEME3KU5I4QlRxRzBHCjlhaCRFHLKp/HQziww5nK8GGxvvF7/fDbrfDUSpBuflmRP73/w0dHR2mNcrEusOc3VaGERmgohE+XiwUUazLEBUG1onXYJ3cjxtuuAGf+tSnMDw8TPPP1gsHDhzAP/7jPzYtzNDZjY9AUVU4BRaRcBCyAjq73TXox+9vtyKVWITNZkN/fz+i0ShYloUkSXj33Xfx9ttvQ5KkJtsoz/MYGhpCR0cHtXVt3rwZ4XAYNpsN27dvb7HiXQwURcGZM2dw8uRJiKKI3t5euN1uTE9PU6KtWCxCFEXaCg00Fi66urowNDRkFiIsg+Xmt3xVBAPAX1tA/jffhSLVMTIygltuuQUPPfQQRkZGLtv21CUZC/kKqnW50QCv1MGxDLyCips7WHximw///I//gHQ6DZ7n0dfXh66uLgwMDOATn/hEE9m6Ui7b1V5+sFbk83ma3UbyDIELC6VE3aaqKiXbvF4v/bu9Xi8l20KhEHVknDhxoi3ZJggCurq6sG3bNnR2dqJerzct0no8HjoTroeoQhRFSryRc+h4PN6kdqvValSFpiXfLBYLzWJzu91QFIXO3+S5I+pG0laqjQdRFAX/oVzGblFCmuOgonEeoCy5I6w2K+y2BqHoqVZxLhTCyXs/Qo8Dw8PDhgv4Jj6YMAm2dYTRAQWqhGQqg6rCAqoKZuEkiq98D4rUOKHmeR7btm3DHXfcAZ7n0dnZiT179mDz5s2XdPIqSRK+//3v4/3336c/m7QOYsI2DE6VwCsND7w2A2O1VfTZbBb79+/H9PR0k23UYrGgZ8n6RLB582Y88MADpm30GkFNkvHPB6bwi8NziOerLQrHe2Icfv7M0xgbG0MikYAoigAukAZbt27FI488gv7+/g3fnm2uCpLv/BIZ1Q6JEcCrIrxSFhwaw8b58+fR19fXlOEBNAb2hx56CA899NCaCQ1ZlqniS6t004O0EulJt3afaRLqutJltc2dehjVnJN8C22Lk/ZvkSSJklLaKnmXy0XVddqihPUiVjcCqqrSzI5SqYRcLod4PI5EItFkKSADKxnabDYbrFYrOI6jVgWiWqzX6/S+EokELZvIZDL4D6UybmUYLCoKVKgkzgkqAJZh4PP7YREEOIpFzMdieP/OO+ggG4vF0NPT06QMNmHiYmHObitjOXIiU6qiVq0iLCfgHvsV0ok4KpUKvF4vHn74YWzduhXRaBRDQ0Prlvl05MgRfPe736WLH9rZjZPrUBQZNpsd0WgnWJajs9uX9w7g3hiLubk51Ot1dHV1obe3l6pM8vk8XnvtNYyPjzf+bo1t1OPxYGRkBC6XC+l0GlarFdu2bYPNZkMkEsH27dvXRbFXr9dx6tQpnDlzBgAwMjICu92OiYkJenJPFjXI4hFBKBTC0NCQWYjQBivNb1+4vRcH9r2Bv/3bv8XMzAwikQh27tyJm2++Gffff/+6EsUrbc89Iz6ceONFnDp9FoxURZe1jq2bR+D1erFjxw789Kc/xblz58BxHD0eer1efPrTn0Z3dzd9jJVy2a6V8oO1QpZlLC4uUnWbtuSOKLlIMymZ3YgVm5BQ0WgUXV1d6OzshNVqpWTb4cOHcfr06bZk26ZNm9DR0UEXG4HmLF99y+mlQlXVJptpNpulyjSykFoqlejfpVW7Wa1WWqxgtVohyzIl3VKpFBYWFijhlk6n8YVcDjtrdaS086yqQlYU8DwHnuPBsCx89TrOhYL4tS573O/3Y2hoiF5isZi5cPoBhUmwrTOMDigsAzDVPNzZ0+itTSKbTmJsbIyq24DG4PBbv/VblIgKBoO46667MDo6etEfTkVR8MMf/hBvvPEGZLB4x30nKqwDNqVKQ3wtFis6OjroYLiWKvrJyUns37+f+uCJbdTtdiMWi9HByLSNXnuoijIOT2dRqEpw23js1GT0KYqC1157Db/85S+xuLjYJB0npMt9992Hj3/84+tGELTbntnZWTz++ONYWFho+R1RFHH69Gm6aqcfyLdu3YovfelLLQTcWqEn3YjSTQ8t6UaIt+VINz3IkLESCUdIz4v9W7SDDMuysFgsEASB1qfX63XY7XaaqWOxWMCybFNRwrVi6SFtUmRISyaTmJ+fRy6Xa1otBRqvHyHctF/J66pVLbr+8R8hHDyEstOJuihCWrKnktVSjm9k+nlrdUx3duK9O26nq69aRCIR9PT0UNLNVHCYWCvM2W11WI4MuHfEB35yPxIL85iamsLU1BQl5Hfs2IGPfOQjVN07NDS0LguK4+Pj+Pa3v41Std40uwGAJDf2JTarFdFoFBzHN81uhWwaU1NTyGazCAQCiMViTVlUU1NTePnll2kpFrGNlstlRKNRDA4OAmhkqBH1mCAIGB4exsDAwLqcMJbLZRw9ehTT09OwWq3YsmULeJ5vUtGQRmpy4kxgFiIsj+XmN6BRivZf/+t/xf79+wEAO3bswMDAAD760Y9i165dG5Kpa7Q9sizjn//5n/Hcc8+hWq3C4/Fg8+bNcDqdGBoawuHDh/Hee++BYRh0dnZiYGAAPM/jIx/5CG655RZ6/yvlsl2L5QdrRbFYpGTbwsICJefJQmk6nUY2mwUASrb5fD4IggCGYRAMBtHd3Y2uri54vV7UajVqI21HtkWjUVqSIssynT05jkMoFEJHRwd8Pt+GkeHattFisUiJMmIzLRaLkCSJllBpSTcyt5KYELJg2v3ss+iamUWG5xvOhDYUSVCWcdRux8+7osvmJlutVgwMDFCV2+Dg4HVD8ppYHibBtkFoOaD0eLHv9VfxzDPP0A/03NwcJicnkc/nIcsyeKsd2/Y+iN6hTbAwCsJsCSG/F3fccQd27tx5UfJmVVXx85//HE/+5l0cdewGq8qAXEOxUISiKLRxxev1ojPaCUnl1lRFL0kSjhw5goMHD6JerzfZRsPhMLq6uqiyxbSNXl9IpVJ4+umnceTIEcTj8abQVK/Xi8HBQTzyyCOrsmNeCmq1Gn784x/TYVELYoUpFAoYGRlpIXh9Ph++/OUvr4s1WwtJklqUbu1IN6JwI18dDsclDSTrpYZTVRXFYpG2WBF1G9k2SZJQq9XA8zz8fj+CwSACgQBsNhsCgQC6uroQDodX1T51taFWq1HiLZPJ0JYpYlPQWoUFQWgh3gb/9Xl4xsYgBwIAwwBQoaoXrKjk4qlUMO734Vf9/WBZlg5pZGDT7/NdLleTwk1ryzdhwgjm7LY2tCMDRFHE888/j6NHj6JYLOL06dO0aTsY6UTPzj2I9g7AaeFwU38QWzcNX3LA+tTUFP6Xx/4JbzGbwaoyeMhQVQWFQhEMw1CFV3d3F2oy0zS7lctlTE5OIpFIUOVJT08P3acoioJDhw7hzTffpNlEs7OzOH/+PFiWxcDAALq6uuj+bsuWLQiHw3C5XNi+fTuCwfVpocxkMjh69CgWFhbg8Xiwfft2KIqC06dP02Mmsd+T4HMCUojQ399vLt6uEZIk4Sc/+QmeeuopxONxxGIxbNq0CSMjI3j44Ycvm9VNVVW8/PLLeOKJJ5BOp+FwOLBlyxZ4PB74fD4UCgW88cYbtHV2aGgIPM9j8+bNeOihh5oWcVfKZbtWyw/WCkVRkEgkqJ00k8kAuNAoTxo9i8Vik7rN7XaDYRqNs93d3YhGo+jo6IAsy7QgYTmyrb+/H+FwmNotgUaeWyQSQUdHx7oUp6wG9XqdLhITi2kymaSkG9l+LeFGLj3PPgfXyZOoeb1QFQWiJEFaihORZAmy3Pi7grKMw1YL/kmzQG/kFNFnHDIMg+7ubkq4DQ0NXZMzsomVYRJslxlnzpzB448/TgNnS6USpmbmMOccghi9AarNC14QYLNa4GRlDPEZbBWS8LqduP3223HjjTdeVCjp//XD5/H9wyUIiohquYhqrUptS0DjQy8IApxuD1i7B//n53bh/hu629+hDsViEW+++SbOnTvXZBvleR7d3d0IBoN0B2LaRq8fqKqK999/Hz/96U8xPz+PdDpNbXYkIH/Pnj34zGc+s+GrhgcOHMCPfvQjqjjSIpvNYnJyEoODgy3bwbIsPvvZz+Lee+/d0IOcnnTL5/MtgwrQWP1zu91NxNulkm56XIwaThRF5HI5eiHXkcKBcrmMarUKq9UKh8NBbbGklZTYEbxe77o0pV5ukOaqUqnU1NKaTqebVksB4Ja330bv1DQUngPAgMESMan7WxlRxPzgIPbt3oVkMklXmAkEQWgi3YhiTns9OXGOxWLo7u42V0dNNMGc3dYXBw8exH//7/8doihidiGOQwUXMp5BKFYPOI6HxcIjYOdxcxh4eKsPWzYNX5Kl8Zn9Y/jP/3IcvFQDC4Xut9WlAY5lGtanaHcMVYXF33x6B+7b1rDJEftcPB5HvV5HKBRCb29v04luuVzGG2+8gWPHjgFoto06nU6MjIzA4/EgmUzCZrNhdHQUNpsN3d3d2LJly7qp1Ofn53Hs2DFkMhlEIhHs2LEDpVIJ4+PjlGhTVZU2emuPTWYhwsXj4MGDeOKJJ3D06FFYLBbs2LEDgUAAd999N+68887LFv1w/PhxPP744zh9+jQsFgtGRkYQDocBNI5z7777LlVkDg8Pw2KxIBAI4NOf/nRTwdZKuWzXQ/nBWlGpVKi6bX5+nmb+kqKETCaDXC4HhmFoM6nf76dlAtqiBJ7nqbJtfHzckGzr6OhAb28vff0ItA31l3tO0c6pJNYjHo8jk8lQi6miKNi9bz+6JiehcI3CtMb0BqpkI3ltrCzjiNuF79tsyxaZcRzXEs2i/0z5fL4mws20lV4fMAm2K4B8Po+///u/x/j4eCO81r4DcS4MWVGg1ktQwYDhBHBWBwSOQYwvYI91Ghyjwm6349Zbb8Utt9yyph3UgXMp/PE/vYVKMQ+pVkK91gjn1r/8nNUB3mrHYPx13H/TMD75yU9ieHh41Y8zOzuLffv2IZPJIJ1O4+jRo7ThJhaLNbUTmrbR6welUgnPPvss3nzzTVqZTUBe+8997nO47bbbNnSgWVxcxOOPP46ZmZmW60j2i9/vR09PT8t27N69G7//+7+/7s1wy4HI97VKt+VIN63SzW63b/hwKEkSqtVqC/FWLpepnXJxcZGGxQLNlexAY6hyOp2w2+0tNlmn02mYDae/XM3hxJIkUbVbPB7H4uIixFdfhf3kKYiSCFmS6W0ZlmmcILIcPVHMDw8ju20rWJaFKIpIpVKUgE2lUpSwJtCSbWSFVPs+CIfDVOUWi8VMW+kHHObstv6Yn5/HU8/8DP+66Ma05IYky1DrFTAsB1awQGUFWHgOo0EGvzMCBLweDA0NIRqNrnmffeBcCl978n2UCzlAqkFVVYhiHTXdDMda7HB4fPjf7u/H5+7e1XQSl0gkqJLb5XIhGo2is7OzaVsWFhbw0ksv0bgFrW2UWF+Bhm20p6cHQ0NDsFqt2Lx5M2Kx2Loci1RVxcTEBI4fP45qtYpYLIbt27cjk8ng9OnTTRmhJFtJqyY2CxEuDslkEv/4j/+I/fv3I5PJUPViNBrFJz7xicu235iensYPf/hD6kYg20EUUbOzs5icnITX68WmTZtgs9nA8zw+9rGP4YYbbqD3s1Iu2/VWfrAWqKqKVCpFybZkMgmgoXrL5/NU3VatVuF0Oinh5vF4wLIsfD4fJdvcbjfGxsZw8OBBQ7KN53mEw2H09PQgFAo1iUO8Xi86OjoQDoev6PMvSRIqlQqd3/Iv/Hfg4PsQ6yJESaTkGcuwYDmWzm6KqiDe3Y2JaBSJRAILCwt0jhdFcdn9ocViaVtCBlywlRLCbSMaq01sPEyC7QpBURQ8++yz+O4rpzFhGwarSlDAQIQAlVlirlUADMAzwI2WBdxgTdDft1gsuPnmm3HbbbetKpiUVNFPxPNAOQNFVSCJEmr1GiRRurATcQWgFpPwvPv/wO9xweFwYGhoCPfffz/uu+++VT2Woig4fvw43n33XVSrVWoblSSJ+vwJqWbaRq8vjI2N4Sc/+Qmmp6dbShACgQBuuukmPPLIIy0rW+sJURTxzDPP4JVXXmm5TlVVTE5O0qY0vRo0HA7jK1/5CmKx2IZt30oQRbGJcMvn8032WwKO45oIN7fbfVlINz2IlXR6ehrT09OYnZ1FPp+nMn1SJV8qlWCz2eByNfYrpPXJ4/HA6/XC7Xa3XbUjSq7lLlar9apajSalCtlsFgsLC1hcXGyyKWgPvRaLhdpLiSKEXE/yU0hOnP69sJK11OVyUUtpLBYzbaUfMJiz28bg714ex7d/MwZVEqEyDKoKC0VlGmMbw4BhAJ4FHuxj8Vuxxn6JzFPd3d2r/gyS2W06XQJTyTU1TkuSiFqt3sh3tHuhFpLoHnsad9x2Cx588EHceuut8Hq9ABqLYJOTk8hmszQjKRaLNSnQVFXFiRMn8Prrr6NcLjfZRgGgr68PsViMRh5s374doVAIXq8Xo6Oj9LEuFbIsY2xsDGNjY1AUBUNDQ9i6dSvi8TjGx8ebiDan0wmLxUKjVgjMQoS1QZIk/OIXv8CLL76IiYkJmnNnt9tx880347777rsspTuJRALPPfccfvOb36BYLKKnpwd9fX0AGqSvJEkYGxuDw+HApk2bqGJx586d+K3f+i167Fspl+16LT9YK2q1GhYWFjA3N4e5uTk6X1QqFdpKWigUaJQQUbeReYU4E4LBIM6ePYtDhw5hbGzMkGwj53/BYJC+biT/raOjA4FA4KqZTchCcTqdpoVYZH4ji8faUgVBEGiOWyqVoqVXxJLK83zb/dBK1lJiKyWEG2krNfdrVzdMgu0KoirK+MTf/gYTqSKgAhIjNIYzVYGqKoDKACwLMAw4uYbP2Y7BbWtm+nmex+7du3H77bevONyQKnoOMsrZFD2Ba1QZ11BTOKicgMrBX0A99RuqkvF6vXA6nXC73bjtttvw0EMPYevWrStKxyuVCt5++22cOnWKBmZOT0+D4ziaz2TaRq8/1Ot1vPDCC3jppZeQSqVaShCi0Sg++clP4iMf+ciG2g/ef/99/NM//ZOhIiyVSmFyctJwpZvnefzu7/4u7rzzzqvmACaKYovSzYh043neUOl2uaG1UCaTSSiKAkmSkE6nkUqlkE6nIYoibDYbbDYbLBYLZFmG1WqlhNtarfDtmlKvJjWcoigol8vI5/OIx+OUeCN5KFrbE2m9IoG8QOOkoF6vI51O0zKNarXaMpQuZy3leb6lrfSDenLxQYA5u60/tKRXrS6iJDYWKBmoUJVG3iKZ3eyMgv/zQ3YIqkSPgzabDYODg4jFYqs6BpLZzcazqBYyLfv+mspCVBjUj/wKtSMvIBqNIhgMYmBgAHfffTduv/12bNq0CYqiYGpqCqlUCvV6HV6vF93d3S1zV61Ww5tvvomDBw/SOAFiG7Xb7RgeHobP50M8HofT6cTo6Cjsdjv6+vqwadOmddvHVqtVHD9+HBMTE+B5Hlu3bsXw8DDm5+cxPj6+VNbVADnWZbPZppgIsxBhbThy5AieeeYZqkiKxWIIBoPwer146KGHsFnXmrgRyOfzePHFF/HSSy/h/Pnz6OjowMjICFiWxeLiIjiOw+TkJBRFwaZNm6gNNBKJ4NOf/nRTftxyuWwfhPKDtUBVVWSzWUq2JRIJqKoKSZLoQmkmk4EkSbDb7bQowefzgeM4mrsdDocxNzeHgwcPtiXbiBIuFArRojGieOvo6Fg3sn49oSgKKpUKCoUCdSuQUgVtCz3HcbBarWBZFtVqFblcDvPz84jH4ygUCrTxlLSbGmEla6nP52si3Exb6dUHk2C7gjhwLoX/9NQhlOsichURqqJAXWpMYlkGqoqlnA0WYFlE8mO4CWcNcwNYlsUNN9yAO++8s204qbaKXpQkiKU8oMhQwEBiBUBV4cifR+X1J5BOxCHLMrxeL1iWhSAIVGUiCAL6+vrw0Y9+FHv37m3KPzBCIpHAG2+8gXg83mQbtdvtiMViTTtX0zZ6/WBmZgY//vGPcfbsWSQSCXpSwDAMvF4vtm/fjt/93d+lq5MbgWQyib//+7/H5ORky3XVahUnT56E3+9HX19fy2fqzjvvxO/8zu9cVObh5UC9Xm9Ruhnlz/E8b6h0u1yQJAnJZJISbuSkqFar0Ty3YrEIt9sNn88Hp9MJnudhsVio2k0QhCar6qU0pV6Najiy8plMJumqOxnG9OUYPM9Tawxpfc3n88jlciiVSqhUKmBZtmXFdDlrKbFxxGIx01Z6ncGc3dYfZHarSQpyFRGAClkSARX0JEdVFShLs9s2IYH/eFvj5DMejzcFgA8MDKCvr29ZUko7uymqCrlahFitXpjdAITERQwk38L4qROIx+MIh8PUlk+slrfeeituvPFGiKKIubk5lMtl2Gy2lkIqglQqhZdffhlTU1MAmm2jwWCQxockEgn09fVRxdPWrVvR1dW1bs93Pp/H0aNHMTc3B4fDgdHRUcRiMczNzbUQbT6fD263G9lstimqwmazYWBgwCxEWAWSySR+9KMfYWxsDIuLi/B4PIhGoxAEAVu3bsWDDz54ye3rK6FSqeCVV17B/v37ceTIEbjdbmzZsgUWiwULCwtN88/IyAglia1WawsRuFIuWzKZRDwev+7LD9YKURSxsLBA89tKpRJ1LBB1G2mb93g8VN3mcDjgcrko2ZZOp3HkyBGcOnWqhWwjv0tuS2YPq9WKjo4OdHR0XPW5iqIo0vmNkG4k6kM7k5OF5HK5jMXFRSwsLNAMX0VRwPM8jQ4xmj+Xs5aSYwkh3Exb6ZWHSbBdQfz6xCL+6mdHkK9KEBUFiiRCkRWojRS2JZtB48Ojshy4SgZ7iq/BZbfS/CI9GIbB1q1bsWfPHkPiS1tFP58pIVcoAooMq1pFZ30OvbVJMKqMQqGAc+fOYXFxEYIg0IGQnKx7PB5wHAefz4edO3fivvvuw44dO9ruCFVVxfj4ON566y2USqUm26jP50MsFqNEhmkbvX6gKApeffVV/OpXv0IqlUIqlWp6L0UiEdx///0tbVDrCUmS8Nxzz+HFF1803L6zZ8/SljT9NnR3d+MrX/lKS4bH1Yp6vd6idDMi3QRBaFG6XS4VU6lUomRbPN4g8mVZpiRRPp9vCtv1er1wOBwIh8OIRCJU+bpSQUO1Wl2xKbUdrgY1HAnlzeVyWFhYoMQbsR3og3UJKUjUjmT4rVQqUBQFFosFgiBQAmA5a6lpK71+YM5u6w/t7CYpKni28blTFBmq2shaZBkWKgAFDFyo4FO2k4gE/bjnnnvAcRymp6ebCoH6+/vR39/fdkFHO7vF81UUSiWItVrT7MaikVWVSqWwsLCAYDBIT4gVRUFnZyeGhoawY8cObN68GTzPo1Kp0LiBWCxmOMOdPn0ar7zyCgqFQpNtVFVVxGIx9Pb20tKeHTt2IBQKIRQKYfv27auKFVktkskkjhw5gmQyiUAggBtuuIGqZcbHx5sINZ/Ph0AggHw+T3OmgAuFCIQMNGEMWZbxy1/+Em+88Qad24LBIFwuF+x2O+69917cfPPNG7oQJYoi9u3bh3feeQfvv/8+VFXF1q1bYbfbkUwmkUgkYLVasbCwgIGBgaZznltvvRX33HMPPW6tlMuWz+cxNzcHt9v9gSk/WCvIczQ3N4fFxUUoSqN0heS25XI5KIoCq9VKyTav1wubzYbOzk6Ew2EUCgWcPHkSY2NjLYuHDMPA7Xajs7MTnZ2dlDR1uVy0HOFqXfDWQ1VVVKtVZLPZJjdHOp1uyo0EGnN7JpOhM142m0W9XqeEG/mqn8GWs5Zq8yhJgYK2bNDExsMk2K4gDpxL4Y9/+D6yFREcw4BjGx80camJDmg0mDBLVgNGrMJ95Cns6m7sbPx+/7LtJSMjI9izZw+6u1vbQEkV/Vwig+f/5WdQkhPg0HpfpGVGkiRMTU3RXA6O4+ByueDz+ajaJBaLYc+ePbj11lsxODhoaH2o1+t49913cezYMaogmp6eBsuy9KBHdiKmbfT6QSqVwtNPP41jx47RHAMCp9OJ4eFh/O7v/i62b9++Ydtw7NgxPPHEE035LQTxeByTk5M020ALm82GL37xi7jppps2bNs2ErVarUXpps3xIRAEoUXpttGkGzkZJKt5+XyetnUSsq1cLtN9DamSDwQCiEQiiEQibavfL6Ypda24Emo4WZZRLBYRj8dpyQSxKRi9roSMTKfTSKfTqFarkCQJPM/T7BBCrLWzlpq20msX5uy2/mid3RioUFGv1SArCqAu5bCxLFSGhaDUcKt4FCM+BizLYtu2bfjwhz+MdDqNyclJ2j7McRx6e3sxMDDQ9vNFZrd8VcT7B/bh9Fu/bpndFEVBLpdDrVbDbbfdhlKphLGxMWSzWciyDJ/PRy2qoVAInZ2dVC1MToT1+yxJkvD222/jnXfeocpZYhu1WCz02Lm4uAi3243R0VE4HA4MDg5iaGho3eIgVFXF9PQ0jh8/TsmSXbt2we12Y35+HmNjY01Em9/vR2dnJyUH9IUIw8PDV6Ud7WrB0aNH8fOf/xzJZBKVSgUOhwNut5u+Vx9++OEVXSyXAkVR8O677+Ldd9/FkSNHkEgksHXrVng8HmSzWZoXl81maVs5ee/29PTgU5/6FFXbrZTL9kEuP1grJEmiM8jc3Bzy+Tzd72SzWaTTadTrdUqaEcLN5XIhEAggHA6jXC7j3LlzhmSbqqqUWOvu7qZzud/vR0dHB0Kh0GVruF1PyHJDwJJIJOjslkwmkcvl6HGA5DCTRWhyzkTcZFrSTf8cLGct9Xq9TYTbaiMKTFwcTILtCqIqynjgG6/jfKoEgWPBMgxUVYEoSo3Q2qUadoblAUUGI1XAHXwa1uQ4+vv7MTo6Cr/fj2g02miua3Oi2N/fjz179qC/v9/w+nw+j2984xuGzYsETqcTe/fuxcsvv4y3336bZkAxDAOXywWv10vVP+FwGJs2bcLevXsxOjpqePDNZDLYt28fZmdnm2yjVqsVPT09VCZMbKN79uwxD3bXOFRVxXvvvYef//znNDSUHFBYlkUwGMSHPvQhfPazn90w+0Emk8H3v/99nD59uuW6crmMEydOIBAIoL+/v2W16N5778VnPvOZ6+J9SGyFWuLNiJyxWCwtSreNDDquVCpN6jZRFKkii9hJiXJW225FyLZgMLhmpRVpkdpINdxKJJxRdftaUavVkE6n6eoyGcyIrYOAnByTQTiXy6FSqVCFGyHdBEFoWiHVW0u1ttKenh5zIeQqhTm7rT+MZjcANGuStntyHBhFgQAJWwoH0cNm0NnZCVVV4XA48KEPfQi7du3CzMwMJiYm6D6YZVn09PRgcHBwWXuUqqp4/vnn8dxzz7VcR1oBRVHExz/+cdx66614/vnncfjwYaRSKYiiCJfLhd7eXpqzOzAwgN7eXnR0dCAWixlaKXO5HF599VWcOXMGQLNt1Ov1YmRkBEBDbTY4OIjBwUG4XC5s3759XcuNiPr8xIkTEEWRzsRWq5VmtOXzeXp74pQgDgp9IcLw8PCGEkXXMtLpNJ588kmcP38e1Wq16TjB8zzuvPNOfOhDH9rQ2ej48eM4cOAAzpw5gzNnzmBoaAiBQIC6bchjk+gZMgfY7XZ88pOfxMDAAL2v5XLZSPmBKIro7e01F5JWiWKxSMm2hYUFSJKEcrlM1W2kad5isdDZjdi5w+EwqtUqpqen6b5EC1mW4XA40NHRgb6+PlqGEAqFEIlErovA/3q9Ti2m5PwolUpRS22pVKI5bqlUCplMBrIsU9KN5/m2FtN21lLy/ieE20rHGxNrg0mwXWH89XPH8M8HzoNlGAjchRNDRVVQr9VBonOVahFQRNReeRxcegIcx8Hv9+Pmm29GJBLB6OgoQqEQrTc3Qnd3N/bs2UMHIC3K5TK+9a1v4ezZs2231W6340/+5E/AcRyefPJJvPzyy0ilUjTc0eFwwOfz0QOSw+FALBbDzp07sXv3bmzdurXlw3vu3Dm8+eabyOfzTbZRYlcg92XaRq8flEolPPvss3jrrbeQTqeRy+Wawp97e3vx+c9/HrfeeuuGHDQVRcGvfvUr/OpXv2ohTWRZxunTp1GpVLB169aW4WpwcBBf/vKXr0syoVqttijdjEh70vypJd42gnRTVRXpdJoSbplMBoqioFgsUrKtVqtRdRu5EBtpJBJZt+GYyP1XIuIkjfp4rdgINZyqqigUCjRHZWFhgdoUtCeYxOZBVIP5fB6SJDU1mwqCAIvFAp7nDa2lTqezSeHW2dlpro5eBTBnt41Bu9lNhQpFUSCKMlSoUGtlMIqEwOlfYsTbmKN6e3tpBEgsFsP999+Pzs5OTE9P49y5c015pcTm006pCwCvvvoqnnzyyZbjmaqqyOVyEEUR99xzD/7kT/4Ek5OT+PWvf423334bi4uLqNVqsNvt6OjogMVigdPpRF9fH7Zv347R0dG2mb7nz5/Hyy+/jHQ63WQbVRQF3d3d6OvrQy7XaD0lttHOzk5s27ZtXUmLer2OU6dOUcJv06ZN2LJlCziOw8LCAsbGxlqItv7+fqrA0xciDA8Pr6nl9YMCWZbxr//6r3jzzTchiiK1orEsS8mOhx9+uO1C/npgYmICr732GmZmZjA+Pg6/34+uri6Uy2WcOXMGtVoNPp8Pqqqiv7+/6fhD8p3J8XO5XDZVVakqq7u7+wNffrBWKIqCRCKBubk5zM/PI5PJQBRF5HI5SriRWYlk7xI7aTAYRK1Ww/z8PCYnJ5vINlK4QMi2gYEBBAIBCIJAF1mvp9dKVVXk8/mmXF6idhNFkT7HJN+c5OER14GWdNMWKrSzljIMg2g0Sgm34eFh01Z6CTAJtiuMXKWOe/6/ryBTFsGxjXIDBmhkdyiNQU2tl6FIEuR8HPnn/g+wqgyO4+hqwODgIIaHh9HT04MHHngADMPgnXfeMbTCAY2mnT179mDr1q1NH5x6vY7vfve7OHbsWNvtFQQBf/RHf4TR0VEUCgU8/fTTeP755zE7O4tSqYRyuQyr1Qqfzwe73b5U2MAiGo2ir68Pu3fvxvbt2zEwMEAPfpIk4dChQzh06BDK5TK1jTIMg0gkgmg0Sm9r2kavH5w6dQpPP/00PUCQQZfkb91666145JFH1nXVW//43//+95uGb6BxUJufn8f58+cxPDzc8vhOpxNf+tKXNtTOerWgWq02EW6EeNHDarVSwo2QbuudlVGr1WhzEzkxJA1NpCiB4zh4vd4mso1cvF7vhg8KG62GY1l2VdlwKxFboihSewcZ3hKJBF0tJYNdMpmkKrdyuQyO4+hqqMVioRftyij5arFY0NXVRVVu3d3dZubRFYA5u20MVprdAEBgZNQqZcj5OJR/+//BYbOgo6MDwWAQ/f39cDgcNAvxpptuwt133w2r1YrZ2VmaDUpA8tPa2RnfeecdPPHEEy2xIYRkJ3bRP/uzPwPP85ifn8frr7+Offv2YW5uDsViEYIg0KIZq9WKrq4u3HTTTdi1a5ehSkSWZRw6dAhvvvkm6vV6k22U53kMDg4iGAxicXERPp8Po6OjcDqd2LRpE/r7+9d1f1wul3H06FGqTNq+fTsGBwcBAIuLixgbG0Mul6O3J2o7SZJw5syZptgK0vLa19dnFiLocPToUTz33HMolUpUyVar1ShhvHv3bnzsYx/bMOXXwsICXn75ZczOzmJqaooSZPV6HWfOnEEul0MwGATP8+jp6WlS1Q0ODuITn/gEPQ6tlMtGyg/C4fCGzaEfBFQqFapum5+fp7ElhGwj+znSLkrspD6fD/V6HfF4HLOzs00FCSQCxG63IxqNUkUjWSyIRCLX7bxBnhOtUyGRSCCTyWB2dhbz8/M0NkRVVXoezvM8VbtpSTeGYdpaSz0eTxPhZtpKVw+TYLsK8N3XzuKbvzmDutQoOFBVgGEAnmXhcwhQJBHpXAE4+W8ovv8L1Ot1qKpKSTZRFOHxeKi0/8Ybb8RnP/tZlEolvPnmm01DhRaBQAB33nknduzYQT8wsizjH/7hH/D222+33V6WZfEHf/AHuOWWWwA0Tir/7d/+DT/72c9w+vRplMtlakvyer1wuVx0kPJ6vYjFYhgYGMDo6Ci2b99OZfmFQgH79+/H5ORkk21UEAR0d3cjGAwCMG2j1xPq9TpeeOEFvPTSSzS3gZwgCIKAzs5OfPazn6XB0OuNQqGAJ554AidOnDC87uTJkwgEAhgcHGxa0WYYBg899BAeeuihD9xKN6kp1xJv7Ug3vdJtvUg3oswgZFsqlYIkSU1WUrLSSQa1YDCIaDRKh+Urte+4mtVwuVyOWkwJ8ZbJZKCqKrUwJJNJZDIZavkgg5rNZqPEGrEWOxyOJmtpOBymKrdYLGYulFwGmLPbxmGl2Y3nGJSqIjrTR3D+3/6BKgyIaqO7uxuRSIRaRqPRKO69915s2bIFDMNgfn6+hfwJh8P0ZFKPY8eO4e/+7u9alMdakm3Hjh34q7/6K6o6TqfT2L9/P43sIPvSQCBALfck54zYPPVOhFKphNdff50eR7W2UaIKYxgGqVSKWpG8Xi+NOVlPZDIZHD16FAsLC/B4PNixYwfNISaKNiOijWVZnD17FqlUil7H8zx6e3vNQgQd0uk0nnrqKczMzIDjOIRCoSbFu8vlwv33348dO3ZsyONnMhm8/PLLOH/+PFKpFLLZLLq7u6EoCs6dO4dUKgWn00kX2bRzh9vtxqc//Wn6nlgpl41k97lcLnR3d5tqnkuEqqpIpVKUbEsmk1RBT1T0hJYgryEh3CRJotllWuWpLMuoVquw2+3o7u6mgf7a+JAPAlGuLUogz9P4+Di17RLLKSHcCLGmt5iSBnoja6lRW6lpKzWGSbBdBSAV7K+OJVCXZdgEDhaOBc8xKFQlsAyD2/o8cJ38Bfa/8Rpl8mVZph8UbR4ayWW799578fGPfxyFQgFvv/020um04eN7PB7ccccd2L17N3ieh6qq+PGPf4xXXnml7TYzDINHHnkEH/7wh+nPVFXFwYMH8dRTT+Hw4cMolUoolUoQRZEGo5IDFyHNYrEYtSMQC+nMzAz27duHdDrdZBt1Op3o7e2lH2bTNnr9YHp6Gk899RQmJiZobhSBy+XCjh078Hu/93vo7e1d98dWVRUvvvginn322ZbVf1EUMTY2hnq9ji1btrQcSLZu3YovfelLG15Zf7WjUqlQhVuhUEChUDAkiGw2Wwvpth6DD5HLLywsYHFxEeVymRYlEPUVy7Lwer10WCMnt5FIZF2b7tYLV4sajig6SZMpacOqVCrIZrNIJBJIp9PIZrO0kZDYFAjpRi5ut5u20Nntdni93qYcN9NWuv4wZ7eNw2pmt7s3hfG/P7wV//qL5/Ctb30Li4uLkGWZNu0RIouc0Hi9XuzevRv33XcfbQGNx+NUnUPg9/sNFdZnz57FN7/5zSa1BwHJzh0ZGcFf//VfN+33isUi3nrrLRw4cACTk5OYn59HqVRCJBKhZATJ1x0YGEBXVxdCoVDTQsX8/Dxeeukl2jCotY2S/KRCoQBRFHHDDTcgFAohFoth8+bN6654np+fx7Fjx5DJZBAOh7Fz505KShJFWzabpbf3eDzYvHkzbDYbzp49i/n5ebMQYRlIkoTnn38eb731FliWRTgcBsdxlKAFGkVrDz/88IY8Z6VSCa+88grOnj2LYrGIbDYLr9cLlmUxMTGBRCIBAOjt7W2KrgEax76PfOQjVCQALJ/LRvLBeJ43yw/WGbVaDQsLC7SdtFQq0XzYTCZDcymJO4FYSUlBViKRaJrbydxEsviGhoYQCoVoMVYoFPpALYrX6/WmBdPTp0/j2LFjmJ6eRiKRQC6Xa+IStGo3EgmiLVXQz4xWq5UqCIeHh01bqQYmwXaVQF/BLisqOJZBxGPDwzu78IXbe8GhQXz9y7/8C86ePYt0Og1JkmiooSRJqFarsFgsYFkWDocD/f39uPfee/HRj34UxWIRx44dowcePZxOJ2677TbcdNNNsFgs+MUvfoFf/epXy273Jz7xCTz44IMtH6bJyUk888wz2L9/P22TqdVqtH1Ue4AKhULo7e1FJBLB0NAQtm/fjr6+Ppw8eRLvvvsuisUitY2S23d3d9P7MG2j1wcURcErr7yC559/HplMBslksqkEIRQK4eMf/zgeeuihDcn8Onv2LL73ve+1ENGktWxmZgZDQ0MtNgKfz4cvf/nLGBoaWvdtulahqqqh0k2b+0WgJ908Hs8lD7D5fL6pGl1rJS0UClAUBTabja6ORqNRRKNRGph7rQxgl0MNp1/FJBdRFOmKMxl0yYp0KpVCOp2m6kaWZelxymKxUMLN6XTS19xut1OVgDbLbSXlSP6FF1B8/fVV/S2uvXvhuf/+i34urkWYs9vGYjWz2/+fvf+OjvO+83vx1zQMpmDQgUHvHQRJUezFkqhCUb1413Icy7YiyZuNt9zk3v3lOifJHzfnJif3ZO+enJxYlr1x2SJZXTYtWTYlFrGToEj0NgBmBgNMxWB6n98f2Od7MQRAUSIlkdLzPof/PBxgKub5PO/Pu2jVy6SxzWbj//q//i/Onj1LNBpFrVZTUFAgiCap5dNgMFBdXc3dd9/Nzp07xfnO4/GsUllJzXCVlZViDrPb7fzN3/zNqvgDWCbSotEo9fX1/Mf/+B9FoZR4PvE458+f5/Tp00xPTzM9PU0gEMBsNtPY2IjRaBQXVRI5Xl5eTlFREQqFgmw2y+DgIB9++CHRaDTHNqpSqWhoaKCsrAyXy0VJSQm9vb2YTCY6OztvuEIom80yMzPD0NAQ0WiU2tpa+vr6BLEoKTwWFxfFzxQUFNDR0UFhYSEWi0XYECXIhQi5GBwc5K233iIej2MymaiqqsLlcolWyby8PO644w527tx5w8+riUSC48ePMzo6KpZqEhlgtVpZWFggkUhQX1+P2Wxetczr7Ozk4MGD4u/rarlscvnBZ49sNovf7xdkm9vtJhQKCXXbSiWvTqcTJVdSoYvH4wEQnzMpRkSn09HQ0EBLS4sg2SorK8V31lcN0us8Pz+P1Wrl4sWLDA4OMjMzg9frJZFICNJyJfEmWUxXKt40Gg16vT5nNpSWPx9nK/2yz24ywXaTQapgD8ZSFOSr2VhXRL4m94N5+vRpfv7znzM7O8vMzAzhcFg01Wm1WkKhEAqFgnQ6TTqdJj8/XxBtmzdvRqFQMDExkTOkrUR+fj5bt25l27ZtnDx5kl/96ldXfcz79+/n61//+ppfVG63m9/+9rd88MEH2O12kskkiUQi549TgrRxqKmpobCwkK6uLpqbm7FYLIyPj+fYRlUqFdXV1aJOXraNfnng9Xp5+eWXGR0dxev15mzt8/PzaW9v51vf+hbd3d03/L7D4TA///nPuXTp0qr/8/v9jI6OipPHyhOGUqnkiSeeYP/+/V/JE/a1IJvNEolEcki3YDC4Jumm0+lWKd0+7d91KpUS7UxOp1Pct0S4JRIJYdsqLi6mrKyMhoYGUZbwWbamfl5IJpMfS8LF4/HrVsNpNBrC4TChUEhkrNjtdrEplYoppOFNygCRLKbS+15YWEhhYSF6vZ7q6mrq6uoE6XalNc7xwx8SeOddlB+jgMkkEpjuP0D1f/pPn+o53qqQZ7fPB9cyu8HyIukXv/gFf//3f8/8/LxQ51dUVLBp0ybi8ThWq1UEtnd1dfHAAw+wYcMGceG4uLjI5ORkzrLUaDTS0tJCdXU1CoUCt9vN//v//r/ionMlpBiPyspK/uN//I+rlkaw/L156dIlTp8+zfDwMLOzswQCAUpLS2lvb6e2tpZkMkl+fj61tbViQSFZSGOxGCdPnuTSpUvigk6yjep0OlpbW1Eqlfh8Ptra2kReW29v7w1XhKfTacbHxxkdHSWTydDS0kJ3d7dQzblcLsbGxlYRbe3t7ZSVlTE7O8v09LRciLAOFhcXefnll5mbm0Ov19PU1EQymRQqSICqqioeeeQRqqqqbuh9p9Npzp49y8DAgJgVpeufubk5HA4H0WiUsrIyuru7V7kUSkpKeOyxxwRhujKXTfpMS5DLDz5fJJNJUdDkcDjEQk/KbpNmR6VSKeYGadbx+XyiYVNaREpkW1NTE62trZjNZioqKqisrLxqicxXBbFYjLm5OQYGBjh//jxDQ0PMzMwQCoVIpVI586HkUpBe45Wkm16vzyHdCgoKaG5uXmUr/bLPbjLBdovCbrfzwgsviBP//Pw8iUQCrVZLYWEheXl5/9RklRR1vlI71O23305XVxc6nQ673Y7P51uTXdZoNGzZsgWFQsGrr7666sS0Ejt27ODpp59ed9AIBAIcO3aM3/3ud0xPTxMKhchms2g0mlU/I2V+1NXVUVhYSGVlJWazWQRur7SNSqScNJCVlJRw//33y7bRWxzZbJbz58/z5ptvCnXMlSUI+/fv58knn7zhJ8ZsNsuRI0d49dVXVyl/EokEIyMjJJNJurq6VlkLN2/ezNNPPy3ntVwjPg3pJhFun5Z0C4VCuFwuFhYWxIZUas+Uvpfy8vJEdltDQwPV1dWioerLSqB+Vmq4bDZLOBzG5/MRCoXweDw4HA48Ho/IhcpkMiIkW7IiKJVKtFqtaKeWyiskJY2knuF//A/C73+A5p9sbOshOTeHcf9dt9yQdr2QZ7ebE5cvX+Z//I//weXLl/H7/SiVSkpKSti2bRs9PT1cvHgRm81GLBajtLSU7du3c+DAAZqbm8X3XiAQEHZGCTqdjpaWFmprawkEAvzN3/wNDodj1f1LJFtxcTH/4T/8h3XjF7LZLKOjo7z//vtcunQJj8dDIBAgPz+fjRs30tXVRTAYJJFIiFKq0tJSysvLKSsrE3lZdrt9lW1UWmZIquK+vj7Ky8tpamqira3thtvFY7EYw8PDWCwW1Go1XV1dIn8NlhfCY2NjOSr2goIC2traqKqqwm63C0uiBLkQYRmpVIrf/e53nDlzhry8PKEKt9lseDwescjasWMH+/fvv6GL8Gw2y6VLl+jv78fr9Yo2XJ/Px8LCAjabjUgkgsFgYOvWravOYWq1mvvuu4++vj7x+66WyyaXH3wxkPLwpPw2iWxbXFwUzcuAuA5OpVKEw2GCwaAok4Hl775oNIper6e5uZm2tjbq6uoE2fZlWKreKKRSKSwWC2fPnuXcuXOMjIwI0vrKOX2l2k0i3aR/BoMhJx6ksbGRu4ZHKJmaIm+FI20t3Kqzm0yw3cKIRCL87Gc/46OPPsLn82GxWFhcXCSTyYiNaGFhIVqtFpvNJqyaJpOJxsZGGhsbaWlpwWQyCVuPXq9fdQEpNYlcvHjxqgPPxo0befbZZ686ZMRiMU6fPs3hw4eZmJjA5XKRzWaF+i4Wi+X80ZpMJurq6qiqqhKecKmpdHR0VNhGi4uLqa2tFRtJ2Tb65UAoFOKNN97g/PnzLC0trSpBqK+v55vf/Cbbtm274cSH1WrlxRdfxOVy5RyXbCdzc3O0tLRgNptz7ru8vJznn39++eJfxieGRMhcSbqtRfDr9fpVSrdPclGWyWRy1G2SrVEi3FKpFAqFQoSSV1ZW0tzcLGwGX0W17LWo4VYOu+shkUjg9/tZWFjAbrczNzeHy+UiFAott2dnszmZIJL1RyrPkPKrHpy10uR2kS1fDrPOW2NpA7fukHa9kGe3mxder5ef/exnnDhxgunpaaLRKFqtlubmZp599lk8Hg/nzp1jdnaWeDxOZWUl+/btY9++fdTU1IgFTygUwmKxMDc3J1QG0u8pLS3lf/7P/4nFYll1/7FYjGAwiNFo5Ic//CEdHR1XfbyTk5O8/fbbjI+PEwqF8Hq9qFQqNm3aRHd3N8lkkmAwiE6no7q6WvyNlpeX43Q6OX78uCDVJduoQqGgvr5e2EbLy8vp7e2lqKiI7u5uzGbzDX/dA4EAAwMDOBwO9Ho9vb291NfXi/P4WkSb0Wikvb2d6upqXC7XmoUIDQ0NNDc3f6UXbIODg7z99tukUinhRAmHw8zOzopWw6KiIh566CFaW1tv6H2Pj49z6tQp0WJbW1vL7OwsTqeT2dlZwuEwKpWKbdu2kZeXt4po27hxI/fee684r18tl00uP/hikUqlRBO6w+EQhUyLi4sEAoGc/ESj0UgqlSIUChGLxVCpVOK7UyLbjEYjTU1NImOysrLyCy3DupkRDocZGhri3LlznD9/nvHxcQKBAJFIZM3l+MoZTqPRkJeXh1ar5blEkr54nCWtFs0/uRi0Wi3afC15mjzxN3Wrzm4ywXaLQwpof+ONN0gmk9jtdqxWq8giKCwspLGxkb6+Pmw2G+Pj4/j9fpLJpFBnSBlEZWVlIqzaaDSuCp31er1MTk5SUlKybv5Ae3s7f/qnf/qx+QTJZJKLFy+K5qm5uTni8ThGoxGj0UgymczJD1Gr1SKbR6vVCtJDp9NhsVgIBAIolUoh6ZZsR/v27WP37t3yl+QtjtHRUX71q1/hdDpxu92i1huWN8w7d+7km9/8JmVlZTf0fmOxGH/3d3/HuXPnVv2f1+tlbGyM4uJi2tracj5jGo2Gp556it27d9/Qx/NVxZWkm6Q2W4t0MxgMOYTbJyHdotGoINucTid+v19YG6XQcI1GQ1FRkWiXra2tpaKiQm5SWoFPq4bLZDL4/X6sVitWq1XYS6XbrhxXpKHtLzV5bFcq8SqVqNQq1Cq1UL5p8/PR/9PF7q06pF0v5Nnt5kYikeDQoUMcOXKEwcFBUYJQUFDA448/zr333sulS5e4ePEis7OzJJNJqqqq2Lt3L1u2bBFEllKpJBqNYrFYsNlsOYuo6upq3n33XcbHx1fdv0Sy5efn81d/9Vds3LjxYx/zwMCAcCNEo1G8Xi9qtZr29nZ6enrQaDRiKVJcXIzZbEan01FUVMTs7CyDg4Ok02kWFxeZmpoSS1NJnbe4uEh7ezvNzc2YzWa6u7s/k+9Xj8fDpUuX8Hq9lJSU0NfXl2MJ9Hg8jI2N5RBpBoNBWGT9fv+ahQhSk+FXtRDB5/Px6quvYrfbMZlMImh+dnaWubk5Mb9t2LCBgwcP3tD31m63c+zYMRYWFshms3R3dzM3N4fFYmF6elqomSTy9soykIqKCh5//HGxnF+Zy9bQ0JBjX5bLD24ehEIhQbZJbcgS4bayVVmj0ZBOpwkEAsTjcfLy8jAajSiVSkG2mUwmmpubaW9vp62tjcrKylsqm/fzRjqdxmazMTo6yrlz5+jv78fpdBIMBoXS7UqqSaFQ8BdqDduVStz/5F5Q/ZPlVMrr1el0VFRWkrpFZzeZYPuSYHx8nB//+McEg0HC4TDT09M4nU6SySRarZby8nIOHDhAWVkZFy5cYHJyUoSQFhYWUlVVhclkwmg0UlJSIlp5pNY36QLV7/dz6tQp9Ho9ZrN5zRNjQ0MDP/jBD64pRyOTyTA0NMSJEycYHh4WllWDwUBlZaWodF6piCgtLaW+vh6DwcDMzIwYJL1eL9lsFq1WK+ylsGwbPXjwIG1tbTfipZbxBSGRSPDb3/6Wo0ePEgwGc0oQVCoVlZWV/NEf/RF33nnnDbWWZLNZTpw4wUsvvZRzoob/z3KSSqXo6upa9ZnftWsXTz311A1vSJOx/L5IWV+S0u1qpNtKpdvK77Sr/X6fzyfItoWFBaFuW6moMxqNFBUViYsqs9lMSUmJvNG+BlyLGi4SieDxeLBarUI5KllCUqkUf6HRsFOpwrlylFGAAgUlJSWiAVEm2OTZ7WbG+fPneeedd5iammJ0dJSlpSVUKhXd3d1873vfo6CggLGxMS5duiSUamazmS1bttDZ2Ul9fb3IjIzH40xPTzM7OysUBQqFgnPnzmG1Wld998XjcQKBABqNhr/8y79kx44dH/t4Q6EQly9f5uLFi1gsFrGEkB7Xhg0bKCgoIB6PE4lE0Ov1VFZWYjAYyGQyjI+P4/F4UCgUObbRoqIiGhsbhQ1TIr2k7J4bfZGbzWax2+0MDg4SDAYxm81s2rQpJ1/L4/GIxythJdEWjUaZmpqSCxFWIJlM8oc//IHTp0+L64WOjg5CoRCTk5O4XC7S6TR6vZ57772XzZs337D79ng8HD16FJvNRiqVoqenh1Qqxblz55iamhLL+/r6enp6enJyfmFZ/fnAAw8IRefVctnk8oObD5lMBrfbLeykdrtd5LattHdns1mhupXy2aTZUFoQFhUV0dTURFdXF52dnVRWVn5lifNrhXTdPjk5ydjYmOAcpNdfIt3+TKlihyp3dpOmZpVKRVFxMdVVVbfs7CYTbF8i+P1+fvzjHzM1NSUqjKenp4Us22AwcNttt/Hoo4/i8Xg4f/48MzMzRCIR1Go1RqNRXIRKPulYLIbP56OwsBCj0YherycUCnHq1CnB9JvN5lU5WJWVlfzFX/zFqkDq9ZDNZpmamuLkyZOMjIxgt9uZn58X21cpT8Hr9YoBJj8/n7q6OvLz85mfn88J69br9RQWFor/h+XGoPvvv39Va5aMWws2m42XXnoJm822qgRBp9OxceNGnn766Rtu0XQ4HPz4xz/OybuB5ZP51NQUTqeTpqYmETItoaamhueff37NIGkZNxaZTIZwOJzTXCrlqq2EQqFYU+l2tYu3eDyOy+XC6XTicDjEZ08qSoBlpW1hYaG4sGpoaKC8vFwmWK8DmUxmlRpOalMcHx+n/fe/p92/hIussJaS/f+UJNI56FYd0q4X8ux262Bubo7XXnuN+fl5RkZGmJmZIRaLUVxczIMPPsgdd9xBIBDAbrczPDwsrHBSLmFTUxPNzc1UVlZiMplIJpPMzMwwMzNDMpkkk8lw7NgxpqenV7U1JxIJ4QR4/vnnufvuuz/28aZSKWZnZ3G5XIyPjwuSCpaXT0VFRXR0dFBeXi7yNtVqNSaTiZKSEpxOJ2NjY6TTaVQqlbCNAtTW1lJaWorH46GyspLe3l5RglBaWnrDX3vpPD48PEwikaCxsZENGzbkECaSan0l0abX6wXRlkqlmJmZkQsR/gnZbJahoSF+85vfkE6nKSkpob29HZPJxNzcXA7Z1dTUxEMPPXTD3ttAIMCRI0eYnZ0lFovR2tpKU1MThw4dYmxsDL/fTyaTwWw2s3nzZoLB4Ko5Ydu2bdx5550olcqr5rLJ5Qc3N6LRqFC3Sd9XUmGCdD2ZzWYJBoMim1KKH1GpVCSTSfF91tzcTG9vL93d3VRWVsrOhWtEJBLBYrEwNTXFyMgIAwMD3DM2Tnc4jAvISrMbkAXy8jRUlFdQUlJyy85uMsH2JUM6neb111/nD3/4A7C8RbLZbNjtdjHcmM1mnn76aTHgXL58WWx6pItMSTar1WpRq9VEIhH8fr+Q+atUKi5cuCC2AUajEbPZnHNiKS4u5i/+4i8+cYbG3NwcJ06cYHR0lPn5eex2O/F4HLPZjNlsJhQKiVY6WL6QqqioQK1WEwqFCIfDuFwuNBoNBQUFoqJbyvCRbaO3PtLpNEeOHOHdd98lGAzicrkEyaFQKCguLubRRx/lgQceuKGBpYlEgpdeeokTJ06s+j+Xy8XExARFRUW0t7fnZBHm5+fz9NNPc9ttt92wxyLj2pDJZNZUuq1Hul2pdFvrgiibzbK0tMTCwgJOpxO73Y7f7xeEngS9Xk9xcTENDQ20trYKpbCMGwfHD39I8PD7ZMvKlom4WJRoNEY6laK6uloMwLfqkHa9kGe3WwuRSIQ33ngDi8WC0+lkeHhYRGJs2rSJgwcPUltbi8vlEuqMcDiMVqultLQUo9FIbW0tzc3NVFdXU1ZWRjabxWq1Mj09TSwW49SpU1y4cEF830nnqmQyKeaqb3/72zzyyCMf+3iz2axQ+CoUCsbHx+nv7yeRSIjvyYKCgpzlUzqdFnnARqORmZkZRkdHyc/PJ5lMMjs7SzQaRaPR0NjYiFqtJhAI0NHRIWz5XV1dn0kYeTKZZGRkhImJCWA5z7ezszNnXvT5fIyNjeW0uOr1ehGWDsuLQLkQYRk+n4/XX38du91OcXExjY2N1NXVEY/HGRsbw263k0wm0Wg07N27lz179twQF0IsFuP48eNMTk4SCoWorq5m7969HDp0iDNnzuD3+0mlUpSWltLX1ycUTStRW1vLo48+KtwJV8tlk8sPbn5I6irJSipll/v9fuGUkr63AoGA+G4tLi5GpVKRyWRIJBIUFxfT3NzMpk2b6OnpoaKiQl6kfgKk02mm/vJ/I3HiBBGDgWgstrwESqeXFy5qNdXVVRj0hlt2dpMJti8pLly4wM9//nPi8biwUk1PT+N2u0kmk+h0Ou655x7uvvtulpaWcLlcwlYKUFRUJHzTeXl5InzQ6/Xi9/vJy8tDp9OJ2nXpIlSSghcWFopwyT/7sz+joaHhEz8Hr9fLiRMnGBwcxOfzYbfb8Xq9lJeXU1dXJ6ro3W63+GLUarWoVCoUCoV4Xmq1muLiYjo6OoSKSLaNfjng8Xj41a9+JTaSPp9PECd5eXm0tbXxne98h66urht6v2fPnuXv/u7vcrbUsBz+OTIyQjqdprOzc5WUfP/+/Tz++OMyufsFQyLdVirdwuHwmqSb0WjMUbqtRbolk0mhbrPb7TidTjGgraySLywsxGw2097eTmNjI6WlpTe8Ke+rBscPf0jo8PurWkSTqZTI9ACZYJNnt1sH2WyWo0ePcuLECRKJBBaLRQRJV1ZWsnv3bnbv3o1SqcTj8QiSTalULmcParVkMhlKS0tpaGigqalJXADa7XYsFgsnTpwQiyJJrSEFv/v9fgAee+wx/tk/+2fXZHcPhULC5mkwGLh48SIjIyPE43FisRihUAidTkdNTQ21tbXiey+RSKDT6VCpVExMTDA7Oysy3NxuN9lsloKCAhoaGkRIfV9fn7AcriwnuJGIRCIMDg5itVrJy8ujp6eH5ubmnPtai2jT6XS0t7eLGdXpdMqFCCy/z++//z6nT5/GYDBQXV1NS0sLOp0Ot9vNyMiIiHmprKzkwQcfXLfZ9pMglUpx6tQpRkdH8fv9FBUV8cgjj3Du3Dlef/11Uf4mlWro9fpVJT16vZ5HHnmExsZGYFkRNTMzs2YuWzAYZG5uTi4/uEUQj8dZWFjA4XAwOTnJwsICi4uLQtEokW1er5dwOIxaraaoqEiQaVLOZHNzM7fffju9vb2UlZXJc9014MrZLZVKLZ8v4nHi8fiycEalumVnN5lg+xJjYWGBH/3oR8LSlk6ncbvdzMzMCFl2c3Mz3/zmNzEajYTDYZGBFgwG0Wg06PV6IY+VWjzz8vJYXFwUW1VJYZaXlyfIrfz8fMxmM8XFxeh0Ov70T/+U9vb2T/U8gsEgp0+f5uLFiwQCAebn51lYWMBgMFBXV4fRaCQQCOB2u/F4PGQyGbGFys/PF0HlsJzf1tHRITapsm301kc2m+XcuXO8+eabLC0trSpBMJlM3HPPPXz9619fZWW+HrhcLn784x+LJlsJ6XSaiYkJ3G43DQ0NYtCW0NzczHPPPSc33N5kyGQyorFUIt6uRrqtVLoZDIYc0i0QCAhVx8zMjGi2WhmonJ+fL4oSpHa6r8rF1o3EegTblbhVh7TrhTy73bqYmJjg7bffJhqNEgwGGRkZERlqHR0dbNu2jZ6eHqHqX1hYQKPRUFxcjMFgEMSWVqulvr6epqYmqqqqKCoqwuFw8MYbb/DOO++I7zidTidsUUtLS2SzWe69916effbZa7I2SpbRYDBIYWEhyWRSkGYej4dwOEwkEkGj0VBbW4vZbBYXqlJrcDKZZGhoiKWlJZLJpPg5jUZDVVUVpaWlLC4uYjab6e3tFfbRzyoXaXFxkYGBARYWFjCZTGzYsEHkOq68zdjYWE7juE6no62tTVgJpUKHr3IhQjabZXh4mEOHDpFKpUS2XklJCdlslrGxMSwWC9FoFKVSyW233ca999573UrFTCZDf38/AwMDeDweQZj5/X5+9KMf4XA4iMfjFBQU0NXVhclkyhEOSNi7dy+7d+9GoVCQSqWwWq1r5rLJ5Qe3JrLZLH6/X1hJJycnRVGCVLIkXWMEAgHRRqrX61GpVCiVSoqKimhtbWXHjh309PTIebxXwZd9dpMJti854vE4v/zlL3NaEOPxOFarFYfDQSwWw2g0cuDAAbZs2UIqlWJpaYn5+Xni8TgajQaNRoNSqRQVxwqFQhwLh8PYbDamp6eJRCKCgNNoNKhUKrRaLZWVlVRUVPD973//mhqqrvZczp49y/nz5wkEAvh8PhwOB5lMhtraWgoLC8lkMvh8PtxutwhUTKfTaLVaQqGQUByZTCZaW1uF2k62jd76CIVCvPHGG1y4cEGUIEjqIZVKRV1dHd/+9rfZunXrDTvhpVIpXn31VT744IOc49lslvn5eSwWCyaTic7Ozhz5uMFg4JlnnqGnp+eGPA4Znw3S6fSaSrcroVQq11S6SYO4x+PB6XQyOzuLw+FgaWkpp5BBqVRSUFBAbW2tsEEVFRXJg9k1wPHDHxJ4512UH2PPyCQSmO4/cMsNadcLeXa7teH3+3n11VdFu6iUz7a0tITZbKa1tZVt27ZRUlJCOBzG4XDg9/spLy+nqakJrVaL0+nE6/WKBtLGxkYaGhooLS3lyJEjvPDCC6vIf4PBQDweJ51Os3fvXn7wgx9c03y00jIqqemkZcXU1BRTU1OivEStVlNbWyvKGVb+DikkO5FIEAwGcTgcZLNZ9Ho9TU1N5OXlEQwG6erqoqmpiaamplXRDDcSCwsLDAwMsLi4SHl5ORs3blyVMby4uMj4+LhwgsBqoi0cDmOxWFYVIpSXl9PS0vKVKETweDz8+te/ZnZ2lrKyMkH86vV6FhcX+eijj3C5XGQyGUwmE/fffz/d3d3Xfb/Dw8NcuHABl8uFUqnk/vvvp7y8nL/5m79heHiYaDSKTqejp6cHg8FAKpVadQ5ubm7m4YcfRqfTXTWXbWX5QV1dnbw8uwWRTCaFum18fFxEgUiLV4lsW1xcRKFQiPbyvLw8YSttbW1lz5499Pb2yvEgV+DLPrvJBNtXANlsliNHjvDKK6/kBDouLS2JzWImk2HDhg3cddddlJaWkk6ncTqd+Hw+ke2hUCiIxWKiwS2bzaJWq1EoFITDYU6dOoXL5RJZZxLZplar0Wq1mM1m/vIv/5K9e/de1/NJp9NcvHiRM2fO4PV6iUajOBwOwuEwlZWV4rHG43Hcbjfz8/OiUVWlUgmSTalUYjKZqKiooKKigra2Nh599FHZNnqLY2RkhFdeeQW3243X6xVqTViW+u/cuZN//s//OWVlZTfsPi9evMgvfvGLHOUcIFQHmUyGjo6OHNWaQqHggQce4IEHHvhKBR/f6kin02sq3a6ERLqZTCZBuhkMBpERKVm1pLDdldkveXl5lJeX09raKhS3X6XMnk+CwLvvEjp+/Jpua9y7F9OBA5/xI7q5IM9utz5SqRTvvvsuly5dApYVMpOTk0xPTwu3QGdnJ319fcDysmlubo50Ok1TUxPd3d3k5eUxOzuLzWYjHA5TUFAgihEkJY/H48mJPZAINY1Gw5YtW/g3/+bfXLOaKBQKMTMzQzabpaSkhEAgIGbB/v5+hoeHRaC4SqUSirYrixdmZ2dZWFgAEJmXUs5qc3OzyO7q6+ujpqaGrq4uqqurb8jrfiWy2SwzMzMMDQ0RjUapra2lr68Pg8GQczu/38/4+Lh43LBMWra1tdHQ0IBSqSSRSDAzM4PFYhH5sbC8/G1pafnSFyLE43GOHTvGqVOnMBgMIjdQp9Oh1WoZHR0VtmiArq4uDh48eN0kxczMDCdPnsTpdJLJZPja177Gpk2b+OlPf8rhw4dF5lZXV9e6joeCggIee+wxoWRcL5dNLj/4ciEQCOBwOIRlXyq6SqfTBAIBXC4XPp+PdDqNRqMRhJter6esrIzOzk52795NX1+fTLjy5Z/dZILtKwSLxcILL7wg8jVgeXBzOp3YbDYCgQAVFRXs2rWLrq4ukcdht9uJxWJUVlaKE4rP52NhYYFwOIxCoUD5T1k358+fZ3JyUlwsqlQqNBqNINu0Wi3f+MY3+P73v3/dddaZTIaRkRFOnTollGxOp5NAIIDRaKSiokIMKIFAAKvVytTUFLFYjEwmQyaTEY+tqKiI/Px8iouL2bFjB9/+9revuQFVxs2HeDzOO++8w9GjR4lEIrjd7pwShIqKCr7xjW+wf//+GzbEer1efvKTn2CxWHKOJ5NJxsbGWFxcpK6ujoaGhpytaFdXF88880xOjoeMWwsS6bZS6XYl2Qr/n1JtpbU0Go3idDqZnp4W38MrCTuFQoHJZKKuro6Ojg5aW1vlQV3GNUOe3b48+Oijj3j33XdFPq7f72dkZIRwOIzJZKKsrIyNGzdSW1tLOp3G6/XidrspKCigo6OD3t5e8vPzGR8fZ2pqSigvamtr0Wq1/OY3vyESiRAMBnMCv+PxODqdjs2bN/N//p//5ypCaT1IhQWhUIiSkhLS6TSRSEQQaSdPnmRgYIBQKEQkEiGbzVJXV0dtbS3ZbFacJ5eWlhgZGRHn8NnZWRYXFwHEXBqNRqmpqaG3t5eamhp6enpuaCTESqTTacbGxhgfHyedTtPc3ExPT8+qkPOlpSXGxsZWEW1Sw7RKpSKdTmO327+ShQjSDP/b3/6WVCpFZWUl7e3tQgEeCoW4cOECDoeDZDJJfn4+d911F9u2bbsudffCwgLHjx/H4XCQSCTYsmULd9xxB++99x5/93d/h9/vR61W097eTklJibheWJmrpVQqueuuu9i6dStw9Vw2ufzgy4dUKiUWpSMjI9hsNiE+WUm2JRIJoW6TmusrKiro7e1l7969bNiw4Uv79/1Vh0ywfcUQDAb5yU9+wujoaM7xSCSCzWbD6XSiVCrp6upi48aNVFVViQanhYUFVCqVqLuWtnnz8/Mir0Oj0TA5Ocng4CCRSIRkMinyJlaq2vr6+njmmWfYvn37NQ9rV4MU2ittTIPBoBgepYZRWP5SlGwKUmuMtFU1Go0UFhYKa+v+/fv54z/+Y6qqqq778cn4YmC1Wnn55Zex2+0iS2FlCYL0OZSav64X6XSat956i9/97nc5x7PZLDabjdnZWZHzsVIJUFRUxHPPPUdLS8sNeRwyvnikUqkcpduVGWwSVCqVULppNBqi0Sgej0cEYy8tLQkrKSyrSSorK2lra6Ozs5Pq6uovtdJBxvVBnt2+XFhYWOC1114Ti9JUKsXc3Bx2u11EetTV1bFhwwbRoDs3N0c4HKauro7Ozk66u7vRarVMTEwwNjaG0+kkHo+jUqk4ffq0WJiGw2Gi0SjZbFYsU7u6uvi//+//+5oXkCstowaDAaPRiM/nEyHwyWRSRH8Eg0ExN9bX19Pc3Ew2myWVSgGIi1m1Wi0ysKRyh8rKSgoKClAqlWzYsIHm5mba2tpoaWn5zALHY7EYw8PDWCwW1Go1nZ2dtLe3r/o+XlpaYnx8XOQhw3IhV2trK42NjahUKvE6TU5O4vP5xO2+CoUIbreb3/72t8zMzFBeXk57eztlZWXk5eVRWFjI8PCwKEEAqKur48EHH8RsNn/q+/T7/Rw9ehS73U4kEqG9vZ0HHniAoaEhXnjhBaxWK0qlUkTJJBIJ0un0KhK1s7OTgwcPotVqr5rLtrL8QD5nf/kQCoWYn59nfHycsbExPB4PgUBAlO0tLi6K79KVM19NTQ2bNm0SZJv8ufjyQCbYvoLIZDK8/fbbvPPOO6uOLy4uihNEbW0tTU1NdHR0YDQaycvLY2FhgaWlJUpKSuju7qapqYlUKiWGB7fbTSaTwW63Mzw8TDweFwOTdB9KpRKVSkVjYyPbtm1jz5497Nu374aoMhwOBydOnGBsbEyUHfj9fjKZDAUFBYLUSKVSzMzMMDMzg9PpJBgMkpeXh16vp6SkROQn6fV69uzZw/79++ns7PzSDjhfZqTTaY4cOcK7775LOBzG7XbnEB2FhYU8+uijPPLIIzesZntoaIj/9b/+F8FgMOe43+9ndHRUWEZLS0vF/ymVSp544gn2798vZ299SSGRbiuVblcj3dLpNNFolIWFBRYWFggEAquaa00mE42NjXR0dNDe3i4uqmXIAHl2+zIiFovx1ltvMTk5KY6Fw2HR4hmPx8nPz6e9vZ3m5mby8vJElIZGo6G1tZWuri5aW1tRq9XYbDbGx8eZnp5mYWFBtItKpH8sFhN2Tklp9O/+3b9j06ZN10xeBYNBZmdnyWazVFRUiMVBbW0tRqORZDJJf38/Z8+eFRejsViM6upqurq6RLNoMplkcnISi8WCSqUiFovhcrlIJBKo1WrKyspQKBQUFRWxbds2mpqa6Onp+UyVQ8FgkIGBAebm5tDr9fT29q7ZbhoIBBgbG7sq0QbLWW6Tk5M5t5MKEb6sCuZoNMqHH34oWkYbGhpoaWlBoVBQWlpKJBLh3LlzWK1WsRjfuXMnX/va1z61AigSiXDs2DFmZ2cJBAJUV1fz+OOP43a7+du//VsuXLgAIMqqJJLtSpt0SUkJjz/+OOXl5VfNZZPLD74ayGQyuN1uZmdnGRoaYnZ2Fr/fLxTFPp+PaDQqFgcGgwGTyURDQwPbtm1j3759dHd3y9cBtzhkgu0rjMuXL/O3f/u3qy7wEokE8/PzzM/PU1BQgNlsprGxkdraWvR6Pdlslrm5OVKplMj3aGlpIRgMCqJtfn6e2dlZPvroI2D5wjISiRCPx1EoFGQyGbLZLKWlpXR1dVFWVsbOnTu59957b0g2ls/nE/YDidyTyA6lUinsqZJyb2Fhgbm5Ofx+PxqNBqPRSGVlpSDUKioq6O7upre3l56eHpqamuRNwy0Gj8fDyy+/zMTEBIFAAK/XKzIJ1Wo1LS0tPPvss3R2dt6Q+/P7/fz0pz9lfHw853g8Hmd0dJSlpSVqampWfZY2b97M008/LZO5XxEkk8lVpJtkz1qJdDpNMplkaWmJhYWFnNIWCWq1murqaqFuq6qqkoe0rzjk2e3LiWw2y8mTJzl69KhQZWcyGTwejygGiEQilJaW0t3dLRRBHo8Hj8cjVLBdXV3ie0IigC5dusTrr7/O4uIier2egoICdDod8XhclCWYTCa+/e1vs337dhoaGq6JLFhpGS0rKxNZwGVlZVRWVorZcHBwUGT6SuReeXk5fX19FBYW4vF4WFpaYmhoCJfLJSz60WiUTCaDVquloKCAWCxGfX09W7dupbu7m56enuuOJrkaPB4Ply9fxuPxUFJSwoYNG6isrFx1u0AgwPj4OA6HQxzTarW0tLTQ1NQkiDapGMJms60qRGhtbf3S2Q3T6TQjIyO89957JBIJKisrxXsm5ZKOjIwwMDAgShDKyso4ePDgp1b/JxIJTp48ydTUFD6fj5KSEp544gkSiQSvvfYa77//vrAfNzU1AYgm3JWfebVazYEDB9iwYQOwfi6bXH7w1UM0GmVubo6RkRFGR0dFNrhEtkUiEVKpFOl0WjQ519fXs2/fPu688046Ojq+6Kcg41NAJti+4vB4PPzoRz/CZrPlHM9ms4RCIex2O+l0mtLSUgoLC2ltbaWwsBCTySSkrwaDgY6ODnp6emhsbMTlcjEzM8Pc3BwXL17k2LFjZDIZ0ai3kmiTWqHq6+uFP33r1q0cPHhQnMyuB6FQiDNnztDf3y+IxGQyiUKhIJ1OixOk1+sVBJsUAiw1GFVVVZGfn49SqaS5uZmmpiaMRiNdXV309vbe0LB8GZ8tstks586d48033yQYDK5ZgnDvvffyjW9844bkt2QyGX7729/ym9/8hpVftZK92mazUVBQsEodWVFRwXPPPXfDrKsybi18HOkmWbZ8Ph9erxefzycsXiu3+YWFhTQ3Nwt123oB5V/2sNmvMuTZ7cuN6elp3nzzzZzMx0QiQTweF5a1bDYrLJcGgwGNRiOWpM3NzbS2ttLZ2SnOealUiuHhYf76r/+aiYkJUqkUeXl5FBQUoNfrCQQChEIh9Ho9f/RHfyRaSRsbGz9WBX6lZbSoqAi3241Wq6Wuri7n58fHxzl16hQ2m024IUwmE5s3b6aqqgqPx4PFYhGRJNJ3okqlQqVSiRysdDpNS0sLPT09bN++nc7Ozs9s8ZDNZrHb7QwODhIMBjGbzWzcuJHCwsJVt5WW0nNzc+JYXl6eULRJ82kikWB6eprp6elVhQitra1fOsuhy+XivffeY2ZmhpKSEjo7O6msrCSdTlNZWUkikeDs2bNMT08TDAZRKpX09fVx3333fSoFdzqd5ty5c4yNjeF2u9Hr9Tz22GPodDreffddfv/73+N0OkUZmkqlIhQKieuXldi4cSP33nsvarV63Vw2ufzgqwupIXlqaorBwUGmp6eZnZ0VpWwS2SZ95xqNRhoaGrjrrrs4cOAAzc3NOb9Pnt1uXsgEmwySyST/+I//KGwBKyEF5fr9foqLi0WmWV1dHUajEaPRyPz8vLCUSkRbZWUlTqcTq9VKf38/r776KsFgUJBqkvUpFouhUCgwGAzU19cL0kuS2d9///1s3Ljxuq17iUSCc+fOce7cuRxCRWpBlTafUuOo3+/H6XQSi8VIJpNC0abX6zEYDHR2dortobRlky2ktw5CoRCvv/46/f39ogRBUjoqlUpqamr4zne+c91huhLGx8f56U9/mlMwAsvE7tjYGABtbW05G2mNRsNTTz3F7t27r/v+Zdz6kFQpK0k3ScGWSqVYXFzE4XDgcDhylBx5eXloNBpUKhV1dXVCtbJSWfFlr0v/KkOe3b78CAQCvP766zlEjTRnmUwmxsfHcblc6PV6WltbqaiooKCggHg8zvz8vCBq2traaG5uFgqqRCLBf/2v/5UzZ84QCAQEuWMwGFAqlcKq98QTTwiLY319PU1NTR+rFJMsowBms5mlpSXi8ThVVVU5bduw/Bk+ceIEk5OTJBIJIpEIeXl5bNq0iY6ODrxeL6dPnxZkoLR4yM/Px2QyUVRURDgcFuRVc3MzO3fupKWl5TOz6WUyGaamphgZGSEej9PY2MiGDRvWfF3WI9okRZv0GK9WiNDS0nLNSsJbAZFIhFOnTgnLaF1dHT09PcTjcfR6PWazmfHxcS5evMj8/DypVIqCggLuvvtuNm7c+Inntmw2y8DAAJcuXcLlcqFSqXjggQeorq7mgw8+4NixYwwPDwvCT61WE4lEiMVimEymnPurqKjg8ccfp7i4+Kq5bHL5gYx4PI7NZhM5g2NjYzgcDrxer7Dkp9NplEoler2euro67rrrLh5//HEaGxvl2e0mhkywyRA4ceIE//AP/yB84Ssh5VwUFxeTSCTEIFVSUkJxcbHYyKhUKtrb22ltbaW7uxuDwYDH4+H8+fP89Kc/xePxEIvFyGazIrxWItqMRiPt7e2C8JKyDpqamrjvvvu4/fbbKS0tva5NXTqd5vLly5w6dQqPxwMgvri0Wq34UrPZbPh8PmFFWLlRKC0txWQyYTabc0g1pVIptqSyhfTWwPDwMK+88go+nw+fz4ff7xdKM61Wy86dO/nud7+bk5X2aREMBvnZz37G4OBgznEpKDkUClFVVZVzgQOwa9cunnrqqRuWDyfjy4OVpJtEvMXjccLhMC6XSxTXSJYWiXDTarWUlpbS0tJCR0cHRX//90Q+OILmn1qi10Nybg7j/rvkIe0Wgjy7fTWQTqc5fPgw586dyzkuqdcWFhYYGhoiHA5TVlZGU1MThYWFFBYW4na7WVpaEkq0zs5OQQSk02l+9rOfcfLkSZaWlkQ5VCqVQqFQoFKpyM/P58EHH6S9vR1YnoVqa2tpbm6+qqJopWVUan33eDwUFhZSXV29Kt/N7XZz6tQphoaGSCQSouCgu7ubrVu34vP5+N3vfsfs7CzJZFK0yufn51NRUUFpaSlLS0sUFhbS1NREXV0dmzZtora2lsLCws9E1ZZMJhkZGWFiYgKA9vZ2urq61iTCQqGQINpWljE1NzfT3NwsfuarUoiQSqUYHR3l8OHDxONxysvLue2229BoNKTTaaqrq8lkMpw9ezbntWhtbeXgwYOfam6bnJzk7NmzOJ1O0uk0d911F93d3Rw/fpyLFy9y+vRp8vPz6e7uJi8vj1Qqhd/vx2Aw5MxoWq1W/E1cLZdNLj+QIWFlM/Tw8DAXLlxgenoaj8cjyDap8CovL4/a2lr+UpNHk9tNfn39VX+3PLt9/pAJNhk5sFqtvPDCC4J8WgkpL0On06HRaPD5fJhMJqqrqykoKKCsrAy/3y/yPVpbW+no6BAtTpOTk/yX//JfmJ+fJxaLCZWFVEYgNax0dXVhNBpJpVLEYjFhjaqpqWHXrl3s3bsXs9mcU4P9SZHNZhkdHeXkyZM5W8OioiIKCwvx+XzMzMxgt9uFmi0ej+c0oubn51NcXExXV9cqQk2v18sW0lsE8Xic3/72txw7dox4PC5yXySUlpbyz/7ZP+Puu+++7uEnm83y+9//njfeeCOnGVLads/Pz2MwGOjq6sq5MKmtreX555/P2X7KkLEW4vF4jtJtcXERp9OJ3W7Hbrfj8/lIJpOoVCpBtt07MUGT24OisgJdvm5dFYQ8pN16kGe3rxaGhoY4dOiQUGRLqK+vx2g0cvbsWSwWCwqFgvr6eioqKigsLCQ/P5+5uTlRglBfXy8WiNlslpdeeokjR46I+JClpSUCgYCwNJlMJh5++GE6OjrE94dCoaC6upqWlpZ1Ixck8sHpdGIwGCgvL2dhYUGQdGsRdIFAgNOnT3Pp0iWi0agIDG9tbWXPnj14vV7eeecdHA4H4XCYhYUF4vE4arWampoaSktLhVquoqKCpqYmamtrqaiooLy8/DMhp8LhMENDQ1itVvLy8kR28Vqk3lpEm0ajEYq2lTEAX/ZChGw2i8vl4v3338disVBcXEx3dzd1dXWEQiFxDTIxMUF/fz9zc3PEYjG0Wi179+5l586dn1jVNzc3x8mTJ5mfnycej7N161Z27drFyZMnGR0d5dSpU0SjUbq7u8Xn0+fzoVAoVr3e27Zt484770SpVK6byyaXH8hYC9ICYmBggNOnTzM4OMjCwoIg27LZLH+mUrFdocSv0WAwGCguKcagN6z+XfLs9rlDJthkrEIkEuFv//ZvGRgYWPP/pc1lWVmZOJmVl5cLNZvJZGJ+fp5EIkFLSwvNzc10dXVRUVGB3+/n//l//h8mJiaIRqMkEgmi0SjJZJJsNksymSSVSlFdXY3RaBSlChIhl0gkKC4u5rbbbmP79u00NTVRWVl5XeqemZkZTpw4gcViEcNMYWEhVVVVQqbucDiEok3Kk9NqtYJwKSkpYdu2bWsSILKF9NaA1WrlH//xH3E4HCKfbWUJwoYNG3j++edvyPefxWLhJz/5iaidlyBtpQFh5ZEG8Pz8fJ5++mluu+22675/GV8trCTdFhYWGB8fx2KxiGHtgZkZ2v1LBLRaFEoFarUGXX4+Or0eXX6++AzKQ9qtB3l2++rB7Xbz6quv5qibAIqLi9m6dSuXLl2iv78ft9uN0Wikvr4ek8lERUUF0WgUp9NJdXU19fX1tLS00NjYiEKh4Ne//jWHDh0Svy+RSOD3+3G5XCILa//+/ezbt490Oi2s6bBsA21paVkziwxyLaM1NTUEg0GCwSAVFRWiGfRKxGIxzp49y/nz54lEIkSjUeLxOHV1dezatQufz8f777+P0+nE7XaL+U2lUlFdXS3yhBsbGykqKqKurg6DwYDRaBSKtxtNdvj9fi5fvozT6cRoNLJhw4Z1/y7D4TDj4+PY7fYcok1StK0k2qRCBKvVmrO8+7IUIoRCIc6ePcuZM2fQ6XTU1NSwZcsWka0svYbnzp1jfHwct9tNJpOhurqagwcPfuI8W6/Xy7Fjx5ifnyccDtPZ2cmBAwc4f/68KAKxWq309PQI8lhqBS8tLc35vNbW1vLoo49SUFCwbi6bXH4g4+OwuLjI8PAwR48e5cyZM9hsNr4bjrAVcElUjkKBWq3CoDcIhTLIs9sXAZlgk7Emstks7777Lm+99RZX+4i0traysLCA3W5HpVJRVVWFTqejqqqKbDYrmkjb29tpbGykq6uLTCbDf//v/53Jycmc6neJaJMIrIaGBsLhMCqVCpPJhEqlEvkb0WgUrVZLT08PfX19tLW1ic3kp1UZOZ1OPvzwQ0ZHRwWxYjAYaG1tJR6P84c//AGr1Spso7BshSgoKBCPraysjO3bt6/KEJFu29LSQm9vL42NjbIU/CZEOp3mgw8+4N133yUWi+H1ekX7LIDRaOTxxx/n0UcfvW7LZiQS4Re/+AUXL17MOR4OhxkZGSESiQgl6EqrzP79+3n88cflLaeM60IsFmNpaWk51+i//TWF4+P4NRqy5H7fmyvNYksvD2m3HuTZ7auJeDzOoUOHGBkZyTmu0Wi4++67icViHDt2jNHRUSKRCBUVFVRWVmIymSgvL8flchGPx2ltbaWqqorOzk5KSko4fPgwv/rVr3J+Zzqdxu1243Q6SSQSdHV1cccdd1BQUIBCoUCj0YjsMYnwWWtGSiaTzMzMEA6HxeJ0YWEBvV5PbW1tDqF05c999NFHnDlzRmS5RaNRKisr2bRpE36/nwsXLrCwsIDVahU5vFqtluLiYoxGIx0dHbS2tooIEKVSiVKppLi4WKj8bqSFdGFhgYGBARYXFykvL2fjxo2UlJSsedtwOMzExAQ2m+1jibYvcyFCMplkbGyMI0eOCCJr69atGAwGgsEgJSUlVFVVMTMzw/nz57HZbIRCIdRqNbfffjt33nnnJ2qRDQaDHD9+HLvdztLSErW1tTz22GMMDw+LgPqxsTFB0EqPUYrTWVkqpNfreeSRR2hsbFw3l00uP5BxrUilUkxPT2P7q/8fprExnNkMmXQm53pdrVbT1dUFyLPbFwGZYJNxVYyMjPCTn/wkJ1B1JbLZLL29vWQyGUZHR/F6vRQWFoqhpbq6msXFRfx+v2iZamtro6qqihdeeIGRkRGy2SyRSIRQKEQqlSKRSIjBYOfOnQSDQdFyajKZhKotFAqJ7VVDQwO9vb3U1NTQ0tJCVVXVp7aQ+v1+Tp48yeXLl8Xj0Gq1dHR0kM1mee+99xgfHxe2UVj2w5tMJtRqNYlEgtraWrE1XGso0+v1ojZetpDefHC73fzqV79iYmJizRKE5uZmnn/+eXHy+rTIZrMcPXqUV155JSf7MJVKMTk5icvlQqfTCdu0hObmZp577rk1L1JkyPikcPzwh4QOv4/CXEk0smy5isVjZDKZ5bwYxfIFmTyk3XqQZ7evNs6cOcPhw4dXLUq3bt1KX18fx48fp7+/n5mZGRQKBbW1tRQUFFBZWYlWq2Vubo7i4mIaGxupra2lvb2dixcv8vOf/zxHKQXLpL3b7SYajdLQ0EBfX59o85Sy2oxGIwqFgpKSElpbW1fNPysto1K5lNPpJJlMfizpkMlkGB4e5tSpU7hcLpLJJJFIhKKiIpqbmwkEAjgcDmZnZ5mYmBDzm16vR6fTUVhYyNatW9mwYQPV1dUAYvbNy8sTYfQ3Sl2UzWaZnZ1laGiISCRCbW0tfX19GAyrLV6wvJSbmJgQzbCwfBEtEW0rl37pdBqbzcbU1BThcFgc1+l0NDc337KFCNLn4+jRo0xPT1NYWEhHRwednZ34fD7UajX19fUoFAouXrzI8PAwCwsLpFIpiouLOXDgAB0dHddMlsZiMT788EOsViter5fS0lKefPJJ7HY7/f39QhFeVFQkPstSRp5KpVpFmu7bt49du3YBMD8/j9vtXpXLJpcfyLhWSLOb0mwmEAiwFFgiElnOMDcVFNDQ0ADIs9sXAZlgk/GxWFxc5IUXXmB6enrd27S0tNDe3s7ly5eZnJwkHo9TVlaGTqcTW9H5+XlUKhUdHR1UVVXR0dHBG2+8QX9/P4CwgoZCIWEVTSaT7N69WzQGSVu5goICSkpKUKlURKNRQqEQ6XQas9lMc3MzlZWV1NfX09jY+KktpJFIhNOnT4umSVgeZqRMuWPHjvHRRx/hdDrFoKnX68UwqdFoqK2txWw2U1RUtO72VbaQ3pzIZrOcPXuWt956i3A4vGYJwr333su3vvWtdQfia4XNZuPFF1/E6XTm3P/8/DwWiwVYJtWqqqrEYGgwGHjmmWfo6em5rvuWIUMa0q4sOUilUjkXYfKQdutBnt1k2Gw2Xn/99VWL0pqaGh555BHm5+c5cuQIQ0NDeDweoWKT2hvD4TB+v5/GxkZh9fT5fLz44ourSrHi8TiBQIBMJkN7ezubNm3C6/WSTCbRarWoVCr0ej2FhYVoNBoKCwtpaWmhsrIyh/QIBAJYrVYA8Rh8Ph8lJSWYzeaPVWFNTU1x8uRJrFYr6XSaSCRCfn4+JSUlIjdubGyMqakpkskkSqUSg8GASqWitLSUnTt3iqbRUCiEx+MRC9cbbSFNpVKMj48zPj5OOp2mubmZ7u7uHAXUSkhEm81mE7OnWq2mqamJlpaWnHn3aoUIjY2NNDU13ZJzZyAQoL+/n7Nnz6LVaqmqqmLnzp3EYjGhyDSbzczNzXHu3DlmZ2dZXFxEqVTS1dXFgQMHrlkhlkqlOH36NFNTU7jdbvR6PU888QThcJgzZ87gcrmw2+1kMpmcdu7FxUVCodCqz2tzczMPP/wwOp0uJ5etqalJvHdy+YGMa8Fas1uWLMFgCPU/fdeCPLt9EZAJNhnXhFQqxauvvsoHH3yw7m2Ki4t54IEHsFgsDAwMMDc3h1qtpqSkBK1WS0NDA6lUCqfTSVVVlQiW7e/v58yZMzm/Kx6PEwqFxIZx+/btNDc3E4lEmJqaYmxsjFAohMFgoLS0FIPBQCaTEVXy5eXlmM1miouLKS8vp6Wlhdra2k9lIU0mk1y4cEHYD2BZxVRVVYVareajjz7i3LlzghyRbKN6vR6FQkFpaamwr6pWfOFdCdlCenMiGAzy+uuvc/HiReLxOG63O6cEoaqqiu9973vs2LHjuiwksViMf/iHf1j1txAMBhkZGSEWi1FWVkZ7e3tOiPQDDzzAAw88IH9eZHxqrEewXQl5SLv1IM9uMmBZifXGG28I0kqCXq/nscceo6ysjDNnznDmzBkR31FWVobJZKK0tJSKigpcLpdYMpaXl6NWq/nFL36Rcz6EZZtiIBBAoVCwZcsWnnjiCWZnZ7FYLITDYbLZLCqVajmUu7gYvV6P0WgUdlSR+XiFZdRgMIi5sq6u7prsflJg/cTEhCjOymaz4hyqVqsZHR3FZrOJIgRpdquqqmL79u3ceeeddHZ2EgwGcblc+Hw+stnsDbeQSo3iFosFlUpFV1cX7e3t657bo9GoULR9HNEGy0H8UpmSBEm12NLScstZEhOJBOPj4xw7doxoNEpRURFbt26lvLwct9tNfn6+UIZdvnyZwcFBHA4H8Xgcg8HAnXfeyW233baqrXYtZLNZ+vv7GR0dxel0olQqeeihh9BqtZw4cQKXy4Xf7ycajeZ8FuLxOA6Hg4qKipzPa0FBAY899hg1NTXr5rJJ5QcqlYr6+vpbUnEo47OFPLvdvJAJNhmfCGfPnuWXv/xlTrbDSqjVah577DH0ej2nT59mYmJCtI1KW8uamhp8Ph+xWIy2tjYqKyuxWCxcuHBh1YCSTCYJh8NEo1H27NlDd3e3OGaxWBgcHMTn86HT6cTWVaVSiWw3acuYl5dHcXExZrNZDHGf1EKayWS4fPkyp0+fxuVyAcvDidFoRKlUYrVaOX36NG63m2w2i0ajwWQyiTKEhoYGqqurycvLEwPceoOTbCG9+TA8PMyvfvUrFhcXCQQCeL3enKF2586d/It/8S8+VTW8hGw2y6lTp/jHf/zHnL8xKXvE5/ORn59PZ2dnzjDc1dXFM888c13NujK+upCHtC8v5NlNhoRMJsORI0c4depUznGFQsEdd9zBjh07cDgcHD16lIGBAXFxX1ZWRn5+Pg0NDWg0GlGCIClrfv3rX4u4DgkSyQawYcMG/vW//tdkMhnGxsYYGRnB7/cTCoVQqVTodDrKy8spKCigoKCA5uZmamtrUSqVqyyj1dXVOJ1OIpEIZrP5ms+3Xq+XEydOMDQ0RDqdFkqnWCxGQUEB2WwWi8WCx+MhGo2i0WjQarWCzLvtttu45557aGtrI5VK4fV6cblcn4mFNBgMiiW1Tqejt7eXhoaGdQm8tYg2lUpFU1MTra2tq4i29QoRKioqBHl6qyCbzTI3N8epU6eYmprCZDLR1tbGli1bcLlcJBIJqqqqBOl29uxZZmZmROFFU1MTBw4cwGw2X9P9jYyMCOdgbzYNAAEAAElEQVRKOp3m7rvvprq6mmPHjonSoHQ6nTPfZzIZ7Ha7+IxIkEpBbr/9dlKpFLOzs4RCoZxcNrn8QMbVIM9uNy9kgk3GJ4bD4eBHP/pRjp3tSuzYsYMDBw5w+vRpBgYGsFgsJBIJCgsLUalU1NXVUVBQwMLCAoWFhbS2tmKxWLh8+fKads50Ok04HKa3t5ddu3aRSqWIx+PE43FmZmbo7+/H6XSiVqsxGo1iWEun0wSDQUpLSykvLxfNVmVlZdTX13+qFtJsNsvExAQnTpwQ2XDZbFaUM/h8Pvr7+/F6vaRSKXQ6HQUFBajVanQ6HW1tbZSUlKBUKsWm60qbxUrIFtKbB1Jw9PHjx0kmk3g8nhzbTWFhId/61re47777rktRNj8/z49//GMcDoc4ls1msdlsommtqamJmpoaMXQXFRXx3HPP0dLS8qnvV8ZXE44f/pDAO++i/JjvwUwigen+A/KQdgtBnt1kXImxsTF+/etfC4eAhLa2Nh5++GFUKhWXL18Wyi8pJ8poNGI0GmlqaiIUChGLxWhpaSGTyXD48GGSyWQOCZRMJoXqv62tjf/j//g/KCkpIZVKMTU1xejoqAiQTyaTaDQa0e5ZWFhIU1MT9fX1qFSqHMtofX098Xgcl8uF0WikpqbmmtU9oVCI06dPc/HiRZH363Q6CYVClJaWkkgkRHNkOBxGrVaj0WgwGAw0NDTQ09PDHXfcQWtrK2q1mmg0isvlyslpLSgooLy8/LotpB6Ph8uXL+PxeCguLqavry/HgnglotEok5OTzM7O5hBtjY2NtLa2rrKcrleIINl2byVrot/v59KlS5w/f14QWXv27EGhUODxeDAajdTV1aFSqRgZGeHSpUvMzc0RDofRarXs2LGD3bt3r2vLXYnZ2VmxaI/FYmzfvp2+vj6OHj0qPjs6nY5oNCoK02D5/QwEAtTV1eW8rp2dnRw8eJC8vDyRyya12koks1x+IGMtyLPbzQuZYJPxqRCLxfj5z38u8tPWQk1NDc8++yx+v5/jx48zOTnJ3NwcGo0GnU5Hfn4+TU1NpFIpFhcXaWpqwufzce7cOdHidCUymQw9PT1s3bqVZDJJNpsVzaPT09OcP3+e2dlZFAoFer1eWEQ1Go0Iu62oqBAEncFgEBbSurq6T2whnZ2d5dSpU0xMTJDNZkmlUmJrGwgEmJqaErJxaThVKBSUlZXR2toqJOOlpaXo9XrRqLoWVCoVLS0t9PT0yBbSLxizs7O89NJLYphyu93ifVMqlfT09PAv/+W//MTV8CuRTCZ5+eWXOX78eM5xv9/P6OgoiUSCkpISOjo6RL6fUqnkiSeeYP/+/Te08UzGlxuBd98ldMXnbD0Y9+7FdODAZ/yIZNwoyLObjLXg8/l47bXXhBpfQlFREU8++SSVlZUsLi5y4sQJLl26hMViIRaLUVpaikajoaqqirKyMkEGFBcX8+6775LJZHLyZlOpFH6/H1gmxv7qr/4qRy0khcRPTEzg8XhYXFwkLy8Ps9lMWVkZxcXFNDU1ibBuyTJqNpsxmUwi+6q2tjanCOjjEI/HOXfuHOfPnyccDotlbSgUory8nFAoRCgUIpFICHWeWq2mrKyMhoYGWltbuf3222ltbRUKuKWlpVUW0pKSEsrLyz+1hTSbzWK32xkaGiIQCFBVVUVfXx+FhYXr/kwsFmNycpKZmZlrItq+LIUI8Xic8fFxTp06RTAYpLCwkNtvv53GxkYcDgeZTIaamhqKi4tZWlri/PnzWCwWUYJgNpu59957aW5u/tj3yul0cuLECUHOdnV1sW/fPj788EPm5uZYWlqitLRUXN9IiEajWK1Wqqurc+JiSkpKePzxxykvL183l00uP5BxJeTZ7eaFTLDJ+NTIZrMcPnyY1157bVWblIT8/Hy++93v0t7ezunTp7lw4QKzs7P4fD4MBgNKpZLy8nLRNqpUKlGpVHz44YeYTKZ1w+M3bdrEQw89JAgO6fFIcvHTp08zOTlJOp0mLy9PtGIVFhaSSCQoKysTiralpSWy2SyFhYWYzWba29s/sYXU7Xbz4YcfMjw8TDqdJhqN4nQ6yWazBAKB5XaXpSXC4TBGo5H8/HxhG125zTKZTEIafuXguxKShbS3t/e6LIkyPj3S6TTvv/8+v/vd70gkEvh8PvFZguXh9LHHHuPJJ5/8VCUbEs6fP88vf/nLnJybeDzO6OgoS0tLaLVaOjs7cwbuzZs38/TTT8uKRxkyvuKQZzcZ6yGZTPLOO+8wMDCQc1ytVnPgwAE2btxIJpPBYrFw7NgxJiYmmJubQ6VSYTKZ0Gg0tLS0oFar8fv9VFZWcvToUcLhMAUFBWKuSaVS4txYVVXF//6//++CMJMQjUYZHx8XqjaXy0Umk6GiooKqqipKSkpoaWmhsbFR2DMlVZLb7cbv91NWVraqLOHjkE6nuXjxImfOnMHv9xMOh5mYmCAWi2EymYTlL5vNkkgkyGQyqFQqampqBPHX2NgoFF+SI+FGW0il92F4eJh4PE5jYyMbNmy4ag7dWkSbUqkURNuVPyvZcScnJ3NIIakQobm5+Zpy775IZDIZUWwwNTWFwWCgubmZXbt2sbi4iN/vp6ioSFiQJycn+eijj7Db7fj9ftRqNX19fezfv/9jCVu/38+HH37I/Pw8fr+f+vp67r//fs6cOYPNZsPn81FWVkZJSQmjo6M5j3F6elqUwEmQ/u42bNiwbi6bXH4gQ8atAZlgk3HdmJiY4MUXXxRWgLVw33338eijj4p67fHxcWZnZ0kkEuj1elQqFQ0NDRgMBtxut6jYVqlUFBcXr9nA2dXVxZ/8yZ8QDoex2WxiIFAoFKjVajweD8ePH2d4eJhEIiFaosrKyigtLSU/P18UEKRSKXw+H8FgEI1GQ2lpKfX19aKR9FoJkkAgwMmTJ7l06RLxeJzFxUUcDgepVIpwOIxSqSQQCBCLxUS+x0rbqARpoDGZTDidToLB4Lr3WVlZSW9vL52dnTf98PNlhNvt5uWXXxbtuS6XS1hvFAoFDQ0N/Mmf/Ml1tX263W5efPFFYQ+F5WF4ZmYGm82GQqGgvr5e1NPDcp7K888/L39fy5DxFYY8u8n4OPT39/Pee+/l2NlgeZF53333oVaricVinD9/XjQyut1uTCYTarWaoqIiYRvNZDIMDg7icrkoKioSZJK0zEyn05SXl/Ov//W/pr29fdVjyWQyWK1WxsbGmJiYwG63E41GKS0tFVla7e3tlJWV4XQ6xTk2nU7jcDjQarXU1dV94qVWJpNhZGSE06dPs7CwgM/nE62earWaeDyORqMR2VrJZBK1Wk1DQwNNTU1UV1eL/LimpiYxi0UiEdxu9w2zkCaTSUZHR4Vror29na6urqv+nng8Log26T2WFrxtbW1rzo1rFSIolUpqampuiUIEn8/H4OAg/f39ohl23759oihDqVSKqJpIJEJ/fz8TExM4HA4SiQTFxcXceeed9Pb2XrUEIRKJcPz4cRwOBx6Ph7KyMh599FEuX77M9PQ0Xq+XoqIi+vr6OHnyZI4te2FhgaWlJZqamnLuY+PGjdx7770Aa+ayrSw/qKurW/P6SIYMGV8sZIJNxg1BIBDgxRdfZHx8fN3btLe38+yzz6LT6RgYGODkyZPY7XYcDgdKpRK1Wo3BYBC20bm5OcbGxoDlgWQt22hTUxM/+MEPMBgMBINBbDabKBlQKBTodDpCoRBHjhxhYGBAEFU6nY7CwkLR/lRVVUVzczPpdBqXy8Xi4iLxeBy9Xk95eTmtra2fyEIai8WEYi8QCOB0OnE4HEQiEbLZLFqtVmxDU6nUmrZRCWazmcbGRtLpNBaLRQxpV0K2kH5xyGaznDlzhrfeeotIJMLS0hI+n09sjTUaDXfffTff+c531lVlfhxSqRSvv/46hw8fzjnu9XoZGxsjlUpRVFRER0eHsIBoNBqeeuopdu/efX1PUIYMGbck5NlNxrXA4XDw2muviYgLCWazmSeeeIKioiJg2Rp3/PhxxsbGmJ2dFUovhUJBXV2dCJMfGRnB5XKRn59PUVERarWaTCaD3+8nnU5TVFTEn//5n7Nx48Z1H5Pf72dsbIyhoSGmpqbwer2irKqyspK2tjY0Gg2ZTEa0xtvtduLxONXV1eIxf1JYLBZOnjzJ9PQ0DoeD6elp0uk06XQapVKJVqtFpVKh1WpJp9OiAKKxsZHS0tKcZk6ppCqbzeL3+3G73TfEQhqJREQRRV5eHt3d3bS0tFz1d1yNaGttbV1TWRcKhbBYLGsWIrS2tt7UJVzRaJSpqSnOnj3L0tISBQUFbNq0ie7ububn5wmFQpSVlVFVVYVSqcRms3HhwgXm5ubEor+9vZ177rnnqs8zmUxy8uRJrFYrLpdLNPNOT08L67Ner+euu+7izJkzOaRlOBxmenqa+vr6nNmwsrKSxx57jKKiojVz2VKpFHa7nUQiIZcfyJBxE0Im2GTcMGQyGd58801+97vfrXubwsJCnnvuOVpbW1laWuLDDz9kaGgIm82G3+8XW0ez2UxVVRWzs7P09/ej0WjQaDQUFxevOpFUV1fz53/+52KYisVi2O125ufnxRBRUFBAMpnkxIkTnDlzRmxSpVKEiooKSkpKaGxspKenh/z8fGZnZ/F6vSwtLZHJZCgqKsJsNtPW1ia2lR+HVCpFf38/p0+fxul0CkIxGAySn58vbLJqtZpQKEQqlVplG5WQn5/Phg0bKC4uZm5uThQsrAWDwUBXV5dsIf2cEQgEeOONN7h48aIoQViZaVJeXs6zzz7Lzp07P3VG2qVLl/jZz35GJBIRx2KxGMPDw4RCIfLy8ujo6KC4uFj8/65du3jqqaeuy6oqQ4aMWw/y7CbjWhGJRHjrrbewWCw5x/Pz83nkkUdobW0FltVoIyMjnDx5kpmZGRwOB2q1mry8PLRaLa2trSiVSj788EO8Xi9qtVo0hEpZZalUCpPJxPPPP8/OnTuv+riSySRTU1MMDg4yPDyMw+EgPz+fiooKqqurKS8vp6ioiMrKSurr61lcXBSExPXY6BwOB6dOnWJgYEA8z3g8TiKRID8/n/z8fLRaLQaDAYVCQWFhIdXV1TQ0NIhFqclkoqWlJSfHLJVK4fF4cLvd120hlcL9Jcvshg0bPvbvPB6PMzU1JYhDWCba6uvraWtrW/P+pYy6W60QIZ1OY7fbuXTpEhMTE+h0OpqamtizZw+xWIz5+Xny8vKor69Hr9eTSCS4fPkyIyMjYiluMBjYs2cPt99++7ozVCaT4dy5c0xOTuJyuVAqlTz00EMsLS0xNDSEx+NBpVLxwAMPYLPZOHPmjPjZVCrF5OQkBoOB6upqcVyr1fLggw/S3t7O4uKiaCKVctnk8gMZMm5eyASbjBuOixcv8rOf/SwnM2olVgaxA0xPT3Ps2DFmZ2eZm5sjFouhUqnQaDQ0NjaSyWR47733RJ27TqcTG1EJpaWl/MVf/IWQUMPyUOZwOJibmxMDgdFoRK1Wc+HCBY4cOYLP5yMej4tSBGmb2NTUxI4dOwSZNTc3JyykarWa0tJS6urqaG1tvSYLqWSbkAoRZmdnWVhYIBgMYjKZRAGC0WjE5/MRi8VW2UZXoqWlha6uLuLxOCMjI6u2zithNptFC6lsIf18MDQ0xCuvvILf7ycUCuHxeHJKELZt28af/MmffGry0+fz8ZOf/ISpqSlxLJPJ5Fg66urqaGxsFERebW0tzz//fM7fiAwZMr7ckGc3GZ8E2WyW48ePryrXAdizZw/79u0T55RQKMTJkycZHBwUmVMSOVNaWkpjY6NQghkMBmEnzcvLIxAIkEwmMRgMfOc73xHz4Mdhfn6e0dFRYVVNpVJiQVpYWEhHRwcbN25EoVBgs9lQKpXU1tbmBMp/UiwuLnLy5EnRqrq0tMTS0hKJRIKCggK0Wi1Go5HCwkJ0Oh0VFRWUlpZSWVkpSCfJTnqlvfJGWUgXFhYYGBhgcXGRsrIyNm3atO78KCGRSAiibeV8cjWi7VYtRPB4PIyNjYlituLiYvbs2YPZbMZqtRKLxaisrKSiogKFQoHL5eL8+fPYbDacTifpdJr6+nruuece6urq1lyQZrNZBgcHGRoawul0kkqluPfee1Gr1Vy8eBG32w3Avffei1ar5dChQ6JAAxDlCFKuoYTt27dzxx13CJLzylw2r9eL0+mUyw9kyLiJIBNsMj4TOJ1OfvSjH+FwONa9zW233cbTTz9Nfn4+8XicCxcucP78eRYWFnA4HCgUCqEcq6io4OjRowQCAQoLC1Gr1ZhMJgoKCsSJzmQy8ed//uerPp+ZTAan04nNZhOqH8m2cOnSJQ4fPszCwgLRaJRsNkteXh4mk4nKykoaGxu54447aGpqwul0Mjs7K1qi4vE4Op2OsrIy2traqK+vvyYL6eTkJMePH+fcuXPYbDbm5+eJxWIUFhaK+y0uLsblcokT6XrEWGFhIVu2bKGyspKZmRnGx8c/1kLa29tLQ0PDTblt/DIhHo9z6NAhjh8/TjqdFmpICUajkW9961scPHjwU70XmUyGt99+m3fffZeVX+NOp1MUfJhMphxiNT8/n6effprbbrvt+p+gDBkybnrIs5uMT4OpqSneeuutHAIAlmM5Hn30UUFYZbNZrFYrH374ITMzM9jtdmKxmDjnNDQ04HA4OH/+PEajEY1Gg16vx2QyEQqFSCaT6PV6vv71r/Pwww9fs7I7EokwNjbGqVOnGBsbE8UKWq2W2tpavva1r9HV1cXc3ByhUIiKigrKysquq107HA5z6tQp3n33XcbHx4lGo3g8HjKZDCaTCa1WS1FRkVDOSc2mVy5gpeb6laqvG2EhzWazzM7OMjQ0RDgcpra2lr6+vo8N608kElgsFiwWSw7RVldXR1tb25rk5HqFCBqNhoaGhpuyECESiWCxWDh//jx+vx+DwUBfXx+bN28WpRR6vZ76+nph/x0eHmZoaEgUGWi1WrZs2cKePXvWjfuYmpriwoULuN1uIpEIu3btoqqqinPnzuFyuUgmk9xxxx20t7fzm9/8BqvVKn42EAhgsVhobGzMccnU1tby6KOPotPp1sxlk8sPZMi4uSATbDI+M8Tjcf7+7/8+Rwp9JSorK/n+978vZNFut5ujR48KJY7P5xMnirKyMsbHx5mfn0en02EwGIRtVDqR63Q6fvCDH9DS0rLqvrLZLD6fD6vVKogOjUZDeXk5k5OT/OEPf2BmZoZIJEIqlRKlCOXl5TQ2NrJ//342b97M0tISDoeDmZkZvF4vfr9fDFhms5mOjo5rspDa7Xbef/99jhw5gtVqFVuywsJCSkpKqK+vp6SkhMXFRXQ6HXq9ft2Tpkqloqenh40bNxKLxRgaGsJut6973waDge7ubnp6emQL6WeMmZkZXnrpJRYWFojFYrjd7pwShM7OTn7wgx9QX1//qX7/yMgIP/3pT3OKMMLhMCMjI0QiEdRqtQiFlrB//34ef/zxm3LTLEOGjBsHeXaT8WmxtLTEa6+9lpMZBcvLzMcff5yamhpxLJFIcOnSJc6ePcv8/DwLCwsolUqy2awgaC5cuIBSqcRoNKJSqSgoKCCdTpNMJtHpdDz44IN84xvf+EQkmFSKcPr0ac6fPy+yT3U6Hd3d3Rw8eJDi4mLm5+fR6/XU1tZedyh8IpHg5MmTvPrqq1gsFiKRCB6PB6VSSUFBAfn5+YIY6+zspLW1lby8PPx+f87vyc/PF6UIK9Vi12shTaVSTExMMDY2Rjqdpqmpid7e3o91WqxFtEnZeu3t7euqAH0+H5OTkywsLIhjUiFCa2vrNcWpfF5IpVLYbDaR66fRaKirq2Pfvn2o1WqsViupVIrq6moxGy8tLXH+/HmsVqsoQaioqGD//v20t7evOZdL9mKPx0MgEKC3t5fe3l5OnTolZsCtW7eye/duTp06xcmTJ8WiVCqyMJlMOd/Zer2eRx55hIaGhjVz2eTyAxkybh7IBJuMzxSS3eDll18WJ+wrkZeXx7e//W22bt0KLEvQx8bGOH78OC6XC4fDQTgcJpPJoNVqWVhYwOv1kslkKCwsFBvRoqIiYS39/ve/T29v77qPKxAIiEIEWB4GKioqcDqdHD58WGxEJZurVIrQ2NjIPffcw44dO1AqlSwsLGC324U9IhgMisai+vr6a7KQejweDh06xLvvvovNZsPj8aBQKCgoKKCqqoqWlhZhbS0tLcXr9a5q+1qJyspKtm3bRkNDAxMTEwwNDckW0i8Y6XSaw4cP87vf/Y5UKrWqBCEvL4/HHnuMP/7jP/5UOWmBQICf/vSnOVXw0pAtfcZrampoamoSw2BzczPPPfdcTlabDBkyvlyQZzcZ14NUKsXvf/97Ya2ToFQquffee9myZUvO8cXFRY4fPy6WoYuLi2g0GtLpNKlUCovFQjqdRqfTodVq0Wg0qFQqstks+fn53HnnnXzve9+7anPjelhcXOTs2bN8+OGHWK1WQqEQxcXF3HbbbezYsQODwUAmk7lheVWZTIYPPviAl156CZvNRiAQwOfzodFo0Ol0FBQUiJy4LVu20NvbSyaTYXZ2NsdpoFAoqK6upqWlZVWEw9UspGVlZVd9nWKxGCMjI1gsFpRKJV1dXesSQiuRTCYF0Sbdp0S0tbW1ravculUKEbLZLG63m4mJCS5duiSuJXbt2iXIK5/Ph8lkoq6uDrVaTTabZXJyksuXL+NwOPB6vWKxfeedd645R/l8Pj788EOhSmxoaGD37t2cOXMGl8tFOBymt7eXe+65B7vdzm9+8xtBqErK0KWlJdrb23PIsn379rFr1y78fv+qXDa5/ECGjJsDMsEm43PBzMwML7zwAj6fb93b3HnnnTz55JNCVRMOh0W+h8fjYWFhgWQySSqVYmFhgUgkQiaTIS8vT2xEi4qKMBgMqFQqvve97wnSbj1Eo1FsNhsLCwtiICgrKyMSiXD06FEGBwdZWloS9yWVItTV1XH33Xdz5513otfrCYfDQtXmdDqFhTQ/P1+0kDY0NFzVQhoIBPjVr37Fb3/7W6xWK4FAQGx7GxoaaG9vFydSs9nMyMjIqo3oSmi1WjZt2sTWrVuFqk22kH6xcLlcvPzyy0xNTZFMJoWFQEJtbS3/6l/9q6uSw+shk8nw7rvv8vbbb4tNaDabZX5+HovFQiaTwWg00tXVJYYuo9HIM888Q3d39415gjJkyLipIM9uMm4ELl++zDvvvLNqUdrb28vBgwdzCIBMJsPExAQnT55kfn4eh8NBNBpFoVCwtLSEzWYTt5fmNUAQUzt37uT73//+py7lSSaT9Pf384c//IHBwUGi0ahQk3V2dlJVVYXZbMZsNt+wOeedd97hpZdeYn5+Hq/XSzgcRq1WC9toVVUV7e3t7Nixg82bNxMKhZiamlo1wxUUFIhShJWv6fVYSIPBIAMDA8zNzaHT6cSM93FKwfWIttraWtrb29cl2q5WiNDa2ipaO79ohEIhZmZmuHz5Mm63W7w2t99+O9FoVLhAamtrKSwsBJYJz/7+ftEwG4lEKCoqYu/evWzcuHGVaiwUCgmxgNvtpry8nP3799Pf34/T6WRpaYm2tjYeeOABEokEhw4dYnp6Wvz84uIiU1NTNDc3i8cAywvShx9+GGBVLptcfiBDxhcPmWCT8bkhFArx05/+lOHh4XVvs5aqxmazcezYMaE483q9pFIpsd3R6/Wk02mRv5GXl0dxcTFarZannnqKr33tax/72JLJpCgzkAYJ6aR05swZPvroI7xeL5FIhEQiIUoRJJn4/fffT2FhobChOhwOLBaLsJBKeViShbSmpmZd2XwwGOSXv/wlhw4dwmq1ihIGg8FAR0cHLS0t5Ofns3fvXsrLy+nv72dycpKr/Sk3NTWxbds2mpubRRuXbCH9YpDNZjl9+jRvvfUWsVhsVQmCSqXirrvu4plnnvnY7JS1MDExwU9+8pOcwT0YDDIyMiIKRNra2sSmXKFQ8MADD/DAAw/cFEOvDBkybhzk2U3GjYLT6eS1117LydyC5UyxJ554YtWsEI1GOXPmDAMDA7jdbpEr6/P5mJqaErEXkgshHo+j0WgwmUxs3ryZP/uzP7uucgKAsbEx3nrrLS5evEg6nRaZvk1NTWzcuJHe3t4bptxPJpP8+te/5s0338RqteL1ekkmkyiVSvLy8qisrKS6upqNGzdyxx130N3dLYg2m82Wo/pSq9XU19fT0tKSQ6zA1S2kFRUV6z4fr9fLpUuX8Hg8FBcX09fXR2Vl5TU9r+npabEYhOW5oaamhvb29nXnlFuhECGZTGK1WhkfH2dychKVSkVNTQ379u2joKBAKBNLSkqoqakRM5LNZqO/v5+FhQWcTifZbJaWlhbuuuuunCZQWCYcP/zwQ+bn53E6nRgMBg4cOMDIyIhQy0kZa1qtlnPnznHs2DHxeYjH4wwPDws76Mrc6UcffZTKyso1c9nk8gMZMr44yASbjM8VmUyGQ4cOcejQoXUJIaPRyLPPPktnZ6c4lkwmuXTpEmfOnGFxcZGFhQUCgQAzMzOinlylUqFSqTCZTEL5VVhYyGOPPcb9999/Tbke6XRaFCJI4b46nQ6NRsPQ0BDnz5/H4/EQDAZzShGKi4u54447ePDBB6mqqgKWhyDpd1mt1hwLaXFxMfX19bS3t69rIXW5XPz85z/nvffeE8OXQqHAaDTS09NDQ0MD5eXlHDx4kJKSEs6fP09/f/+qUOKVMJlMbNmyhS1btuQEuH6chbS3t5eOjg7ZQnoDEQgEeO2114RF4coShJKSEv7Fv/gX7N279xMHM4dCIX7+859z+fJlcSyZTDI2NiZUpGazmZaWFqEe6Orq4plnnrmp8lJkyJBxfZBnNxk3ErFYjF//+teMj4/nHM/Ly+Ohhx7KmdskLCwsiKZ4t9vN4uIiwWCQwcFB1Gq1ULHpdDpSqRSZTIbi4mI2btzIX/7lX94QBc7U1BSvv/46IyMj4phWq6W9vZ17772Xzs7OT2VLXQtLS0u8+eabvP/++6IEIJPJiDw6s9lMfX09u3fv5s4776S2tpZEIsHMzAwWiyWHjILlRtaWlhZqa2tXLcE+qYU0m80yNzfH4OAggUCAqqoq+vr6VpF4ayGVSgmiTVKmXQvRdrMXIkiPb3p6mqGhIeLxOCaTie3bt9Pe3i6W5hLpKSn3pNzBiYkJ5ufnCQQCGAwGtm3bxrZt23LI4XQ6zenTp7FarSwsLKBSqThw4ABWqxW73Y7X66W0tJQnn3wSo9GIw+Hg7bffFrN5JpNhenqaQCBAZ2enuGZQKpXs37+fLVu2rJnLJpcfyJDxxUAm2GR8IRgcHORv//ZvVw0SEhQKBY888ggHDhzIIRf8fj/Hjh1jYmKCxcVFXC4XFouFmZkZtFotBoMBhUKRUwpQVFTEI488wpNPPnnNREU2m8Xj8YjtFSwPA/n5+UxNTXH+/Hm8Xi+Li4s5pQhGo5EdO3bw6KOP0t7eLn5fJBLB4XAwOzsrNlbxeBytVitaSBsbG9e0kE5NTfF3f/d3vP/++yJsWKFQUFhYSE9PDzU1NWzYsIEDBw6g1+sZGhri7NmzzM3Nrfv8lEol3d3dbN26lfr6ejFwTUxMXNVC2traKsg9+UR9YzA4OMgrr7zC0tISsVgMl8uVM7xu2bKFP/3TP/3EG8hsNsvhw4d5/fXXRWZfNpvFZrMxOzsrhv2uri4xMBYVFfHcc8+tWRIiQ4aMWw/y7CbjRkNSYX/wwQerFqXbt2/nrrvuWjUfpNNpBgcHOX36ND6fD6fTicfj4fLlyySTSQwGA/n5+Wg0GpRKJfF4nIKCAjZu3Mi/+Tf/5obkdyUSCYaHhzl58iSzs7MEg0GWlpZQKBRs2LCBhx56aFXhwPVgZmaGQ4cOce7cOYaHh0URUSqVoqCggOrqajo6OrjrrrvYu3cvJpNJkD1TU1M5pQGwTAg2NTXR3Ny8Stl3NQtpRUUFJpMpZ/7NZDJYLBaGh4eJx+M0NDSwYcOGa3ruaxFtgCDarraku5kLEQKBAFarleHhYZxOJ1qtlq6uLrZu3YpCocBqtRKJRKioqMBsNovX0+VyceHCBebn55mfnxf5Z3fccQfNzc05TbEXL15kYmICp9NJKpXirrvuYmlpCavVitvtpqCggMcff5zS0lJisRjvvPMOExMT4jF6PB4mJydpbW3Ncfp0dnZy8OBBIpHIqlw2ufxAhozPHzLBJuMLg9fr5YUXXmB2dnbd2/T19fHd7343Z5iQwkalXAOfz8fQ0BBjY2MolUp0Oh06nU4UBUg5GPfddx/PPffcJyaGpMwQj8cDLA8Der0eq9XKxYsX8fl8eDweQqGQaIeUshwef/xxNm/eLGTw2WyWxcVFYSH1eDzCQlpQUIDZbKazs3OVhTSVSnHx4kVeeeUVjh07JlRICoWCkpISOjs7qa+v5+6772bnzp2o1Wrm5uY4d+4cAwMD6xZMwHIA7datW9m4cSMKhYLx8XG5hfRzRiwW4ze/+Q0nTpwgk8ng9/vFkAzL7VHf/OY3efjhhz/x53dmZoYXX3xRfH5hmageHR0lkUigVCpFGYdCoUCpVPLkk09y1113fWLlnAwZMm4uyLObjM8Ks7OzvPHGG6sWpXV1dTz++ONrKpqCwSAnTpxgdHRUhLT39/cTCoXIy8sTM5tSqSSdTqPRaOjq6uLf/bt/t8p692kg5VPNzc1hs9mEakwqA9qyZQt33nknbW1tFBUVXfc5MJ1Oc/HiRY4ePcpHH33EwMAA8XicbDZLMpmkqKiI2tpaNm7cyIEDB9i6dasgQEKhENPT06uyzBQKhSjAqqioWPUYJQupFKQP61tIpcbKiYkJstks7e3tdHV1XZN1M5VKMTMzw+TkZM7jk4jDq5Fl61ljv+hChEQigdVqxWKxMDk5CUBVVRV79uyhrKwMl8uF0+kkPz+f+vp68VpKjpCRkRGcTiderxetVktfXx+7d++mqKhI3Mfo6CiXLl0SWX27du1CpVIxNTWF2+0mLy+PRx99VHze+/v7+eCDD8SiNBqNMjg4SHFxMY2NjeL9LykpEX93V+ayyeUHMmR8vpAJNhlfKJLJJL/61a84duzYurcpKyvj+9//PnV1dTnH4/E4Z86c4eLFi4RCIQYHB7lw4YIYyiTbQX5+vlC2bd26lX/7b//tp5KjRyIRbDYbTqdT2DXz8/OZn59naGgIr9crKrnD4TDZbBa1Wk1LSwsPP/wwu3fvzhk4JTuq3W5nZmZGWEilraNUjW42m4UcPBAI8OGHH3Lo0CHOnj1LIBAgm82K5tLm5mZ6e3t54oknhAopGo1y8eJFzp07d9WSiby8PDZu3Mi2bduoqKjA7/fLFtLPGTMzM7z00kui0OPKEoS2tjZ+8IMf0Nzc/Il+bzQa5Ze//CUXLlwQx+LxOKOjo8KWWlFRQVtbm7CU3HbbbXz729+WBzEZMm5hyLObjM8SwWCQ119/fdVCzmAw8Pjjj1NfX7/qZ7LZLLOzsxw/flxkWB0/flzkhup0OqGqluaompoa/v2///d0dXXdkMctqYakdtP+/n4GBgZwOp1oNBo6OzvZunUrGzZsoKKi4rpVP6FQiA8//JBLly7R39/PyMgIyWSSdDpNJpOhpKSEhoYGdu3axT333ENXV5c4F6fTaex2O1NTU6tmOKPRKEoR1ooauVYLaSQSYXBwEKvVikajoaenh5aWlmuOVpGINmnJDMtEW3t7+1UtvvF4nOnpaWZmZtYsRKiurv7cF30SCWuz2RgZGSESiWA0Gtm2bRsdHR3EYjGsViuJRIKqqqocd4Hf7+fChQvMzc2Jco+Kigr27NlDd3e3+BxZrVbOnTsnokE2bNhAeXk5Y2Njgux96KGHxKzndDp5++23hcVWKhIJhUJ0dXWh1WqB5ey+AwcO0NXVtSqXTSq9Wlpaorq6+ppswTJkyPh0kAk2GTcFTp06xd///d+va09Uq9V885vfZPfu3av+z+VyceTIEaxWK9PT0xw7doxIJIJCochpGDUYDEI2/Vd/9VdrDn7XgkQiIQoRJGWYVqvF5XKJjKvFxUU8Hg+RSIRkMilyKu655x7uvvvuVVvHaDTK/Pw8MzMzOBwOYSHNy8sTFtKmpiZhIbXZbPzhD3/g6NGjggCT8uBKS0upqanhrrvu4utf/7o4iWazWaampjh79izj4+NXLUVoaGhg27ZtdHV1oVQqsdvtsoX0c0I6neYPf/gD7733Hul0mmAwiMfjEdtLtVrNww8/zD//5//8Ew392WyW48eP8/LLL4vPbTabZXp6Wlwc6XQ6Ojs7xea5oqKC559/Xv7OlyHjFoU8u8n4rJHJZDh8+DBnz57NOa5QKLjrrrvYsWPHmj+XSCTo7+/n/PnzLC4ucvjwYdHorlQqMRgMaDQaMpkM+fn5FBcX86/+1b9i//79N2S+kHLPJNtfPB7ngw8+4PLly/h8PlQqFWazmc2bN9PX10dtbe2nKh5aCYfDwfvvv8/s7Cznz5/HYrGQSCRIJBKoVCpKSkpob2/n7rvvZvfu3dTV1eUQZ1KrpM1mEzMBIOx/ra2tOWopCddqIfX7/Vy+fBmn04nRaKSvr4+ampprem7rEW1VVVV0dHRclWhLp9NYrVampqZyloo6nY6Wlhbq6+s/90IEv9+PzWZjfHychYUF1Go1nZ2dbNu2Da1Wy/z8PB6PB6PRSH19vZjHJJfNwMCAULypVCra29v52te+JoolXC4XJ0+exOfz4fV6aWpqoqOjg6GhIdxuN6lUinvvvZeenh5g+fP63nvv5RTFLSwsYLFYaG1tzXGTbNy4kXvuuUeQqytz2aTyg7KyMlGIIEOGjBsLmWCTcdPAbrfzox/9SGxv1sLu3bt56qmnVhEL6XSaoaEhTp48yczMDIcPH2ZpaYl0Oo1KpUKr1aLT6UROm9ls5plnnuH222//1AqddDrN/Pw8drudWCwGLJMfXq+XyclJAoGAGGikUgRYVuTt2rWL+++/f82to9/vZ25uTtgWJAup0WiksrKSrq4uamtrMRgMDAwM8N5773H27FnRqprJZNDr9aIa/utf/zoPP/xwznDi9/s5f/48Fy5cyBlmroTRaGTLli3cfvvtmEwmEonEJ7KQ9vb2UlJS8qle3686nE4nL7/8MhaLhXQ6jdfrzVESms1m/vRP/5TNmzd/ot9rt9t58cUXczJQvF4vY2NjIkuwqalJbI41Gg3f/OY32bVr1w17bjJkyPh8IM9uMj4vjIyM8Jvf/CZHiQTQ0dHBQw89JFQ2V8Lr9XL8+HEmJyc5evQoFouFVCpFNptFo9GIGU0qlHryySd55JFHbshsIamV3G43JpOJ2tparFYrR48eZWpqikAgQCKRwGg0ipmmpaVlzbzcT3KfAwMDfPjhhzgcDi5evMj8/LxYyEqL0g0bNnDfffexadMmKioqcqJSVpYiSE2iEkpKSkQpwlrFDddiIV1YWGBgYIDFxUXKysrYuHHjNceBpNNpZmdnmZiYyCHazGYzHR0dV1VOSSqrqampVYUIjY2NNDU1fa5OiXg8zuzsLHa7nYmJCdLpNBUVFezdu5fKykqCwaCwudbU1OTkokUiES5cuIDNZhMlCMXFxWzfvp1Nmzah0+lYWlri+PHjeL1e3G43FRUVbNmyhcHBQbxeL9FolH379rF161bxewcGBvj9738vFqXhcFhYRpuamsTnsrKyksceewxgVS6bXH4gQ8ZnC5lgk3FTIRKJ8LOf/YxLly6te5u6ujq+//3vr5nRIMnwz5w5w+HDh1lcXBTElkqlQq/Xo9VqRZPTk08+yaZNm2hsbLyuYcntdmOz2USIrVKpxO/3Mzk5SSgUIhgM4vP58Pl8RCIRkbm2efNm7rvvPjo6OlZtHTOZDC6XC5vNxvT0dI6FtKioiPr6epFzceHCBX7/+98zPDyM1+vF7/eTyWQoKCjAaDRSV1fHd7/7Xb72ta/lPM9UKsXw8DBnz57FZrOt+xyVSiUdHR1s27aNpqYmsekcGhpieHhYtpB+Rshms5w6dYq3336bWCxGNBrF7XbnlCB87Wtf4/nnn/9E4cDxeJx//Md/5NSpU+JYLBZjeHhYDOulpaW0t7cLMns9cluGDBk3L+TZTcbnCY/Hw2uvvZaT+QnLpM8TTzyxrmImk8kwPj7OsWPHeP/99xkfHycajZJOp8WiR1qWlpSUcODAAfbv3097e/u6xN0ngWQZVSqVNDY2otPpuHTpEh999BFer5dgMIjb7RbNlz09PXR0dGA2mz/1XBOLxThx4gQfffQRNpuN0dFRETOSTqcxGAyUlpaybds27rvvPhobGykrK6OgoEA4ILLZLC6Xi8nJSVGCJUEiVJqbm4Xl9kqsZyGtqKigpKQEu93O0NAQ4XCY2tpa+vr6rlnFJxFtk5OTYgkNyzNhe3v7mkq7lfB6vavKHpRKJbW1tbS0tHxuhQiZTIa5uTnm5+cZHx9naWkJg8HAli1b6O7uFs2sfr9fZOqtJDatVisfffQRbreb+fl50uk0TU1N7N27l4aGBuLxOMePH8ftdgvl4K5du8TnIRQKcfvtt/O1r31NvO9er5e33npL/J2lUinGxsaIRCJ0dXUJUlqr1fLggw9SV1e3KpctHo9jtVrl8gMZMj4DyASbjJsO2WyW9957jzfeeGNdG6Ner+e73/0ufX19a/6/1Wrl0KFDvPbaa6LpMx6Po1QqycvLQ6/Xk5+fT2VlJU8++SR1dXV0d3df90Z0cXERm82Wk5OxtLTE9PQ0oVCIWCzG4uIibrebcDhMPB5Hp9PR3d3N3r172bRpE1VVVau2jvF4HIfDwczMDHNzc8JCqtFohIXUYDDQ39/PiRMnBNm3uLhINpulsLAQrVZLR0cH3/ve99iyZcuq+5ifn+fcuXOi1Ws9lJWVsXXrVjZt2kR+fr5ophwaGmJ8fHzdQgXJQtrb20tDQ4McoP8JEAgEeO2117h06ZIoypDeW1jOK/ne9773iYsJTp8+zT/8wz+ILXMmk2FqakoM6lKLlmTtqK2t5fnnn5dtBTJk3CKQZzcZnzcSiQSHDh3KsbLBssL/gQceoLe3d92fjUajIjJkcHCQWCxGLBYT5zqpZKqkpIT9+/ezY8cOWlpaqKuru+6ZQlKFRaNRkVsVDAYFAZZKpcTSE6C8vJyOjg7a29tpaGigsLDwUz0Gt9vN+++/z9TUFBaLBbvdLmyDSqVSEF779u1j3759VFRUUFpaSnFxcc7CNBwOi1KElcoxQJQiSGVGV+JqFtKSkhKcTicTExOkUimam5vp6elZM/NtLWQyGaFoW0m0VVZWrrlcvhLrFSJUVlbS0tLyuRUi+Hw+7HY709PTzM3NoVKpaGtrY/v27RgMBhYXF5mbm0OpVFJXV5dDACYSCS5dusTU1BQulwuv14vBYGDz5s1s27YNvV7PyZMncTgcuFwulEole/bsYWZmBo/Hw9LSEl1dXdx///3iPU8mk7z//vs5ggTJ/dLW1paTDbd9+3b27NmDzWbLyWWTyw9kyPhsIBNsMm5ajI2N8eKLLwpV2Fo4ePAgDz300Jrqs2QyyfHjx/lv/+2/CdVPKBQS+R6SXbS4uJinnnqKyspKqqur6ezsvObBYT2Ew2FRiJDNZslmswSDQWZmZsR2cmlpSZw4I5EIeXl5NDc3s3XrVrZs2UJTU9OqKnZYJuwcDoc4UUsWUoPBQEVFBXl5eQwMDDA4OEggECAUCrG4uIhCocBkMpGXl8fmzZv54z/+YzZu3LjqucZiMT766CPOnj2L1+td9zlqNBr6+vrYtm0bZrMZ4JotpJLdoqenR7aQfgIMDAzw6quvsrS0RDKZxOVyCYUmLOdu/OAHPxDvx7VgYWGBF198Mef9koZpqcyjsbGR2tpaUezx9NNPc9ttt93Q5yZDhowbD3l2k/FF4dy5c/zhD3/IIUVguanznnvuWdO+KGF+fp6/+Zu/4ejRo6RSKaLRqCCN0uk0Op2O8vJy9u7dy759+zCZTHR2dl53cLukRvJ4PJhMJpHVOzc3h9VqJRgMkkwmcTgcWK1WYrEYJpOJhoYGWlpaaG9vp7y8/FPlhY2OjnLs2DGRRRYIBPB4PLjdbrRaLQUFBdTU1HDHHXewbds2ioqKKCkpobS0NOf+MpmMKEW4coYzGAw0NzcLq+BaWM9CWlBQgM/nEyRSZ2cn7e3tV30fVyKTyWC1WpmYmMiZWyorK2lvb8+xV66F9QoRioqKaGlp+VwKEaLRKFarlfn5edGeWl5ezu7du6muriaRSAgSq6ysjKqqqpzrE6fTyYULF3C73TgcDmKxGLW1tezevZvW1lY++ugjpqencblcpFIpdu3ahcvlEsRnY2MjDz/8cM57NzIywu9+9zvxmgQCAYaGhigqKqK1tVXcf21tLY888gihUCgnl02hUMjlBzJk3GDIBJuMmxp+v58f//jHTE1NrXubrq4unnnmmXXl4vPz8/zbf/tvGRsbI51OE41GRQmCSqVCp9NRWFjI17/+dVpbW1Gr1bS3t9+QjWg8Hsdut+NwOEin04Jom52dZWlpiWw2SzgcFtvKcDiMQqGgrq6Ovr4+NmzYQHd3N2VlZaseSyaTwe12Y7fbsVgswsagUCgwGAyEw2GmpqaYm5sjk8kQiUTw+/0iPNhgMHD77bdz3333sXnz5lVknhSAf/bsWUZHR69ailBXV8e2bdvo7u4Wg961WkirqqqE3UK2kH48YrEYv/nNbzhx4oT4PK0sQcjPz+eP//iPeeKJJ67Z9pxMJnnllVc4evSoOBYOh0WDFkBxcTEdHR1isLv77rt5/PHHr3m4liFDxucPeXaT8UVibm6O1157bdWitLq6mieeeOKqwfepVIpf/vKX/K//9b9IJBIkk0lBtKVSKTQaDaWlpezcuZODBw+iVCqpqamhtbX1uu1uUsC9ZBk1GAx4vV4WFhZIJBLCjeDxeLDb7Xi9XvR6PZWVlSK+Q8rK/SRIJpOcOXOGc+fOMTs7K9oqJdWTlK/b0NDA3r176e3tpaioiKKiIkpLS1fNUH6/n6mpKdGYKkFSWbW0tFx1yRmJRHC5XHg8HuFsUCqVwrpoMBg+sSthPaKtoqKCjo6OjyXavuhCBKnV1e12Mzk5ic/nQ6fTsXnzZnp7e1Gr1cIOmpeXR0NDQ44yTMqMHhsbE0RmXl4e3d3d7Ny5E5fLJeJewuEwW7duJRwOi8+A2Wzmsccey5nZFxcXefvtt3E6ncDy52hkZIRYLEZXV5e4rV6v55FHHqGwsBCbzYZWqxVkq1x+IEPGjYNMsMm46ZFOp3nttdc4fPjwurcpLi7mueeeE5XWVyKRSPCf/tN/4oMPPiCRSJBKpQiFQiQSCWEbNRqNHDx4kG3btqFUKiksLKS7u/uGbHNSqZQoRJC2sH6/H7vdLtRlyWQSv9+Py+UiGAySSqWorq4WFoSNGzeum5OQSCSYn58XjZCShTSRSOBwOJibmyMcDpOXl0c8HicUCqFQKNBqtZSXl7N582a2b9/O7bffvubzXVpa4sKFC1y4cGFVoO5KGAwGbrvtNm6//XYh+5ctpJ8Npqenefnll1lYWFizBKGpqYk/+7M/o62t7Zp/Z39/Pz//+c+FjSOVSjExMSGKR/Ly8ujs7BTvbUtLC88+++zHDsQyZMj4YiDPbjK+aITDYd58801mZmZyjut0Oh599NF15zYJx44d4z//5//M4uIimUyGeDxOOBwWip3CwkJuv/12nnjiCbRaLRqNhra2Nqqqqq5rjojH48zMzBCLxYSlLhaLCbtoQUEBLpeLhYUFotEodrudhYUFNBoNJpOJ6upqWltbaWlpWWXn/DgsLi5y9OhRhoeHsVgseDweUqkUCwsLBINB9Hq9yEndvHkz7e3tlJSUUFBQQFlZ2aqctGQyKUoRriQ7i4uLhc12vYWZZCF1uVwiniIQCOByuUin09TU1LBhw4ZPpJ7PZDLYbDYmJiZyiDLJevtx7gapEGFychK/3y+Of16FCG63m4WFBWw2G1arFYVCQXNzMzt27KCgoIBYLMbs7CzxeJzKykoqKipyPo9S2ZjT6WR+fp5gMEhFRQXbtm3DaDQyNDTE4uIifr+fnp4e1Go1TqcTt9stRAErCep0Os2RI0e4cOGCeH2sViuzs7O0traK9lKAffv2sXnzZmZnZ3Ny2eTyAxkybgxkgk3GLYPz58/zi1/8YlW2hASVSsUf/dEf5QSBrkQmk+EnP/kJv/71r/F4PDmD2krb6J133snevXvFiau+vp62trYbEgC6srhAkt5LraFS3kYmkyEYDOL1egVRVllZSVNTEw0NDfT29tLe3r7u5lc6QU5OTgoLqdPpZHJyUtSF63Q6MpkMiUSCbDaLWq2msbGRtrY2ent72bFjx5q5Ful0mpGREc6ePcvs7Oy6z1OhUNDe3s62bdtoaWkR74dsIb2xSKVSHD58mPfee0+oM10ul9g0q1QqDh48yHe+851rDoL2eDy8+OKL4mJIGmItFouwjNbX11NfX49CocBoNPLMM8/Q3d39WT1NGTJkfErIs5uMmwHZbJajR49y4sSJVf+3b98+9uzZc1UybHh4mP/yX/6LICwk22gwGCSbzZKfn097ezvf/va3xcxQVFREZ2fnNYfyr4VMJoPD4cixjCqVShYWFvD5fJSUlGA0GpmensbhcIicNinMXqfTCcVZZ2cnVVVVnyiCxGKxcOTIESwWC1NTU0SjURKJBAsLCySTSfLz88UitrW1lba2NkpLS9Hr9ZSVla2ZC+dyuZiamsLhcOQ4E/Ly8mhsbKS5ufmqr9lKC6lkN5TUfu3t7WzduvUTLaYlS+v4+PinItrgiytEkOJgPB4Pk5OTRCIRSktL2bVrF3V1dWSzWRYWFnC5XBgMBurr63Pe/2w2y8TEBIODg/h8Pubn51EoFLS1tdHc3IzVahVW4aamJkpKSpifn8ftdpOfn88TTzyRk7UGMDExwTvvvCMWpYuLiwwPD1NUVERbW5sgUZubmzl48KB4HyUSWS4/kCHj+iETbDJuKczPz/PCCy+saktaiW3btvGtb31rTUIhm83yyiuv8Otf/xqHwyHItZW2UaVSyaZNm9i/fz81NTWo1Wqh3Kmurr5hz8Xn82Gz2UQV+dLSEnNzc7hcLiFvl2wILpeLSCRCSUkJ9fX1YjMqbQzX2jJlMhlhX5Dy2kZGRsRWTFLt5eXloVAoiMfjYkg1m820trayY8cOampq1hx8nU4n586d49KlSzl5GFeipKSErVu3snnz5hyZvGQhHRoaumrOXlVVlWghvRFtYV9GOJ1OXn75ZSwWC9lsFp/Ph9/vF8NzeXk5//Jf/sucqverIZVK8eabb/L73/9eHAsEAoyOjoqhrbCwkM7OTrRaLQqFggcffFDYdGTIkHFzQJ7dZNxMmJiY4K233lq1KG1paeGRRx65asj69PQ0f/3Xf83s7KyIRUgkEiwtLZFIJMjLy8NsNvONb3yDtrY21Gq1WAg1NzdfV5yB3+///7P3nt9t3VfW/0ZvRCEBNgAECAJEYS8iVezEcXdsJ5Flp9gp9sSxnDgzjtd61vwZz1pOnsyM45KejBxbdizZcpUlWRJFkZQo9goWgA29d+De3wv+7jeCWETJTbbvZ628CIh2LyDfg3PO3ps0HRjJaDQaxfLyMgQCAWpqalAoFEhIAUVRiEQi8Hg8SCQSEAgEEAqF0Ov1sNlsMBqN28pjL6dQKODChQvo6enBwsICkXumUilSPzJSv7q6OlK/VVRUQCwWk0CEK48/lUphbm4O8/PzReEDwL8CBK62BchISL1eLxYXFzE/Pw8ej4fGxkbcfPPN19TcZBptMzMzZAANrAdr2Ww2qNXqqz7H5xGIkM/nSbjZwsICPB4PJBIJWltb0dzcDIFAgEQiAZfLRZQpVx5LIpHAxYsXsby8DI/Hg2AwSDYLE4kE8dwtLy+HXq/H2toafD4fuFwu9u/fj5qamqLni0ajOHr0KJaXlwGsb2OOj48jk8nA4XCQz0WhUOA73/kOuFxukS8bozxhww9YWK4PtsHG8oUjk8ngz3/+MwYGBra8j1arxc9//vOilWgGmqZx7Ngx/POf/0QoFMLa2hry+TyRjeZyOXA4HJjNZuzbtw/19fVkglZWVoaGhoaPNRG9klgsBrfbDZ/PR9bul5aW4PF4yOSICUXw+XyIRqMoKSmBwWCAWq2GTqdDa2sr6urqtlyHz+VyWF1dxcLCAiYnJ9Hb20saW0yqqkKhgEwmQzweh0KhgNVqhUwmQ01NDfbt2weTybRp8ySTyWBoaAh9fX1ESrgZfD4fLS0t6OrqKmpUMhfy0dFRklK1GUxiU2NjIysh3QSapnHu3DkcOXIE6XQa2WwWPp+vyOPkpptuwlNPPbXj6fLIyAj+8Ic/kGI3l8thamqKpOQKBIKiCXNDQwN++tOffmrTYhYWlmuDrd1YbjRCoRAOHz5M/KIYlEolHnzwQVRXV2/52JWVFfz617/G2toaVlZWEI/HUSgUEIvFkEgkwOfzoVKpcOutt2L37t1QqVREnWCz2T6Wt9RmktFcLoelpSUkk0lUV1ejrKwM6XQa8/PzpBGWy+VIo42iKNA0DY1GA7PZTNIed9L8i8ViOH36NC5dukRko4VCAfF4HJlMBjwejwxJKysrycaSVquFQCBAaWkpNBrNho0kiqKwvLwMp9MJv99f9DepVEpCEbYbcDIS0uXlZZK6CgAtLS3Ys2fPjppjlz8Xs9F2eaNNrVbDZrPtqEnGBCLMz8+TjX7g0wtEoGkaXq8XHo8Ha2trWFhYAEVRMJlM2LNnD5RKJQqFAlZWVhAMBqFQKFBTU7PBK87lcuHSpUsIhUJYWVlBNptFVVUVCZhi0ketVivW1taIdPi+++7bYAdCURTOnDmD3t5e8v/n5+fhdrvJOQDWN/1uv/12mM3mIl82gUDAhh+wsFwnbION5QsJTdM4efIk/vGPf2xIqGIQiUR47LHHtkw7PHnyJA4dOoRMJkMuIoxslPEZq6qqQnt7OywWC2pra8m2jslkgtls/kQN3tPpNJaWloi0IB6Pw+12E6NU5riZUAS/3w+JRAK9Xg+lUgmNRgOHw4GGhgaUlZVtWTzE43GsrKygp6cH7733HjFCZYrQiooKlJaWIh6Po6KighQB1dXV2LNnD2w226YGsjRNY3FxEX19fZiYmNjycwEAnU6H7u5uYgjLkM1mMTU1hbGxMTJ52wxWQro1kUgEhw8fxvDwMAlBCAQCxOC4pKQEjz32GO6+++4dFZihUAgvvfQSZmZmAPyrIbq4uEg25PR6PWpra8HlcqFSqXDw4EGYzeZP7yBZWFh2BFu7sdyI5PN5vPPOOxgaGiq6ncfjkeClrfD7/Xj22WeJH5jH40E2m0UqlSJJ8SUlJWhvb0dLSwscDgdpDqnVatjt9uveyLlSMmo0Gonpv9frJUmfPB4P2WwWCwsLWFhYQD6fB03TCIVCpMbM5XKQSCQwGo0kFGEn72tpaQknTpzA1NRUkWw0Go1CJBKBz+dDJpPBZrOhpKQEUqmUeKwJhUIoFApoNJpNXysSiWBubg6Li4tFw87L5ZZXa5bl83ksLi6it7cXCwsLJDiss7OTbNXtBCbRdXp6usj791oabVsFIjCNw086EIGp2yORCHnfZWVl2LNnDxkMRyIRYpHC1O+Xwwyt5+fnEQgE4PP5iNKkrKwMqVSKbAky226pVAq333472traNrynhYUFvPnmm+T4/X4/JiYmoFKpYLVayfHb7Xbceuut5DcI48vGhh+wsFw7bION5QuN0+nE888/X2RweiXbpR329/fj97//PWloMbHZTOpmKpUi6+lVVVVobGwkRqUSiQQOh+MTv+AwEfDLy8ukaJyfnyeNNqYpksvlEAqF4PP5IBAIUF5ejrKyMshkMtTW1qKtrQ16vX7L4oGZuB0/fhyvvPIK5ufniWxDIBCgqqoKGo0GFEWhqqqKNO00Gg26urrQ3Ny85UQzFovhwoULGBgY2Fb+KZVK0d7ejq6urg1G+ayE9OMxMjKCV199FZFIBPl8nqTMMjQ2NuLpp5+GTqe76nNRFIU333wTx44dI021cDiMyclJIg+Wy+VwOBwQi8Xgcrl46KGHcNttt7GbhiwsnyNs7cZyI3Pp0iW88847RQmXwPrm0z333LOl/1M0GsWvf/1rLC0tkQ2xYDCIdDqNVCqFXC4HkUiEhoYG6HQ6dHR0kK0lLpdLPG2v19IgFArB7XaDz+fDaDRCJpMhmUwSeWhNTQ1JEGUaTvPz8+R6yWyZR6NRJBIJ0DSN6urqolCE7a6dFEVheHgYp0+fLkoJZerW0tJSiEQiqFQq6PV6CAQCspnEpEbKZDJoNJpNN85zuRxpTF2ZAs9sgW22gXUly8vLOHPmDObn5yEUCmE2m2E2m1FRUQG1Wr2jIfVWjbaysjLYbLYNHmRbPcd2gQh1dXWfWP3InLtYLIbFxUUsLy9DLBajqakJbW1tEAqFyOfz5D5lZWXQ6XQbvosejwcXLlwg22yxWAyZTAalpaUQi8WgaRpNTU0IhUIIBoOIx+PYt28f9u3bt+E9xeNxvPXWW8Q7OZVKYXR0FPl8Hg6Hg3wHysrK8O1vfxvJZLLIl43xdpbJZJu+VxYWlmLYBhvLF55YLIYXX3wRk5OTW97HYrHgiSeeIOmHlzM6OornnnsOuVyOhBD4fD5QFEVkozKZDPX19ZDJZLBYLEQ+CaxHizscjk/co4CiKHg8HrjdbiSTSWQyGczNzWFlZQUCgYBc4C4PRaAoChqNBqWlpeDz+aisrERLSwusVuu2stZwOIzXX38d//znP7G4uEjOBY/HQ2VlJUpLS6FSqYqmjyqVCu3t7ejo6Ngyir5QKGBqagp9fX2Yn5/f8vU5HA4sFgu6u7thsViKLt6shPT6SafTOHr0KDGWTiaT8Pl8RDIhFArx4IMP4gc/+MGOCt3JyUm89NJLpODOZDKYnJxEJBIBADKpZibLHR0dePTRRz/VJC8WFpatYWs3lhudtbU1Mgy6nIqKCjz00ENbplQnk0n89re/hdPpBLDuY7W8vIxwOIxcLkeadg6HgzSaOjo6SC0klUpht9uvewt+M8koIwOMRCIoLy8vSo4sFApwu92Ym5sjnmcURREfYK/Xi0wmA6VSidraWtjtdlRVVW1rMp9KpXD27Fn09fXB6XSSbfVoNAoOh4OKigoolUqUlpZCoVCAw+FAIBDAYDDAYrFALBZDJBJBo9FApVJtWjf5fD44nU4sLy8XhSIIBAIYjcYdhQisra2hp6eH+NjV19ejtLQUarUa5eXlO5If0jSNlZUVTE9PFw0Ly8rKYLVadzzsDgQCJPSL4ZMORGDCDXw+HwKBAGZmZkBRFAwGA/bu3Uu+04FAACsrK+Dz+TAYDBtq6UKhgLGxMUxPTyMUCpGAA+Yz4/P5aGhoQDKZRDgcRjgcRmtrK+64444NnyVN0+jt7cWZM2dA0zQKhQJmZ2exurqK2tpa6PV6cDgc8Pl83H333dBoNEW+bEzjkA0/YGG5OmyDjeVLAUVROHLkCN5+++0t76NQKPDEE0/AarVu+Nvs7Cx++9vfEr+qTCaD5eVlxGIxkrbJ5XJhsVggFAqhVCrR0tICvV4PHo8HHo8Hs9lMZHKfJIxpvcvlQiQSQS6Xw9zcHNxuNwQCQVFjJJ1OkymuRqMhBRWzCt7c3IzKysotm0+rq6s4dOgQjh8/Dq/Xi3Q6jXw+D4FAgLKyMpKipVKpyMVVJpOhtbUVu3bt2rSByeDz+dDf349Lly5tmQQLrEfG79q1Cx0dHZBKpUV/YyWk18f8/DwOHToEj8ezaQhCTU0Nnn76aTgcjqs+VzQaxe9//3tMTEwAWP+3t7CwUJQKq9VqUVdXBy6Xi4qKCjz55JPsdYOF5XOArd1YvgikUikcOXIEs7OzRbeLRCJ8+9vf3rRuA9Zrgt/97ncYHR0FsH49CgQCcLvdxPqCpmnU1taSzZ+GhoYiq4vKykpYrdbr2mBi/MsCgQCUSiUMBgN4PB7C4TBWVlYgEomINPPKx1wpWxSJREXbT3w+HzU1NbDZbDAYDNs2frxeLz788EOMjo5idnYW6XQa6XQasVgMJSUl0Ov10Gq1kMvlZMDGNJVsNhukUin4fD7KyspQVla26WYao6aYn58v8nYF1puhV/M2oygKi4uLGB4extraGiQSCXQ6HaRSKUQiEcrLy1FeXn7VgdxWjbbS0tJr8tmLxWJwOp1k65ChsrISFovlmnzjtiIajcLtdiORSJDtOZVKhd27d6Ouro4EjLlcLiSTSVRUVBDPtcsJh8MYGBiAz+eDx+PB7OwsUqkUKioqIJVK0djYCJqmEYlEEAwGUV9fj/vvv3/Twanb7cbRo0fJNuDa2hqmpqbI7wTmu9rW1oZdu3ZhdXWVbD/yeDw2/ICFZQewDTaWLxXDw8P4/e9/v+Hiz8DlcvHAAw/gzjvv3HABc7vd+PWvf00u2Ixp6+rqKtno4nA4MBqNEAgE4PP50Gq1aG1tJY0lprHzaTV1mIu1z+dDoVAgaUlcLrdomsSEIqRSKahUKkilUuKxptfr0d7eTjzlroSiKIyOjuLll1/G4OAgYrEYkskkSelSKpVQqVQwGAwQiUQQi8XEXNdut2P37t2bhkswZLNZDA8Po6+vb4PJ8eXw+Xw0Njaiu7t70yRTJnr8ahJSrVaLxsbGr7yENJ/P44MPPsD777+PQqGATCYDn89HJulcLhd33XUXHn/88asWTTRN491338Ubb7xBClO/34/p6WmyYVhSUkI2OwUCAR555JFNpQssLCyfHmztxvJFgaZp9PT04OTJkxv+tm/fPtxyyy2bDjDz+Tz++Mc/or+/n9zGbJcFAgEyBDUYDCS5s7S0FN3d3SgvLweHwyFD0pqamuvafr9cMlpbWwupVIpsNgu3241MJgOtVrthAMnIFp1OZ1ENwwxG3W43CeHSaDREPaHRaLYc5I6Pj+PkyZOYnp4mstFEIoFUKoXy8nKYzWY0NDQgl8shEAgAWFcQVFVVwWazkffIbJdtVSMy79vr9Rb9TSKRkFCErRpl+Xwes7OzmJqaQjweh0qlKqqZ5XL5jiSkzPmbmpra0Ghjgh52AqMOWVhY2BCIYLFYrpqkejWy2SwWFxfJhqXL5YJQKERDQwM6OjogEomKQhLEYjEMBsOG80fTNGZmZjA6OopoNIqJiQm4XC6oVCoolUo0NTVBKpUiFovB7/dDr9dj//79m36GqVQKx44dK9r+ZCSjdrudfA8qKytxzz33IBKJEF+2kpISNvyAheUqsA02li8dPp8Pv/vd70iK0Wa0tbXhscce29BI8Hq9ePbZZ0nhAaDI34OmaVJAAesXPLFYDIfDgfr6enIh02q1sNvtRVPLT5JUKlVUfDHFAYCiizITipDJZIjZLdMo1Gg0aGxsRENDw6YSjHQ6jVOnTuG1114jRWIqlUI6nYZEIoFEIkFFRQVJ0pLJZJDJZBAKhbBYLNi9e/e2xSoj/ezr68P4+PgGD5bLqa6uRnd3N4k83+x5riYh5fP5sFgsX3kJqcfjwaFDhzA/P09Saxl5MbBuIHzw4EHcdNNNV30up9OJF154AaFQCMD693JiYoJMRnk8HiwWCyl0b7rpJjz88MOstICF5TOCrd1YvmjMz8/jn//8Z9F2FwAYjUY88MADm1pS0DSNQ4cOFTXnaJomNhsURUEoFJI6rVAogMvlwmg0oquri9Rucrkcdrv9upoGl0tGtVotysvLyXvw+/1QqVTQarUbmmNMc8XpdBb5g6lUKkgkEvh8PszOziKZTEIikaC2tpaEImzWxMpms+jt7UVPTw9mZmYQCASQz+eJtYNWq0VbWxs6OjqwtrZWVCur1Wo4HA6Ul5eDoigoFAqo1eotbUCi0SgJRbi8OcXhcIjkcqswgnQ6jcnJSTidTnC5XFRXV0OhUCAajYKmaXC53B1JSBk55tTUVJFfHLORVVVVteVjL4fxRZubm9sQiGA2m8l24vXANCUZL1zGv7ampgZ79+4l23LJZBIulwvZbBZarXbTc5dIJHDhwgWsrq5iYWEBg4ODJNiiqakJer0e8XgcPp8ParUaDz300Jb/ZgYGBnDq1CliiTM1NQWv1wuj0QiDwQAOhwORSIR77rkHIpGoyJeNDT9gYdkatsHG8qUkl8vh73//O3p6era8z1bStXA4jGeffRarq6tFtycSCaysrCCZTEIkEsHhcCCdTiOXyxEZZmdnJ1nvZvyornciutPjXF5eJoEIi4uLmJ2dRaFQgFQqLXrdXC6HdDoNuVwOiURCNpdkMhnq6urQ1tZG0q8ux+/347XXXsOJEycQi8VQKBTIcUskEmLyKxaLEQqFUFJSQrbmDAYDdu/evcFX7Uri8TguXryIgYGBDT4slyMWi0kowmbr+5lMBtPT0zuSkDLNxa+ihJTZFDh69CiRAfv9/iID4d27d+Opp566qkwikUjgT3/6E0mDoygKTqez6N8PI7ng8XjQ6/V48skn2YKMheUzgK3dWL6IRKNRHD58GCsrK0W3y+VyHDhwYNPvMk3TOHLkCI4dO1Z0O5PsGI1Gyaa9SqUiNglisRhtbW0wm82kTtHpdLBYLNc8DNpKMhqPx7G0tAQul7uttM7v9xMvNQaFQoHKykqytR8MBsnWmdVqhdlshlKp3FBnBoNBnDx5EhcvXiSyUSZplanPvv71r2P37t0YGRnB9PQ0eaxKpYLdbodOp0M+n4dEIimyHbkSpjnldDo31HAKhQJmsxlGo3FT6WksFsPo6CgJA7DZbJDJZPD7/UgkEgCwIwkp02ibnp4ueg9KpZIEle0ERoJ6ZcNTIBCQkIjrVUOEw2EsLy8jnU5jdnaWfE+6urpIncw04/x+P0pKSmAwGDb9Hi4uLuLSpUvwer3o6elBLBaDQCCA2WxGc3Mz8vk8fD4fpFIpvvvd727pZbi6uoojR46Qc7a8vIyZmRkoFArY7XZyrN3d3bDZbAgEAsSXjQn2YMMPWFiKYRtsLF9qzp49i7///e9bbjUJBAL86Ec/wp49e4puTyQS+M1vfkO2whgYfw+PxwMej4fOzk7kcjlEo1EUCgUiQ2hrayNGukqlEg0NDZ/qGnWhUCCT2lQqheXlZUxOTiKTyUAulxcVRIynXGlpKYRCIYm15/P5qKqqQmtrK6xW6wb/s4mJCfzv//4vhoaGUCgUQNM0crkcBAJBUUJVJpOBx+MBl8slk0e9Xo/du3fDbrdvW7BSFIXp6Wli2LsdZrMZ3d3dsFqtm17UQ6EQxsbGMD4+zkpItyASieDVV1/FyMgIgPXvvc/nI/9eZDIZfvjDH+L+++/ftnCiaRonTpzAq6++SjYRPR4PMfYF1qfADocDMpkMYrEYjz32GNrb2z/lI2Rh+WrD1m4sX1QKhQI++OADDAwMFN3O5XJx++23o7u7e9PHHT9+HP/4xz+Kbkun0/D7/QgEAqBpGlarFSaTCT6fj2wrqdVq7N69mySmCwQC1NfXX5dEcDPJaD6fx/LyMuLxOCorK7fc7GIef6UEs6SkBEajEZlMBhMTE1hcXEQ2m4VKpSoKRbiyieV0OnH8+HGMj4/D5XKR8K5UKgWNRgOr1Ypvf/vbcDgcOH/+PMbGxsh1nLluG41GUBQFgUAAtVqN0tLSLWsCJkRgeXm5yNuMGcaazWYi1b0cv9+PkZERYqzf0tICuVwOn89XVJcoFAqUl5dvKyFlNtqubLQxG207/Ty3CkSoqalBXV3ddQUipNNpuFwupFIpeDwezM3Ngc/nw+FwoKOjgzRfY7EYXC4XaJqGXq/f1OM4k8lgaGgIMzMzGBgYwPLyMpEkM79DAoEA+Hw+Dhw4sGWTMZPJ4J133sHU1BSA9QY38z2w2WxkEK3X63HLLbcgHA6Tmp+maTb8gIXlCtgGG8uXHpfLheeee65oInglX//61/H973+/qDDJZDL47//+703TSbPZLFZXVxGPx7F7924IhUIEg0Ekk0lQFAWxWIzW1lbU1dWR5zQYDKivr/9ULz40TcPv95Np7draGiYmJpBIJCCXyzcUI/l8HhUVFeDz+YhGoyR8gJleNjc3E48S5rhPnz6NV155pcjUHlgveoRCIeRyOWpqauD3+7G4uIhMJgO1Wg2NRgODwUCknlczsg0EAujv78fg4CDZttsMpVJJQhE2S0plLv5jY2OshHQLhoeH8eqrryIajYKiKASDQUQiERKCYLPZ8PTTT8NoNG77PIuLi3jhhRfg8/kArDfsJiYmyA8YLpcLs9lMCtw77rgDBw4cuG7ZBQsLy/awtRvLF53R0VEcO3asSIIIrKeD3n///ZtacZw7dw5//vOfixo8mUwGkUgE8XgcyWSS+NFGIhEEAgFkMhkSZtXW1kYGbkw9tF0S+2ak02ksLCyQhkd5eTmA9dpmbW0NMpkMer1+060uhmg0umEjXCKRwGw2QygUYnp6GhMTE4jFYhAKhdDr9bDb7TAajUWywHw+j4GBAXz00UeYmppCIBAgw2EAqKqqQldXFx566CFUVFSgt7cXg4ODyGazANY3yKxWKywWC4B1CWhZWRnUavWW7585/isllwCIH9yVklmaprG8vIyxsTFEIhFUVVWhpaUFSqUSoVAIPp8PoVBoxxJSj8eDqampok00hUJBNtp2WudtFYhQVVUFs9l8zYEIzKZjOBwmdVI6nYZOp8OePXvIhn8+n8fS0hIikQhJwt2sXvJ4POjr68PAwAAWFhbI9ltLSwt0Oh0ZpH/rW9+CyWTa8n1dunQJx48fR6FQQC6Xw8TEBAKBAGpqalBbWwsOhwOpVIq7776bJJEajUZIpVI2/ICF5TLYBhvLV4JEIoHf//73JGlqM2pra3Hw4MGiC2U+n8eLL76IwcHBDfenaRqxWAwejwctLS0oLy9HMBhENBolIQvl5eXo7OyEWq0Gh8OBUCiE3W6HVqv95A/yCiKRCFwuFwKBAHw+HylYFArFpk0+rVYLsVhMmisASAJWR0cHjEYjKWTD4TBef/11vPvuu0WyQqFQiOrqaojFYmJ+ysgGQqEQpFIpKioqYDQasXfv3i2bYpeTzWYxOjqKvr6+DbLdy+HxeGhsbERXV9eWslxWQro1qVQKR48eJbLqTCYDr9dLmq5CoRDf+ta38MMf/nBbb8F0Oo2//vWvxHA6n89jZmaGNN2A9X8X9fX14PP5MJvNeOKJJ7aUL7CwsFw/bO3G8mXA5/Ph1VdfRTAYLLqd8ZjabBtsaGgIzz//fNFQLZvNIhKJgKIopFIpyOVydHV1IZ/PIxAIIBwOo1AoQCKRoLOzE0ajEVwuFxwOBwaDAXV1ddc0ENpKMppOp+F2u5HP56HX66+6CZVIJOB0OrG8vEwGXyKRCHV1daioqMDc3BxGRkbg8XhQKBTINba+vh5qtZo0sWKxGE6dOoVz587B6XQilUoR2ahEIoHBYMDtt9+OBx54AHw+H/39/RgYGCBSTYFAgLq6OqJGyOfzUCqV0Gg020o3mVCEK4OtxGIxCUW4vClTKBSwsLCA8fFxpNNpGI1GNDc3QyKRIJfLwe/3w+fz7VhC6vF4MD09TfxigXW5sc1mu6YNxXQ6jfn5+U8sECEQCGB1dRWFQgEzMzPwer1QKpXo6OiAzWYjn1soFMLy8jLZntvs+5LP5zE6Oop33nkHLpcLKysryGazZFuTx+Mhl8vhnnvu2TYx3uv14siRI8R3mvGlk8vlcDgc5Pzu27cPOp0OiUSC+MWx4QcsLOuwDTaWrww0TePtt9/GkSNHsNXXXiaT4fHHH0djYyO5jaIo/PWvf8XZs2c3fQxFUfB6vaitrYVOp0M0GkUsFkMsFkMmkyFG742NjWSiWFZWhoaGhmueiF4PyWQSbrcbHo8HgUAAY2Nj5CK+mSTSYDCgrKyMGPPm83lwOByUl5ejqampSO7qdDrxl7/8BYODg0UhBcwWm1qtRktLC5LJJObn5zE6Ooq1tTXw+XxSbN50003Yt2/fVRtZzGSzr68Po6Oj24YiVFZWoru7Gy0tLVs2g1gJ6ebMzc3h5ZdfhsfjKYp9Z6a2Op0OTz31FFpbW7d8DpqmcfbsWRw6dAi5XI4U2HNzc+R5mHAQuVyOkpIS/OxnP9u26GNhYbl22NqN5ctCJpPBm2++uUFVIBAIcP/996OhoWHDY6anp/Ff//VfRVvw2WwW0WiUpIfyeDy0tLSAy+UiHo8jFAohFouBpmlUV1dj165dRAEgEolgs9mu2UM0GAxiaWmpSDJKURTW1tYQDAahVqt3tFGVSqUwNzdHghuY46+trYXRaMTq6iomJiZI80wqlRL5qE6nIzWMy+XCBx98gOHhYbhcLuRyOcRiMaTTaajVatjtdhw4cAB79+4FTdMYHBxEX18faVAxARGNjY2QSqUkSEuj0Wxb18ZisU0TOzkcDrRaLcxmc9G5zeVymJqawvT0NJH2OhwOsjXH2FrsVELq9XoxNTW1odFmtVqh1Wp33By73HOOGagD1xeIkEqlSLCBz+fDzMwMeDwerFYrOjs7ye8GJpU2Ho9Do9Ggurp6S3uUV199FePj4/B6vfD5fKiqqoLRaIRKpQJN07j11lvR2dm55XvKZrN4//33MTY2Rp6TCSOz2WykoW0ymdDV1UXSYGtqahAKhdjwA5avPGyDjeUrx8TEBF544QUy+boSDoeD+++/H/fddx+52NI0TTa2tiKVSsFgMECpVCKXyyEcDiOVSpF4a5lMhvb2duj1eggEAnA4HJhMJpjN5s9EIpfNZkkgQjAYxNjYGFZWVkjowZWFhclkgl6vx9raGlZXV8n5kslkMJvNaG9vh1arBUVROHfuHP73f/8XLpeLPJ5pymm1WjQ1NaGzsxOBQAAjIyO4ePEiFhYWQFEUJBIJqqqqsHfvXtxxxx3Q6XRXPZZEIoHBwUH09/cXrf5fiUgkQltbG7q7u7f0O7kWCWl9fT0aGxtJutKXlXw+j/fffx8ffPABCoUCMctlvgNcLhff+MY38MQTT2w7eV9ZWcHzzz9PNg+j0SgmJyfJjx0ulwuTyURkIvfffz/uvfde1iiXheUTgq3dWL5snD9/HsePH98wKO3q6sLtt9++oZ5aXFzEb37zm6Jt+1wuRzb15XI5pFIp6urqwOFwkMlkEI/HEQwGkUqlwOPx4HA4irZ3NBoNbDbbNUnhtpKMRqNRLC8vQyAQoKamZkeDvEwmg/n5eSwuLpJhI4/Hg9FohMlkQjqdxtTUFEZGRhCJREgogs1mI6EIFEXh0qVLOHHiBMbHx4lMlhk4VlZWYs+ePfjBD36Ampoa0DSN8fFx9Pb2Ym1tDcC/GmPNzc1QqVRIpVIQi8XQaDSbBi8wFAoF0qC6soaTy+UkFIFRWySTSYyNjWFxcRECgQANDQ2wWCxFNfq1SEh9Ph+mpqaKNiKvp9G2VSCCUChEbW3tjgMRCoUClpaWyGB+dHQUyWQS1dXV2L17N6qrq8nr+f1+rK6uQiQSwWAwbPodpGkaH374Id555x1Eo1EsLS2R71d5eTlEIhH27NmDW265Zdv3NTo6ivfffx+5XA6ZTIYoYbRaLerq6sDlcqFQKHDrrbcin89DLBajtrYW2WyWDT9g+UrDNthYvpIEg0H87ne/2xBicDmNjY14/PHHi3ws3n33Xbz22mtbPoamaTQ2NkIgECCRSCAejxO/D2YiypiPMiaxEokEDofjM5v0FAoFrK6uYmlpiWy0uVwuyGQylJSUbNpoq6+vRyQSgdPpJGvjfD4f1dXVaGtrQ319PQqFAl5//XW89dZbRYUsn8+HTqdDdXU1vvGNb6CjowNerxdjY2M4ffo0kQDweDwoFAp0dHTgvvvuQ0NDw1UvyhRFYXZ2Fn19fZidnd1yM5E5ju7ubtjt9i2fN5PJYGpqijQft0Iul6OhoQGNjY1famnj2toaDh06RP6dxONxstUIAKWlpfi3f/s33HrrrVsWpJlMBi+//DLZAGUm0pcXtmq1GlarlRTOjz/++Gey3cnC8mWHrd1Yvoy43W689tprRbUGsL5hfeDAgQ0m+mtra3j22WeLNpfy+TxpiigUCmg0GrS1tWFtbQ2FQgGZTAbRaBTBYBD5fB4KhQLt7e2oqqqCUCgkAyJGRroTKIrC0tISgsEg2fhhpHtLS0ukqbJTa4pcLoeFhYWijbDLTfgFAgHm5uZInZfL5aBSqWAymUgoQiaTwZkzZ3Dq1CmykZVIJJBIJCAWi2EwGHDvvffi29/+NmnmLCws4OzZs0U1dHl5OVpaWlBZWYlkMgkejwe1Wo2ysrJth8jBYBBOp7NoKw9Yrx0NBgNpCALr9iTDw8PweDwoKSlBc3Pzhv+uXYuEdLNGW0lJCaxWK3Q63TUNUpkE2M0CEcxm845qGp/PRx4/OztLhuDt7e1wOBzkPKbTaeJxXFVVVeSVfDlzc3P4+9//Do/Hg7W1NaRSKZSWlqK8vBwlJSXYs2cP7rnnnm2/v4FAAEeOHIHP5wNFUZifn4fL5UJJSQkcDgckEgm4XC5uuukmqNVq4ssmFArZ8AOWryxsg43lK0s+n8crr7yCkydPbnmfsrIyPPnkk6itrSW3nT59Gn/729+2beZ0d3dDp9NhfHwcmUwG4XCYGMqmUinw+XzY7XbYbDaS1llRUUEuVp8FNE3D5/PB5XLB5/MRWYFEIoFCodhwwTUYDGhubkahUCAyU8YAV6VSweFwoKWlBYlEAn/+858xMDBQJOOUyWSoqamB0WjEfffdh7q6OkSjUSwsLODtt9/GxYsXEQ6HQdM0RCIRLBYL7r33Xtx88807OifBYBADAwO4ePFi0cr+lSgUCnR2dqKzs3Pb7StWQroOI/c8evQoMpkMSdK9PJ2ro6MDTz311JYJVcD61sHf/vY3ZDIZ0DQNt9uNxcXFIj8Zu90OpVKJ0tJSHDx4EHV1dZ/68bGwfJlhazeWLyvxeByvv/560eY8sC7Te+CBB4rqNmC9Rnj22WeLGiD5fJ4E+igUClRUVOA73/kOxsfHEYlEkM1miUcZM1w0GAxoaWmBQqEAj8eDVCqF3W6/Jr/WzSSjTE3m8/kgl8uh0+l2rG5gJIvz8/PEN5XD4UCn08FsNkMmk8Hj8WBiYgJTU1MkFKGmpoaEIkSjUbz//vsYGBgghvXMRpVKpUJDQwO+//3vo7u7mzRzVldX0dPTg6mpKXItVyqVaGlpgcFgQCqVAk3TKC0thVqt3ta/NZPJkFCEKxUmarUaZrMZer0eXC4XHo8Hw8PDCIVCUKvVaGtr2zRoYKcSUr/fT8IfGK630fZxAxESiQRcLhcKhQKRZjLhG11dXaRRR9M01tbW4PV6IZPJYDAYNj2/oVAIr732GsbHx+Hz+UhTWq1Wo6SkBF1dXfjud7+7bQMsn8/jww8/xKVLl8j5mpiYAE3TqK+vJwsCVqsVDQ0NyGaz0Gq1KCsrY8MPWL6SsA02lq8858+fx1//+lfSLLoSPp+PH/zgB7j55pvJRfbChQt46aWXtvUB6+zsxN13341Tp07B4/GQjTZmxTufz0OlUqG1tRWVlZUQiUTg8Xgwm82ora39TFeqmUh5Jtp8cnISQqEQSqVyQ0JUdXU1uru7IRaLMTY2RhJLAZCV9fb2dni9Xvz973/H4uJi0ePVajV0Oh3a29tx9913Q6FQgKIoeDwefPjhhzh+/Dg8Hg9yuRy4XC4qKirwta99Dffcc8+OVs1zuRzGxsbQ19e3bZABl8tFQ0MDurq6tk0NZSWk64TDYRw+fBgjIyMA1ieoPp+PFPNSqRTf//738Z3vfGfLVDGPx4MXXngBbrebPOfk5CT5t8fhcGA0GslU/6GHHsJtt932pTuXLCyfFWztxvJlhqIonDhxAr29vUW3czgcfOMb38DevXuLrh+xWAy/+c1vippyhUKBWHkoFAqo1Wr8/Oc/x/LyMi5duoRcLodUKoVcLodgMIh4PA6RSASHw4G6ujqicqisrITVat3xoO1yyahOpyM2FoxvLgDo9foiFcXVYKSGc3NzZNDIyEMtFgvkcjlSqRRmZmZIKAJN00WhCGtra3j//fcxOjqKQCCAdDpNFBhMPfbII48Q2SKwXkP29PRgdHSU1EhSqRTNzc2wWCzIZrPIZrMkEGG7RgvTOHI6nUSKyiASiWAymVBXVwexWExqs0QiAb1ej5aWli3T5HciIfX7/Zienobf7yePlclksFqt0Ov111SLbBWIUFpaCrPZvG0gQj6fh9vtRiKRQKFQwPDwMGKxGKm/L7dSicfjJDBDp9Nt2uhNJBI4fvw4hoeHSaprNpuFUChESUkJmpqacPDgQTLw34qpqSm8/fbbpPE8OjqKeDxOmoc8Hg9lZWXYt28fKIoiyacej4cNP2D5SsE22FhYsO4V9dxzz21IOLqcvXv34pFHHiETovHxcfzP//zPlo05AGhoaMATTzyBsbEx9Pb2IplMIhQKEfkBs5ZuMBhIYiWXy0VJScnnkmCZSCTgdruxvLyMqakpTExMgMvlQqlUbpiMaTQa7Nu3DxUVFZiZmcHk5CQCgQAKhQLxX7NarZibm9uQNsrj8aDVaqHVanHrrbdi7969ZJKYzWZx7tw5vPXWW5ienkY6nQZN05DJZGhtbcWdd96JpqamDRKQzVheXkZ/fz9GRka2bIwB69uDXV1daG1t3bY4/qpLSGmaxvDwMA4fPoxoNLppCILFYsFTTz0Fq9W66XPkcjkcPnwYJ06cALB+TicnJ4s24kpLS2Gz2SAUCtHR0YFHH310y4QyFhaWrWFrN5avApOTkzh69OiGesxqteJb3/pW0fUjnU7jv/7rvzA9PU1uoyiKJIjK5XKoVCr88pe/hFKpxIkTJ7C8vIxsNot0Oo1CoYC1tTXkcjloNBo0NjZuGJJulWR+JVtJRguFAlZWVhCJRFBeXo6Kioprau4woVBOp7NoG6yiogIWiwUqlQoURcHtdmN8fByzs7PIZDKQyWQwGo2oq6uD2+3GqVOnMDs7i2QyiUQigWQyCaFQCIPBgP379+Nb3/pWUW2YSCTQ29uLwcFBMnxjttMbGxtJeqtMJoNarb5qHRePxzE/P4/5+fmiz5bD4aC6uhpmsxllZWVwOp2YmppCLpdDXV0dmpqattyWYySkXq8XyWSSvMfLJaSBQABTU1NFjTapVEoabdcyAL/eQASapuH1euH1eiEQCDA7O0ukmS0tLWhqaiLDTOb7EgwGoVAoUFNTs2HQmc1mcfbsWczMzGBoaAjRaBSFQgHpdJqkw/7iF7+AwWDY9njC4TCOHj1K0k8ZKatUKoXD4YBMJgOfz8e+ffsgl8uJL1ssFmPDD1i+MrANNhaW/590Oo0//elPuHjx4pb30ev1ePLJJ8nFYW5uDv/v//0/cpHejLq6Ovz7v/878vk8Tpw4QYqVcDgMkUgEv9+PZDIJiUQCm82G2tpaIl3UarWw2+3brtV/GmQyGSwtLcHtdmN6ehpjY2NkunvlhKu0tBR79+5FfX09XC4XRkZGsLq6Ss4JU0iNj4+TFCIGxt/DZDIR2ejlTE5O4tixYxgYGEAsFkM2myUTzD179qCrq2tHpsDJZBKXLl1Cf39/kdfGlQiFQrS2tqK7u/uqBUAwGMT4+DjGxsY2+MBcjk6nQ2Nj4zVNtm90UqkUjh49ip6eHgDrBavP5yOfuUAgwL333osf/ehHW06qBwcH8ec//xnJZBIURWFhYQFLS0vk70KhEDabDaWlpaioqMCTTz7JXntYWK4RtnZj+aoQDAbx6quvwufzFd2uUqnw0EMPobKyktyWy+XwwgsvYGhoiNxGURQikQjy+TxKSkogl8tx8OBBNDc3Y3R0FGfOnCFyUZqmkUwm4fV6SaKm3W6HSqUCj8eDXC4nlgc7fe+MEX1tbS25boZCIayurkIsFkOv119zLXj5NhijNADWlQQWi4XIFcPhMKampjA6OopQKESGoJWVlVhYWMDg4CDcbjfxpcvlclAoFGhqasIPf/jDDYmU2WwWAwMD6O/vJ/URn8+HxWIh6e6xWAwikQhqtRoqlWrbphWzmcf4AF9OSUkJzGYzqqqq4HQ64XQ6weVyYbfbYbVat5XZXk1CGolEMDU1VfSdkkqlqK+vR01NzTU12phAhNnZ2aKB4tUCEWKxGNxuNzgcDrEuAdZ/W3R1dRU1KSORCKmjampqNjQwC4UC+vr6sLi4iPHxcczPz4PH4xHfPZ1Oh5/85Cfo7u7eVjJaKBTw0Ucfob+/HwCI+gVYH7IydiHMlifzb4TD4bDhByxfCdgGGwvLZdA0jePHj+Pw4cNF3gmXIxaL8dOf/hStra0A1rffnn322aIL5pVotVr86le/glKphNPpxIkTJxAKhRAOh5HNZsHn88lEtKKigpjPikQiCAQCcjH/rKVy+Xweq6urcLlcpPhKpVJQKpWQyWRF70cul2PPnj1obW2F3+/H2NgYSVaiaRo8Ho9sgDEFHAMzue3o6CCy0cvx+/14++23ce7cOXg8HqTTaSJ7aGxsRFtbGwmK2O6CTdM0nE4n+vr6SOz7VhiNRnR3dxcZy271nIuLixgbG8Ps7OxXSkLqdDpx6NAhUnxeGYJQXV2NgwcPYteuXZsebyAQwAsvvID5+XkA/5JnMI/ncDjEt08oFOKRRx7Bvn37PqOjY2H54sPWbixfJXK5HI4dO4bR0dGi2/l8Pr75zW+ipaWF3EZRFP70pz8VyUspiiINJCb46dFHH8WePXsQj8fx0UcfYXJyErlcDvF4HAKBAD6fD9FoFHK5HBaLBTU1NaSxptPpYLFYdmTwvpVklBl4Mrdfr8TO5/Nhdna2KOhBpVLBYrGQgWIul8Pc3BxGR0eJ7FClUkEsFmN+fh5zc3MkXZVpnGk0Gtx666340Y9+RJJRGRh5Y29vL2mMcblcGAwGdHZ2oqSkBNFoFBwOhwQibGUxwRAKhUgowuUDW8ZMv7KyEsvLy1haWoJYLEZTU9O2NiDA9hLSiooK5PP5DY02iUQCq9V6zY02YL3WmZ2dhdfrJbdtF4iQy+XgcrlIAvvQ0BDC4TAqKyvJoJk5vlwuB7fbjVgshrKysg2NLEaJMD09DY/Hg6GhIRQKBbIxJ5PJ8OCDD+KWW27Z1lcXWK8B33rrLaTTaSQSCZJ+yihYeDweNBoNOjs7SdNWoVCw4QcsX3rYBhsLyybMzMzg+eefL5r4Xcndd9+N/fv3g8vlwu/349lnn90wOb0cjUaDZ555BuXl5chms+jt7cWFCxeQSqXg9/shlUrJNI0xvq2rqyNGrEqlEo2NjTuSRn7SUBQFr9cLl8uFyclJEv2uVCohl8uLLt4SiQS7d+/Grl27kE6nMTs7i9HRUXi9XuRyOeItwUzOmIksl8tFVVUVampqNshGGWKxGE6fPo0zZ85gZWUF4XAY+XweZWVlMJlMsFgsaGhoQH19/VXPUzgcxsDAAC5cuLDtBmJJSQk6Ozuxa9euqz7nV1FCmsvl8P777+ODDz4ARVEbQhC4XC5uvvlmPP7445ua+xYKBbzxxht49913Aaxvx01MTBRtBSoUCjgcDohEItx00014+OGH2aKMhWUHsLUby1eRCxcu4L333tswKG1vb8ddd91Fmjg0TeOVV17B8ePHyX0Y64NcLgepVAqZTIbvfe97uP322wGsJ2iePHkSgUAAyWQSuVwOfD4fbrcbuVyOSBcrKyshkUjIkFSr1V71fW8lGaVpGh6PB36/H6Wlpaiurr7u7Z9AIACn01kkf5TL5UW+YMzrTUxMYHJykjQT0+k0CcZKp9OIx+NIpVIQCAQwGAz43ve+h/vuu29Dk4ymaUxOTuLcuXNYXV0F8C9vuF27dkGj0RDJIhOIcLWN/2w2S0IRrlQRlJWVobS0FIFAgJzLlpaWqzaMgO0lpHw+H4uLi0WNMYlEgvr6ehgMhmv+TLYLRLBYLEU2Mcw2ot/vh1gsxuzsLObn51FSUoLGxkY0NzcXbTgGAgGsrKyQNNYrvfxmZmZw6dIlJJNJjI2NwefzQSKRYG1tDel0GnfeeSe+8Y1vFIWxbXUMR48exdLSEmlEer1eSCQSOBwOlJSUQCgUoru7GwqFAiqVClqtlsiu2fADli8jbIONhWULotEonn/+eczMzGx5H6vViieeeAIKhQLRaBS//vWvi2RuV6JQKPCrX/2K/Bvy+Xw4fvw43G43IpEI4vE4VCoV1tbWEIvFoFQqYTabodPpoFKpAKz7tdXX139uDYZgMAi3243R0VGMjo7C7/dDLpdDoVAUFVUikQi7du0iq+aMfNTlciEejyMejxOTXcangcvlkmQri8WyqWwUWG9k9ff34/Tp01hZWYHf70cqlYJUKoXBYIBWq4XBYCDbYtsVavl8HuPj4+jr6yPGwpvB5XJhs9nQ3d0Nk8l01Q20a5WQMp5jX1RWV1fx8ssvY2FhAcDGEASVSoUf//jHuOOOOzbdCBwdHcUf/vAHxONxUBQFp9NJCnFgXXZqtVqhVqs3SLVZWFg2h63dWL6qrKysEL/Qy6mqqsKDDz5IaiqapnHs2DEcOXKE3IemaUSjUWSzWUgkEpSUlOC+++7Dt771LXA4HOTzeQwMDKCvrw+ZTAaRSARisRiJRAJra2sQCoUksIdpzKhUKtjt9k1N+K9kK8loPB7H0tIS2Xb6OI2JcDgMp9NZ5D3M+IJdvvWUTCYxMzOD4eFhYgXh9XqxtrZGLB6i0Sjy+TxkMhlaWlrwb//2b0XbgpezsLCAnp4esrkOrEtWd+3aBb1eT5JL5XI5NBrNVUMemM2r2dnZopoBWK8bZDIZEokEMpkMqqqq0NLSQj77q7GVhFQoFMLn8xU1KT9Oo227QARGcsnUnIwMlMfjIRaLYWhoCBRFwWQyYdeuXUVD20wmA5fLhWQyicrKSlRWVhbVrktLSzh//jxpnLpcLnA4HMTjcfh8PjQ3N+O2226D3W7f9rgoisLZs2dx7tw5AOv+x8xvp7q6OhLK0NDQAKPRCKlUCqPRSIaxbPgBy5cNtsHGwrINhUIB//znP/Hee+9teR+lUoknn3wSZrMZyWQSv/3tb+F0Ore8v1Qqxb//+7/DbDYDWL8wMf4e0WiUbLAJhUK43W5QFIXq6moYDAbodDqIxWIIhULY7fYdTUQ/LRhfiNHRUYyMjGBlZQUlJSUbAhH4fD7a29uxd+9eyOVy+Hw+TE5OklCE5eVlXLx4EfF4HAqFAiUlJeDz+cSoddeuXZvKRoH15tilS5fQ09ODlZUVhEIhhEIh8Pl8aLVaVFRUQKVSwWw2o6GhAZWVldsWPqurq+jv78fw8HBRkXMlGo0GXV1daGtru6r5/ldJQsoUWW+++SYymQxomkY4HEYoFCLT2dbWVjz55JObGumGw2G89NJLxHza4/FgZmamaLKr0+lgMpkglUrx2GOPob29/bM5OBaWLyBs7cbyVSaZTOKNN97A3Nxc0e1isRjf+c53YLFYyG0nT57EoUOHiHUETdOk2SMWi1FSUoJbb70VP/jBD8j1ORgM4sSJE1hYWEAqlUIqlYJCocDq6ipCoRA0Gg1qampQXV1NQqwMBgPq6uq2tZ4AtpaM5vN5LC8vIx6Po7Kyktx+vTBbVKurq+TYxWIx6urqyAYdsH59d7lcJBTB4/EQyShFUcSXjqZplJWV4c4778RPfvKTLcO61tbWcO7cOUxOTpJrvEKhQEdHBywWCwlWkEgk0Gg0UCgUV62LEokECUVghnvMe2fOKWO639zcvOMGJSMh9Xq9xPaEy+VCIBAgHA4XDVHFYjHq6+thNBqvudG2VSCCTCYj4Rk8Hg/ZbBaLi4vIZrMQCAS4ePEigsEgysvL0dnZWTQEZrYRvV4v8T2+vG71+/04e/YsUqkUPB4PVlZWkEgkkM1m4fV6UV5ejptvvhk2m434DG7F4uIi3nzzTSQSCUSjUYyNjSGdTkOtVsNms4HP56OiogItLS2QSCQwGo3I5XJYW1tjww9YvlSwDTYWlh0wODiIP/7xj8T/4Eq4XC4eeugh3Hbbbchms/jd735HjEg3QyAQ4Be/+AUaGxvJbYlEAh999BHGx8cRi8Xg9/tRXl6OWCyGtbU1SKVS1NTUQKvVoqqqisRhNzQ07Ggi+mmRTqextLSEkZERDA8Pw+VyQSqVQqlUFl3EuVwuWltbsXfvXpSVlSGZTGJubg7Dw8NYWlrC0NAQRkZGUCgUIJPJoFQqIRKJUFlZCaPRiNtvvx179uzZtCilKArj4+M4d+4clpaWEI/HEQwGUSgUoNFoUF5eTlb8HQ4HrFbrtnLPdDqNS5cuoa+vD4FAYMv7CQQCtLS0oLu7e0fSg2uVkDY1Ne140nojEQ6H8corr5B/A1eGIEilUjzwwAN44IEHNjQoKYrCsWPH8Oabb4KmaSQSCUxMTBTJeBkDaYlEgjvuuAMHDhy46o8VFpavImztxvJVh6ZpnD59GqdPn97wt5tvvhlf//rXSTOiv78fv//970lDhqZpxONxpNNpiEQiyOVy7N69G4899hi55tA0jampKZw6dQqxWAzhcBh8Ph8ikQjz8/OgKAo6nQ7V1dWorq6GTCaDSCSCzWa7akOBSfoMhUIoLS2FXq8nrxsIBLC2tgaZTAa9Xn9V77KrkUwmiVyR+WkoFAphMplgNBqLnj8UCmFychKjo6MYGxvD5OQkotEoBAIBCoUC8vk8+Hw+ampq8MMf/hDf/OY3t7xGh8NhnDt3DiMjI2SwKZVK0dTUhKamJhQKBUQiEfD5fGg0GpSWll61ccVIbZ1OZ1ENl8/nEQqFSG3Y0NCAhoaGazp3m0lIs9ks2Xhk1CVisRgWiwVGo/Ga6xOKorC6urppIILJZEJtbS0EAgFWV1cRDAYhk8ngdDoxOztL0jxbW1uL1BvJZBIulwvZbBZarbaoMcvYrzC/PYLBIOLxOKLRKMLhMGQyGWw2G2w2G6xWK8xm85YqmkQigbfeeots401MTCAQCEAkEsHhcEChUEAkEqGzs5PIRSUSCRt+wPKlgm2wsbDsEI/Hg+eee27bxkhnZyd+8pOfgM/n449//CNJ2NkMLpeLn/70p+jq6iq63eVy4cMPPyR+G/l8HuXl5VhaWkI0GkVFRQW0Wi2qq6uhVqvB4XBgMplgNps/1yZDLpfDysoKxsbGcPHiRSwsLEAgEGwIROBwOGhoaMDNN9+M8vJyFAoFrK6uYmRkBKOjozhx4gTm5+dB0zTEYjGUSiXxIrHZbFvKRhnm5uZw9uxZEusei8VIMINKpYJQKCSJXI2NjTAajVtKSGmaxvz8PPr6+jA5ObltKEJNTQ26u7t3XKx92SWkNE1jaGgIhw8fRiwWA7BexAUCAbLFZzab8bOf/QzNzc0bHj89PY0XX3yRpLrNzMwUeRwyG3/l5eUwm8144oknvvB+diwsnzRs7cbCss7s7CzeeOONDYNSk8mE/fv3E5+p0dFRPPfcc0Vb7IzXmFAohEKhQHNzM5588smia3I6ncaZM2dIGFQ4HCbeYktLS5DL5aiurkZVVRWqq6tJw8hms111k4rZ9r9SMppOp0kYgV6vJwn0H4d0Oo25ubmiEAGBQACj0Yja2tqiY87lcnA6nejv78epU6cwNTWFRCJBaiAulwupVIqWlhYcPHgQDQ0NW75uIpFAX18fBgcHyfYW04js6OgAj8dDMBgEh8MhPm07sUphpLAul4scTyaTwfLyMiiKgl6vx+7du2GxWK5ZOXClhDQejyMQCCCXy6GkpAQ8Hg8ikYhstF1PjX61QASm9maSWQcHB5HP51FbW4tdu3YVed9SFIWVlRUEAgHI5fKikAHm+xsMBhEKhRAIBEBRFGm4lZSUoLy8HDU1Naivr4fdbi9K5r0cmqZx/vx5nD59mmw+MlukJpMJer0eHA4HDocDtbW1JC1+eXmZHBvrs8vyRYZtsLGwXAOZTAZ/+9vfcP78+S3vU1lZiV/84heorKzEoUOHcOrUqS3vy+Fw8Mgjj+DrX/960e2FQgH9/f04f/48otEoPB4PSktLwefzsbCwAC6Xi+rqalRUVBAfDsZQ9PNesaYoCh6PB2NjY7hw4QJmZ2fB5XKhVCpRUlJSNJmyWq24+eabidQ1HA5jZmYG7777Lt555x3SUBEKhSgpKSHGr93d3VvKRhlWV1dx9uxZTE5OolAoIJVKIZ/PQywWF0kNFAoFzGYzGhsbt5WQRiIRXLhwARcuXNi2ISaTydDR0YFdu3btaPvscgnpzMxMUSrW5TANpaamps8lUfZ6SSaTOHLkCElqKxQKCAQCxBdHIBDgjjvuwE9+8pMNPw5isRj+8Ic/YGxsDDRNY3V1FXNzc0WSUcZQWqFQ4Gc/+xkcDsdnd3AsLDc4bO3GwvIvIpEIDh8+vMGrS6FQ4MCBA8QranZ2Fr/97W+LZHqJRALJZJIMDuvr6/HLX/5ygwH86uoqPvzwQ6ytrZEBUWVlJRYXFxEKhVBVVYWKigpUVlaivLwcPB6PbIltt7nD+HRls1no9XrSOGG2nUKhENRqdZFf18chm81ifn4ei4uLZCjG4/FgMBhgMpmKts8ZA/7Tp0/jjTfewPT0NEl753K5EIlEKCsrwz333IPHHntsW7+tXC6HCxcuoL+/n9QJfD4fdXV16OrqgkwmQzAYRD6fh1KphEajuapVB/O8i4uLcDqdZOiXTCaxuLiIXC6Huro63HrrrTAajdd8rq6UkMbjcbjdbmSzWSiVSkilUohEIlgsFtTW1l5Xoy0Wi2F2dpY0Bhmqqqqg1+sRj8eRz+chEolw8eJF+Hw+lJeXo729HXV1dUXfrVgsBpfLBZqmodfrSa2az+dx7tw5rK2tkd8eAoEAqVQKTqcTSqUSOp0OFEWRTbbtGsTLy8s4cuQIYrEYQqEQxsfHkc1mUVpaSobGlZWVRK1RU1MDj8fDhh+wfOFhG2wsLNcITdP46KOP8PLLL2/ZDBGJRPjxj3+MXbt24ciRIzh27Ni2z7l//37cc889G4qicDiM48ePY35+Hn6/H/F4HHq9HsFgECsrK2TqU1FRQaQDFRUVcDgcn/uFiaZpBAIBTE5Oore3l3hpKRQKKBSKogLDZDLh5ptvJoUNkw71pz/9Ce+++y4phphpaG1tLZqamnDXXXdtKRtlCAaDOHv2LEZGRpDP51EoFEBRFGQyGcRiMSiKQqFQAJfLRXl5Oex2O2w225YFYKFQwMTEBPr6+rC4uLjl63I4HFitVnR3d8NsNu+o4GUkpKOjoxt+AFyOXC5HY2MjGhsbvzAS0tnZWbz88sukaZpKpeDz+ZDNZgGsF4mPPvoobrrppg2R8u+//z5ef/11YqY8MTFR5K8ik8ngcDggk8lw//3347777vvCNCBZWD5N2NqNhaWYfD6P9957D4ODg0W3c7lc3HXXXejs7AQAuN1u/PrXvyb1BwDiCyYQCKBQKGAwGPCrX/1qw7CPoijiD5tIJBAIBIg/7ezsLHg8Hqqrq6FUKmEwGFBSUgKpVAq73b6lZxnzvJdLRmtqasj1MhqNki23mpqaq6Zw7hSmMbWwsECu11wuF3q9HnV1dRsajPF4HMeOHcPLL78Mp9OJTCYDiqJI4EBtbS0effRR3HPPPds2FAuFAkZHR3H+/HlSNzCv293djfLycgSDQaTTaZSUlECj0ezYLsXr9cLpdGJlZYV4xS4uLiKTyaC+vh533nknamtrr/t8MRJSn88Hl8uFSCQCpVJJBs0fp9HGbBgyjUGGyxt5SqUSTqcT09PTEIvFsNvtaG1tLfpdkM/nsbS0hEgkgtLSUuh0OpJYe+HCBczPzyOZTGJlZQVcLhclJSUYHR0Fj8dDQ0MDgsEgysrKUF9fD6vVWvRdvPL9Hjt2DLOzs8hkMhgbG0MkEoFAIIDdbkdpaSnEYjHa29tRXl4Og8GAWCzGhh+wfKFhG2wsLNfJwsICnnvuOYRCoS3vc9ttt+HBBx/EyZMn8corr2z7fHfccQceeuihDY0BmqYxPT1NYuGZhKqqqirMz88jHo+Ti5BOpyMTUbPZjNra2hvCyyAajWJ6ehpnzpzB1NQUWZ9XKpVFa+B6vR4333wzMR6maRpzc3P49a9/jVOnThHjfA6HA4lEArPZjD179uDBBx+EyWTa9j3E43GcP38eFy5cIBIRZrNOKpUim80ikUiApmmIRCLodDo0NTWhtrZ2y0LV4/Ggv78fQ0NDpPDcjLKyMnR1daG9vX3Hjc9gMIixsTGMj49/aSSkuVwO7733Ho4fP05MkZlgCsY0ePfu3Xj88cc3eNrNzc3hxRdfJPKLqakpBINB8ncejweLxYKKigo0NTXhpz/96efqTcjCciPA1m4sLJszPDyMt99+e0PwUFNTE+69914IBAJ4vV48++yzRT5eqVQK8XgcfD4fSqUSVVVVeOaZZ4qkeAzxeBwnTpzA9PQ0MX6vqalBKBSC2+1GWVkZ+Z/BYIBAIEBlZSWsVuu2DbKtJKO5XA5utxvpdBpVVVXbNuuulUKhAJfLhfn5eVJDcTgcaLVamM3mDdfbRCKB119/Ha+88goJHcjn8xAIBFCpVOjq6sJ//Md/wGq1bvu6NE1jZmaGeOwyr1tZWYnu7m7U1NQgEokgFotBLBZDrVZDpVLtaMiWSqUwNzdHmkl+vx9utxuZTAYWiwV33HEH6uvrr3tgl0gk4PV6sbi4iPn5efh8PuJRXFZWBpvNhtra2uvyz8vn81hcXMTc3FzRpiWT5moymZBOp3HhwgVks1kYDAZ0dnZuULmEQiGSSso0ewEQC5NMJoOlpSUUCgVotVrMzMwgHA6jo6MDyWQS0WgUFosFdXV12w6nL1y4gBMnTiCfz2N+fh4ulwsAYDAYYDQaiX1MbW0ttFoteDweG37A8oWFbbCxsHwM4vE4XnrpJYyPj295n7q6Ohw8eBATExP4y1/+UrTafSX79u3Dj3/8402bYplMBmfPnsWlS5cQCoXg8/mIj4HT6SSBACUlJWSqWFJSgoaGhk+0yPo4MGvmH330EUkX2iwQobKyEjfddBMcDgcpbEZGRvB//+//xaVLl8hEFPhX2tW9996L733ve1eddmWzWfT19aG/v59MpgUCATQaDSorKxGPx+H3+8mGlFwuJxLSqqqqLT+boaEh9PX1FfmEXQmfz0dLSwu6urp2nAD7ZZSQrq6u4tChQ2QDMJfLwev1kiJRpVLhe9/7Hu65556ipmEymcSf//xnDA4OgqZpuN1uLC4uFnnjVVZWwmKxQKPR4ODBg9v69bGwfNlhazcWlq3xeDw4fPjwhkFpeXk5HnroIZSVlSEcDuPZZ58t2ipPp9OIxWLg8XhQKpVQq9V45plnUF1dvenrzM3NkSFpIBAAj8eDTqcjcsXq6mpIpVISYsXn82GxWEiNtxmpVIoYyet0OtLgo2kaPp8PXq8XCoWCbCZ9UjABAnNzc0XhQ1VVVTCbzRtqML/fj0OHDuGdd96B2+1GPB5HoVCAQCBAeXk5HnzwQRw8eHBH/nEulws9PT1wOp3ktrKyMuzatQtWqxXxeBzhcBg8Hg9qtRplZWU7OnaKorC8vAyn0wmPx4O1tTWsrKwgn8/DaDTiG9/4Bmw223VvBVIUhXA4jPn5eYyMjMDr9YLD4UChUECj0aClpQUmk+m6Gm2Mr5rT6SSBCJlMBuFwGDqdDlarFUNDQ/B4PFCr1Whra0N9fX1RLZvNZuFyuZBIJFBeXo7q6mpwOBwsLCxgYGAAuVwOS0tLSCQSMBgMCIVCWF5eRkNDAwQCAVZWViCVSlFfXw+z2Qyz2bzpsaytreHIkSMIh8Pw+/2YmJggUl+73U4ko62traioqIBKpcLKygobfsDyhYNtsLGwfEwoisJbb72FN998c8v7yOVy/OxnP0Mmk8Hzzz+/YWJ6OW1tbfjZz362pcGnx+PBBx98ALfbDY/Hg3w+D7PZDK/Xi+XlZVRUVEChUKCsrIwkP2m1WnLxuhHI5XJYWFjAqVOnMDQ0hGQySdbapVIpKSjVajX27duH5uZmcLlc0DSN9957D7/73e/IxZ45l4z04KGHHsKBAweu6slRKBQwNDSE3t5e+P1+AOtbUEyDplAoYHFxEeFwmEhI1Wo1HA4H7Hb7po08phnW19eHiYmJbZupOp0O3d3daGpq2nFRtVMJqUKhQENDww0tIaUoCmfOnMFbb71FmplMghXTRGxubsZPf/pT1NfXk8fRNI1Tp07hlVdeIYlgU1NTRRuEjNRGoVCQdN8bueHIwvJpwdZuLCzbk06ncfToUUxPTxfdLhQK8a1vfQt2ux2JRAK/+c1vsLCwQP6eyWQQjUZJk02hUODpp5/eUlqYz+fR29uLgYEBJBIJ+P1+VFdXQygUYnJyEhKJBGq1mgwNFQoFSczeanC4nWQ0mUzC7XYDWFcHyGSyj3+yLoPxRZ2dnS3asi8vL4fFYtkQOjQ1NYV//OMfOHfuHFZXV4k/nUAggMFgwMGDB/H9739/R00Un8+Hnp4ejI+PkzpLLpejvb0dzc3NyGQyCAaDoGmaBCLstP6NRqMkkXNhYQEejwccDodIU61W66bbijuFkdwODAzA7XaDpmky5G1vb4fdbr/uRNjLAxEKhQL8fj+y2SysViupaYVCIWw2G9ra2orkvUxjdm1tDSKRCAaDARKJBB6PBz09Pchms1hbW4Pf74fBYCCDUZ1OB6VSSZrHBoMBdXV1sFqtm26eZTIZvPvuu5icnEQqlcLo6CjZCGUk0oxklAl0Yz4DZsuTheVGh22wsbB8QoyOjuKll14qmuhdDofDwXe+8x2YTCb893//d5GH1JXYbDY89dRTWzaJKIrC8PAwTp8+TfzYmA2smZkZIg8QiUSoqalBRUUFhEIh6uvrb6jtpkKhgOXlZZw6dQoXLlxALBYjsgu5XE7ep1KpxN69e9HW1gY+n49kMok//elPOHr0KJEGMA0WuVyO2tpafPe738U999yzbRACsF5UTExM4Ny5c1heXgawLh1lpooqlQpzc3NYWlpCPB4nxZBer0dTUxNMJtOmU81YLIYLFy5gYGCgyMPlSqRSKdrb29HV1XVNKZhfFglpOBzGK6+8grGxMQAbQxCkUinuu+8+PPTQQ0U/ENxuN55//nl4vV5kMhlMTk4WxdlzuVzU1dWhuroanZ2dePTRR3dkhMzC8mWCrd1YWK4OTdPo7e3FiRMnNqSF7969G7fddhtyuRz++7//G5OTk+Rv2WwW0WgUHA6HJKb/4he/2DZsJxAI4MMPP8Ti4iKCwSCy2SzMZjMCgQDcbjcqKipIKACT2qnT6WCxWLZsLgQCASwtLUEoFBYFEBQKBaysrCASiaCiogLl5eWfeP1H0zQ8Hk/RBhWwvllmNptRXl5Obsvlcujp6cE///lPTExMwOfzIRgMolAoQCgUwuFw4JlnnsHXvva1HdUrkUgEvb29GBoaIn5kEokETU1N6OzsBIfDIU0mZlvsSs+4rcjlcnC5XBgZGcH4+DgCgQCEQiH0ej2sVissFgtqamquuxkGrA/M+/r6MDMzQ4bFCoUCbW1taGtru+6ahWkSLi8vIxQKIRwOQywWQy6Xw+/3E5++jo6ODXYcqVQKLpcLmUwGVVVVKC8vRyQSwenTp5FKpeD1erG2tkaScGOxGGnsZrNZMvy12WwwGo2wWq2bHsfQ0BCOHz+OTCaD2dlZrKysAFhvBtfW1oLD4aCxsZFscoZCIRLwsdPPkIXl84JtsLGwfIIEAgH87ne/29b8vqWlBbfffjuef/55JBKJLe9nNBrx9NNPb+sjlUgkcOLECUxMTMDr9RIvBIqiMD09DYVCgdLSUpJeJJPJoFQq0djYeNXG02cJMzk7deoUent7EQ6HiT+aXC4nK/4ymQx79uxBZ2cnhEIhFhYW8D//8z8YGRlBKpVCNBpFKpUixa7ZbMY999yDr33ta9Dr9VedjC4uLuLMmTNEfsDEwbe1tcFkMsHv92N8fBw+n480SEtKSoiEtLq6esNrFAoFTE1Noa+vD/Pz81u+NofDIQmpFotlx6vwO5WQCgQC1NfXo7Gx8YZqsgLrx3Dp0iW89tprpBnJFHJM0WwymfDoo4+io6ODnJt0Oo2///3vOH/+PCiKwsLCAvFoYdBoNLBardBqtXjyySfZ6xfLVwq2dmNh2TmLi4t4/fXXN9RmNTU1ZDP+xRdfLApIyOVypLGkUqkgFovx+OOPo6OjY8vXoWka4+PjOH36NMLhMDweD1QqFbRaLaanp5HJZFBeXg4ulwudTgedTkeGpFvZS2wlGQXWfbZWV1chFouh1+s/tWEbs0F1uT8qU4tVVlaSuiMcDuPYsWM4fvw4VldX4fF4EA6HSQrm7t278bOf/Qzt7e07qlVTqRT6+vpw8eJFMuQWCoUkaEomk8Hv9yOZTEIqlUKj0RQNca+Gz+fDxYsX0dvbi2g0ColEgtraWpSXl8NoNMJsNu9I4roVzEB2YmKiaJDb2NiIzs7O696YYwIRJicnsbq6Cg6HA7FYjNXVVXC5XJhMJrS0tMBmsxVJaZntRJ/PB5lMBoPBgHw+j9OnTyMajSIUCsHlckGj0UAmkyEWi6GxsRH5fB7JZBKRSARra2uoqqpCXV0daUZeeb59Ph+OHDlC/KWnpqZAURTkcjkcDgeEQiGqq6vR2tqKmpoaFAoFhMNhNvyA5YaHbbCxsHzC5HI5vPzyyzh9+vSW99FoNHjggQfw6quvbhuSwJjnXm2zaWFhAcePH8fa2hqWlpYgFothsViwsrICj8dDPD3UajXxeTAYDKivr7/h1q2ZSdnp06fh9/vB4XAgl8uhUCjIexWLxeju7kZXVxckEglOnjyJP/zhD2RSGY/HyTZcaWkprFYrdu3ahV27dsHhcFx1KujxeHDmzJkimadcLkdzczPa29uRz+cxPT0Np9NJpq+MhNRut8PhcGx68ff5fOjv7yc+cltRWlqKXbt2oaOj45omdel0GlNTUxgbG/tCSkiTySSOHDmC3t5eAOtFXjAYRDgcBk3TEAqFuOWWW/DDH/4QGo2G3OfcuXP4+9//TtK7pqeni2TYYrEYDocDZWVleOSRR7Bv377P5fhYWD5r2NqNheXaiMVieO211zYMa2QyGQ4cOAC9Xo+//vWvOHv2LPnb5U02Ji30xz/+MW666aZtXyudTuP06dMYHR0lm0YWiwUcDgdTU1OQy+WQyWQQCASwWCxQKpVQqVSw2+2bDl8LhQLcbjfC4fAGyWgmk4Hb7UYul/vUGxShUAizs7NFnrTMMFKr1ZJGy/z8PF555RVcuHABkUgEPp8P8XgcmUwGcrkct956K+677z4ywLyan1oul8Pg4CD6+vrI58Hn81FbW4s9e/ZAo9GQDXmhUAiNRgOVSrXjgWYymcT58+dx5swZhEIhqFQqGI1GyGQyVFRUbDi+ayWRSGB8fBzj4+MIh8PIZDLg8/kwmUxoa2uDVqu9Lh+4fD4Pp9OJ8+fPIxKJQC6XY21tDR6PB9XV1dizZw86Ojo2fKfi8Tjcbjfy+Tx0Oh1KSkrQ09MDn8+HWCyG+fl5lJSUoKysDNFoFJ2dnSTBNJ/Pw+PxIJlMwmq1wmAwwGazbWiY5nI5fPDBBxgZGUEikcDo6CiSySR4PB5sNhvUajVEIhG6u7thNBohlUrh9XqhVqtRWVl5XeeZheXThm2wsbB8SvT09JAf/ZshEAhw//33o6enBx6PZ8vnKSsrwzPPPHPVC0k+n8f58+fR398Pr9cLj8cDg8EAjUZDpkIVFRWgKApGo5FISO12+44N9z9Lkskkzp07hxMnTpBmEbOBxxQYQqEQnZ2d2L17NwQCAf7yl7/gzTffRD6fR6FQQCKRQCwWI/4WDocDNTU1aGhoIJHg2xGJRHD27Nki+YFUKoXNZiOSzuXlZUxMTGBpaQmxWIxMHhlpptls3lAQZbNZDA8Po6+vb9vPns/no7GxEd3d3dDpdNdUtDES0rGxsW03JfV6PRobG2G1Wm8YCens7CxefvllUpxns1n4fD4SglBZWYmHH34Yt9xyC2m6rq6u4vnnn8fKygpSqRSZBDNwOByYTCbodDrcfPPNePjhh2+45jILyycNW7uxsFw7FEXh+PHj6OvrK7qdw+Hgtttuw+7du/H666/j3XffJX/L5/OIRCKgaRoKhQJCoRAPPvgg7rrrrqu+3vLyMhmSer1e0lxgzPYZLyulUon6+nrikVVXV7dp08nv92N5eRkikQi1tbVkqEjTNNbW1hAIBFBaWrrp1v0nSSQSgdPpxNraGrlNKpWirq6OqAoKhQL6+vrw6quvYm5uDrFYDOFwGKlUCplMBpWVlbjtttvQ0tICh8MBs9l81TR2iqIwOjqK3t5eUkcw24C7d+8mRv3hcBgcDgdlZWVQq9U7lnvmcjn09fXho48+gt/vh0ajgdFohFAohEQiQV1dXZFU91pJJpOYnp4miZ2MZYZWq0VDQwO0Wi3UavU1f3aFQgHDw8MYHBwERVHIZrOYmZlBLpdDU1MT7r77bphMpg2PYaSmSqUSWq0WFy9ehNvtRiqVwuzsLAQCAbRaLSKRCNm6u3jxIiKRCJLJJFZWVqBSqWA2m2EymTYNdBgbG8N7772HVCqFqakpeL1eAEB1dTXMZjNomkZLSwvsdjs0Gg28Xi8bfsByw8I22FhYPkXcbjd+97vfbZss2dnZibW1NeL/tRlyuRxPP/00DAbDVV8zFArhgw8+IJ4G+Xwedrsd2WwWTqcTZWVlEIlE4PP5sNlsZPrU2Nj4iZvgfhJks1kMDAzg/fffJ7HeYrEYSqUSEokEHA4HPB4P7e3t2Lt3L2KxGP7nf/4HQ0NDANYLrUwmg2QyCT6fD71eT+SyRqMRnZ2dqKur27awSiQS6O3txeDgIJEfiEQimEwm7NmzBzU1NYjFYlhYWMDExAQ8Hg/ZUJPJZERCqtVqiwoBJgmzr68P4+PjW8o7gfUio7u7G83NzdfUGKIoCi6XC6Ojo5idnf3CSEhzuRzee+89HD9+HBRFgaZpxGIxBAIBsjG4a9cu/OQnPyGm0tlsFv/4xz9w+vRpUBQFp9O5YZOvrKwMNpsNJpMJTz75JBv/zvKlhq3dWFiun4mJCRw9enTDoNRut+P+++/HyZMn8dprr5HbGQkbRVFQKBQQiUS45557sH///qteUwuFApEhhsNhYiBfVVWFyclJFAoFlJaWIpfLQa/Xo6amBlKpdEsz+cslo3q9vihNPh6PY2lpCTweD3q9/qoNq49LPB6H0+nEysoK8bgTi8UwmUwwGAzg8XiIx+N4++238c4778Dj8SAWi5EmW6FQQH19Pbq6ulBVVYX6+nrSaNnuvNI0jdnZWZw7d44EPnA4HFRUVKC7uxv19fWIx+OkrlCpVNBoNDveEsvlchgYGMC5c+fg8XhQUVEBnU4HPp9PghHMZjPZuL9WkskkZmZmsLCwgFgshmg0inQ6jerqahgMBuKRdq2WL7FYDENDQ1hZWSGvEQgEIJfLsWfPHtxxxx0bNhzD4TCWlpbIcS0uLpKAqdnZWVAUhZqaGkSjUdTV1eHee+/F3NwcxsfHkc/n4ff7EQ6HUVdXRxQ0Vw65g8Egjhw5QkLbZmZmQNM0ZDIZHA4H+Hw+ampq0NXVBa1Wi1AoxIYfsNyQsA02FpZPmWQyiT/+8Y+k4bMZzEr5dk02sViMX/7yl7BarVd9TZqmMTk5iZMnT2J1dRUrKysoKyuDxWLB0tISAoEAKisrkc1moVarYbFYiDmu2Wz+RCPdPymYydu7776LmZkZAOtNIaVSiZKSEnA4HHC5XDQ1NeGmm27CxMQEXnrppaLmZi6XA0VREAqFqKmpgVarBY/HIz5rLS0t23reMcXU5fIDxiy2u7sbZrOZmOrOzMxgZmaGSEg5HE6RhPRKaWY8HsfFixcxMDBQZBR8JUy6UldX1zX7cnwRJaQrKys4dOgQaa4yyViMV5tKpcL+/ftx7733Ejltf38//vrXvyKdTsPj8WBmZqYo0ZXZ3KysrMRjjz2G9vb2z/7AWFg+A9jajYXl4+H3+3H48GGSNs5QVlaGBx98EFNTU/jb3/5GGkeFQgGRSASFQgFyuRxisRhf+9rX8Mgjj+xo0yYWi+HDDz/E9PQ0/H4/MpkM7HY7KIrC7OwsFAoFGQjabDaUlpZCo9HAZrNtaJRdLhktKysr8qLN5/Mkjb2ysvK6m0DXQjKZJKFRzDVZKBSitrYWRqMRAoEAy8vLePnll4mcMRaLoVAoIJvNQiqVoq2tDQaDgQxJHQ7HjhosS0tL6OnpwezsLPmsysrK0NHRQZJHmfMtl8uJv9hOj2toaAgXL16E3+9HWVkZqqqqSPNPoVDAbDbDaDReVyhCKpXCzMwMXC4XstksIpEIEokE1Go1dDod5HI5ysvLUV5efk3NQZfLBb/fT+rPhYUFcLlcVFdXY+/evWhtbS1qzOZyObjdbsRiMajVanLc+Xwe8/PziMfjMJlMSCQSqKqqwv79+5HP53HhwgV4vV5ks1msrKwUha5Zrdai95zP53Hy5ElcvHgR0WgUo6OjyGQy4HK5pCknEolw0003wWQyIZPJIJ1Ok4YzC8uNANtgY2H5DKBpGu+99x5ef/31DQlVDGKxGDKZDIFAYMvnEQgEOHjwIFpaWnb0uplMBqdPn8bg4CBWV1cRDodRX1+P0tJSzMzMgMvlQi6XI5vNwmQyQavVQiqVoqGh4aryyc8LmqYxPT2Nt956C+Pj46BpGjweDwqFAgqFghSPDocD3d3dOH78ON54442iCTRFUSRR6fJEIpFIhPr6enR2dm4ryWRSXM+dO0caeDweD9XV1ejq6iKJnUwxMTExAZfLRSSkfD6fSEgtFktRccEEVPT19ZGwha0wm80kNv5aV+S/SBJSiqJw5swZvPnmmyQtNplMwufzkc+1sbERP/7xj9HQ0AAulwuv14sXXngBLpcLiUQCExMTRQm/zNTTYDDgzjvvxIEDB27IxjILy8eBrd1YWD4+2WyW1ByXw+fzcd999yGTyeCll14iG+IURSEcDhc12To7O/HTn/50xw0Wp9OJDz/8ED6fD6urq1CpVLDZbFhaWkIwGERpaSlSqRSUSiXsdjskEglMJhOMRuOGemArySgAYjAvk8mg1+s/VirmTslkMpibm4PL5SLnjM/nw2g0wmQyQSAQ4OLFizh06BBmZmYQi8WQTCZB0zRomkZtbS06OzvB5XLB4/FQUVEBm80Gm8121bABv9+Pnp6eItVASUkJWltb0dnZCWD9nMTjcUgkEqjVaiiVyh1t9UciEQwPDxNpJ6MQYWCO0Ww2X1fQ2OWNNoqikE6nEYvFUFJSgurqajJ0Li8v35GElJEM+/1+SCQSzMzM4Pjx40gkEigrK0NdXR1JTL28Yej3+7G6ugo+nw8+n4+RkRHSsPP5fKirq0M2m4VSqcSBAwcgl8uxsLCAoaEhZLNZsqFZU1MDo9FIJMOXn+Pp6Wm8/fbbiMfjGB8fJ8EZlZWVsFgsKBQK6OrqQmNjI/h8PqLRKKqrqz/3oTALC8A22FhYPlOmpqbwwgsvkO2bK6FpGmKxGOl0esuLOZfLxWOPPYbdu3fv+HVXVlZw/PhxzM/PY2lpCSKRCA6HA5lMBvPz8ygrKwNFUeByuWhoaIBCoUBFRQUcDsenLh34OCwuLuLYsWMYHBwkW2JyuRxKpZIUiRaLBfX19Thy5AguXLhQ9Hgej0emuiqVCul0GsB680Wr1aK9vZ0kGW0G0+w7e/YsMURmwg527dqFxsZG0ryLx+NYWFjA+Pg4PB4PeS3Gj6SpqWmDl0QgEEB/fz8GBwfJ/TdDqVSSUITtNvA241olpE1NTRsKoc+KUCiEV155hfzIuTIEQSaT4Y477sBDDz2E0tJS5PN5vPbaazh+/Djy+TxmZmY2yLWZHy0NDQ04ePAgW5yxfKlgazcWlk+O/v5+fPDBB0Ub0QDIUO75558nQyCKohCJRJDP5yGTySCVSuFwOPCLX/xix1tG2WwWvb29uHjxInw+HwlB0Gg0mJ6eJte9ZDIJg8EAo9GIkpIS2O32osYOsN6cmZ+fRz6f3yAZTaVSWFpaQqFQIBtRnwXZbBaLi4tEygqs12U1NTWoq6sDALz33nt44403sLa2hlgsRs6vRCLB17/+dTQ0NGBlZYU0M81mMxwOx1X95WKxGM6dO0eaPsxzNjQ0oKuriySPRiIREhJWVla2o2Hm2toaRkdHsba2hlwuB7FYvGHDrry8nIQiXOuAlPE+W1xcBEVRoCgKqVQKJSUlUCqVEAgE4PF4UKvVO5KQRiIR8ttAKpXi2LFjGB4eJgoNg8FA0mBramrA4/GQyWTgcrmQTCYhEAgwPT2NbDaL1dVVuFwumEwm0DQNkUiEAwcOQKPRIJ1O49KlSyQ4wePxIJ/Pw2azobq6ekODNBKJ4MiRI1hZWYHL5cLc3BwAkH9LXC4XFosFu3fvhlKpRDAYZMMPWG4I2AYbC8tnTDgcxvPPP7/ldhJN08hkMuDxeNuuvH//+9/HbbfdtuPXpSgKg4ODOHPmDJaXl8n0yGQyweVyIR6Po7S0FLFYDOXl5bBarRCLxTCbzaitrb2hTUQ9Hg/efvtt9Pb2kiKNKTSY5hhTeB49epSYpzJIpVLU1NSgubkZAOD1esnzyGQyNDY2oqOjY1tJ5uLiYpH8gMPhQKlUor29Ha2traTAoWmaSEinp6eLJKSlpaWw2+0bpJnZbBajo6Po6+vbVtrJ4/HQ2NiIrq6u6/JRYySkTGG4FQqFAo2NjWhoaPjMG1I0TePSpUs4fPgwCTHIZDLw+XykCVlbW4sf/OAH2LNnD/h8PoaGhvDHP/4RiUQCq6urmJubK/qBJBQKYbPZYDAY8Pjjj8PhcHymx8TC8mnB1m4sLJ8sS0tLeO211zYMSrVaLTo6OvDHP/6RbEtTFIVoNIpcLkeabHV1dfj3f//3a/K89fv9OH78OBYWFrCyskICkJitIalUSpKz7XY7aTJcKb/bTjJKURRWV1cRCoWgVquLNpY+bfL5PBYXFzE/P0+aXZd7mMXjcRw6dAhnz55FKBRCPB4n13CtVosf/ehH0Gg0mJiYQCQSgVgshlarhcPh2KASuJJ0Oo3+/n5cuHCBbPMzEsbdu3dDrVYjGAySDSomEOFqklRmeDk+Po5YLAaxWAyxWLzheyMWi0kowrUOtNPpNGZnZ7GwsEDOBxOyUVJSQr4TYrH4qhJSpmGWy+VQWVmJmZkZnD59Gl6vF0qlEnV1dSgpKSF2MsymocfjgdfrRT6fh9vtRjabJTWuXq8nr7d//37odDoA66FUFy9eRDKZJHVZRUUFkQqbTCaiKKAoCh999BH6+voQCoUwNjaGXC4HLpcLs9kMtVoNuVyOW265BTqdDuFwmA0/YPncYRtsLCyfA4VCAYcPH8bx48c3/TtN0wiFQuDz+dtOnu6//37cf//911QExWIxnDhxAqOjo1haWkIul4PdbifR2gKBAHw+H4lEAvX19WSa2dDQsGEieqMRiUTw9ttv4/Tp06TZIpFIoFAoyCZZZWUlwuEwzp07Rwo5hrKyMrS0tGDv3r0IBoOYm5sjTRwej4fa2lrs2rWr6OJ/JV6vF2fOnMHExATZBmPO365du4qadLlcjkhIFxcXiYSUx+NBq9WiqamJJIYB69+L5eVl9PX1YXR0dNtQhMrKSnR3d6OlpeW6pJ2BQABjY2MYHx+/ISWkyWQSb7zxBs6fPw9g/dxEo1EEAgHis7dv3z48/PDD0Ol0CAaDePHFF+F0OhGNRjExMUGCKBiYhvO3vvUt3HfffZ970AMLy8eFrd1YWD55EokEXn/9dSwuLhbdLpFIsHfvXrz++uvES5W5NmWzWUgkEpSUlECr1eJXv/rVNQ2oaJrG6OgoTp8+TZoaWq0W9fX1WF5eRiQSgVQqRTweJz6qMpkMFotlw9b5dpLRSCSClZUVsr200227T4JCoYClpSU4nc4iRQGTJLmwsIC//vWvmJycRDQaJcniPB4PnZ2d+OUvf4lcLofh4WGsrKyQwaXNZiNbfVtd1/P5PAYHB9Hf349QKATgX5LO7u5uGAwGRCIR+P1+5PN5KJVKqNXqqzbF8vk8nE4nJicnkcvlUFVVBbFYjOXl5SLrEkY5YTabrzl8abNGG5fLRUVFBRQKBWKxGGm2bSchpSgKKysrpMmayWRw7tw5otAoLS1FRUUF8Tw2GAzEe5gZ1DOy32g0irGxMVRVVZFm37333guLxULOy+joKPHI9fv9SCQSRI5qtVqLfAHn5ubw1ltvIRwOY2xsjPz7YjYBaZrGTTfdBIfDQVRAbPgBy+cF22BjYfkcGRgYwJ///OcNP/QZ1tbWQFEUqqurtywKbr31Vnz/+9+/5mbA3Nxc0US0tLQUDQ0NiMfjWFtbg0KhQCKRIMEBTDy33W7/3Ly4dkoymcT777+PEydOkGmhUCiEUqmETCYDh8OBWCzG/Pw8FhYWis4dY/B6++23Y8+ePVheXsbIyAhp3ADrjbj29nY0NzdvOYWORCLo6enZID+w2WwkCevy100kEkRCura2RopLZouQkWYyBVEikSDFYDgc3vJciEQitLW1obu7+7pMjCmKwuLiIsbGxm5ICenMzAxefvllYkDNpFUxjdGqqio88MADuPXWWyESiXD06FG8/fbbyOVymJqaIlNpBoVCAbvdTvxyrlVyy8JyI8HWbiwsnw4UReHUqVPo6enZ8LfW1lacO3eOXJeYFOxMJkOabBqNBs8888w1+92mUimcOnUKIyMj8Hg8SKVSaGhoQGlpKRYXF8lGfCKRgMFggMlkgkqlIoNUhmQyiYWFhU0lo4yZPZNYWVpaep1n6fpgGj1Op7NowFdZWYmamhpcuHABr776KlZXVxGNRknzSCKR4IEHHsDjjz9Otp1mZmaQSqUgk8lQW1sLu90Og8GwpdccRVEYHx9Hb28vPB4PgPW6UKvVoru7GxaLBclkEn6/nzyvRqO5qqw2k8lgcnISTqcTHA4HFosFEokECwsLG2o4RurKBD/slEwmQxptTK3GNMLUajVisRix1NhOQhoKhbCysgKxWAyVSoW+vj4sLCxAIBBAIpFAIpEUnb/q6mrU1dUhlUrB4/HA6XSiUCggl8thaGgIKpUKarUa2WwWt99+O1GLAOt+wEy4VyaTwerqKkpKSoh89vIhczwex9GjR8m2IxN+JRaLifKgubkZe/bsAbCu/mDDD1g+D9gGGwvL58zq6iqee+65LSV5Xq8X4XB423TP3bt349FHH71mk/ZcLofe3l6cP38eLpcL0WgUZrMZBoMBy8vLSKfTEIvFCIfD0Gg0sNvtkMlkJP3nRt/wyWazOHHiBD744APSSOHxeFAqlZDL5eByuYhGo5iZmUE2my06HiZC/oEHHkB7ezvcbjcGBwfhdrtJQ1QkEsFms6Gzs3PLJmgqlSIeKkyhKBKJYDKZyFT08iki4ys2MzODyclJIiEF1ht7NpsNjY2NpOBlksX6+vqK0rE2g3lNu91+Xavz1yohbWxs3BD1/mmQy+Xw7rvv4sMPPyRN0EQiAb/fT6QEHR0dePjhh2G1WjExMYHf//73iEajcLvdWFxcLDpvfD4fNpsNFosFBw8eJF4wLCxfNNjajYXl02V6ehpHjhzZMCitrq7GwsICadJc3mQTiUTEL/ZXv/rVdf3bdLvd+PDDD+FyubCysgKlUommpiaSns0oEWiaRmNjIzQaDXQ6HSwWC2nabCcZpWkaPp8PXq8XCoUCOp3uMw8CYkz4Z2dni6SVGo0GarUa7733Hk6ePIlgMIhEIkGu/9XV1Xjqqadwxx13IBaLYWpqigxK+Xw+KisrYbfbUV9fv2VjjKZpzM3NoaenhzRyAKCiogK7du1CQ0MDGejFYjGIRCJoNBqoVKpta+N4PE4UJGKxGI2NjVAoFJibm4Pb7S6yr+Dz+WRL7Fpqqa0abUajEQaDAYlEAl6vl2wAbiYhTafTZButqqoKc3NzuHTpEmiahk6ng0KhIB6DDGVlZaisrEQ6ncb4+DhSqRQEAgGGh4chEolQXV2NVCqFm2++GV1dXeQ8URSFqakpjI+Pk4CQQCCA2tpastWn1WrB4XBA0zR6enpw9uxZ+P1+TExMIJ/Pg8PhwGQyEQnv7bffDoVCgWQyyYYfsHzmsA02FpYbgHQ6jb/85S8YGBjY9O/BYBCLi4uw2WxbTmKam5vx5JNPXtc6NOPvMT4+DrfbDZFIRAz6l5aWwOPxkMvlyPp2TU0NSktLSWFwo0NRFM6ePYt33nmHNIaYBFUmedTpdMLtdkMsFhc1n5RKJdra2vDQQw+htrYWPp8Pw8PDmJycJCvqzGp/R0cHHA7Hpp9BLpfDxYsXcf78eTKt5PP50Ov1ZCp65UQ1n89jdXUVExMTmJ+f3yAhZaSZTEHETAIvXrxICqfNUCgU6OzsRGdn53WbGV+LhLSpqQlWq/VTX9VfWVnBoUOHSDFMURSCwSAikQhomkZpaSm++c1v4t577wWHw8Hvf/97TE5OIhQKYWpqaoNkWKfTwWw243vf+x5uu+22G76hzMJyJWztxsLy6RMKhXD48GHSTGOQSCQIBoNF4TrxeBypVIo02aRSKf7jP/4DZrP5ml+3UChgYGAAvb29xD/NbDbDZDLB6/UiEomAy+UiEolAoVAQNUJ9fT2qq6vJ8/h8PqysrGwqGU0kEkQi+HluA3m9XszOzhZtezGDxqNHjxLZILP9zwzW/vM//xNGo5H4vA0NDZFNP4VCQUIRqqqqthw8Li8vE49dpgFWWlpKPHZ5PB78fj/C4TDZDCsrK9u2IRkIBDAyMgKfzweFQoHW1laUlpZiYWEBc3NzG+oqtVoNi8VyTd5imUwGTqcT8/PzGzba6uvrUSgU4PP5iOwVKJaQAuvXkEgkgoqKCmQyGfT09CAUCqGqqgoOhwM8Hg/z8/NFQVxisRhSqZRImZVKJcbHx5HP51FTU4NMJoOOjg58/etfLzqWWCyGCxcuwOfzIZ/PY21tDVwuF1arlSTFMqoCl8uFN998Ez6fD6Ojo0S1oFariY3L7bffDqPRiHg8zoYfsHymsA02FpYbBJqmceLECbzyyisbEqqAdcnh5OQkzGbzllK/+vp6/PKXv7yu5E+apjE2NoaTJ09ifn4ePp8PWq0WNpsNiUQC4XAYfD6fFGxNTU0oKytDTU0N6uvrvxA+BzRN4+LFi3jrrbeKvFOYQIR8Po9Lly6RYpQpjjgcDiorK3HnnXfi3nvvhVwuRzKZxMzMDAYHB7G2tlYU997U1ISOjo5NZRUURWF0dBTnzp0jxTiPx0NlZSW6urpgt9s39TxJJpNEqrm6ulokIWVSSGtqasDlcpHL5TA2Noa+vj4sLy9veT6Y1Niuri4YjcbraiAxEtLR0VEiC9gMgUAAq9WKxsbGT1VCSlEUTp8+jbfeeos0zDKZDLxeLzKZDDgcDhwOB773ve+hpaUFH3zwAY4cOYJ0Oo2JiQlEo9Gi5yspKYHD4cBNN92ERx99tOiHBwvLjQ5bu7GwfDbk83m8/fbbGB4eLrqd8WBjBnLAv5psQqEQCoUCQqEQv/jFL9DY2Hhdrx0Oh/Hhhx9iamoKy8vL4HK5aG5uhlwuh8/nQzKZRDabRSwWg8lkQl1dHdRqNex2O2lYXC4ZZYaoDIVCAcvLy4hGo6ioqEB5efnnNnAKBAKYnZ1FIBAgt0mlUni9Xrz33ntYWlpCLBYjtYhEIsH+/fvxxBNPQCqVkqCp8fFxcs0Xi8XQ6XRoaGiA2Wze0gYlEAjg3LlzGBsbIw2pkpIStLS0oLOzE1KpFMFgEIFAADRNQ6VSQaPRbJtEv7KygtHRUUQiEVRWVqK1tRVKpRJra2twOp0b1AKMAqKurm7Hzc5sNksabcz7vrzRJhKJEAqF4PP5iP/c5RLSXC6H1dVVyGQyqFQq9Pf3Y25uDnK5nNR1wWAQs7OzRTVUJpNBJBJBJBKBRqOBy+VCOByGyWRCNpuF3W7H3XffvaEROT8/j6GhIeRyOcTjcXg8Huh0Ouj1ehgMBtTW1oLH4yGZTOKtt97C7OwsZmdnsbKyAuBfwVUcDgfd3d3o6OhAJpNhww9YPjPYBhsLyw2G0+nE888/v6mvFrNazqTtbHaRqKmpwdNPP33dm2XpdBqnTp3CwMAAmfIxEdqM5wSwvvVWXl6OhoYGyOVy2O12aLXa63rNz4Px8XG8+eabJOoeWC/SlEolwuEwBgYGwOFwoFKpyMWfSU968MEHsXv3bvB4PFJ4Dg4OYm5ujqSH8Xg81NXVobOzEyaTacNnRdM0Zmdn0dPTQ5p9XC4XarV6W383JgCDkZAGAgFSSDI+K42NjcRPZXl5Gf39/RgZGSla5b+SiooKdHV1obW19bpNjdPpNCYnJzE2Nva5S0iDwSBeffVVjI+PA9gYgiCTyfCNb3wDDzzwAGKxGF588UUEg0EsLCyQaT0Dj8cj/nI///nPSRIWC8uNDlu7sbB8tly6dAnvvPNO0bCJoih4vV6k02lSCzAJigKBAAqFAnw+Hz/96U/R1dV13a89PT2NEydOYGlpCWtra8Q3N5fLIRgMIpvNIhqNgqIoNDU1obKyEgaDAXV1daSecblciEQiUKvVG5oRoVAIq6urEIvFqKmp+VwHq+FwGLOzs0Wp8BwOhySuh0IhIpEF1v1Yn3rqKdx5552kOZhKpeB0OjE0NITV1VVwudwNoQibEYvFcP78eQwNDRUFatntdnR1dUGtViMcDsPv9yObzUKhUECj0WzZECsUClhYWMDExASSySSMRiOam5tJaMX8/HxRwipzrEz4AxM8cDW2arQxg3KpVEpSQK+UkJaUlJCmcHV1Nebn53Hx4kVQFAWj0Yi2tjaUl5fD5/NhdnaWbG3mcjksLCxgbW0Ner0e6XQaKysrqK2tBQAYDAbcf//9G+rOdDqNS5cuEdms3+9HOp0m4Qc2mw1lZWWgaRp9fX346KOPsLq6iqmpKbKkUFtbC6VSiZqaGtx2223g8Xjg8Xhs+AHLpw7bYGNhuQGJxWJ44YUXMDU1teFvqVQKIyMjEIlEcDgcm07GKioq8MwzzxQlVl4rS0tLeP/99zE1NVUUgiAUCokZaSKRQCwWg9VqhdFohEajQWNj4zXFz3/eOJ1OvPXWW0WpnMxUeXFxEePj45BIJCgtLSUSTrlcjvb2dnzve98jRQKwXvCNj48Trw/mP68ajQZtbW1oaWnZdLvQ7Xbj7NmzmJmZIY9RqVRoaWlBe3v7lk2oQqFAJKRzc3NEQsoENTDSTLFYjGQyiUuXLqG/v3+Dsf/lCIVCtLa2oru7+5qTrC7nRpCQ0jSNwcFBvPbaa0Q+kM/n4fP5yHuqra3FgQMH0NzcjJdffhkjIyPw+/2Ynp7e0JCsqqqC3W7HT37yE+zdu/cTfa8sLJ8GbO3GwvLZs7q6isOHDxdtrdE0jcXFReRyOdJMYJpsfD4fSqUSPB4PjzzyCL7+9a9f92tns1mcOXMGFy5cwMrKClKpFBwOB3Q6HeLxOEKhENLpNILBIJRKJVpaWlBWVkZkeMD2ktFMJgO3241cLgetVvuZ+KxuRywWg9PpxOrqKqmfwuEwent7SV3E+ONxuVy0tbXhP//zP2EymchzFAoFskk2MzNDtp1MJhMJRdhM7pnJZNDf348LFy6QGkMoFKKurg67d++GVqtFPB6H3+9HMpmEVColgQibNcSY8CUmWbO+vp7YjlyesHplDceEAhiNxh2FkGWzWczNzWFubo7UORwOBzU1NbBaraQRGI/HiySkhUKByJuZxm1PTw/8fj8qKyvR1NQEi8VCZMlOpxPLy8ugKAqrq6u4ePEiUYz4fD7U1taCz+ejvLwc+/fv3/S3w8rKCrE9SafTWF1dhVqthtFoRHV1Nerr6yEUCrGysoIjR45gdXUVo6OjZNitUqlgMBgglUpxzz33oKysDBRFseEHLJ8qbIONheUGhaIoHDlyBG+//faGv2UyGbKR5HA4Ni1wVCoVnnnmmSKfjet5DwMDA/joo4+IB5jJZILFYkEmk0EsFkM2m0UwGASHw0FzczM0Gg1MJtO2oQw3IktLSzh27BguXrxIotP5fD5EIhEmJyfhdrshl8tRWlpKGkEVFRW48847cf/99xd5mWWzWTLdW1paIlNHsVhMEiqrqqo2vAe/34+zZ88WNfvkcjkcDgc6Ozu3TRtLpVJEQrqyskKmqiKRiEhIDQYDOBwOnE4n+vr6irb3NoOJp2d8Nq6HG0FCmkgk8MYbb6Cvr4/cxhS9+XweQqEQe/bswYEDBzA/P4/XXnsNiUQCExMTpGhmkEqlcDgcuOuuu/Dwww+zU1CWGxq2dmNh+XxIpVI4cuQIZmdnyW00TcPpdCKZTBLT9XQ6jVgsBh6PB5VKBS6Xi+985zv45je/+bGugx6PBx988AFmZ2exvLwMhUJRtBXFKBIikQhMJhPq6+tRWVkJm80GiUSyrWSUCR8IBAIoLS1FdXX15y67SyQSpKFD0zQoisLMzAz6+vpIg4upP8RiMfbv34+DBw9uaLJEIhFMTU1heHgYwWAQQqGQhCJYrdZNm0D5fB7Dw8M4f/48aX4xAQXd3d0wmUzIZDLw+/2IRCIQCoVQq9UoLS3d9Lwlk0mMj49jcXERfD4fDQ0NsFgs5PsQCoWIb/DlNRWPx0NNTQ0sFsuOTP1zuRxptDF1L9Noq6+vJ8dKUVSRhDQUCiEcDkOv16Ourg7T09OYmZkhIWjM9wxY/3cwNzeHxcVFhEIh9Pb2Ih6PQyKRIJFIwG63QyKRQKFQ4IEHHtjUWiWfz2NkZIQEeYXDYUQiEZjNZpSXl8NsNqO6uhqZTAZvv/02JiYmMDU1RbYbBQIB8Tm+5ZZbYLPZkMvl2PADlk8NtsHGwnKDMzQ0hD/84Q8bTOtzuRwx9jSZTNDpdBuKMZlMhqeffrpoy+p6iEaj+OCDD0jDiJleVVRUIJVKIRaLkdTG8vJyknLZ0NBwzRH0nzderxfvvPMOzp8/X2SWy0SsRyIR0mgTiUTg8/mora3Fd7/7XezZs6eoEUXTNLxeL4aHhzExMUFSsDgcDvR6PTo7O2Gz2TaEG0SjUZw7d26D/MBisaC7u5ukKW1FOBzGzMwMJiYm4Pf7SQGmVCpht9uJfx4jhb1w4QKZ9m1GSUkJOjs7sWvXro8VavF5S0inp6fxj3/8A36/H8C/QhAYOXZVVRXuvfdeWCwWvPzyy8RU+cr3yuVyYbFY0NnZiZ///Ocfa9OPheXThK3dWFg+P2iaxtmzZ3Hq1Kmi22dnZ+H3+0n6+OVNNmaT7Y477sBDDz30sZpsNE1jaGgIp0+fhsvlQigUgslkgtVqBZfLRTAYJM2SQqGA5uZmVFdXo66uDkajETRNbysZjcViWF5eJo2dG8GjNJVKYX5+Hi6XCxRFIZPJ4Ny5cxgZGSFbUMB6HVZRUYFf/OIXuPvuuzec51wuR0IRmDRNpVIJi8UCu92OqqqqDY+hKAqTk5MkdAIAURR0d3ejvr4eNE0jEAggFAqBw+GQ1Msr60Bgvdk3PDyMtbU1yGQyNDc3o6amhvw9m82SUIQrh4FlZWUwm83Q6/VXHZBu1WjT6/UbmoqMhHRxcRELCwvgcDjQ6XRIp9NYWFgAsD6cbWlpKRokM+dzamoKp0+fhsfjQaFQQCgUgtVqhV6vh1QqxXe+851NB9DAujJiYGAA0WiUhCCIxWLiTW2z2SCTyXDx4kWcOHECi4uLpCkHrCsmFAoFHA4Hbr75ZlAUxYYfsHwqsA02FpYvAD6fD88999wGb6hCoYCxsTGEw2FoNBpYrdYNF2mRSISnnnoKdrv9Y7+PmZkZIhsNBAKoqqpCQ0MDRCIRkRxEIhHEYjHU19fDZDKhuroaDofjuoIXPk9CoRCOHz+OM2fOkMYYE1s/Pz+PXC6HkpISlJaWksSk9vZ2PPzww5s2NJPJJKampnDp0iWsra0Rjwi5XI7m5mZ0dHRsaCal02n09fWhv7+fSBpFIhHZLNvKh4+Boiisra1hfHwcc3NziEajREJaVVWFxsZG2O128Pl8jI2Nob+/H263e8vn43K5sNlsZCL7cQr/z0tCmsvl8M477+DEiRPkM0in0/D5fMhkMuByuWhvb8c3v/lNDA4O4tKlS/B4PESycTkVFRVoamrCz372M7S3t3/s98bC8knD1m4sLJ8/c3Nz+Oc//1k0KHW5XJidnYXRaIRAIEA2myUhUkqlEnw+H/v27cOPf/zjj70dlkgkcPLkSQwODpKGWGNjIxkOra6uIhaLwe/3Q6lUorW1FeXl5cSHjJGMisVi1NbWFvll5fN5LC0tIZFIoKqq6mNZk3ySZDIZLCwsYHFxEfl8Hn6/Hx988AFcLhdyuRwZPPJ4PLS2tuL//J//s2mSK1P3jY2NYXJyErFYDBKJBHq9Hg6HA2azeUNtQtM0FhYW0NPTQ5pOAFBeXo7Ozk40NjZCIBCQQIRCoUACETbzwPV4PBgZGUEwGIRarUZra2tR2BkzzL1SKgv8yzu4rq7uqvYtuVwO8/PzcDqdRY02nU4Hq9VKAjEYwuEw8a8rLS1FJpPB/Pw8MpkMUU7U19cXNfgoioLb7cbhw4cxNjZGQqjKy8thNBpRUVGB/fv3b7kYQFEUpqamMD4+DoqiiIzVaDQST8Ha2lr4fD4cOXKEKCgYmbBCoYBOp4NGo8Fdd90FqVQKuVzOhh+wfKKwDTYWli8IuVwOf//739HT01N0OzMx8/v9kEgkaGho2HAR5fP5n1gTIJfL4cyZMzhz5gwpXOrr60mzJxaLIRgMwu/3kySrqqoqmM3mqzaEbkRisRhOnjyJU6dOIRwOg6ZpUrh5vV5wOBxIpVKUlZVBIpFArVbjrrvuwre//e0i2SgD46PBhCIwBTePxyNbUbW1tUXNq3w+j8HBQfT29pKEJz6fD51Oh+7ublgslqs2n5j3PDY2huXlZTLFZYqv5uZmGI1GeDwe9Pf3Y3h4mBRYm6HRaNDV1YW2traPNbW+VglpU1PTptua18ry8jIOHTpEGoo0TSMSiSAYDIKiKJSWluKOO+6AUqnE+++/j0gkQgyIL0cikcDhcGD//v04cODAF0oWzfLlh63dWFhuDKLRKA4fPkySDoH1xtbQ0BAMBgNKSkpICMH/x96fB8d1n1fC8Ol93zegse87CBALSUmkJCpabcX2OGMndjyJ4zh2HMeW7UneTDSpmUpVvsn8MY4rNd+8dmInduR4bMnabMlcLFLiIq4gSIDYdzS6G+h93/ve+/4B/n6+3QBIkAQpUuxT1QUSBIFG9+2+557nOecIBAIqsvX09OCP//iPd2TAtLy8TG2jXq8XNpsNHR0d0Gq1yGQycDgciEajCIfDqK+vR3NzMxVXSBD/ZpZRYD3iwuPxQKVSobKyctONrA8CZHOKlATMzMzg17/+Nd3aI1xCLpfj+eefx5e//OUthahkMon5+Xk6KBWLxTAYDGhtbUVra+umdkO3242zZ89iZmaGDul0Oh16e3vR09MDpVKJcDiMQCCAdDoNjUYDk8m0QcwiwtT4+Dji8TgqKirQ3d29gWcmk0ksLCxQoYsPUopgs9muy6Hy+TzdaCMRJ1sJbcQuPDc3R3PapqensbKyAovFgl27dmHv3r0bfh+O43D8+HEcPnwYa2trWFtbg1arhclkgtFoxKc+9Sns3r17y/sYi8Vw6dIl+Hw+sCwLn88HjuPQ2NhIiypUKhWOHj2KkZERTExMbLDvqtVqPPHEE6isrIRMJiuVH5SwYygJbPcZ0jkGV1bCiKfzUMvF6KnSQy4pXdA9KCB2g//7f/9vQQA7x3GYnZ3F2toahEJhQVgtgUAgwOc+9zk8/PDDO3JffD4fDh8+jJGREaytrUGn01FrqEQiwdraGoLBIHw+H930Ia2jWzUz3ctIJpM4c+YMjh07Br/fTzMp5ubmkEwmIRKJaBmCVqtFdXU1Pv3pT2Pfvn1bii7hcBhjY2N0MklgNpuxe/dudHV1FYhXHMdhfHwcZ86cobZFoVAIm82G/v5+tLW1bUvsikQi1ELq8/moqKXVatHS0kIbTK9cuYILFy4gEAhs+b0kEgm6u7sxODi45Vr/dkEspGNjY/B4PFt+HTnW2tvbb8tCyrIsTp48iV/96leUROZyOfh8PiSTSQgEArS2tmLfvn0YGhqim2ykHYtAKBSirq4O+/fvx5e+9KVSpkcJ9wxK3O3eQIm7lQCsD9h+/etf49KlS/Rzfr8f58+fh81mg9VqRS6Xo+UIOp0OEokELS0t+MpXvrIjFsx8Po8LFy7g9OnTWFlZQSaTQXNzM+rq6qDRaOD1euFyuRAKhZDP59HV1YXq6mo0NjaivLwcKysrW1pGU6kUbX2srKzcIKp8kCANqcRO+f777+Ps2bN02CgSiaht9Itf/CI+8pGPbClCkfb4sbExzM3NIZvNQqlUoq6uDm1tbaiurt4wTA6FQjh79izGxsbo8FKtVqOzsxN9fX3Q6/U0Gy8ej0Mul8NsNkOn020YuM7Pz2Nqagq5XA51dXXo7OzcsPnGsiwtRSjmcCqVCvX19airq7tuKUI+n6cbbfwGUyK08cW9WCxGn3u5XI4rV65gZGQEuVwOVVVV2LNnD7q6ujbcz/n5eRw9epTGh8hkMsolH3vsMXzkIx/Z8pqB4zgsLi7SgXA6nYbH44HNZkNFRQXKy8vR2NiIqakpHD16lDaoEthsNuj1euzevRt9fX0Qi8Wl8oMSdgQlge0+QTrH4N/PO/DLETe80TQYloNIKIBVK8fzu+z4/b3VkIlLZO1BwfLyMr73ve8VnDTJSjrZyLHb7aivr99wkv+d3/kdPPnkkztyPziOw+joKN555x3MzMwgmUyiqqoKLS0t0Gg0lNAEAgFEo1G0tLSgvr4e1dXVaGlp2Vbb0b2GTCaD8+fP49ixY1hbW0M2m4XT6cTS0hJYlqXFCCRXY/fu3fjsZz973Rw80uh06dIluFwuSr7kcjna29vR19e3QTBdWFjA6dOnqf1AKBTCaDSit7cXXV1d2yK2LMvC4/FgcnISc3Nz1EIqEAgKLKSrq6u4cOECpqamrluKUFVVhcHBQbS3t9/29Nrv91ML6fXy4aqqqtDR0XFbFtJgMIhXXnkFk5OT9HP8EgSVSoV9+/Yhl8thZmYGq6urWFxc3GAZJcLol7/8ZbS1td3SfSmhhJ1Eibt9sChxtxI2w9jYGN5++206KA2Hwzhz5gxUKhWqq6tpkDvwG5GtpqYGX/va13ZMtCIxGKOjo3C73VTssdlsUCqVmJ2dRSAQKLCN2u12tLa2IpvNbmkZJY2RoVAIZrP5httSdxt84Wl1dRVHjhzBxMQEcrkcxGIxvXV2duIb3/gGWlparvv9wuEwHQwGg0HIZDLYbDa0tbWhqalpg1gTj8dx4cIFXL58mYp7pABrYGAAVqsV6XSaFiKIxWJaiMAf1pJc4Pn5eQgEAjQ3N6O1tXXTgW44HKZFA3yXgFAoRFVVFRoaGq47+M7n81haWqJiIkGx0JbL5eBwOJBOp1FRUYFgMIijR49icXEROp0ONTU16Orqgt1uh9FopNcnbrcbp0+fhsvlwpkzZxAOh6lTobW1FQ899BAt4djsWEqn07h8+TKcTic4jkMoFEIqlUJDQwMMBgO18f7iF7/A7OwsxsfHKc9WqVQoKytDdXU1Dh48CJVKVSo/KOG2URLY7gOkcwz+6xtjODnjA8tx0MolkIiEyDEsoukchAIBDjRb8Hef6CwRtQcIiUQC//Iv/4KxsbGCz6+srNAJDWmgLJ56Pvvss/jYxz62Y6QnlUrh3XffxenTp+F0OiGRSNDW1ga73Q6DwQC/3w+n0wmv1wuRSISuri5UVlaiqakJVVVV9xT52i5yuRyGh4fxzjvvwOVyIRqN0kB8oVBIhTaDwYCqqir6mG9mGyUgORpXrlzB1NQUDa0lrU6kFIFPoFwuF06fPl2QEabX69HV1YWenp5NG5k2QzabpS2kKysrlPhJJBKapWEwGHD58uWCSvrNoFKpsHv3bvT39982SWFZllpb76SFlOM4DA8P0wZR8rMDgQDdJqitrUVdXR1mZmYQjUYxOTm5wYIhl8vR1taGz372s3juuefuy2O7hA8PStztg0OJu5VwPXi9Xrz66qt0ez0Wi+HMmTPgOA719fUQi8WIRCLgOA5arRZSqRRlZWV44YUXtn1e3w4mJydx/PhxzM/PIxwOo6amBi0tLbBarchms5icnKRFCPX19WhtbUVNTQ3Ky8vhcrnAMAyqq6s3nOsjkQhcLhekUimqqqo2zRb7IMFxHNxuN+bn5zEyMoLDhw/T34cUWCkUCjz33HP48pe/fF3uBvymcGB0dJQWLOj1ejQ2NqKtrQ1Wq7WAD2QyGVy6dAmXLl1CNBoF8JvIjsHBQVRVVSGfz9NCBI7jYDAYYDabC4aJ8XgcY2NjtICss7NzQ8wIAbHLzs/P02xhAiJEVVVVbem6IDbhubm5Au5jt9vR3NwMrVYLjuOwurqKQCAAo9EIvV6PoaEhDA8PI51Ow2AwoLGxkdpBrVYrNBoNgsEgTp8+DZ/Ph+HhYTgcDirCkRgTnU5H7+NmcTNutxvDw8NIpVLI5XLweDzQ6XSorq6GyWRCfX09zp07hwsXLmB8fJxyO6FQiIqKCphMJjzxxBMoKyuDxWIplR+UcMsoCWz3AX5wehHfOzEPhUQErWLjhkY0lUMqx+BLjzbgC4/UfQD3sIQPChzH4Ve/+hV++ctfFmwWra6u0uYcYi8onk4dOHAAv/d7v7ejmWjLy8s4dOgQxsbGEAqFaKOo0WiERqPBwsIC1tbWqG2UNFaRDJD7EQzD0C0+h8MBt9uN6elpRKNRiEQiSCQSSKVSGI1GtLe34zOf+QwefvjhG2Z1JZNJTE5O4sqVK/B6vVQ802q16O7uxu7duwsIXyAQwPvvv4+xsTE6FVer1Whra8Pu3btviijEYjHMzs5iYmICXq+XiloajQYtLS1ob2+Hz+fDhQsXsLy8vOX3IVPVwcFBNDQ03LbYlEqlMD09fUctpIlEAm+++SYuXLhAP8cvQZBKpWhra0MkEkE8Hsfs7GyBvRdY/71ra2vx9NNP4wtf+MI9ZZMp4cFCibt9cChxtxJuhEwmg7feegtTU1MAfhNFkUwmUVNTA61Wi0gkAoZhoNVq6dDuG9/4xo5e/GcyGZw6dQpnz56F0+mEQCBAR0cHqqqqUFZWhpWVFczPzyMYDNK20bq6OjQ0NNDcOLPZDLvdXsApyYZ/Op1GeXn5jgqDOwWO4+DxeKiN8NSpUwiHw9TqKJPJYLFY8IUvfAHPP//8DXkMyQMjpQjxeBwKhQJVVVVob2+n4ikB4ZDnz5+nrhTSyjo4OIj6+noIBAKEQiG6Va/VamE2mwvKwwKBAK5evQqfzwetVotdu3ZdN7aDlCK43e4NpQi1tbWor6/fkrtsJbSVl5ejpaWFHrdE9KuoqMDi4iKGh4cRjUbp/ddqtRAIBJDL5bBarZDL5bhw4QJCoRCuXr2K2dlZ2Gw25PN5WCwW9Pb2QiKRQCaToa6uDrW1tRucMLlcjlp3gXU+Gw6HUVdXB6PRiJqaGiQSCRw5cgTT09NwOBz0/xLnCbGzGgyGUvlBCbeEksB2jyOdY/C7/3QOzlAS5ToFMpkMBAIBRCIRRCIhgPU3+tVICpUGJX76J3tLuR4PICYmJvD973+/oI3R5/NhenoaLMtCIBCguroa1dXVBeSgv78fn//853c0jJZhGFy4cAHvvPMOFhYWwDAMGhoaUF9fT0+UExMT8Pv9tG20ubkZtbW1aGpqum8DRlmWxcTEBN59913Mzs5ienoa8/PzyGQyBUKbxWLBZ2rr8JhWC7XmxsKL8uGHEW5vx/DwMBYWFiiZEYlEaGpqQn9/f8HzGovFcO7cuQL7gUKhQH19PQYHB1FZWbltoYts1E1MTGB2drbAQmqz2dDe3g6TyYSxsTGMjIwUWAeKYTQaMTAwgN7e3h1plL0ZCylpsrqZY2t6ehovv/wyJbzEshMKhcCyLKxWK1QqFcLhMK2sLz6dGo1G7NmzB1/96ldRX19/a79oCSXcBkrc7YNBMXfbCiXuVgIAGjtBSpTOnj2LcDgMq9UKm82GWCwGhmGg0Wggl8uh0Wjwta99DdXV1Tt6P1ZXV3H06FE6XLNYLOjs7ITdbodWq6Wt2n6/H1qtFj09PaipqYHJZEI0Gt3UMkpaOL1eL3Q6Hex2+z1bBkS2p372s59hdHSUZrHK5XLI5XJ0dnbia/39MPCaQa8HyeAg3LW1GB0dpaUIRqORliLwB4Asy2JmZgZnz56lRRik8X1gYIDGYESjUfj9fqRSKahUKpjNZjpsJVt5Y2NjiEQisNls6O7uvq6wmUqlaCkC4YwENpsNDQ0NKC8v35Q3MgyD5eVlzM7OFghtZWVlaGlpgVwup62tlZWVSKfTOHfuHFwuF8xmM918jEajdDCsUChoQcPi4iIuXbqEyspKKJVKKBQKtLa20uNLJBKhuroaDQ0NG6y4gUAAQ0NDiEajYFkWXq+XbgjqdDrYbDacOHGC8kgySCYZcM3NzThw4AAMBkOp/KCEm0ZJYLvHcW4hgG/+7ArkEhFUMjEWFxeRTCYhFAogFIogFosgFksgkMgBsQQv7DHgyZ51lb4U0vhgIRgM4nvf+15BJXgoFCo4cRiNRrS0tBScKDo6OvClL31px9f3w+Ewjhw5gnPnzsHr9UKtVqO9vR3l5eWoqqrC8vIyFhYW4PF4IBKJ0N3djaqqKmotvV9BCifee+89jI6OYnh4GC6XCyzLQigUQiKR4OtiCR4RCCCSySCVySDcQvBis1lon30G9r/7OwDrj+nVq1dx9epV2iYKAFarFb29veju7qbTvEwmg4sXL+LixYvUCkBakgYGBlBXV3dTJJdYC8bHx2nGBrBuzSTiaCwWw9DQ0IYSAD7EYjG6u7sxMDCwI88z30I6Nze3IReNgGxydnR0bPscls1mcfjwYbz33nv0+/JLEEQiEaxWKyKRCFiWxezs7AaRUSqVor29HV/84hfx+OOPlyyjJdxVlLjbB4Ni7gZwWFtb37oVS8SQSCSQiCXIcQJkWQH+4VO7sK/R8sHe6RI+UDgcDrz++uuIx+PI5/M4d+4c/H4/1Go1qqurkUqlkM/noVaroVAoIJfL8Wd/9mdobm7e0ftB4hKOHz9OhY6mpiY0NjaitrYWiUQCw8PDCAQCdDOoo6OjQITZzDKaSCTgdDoB4J4PkifZYT/+8Y9pEYRQKFx/zDlgIJeDVKXakrsBhfwtn8/D5XJhdHQU8/PzNNu1oaEBra2tqKyspFtSHMdheXkZZ8+exeLiIh3cmc1m9PX1oaOjAwqFAolEgg6qZTIZzGYz9Ho9BAIB3TAjzeck++x6jznLsnC5XJifn4ff7y/4N6VSSUsRNrtWIELb3NxcgUhXVlaGpqYmJBIJmsmn0+kwPDyMqakpGkHS0dFBB7qk3XVychLpdJpyWaPRCJvNhvLycuoiII+NQCCgDal8MZFlWUxNTWFychIsy1I3QkVFBaxWK6xWK5xOJ86dO4exsbGCWBar1YrKykocPHgQFRUV9/wxW8K9hZLAdo/jnQkPXnz9KvRKKaRi4fqUIJvZ8HWcUARIVJCOvAp5cBZyuRxqtZoGjFZVVdF1WrvdDovFcl8GzJdwfeTzebzyyit477336Oei0WiBbZBkRPHthfX19fjqV7+6ZTX57WB6eppaIFKpFOx2O9ra2lBWVgaDwYDR0VE4nU5qGyVCW0dHxx25P3cTi4uLOHnyJE6ePImhoSGao/EtmRx7RSL4BALIZDJYrZb1VXkUkrWcywX1EwepwEaQzWYxPz9PSxH4z21HRwf6+/thNpsBrB8TIyMjOHfuHN3GEovFsNvtGBgYQFNT002/F8TjcczMzGBychIej4cKuCqVCi0tLdDr9Zifn6ekZitUVFRgcHAQnZ2dO7JFebMW0u1ak51OJ37605/SiwNgfVMwEAggn89DqVQin89DKBRiZWWF5qkQkA3Sj3/84/jDP/zDHWmCK6GE7aDE3T4YFHM3lmXgcKwUtH8LBAArEIEVy9GZHkOzKgOLxYKysjKUl5fDZDLBaDTCYDBApVKVxPkHAPF4HK+99hpWVlbAMAwuXrxIN59qamrAcRxyuRxUKhWUSiUkEgn+5E/+BN3d3Xfkvhw7dgwXLlzA2toaFAoFbRStra3F9PQ0pqamqG2xq6uLbhJJpdJNLaOkfTMajcJqtcJisdzTx3UoFMK//uu/4vXXX6etql+FAAMAojIZTCYT9AYDNvsNNuNvZBN+cnIS4+PjCIVCkMvlBaUI/A1/j8eDM2fOYGpqinIpUjjR09MDjUaDTCZDc9pEIhG1OYpEIuRyOUxPT2Nubg4Mw6CxsRHt7e033MaKRqOYn5/H8vJywXuWUChEZWUlGhoaYDKZNvw/lmXpRhtfaLPZbLDZbIjH41AqlaioqMDS0hKGhoYQj8epfbampoYOMb1eL0ZGRuByuZDNZjE2NgatVguLxYLGxkY888wziMViG+6jyWRCQ0NDQSECGf76/X5wHIdgMIh8Po+GhgZoNBqIRCIqspHtQY7joNFoYLfbsW/fPnR1daGioqJUflDCtlAS2O5x8KegSqkIi4uLyGazYFkWHMeCPHmcSAqBSArJ8P+FKLCw5fcTCASQSqWQyWTQ6XSUzFVVVaG+vh5VVVX0jfBeXeEu4cY4f/48XnrpJdqSk0gkcPXqVbpdIxQK0dDQgLKyMnoCstvteOGFF246s2o7yGazOHnyJC0EEIvFaG5uRk1NDRobG+lE1OPxIBaLobm5Gc3NzWhqakJDQ8N9fyw6nU6cOHECv/jFL3DlyhV8Oc9gn0gEzzW7pUAohFKpRFlZGZQ8crWVwEZAskOuXLlCJ5XAbwQdIqAJhUJwHIeJiYkN9gOr1Yq+vj60t7fftHWTWD+IhZRMFAUCASwWC2pra5FKpTA2NrYhUJcPpVKJ3t5eDAwM7FhGy05bSFmWxYkTJ3Do0CH6OmIYhjbkAqCENp1OY3V1dcP30Ol0OHDgAL7+9a+joqLiNn/DEkq4MUrc7YNB8QYby7KIRCLI5XLI5/PI5/NgWRY5TgBGIEa16zjUqTX6/0kUCOFrSqUSJpMJZrMZVqsVZWVlMJvNMBgMMJlMO2K7L+HeAMuyePfdd3Hu3DmwLIvLly9jZWWFbunIZDI62FGpVBAKhfiDP/gD7N27947cn4WFBRw9ehSTk5OIRqOw2+3o6OhAfX09dDodzpw5A5fLBZ/PR8Ufk8kElUoFvV6/wTIKrG+Ira2tQS6Xo6qq6p6337ndbnz729/Gu+++i88nkhgABy8EEAmFUCgVsFptBdwNuDF/y2QyWFpawpUrV+jzq9Pp0NzcjLa2Nlgsv9loDYfDOHv2LK5evVrQftnR0YG+vj4YjUbk83kEg0EEAgFwHAe9Xg+TyQSZTIZkMomJiQksLy9DLBajra0NjY2NN8wWI62g8/PzGwaHer2eFg4UD0hZloXD4cDs7CxSqVTB/1EoFFCr1aiqqqKW0ZWVFZjNZioAkuMlHo/j3LlzuHjxItLpNK5cuQKRSASLxYLm5mZ8+tOfhs1mo44YvqinVqsLChE4jsPi4iJGR0eRy+WQy+Xg9XqpECyVSjE7O4urV6/SeB1gndeRrOhHHnmEXieXUML1UBLY7nFsluPBcRwYhkE+n0M6k0Emk0EkK4AkG0P55M8Qj4SQTCavm4e0FUhWlFwuh16vh8ViQXl5Oaqrq1FfX4/y8nKUlZXBarWWQh/vcbhcLnz3u9+F1+sFsB7UPjo6umF9my9gmc1mvPDCCwUn9p2Ex+PBW2+9haGhIUQiERiNRprv0dTUhImJCboVJRaL0dXVRdfH79R9upvweDw4cuQImH/8R7TF4vDwtrsEAgGEQiE0Wg3KbGWQSqU3JGh8kFKEy5cvw+v10tV5kpPS29tLA2sXFxdx5swZLCwsgOM4CIVCGAwG9PT0oKur65YKJ/L5PBwOB8bHx7G0tESPM5FIhJqaGigUCqytrd2wFKGxsRGDg4PbIn/bAbGQjo2NYX5+/rYtpIFAAK+88goNpQbWN+d8Ph+y2SxyuRyy2SzEYjFWV1c3tJ5KJBJ0dHTga1/7Gvbt23fbv18JJVwPJe72weBGGWwcOHAsi9VIGlalCN/aLYZvzQ2v1wu/349gMIhUKoVkMolMJrMh35EMS0ngNxHgSGaXzWaDwWCAwWCA0WgsCXD3IaampvDLX/4SmUyGnr+A9UGNwWAAwzBUrACAT33qU3jiiSfuyH0hltXjx49TMai1tZVaHIPBIM6fP08tfiQ2ggwOa2pqNmz+ZDIZrKysIJfLoaKi4r4ourpy5Qomv/pV1PsD8NIVBwHEIhG0Oh2sFgsVm7bL30g+2NjYGKanp5FIJKBUKlFdXY22tjbU19dTjp5IJHDx4kXalAmsOxeampowODgIm81Gt+QCgQAymQwtFFAqlYhEIrh69SrW1tagVCrR2dm57Rw/n8+H+fl5uFyugvcjEhNSX1+/oWl1M6GNcDCLxYKmpiZoNBqMjIxgfHwcIpEItbW16O7uLihmW1pawnvvvQe/34+hoSEkEgkoFArY7XZ87nOfQ19fH7W4zs3NFQx0iwsR0uk0Ll++TB0J0WgUyWQStbW1UKvV8Pl8uHr1asFwlmVZmM1m1NfX44knnkB9fX2BrbeEEopREtjuA9xsE1UqlUI0GsXa2hrm5uawsLAAp9OJtbU1+P1+JBIJJJNJOgW5GZCgdplMBqPRCKvVCrvdjpqaGlRXV1NiZ7FY7vutow8D0uk0fvjDH+Ly5csA1jfJrl69WlCGQJomCQHXarV44YUX7tiGDcdxuHLlCt58800sLi6CZVlaC9/Q0AC9Xo/Tp0/D4XDA7/fDYrGgu7ubVp1/GOx1i//5PyP6zjG4r03R+G/DQqGQ2oNuRmAjYBgGDoeDliIQoZ1sDfb399Oig9XVVbz//vsF9gO9Xo+Ojg709PRsagHYDpLJJGZmZjAxMQGPx0PX91UqFWw2G53c8kNxi2EwGNDf34/du3fvWO5FKpXC1NQUxsfHb8tCynEcLl26hNdff52+ljiOQygUQigUQjabRTweh0QiQSQS2RAcDKxvzn3mM5/BZz/72Xt+el/C/YsSd/vgcDstogzDIJFIIJFIIBgMwu12Y3V1FT6fDz6fD9FoFIlEAqlUCtlsdoMAR4alMpkMEokEGo2mQIAzmUxUgDObzTuewVrCziAYDOLnP/85vF4vjWUAQJstSQC/Wq2GQCDARz/6UXz0ox+9Y7bLQCCAI0eO4NKlSwgEAjAajbRRtKmpCZcuXcL4+Dh8Ph8YhkFbWxvNzmpoaEBFRUXBfeM4DmtrawgEAjAYDCgvL7/nRQv3X7+IwOHDWM3n1197IDlgQpiMRtrceSv8jTSTj46OwuPxQCKRwGQyoa2tDa2trVTAymazuHz5Mi5evIhIJAJg/fqsrq4OAwMDqK6uhlAoRCwWo9d9CoWCNncSQS8QCMBkMqG7u3vbQ+xUKoXFxUUsLi4WbKcB63nADQ0NsNvtBc8zy7JYWVnBzMwM/T9kI45wLYfDQYfvFRUVaG9vR11dHf0+Pp8P77//PhKJBM6dOwen0wmxWAy5XI4nn3wSBw8ehNVqhVQqhdfrxdzcXEGWXHEhgtvtpkIlwzDw+XxQqVSorq5GPB7H5OQkpqam6JICaZWtqanBgQMH0Nvbi6qqqlLcUgmboiSw3QfI5Bm8+PoYTs74wHIctHIJJCIhcgyLaDoHoUCAA80W/N0nOiETX1/U4jgOiUQCkUiEKv1LS0tYWVmB1+tFMBikpI3vad8O+PZTsgFXVlYGu91Os9+sViut6i5dVN4dcByHd955B6+99hpYlkU+n8fY2FjBurdYLEZLSwsVVJRKJb761a+ioaHhjt0vUpNNplJKpRLt7e2orq5Gd3c3/H4/tTMmEgk0NDSgs7MTLS0tqK2tvedJ2PXgfvFFxI8dh8Bmg8/vRygUApPPgwMHhVyBxsZGCASCWyJofITDYYyMjODq1auUhAHreRi7d+9GV1cXJBIJgsEgzpw5U2A/UKvVVJDjZ1ncDDiOQyAQwMTEBKanpwtCaY1GI8RiMRWltoJYLEZHRwcGBwc3kPPbwXYtpNXV1ejo6NjUQppIJPDGG2/g4sWL9HPEdpBKpWhGG8dxiMfjG+67RqPBE088gW9+85sfig3NEu49lLjbB4ed5G7FyGazmwpwXq8XPp8PsVgMyWQS6XR6g5tBIBBALBbTgalUKoVGo4HFYoHFYqHbb8RiZjQaSwLcB4hcLodf/epXGBsbw+LiIkZGRgCsP4/EHiyTyaDRaCAQCPD444/j05/+9B0T2TiOw/j4OA4fPoy5uTlks1nU1taira2NRk2cOHECy8vL1DZaV1dHS7aam5s3iBKxWIwKJlVVVff0IJXwN1F5OXw+H0LhEPJ5BlKJBHX1dZCI13nC7fC3fD4Pp9OJ0dFRLCws0HILfikCKTMYGxvD+fPnabmUSCSi+bYNDQ0Qi8VIpVLw+/2IRCKQSqUwmUzQ6XRwuVyYmJhALBZDRUUFuru7N2yhbQWWZbG6uor5+XkqQhGQ5vq6urqC55IIbbOzs3Q71+fzwWAwYP/+/VAoFLhw4QKWl5dhNBrR0NCAjo4O+j2i0ShOnTqFRCKByclJLC0tQSAQIJVKoa+vD3v37qXLH0ajEbFYDHNzc3C73RsKERobG6FSqXD16lW6HZpMJhGJRFBVVQWZTIbp6WnMz89Ttwf5Hna7Hb29vTh48CDq6upK5QclbEBJYNthpHMMrqyEEU/noZaL0VOl35Hq9UyewY/POfDLETe80TQYloNIKIBVK8fzu+z4/b3VN03QNgPDMIjFYggGg3A6nZifn8fS0hLc7nXrQigUQiKRQDabBcMwG6am1wPJE5FKpVCpVDAYDCgrK0NFRUXB9pvNZoNWq72nT7D3I2ZnZ/FP//RPiEajtKEnGAwWfE1VVRVqa2upWPrlL38ZHR0dd/R+LS0t4fXXX8fVq1eRyWRgtVrR0dGBuro6tLa2YmhoCCMjI3SaR0S2jo6OghXy+wmEoEmubQlmczl41tYQi8dRYbfTHLzbFdgIstks5ubmaCkCsS0qFAp0dnaiv78fRqORTgb59gNClPr7+2mWxa0gn89jZWUF4+PjBXXwQqEQarUa6XSalkBshfLycgwODlJhcCewExbSqakpvPLKK7REAgCdHEciEQSDQUilUkSj0Q2bvWKxGJ2dnfjLv/xL9Pb27sjvVEIJBCXutn3cCf52t7gbHxzHIZ1OI5lMIh6PIxAIUBeDz+dDIBCgToZUKrXBzcC3n5KhqVarhclkgs1mo22FxH5qMBggk8nu6bD6DwMuXbqEo0ePwuFw4NKlS/RcSbLO5HL5elmSQIA9e/bgD/7gD+6okySdTuPdd9/FiRMn4PV6aclSY2Mjenp64HA4cObMGayuriISiaC8vBw2m43yieK8VSIqJRIJlJWV3fIG/Z1GMX/LZDLwer3Upk2wE/yNbMZPTExgfHwc4XAYCoUCZWVlaG9vR2NjI+RyOW2vJ3l4wDq3stls6O/vR2trK6RSKbLZLC1EEAgEMBqN0Ol0WF5extTUFHK5HOrq6tDZ2XlTono0GsXCwgKWl5cL3k8EAgEtRSDFW8A673I6nZiZmaFcKZ1Oo6GhAbt27YLT6cT4+DgEAgFqa2vR2dlJh5CpVAqnTp1CJBLBwsIC5ufnodVqEYvF0NDQgL6+PiiVSohEIpjNZliu2XbJ/SsuRGhsbIRIJMKlS5cQjUbpYJgUiywsLGBmZgYLCwvUdZHL5WAymdDc3IznnnsObW1tpfKDEgpQEth2COkcg38/f+dJVDrHYGQljFg6D41cjF07JOBtF7lcDtFoFF6vFw6HA4uLi1heXsba2hq8Xi+1LeRyObAse90GwWLw7QxqtRrGa6vWpB7ZYrHQWmWNRgOFQlEidDeJSCSCf/7nf8bs7CxYlsX09DSdehHo9Xp6MhaJRPijP/oj9Pf339H7xTAM3n//ffzyl7+E2+2GWCxGY2MjGhsb0d3dDYVCgXfffRfz8/MIBAIwm83o7u5Ge3s7Wlpa7rsV7WKCRpDJZApIzU4JbATEjjE8PIyZmZmCUoS6ujrs3r0bzc3NyOVyGBoawoULF+imo0wmQ0VFBfbs2YO6urrbav1MpVKYmZmhVk1CeIRCIQQCASU5W0Eul9NShJ0k4TdjIe3s7ER7ezu1kGazWRw+fBjvvvsuve+kBIEMLAQCATKZDM3b48Nut+OLX/wi/uN//I8le30JO4YSd7sx7gZ/+6C5Gx8syyKVSiGRSCAWi8Hn88HtdlMBLhgMIplM0g244mEqf1gqlUohl8uh0WhoAYPRaCwQ4MjA9H7eOr9X4HK58Nprr2FmZgYXL16kwzIiJqhUKuh0OggEAnR1deFLX/rSHXeLuFwuvP322xgdHUU8HofNZkNXVxfNDzt79iyuXLlCG8fLy8tRWVmJwcFBtLa2buDxfr8fHo8HarUaFRUVO9IwvpPYir8VY6f5WzqdpqUITqcTQqGwoBTBbDaD4zg4HA6cPXuWbl0B60LS7t270dnZCaVSCYZhEAqF4Pf7wTAMdDodNBoNlpaWMD8/D4FAgObmZrS2tt4UHyFZvPPz8wWuCWA9fqahoQE1NTX0OSW5aTMzM3C5XAiHw1AqlWhqaoJKpcLCwgJCoRAqKipojIxQKEQul8PZs2fh8XjgdDoxMTEBi8WCdDqN+vp67N27FyKRiA5y5XI5rFYr9Ho9VldXNxQiaDQa1NXVIR6P04IDIkaWlZWB4zgMDQ1hZWWF/l75fB5SqRS1tbV4+umnsWfPnlL5QQkUJYFtB5DOMfivb9wZG8D9hlQqRe2nDocDKysrcDgcWF1dhd/vRzweRzKZBMMwYBhm2wKcUCikZI6swhMLalVVFcrKyqi1gZAM0q5UQiEYhsEbb7yBo0ePguM4zM3NbWg8lMlkaGtro9PQz3zmMzhw4MAdv2/BYBBvvvkmzp49i3g8TrOw6uvr0dfXh6WlJZw6dQoulwvJZBINDQ3o6upCR0cHqqqq7hvB9YMiaHwkk0mMj4/j8uXLtLocWBePenp6sHv3bshkMoyMjODcuXM0y0IsFqO8vBz9/f1obm6+LesQqUufmJjA1NQUtZByHId8Po9MJoNsNntdAbWhoQGDg4Nobm7e0de7z+fD+Ph4QTvrZii2kDqdTvz0pz+lAbrA+vui1+uF2+1GIBCAQCAAy7Ibfi+1Wo1nnnkG3/rWt0rT0BJ2BCXudn2U+NtG5PN5JJNJJBIJhMNh+P1+uN1ueDwe+Hw+RCIRKsARNwPhcsR+SviaRCKBUqmkG3AWi4XaT4kFVaVSQS6X33NCyr2KZDKJN954A0NDQzh37hzdGGJZlmbrabVaCIVCNDU14Stf+codt7AxDINLly7hV7/6FS1BaGpqQltbG3p7eyEUCvHOO+9gfn4efr8fQqEQlZWV2LVrFw4cOACVSlXw/VKpFFZWVsCyLCorK2mRw72AD5q/MQxTUIqQSqVoKQLJLRMKhfB6vThz5gympqaoEKvT6dDd3Y2enh5otVpwHIdIJEI3yNRqNeRyOZaWluB0OiGTyaij5Gb5dSAQwPz8PJxOZ8G1HtkMa2hooANKjuPgcrkwMjJCbZ8WiwVqtZpu4ZLW0o6ODiiVSrAsi6GhISwvL8Pr9WJkZAQWiwX5fB5lZWV46qmnUF5eTt/D+I+B2WxGKpXCwsLChkIEi8VC3QcAaPuz3W7H5cuX6TEMgMbuVFRU4JFHHsHTTz99W06PEj48KAlsO4DbCbJ9kEDy34LBIBwOBxwOB1wu14YChkwmQwlbcQPfVuATOrlcDp1OB51ORzfgyOab1WqFwWCAWq2GSqV6oAnd8PAwfvSjHyGVSmF5eRkOh6Pg3wUCAerr62lY6cc//nE888wzd0XEGhsbw89//nPMz8+D4zhUVFSgra0NnZ2dqK+vx/vvv4/h4WHaNtrR0YGuri50dnbeF01U7hdfRPTQYQhvsHnHZrPQPvvMHRHYCBiGwfLyMi5duoTFxUVK1sViMVpbW9Hf3w+73Y7p6Wm8//77BfYDi8WCvr4+tLW1bSDIt3I/nE4nxsbGCqaLJMQ7kUhArVZvSVx0Oh0tRdhJMs6yLBYXFzE+Pr5tC6ndbsd7772HQ4cO0ceTWD2cTidtUs3lcpDL5QUTYpFIhO7ubvzN3/wN2traduz3KOHBRIm7XR8l/nbzIO/H8XgcwWAQPp8PXq+XbsDF4/GCLF8yUAXWzxsSiYQ6FqRSKdRqNbRaLYxGI82G0uv11H6qUCggl8vvu031OwmO43Dq1Cm89dZbOHPmDLWu5fN5aDQa2Gw26PV6CIVCVFVV4Wtf+9pd4UbRaBSHDx/GmTNnEA6HodVq6RB0165dmJmZwcmTJ+FyuRAIBGjJ1uOPP46WlpaC78WyLNxuN8LhMMxm8y1nwe407iX+FovFMDs7i5GREfh8PkilUpjNZrS0tKCtrQ1qtRrhcBjnz5/H6OgozWNUqVRoa2vD7t27qfUyHo/TRQiycbq8vAy/3w+NRoNdu3ahvLz8pu8j2bxbWFjYMKy0WCy0FEEoFILjOCwvL+P06dMIhUIwGo2Qy+UIhUJIJpPQaDSorq5GZ2cn3RYbGxvD5OQkQqEQhoeHodfrIRaLodfr8fDDD6OlpQV6vZ6+V4XDYQC/2foEQK9BCYRCIUQiEaLRKIRCIXUjECvqxYsXsba2BmD9tZjJZGjhxyc/+cn70llTws6iJLDdJoqr2LPZLBaXFiESiui2lUQiQYwRoVwjxfc/242KMktJ3d4EJP9tbW0NDocDTqcTbrebTk6JbSGXyxVswN3oEBYIBAVZImSaSoIwyfYb2YDTarV0A+7DHuzr8Xjw3e9+F263Gy6XiwZ98kGqtMViMZ588kl88pOfvCskJ5PJ4MiRIzhy5AhCoRDkcjkNyB0YGADLsnjnnXcwMzNDm5B6e3vR1dW1aSD9vYTo4cOInzq1ra9V798P7TPP3OF7tI5wOIzLly9vKMEoKytDX18fOjs74XQ68f7771P7Acnx6O7uRnd3N82Pux2k02nMzs5ibGwMa2tryOfzyOfzCIfDSKVSEIvFWwbxikQidHR0YGBgYMe3GrdrISVNrGVlZTh06BCmp6fpv2WzWayurmJubg6RSAQMw9DBAB9lZWX4+te/jo997GP3xEVFCfcnStxtaxTzt3Q6jWg0CoVCAYVCQQdwq5EUKg1K/PRP9n5gts77Bfz8t2g0Cp/PB7/fD6/XC6/Xi0AgQAU4MkwlIhyw/v5NeDPhbESAI62nOp2O2k91Oh3kcjkUCgWkUukD9145NzeHf//3f8fx48epeJHP5yEWi1FZWQmz2QyhUAir1YoXXnjhruWazc3N4Y033sDMzAxyuRwqKyvp5lRFRQVOnTqFoaEhuN1uRCIRVFdX46GHHsJjjz22YUBGXDEymQyVlZUfOC+/F/lbLpeD0+nEyMgIFhcXwTAMNBoNGhoa0NbWBrvdjlQqhaGhIQwPD9P2c5lMhsbGRgwODqK8vBwCgQDpdJpucAmFQpqNF4vFYLVasWvXrg35edsBx3G0FKGYP8nlclqKoFAowLIsRkZGMDw8DGC9FMvv92NtbQ0KhQLNzc1obm5GU1MTRCIRFhYWMDw8jFgshkuXLkEqlUKr1UKtVmNgYAA1NTWorKyEWCxGNpulbcwka1gul0Mmk9H3LHJdmc1mEYlEIBKJoFKp6GavWq3GhQsXsLi4iHw+D6FQiGw2C4lEgvr6enzyk5/Evn37SuUHDzBKAttt4txCAN/82RXIJSKoZGIkEgksLS8BvEdVKBQCEhkEIilaw+dh5sIFEzqTyUSDGG02G6xW6z0tDnxQyOVyCIfDcLlcWF5ehtPpxOrqKs1/i0QitG6Zf7vRIV5sP1Wr1ZTAka038vyQlWWyASeXy+97QpfJZPDv//7vOH/+PDweD2ZmZjY8ZqThU6lU4uGHH8bv//7v3zWR2O124+WXX8aVK1eQzWZhMplo/lVfXx8mJibw3nvvYWVlBalUCnV1dejt7b3laVsJ66RiZmYGw8PDcLlcdGtLqVSiq6sLfX19yGazOH369Ab7AbGE7FQrZigUwvj4OKamphAOh8GyLKLRKEKhEHK53HUb7mw2GwYHB9Hd3b3j08TtWkirqqrAcRwuX75Mt/I4jkM0GsX09DRWVlaQz+chEokKLuqB9cf7Yx/7GL75zW/eUxaZEu4flLjb1ijmb4FAAD6/DwIIIBAAIrEYMqkUkMghEEnxF/ut2NdogVarhUajKfG0WwA//y0UClEBjohwZJBKsnyLBTi+W0EqlUKhUECj0UCr1UKr1VJuTWyoJP+NNG1+GIfb4XAYL730El5//XVqdyP8l0SoiEQi6PV6vPDCC3eNF+VyOZw8eRJHjhyh21VtbW3o7u7G7t27kU6n8c4772BqagorKysAgF27duHpp5/ekP2VzWbhdDqRTqdRXl5+SwLPgwASu0Ea0snAoLy8HK2trWhubgYAXLlyBRcvXqTbXBKJBLW1tVSMIsIayZDN5/OIxWLwer3IZrOorq5Gd3f3LQtI8XiclugVlyLY7XY0NDTAarUiGo3i0qVLcDqdUCqVyGQymJubQywWQ2NjI3bt2oXOzk6o1Wqsrq7i7NmzSCQSGB4eBsuyVJDv6emB1WpFVVVVwSYnyaDkW0hlMhm1x5NroUAgQNtOtVotQqEQRCIRXC4XLl++jFgsBplMhnw+j1wuh/Lycjz33HN4/vnnS3EfDyhKAttt4p0JD158/Sr0SimkYiFCoRA8nrVrm1W/+TpOKAIkSsjH3oAmtgSNRgONRgO1Wr1pgCTJGCMCnMlkomKPzWaDQqG4i7/l/YFUKoVAIEDtjkR8I7khsViMTkz5FtQb5cDx7QykTUun08FoNMJoNMJsNtPnx2QyUfGN3O4HAY7jOJw8eRI/+9nP4PV6MTk5ueFxEYlEaG5uhsViQW9vL77whS/ctQsMlmVx/vx5vPrqq3C73RCJRKirq6Mim91ux8mTJ3Hu3LkC2ygJdb1d++KDCjJxHB4exvT0NBWIiH24v78fJpMJ586dw8jICLUfqNVqNDY2or+/n05FbxekdYpYSIl1dHV1FbFYDCqVCiaTadMLKJlMhp6eHgwODhY0We0Etmsh5TiONvgRsYzYcycmJhCLxcBxHN2eIY+ZUChET08P/vZv/xYNDQ07et9L+PCjxN22RjF/i0QiCEfCyOdyYFgWHLtO4jihCKxYgfrAOTQq0zCbzTCbzdDpdFCpVHToxud1Wq22xNNuAST/jeQueb1e+P1+BAIB+P1+hMNh2n5KNpuJCFfsViACHOFsKpUKWq2W5r8ZDAaoVCpqQS22699PyOfz+MUvfoHvfve7CIVCANbPL9lsFjabDbW1tRCLxVCpVPja176G2trau3bfvF4vfvGLX+DSpUtIpVIwmUzYtWsXent70draiqmpKbz77ruYm5vD2toarFYrDh48iEceeQRWq5V+H47j4PV64fP5oNPpYLfb79vn624gnU5jYWEBIyMjcDqdEIvFtBShvb0dOp0OExMTdLgOrPP8iooKDAwMoLGxEWKxGCzLIhQK0fbhQCBAWzZJzt6tXgswDIOVlRXMzc1RsY+AbOCVl5fD7XbD6XQiHo8jm83C4XBgeXkZcrkcu3fvxiOPPAK73Y5QKIRTp04hFovR0g2z2QyDwYDe3l6a+0gsqQQsy26wkHIch1QqRdvf8/k8lpeXEQ6HUVZWBrVajVgshnQ6jYsXL8LtdtM25VQqBb1ej7179+Jzn/scqqqqbunxKeH+RUlgu00UT0CB9TeMXC6LZCqFdCqNTCaDDAMwAhGUo68A3ln6/4VCIZRKJSVnKpVqWycMhUJR0NJksVhoRgGxOZbwG3Ach3g8jrW1Nbr95na7aW5IMBhEIpEoyAvZTg4cv85eJpNRMsdvzyIiKcl/I8IbEeLuJYKwuLiI733ve1Qw2Ox3r6iooOLWn/7pn26wtt1JxGIxvP766zhx4gRSqRRUKhU6OjrQ0dGBwcFBxONxHD16FBMTEwiHwzAajejp6UFfXx8aGhruqcf6fkMymcTVq1dx5cqVgqwKg8GAnp4etLS00JV+stGlUChQU1NTMBXdCWQyGczOzmJ8fBxutxu5XI7aB1iWpU3Dm6Guro62l+30NkMqlcLk5CTGx8fh9Xo3/Rqv14vp6WnIZDKYTCZIpVJEo1FcvnwZHo8H6XS6wPJEYLPZ8Fd/9Vd4+umn7wvRvoR7AyXutjU242/AOofLZDJIp9NIp9NI5VjkIUS95yS0GS/daiDcjYT3F2/JEkGDCHCE55Ftq9Lg5+ZB8t9Imz0R38jF8Wb203w+D5ZlKV/j34gYqtFo6PPJ34AjAw8iwN3rW4sXLlzA3/7t31LBhBzLWq2WtsPLZDJ85StfQWtr6127XxzHYXR0FK+99hrdVquvr0dPTw8GBgag1+tx+vRpnD59GnNzc8hms+js7MSzzz6Ljo6OArE6kUjQEqGqqqqSDe8GYBgGHo8HV69exczMDNLpNJRKJWpqatDe3o6amhosLCzg3LlzNIuZZOz29/ejtbUVcrmcbt77/X4Eg0G43W7EYjFoNBq0tbWhsbHxtjhVMBjE/Pw8LbcgEIvFqKqqgkqloqUqoVAICwsLmJ+fRzgcRlVVFR555BE89NBDyGQyOHXqFHU/BAIBOhB56KGHaPxPdXX1pscOsZB6vV6k02mwLItIJIJ4PA6RSIRUKoW5uTnkcjnYbDZIpVKkUinMzMxgcnISwLrlNJPJQCgUor29HZ///OfR09PzodygLWFzlAS220RxhsdWcIeTsCpF+NPWDBZmZ7C8vIxgMIhUKkWncel0GtlsFmKxuEBwu5UXpEQiKRDgCAEkNtStNj0eVDAMg0gkQltPif2UEDgyNeWLbvzbVhCJRAVhvqRCnTwvKpUKer2eWlBJZgBfhLvbhC4ej+P73/8+Lly4gLGxsYL1bQKtVou2tja0tLTgz//8z++6fW12dhY//vGPaQlCWVkZurq6sHv3brS3t2NsbAy//vWv4XA4kMlkUFNTQwPwd8q6+KCCYRgsLi5ieHiY5k8A6+85ra2t6OnpgcvlKrAfSKVSOhVtaGjY0WM6HA5jYmKChtwS8h0MBqHT6WC1Wje1kGq1WvT19aGvr29LMe52QCykExMTNOeDIJ/PY2pqCvPz89BoNDCZTNBqtZidncXExATi8TjEYjG9DsojeQAAydtJREFUuCPv1XK5HJ/61KfwjW98464K2yXcvyhxt62xPf7GwRVKwijj8CmTG8uL8wiFQsjn83RbSiQS0cwjfotjJpOhVsfNwB+wEs7HF+E0Gk2Jp90E+PlvRHAj9tNQKIRQKIRYLEazfPnbbyzLFsSFEN5GxFDSrqhSqaj4ptPpCsQ3UsTwQQ9AXC4XvvWtb9FMXZZlkU6naTu8Wq2GWCzGH//xH6O3t/eu3rdUKoXDhw/j+PHjiMViNHaiv78f3d3diEajOHr0KM6fPw+32w2dTocDBw7gwIEDBUM6hmHgcrkQi8Xotc0H/bjfD4hGo5iZmcHo6Cj8fj9kMhnMZjNaW1vR2tqKYDCIs2fPFmzjm0wm9PT0oKuriw4Fkskk/H4/XC4XlpaWkE6nYbFY0N3djerq6tu6j5lMhpYikKw4ArLIUFVVBblcjrGxMVy+fBmrq6tQKBSor6/HU089herqapw+fRo+nw/T09NwuVw04uexxx6DXq9HNpulrrCtjh2+hTSfzyMUCiESiSCfz9NsQIFAAIPBAJFIRIXMSCQClUoFlmWRzWZht9vx2c9+Fs8880yp/OABQUlg2wHcSgtVIpGga7FLS0sIBALUfsUwDM2gyGazBQ1KpAhgu+2aW0EoFBYIPWQLzmKxUCvqvT6pu5vI5XLw+XwF7aerq6t0ehqNRulkZbMNuK1eZoSg8/Pf+NNTlUoFjUZDLSn8/Dd+DtxOg2VZvP3223j55Zdx9epV2lDFh1QqRWtrK9ra2vD1r3/9rmdi5PN5/PrXv8abb76JaDQKiUSC5uZmdHV1Ye/evdBqtThx4gROnz4Nr9cLsViM9vZ27N27Fx0dHSWBYgcQCoVoKQK/6txut6O3txcMw+DixYt0m0ssFsNms6Gvr49ORXcKpHFsbGwMc3NziMfj8Hg8dFpOLNzFF6xkwki27HaapF/PQhoOh3HlyhUaJkxI2vj4OPx+P7LZLJRKJc0PIve3t7cXf//3f186F5dwQ5S42/Vxs/yNZVkaozAxMYGlpSUkEglwHEdjJMRiMUQiEcrKylBbW4uKigpYLBZkMhnEYjFEo1G6aRWPxyn32wwKhWLDBhx/C+5BbkK/WZD8t1gshmAwCI/HQ62n5MK52H7Kz38r3oCTy+UF24hyuRxKpZLmwGk0mg0C3N3K7U0mk/iLv/gLXLhwgf7umUwGAoEAjY2NsFqtEAgE+NznPoeHH374jt+fYjgcDrz88suYmppCPp9HeXk5ent7sWfPHtTU1GB8fBxvv/02xsbGkEwm0dzcjMcffxx9fX0wGo30+wSDQayurkKpVKKysrJ03bJNZLNZrKysYHR0FIuLi2BZFhqNBo2NjWhra4NUKsWZM2cwOTlJh6ikEbanp4fmimUyGQQCAczOzmJ+fh4Mw6CmpmZHcng5jsPa2hrm5+dpYyewfj0WjUZRVlaGPXv2IJfL4ejRoxgbG0M+n4fZbEZ7ezv27duH5eVlWuC2tLQEk8kEuVyO3/qt30J1dTUCgQCUSiWqq6uvW55RbCGNxWJYXV1FJBKBz+ej8SgKhQKpVArz8/NwOp30vSKVSkGr1eLZZ5/FH/7hH5ZcZg8ASgLbDiCTZ/Di62M4OeMDy3HQyiWQiITIMSyi6RyEAgEONFvwd5/ohEy8uUUtEonA6XTSNwHSllcMmUwGu91eYIEKhUJU6AmHwwiHw/TFfjsgWRUGg4EWMfBz4EoCxW+QTCbhdruxvLyMlZUVaj8lxC0ej28Q3/gi3Ga5TcVkjthPyXNCwnvVajXNgeNnvxEBTqlU3jKhGxsbw//+3/8bFy9e3DTIXSAQoLa2Ft3d3fjGN75Ba7PvJvx+P3784x9jeHgY+Xweer0enZ2d6O/vR09PD4LBII4cOYLR0VFEIhGaxbB3717U1taWNgR2ANlsFlNTUxgeHsbq6io9nlUqFbq6uqDX63H16tUN9oPe3l60t7fv+AZkNpvF3NwcxsbG4HK5EAgE4HQ64fV6aXnJZgTHarViYGAAu3btuiNNZZtZSFmWxfz8PKampujjJhavB657PB4kEglIpVIolUrI5XJ6QW21WvHf/tt/w8GDB3f8fpbw4UGJu10ft8vfcrkcHA4HJicnMT09jdXVVSpkkM11gUAAkUiE8vJy1NbWoqGhATU1NXSTIZ/PIx6PIxKJIBaLIR6PIx6PIxaLIZFIIJlMbjmkk8lkm+bAERGuxNO2D5L/FolEKH/z+XwIBoOIRCKIRqNIpVLIZDIbtt+A9fMaf2hKGuuJ/ZRs3xABjmzFkRsR43aKk7Asi//xP/4H3nzzTXAcR0U2juNQWVlJB0qf/OQn8dRTT+3Iz7wZMAyDM2fO4M0334Tf76dD0sHBQfT19UEqleLkyZN46623sLq6CplMht27d+Oxxx5DW1sbPUdnMhmsrKwgl8uhoqKiJF7cBIhwRAqbSCmC3W5HW1sbbDYbhoeHMTIyQgftSqUSra2t6Ovroxl5+Xwefr8f4+PjmJ1dj0Fqbm5Gf3//jjwfiUQCCwsLWFxcRDabpblwiUQCTU1N6O/vh8vlwpEjR+B0OmE0GmGz2dDW1gZg/Tp5ZWUFs7Oz0Ov1kEqlePTRR7F7927aVF9eXr6tjF6+hTQcDsPpdGJ2dhahUAhSqRRCoRDpdBpra2tYW1tDLpeDUqmkOZF9fX34+te/jpqamtt+XEq4d1ES2HYImTyDH59z4JcjbnijaTAsB5FQAKtWjud32fH7e6u3FNeKQVpgiOrucDhoQ2YxyKpsfX09GhsbYbPZIBKJEIvF4PF4aCAoCaUMhULU7ni7IDlw/LB/IsLZbDbodLrb/hkfBpDcArL9tpn9dLPyhWIRrhgikahg+40QN1JlTzZflEolLBZLQSBzsQB3vWwyv9+Pf/zHf8ShQ4cQj8cL/m0fB/QCkMtkMJnN6OraulDgTteVX7p0CT/5yU+wuroKoVCI6upq7Nq1C/v27UNVVRXGx8dx+PBheoKuqqrCvn37NkxES7h1cBwHt9uNS5cuYWZmhhIygUCAhoYG2O12rKysYGFhgWbiGAwGdHV1obu7+45sQUYiEWohXVtbg8vlwsrKCjiOo+9XxRehUqkUu3btwuDgYEHI8k6i2EKaSCQ25NuFw2Gsra0hlUpBIBAU5AE9BAH6RUI0NjSivaMDQuHWIvqdfu2VcO+ixN1ujJ3kb5FIBIuLi5iensbCwgJt4BMKhQUtlhKJBGVlZairq0NjYyMqKyu3tA6xLEsFN/6NbMElEoktXQ2lHLidQzabRTweRygUgtfrpdtvwWAQoVAIqVQKqVSK2k+JCMeyLEQiEQQCQUELKl8MlcvlkMlkdCtOr9fToUqxAHezW4s/+tGP8N3vfhf5fJ7aaPeyLPZIpdBqtRAKhaisrEJNdTWwyWnkTp8/IpEIXnvtNZw7dw6ZTAZ6vR67du3CQw89hKamJgSDQfz85z/HhQsXkEgkUFZWhsHBQRw4cACVlZUQCAR02ykQCMBoNKKsrKw0PL1JkM2r0dFRuFwuiMViGAwGNDU1oa6uDvPz8xgeHqbXATKZDI2NjRgYGIDdbqfPg9/vx+XLlzE9PQ2xWIzOzk4MDg7uyNCSYRi6jBIMBhGPxxEMBiGTyVBTU4OKigrMzs7i4sWLYFkWFosFNTU1VDzPZDKYnJyk4vbg4CAOHjxISxw0Gg2qqqq2vQlJLKSrq6tYXFykx6hSqYRQKITD4YDX60Uul4NEIoFYLEZ3PI69Mhk62jtgtV2fX5a42/2LksC2w0jnGIyshBFL56GRi7GrSg+55PaC1RmGgd/vh9vtxsLCApxOJ6LR6KYimVqtRmVlJSVtVqt105NxKpUqEODIBhwR4Eij3e1AIpFQqyNpQiVZY2VlZTAajaUTINafX5/PR9tPXS4X1tbWqDAai8XolHQrC+pm5JpkiJCPJO+N1FYTAU6hUMBoNNL8l+INOLVaDY7j8NJLL+Gf/umfEIlE6M/4Cgc8DIAkzIjFYigVig2CHZvNQvvsM7D/3d/dwUdyvTXp1VdfxTvvvINMJgO5XI62tjYMDAxgYGAAYrEYp06dwnvvvQev1wuRSIS2tjY8+uijdC2+hJ1BIpGgpQiBQIB+3mg0oqamBrFYDHNzc/TY1el0NMPtTmxCEvFvbGwMs7OzcDqdcDgc8Pv90Gg0tCim+NitqanB4OAg2tra7khJBt9COjc3h+Xl5YLsw2w2C6fTiVAoBJZl6Wv5L1Uq7BeKkAMgFomgUqsgFGx8P71br70S7k2UuNv2sdP8LZfLwePxYGlpiVqGkskk3WTgczOZTFaw4Wa327d9QUra7oo34PhbcNfLgeOf70s5cLcGIlzF43H4/X7KrYn4Robkxdtv5PwnFAohEoloDhzJfyPbb3yBjd9SWyy+kTiZYhw5cgR///d/j3g8Do7j8PlEEvs4DnkAQpEIQoEA4mtcka+x3c3zx8TEBF5++WUsLy8DWD/37tmzB3v27IHRaMTly5fx8ssvw+FwQCqVorq6Gvv378fg4CDdkorFYrQxk+R0lXBzyOfz8Hg8GB0dxezsLI2qqKurQ0NDA8LhMIaGhmhbrUQiQVVVFQYHB1FbW0t5kt/vx4ULFzAzMwOpVIrdu3ejv79/x2ztoVAI8/Pz1ELKcRy1gJKMPq/XC6VSiYqKCgDr+YRSqRQulwsSiYS6LJ577jlwHEeHr1VVVTe1JEI2AT0eD86ePUtdP0qlkhYyZLNZyOVyfDGTRX8mg7xAANk1YX2z8WiJu93fKAls9yFIHhgR3NxuN51kFkOr1aKyshL19fWor6+H2WzeljJPVmD5NyLABYNBRKPR286BE4lEBRtXxTbUrcTBBw3ZbBYulwuLi4u0/dTj8cDn8yEUCiGZTFKb6WYiHBHi+C91YmUgmTHF+W/FpE6n00GtVmNubg4/+tGPEA6HIRQK8TWRCHsggPeaBVWA9SD2crsdSl7rU87lgvqJg3ftRLG4uIgf/ehHmJ2dBcdxsFgs2LVrFx5++GE0NzfD7/fj0KFDuHz5MmKxGHQ6Hfr6+rB//346ES1hZ8AwDBYWFnDp0iUsLS3R9w2JRIKamhqwLEvLKIB1W2ljYyP6+vpQUVFxR56LXC6Hubk5jI+PY3p6muZ0MAwDg8Gw6QauWq1GX1/fjlkeNgOxkF66dAnvvvsubUpjWRY+nw8ej4dus/2tzYYDMjl8AgEgEEAiFqOiomKD3fZuv/ZKuLdQ4m73BsgWmtvtxtLSEj2Pk+11stFLoFQqUV5ejpqaGjQ0NNx2LEcmk0E0GqUC3HZz4MjmLH/gxregajSaUu7VNsCyLJLJJKLRKOVvfGdJLBaj4lux/VQgEGwQ4Mim22Y5bwqFYsPnFQoFxsbG8O1vfxterxd/nM6gN5+HD9fiSGQySMRiKFVKmE1meize7fNHNpvF4cOHceTIESQSCcjlcrS3t+PAgQPo6OhALpfDL37xC5w4cQKJRAIqlQrNzc144okn0NbWBolEgnw+D6fTSbfdTCbTXbnvHzYQ983U1BTNhZXL5TCbzWhqagIAjI6O0nw0kUgEu92O/v5+NDY2UqHX5/PhzJkzWFhYoFtjPT09OzawzGazmJ+fx8WLF+Hz+WgBSSAQgM/no0KXRqOBTqfD7Ows0uk0IpEIfR01NTXhYx/7GLRaLVZWVhCNRmEwGFBRUXHT95Ncsx06dAhXrlyhQ5X5+XlEIhF8XSTGAMfBLxJBAECj0cBWZoNYVHi9W+Ju9zdKAtuHAOl0Gl6vF6urq1hYWIDX66VNScWiik6nQ0VFBerq6lBXVweTyXRLWzukJplsWpENOEIWotHopsH4NwOBQACNRrNBgCNFDKUcuHXEYjE4HA4sLS3B6XRibW2N2heCwSDN4+OLb5sJcfxjhdgYiAhH8t8MBgMEAgGOHz+OcDiMv1Sp8bBEAg/H/WYCIxBAIhGjzFZGp9/M6updP1GwLItjx47h1VdfRTQahVgsRn19PQYHB7Fv3z4YDAZMTEzg7bffxuLiIs3veOSRRwomoiXsHEKhEC5duoSxsbGCgYDFYoFEIoHP56MXegqFAtXV1bR84E5sjwHrr5+JiQlcvXoVExMTWFlZQTgchlwuh8lk2vA+IxQK0dLSgsHBQdTV1d0xMdbn8+FXv/oVXnnlFUSjUXpfSXPa3xiNeEypwhrLQiQUQigSQiAQXhtSWCDAB3OBVMK9hRJ3u/fAcRwVW/iFSX6/H7lcjt4IBAIBVCoVbDYb3SIxm81QKLZurr9Z5PN5xGIxRCIRuvnGF+BSqdSmURUAaLsm2YAr3oIr8bQbI5/PI5FIIBAI0PgQkuEbDocRj8eRy+UoXyMiHMdxEAqFBQKcRCKh+W/ECkfsp16vF6+99ho+ubqGAQCea7xPSOyrMhmUCgUsViuEAsEHdv5YW1vDT37yE4yNjVGr38DAAPbv34+ysjLMzMzgjTfeoHlfGo0GfX19OHjwIOx2OwDQnCy1Wo2KiorSwP42kMlk4HQ6ceXKFSwtLQFYf8wbGhqgUqkwNzcHh8MBjuMgEAhgtVppmRV5n1pbW8OZM2ewvLwMlUqFPXv2oL29fccEeo7jMDk5ieHhYcRiMRiNRmSzWSwtLcHr9UKr1cJoNEKn09HoJbfbDZVKBbvdjsrKSvyH//AfUFZWRmOayCbkreYELy0t4e2338bExATC4TCWl5fxO2se7BWJ4BMIIBQIIBAI1hcTyssL3itL3O3+Rklg+xAikUhQwW1xcRF+vx/xeJxuOhGIRCLo9XrY7XY0NDSguroaRqNxR3zyHMchFotRsYeQhWAwiGAwuGWJw81CqVQW5MAREc5ms8Fms0Gj0TzQ20gcxyEQCGBhYQEOhwNutxurq6t0esrfROQ47roWVHLskMDm1dVVfEsux6MKJVYZZv1yXoD1C3sBIBSsC7oqtQqySBSJjg7E/+A/FdhRyJ9JXsGdQCQSwY9+9CNcuHABLMvSlfADBw6gq6sLLMvi5MmTeOedd+D3+yEWi9Ha2oqDBw+ira2tRMruALLZLCVCq6urVNwlk3b+RoVUKt10KrrTIBkuV69exdDQEGZmZrC2tgaGYQospPzjwWw2Y2BgAD09PXfsIjKVSuGHP/whDh06hHA4jFwuB7fbjT9lWRyQK+DO57G+xCakF1pqtRpVVZUQi8QlkvaAo8Td7n2QdtFwOAy3200thaFQCLlcDul0umDDjAwfSYZbbW0tTCYTlErlHbuPDMMgkUhs2IIjGXBkS2MziMXiggw4ciM5cLdTxPSgIJvN0rwnfv4bP1eZZKyxLFtgPxWJRFR8EwqFSCQSGLxwEb35PDwsC3IRKBAAQqEI0muDVaPJCHEwBOHevVD/P39ZYEW9UwMvPliWxfnz5/Hzn/8cgUAAIpEIDQ0NeOyxx9Db24t8Po/Tp0/jxIkTCIfDYFkWVqsVjz32GPbt2weVSoVUKoWVlRWwLIvKysodL1R60MCyLAKBAC1FiMVi1IJpMBiwtrZGW0kBwGAwoKenB11dXdBoNOA4Dk6nE+fOnYPL5YJOp8PAwAAaGxt3jEMlk0lMT09jZWUF2WyW2kUdDgfS6TQtM/D5fIhEInA4HMjlcqiqqkJDQwN+53d+B9XV1chms3A4HEgkErBYLCgvL7+l9ymWZTE+Po73338fCwsL6HzvBJojEayxDAQC4bXX3bowbrVaodfpAZQEtvsdJYHtAQBpRFpbW8PS0hJttUwkEgWCm0QigV6vR0VFBerr61FRUQGj0XjHLhyJEEgIQ3EOXPH9uxXIZDJqQyUCnMVioTlwBoPhgc4XyefzcLlcWFpawsrKClwuFzweDxVDE4kEFT8IceNPUD8TDGGfWAx3Pg+A/1YioBttMpkMZSIh5o0mnOvvg1qtptlv5CYSiQpsKMU5cCqV6raFrsuXL+Oll17C2toaBAIB7HY7BgcH6UQ0EAjgV7/6FYaGhpBIJOhE9IknnkB5eflt/ewSNgfJRRsaGsL09HTB1oZUKqVh0cD6RZrVasXu3bvR1ta2o9sbxSDr/MPDwzh37hyWlpaQTCYhEoloCympqQfW3zu7u7sxODiIsrKyO3KfVlZW8NJLL2FsbAx+vx/PO1bQz7Jw53LgsG7PFlybhpKJaGVlJaThcImkPcAocbf7C2STLBqNIhwOU74WjUYRCoVo4yh/QCkSiaDVamGz2VBbW4uamhoYjca7WmDAsiy1XfFLGMjtejlwIpEISqWSnu/J5hXZgFOr1Q80T7sRSP5bOBwuyFUm+W/RaBTpdJoKbgzD4MDVq6j3B7DKMOA4toC+kXxAiVSCMoEQoeYmuD/xiYKMN4lEsmn+G/nzTtqGE4kEXn31VZw6dQq5XA4qlQq9vb04ePAgKisrsbi4iDNnzhRsxjc1NeGZZ55Ba2srBAIB3G43wuEwHcCXBN3bB2n2HBkZwerqKsRiMYxGI6xWK6LRKJaXl+lrXqvVoqOjAz09PTAajWAYBouLi7h06RI8Hg9MJhN27959W9tifBCbcDQapY2jc3NzWFhYQCAQgE6ng06nQyqVotb9eDwOvV6P6upq/NEf/RGam5sBrG9Crq2t0SKFW70mjkajGBoagvh7/wTN5CRW83lks9mCY5H/GH4Qzp8Sdg4lge0BA8dxCIVC8Hg88Hg8cDgcdP28WNCSy+V0w62+vp5mGdytdf9sNguv11uwAUcmdiSzYquJ6XYhFoup9ZG/AUcy4CwWywO9wZRMJrG8vEztp263m05PQ6EQ/qPXi65UCs5stlBfIxAAErEENqEQFwH8/wXrYgSxnpIwX41GQ1e3lUrlhlBfIhhsJryRv29nsymbzeLVV1/FkSNHkM1mIZVK0dLSgkcffRR9fX2QyWTUerCwsACGYVBeXo7HH38cg4ODpca1O4hEIoGRkRFcuXKFhufyRV0AdApvMpnQ09ODjo6OO27ljcfjGB8fx4kTJzAyMgKv1wtgXTgmFlK+2EfCftvb23f8vYNhGLz33ns4fPgwBi9cRK3XC9c12xB/C2F9Y0EMgUCAKpkMlueeg/3/VyJpDyJK3O3+BcdxVFwjsR8kfoNsvPEz1AidF4vF0Ov1sFqtqK2tRVVV1V0X3Db7Xcj9JrfiHLhsNrtpuRY/B47fhMq/lXLgtgbLskgkEjSE3efzQfNvL0E7OYlVlkEulwPL8HJ6BQKIr229lQmFWLRYcLK7i7oXCG9WqVQF2W/EiiqRSCASiTaIbsVFDDcrcs3NzeHHP/4xtShWVFRg//792LdvH1KpFObm5nD58mUsLCwgkUhAoVCgv78fTz31FGw2G90QlclkqKqqKpVa7RDy+TzW1tYwMjKCubk55HI5KJVK2Gw2WqhHNnCVSiWam5vR19cHm82GfD6PmZkZjI6OIhKJwGKxoL29HWVlZdDr9bcthHq9Xng8nnU3jUqFyclJnDx5khZhGI1GJJNJhMNh+P1+pFIpmmH4u7/7u3jyySchFouRSqVoVjDZgLuV+8ZxHGa+9nVk338fWZ2Ovg+yHAtAQDmuUqlEmVAI3ZO/VRLY7lOUBLYHHAzD0MyHtbU1uFwuSnri8XiB4KZSqaDT6VBeXo66ujrYbDaYTKY7uklyPeTzeSoW8osYAoEAwuHwjuTAEbsV34ZqtVrpFMxqtT7Q+SKL//k/I/neCYQlEvgDAbAMAw4cFdukUinEYjFMDIMRuQz/ptHcsAmVn/1Gbnq9Hnq9ntpJCFlTKBTU0kwagTZrQlWpVAXHqdPpxA9+8ANMT08DWF9j3717Nw4ePIi6ujpqPThy5Aj8fj9EIhGam5vx9NNPo7W19a7YIx5UMAyD+fl5WopA3oNSqRRyuRyEQiEUCgUEAgEMBgM6OjrQ3d19x4OMOY6Dx+PBuXPncPz4cczPz9N8w80spCqVirZm8bfddgJ+vx8jf/IlGGZnEZXJkEgkkMlmwV17rEgej0AohBXAksWCZ9765U21YpXw4UCJu314kEqlEI1G6UZSNptFOp2mBQbBYJByt1QqRQUTqVQKvV4Pi8WCuro62O12GAwGqNXqe2qTJ5fLIR6Pb8iBI02o6XR6U1eDQCCg7cr8HDh+GcNORJ98mOB+8UXEjx2HsKyMludkMhmwHAuxWAyxSLzOm+JxLFksONbWSv8viRMhghsR3YRC4XphglRKoz/4+W/kRkQEku+72SYcGawWI5/P4+jRo3jrrbeQTCYhlUrR1taGp556ChUVFXA6nXC5XBgZGYHD4UA2m4XFYsETTzyBffv2QSQSYWVlBZlMBna7fcfPzQ8ySCnCxMQEJiYm4Pf7aX4zx3EIBoNUaJPJZKirq8PAwAAqKyuRyWQwMTGB2dlZWk5RV1cHi8UCo9F4W5w7Ho9jZWUFQqEQNTU14DgOJ06cwMmTJxEMBmEymWicDinH0Ov1kMlkeOihh/D888+jtrYWAGhe5noMx62JtO4XX0TsnWPI6vU0bzMUDiGbyUIqk4Lj1kVxKwDRvr3Y/YMf3PLvXsIHh5LAVkIBcrkc/H4/3XBbXV0tWPPnN4dqNBro9XqUlZWhtrYWFovljueA3AxYlkU0GqUTDH4OHLGh8qe+twIyWSXtm2QDji/AfZhz4AhJk1RUIJlMwulyra88A7BardDpdMhks4DXi9WqSvy6uZlW18disQ3f70ZNqADo1hv5SHL49Ho9DfQl4ptSqaSCh0gkKth6U6lUGBkZwdGjR5HJZCCRSFBbW4sDBw7goYcegkajQTgcxltvvYXz588jlUpBrVZjYGAATz31FCwWy918qB9IBINBWoqQTCYBoODi0Wg0QiAQQKvVoqWlBb29vXfF/sEwDGZmZnD48GGcP38egUAAwLqwxbeQErtmc3MzBgcH0dDQsGP3zf3XLyJ09ChCYjEYlkU2k0EylQSTZ+jPEAgFKBeKcEEgwM9tVvyv//W/0N/fvyM/v4T7AyXu9uFELpejYhuJciBbYqlUim6Zk+0wfukV4Sxmsxn19fWwWq0wGo33nOBWjHw+T7f4+Jtw/CKGrVwN/AEcfwuO5MCRoc2DAj53Y1kWXq933UbKMpDL5CgrK4NIJELO5YLy8ccgeeEFOswOBoMF+W/Fwic5zsighwhqfGFNqVRSQY0IcGRgSs6bUql0SwEuHo/jJz/5CUZHR8FxHPR6PR566CHs37+fZgWurq7i/Pnz8Hg8AIDGxkY8//zzaGpqgt/vh8/ng06ng91uLw1NdxjpdBorKysYGRnB8vIyBAIBzVqMRCLUOiqRSFBZWYmBgQHU1dVRx4DT6UQul0N5eTnKy8vp9dWtbh3mcjmsrKwgmUzCbrfDaDRibW0N7777Li5fvkzLGVKpFF02kUqldPje39+P9vZ2VFVVIZFI0Fw/kj13M+C/9rK5LJLJdd4WjcWQSqVo8YEhn8eoTIb8l7+EL3zhC6WNy/sMJYGthOsik8lQgYqIVPx2KUJmhEIhtFotdDodysrKUFNTA5PJBJPJdM/a6jiOQyKRoBlwxPpI2jdJcOxO5MARAa44B85ms0Gv19+3J3f+iQIA0pkM3C4X5AoF7Lzcss3COsnKNcl/W11dpXbgUChUEOpMwC9iKBbiWJalQaH87TeygUjsp8X5b5lMBqdOncL8/DwV4Zqbm/Hoo4/SYNa1tTUcOXIEi4uLYBgGZWVlePLJJ7Fnz54HeoPxbiGbzWJ8fBzDw8O0Ej6dTsPv94PjOJjNZkilUqhUKtTX16Ovrw+VlZV3JbcnkUjg1KlTOHLkCCYnJws2MYnQTrYnjUYjBgYG0Nvbe9ubv3QDobwc4VBoPeSaySOZSBZkHZWJRDjPsfieWAyJRIKvfOUr+NKXvvRAXUw+yChxtw8/GIYpsJIyDEMvxlKpFA3HJ5thxYKbSqWigltNTQ0V3O634SDDMEin0wUCHH8TLplMlnLgrqGYu7EcB7/Ph2w2i7KyMjqYvFHQOsdxVNBdW1ujAhzh0JFIZFPrr0AgoHytWIBTKBSQSqWbCmt8kUEqlWJhYQHHjx9HLBaDRCJBXV0dnnrqKdTW1iIWi0EgEGBqagrnz59HPB6HXC7HwMAAnnvuOcjlcrrZVFlZec8sB3yYwLIs/H4/xsbGMDU1hXg8DoVCQe3v+XyeirBlZWXo6+tDS0sLQqEQxsbG4PP5aNGVWq2GTqe75QZl4kLw+XwwGAyw2+1IpVIYHR3F1atXEQgECqz3JJdNKpWip6cHDz/8MAwGA5qamlBRUUFz/XQ6HSorK7cdCVL82uM4DslrG6SpVBKxaAwsx8HKcbgkFuP7Mikeeugh/M3f/E1psH8foSSwlXBTSCaT8Hg8VHQLh8MFYbZEcBOJRNDpdDQHpLq6GiaTiU5J7xek02lqoeUXMRDyQKrTbwcSiQRarZbaUMkGHMmBM5lM92y+iPvFFxE9dBhCHulhr02C+LSczWahffaZbWcJcByHcDhM209dLldB+2k4HC7YpuT/v82EN36dvVQqpdtvMpmMZvCl02lMTk4inU7Tht329na0t7dDrVaDZVksLS3RbSq5XI7GxkZ85CMfQVdXF9Rq9T37PH1YwHEcXC4XLl68iNnZWeRyOWSzWayuriKbzcJkMkGn00GhUKCqqgr9/f2ora29azmKKysr+MUvfkGtBwRqtZqK6mKxGGKxGN3d3RgYGIDdbr+ln1X82mMY5prFhytokJMCOJbJ4NuZNLRaLUQiER566CF85zvfgUajue3fuYR7GyXu9mCB4zi65RWNRpHNZungiGVZJJNJur1TXEJAtjiIO8FsNhdwN61We18JbsUgZQBEdNupHDitVguNRnNf5fVuxt04rD9GQt5zfLPcrRgsyyIWixW4SPhZyvzrBgJ+1lvxBpxYLKYOBr7ldGhoCFNTU+A4DlKpFE1NTejt7YVIJKJW1EuXLmFxcREcx8FqteK5557D4OAgXRwg8S/38zF+LyORSGBubg6jo6NYXV2FRCKhwwGWZen2osViQW9vL1pbWxEMBjE+Po5oNEodU0QMN5vNtzQEiEajcDqdkEgkqK6uhlgspvcrGAyCZVnMzMzQ0pB0Oo1cLgez2YxHHnkENTU10Gq1aGpqglarhdvthkAgQHV19bY41WavPQDXuFsODMOuv2+zLC5KJfgXhQK5XA7V1dX467/+awwODt7U71vCB4OSwFbCbYFvwfR6vRuao/hrwMTGZ7FYaEOpyWS6520J10M2m0U4HN6QA8dvbyI5TbcKQo71ej0luiTviRQxfFCTt+jhw4ifOrWtr1Xv3w/tM8/syM9lWRZutxtLS0tUgOO3nxbnBxb/380sqGSiSkg3adIymUxoaWlBbW0tPVbn5ubgcDgoKWhoaEB3dzcVkLdqQi1tu+0cEokELl++jJGRESq4kokiEfaVSiXKy8vR19eHpqamu5bDQzL83n77bVy9epWKXUKhEHq9nm6uCgQCVFRUYHBwEJ2dnTd1gbbZa49lWKw4V7C2uoZYPI5gIIBsLofz2QzezWQgEAhoNo7NZsO3v/1t7N69e0d/9xLuLZS424ONdDpNz2nJZLLgPSCbzSIUCsHr9dLWcCI2kdIroVBIG/dMJhMtTCCC24dpq4vjOJoDR7bgCJclVtRMJrPpcO9+y4H7oLhbMUiWMrmOIPw5GAwiEols6iIhg1KS+0Y+RqNRGpYvEAig1+vR2dmJ8vJySCQSWK1WBAIBnD9/HqFQiIoipASBiDj19fXQaDRU8ClhZ5HL5bC6uoqRkREsLCzQtnjSNExEKoPBgF27dqGtrQ1+vx9TU1N0w9JmsyGTyVCngF6vv6n3omw2S0sLKioqoNfr4fF4MDIygsXFRajVajidTvj9fkxPTyOTydDtyurqalRWVqKsrIwWyQgEAiSTSZjNZpSXl1/3vlzvtcdxHCLhCILBIJwuJ96LRHFBvC4ysywLjUaDz3/+8/jsZz97Xwn6DyJKAlsJt410jsGVlTBi6RyQTcEqTiESXJ+QJpPJAkspEZtkMhkMBgN0Oh0sFgvKysqo4Ha/2RKuB4ZhrpsDF4lE1rdObsOGKhQKIZfLYTAYqA3VbDbDYrHQLbgPm73hRkin01heXsby8jKcTidWV1fp4x8KhZBKpbb8vyzLIpVKYXV1ldpoSNmF0WikOSEsyyIQCCCXy9EtxLa2NjQ3N9MLmOIToFgsvm4T6oOWA7MTYBgGc3NzGBoagsPhQD6fh8/ng9vthkKhQEVFBbRaLaxWK52K3k3busfjwZtvvol3330Xfr+ffp4QQyIEKpVK9Pb2YmBg4KYzPYrhcDjw05/+FJOTk7hw4QI8Hk/Beww5XhUKBf78z/8cX/ziF2/r55Vw76LE3UogyOfzBVtbqWwey3EBIJbDoJajTidEJOiH1+tFOBxGKpUq2HBj2fUAfOJOMJlMtDCBtIB/2HlGLpcryIHjW1ATicQNc+DI9hs55z/IOXDbQTqdpjEu/AF2OBxGOBzetMhsZWUFy8vLYBgGIpEIFosFtbW1kMvlUKvVMBgMmJubw8LCAvL5PORyOdrb27F37176/JnNZqjV6g1FDMWW1Q/78X4nQZwqU1NTtBSBlCFks1nodDpIpVJoNBrqJvH5fFhYWADHcairq4PZbKZWYLKAsF3hieM4rK6uIhAIwGQyoby8HKlUClevXsXU1BTy+TzNi5uZmYHP5wMAaLVaVFVVgWVZqFQqlJeXo7KyEhaLBSKRiIpwt7P4EI1GMTQ0hOPHj+Odd94piMyRyWR49NFH8V/+y3+5ba5Ywp1DSWAr4ZaRzjH49/MO/HLEDW80DYblIBIKYNXK8fwuOz4zWIl4JEy32wKBAK1pJ6SNCG78oHoiDn1YbAnXA6lPJ9tvxIZKBDgiBm02Mb0ZyOVy6HS6DQIcyYHT6XT3bQ7czYLjOEQiESwuLlIBjn+M8gNYyWSVbLiR0FO1Wk3z4CKRCMLhMJ32k60p8nXkY3H2m0Kh2EDOyP/nC3DFQlyJ0G2NQCCAoaEhjI2NIZVKIRgMwuFwgOM42O12agHp7u5GZ2fnXW3VZBgGZ8+exS9/+UuMjY0VbLYSCynJkmtsbMTg4CAaGxtv+flmGAbvvvsu3nrrLQwNDVFLLR/kmDpw4AC+/e1v31f2/RK2hxJ3K6EY6RyDH59bwpuXXfBG08gxLEQCwKQU4+k2M/7TQ3VQyqUF+bREPCKbXIlEgma9kYxTo9GIsrKyAsHtQeEVBCQHLhKJbNiAI0UMW7kaxGJxwflfo9EUiHAqleqBezyvB5LjxR9gk+23tbW1goIDqVQKi8UCg8EAlUoFi8UClmUxNTVFB19KpRKdnZ0oKyujAo/NZqNtqJuFzBcXMfAFOIVCUXq+tglSinD58mU4HA7kcjkEAgGkUimagahUKtHY2IjW1lb4/X64XC7IZDK0trZCp9MhGAyCYRh6HbndjdFwOEy/V3V1NUQiERYWFnD16lX63pdOpxEKhRAMBuk2W29vLzKZDMLhMMRiMaxWK8rLy+m1K+Gct3r9ynEc5ufncfz4cbz66qvweDz0WlAmk6G+vh5//dd/jZ6enlv6/iXcWZQEthJuCekcg//6xhhOzvjAchy0cgkkIiFyDItoOgehQIADzRb83Sc6IROvn2DIZgk/vy2bzRYQEEI8SAYIIW1ms5luuOl0ug+t4FYMEiAbDAYLcuCCwSDNIiNr1bcDstHCL2Lg58AZjcYHpsGGYRh4PB4sLi7SEoZjx45hYWEB2WwWDMNAqVQWNOYyDAO/349IJAKWZakNh4jD/PwQkUhEM+CI9UalUtHWUz5B2+w4JzkwxeIb+XNpbXy9nGVsbAyXL1+Gx+NBJBLBysoKotEobDYbKioqYLVa0d7ejl27dsFsNt/V++fxePDWW2/hvffeg9frpdtlxEJqtVrpa7G/vx+7d+9G/uTJW7L0+P1+/PSnP8WxY8dw6dIlxOPxgq8ViUTQarWorq7GP/zDP5TI2ocMJe5WAh9bcbdsnkEklQU4Dr3lcvz5XjOMei3NF0un05SDkFwisuFGrKUMw0ChUNANN4PBAIvFAqPRCIPBcF8XOu0UWJalg2b+NiH/cczlcpu6GooHcMVtqKUc2EIwDIMTJ07gRz/6EdbW1pDJZKBQKGA2m2m2m1wuh8/nw+LiIrVQk1gQmUwGiUQCk8kEuVwOsVgMqVRKM+HEYjEV1cjXFkMikWy5/aZQKErPVxEIl7569Sqmp6cRiUQQCAQQjUbpwFqpVKK2thYNDQ0Ih8Pw+/3QarXo6uqCQqGA3+9HJpOBRqOB2WzelmMhk8lQYa+qqgoajQY+n4/aWElBTCwWQyaToflxHR0dUCgUWF1dRSgUAgDodDpotVrU1NSgqakJltlZZM6d29bvv5kdO5VK4dy5c/jhD3+I0dFRMAyDXC4HmUwGi8WCP/zDP8Tv/d7vlYbv9xhKAlsJt4QfnF7E/zk+g2wyDoUYVDgQi0UQicRIMUCWAb78aAO+sL9+0++RyWTg8/kK8tv4gls8HkfmWm4QKQEgNyICPSi2hOshk8kgEonQBieyRk824MgJ4XYgEoloGyd53Ekbqs1mo60+H2bhc2xsDP/yL/8Cp9OJZDIJkUiEiooKmM1mhMNhBAIBuN1uOBwOJJNJAOvkqjirhl/EwLIstaCS8HuRSERJmV6vh1arLai0J9tvW03nSA7MVhtw91IOzJ0Gx3FwOp24cOECZmdnEY1G4XA44PP56ISxqqqKBiKXl5ff1WM4m83iwoULePvtt2nDFoFUKoXJZILVaoVOp8NvTU1DPzEOkez6OX6bhVJzHIcLFy7gpZdewvHjx+HxeDYEecvlcphMJnz961/HF77whQ/1a/lBQom7lcDHD04v4nsn5qGQiKBVbLy4j6ZySGbz+NxAGT7apKLnMqVSSa2MUqkUkUiEim0+nw+5XI4WCJAbwzA0P5a/QU8+6vX60kCoCBzHUR5MxDf+JhyJWtnMhkpy4Mg5n+S/8QsZHsRcsVwuh1deeQXHjh2jW5f19fXo6emBRqOBVCqFx+PBsWPHMDw8TDN8jUYjrFYrANAMXb6oJpFIaAsqEd8IhxMKhRAKhZDJZAVlDMUgIt9WApxUKn3gni+CeDyO2dlZjI6OwuVy0cUCIixpNBrY7XbU1NTQ14nVakV3dzckEgn8fj8SiQQVVW/khiLZzqFQiF7bpNNpjI+PY3JyEm63G5FIpMAur1ar8fDDD0Mul8PhcGB1dZUWXAkEAiiVSjw5Mwv9+DhEN+DeNyoUcTqd+Nd//Ve89dZbSKfTtJDNaDTiwIED+Ku/+itotdpbf8BL2FGUBLYSbhrpHIPf/adzWPRGwSWC1D4HASCAgL6BpYVyKLk0PqKYhcWopxlhxTcizCSTSUrYPB4PUqkUDZzlTw7Ilge5aTQaGAwGKvzcbNjlhx25XA7RaJSKmfwm1GAwiGg0uiM5cMTmu1kOnMVi+VDkwDEMg5///Oc4dOgQstkshEIhGhoa8PGPfxy7du2CQCBAKBTCm2++iTfffBNerxepVIpm5OXz+W2VXvCLGEijFiFy5M/k8Sb2U774tln+GwEhBVttwG21OXe/Ix6PY3h4GCMjI/B6vVhZWcHa2hpkMhnKy8tRX1+PhoYG7N69GzU1NXf1WOU4DisrKzh27BhOnjwJr9dbcJyoVCp8NhRGYyiEnF5PNx43e55yLhfUTxzclKTFYjG8/PLL+OEPf4jp6ekNF2kH5XI8olKhwl6BwT2D152u38ng6xJ2DiXuVgIB4W7OUBLlOjk8Hi/drhaJRRCL1sWBQIpBlUGJn35pH6QiAd20isVitNiHnxsGrMcpkLInv99PW42JpZQIbhqNhm64kY9EdDMYDCXBbRvIZrNIpVK0KZZfwhCPx5FOp7cU4CQSyZYbcBqN5kOdK+ZwOPDDH/4QCwsLANajGXbt2kVLkDQaDdbW1vDKK69gfHwc6XQaYrEYzc3NkMvlyOVy4DiOPu7Fj7FQKIREIqG5bTKZjA5OidhGBDiBQEAFUfK1m53PBQLBpuIb+btMJvvQPl8EpCl+ZGQEc3NzWFtbg8fjoY2jRAi12WzI5XK0cbOrqwtCoZBGv5DSMqPReN3HLBgMwu12Q6lUorq6GkKhEIuLixgdHcXMzAzNzybCqkqlwjPPPIOOjg44HA5MTU1heXkZwWAQHMdh8OIQGvx+cFYrdNcphrked+M/Fr/+9a/xne98Bz6fD5lMBrlcDnq9Hh+3WvGJmloYjDfOZSvxtzuPksBWwk3j3EIA3/zZFbC5DJLR9aZMPtbPEQJALAXEUqhGfw5ZdIWKA8U3Ij6QZioiymg0GjrxiUajtGmGL7iREyBfcCtu3CzZEq4PhmEKcuAIQSZbcJFI5LrBvdsBIRJkim0ymegGHLGharXa+4Zcr66u4p//+Z8xNTUFYH26v3fvXnzsYx+DxWIBsL7W/dprr+H06dN0nXvv3r3Ys2cPAoFAQfupz+ej7afbeZxJ/hvJfROLxQUCHDnp8+2nN8p/IxAKhRtEt+ItuPtZgGMYBjMzMxgaGsLCwgKcTidcLhcYhoHNZkNdXR1aWlrQ39+Purq6u27hiMfjGBoawtGjR7GwsIBQKASWZfF7wSC6UmlEZDJIrj3fSpUSKmWhLXg7JG18fBz/83/+Txw/fpxuqADAX2m0eEImQxa/2VoVb/LeeaNJawn3DkrcrQQCwt3kEhEUEiGcTicNz16/4AcAAVihBJxQjI7YEOrUTMHmGdmCkslk1K5IxDYyRGMYBoFAgA5MQ6EQOI6jG258wY3EJJBhKT+qwmAwlCx0t4B8Pk/z8siNL8KR4fVml3/8HLji7TeNRnPd4d39gnfeeQevvfYaPfeZzWY8+uijGBgYgM1mg0AgwOXLl/Haa6/RDLeqqio8/PDDsNvtqKyshFKpRDgcLsh/C4VCVMzZrMleJBIVvHbIxhu5EbGMbMCR3N+t8t8IyJZc8fbb9Tbn7kewLItwOIzJyUmMj49jfn4eHo8HmUwGZrOZDvWNRiOAdQdJY2Mj2tvbAYAuFgCgLpyt3l9IURrLsqiqqoJarYbf78fo6ChGRkbgdDoRi8XowFuhUODxxx/HwYMHIRQK4XA4aFlC67HjqFpdRUQqhVAkhFajhVanhVBQyMG3w90IXC4X/vt//+8YGhqigvs3ZTI8DAHECjlk0q235Ur87e6gJLCVcNN4Z8KDF1+/CkEuhURsPch1UwhFgEwF4fDLEK6O3/LPIxs8xdMg/po2OeFzHLehUVOr1cJsNsNut6O8vBx2u/2+Jwh3ExzHIZlMbpoDR8oYrhfcu10Qa8NmOXA2mw16vf6esze+9957+OlPf0pF5vLycjz//PN46KGHKCFyOBz4yU9+gvn5eXAcB6PRiN/+7d/Gvn37NhCfZDIJt9uN5eVlrKysYHV1FWtra/D7/bTVbTubhkSA4ziOim+/sXGvfyQXM/yttxvlvwG/WXsvbkDlf7xfCJ3f78fFixdx9epVLC0tYWVlBclkEgaDATU1Nejs7ERfXx+dYN9NsCyLmZkZnD59GhcvXsQjIyNojUQRvPbeJRAI1t8Pr20zmq6Ryu2StEwmg3/7t3/DP/zDP1DL6F9ptNgvk2HtWpCuSCRCeXn5hoy6myGCJXywKHG3EggId9MrpRBwDNxuNxKJxAahhROKAKkS8qtvQOydpOcCsnVD/sxxXEGmKNmqJoNSk8lEQ+X5mW1k4MSyLLLZLLLZLBX6yKCUDIc0Gk3BhtuDkgV7J0G2C4noxm9EJQUWW+XAkaE4fwOO5M2SIob7QRSNRCL40Y9+hOHhYXr8t7e349lnn0VbWxskEgkymQwOHz6MY8eOUTtea2sr9u7di9ra2usG2OdyuQ1Da8KXCZcrBt9+SngYyXwrHqISAY4Ic9d7zIkAtJUV9X54voqRSqXgcDhw5coVjI+Pw+l0IhqN0gxbq9VKI1G0Wi3a2trQ1NQEANRqms/nodPpYDabN+V3DMPQ72uz2WC1WpFKpTAxMYHh4WFMTExQIZVw6D179uC5556jm72BQABLf/EXEF0aRkwup7xcKBJBo1FDo9HSIebN8iqGYfCDH/wA3//+95FOp/GFdBp7BEKExGJotFqUldkgEm7k4iX+dndQEthKuGmQKWgmGUMmGUMiHsemR5FYCoikEF38MQT++R29D/wcq2IrHSGB5FYcMkrsjPy1eLJqTcJKN5sEkY/8QHpyu18EhTuFVCqFSCRSYBEhE71QKIREIoFMJrPpxHS7ILZIQrSJAEeyEkwm0weSAxePx/Fv//ZvOHPmDM1l6OzsxKc//WnU1NQAWD9ez5w5gzfeeAOhUAgCgQBNTU343d/9XVRXV2/r57Asi0gkAofDsWX7aTab3dZjTF4/5PVBiBs/T4RMr4kAt538NwLyOtnMgnovBjFnMhmMjo5ieHgY4+PjtBBBLpejoqIC3d3d2LNnD9ra2j6Qps1AIICFb34L4suXEZJIkM/nwV27+BGLxfQ1Adw8eZqdncULL7yACxcu4C+UqgKBDQDNwKyuqaZkrUTQ7h+UuFsJBPwNNhGXv2ZFz4Dj1s8JHDiAAxihGBBKoJ14HeLQEs0KBa59He8cQ/7OcRzy+TxYlqXCDD/egIgFMpmM8iXC4YhQQD5HPk+2aMn2lEKhoBtu5PxPNurI+YnklZL8qg+7fe5OgAifiURiQw5cLBZDMpncMgeOWCH5NlT+BhzJPLtXtuCvXLmCl156CYFAALlcDkqlEk899RSeeuop2jLu8/nw05/+FGNjY+A4DjKZDAMDA9i3bx9qampuSfQlkTjFsS3hcLigyZ5AKBTS1w+fjxULcGQBgZRrEYvp9a5RRCLRlttv5HV7rzxfxcjn83Sz7PLly5iZmUEgEIBaraYDevK+UV5ejo6ODtTU1IDjOFqgkEqloFKp6PtJMfx+P9bW1qBWq1FZWQmhUIjl5WUMDQ3hzJkz1E1F8pJ37dqFj33sYzQPzf3ii4gdO4acwYhEPI50JgP22vue2WKB6lpZ2q3yqvHxcfzlX/4lPrK4hN0MA++1wYdcoUB5eTkUReJhib/dHZQEthJuGiTHY9rpgySXAMAnZxwla1mREpJcDDXzr4PLr5+M+YIYwzCUkJGP5PM3K8SQjR1yIwSPED9C2jiOg0QioScqmUxGV7DJVo5SqbxpUka+J//7khtfvCsW7PiEkC/YfdimtNlsFtFoFH6/n4pw/Bw4kq93OzlwhCSQ5lliQyUZcBaL5Y5uV42NjeGHP/wh3G43gPUmoSeffBLPPPMMbRslttFTp04hn89DKpVi//79+PjHP04nXreKbDYLv98Ph8NB7adk+43YT2+mbZaIhcXhvYTIkdr0m8l/IyD/f6sNuLu9LUbAcRwcDgcuXryI8+fPY2lpiQqiFosFnZ2d2L9/Pzo7O6mgdbewTtKOI28wrNt80mnkczmIJWLYbGWQXXvPuBXyxDAMvvOd70Dwf/4P9onEBQIbgUQqRW1NDVQqVYmg3UcocbcSCPgZbBaVBOl0ep2HMde4WX6dPyUhhYpL4wBzGSKwlF/x+Voxl2NZlm495XI5ZLNZZDIZZDIZ+u9k6MlxHP3Ijzsg4NvkgN8MhMi2HOFZ/MwrcuOf3/nxCWTLh3A+8j0I/yP/v5inkc08ImaQn0vEiwcRpIih2IbKz4Ej2VDFILZHpVK5YQOOCHB3Owcum83itddew69//Wvkcjkkk0k0NTXhc5/7HNra2ujzfOXKFbzyyivw+XxgmHXr9NNPP42BgQHo9foduz9E/CEDVP72WygUovZqPghPKz5uiQDHt6PyBTgANyy+4DuGNhPg7oXcXo7jaCnC0NAQLl26BK/XC5FIBLPZTJtgLRYLGhsb0d3dTQssEokE/H4/YrEYZDIZzGYz9Hp9we+UTCbhcDgAANXV1VAqlQgEArh8+TLeeecdeDweJJNJ6HQ6qNVqtLS04JOf/CTMZjPcL76I+LHjkFRUAFgf6JIMv7LycpCfcju8KpPJ4PDzv40yhwNecOBYDgLh+gDdYrHCyOOrJf52d1AS2Eq4JZAWUamQg0zIgWHyyF8jZ3kmj3RegLxAhLr0HGozCzf9/fkkLp/P0xv5O//zxUSPVBgTkpfNZunJaDPBjUxZ+Scd/kSUCGSEoEml0ju+gcO3XfBFO34l+Fahp8WCHVmT/qBPgNcDydbz+/0FTahkCy4SiWxJ2LYL8rzy28xIboPNZqOtRLfz3Obzebz++uv41a9+RQs5amtr8bvVNTAuLYGcSROJBObn56m1VCaToaa6BlarFeoDdyZ8NJFIYG1tjW6/kfw3Yj/dKhj5eiCkrThLhC9a30z+G4FIJNo0+438XalU3vHjORaLYXh4GCdPnsTExAT8fj9YloVGo0FLSwsef/xx7N69m5K0O41ikpbNrW8Y5HN5mvsH3B55Gv/KV5A4/i5cW7zOBAIBampqoEokSgTtPkGJu5XAx41bRLNIZhn8Xo8ZT9dJEQ6H6VYN/+N2zxVks41wsXw+TzfWiFBVPCBNp9M0KyyVShWIeAKBgIpxQqGQ8jM+T+O7EUQiEeV6/O274s9t51KI3N/i3Cwi2JFtn61EO34wfbEFkHzvYnHxfgVpld0qBy6dTm+6cU+ODYVCsWkOHIm1uBMxL6QEYX5+HqlUCiKRCM888ww+YjAge/48AIBlOThXVuB2u+kxqdPpUFtbu55DeIf4Gx/5fB6BQIAOUPkCXCQSQTKZLHhciVNhq6w2sh3Hz8Ym10P87dKtQDKWt7omuduCaSaTgdvtxtDQEE6fPg2HwwGWZaHX66mQW15eju7ubvT29tJNs0wmQ/mwSCSiOd5EtM/n83A6nYjH4ygrK4PZbEY6ncbExAQOHTqE6elpJJNJqFQqGI1G1NTU4Hd+53cg/H//3wLuRsByHIQ8Hnu7wpf7xRcROHQYzmwWzLVhCASAWLRuGa2w2yEQCEoC211CSWAr4ZaQyTN48fUxnJzxgeU4aOUSSERC5BgW0XQOQoEAD9Ub8LWHLEhEI3Tywr8Fg8HbEkxuBrlcjk7WSM05IW3ZbJaKcfl8voDoEOLFt5wS8MU2vvhGbveSDY4/6S3esttKtONbA4ttsQqF4q7aYlmWRSKRQCgUgtfrpSIc2YAjmRa3kwPHr7gnOTL8HDir1QqdTretTTO3241/+Zd/weTkJDiOw8ecTrSHwpAolYUn1FwOGR7JlAKQ/9YTaPrOd27597gVsCyLUCgEt9uNpaUluN1urK6ubrCf3syGIdkcKBbhyKYhsQ3dTP4bgUAguG4G3E7athmGwdTUFE6cOIFz585hbW0NDMNAIpGgrq4Ojz76KB555BFUVFTcUdGvWGAj4EB1WwC3R9LIltwayyAYDKHY+y+WiNHS3AJ4vSWCdp+gxN1K4GM73O1AswV/94lOyMSbv4eSXNZi8a34z5uJcHyxLZfL0Yt/wpuK37dJOUIqlUIymUQqlaLnSzJoJdty5BxOhBjCXfilCSqVCkKhEPl8HplMht4X/o0M88gwFkDBx2JhrtgxQT7HF+42u9QiA11ioyUWP/KxeOuumKuRjSX+1hLhnnyhrvhGolQ+SJDHn5SWFefAJZNJ5HK5DZtawG+yyorbUEkO3O24QDiOw/Hjx/Haa68hHA4jHo/jD+JxtAaCEPM261mOW9/OzOdBnlmZUAjts8+i+n/+/S397J1COp2m4pvX66XiWyAQQDQa3cCTicDGd9vwjyv+8Ulbh3mv0+0cS+T43UqAuxOCKeG1IyMjePfddzExMYFMJkOvYfR6Perq6rBnzx7s3r2bRp8QAZO0gOr1epjNZnpMEWuvTqej51SHw4FDhw7h7NmziMfjkMlksFgssNvt+LjLBeHFoQ3crRg7IbDFjx0HZ7HA6XQilU6DZRgIRSIYDAbYy8t35OeUsD2UBLYSbhmZPIMfn3PglyNueKNpMCwHkVAAq1aO53fZ8ft7q7ckaMD6iSyVSlGBhIhu/D+HQqHbDs/fDLlcjk5HyYSU3Kd8Pk8bS8mFNCEu5CRQHDgKYMOWHcuylPTwc67IjVgk+Ft2t5NRdjdRTIq3EuyKp7V8G2Hxpp1YLL5l0kcIfzQapQJc8WSP5Ibcjg2VWBuIAEfCnEkbqsFgoBZjjuNw4sQJvPzyy3hkZATN4QgSKhUMBj3UajUE12QRlmMRCoURj8WgSaexaDEj9yd/gidVKjo1vRHudOU2mew5nU4sLy/D7XZTAhcIBGgo8naPX2ITIRcD5CN5XRHCtdmxs92iCzIB38qKeiuEzuv14tSpUzhy5AgcDgcdEJSXl9Nw29ra2t+0Hx8+jPipU9v63jd6DrcS2IpxuwIb+RmRSAQrKyv0IlkgEKC2rhY6ra5E0O4jlLhbCcW4Xe62HZBz8lbiWyQSQTAYpANPss1EhjFbDSlZlkU6nUYymUQymaTlCOTf+MNSfmujRCKhbgWVSgWr1YqysjJYrVZaqsAvWFAoFOsDsEwG6XSa2l351lcixBULdXz3BPk7Xyy6GbHuZrbv+IIdGWYR3soX8PiD4GLxjn/jC3b8jcPiTbvNbrcr4JHhNz8Hjp8Fd70cOPJ7b5UDp1arb+jqiEajeOmll3DhwgX81vQ02mJxsNdiR/jbWMlUCqFgELl8DppUGivlZaj/5CdhWXEWTr62wJ3mbsVgWRaxWKwgv5eISZHIemld8WNKuFixMEa4Gf/6hn+ckMd3M5G0GEQ03SoL7nZjc5LJJBYXF3HkyBGcP38esViMHtd6vR7t7e149NFH0dPTA7FYvM7dTp5CJptFJp0GwzKQSqSQyeUQi0XI5/JIJBKAALA+9TRMv/08gsEgjh07hkOHDsHv90MqlcJoNOLjLhdqvT4obpC3vFMCm6SiAhzHYW1tDZFIBFKpFLV1tbS1tMTf7g5KVYol3DJkYhG+8EgdPrunGiMrYcTSeWjkYuyq0kMuuTE5I22ESqVyS+JPRLhwOFzQwFP850wmc1P3nZwQtFptQZ5EKpVCOp2mxI5YF4ggx7JsgeBGLqTJ58iJ+0YX73yLAyF4pIGp2JIqEomQyWSQTCaRyWTofSSkj0/++CSQiHx3AnxhcKdAyPBmgl3xyZ1/8uWLdQqFAlVVVWhpadkQcJxOpxGNRhEIBGiuBb8NleSGXI8MkN85HA5jaWlpw78TcYjYUE0mEw4cOADF7BwQjiCXzcLn8yMRT8BoMkIqkUIoEMJkNEKjViN1bZX92LFjUDocaPJ4Ib7Bxhx77Tm4kyRNJpOhoqICFRUV2LNnT8G/kewLj8cDp9MJh8NB208DgQBCoRDS6XTBtip3bfq71euWHAt8Cw2x+5CLATJ13Sr/jbye/X7/lr/TZltw5M+bCXlWqxWf/OQn8dGPfhQXL17EL37xC4yPj2N1dRVvvPEGjh07hu7ubnziE59AR0cH4qdOIXroMIQ3IIfbfQ7ZbBY5l2tb3+t2odPpoFSpsLS0hEQiDqPRCJ1WtyPfu4QSSvjgcLvcbTvgbxnb7fZNv4bwu0gkQjeo19bW4PF4EA6HKechAz0iIJFzPrDOpchmWzKZpFEQ5N9yuRzi8XiB4Ea4wNzcHG3FLG5UFIlEVHQrFt8IZ1Or1dsWkghnItZI/sdioW6zzbpMJkO39YiAxx9obSXWkZ9LMm63I9bxb3ybIF+k42+l84U7fsYx3ypLeC0R9ooFu63EOpLVVvx5wiHIc8kX4YhbheT+bnZs8gU48pGfA6dSqfBnf/Zn2L9/Pxa/9Z+BWJzaL80WM1RKFQBAqVBAYbcjEomAy3iRSWcw9m8voTsWg0SlKnAsFONucLdiCIVC6HQ66HQ6NDc3b/j3fD6PYDAIr9dbIMAR+2koFCq4rtisSI7PycgxwG+yLxbgyPUKef62ut/XE+BuJJgqlUp0dHSgpaUFHo8Hx44dw3vvvYfV1VWkUim43W5cvHgRvb29ePbZZ2E+eQrRw+vcTQgAHIc8yyKHdd1UIBRCKBCATaexxrIQ7n8ERqMRv/3bvw2j0Yg333wTc3NzYBhmPW8tlURyaQliiWRL3XWnuBuw/ryUl5dTdwgR10q4eygJbDeBdI7BlZUw4uk81HIxenaQjNzPkEtE2FNvuiPfmy/CXY+kpdPpTW2o/Bt/0ln8M4iYo9fr6YmbL7iRTAi+4EZyDoi4RjK3gN9MY8j33UpwY1mWCjxbgZwQicVBr9ejqqqK/tlgMECn023IOCAr+GTSS0go+Z3I9Jf8nfzOxWIdIXtbrevvFBiGofdzJ0DW2osLLchJnxC+iooKNDY2UrGUkF5SV09uZGJ6PVszmVTHYjE4nU76+WfTKYDjkM6kwbLrk/1QKASlSgWtVgO1WgOxWAylQoHmpmac02rWnw+GQVYkgtFkgnQLy/GNBJc7DYFAQElpY2Pjhn9nGAbBYBBra2tYWVmB0+nE6uoqPB4PgsEgtSzwp6Y3OhZIoyx/ws63ohJxlbxuN8t/I8f2Vq89sVh83Q24hx9+GA8//DDm5+fx6quv0jap999/HxcuXEBzczP+UzQGjUSyra2zG0G9f/8Nv+ZWvvZ6kIjFaGxsQCgYuuulDiWUsFMocbfNcSe523bA53fl5eVob2+n/5ZKpRCNRhGJRBCJRJBIJAoyduPxON2GC4fDdNiXz+ep2LaV4BaLxWhLKeFuicR6YRdfcCMOi61ARLithDgiwhExgfCQnUI+n98waOVv2BWLdXwud70Nu2JstllHfiZfrCvOOt7MKgv8pl2Wb4/l57jyt9qL7bLkxg/xJ+d94iggjzefq/OPCfJ3IiRxHEd5ARF/JBIJ3YKvqKiA2OsBBEAmm4HbvQqNWg2zxQLxNTeKXq9HJh6HxWJBIBBAhuPgz+dhMBqhu5bxVYwPmrttBrFYTCNRNkM6nYbP56P5vT6fj7qOotEofD5fwXPNz7QmQit/+41wcX6uIV+AI9d3REjfihOSttTrFTEQTlhRUYHPfe5z+PjHP44zZ87g7bffxvz8PAKBAA4dOoT3338fX2ZY1IpEkBXnpvHywQVCAQQ+P2QyGZxOJxKJBOx2Ow4ePAiLxYKf/exnuHLlCq6KxWC1uvVjqbICZWXlEAo3l9l2irsRkCbcEu4+SgLbNpDOMfj383d2nb6EW4dAIKBv2FuJcAC2JcKlUqmCN2qDwUAtCUScymazGwQ3ssoO/GabLZvNFghu/I2sm7GnkRyBUCi05dcQEc5gMGx5q6qquq2gUZZlKamKx+NUpCsW7YhIR7bsNtuwI9PY4vawnQQ/MHmnwS/eIGSVWCQ5jiuwAovFYqTT65taMpkM+dz6/8mxLKKRCFKp5DVxSQB5Oo1EMoEDH/0IxBOT4LgAEsn1sGeNVgujwXDfhR+LRCJqo+3q6trw75lMBl6vF6urq3A4HHSLgWSHbGbtzeVyiEQim/48ciHDb37jE3Xy+uNPyDfLf8vn8/TibjMIhUIqvO3fvx8PP/wwzp49i/fffx/hcBhjY2MYjUSwK50BFwhAp9dBLLr1U672mWfuyqR7sy05DYB8KlXwNSWUcK+jxN3uXxBOZ7PZkMvl6JZSIpEAx3FQKpU0c0sqlSKdTm9qR/V6vbRROxKJFAhuZLurWHDj87niBkWySQf8Zni0nQHpVgIcX4S7WYjFYmqB3AkwDEP5WfGm3WaiHdm643/k86HtZiyzLEujVW5GrOM305LHg992Tj4W22T5m+/kZxT/PnzRUejxoIIDZHLZNT6y/lhEohGUlZWt8wmBEAKWRU1tDYxGI7hrm14+nw+xaBQWiwVSmWw7rtF7GnK5HFVVVaiqqtrwbyzLUjdDcVEZyYArPib4USCEt/HLQPhDcf7zS553hmFozA+5FtkKm7UE9/f3Y3BwEIuL/197bx4mSVmm/d4RkZF7Zi1Za+/dNJv2AraAAu0yKIsDIgKfIirtModvZlyOgh6/QWQQ0etCcAHUQUbRo8wRdRSEYVoE2WUVmq0ZGpteqqtrr8rKfYuI80fW89YbUZFLVWV1Lf38uJLKzIrOjFwq4877fZ772YO77roLzz//fHmBfGQU7aUiRnt7xQRSYDKvjuKESqaJYqGIQCCA0dFRZLNZrFq1Cps2bUIsFsPtt9+Oxx9/HM8VCuWFhFAIp7/9bTj99NPFbTaSQ9nlwFSHM9hqkCsa+OqdswuEZRYPZMJVa0lNJpNTDDdCrnCTp2XJGSDA5CrpTAy3maIoSk0Trqmp6ZANL6DVZFo9llcW5edXbol1Gnay+CMxNJdVdtWQ8/ucmS2fzGTxVtPEACzAmnifmCZgWfDoHnh1LxRFQathYGc4hDuXL8cF/QM4OjGOAcuCaZSnAamqimAgiEg0gkAgCE1VYQ0OIvDOd6DrmmtES/FSwbIs0ebhbD8dGhrC+Pi4aD+tN/+NTDYy4OQBJfQ7eYqWvOpaT/WBYRh45ZVX8Ne//hXn9/fjraaFYVWFR9MQDIXQ3NyEgD8ApFIwJ9ohzEwGnmXdCG7Z4nqbhyqnpZGZcczCgLUba7elhGEYtlB8wzBsgw2qTZjO5XLo6enB3r17xTRtMu3kSnU6npAZ5/P5bBVOzoyy2eaNqaqKaDTqar7R+UgksugW1+QoCFmrOXWbrJXkqjo57kRexJzuPtQzhIIu02tJz7WiKDj5+eexsq8fyYC/bMblC8jlc/Dq5TZTRVEARUE4nUb/qlVQFAVdPfsR173I5rIwzfK0yGAohGbS2IoCZWgIOPEEqJ/9bF3tsm7XO6+b7+EV1SgWi8L0ppxkaj+lQWWyfpeHszmNMdmAC4VColPIOciD4n3qMXs1TcP4+Dj+8pe/4Ng//xlvTqUxRG3poRDa29rgDwRgJsZhJsvazchkoHZ0wDz66PJ3v1IRHo+OaCQC3atDP+FE3Dk0iPvvvx+jo6MIBoOIxWI47bTTcM4554gJpo2A9dvCgg22GtQeaV5Etmjg0ncegU+dunYe9pA51OTzeVvV28DAgBBrfX19iMfjIluKDDeqdpINNznUX86coi/88wGZcNR62traOuV8c3PzgjRxLMsSq6fV2mKpws5ZaefWFkuG3WyHT5yzbz+OGh9H3KvDsgDLNGFaFoyJqbVQyvvfZph4Ttfx04Afn8zm8JZiEcOqCsM0J6ZVlfcj4A+I90hToYBdTU34r7VrbCu2co6dW4adLFho1dA5fIJWfBcqcl6Is/10ZGRE5OrVu5JOeSLy8A25DZXEnmzAyds7jfK2X92B9j170C8JeXUi/2WlzwcrlQIUBTBNKLoOdSJXSMYsFBA960wOpGVmBGs31m5LFXm4EUUNaJpmy9KqZkpZliW+8FMuK7USUo4XLa6SRnBWTcktcNTyNhcmh6qqiEQiVU24aDS66Ey46SBrPNJtdF426pxtsfQ7eYItGXb1dFCc8NTTWN7bi3QoJK4zzfKipwJFGHXhbBZvxGIwTRNHjI4g7vXBMk3kKQLDssptp+EwFAUIptIYWrMGO//u3baKO3lggNwuSz+dw9JkKKdwukZdvdfPFdQKOjIygqGhIZH/Rn9/iURC/P0R8hATp/lGbac0bZYKGZwZcLL+lwnd9jP4X3oJgxPvOdMyoSoqfH4/Vug6kE5P0W6mZcEyzfLiOQDNMND0vveh85qv45577sFdd92F/fv3w+/3IxqN4h3veAcuuOACtLW1zdnzyswf3CJahVzRwN0vHCyvfgZ09Pb2IpvLlieJUOWR7kGxZOF3z+zF+9/Ugljz0j7AMeUy466uLnR1dbn+PpVKoaenB2+88Qb27t2LsbExIdbS6bSthUE23Ch3Cph64DhUJgcJzkoh/kD5AB6NRmtWwh2KqjznflH+ViNaJuSKtGw2i3Q6bauqc2bY0Xm3qWPevv6y8FFUWAoARYGGcsaVjFooIBQKorO9Hf6+fiilElRNhaKq0FRV5JSpLganaZqiKrBamXy9UE4KDRZwhhbL+RmyYSeX+JNZR8M7SCA26jNSzgvZsGHDlN9ns1kxUr2npwd9fX04ePAghoaGRDm/nP8mD+9waw2Vg3aDwaBYOZVbT2gog6Zp6NDKK7CxcLj8fpkQiKVSCYauQ1UUKLoOq1iEGgy6ZrUtxJwWhlnIOLUbUJ4K6NE90D3lz5+I34N0oYS7XziIi09axZlsixB5kEJ3dzdyuRySySQSiQTGxsagKIqYIBmNRqfoKEVRhGY5+uijRSYuHTNGR0enDBKgyahk3KiqKq6nKjhaTG2k4Waapogs2LdvX8XnIxqNVmxFbWpqQjQaXZALpPXQaI0HTGYVVzoVCgX49+2HZ2AAwWBwsgLOsmBZJixzsipOzeVE26kWj8Pv88GCBZ/fJzobAoEAYFmwLAATeqNa/IvbYAoyeEmjyQME5MvyiQw7p5En6xUyn6oxG7Ou1raBQAArVqxwXQii9lPKf6P2U/q+Mjo6OmXwBxlsbuYbLZTSQip9PohswOZmWD4flrW2Ip1JI5lIIl+YqL60LOgA1ArazZiI0yn19+NgXx/00VGcc845aGtrwx133IHXXnsNIyMj+NOf/oTx8XFcfPHFh9Xi1+ECV7BV4ck3RvDFO3bAr2sI+TzYtWsXsrnJL64KJj7kvH6oHh+W99yPFb4cOjs70dHRgdbWVrS1tSEWi6GpqUkc5CORiC3rh1m6UIsbjcQeHBwUX+ZLpfKYZ5rSk0gkkMvlRC4ZGTfA5Bd758j5hUwtE665ufmQm3DzBY3P9ixbJoSSYZrlSjbppA4PI3nssdh79t9j2e9+j+bXX0c2HBYrY6ZpwjCNcpspyrfTlMvjtaYm3L26+gjwhYAsAOXqOjLu5CBcp2knTwqlExl2M3kfWZaF8fFxDA0Nifw3aj8l4UYCu96cQHmoQigUwtn79mHt0DAy1EYCiL/pNsuCJ58HPB6ohgE1GoW3gsFWbaQ6tQUUDxxAcWCg6v7pnV3QVyzn9oDDCNZuHpiWib+9/rey8aEq0FQNmqYCHh8Ujw8fXJbEiWsnhwhFIhFbNS8vmi4+SqWSaCNNJpOwJqqHKLetnvyjUqmE4eFhod3i8bjt92Ss0eR5qp4LBoNC31Gur6IoME0ThUIBiURCtMQdamgoUbVMuMVsws0FpN/qGVYU+rt3wzRNpB98CGpXJyzTgjlhxMmmnGlZ0EZGkDz2GOw7++yKwyfmMqdYnhRLFV2yRpONuGqGHf2ehghQZR1drhe5+m66Bh5NIR0bGxO5iDR8IZlMIpvN2gZpUZZapfy3Nff8FyL/8z8wYjFoqgpFVZDPF5BJpxHJ5aAXCrA0DappQqug3fIHDiCzYQMOnPt+LF++HKtXr0ZfXx9+9atf4cUXX0Qmk8HJAE7QPDgyGoWeSlZ9fli/LS4Oj2+3MySVK8EwLehaeSR0ySjZfk9fcFEswFR1vHGgD7sPvFj+4PEFoXethzcYhV+z0GTEEQ74RMWD3+8XrXexWEyc6OBP5e1er3eeHj3TCKjlkkZiW5ZlWyEdHh5GU1MT1q1bJ/5NPp9HMpkUrQimaUJVVeTzeTHGOpPJiKw3wzAWpOFGbRuVVlsBIBKJuLahykbcUjHhzEIBpYMHxWVl4iR/dTMBrF6zGm+/6CIcfPllpHp6EO3srHibFiwUe3txyimn4LTLL7NV2WUyGTFFVq6uc06KlaePkZk0V6KOWm7lEOnZIK+8ymadc2qV3MLjrLDz+/1Yt24dNmzYYKtGK5VKol2ht7cXPT096O/vF6unyWRySi5MqVQSX+gAYGx8HKsmhB8UQFFUqGq5krFYLEKbaCH3mCay6TTig4PQvTpCwVDdn/2pRx9F4r+3A6USrEKh3LbghmWhdLAP2ZdeAgAWaMySRdZuAGAaEwsTACzTQsksoVQCrIIBywv87p7/xh/jbwhzPNzUAqtlFfRQFBG/jiOaNbREw7acrJaWFoRCIVHJOheB1czM8Xg8aG1tRWtrq6h+SSQSGBkZwcDAAHRdF3o7JC2AOG9D7lYoFAoiP2pwcBDJZNI2ZMA0TSSTSYyOjgqjobm5WVQ1AZOTR1tbW8W/I8PNbUhDo6aqE7Tom0gkKm5DJpybASf/PJxMuHrD48WEVEWB7qneeVJMJhFbswYnX3RR5ducqHivNCxMboV15tjJJznHjrSd3OosZ9NRtd1McWbCycacm1knXycbe3RZzpZTVXXKVFrnfbe1tdnaLikLkCYP098WabWhoSGhfek5eV/vAawvFpFKlAejqIoq2sGViefGME3ANJFLp5EcHUEwEITX54Wmlv8uVEVBV2cnmjdswL59+/DKK6+gra0N559/PsLhMJ555hms2bsPK7NZlCb2s+L3ONZvi46l8c11jgj7PdBUBUXDhKYo8Oo6YE22YgHlL7hQNcA0gUIWUD2wjnwnSmtOgBFoQk7VkDANDObGoex/DtrOv8CvT5atymOnKTtJvi4YDKK5uRmtra1TjDgy4SKRyGF1oFvMKIoiXsNjjz0WhmHYVkhHR0enjHKnKVdUxuz3+23ihg6+tGpEBxNavaEVnIUIHeBqmXD0N0AVBvL5lpaWBZ0TBkxv9PZ0tlUm/tO9OjqrGHH1QuKKBFaltlh5CIVz6IRT6JGpNhfDJ+TbbUQ1AK2MUlusM8vO7/fjqKOOEl+mKf8llUphfHwcyWQS4+PjSKfTKBaLUOSZYRZgWSYMEzBgwKJjiGnCQjkAeGR0BKqqobmpMJkZYk5OV6skvlSvF/B6y6PjK/wtUCsDwyx1ZO3m9ahQFCDWGhNfMMUEa1UDLBOqUW73KxgWBlvejELLJpi+CBRDBdIWHo6n0ZLYjWXZ3Qj5J4eiyG3vXq9XGDZUvU3VQOFw2FZ1yxxaaIgABYrLuW0jIyMi34x0dSU97fV6bS1sFEFAp2w2i5CU00Vm1ujoKLxeL1RVRTgcFtNHad9IxxxxxBFTjCu54q2SCdeIxSoZ2YTr6empuF09lXBLYYF0uvqt3qD5eqBKq0Z9blD0CWkX0mry+UpmHbVFy9V29FnqbKWWF2id+WbTxVm1Jpty9F2Zqu7czDq6rqWlBZ2dnbaKPSpioLxFmnoaGhgERsdgGCZKxbL+okfYXCpBkx5vqVTCyOAQNI+G7q7ucsYe/VIB1q1bh+bmZhw4cAC5XA5+vx/veMc7AAD+/n6UcjmUYCFoAdaEMT/ldWP9tuhY/J98c8hxK5vREfXjwFgGoaYA1q9fD9O0YJoGisXJFYXxogItF0ckaGJ06ydQaj8KgAIUM7CKZdMNoRisY96DYrQbuaduB+JxsdJBK126rovMBvmDwzkZkIK15VMkEhEVQNSa2traajPhQqEQtzosMDRNQ2dnpzBHisWiyBgYHBzE+Pg4vF4vYrGY+DeFQgEHDx4URoDf758S5huJRHDkkUcKYzYYDIpcErl0Wj5fbUVzPiETrprQC4fDVdtR59uEi5555oxWnA71yG15alo4HJ51+KocSlwsFl0NO3lqLFXYuZl2JPBkw262wyecyDl2M13B9Xg8iEajKBQK8MTH4QHQahgAJnJXAMCy4EW5glGb+Fk22UpQFQOJRAKpVAqKoqC5UMCLf3kC1332s8Kcb29vR2dnJ9rb2xHMZmFN3AbDMA7t5vNA08pZjTKmaeJgPINm3cD/etP70Nvbiz8nO5DW2svTnUs5WEYBFhQUvVEMth2HseEQwq/eDa9n8kseVb3Rl2D5ix8wqdfoOgq4pmo4MuOamppENVwwGORF0zmEzM6uri5hYCUSCRw4cED8nl6jasZGIBDA6tWrsXr1agBlrUJmG1XFBKUvxZRxS1U9qqoiFAqJlrbdu3dDVVUx7Z0q+p0VOU6KxaLIZ6tkxKUmJlY3knq1WbXBDPOR1ztdpqvfUo8+esi1W73I3x9lM3im0CA3MuHk4RNuhp1s0jnbYSu1xZIuk1s8Z4OoMpQWVOUhEhTr5O3tRZu0H2SyaZYFFYBCg6smDEDTNJHOpJEv5KEqKrzFIgYHBjH84ovi74DazNeuXYtIJIK+vzwBZTxRHn4GS+jahf43wdSGX8Eq+HUN52xehlse3o1EtohoQC+396gqPJ6ysEpki1CKBi59/1tgnP9O3PLwbuiKBc0sIJ/3TFRwlGAYeRThgblyEwL5USivPyRKdSlIPZPJIB6P20w3t1Ja50lRFAwNDU3ZfxpxLAdwywfuWCyGtrY2IfLIiON8uPlD13UsW7YMy5YtA1DOayLBNjAwgHQ6PWU1q1AooK+vTxixFAA7NDQk3heaponX/aijjkJTU9MUs7VUKgmzjUJDnecpy2ShQa2z1YReKBSqacItpOqCuap6O5TIocQA0NLSMqvbkwdPFAoFm0nnNOvk4RNk0jkr7GiyGBl2jWqLVRQFPp8P+2IxV2PXsiysTKXQXCwiP1F1mkXZJDMtE7l8XhhmYQADgwO47777bOHGJOr+0TTxVsMEFAXBiWw/RVEACkGW2isY5nDATbs5SeUN6B4PPvqOo/GpU/8eP3lsD4oP7UanBwh6FFHJWywWUTJKyBkqSt1vQnMIaBt9SXyOZLNZJBKJKVUW8tATWYPpuo4Bl6xEuTKDuhcosF424ciI43y4xuD1eoWBZRiGGJIwMDCAvr4++Hw+YbYFAoGqn6WkoY844giR8SkbboZhTDHcxsbGxBd8CnqnzoM33nhDxIzIhpvzy7eu6zVNuFKpJEy3SkbcXJhwpM3IvHQjHA5XHcyw2PJ6l4J2qxe5BTTYgAoruU3V2QrrjDNxmnZu5l2pVBIaT76PWmZdPhhEd4XMvbahIYQyGZQm3pOmx4OgrsM0TeRzeeSQg2UB0Xwef9v1Gu6/6SbxHKmqKoajdHd34216+TqDBm2ZJoyJFliPx8PabRHDQw5qkC8ZuOL3L+ORXUPliVR+HbqmomiYSOSKUBUF7ziqHV87+034+E+fxoGxDLqbAigWC0il0uVJJR6tnPtRKmEoXUQIebzbehH5TFKEsKZSKfHB4CzNpQ8GOTiSRJz84eY03uo9IMmltbTC6pYPR0KP8+Hmj3Q6Lcy2wcFB1yobOtDIhpvzwEfZIGS0Njc31yXUSaSRAHQ7JRKJBWnC1QOZcNVaUuX2XWbxQ4YdVa2RUedm2rlNinVW2NEKbK222HP27cdR4+MYd/kcFVPKTBOtpRKe0z24xeOZkvkGAJ9RFJw0YccFAdhk44TBBkWBDkCLRACg6tAEZmnB2q26drv2vA2wLODDP35S6DcLltTeXTatjZKBgWQeMR/wybVJJOOjNgM/mUwil8uJFtR8Pi/+TulLo/x5QIacbL7V84VKbnmi7gU5L4vMGGpL5Xy4mWFZFtLptKhuKxaL8Hg8opU0HA5Py9yUJ5RSHIhzQYemWMvZUz6fz3Y/8hR30iWNMp9I31WrhJuvuJFQKFSzEm6hR4Uw84Oc/VtpodXNrHOrsCsUCtj82ONY1tODdI0KwFA6jQPLl+HRTZumfLdPpVJIJpP4yOgYNmazMC0LAdO0ZbHRd32PxwOrWGT9tshgg60O8iUDv3xyP+5+4SAGEzkYpgVNVdAR9eOczcvw0betwvP747apVUNDQ+jv7xe3QX8oil6eWnWCsROrAgVxQKBqDMrQMk3T9qWOXHgSb/QHD9j73alMVS4zdTPfZrLq6bydSvlwNDGV8+HmHlohHRgYwNDQkHhPyND7hoQaja6WoRyQWCwmVkln+roZhiEEWSUTbnx8fNGacMFgsGYlHJtwhy80fERumchkMjbjLpvNouU//gPhna8iF42WJ8TSpNiJn3QciORy+FtLC/5r7RqbUCSR+OGREWzO5WFhwmCT/67YYDvsYe1WXbv5PNqUqaP79+9HNpsttw95NOgeDzweHZamw1Q8+PxJzXjLyqioes3lciKCQa6elc14c+JvWq5Uo2NssVgUhn6pVIJpmrZFVNKP00HWfl6vV5hwdCKDgvPh6oNe40QiIbR6OBwW1W3TNbloiA4ZbmNjYxW3k3W704SlgQSy4TaXRpNhGHVVws2HvguFQuL9LVd8yuf5Pc7MloP/cgWSf/4zPN3d9gmxpNsmrlOHh5HZsAF953/Q1jFBRls6ncZR9/0JqwcGAEWBt1hEEQBlidBnv9frZYNtEbJ4am7nEZ9Hw6dOXYuLT1qFF3riSOZKiPg92LyyGX69LHqcU6sKRXtvvTDBDBPwqnj8r8/hqcFXRYURTbzzeDxiW7lHXJ7KQgdb0zTFYAQ6uJdKJZt4I7ONpgmSgJNbjOSe/GorqGTyUZj4yMiIa0tepXw4uSKura2N8+EaAImGI488UrQbUHXb8PBweULhhNAmSPh7vV5bGzLlgADlD3bZPG1paalb4GuaJl7nShiGIcJ/3VpS6edCNOHo76i3Sr5GIBCoacJxVcHSRM5GrMbBv/wFiVf/B363AOCJFk+oKkyvF6eccgred8W/iJVVyrLLZDLQb/kxvC/sgGUBajYL3eOxTdlSVLVci7MA/5YYZi6pR7sBU/UbmVwij3FiO0vRYHkDuOnffouW7AHRpklGSzgcFnqMFjlp8coZqk2mWqFQgKZpiEaj6O7uRigUQiAQgMfjEX/vZKZTpIh8UlXVNlGcoN9TlfvIyIjrcyRXwwUCAVGhRSZcS0uLmLR5OOfD0WCyjo4OFItF0Up68OBB9Pb2IhgMiueunmO7x+Ox5e8WCgUxnZQmlNJ2VNkGlKtb6LsC6XjnxHaazj4X0Reapoms50qQCVfNiJuLuJF0Oo10Oo2D0rR2J8FgsGY7KptwTFUUwCoUYPT1iavcvr2aloXlK5bjhPe/v+JN9cbHkbj//nLVmmFA9XhQMgxgIu+Nim6YxQcbbNPAr2s4aZ27aeCcWkUrTVOCuCemVqGYtVUjUPYBHTDJ9JDz2EjwydVqFDBJ9yOPXXbL81i+fLkwtTwej8gQSafTSKVSyGazol9dbkmdDtRuRQwPD2PPnj22bZyDGuR8uJaWFpEnQeKO8+FqoyiKED40oZRG0lNLghyeSTkE+XzelgFCpi5Vm+3evVvkgFCF22zbEigTrloml2maNdtRx8fHG5ab1UioiqGa0PP7/bZcFbfzbMItXaab0xKZWL10cvAPf0BqZ/kLgZHPQ6nwd2m5VLcyzOFANe0GTNVvoXAImqahZJRglAxRfQZVBUwTxWxSGAaEnJXr9/vF8ANd14XhbRiG6ECgfEr5uAtM6jpgsi2OBpuEw2FbdZJ8/JanilOHQzabRTKZFJUTpM3kYyZdR1lybvlwzkENZCY5hzQ0NzcfFvlwuq4LrWWapgj7Hx4exsDAgJgoS4vH9WhWr9eL5cuXY/lE7lM2m7UZbplMBkD5taDqaGByMBHlO8n7Q4ZbOBwWuq21tXXODaR6TTiajlqpEm4uTDhaIO2TzBEngUCg6nTUpqYm7lI4jGlkxp6ilgcumJYFExOftR4PSkYJMMvHjZJhwMMm26KDW0QbRK5o2DI8ZEqlIrLZ8grkWN6Ct5jC6t2/RzI+KgwtedyxW/iic+Ioudr0U85po98DELcnZwKRuNI0DYFAQOR0kIkVCATK2XFSiwLdnnNijGEYUBSlonibLc5BDZQPR6Kz0qAGXoGaCk0oJcEmfzmQMU3TNr2WVuFlKAdErnCbj/wL0zRFa06lUzweX5AmXD2QCSefnJlwtYKXmaXNwSuuQOqBPwMAjGQSSoW/Q24xODxh7VabavoNgJgc35/IIaqVcJbvNYwM9otpv5S9RhEe9JMWKsmYkieJ0rGVugzop3yslfUb6T2PxyOmPFOlWSQSEaYWVVrJkwJlA4Yq6KgCKx6P23LG5P2f7nFTbn+lYUtyayotHtH+LrV8OMuykMlkRLYymZ+NiExJpVK2gQl5t8pnQEwBpwVQt9iQUChkM9wWqllEQycq5cHRe3c+Knzk7yKVjLiF+rwyCwenfoO0QGr7vm5ZUMNhaJrG+m2RwBVsDaLa1CqPR0ckosPy+BH1Grj0ncfhU6d+DOl0GiMjI+jr60Nvby/6+/sxMDCAVColyqtp3LE80UoWb2TGkRlGFW6ymUamG5lztB1VMI2Pj2N0dNRm0smCkIQZ9YJ7vV5xECcThtoHZDFH2STxeNwmQOlUz0HRNE0RQEm4rTw5BzXI+XCtra1oa2sTYf6Haz6cc0JpPp+3DUygKkpqOaH3gzwJkn5HE7LGx8dFdWI0GhUVbq2trYfEcKP2m+bmZqxdu9Z1GzLhqPV0dHTU9fxCNOFyuRz6+vqqrrb6fL6a7ahswi1tzEIBKJUAw4BV6X1sWTAzGZuAYxim9tRRVdWQN034fT584p1vwqdOPddmplDl0tDQEEZGRkT2WiaTQTqdFpk7lM1I2kzuEKDKNdJmcscCaThaPDVNE2NjYxgaGhJGmBz7QSaerNV8Pp8w3/x+PyKRCNra2rBixQp0dHQgFovB5/NBURTkcjmk0+kpbX5kKDqNONJypmmKxdd0Ol0xV0we1EDVXnSi4zlptcWUD6coitDB3d3dyOfzwmzr6ekRv6fF4Ok8pnA4jHA4jHXr1sGyLFFtSNPi6buAUy8Hg0GhyS3LEhED6XRaRLwEg0Gb4bZQTE8axtXc3FxxGzKLa5lwjdZ3uVwO/f39tqxtJ36/v2Y76kJ5rpn5Q9ZvME3RHqrB3npqpNOwFsHnIFOGK9gayHSmVvk87saOZVlIpVKi+mZgYEB8iLuF59LL55x04mwBkMcdZzIZmKZpy3ejn/J+OCfhyaG5ZNZVgww5WsWkFctwOIxAoLxKTC0O1KoqD3Nwtpo2AuegBhJ0JCza29sRi8WEAIpGowgGg0u21UEmk8kIs21gYMB1QikAkS/l9XpF9aKbeIlEIjbDbSELZBKstSrhqk2FXMj4fD7XNlT5FAwG2YRbhCS2b0fq0UdRPHAARZf2Lhm9swv6iuUIb92K6JlnHqI9ZOYT1m710Qj9BpS/8MvTJ+PxuDDeaOKonKsmdyzQ4iZVmhF0ezSplAZfyTpOXjylBVgy3agLgW7bDTLmfD6fqDqjYwPpora2NoRCIRFZUiwWkclkRJQDmXCJRELk/cqLwtNF1ptu+XB0Wgz5cKVSSVQNplIpmKYJv98vjEXSxDOBDFeqcBsZGXHVZDQQQV4Ep9ZTmUAgYDPcZrNvCwEy4apNR52vuBHSZrUq4VibLU1c9Zs1OfmUFjxyuRzSfj+y0Qi6zjoLb/3c5/g9scBhg63B1Du1arrIrXBUcTMyMoLh4WFkMhnxB0gGHFWMOaHqLsuyxPaZTAapVAqZTGaKGJJXUuU2Ufk6ueJNbk+thdxKQCfKLqHbkgUjmYc0pIFWcsmUaxQkRMmE8/l8woijari2tja0trYu6Xy4RCJhq3CrZHZ6vV7xmqmqinw+72pEUQ4ImW6LrXzesiwkk8mqJtzY2NiiNeG8Xm/FNlQ6zyYcwywuWLvVz1zpN6D8hSmVSgkDanx8XHzpd5pupOEKhQIsyxIDCOR8M03TkM/nxYCgeDwuBlABsFXAkU7K5XKiDZU01HQ+z8mEI61GUxtpGiktoNKgBrnyjsw4euxkxuXzeVtF3HSPn6TXSLNR3Ik8MZWOYwslH44W0um9QMPJyGwLh8OzOs4ahoHh4WEMDQ1hYGCg4sAoTdNEnpimaeI96sTv99sMt2AwOON9W6iYpolUKlU1E24+Tbh6KuFYmy0tEokEDhw4AAB45ZVX8PTTT6NUKsHr9eI973kPTjvttMOi+GOxwgbbHJErGlWnVjUKmtbjrLRJJpNCUMmmG52niVSEz+cTFUnUYlosFpHL5Wy5a2TAkQhyZrJR6yfdttxiWi8UqEsBwGS+0UqmvEJL2zsHPjgHNcgHnkZ9IDkHNVA5OJlIdHIOaljIlVyVoAmlcgZIJSFMXwTocWYyGddtQ6GQrcJtKZTKk3Cu1IZKp0YawocSXddrtqPWG+rMMMzcw9pt+hwq/QbYK5tk8y2Xy4mWS9l0k404VVWFYUTHXIrmoDbATCYj8tvISKHPZ3loFjA58IDuI5VKVcz6cuJcMKXzNO2SFmJlozAUColtaEGVoktyuZxosU0mk2JxdaZZv3I+nK7rIsPObWLqoc6HowETiUQCuVwOqqqK/aOBZLPBmb+bSCRct6NFNnqfFAoF1219Pp/NcAuFQrPav8UCLbLWqoSbj0VWr9dbczADR4UsPgqFAvbv349sNos9e/bg8ccfR6FQgMfjwSmnnIKzzz571p8PzNzABtsSpVgsii/18pd7Kgcn4Sabbm7Vb5S7Rf+GzLRq4oaMJ8uyhFiTTS4y6AzDsGWT1COYKLvDKeKoekqGjDbZFJQnrpJRJ7dR0E/nVM3ZHJScgxpoMhgZS1QVR+G/iyEfzjRN24TSkZGRipl6JFgpk43MXyfBYNBmuC32toRKkAlHlQf0NyqfHxsba3h79KHC4/HUNOFmu0LPMEx9sHZbnBSLRddqN5oeCUAYUc7qN9IyNPRAURQYhiEMK2rflKHhCXQKBAJob29HU1MTPB6P0Ijj4+MYHh4W2bpy/AhVozm1gJte83q9VTUOGXJer1foB6cRR22qcgyK3PFQb9avDOk1ej6oLdU5LZV0TaPz4QqFgjBc0+k0LMtCMBgU+9CIyv9cLifMNnlCqZNAIGAz3GgqrdvrKxtu4XB41vu4WJH1XTUjbr5MOHngiJsJx10KCw/LsnDw4EGMjo5i//79eOyxx4QR/5a3vAUXXHDBousIOhxgg+0wg1oK5Gq30dHRKauUJNycxhtdJjNLFjm0SloLWThROTytFFLVnDNrLpvNipVcEnHOAxQZWHLLKQmlSvvlFGnUqkDGH7WfUtZJJSNOFn/TaZN1Pi+V8uEo/yQWi4mpYQspH65UKokVUmpJcENRFJGZQqG7NL3MSSAQsBluS7EtoRKWZYmQ6GqnxWzCOdtQnflwkUiEhR7DzBLWbkuLfD5vM9zkNkMZCrQn061QKNg0jDzoSh7GUAky3iiTraOjQ5g+qqoimUyKAV2JRMKWGSyf5P2hAQlu1W/T/eynCjXLsmxDvaizgXSerOfoOaBF1+kgD2qQ8+HIfKPjGrXMziQfTm4rTiaTMAzDZvo1ygxJp9M2w63S+4BiPvx+P1RVFYMw3Aw3Op7zgtpUSN9VG8wwX5m/uq4L062SEccm3PwQj8fR29uLAwcO4JFHHhHG+DHHHIOPfvSjh00l6WKBDTYGQLlE3VntVukLPAk3pwEnT8uS2zWnYzgpimJrp6QVQhJAVAUlB+ySkKOTvJpLZhxgN/ZIxNUreEiIOgcwUBUgtcjSAdFpxtF9O4cs0GrsdEwyt3w4qgyS8+HIiJuPfLh8Pi/yP2gyrhuqqiIWiyE8MX6awnrd3nd+v99muB3uBxOaYlfNgBsdHV30Jly1Sjg24RimOqzdDg+ozVA23ciUcUNeSCRtQlVumUxGDMTKZrM1zSc6Nnd0dIiJpHR8JpOITs79IS1HWpL0JCFn/lJXgjy5dCbIw7vIbKPrSevJWk7WsNNdQJUHNcj5cLIJR4uO1fLh5Km1iUQChUIBmqbZ9HKjFlvHx8dtcSCVIi2amprQ1tYmOl3o/ed8bShaggw3Pm7XRjbh3Aw4+jkfcSOkzaoZcRwVMjfk83ns27cPPT09eOihh8R3q1WrVmHbtm1oaWmZ5z1kCDbYFji5ooEdPXGkciWE/R4cN4dZIG44q2jIgKu2siJXv9HqGxlR6XQaxWJRCLp6IeFB4+XD4bAQhBQQ6/P5KgoMahtwmnE0qEHOnpNbDKab8yEPXZCNODL7yIhzm7zpNOJkYeZmyFU7eDkHNfj9fluVEFXEOQc1zFU+XCaTsQ1MkMOYZTweD9rb20WFW7FYxNjYmK0thqDgXTodzm0JlZBNuEotqZWe38WApmk1TbhoNMpCjzlsYe02f8y3fqMv6c5qN5piWe3fmaYpKsnIbKMKtXpa3BRFQTAYRDgcRltbm6h6o1YmMs3y+TySyWTFNkUnmqaJlkWCbouGdlH+HF2eiZYDYNNypNnIfJMHSJAmI/ONqtvkIQ+1nivSfF6vd0o+nDyogYw4AMJEzWQyUBQF4XBYVLfRYvRscebvDg8PV5xQStqSpl5WGgxAMRJkuPExemaQvqtVCTdfJlytSjg24WaGaZro7e3Fnj17cP/99yOZTAIA2tvb8YlPfALd3d3zvIcMwAbbgiVXNHD7U3MzzWq2yNMU5Yq3eibsUPXb6OgoksmkEG7pdFq0MNRbVUZBteFwWOR5yLldtErp9/vh9/unJTgURYHf759i2pEgJJGaTqfFT7mCrRbOfLhSqSTMN9mMk/PiqJ1BroqThZyzQk4eMFFp1dU5qIGmgpFpJRtxjcyHSyaTwmwbHBysaPD4fD50dHQIQzWXy7m2NNO28n6z4VYflmWJCtZqp3rDrhcaqqpWbEOl89TmxDBLDdZuh56FrN+AyYmJ8mAFOfOrGqQTaMGS9M/Y2JgI6a+G3+8XxpusN+QAdtI7hUJBGGT1EAgEhMlEP0OhEHRdF9VfcjUd7bucTZfJZKal5ajiTZ6ESjqODDqquJMXSEmXybl0csyIG/KgBq/XK8y0cDhs03yhUAgtLS3o6OhAW1tbQ4c1GIaBkZERod3GxsYqTigl/ej3+0V3QjwedzXc5MiIpqYmNl4aBJlw1fLg4vH4vHQ60BTbatNRub24MqOjo3j99ddx7733Ih6PQ1VVRKNRfPzjH8cRRxwx37t32MMG2wIkVzTw1TtfxiO7hmBaFqJ+HbqmomiYSOSKUBUF7ziqHdeet2FeRZoT0zSRSCSmVLwlEom6yvkpWySVSolqNwrslXPeqkHTl0hcyePsC4WCuC1adaQw32rVb05oWo+zRDoQCIiMO1rlpawUOtFqMrVgVDuoueXDkXgjAUfn5dVU+neyGSdP8ZJPco6IM0RYfj6cgxoq5cORWUHtvfU8p5ZlIR6Piwq3ahNKg8EgOjs7RUtpJpPB6Oioq6in4F06cVvCzLEsC7lcrmY76mI24SpVwpHwZxOOWYywdju0LFb9BpTNE+dE00QiUVd1md/vFwuc1L2QTCZtwxByuZxrNQ2ZQvKJWiXlAUnA5OR6GpJVj7Z06kL5p1vVfrFYFC23pNtkU87NkKvVtkpDv5xGHD03cmcDmXGk3+SFU5q6Kk9flRdQyfCTqxAjkQja29vR2dmJjo4OsbBEBuRM8uHk52p4eFjot0oTSnVdR3t7O9ra2uDz+WAYhpiw7tR7VJlO2o2PvXMLLbLWMuHmo9OBTLhqRtzhrO2z2Sz+9re/4e6778bBgweh6zoCgQAuuugibNy4cb5377CGDbYFyE8e24NbHt6NgK4hGphadZXIFpEtGrj0nUfgU6eunYc9nB6GYbhONK2UyyVTKpWEKUWCJJ1Oi7ZOy7KESKlkVimKIlY2ZWEFQLQQULsoiRJ5kpbP56u7dTIcDosP/lqrMLQ6m0gkxEGMWjnIkJNXWcmQc67+VcuHI+EmT/dymnHyR4DTkHOacrUMuWr5cE4jrlI+HE0opRXSkZGRipWR0WhUhC1T0PLo6KhrC6qu65MTSp9/HuazfwXqOCaHt25F9Mwza2/ICBOOhLPb+VoVDgsVVVVFq06lU1NTE38RYBYUrN0OLUtNvwFlHeY03eqpVgNga+ckUyyZTIq8XGdmLukRMtdk081pBJE+IS0oZ7VZllXXZzFVgsnakE71fpZTDAiZik5TjlpW5Xy7SgvGzu4GZ9avs7uBkNtM5cVQWSuTrqOuDurSqJQPR9WFtfLhZHK5nBh4NTg4iHQ67bqd3+9HR0cH2tvb4fV6USgURGyFm+EmKh2ffx547nnWbocYWmSt1oo6XyYcaTO3wge6bjp/z4sNwzCwb98+3HnnnXj99ddFx9YHPvABnHzyyfO9e4ctbLAtMHJFAx/+8ZM4MJZBd1MAFsqrXgodTRRAgYKBZB7Lm/y49SMbEPTpU6ZbymXodIBdaBSLxSkTTcfGxmqulpqmKfJAKDTXMAwhHuhDVB7EQOPbZWRhRadAICAOJJTTRoKIBAqt1E6n+s1ZCi0fCOodr0yrqvIIcMpVkce60/MiCzF63uoZ1CCvospmnBv03pJbHqqZcXQ9GZher1fkwzU3NyMWi9lOcmUg5YCMjY1V3Bdqi6DA3/HxcYyOjk55T7X96g6EXngB0HWoigJFVaG6/I2YhQKiZ52JZddeW9drxNSmViXc2NhYxYy+hY6iKOKLgFtLKlXELVWhxyw8WLsdOpz6rRJ941msaAniV//X2w5pJlujKRQKrsZbrS/ZNDlUzr6lqis5w5dMN7n6ze/3IxQKIRwOCwPObQFUHlJA90lakfRINeQMOVoMlHXibKBIEFo8lhdTnYYcDX0oFos246la1q/c3SDnxcnnSS87F5PlKjmKHyGd5syHo44FGqpFr4nckkoTSsl0q2TKhkIhYbj5/X7RnTA2Nmaremz9j/8PwR07bNpNUZQpfhtrt0MPfXeqVQk3H50OZMLVqoRbzNpsaGgIv/vd7/DMM8+ILp/TTz8dp59++oL0AJY6bLDNETMNt33yjRF88Y4d8OsaQj4PLMvE2Fh8ynZ5AyhZCj60Jod14frCW52TLSuZctWMukrbTuf3tbYtFouiiks2kqhCjT4AnR+E1JZAAoUOyrTyRoJKNt7k8nygbISReJMFFd0X5Y6Q8KGfhmHA6/XC5/PZVgdJlFQjEAi4mm+zLYsnoSpPInJOGJOfr2w2K6Z5VcuHcxNusiFXDdmQc2tblU05EnW0yhqJRES7AGXLUSsFmZ7yCitNKO3s7EQ0GhWtqKOjowj85KcIvvwySq2ttn3TPBo8mgeaxwOPpqHY24vwaX/HIu0QQyul1Uy4ekOxFxpkwsmZM87zzc3Ns845ZBiAtdt0mc1gAqd+y+dzGBwcmjzmqQpURUUJGkpQcdHaPN7c7rO1/FHFl5v2kkP03TSa2zZuJ3mbufjilcvlXI23WmHrpVJJ7A+ZYaZpwuv1QlEUUQEnm2+5XE7oALnirdLkdOp6kO+LzDfLsoThVE17eTyeKRpRHr7VaAzDQKFQsLWtOjOA5Qq5bDY7xXyTL5P2Jd1HC9ByrAiAKe8x5+Atet+SKUd6LRAIiEWlSCRi07WapgljsVr2V1NTkzDcfD4fUqkURkdHYd54E3wvvmjXbihrd9JtmkdDqfcga7cFSq1KuPHx8XnpdKAMMzfzjboXFroJl06ncffdd+O+++4T+ZNvf/vbcf755y/o/V6KNP5IcJgz03BbEnWPvT6MbNFAyFd+aSrZn5oCFCwgZ9QvjqpVIi10gsEgCoWCCOaVg2npMcnVevJk0JGREVs7AokoEgORSAS6rtsqugYGBtDT0yPMpEAgYAvmpZJ5ajVVFMUmXOgAQa0AsuFG01BpxbBUKiGVSuHgwYO2akNFUVw/7Jubm8UUqWqQCAyHw+js7Ky5vWEYomqP2lZp7Los6mRjjsQcPU65NZWucw5qoEo6ADWDVd0MOTdTTv6i4PP5xCorlYZT68PKlSuxatUqRGOtKHo8UL1elIwSTKMsLkvFEkrFEt059FIJiUQSgbGxmu1/ie3bkXr00ZrPM8CtC7Xw+/3o6upCV1dXxW0o79DZhiqfKrWnzCdk9Mbjcezdu9d1G/rbr9WOOhdf5hjmcGQ2gwlq6TfDNGzVRwYUlBQvnvrri9hTGrBNm6S8UzIuyHCTfy+bZPIgo5kim21uxp58/HXbRu4ccE46l9v7crmczQQiHefMi6XbodZGmpJZfk7LZlgoFEJra+uU7LFcLoeRkRFbpRpphmAwKExMt84BMvCoSp4iNEzTFNVeVJ1Fn+FO5En38s9gMDhjM5MGeAUCAbRKxlIlaBgXtePKhhxlADuz5GgBVs6UkyvfyKSjOBV5YZVwvjfoeZdNOa/XKx4LRYTI7xeqUIzH4+jv7xfvf+pOCLS0oOjxQPf7UTJKMEoGTMtCyTBQMgzkUTbcPKUSEuMJeIeH0dzcXPNYyfrt0FGPvpMr4SoZcY024UzTrPh3TVT6Xiafn8/MwFAohAsuuAAtLS34zW9+g1KphL/85S9IJBL42Mc+VnfcETN7uIKtgcwk3NYp6rIFA4lcEZqqIhb2ojXkhVEqSkabBQtAtmAiXzLx/7yzCxs6/VPC7Z1VRdP5fb3b1vnWmXOock0ezy4bb05okimZYdlsFoVCYUrrIq0ok6CTDaRisWgLoK1WrUYrps4KsHw+D8MwbKuAJEDoJGdoyCeaoCpP9pTHszsrBOeiGhGAeBwk0JwGHGXK0dRYEtb0nNPzQdlybu+/mSC/R52Vcrqu47OqihNMC+N+H7x6+bnWPBpUxf74PKOjyGzYgOEPfwiqqoqWP6o2kg+iB6+4Aon/3g61xgGMWxcOHYVCwdaCTi0n8vmFaMLVg6IoiEQiVU24er5YMEsb1m61melggnr0myoMIBOGUT6+pfNF5Iom/mGTDyt9k3EUzsmTtDhFRhodx+SFJtncIrNNNuAqGWZ0mY618kKY2/GXzju1n3x5JtdTWxkZb3Q+m8263hcwuaDprIhTFEW0OToXK8mAkzPK6LlzGwrlXNwjDUi6TT72k2lEBpxTJwGT3RFyhwTlnpHZ59xn+bx8XaVtpnNehia2Utvq8PAw+vv7MTg4iHg8Ll4LWkQmU470s7M11dndUG0iq7PS0vn9wuv1isrESCSCc3sOYN3ICArNzfD7fPD6vNBUDSa1AJdKMC0LntFRZDdswOhHLhLHSlm7OY+LrN8WH/JQuUpG3HzEjdD7rVIrKplwc9mlYFkW/vrXv+LWW2+FYRgIh8NYt24dPv3pT9dVoMHMHlbeDeT2p/bjkV1D8Osqon5dHMi8HhUhnweJbBGP7BrCL5/cj0+dutZV1IV8HmSKBgolE0PJPHJFA8uaA1PyoUazWaxsDePv375h3jI85NW96Rh1btfX+v10DUTDMES1G+WTkfFDQwwikYjt39DQASpfzuVyMAzDZrjRRCZN04QAoEmfiUQCpmnaRC4JYzLFdF23CUZaMSXjjownCsClfy9nmMnn3cSSruu2lkk67zT/qgnlei/LPytto+u6mF5Fz5Vbzoos2uTWXXo+SLTRc063T+crtYO4ncZjbSgGgxiPSytgE/mG4UgYoWAIUBSEDAMjoyN46aWXbJNpScST4RaLxcqhyl4v9OXLp+yHTLG3t+rvmcbh9XrR2dlZtYKzWCzWzISrZyDLocayLNF6tW/fvorbUSWcWyYcnZ+OCccr/cxSg7SbczBBJe0GuJty1fSboqjweMqmzEjWwMq2MC45257BViqVbMc9OYaC2vnIeJN/Un4ZZdHKLaeybnA7Rjp1hWw+yQYe6Rr6rHDTf9V+TndbecCVvHBHFW/ybZIGIF1RKBRsv5cXTglaMAUg4j3oMVP8BJl9cqacrPvoPjKZDAYHB4UBJxuBbtWIBBleVKFG+WXyeaepZln24Q3VTDfAPVLFaSTS45X/HXU/0GRZei4sa7KNlhagne9T0ts07EE2T+n1qZT161xUpVZh2n5zKIRlHh0DfQeBiXRqegzNzc1oam6GR5kY+qB7xGtFx0qqGqfjIhluAFi/LTJ8Pl9NfZfP52tWwjXahJO12f79+123IROuWiZcU1PTjE04RVHw1re+FbFYDN/73vcQj8dhPvYYHvnTn3DMMcfUrGRj7TZ72GBrELmigbtfOAjTsuCFgV279pQPXhOZG6pWNlzyagA/+q+nMfiX/8TrVheeTjbBpwFhrwaraMBQFIR1IG4ApmUhkS1CtQw0+SYPjumiiVLJwoamAp5/9ukpq4F0oJQv0+8r/XT7d7X+jdtt1Lpd5/W1bsvttp2/r3SbdPJ6vYjFYmhtbRUDEuTS+HQ6LUqNSWSGQiEoiiIEby6XE5OvqNqNhKkcCKsoik3cARCrpnIro6IorpOkaDu/32/Lz8hkMiJDQ54eRSvVJIKd7Rny8yObfHLFHIlp52sgv2Zuz7Xzuumcd8PZtkHCiwSWs/1Ubkulk5zbIt+fs5qxXAdqlX9a4kpYStloNY3y9l7DwIEDvbjnnnvE80/7JYtmv9+Pi0ZGcWwqhcyBA+X30cTzqiiwDSnRikX09/Xjb4884lotOBfVhtPd9nBC13V0dHSgo6Oj4jbOgSxup2QyOav9UBQFxxxzDJqbm6HrU6cPNgLTNDE8PIzh4WHb9fQ+pi+a9FNeXKD3RSadRrGjA0qN94llmoin0whWaINditBnd1tb24zbwZhDi6zdogEd+UIe2Ux28niqqfBrGpI5E3fv6MXFJ62CX9cqmnKtIS+GknkAFpK5EkbTBcRCk19mErkiFABnvbkDqmUgny9XYcnHdDl31KlrqIrIWRUv/5QnUVJGLGWO0fuS/p7pOO+cyOm2gEaXZd3hrKyTtQ5lqsotrPVoR6CshShMX97GMAxhvMlDnTKZDHw+H4LB4BSjjqqu5Ep50m5erxfFYhGpVMpm8imKIhYpQ6GQeL7I+JQXAmWzSH7+SqWSMOjkLDR6/t2qEZ2fG/J+yvqNzC3Arm+ci4nVnmM3Q0tRytEq9Hhl401eICUjUVVVYbbRa01QZ0csFrNpOcr1lbsZ5MVWeeAWmcfitTGlx2zZdtxmlIQKBRw4cAB/feIJ22On197v9wsTMxgM4pjegwgaBlAqwqO5L14ziw+fz1dT3xUKhZqDGWpl/sZiMaxcubLuoXVu+zA0NIShoSHb9c6uJrdTtfeqoij4p3/6J7z00ktoHRhAcz6PUUUV303cOBy1G9B4/cYGW4PY0RPHYCKHqF+HUZyYWmm5tCh6Msh6vPj+vfdA3/JBKGEFVmYMgxMHCqX8bRxqKAZ4A4CiYiRZwNjgGBRVhaUHoAAIJfbhqacfwnPqZC6GWz6HfL4Sbgdp+Xo6X01wuV12276a4eJmrNXal1r7UM/29HunWSMf3J3/1rmqRtvJz79sernh1uIB2Ffs3J5DEkBktFUyleg+3EwbWrGUb9MpjOmy2/PmfD/Iq6vy7+XrnYZdve8n53V0nkwA5/X03MiGpNx2Y5omVEUFJNNLNtnKlQHlvwVDUZBOpdATH5tyP842iHeHwjhC1zHiyIVQlPLwhO7ubgCAapnIZrPo6+ub8jgXCjMx9WqZhbM1/aZjPFa6bqbouo729na0t7dX3IaGs7i1oZI5l0gkKj7fJ510Eo4++miRE7RQoL8b+nKuH3kkvEccASjqRNWnjHTJNKAcZnkflBOaz+exfPly/pK2CJC1GwDksjkcPHhwynaW5sXYiI4Tz/wOAtlBjB3/cZR8zdBLKQxAEQuqiqIAnjAMRYcFBQPxFNLjeVhQUIQHCoAVWgJ9jz6NG56of8p7ve8lMnbkNlNntZtcAU7I+kGuinJbqKNjH7X+1dJ18mNw6hC3z2xZv8pU0qhknMgmI50nQ8u5j/IUTnrsdN+VzC5apKQ4EDpRRwPdHhlI9JwTqqqKxRPKM5P1Hmk+OVePtJlTt8maTa5UlE1SeqxOPUbPpdtr6zxP9+X2kxY+4/G4MOXkhVvncVf+95qmiVgT+fd0kt9jZL7lcjk0Dw1Dy+fg03xl7WWZME0LsMo5bNlsFhYs+EolDA8N48knn5xy33KmHL3+n8xkcXypiJHdmYn3ngZtwgRdsXKl65R5Zmng9Xpr6juamOxmwqXTaRx77LFoa2s7pBlnTm3mrEylyz6fD1u2bEExmYRqmLAm3srlz1eX9/VhqN2Axuu3haPiFzmpXAmGaUHXVGRzlfMGYBqAokFtXwMl2ASrUHbFLViiagYAEB+AGmyCGohC0TwoagFYpTzMRC+yrz2OAy9ux+uWPdxfFiz0xqiUw1HtD7LSdbLgme0X1rnGbd+c4qDav5PbOUmQykKVxIm8ekDbkOFGK3z03Mmmm5ztJu+Xqqoiz4NWKUkQyPtBosAJ3aa8Wug0gui+ZOHmNnGMtpVXUKk1klYsnc+r25cFN4HVqN/JYtMpJCuZgZZlYe3uN+AbHUVLODwh0ixYE8+PqpUPOpZlQTVK8Hl8CPt9ti8ncrUB3a7I2qH/W5OviWKUX39Y5Wo2j+5BIBBYsNmG9D5aajTCqKvH9KMhG2vWrLFtI7euUxvN+Pg4IpEI1q9fj5aWlikCrZ7Pq0NJJJeDr1SC6XbfZMJ5dFjFAtRoFN4aLTdLjZGREQwODmJ4eLiqYGcWBrJ2AwDTNABYcH4MW2YJ8HhRhAdFvRVFTxDIp1E0ygtypmVOfFVRAMShBMr6zVA1xHMmYOSBzDjMfc/g1dcfxeuqPX9Kbsukzwy54l3WXvTT+TvnZdkooWMi/XSL2XBerqSjnPvmdqyWF7zoNumyvJ0bzttzfua66VrnddR2KZtpZHoVCgVhCMmLGaSbqCqNHoMcy0FVgfI+kiaiUzgcFp/jiqLYzBzSkHKLqvxYZc0pm6S0rXNhVtaVQPlLt9MElM9P55jhZtLJl+Xz1B5LnQXApHak93al26l2u5ZlIRQKifNNhSK0kRKCjkohy7Js2ktRDNtrRmar/BrIkTKFUnnBumg4XhNVxQL+qsMcIrxeL9ra2tDW1jbld/39/RgbG8Pq1auh67rtu5dsoMsxN4caVVXRGgjCU8iDLH9FUeDRNKiaJh27cNhqN6Cx+o0NtgYR9nugqQqKhomSUSoLAbfvyR4dgAWv7gVUD4A8oGplQaPY//DM7DjMfAZaqBmZF/+Iwv4XUeh/HZg4ANDWzj/YSqtNAKaYKM6Tm9HiJuBkg8Y54crtwO+cgCX/O/myU7DRedmocjN0ZKOl2nPhvM7579zEh5vhIbeOyqXtztBd0zRFe4Bc9k645aq5rZwGg8EpmWqqqorbdA4MIAEpPwZ5VVsWGrQdHQjoeSAh5yzpl19HuYVMbi2TV1orUe33bgbadKhm1gGA7tWhqKrI6JjYasrteLJZdHZ04LQNbxZGp1wdR697LpdDc38/PJks/Lo+adhZJiyrfHBuijYBAIqpFFpXrsRJ73+/676TkJxutmGt3zci27De23L7/UJAfq8vBDRNQ2trK9ra2hCNRtHe3j7leXYLjK70mV+J2Zh0zt9bE605lss/U8SXnMOXWCw2JxPOmLlB1m5ejwpV1SbaDCe+4E+8ow1Fg6Wo6GiNwFBj6Nd0qDCgat7y54lp61cDcgmYhSyUYBOKO/+MUu/LMIbeEPpNbOnQM04tJhsqbvlozugJOvbK2znPk9aQtZlzsU3WWvSZ5KyGc1anEaTpqu03GRt0286plfLv5FgMWbdVQjYBK+lSy7KEXqKW21wuN+Wz1VkZRxEgcqUWmUgylmUJ3SbrN+fnqawj3LSc/Fhl7SGfaDGXnmu6f9lYpPcImYDyZHtaMK33WFDpdafLVHVGQxAsyxJGHy0ey//OWTFX6bYBoCM+Dj0xjnAkLG8JeVETAPTxcaxYuQLnvPe9UBRlisFL3Q30+q9+/C/Q+/oQ9U9UxhkmDNOAoqiTER8M40KhUBDt6QCqRnzQ5x59jroZcXOhl20FFihHUFmWhYJpQpWyO8vdPYcvjdRvbLA1iONWNqMj6seBsQw6mpoQCoZc/4CShoaAmcHq1THs9OpQPCFoZgmGUUI+X7Ct3gCA5dGhmEWog7ugx/dB93sBVC/drMeMcBpldJ9ydZZsssgrrJVyNuSpTPWcZBEki0o6+Msrb5VyKOT+dLqeqqzkrIpqz890L1fbhqZpyqXE8Xjc9oXemChjp+wQmnoqm1uyMKOMCKcoiUQiaG9vRywWE6dIJCJ+T/dBYbNyWDBNTqTVVTmklvJFKNDW7XHS80yGH71mJJ7C4TCamppEuDqFd0YikSltH3NdheMmxvuu+CoS27cjpFYPELV0HSeddCLOveaamrc7cOXXkHrwQcSWLbNtY5iGyA2pB2cVwlKhUUZdI02/6Wzrdl+Nqjakzz5g0tyuFm47XROukVWR5duaWuFD+z4fVXULDU3TFoyJy1RH1m4hn0eESwO02FH+mxpIFNAR1vD1iz+DHT1x/Oj5DDwIwAMTuq6LKjZLMuYs1QsLBkL5fsCKw4q12Cq7Ju+j+t8n/V3T9HJnNZdbHiyZGs5FSDJeAEzRX256TF4EpeO7XM3u/Hd0X/Ljks02eV9JO5BWkyvkaaEOmIzFcC4SyouFctuffKqF/HmlaZow1GjhlML5ZYOOdBMZSJTjq+u6MNKCwaDQbfLty0Nn6FSprcwwDJETTHlzdF7WjJZlIZ/PC+0mZ8TRgAZVVW1aze1nJBIRw5zkSaczDVqnfZOrtqnCjtpDw+HwtI4ZB//2NyR274Y/l6+6namqOProo/HuSy+t73avuAKpB/6MlsOwaoeZHVQNWw9kZFWLAaHv4tWMuNnoC0VVoTqqig93Y02mUfqNDbYG4dc1nLN5GW55eDfSBQvRgH/KNolsEUrRwKXvPA4Xn/RhfPjHT6BnNIP2kF7+UmQaMA37F6R4AQghh1PftxXJ+JtFIKhzwqIcCOr2hcvt5BR2bkJPLmF3mnJubaQkoMjwcooqMuPcssfccGsDqFV157xMgabOk5yhIYd81xMqWUlwUNCljGVZSKfTU0LRx8fHbQIpm80KE4yMMDmzSdM0m/FGkzflvBgKaGxvbxflzKtWrZryPDunrMqTVhOJhFg5JaONjDcaEEETV92CPxVFcX2+/X4/QqEQWltbbdNyyIDz+6f+zTQCN/EWfsdW1+gBN8Jbt9YlMBV1cqKVfN9LzSibKfT5sJRoVLXh+Pi4GK7i/Ex2O1/td/IXablFvNJ5t8c0G9hgYxYTsnZLZIu2gQVlXaMhlTfh8Wi48KQjcOKWtdi0ycC9vU/iwFgGbU0B0TZkwYJlTv5tD6YKaPMr+MJZn0U2PbnQFY/HxXlZw9GAAjlqwvn3WslQJ4PFaW45K+MURbFVX8kLk7Juc2Z61YuzCo+Q94tOtK08vVTWYDRgIBgMIhwOT6kIC4fDtoB9WgyWkU04Z6umbMjJ1+m6jnA4bNt30j80UEHWP6FQCJFIRJhzY2Nj6OvrK2cxmyaCwSBCoRDC4TBCoRCGh4dtVS6hUGiK6RYMBkVOmbxwKpPP56eYbnSiwH+qJpNNt0QigVwuZ/sSSQanm/lGi6NO842mnNZ6P9BjWL58ObLZrDDbRkdHbZEKkUikZv5oeOvWqr+f6bYMs1CQTbhKQxNkfVfNiKt2H6o6GQUEhbVbo2GDrYF89G2r8GpfAo/sGkK6UELUr0PXVBQNE4lcEaqi4B1HteOjb1sFn0fDOZuX45aHdyNbAqKBqX9EiWwRYd3Ape/chE+d+iFheJABQwdW+TwFLyYSCYyPj9smSjn/COXz1b6A1WPYUbuAvCrrNN/czDm5baFSS4NszjXqC7pbK0Y1E895vZtpFwgEbC2Tbgae3+/HqlWrsH79eui6LgwsGm1OxhuJYzLR5Aq0wcFB22PxeDxitTQQCGBwcNAWlK5pGmKxmM14i8VitpV6JzQ1Vc6Kop+pVEoE9MqVb/STKvQqjb6mfBDnqampCZ2dnWhpabEZcNFotOHGTPTMM3kENTNrGlVtSJUWM51ANVPcVkrdfsqouRxQoaUHqNu3ZpgFw3S0G1DZlFMwMTleVZHIFqF7PPjwKUfgfaeudb1fufKIKqFIu8nHX1njURSFcwiT0+x3fgFzM+5M0xStkRRU71b56mxTdWobN50m54bVIpvNTmkPdVsIkHPW3PQjaQu5goyq7Mkwks0jqsgLBoPi3zuHO9Fz6mbO0fRBes3I3KJqMWCyYyEej6O/vx/ZbBb5fF5UBJKBSD99Ph80TRPTN+nU3t6OpqYm2+cuPZ5YLDblOXVWvzlPzpgTMuCSySSGh4dtCy2Vqt+CwSBisRiam5tt5ls4HK7YJkcVfp2dnSgUCuJ93tvbC8uyhFkZjUZdj4es3Rhm+pVwhmHAMgahTERYWZZV1mpy/NOh2fXDBjbYGojPo+Ha8zbgl0/ux90vHMRgIod0vgRNVbCiJYhzNi8T5howfVGnKJNjw+sJ36OMiUpmHFUwjY+Pi1M2m61owtH5eu5XNuXqMe3oQE//HrC3UBByG6mzKs7NkKv04SN/8NQrAuuh3lw7t9+Taef1eoVAloWcaZpibL284k0rkvL903NEoq2vr0+IThKPzc3NosqNzDfKECBB6vY+M03TteqNfqZSKVfjjQw3WklOJpNTbpuq30j8Uq5BLBZDV1cX2tvbbS2noVCoYa9dIzALBRR7e2tuwzALgemulJZKJWBwCEo2A03VyhU7VnmCm3Sjh2jvGaYxTFe7AdPXb27I1d4tLS117StlodJxlUw42ZiTq+Cdhtx0832cxp38hY3MOWc2lxwpImsxOV/W2flQb0u804AjbVRNm1bLBZb1ozOvTc7AIzNLzpuVzaampiZbpAqZVXJ7Zz5fbmsk0y2dTotYEVoMp/tyxqTQ/tHCY0tLC1pbW9HS0oJAICDMQdnwDAaDVavf3Iw3+n5AxpvTfBsbG5tyW27Vb9FoFG1tbWhtbRXVaVTBR+8Br9crDES5q2JoaAj9/f2ibTUajSIYDM55hQ3rN2Yp4TThCroOU1GgaPL34omUUcsCL482FjbYGozPo+FTp67FxSetwgs9cSRzJUT8Hmxe2Qy/rk3ZdrqibjrIgaZuK1xuFIvFimYcCTaqkhsfH0cmk3E14SpVP8yUSplDBN0/CRi5mo5EBwkm+r1bK6nbgIHp4FwBbhS033LIrZw9QvftbHeQHyc9PlqxpbYLMt50XUcoFBLjqjs6OtDZ2Yn29nZbZZ6qqnVVvzmNt/HxcQwNDdmMNzpPJie1YtCKrxM5+43EV3t7O7q7u9HV1WXLfKsWNFoPie3bkXr00bq2DW/dyq0LzIIgVzSwoyeOVK6EsN+D41yOPdPBacIVvHGYOQWKY/GCzDZuM2AWI9PRbrT9XOq3StBxuNLx10mpVLItcqVSKdHlIFelJxIJ0VooG3LGRAA2aZCZHldJG9HnibOLQK70l3N0Sbs5s92cAwFosVGOS3Eu8lZaTHXmAVcajjAd3BZY5YBzMjtpInwwGBQLvul0GvF4XJiHspanE1XmaZomMm/l9k1aMKGWYNl8k396vV6EQqEpbbZkADpbUOPxOFKp1BTzjTot3J4Hp/kmV+XJ1W+kKylWhbTj8PCwyK+j3LZqVePT1W7RM89k/cbMO43WbvUxUbnGuq3hsME2R/h1DSetq21qTVfUzTW6rqOlpaXu1VTDMGwGnJs5R0H/VN1EplCltlW3DKB6VjnrQa5ykwN0qRddURSxGkdVcs7BDM6WDFnkubXTNiJgnEw7OT9F/h1VtNFPWkGm55Qy05xViPRlmIRgJSEmt1xEIhG0tLSIAQzUkkE/aTWVTDnKF6EAZnlqFomzsbExDAwMCNOWvgxQ6whB1W+pVEpc9+qrr4rHIk/HamlpQWdnJzo7O7Fs2TIh5iKRSF0mQOrRR5H47+1QKwQQi+d/YkVz2bXXcusCM2/kigZuf2ryy75hWtBUBR1Rf8O/7FuWBRSnruTTX5UFNORzz0mhUKgYCM4wjaBe7QYsPP3mhsfjEcZEPcgRD2TIURuks21Vbjesd7CAZVnT7hqQq+BI/9ACYXt7u8j4okopygaTq+WKxSIymYytqkzOKyPdQRV/pDXIcDRNs6YBJ5s+cnyKG7KmpUo8Z66yPNxifHzcpp1pMUPu2nAO+aLcOlm7ya298sAKZ1YyGWKk22KxGDo7O4WuJB0nm5vxeBxjY2NiMIQ8dEGODNm7d684L0eGUEavnCHc1dUFj8cjDLd9+/ZBURRbbpvT9J2udqO2U9ZvzHywELTblG0azOGo3dhgWyBMR9QtJORVpXowTVPkQrgZcrSSShVyFLhfLUNuOhVjzkq3eh6fW5uDfD4UColMs1AoZCvvlyerApNCzxk6SxOr6Kc8tt0t0Fh+zLRC6GzxIhEkCzbncylfJpEnI6/sys8BGXjyiqo8/GG6+XZ0W/JE2kgkIgJ6qdqNzEIS8s7bprYHVVXR398vzDfAXv1GFXpdXV3o7u5Gd3c3WltbpwxaUL1e6DWmStVqKWCYuSZXNPDVO1/GI7uGYFqWrV3twFgGtzy8G6/2JXDteRtmLdQ0KfwbAH53zz245rvfxe69exEOBnH8pk24+xe/gALgmhtvxE9/+UsMDw9j48aNuOGGG/DOd74TAPCzn/0Ml19+OYaHh8Vt3Xzzzbj++uvFF7Bt27YhlUph06ZN+NGPfoRYLIaXX34Z+/fvx+WXX44//elPKBQK2LBhA/793/8dGzduBADceuutuP7667Fv3z6sW7cOX/rSl/CJT3xiVo+bYSqxWPWbG1QRFXb8nVfCNE2hZ+g47WxZlTNk0+n0lGmftb7Q0Xb16jZ5cAKZRoFAQBhxZM50dXUJk4YWB/1+P1RVFXpIXgykLg4y4mRTjHQJPRa3OBR5wZP0GD3ndP/Oxy3rNzo5uxVIU8oLsc7IFNJtpEvl5waAa/WxZVk2847MNvk25MEU8mIrLUjTvtJrT8+T/BxQZMjQ0JDNfJOr35qamkTeWyAQQDqdxtDQkIgRoe8hpOFYuzGLAdZuS1e7scHGHFJoYlA4HEZnZ2fN7esZ7OA05Gg6UrWhDvVSyXiq9vjcTDhZ7JBYoAD/WCwmsinouZEnTpEpR2Ikn8+L1Vh5pZlOlUw7ap+gfA15WxKLzqlacuaK/BzIq6/y45NNN1oZnW2rJuEMbi6VSjZDklZ3DcOwtca6Za7IRqG8YhuJRNDW1oaOjg6c/Le/obVYhJHPw6vrUFSVEwqYBcntT+3HI7uGENA12xREr0dFyOdBIlvEI7uG8Msn9+NTFQLX60VraoI20Z7W19eHj33mM7juuutw3nnnIZFI4M9//jO8y5fju9/9Lm665Rbceuut2LhxI370ox/hfe97H3bt2oXlNb74yPzxj39ENBrFAw88AKC8Evre974XK1aswL333ou2tjY888wz4nP99ttvx9VXX42bb74ZmzdvxjPPPINPf/rTaG1txbnnnjurx84wjB1VVUWFVD1RJM7BDmTIOavkZM0nL6rVkyNHwxumq9vkEw0gIEOuqalJmDhdXV1TcuXcuhncoHxcMqVof+WMO7kqjPQN6Td5AJWctyfnptHzRcab/DidlWryAC65e4MWSyluhUwxt0nYzpNsmJJWrDQMSF7wdVb7yRWLZOzRT7mtmIZZUO7bkSMj0OegCodhGg1rt6Wr3dhgYxY0h2qwQyVDrt7BDgQJu+k8PjdDTv7p8/kQjUaFKec041pbW8X5cDhcc3Q6tWjIK6L5fB7xeByDg4MYGhrC6OgohoeHRdsmPYdU+u+2Kkthx05jq1JA8HSzmuSJjSRM5VV2edKX02CUKwErtSA7w5cDHg9OUlQMvr5rYjqcCm3icQUCAaxYsWJa+88wc0GuaODuFw6WVz8D7mZ2NKAjXSjh7hcO4uKTVjWsfa2vrw+lUgkf/OAHsXr1agDApk2bAAA33HAD/uVf/gUXXnghAOB73/se7r//fvzwhz/EtddeW/d9RKNR3HLLLeJv/uc//zkGBwfx9NNPixyqI488Umx/1VVX4bvf/S4+8IEPAADWrl2L5557DrfccsuiFGkMs5RoxGAHuVVVrpKjNkJ5sANlnFVjuoacszWTNA5psEgkIkwfOskRJ2QuVqrc83g8tmgNOQO3kn6yLEu0846OjmJoaAgjIyMYGhoS5iUZl3JcCGkiWqSkVk65q4A0V726TR6KUcmAkzP55IVOefhYJpOxLf7Sa2mapi22xRnj8n8Vi9himBhNpWwVeLquw+vzobnO/EKGmUtYuy1t7cYGG7OkmKvBDnKVXDqddjXhZIOu3h52MrvqzSSRK8cqta3quj7FkJNPzkq55uZmqKoqPmRlaGLo2NiYOA0ODuLgwYNTppU5W1nJfKPnQp7URW0owWDQ1kpLLaHyKrDbbcqXCblV1jldlFaGSUDKBhyJcacIN1UVUCeClWEBpgmyWjnInVko7OiJYzCRQ9RfvVI06tcxmMjhhZ54w9rZNm/ejHe9613YuHEjzjrrLJxxxhm44IILAJQF3CmnnCK2VRQFJ598sq1tux42bdpkq4J98cUX8Za3vMU15D2dTmP37t245JJLbG0FxWIRa9asmeajYxhmIdCIwQ6yySS3rcqDHeSMs2pMV7cBmGLI0YRQOb9WHgQgt3RSdZmsO8jwci5cNjU1oaOjAxs3bhQtmrIhmUgkMDY2hqGhoSlZdE4NpOu6MAhpOFZLS4sYokDajdp3nYuactQJGXjyeadGrFUdJ1fluXVf0G0VVQ0WdX4UC+VQUABQAL8/wAYbsyBg7TbJUtRubLAxhz1zMdhBnrSaSqWmTOWaaY6cZVnTziOplSHn8XjEBCcaAuBmxnV2duKII44Q5tX4+DjGxsYQj8cxPDyM3t5e9Pf32/JWqglVymKRByh0dXVh1apV6OzsRFtbG5qamuDz+aZU3MmVd9VaZN1aLNzMOjLeKEMmk8lAz2YB0wJgTQo06TllmIVAKleCYVrQtcpT1QBA11Sk8yUkc42Z7AyU/w4eeOABPP7449i+fTu+/e1v42tf+xoef/xxAFONaDnnh6bqybh9YXWa5W63S9Dwk9tuuw1btmyx/a5RreoMwyxsZjvYgTScbErJC4pUATaTHLl6kQcpUJUWLSBSBSAZa5Qh55w0Kg8q8Hq9aG1tRXd3t2ixzOVytkrAgYEBjIyM2AZY0O9lvF6vTbd1dHSgu7sby5YtQ3NzM5qbm6dkAjtxalm3aBO3bGK3ybH5fB6ZTEa8Pi2734CWTEJXVZiWBcs0YVoWYFnwsHZjFgis3SZZitqNDTaGmSYzHexQT9sqTVt1Tlp1np9Jjtx0Hl+lDDn6GQwGhZAiQ66zsxOrV68WwqlYLCIej2NoaAhjY2NCoMqZIGSUjY+PAwB2796NJ598UgTXUnjtypUrbcZba2vrjEwuaiF2npwrq/l8Hv6f/AT+F1/C8qYmlIpFFEtFlIollAwDwWBw2vfNMHNB2O+BpiooGia8nspCrWiY0FQFEX9jD/uqqmLr1q3YunUrrrrqKnR0dODhhx9Gd3c3HnvsMZx88skAyn97TzzxBM477zwAQHt7O+LxOHK5nAimfuGFF2re36ZNm/DTn/4U4+PjU1ZCaWLwG2+8gQ996EMNfZwMwyxNGjHYIZlMTqmSkxdh5ZbVegw5N92WTqenbEfVXsBkvAWd5EzcUCiEYDCIUCgkTpQ7G41G0d7ejlWrVkFVVZvBNTY2hv7+foyOjtqmrjp1GwDbJPmWlhYsW7YMy5YtQ2trq5g8L0dxUNumc7jDTCHt2ffVK5F56CE0d3WVF1MtC0aphGKpCE1lg41ZGLB2m2Qpajc22BhmjpEHO9SDZVmirYEMOLfBDnKOXD6fr2jITTdHbrqGnDxAwG2ogxyk29raCsuyxOPJ5XLCTJT/Da1Kjo2Nobe3V5Qm+/1+BINBhMNhLFu2DKtWrcLKlSvR0dGBtra2mqumiqIg/+CDSD36qLhOnzg5LbNMfz8MAOHW1rqfC4Y51By3shkdUT8OjGUQ8lU+pCdyRaxsCWLzyuaG3fdTTz2FBx54AKeffjra29vxyCOPIJVK4eijj8Zll12Ga665BuvWrcPGjRvxwx/+EHv37sU//dM/AQBOPPFE+P1+XHnllbj00kvx4IMP4p577qn5OXnRRRfhm9/8Jj74wQ/iG9/4Bjo6OvDss8/i6KOPxnHHHYcrr7wSl19+OSKRCE4//XTkcjk8+eSTMAwD//iP/9iwx84wzOHJTAc7yG2rZMi5TVpNpVKiBbXWYAeKDXG7T1pATCQSGBgYmBKvIZtx1I4qZ8dRnlxTUxOWL18uKuooH432nSa9y6bcwYMH8corr4jMPRpCQF+kly1bhpaWFjQ1NbkOP6hEYvt2m35zo/DCDiCbhZJOw0tf5GtoQ4Y51LB2W9rajQ02hllgKIoixFs91DPYQRZz8mAHpwkn/6wXEmvTzZGj8fB0f4lEQrR8ksEnT4pyZpjs3LlTZM5RqHBbWxtWrlyJ1atXY9WqVWhvb0ckErHdf+rRR5H47+1Qvd6q+2kkk1C4nYBZ4Ph1DedsXoZbHt6NRLboGpabyBahKgrO3rysYSG5QDnE9pFHHsF3v/tdpFIprF27FrfeeitOOukknHDCCUgkEvjc5z6HkZERbNy4Effeey+WLVsGAIjFYvj5z3+OL3/5y/jhD3+Ic889F5dffjn+7d/+rep9er1e3HffffjCF76AM844A6ZpYtOmTfjxj38MAPjf//t/IxQK4frrr8dll12GSCSCzZs348tf/nLDHjfDMEy9yIMd6sUZfyHnyFH0CA0toAmicgyJrOHk7N5KUGQGDUlwy9KVtZuszegka0Gq0KPtvF4vRkdHsXfvXgBlk5IWTOUW0+7ubjHcwY169JuZycAqFGCmUmJqIsMsNFi7LW3tplh1prEfOHAAK1euRE9PD0/PY5hFTq3BDiTmqGU1nU5PMeGchly9gx3qgUbVy2PnaRIoCT3n6iuF/oppUV6vGOKwYsUKrF69GkcffTRW/uEPsJ56Gt4a46bzr78OyzCgOQw6J2ahgOhZZ2LZNKbrMIwMfemYaZhrvmTgit+/jEd2DZUnUvl16JqKomEikSsLtHcc1Y5rz9sAn4dN47lmtq9nI2HtxjBLH8qRpbZVOUeOhhzIFXK0kElVcvUM56o0bEo+T7dB15VKJdsQAgBiyIOcJyfnyoXDYbS3t2PlypVYtmwZli9fjmXLliEYDOLgFVcg9cCfoVfRb8XeXhjxOBSvF2qVhWrWbkwjmM3xnrXbwqNR+o0r2BjmMGQuBzuQkKs0YZXOVxvsoGmaaFMgaBKovLJLwbfOllY3A47Of96j422qitFkEh7dA6/uhdfnhVf3QvNoCPjLE7Oge6Av60bQEbjpRnjr1rqeR4aZC3weDdeetwG/fHI/7n7hIAYTOaTzJWiqghUtQZyzeRk++rZVLNAYhmGWILSwOJvBDvJQB6chR4ustQY7uE16p8tk6FHe7ejoqBhYYBiGqJCTY0eo+i0QCKC5uRkfGRvD+kwGpdER+LzlAQ6qqpY1nqpCAaBGwjAzGXjq0G+s3Zj5hLXb0oUNNoZhajKXgx1olVVeSXVrWzUMQ7SHytNpLMuyrd7Ko9vdTDzDr8D0eJA3SsgX8kijHBysoCzuVq1aVRZtuhfBLVt4dZNZFPg8Gj516lpcfNIqvNATRzJXQsTvweaVzQ1tLWAYhmEWN7Md7JDNZm1mHLWtkilHU+JlQ470mNN0o4VTGhhBXQu5XE5MFyTe4fFghaphsPcgIC2gejQNy1esQCgYhBZtgplMsX5jFgWs3ZYmbLAxDNNwGjHYwXmecuSoSq5QKFRsW6VpXSQIqdpNnmA6ZR9gwQJmNJ2UYRYKfl3DSetqB28zDMMwTD00YrADRY84Bzskk0mk02mh6eQqN2p3JfNNzeflOyl3NkxUynlYuzGLGNZuSws22BiGmXdmMtghn8+7mnHyYIdEIoF4PG4b7ND52mvwxMfhUxQhzCzThAULmqpC1eqfaMUwDMMwDMNM0ojBDul02hY7kkgksOae/4Jv715EdH1yYdU0oSoKlGlMI2UYhplL2GBjGGbRIYu3elZTAaBQKCCdTmPk6q+j+PjjCLS1waQgXsNAsVCAaZnQNP5YZBiGYRiGOVT4fD74fD40NzdX3Obgvn1IDQ6ieflyWJZVznmzJirYPKzdGIZZGPCnEcMwhwUUlpsNBZHSNPjrrJZjGIZhGIZhFg40FEEDt4YyDLOw4HpahmEYhmEYhmEYhmEYhpkFXMHGMMxhh1kooNjbW3MbhmEYhmEYZmHA+o1hmIUOG2wMwxxWhLdunZNtGYZhGIZhmLmB9RvDMIsBNtgYhjmsiJ55JqJnnjnfu8EwDMMwDMPUCes3hmEWA2ywMQzDMMwiJbF9O1KPPlrXtuGtWxv+5eRd73oX3vrWt+L6669v6O3OFEVRcPfdd+Pss8+e711hGIZhGIaZAms3O0tNu7HBxjAMwzCLlNSjjyLx39uher1Vt6NMmoW4+l8oFOCtsf8MwzAMwzBLAdZuSxueIsowDMMwixjV64W+fHnVUy0RNxO2bduGhx9+GDfccAMURYGiKHj99dfxyU9+EmvWrEEgEMCxxx6LW2+91fbv3vWud+Hzn/88Pve5zyEWi+G8884DANx1111Yv349AoEAzjjjDPz4xz+Goii2f3vnnXfiuOOOg9/vx/r163HdddfBNE0AwJo1awAA55xzDhRFEZcZhmEYhmEWEqzdlq524wo2hmEYhmGmzfe//33s2rULxx13HL72ta8BACKRCFatWoXf/va3iMVieOihh/CP//iPWL16NU4//XTxb3/605/is5/9LJ544gkAwN69e3HhhRfisssuw7Zt2/DUU0/hy1/+su3+HnvsMWzbtg033XQTTjnlFLz22mv4h3/4B/h8Pnz+85/HM888g46ODvziF7/Ae97zHmiaduieDIZhGIZhmAUOa7e5hw02hmEYhmGmTVNTE7xeL4LBILq6usT1//qv/yrOr127Fg8//DB+/etf20TaMcccg29+85vi8le+8hW86U1vwre+9S0AwNFHH43nnnsO3//+9223e8UVV+BjH/sYAGDdunX46le/ihtvvBGf//zn0d7eDgBobm627Q/DMAzDMAzD2u1QwAYbwzAMwzAN4wc/+AF++tOfYt++fcjlcigUCnjXu95l2+atb32r7fJrr72GE0880Xad8/ILL7yAxx9/HFdffbW4zjAM0WbAMAzDMAzDTB/Wbo2DDTaGYRiGYRrCr371K3zpS1/Cd7/7XZx44omIRCK4+uqr0dPTY9suFArZLluWNSWzw7Is2+VUKoVvfOMbOPfcc+dm5xmGYRiGYQ4zWLs1FjbYGIZhGIaZEV6vF4ZhiMuPP/44tm7diksvvVRc9/rrr8Pv91e9nWOOOQZ//OMfbdc9++yztsvHH388du3ahfXr11e8HV3XbfvDMAzDMAzDTMLabW7hKaIMwzAMw8yINWvW4Mknn8S+ffswPDyMI488Ek899RT+9Kc/YdeuXfjSl76EnTt31rydSy+9FK+88gquuOIK7Nq1C7/85S/xq1/9yrbNlVdeidtuuw3XXHMNdu7ciZ07d+L222/Htddea9uf+++/H/39/RgbG2v442UYhmEYhlnMsHabW9hgYxiGYRhmRlx++eUAgGOPPRbt7e0444wz8IEPfAAXXngh3v72t6NQKOCTn/xkzdtZu3Ytfv3rX+OOO+7Apk2b8Itf/AJf+cpXbKunZ511Fu666y7ce++92LJlC04++WT84Ac/sI10v+GGG7B9+3asXLkSxx9/fMMfL8MwDMMwzGKGtdvcoljORtkKHDhwACtXrkRPTw9WrFgx1/vFMAzDMIcFe/fuBQCb2KiXg1dcgcR/b4fq9VbdziwUED3rTCyTVgwXOv/n//wf3HPPPXjppZfme1emxWxez0bD2o1hGIZh5oaZHu9Zuy1MGqXfOIONYRiGYRYp4a1b52Tb+eDmm2/G2972NrS0tOChhx7CTTfdZBsbzzAMwzAMs9hh7ba0YYONYRiGYRYp0TPPRPTMM+d7NxrCrl278M1vfhOjo6NYs2YN/vVf/xVf/OIX53u3GIZhGIZhGgZrt6UNG2wMwzAMw8w7N954I2688cb53g2GYRiGYRimDli7TYUNNoZhGIZZzOzYAfzpT8CzzwK7dgH5PODzAUccAbz1rcDf/R1wwgmAosz3njIMwzAMwzDxOPCHPwDPPFPWcaOjZZ3W3g4cfzxw4onAOecAodB87ykzTdhgYxiGYZjFhmUBv/kN8J3vAE895b7Njh3Af/5n+fzxxwP/9/8NfPSjgMoDxBmGYRiGYQ45b7wBfOtbwO23A9ms+zYPPVT+GY0C27YBX/kK0N19qPaQmSWsshmGYRhmMdHbC5x9NvChD1U215w8/zxwySXAe94D7Nkzt/vHMAzDMAzDTGKawI03Ahs2AP/+75XNNZlEovxv3vQm4P/9f8uLq8yChyvYGIZhGGax8NxzwBlnAMPDk9cpCnDmmcC73gVs3lxe8UylgBdfBB55BPiv/wIMo7ztgw+Wq9nuuQc49dR5eQgMwzAMwzCHDYVCuYPgN7+xX792LfDBDwJbtgCrVpUNtD17ypEfv/kN0NdX3i4eLy+SPvEE8IMfcCfCAocNNoZhGIZZDLz8MnDaaWWhRXzsY8DXvw6sWTN1+/e+F7jsMuDAAeCaa4Af/7h8/fh42ZD785/LGR8MwzAMwzBM4zEM4OKLgd/+dvK6NWuA738f+Pu/BzTNvv2pp5a13fXXl022L34RGBgo/+7f/q1srt18M+fqLmDY/mQYhmGYhU42C1xwwaS5FgwCv/99uWXAzVyTWbECuOUW4I9/BJqaytel08CFF5bbDxiGYRiGYZjG873v2c21j3wEeOkl4P3vn2quyeh6edtXXgHOOmvy+h/+EPiP/5iz3WVmDxtsDMMwDLPQueoq4LXXyue93nLb5wc+ML3bOP308rRRmki1fz/wpS/NardKpRIuueQStLS0QFEU7NixY0a3s23bNlxwwQVVt1mzZg1uvvnmir9PpVJQFAUPUTiwC/fccw8UXvVlGIZhGGau2bUL+OpXJy9/9KPAL34BhMP130YsBtx5Z7nzgPjsZ4H+/hnvFmu3uYUNNoZhGIZZyPT3l1dAiauvLuetzYQTTgC+/e3Jy7feWp5oNUN++9vf4s4778T999+Pvr4+bNiwYca3VYtnnnkGn/zkJ+fs9hmGYRiGYRrG1VcDuVz5/Pr15W6CmeSneb1lY66jo3x5bMyu5aYJa7e5hQ02hmEYhlnI/OQnQLFYPn/sscDll8/u9i69dDJ7zbLKgm+G/O1vf8P69euxZcsWdHV1weOZXrRrqVSCVedUrPb2dgSDwZnsJsMwDMMwzKFjcNA+1OAHPyjHe8yUtja7qXbbbfVNInWBtdvcwgYbwzAMwyxkbr998vxnPwtMUwhNQVWBz31u8vIvfzmjm9m2bRuuvPJKPPfcc1AUBWvWrEE2m8VnPvMZtLe3w+/3493vfjdeeukl8W9+9rOfoa2tDXfeeSeOOeYY+Hw+DEsTUa+66iq0tbWhubkZX/jCF2DQ9FNMbTN47bXXsHXrVvj9fmzcuBGPPvrolH285557cOSRRyIQCOD0009Hb2/vlG3uvPNOHHfccfD7/Vi/fj2uu+46mKYpfq8oCn7yk5/g7LPPRjAYxIYNG/DYY4/N6DljGIZhGOYw4D//0744+t73zv42P/xhexXb9u3TvgnWbnMPG2wMwzAMs1BJJIBXX528fOGFjbnd886bDNc9eLB8mibf//73cdlll2Hz5s3o6+vDM888gy9/+cu46667cPvtt+PZZ59FR0cHzjzzTGQyGfHvkskkbrjhBvz85z/Hyy+/jGg0CgD44x//iD179uCRRx7Bz372M/z85z/Hd77zHdf7Nk0T5513HoLBIJ5++ml8//vfx5cceXL79u3D+eefj/POOw87duzARRddhK/KWSgAHnvsMWzbtg2XXXYZdu7ciZtuugk33ngjbrrpJtt2X//617Ft2za88MIL2LRpEz7ykY+gSMKZYRiGYRhG5umnJ8//r//VmKmfXm9Zv7ndR52wdjsE2s2qk56eHguA1dPTU+8/YRiGYRimBnv27LH27Nnj/stHHrGsciOnZa1Z09g73rRp8rbvuWdGN3HVVVdZW7ZssSzLspLJpKXrunXHHXeI36fTaSsWi1k//vGPLcuyrNtuu80CYL388su227nkkkus9vZ2K5fLieu+/e1vW93d3eLy6tWrrZtuusmyLMvavn27peu61d/fL37/m9/8xgJgPfjgg5ZlWdZXvvIVa9OmTbb7+ed//mdLlj6nnXaadd1119m2+dGPfmQde+yx4jIA6+qrrxaXX3vtNQuA9eqrr7o+J1Vfz0MMazeGYRiGmRuqHu83b57UWHff3bg7vfXWyds944wZ3QRrN3capd+4go1hGIZhFiojI5Pn16xp7G2vXet+PzNk9+7dKBaLOOWUU8R1wWAQxx9/PF6VqvACgQDe/OY3T/n3xx13HHw+n7j89re/HX19fRgfH5+y7f/8z/9gzZo16OzstG3v3Oakk06yXefc5oUXXsDXvvY1hMNhcfrCF76A3bt327bbuHGjON/d3Q0AGBwcnPokMAzDMAzDSC2UDdVvrN0WvHabZZALwzAMwzBzRp0hsgvpfpyj1C3Lsl1XKey20gh2t+udt+lGPdukUil84xvfwLnnnlt1O13Xp+yPnPXBMAzDMAwjOBT6jbVb1e3mS7txBRvDMAzDLFRiscnz+/Y19rb37p0839o665s74ogjoOu6LUQ2m81ix44dOPbYY2v+++effx75fF5cfvLJJ9Hd3S1yPmSOPfZY7Nmzx7YS+eSTT07Z5qmnnrJd59zm+OOPx65du7B+/fopJ4ZhGIZhmBnR1jZ5vpH6jbXbgtdubLAxDMMwzEJl8+bJ83v2NKQdAACQywGvvDJ5+fjjZ32T4XAYl156KS677DLcd999eOWVV7Bt2zZ4vV585CMfqfnvs9ksLr30Urz66qu466678M1vfhOf//znXbd973vfi3Xr1uGSSy7Biy++iAcffBBXX321bZtLL70UO3fuxFe+8hXs2rULP//5z3HHHXfYtrnyyitx22234ZprrsHOnTuxc+dO3H777bj22mtn/kQwDMMwDHN4c9xxk+effbZxt/vXv06ef8tbZn1zrN0aDxtsDMMwDLNQaWoCjj568vJ//mdjbvfOO4FSqXy+qwtYvrwhN3vdddfh3HPPxcUXX4wtW7ZgYGAA27dvr9haIHPGGWdgxYoVOPXUU/Hxj38cH/3oR3HZZZe5bquqKu68804kEgmccMIJ+MxnPoNvfetbtm3WrFmD3/zmN/jtb3+LTZs24Re/+AW+/vWv27Y566yzcNddd+Hee+/Fli1bcPLJJ+MHP/gB1jQ6745hGIZhmMOHE06YPP/b3zamnbNYBH7/e/f7mAWs3RqLMjFloSYHDhzAypUr0dPTgxUrVsz1fjEMwzDMYcHeiXL/isLg618HrrqqfH7DBmDHDkDTZn6HlgWccgrwxBPly5ddBlx//cxvj7FR8/U8hLB2YxiGYZi5oerxvr8fWLlycjHzgQeAv/u72d3hf/wHcPHF5fNNTcDBg0AdJhhTH43Sb1zBxjAMwzALmU9/GvBMzCR6+WXge9+b3e395CeT5hoAXHrp7G6PYRiGYRiGmaSrCzj//MnL//RPQDY789sbHS0viBLbtrG5tkBhg41hGIZhFjLLlgGf+czk5SuuAB5/fGa3tWMH8MUvTl7+5CeBI4+c1e4xDMMwDMMwDr72NcDrLZ9/7bWylptJq2ipVDbU+vvLl5uagC99qWG7yTQWNtgYhmEYZqFz7bUATUfK54GzzgK2b5/ebTz8MHDaaUAyWb68fDlwww2N3U+GYRiGYRgGeNObADnE/6c/BT71KSCTqf82xseBCy8E7r578rrvfa9h2blM42GDjWEYhmEWOsEg8OtfAzT2PJksm2z/8A9Ab2/1fzswAHzuc8C7311uMQCAQAC44w6guXlOd5thGIZhGOaw5fLLgXPOmbx8223lCaN//GP1ajbDKA+2evOby4OpiE99CrjkkrnaW6YBeOZ7BxiGYRjmcEbTNBQKhdobHn88cN99ZWNtbKx83b//e1msvf/9wDvfCWzeXDbhUingxReBRx8tT5wqFidvJxwG/vCH8qADpuEYhgEvtYQwDMMwDLMkUVUVJRpiUAmPp7xA+qEPlbUXALz+OnDmmeUp8eefD2zZAqxeDZgmsGcP8Oyz5X+zb5/9tj7xCeCWWwBFmZsHdJjTKP3GBhvDMAzDzCN+vx+pVAojIyOIxWLVNz7pJOC558qDDx54oHydYZRNNHl0eyVOPrlsyB111Ox3nJnCyMgICoUColRpyDAMwzDMksTr9SKdTiOXy8Hv91fe0O8Hfvc74DvfAa68shz1AZRz2b75zdp3FImUIz0+/Wk21+aIRuo3NtgYhmEYZh5pa2tDPp/H4OAg4vE4NE2r/Y9uvRWh3/0OTT/5Cbyvvlpz88KRRyKxbRtSH/oQoGnAxChypnEYhoFCoYBIJIK2trb53h2GYRiGYeaQtrY2JBIJ7N+/v77KpwsvhOctb0HzzTcjdO+9UGp0L5jBINLnnov4P/8zjOXLp1a0MQ2h0fpNsaz6RlkcOHAAK1euRE9PD1asWDHrO2YYhmEYpoxlWRgeHkYul4NhGNP5h/Dt2AH/Y4/B99JL8OzZAyWfh+X1orR6NfKbNiH39rcjf8IJvOo5x2iaBr/fj7a2NigL5Llm7cYwDMMwc0c6ncbY2FjtVlEH6sgIgvfdB9+LL8K7cyfU8XFAUWC0tKDw5jcjf9xxyLz3vbC4In7OabR+4wo2hmEYhplnFEVBe3v7zP7x2rXAeedNudoLIDi73WIYhmEYhmEqEAqFEAqFpv8P16wpZ6850AFUaTZlFgE8RZRhGIZhGIZhGIZhGIZhZgEbbAzDMAzDMAzDMAzDMAwzC9hgYxiGYRiGYRiGYRiGYZhZwAYbwzAMwzAMwzAMwzAMw8yCuocc0GSMvr6+OdsZhmEYhmGYpU5XVxc8nrmfM8XajWEYhmEYZvbUq93qVndDQ0MAgBNPPHHme8UwDMMwDHOY09PTgxUrVsz5/bB2YxiGYRiGmT31ajfFsiyrnhvM5XJ46aWX0N7efkhWXRmGYRiGYZYih6qCjbUbwzAMwzDM7KlXu9VtsDEMwzAMwzAMwzAMwzAMMxUecsAwDMMwDMMwDMMwDMMws4ANNoZhGIZhGIZhGIZhGIaZBWywMQzDMAzDMAzDMAzDMMwsYIONYRiGYRiGYRiGYRiGYWYBG2wMwzAMwzAMwzAMwzAMMwvYYGMYhmEYhmEYhmEYhmGYWcAGG8MwDMMwDMMwDMMwDMPMAjbYGIZhGIZhGIZhGIZhGGYW/P/81WIeQNy/VwAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(11.5, 5.2))\n", + "forbidden_mask = forbidden.astype(bool)\n", + "for ax, P_pc, title in zip(\n", + " axes,\n", + " [P_uncon_pc, P_con_pc],\n", + " [\"Unconstrained transport\", rf\"Mass into forbidden cluster $\\leq {t_cap}$\"],\n", + "):\n", + " # Source and target points\n", + " ax.scatter(\n", + " src_pc[:, 0],\n", + " src_pc[:, 1],\n", + " c=\"C0\",\n", + " s=45,\n", + " alpha=0.8,\n", + " label=\"source\",\n", + " zorder=3,\n", + " )\n", + " ax.scatter(\n", + " tgt_pc[:, 0],\n", + " tgt_pc[:, 1],\n", + " c=\"C3\",\n", + " s=45,\n", + " alpha=0.8,\n", + " marker=\"s\",\n", + " label=\"target\",\n", + " zorder=3,\n", + " )\n", + " # Highlight forbidden targets with red rings\n", + " ax.scatter(\n", + " tgt_pc[forbidden_mask, 0],\n", + " tgt_pc[forbidden_mask, 1],\n", + " facecolor=\"none\",\n", + " edgecolor=\"red\",\n", + " s=220,\n", + " linewidth=2,\n", + " zorder=4,\n", + " label=\"forbidden\",\n", + " )\n", + " # Top-mass arrows\n", + " threshold_p = np.percentile(P_pc, 88)\n", + " for i in range(n_pc):\n", + " for j in range(n_pc):\n", + " if P_pc[i, j] > threshold_p:\n", + " w = 4 * P_pc[i, j] / P_pc.max()\n", + " ax.plot(\n", + " [src_pc[i, 0], tgt_pc[j, 0]],\n", + " [src_pc[i, 1], tgt_pc[j, 1]],\n", + " \"k-\",\n", + " alpha=min(P_pc[i, j] / P_pc.max() * 0.8, 0.6),\n", + " linewidth=w,\n", + " zorder=2,\n", + " )\n", + " ax.set_xlim(0.05, 1.0)\n", + " ax.set_ylim(0.1, 0.95)\n", + " ax.set_aspect(\"equal\")\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.set_title(title)\n", + " ax.legend(loc=\"lower right\", fontsize=9)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "f4d3a5c9", + "metadata": {}, + "source": [ + "On the left, standard OT routes a quarter of the mass into the upper cluster (red rings), as that is its closest target for several source points. On the right, with the cap activated, the algorithm reroutes most of that mass to other available targets, paying only a small price in transport cost in exchange for satisfying the constraint. The constraint dual $\\alpha$ is exactly the *shadow price* the LP would need to pay to relax the cap by one unit; we make this interpretation precise in §9.\n", + "\n", + "This kind of pattern, redirecting mass away from one zone, is exactly what arises in fairness-constrained matching, capacity-limited assignment, and any application where a hard usage cap must be respected." + ] + }, + { + "cell_type": "markdown", + "id": "44e34b0f", + "metadata": {}, + "source": [ + "## 8. A realistic example : fairness-constrained matching\n", + "\n", + "We turn to a more applied scenario: an assignment problem where the cost matrix carries a structural bias against one group. Standard OT will faithfully optimise the average cost, but the resulting allocation can subject one group to systematically worse matches : a form of *cost inequity*. With the constrained solver we can fix this directly: ask that the gap between the two groups' average matching cost stays below a chosen threshold." + ] + }, + { + "cell_type": "markdown", + "id": "9dafb4b9", + "metadata": {}, + "source": [ + "**Setup.** We have $n = 50$ jobs to assign to $n = 50$ candidates. Half the candidates are in group A, half in group B. The cost matrix $C$ is constructed so that A-candidates have systematically lower mismatch on most jobs : group B systematically receives the worse matches.\n", + "\n", + "Concretely, if we write $\\mathrm{avg\\_cost}_A = \\frac{\\sum_i \\sum_{j \\in A} C_{ij} P_{ij}}{\\sum_i \\sum_{j \\in A} P_{ij}}$ and similarly for $B$, the *cost gap*\n", + "$$\\mathrm{gap} \\;=\\; \\mathrm{avg\\_cost}_B - \\mathrm{avg\\_cost}_A$$\n", + "captures how much worse off group B is. We constrain $\\mathrm{gap} \\le \\delta$. With uniform marginals, both groups receive equal mass (50% each), so $\\mathrm{avg\\_cost}_B - \\mathrm{avg\\_cost}_A = 2 \\langle P,\\ C \\odot (\\mathbf{1}_B - \\mathbf{1}_A) \\rangle$; our solver implements this as a single inequality $D \\cdot P \\le 0.5\\,\\delta$ with $D = C \\odot (\\mathbf{1}_B - \\mathbf{1}_A)$." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7c5059b1", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:40:29.555088Z", + "iopub.status.busy": "2026-05-02T06:40:29.554520Z", + "iopub.status.idle": "2026-05-02T06:40:31.350863Z", + "shell.execute_reply": "2026-05-02T06:40:31.349887Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vanilla Sinkhorn (no fairness):\n", + " avg cost on group A = 0.0586\n", + " avg cost on group B = 0.3172\n", + " gap (B - A) = 0.2587\n", + " total transport cost = 0.1879\n" + ] + } + ], + "source": [ + "# Construct a biased matching problem.\n", + "rng_fair = np.random.default_rng(3)\n", + "n_fair = 50\n", + "half = n_fair // 2\n", + "group_A = np.arange(half)\n", + "group_B = np.arange(half, n_fair)\n", + "\n", + "C_fair = 0.5 + 0.1 * rng_fair.standard_normal((n_fair, n_fair))\n", + "# 'Top-half' jobs match group A particularly well.\n", + "C_fair[np.ix_(np.arange(half), group_A)] -= 0.3\n", + "# Across all jobs, group A gets a small uniform advantage.\n", + "C_fair[np.ix_(np.arange(half, n_fair), group_A)] -= 0.1\n", + "C_fair = np.clip(C_fair, 0.01, None)\n", + "\n", + "a_f = np.ones(n_fair) / n_fair\n", + "b_f = np.ones(n_fair) / n_fair\n", + "C_fair_j = jnp.array(C_fair)\n", + "a_fj = jnp.array(a_f)\n", + "b_fj = jnp.array(b_f)\n", + "\n", + "# Vanilla OTT-JAX Sinkhorn: ignores fairness.\n", + "geom_fair = geometry.Geometry(cost_matrix=C_fair_j, epsilon=0.005)\n", + "out_v = sinkhorn.Sinkhorn(threshold=1e-7)(\n", + " linear_problem.LinearProblem(geom_fair, a=a_fj, b=b_fj)\n", + ")\n", + "P_van = np.array(out_v.matrix)\n", + "\n", + "mass_A_v = P_van[:, group_A].sum()\n", + "mass_B_v = P_van[:, group_B].sum()\n", + "avg_A_v = (P_van[:, group_A] * C_fair[:, group_A]).sum() / mass_A_v\n", + "avg_B_v = (P_van[:, group_B] * C_fair[:, group_B]).sum() / mass_B_v\n", + "print(\"Vanilla Sinkhorn (no fairness):\")\n", + "print(f\" avg cost on group A = {avg_A_v:.4f}\")\n", + "print(f\" avg cost on group B = {avg_B_v:.4f}\")\n", + "print(f\" gap (B - A) = {avg_B_v - avg_A_v:.4f}\")\n", + "print(f\" total transport cost = {(P_van * C_fair).sum():.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "681fe945", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:40:31.352841Z", + "iopub.status.busy": "2026-05-02T06:40:31.352683Z", + "iopub.status.idle": "2026-05-02T06:40:35.108261Z", + "shell.execute_reply": "2026-05-02T06:40:35.107254Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "With fairness constraint (gap <= 0.05):\n", + " avg cost on group A = 0.2328\n", + " avg cost on group B = 0.2829\n", + " gap (B - A) = 0.0501\n", + " total transport cost = 0.2578\n", + " alpha (constraint dual, raw) = 47.56\n", + " cost overhead vs vanilla = 37.2%\n" + ] + } + ], + "source": [ + "# Fairness constraint: gap = avg_cost_B - avg_cost_A <= delta.\n", + "# With uniform marginals, sum_{i, j in B} P_{ij} = sum_{i, j in A} P_{ij} = 0.5.\n", + "# So 0.5 * gap = sum_{i, j in B} P_{ij} C_{ij} - sum_{i, j in A} P_{ij} C_{ij}\n", + "# = \n", + "# Encode as a single inequality:\n", + "# <= 0.5 * delta\n", + "# i.e. D = C * (1_B - 1_A); we want D . P <= 0.5 * delta.\n", + "delta = 0.05\n", + "indicator = np.where(np.arange(n_fair) >= half, 1.0, -1.0) # +1 for B, -1 for A\n", + "D_gap = C_fair * indicator[None, :]\n", + "D_solver_fair = (0.5 * delta * np.ones((n_fair, n_fair)) - D_gap) / n_fair\n", + "\n", + "# Warm up\n", + "_ = constrained_sinkhorn(\n", + " C_fair_j,\n", + " a_fj,\n", + " b_fj,\n", + " jnp.array(D_solver_fair[None, ...]),\n", + " jnp.zeros((0, n_fair, n_fair)),\n", + " eps=1.0 / 200.0,\n", + " n_iters=10,\n", + " n_newton=5,\n", + ")\n", + "\n", + "res_fair = constrained_sinkhorn(\n", + " C_fair_j,\n", + " a_fj,\n", + " b_fj,\n", + " jnp.array(D_solver_fair[None, ...]),\n", + " jnp.zeros((0, n_fair, n_fair)),\n", + " eps=1.0 / 200.0,\n", + " n_iters=250,\n", + " n_newton=10,\n", + ")\n", + "P_fair = np.array(res_fair.matrix)\n", + "\n", + "mass_A_f = P_fair[:, group_A].sum()\n", + "mass_B_f = P_fair[:, group_B].sum()\n", + "avg_A_f = (P_fair[:, group_A] * C_fair[:, group_A]).sum() / mass_A_f\n", + "avg_B_f = (P_fair[:, group_B] * C_fair[:, group_B]).sum() / mass_B_f\n", + "overhead = ((P_fair * C_fair).sum() - (P_van * C_fair).sum()) / (\n", + " P_van * C_fair\n", + ").sum()\n", + "print(f\"With fairness constraint (gap <= {delta}):\")\n", + "print(f\" avg cost on group A = {avg_A_f:.4f}\")\n", + "print(f\" avg cost on group B = {avg_B_f:.4f}\")\n", + "print(f\" gap (B - A) = {avg_B_f - avg_A_f:.4f}\")\n", + "print(f\" total transport cost = {(P_fair * C_fair).sum():.4f}\")\n", + "print(f\" alpha (constraint dual, raw) = {float(res_fair.alphas[0]):.2f}\")\n", + "print(f\" cost overhead vs vanilla = {100*overhead:.1f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "a1a04b32", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:40:35.110685Z", + "iopub.status.busy": "2026-05-02T06:40:35.110033Z", + "iopub.status.idle": "2026-05-02T06:40:35.235825Z", + "shell.execute_reply": "2026-05-02T06:40:35.234940Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAvMAAAHVCAYAAAB8GHRhAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAAAlcFJREFUeJzs3XdUFNfbB/DvUpdeBBGUoiC2QBQUETsWbNiiiEYQSyzBSjR2QU0ssURjiYoFxd4rdkQTu6gxUbGgYkMRbChVdt4/fNmfm13asIjG7+ccznFvm2dWdnl29s69EkEQBBARERER0WdHo7QDICIiIiIicZjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9ppjME33mgoKCIJFIcO/evdIO5ZMWFhYGiUSCmJgYhXKJRIImTZoUqu3n7smTJwgKCoKdnR00NTWL9Xuj6nkjKkhERAQkEgkiIiJKOxQF9+7dg0QiQVBQUIkdQ53nLpFI4ODgUKi2MTExJX5uVLqYzNNHcfPmTQwfPhxff/01TE1NoaOjAysrK7Rs2RILFy7E69evP0ocX8qb2pdynlQ0QUFBWL16NWrVqoXx48cjNDQUpqampR0WlZLcBJYfysTJzMzE3Llz4enpKf+7Vq5cOdSqVQsDBw7EgQMHSjtE+kJolXYA9N83bdo0TJw4ETKZDHXq1EHPnj1hYmKCZ8+e4Y8//sCQIUMQFhaG5OTk0g71szR9+nSMGTMG5cuXL+1QPmmDBw+Gv78/7OzsSjuUUpGVlYXDhw+jSpUq2LVrV7HHu379OvT19dUQGX1JOnXqBE9PT1hbW5d2KMXy9u1bNGnSBBcuXICVlRU6deqEcuXK4dmzZ7h16xZWrVqFe/fuoVWrVvI+/5Vzp08Pk3kqUTNnzsT48eNRvnx5bNq0CfXr11dqc+LECQwZMqQUovtvsLa25h+HQrCwsICFhUVph1Fqnjx5AplMhnLlyqllvKpVq6plHPqymJiYwMTEpLTDKLb58+fjwoULaNmyJfbs2QMdHR2F+tTUVFy8eFGh7L9y7vTp4TQbKjEJCQmYOHEitLW1ERUVpTKRB4BGjRrh3LlzSuUxMTFo27YtypQpA11dXVSqVAnDhw/Hs2fPlNomJSXhxx9/RNWqVWFgYAAjIyNUrFgR/v7++OuvvwC8nwfdtGlTAMDq1ashkUjkP4Wdw/jkyRMMHz4clStXhlQqhZmZGVq2bImjR4+qbP/27VuMHj0a9vb20NXVhaOjI8LCwpCVlaXy6+385r/nNafz330KOs99+/ZBIpGgd+/eKmOWyWSwt7eHvr4+Xrx4Uajn5UOtWrWCRCJBbGysyvoDBw5AIpEgICBAXhYbG4thw4bh66+/hrm5OaRSKSpXrowffvhBZQy5c0/DwsJw+fJltG3bFqamptDX10fjxo1x6tQppT7qmAe/c+dO9OzZE87OzjAwMIChoSHc3d3x22+/QSaTFWqMjRs3QiKR5PkBViaToXz58jAwMEBqaiqA91fVFyxYAHd3d5QpUwZSqRTly5eHj48Ptm/fXuAxHRwcYG9vDwA4fvy4/Pch93fp1atXmDVrFry9vVGhQgXo6OjA0tIS7du3x5kzZ1SOmd+9BhEREYiKikKjRo1gbGwMMzMzpfpjx46hSZMmMDIygrGxMdq2bYvr16+rPFZGRgZmz54Nd3d3GBoawsDAALVr18aSJUsgCIJS+z179qB58+awsbGBrq4urKysULduXUybNk2hXWHeNwrr0aNHGD58OJydnaGnpwczMzO4u7tj0qRJyM7OVmh7584d9O7dW/5cW1lZwc/PD1euXFEaV8zvempqKn766Se4uLjAxMQEBgYGsLOzQ4cOHeS//xEREahYsSIAxd+J3GMBitNwHj9+jD59+sDa2hqamprYuXMnAPGv3X+/5zo4OEAikeDdu3eYNm0aKleuDF1dXdja2mL06NHIyspS+bzfvn0b/fr1k7/HWlpaolOnTkqJdK5nz55h4MCBsLa2hlQqRfXq1fHbb7+p/D3Kz8mTJwEAgwYNUkrkAcDIyAiNGzcu0XNXZfv27dDT00OlSpVw48YNpfp79+7B398fFhYWkEqlqF27Nvbu3atyrKysLMyaNQs1a9aEvr4+jIyM4OnpiZUrV6p8vnLn8b969QrDhw+Hvb09tLS0MG/ePIX6tLQ0jBo1CnZ2dtDV1YWTkxNmzpxZ5P8D+h9emacSs2rVKmRnZ6Nbt25wdXXNt62urq7C4+XLl6N///7Q09ND165dYW1tjVOnTmH+/PnYsWMHTp48iQoVKgAA0tLS4OXlhfj4eDRr1gzt2rUDADx48ADR0dFo3rw5vv76azRp0gT37t3D6tWr8fXXX6Njx47y49WsWbPA8/n777/RokULJCUloWXLlujQoQNSUlKwc+dOtGjRAsuXL0efPn3k7bOystCiRQucPn0a1atXx9ChQ5Geno7ff/+9yIlCURR0nq6urqhYsSI2bdqEX3/9VWnO9P79+3H//n0EBQXBzMwMERER6N27Nxo3blyoRDgoKAgHDx5EREQE3N3dlepXr14tb5crPDwcO3bsQOPGjdG8eXPk5OTg4sWLmDt3Lvbv34+zZ8/CyMhIaawLFy7gl19+Qb169dCvXz/cv38f27ZtQ7NmzXD58mVUqVKlUM9ZYY0ZMwYaGhqoW7cuypcvj1evXiE6OhrDhg3D+fPnERkZWeAYHTt2hKmpKTZs2IA5c+YoJQKHDx/G48eP0bNnT/k5BwUFYcOGDahevTq+/fZbGBgY4PHjxzh37hy2b9+Ozp0753vM4cOH4969e5g/fz7s7e3lz33u7/3169cxfvx4NGrUCG3btoWZmRnu37+P3bt3Y//+/dizZ4/CdIGCbNmyBQcPHkTbtm0xaNAgPH36VKF+79692LVrF1q3bo2BAwfi2rVriIqKwvnz53Ht2jWFb1BSU1PRvHlznDt3DrVq1ZLHfvDgQQwaNAhnzpxRSI6WLVuGAQMGwMrKCu3atUPZsmWRnJyMa9euYcmSJRg3bhyAwr9vFMaFCxfQqlUrpKSkoEGDBujYsSMyMjIQFxeH6dOnIyQkRP46u3jxIpo1a4aXL1+ibdu2cHV1RXx8PLZv3449e/Zg165daNmypcpjFOZ3XRAEtGrVCqdOnYKHhwf69OkDHR0dPHr0CH/88QeOHDmCJk2aoGbNmhg2bJjS7wQApQ9pKSkpqFevHkxMTNC1a1fIZDKYm5sDEP/azUuPHj3wxx9/oHXr1jA2NkZUVBR++eUXJCUlYdWqVQpto6Oj0aFDB2RkZKBdu3aoXLkyHj16hO3bt2P//v3YtWsXfHx85O2fP38OLy8v3L59Gx4eHujVqxdSUlIwceJEHDt2rNAxApCf/82bN4vULz9FOXdV5s+fj5CQELi5uWHv3r2wsrJSqE9ISICHhwcqVaqEgIAAPH/+HJs2bUKHDh1w5MgR+UUgAMjOzkbr1q0RHR0NZ2dnDBo0CFlZWdi+fTv69u2LP//8EytXrlSKITMzE97e3nj58iVat24NfX19+d/q3HFbtmyJx48fo3Xr1tDS0sLOnTsxZswYZGRkIDQ0tBjP4BdMICoh3t7eAgAhPDy8SP3u378v6OjoCAYGBsI///yjUDdhwgQBgNC2bVt52e7duwUAwrBhw5TGevfunfDixQv542PHjgkAhF69ehUppnfv3gnOzs6Crq6uEBMTo1D3+PFjoUKFCoKenp7w9OlTefn06dPlsb57905e/uzZM8HBwUEAIDRu3FhhrF69egkAhLt37yrFcPfuXZWxq+pT0Hn+8ssvAgBh3rx5SnXt2rUTAAhnzpwRBEEQVq1apTLWvKSnpwumpqZCmTJlhMzMTIW6ly9fClKpVLC3txdkMpm8/N69ewrPUa7ly5cLAIQZM2YolOfGBEBYtWqVQt2SJUsEAMKgQYMUykNDQwUAwrFjxxTKVZ1bXm1v376tFGNOTo4QGBio8JwVZODAgQIAYevWrUp13bt3FwAIR48eFQTh/XMmkUgEd3d3ITs7W6n9s2fPCnXM3N8fVf+PL1++VDnOgwcPBGtra6Fq1apKdfk9bxKJRNi/f79Sn9x6TU1N4ciRIwp1Y8aMEQAIM2fOVCjv27evyt+BjIwMoU2bNgIAYffu3fJyNzc3QUdHR3jy5InS8T88x6K8b+QnMzNT/npevXq1Un1iYqL8/00mkwnVq1cXAAgREREK7Q4fPixIJBLB0tJSePv2rby8qL/rV65cEQAIHTp0UIpFJpMJycnJ8sf5/U58WA9ACAgIUPn7J/a1++9zsbe3FwAIbm5uQkpKirz8zZs3gqOjo6ChoSEkJibKy1++fCmUKVNGMDc3F65evaow1rVr1wRDQ0PB2tpayMjIkJcPGDBAACAMHDhQof3t27cFExOTIv1t2LNnjwBA0NHREQYOHCjs3r1bePjwYb591HXugvD+9Wdvby8Iwvv/15CQEPnfmzdv3ii0zf17AEAICwtTqDtw4IAAQGjdurVC+YwZMwQAQsuWLRXex1++fCnUqFFDACBs2bJFKSYAQrNmzRR+h/9d37p1ayEtLU1e/vTpU8HExEQwMTERsrKy8nj2KD9M5qnEVKtWTQCg8o96fn766ScBgDBq1CiluoyMDMHGxkYAIDx69EgQhP/9UR4zZkyBY4tN5nOPMWLECJX18+fPFwAIixYtkpdVrlxZkEgkwvXr15Xar1y5slST+eTkZEEqlQrVqlVTKL9//76gqakp1KpVS1728uVL4fr160JCQoLKsVTJTVa3bdumUL506VIBgDBx4sRCjSOTyQRjY2OhadOmCuW5fxTr16+v1CcrK0vQ0tIS3N3dFcrVkcznJTY2VgAgTJ48uVDtz5w5IwAQfH19FcpfvXol6OnpKXzYef36tQBAqFevnsIHoKIqKHHLy5AhQwQASv//+T1vHTt2VDlWbv23336rVHfnzh0BgPDNN9/Iy1JSUgQtLS2F38cP/fXXXwIAoWvXrvIyNzc3QV9fXyEpUqUo7xv52bp1qwBAaNOmTYFt//zzTwGAUKdOHZX1nTt3FgAI69evl5cV9Xc9N5n39/cvMJ7CJvM6OjoKFyoKo6DXbl4J7eHDh5XGmjRpkgBA2LNnj7zst99+EwAI8+fPV3n8ESNGCACEffv2CYLw/rkyMDAQDAwMFD7Q/PsYRfnbMH/+fPmHgNwfS0tLoUuXLvLjlsS5C8L/kvn09HShS5cuAgChf//+Kj9Y5f49sLe3V1lvZ2cnlClTRqHMyclJAKB0QU0Q/vfaadmypVJMAITLly8r9fmw/tatW0p1uRdE/v77b5V9KX+cZkMlRvj/+W8SiaRI/XLnOnp7eyvV6erqokGDBti8eTMuXboEGxsbNG7cGLa2tpg5cyYuXLiAtm3bwsvLC25ubtDSUs+veO78yPv378vnk37o1q1bACCf85uamopbt26hXLlyKm8ULO2l4MqUKYNu3bph9erVOH78uHxu5/Lly5GTk4OBAwfK24q5aSsoKAhLlixBRESEwhSQ3Ck2vXr1UmifnZ2NpUuXYuPGjbh69Spev36tMAf90aNHKo9Tu3ZtpTJtbW1YWVmJmu9fkJSUFMyaNQtRUVG4c+cO3r59q1CfV5z/VrduXVSrVg379+9HUlISypYtCwDYtGkT0tPTERgYKH/dGBkZoUOHDti1axdcXV3RuXNnNGjQAPXq1YOhoaHazu3kyZOYP38+Tp8+jaSkJKV5uo8ePSr0SkAeHh751qv6f7O1tQUAhf+3c+fO4d27d9DQ0FD5usudi/7hXPuAgACMGDEC1atXR7du3dCoUSN4eXkp3SRelPcNVccOCgqCg4OD/J6C1q1b53vOQP7vbQDQvHlzbN++HRcvXkT37t0V6gr7u169enW4ublh48aNuHfvHjp06ID69eujTp06kEqlBcaoioODg/x39N/EvnbzUtjfjdz35CtXrqj8/8mdL379+nW0adMGcXFxePv2LTw9PVGmTBml9k2aNMGUKVOKFOvQoUPRr18/HD58GKdOncKlS5dw6tQpbN26FVu3bkWfPn2wfPnyQv8NLOy550pPT0fz5s1x8uRJ/PTTTxg/fny+49esWROampoqj3H69Gn549TUVNy+fRtWVlaoUaOGUvtmzZoBgMr7EqRSab7Tak1MTODk5KQyBkD1eVLBmMxTibGxsUFcXBwePHhQpH6vXr0CgDxX3cj9o/zy5UsAgLGxMc6cOYPJkydj165dOHLkCADA1NQUffr0wdSpU4u9hF5KSgoAYNu2bdi2bVue7d68eVOkcyhNwcHBWL16NZYuXYrGjRsjJycHK1asgJGREXr06FGssevWrYvq1asrJKu3b9/GqVOn0KhRIzg6Oiq079atG3bs2IFKlSqhY8eOKFeunPw+innz5iEzM1PlcfJaI11LSws5OTnFOod/e/nyJerUqYO7d+/Cw8MDgYGBMDc3h5aWFl6+fIn58+fnGacqQUFBGD16NNauXYuQkBAA/7th+d8fdjZu3IjZs2dj3bp18oRDW1sbvr6+mDNnTqE3j8nLjh070KVLF0ilUrRo0QKOjo4wMDCAhoYGYmJicPz48SKdW0Er5qj6f8tNoD/8f8t93cXGxuZ5QzXwv9cd8P7+gLJly2Lx4sVYtGgRfvvtNwCAp6cnpk+fLv8gXZT3jcmTJysds0mTJnBwcJC/DxVmadiivrd9qLC/65qamjh69Ch+/vlnbNmyBWPHjgUA6Ovro1u3bvjll1+KvKpTfv+fYl+7eSnq78aKFSvyHa+k35P19fXRoUMHdOjQAcD7e6XCw8MxbNgwrFy5Eu3bt5fXFaSw557rzZs3iI2NhZGRUaE+TOb3O/ThB7CCnit9fX2YmJio/D0tW7Zsvh9e8osBUH2eVDCuZkMlpmHDhgCQ50ovecm9CvzkyROV9YmJiQrtgPcfHJYuXYrExERcu3YNv//+O+zt7TF37ly1LHuZe6xt27ZBeD89TeVP7k1KhT2Hf9PQeP+SfPfunVKdqjfO4qhTpw7q1KmDbdu24dmzZ9izZw8ePXqEnj17quWKb69evfDu3TusXbsWAOQ3Kf57NZ4LFy5gx44daNasGeLi4rBq1SpMnz4dYWFhmDRpUpFWcihJy5cvx927dzFp0iScPXsWixcvxk8//YSwsDB069atyOMFBARAU1NT/m3F7du3cfLkSTRs2FDpw45UKsWECRNw/fp1PHr0CBs3bkSrVq2wfft2tGrVSmm1lKKaOHEidHR0cOHCBezcuRNz5szBlClTEBYWJuom4qJ+G5eX3NfRkCFD8n3d3b17V6Ffjx498Oeff+L58+c4cOAAgoODcfHiRbRu3VrhhsXCvm+oOmbuh4Lc5KQwV6DFvLeJYWpqilmzZuHevXu4c+eO/Gb0VatWifpdzev/szRfu7nPUWxsbL6/G7k3VIp9Ty4qHR0dBAcHy79ZKerfv6KwtLTEvn37IJPJ4O3tnefKU0VV0HOVlpaGV69eqfw9Vddrn4qGyTyVmN69e0NbWxvbtm3DP//8k2/bD6/euLm5AYDK1QUyMzPlX6/mtvuQRCJBtWrVMHDgQPzxxx/Q1dXFjh075PW5XzEW9dN/vXr1AAB//PFHodobGRmhcuXKePr0KeLi4pTq81oVJncJP1XfZpw/f76Q0Rb+PL///ntkZWVh1apVWLp0KQAoTLEpjg+TVUEQEBkZCQMDA3Tt2lWh3e3btwEAHTp0gLa2tkLduXPnkJ6erpZ4iis3zi5duijVHT9+vMjjWVtbo2XLlrhy5QouXbqkcpUfVWxsbNCtWzfs3r0b9erVw40bN/Jc0rGwbt++jerVq6NatWoK5TKZDH/++Wexxi6OunXrQkNDo9Cvu38zNjaGj48PFi5ciB9++AEZGRkqd+Us6H0jP56engDerwJVkPze24D/JX6qVoESq2LFiujVqxeio6Nha2uL6Oho+ZVXse+HuUrztVvU9+SqVatCX18ff//9N54/f65UX5wla1XJXcEnd7ppSfH29sbBgwchCAJatGiBEydOFHtMIyMjODk54enTp7h27ZpSfXR0NAD1/p5S8TCZpxJjb2+PqVOnIjs7G23atMnzqsHJkyflfxABoGfPntDR0cHixYuVEuHp06fj0aNHaNOmDWxsbAAA//zzj9KVOeD917DZ2dnQ09OTl+XOlbx//36RzqV9+/ZwcnLCkiVLsHv3bpVtLl26JP/qFwD69OkDQRAwcuRIhT+WycnJec7NrFu3LgBg6dKlCn8EEhISVH7Vn5fCnqe/vz/KlCmDefPm4dChQ/Dy8lKa7/jq1SvExcUV+TmztraGj48Prly5grlz5+L+/fvo0qWL0lX/3Cki//5jmpSUhODg4CIdsyTlFeelS5cwffp0UWPmrvW/atWqPD/sPHv2TOVrJzMzU/5tzYe/42I4ODjg1q1bePz4sbxMEARMnjxZ5R/zj8XS0hIBAQG4fPkywsLCVH5j9fDhQ4X3if3796v8piL3KmPuc1WU9438+Pr6wsHBAVFRUSqXJn369Kk8bi8vL1SrVg3nzp2Tf2OVKzo6Gtu3b4eFhUWhp2WocvfuXZUXT1JTU/H27Vtoa2vLpzSYmZlBIpEUeSpkrtJ87fbu3RtmZmaYOnWqwnzvXIIg4M8//5R/O6CtrY2AgAC8fftWaW55fHw85s+fX6TjL1myJM+/aXFxcdiyZQuA9/uolLT69evj6NGj0NHRQevWrXH48OFij9m3b18AwA8//KDwenr9+rV8edd+/foV+zikHpwzTyVq9OjRePfuHSZNmoR69erBw8MDHh4eMDExQXJyMk6ePIl//vlHYQ6nvb09fvvtNwwaNAi1a9eGn58fypUrh1OnTuH48eOoUKECfv/9d3n7I0eOICQkBJ6enqhWrRqsrKzw5MkT7Nq1CzKZTD5nFACqVKkCW1tb/PHHH/j222/h7OwMTU1NtG/fPt+bdrS1tbFjxw75+vJ169aFm5sbDA0N8eDBA1y6dAk3btzApUuX5Il0SEgIdu/ejX379sHV1RVt2rRBeno6tmzZAi8vL5UbQ3Xo0AFVq1bFpk2b8ODBA3h5eeHx48fYvXs32rVrh40bNxbqeS/seUqlUvTp0wezZs0CoPqq/I4dO4q0zvyHgoKCEBUVJf8/UHXVuU6dOqhfvz62b98OLy8vNGjQAE+fPsX+/ftRpUoV+Ye20hYYGIhZs2Zh+PDhOHbsGCpXroxbt25h79696Ny5MzZt2lTkMdu3bw9zc3MsWbIE2dnZ6NWrl9KHnUePHqFevXqoUqUK3N3dYWtri7dv3+LgwYO4desWvvnmG1SuXLlY5zZixAgMHDgQtWrVwjfffANtbW2cPHkS165dg6+vL/bs2VOs8YtjwYIFuHXrFiZPnozIyEg0atQI5cqVw5MnT3Djxg2cOXMGc+fOld9o3r17d+jo6KBhw4byzXjOnTuHP/74A46OjvDz8wNQtPeN/Ojo6GDLli3w8fFBYGAgwsPDUa9ePWRlZeHGjRs4cuQIkpKSYGpqColEgtWrV6N58+YIDAzE5s2b4eLigvj4eGzbtg06OjpYs2ZNse7x+euvv9CpUyfUqlULLi4usLGxwYsXL7B37148f/4cP/zwAwwMDAAAhoaGqFevHk6dOgVfX1+4u7tDS0sLjRo1KlQSWpqvXXNzc2zbtg0dO3aEl5cXvL29UaNGDWhra+PBgwc4e/Ys7t+/jxcvXsj3cpg2bRqOHj2KJUuW4NKlS2jatCmSk5OxefNmeHt7yzfCKowDBw5g0KBBcHBwgJeXF+zs7JCZmYlbt27h4MGDyM7ORufOnVV+k1cSateujWPHjqFFixbw9fXF1q1b5XsniBESEoIDBw7gwIEDcHFxQbt27ZCdnY1t27bh0aNHCAwMVLrwQKWopJfLIRIEQbhx44YwbNgwwcXFRTA2Nha0tLQES0tLwdvbW/jtt9+E169fK/U5evSo0KpVK8HMzEzQ1tYWHBwchCFDhiitH33t2jVhxIgRQu3atQVLS0tBR0dHsLW1Fdq1ayccOnRIadzY2FihefPmgomJiSCRSFQuFZaXZ8+eCePHjxdcXFwEfX19QU9PT6hUqZLg6+srhIeHK6ydKwiCkJqaKowaNUqwtbUVdHR0hEqVKgmhoaFCZmZmnkvCPXr0SOjRo4dgbm4u6OrqCi4uLsLy5cuLtDRlUc4zLi5OACCUKVNGSE9PV6ov6jrzH8rIyBDMzc0FAELFihXzXFoxJSVFGDRokGBvby/o6uoKlSpVEsaOHSu8fftWsLe3l6+n/O+YQkNDVY6nqo86lqa8evWq4OvrK1haWgr6+vqCm5ubEB4enuf/TWEEBwfLl2xTtRTmixcvhClTpghNmzYVypcvL+jo6Ahly5YVvLy8hPDwcJVrf6tS0DKEq1atEr7++mtBX19fKFOmjNCxY0fhypUrop63vF5PBdXnFV9WVpbw+++/Cw0aNBBMTEwEHR0doUKFCkLDhg2F6dOnK6zv/fvvvwudOnUSKlWqJOjr6wsmJiaCi4uLEBoaqrAkYVHfNwpy//59ITg4WKhYsaKgo6MjmJmZCe7u7kJoaKjS2tm3bt0SevXqJdjY2Aja2try5QwvXbqkNG5Rf9cfPHggjBs3TvDy8hLKlSsn6OjoCNbW1oK3t7ewefNmpf7x8fFCx44dhTJlyggaGhoKxyrMcqZiX7t5Lc+oSl59BEEQEhIShKFDhwrOzs6CVCoVDA0NhcqVKwt+fn7CunXrhJycHIX2SUlJQv/+/QUrKytBV1dXqFatmjBv3jz50qiFfQ3fuHFDmDNnjtCmTRvByclJMDAwELS1tQUbGxuhTZs2wvr165Xe79R57vhgnfkPXb9+Xf57lbsOfEFLFTdu3Fjl8TMyMoQZM2YILi4uglQqFfT19QUPDw8hPDxc5Xt5XjEVpr6oywGTIokgcP9cotIgkUhEXe1Wt40bN6J79+744YcfMHv27FKNhYiIiIqGc+aJvmA5OTmYNWsWNDQ08P3335d2OERERFREnDNP9AX6888/cfz4cRw/fhwXL15Ev379UKlSpdIOi4iIiIqIyTzRF+jIkSOYPHkyzMzM0Lt3b8ybN6+0QyIiIiIROGeeiIiIiOgzxTnzRERERESfKSbzRERERESfKSbzavLu3Ts8fPhQ5Q6FREREREQlgcm8mjx58gS2trbyLcOJiIiIiEoak3kiIiIios8Uk3kiIiIios9UsdaZz8rKwvLly7F3714kJCQAAOzt7dGuXTv07dsXurq6agmSiIiIiIiUib4yn5iYCDc3NwwePBgXL16EqakpTExMcPHiRQwePBhubm54/PixOmMlIiIiNVm0aBEcHBwglUrh6emJ8+fP59l2165dqFOnDkxNTWFgYICaNWsiMjJSoc327dvh4+MDCwsLSCQS/PPPPwr1MTExkEgkKn/yOzYR5U90Mj948GDEx8dj3bp1SExMxMmTJ3Hq1CkkJiZi7dq1uHPnDoYMGaLOWImIiEgNNm3ahJCQEISGhuLixYtwdXWFj48PkpOTVbY3MzPD2LFjcfr0aVy5cgV9+/ZF7969ceTIEXmbt2/fon79+pgxY4bKMby8vJCYmKjw069fPzg4OKB27dolcp5EXwLRO8AaGRnh+++/x8yZM1XW//jjj/j999+RmpparAA/Fw8fPoStrS0ePHiAChUqlHY4REREeapbty48PDywYMECAIBMJoOtrS1GjBiBkSNHFmoMNzc3dOjQAaGhoQrl9+7dQ8WKFfH333/jq6++yrN/dnY2KlSogMGDB2PixIniT4boCyf6yryenh5sbW3zrLezs4Oenp7Y4YmIiKgEZGVlITY2Fi1btpSXaWhooHnz5jh9+nSB/QVBwNGjR3Hjxg00bNhQdBy7d+9GcnIygoKCRI9BRMW4AdbPzw8bNmzAgAEDoK2trVCXlZWF9evXo1u3bsUO8L9GEAQkJycjIyMDOTk5pR0O0WdNU1MTUqlUPkeXiAqWnJyMnJwcWFlZKZRbWVnh9u3befZ79eoVypcvj8zMTGhqamLJkiXw9vYWHceKFSvg4+OT74VBIiqY6GS+S5cu+OOPP1CnTh0MHDgQTk5OkEgkuHnzJpYuXSpvc+7cOYV+Hh4exYv4MyYIAh49eoTU1FTo6OhAU1OztEMi+qxlZWXhzZs3yMzMRPny5ZnQExXBv18vgiDk+xoyMjLC5cuX8ebNGxw9ehTDhw+Ho6OjqKvzDx8+xMGDB7F58+Yi9yUiRaKT+Q8/jX///ffyN4APp+B/2Cb3TeJLvhqdnJyM1NRUlC1bFmXKlCntcIj+E1JSUpCUlITk5GRYWlqWdjhEnzwLCwtoamoq7VielJSkdLX+QxoaGnBycgIA1KxZE9evX8eMGTNEJfOrVq1CmTJl0L59+yL3JSJFopP5VatWqTOOL0JGRgZ0dHSYyBOpUZkyZfDy5UtkZGSUdihEnwUdHR24u7vj8OHD8PX1BfD+Btjcq+2FJQgCMjMzi3x8QRCwatUqBAYGKk3TJaKiE53M9+rVS51xfBFycnI4tYaoBGhqan7R3/oRFVVISAgCAwPh7u4ODw8PzJs3D2lpafKbUQMDA1G+fHlMnz4dADBjxgy4ubnByckJWVlZ2L9/P9asWYNly5bJx3z+/Dnu378v32Pmxo0bePfuHezs7GBubi5vFx0djbt376Jv374f74SJ/sOKtQMsERERfX66deuGZ8+eYdKkSXjy5Alq1qyJAwcOwMLCAgBw//59aGj8b8G79PR0BAcH4+HDh9DT00PVqlWxdu1ahYUudu/ejd69e8sfd+nSBcD7b/I/XLFmxYoV8PLyQrVq1Ur4LIm+DKLXme/Tp0/Bg0skWLFihZjhPzuFWWf+3r17AAAHB4ePFxjRF4CvLSIi+lKJvjIfHR2tdNd7Tk4OEhMTkZOTA0tLSxgYGBQ7QCIiIiIiUk30plH37t3D3bt3FX7u37+PtLQ0/PrrrzA2NkZMTIwaQyXKW1hYmNKHyyZNmqBJkyYKZRKJBGFhYR8vMDUdu0mTJqhatWqhjjFw4EBRxyAiIqLPj+hkPi/a2toYNmwYmjZtiiFDhqh7eKJPzv79++Ht7Y1y5cpBT08P9vb26NChA9avX1/aoREREdF/XIndAFu7dm388MMPJTX8f5L37JjSDgHRI5uUdgiiTJgwAWPGjPnox507dy5++OEHeHp6YuTIkTA2Nsbdu3dx8OBBhIeHo0ePHvK26enp0NLiPedERESkPiWWWZw5cwa6urolNTyRAi0trY+eKL979w5TpkxB48aNER0drbDyw/Tp0+XLs+WSSqUfNb6SkpOTg5ycHOjo6JR2KESfjE/hYgx9PJ/rhS/6bxI9zWbNmjUqf3777Td07twZERERCktW0X/b1q1bIZFIcPDgwQLrEhISEBwcjGrVqkFfXx+mpqbw9fXF1atXFfrFxMRAIpFgw4YNmDt3LhwcHCCVSlGvXj1cunRJoa2qOfOF8fz5c4waNQqurq4wMjKCoaEhmjZtipMnTxbYNzk5Ga9evULDhg0VEvlcNjY2Co//PWc+IiICEokEf/zxB8aNGwdra2vo6emhZcuWuHv3boHH//PPP2FiYoJWrVohPT1doW7//v2oVasWpFIpKleujA0bNij1v3//Pr799ltYWFhAT08PtWrVwrp16xTa3Lt3DxKJBDNmzMDvv/+OypUrQ1dXF6dOnSp2/ERERFR8oi9lfrhm7L9ZWFhgwoQJGD9+vNjh6TPTrl07GBkZYdOmTfDx8VGo27hxIywsLNCsWTMAwPnz5xETE4NOnTrBwcEBiYmJWLp0KRo1aoSrV6+iXLlyCv3nzJmDzMxMDBkyBNnZ2Zg1axY6duyI27dvF3v3wDt37mDLli3o0qULnJyc8OrVK6xYsQLNmjXDhQsX8NVXX+XZt2zZstDT08OePXswfPhw0Tv7jhgxAlKpFGPHjkVycjJmz56Nb7/9FqdOncqzz9GjR9G+fXu0aNECmzdvVrhKfubMGezYsQODBg1C3759sXz5cvTs2RO1atWS30SbnJyM+vXr4/nz5xgyZAhsbGywceNG9OzZEy9evMDgwYMVjrd27Vq8ffsW/fv3h6GhIaytreXLQYqJn4iIiNRDdDKv6sqbRCKBmZkZjIyMihUUfX6kUik6dOiAHTt2YMmSJfLk8s2bN4iKikJQUJB8GkybNm3km4nkCggIQPXq1bFixQqlD4EvXrzAP//8Az09PQBA1apV0alTJxw6dAht27YtVtwuLi6Ij49X2Jm3f//+qFq1KubPn4/w8PA8+2poaGD06NEICwuDnZ0dGjdujHr16qFFixaoW7duob8p0NPTQ0xMjDwGc3NzjBgxAlevXkWNGjWU2u/btw9dunRBx44dERkZqTS96OrVq7h8+bK8b9euXWFvb48VK1Zg1qxZAN7v5vjw4UMcOXJE/iFr4MCBaNiwIcaOHYvAwEAYGxvLx0xISMCtW7cUPmidPn1aVPxERESkPqKn2djb2yv92NnZMZH/gvn7++Ply5c4dOiQvGzXrl1IT0+Hv7+/vExfX1/+77S0NKSkpMDY2BhVqlRBbGys0rh9+vSRJ/IA0LhxYwDvr6oXl66urjwJzcjIQEpKCnJyclCnTh2VsfxbaGgoIiMjUbNmTRw+fBiTJk1CvXr1UKVKFZw5c6ZQMQwcOFDhw0R+57d161Z06tQJ3bt3x7p161TeJ9CkSROFJNrKygpVqlRRGG/v3r2oVauWPJEHAB0dHQwfPhxv3rxRWla2U6dOSt+YiImfiIiI1KvYdwzeuXMHUVFRuHv3LiQSCRwcHNCmTRtUqlRJHfHRZ6Rly5YwNzfHxo0b0a5dOwDvp9hUqFABDRs2lLfLyMjApEmTsHbtWiQmJiqMoWqqyr939TQzMwPwfr57cclkMvzyyy9YtmyZ0rdNFStWLNQYPXv2RM+ePfH27VucPXsWmzZtwvLly9G2bVvExcXB0tIy3/6FPb/79+/D398f7dq1w4oVK/K88q9qF1QzMzOF8e7du4fOnTsrtatevToA5W/e8nsuSvL/h4iIiPJXrGT+xx9/xNy5cyGTyRTKhw8fjh9++AEzZ84sVnD0edHW1sY333yDjRs3IiMjAxkZGTh06BAGDx6skHgOHToUK1aswJAhQ1C/fn2YmJhAQ0MDw4cPV/pdAqBw1fdDgiAUO+YZM2Zg/Pjx6NWrF3766SeUKVMGmpqamD59OuLj44s0loGBAby9veHt7Q0rKytMnToV+/fvR2BgYL79Cnt+VlZWcHBwwOHDh3Hy5Ek0aNCgWOPl1+bfHxQ+/GZEnccjIiKi4hGdzM+fPx+zZ89Gp06dMGrUKPkVvWvXrmH27NmYPXs2KlSowI2jvjD+/v4IDw9HVFQUXrx4gaysLIUpNgCwefNmBAYGYt68eQrlL168gIWFxUeM9n0sTZo0QUREhEJ5aGhoscb18PAAAKXlKYtDV1cXe/bsQYsWLdC2bVscPXoUtWvXFjWWg4MD4uLilMpzy1Rd3SciIqJPj+g580uXLkWbNm2wbds2eHp6wtjYGMbGxvD09MTWrVvRqlUr/P777+qMlT4DTZo0gbW1NTZt2oRNmzbB0dERderUUWijpaWldNV2w4YNak18C0tVLKdOnZLf3JmftLS0PJewjIqKAgD56jHqYmhoiP3796NSpUrw8fHBP//8I2qcdu3a4dKlSzh27Ji8LDs7G/Pnz4ehoSGaNGmipoiJiIioJIm+Mn/nzp18r7q3bdsWISEhYoenz5SGhga6du2K8PBwZGVlqdyV1dfXF2vWrIGxsTG++uorXL58GZs2bSqV+yx8fX0RFhaGwMBANGzYELdu3cKyZctQvXp1vHnzJt++aWlpaNCgAerUqYPWrVvD3t4eqampOHLkCPbu3Yu6devK7x1QJ1NTUxw6dAiNGzdGixYtcOLECVSuXLlIY4wZMwYbN26Er68vhg4dKv8AdubMGSxYsEBhJRsiIiL6dIm+Mm9ubo4bN27kWX/z5k2Ym5uLHZ4+Y927d0d6ejpycnKUptgA76do9e3bF5s2bZIvYXjgwAHY2tp+9FjHjh2LH3/8EdHR0Rg6dCiOHTuGjRs3Fmr6iqmpKZYvX44KFSpgzZo1CA4OxtixY3Hv3j1MmjQJR44cKbFdaS0tLXHkyBEYGBigWbNmSEhIKFJ/CwsLnDp1Cu3bt8fSpUsxatQovHnzBpGRkUprzBMREdGnSyKIvEtt0KBBWLFiBZYtW4ZevXrJb5gTBAFr1qxB//790bdvXyxevFitAX+qHj58CFtbWzx48AAVKlRQ2SZ3kx3ORyZSL762qLR5z44p7RDoI4oe2aS0QyCSE33ZcNq0aTh9+jT69u2L0aNHy7/mv337NpKSkvD111/j559/VlugRERERESkSPQ0GzMzM5w7dw7z58+Hm5sbUlJSkJKSglq1amHBggU4c+aMfL1pIiIiIiJSP1FX5jMyMvDLL7/A09MTgwcP5hxbIiIiIqJSIOrKvFQqxfTp03H//n11xwPg/RJ5kyZNgp2dHaRSKVxdXbF+/foC+927dw89evRA5cqVYWhoCFNTU3h4eGDNmjUqN7BJTU3FkCFDUK5cOejp6cHT0xOHDh0qiVMiIiIiIlI70XPma9asWeQdMgurf//+8tVBXFxcsHPnTnz77bd49+5dvrtpJiYm4unTp/D394etrS2ysrJw+PBh9OrVC9euXcOMGTPkbQVBQMeOHXHq1CmEhITAzs4Oq1evRps2bXDo0CF4e3uXyLkREREREamL6NVsjh8/ji5dumDNmjVo3bq12gK6dOkS3NzcMGXKFEycOBHA+8Tb29sb165dw4MHD6Cjo1OkMX19fXH06FG8evUK2traAIAdO3agc+fOWLNmDQICAgAAmZmZcHV1hZ6eHi5fvlykY3A1G6LSw9cWlTauZvNl4Wo29CkRfWV+5syZMDU1Rbt27WBra4uKFStCT09PoY1EIsG+ffuKNO7mzZuhoaGB4OBghXEGDx6MLl264NixY/Dx8SnSmPb29khPT0dmZqY8md+8eTPMzc3Ro0cPeTtdXV30798fI0eOxI0bN1ClSpUiHYeIiIiI6GMSncxfu3YNEokEdnZ2AP53ZexDuWvPF0VsbCwcHR2VNpyqW7cuAODixYsFJvNpaWlIS0tDamoqjh49ipUrV6JevXowNDRUOI67uzs0NTXzPA6TeSIiIiL6lIlO5lUl7+rw+PFjWFtbK5Xb2NjI6wsybdo0hTXumzdvjpUrVyodp169eqKP8/r1a7x+/Vr+ODExscC4iIiIiIjUqWT2mi+G9PR06OrqKpVraGhAW1sb6enpBY7Ru3dvNGnSBElJSdi5cyeSkpKQlpZWqONIpVJ5fX7mzp2LyZMnFxgLEREREVFJEZ3MF7QspUQigVQqhYWFRZGm20ilUmRmZiqVy2QyZGdny5Pt/Dg6OsLR0REA0KNHDwQFBaFFixa4efOmvH9ex8nIyJDX5yckJAT9+vWTP05MTISHh0eBsRERERERqYvoZN7BwaFQSbpUKkXjxo0xadIkeHp6FtjexsYGCQkJSuW5015yp8EUhZ+fH1avXo3jx4/L59vb2NionBpT2OMYGxvD2Ni4yLEQEREREamL6GR+xYoVWLBgARISEtCjRw84OztDEATcvHkTGzZsQMWKFREUFIRbt25h7dq1aNKkCY4cOYIGDRrkO66bmxuio6Px/PlzhZtgz549K68vqtwpM69evVI4ztGjR5GTk6NwE2xxjkNERERE9DGJ2gEWAJKTk5GWloZbt25hwYIFGDJkCIYOHYqFCxfixo0bSE1NRWZmJubPn4+4uDhYWloiLCyswHH9/Pwgk8mwePFieZkgCFi4cCEsLS3RtGlT+fHj4uIU5sInJSUpjScIAlauXAmJRKKQoPv5+SElJQUbNmyQl2VmZmLZsmVwcXFB1apVxTwtREREREQfjehkftGiRfjuu++UlpAEAAsLC/Tr1w8LFy4EAFhaWqJv3744f/58geO6u7sjICAAoaGhGDZsGJYvXw5fX1/ExMRg5syZ8ptWFy5ciGrVquHcuXPyvqNHj0aDBg0QGhqK5cuXY8aMGahduzaioqIwePBgODk5ydt27twZjRs3xnfffYcJEyZg2bJl8Pb2Rnx8PObOnSv2aSG83/irUaNGMDQ0hEQiQUxMTKH7RkREQCKRlNhqSZ+CnTt3wsjICM+fPy/tUD5p8+bNg4ODA7Kysko7FCKiz96iRYvg4OAAqVQKT0/PfHOy8PBwNGzYEGZmZjA3N0fLli1x4cIFhTaJiYno0aMHrKysYGhoCA8PD0RFRSm06dSpE+zt7SGVSmFtbY2AgIBCrUpIRSN6ms3Tp0+RnZ2dZ/27d+/w5MkT+ePy5csjJyenUGMvX74c9vb2iIiIwJIlS+Ds7IzIyEj07Nkz336dOnXCkiVLEB4ejuTkZOjp6cHV1RWrVq1Cr169FNpKJBLs3r0b48aNQ3h4OF6/fg0XFxfs3bsXzZs3L1Sc6ua7w7dUjvuhPZ32FKu/TCZDt27d8O7dO8yePRuGhoaoVq2amqL7/MlkMkycOBEDBw5U+UEYAFavXo2goCA4Ojri9u3bHy227OxsTJ06FREREUhKSoKzszPGjBmjsLGaOvrHxMTIv2H7t8OHD8tff/3798fPP/+MpUuXYsiQIcU7OSKiL9imTZsQEhKCJUuWoG7dupg3bx58fHxw8+ZNWFhYKLWPiYlB9+7d4eXlBalUipkzZ6JFixa4du2afPnwgIAAvH37Fnv37kWZMmWwatUqdOrUCVevXpVfPG3SpAl+/PFH2NjY4PHjxxg5ciT8/Pzw559/ftTz/6+TCIIgiOlYt25dPHnyBCdPnkSFChUU6h48eID69evDxsYGZ86cAQCMGjUKu3fvxo0bN4of9Sfo4cOHsLW1xYMHD5Sej1wFbTn/X0jmExIS4ODggF9//RXDhw8vcv+cnBxkZ2dDV1dX1KZjn7p9+/ahXbt2uH37tnzFpX9r0aIFTpw4gaysLJw6dUrlfggloXfv3lizZg2Cg4Ph4uKCnTt3IioqCqtXr0ZgYKDa+ucm88HBwUo3xTdv3hzlypWTPw4JCcG2bdtw9+5daGjk/UViQa8topLmPTumtEOgjyh6ZJPSDqFI6tatCw8PDyxYsADA+wtLtra2GDFiBEaOHFlg/5ycHJiZmWHJkiXyCzSGhoYIDw9H9+7dAbyf1qynp4e1a9eiS5cuKsfZvXs3OnfujMzMTKVNO0k80Vfm58yZAx8fHzg7O6N9+/aoXLkyAODWrVvYvXs3BEHA+vXrAbxf7jEyMhLt27dXT9T0yXr69CkAwNTUVFR/TU3NAl/gaWlp0NfXFzV+aVu5ciVq166dZyKfmJiI6OhojBs3DvPnz8fatWs/SjJ/6dIlREREYMqUKZg4cSIAoF+/fvD29saoUaPg7+8PHR0dtfZv0KAB/P39842rW7du+PXXXxEdHV1q35gREX3OsrKyEBsbiwkTJsjLNDQ00Lx5c5w+fbpQY6SlpSE7O1vhG2UvLy9s3LgRrVq1gomJCdasWQNdXV3Ur19f5RjPnz/HunXr0KBBAybyaiZ6znyDBg1w6tQptGjRAnv27MHPP/+Mn3/+GXv27EHLli1x5swZ+co1UqkUT548wbJly9QWOH16goKCULduXQDvr9JKJBI4ODggISEBwcHBqFatGvT19WFqagpfX19cvXpVaYx/z5kPCwuDRCLBtWvX0KtXL5QpUwY1atRQqLt16xYGDhyIMmXKwNDQEF27dkVKSorS2E+fPsWAAQNQvnx56OrqonLlyvjll1/w4ZdTb968wciRI1GxYkVIpVJYWVmhWbNm+OOPP4rURpWsrCxERUWhZcuWebZZv349ZDIZgoKC0L59e2zatCnf6WwFyczMxKFDhzBixAhcvHgxz3abN2+GhoYGgoOD5WUSiQSDBw9GUlISjh07lu9xxPZ/8+ZNvufn4eEBExMT7NixI9/jExGRasnJycjJyYGVlZVCuZWVlcJ06PyMGTMGdnZ28Pb2lpdt3rwZaWlpMDc3h66uLn744Qfs3btXPg0n1+jRo2FgYIAyZcrgwYMH2L59e/FPihQUawfYr7/+Grt27YJMJpOvJFO2bNl8vw6n/64BAwagYsWKCAsLQ//+/dGwYUMYGhri/PnziImJQadOneDg4IDExEQsXboUjRo1wtWrVxWmVeTFz88PFStWxE8//aS02Vf37t1Rrlw5TJ06Vb66kra2tvybIeD9m5mnpycyMzPRv39/WFtb448//sDo0aPx+PFjzJs3DwAwaNAgbN68GcHBwahRowZevHiBs2fP4vLly2jYsGGh26hy4cIFZGRkwN3dPc82kZGR8iv3/v7+WLduHQ4cOABf38JPwbp37x7279+PqKgoREdHIy0tDfb29ggKCsqzT2xsLBwdHZXm8ed+OLt48aJ8jwZ19f/uu+/w5s0baGhowMvLCzNnzoSXl5dCG4lEgtq1a3N+JRFRMf176qogCIWazvrLL79g48aNOH78uMI3rBMmTEBaWhqio6NhZmaGTZs2oXPnzjh79iwqVaokbzdq1Cj07dsXCQkJmDx5Mnr37o1du3ap78SoeMl8Lg0NjUIlZPTfVq9ePUgkEoSFhaFevXryG5bT0tKU5s8FBASgevXqWLFiBcaPH1/g2NWqVcOWLVtU1jk7Oysk7rlLmS5ZskS+sVfum86VK1fkVycGDBgAGxsbzJkzByNGjIC9vT327t2L7777Lt8VjQrTRpW4uDgAUHiT+9DVq1fx119/Yfbs2QAAHx8fmJubY+3atfkm81lZWThx4oQ8gY+Li4O2tja8vLwQGhqKtm3byr/NyMvjx4+VrqYA/9s8raDVB4rSX0dHB507d0abNm1gaWmJuLg4zJkzB02aNMGxY8eUvqKtVKlSgd96EBGRahYWFtDU1FS6Cp+UlKR0tf7fZs+ejWnTpuHIkSP46quv5OXx8fFYtGgR4uLiUKVKFQBAzZo1cezYMSxbtgwzZsxQOL6FhQWcnZ1RrVo12Nra4vz586hTp44az/LLxkvoVOI+nN+elpaGlJQUGBsbo0qVKoiNjS3UGIMGDcqz7sOpHQDQuHFj5OTkyHcSFgQBmzdvRrt27aCpqYnk5GT5j4+PD2QymXz5TGNjY5w/f17l7sC5CtNGleTkZACAmZmZyvrIyEhIJBJ069YNAKCtrY3OnTtj9+7deP36tco+sbGxMDc3R4sWLRAZGYl69ephy5YtSE5ORkxMDH788ccCE3ng/cZqucu+fkhDQwPa2tryjdfU0d/Lywvbtm1D37590b59e/z44484d+4cdHR0MHr0aKUxzM3NkZWVledzQEREedPR0YG7uzsOHz4sL5PJZDh69Gi+92TNmjULU6dOxYEDB1C7dm2Futw9fv49911TUxMymSzPMXOntf77G3YqHibzVOIyMjLkS1MZGBjAwsIClpaWuHLlCl6+fFmoMSpWrJhn3b9XMMlNlnPXcX/27BlevHiBlStXwtLSUuEn96bK3GliM2fOxJUrV1ChQgXUrVsXoaGh8ivquQrTJj+qFpDKvWG8QYMGCqsh+fv7IyMjA9u2bVM5lp6eHmxtbQEAL1++xL1793D37l3cv3+/0PEA7+9rUfXmKpPJkJ2dDalUWqL97e3t0aVLF5w+fVphIzjgf8/Xf3F1IyKijyEkJARLly7F6tWrcf36dQwaNAhpaWny6ZeBgYEYO3asvP0vv/yCCRMmYOXKlXBwcMCTJ0/w5MkTvHnzBgBQtWpVODk5YcCAAbhw4QJu376Nn3/+GWfOnJF/k3zhwgX89ttvuHz5MhISEnDs2DH06NEDTk5O8PDw+OjPwX8Zk3kqcUOHDsWcOXPg5+eHzZs34+DBgzh8+DBq1KiR7yf4D+np6eVZl9dd8blJYO4xunfvjsOHD6v88fPzA/A+eb5z5w4WL16MChUq4Ndff4WLiwvWrl0rH7cwbVTJXcv3xYsXSnUxMTF48OCB0uouTZo0gZWVVZ5jV69eHdevX8edO3cwb9486OvrIywsDC4uLrCzs8OAAQOwc+dOpKam5hubjY2Nym8acqfH5E6XKan+AGBnZweZTKb0Ae/FixfQ0dGBkZFRgWMQEZGybt26Yc6cOZg0aRJq1qyJy5cv48CBA/K/S/fv31d4D1+8eDGysrLQpUsXWFtby39yp4Fqa2sjKioKJiYmaNOmDWrWrIlt27Zh48aN8nvH9PT0sGvXLjRr1gxVqlRBnz598NVXXynNvafiU8uceaL8bN68GYGBgfKbTHO9ePFC5WYV6mZpaQljY2O8e/euUMsbWltbY8CAARgwYABevnwJT09PTJ48WWHTssK0+bfczbPu3r2LWrVqKdStXbsWmpqaSvcWaGpqomvXrli8eDEePXqE8uXLqxy7YsWK+P777/H9998jMzMTx48fR1RUFPbv349ly5ZBW1sbx48fz/MrVTc3N0RHR+P58+cKN7GePXtWXp+f4vYHgDt37kBTU1PpJtr4+HhuPEZEVEyDBw/G4MGDVdb9e6f2wuzCXrly5XxXpqlRowaOHj1alBBJJF6ZpxKnpaWlNLVkw4YNH21L59wkeceOHSqXZ3z16hWys7ORk5ODV69eKdSZmpqiYsWK8qvphWmTF3d3d0ilUqUtsTMyMrB161Z4e3ujbNmySv38/f0hk8kUbvLNj66uLlq2bIl58+bhxo0biI+Px9y5c/Nd+9/Pzw8ymQyLFy+Wl+XeSGxpaamwY2taWhri4uLk9wAUtX/u9KcPXb9+HVu3bkXjxo0VpuQIgoDY2FilVW6IiIjovWJdmc/JycHGjRsRHR2NpKQk/Pzzz3B1dcXLly9x6NAhNGzYUOUKF/Rl8fX1xZo1a2BsbIyvvvoKly9fxqZNm/Jc1aUkzJgxA8ePH0f9+vXRt29fuLi44PXr1/jnn3+wbds23L59G1KpFOXLl8c333yDr7/+GsbGxjh58iQOHDggv8k2NTW1wDZ50dHRQatWrXD48GFMmzZNXp57g6umpqbCCgC5BEGAtrY21q5di1GjRinUJSYmYt++fQWev1QqhYmJSZ717u7uCAgIQGhoKJ49eybfwTUmJgYrV65UuLn13LlzaNq0KUJDQxEWFlbk/t26dYOhoSFq166NsmXL4saNG1i6dCm0tLTkX+HmOnv2LF69eoWOHTsWeI5ERERfItHJ/KtXr+Dj44Nz587B0NAQb9++xYgRIwAARkZGCAkJQWBgoELSQvnb02lPaYdQIubPnw9tbW1s2rQJK1asQO3atXHgwAGlxLQkWVpa4uzZs/jpp5+wc+dOLFu2DKampnB2dkZYWJh8akdwcDAOHz6MXbt24d27d6hYsSJmz56NYcOGAXi/Mk9BbfLTt29f+Pr64vbt23BycgIA+Xz4AwcO4MCBA3n2vXLlCq5cuQJXV1d52Y0bN/Ddd98V6jk4fPhwvnPXly9fDnt7e0RERGDJkiVwdnZGZGRkvlOHxPT39fXF+vXrMXfuXLx+/RplypSBr68vJk6cqDSdZsuWLbC1teXur0RERHmQCKqW1iiEQYMGITIyElu3bpVfYTty5Ih8d7Bhw4bhxIkTuHTpkloD/lQ9fPgQtra2ePDggcJqJB/KnYP279VX6Mshk8ng6uqKVq1aKV2FJkW5m11NnDgRQ4cOzbctX1tU2rxnx5R2CPQRRY9sUtohEMmJvjK/c+dODBkyBK1atUJKSopSfeXKlREZGVms4Ij+azQ0NPDTTz8hICAA48aNU7rZk/4nPDwcBgYGGDhwYGmHQkSkwHdH4Xflps/fpz5zQvQNsC9evICjo2Oe9YIgcFMAIhU6duyI1NRUJvIFGDZsGO7du8clzIiIiPIhOpmvWLEi/vnnnzzrjx8/Lt/il4iIiIiI1E90Mt+zZ0+Eh4fjxIkT8rLcHRoXLFiAHTt2yHcWIyIiIiIi9RM9Z3706NE4e/YsvL294ezsDIlEgqFDhyIlJQVPnjxBx44dMWTIEHXGSkREREREHxB9ZV5LSwu7d+/G2rVrUbVqVVStWhXv3r2Dm5sb1qxZg23btsmv1NN7GhoakMlkpR0G0X+OTCaDhgb3wCMioi9PsTaNAt7vTunv76+OWP7ztLW1kZ6ezsSDSI1kMhmys7Ohp6dX2qEQERF9dMwoPyJjY2PIZDIkJydD5PL+RPQBQRCQnJwMmUwGY2Pj0g6HiIjooyvWlfkjR44gPDwcd+7cwfPnz5USVIlEgvj4+GIF+F9iYGAAIyMjpKSk4PXr19DSKvYXI0RftHfv3iE7OxtGRkYwMDAo7XCIiIg+OtHZ5Pz58xESEgJLS0t4enriq6++Umdc/1k2NjZ4+fIl3rx5w/nzRMWko6MDc3NzmJqalnYoREREpUJ0Mj9nzhw0btwYBw4c4KYuRaChoQFzc3NuGERERERExSZ6znxycjK6devGRJ6IiIiIqJSITubd3d1x9+5ddcZCRERERERFIDqZnzt3LiIiInDkyBF1xkNERERERIVU6Dnzbdq0USozNTWFj48PnJyc4ODgAE1NTYV6iUSCffv2FT9KIiIiIiJSUuhk/tq1ayp3dLWzs0NWVhZu3ryp1sCIiIiIiCh/hU7m7927V4JhEBERERFRUYmeM3/ixAk8e/Ysz/rk5GScOHFC7PBERERERFQA0cl806ZNcfjw4Tzrjx49iqZNm4odnoiIiIiICiA6mRcEId/6rKwsaGiIHp6IiIiIiApQpB1gX79+jZcvX8ofp6Sk4P79+0rtXrx4gQ0bNqB8+fLFDpCIiIiIiFQrUjL/66+/YsqUKQDeLzs5fPhwDB8+XGVbQRDw888/FztAIiIiIiJSrUjJfPPmzSGVSiEIAsaNG4du3bqhZs2aCm0kEgkMDAxQu3Zt1K1bV52xEhERERHRB4qUzNevXx/169cHAGRmZuKbb77BV199VSKBERERERFR/oqUzH8oNDRUnXEQEREREVERcbkZIiIiIqLPFJN5IiIiIqLPFJN5IiIiIqLPFJN5IiIiIqLPlOhkXiaTqTMOIiIiIiIqItHJfIUKFTBy5EhcvnxZjeEQEREREVFhiU7m69evj8WLF8Pd3R0uLi6YNWsWHj16pM7YiIiIiIgoH6KT+S1btuDJkydYtmwZLCwsMGbMGNjb26N58+aIjIzE27dv1RknERERERH9S7FugDU2Nkbfvn1x7NgxJCQk4KeffsLTp08RFBQEKysrBAQE4NChQxAEQV3xEhERERHR/1PbajYVKlTAmDFjsH//fnTt2hVpaWlYt24dWrduDVtbW8yePRvv3r0r1FjZ2dmYNGkS7OzsIJVK4erqivXr1xfYLy4uDmPHjoWbmxtMTExgaGiIBg0aYOfOnUptIyIiIJFIVP7cvn27qKdPRERERPTRaaljkNTUVGzduhVr167F8ePHoaWlhW+++Qa9evWCjo4Oli1bhtGjR+PatWtYuXJlgeP1798fa9asQXBwMFxcXLBz5058++23ePfuHQIDA/Pst3z5coSHh6Nz58747rvvkJmZicjISHTq1Anh4eHo16+fUp+wsDA4OjoqlJUrV67oTwIRERER0UcmEUTOgcnJycH+/fuxdu1a7NmzB+np6fDw8ECvXr3g7+8PMzMzhfaTJk3CvHnz8Pr163zHvXTpEtzc3DBlyhRMnDgRACAIAry9vXHt2jU8ePAAOjo6KvteuHABVapUgZGRkbwsKysLtWvXRmJiIp4+fQoNjfdfRkRERKB37944ffo0PD09xTwFCh4+fAhbW1s8ePAAFSpUKPZ4RET0+fCeHVPaIdBHZOA4p7RDoI9oT6c9pR1CvkRPsylXrhw6dOiAU6dOYdiwYbh+/TrOnDmDQYMGKSXyAFC9enW8efOmwHE3b94MDQ0NBAcHy8skEgkGDx6MpKQkHDt2LM++tWvXVkjkAUBHRwe+vr5ITk5GUlKSyn6pqanIyckpMDYiIiIiok+J6GTex8cHBw8eREJCAqZNm4YqVark297f379QG03FxsbC0dER5ubmCuV169YFAFy8eLHIsT5+/BhaWlowMTFRqmvRogWMjY2hr6+PNm3a4Nq1a4Ua8/Xr13j48KH8JzExschxEREREREVh+g582vXrlVnHHKPHz+GtbW1UrmNjY28viji4+OxceNGtG/fHnp6evJyfX199OrVC02bNoWJiQkuXbqEuXPnwsvLCxcuXICTk1O+486dOxeTJ08uUixEREREROpU7Btgo6KisHfvXiQkJAAA7O3t0a5dO7Rp00bUeOnp6dDV1VUq19DQgLa2NtLT0ws9VlpaGvz8/KCrq4u5c+cq1Pn5+cHPz0/+uGPHjmjXrh08PT0RFhZW4IeVkJAQhRtqExMT4eHhUejYiIiIiIiKS3Qy//r1a3zzzTeIjo6GRCKBlZUVBEHAwYMHsXTpUjRp0gQ7duyAsbFxkcaVSqXIzMxUKpfJZMjOzoZUKi3UONnZ2ejatSuuXr2Kffv2wd7evsA+derUQcOGDXHkyJEC2xobGxf53IiIiIiI1En0nPkRI0YgOjoaP/30E168eIFHjx7h8ePHePHiBaZOnYqYmBiMGDGiyOPa2NionH+eO70md7pNfmQyGQIDA3Hw4EGsW7cOzZo1K/Tx7ezs8Pz588IHTERERERUSkQn89u3b8eAAQMwduxYhRVkjIyMMG7cOPTv3x/bt28v8rhubm6Ij49XSqjPnj0rry/IwIEDsXHjRixZsgTffPNNkY5/584dlC1btkh9iIiIiIhKQ7F2gHVxcRFVlx8/Pz/IZDIsXrxYXiYIAhYuXAhLS0s0bdoUAJCcnIy4uDikpaUp9B85ciTCw8Mxc+ZMlZtE5VJ19f3o0aM4efIkfHx8RMVORKotWrQIDg4OkEql8PT0xPnz5/NsGx4ejoYNG8LMzAzm5uZo2bIlLly4oNBm6tSpqFatGgwMDGBmZobmzZvLP/Dn6tSpE+zt7SGVSmFtbY2AgIAi30BPRET0qRM9Z75NmzbYu3cvBg0apLJ+7969om6CdXd3R0BAAEJDQ/Hs2TP5DrAxMTFYuXKl/ObYhQsXYvLkyTh27BiaNGkCAPjtt98wZ84c1KpVCzY2Nko3sXbq1AkGBgYAgAYNGqBWrVpwdXWFqakpLl++jBUrVsDKygphYWFFjpuIVNu0aRNCQkKwZMkS1K1bF/PmzYOPjw9u3rwJCwsLpfYxMTHo3r07vLy8IJVKMXPmTLRo0QLXrl2Tr3Tl5OSE3377DY6OjkhPT5ePGR8fjzJlygAAmjRpgh9//BE2NjZ4/PgxRo4cCT8/P/z5558f9fyJiIhKUqF3gP33hkspKSnw9/eHra0tBg8eDCcnJ0gkEty8eRMLFy7Eo0ePsGHDBlSrVq3IQWVlZWHq1KmIiIhAUlISnJ2dMXr0aPTs2VPeJiwsTCmZDwoKwurVq/Mc9+7du3BwcAAATJgwAVFRUbh79y7evn2LcuXKoVWrVggNDUX58uWLHDN3gCVSrW7duvDw8MCCBQsAvL+nxdbWFiNGjMDIkSML7J+TkwMzMzMsWbIEPXr0UNnm9evXMDExQUxMDBo3bqyyze7du9G5c2dkZmZCU1NT/AkRqcAdYL8s3AH2y/Kp7wBb6GReQ0MDEolEoSy3a17lGhoaePfunTri/OQxmSdSlpWVBX19fezYsQO+vr7y8l69euHNmzfYtm1bgWOkpqaibNmy2LFjB1q1aqXyGL/99humTZuG27dvK204B7yfVjdo0CA8ffoUMTExxTonIlWYzH9ZmMx/WT71ZL7Q02wmTZqklLQTEeUnOTkZOTk5sLKyUii3srLC7du3CzXGmDFjYGdnB29vb4XyvXv3wt/fH2lpabC2tsbhw4eVEvnRo0dj4cKFSEtLQ7169bB3797inRAREdEnptDJPOeRE5FYqr69K8zFgV9++QUbN27E8ePHoaOjo1DXtGlTXL58GcnJyQgPD4efnx/Onj2rMA9/1KhR6Nu3LxISEjB58mT07t0bu3btUs9JERERfQKKvQMsEVFeLCwsoKmpiSdPniiUJyUlKV2t/7fZs2dj2rRpOHLkCL766iulegMDAzg5OcHJyQmenp6oXLkyVq1ahVGjRikc38LCAs7OzqhWrRpsbW1x/vx51KlTRz0nSEREVMqKtTQlEVF+dHR04O7ujsOHD8vLZDIZjh49inr16uXZb9asWZg6dSoOHDiA2rVrF+pYgiCo3D36w3oA+bYhIiL63PDKPBGVqJCQEAQGBsLd3R0eHh6YN28e0tLSEBQUBAAIDAxE+fLlMX36dADvp9ZMnDgR69evh4ODg/yqvqGhIQwNDZGRkYHJkyejQ4cOsLa2RkpKCn7//Xc8fPhQvknchQsXcOrUKTRq1AhmZma4c+cOJk2aBCcnJ3h4eJTK80BERFQSmMwTUYnq1q0bnj17hkmTJuHJkyeoWbMmDhw4IJ/bfv/+fWho/O9LwsWLFyMrKwtdunRRGCc0NBRhYWHQ0NDAzZs38c033+DZs2coU6YMPDw88Mcff8iXwtXT08OuXbswefJkvH37FtbW1mjVqhU2bdqkNPeeiIjoc1bopSkpf1yakojoy8WlKb8sXJryy/KpL03JOfNERERERJ+pQk+zuX//vqgD2NnZiepHRERERET5K3Qy7+DgIGrTqJycnCL3ISIiIiKighU6mV+5cqVCMi8IAubPn4979+7h22+/RZUqVSAIAm7cuIENGzbAwcEBQ4cOLZGgiYiIiIioCMl87jJyuWbOnIm3b9/i9u3bKFOmjEJdWFgYvLy88OzZM7UESUT/47vDt7RDoI/oU7/xioiISpfoG2AXL16M/v37KyXyAGBpaYnvvvsOixYtKlZwRERERESUN9HJfFJSErKzs/Osf/fuHZKSksQOT0REREREBRCdzNeqVQuLFi3CvXv3lOru3r2LRYsWoVatWsWJjYiIiIiI8iF6B9i5c+eiRYsWqFq1Ktq3bw9nZ2dIJBLExcVhz5490NLSwpw53FSBiIiIiKikiE7mPT09ce7cOUyYMAFRUVHYunUrAEBfXx9t27bFlClTUKNGDbUFSkREREREikQn8wBQrVo1bNu2DTKZDM+ePYMgCChbtiw0NLixLBERERFRSStWMq8wkJYWTExMmMgTEREREX0kxcq8z58/jxYtWkBfXx9ly5bFiRMnAADJyclo27YtoqOj1RIkEREREREpE53Mnz59Go0aNUJ8fDwCAwMhCIK8zsLCAm/evMHy5cvVEiQRERERESkTncyPHz8ezs7OuHbtGqZNm6ZU37RpU5w9e7ZYwRERERERUd5EJ/Pnzp1Dnz59IJVKIZFIlOorVKiAxMTEYgVHRERERER5E53Ma2pq5nuz65MnT6Cvry92eCIiIiIiKoDoZL527drYs2ePyrqsrCysW7cOXl5eogMjIiIiIqL8FWvOfHR0NPr27YvLly8DAB4/fowDBw7A29sb8fHxGDdunLriJCIiIiKifxG9zry3tzfWrVuH4OBgREREAAB69eoFQRBgamqK9evXw9PTU11xEhERERHRvxRr06hu3bqhffv2OHz4MG7cuAGZTAZHR0e0atUKhoaG6oqRiIiIiIhUEJ3MnzhxAtWqVYOlpSXat2+vVJ+cnIxr166hUaNGxQqQiIiIiIhUEz1nvmnTpjh8+HCe9UePHkXTpk3FDk9ERERERAUQncx/uOOrKllZWfkuXUlERERERMVTpGk2r1+/xsuXL+WPU1JScP/+faV2L168wIYNG1C+fPliB0hERERERKoVKZn/9ddfMWXKFACARCLB8OHDMXz4cJVtBUHAzz//XOwAiYiIiIhItSIl882bN4dUKoUgCBg3bhy6deuGmjVrKrSRSCQwMDBA7dq1UbduXXXGSkREREREHyhSMl+/fn3Ur18fAJCZmYnOnTvDxcWlRAIjIiIiIqL8iV6aMjQ0VJ1xEBERERFREYlebmbEiBGoXLlynvXOzs4YNWqU2OGJiIiIiKgAopP5ffv2oVu3bnnWd+vWDXv27BE7PBERERERFUB0Mv/gwQM4ODjkWW9vb48HDx6IHZ6IiIiIiAogOpk3NjbGnTt38qyPj4+Hnp6e2OGJiIiIiKgAopN5b29vLFmyRGVCf+fOHSxduhTe3t7FCo6IiIiIiPImejWbKVOmYP/+/XB1dUXv3r3x1VdfQSKR4O+//0ZERAS0tLQwdepUdcZKREREREQfEJ3MV65cGSdPnkRwcDAWLVqkUNe4cWMsWLAAVapUKXaARERERESkmuhkHgBq1KiBmJgYJCcn486dOxAEAU5OTihTpoy64iMiIiIiojwUK5nPZWFhAQsLC3UMRUREREREhST6BlgASEhIwHfffQdHR0cYGxvj+PHjAIDk5GR8//33iI2NFTVudnY2Jk2aBDs7O0ilUri6umL9+vUF9ouLi8PYsWPh5uYGExMTGBoaokGDBti5c6fK9qmpqRgyZAjKlSsHPT09eHp64tChQ6JiJiIiIiL62EQn89evX4ebmxu2bNmCypUr4+3bt8jJyQHw/kr9mTNnsHjxYlFj9+/fHz///DM6duyIBQsWwNbWFt9++y3WrFmTb7/ly5dj8eLF+PrrrzFjxgz89NNPSE9PR6dOnbB8+XKFtoIgoGPHjli+fDn69u2LefPmQUNDA23atEF0dLSouImIiIiIPiaJIAiCmI7t27fH33//jTNnzkBTUxNly5bFkSNH5MtRTpgwAVu2bMGNGzeKNO6lS5fg5uaGKVOmYOLEiQDeJ97e3t64du0aHjx4AB0dHZV9L1y4gCpVqsDIyEhelpWVhdq1ayMxMRFPnz6Fhsb7zy87duxA586dsWbNGgQEBAAAMjMz4erqCj09PVy+fLlIcT98+BC2trZ48OABKlSoUKS+REXhu8O3tEOgj2hPJ+6k/Tnwnh1T2iHQR2TgOKe0Q6CP6FN/HxZ9Zf7EiRP4/vvvYWVlBYlEolRvb2+PR48eFXnczZs3Q0NDA8HBwfIyiUSCwYMHIykpCceOHcuzb+3atRUSeQDQ0dGBr68vkpOTkZSUpHAcc3Nz9OjRQ16mq6uL/v3746+//iryhxAiIiIioo9NdDL/7t07GBoa5ln//PlzaGkV/f7a2NhYODo6wtzcXKG8bt26AICLFy8WeczHjx9DS0sLJiYmCsdxd3eHpqamqOO8fv0aDx8+lP8kJiYWOS4iIiIiouIQncy7uLjkeZVcEARs374d7u7uRR738ePHsLa2Viq3sbGR1xdFfHw8Nm7ciPbt20NPT09tx5k7dy5sbW3lPx4eHkWKi4iIiIiouEQn8yEhIdi6dSumTp2KlJQUAEBOTg7i4uLg7++PCxcuYOTIkUUeNz09Hbq6usqBamhAW1sb6enphR4rLS0Nfn5+0NXVxdy5cwt1HKlUKq/PT0hICB48eCD/OXfuXKHjIiIiIiJSB9HrzHft2hUJCQkYN24cwsLCAACtWrUCAGhqamLOnDlo3bp1kceVSqXIzMxUKpfJZMjOzpYn2wXJzs5G165dcfXqVezbtw/29vaFOk5GRoa8Pj/GxsYwNjYuVCxERERERCWhWJtGjRw5Ev7+/ti2bRtu3rwJmUwGR0dHdOnSBQ4ODqLGtLGxQUJCglJ57rSX3Gkw+ZHJZAgMDMTBgwexadMmNGvWTOVxVM1zL8pxiIiIiIhKU7F3gK1QoQKGDRumjlgAAG5uboiOjsbz588VboI9e/asvL4gAwcOxMaNGxEeHo5vvvkmz+McPXoUOTk5CjfBFuU4RERERESlqVg7wJYEPz8/yGQyhQ2nBEHAwoULYWlpiaZNmwJ4v8tsXFwc0tLSFPqPHDkS4eHhmDlzJvr165fvcVJSUrBhwwZ5WWZmJpYtWwYXFxdUrVpVzWdGRERERKRehb4yX7FiRWhoaCAuLg7a2tqoWLGiyvXllQ6gpQULCwu0aNECY8aMUVhRRhV3d3cEBAQgNDQUz549g4uLC3bu3ImYmBisXLlSftPqwoULMXnyZBw7dgxNmjQBAPz222+YM2cOatWqBRsbG6xdu1Zh7E6dOsHAwAAA0LlzZzRu3Bjfffcd4uLiYGdnh9WrVyM+Ph4HDhwo7NNCRERERFRqCp3MN27cGBKJRL6Dau7jguTk5CAxMREzZszAw4cPsWLFigL7LF++HPb29oiIiMCSJUvg7OyMyMhI9OzZM99+uWvDX7p0Sb6r64fu3r0rT+YlEgl2796NcePGITw8HK9fv4aLiwv27t2L5s2bFxgjEREREVFpkwiCIHyMA02ZMgULFizAs2fPPsbhPrqHDx/C1tYWDx48QIUKFUo7HPoP893hW9oh0Ef0qW8jTu95z44p7RDoIzJwnFPaIdBH9Km/Dxf7BtjC6tSpEz7S5wYiIiIioi9CsZL5rKwsLF++HHv37pUvJ2lvb4927dqhb9++Cpsyubi4wMXFpXjREhERERGRnOjVbBITE+Hm5obBgwfj4sWLMDU1hYmJCS5evIjBgwfDzc1NvmY7ERERERGpn+hkfvDgwYiPj8e6deuQmJiIkydP4tSpU0hMTMTatWtx584dDBkyRJ2xEhERERHRB0RPszl06BCGDh2K7t27K5RLJBL06NEDly9fxu+//17sAImIiIiISDXRV+b19PRga2ubZ72dnV2Ba8oTEREREZF4opN5Pz8/bNiwAdnZ2Up1WVlZWL9+Pbp161as4IiIiIiIKG+FnmZz7tw5hcddu3bFH3/8gTp16mDgwIFwcnKCRCLBzZs3sXTpUgBAly5d1BstERERERHJFTqZ9/T0VNrxNXfd+O+//15e9+Fa8t7e3sjJyVFHnERERERE9C+FTuZXrVpVknEQEREREVERFTqZ79WrV0nGQURERERERST6BlgiIiIiIipdoteZ79OnT4FtJBIJVqxYIfYQRERERESUD9HJfHR0tNINsTk5OUhMTEROTg4sLS1hYGBQ7ACJiIiIiEg10cn8vXv3VJZnZ2dj8eLFWLhwIY4ePSp2eCIiIiIiKoDa58xra2tj2LBhaNq0KYYMGaLu4YmIiIiI6P+V2A2wtWvXRnR0dEkNT0RERET0xSuxZP7MmTPQ1dUtqeGJiIiIiL54oufMr1mzRmX5y5cvERMTg507d2LQoEGiAyMiIiIiovyJTuaDgoLyrLOwsMCECRMwfvx4scMTEREREVEBRCfzd+/eVSqTSCQwMzODkZFRsYIiIiIiIqKCiU7m7e3t1RkHEREREREVkegbYJ8+fYpLly4plF2/fh0DBgyAn58fduzYUezgiIiIiIgob6KvzA8ZMgRPnjzBiRMnAAApKSlo1KgRXr16BT09PWzbtg27du1Cu3bt1BYsERERERH9j+gr82fOnEGrVq3kj9evX4+XL1/i4sWLSE5ORv369TFr1iy1BElERERERMpEJ/PPnj2DtbW1/HFUVBQaNWqEr776Ctra2vD398fVq1fVEiQRERERESkTncybmZkhMTERAJCRkYETJ06gZcuW8nqJRIKMjIziR0hERERERCqJnjPfoEEDLF68GNWqVcPhw4eRkZGB9u3by+tv3LiB8uXLqyVIIiIiIiJSJjqZnz59Olq2bIlvvvkGADB8+HBUq1YNAJCTk4OtW7eiTZs26omSiIiIiIiUiE7mHR0dcePGDVy7dg3GxsZwcHCQ16WlpWHRokX4+uuv1REjERERERGpIDqZBwAtLS24uroqlRsZGaFDhw7FGZqIiIiIiAog+gZYIiIiIiIqXUzmiYiIiIg+U0zmiYiIiIg+U0zmiYiIiIg+U0zmiYiIiIg+U6KTeU1NTaxfvz7P+k2bNkFTU1Ps8EREREREVADRybwgCPnWy2QySCQSscMTEREREVEBijXNJr9k/ezZszAzMyvO8ERERERElI8ibRo1f/58zJ8/X/54+PDhGD9+vFK7ly9f4tWrVwgICCh+hEREREREpFKRknkLCwtUqVIFAHDv3j1YW1vD2tpaoY1EIoGBgQFq166NIUOGqC9SIiIiIiJSUKRk/ttvv8W3334LAGjatCkmTJiAZs2alUhgRERERESUvyIl8x86duyYOuMgIiIiIqIiEn0D7MmTJ7Fo0SKFsvXr16NKlSooW7Yshg0bBplMVuwAiYiIiIhINdHJfGhoKE6cOCF/fOPGDQQFBUFDQwO1a9fGwoUL8dtvv6klSCIiIiIiUiY6mf/nn39Qt25d+eMNGzZAX18fZ8+eRVRUFAICArBy5UpRY2dnZ2PSpEmws7ODVCqFq6trvhtUfWjatGno2LEjbGxsIJFIMHDgQJXtIiIiIJFIVP7cvn1bVNxERERERB+T6Dnzr169UlhH/sCBA2jRogWMjY0BAA0aNMC2bdtEjd2/f3+sWbMGwcHBcHFxwc6dO/Htt9/i3bt3CAwMzLfv+PHjUbZsWdSpUwf79u0r8FhhYWFwdHRUKCtXrpyouImIiIiIPibRyby1tTWuXbsGAHjy5AkuXLiAvn37yutfv34NLa2iD3/p0iVERERgypQpmDhxIgCgX79+8Pb2xqhRo+Dv7w8dHZ08+9+5cwcVK1YEkP+mVrl8fHzg6elZ5DiJiIiIiEqb6Gk2nTt3xsKFCzF06FB06tQJurq6aN++vbz+r7/+QqVKlYo87ubNm6GhoYHg4GB5mUQiweDBg5GUlFTgKjq5iXxRpKamIicnp8j9iIiIiIhKk+hkfvLkyejSpQvWrl2LxMRErFy5ElZWVgDeX5Xftm0bWrRoUeRxY2Nj4ejoCHNzc4Xy3Pn5Fy9eFBuySrlTg/T19dGmTRv5tw0Fef36NR4+fCj/SUxMVGtcREREREQFET3NxsDAAJGRkSrrDA0N8ejRI+jr6xd53MePHyvtKgsANjY28np10NfXR69evdC0aVOYmJjg0qVLmDt3Lry8vHDhwgU4OTnl23/u3LmYPHmyWmIhIiIiIhJDdDL/b6mpqZBIJDA0NISGhgZMTExEjZOeng5dXV2lcg0NDWhrayM9Pb24oQIA/Pz84OfnJ3/csWNHtGvXDp6enggLC8PatWvz7R8SEoJ+/frJHycmJsLDw0MtsRERERERFYboaTYA8ODBAwQFBcHS0hKmpqYwMTGBhYUFevfujQcPHogaUyqVIjMzU6lcJpMhOzsbUqm0OCHnq06dOmjYsCGOHDlSYFtjY2NUqFBB/qPq2wQiIiIiopIk+sr8rVu3UL9+fTx//hzNmzdH9erVIQgC4uLiEBkZiaioKJw8ebLA6Sr/ZmNjg4SEBKXy3Ok1udNtSoqdnR1OnTpVoscgIiIiIlIH0cn8mDFjkJOTgwsXLqBmzZoKdX/99ReaNWuGMWPGYOvWrUUa183NDdHR0Xj+/LnCTbBnz56V15ekO3fuoGzZsiV6DCIiIiIidRA9zebYsWMYOnSoUiIPAF9//TUGDx6M6OjoIo/r5+cHmUyGxYsXy8sEQcDChQthaWmJpk2bAgCSk5MRFxeHtLQ0UfE/f/5cqezo0aM4efIkfHx8RI1JRERERPQxib4yn5mZme9NrqampirnvhfE3d0dAQEBCA0NxbNnz+Q7wMbExGDlypXym2MXLlyIyZMn49ixY2jSpIm8f2RkpMI0nYsXL+Knn34CAAQEBMDe3h7A+x1qa9WqBVdXV5iamuLy5ctYsWIFrKysEBYWVuS4iYiIiIg+NtHJvIuLCyIjIzFgwADo6ekp1GVmZiIyMhIuLi6ixl6+fDns7e0RERGBJUuWwNnZGZGRkejZs2eBfVesWIHjx4/LH58/fx7nz58H8D6Bz03mO3fujKioKERFReHt27coV64cgoKCEBoaivLly4uKm4iIiIjoY5IIgiCI6bh792506tQJVatWxffff48qVaoAAOLi4vD777/jxo0b2LFjB3x9fdUa8Kfq4cOHsLW1xYMHD1ChQoXSDof+w3x3fBmvKXpvT6c9pR0CFYL37JjSDoE+IgPHOaUdAn1En/r7sOgr8+3bt8fatWvxww8/YMiQIZBIJADez28vV64c1q5d+8Uk8kREREREpaFYm0Z1794dXbt2RWxsLO7duwcAcHBwgLu7O7S01LYfFRERERERqVDsjFtLSwt169ZF3bp11REPEREREREVkuilKVesWIFvvvkmz/ouXbpg9erVYocnIiIiIqICiE7mf//9d5QrVy7PehsbGyxatEjs8EREREREVADRyfzNmzfh6uqaZ32NGjVw8+ZNscMTEREREVEBRCfzEokEycnJedanpKQgJydH7PBERERERFQA0cm8u7s71q1bh4yMDKW69PR0rFu3Dm5ubsUKjoiIiIiI8iY6mR87dixu3LiB+vXrY9u2bbhx4wZu3ryJbdu2oWHDhrhx4wbGjh2rzliJiIiIiOgDopembNGiBSIiIjBkyBD4+fnJywVBgLGxMVasWIFWrVqpJUgiIiIiIlJWrHXmAwIC0LFjRxw6dAjx8fEQBAFOTk5o2bIljIyM1BUjERERERGpUOxNo4yMjPJdb56IiIiIiEqG6DnzRERERERUupjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9ppjMExERERF9pkQvTent7Z1vvUQigVQqRYUKFdCsWTN07twZWlrFXgmTiIiIiIj+n+jsWiaT4dGjR4iPj4epqSkqVqwIQRBw7949vHz5Ek5OTjAxMcHZs2exfPly1KxZE4cPH4a5ubk64yciIiIi+mKJnmYzY8YMPH/+HCtWrMCzZ88QGxuLixcv4tmzZwgPD8fz58+xaNEiJCUlYenSpfjrr78wYcIEdcZORERERPRFE53Mjxw5Er169ULv3r2hqakpL9fU1ETfvn0RGBiIkJAQaGhooF+/fggKCsLu3bvVEjQRERERERUjmb906RIqV66cZ72TkxMuXbokf1ynTh0kJyeLPRwREREREf2L6GS+TJky2Lt3b571e/bsQZkyZeSPX7x4wfnyRERERERqJDqZHzBgAPbv34/27dvj4MGDiI+PR3x8PA4cOABfX18cOnQIAwYMkLffu3cvatasqY6YiYiIiIgIxVjNZvz48Xj79i3mzJmDffv2KdRpamrixx9/xPjx4wEAGRkZ+O677+Dq6lq8aImIiIiISK5YC79PmzYNI0aMwJEjR5CQkAAAsLe3R/PmzWFpaSlvJ5VK0atXr+JFSkRERERECoq9i5OlpSW6d++ujliIiIiIiKgIip3Mv3nzBgkJCXj+/DkEQVCqb9SoUXEPQUREREREKohO5l+8eIEhQ4Zg8+bNyMnJAQAIggCJRKLw79w6IiIiIiJSL9HJ/IABA7Bjxw4MHjwYjRs3hpmZmTrjIiIiIiKiAohO5vfv34+hQ4dizpw56oyHiIiIiIgKSfQ687q6uvnuAEtERERERCVLdDLfpUsXREVFqTMWIiIiIiIqAtHJ/MiRI5GYmIiAgACcPn0aiYmJSEpKUvohIiIiIqKSIXrOvLOzMyQSCWJjY7F+/fo823E1GyIiIiKikiE6mZ80aZJ8GUoiIiIiIvr4RCfzYWFhagyDiIiIiIiKSvSceSIiIiIiKl2FvjK/Zs0aAEBAQAAkEon8cUECAwPFRUZERERERPkqdDIfFBQEiUQCf39/6OjoICgoqMA+EomEyTwRERERUQkpdDJ/9+5dAICOjo7CYyIiIiIiKh2FTubt7e3zfUxERERERB8Xb4AlIiIiIvpMiV6aEgCOHDmC8PBw3LlzB8+fP4cgCAr1EokE8fHxxQqQiIiIiIhUE31lfv78+fDx8cHx48dRvnx5NGrUCI0bN1b4adSokaixs7OzMWnSJNjZ2UEqlcLV1TXfXWY/NG3aNHTs2BE2NjaQSCQYOHBgnm1TU1MxZMgQlCtXDnp6evD09MShQ4dExUxERERE9LGJvjI/Z84cNG7cGAcOHJDfFKsu/fv3x5o1axAcHAwXFxfs3LkT3377Ld69e1fg6jjjx49H2bJlUadOHezbty/PdoIgoGPHjjh16hRCQkJgZ2eH1atXo02bNjh06BC8vb3Vek5EREREROomOplPTk7G+PHj1Z7IX7p0CREREZgyZQomTpwIAOjXrx+8vb0xatQo+dKYeblz5w4qVqwI4P00n7zs3LkT0dHRWLNmDQICAgC8X37T1dUVISEhuHz5svpOioiIiIioBIieZuPu7l4iy1Nu3rwZGhoaCA4OlpdJJBIMHjwYSUlJOHbsWL79cxP5whzH3NwcPXr0kJfp6uqif//++Ouvv3Djxg1xJ0BERERE9JGITubnzp2LiIgIHDlyRJ3xIDY2Fo6OjjA3N1cor1u3LgDg4sWLajuOu7s7NDU1S/Q4REREREQlpdDTbNq0aaNUZmpqCh8fHzg5OcHBwUEpMZZIJPnOW1fl8ePHsLa2Viq3sbGR16vD48ePUa9ePdHHef36NV6/fi1/nJiYqJa4iIiIiIgKq9DJ/LVr11TOQbezs0NWVhZu3rypVJffnPW8pKenQ1dXV6lcQ0MD2traSE9PL/KYRTmOVCqV1+dn7ty5mDx5slpiISIiIiISo9DJ/L1790owjP+RSqXIzMxUKpfJZMjOzpYn2yV1nIyMDHl9fkJCQtCvXz/548TERHh4eKglNiIiIiKiwijWplElwcbGBgkJCUrludNecqfBqOM4qqbGFPY4xsbGMDY2VkssRERERERiiL4Bdvfu3Rg8eHCe9UOGDMHevXuLPK6bmxvi4+Px/PlzhfKzZ8/K69XBzc0NFy9eRE5OTokeh4iIiIiopIhO5mfNmoW0tLQ869PT0/HLL78UeVw/Pz/IZDIsXrxYXiYIAhYuXAhLS0s0bdoUwPt17uPi4vKNoaDjpKSkYMOGDfKyzMxMLFu2DC4uLqhataqocYmIiIiIPhbR02z++ecf+Pv751nv5uaGHTt2FHlcd3d3BAQEIDQ0FM+ePZPvABsTE4OVK1fKb1pduHAhJk+ejGPHjqFJkyby/pGRkQrTdC5evIiffvoJABAQEAB7e3sAQOfOndG4cWN89913iIuLk+8AGx8fjwMHDhQ5biIiIiKij010Mp+dnZ3vVfG0tDT5zaRFtXz5ctjb2yMiIgJLliyBs7MzIiMj0bNnzwL7rlixAsePH5c/Pn/+PM6fPw8AaNCggTyZl0gk2L17N8aNG4fw8HC8fv0aLi4u2Lt3L5o3by4qbiIiIiKij0kiCIIgpmO9evUgCAJOnToFDQ3F2ToymQz169dHTk4Ozp07p5ZAP3UPHz6Era0tHjx4gAoVKpR2OPQf5rvDt7RDoI9oT6c9pR0CFYL37JjSDoE+IgPHOaUdAn1En/r7sOg588OGDcO5c+fQoUMHxMbGIjMzE5mZmYiNjUXHjh1x7tw5DB06VJ2xEhERERHRB0RPs/H390d8fDxCQ0MRFRUF4P3UFUEQIJFIEBoaWqhpMUREREREJE6x1pkfP348unfvju3bt+POnTsQBAFOTk7o1KkTKlWqpK4YiYiIiIhIBdHJ/P3792FpaYlKlSph5MiRSvXp6el49uwZ7OzsihUgERERERGpJnrOfMWKFfNdenL37t2oWLGi2OGJiIiIiKgAopP5ghbBeffuHSQSidjhiYiIiIioAKKTeQB5JuuvXr3C/v37UbZs2eIMT0RERERE+ShSMj958mRoampCU1MTEokEPXv2lD/+8Mfc3BwbNmxAt27dSipuIiIiIqIvXpFugK1duzb69+8PQRCwbNkyeHt7o3LlygptJBIJDAwMULt2bXTt2lWtwRIRERER0f8UKZlv27Yt2rZtCwDIzMzEwIEDUbdu3RIJjIiIiIiI8id6acpVq1apMw4iIiIiIiqiYm0aBQCPHz/GxYsX8fLlS8hkMqX6wMDA4h6CiIiIiIhUEJ3MZ2VloW/fvtiwYQNkMhkkEol8ucoPV7lhMk9EREREVDJEL005adIkbNy4EVOnTkVMTAwEQcDq1atx6NAhtGzZEjVr1sSVK1fUGSsREREREX1AdDK/ceNGBAYGYuzYsahRowYAoHz58mjevDmioqJgYGCAJUuWqC1QIiIiIiJSJDqZf/LkCTw9PQEA2traAID09HQA76fZdOnSBVu3blVDiEREREREpIroZN7S0hIvXrwAABgZGUFPTw937tyR12dnZ+PNmzfFj5CIiIiIiFQSfQNsrVq1cObMGQDvr8Q3btwYv/76K2rVqgWZTIYFCxagVq1aaguUiIiIiIgUib4y/91330EQBGRkZAAAZs2ahdTUVDRu3BhNmjTB27dvMWfOHLUFSkREREREikRfmff19YWvr6/8cY0aNRAfH49jx45BS0sLXl5eMDMzU0uQRERERESkrNibRn3I2NgYHTp0UOeQRERERESUB7Uk88nJyXjx4oV806gPOTs7q+MQRERERET0L6KT+bS0NEyYMAErV65Eampqnu1ycnLEHoKIiIiIiPIhOpnv27cvNm3ahNatW6Nu3bowMTFRZ1xERERERFQA0cn8nj17MGDAAPz+++/qjIeIiIiIiApJ9NKUxsbGcHV1VWcsRERERERUBKKT+V69emHbtm3qjIW+MIsWLYKDgwOkUik8PT1x/vz5fNtv2bIFVatWhVQqhYuLCw4cOKBQHxQUBIlEovDTqlUrhTYXLlyAt7c3TExMULZsWYwYMQJZWVlqPzciIiKij0F0Mv/TTz/B2dkZTZs2RWRkJI4dO4YTJ04o/RCpsmnTJoSEhCA0NBQXL16Eq6srfHx8kJycrLL96dOn0b17d/Tt2xeXLl1Cp06d0KFDB1y/fl2hXbt27ZCYmCj/2bBhg7zu0aNHaNGiBb766iucP38e27dvx6FDh/DDDz+U6LkSERERlRTRc+ZfvXqFhw8f4vjx4yqTdkEQIJFIuJoNqTR37lz0798fvXv3BgAsWbIE+/btQ0REBEaOHKnUft68eWjdujVGjRoFAJgyZQoOHTqERYsWYeHChfJ2urq6KFeunMpj7tu3D3p6epg/fz4kEgmcnZ0xc+ZMdOnSBT///DOMjY1L4EyJiIiISo7oZL5Pnz44cOAA+vTpAw8PD65mQ4WWlZWF2NhYTJgwQV6moaGB5s2b4/Tp0yr7nD59Wp7I5/Lx8cHevXsVyo4ePYqyZcvCzMwMzZs3x9SpU2Fubg4AyMzMhK6uLiQSiby9np4eMjMzERsbi6ZNm6rrFImIiIg+CtHJ/NGjRzF8+HD88ssv6oyHvgDJycnIycmBlZWVQrmVlRVu376tss+TJ09Utn/y5In8cevWreHn5wc7Ozvcvn0b48aNQ9u2bXHy5EloaGjA29sbI0aMwLx58xAcHIyUlBRMnTpVPj4RERHR50b0nHkzMzPY2dmpMxb6wnx4hRz439Qsse27deuGNm3a4KuvvkLHjh2xd+9enDlzBn/88QcAoEaNGli5ciV+/vln6OnpoVKlSvDx8QHw/psBIiIios+N6Axm4MCBWLduHd69e6fOeOgLYGFhAU1NTaWr4UlJSUpX33OVK1euSO0BoFKlSrCwsFC42h8YGIhnz57h4cOHSE5ORqdOnQAAFStWFHs6RERERKVG9DSbSpUqITMzE19//TWCgoJQoUIFaGpqKrXz8/MrVoD036OjowN3d3ccPnwYvr6+AACZTCafuqVKvXr1cPjwYQwZMkRedvjwYdSrVy/P4zx8+BApKSmwtrZWqsu9SXbTpk0oX7483NzcinFGRERERKVDdDLfo0cP+b9Hjx6tso1EImEyTyqFhIQgMDAQ7u7u8PDwwLx585CWloagoCAA76+gly9fHtOnTwcADBs2DI0aNcKcOXPQtm1bbNy4EbGxsVixYgUA4M2bN5g8eTK++eYblCtXDvHx8fjxxx9RpUoVNG/eXH7chQsXokGDBpBKpdi5cyd+/vlnbNiwAVpaol8KRERERKVGdAZz7NgxdcZBX5hu3brh2bNnmDRpEp48eYKaNWviwIEDsLCwAADcv39fYR67l5cXNmzYgAkTJmDcuHGoXLkydu7ciWrVqgEANDU1ceXKFaxevRqvXr2CjY0NfHx8MGXKFOjo6MjHOXXqFCZNmoS0tDS4uLhg27Zt8m8HiIiIiD43EkEQhNIO4r/g4cOHsLW1xYMHD1ChQoXSDof+w3x38MPHl2RPpz2lHQIVgvfsmNIOgT4iA8c5pR0CfUSf+vswl/AgIiIiIvpMMZknIiIiIvpMMZknIiIiIvpMMZknIiIiIvpMcT2+/wDeePVlMXAs7QiIiIjoU8Er80REREREn6liJfMJCQn47rvv4OjoCGNjYxw/fhwAkJycjO+//x6xsbFqCZKIiIiIiJSJnmZz/fp1NGjQADk5OfD09MS9e/eQk5MDALCwsMCZM2eQmZkp36GTiIiIiIjUS3QyP3r0aBgbG+PMmTPQ1NRE2bJlFerbtGmDLVu2FDtAIiIiIiJSTfQ0mxMnTuD777+HlZUVJBKJUr29vT0ePXokauzs7GxMmjQJdnZ2kEqlcHV1xfr16wvVVxAEzJ8/H87OztDV1YWzszN+++03/Huj24iICEgkEpU/t2/fFhU3EREREdHHJPrK/Lt372BoaJhn/fPnz6GlJW74/v37Y82aNQgODoaLiwt27tyJb7/9Fu/evUNgYGC+fadMmYKwsDAEBATgxx9/xPHjxzFs2DC8fPkSkyZNUmofFhYGR0fF5UHKlSsnKm4iIiIioo9JdDLv4uKCY8eOYdCgQUp1giBg+/btcHd3L/K4ly5dQkREBKZMmYKJEycCAPr16wdvb2+MGjUK/v7+0NHRUdn3yZMnmD59Ovr06SOfq9+vXz9oampi2rRp6N+/v1Ki7uPjA09PzyLHSURERERU2kRPswkJCcHWrVsxdepUpKSkAABycnIQFxcHf39/XLhwASNHjizyuJs3b4aGhgaCg4PlZRKJBIMHD0ZSUhKOHTuWZ99du3YhMzMTQ4YMUSgfMmQIMjMzsWvXLpX9UlNT5TfvEhERERF9LkRfme/atSsSEhIwbtw4hIWFAQBatWoFANDU1MScOXPQunXrIo8bGxsLR0dHmJubK5TXrVsXAHDx4kX4+Pjk2VdXVxeurq4K5bVq1YKOjg4uXryo1KdFixZ48+YNdHR00KxZM8yePRvVq1cvMM7Xr1/j9evX8seJiYkF9iEiIiIiUqdi7QA7cuRI+Pv7Y9u2bbh58yZkMhkcHR3RpUsXODg4iBrz8ePHsLa2Viq3sbGR1+fX18rKChoail84aGhowMrKSqGvvr4+evXqhaZNm8LExASXLl3C3Llz4eXlhQsXLsDJySnfOOfOnYvJkycX5dSIiIiIiNSqWMk8AFSoUAHDhg1TRywAgPT0dOjq6iqVa2hoQFtbG+np6UXuCwBSqVShr5+fH/z8/OSPO3bsiHbt2sHT0xNhYWFYu3ZtvnGGhISgX79+8seJiYnw8PDItw8RERERkToVO5lXN6lUiszMTKVymUyG7OxsSKXSIvcFgIyMjHz7AkCdOnXQsGFDHDlypMA4jY2NYWxsXGA7IiIiIqKSIvoGWA0NDWhqaub7Y2BggCpVquD777/H3bt3CzWujY2NyvnnuVNkcqfb5NX36dOnkMlkCuUymQxPnz7Nt28uOzs7PH/+vFCxEhERERGVJtHJ/KRJk/D1119DU1MTbdq0wfDhwzFs2DC0bt0ampqaqFmzJr7//ntUq1YNy5cvh5ubG/75558Cx3Vzc0N8fLxSQn327Fl5fX59MzMzceXKFYXyS5cuISsrK9++ue7cuaO0my0RERER0adIdDLv4OCAJ0+e4OrVq9izZw/mzJmDuXPnYu/evfj777/x+PFjfPXVV9i5cyeuXLkCiUSCCRMmFDiun58fZDIZFi9eLC8TBAELFy6EpaUlmjZtCgBITk5GXFwc0tLS5O06dOgAHR0dLFy4UGHMBQsWQEdHBx06dJCXqbr6fvToUZw8eTLP1XKIiIiIiD4loufMz5w5E8HBwahcubJSXZUqVRAcHIzp06ejV69eqFq1KgYMGIClS5cWOK67uzsCAgIQGhqKZ8+eyXeAjYmJwcqVK+U3uC5cuBCTJ0/GsWPH0KRJEwDvp9mMHj0aU6dORXZ2Nho1aoTjx48jMjISkyZNUlglp0GDBqhVqxZcXV1hamqKy5cvY8WKFbCyspIvtUlERERE9CkTnczfu3cPenp6edbr6+sjISFB/rhSpUrIyMgo1NjLly+Hvb09IiIisGTJEjg7OyMyMhI9e/YssO/kyZNhZmaGRYsWYePGjbC1tcXcuXMxfPhwhXadO3dGVFQUoqKi8PbtW5QrVw5BQUEIDQ1F+fLlCxUnEREREVFpkgiCIIjp6OLiAgA4ffo0DA0NFepSU1Ph6ekJTU1N+fz1CRMmYP369bhz504xQ/40PXz4ELa2tnjw4AEqVKjwUY/tPTvmox6PSpeB45zSDoE+oj2d9pR2CFQIfB/+svB9+Mvyqb8Pi74yP3XqVHTp0gXOzs7o1auXfJOlW7duITIyEk+fPsXWrVsBADk5OdiwYQO8vLzUEzUREREREYlP5jt27Ii9e/di9OjRmDlzpkKdq6srli9fjtatWwN4fwNrdHQ0zMzMihctERERERHJFWvTqFatWqFVq1ZITEyUz4+3t7dXuNEUALS0tGBvb1+cQxERERER0b+oZQdYa2trpQSeiIiIiIhKVrGT+cePH+PixYt4+fKl0s6rABAYGFjcQxARERERkQqik/msrCz07dsXGzZsgEwmg0QiQe7COBKJRN6OyTwRERERUckQvQPspEmTsHHjRkydOhUxMTEQBAGrV6/GoUOH0LJlS9SsWVO+LCUREREREamf6GR+48aNCAwMxNixY1GjRg0AQPny5dG8eXNERUXBwMAAS5YsUVugRERERESkSHQy/+TJE3h6egIAtLW1AQDp6ekA3k+z6dKli3ydeSIiIiIiUj/RybylpSVevHgBADAyMoKenp7C7q7Z2dl48+ZN8SMkIiIiIiKVRN8AW6tWLZw5cwbA+yvxjRs3xq+//opatWpBJpNhwYIFqFWrltoCJSIiIiIiRaKvzPf7v/buPSiq8/7j+GcRWCAVxBkVExewKlAbJUKMJOkIOgw61gLxSmJorTE0tiGx2ow2thqkqQ1pnEaMVqOpMV6iTkdkEmtrAW1mkskFiNFJwHhLIIiKyq3KJXD6R3/szy33wyIg79fMjnO+53me811mfPju2ec8LF4swzBUU1MjSXr55ZdVVVWlyMhIRUVF6d///rdeeeUVpyUKAAAAwJHpO/OxsbGKjY21H3//+9/X2bNnlZOTI1dXVz300EPy9fV1SpIAAAAAmjNVzNfU1CgtLU0RERGKiYmxx729vRUXF+e05AAAAAC0ztQyGw8PD61bt05ff/21s/MBAAAA0EGm18zfd999Onv2rDNzAQAAANAJpov5P/zhD9q2bZv+9re/OTMfAAAAAB1k+gHYl156SYMGDdLMmTNls9k0cuRIeXp6OrSxWCx69913u5wkAAAAgOZMF/Off/65LBaL/P39JUkXLlxo1sZisZhODAAAAEDbTBfzLRXvAAAAAG4f02vmAQAAAPSsLhXzDQ0N2r17t5544gn96Ec/0meffSZJKi8v1/79+3Xx4kWnJAkAAACgOdPFfEVFhR5++GElJibqwIEDOnz4sMrKyiRJAwcO1LJly5Senu60RAEAAAA4Ml3Mr1y5UqdOndLhw4d17tw5GYZhPzdgwADNnj2bbSsBAACAbmS6mM/IyFBycrKmT5/e4q41Y8aM0VdffdWl5AAAAAC0znQxf/36dY0aNarV84ZhqLa21uzwAAAAANphupgfOXKkTp061er548ePKzg42OzwAAAAANphuph//PHH9frrr+tf//qXPda03CY9PV0HDx7UwoULu5wgAAAAgJaZ/qNRK1as0IcffqipU6cqKChIFotFzzzzjK5evarS0lLFx8crOTnZmbkCAAAAuIXpO/Ourq7KzMzUrl27FBISopCQEH377bcKCwvTzp079de//rXFB2MBAAAAOIfpO/NNEhISlJCQ4IxcAAAAAHSC6Tvzq1atavMBWAAAAADdy3Qxn5aWptDQUN1777168cUXdebMGWfmBQAAAKAdpov5b775Rq+++qp8fX21evVqBQcHa+LEiVq/fr2Ki4udmSMAAACAFpgu5ocOHaqnn35a7733nr7++mulpaXJYrHoV7/6lQIDAzV58mRt3rzZmbkCAAAAuIXpYv5W99xzj5YvX66PPvpIZ86c0QsvvKATJ06wNSUAAADQjbq8m82tPvjgA+3bt08HDhxQVVWVfHx8nDk8AAAAgFt0uZjPzc3Vvn37tH//fhUVFcnT01MzZ85UQkKCZsyY4YwcAQAAALTAdDG/atUq7du3T+fPn5ebm5tiYmK0bt06xcXFycvLy5k5AgAAAGiB6WI+LS1NUVFRev755zVr1iwNGjTIiWkBAAAAaI/pYv6bb77R0KFDnZkLAAAAgE7o0taUAAAAAHpOlx6AvXz5srZv367c3FyVl5ersbHR4bzFYlFWVlaXEgQAAADQMtPF/Oeff67IyEhVV1crODhYJ0+e1NixY3X9+nWVlJRo1KhRstlszswVAAAAwC1ML7NZuXKlrFarvvjiC/3zn/+UYRh69dVXVVxcrN27d+v69et6+eWXnZkrAAAAgFuYLubfe+89/exnP1NgYKBcXP47TNMym0cffVTz58/Xc889Z2rs+vp6rV69Wv7+/vLw8ND48eO1Z8+eDvVt+lARFBQkq9WqoKAgbdiwQYZhNGtbVVWl5ORk+fn5ydPTUxEREfrHP/5hKmcAAADgdjNdzNfV1Wn48OGSJE9PT0lSRUWF/fx9992njz/+2NTYSUlJevHFFxUfH6/09HTZbDYtWLBAO3fubLfv2rVrtXTpUkVEROi1117TpEmT9Oyzzyo1NdWhnWEYio+P17Zt2/TEE0/oT3/6k1xcXDRjxgxlZ2ebyhsAAAC4nUyvmQ8ICNCFCxck/beYHz58uN5//33Nnj1bknTq1Cl95zvf6fS4+fn52rFjh9auXavf/va3kqTFixdr6tSpeu6555SQkCB3d/cW+5aWlmrdunVatGiRtm/fbu87YMAA/f73v1dSUpL8/PwkSRkZGcrOztbOnTuVmJgoSVq4cKHGjx+vZcuW6dNPP+107gAAAMDtZPrO/JQpU3To0CH78YIFC7RhwwYtXrxYixYt0qZNmxQXF9fpcffv3y8XFxf94he/sMcsFouefvppXb58WTk5Oa32PXTokGpra5WcnOwQT05OVm1trUO++/fv1+DBg/XYY4/ZY1arVUlJSTpx4oQKCws7nTsAAABwO5m+M79ixQpNnTpVNTU18vDwUGpqqsrLy3XgwAG5uroqMTFRf/zjHzs9bm5urkaNGqXBgwc7xCdNmiRJysvL07Rp01rta7VaNX78eIf4hAkT5O7urry8PIe24eHhGjBgQKvXCQ4ObjXPyspKVVZW2o+LiookSRcvXmzvLTpdTfmV235N9ByXsps9nQJuo+Li4p5OAR3APNy/MA/3Lz05D/v5+cnVte1y3XQx7+/vL39/f/ux1WrV1q1btXXrVrNDSpJKSkrsa/Fvdffdd9vPt9V32LBh9gdym7i4uGjYsGEOfUtKSvTggw+auo4krV+/XikpKc3iDzzwQJv9AKAzbGKLXwDoST05DxcVFWnEiBFttunSH43qDjdv3pTVam0Wd3FxkZubm27ebP3TcGt9JcnDw8Ohb2ttPTw87OfbsmzZMi1evNh+XFNTo6KiIo0cObLdT1CAWRcvXtQDDzygjz76qMUPvQCA7sU8jNup6VnPtvS6qtPDw0O1tbXN4o2Njaqvr7cX253pK8m+HKi9tjU1NfbzbfH29pa3t7dDbPTo0W32AZxl+PDh7X5SBwB0H+Zh9BamH4DtLnfffXeL686blr00LYNpre+lS5fs+903aWxs1KVLlxz6duU6AAAAQG/Q64r5sLAwnT17VteuXXOIf/jhh/bzbfWtra3VZ5995hDPz89XXV2dQ9+wsDDl5eWpoaGh09cBAAAAeoNeV8zPmzdPjY2N2rRpkz1mGIY2btyoIUOGaMqUKZKksrIyFRQU6MaNG/Z2cXFxcnd318aNGx3GTE9Pl7u7u8NWmfPmzdPVq1e1d+9ee6y2tlZbt27VuHHjFBIS0l1vETDN29tba9asabbECwBwezAPo7fpdWvmw8PDlZiYqDVr1ujKlSsaN26cMjIydOzYMb3xxhv2h1Y3btyolJQU5eTkKCoqStJ/l8asWLFCqampqq+v1+TJk3X8+HG99dZbWr16tcODKrNmzVJkZKSefPJJFRQUyN/fX2+++abOnj2rI0eO9MRbB9rl7e2tF154oafTAIB+i3kYvU2vK+Yladu2bQoICNCOHTv05z//WUFBQXrrrbf0+OOPt9s3JSVFvr6+eu211/T222/LZrNp/fr1Wrp0qUM7i8WizMxMPf/883r99ddVWVmpcePG6Z133lF0dHQ3vTMAAADAeSyGYRg9nQQAAACAzut1a+YBAAAAdAzFPAAAANBHUcwDAAAAfRTFPAAAANBHUcwDJly7dk1r1qzRhAkT5O3tLavVKn9/f82bN0+HDh3SnfJc+ZQpU2SxWJSamtrTqQCAgzt5Ho6KipLFYrG/3NzcFBgYqJ/+9Kc6d+5cT6eHXobdbIBOys/P1w9/+ENdvXpVc+fOVUREhLy8vFRcXKy///3vev/997Vp0yYtWbKkp1PtkuLiYgUEBCggIEBubm4qLCzs6ZQAQNKdPw9HRUXp9OnTSktLkyTV19ersLBQmzdvlpeXlwoKCuTj49PDWaK36JX7zAO9VUVFhWJjY9XY2Kjc3Fzde++9DudXr16t48eP6/r1622Oc+PGDXl5eXVnql22e/dueXl5afPmzZo+fbo+/vhjTZw4safTAtDP9Zd52Nvbu9nf1wkODtaiRYuUnZ2tRx55pIcyQ2/DMhugE7Zs2aLi4mK98sorzX6BNImMjFR8fLz9eMeOHbJYLMrJydHSpUvl5+enu+66y34+MzNTkyZNkpeXl3x9fTVr1iydPn3aYcyFCxcqMDCw2bWaxr5w4YI9FhgYqOnTpys7O1vh4eHy9PTU6NGjtW3btk691127dikuLk4xMTEaMWKEdu3a1an+ANAd+tM8/L+GDh0qSXJ15V4s/h/FPNAJmZmZ8vDw0Ny5czvdNzk5Wbm5uVq1apXWrl0rSdq7d6/i4+NVV1en3/3ud3rmmWd07NgxPfTQQ/rqq69M53nu3DnNmjVLU6dO1UsvvaQhQ4boySef1Jtvvtmh/p9++qlOnTqlRx99VBaLRfPnz9fbb7+tb7/91nROAOAM/WUebmhoUFlZmcrKynTx4kUdO3ZMq1at0j333KMpU6aYzgt3IANAhw0ePNgIDQ1tFq+urjauXLlif5WXl9vP/eUvfzEkGZMmTTLq6+vt8bq6OsPPz88IDg42qqur7fG8vDzDxcXFSExMtMd+8pOfGAEBAc2u2zT2+fPn7bGAgABDkrFnzx57rKamxhg3bpzh5+fnkENrli9fbgwePNioq6szDMMwPvnkE0OScfjw4Xb7AkB36g/zcGRkpCGp2Wv06NHGF1980WZf9D/cmQc6obKyUgMHDmwWX7FihYYMGWJ/xcXFNWuTlJTk8NVobm6uSktL9fOf/9zh694JEyYoOjpa7777ruk8hw4dqvnz59uPrVarkpKSVFpaqk8++aTNvo2Njdq7d6/mzJkjNzc3SVJ4eLjGjBnDUhsAPa4/zMOSZLPZdPToUR09elRHjhzRpk2bJEkxMTHsaAMHFPNAJwwcOFBVVVXN4snJyfZJ12aztdh35MiRDsdN6ytDQkKatR07dqyuXbumiooKU3mOGjVKLi6O/72DgoIkqd2vjbOyslRSUqIf/OAHunDhgv0VHR2tjIwMVVdXm8oJAJyhP8zDkuTl5aXo6GhFR0dr2rRpWrJkibKzs3X58mWtXLnSVE64M/EEBdAJ3/ve95SXl6e6ujq5u7vb48HBwQoODpakVndH8PT07PB1jP/bMdZisTj8+78aGhpajLfU3ujgLrRNd99//OMft3j+4MGDSkxM7NBYAOBs/WEebo3NZlNISIhycnK6NA7uLNyZBzohNjZWNTU12rdvX5fHatoVoaCgoNm5goIC+fr6ytvbW5I0aNAglZeXN2t36+4Jtzpz5owaGxsdYl9++aUkKSAgoNWcbty4oYMHD2rOnDk6ePBgs9fYsWNZagOgR93p83B7Ghoa+IYUDijmgU546qmnNGLECC1fvlwnT55ssU1H77yEh4fLz89Pmzdv1s2bN+3xEydO6OjRo5o5c6Y9Nnr0aFVUVCg/P98eq66ubnVXhMuXLzv8oqutrdXWrVs1bNgw3X///a3mlJGRoaqqKj311FOKj49v9po7d66ysrJUWlraofcIAM52p8/Dbfnyyy9VWFio0NBQU/1xZ2KZDdAJPj4+yszM1IwZM3T//fdrzpw5evDBB+Xl5aWSkhK98847On36tCIiItody83NTevXr9eCBQv08MMPKzExURUVFUpPT5evr69SU1PtbR977DH9+te/1iOPPKJnn31W9fX1euONNzRs2DAVFRU1G3vMmDFasmSJ8vPzZbPZtGfPHp08eVLbt29vc3/iXbt2ycfHR5MnT27xfGxsrFJSUrR371798pe/7MBPDACc606fh5tUVlbavwltaGjQ+fPntWXLFjU0NCglJaUTPzHc8XpyKx2gryorKzN+85vfGKGhocZdd91luLu7GzabzZg9e7aRkZFhNDY22ts2bVv2wQcftDhWRkaGMXHiRMPDw8Pw8fEx4uPjjcLCwmbtsrKyjNDQUMPNzc0IDAw0NmzY0OqWaNOmTTOysrKMsLAww2q1Gt/97neNLVu2tPmeLl26ZLi6uhoJCQlttrPZbEZYWFibbQCgu92J83CTlram9PHxMWJiYoycnJxO/Zxw57MYRhefxgDQqwQGBiokJERHjhzp6VQAoF9iHsbtxJp5AAAAoI+imAcAAAD6KIp5AAAAoI9izTwAAADQR3FnHgAAAOijKOYBAACAPopiHgAAAOijKOYBAACAPopiHgAAAOijKOYBAACAPopiHgAAAOijKOYBAACAPopiHgAAAOij/gMx+GkQhGglZgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(7.0, 4.4))\n", + "groups = [\"Group A\", \"Group B\"]\n", + "xs = np.arange(2)\n", + "w = 0.35\n", + "v_costs = [avg_A_v, avg_B_v]\n", + "f_costs = [avg_A_f, avg_B_f]\n", + "ax.bar(xs - w / 2, v_costs, w, label=\"vanilla Sinkhorn\", color=\"C0\", alpha=0.85)\n", + "ax.bar(\n", + " xs + w / 2,\n", + " f_costs,\n", + " w,\n", + " label=rf\"fairness ($\\Delta\\leq{delta}$)\",\n", + " color=\"C2\",\n", + " alpha=0.85,\n", + ")\n", + "ax.set_xticks(xs)\n", + "ax.set_xticklabels(groups, fontsize=11)\n", + "ax.set_ylabel(\"average matching cost subjected by the group\")\n", + "ax.set_title(\"Cost equity: vanilla vs fairness-constrained Sinkhorn\")\n", + "ax.legend(loc=\"upper left\")\n", + "for i, (vc, fc) in enumerate(zip(v_costs, f_costs)):\n", + " ax.text(i - w / 2, vc + 0.005, f\"{vc:.3f}\", ha=\"center\", fontsize=9)\n", + " ax.text(i + w / 2, fc + 0.005, f\"{fc:.3f}\", ha=\"center\", fontsize=9)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "f6fedb80", + "metadata": {}, + "source": [ + "**Reading the figure.** Vanilla Sinkhorn produces a large average-cost gap between the two groups : group B is subjected to costs roughly five times higher than group A, simply because the cost structure favours A. With the fairness constraint, the algorithm is forced to redistribute matches: group A's average cost rises (it gets some mediocre matches now), group B's drops (it receives some of the previously A-only good matches), and the total transport cost is a few tens of percent above the unconstrained optimum : the **price of fairness** that the user can dial in by choosing $\\delta$.\n", + "\n", + "The dual variable $\\alpha$ on the fairness constraint is informative on its own: it is the *marginal cost* of tightening the fairness gap by one unit. We make this *price* interpretation precise in §9 below." + ] + }, + { + "cell_type": "markdown", + "id": "bf6c56d8", + "metadata": {}, + "source": [ + "## 9. Sensitivity analysis: $\\alpha$ as a free gradient\n", + "\n", + "A small but consequential observation: at the optimum, the constraint dual $\\alpha$ already encodes the **derivative of the optimal cost with respect to the constraint threshold**, no autodiff call required. This is the *envelope theorem* applied to our Lagrangian.\n", + "\n", + "If we write the constrained problem with a parameterised threshold,\n", + "$$\\mathrm{cost}^\\star(t) \\;=\\; \\min_{P \\in U(a,b)} \\langle P, C \\rangle \\quad\\text{s.t.}\\quad D \\cdot P \\le t,$$\n", + "then by the envelope theorem\n", + "$$\\frac{d\\, \\mathrm{cost}^\\star}{dt} \\;=\\; -\\alpha^\\star_{\\mathrm{phys}}(t) \\;=\\; -\\alpha^\\star(t)\\,/\\,n,$$\n", + "where the factor $1/n$ comes from the way we normalised $D$ in the homogeneous form (recall: $D_{\\mathrm{solver}} = (t\\,\\mathbf{1}\\mathbf{1}^\\top - D_{\\mathrm{phys}}) / n$). Let's verify this empirically by sweeping $t$ and comparing the analytical $-\\alpha/n$ value with a finite-difference numerical derivative." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "be516613", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:40:35.237516Z", + "iopub.status.busy": "2026-05-02T06:40:35.237362Z", + "iopub.status.idle": "2026-05-02T06:41:25.147728Z", + "shell.execute_reply": "2026-05-02T06:41:25.146450Z" + } + }, + "outputs": [], + "source": [ + "# Sweep the threshold and record both cost and alpha.\n", + "ts_env = np.linspace(0.12, 0.45, 9)\n", + "costs_env = []\n", + "alphas_env = []\n", + "\n", + "# warmup at a representative t\n", + "_ = constrained_sinkhorn(\n", + " C_j,\n", + " a_j,\n", + " b_j,\n", + " jnp.array(((0.3 * np.ones((n_a, n_a)) - DI_orig) / n_a)[None, ...]),\n", + " jnp.zeros((0, n_a, n_a)),\n", + " eps=eps_run,\n", + " n_iters=10,\n", + " n_newton=5,\n", + ")\n", + "\n", + "for t in ts_env:\n", + " D_t = (t * np.ones((n_a, n_a)) - DI_orig) / n_a\n", + " res_t = constrained_sinkhorn(\n", + " C_j,\n", + " a_j,\n", + " b_j,\n", + " jnp.array(D_t[None, ...]),\n", + " jnp.zeros((0, n_a, n_a)),\n", + " eps=eps_run,\n", + " n_iters=250,\n", + " n_newton=10,\n", + " )\n", + " costs_env.append(float(jnp.sum(res_t.matrix * C_j)))\n", + " alphas_env.append(float(res_t.alphas[0]))\n", + "\n", + "costs_env = np.array(costs_env)\n", + "alphas_env = np.array(alphas_env)\n", + "dcost_dt_num = np.gradient(costs_env, ts_env)\n", + "dcost_dt_thy = -alphas_env / n_a # divide by n: physical units" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "73061275", + "metadata": { + "execution": { + "iopub.execute_input": "2026-05-02T06:41:25.151040Z", + "iopub.status.busy": "2026-05-02T06:41:25.150369Z", + "iopub.status.idle": "2026-05-02T06:41:25.420972Z", + "shell.execute_reply": "2026-05-02T06:41:25.419885Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABVIAAAG/CAYAAAC61iwxAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAQ6wAAEOsBUJTofAABAABJREFUeJzs3Xd4FFUXwOHfbnpvhBQICQRCD733jnQQEVSkiCjS7aJAAAVFRZEiKGAQpCm9Cwih9470GkgCKSQB0nfn+yPfriy76QkJcN7n4dGduXfmzN3JlLMz96oURVEQQgghhBBCCCGEEEIIkSF1YQcghBBCCCGEEEIIIYQQRZ0kUoUQQgghhBBCCCGEECILkkgVQgghhBBCCCGEEEKILEgiVQghhBBCCCGEEEIIIbIgiVQhhBBCCCGEEEIIIYTIgiRShRBCCCGEEEIIIYQQIguSSBVCCCGEEEIIIYQQQogsSCJVCCGEEEIIIYQQQgghsiCJVCGEEEIIIYQQQgghhMiCJFKFEEIIIYQQQgghhBAiC5JIFUIIIYQQQgghhBBCiCxIIlUIIYQQQgghhBBCCCGyIIlUIUSu3bhxA5VKRf/+/Qs7FANFNS6RuRkzZlC5cmVsbW1RqVQEBQUVdkg59rzte8/DdyKEeL4FBwejUqkIDg7OUT2VSkXz5s3ztO7+/fujUqm4ceNGnpbzosvtd5hXunN2XvcD8XT4+fnh5+dX2GFkqLDiy8u1Z9u2bSlfvjwajSb/AytkL8LxWY5huXP48GFUKhXz5s3L9TIkkSrEM+rkyZMMGjSIcuXKYWtri729PZUqVWL48OFcuXIlX9axa9eu5yopVNS9yO29bNkyRowYQWpqKsOHD2f8+PFF9qLgRfmecvudzJo1C5VKxS+//FLwQQoh8kyr1TJ//nyaN2+Oq6srFhYWuLu7U6VKFfr378+yZcsKO8RcUalUhZLUeFHOEc8CSTIUHF2SateuXYUdSo49y7Hnp7Vr17Jt2zYmTpyImZlZYYcjTJBjWMGoW7cuXbt25YsvviA+Pj5XyzDP55iEEAVMURTGjh3L5MmTUavVtGrVim7duqHVajly5AgzZ85kzpw5TJ8+nffee69AYylRogTnz5/HycmpQNcjnn8bNmwA4Pfff6d+/fqFHE3uPU9/E7n9To4ePQpArVq1CiQuIUT+0Wq1dOnShY0bN+Lk5ESnTp0oWbIksbGxXL16lb/++ouDBw/Su3fvwg41Q927d6d+/fp4eXnlqN758+extbXN07qnTJnCp59+SokSJfK0HCFE1nbs2FHYITxXxowZg7+/P7169SrsUIR46j777DPq16/P9OnTGTt2bI7rSyJViGfM5MmT+eqrr/Dx8WHdunVUr17dYP7OnTt5+eWXGTp0KE5OTrz++usFFouFhQUVKlQosOWLF0dYWBgAnp6ehRxJ3jxPfxO5/U6OHTuGhYUFVatWLYiwhBD5aNmyZWzcuJFq1aoREhJi9CNQUlISe/bsKaTossfJySlXP17lx7Hay8srxwlcIUTu+Pv7F3YIz40dO3bw77//MmHCBFQqVWGHI8RTV69ePSpUqMDcuXP57LPPMDfPWWpUXu0X4hly8+ZNJkyYgLm5OevXrzdKogK0aNGCRYsWATBy5EgePnwIGL4acOfOHd544w3c3d2xsbGhdu3aLF++3GA5QUFBtGjRAoCFCxeiUqn0/3R9WGXUJ8/j67p79y4DBw7Ew8MDOzs7GjZsqL8pe/jwIe+//z6lSpXCysqKypUr8+effxptU3BwMC+//DJlypTBxsYGR0dHGjVqxOLFi/PSnEaOHDlC7969KVGiBJaWlnh6etKyZUsWLlxoVHbVqlW0aNECZ2dnrK2tqVixIuPGjdO39+PWr19P69at8fb2xsrKCg8PD+rVq8fkyZP1ZbLT3pnJaRtlJ6b8XqcpQUFBqFQqdu7cCUDp0qX12w3/vSKZUd+czZs3N7oAfHz/i4qKYvDgwXh5een3sd9++83ksrLz/Wf1PWXVT1V295vcbkNWsrP+rL6TjHz++eeoVCrOnDlDamoqVlZW+nozZ87MVbxCiIK1b98+IP1VV1PJSGtra9q0aWOy7s6dO+nSpQvu7u5YWlri6+vLe++9R0REhFFZ3bH6xo0bzJ07l6pVq2JtbY2HhweDBw8mLi7OqM7p06d57bXXKF26NNbW1ri6ulK5cmWGDBliUP7J/jV15w1Iv256/Dj9+LH5ydcl3333XVQqFStXrjS5vRcuXEClUtGkSRP9tCf74MvqHLFx40ZUKhUDBgwwuQ6tVouvry+2trbcv3/fZJknXblyhUGDBuHr64uVlRXu7u50796d48ePmywfGRnJu+++i5eXF9bW1lSqVImffvqJ69evmzx/mTrP6mR0jj527BgjR46kWrVquLq6Ym1tTbly5fjggw+yvV2Q/X3AlODgYEqXLg1ASEiIwXdh6poip+fanOz/WVmxYgVt2rTBzc0NKysrfH19eeWVV9i7d69BOUVRmDdvHvXr18fBwQFbW1tq1KjB999/T2pqqtFy/fz8UKlUpKWlMXnyZMqVK4eVlRU+Pj588sknpKSkGNXZs2cPXbp0wcfHB0tLS4oVK0aNGjUYPXo0iqLol6u7NmrRooVB2+o8/vr8okWLqFOnDnZ2dvp7l5SUFGbOnEmHDh30+66rqyutW7dm8+bNJtvJVB+kur//oKAgTp48SceOHXF2dsbW1pZmzZqxf/9+o2VkFXtmUlNT+frrrwkICNC35ciRI3nw4IHJ8rprqoy6ETD12nZYWBgTJ06kUaNGeHp6Ymlpibe3N6+99hrnz5/PVpxZ0fUN+eqrr2ZYZuXKlbRu3RpXV1esrKwoV64cn376qclXoXOyr925cwczM7NMf3Dv3bs3KpXK6Cnk48eP07t3b7y9vbG0tMTLy4u+ffvmuFu7nNzH6Y6B165d47vvvqN8+fJYW1vj4+PDBx98kOF3HxERwahRoyhXrhzW1ta4uLjQtm3bbD9ZXdSOYdeuXWPAgAGULFkSS0tLPDw86NWrF6dPnzYZuy7OgwcP8tJLL+Hi4oJKpSI2NhZIP9/NmzePRo0a4eTkhLW1NVWqVGHKlCkmj0267noePnzI6NGj8fHxwcbGhurVq7NmzRoA0tLSmDRpkr7N/f39M70H6d27N3fu3GHr1q0ZlsmIPJEqxDNkwYIFpKam8sorr1CtWrUMy3Xs2JHatWtz9OhR/vrrL4OL4vv379OoUSNcXFwYOHAg9+/fZ8WKFfTu3ZuwsDBGjx4NpJ80bty4wcKFC6lWrRrdunXTL8NUAteU2NhYGjVqhKurK6+//jq3b9/mr7/+ol27dhw4cIC3336bhw8f0q1bN+Lj41m6dCmvvvoqPj4+Bq8SDxkyhEqVKtG0aVO8vLyIjo5m06ZN9O3bl4sXLzJp0qQctaMp8+bN491330WtVtO5c2fKly9PVFQUx48f58cff6Rfv376suPGjWPSpEm4urry6quv4uzszLZt25g0aRLr1q1jz549ODg4APDLL7/wzjvv4OHhQadOnShevDhRUVH8+++/zJkzhzFjxuRLe+ekjbIbU36uMyO6i8fg4GBu3rzJyJEjcXZ2ztb6s6Lb/ywtLenZsydJSUn89ddfDBw4ELVabfCdZvf7z8v3lJP9JjfbkJXsrj+330mNGjV46623mD9/PrVr16Zjx476ee3atct2nEKIp8fV1RWAS5cu5ajeN998w6effoqrqysdO3bE09OT06dP8/PPP7Nu3ToOHjxIyZIljep9/PHHbN26lc6dO9O2bVt27tzJr7/+ypUrV/jnn3/05U6fPk29evVQqVR06tQJf39/Hj58yPXr11m4cCEffPBBhk+h+vn5MX78eCZMmICTkxOjRo3Sz8vsON2/f3/mzp2r/5HwSbrkS2Z9n2Z1jggMDKR06dIsX76cH374wejYunnzZm7dukX//v1xcXHJcD06//zzD127diUpKYlOnTpRrlw57ty5w6pVq9i8eTNr1641OP7GxMTQsGFDrly5Qt26denXrx/R0dGMHTtW/+NZfvj1119ZvXo1zZo1o3Xr1mg0Go4fP860adPYvHkzhw4dMjrfPSkv+wCkt/fIkSOZPn06vr6+Bt/bk4mrnJ5rc7v/P0lRFAYMGMDChQtxc3Oja9eueHh4cOfOHXbv3s1ff/1F48aN9eX79evHokWLKFGiBAMGDMDCwoL169fz4Ycf8vfff7Nx40aTT1W99tpr7Nmzh5deeglHR0c2bdrE1KlTuXfvnkGiZcuWLXTs2BEHBwe6dOlCyZIluX//PpcvX2bmzJl8++23mJubM2rUKIKDgzl16hT9+vXLtC/i7777jh07dtClSxdatWqlT5DExMQwcuRIGjZsSJs2bXB3dyc8PJz169fToUMHfv31VwYNGpRlG+ocPXqUqVOn0qBBAwYNGsStW7dYuXIlrVq14uTJk5QvXx4gR7Gb0qdPH1auXImvry9DhgxBrVazevVqDh48SEpKCpaWljlanim7d+/m66+/pkWLFrz88svY29tz+fJl/vrrL9atW8e+ffsyvQ/MiqIo7Nixg2LFiunb5UnvvfceP//8MyVLlqR79+64uLhw8OBBvvnmGzZt2sS+fftM/g1nZ18rUaIEbdq0YevWrRw7dsyoK6i4uDjWrl1LqVKlaNmypX76H3/8Qf/+/bG0tNQn+69cucLSpUtZv349u3btytY9U26uxyF939m7dy+9evXCycmJzZs3M23aNPbu3cvu3buxsrLSlz1z5gxt2rTh3r17tG3blq5duxIdHc2aNWto06YN8+bNY+DAgZnGWZSOYcePH6dVq1bExsbSsWNHAgMDuXr1KqtWrWL9+vWsXbuWtm3bGm3D/v37mTx5Mk2bNmXQoEGEh4djZmZGWloaPXr0YP369QQEBNCnTx+sra0JCQlhzJgx7Nixgy1bthgdz1JTU2nTpg1xcXF0795dnz94+eWX+fvvv5kxYwYnT56kffv2ACxdupThw4fj7u5u8keDRo0aAbBt2zaD+5ZsUYQQz4yWLVsqgPLLL79kWfazzz5TAOWtt95SFEVRrl+/rgAKoPTq1UvRaDT6sleuXFGcnJwUS0tL5caNG/rpO3fuVAClX79+JtehW+aT8x9f14gRIxStVquf99VXXymA4uzsrPTo0UNJTk7Wz/vjjz8UQOnWrZvB8q5cuWK07uTkZKVly5aKubm5cvv27WzFlZFz584p5ubmipOTk3L69Gmj+bdu3dL//4EDBxSVSqWUKFFCuXPnjn66VqtV3nzzTQVQhg4dqp9es2ZNxdLSUomIiDBabmRkpMHnrNo7Mzlpo5zElF/rzEqzZs0UQLl+/brBdF2bjB8/PtN6j3t8/3vrrbeUtLQ0/bxz584pZmZmSsWKFQ2mZff7fzwmU99TRvteTvebnG5DVnK6fkXJ+DvJTHBwsAIo06ZNy3YdIUThOX78uGJhYaGoVCrl9ddfV/7880/l2rVrBuftJ4WEhCgqlUqpX7++cv/+fYN5v//+uwIoPXr0MJiuO574+PgoN2/e1E9PTU1VmjRpogDKoUOH9NPff/99BVBWr15ttP74+HglKSlJ//m3335TAOW3334zKAcovr6+GW4HoDRr1sxgWsWKFRVzc3Oj86NGo1FKliyp2NraKvHx8frp/fr1MzpOZnUunzp1qgIoP/74o9G8Tp06KYBy8ODBDOPWiY2NVdzc3BRXV1fl3LlzBvP+/fdfxd7eXvHy8jJoq3feeUcBlHfffdegvO460FTcps6zT27rk+foGzduGJy3dObNm6cAytdff20w3dR3mJN9ICO6c+mT3/OT83Nyrs3N/p+RuXPnKoBSq1YtJSYmxmCeRqMxuI5atmyZAiiBgYFKXFycfrruugtQvv32W4Nl+Pr6KoBSs2ZNJTo6Wj/94cOHir+/v6JWq5Xw8HD99B49eiiAcuLECaNYo6KiDD7r9v2dO3ea3DbdfFtbW5PLS0pKUkJDQ42mx8bGKpUrV1ZcXFyUhIQEo+158m9at++YOgbMmTNHAZQhQ4bkKPaMLF26VN+eDx8+1E9/9OiRUqdOHZPHnPHjx2e6LlP75927dw2OMzonT55U7OzslPbt2xtMz+l9z/nz5xXAaDk6ixYtUgCle/fuRt/BpEmTFEB5//33DabndF/T7c/Dhg0zWr/u72Ls2LH6aZcvX1asrKyUMmXKGN1f7Ny5UzEzM1Nq1qxpMN3U8Tkv18Nubm4G56+0tDSla9euCqBMnjzZYHpAQIBiZWWl7Nq1y2BZYWFhSsmSJRUbGxvl7t27Rtv+pKJwDNNqtUqlSpUUQAkODjYov23bNkWlUinu7u7Ko0eP9NMf/7ucO3euUdy6/Wjo0KEGMWs0GuXtt99WAOWnn34yqKNbXkb5A2dnZ6VBgwYGfzv79u1TAKV69eom2y82NlYBlBo1apicnxlJpArxDKlYsaICKJs3b86y7OzZsxVAeemllxRF+e9Aa2Zmply7ds2ovC7x+uWXX+qn5TWRamdnZ3ChoSjpSSndgfDJBE1aWppiYWGh+Pn5Zbl9iqIoK1euVABl4cKF2YorI8OGDVMAZerUqVmWHTRokAIos2bNMpoXERGh2NjYKHZ2dkpKSoqiKOlJS1tbW4OLiozkJZGaEVNtlJOY8mudWSmIRKqtra3BzYZO06ZNFUB/os3J9/94TDlJpOZ0v8npNmQlp+tXlNwlUkeMGKEASkhISLbrCCEK1/LlyxVPT0/9uRlQnJyclA4dOihLly41+OFVUf5Ltpw6dcrk8mrUqKGYmZkZHJ90x5Nff/3VqPyCBQsUQJkxY4Z+mi6JtmXLlizjz89E6tdff60Ayvfff28wfevWrQqg9O3b12B6bhKpUVFRirW1tdHN7a1btxQzM7Ns39D99NNPCqBMnz7d5PzRo0crgLJx40ZFURQlJSVFsbOzU+zs7IySYoqiKOPGjcu3RGpGtFqt4ujoqLRo0cJgemaJ1OzsAxnJbhIiJ+fa3Oz/GalSpYoCKIcPH86ybOvWrQ2+z8edOnVKAZSAgACD6brk1rZt24zq6L7v9evX66fptu3ChQtZxpPdROqoUaOyXNaTvv/+e5PXEpklUhs1amS0nJSUFMXc3FypVatWjmLPSJs2bTLcJ//55598S6RmpnPnzoqVlZXB9VpO73t0x7MBAwaYnF+zZk3FzMzM5H1CWlqaUqxYMaV48eIG03O6ryUlJSnOzs6Km5ubQVJMURSlQYMGCmDwwIbueLB27VqTMXfv3l0BDH5UMnV8zsv18MSJE43qXLhwQVGpVErZsmX109atW6cAyujRo03GOn369AxjeFJROIbt3btXAZQ6deqYLK9b3pIlS/TTdH+XphKYGo1Gvw+lpqYazY+NjVVUKpXR+rLKH2T0d1a6dGnFzMzM5I97iqIo1tbWipubm8l5mZFX+4V4hij/75soJ52CP1m2VKlS+v5WHtesWTOmTJnCiRMn8hbkY8qVK4ednZ3BNN2gDM7Ozkav05iZmVG8eHFu375tMP3WrVt88803bN++ndDQUBITEw3m37lzJ09xHjx4EICXXnopy7K6Pscef9VEx8PDg6pVq3L48GEuXbpE5cqV6du3L6NHj6ZSpUq8+uqrNG3alIYNG+b74BQ5aaP8iqmgv5e8KleuHI6OjkbTfXx8gPTXYBwcHHL0/edWTvcbnexuQ0GtP6eOHTuGSqWiRo0auV6GEOLp6tWrF927d2fnzp3s3buXEydOsHfvXjZt2sSmTZsIDg5m3bp1+ldW9+3bh7m5OatWrWLVqlVGy0tOTkaj0XDp0iWjVzZr165tVF53PHu878w+ffrw008/0a1bN15++WVatWpFgwYNCnwwv759+/L5558THBzM+++/r5+endf6s8vNzY1XX32VhQsXEhISQrNmzYD0LmY0Gg3vvvtutpaj69/29OnTJvvLu3jxIgDnz5+nQ4cOXLhwgUePHlG/fn3c3NyMyjdv3pyJEyfmcqsMpaamMnfuXJYtW8a5c+eIj49Hq9Xq52fn+uBp7gM5OdfmZf9/3KNHjzh79ixubm7UqVMnyxh153FdH7yPCwwMpHjx4ly6dImHDx9ib29vMD+7f3d9+/Zl1apV1KtXj169etGiRQsaNGiQ49ffH1e3bt0M5507d45vv/2W3bt3ExYWRnJyssH8nFxHmtpGCwsLPDw8ctQvb2Z01zi6v9nHNWnSBDMzs3xZD8DGjRv5+eefOXbsGFFRUaSlpRnMj4qKyvW9RHR0NIDJ7kMSExM5ceIErq6u/PTTTybrW1paEhYWRnR0tNGxJLv7mpWVFb1792bOnDls2LCBHj16AHD58mUOHDhAkyZNDAYX0x3vdu/ebbL/57t37wLpx7tKlSpluO15uR429b2XL18eDw8Prly5woMHD3BwcNDHeuvWLZPH5suXL+tjzS8FeQzLrM0AWrduzapVqzh+/Dh9+vQxmGfq7//SpUtERUXh7+/Pl19+aXKZNjY2Jtsns/zBnTt3TN5/eHt7c/36dSIiIihRooTRfFdXV8LDw9FqtajV2R9CShKpQjxDvLy8uHDhArdu3cqybGhoqL7O4zw8PEyW103PqvP+nDDVd5Wur5OM+rUyNzc3uFi4du0adevW5f79+zRp0oR27drh5OSEmZmZvh+yJy+8ckrX6bWpg+uTdO2T0UjmuvbWLXPUqFEUL16c2bNnM2vWLP1FSf369ZkyZYpRHze5kdM2yo+Ynsb3klcZ9eup2wc1Gg2Qs+8/t3K63+hkdxsKav05odVqOXnyJOXKlctWclcIUXRYWFjQtm1bfR9nGo2GlStXMnDgQLZu3crPP//MyJEjgfSb8LS0NCZMmJDpMk0N2mHqmGbqeFa7dm327dvHV199xerVq/njjz+A9P5PP/nkk2wnG3PK29ubtm3bsnnzZo4fP07NmjWJj49n9erV+Pr6mkxi5cbQoUNZuHAhc+fOpVmzZmg0GubPn4+DgwOvvfZatpahS4bMnz8/03K67yG754H88Oqrr7J69WrKlClDt27d8PT01Pcf+OOPP2br+uBp7gM5OdfmZf9/XE6vPeLi4nBycsLGxsbkfC8vL+7du0dcXJxRIjW7f3fdunVjy5YtfPfddyxcuJBff/0VgCpVqhAUFGSy7+CsZLS/HTx4kJYtW5KWlkarVq3o0qULjo6OqNVqTp48ydq1a3N0HZnZd5jda6Ws6L4Da2trk+spVqxYvqznp59+YuTIkbi4uNCmTRtKlSqFra0tKpWKNWvWcOrUqTxdY+v2oaSkJKN5MTExKIpCdHR0tvbxJxOp2d3XAAYMGMCcOXMIDg7WJ1J1P1o9OSCf7nj3/fffZxlTZvJyPZzZPXRERATx8fE4ODjoY125cmWGgxdmJ9acKMhjWF7azFQdXftcvXo1yxielFn+IKP5unmmBuSD9B8PrKyscpREBUmkCvFMady4MTt37mTbtm28/fbbmZbdvn27vs7jdL/YPUk3PbOO+wvDtGnTiI6OZsGCBUYn1aVLlxqMqJ5bupPPnTt3shzcQdc+ERERJk9a4eHhBuUgveP11157jfj4eA4cOMD69ev59ddfeemllzh16hQBAQF5ij83bZTXmJ7G9wLoT2pP/hKvk5fEn05Ovv/cys1+86yt/+LFizx69CjTJ3CEEM8GMzMzevXqxZkzZ/jyyy/ZsWOHPpHq5ORESkqKyZGb81PdunVZu3YtKSkpnDhxgr///puZM2cyZMgQ7Ozs6Nu3b4Gst3///mzevJng4GBq1qzJihUrSExMpF+/fjl6IygzderUoU6dOqxcuZLIyEj27dvHnTt3GDJkiFESLCO64/WxY8eoWbNmtstnNCKz7jzwpMfPw08O/GHqHHz06FFWr15Nq1at2Lx5MxYWFvp5Wq2WqVOnZhmrTmHtA5nJr/3/8WuP7K43JiaGxMREk8nU/LqOaNeuHe3atSMxMZHDhw+zefNmZs+ezSuvvMLOnTtNPpWXmYz+Zr788ksSExP5559/jH6gmDJlCmvXrs31NhQUJycn7t+/T1JSklEyNS0tjaioKKNBxjK7jjX195OWlsb48ePx9PTk+PHjRj9wHDhwII9b8V9CUJfQepxu/6latarJ0djzU926dalUqRKbN2/m3r17uLu7s2jRIuzs7HjllVdMxhUdHa0fJDE38nI9fPfuXZODc+nuoXVPhOrqrly5Up8gLkpyegzL7rnDVJuZ+vvXlevcuTPr1q3LVgwFRavVEhsbq39yNydylnYVQhSqAQMGYG5uzpo1azhz5kyG5TZv3syRI0dwdXWlZ8+eBvNu3brFjRs3jOqEhIQAGDwSr3tFJb9+yc2NK1euABhtB/wXc17Vr18fSG+3rOhuVkyNbhsZGcnZs2exs7MzeaJ1dHSkXbt2zJw5kw8++ICkpCS2bNmin5/b9s5LG2UVU0GsMyd0iU3dE9aPi4uLy/FI06bk5PuH3H1Pedlv8sPTWL+uWxB5rV+I54fu6XJd10IADRo04MGDB5w6deqpxGBpaUm9evUYO3YsixYtAmD16tVZ1lOr1bm6funatSsuLi4sXbqU1NRUFi5ciEqlMhr1OCPZPUe89957pKSk8NtvvzF37lyAHD1l2aBBAwD27NmTrfIVKlTA1taWM2fOEBMTYzR/165dJutldh4+cuSI0TTd9UHXrl0NkqgAhw8fNuoGKDtyuw8UxHVsfu3/dnZ2VKlShejoaJPt+CTdedzU93T27Fnu3btHQEBAthPxWbGxsaFZs2Z8/fXXfPfddyiKYpDczGvbXrlyBVdXV5NPeefndaQpuY29Vq1aKIpiMr49e/aYXF5O/36ioqKIjY012eXWw4cPTb7WnlNVqlTBzMyMCxcuGM2zt7enSpUqXLhwgaioqDyvKyv9+vUjLS2NP/74g3/++Ydbt27x8ssvG+3HOT3eZSQv18OmvveLFy9y9+5dypYtqz9f5lesUDSOYZm1GcCOHTsAsv0gRYUKFXB2dubQoUOkpKRkq05BuXjxIoqiUL169RzXlUSqEM+Q0qVL88UXX5Camkrnzp1N/lIYEhLCG2+8AaS/GvLkiUij0fDJJ58Y9FV19epVZs+ejYWFBa+//rp+uu51jex0JVBQdP2gPHnhuHXrVubNm5cv6xgyZAjm5uZ89dVXnD171mj+4322Dhw4EIDJkycb/DKnKAoff/wxCQkJ9OvXT3/zsHnzZpOvEujqPv5UQW7bO6dtlJOY8muduVWhQgWcnJxYs2aNQXunpaUxatSoXN2QPSkn3z/k7nvK6X6T357G+u/duweY7nNLCFE0LV26lG3bthlcE+hEREToX+1t2rSpfrqu79DBgwcbHR8h/XXRvXv35imuPXv2mHxaKyfnKTc3NyIjI3N8nrCysqJPnz5ERUUxffp09u7dS9OmTSlTpky26mf3HNG7d2/c3Nz48ccf+fvvv2nYsCGBgYHZjnPAgAG4uLgwadIkk0+pKYrC3r179TeqFhYW9O3bl0ePHvH5558blL169SrTp083uZ569eoB8PPPPxtMP3nypMk6GV0f3Lt3j6FDh2Zr2yB/9gEXFxdUKpXJJFZu5ef+P2LECCD9OuTJbVUUhbCwMP3nt956C4AxY8YYvBKcmpqqj2nQoEHZ3xATtm/fTkJCgtH0/Lxm1fHz8yMmJsboXmb+/Pls3bo1V8vMrtzGrruW+vzzz3n06JF+ekJCAp988onJOrq/n/nz5xtce0dFRfHRRx8ZlS9evDi2trYcO3bM6HseOXJkviQ3HRwcqFWrFmfPnjXYDp0PPviA1NRUBgwYYPJHlwcPHnDo0KE8xwHp/fKamZkRHByc4Wv9AMOGDcPS0pIPPvjAZAJYo9Fk+GPQ4/JyPTx9+nSDfUZ3T60oikHMXbp0oWzZssyZMyfDJy5PnDhh8ongJxWFY1jDhg2pWLEihw8fZvHixQZl//nnH1atWkWxYsXo2rVrttZvbm7OyJEj9ecEU8ecqKgoTp48mYOtyh3dOBm56bZHXu0X4hkzbtw4EhMT+eabb6hZsyatW7cmMDAQrVbL0aNHCQkJwdzcnJkzZxokRXUCAwM5dOgQtWvXpm3btsTExLBixQri4uKYNm2aQQfO5cuXx8fHhz179vD6668TEBCAmZkZXbp0ydHFfl689957/Pbbb7zyyiu8/PLLlChRgrNnz7JlyxZ69erF8uXL87yOSpUqMXv2bN59911q1apF586dCQgI4P79+5w4cYLk5GT903YNGjTgs88+Y8qUKVSpUoVXXnkFJycntm3bxvHjx6latSqTJ0/WL7tPnz5YWlrSpEkT/Pz8UKlUHD58mD179uDv70+vXr30ZXPb3jlto5zElF/rzC0LCwtGjx5NUFAQNWrU0L8is3PnThRFoVq1anl+KiQn3z9k/j2Z6ugdcr7f5LensX7dL9Hjx4/nwoUL2NvbU65cOZPHISFE0XDo0CGmT5+Op6cnjRs31g9Gef36dTZu3EhiYiINGjRg2LBh+jotWrTg22+/5ZNPPqFcuXJ06NCBMmXKkJiYyK1bt9i9ezd+fn55ugn6/vvv2bp1K82bN6dMmTI4Ojpy6dIlNmzYgI2NDaNGjcpyGW3atGHJkiW0b9+epk2bYmVlRbVq1ejcuXOWdfv378/s2bMZM2aM/nN2Zfdcbm1tzcCBA/n222+BnD2NCukDZKxcuZJu3brRsGFDWrZsSeXKlbGwsCA0NJRDhw5x69Yt7t+/rx8obPLkyezYsYM5c+Zw4sQJWrRoQVRUFCtWrKBly5asWbPGaD0DBw7ku+++49tvv+X06dMEBgZy7do11q1bx8svv8yyZcsMytepU4dGjRqxatUqGjZsSOPGjbl79y6bN2+mfPnyeHt7Z2v78mMfsLe3p0GDBuzfv5/OnTtTq1YtzM3Nadq0qcGPAzmRn/v/oEGD2Lt3L7///jtly5ala9eueHh4EB4eTkhICF26dOHHH38E0hPv69evZ8mSJVSqVInu3btjYWHB+vXruXTpEq1atcpWm2Tmww8/5Pr16zRv3hw/Pz+sra05ffo0W7duxc3NjcGDB+vLtmnThm+//ZbPPvuMs2fP6n9E/eKLL7K1rlGjRrF161YaN25Mr169cHJy4ujRo+zdu5eePXvy119/5WlbMpPb2Hv37s1ff/3FypUrqVy5Mt27d9f3W+ru7m6yn+G6devSokULdu7cSe3atWndujUxMTFs3LiRVq1aGV3DqtVqRowYwddff03VqlXp2rUrKSkp7Ny5k5iYGP2y8qpnz54cPnyY7du3GyXA+vfvz/Hjx5kxYwb+/v60a9cOPz8/YmNjuXHjBrt376Zt27Ymjxc55eXlRbt27di0aRPnz5/Hz88vw0GdgoODGTBgAFWqVKF9+/YEBASg0WgIDQ1l3759JCcnZ9nlV16uhxs2bEj16tX1++vmzZs5c+YMderU4YMPPtCXs7CwYPXq1bRt25auXbtSr149atasib29PaGhoZw4cYKLFy9y4sQJkwP/Pa4oHMNUKhULFy6kdevWvPnmm6xYsYKqVaty9epVVq5ciaWlJb///ju2trbZjuGLL77gzJkzzJs3T/+3ULJkSSIjI7l69Sp79+5l6NCh+uNfQdm6dStqtZpu3brlvLIihHgmHTt2TBkwYIBSpkwZxcbGRrG1tVXKly+vDB06VLl06ZJR+evXryuA0qxZM+X27dvKa6+9phQrVkyxsrJSatasqSxdujTD9bRu3VpxcnJSVCqVAii//fabwTL79euX4bpMARRfX1+T83x9fZUnD0379u1TWrRooTg7Oyv29vZKo0aNlNWrVys7d+5UAGX8+PEm1/9kXFk5cOCA8vLLLyseHh6KhYWF4uHhobRs2VL5/fffjcquWLFCadq0qeLg4KBYWloq5cuXVz7//HMlPj7eoNzPP/+sdO/eXSlTpoxia2urODk5KVWrVlXGjx+vREVFGS03s/bOTE7aKKcx5cc6s9KsWTMFUK5fv240T6vVKt9++61StmxZxcLCQvH09FTeffddJTo6Wl/vcVntf/369TO5rpx8/xl9T1nte9ndb3K7DVnJ7voVJfPvJCOzZ89Wypcvr1hZWSmAMmLEiBzFJ4R4ukJDQ5XZs2cr3bt3V8qXL684Ojoq5ubmSvHixZVWrVopc+bMUVJSUkzWPXDggNK7d2+lRIkSioWFheLq6qpUqVJFGTJkiBISEmJQNrPjialzxtatW5UBAwYolSpVUpycnBQbGxulbNmyyqBBg5Tz588b1P/tt99Mnivv3bun9O3bV/H09FTUarXRsTmzY6yiKErlypUVQLGzs1MePHhgskxGx+LsnssvXLigAIqbm5uSmJiYYSyZuXnzpjJixAglICBAsba2Vuzt7ZVy5copvXr1Uv744w9Fo9EYlL93754yePBgxcPDQ7GyslIqVqyo/Pjjj8q1a9cyPH+dP39e6dy5s+Lo6KjY2Ngo9erVy/R8Hx0drQwZMkTx9fVVrKyslDJlyiifffaZ8ujRI8XX19foGtDUd5iTfSAzV69eVbp166a4ubnp9wNdvHk51+Zk/8/KH3/8oTRr1kxxcnJSrKyslFKlSimvvPKKsm/fPoNyGo1GmTNnjlKnTh3F1tZWsba2VgIDA5WpU6cqycnJRss1dU2tY6rNly9frvTp00cpV66cYm9vr9jb2ysVKlRQRo8erdy6dctoGdOnT1cqVaqkP+c/vi5d2+3cuTPD7V6/fr1Sr149xd7eXnFyclLatGmjhISEZPg3ndm+k9E1p6k6WcWemZSUFGXy5MlK2bJlFUtLS6VEiRLKiBEjlLi4uAzXFRsbq7z77ruKp6enYmlpqQQEBCjffPONkpaWZnL/S01NVb7//nulYsWKirW1teLh4aG88cYbyo0bN0zuk7m574mKilKsra2Vbt26ZVhm8+bNSpcuXfTXxe7u7kqNGjWUDz74QDl+/LhB2Zzua4/7888/9d9BVvcO586dU9566y3Fz89PsbS0VJycnJQKFSoo/fr1U9avX29QNrO/39xcD1+9elWZOnWqEhAQoP/uR48ebbKOoihKZGSk8vnnnytVq1ZVbG1tFRsbG6VMmTJK586dlV9//VVJSEjIdFt1isox7PLly0q/fv0Ub29v/f7Qs2dP5cSJE0Zls/q7VJT0+7slS5Yobdq0UVxdXfX3ePXq1VPGjh1rlM/Iaf5AJ6Prj9jYWMXa2lrp1KlThjFmRvX/oIQQz7kbN25QunRpmjVrlq1XH4QQQgghRMFZtmwZffr04YMPPuC7774r1Fh014n9+vUjODi4UGMRQhS8YcOG8csvv3D16tVcDbbzomjevDkhISFcv37d4M1N8WybPn06o0aNYteuXTkeRA+kj1QhhBBCCCGEeKo0Gg3ffvstarWa9957r7DDEUK8YIKCgrCzs2PixImFHYoQT1VCQgJff/013bt3z1USFaSPVCGEEEIIIYR4Kvbu3UtISAghISEcP36cQYMGZXsgKyGEyC/FihXjjz/+4MSJE2g0Gv0I8UI8765fv84777yTo/7PnySJVCGEEEIIIYR4CrZv386ECRNwcXFhwIABBT6YhhBCZKRDhw506NChsMMQ4qmqXLkylStXztMypI9UIYQQQgghhBBCCCGEyIL0kSqEEEIIIYQQQgghhBBZkESqEEIIIYQQQgghhBBCZEESqUVAWloat2/fJi0trbBDEUIIIYQQLwi5BhVCCCGEyBlJpBYBERER+Pj4EBERUdihFGkajYaIiAg0Gk1hh/JMkvbLG2m/3JO2yxtpv7yR9ss9abvnn1yDZo/8LeSNtF/eSPvlnrRd3kj75Y20X+4V9baTRKoQQgghhBBCCCGEEEJkQRKpQgghhBBCCCGEEEIIkQVJpAohhBBCCCGEEEIIIUQWJJEqhBBCCCGEEEIIIYQQWZBEqhBCCCGEEE9Ramoq48aNo1SpUlhbWxMYGMiSJUuyVVdRFKZPn05AQABWVlYEBATw008/oShKAUcthBBCCCEkkSqEEEIIIcRTNHjwYL766iu6devGjBkz8PHx4fXXX+f333/Psu7EiRMZNWoU9evXZ9asWdSrV4+RI0cyadKkpxC5EEIIIcSLzbywAxBCCCGEEOJFceLECYKDg5k4cSJjx44FYNCgQbRs2ZKPPvqI3r17Y2lpabJuREQEU6ZMYeDAgcyfP19f18zMjMmTJzN48GA8PT2f2rYIIYQQQrxoJJEqhBBCCCHEU7JixQrUajVDhw7VT1OpVAwbNoyePXuyc+dO2rVrZ7Lu2rVrSU5OZvjw4QbThw8fzsKFC1m7di3vvPNOvsesKApRUVEkJSWh0WjyfflFnaIoJCUlkZiYiEqlKuxwnjnSfnnzePuZm5tjbW1NsWLFpC2FEKKQyKv9QgghhBBCPCXHjh3D398fV1dXg+n16tUD4Pjx45nWtbKyIjAw0GB6jRo1sLS0zLRubimKwp07d4iKiiIlJSXfl/8sUKlUWFlZSeIql6T98ubx9ktJSSEqKoo7d+5Iv8hCCFFI5InUF4CiKJwMjWXLuQjiE9NwtDGnfWVPqvs4ywWNEEIIIcRTFBYWhpeXl9F0b29v/fzM6np4eKBWGz4LoVar8fDwyLQuQHx8PPHx8frP4eHhAGg0mgyfNI2KiiI+Pp7ixYvj5uaW6fKfV4qikJaWhrm5uVw754K0X9482X7R0dHcu3cPCwsLihUrVtjhFWkajQatVvtCPkmfH6T98kbaL3vuTpxI8qXLT0xVSE1NI9HCHPjvvGEVUA6PcePyPQYzM7MclZdE6nPuYsQD3l9xknNh6RfNahVoFZgbco3K3o5M61Wd8p4OhRylEEIIIcSLITExESsrK6PparUaCwsLEhMTc1wXwNraOtO6ANOmTWPChAlG06OjozNcbmxsLBYWFjg5OZGWlpbp8p9XiqLob4QlEZhz0n5582T7OTk5cf/+fe7fvy9PpWZBq9USFxcHYPQDlMiatF/eSPtlz8Nz/6I5e9bkvCdT0KmpqagjI/M9hpz2Ly+J1OfYxYgHvPzzfhJS/rvo1T52rj0fHs/LP+9n5ZCGkkwVQgghhHgKrK2tSU5ONpqu1WpJTU3F2to6x3UBkpKSMq0L8P777zNo0CD95/DwcOrWrYubmxvu7u4m6+j6tTQ3f3FvG3TJKnmiMnek/fLGVPtZWFhgbm6e4d+tSKdLQBcrVizHT5wJab+8kvbLnkQLC6OEaUYsLCyKxHHvmbgiSk1NZdKkSQQHB3Pv3j0CAgL49NNPee2117KsqygKP/30E7NmzeLmzZv4+voybNgwhg8fbnAiDw4OZsCAASaXcfnyZcqWLav/3L9/fxYuXGhUrkSJEty+fTsXW5j/FEXh/RUnSUhJM0iePk6rQEJKGu+vOMmG4Y3lwkYIIYQQooB5e3tz8+ZNo+m61/J1r/hnVHfHjh1otVqDp1u0Wi13797NtC6Ao6Mjjo6ORtPNzMwyvMnTXR++6NeJKpVK/0/knLRf3phqP5VKJcmZbFCr1Zke40TmpP3yRtovazk5LxSV494zkUgdPHgwv//+O0OHDqVq1aqsWbOG119/nbS0NN58881M606cOJGgoCD69u3Lxx9/TEhICCNHjiQ2NpZxJvpWCAoKwt/f32Caqcd8LSwsWLBggcE0Ozu7XGxdwTgZGqt/nT8zWgXOhcVz6nYc1X2cCz4wIYQQQogXWM2aNfnnn3+IiYkxGHDq0KFD+vmZ1Z03bx6nT5+mevXq+uknTpwgJSUl07pCCCGEEM+yxJSi0d9skU+knjhxguDgYCZOnMjYsWMBGDRoEC1btuSjjz6id+/eWFpamqwbERHBlClTGDhwIPPnz9fXNTMzY/LkyQwePNgoSdquXTvq16+fZVxqtZo33ngjj1tXcLaci8hZ+bMRkkgVQgghhChgvXr1YurUqcyePZsvvvgCSH+TaObMmbi7u9OiRQsgfZCnqKgoSpUqha2tLQBdu3Zl1KhRzJw5k3nz5umXOWPGDCwtLenatevT3yAhhBBCiGxSFAVNbCypYWGkRUSgefAg23WvRz2kgqIU+tsNRT6RumLFCtRqNUOHDtVPU6lUDBs2jJ49e7Jz507atWtnsu7atWtJTk5m+PDhBtOHDx/OwoULWbt2Le+8845RvQcPHmBra5vlI8NarZaHDx/i4OBQ6F/kk+IT0/QDS2VFrYK4xNSCD0oIIYQQ4gVXq1Yt+vbty/jx44mMjNS/bbVr1y4WLFigH/Rp5syZTJgwgZ07d9K8eXMg/dX+Tz75hEmTJpGamkrTpk0JCQlh0aJFjBs3Di8vr0LcMiGEEEK86BStFtVj3Q/FLFlC0tlzpIaHkRYeQWpEBEpS0n/l7R3IbjYtIUVTJN6mLvKJ1GPHjuHv72/w6hNAvXr1ADh+/HiGidRjx45hZWVFYGCgwfQaNWpgaWnJ8ePHjeq0adOGhw8fYmlpSatWrfjuu++oVKmSUbmUlBQcHR159OgRTk5O+qcLnJ2ds9ym+Ph44uP/e+0+PDwcSO+MWNchcV45WJllK4kK6clWR2uzfFt3QdFoNGi12iIfZ1El7Zc30n65J22XN9J+eSPtl3tPs+2KQn9XT9O8efPw9fUlODiYOXPmEBAQwKJFi7L1ttOECRNwcXFh1qxZLFu2DB8fH6ZNm8aoUaMKPnAhhBBCvNBSIyJIDQsjNSyctIhwUsPC06eFh5MWHo5D+3Z4BQUB6U+f3v97OykHD5hcllalJjE5jZx0klkU3qYu8onUsLAwk7+u6zrT13XMn1FdDw8Pg874If21fA8PD4O6tra29OvXjxYtWuDk5MSJEyeYNm0aDRs25OjRowaDTXl5efHhhx9Sq1YtALZt28a8efM4duwY+/fv1z9JkJFp06YxYcIEo+nR0dFZ1s2uut5W/JLD8pGRkfmy7oKi1WqJi4sDMPpORdak/fJG2i/3pO3yRtovb6T9cu9ptp2p/uifZ5aWlkyaNIlJkyZlWCYoKIig/9+IPE6lUjF69GhGjx5dgBGK54FuMN3r16/j5+f3zK5DCCGeV+FBQSRfuqz//Cg5jfsJKWi0CmZqFS62lthZpaftrALK6ROUBUFRFLRxcaSGh+v/pYWHY1O9Og6tW+vL3Ro0iJQrVzNczvEjF/hl1j4i45OIfJjMy3FOBHhWJtLGmUhbZyJtnLln40KkjTMx1g5M3fszlWNuZDvOovA2dZFPpCYmJppMLqrVaiwsLEhMTMxxXQBra2uDur169aJXr176z926daNTp07Ur1+foKAgFi9erJ83ZcoUg2W9+uqrVKxYkQ8//JBFixYxaNCgTLfp/fffNygTHh5O3bp1cXNzw93dPdO62dW8WDEq777D+YgHmT6ZqgIqeTvSrIpvkeue4Em6J2KKFSv2wj25kh+k/fJG2i/3pO3yRtovb6T9ck/aThQURVE4GRrLlnMRxCem4WhjTvvKnlT3cS7y16OiYCxcuJD+/fvj7+/PlStXCjscI3v37mX79u2MGjUqwzcQfX196d+/v9EDMznZti+//JJ69erRpk2b/ApdCFGEJV+6TOJjb0qrAbcnymSc8coZbVISaRERqO3sMH8s73R7xEiSr1whNTwcxUR+7XqDtmyPLsa9B8lEPkim/yNLqgPJanMibdMTounJ0fT/3nL04EJorL7+svL/JWEdrMxxd7DC3cEKfwcrijtY43A0Z2lJJxuLnG56vivyiVRra2uSk5ONpmu1WlJTU7G2ts5xXYCkpKRM6wLUqVOHJk2asH379izjHDFiBJ999hnbt2/PMpHq6OiIo6Oj0XQzM7N8vVGZ9moNXv55PwkpaRknU1XwzctVMTcv8rsCkJ5Az+92epFI++WNtF/uSdvljbRf3kj75Z60nchvFyMe8P6Kk5wLS+/mSten/9yQa1T2dmRar+qU93Qo5CifD3379qV379759sZbQVq8eDGWlpZcvXqVAwcO0KBBg8IOycDevXuZMGEC/fv3N5lIPXPmDLdu3aJjx45G87K7bdevX2fs2LH8+uuv+R2+EOI5kJ0R6xVF4cGWLY+9bh9GangEKWFhKPfvAxDxcj/OtnyZyIfJ3ItPpvexs3hEG77prUHFfWtH7tk6czDOnFUn7ujnfV/1ZVKqWxBvaQsqFWoVuNlb4W5vRXFHKwIdrGj1/ySpu4MVxf+fOHV3sMLW0jjvdC7YFjJ+0dxI+yqF/xZTkc+eeXt7c/PmTaPputfyda/4Z1R3x44daLVag1fStFotd+/ezbSuTqlSpdi/f3+W5SwsLPD09CQmJibLsk9LeU8HVg5paPJiVUdRYPflKKqUcC6cIIUQQgghxAvhYsQD/Y/8Oo9fl54Pj+fln/ezckhDSabmUUJCQrYGzy0KwsPD+eeffxgzZgzTp09n8eLFRS6RmpWNGzdSvHhx6tSpYzA9J9t25MgRAH33cUII8bg7ETGU/GcnaXcjSA0LJ/HOHZLuhPOoc0/CAxsQ+SCJew+SafvZF1glJWS4nOPHLjJDc0n/2bxEPWw8kvVPlkbaOhNt7YSFlQXFHawp7mDFS48lRIs7WOPu+F/i1M3OCjN17t8mca1SkRPRj0jIRqI43suXaiWdcr2u/FLkE6k1a9bkn3/+ISYmxmDAqUOHDunnZ1Z33rx5nD59murVq+unnzhxgpSUlEzr6ly7do3ixYtnWS4pKYnw8HCaNm2aZdmnqbynAxuGN+bU7Ti2nI0gLjEVJxsL4hJTWHo4FIAZO67QvUYJvJxsCjlaIYQQQgjxPFIUhfdXnMz0TSmtAgkpaby/4iQbhjd+aq/5BwUFMWHCBC5dusT333/Pn3/+SXJyMi+99BJz5szR34MMGDCAXbt2cePGDYP6pvoJ1S3zwoULTJkyhbVr16JWqxk4cCDffPMNUVFRDBs2jL///hsLCwuGDRvG+PHjDZZ79+5dxo0bx4YNG4iKiqJUqVK8/fbbfPTRRwZto1vXuXPn+Oabb9iwYQOOjo6MHz/eZP+l4eHhBAUFsXHjRiIjI/Hy8qJNmzZMmzYNBwcHbt68ydSpU/nnn3+4efMmlpaWNGnShK+//prKlStn2Z4XLlzA1taWUqVKZav9lyxZglarpX///ly/fp3ly5fz448/YmGR+eubWW0HwNmzZxkzZgy7d+8mJSWF6tWrM378eIPBih8+fEhQUBArV64kPDwcJycnqlSpQlBQEE2aNNG3L0Dp0qX19Xbu3Enz5s0B2LBhAy+99JLRPpvdbWvQoAEHDx4E/ru/NTMz4+HDh8/EE8VCiNx5lJxGdnuhLxEVyu333jOa/mdycZae/e84Ud7aFTszW4PEqK5P0khbZ9KKFaeCi4P+KVH3Zn0p7mBNk/8nS3UJU3sr86dyHvaaEET8Yz+0mrpGUKvA1tKcVe81LBJdABX5RGqvXr2YOnUqs2fP5osvvgDSL8RmzpyJu7s7LVq0ACAqKkp/gWFrawtA165dGTVqFDNnzmTevHn6Zc6YMQNLS0u6du2qn/ZkohZgx44d7Nu3j4EDB+qnJScno9Fo9OvQmTx5MmlpaQYn5aJCpVJR3cfZYGSzpFQN+65EcysmgcRUDV9tPM/M17JOLAshhBBCCJGSpuVObPZ7bvs3PF7/hlRmtAqcC4tn89kIKnoZd4WVkRLONlia521QtD59+uDp6cmkSZO4fPkyM2bMwMLCgj/++CPXy+zduzcBAQFMnjyZbdu28d133+Hi4sKKFSuoWbMmU6ZMYeXKlQQFBVGtWjW6desGpN/b1K9fn+TkZAYPHoyXlxd79uzhk08+ISwsjB9//NFoXb169aJ06dJ8+eWXGXZvFhERQd26dYmMjOTtt9+mSpUqhIeHs3r1aqKjo3FwcODIkSPs2rWL7t274+fnR3h4OHPnzqVp06acO3cuy8HhKlasSLNmzdi1a1e22mjRokXUrl0bf39/evfuzR9//MGWLVvo3LlzhnWysx2XLl2iUaNGWFlZMXr0aOzt7fntt9/o0KEDq1evpkuXLgAMGTKEFStWMHToUCpXrsz9+/c5dOgQJ0+epEmTJvTo0YMLFy6wfPlyfvjhB4oVK6bfTki/jzx48CCjRo3K9bZ99NFHfPnll9y/f18/CJ2NjQ3W1tYoSiYDXgghnlmKohAf+wDnHNTRouK+tYNBv6TnXX2xNFPrE6Mbh0zG3dFa/2p9Jd2TpP9/ejSv58qCkNXb1BW9HPnh1eoEeBSNt1WKfCK1Vq1a9O3bl/HjxxMZGUnVqlVZs2YNu3btYsGCBfpf6GbOnMmECRMMfhn09vbmk08+YdKkSaSmptK0aVNCQkJYtGgR48aNw8vLS7+exo0bU6NGDQIDA3F2dubkyZPMnz8fDw8PgxFTw8PDadCgAT169CAgIACVSsX27dtZv349zZo1o0+fPk+zeXLN2sKMsZ0q8fbvRwHYcDqc1+tF08D/ya6NhRBCCCGEMHQnNpEW3+0qsOW/98fxrAs9ZueHzSldzC5P6wwICGDJkiX6z7qHN37++Wejhyiyq2bNmsyfPx+Ad999l3LlyvHFF1/w2Wef8dVXXwHpT7p6e3szf/58fSL1iy++ICEhgdOnT+Ph4QHAO++8g7e3N99//z2jR4/G19fXYF0VK1bkzz//1H8ODg42iufTTz/lzp07hISE0KRJE/30oKAgfcKuQ4cO9OzZ06Be3759qVSpEvPnz+fzzz/PVVuYcu7cOU6dOsV3330HQLt27XB1dWXx4sWZJlKzsx1jxowhISGBQ4cOUaFCBQDefvttqlatysiRI+nUqRNqtZoNGzbw9ttvM23aNJPrCgwMpHr16ixfvpxu3boZPN0LsGXLFtRqNW3bts31tvXo0YNPP/2U+vXr88Ybb2Sj5YQQzyIlLY3QFasI/2c35qeO4fwgNtt1LzmX5JdXPqNrHV+KO1hT2cGKFg5WDHOwwsnGokg8qZkXj79NvflMGHdjHuDh6sBLVb2pVtKpSG1fkU+kAsybNw9fX1+Cg4OZM2cOAQEBLFq0KFsnmQkTJuDi4sKsWbNYtmwZPj4+TJs2zegXwx49erBp0yY2bdrEo0eP8PT0pH///owfP54SJUroyzk7O9OmTRu2bdvGwoULSUtLo0yZMkyYMIGPP/74mRm0CaB1xeI0L+/OrouRAAStO8eGEY2xMCt6v1AIIYQQQghRkIYOHWrwuVmzZkyfPp2bN2/qnz7Mqbffflv//yqVijp16nD16lWD6dbW1lSrVo2rV68C6QncFStW0L17d8zMzIiKitKXbdeuHd9++y27du2iX79+BusaMmRIprFotVpWr15N+/btDZKPj8cHGCSNExISSExMxNHRkfLly3Ps2LEstzknT1AuWrQIlUrFq6++CqSPO9GjRw8WL15MfHy8yQF6s7MdGo1G/+SnLokK6YP+vvvuu4wZM4azZ88SGBiIo6MjR44cITw83OBBm+zasGEDjRs3Noo1J9uWkJDA1atXGTBgQI7XL4QourTx8YTHJ3EswYKD12I4dDWKiX98j1tS1m9oPClVbU6TSt4MbupfAJEWDbq3qat6OxAZGYm7u3uR7Ov7mcj6WVpaMmnSJP1rDqYEBQUZPDmqo1KpGD16NKNHj850HV9++SVffvlllrE4Ozvz+++/Z1nuWaBSqRjfuTL7r+wmRaPl4t0HLDpwk4GNS2ddWQghhBBCvLBKONuw88Pm2S4/N+Qqy46EZrt8n7o+ObpZLOGc977+n3zS0MXFBSBPg8n6+PgYfHZycspw+oULFwCIjIzk/v37LFiwgAULFphc7r1794ymPd5/pymRkZHEx8dTtWrVTMslJSUxbtw4Fi9eTHh4uME8N7f8e3tNURSWLFlC48aNKVmypH567969mTdvHitXrjSZWMzOdkRGRvLo0SODJKpOpUqVALh+/TqBgYF88803DBgwgJIlS1K7dm3at29Pnz59TNZ9kkajYevWrUZP6eZ0286cOYNWqyUwMDDLdQohii5tUhJ39hzk5rZdKMeO4HrnOhtKN+Dnaj30ZQ55VsIrIZp7ZatS+9pR3CJuZXv5RWHEevGMJFJFwSldzI63mpTm513pv4D/sO0Snat54+4gnZoLIYQQQgjTLM3VOXqV/tU6PjlKpL5ap1SeX9XPqYyeetE9YZnRa4UaTcYjDWe0TFPTdevRarVAep+tj4/V8Lhy5coZTbOxyTyZnNV26IwYMYL58+czfPhwGjVqhJOTE2q1mlGjRuljyw+7du0iNDSUTz/91GB68+bN8fDwYPHixSYTqdndjow8Wb937940a9aMdevW8ffff/PDDz8wefJkfvvttyzfgDxw4AAxMTF07NgxT9t26tQpAKpVq5arbRJCFJ7QQye4tmk7aUcO437zAhaaNB7/yal61BUszFRUK+lMvTKu1H1rKrV8XbC1NOfGa6+TmM1Eqq2VWZEYsV5IIlUAw1qUZfXxO0TEJ/EgOY2pWy7w7StyEhdCCCGEEPmjuo8zlb0dOR8eb3JEXh21Cip5OxbJm0UnJydiY2ONpt+4cSNf1+Pu7o6joyNpaWm0bt0635ZbvHhxHB0dOXPmTKblVqxYwZtvvmk0oNX9+/f1Ay3lh8WLF2NmZmbUH6uZmRmvvPIKs2fP5s6dOwbdrEH2tsPd3R07Ozv9U76P0017/AlkLy8v3nnnHd555x1iY2OpX78+EyZM0CdSM0rabtiwAX9/f8qXL5+nbTt9+jQuLi4GT68KIYoeRVG4fSuCI9EaDl6L5uC1GN7YMocmYacNyoU6FOdO6SqkBQZS4aVWnKrqh61l3tJvpYvZF6l+Ql9k0hmmwM7KnM87/tfv05/HbnP81v1CjEgIIYQQQjxPVCoV03pVx9bSHHUG94FqFdhamjOtV/UiebNYtmxZ4uLiOHHihH7aw4cPWbhwYb6uR5eAW716NcePGw+6FRcXR2pqao6Xq1ar6d69O5s3b2b//v1G83VPapqbmxv1c7p06VLCwsKytZ4LFy5w61bmT1glJSXx119/0bJlS4oXL240v3fv3mi1WoPBv3KyHWZmZrRv354NGzZw6dIl/bwHDx4wd+5c/Pz8qFKlChqNhri4OIP6zs7OlC5dmvv3/7sfsrNLfzr68WkAGzdupFOnTnnetps3b0oSVYgi6talm2z78TfW9X6HA7UacqrXG7y/4hQrjt7mVkwCJ9zLEW3tyMmKDTj1xkhig1fRZN8/vLn0JzoP6krD6qUzTKJaBZTDpmZNqFKN657+nHP1M/p33dMfqlbDuXLW3Y2Ip0OeSBUAdAr04o9DNzl4Lb0PqPFrz7FmaCPMMrrSFUIIIYQQIgfKezqwckhD3l9xknNh6QNtqFXon1Ct6OXID69WJ8DDoRCjzNhrr73GmDFj6N69OyNHjiQ1NZUFCxbg4eFBaGj2uy3Ijq+//pqQkBAaNWrEW2+9RdWqVYmPj+fs2bOsXLmSK1eu4OmZ877ypkyZwrZt22jVqhWDBw+mcuXK3L17l1WrVrF69Wr8/Pzo3Lkzv//+O46OjlSpUoWTJ0+yfPlyypQpk611VKxYkWbNmrFr164My6xbt474+HjMzMz4+uuvjeYrioKFhQWLFy/mo48+ytV2fPXVV2zbto0mTZowdOhQ7O3t+e2337h16xarVq1CrVYTGxtLiRIlePnll6lWrRqOjo7s27ePLVu2GAw+Vrt2bQDGjBlDnz59sLS0pGHDhpw9e5Zp06bledtKly7Nli1bmDx5MqVKlaJcuXLUq1cvW+0thMhft27d5d/Nu3i4/wAuF07hHRfB4z9zuBCLmyaBsmVLUr+MG/VK1aJG6SAaP5EszazbFx2vx8b5qaAonLodx5azEcQlpuJkY0H7Kp5FbsR6IYlU8X8qlYoJXarQ4ac9aLQKZ+7EseJoKH3qlirs0IQQQgghxHOivKcDG4Y3fiZvFl1cXFizZg3vv/8+n3zyCSVKlOD999/HwcEh30dbd3d359ChQ3z55ZesWbOGX375BWdnZwICAggKCsLV1TVXy/Xy8uLQoUOMHTuWZcuWERsbi7e3N23bttW/tj99+nQsLCxYvnw58+fPp3bt2mzZssVkQjO3Fi9eDMCWLVvYsmVLhuVOnz7N6dOnjQZhys52lC9fnn379vHZZ5/x/fffk5qaSvXq1dm4cSPt27cHwNbWlqFDh7Jt2zbWrl1LWloapUuX5rvvvmPkyJH69dWvX58vv/ySuXPnMmDAALRaLbNmzcLe3p5mzZrledvGjBnDlStX+Prrr3nw4AFffvmlJFKFeAoUReHWvXgOhcZz8Fo0h67FUOfoVt45u86gXKrajNslypFavTbeLZqwt2V9bKwt8zUW3Yj11X2c83W5Iv+plCff2xBP3e3bt/Hx8SE0NLTQX+mYsP4cv+27AYCLrQU7P2yOs23+HiByS6PREBkZibu7e4Yd94uMSfvljbRf7knb5Y20X95I++WetN3zLzvXoLr+P58c0f5FoigKaWlpmJubF+lEb1H1vLZfx44dsbKyYtWqVQW6HlPtJ3+X2SPnsbx5HttPURRuRj7g9K4j3N+9D7tzxykdeYO+7caSYGENgG98OLP/mUaEhy/JgbUo3qwxldo2xtbRPkfreh7b72kp6m0nT6QKA6NaB7D+VBhRD1O4n5DK939fYlK3KoUdlhBCCCGEEEIUGc2aNaNRo0aFHYYQIhOKonAj6hEnDpwhKmQP1mdOUCH8IuVSEw3K1Yy9jtKgMfVLu1GvdH1KT+xO5WJuhRS1KOokkSoMONlY8HH7Cnz8V/qoc38cuknvuj5U9i56I6cKIYQQQgghRGH4+OOPCzsEIZ4LiqJwMjSWLeciiE9Mw9HGnPaVPanu45zjp9gVReF61CMOXovh0PVoDl6Lxv7WVWbu+tGo7ENbRx5WroFL44b80rU9tp7Gg8MJYYokUoWRnjVLsuTQLU6GxqJV0gee+vPdBs/VqzhCCCGEEEIIIYR4+sKDgki+dJnEFA3Xox6SkKKh3GPzrwLhlmaULmaPc5UKBoMyPU5RFK5FPeLwuVvcCTmA+YmjBIRdYEWVLpwsHgBApJM3Dy1ssEDhQYWqODZsQJn2LbGvECA5DpErkkgVRtRqFRO7VqbrrH0oChy9eZ81J+/QvUbh9t8qhBBCCCGEEEKIZ1vypcskHj8OQOnMCkZALOD1/4+KonA18iGHLt3lxp5DqI8fpdydC1S9f4vqilZfrW7MFRwbN6J+GVfqlXajwtt/YutbCpWFRQFtkXiRSCJVmBRY0pnedXxYejgUgMmbLtC6ogcO1nLgEUIIIYQQQgghRMG7GvmQI/uvc+jGfQ5diyb6QRJ/bJlIzeSHRmUfliyNdf36DOvWEcfatR6b4/L0AhbPPUmkigx91K4Cm85EEJeYSuSDZGb8c4UxHSoWdlhCCCGEEEIIIYR4Rj1KTkOdzbLF7oezfdocNpZtmj5BpeaGS0lcIi6QWMwTy7r1KNGqKQ4N6mPu6lpgMQuhI4lUkSFXO0s+bBvA2LXnAFiw9zq9apekbHGHQo5MCCGEEEIIIYQQz6L7CSm4ZbOsU0oCjcPPEN+pJ/XLuFK/jBv+/fyxcLDHsqR0Pyievuz+CCBeUK/V86WilyMAaVqFoHX/oihKIUclhBBCCCGEEEKIZ83D5DQSUjTZLv/AwoZH9ZuyoH8dBjf1J7CkM3YVK0gSVRQaeSJVZMrs/wNPvTLnAAB7r0Sx9VwE7at4ZVFTCCGEEEIIIYQQLzKNVuHsnTj2/XuHG3sOY3PmOG2i72a7/i0HD8Jbdyu4AIXIIUmkiizV8XOle40SrD5xB4BJG87TLKA4NpZmhRyZEEIIIYQQQgghipKw2ET2XIzgwu6jaI4epnzYRRpGX6eFNg0AbQ6X176KZ/4HKUQuSSJVZMtnL1Xg73MRPErRcCc2kZ93XeH9tuULOywhhBBCCCGEEEIUokfJaRy6Hs3uS1HsuRxJqeN7eO/UKgLTkozKpnh4Y5WajBITna1l21qZUa2kU36HLESuSSJVZEtxR2tGti7H5E0XAJiz+xo9a/lQys22kCMTQgghhBBCCCHE06LVKpwLi+fQoXNEhezF7uJZfqjWE606/a1VO2tH7P+fRE1xcsGyTl08mzfGvkEDLEqU4MZrr5OYzURq6WL2qFSqAtsWIXJKEqki2/o3LM3yI6FcjXxESpqWiRv+ZV6/2oUdlhBCCCGEEEIIIQpQeFwi+49fJXTHHsxPHqNC+EUaPYrSz9/oWw9V5UCalitGE9/quDZxwblxQyz9/fOUCLWxkC4FRdEiiVSRbZbmaoK6VKbv/MMAbD9/l50X79GifPFCjkwIIYQQQohnT3BwMAMGDOD69ev4+fkV2WU+zeULIYqGhJQ09l+P48zhKC4dPUefLb9QKT6MSk+USzO3ILlCVea/WYsSjev/N6Pymxku2yqgnMHnR8lp3E9IQaNVMFOrcLG1xM7K3GRZIQqbJFJFjjQp5077yp5sORcBwMT1/9LQ3w0rc/mVSAghhBBCmBYeFETypcvZKmsVUA6voKCCDegZt3fvXrZv386oUaNwdnYu7HCeGb6+vvTv358JEyYUdihCFDlarcK/N6M4vW0fD/YfYJlDBa7bpT80ZZ1mwacP0nMAWpWah6UDcGzUAJ9WTbGtUQO1lVWO1iXHePEsk0SqyLEvOlVk58V7JKdpuR71iPl7r/Ne87KFHZYQQgghhCiiki9dJvH48cIO47mxd+9eJkyYQP/+/Y0SqX379qV3795Y5TCx8bw7c+YMt27domPHjoUdihBFRsT9RxzdfpDIkL3YnT1Buchr1NCkAnCpspbr5YpTwdOBpgFlSPIahn+1AJzq18PM3r6QIxei8EgiVeRYSRdbhrYoy7RtlwCYseMK3WuUwMvJppAjE0IIIYQQ4sVmZmaGmZm8LfakjRs3Urx4cerUqVPYoQhRaBJTNBy+EcOeC3fxnzMF/9Dz+Kcm4v9EuXiX4nSs6cOA3oFU9PNOP6Z0qFgoMQtR1KgLOwDxbBrctAw+rumJ08RUDZM3XSjkiIQQQgghhMi5mzdvMnToUCpWrIitrS3Ozs507tyZc+fOGZQLCgpCpVJx+fJl3n33Xdzc3LC3t+eVV14hOjo6x8t70rZt21CpVKxevdpo3saNG1GpVGzYsIGgoCA+++wzAEqXLo1KpUKlUrFr1y4gvQ9TlUrFjRs3DJYRHh7OO++8Q8mSJbGyssLPz4+3336bBw8e5ClunQsXLnDr1q1slc2LlStXUqdOHWxsbChbtiwrV64EoHXr1rzzzjsZ1tuwYQMvvfSSyUFvXnrpJerVq8eZM2fo0qULjo6OeHh4MG7cuALbDiGeBkVR+PfkZdZMmcvYL+ZTbeLf9FtwmHn7b2Ibcw+H1EQAHto6cq9OM5QPP6f0339T70AIzT4dSjE7i0LeAiGKHnkiVeSKtYUZ4zpV5u3fjwKw/lQYr9UtRQN/t0KOTAghhBBCPC0pTyTrnmTm7IxZHvrwTL17FyUxMcP5KisrLLy8cr18gCNHjrBr1y66d++On58f4eHhzJ07l6ZNm3Lu3Dk8PDwMyvfp0wdPT08mTZrE5cuXmTFjBhYWFixZsiRby/P09DQZR6tWrShZsiSLFi2ie/fuBvMWL16Mu7s77du3p1SpUly4cIHly5fzww8/UKxYMQAqVsz4abGIiAjq1q1LZGQkb7/9NlWqVCE8PJzVq1cTHR2Ng4NDruPWqVixIs2aNdMndAtCUFAQEyZM4LXXXmPAgAHMnTuXAQMGULx4cUJCQvjll19M1ouJieHgwYOMGjXK5PzTp0/j4uJC27Zt6devHx07duSPP/5g0qRJNG7cmLZt2xbYNgmR3yJCwzmz/h/i9x/E5eIpvB5EUh6I9qpCSr3+AJT3cCDupe48sFdRvn0LHCuVN/kjgxDCmCRSRa61rlic5uXd2XUxEoCgdefYOKIx5mbyoLMQQgghxIvgavuXMp3vPmokxd59N9fLDx83jkchuzOcb1OjBn5Ll+R6+QAdOnSgZ8+eBtP69u1LpUqVmD9/PmPGjDGYFxAQoE+aQvoTXzNnzmTOnDk4OjpmubzPP//cZBxqtZq+ffvy/fffc//+fVxcXAB48OAB69at46233sLc3JzAwECqV6/O8uXL6datG35+fllu46effsqdO3cICQmhSZMm+ulBQUEoipKtdsgo7qflwIEDTJgwgZEjR/Ljjz8CUKlSJVq0aMHQoUPp3bs3ZcqUMVl3y5YtqNVqkwnRmJgYwsLCSEhI4NixY/pldO/eHQ8PD44ePSqJVJFvnhx4L6vR6rMzKFNSqoYjN2IImx+M655teEffpiSKQZkUtTkezrZ890o1GpcthqeTNdA0PzdNiBeGJFJFrqlUKsZ1qsS+K7tJ1ShcvPuARQdvMqBR6cIOTQghhBBCiGyxtbXV/39CQgKJiYk4OjpSvnx5jh07ZlR+6NChBp+bNWvG9OnTuXnzJlWrVs3x8h7Xr18/pkyZwooVK/Svqa9cuZKEhAT69u2bq+3TarWsXr2a9u3bGyRRdXRPoeUlbkCfkC0oP/zwA46OjkyYMEE/rXTp9PuOs2fPsmzZsgzrbtiwgcaNG+Po6Gg07/Tp0wB8/vnnBolYS0tLwLBdhMirJwfeUwNPvtOZ8TP46bTJyVzcdZDz/95ktWN5Dl+PITlNyzunLxMYHQqABhXhXqXRVKtNydZNqdyqEdVsrPN1W4R4UT0TidTU1FQmTZpEcHAw9+7dIyAggE8//ZTXXnsty7qKovDTTz8xa9Ysbt68ia+vL8OGDWP48OEGj64HBwczYMAAk8u4fPkyZcsajkp/5coVPvjgA3bt2oWiKLRo0YJp06bh7/9kN83PtzLu9gxqUoafd10FYNrfl+gU6I27g4wSKoQQQgjxvPPfsjnT+Xl5rR/Aa+LELF/tz6ukpCTGjRvH4sWLCQ8PN5jn5mbcbdWTT4DqnhyNiYnJ1fIeV758eerVq8fixYv1idTFixdTvnz5XA+SFBkZSXx8PFWrVs20XF7izoldu3bRokWLbJXdtm0brVu3RqPRsHXrVtq1a4eTk5NRuR49elCpUiWTy9DVzeiJWl0itVu3bgbTz58/D6R/J0IUhsQUDQCKRkPEsdNc2rKT5MOHKH79AlaaFDyt7NnTfjz8P69xs3JdLnk74t6sEYEdmlPFzaUwwxfiufVMJFIHDx7M77//ztChQ6latSpr1qzh9ddfJy0tjTfffDPTuhMnTiQoKIi+ffvy8ccfExISwsiRI4mNjTXZeXhQUJBRMvTJ/oAiIiJo0qQJFhYWjB8/HoAff/yRJk2acOLECaN+lJ53w1qUZfXxO0TEJ/EgOY2pWy7w7SvVCjssIYQQQghRwCyz8Vp5Xlg8hevqESNGMH/+fIYPH06jRo1wcnJCrVYzatQotFqtUXkzMzOTy9E9kZnT5T2pX79+DB06lBs3bmBpacnOnTuZOHFirrdPF1dW/R/mNe7sKl++PL/++mu2yuqSo9euXSM+Pp5atWoZzI+MTO9ibPjw4Rku48CBA8TExNCxY0eT88+cOYOzs7PRgzOnTp0CoFo1ua8RhePOnUjuvNIfl0tnsE1OoPgT81PNrXjV35ZaNcvRpFwxvJxsCiVOIV40RT6ReuLECYKDg5k4cSJjx44FYNCgQbRs2ZKPPvqI3r1761+7eFJERARTpkxh4MCBzJ8/X1/XzMyMyZMnM3jwYKMkabt27ahfv36mMU2ZMoWYmBj+/fdffdK1c+fOVK5cmSlTpuj77HlR2FmZM6ZjRUYsPQHAn8du81q9UtQoJb+ACSGEEEKIom3FihW8+eabRtfw9+/f1w/k9DSX17t3b0aPHs3ixYuxsrJCURTeeOMNgzI5GRSmePHiODo6cubMmQKNO7u8vLwYNGhQjupERUUB/z39qzN58mQAXF1dM6y7YcMG/P39M3yy9PTp0yaTpadOnaJYsWJ4e3vnKFYhMvMoOY3sjiiiTU3F58wh/edYK3tC/SphXrsu5dq3pEmtCjRTywBRQjxtRX5UoBUrVqBWqw36IlKpVAwbNox79+6xc+fODOuuXbuW5ORko18ohw8fTnJyMmvXrjVZ78GDB2g0mkxjeumllwyeXC1Xrhzt2rVjxYoV2d2050rnQC/qlf7vAmbc2nNotAXbT5IQQgghhBB5ZW5ubtS/59KlSwkLCyuU5bm4uNC5c2cWL17M4sWLadq0Kb6+vgZl7OzsgPQkZ1bUajXdu3dn8+bN7N+/32i+Lta8xn3hwgVu3bqVrbI55fz/LiIeTwZv2rSJ1atXA+l9umZk48aNdOrUyeQ8RVE4d+5cholUeRpV5Lf7CSnZLvvIwoZj3pU50bk/kT8uoOqh/fRau5AeY4dStU5F1JJEFaJQFPknUo8dO4a/v7/Rr4z16tUD4Pjx47Rr1y7DulZWVgQGBhpMr1GjBpaWlhx/rJNnnTZt2vDw4UMsLS1p1aoV3333nUF/O2FhYURERFC3bl2juvXq1WPDhg2Eh4fj5eWV4TbFx8cTHx+v/6zrg0ij0WSawC3qxneqSOdZ+9FoFc7ciWPZ4Zv0ruOTb8vXaDRotdpnuo0Kk7Rf3kj75Z60Xd5I++WNtF/uPc22y+hVafH8sAooVyBl80Pnzp35/fffcXR0pEqVKpw8eZLly5dnOAL801hev3796Ny5MwDz5s0zml+7dm0AxowZQ58+fbC0tKRly5YUL/7ky7/ppkyZwrZt22jVqhWDBw+mcuXK3L17l1WrVrF69Wr8/PzyHHfFihVp1qwZu3btyvZ2ZleFChUoW7Ysc+bMwdHREVtbW77++mt69erFihUr+Oabb/jkk0/094g6t27d4uzZs0ybNs3kcq9evcqjR4+MEqaKonDmzBkGDx6c79siXkwPklI5feAM1mHZ/7FBo1Jzbth4pvTIvH9jIcTTVeQTqWFhYSaTkrpXLDL7hTQsLAwPDw/UasMHb9VqNR4eHgZ1bW1t6devHy1atMDJyYkTJ04wbdo0GjZsyNGjR/V95ujqZBVTZonUadOmGYw2qRMdHY1VPnSYX1hczeDlQHdWnLwHwNQtF6jtaY6Tdf7sZlqtlri4OACj71RkTdovb6T9ck/aLm+k/fJG2i/3nmbbPdnVknj+eAUFFXYIGZo+fToWFhYsX76c+fPnU7t2bbZs2cJHH31UaMtr3749Hh4exMXF0bNnT6P59evX58svv2Tu3LkMGDAArVbLzp07M0ykenl5cejQIcaOHcuyZcuIjY3F29ubtm3b6l/bz+92yE8qlYqVK1cyZMgQvv/+e6ytrRk0aBDfffcdJUuWZNasWdSqVcsokbphwwbs7e1p1qyZyeXqnnB9MpF69epVHj58aPRAjhDZkRZzn9C9h7iz5yB7KjVld7wFFyPicUqMZ0lyxk9Pm+JkY1FAUQohckulPPn+RhHj7++Pv78/f//9t9E8S0tL3nzzTZO/0gK0atWK0NBQLl26ZDQvICCAUqVKsX379gzXfeTIEerXr0+fPn1YvHgxAHv27KFp06b88ccfvPbaawbllyxZwuuvv86ePXto3Lhxhss19URq3bp1uXHjBiVLlsyw3rMgPjGVVj/sIeZR+isLfeuVIqiL6RE0c0qj0RAVFUWxYsXkyZVckPbLG2m/3JO2yxtpv7yR9su9p9l28t0Ujtu3b+Pj40NoaGiG16A3btwAjEeqf5EoikJaWhrm5uY56ps0L7RaLaVKlaJRo0YsX778qayzoBRG++l07NgRKysrVq1a9VTXm59MtZ/8XWaPRqMhMjISd3f3AjvPKIpCwvWbXN21n+gDh7E8fwbXqP8e2Pq+5qtsL1VH/3np1ok4J8abWpSRc65+lF2+lOo+zvkddrY8jfZ7nkn75V5Rb7si/0SqtbU1ycnJRtO1Wi2pqalYW1vnuC5AUlJSpnUB6tSpQ5MmTQySrbo6ppablJRkUCYjjo6OODo6Gk03MzMrkjtJTrjYm/HpSxX4+K/TAPxx+Ba965WisrdTvixfrVY/F+1UWKT98kbaL/ek7fJG2i9vpP1yT9pOiMKxadMm7ty5Q79+/Qo7lGdas2bNaNSoUWGHIZ4zsQkpHLt5n1MXbtN40lAcEuKwAB5/t0KLiluOnpRyd+CdpmWo5etCTV8XHt5aQqKJLgZNsbUyo1rJ/LmPFkLknyKfSPX29ubmzZtG03Wv2Gc2iqK3tzc7duxAq9UavJKm1Wq5e/dutkZgLFWqlEGn7Lo6un5NcxrTi6BnzZIsOXSLk6GxaBUIWneOFe80eOq/QAshhBBCCPEsOXToEGfOnOGrr76iUqVKtG/fvrBDeqZ9/PHHhR2CeMalPXjA9d2HuLPnEJrTJ9niV48/HSvq51fHDAcgWW3OtWK+xJetjF2tmpRpXp+W5UvykrnhD5EPc7Du0sXs5R5aiCKoyCdSa9asyT///ENMTIzBgFOHDh3Sz8+s7rx58zh9+jTVq1fXTz9x4gQpKSmZ1tW5du2aQV9DJUqUwMPDg8OHDxuVPXToEJ6eni98IlWtVjGxa2W6ztqHosCRG/dZezKMbjVKFHZoQgghhBBCFFk///wzixcvJjAwkAULFkjfzkI8ZQ9D73Bpxz6iDx7B4vwZ3O/dQq0oePx/vpfGDmqkJ1IDPOz595XBJJb2olKzunT3dMoy8akbTC8xRcP1qIckpBgP6GhraUZpd3ucK1fI120TQuSPIp9I7dWrF1OnTmX27Nl88cUXQHo/JDNnzsTd3Z0WLVoAEBUVRVRUFKVKlcLW1haArl27MmrUKGbOnGnQj+qMGTOwtLSka9eu+mlPJmoBduzYwb59+xg4cKDB9FdeeYVffvmFq1ev4u/vD8Dly5fZunUr7777bv43wjMosKQzvev4sPRwKABfbTpPq4rFcbCWzrKFEEIIIYQwJTg4mODg4MIOQ4gXgqLVcvduDCejUzl64z7Hbt1n5LyP8UiI4cleo+/Yu3PPtzzu9RvzW5c61PRxwcnWAjA9kFlGHh94r4KicOp2HFvORhCXmIqTjQXtq3hSrWTWCVkhROEp8onUWrVq0bdvX8aPH09kZCRVq1ZlzZo17Nq1iwULFuhHuZ85cyYTJkxg586dNG/eHEh/xf6TTz5h0qRJpKam0rRpU0JCQli0aBHjxo3Dy8tLv57GjRtTo0YNAgMDcXZ25uTJk8yfPx8PDw+CnhhldMyYMfz555+0bNmSUaNGAfDDDz/g6urKZ5999jSa5ZnwUbsKbDqTflKIfJDMjH+uMKZDxawrCiGEEEIIIYQQ+Sg1IZErj72mX+zmRXZ5V2NG9Z76Mudc/XBLjOWmWyniy1bCtlZN/Js3pFklXyzM8vcJcZVKRXUf50IbTEoIkTtFPpEKMG/ePHx9fQkODmbOnDkEBASwaNEi3njjjSzrTpgwARcXF2bNmsWyZcvw8fFh2rRp+gSoTo8ePdi0aRObNm3i0aNHeHp60r9/f8aPH0+JEoavpHt5ebF7924++OADxo8fD6R3ZD5t2jSD5OyLztXOkg/aBjBu7TkAFuy9Tq/aPpQtbl/IkQkhhBBCCCGEeJ49Sk7j7KbdpBw9juX5s3jcvYG5VsPjd/eVo6+jVkElb0dqlXLB76XPcCnvTVVP1wyXK4R4sakURVEKO4gX3e3bt/Hx8SE0NJSSJZ98ieDZlqbR0nnmPs6HxwPQuGwxFr1VN1evKmg0GiIjI3F3d5fRg3NB2i9vpP1yT9oub6T98kbaL/ek7Z5/2bkGvXHjBgB+fn5PL7AiRlEU0tLSMDc3l9dtc0HaL29Mtd+z+ncZHhRE8qXL+s+PktO4n5CCRqtgplbhYmuJnVX6s15WAeUMXoPPiqIohJ65yPnzt9hv7c3Rm/c5Hx7Ph0f+oMXtEwZl79m6EOlbAXVgNbyb1Kdqk1r69Yr/yHVA3kj75V5Rbzs5WogCZW6mZmLXyrwy5wAAe69EsfVcBO2ryJO7QgghhBBFnVqtJi0trbDDEEI8RqvVYm7+7N3KJ1+6TOLx4/rPasDtiTKJ2VxWSmISF/cc5c6eg6SdPoX7zQs4Jj1E6+DBwlYf6cudKeaPf0IkD8pWwqZmTcq0bEjjKmUxU0tSXwiRO8/e0Vc8c+r4udKtujdrToYBMGnDeZoFFMfGsuj9siCEEEIIIf5jYWFBYmIiWq1WRpAXogjQarWkpqZiY2NT2KEUqMQnRrOPS0jleOh9wlavx23bejzDr2GlTcX3iXrFE2Op62VDoL8nNUs54WsXSAW/oCL5VJsQ4tkkiVTxVHzWoSLb/r3LoxQNd2IT+TnkKu+3CSjssIQQQgghRCYcHR2JjY0lKioKd3d3eTVbiEKkKApRUVFotVocHR0LO5wCdTMiltDpwURfus5vZVty6e5DADpdu8rQOxf15e5bO3LPrzxmVavj3aQelRrXYoWtNfDf68FCCJGfJJEqngoPR2tGti7H5E0XAJgTcpWeNUtSys22kCMTQgghhBAZsbOzw8HBgejoaOLj45/J14nzgzyRmzfSfnmja7+0tDRSU1NxcHDAzs6usMPKsUfJaWR3L/CNugk/f0MJVNyxqgqW6U/gxgRU5aJlPLY1a1CmZSPqBwbIviWEeKpezCshUSj6NyzNsiOhXIt8REqalkkb/+XXN2sXdlhCCCGEECIT3t7exMbG8vDhQ7RabWGH89QpikJycjLW1tbyRG4uSPvlzePtZ2lpiaurK87OzoUdVq7cT0gx6hM1M6lqM8I9SvNedRcq1KlCzVIuuNhZAn0KKkQhhMiSJFLFU2Npriaoc2XeXHAYgG3/3mXnxXu0KF+8kCMTQgghhBAZUavVuLq64urqWtihFIqiPnpwUSftlzfPQ/spWi03Dp9GFR2V7TrXHb3YO3oqk/vUpl0BxiaEEDkliVTxVDUNcKd9ZU+2nIsAYOL6f2no74aV+bN5USCEEEIIIYQQ4j+KohBz8TIXNu7kwcGDuF06g33yI3LyU0yCuRWOzvYFFqMQQuSWJFLFU/dFp4rsvHiP5DQt16MesWDvDYY09y/ssIQQQgghhBBC5IJGq3D6dixH9p+l2rcf4/goFlcwSJ6mqdSYK9nvHqR9Fc98j1MIIfJKemUWT11JF1vea15W/3nGP5cJj0ssxIiEEEIIIZ4ORVGYPn06AQEBWFlZERAQwE8//YSiKNmqP2fOHF599VVKly6NSqWiffv2BRyxEEIYS4uK4vqK1ex5930mfrOMGhP/pvvs/Xx9/D5myUlA+lOl5/0C+bfbQB7O/A27atWyvXxbKzOqlXQqqPCFECLX5IlUUSjeaVaGv46HEhqTSEKKhsmbLjCjT43CDksIIYQQokBNnDiRoKAg+vbty8cff0xISAgjR44kNjaWcePGZVn/66+/Ji4ujjp16hAVlf3+BoUQIi80cXHE7D/IjW27STt2BOe7oQAUA9LCUoiv9BIAHi52HH75HcpWLUetNg2o5WSrX8aNBTOyvb7SxexlcDIhRJEkiVRRKKwtzBjXqTJv/34UgPWnwnitbika+OdkHEchhBBCiGdHREQEU6ZMYeDAgcyfPx+AQYMGYWZmxuTJkxk8eDCenpm/yhoSEkKpUqVQqVT4+fk9haiFEC8qrVbh7J1YHg15G6cr/6JC4fFeS9NUai67+uJcxpdxnSrRNKAY/u72qFStTC7PKqAcAIkpGq5HPSQhRWNUxtbSjNLu9jhXrlAQmySEEHkmiVRRaFpXLE6zAHdCLkUCELTuHBtHNMbcTHqcEEIIIcTzZ+3atSQnJzN8+HCD6cOHD2fhwoWsXbuWd955J9Nl+Pr6FmSIQogXlDY5mcSTp7gbspcrFs6s86jOvitR3E9I5au4VGqioEXFFecS3PGrhHXdelRo05iOFUtke+Bgr6Ag/f9XUBRO3Y5jy9kI4hJTcbKxoH0VT6qVdJInUYUQRZokUkWhUalUjO9ciXY/7iZVo3Dx7gMWHbzJgEalCzs0IYQQQoh8d+zYMaysrAgMDDSYXqNGDSwtLTl+/HiBrj8+Pp74+Hj95/DwcAA0Gg0ajfGTYSKdRqNBq9VKG+WStF/eFFT7KWlpJJ07R9z+A0Tu3o/Zv2cwT0sFIKaYPxsae+jLbg9sQ5R7R/xaNqZ+NT+K2VsZxZgbVb0dqOrtYDBNq83+YFRZkX0vb6T98kbaL/eedtuZmWXvxyAdSaSKQlXG3Z5BTcrw866rAEzbdonO1byNTs5CCCGEEM+6sLAwPDw8UKsN375Rq9V4eHgQFhZWoOufNm0aEyZMMJoeHR2NlZVce2VEq9USFxcHYPTdiaxJ++VNfrefVlEImz4Hy83rsUhOH/D38b/+O3bFCHXyol4pR+r6OlLf15EybjX1T4kqifFEPiPjBMu+lzfSfnkj7Zd7T7vtsupW6UmSSBWFbliLsqw+foeI+CQeJKUxdcsFpvbM/oiOQgghhBDPgsTExAwTltbW1iQmFmx24v3332fQoEH6z+Hh4dStWxc3Nzfc3d0LdN3PMt0TMcWKFcvxUytC2i+vctt+iqKQeuMGCYcOEX83iqNNu7P3chR7r0TT7kIsff+fRI2yduKke1nula2Kc8N61K5bkfd8XbC2ePa/K9n38kbaL2+k/XKvqLedJFJFobOzMmdMx4qMWHoCgBVHb9OnbilqlHIp5MiEEEIIIUzbtWsXLVq0yFbZbdu20bp1a6ytrUlOTjZZJikpCWtr6/wM0YijoyOOjo5G083MzIrkjUpRolarpZ3yQNovb7Lbfql37vDo4CHiDxwgfv9BzGKiAEgys+CzCD/S1Om3/3u9A0lzdMa6Xj2q16/Kq+XcKe5YsMefwiL7Xt5I++WNtF/uFeW2k0SqKBI6B3rxx8GbHLoeA8D4dedY814j1GrpaFwIIYQQRU/58uX59ddfs1W2UqVKAHh7e7Njxw60Wq3Bq2parZa7d+/i7e1dILEKIQpXeFAQyZcuZ6usVUA5g0GZsiP+723c+fobCLujn6ZLPTwyt+ZMsTK4pSVRtpIvTcq506RcYyp6Osq9lhBC5IIkUkWRoFKpmNC1Mh1/2otGq3D6dhwrjobSu26pwg5NCCGEEMKIl5eXwWvy2VGzZk3mzZvH6dOnqV69un76iRMnSElJoWbNmvkcpRCiKEi+dJnEfBhMThMby6PDh0k8dgzVuyPYfyOW3Zcjidl9kU/+n0RNMrPgX9fSnHQvS0xAVfzq16JJRQ92lXbDxrLoPdklhBDPGkmkiiKjgqcjfev7Erz/BgDfbLlA+yqeONtaFm5gQgghhBD5oGvXrowaNYqZM2cyb948/fQZM2ZgaWlJ165d9dPi4uIIDw/Hy8sLJyenwghXCFHINA8f8fDIYRJ27uLG6dMkX7yISlEAGHnDmUsu6Q+dWFmX4I/ybbjqU4Hi9WrTqKIX75crhpeTTWGGL4QQzyVJpIoiZXSbANafCiP6UQr3E1KZtu0SE7tWKeywhBBCCCHyzNvbm08++YRJkyaRmppK06ZNCQkJYdGiRYwbNw4vLy992dWrVzNgwAB+++03+vfvr5++fv16Tp06BaQnW69du8aXX34JQJcuXQgMDHyq2ySEyF+JKRpSbtwg7NPPSDxzBv4/6AqACtCg4opzSazSUjBXq6jl60LTAHealmtJZW95XV8IIQqaJFJFkeJkY8En7Svw8crTACw+eJPedUpRydt4YAQhhBBCiGfNhAkTcHFxYdasWSxbtgwfHx+mTZvGqFGjslV/5cqVLFy4UP85NjaWsWPHAlCyZElJpArxjLt87wG/7wnn9VOnMVO0AFx39OJUsbKcdC/Lw/JVqFPFl1HlilGvjBv2VnJLL4QQT5McdUWR07NWSf44fItTobFoFRi/7iwr3mlQ2GEJIYQQQuSZSqVi9OjRjB49OtNy/fv3N3gSVSc4OJjg4OCCCU4IUSAeJaehzroYAClpWpaeiyGtYnsi7Fy5XqI8AeW8aF25BH3LF6eki22BxiqEECJzkkgVRY5arWJil8p0m70PRYEjN+6z9mQYnQM9Czs0IYQQQgghhMiR+wkpuOWgvKejFT5D3+W1csWo7OVATHQU7u7umJnJYFFCCFHYsvvDWLYp/+/8Woi8qObjTO86PvrPkzed52FyWiFGJIQQQgghhBA5p9Hm7B65RQUPRrQqR41SLphJn6dCCFGk5EsiNTQ0lBkzZtC6dWvs7Oxo0aIFP/30Ezdu3MiPxYsX1Idty+Nonf7Q9L0HyczcebWQIxJCCCGEEEKI7FMUhTRNzhKpTjYWBRSNEEKIvMp1IvX06dNMnDiRmjVr4ufnx9ixY3F3d+fbb7/F29uboKAg/P39qV69OkFBQZw4cSI/4xYvADd7Kz5sV17/+bd9N7gRk1SIEQkhhBBCCCFE5tKio7n7/TT+PnadLjP3cT8hJUf121eRLs2EEKKoylUfqeXLl+fy5cuUKFGCLl268PXXX9OiRQssLNJ/ORs6dChpaWns2rWLdevWERwczMSJEylVqpQ8pSpy5LW6pVh6OJTz4fGkaRWm7brFkoCShR2WEEIIIYQQQhjQxMcTNX8BUcELUScnsW3fHc4EtMzRMmytzKhW0qmAIhRCCJFXuXoitW/fvhw5coTQ0FBmzZpF27Zt9UlUHXNzc1q3bq1/xf/YsWMMGDAgV0GmpqYybtw4SpUqhbW1NYGBgSxZsiRbdRVFYfr06QQEBGBlZUVAQAA//fRTln25fvnll6hUKipUqGA0r3///qhUKqN/JUtKgi+/mZupmdClsv7z4VsP+Pvfe4UYkRBCCCGEEEL8R5uQwN05cznfvBUxc+eiTk4iwdwKrUpNvdKu+NerBlUCuVCsNOdc/TL8d6FYacrUqYZKJf2iCiFEUZWrJ1I//PBDPv74Y86fP88bb7yRrTo1atSgRo0auVkdgwcP5vfff2fo0KFUrVqVNWvW8Prrr5OWlsabb76Zad2JEycSFBRE3759+fjjjwkJCWHkyJHExsYybtw4k3VCQ0OZMmUKdnZ2GS7XwsKCBQsWGEzLrLzIvbqlXelW3Zs1J8MA+GrTeVpU8MDGUkatFEIIIYQQQhQObUoKUUuWEjF7DhbxsZgBKWpz1pduSGj7nrzVqRZ1S7sCDQBQRzzg/RUnORcWn/5ZBbpxqCp7O/LDq9UJ8HAonI0RQgiRLblKpM6cOZNZs2bRqVOn/I7HyIkTJ/RdA4wdOxaAQYMG0bJlSz766CN69+6NpaWlyboRERFMmTKFgQMHMn/+fH1dMzMzJk+ezODBg/H0NO5/5sMPP6RBgwakpaURERFhctlqtTrbSWSRd591qMi2f+/yKEXDndgkfg65yvttAgo7LCGEEEIIIcQLKDFFw55x3+KzZjEWQJpKzVbfetzu1Jv+XetS3cfZqE55Twc2DG/MqdtxbDkbQVxiKk42FrSv4km1kk7yJKoQQjwDcvVq//Lly2nTpg1t27bNtNy3335L7dq1+ffff3MVHMCKFStQq9UMHTpUP02lUjFs2DDu3bvHzp07M6y7du1akpOTGT58uMH04cOHk5yczNq1a43qhISEsHLlSn744YcsY9NqtcTHx2fZTYDIOw9Ha4a1LKv/PCfkKreiEwoxIiGEEEIIIcSLRNFqeZCQzJyQqzSZ+g8fpJTjgYUNO3xqseS97+iw4Ed+HNbWZBJVR6VSUd3HmU9fqsCUHlX59KUKVPdxliSqEEI8I3L1ROr58+f56quvsiw3fPhwvvvuO5YuXcqkSZNysyqOHTuGv78/rq6uBtPr1asHwPHjx2nXrl2Gda2srAgMDDSYXqNGDSwtLTl+/LjBdI1Gw/Dhw3n77bepWrVqpnGlpKTg6OjIo0ePcHJyolevXkydOhVnZ+cstyk+Pp74+Hj95/DwcP36NRpNlvVfVH3rlmTpoRvcup9MSpqWiRvOMfeNmoUd1jNDo9Gg1WplH8slab/ck7bLG2m/vJH2y72n2XZmZtJdjxCi6FIUhXtbt3Pz22ks9mnIOs/0exAza3tWffATgzpUo2xxeSVfCCFeBLlKpEL2+gO1tramR48e/P3337lOpIaFheHl5WU03dvbWz8/s7oeHh6o1YYP3qrVajw8PIzq/vzzz4SGhmYZq5eXFx9++CG1atUCYNu2bcybN49jx46xf/9+rKysMq0/bdo0JkyYYDQ9Ojo6y7ovMq1Wy+BaLnyxPb27he3n77HuyBUa+Mmoltmh1WqJi4sDMPqbEFmT9ss9abu8kfbLG2m/3HuabWeqqyUhhCgKInbu4drX3+Fy8xIOQI+YOP72qk632r4Mae6Pr5uMkyGEEC+SXCVSS5YsyYULF7JVNjAwkNWrV+dmNQAkJiaaTC6q1WosLCxITEzMcV1IT/I+XjcqKopx48Yxbtw4ihUrlmlMU6ZMMfj86quvUrFiRT788EMWLVrEoEGDMq3//vvvG5QJDw+nbt26uLm54e7unmndF5lGo6FlJWhzR8u28/cAmL4njHY1ymBlLjfHWdE9UVSsWDF58icXpP1yT9oub6T98kbaL/ek7YQQL7I7+w5zZcp3FL9yBpf/TzvnVpqInv3ZMaA1JZxtCjU+IYQQhSNXidS2bduyaNEiJkyYkOWTqSqVitjY2NysBkhPeCYnJxtN12q1pKamYm1tneO6AElJSQZ1P//8c9zd3Rk2bFiu4hwxYgSfffYZ27dvzzKR6ujoiKOjo9F0MzMzuVHJglqt5ouOFdh9OYrkNC03ohNYeOAWQ5r7F3ZozwS1Wi37WR5I++WetF3eSPvljbRf7knbCSFeNKFnL3Fx7ERKnD9G8f9Pu+pckoie/ej89st4OEkCVQghXmS5eoxvxIgRPHjwgJ49e5KQkPmAPydPnjT5an52eXt76/sQfZzutXzdK/4Z1b179y5ardZgular5e7du/q6ly5dYt68eQwdOpSbN29y5coVrly5QmJiIqmpqVy5coXIyMhM47SwsMDT05OYmJicbqLIoZIutrzX/L+Bp2b8c5mIuKRCjEgIIYQQQgjxLLsZ/YhPV56mz7xDFL9wEoDbjh4cG/AxDbZvYNCHb0gSVQghRO4SqWXLlmX27Nls27aN6tWrs2rVKpMDERw6dIjg4GDat2+f6wBr1qzJ1atXjRKUhw4d0s/PrG5ycjKnT582mH7ixAlSUlL0dcPCwtBqtYwcOZJy5crp/x0+fJhr165Rrlw5xo4dm2mcSUlJhIeHU7x48UzLifzxTrMy+LimX8gkpGiYvOl8IUckhBBCCCGEeNZcPnOZL39cQ8vvQ1h2JJTb1q6srtqOs2+Oos4/W3jjkwG42cs4FkIIIdLlerCp/v374+DgwLvvvssrr7xC8eLFadWqFb6+vlhaWnL27FnWrFmDs7Mzn3/+ea4D7NWrF1OnTmX27Nl88cUXQPqoiTNnzsTd3Z0WLVoA6X2cRkVFUapUKWxtbQHo2rUro0aNYubMmcybN0+/zBkzZmBpaUnXrl0BqFy5Mn/++afRusePH09sbCzTp0/H3z/91fHk5GQ0Go1+HTqTJ08mLS2Ndu3a5XpbRfZZW5gxtmMlBi86BsC6U2G8Vq8U9cu4FXJkQgghhBBCiKLu37PXOf3Nj1Q6toNqDp5om4+kmIM1bzcpw+v122FvletbZSGEEM+xPJ0dXn75ZZo1a8YPP/zA77//zpIlSwzm169fn19//ZWSJUvmeh21atWib9++jB8/nsjISKpWrcqaNWvYtWsXCxYs0A8mNXPmTCZMmMDOnTtp3rw5kP5q/yeffMKkSZNITU2ladOmhISEsGjRIsaNG6fvcsDd3Z2ePXsarXvmzJloNBqDeeHh4TRo0IAePXoQEBCASqVi+/btrF+/nmbNmtGnT59cb6vImTaVPGgW4E7IpfRuF4LWnWPD8MaYm8nAU0IIIYQQQghjJ8/d5OR3M6l2eCvVNKkAeCbeZ3I9F7p1aoCNpfQJLYQQImN5/pmtWLFifPXVV3z11Vdcu3aN0NBQNBoNZcuWpVSpUvkRI/PmzcPX15fg4GDmzJlDQEAAixYt4o033siy7oQJE3BxcWHWrFksW7YMHx8fpk2bxqhRo3IVi7OzM23atGHbtm0sXLiQtLQ0ypQpw4QJE/j4448xN5dfLp8WlUrF+M6VaPfjblI1ChciHrD44E36Nypd2KEJIYQQQgghipAj/4ZybNocah3cRL209PEVEi2sie34Cg0+HUodZ6dCjlAIIcSzQKUoilLYQbzobt++jY+PD6GhoXl6evd5p9FoiIyMxN3d3WD04K83X2BOyFUAHKzN2flhc4pJP0ZGMmo/kT3SfrknbZc30n55I+2Xe9J2zz+5Bs0e+VvIm8JsP0VR2Hclml83neSdBZ/hkvwQgBQzC+Lbd6fOmJFYu7k+1ZhySva/3JO2yxtpv7yR9su9ot528g60eOYNb1kWD8f0xOmDpDSmbrlQyBEJIYQQQgghCouiKPxz4S7dZ+/njfmHCAlP5pxbadLUZsS260qFf7bT5PsJRT6JKoQQouiR99DFM8/OypzPO1ZixNITAKw4eps+dUtRo5RLIUcmhBBCCCGEeFq0WoWtp++w79dlOF89x8nq6WNdVC3hROmxYyjn74Z1KZ9CjlIIIcSzTBKp4rnQOdCLPw7e5ND1GADGrzvHmvcaoVarCjkyIYQQQgghREFK02jZeDqM3Qv+ovXBNbz+4C4A4bWb0rlvR5oFuKNSyX2BEEKIvJNX+8VzQaVSEdSlMmb/T5yevh3HiqOhhRyVEEIIIYQQoqCkarSsOHyL4e//jDJkIIO2zcXv/0nU1Nr1+erNRjQvX1ySqEIIIfJNgSdSW7ZsyRtvvMG///5b0KsSL7iKXo70re+r/zx160XiElILMSIhhBBCCCFEfktK1bDo4E0GfTgfZfQQRmydSfnY9IcoNFWr4/vHYgIX/4Z1QEAhRyqEEOJ5U+CJ1F27drFkyRICAwPp27dvQa9OvOBGtwnAzc4SgJhHKUzbdrGQIxJCCCGEEELkh4SUNObtuUbTqTsZu+YsHQ+vJjD6GgBKQAV85s2j8ool2NaqVciRCiGEeF4VeCJVq9Xy4MED1q1bh5eXV0GvTrzgnGws+KR9Bf3nRQdv8m9YfCFGJIQQQgghhMiLB0mpzN51hY5Ba/ly43nuPUhGpYIrnd8Av9KUmPETFdeuwr5xI3mNXwghRIF6KoNN2dnZ0aFDBzp06PA0VidecD1rleSPw7c4FRqLVoGgdedY/k59uagSQgghhBDiGRKbkMJv+26wYetRup7azPTbJxje6gPqNKvJkOb++Lvbo3zcG5Vahv4QQgjxdDyVRKoQT5NarWJil8p0m70PRYHDN2JYdyqMrtVLFHZoQgghhBBCvJAUReHErftsORdBfGIajjbmtK/sSXUfZ6MHHqIeJjN/73XW/3Oazme28sONQ1goGgDm2V6m3CsD9GUliSqEEOJpkkSqeC5V83Hm1do+LDuS3un8VxvP06qiB/ZWsssLIYQQQgjxNF2NSuSt5fs5F/4AALUKtArMDblGZW9HpvWqTnlPB+7GJzE35Brr9/xL5/M7mHFtH9aa9MFjVc7OuL/7Di69exfmpgghhHjBSVZJPLc+aleeTWfCiU9K496DZGbsuMxnHSoWdlhCCCGEEEK8EMKDgrh/9jx3Ix4wQKtkWO7cChWnfcvwuV9HWl3dz89nN2CXlgSAys6eYoMG4tL3Tczs7Z5W6EIIIYRJkkgVzy03eys+bFeecWvPAbBg33Veqe1D2eL2hRyZEEIIIYQQz7/kS5dRnT1Ndh5lOKdVSPHRorG1wy4tCZW1Na59++L21kDMnJ0LOlQhhBAiW3KdSA0JCaFWrVrY20tSShRdr9UtxZJDt7gQ8YBUjcIHf56kfmk34pMy75dJCCGEEEIIkTePktPISQ+mb9QrxUftWpO6yBXnHt0xd3cvsNiEEEKI3Mh1z9xHjhyhQ4cODBgwgLS0tPyMSYh8Y26mZmLXKvrPp0LjmLv7GsuP3GJuyDW6z95Ppxl7uRjxoBCjFEIIIYQQ4vlzPyElR+XtrS1wsrWi2DuDJYkqhBCiSMpVInX58uX8+uuv7Nu3jxo1amBuLj0EiKLLycYCc7XhE6ePd9F0Pjyel3/eL8lUIYQQ4n/s3XlclOX6x/HPDDvKpiCL4i6aiuJe5pKWqS2iVlqaHTOzxcyl0mxR0cpfnY62YIupWZaaZWpZ2jHXstx3XHIXBRQBQWWfmd8f5CQHRWSAYfm+Xy9fOffz3M9zzRXgwzX3IiJShEz5rIt6LclpWcUUiYiISNEoVAW0f//+XLp0iXPnzuHo6EhmZibOzs5FHZuIzSwWC2MW7cJsuf5DnNkCqZnZjFm0i+UjOmiav4iIiIhIEXAw3txztZebUzFFIiIiUjQKPZS0W7du1KxZsyhjESlyu6IvEBWTcsPzzBaIiklh9+lkwoK9iz8wEREREZFyzsf95gbb9GgaUEyRiIiIFI1Cr5GqIqqUBSuj4m7u/H03d76IiIiIiFxbJZeCj9txd3GgeQ2vYoxGRETEdoUupIqUBSlp2RR0RpHRoHWZRERERESKytmU9AKfW8e3spbYEhGRUk+FVCnXPN0cKega92aL1mUSEREpbzIzb27XcBEpGt+t3UdMwqUCn+/m5FCM0YiIiBQNFVKlXOvR5ObWWdK6TCIiIuVL5cqVWbJkib3DEKlQVqzehcfY4QRfPGvvUERERIpUoTebAsjKyiI1NRUvL61lI6VTWLA3TYI8ORCbcsORqU4OBqq4a0SqiIhIeZKdnU1aWtp1j2/ZsoXVq1czfvz4EoxKpPxat3ITlV8ehW96MmbAoW49cHfDyckp36n7LiENSi5IERGRQipUITUxMZHBgwezcuVKTCYTderUISIigoEDBxZ1fCI2MRgMTOsXxgMf/0FqZna+xdQsk4VHPtvMgidvpWZV95ILUkRERIrU6tWrOXnyJK1atQLIt3hz+PBhXnvtNRVSRYrA5h/WUOnVF6mclUamgxNV33mXwB53Eh8fj5+fHw4Omr4vIiJlW6Gm9o8fP57ly5cTFhbGvffey6VLl3jsscdYsGBBUccnYrOGAR4sfqY9twR6Wtuu3oCqurcbV16euZBGv0//5Pj5yyUbpIiIiBSZTZs2MXToUFq2bInBYOC1115j4MCBvP3226xYsYKYmBjruadOncLDw8OO0YqUD7sW/oDL+FFUzkrjsrM7Ph/PpMa9d9s7LBERkSJVqBGpK1asoH///tbC6aVLl7jvvvt49dVXeeSRR4o0QJGi0DDAg+UjOrD7dDIr98WRnJaFl5sTPZoG0LyGFz/tjWXkwl2YzBbiUtLp/+mfzH/yVupXq2zv0EVEROQmvfrqqwwcOJAtW7bw8MMP4+3tzbZt2/jmm28wm80YDAZ8fHyoUaMGBw4coFu3bvYOWaRMi5o1D6f/TMVosZDg7k3QpzOp2SbU3mGJiIgUuUIVUs+cOUP37t2trytXrsykSZO48847OXr0KPXq1SuyAEWKisFgICzYm7Bg7zzH7msWhKPRwHPzd5JttnDuYgYPz8wppob4a5SKiIhIWVO7dm1q167Ne++9x9ixY+nduzdpaWns3buX3bt3s2fPHk6ePEnnzp155ZVX7B2uSJl1+K/TpH0wnUoWC6e9Aqg56zPqhta3d1giIiLFolCFVIvFgrOzc662Ro0aYbFYiI2NVSFVyqQeTQP5+FEjz369nSyThfOXMnl45ia+eqIdjYM8b3wBERERsbvHHnuM0NBQQkNDadasGX/88Yf1mJubG23btqVt27Z2jFCk/IhOTGXQd4fwbTuYx/76lUaffEjDxrXsHZaIiEixKdQaqQAxMTFkZWVZXzs55ex2npmZaXtU/yMrK4sJEyZQs2ZNXF1dadasGfPnzy9QX4vFwvvvv09ISAguLi6EhITwwQcfYLHkv4X7G2+8gcFgoFGjRtc8fuTIEcLDw/Hy8sLT05Pw8HCOHj160+9NSpdujf2ZOag1zo453xqJlzMZMGsT+84k2zkyERERKYjff/+dl19+mXvuuYfg4GB8fX3p0qULzz//PLNmzWLz5s1cvqy10EVsYc7MJC7hIgNnbSYuJZ0jgQ2oPfdzmqmIKiIi5VyhRqQCjBs3jldffZXGjRvTqlUr6tevj8FgwGQyFWV8AAwbNowvv/yS4cOHExoaytKlSxk4cCDZ2dk89thj+fadPHkykyZNYtCgQYwdO5b169czcuRILly4wIQJE67ZJzo6mqlTp1KpUqVrHo+Li6Njx444OTkxceJEAN577z06duzIzp078ff3t+0Ni111aVSNWY+15skvt5GRbeZCahYDPtvEvCfa0fwaywKIiIhI6XHs2DFSU1OJiooiKiqKffv28euvvxIZGQnkLPVjMBioXbu2ddRqaGgoDz74oJ0jFykbTJcuceKZ4fyZaCa6aT+cHR2ZOag1betWtXdoIiIixc5gudHQzGtYv349u3fvtv6JiooiIyMDAKPRSK1atWjSpAmNGze2/rdVq1aFCnDnzp20bNmSyZMn8/rrrwM5o0y7du3K/v37iY6OzrPMwBVxcXHUrl2bgQMHMnv2bGv74MGDWbhwISdOnCAgICBPv/79+5OQkEB2djZxcXEcPHgw1/GRI0fyySefsH//fusyBocPH6ZJkyY8++yzvPfeezf1Hk+fPk1wcDDR0dHUqFHjpvpWJCaTifj4ePz8/HBwcCj2+/1x5DxPfLGNtKycDwc8XByZO6QtrWr5FPu9i0NJ56+8Uf4KT7mzjfJnG+Wv8MpL7tasWUPfvn157rnn6NKlCyaTia1bt/LZZ59x6tQpDAYDQLEMBijt9AxaMOXle6EoZMfHc+LJYWT9/fvRpNue4PEXHqVH07y/U12h/NlG+Ss85c42yp9tlL/CK+25K9TU/s6dO/P8888ze/Zstm3bxqVLl9i3bx9ff/01L774IiEhIWzbto133nmHf/3rXzatQ7Vo0SKMRiPDhw+3thkMBp577jnOnTvH2rVrr9t32bJlZGRkMGLEiFztI0aMICMjg2XLluXps379ehYvXsz06dPzjalnz5651oJt0KAB3bt3Z9GiRTfz9qQUa1/fl7mPt8HdOecb92JGNo/N3szWE4l2jkxEREQKavTo0QwbNow33niDO++8k7vvvptXX32Vo0ePMnXqVKpWrcrixYvtHaZIqZd56hTHHxlA1sGDmDHwcWg4Dz//cL5FVBERkfKm0GukXs3BwYHGjRvzyCOP8H//93+sWLGCmJgYzp49yy+//MI777xT6Gtv376devXqUaVKlVzt7dq1A2DHjh359nVxcaFZs2a52lu0aIGzs3OeviaTiREjRvDkk08SGhp6zWvGxMQQFxd3zeJwu3btiI2NJTY2tkDvTUq/dnWrMu+JtlR2yVkF43Kmicdmb+HPowl2jkxEREQK4vDhw4SEhORpd3BwYNy4cXTr1o158+bZITKRsiMtKorjDz9C9unTZBkceLv1QNq88Ax9Wmgks4iIVCyFWiM1KSkJH58bT2/28/OjW7dudOvW7ab6XS0mJobAwMA87UFBQdbj+fX19/fHaMxdLzYajfj7++fp+/HHHxMdHc2UKVPyvSZww5iudfyKlJQUUlJSrK+vFF5NJlOFnFZWUCaTCbPZXOI5CqvhxZePt+Zfc7dxMT2btCwTj8/dwsxHW3J7fd8SjcUW9spfeaH8FZ5yZxvlzzbKX+GVZO6Kc9pWkyZN+PHHHxk6dOg1j3fp0oWXX3652O4vUtZd/vNPooc/hyU1lVRHF6a0/Rc9/9WLQbdqYykREal4ClVIDQ4OpnXr1vTu3Zvw8HDq1Klz3XNPnjzJsmXLWLp0Kb///juZmZk3da+0tDRcXFzytBuNRpycnEhLS7vpvgCurq65+p4/f54JEyYwYcIEfH2vXxy70uda13V1dc11zvVMmzaNiIiIPO0JCQnXjVfAbDaTnJwMkKc4XtyCXOHDvg14fvFfpGSYSM8yM/TL7bx9fz1uq+1VorEUlj3zVx4of4Wn3NlG+bON8ld4JZm7a61ZX1Ree+01+vTpwwsvvMCkSZPw8PDIdXzLli0qtItcR3ZSEtHPDseSlkaSS2Um3DaUe/rewVOd6924s4iISDlUqELqF198wdKlS5kyZQovvPACTZs2JTw8nD59+tCiRQt27dplLZ7u2bMHT09PevbsWahpU66urtaNrK5mNpvJysqyFi9vpi9Aenp6rr6vvvoqfn5+PPfcczeMB7jmddPT03Odcz1jxozJNSoiNjaWtm3bUrVqVfz8/PLtW5Fd+SXH19fXLgsO+/nB/Co+PDZnK4mpWWSaLIz78SgzBrSga6NqJR7PzbJ3/so65a/wlDvbKH+2Uf4Kr7zkLjw8nI8//piRI0cyZ84cBg4cSMuWLYGcjagWLFhAnz597BylSOnk4O3NhvuHUu+n+bx225P06N6a0d3yLpUhIiJSURSqkPrAAw/wwAMPYDKZWL9+PcuWLWPevHm8+eabuLm5kZaWRo0aNejVqxf//ve/ueOOO3B0LNStCAoK4uTJk3nar0yxvzKd/np9V69ejdlszjWSwmw2c/bsWWvfv/76i1mzZjF9+vRc90pLSyMrK4sjR47g5eWFn5+ftc+11kEtSEwAnp6eeHp65ml3cHAo07+olASj0WjXPDWt4cOCYbcxcNYmzl/KJNNk4dn5O/nwkZZlYqF9e+evrFP+Ck+5s43yZxvlr/DKS+6eeuopOnTowBtvvMEXX3zBRx99ZD3Wo0cPPvnkEztGJ1K6WCwW69/f+vkAn2XWwqnri/RtV5cJ9zXGYDDYMToRERH7Klx1828ODg507dqVrl278v7777Nr1y42btxI+/btadGiRZEE2LJlS9asWUNiYmKuDac2b95sPZ5f31mzZrFnzx7CwsKs7Tt37iQzM9PaNyYmBrPZzMiRIxk5cmSe6zRo0ICnnnqKTz75hOrVq+Pv78+WLVvynLd582YCAgJuWEiVsq1hgAcLh93KI59tJv5iBlkmC8/N38H7D7fg3mbXXxtXRERE7KdJkyYsWLCArKwsjh49SmpqKjVr1sx3SSeRisaSnU3sxIk4+vmxoOm9fPbbcQC6t6jJW31DVUQVEZEKr0gXvAoLC2P48OFFVkQF6NevH2azOdfIAYvFQmRkJH5+fnTp0gXIWeP04MGDpKamWs8LDw/H2dmZyMjIXNf88MMPcXZ2Jjw8HMh5sP7222/z/GncuDFBQUF8++23PPXUU9b+Dz30ECtWrODo0aPWtsOHD/PLL7/w0EMPFdl7l9KrfjUPvhl2KwGeOcs4ZJstPL9wJ8t2nbFzZCIiIpIfJycnGjVqRMuWLe1SRLVYLLz//vuEhITg4uJCSEgIH3zwQa5RgNdz+vRppkyZwm233UbVqlVxc3OjZcuWzJ49uwQil/LOnJbG6RHPk7z4exI++ZTfv1oGwJ2NqjG9fxgORhVRRUREbBqRWhJatWrFoEGDmDhxIvHx8YSGhrJ06VLWrVvHnDlzrJszRUZGEhERwdq1a7njjjuAnCn248aNY8qUKWRlZdGpUyfWr1/PvHnzmDBhAoGBOaMH/fz8ePDBB/PcOzIyEpPJlOfYK6+8wrfffkvXrl0ZNWoUANOnT6dKlSqMHz+++JIhpUpdv8p889StDPhsM2cupGEyWxj9zS6yTRYeaFXD3uGJiIhIKTR58mQmTZrEoEGDGDt2LOvXr2fkyJFcuHCBCRMm5Nt36dKlvPXWW/Tq1YsBAwZgNBr5/vvvGTp0KMePH+eNN94ooXch5Y3pwgWin3mWtJ07Afihzu1s929I+3pVmTGwJU4O2qxPREQEykAhFWDWrFnUqlWLuXPn8sknnxASEsK8efN49NFHb9g3IiICHx8fZsyYwcKFCwkODmbatGnWAmhhBAYGsmHDBl544QUmTpwIQOfOnZk2bZq1OCsVQ62qlf6e5r+J00lpmC3w4ne7yTab6d+mpr3DExERkVIkLi6OqVOnMmTIEOso0qFDh+Lg4MBbb73FsGHDCAi4/prrd9xxBydPnqRatX82uRw+fDg9evTg7bffZsyYMbmWwhIpiKzYWE4NfZLMv2fbzW3ck28adKVFLR8+e6w1rk5le41kERGRolQmPlp0dnZmypQpREdHk5GRwd69e/MUUSdNmoTFYrGORr3CYDAwevRojhw5QkZGBkeOHGH06NEFWt9n3bp1HDx48JrHQkJC+PHHH0lJSSElJYUff/yRBg0aFPo9StkVXMWdb566jVpV3QGwWGDc4r18tSnvJmkiIiJS/CIiIti+fbu9w8hj2bJlZGRkMGLEiFztI0aMICMjg2XLluXbv2nTprmKqFc88MADZGdn89dffxVpvFL+ZRw+zIlHBpB59CgWo5H3W/bjm5A7aRzkxdzBbankUibG3YiIiJSYMlFIFSntqnu78c2w26jrW8na9trSfczdeNyOUYmIiFRMn3/+OW3btiUwMJAhQ4awZMkSLl26ZO+w2L59Oy4uLjRr1ixXe4sWLXB2dmbHjh2Fum5MTAyANs6Sm5KdkMCJRweRHReHxdmZt9oNZmXNttT1q8SXT7TFy93J3iGKiIiUOvqIUaSIBHi5snDYrQyYtZkj53J+WZv0436yzRaGdqxr5+hEREQqjhMnTrB3715++uknfvrpJ/r164eDgwMdOnTg/vvv55577rHLTKKYmBj8/f0xGnOPZTAajfj7+1sLojcjKSmJjz76iJYtW1K/fv18z70yk+qK2NhYAEwmEyaT6abvXVGYTCbMZnO5y5HB2xufRx/l/Bdf8mqbwez0qkUNHze+fLwNPm6ORfZ+y2v+SoryV3jKnW2UP9sof4VX0rlzcLi5JWxsLqRu2LCBW265BT8/v2seP3/+PPv376dTp0623kqk1KvmmVNMHfjZZg6dvQjAGz8dIMtk4Zk76tk5OhERkYojNDSU0NBQXn75ZZKSklixYgU//fQTU6ZMYcyYMdSrV4/77ruP++67j06dOuHoWPzjC9LS0qwbpf4vV1dX0tLSbup6ZrOZRx99lKSkJJYvX37D86dNm0ZERESe9oSEhOvGJTl5Tk5OBshTBC/rDnbqyYSTvpxy9MC3khPvhdfDMfMi8fEXi+we5Tl/JUH5KzzlzjbKn22Uv8Ir6dzltz79tdj8xNilSxfmzZvHgAEDrnl89erVDBgwQFV4qTB8K7uwYNitPDprM/tjc0Z9vL3yIFkmM8/fqXV0RURESpqPjw8DBgxgwIABmM1m/vjjD+to1ffeew8PDw+6devGCy+8wG233Vaga65bt44uXboU6NxVq1Zx11134erqSkZGxjXPSU9Px9XVtcDvCeCZZ55hxYoVfPHFF7Rp0+aG548ZM4ahQ4daX8fGxtK2bVuqVq163UERgvX3GF9f35setVKaWCwWkmbNxqFqFbz69uXIuUuMWbaHREcPfNyd+GpoOxpUq1zk9y0v+bMX5a/wlDvbKH+2Uf4Kr7TnzuZCqsViyfd4Zmamqu9S4VSp5Mz8J9sxaPYW9p7J+SRl2qq/yDaZGd0tpECbnYmIiEjRMxqNdOjQgQ4dOjB16lSio6P58ccf+fnnn9m4cWOBC6kNGzbks88+K9C5jRs3BiAoKIjVq1djNptzPR+bzWbOnj1LUFBQgd/HuHHjmDlzJtOnT2fQoEEF6uPp6Ymnp2eedgcHh1L5i0ppYjQay3SeLGYzZ//vbZLmzQMHB5K9/XlsayaJqVl4uDgy74l2NAr0Krb7l/X82ZvyV3jKnW2UP9sof4VXmnNXqEJqSkoKFy5csL5OSEjg1KlTec5LSkpiwYIFVK9evdABipRV3u7OfDW0Hf+as4Vd0RcA+GDNEbLMFsZ2b6hiqoiISCkQHBzMs88+y7PPPntT/QIDA3ON7iyIli1bMmvWLPbs2UNYWJi1fefOnWRmZtKyZcsCXeeNN97gnXfe4bXXXmPUqFE3FYNUPObMTGJfHk/Kzz8D4HRre4ZuTeVsigk3Jwc+f7wNTasXXxFVRESkPCnUUNHp06dTp04d6tSpg8FgYNSoUdbXV/9p2bIlK1eu5KmnnirquEXKBC83J+Y90ZZWtXysbR+vO8pbPx+44WhuERERKV/Cw8NxdnYmMjIyV/uHH36Is7Mz4eHh1rbk5GQOHjxoXSPs6nNff/11nnnmGaZMmVIicUvZZbp0mdNPP20torre34tnGvXnaIoJZwcjMx9rRevaVewcpYiISNlRqBGpV9Z4slgsvPLKK/Tv3z/Xp+oABoOBSpUq0bp1a9q1a1cUsYqUSR6uTnwxpC1D5m5ly/FEAD777ThZJgsT72+skakiIiIVRFBQEOPGjWPKlClkZWXRqVMn1q9fz7x585gwYQKBgYHWc5csWcLjjz/O559/zuDBgwFYtmwZI0eOpGbNmtx666189dVXua7frVs3/P39S/ItSSmWff480cOeIn3/fgAqP/4ET7m05XDcRRyMBiIHtKBjA62NKyIicjMKVUi9/fbbuf322wHIyMigb9++hIaGFmlgIuVJZRdH5j7ehqFfbOOPowkAzP3jBNlmM5N7NcVoVDFVRESkIoiIiMDHx4cZM2awcOFCgoODmTZtWoGm6O/cuROLxcKpU6f417/+lef42rVrVUgVALITEzkxYCBZfy+/5v3SWIZnhLD/1AUMBvjPQ825u8nN7VIsIiIihZzaf7WJEyeqiCpSAO7Ojsz+Vxs6NvC1tn216RSvLNmL2axp/iIiIhWBwWBg9OjRHDlyhIyMDI4cOcLo0aPzzFAZPHgwFovFOhoVYNKkSVgsluv+ueOOO0r2zUip5eDtjXuLFuDkhN/b7/Ci5RZ2nLoAwJu9Q+ndQntYiIiIFIbNhdSNGzcyY8aMXG3z58+nYcOGVKtWjZEjR2I2m229jUi54ObswGePtaZLw3+mUS3cGs1L3+3BpGKqiIiIiBQBg9FI4BtTqDFvHi9fCGDjkZwZUa/ecwsD2tW0c3QiIiJlV5GMSN2wYYP19aFDhxg8eDBGo5HWrVsTGRnJBx98YOttRMoNVycHPhnUirtu+Wfq3eIdpxmzaBfZJn3oICIiIiI3L2XFCi7/8Yf1tdnBkVcOmPn1wDkAnr+zAU92qmuv8ERERMoFmwup+/bty7WZ1IIFC3B3d2fz5s38/PPPDBo0iDlz5th6G5FyxcXRgY8GtqTHVWtTLdsVw8hvdpGlYqqIiIjN1q9fz6VLl+wdhkiJSJz3FWfGvMDp50aQvn8/FouFV5fs5cfdMQA80aEOo+9qYOcoRUREyj6bC6nJycn4+PhYX69cuZJu3brh6ekJQIcOHTh+/LittxEpd5wdjXw4oAX3Nftnh96f9sTy3PwdZGarmCoiImKLrVu3cs899/D444+TnZ1t73BEioXFYuHce+9x9s03wWLBqVYtHHx9eeOnAyzcGg3Aw22Cee3eW/KswysiIiI3z+ZCamBgIPv37wcgLi6Obdu2cffdd1uPp6Sk4OjoaOttRMolJwcj7/UPo3dYkLXtl6izPPv1djKyTXaMTEREpOz65ptv+Oyzz9i4cSMtWrTQs6iUS5bsbGJff52ETz4FwL1dO2rN+5LI3ReY/XvOQJZezYN4s0+oiqgiIiJFxOanyr59+xIZGUlGRgZbt27FxcWFXr16WY/v3r2bunW1Fo/I9Tg6GPlPvzAcHYx8t/00AL8eOMdT87bzyaOtcHVysHOEIiIiZUv//v25dOkS586dw9HRkczMTJydne0dlkiRMaelceaFF7m0Zg0AHj16EPTO28zedJr3Vx8G4K5b/PlPv+Y4GFVEFRERKSo2F1IjIiI4e/YsX331FZ6ensyZMwd//5xNdFJSUli8eDHPPfeczYGKlGcORgPvPNAMJwcDC7bkTMNadyieJ7/cxsxBrXFzVjFVRETkZnTr1o2aNbU7uZQ/ppQUop9+hrQdOwDwGTAA/1dfYf6207z58wEAbq9flcgBLXBysHkCooiIiFzF5kJqpUqVmDdv3jWPVa5cmTNnzuDu7m7rbUTKPaPRwJu9Q3E0Gpm36SQAvx0+z5C5W5k9uDXuzpqWKCIiUlAqokp5ZXBxwfD3chV+o0ZS9amnWLYrhteW7gOgVS0fZg5qrVlNIiIixaBIKzP79u3j+PHjGAwGateuTdOmTfHy8irKW4iUa0ajgcnhTXB0MPD5xhMA/HksgcFztjLn8TZUdlExVURERKQiM7q4UGNGJJc3bsSzRw9+iYrjhW93Y7FA40BP5gxuQyU9M4qIiBSLIpnrsXz5curVq0fz5s3p3bs34eHhNG/enPr167N8+fKiuIVIhWEwGJhwX2Oe6vTP2sJbTiTy2OzNpKRn2TEyEREREbGHtF27yDh82PrawcMDzx49+O1wPCPm78RktlDPrxLznmiLl5uTHSMVEREp32wupP73v/+ld+/emEwm3nzzTZYsWcL333/Pm2++iclkok+fPqxataooYhWpMAwGAy/3bMTwLvWsbTtOXWDQrM0kp6qYKiIiUlAbNmwgPj7+usfPnz/Phg0bSjAikZtzcd06Tg5+nFNPDiMrNtbavvVEIk9+uY1Mk5ngKm58PfRWqlZ2sWOkIiIi5Z/NhdSIiAgaN27M3r17efnll+nVqxfh4eG8/PLL7Nmzh0aNGjF58uSiiFWkQjEYDLx4d0NG3dXA2rb7dDIDZ28i6XKmHSMTEREpO7p06ZLvh/qrV6+mS5cuJRiRSMFd+H4Jp4c/hyU9HSwWzJcvA7D3dDJDPt9KepYZf08Xvn7iVgK8XO0crYiISPlncyF1165dPP7443h4eOQ55uHhwZAhQ9i5c6ettxGpkAwGA6PuCuHFu0OsbfvOpPDIZ5tIuJRhx8hERETKBovFku/xzMxMjEbtbC6li8Vi4fxnnxH7yitgMuFcty61F8zHpX59Dp+9yGNzNnMxI5sqlZz5emg7albV5r4iIiIlweZVyJ2dnbn89yej13Lp0iWcnLROj4gtnuvaACcHI1NXHATgYNxFHvlsE18PvRU/D03hEhERuVpKSgoXLlywvk5ISODUqVN5zktKSmLBggVUr169BKMTyZ/FbObs//0fSV/OA8CteXNqfPIxjj4+nEy4zMBZm0lKzcLD1ZEvh7SlfrW8A1pERESkeNj88XvHjh2JjIzk8FWLn19x5MgRZsyYQadOnWy9jUiF91Tnerx+X2Pr67/OXuLhmX9yLiXdjlGJiIiUPtOnT6dOnTrUqVMnZ3bHqFHW11f/admyJStXruSpp56yd8giAJgzM4l58SVrEbVy587U/HwOjj4+xCanMeCzzZy7mIGbkwNzH29D0+pedo5YRESkYrF5ROrUqVNp3749TZs2pVevXjRs2BCAgwcPsnz5ctzc3Jg6darNgYoIPNGhDk4OBiYsiwLgaPxl+s/cxPwn2xHo5Wbn6EREREqHu+66C1dXVywWC6+88gr9+/cnLCws1zkGg4FKlSrRunVr2rVrZ59ARf6HJTWV9IM5M5C8evcmcMpkDE5OnL+UwcBZmzlzIQ1nByOfPdaaVrWq2DlaERGRisfmQmqTJk3Ytm0b48ePZ+XKlSxevBiASpUqcf/99/Pmm28SEhJyg6uISEE9dlttHI1GXlmyF4Dj5y/T/9OcYmoNH62PJSIicvvtt3P77bcDkJGRQd++fQkNDbVzVCI35uDtTc1Zn5H8w49UfWoYBoOB5NQsBs3ewrH4yzgYDcwY2JIODXztHaqIiEiFVCQr6zdo0IDvvvuO5ORkYmNjiY2NJTk5mW+//bZIiqhZWVlMmDCBmjVr4urqSrNmzZg/f36B+losFt5//31CQkJwcXEhJCSEDz74IM/GA6tWreK+++4jODgYV1dXAgMD6dmzJxs3bsxzzcGDB2MwGPL8qVGjhs3vVaQgBrSryTsPNsNgyHl9KjGV/p9uIjox1b6BiYiIlDITJ05UEVVKtczoaLLj462vnYKC8H36KQwGA5cyshk8dwsHYlMwGGBav+Z0a+xvx2hFREQqNptHpF7NaDTi71/0/7APGzaML7/8kuHDhxMaGsrSpUsZOHAg2dnZPPbYY/n2nTx5MpMmTWLQoEGMHTuW9evXM3LkSC5cuMCECROs5x04cAAXFxeeffZZqlWrRlJSEl999RWdOnXixx9/5J577sl1XScnJ+bMmZOrrVKlSkX3pkVuoF/rYJwcDLywaDdmC5y5kEa/T/9kwZO3UttXX4siIiIAGzduZNeuXQwfPtzaNn/+fCIiIkhKSuKRRx5h+vTpGI1FMr5ABIDYSZPI+CvvHhL/y5yaSuaJEzjXrUOtL7/EoXJl67H0LBNPfrGNnacuAPBWn1DCw7QxmoiIiD3ZXEi9fPkyCQkJ1KxZ85rHT506ha+vL+7uhZtyvHPnTubOncvkyZN5/fXXARg6dChdu3blpZde4uGHH8bZ2fmafePi4pg6dSpDhgxh9uzZ1r4ODg689dZbDBs2jICAAACef/55nn/++Vz9n332WerWrcv06dPzFFKNRiOPPvpood6TSFHp06IGDkYjo7/ZhclsITY5nf4z/2T+k7dSz6/yjS8gIiJSzk2cOJGqVataC6mHDh1i8ODB1KtXj9atWxMZGUmdOnUYNWqUfQOVciXjr8Ok7dhR4PMzT5wk88gR3P5eyzfLZGb41zv481gCAK/dewuPtL3271siIiJScmz+6H306NGEh4df93jv3r158cUXC339RYsWYTQac40iMBgMPPfcc5w7d461a9det++yZcvIyMhgxIgRudpHjBhBRkYGy5Yty/fe7u7u+Pr6cuHChWseN5vNpKSk5FkmQKQk9WoexIePtMDRmDPP/2xKBv0/3cThsxftHJmIiIj97du3L9dmUgsWLMDd3Z3Nmzfz888/M2jQoDyzjERKlKMjtb74wlpENZktjP5mF6sPngNg1F0NGNqxrh0DFBERkStsLqSuWrWKPn36XPd4nz59+OWXXwp9/e3bt1OvXj2qVMm9K+WVB+Id+XzSu337dlxcXGjWrFmu9hYtWuDs7HzNvsnJyZw/f579+/fz0ksvERUVRbdu3fKcl5mZiaenJ15eXvj4+DBs2LDrFlxFits9oYF8NLAlTg45xdTzlzJ4eOYmDsal2DkyERER+0pOTsbHx8f6euXKlXTr1g1PT08AOnTowPHjx+0VngiuISG4hTYFwGy28Mr3e1m+JxaAJzvWYeSdDewZnoiIiFzF5qn9sbGxBAYGXvd4QEAAMTExhb5+TEzMNa8fFBRkPZ5fX39//zxrXl1Zy/Vafe+9917rBlMuLi489dRTudZSBQgMDOTFF1+kVatWQE4xedasWWzfvp0//vgDFxeXfN9TSkoKKSn/FLhiY3MelEwmEyaTKd++FZnJZMJsNitH13FnIz8+HtCCZ+fvJNNkIeFyJo/M3MSXj7ehcZCn8mcj5a/wlDvbKH+2Uf4KryRz5+DgUGzXDgwMZP/+/UDOsk/btm3jiSeesB5PSUnB0bFItw0QuSkGV1cgZ5PcKT/t55tt0QA80rYmr9xzC4Yru4uKiIiI3dn81Ojn50dUVNR1j0dFReHt7V3o66elpV2zMGk0GnFyciItLe2m+wK4urpes+97771HYmIip06dYs6cOaSnp5OVlYXr3w84AFOnTs3Vp3///txyyy28+OKLzJs3j6FDh+b7nqZNm0ZERESe9oSEhBsWYSsys9lMcnIygDaEuI6mVQ28c389xv14lAyThaTULAbO2sz7fRvQ0M9N+bOBvv4KT7mzjfJnG+Wv8Eoyd1fWrC8Offv2JTIykoyMDLZu3YqLiwu9evWyHt+9ezd162ratNjf9FV/8fnGEwCEhwXxRu+mKqKKiIiUMjYXUu+55x5mzpxJv379aN++fa5jmzZtYubMmTzyyCOFvr6rqysZGRl52s1mc54CZ0H7AqSnp1+zb+vWra1/f/TRRwkLC2PIkCF8++23+cb5/PPPM378eH799dcbFlLHjBmT65zY2Fjatm1L1apV8fPzy7dvRXZlRIyvr2+xjlwp6+7386NKFR+GzdtOepaZlAwTzy85wpxBLanu5aX8FZK+/gpPubON8mcb5a/wykvuIiIiOHv2LF999RWenp7MmTMHf39/IGc06uLFi3nuuefsHKVUdJ+uP8oHa44A0K2xP+8+1BwHo4qoIiIipY3NhdSIiAh+/vlnOnXqRM+ePWnaNOeT071797JixQoCAgKYMmVKoa8fFBTEyZMn87RfmZZ/ZYr/9fquXr0as9mcaySF2Wzm7Nmz+faFnEJsr169ePfdd0lLS8PNze265zo5OREQEEBiYuKN3hKenp7Wdbmu5uDgUKZ/USkJRqNReSqATiHVmPt4W4bM3UpqpomL6dk8/uV2/hNenzurGdlzJoWVUXGkpGXj6eZIjyYBhAV7a9TDDejrr/CUO9sof7ZR/gqvPOSuUqVKzJs375rHKleuzJkzZ3B3dy/hqET+ce5iOlNXHASgQ31fPnykBU4OGkEvIiJSGtlcSA0ICGDbtm2MGzeOpUuX8tNPPwE5xcJBgwYxdepUm6ZrtWzZkjVr1pCYmJhrw6nNmzdbj+fXd9asWezZs4ewv3fBBNi5cyeZmZn59r0iLS0Ni8XCxYsX8y2kpqenExsbS6dOnQrwrkSK3611q/LlkLYM/nwrlzKyuZRh4vnv/yJwdTQnElIBMBrAbIFP1x+jSZAn0/qF0TDAw86Ri4iIFI99+/Zx/PhxDAYDtWvXpmnTpnh5edk7LKngTv39XNaqlg8zH2uFq1PZ/eBCRESkvCuSjzr9/f2ZO3cuSUlJxMXFERsbS1JSEp9//rnNa17169cPs9nMRx99ZG2zWCxERkbi5+dHly5dADh//jwHDx4kNTXVel54eDjOzs5ERkbmuuaHH36Is7Mz4eHh1rZz587luXdiYiJLliwhODiYatWqAZCRkZHrHle89dZbZGdn0717d5ver0hRal27Cl8+0RYPl5zPTDKyLdYiKuQUUa84EJvCAx//waG4iyUdpoiISLFavnw59erVo3nz5vTu3Zvw8HCaN29O/fr1Wb58ub3DE6FJkCdzBrfB3Vkbn4mIiJRmRfovtcFgsBYci0qrVq0YNGgQEydOJD4+ntDQUJYuXcq6deuYM2eOdXOmyMhIIiIiWLt2LXfccQeQM7V/3LhxTJkyhaysLDp16sT69euZN28eEyZMIDAw0Hqf22+/nebNm9O6dWt8fX05ceIEc+bM4ezZs3zzzTfW82JjY7ntttvo27cvISEhGAwGfv31V3788Uc6d+5s03qwIsWhZU0fvhralj4f/ZGrcPq/zBZIzcxmzKJdLB/RQdP8RUSkXPjvf/9L7969qVGjBm+++SaNGzfGYrFw4MABPv30U/r06cPPP/9Mt27d7B2qVFCuTg58OaQtXm5O9g5FREREbqBMfOQ5a9YsatWqxdy5c/nkk08ICQlh3rx5PProozfsGxERgY+PDzNmzGDhwoUEBwczbdo0Ro0aleu8YcOGsWTJEt59912Sk5OpUqUKt912Gy+88AIdO3a0nuft7U23bt1YtWoVX3zxBdnZ2dStW5eIiAjGjh2Lo2OZSKlUMGYL+RZRrz4vKiaF3aeTCQv2Lva4REREiltERASNGzdm48aNeHj8s3xNeHg4w4cPp3379kyePFmFVClSLiENSMs0cfpkLDUu5sx8O+YZSJqjS55z67dsStXKedtFRESk9DFYLJYClFekOJ0+fZrg4GCio6OpUaOGvcMptUwmE/Hx8fj5+ZXpTS/sYeqKA3y6/liBz3+6cz1e7tmoGCMqe/T1V3jKnW2UP9sof4VXXnJXqVIl3njjDUaPHn3N49OnT+f111/n0qVLJRyZ/ekZtGAK871gsVi478PfGbLwLRonnmBv1bqM7fhsnvOMBrgl0LNczwYqLz9L7EX5KzzlzjbKn22Uv8Ir7bnTdpAiFUBKWjbGAj6bGw2QnJZVvAGJiIiUEGdnZy5fvnzd45cuXcLJSVOqpWjtir6AZe9uGieeAOC7Bndc87yrZwOJiIhI6adCqkgF4OnmWKCp/ZDzQK81ukREpLzo2LEjkZGRHD58OM+xI0eOMGPGDDp16mSHyKQ8WxkVxwOH1wFw0sOfrf75z/RZuS+uBKISERERW2lBT5EKoEeTgJua2t80yOPGJ4mIiJQBU6dOpX379jRt2pRevXrRsGFDAA4ePMjy5ctxc3Nj6tSpdo5Syp2TJ7k1bj8Ai+t3xmK4/vgVzQYSEREpO266kHrq1KlC3ahmzZqF6icitgsL9qZJoAcH4i4WaGTq6EW7OXMhnaEd6+JQ0DUBRERESqEmTZqwbds2xo8fz8qVK1m8eDGQs3bq/fffz5tvvklISIido5Typs6pKIxYSHD1ZF2Nlvmeq9lAIiIiZcdNF1Jr165dqIXQTSbTTfcRkaJhMBj490PNeOiTP0nLMl+zmGoArjRnmSxMXXGQlVFxvPtQc+r5VS7JcEVERIpUgwYN+O677zCbzcTHxwPg5+eH0ahVrqToWSwWDt3agxnJPvinJpLlcONfuXo0DSiByERERMRWN11InTNnTrndUVKkPGvo78HMfo14a3U0UbEXgZypZFeKqo2DPBnasS4frj7MsfM5m3LsPHWBe97/jZe6N+Tx2+todKqIiJQ5ly9fJiEhgZo1a2I0GvH39891/NSpU/j6+uLu7m6nCKU8SbycyYvf7mbNwXPg4U+0h3++5xsNOc9gzWt4lVCEIiIiYoubLqQOHjy4GMIQkZJQz9eNZcPbsy/2Eiv3xZGcloWXmxM9mgbQvIYXBoOBnk0DePeXQ8zeeByLBTKyzbzx0wFW7ovj3w81p45vJXu/DRERkQIbPXo0W7duZefOndc83rt3b2699VY++uijEo5MyptNxxIYuWAHZy9mAtChflV2Rl8gLdN0zdlARgO4OzsyrV+YBqqIiIiUEdpsSqSCMRgMhAV7Exbsfc3jrk4OvHZfY7o3DeClb3dzIiEVgG0nk+j5/gbGdm/E4Pa1MWp0qoiIlAGrVq3i8ccfv+7xPn36MHfu3JILSModk9nCB6sP8+Gaw/Q6vIFbzx3A7bHH6fVEO/46e4kxi3YRFZMC5J4NdEugJ9P7hxHir00+RUREyooiKaSeO3eO2bNns337di5cuIDZbM513GAwsHr16qK4lYiUkDa1q7BiZCfe+eUgn288AUB6lpnJy/ezMiqOfz/YjFpVNTpVRERKt9jYWAIDA697PCAggJiYmBKMSMqT2OQ0Ri7cxZbjiTiYTTx04neqXErE6/AmDIY+NAzwYPmIDuw+nXzd2UAiIiJSdthcSN2/fz+dO3fm0qVLNGzYkL1799K4cWOSkpKIiYmhXr16BAcHF0WsIlLC3JwdmHh/E7o3CWDsd3s4lZgzOnXL8UR6vPcb4+9pxKPtaml0qoiIlFp+fn5ERUVd93hUVBTe3t4lF5CUG6sPnOXFb3eTlJoFwEsu0VS5lAhAlSFPWM+70WwgERERKTts3qr05ZdfxsXFhQMHDvDrr79isVh4//33OX36NF9//TVJSUn8+9//LopYRcRObq1blZWjOvKv22pZ29KyTExYFsXAWZuJ/rvAKiIiUtrcc889zJw5kz/++CPPsU2bNjFz5kzuueceO0QmZVVGtonJP+7niS+2kZSaRWUXR97v35y79+XMwKvUsSOuDUPsHKWIiIgUB5sLqb/99htPPfUUtWvXxmjMudyVqf2PPPII/fv356WXXrL1NiJiZ+7OjkSEN2X+k+2o4eNmbf/zWAI93tvAV5tOYrFcYycFERERO4qIiKBq1ap06tSJ+++/n/Hjx/PKK69w//3306FDB6pUqcKUKVPsHaaUESfOX+bBj/9kzsbjAIRW92L5iA7clXaKjAMHAKj6xBB7higiIiLFyOZCamZmpnXdKTe3nOJKcnKy9XhYWBhbt2619TYiUkq0r+fLylGdGNiuprXtcqaJ15buY9DsLZxO0uhUEREpPQICAti2bRuPPvoov/32G2+//Tb/93//x2+//cagQYPYtm0bQUFB9g5TyoAfdsdw7we/sfdMzu86T3Sow+Jn2lPbtxIJs+cA4Nq4Me7t2tkzTBERESlGNq+RWqtWLU6cOAHkFFIDAwP5448/eOCBBwDYt28flStXtvU2IlKKVHZx5M0+ofRsGsi4xXs4cyENgN+PnKfHe7/x6r238HCbYG2gICIipYK/vz9z587FYrEQHx+PxWKhWrVq+ndKCiQ1M5s3V53gx6gEAHzcnXj3oebceYs/AOkHDnB540YAqjwxRF9XIiIi5ZjNI1K7dOnCsmXLrK8HDhzIBx98wNChQxkyZAgfffQR4eHhtt5GREqhDg18WTmqI4+0/WdDuUsZ2Yz/fi//+nwrMX8XWEVEREoDg8FAtWrV8Pf3V7FLCuRAbAq9P/rTWkRtV6cKK0Z2shZRAS6uXQuAU/XqeHbvbpc4RUREpGTYPCJ13LhxdO3alfT0dFxdXZkyZQoXLlzg22+/xdHRkUGDBvHuu+8WRawiUgp5uDoxtW8zejQN5OXFe4hNTgdgw1/xdJ++gdfvb8xDrWroF1YREREpMywWC19vPsXk5fvJzDZjNMCILvV5/q4QHIy5n2n8nn2WSrfdhjklBYOjzb9eiYiISClm87/0NWvWpGbNf9ZKdHFxYebMmcycOdPWS4tIGdI5xI9fRnfijeX7WbTtNAAXM7IZ+90eVuyNZWrfZgR4udo5ShERKe8ee+wxmjVrRtOmTWnWrJnWP5WblpyWxcuL97BiXxwA/h4uTOxei+4t6uYpol7h3qJFSYYoIiIidqKPTEWkyHi6OvHOg83p2TSQl7/fw9mUDADWHorn7unrmXh/E/q2rK7RqSIiUmx+//13vv76aywWCwaDAR8fH0JDQwkNDaVZs2aEhobStGlTKlWqZO9QpRTacSqJEfN3Wtd/79qoGm/3bYopNfkGPUVERKQiKJJC6g8//MDs2bM5evQoSUlJWCyWXMcNBgNnzpwpiluJSBnQpVE1/juqMxHLo/h+R873fkp6Ni98u5sV+2J5q08o1Tw1OlVERIresWPHSE1NJSoqiqioKPbt28evv/5KZGQkkPNcajAYqF27trW42qxZM+tGqVIxmc0WPt1wjHf/ewiT2YKTg4GXe97CkNtrYzabiU/N2ydh9mzS9u6j6hNP4BbatOSDFhERkRJncyF18uTJRERE4O3tTfPmzWnQoEFRxCUiZZyXuxPT+oVxT9NAxi/ZS/zFnNGpvx44x9YTG4jo1YTwsCCNThURkSLn7u5OmzZtaNOmDWvWrGHWrFm88sordOnSBZPJxNatW/nss89YtmwZP/zwAwAmk8nOUYu9xF/MYMyiXfx2+DwAtaq68+EjLWhWw/u6fcyZmSTMnYsp/jwOXl4qpIqIiFQQNhdSIyMjufPOO/nxxx9xcXEpiphEpBy5q7E/rWv7MOmHKJbuigFy1h4b9c0uft4by5t9QvHz0M8OEREpHqNHj2bYsGG88cYb1ra7776bl19+mXfffZf//Oc/Wtu/Avv98HlGfbOL85dyPvDt1TyIN/s0xcPVKd9+KT/8gCn+PBgMVH18cAlEKiIiIqWB0dYLZGVl8cADD6iIKiLX5e3uzHsPt+CTR1vhW9nZ2v7f/We5e/p6ftwdk2dJEBERkaJw+PBhQkJC8rQ7ODgwbtw4unXrxrx58+wQmdhTlsnMOysPMmjOZs5fysDVycg7DzTj/YfDblhEtZjNJMz5HACPu+7CuXbtEohYRERESgObC6l33303W7duLYpYRKSc69E0gP+O7sz9zf/ZQTkpNYsRC3YyfP4OEv4eDSIiIlJUmjRpwo8//njd4126dGH9+vUlGJHY2+mkVB6euYmP1h3FYoFGAR4sH9GBfm2CC7Tk0KV168g8dgyAqk8MKe5wRUREpBSxuZAaGRnJ9u3bmTBhAqdOndKoMhHJV5VKznz4SAs+GtiSKpX+GZ3689447p6+gZ/3xtoxOhERKW9ee+01fvzxR1544QUuXryY5/iWLVu0PmoFsnJfLPe8/xvbTyYB8OitNVk6/HbqV/Mo8DUSZs8BwK1VK9zCwoojTBERESmlbF4j1c/Pj379+vHaa6/x5ptvXvMcg8FAdna2rbcSkXLkntBA2tapwoRl+/h5bxwACZczefbrHdzXLJDJ4U1zFVpFREQKIzw8nI8//piRI0cyZ84cBg4cSMuWLQFYs2YNCxYsoE+fPnaOUopbepaJN386wLxNJwHwcHXk7QeacU9o4E1dJ3XnTtK2bwc0GlVERKQisrmQOnbsWP7zn/9Qt25d2rRpg5eXV1HEJSIVgG9lFz4a2Irle2J4fek+klKzAFi+J5ZNxxJ4o3coPZoG2DlKEREp65566ik6dOjAG2+8wRdffMFHH31kPdajRw8++eQTO0Ynxe1o/CWem7+TA7EpALSo6c0HD7cguIr7TV8r5ecVADjXrUvlO+4oyjBFRESkDLC5kDpnzhz69OnDd999VxTxiEgFdF+zINrVqcprS/fyS9RZAM5fyuTpr7bTOyyISb2a4O2u0akiIlJ4TZo0YcGCBWRlZXH06FFSU1OpWbMmvr6+JRqHxWLhgw8+YMaMGZw8eZJatWrx3HPPMWLEiBuuz5mcnMyoUaPYvHkzZ86cwWw2U69ePYYMGcIzzzyDk1P+myRVRN9tP82EZftIzcxZvuHpzvV44e4QnBwKt8KZ/yvjqdzhdjAYMBhtXiVNREREyhib//U3m81069atKGK5rqysLCZMmEDNmjVxdXWlWbNmzJ8/v0B9LRYL77//PiEhIbi4uBASEsIHH3yQZy3XVatWcd999xEcHIyrqyuBgYH07NmTjRs3XvO6R44cITw8HC8vLzw9PQkPD+fo0aM2v1eRisrPw4VPHm3F+w+H4eX2zy+CS3fF0G36BlbtP2vH6EREpLxwcnKiUaNGtGzZssSLqACTJ09m1KhR3HrrrcyYMYN27doxcuRIpkyZcsO+KSkp/PXXX/Tq1YupU6fy73//m2bNmjFq1CgGDRpUAtGXHZcyshnzzS5e/HY3qZkmfCs78+WQtrzcs1Ghi6iQs2RZ5c6dqdypUxFGKyIiImWFzSNSe/Xqxbp163jqqaeKIp5rGjZsGF9++SXDhw8nNDSUpUuXMnDgQLKzs3nsscfy7Tt58mQmTZrEoEGDGDt2LOvXr2fkyJFcuHCBCRMmWM87cOAALi4uPPvss1SrVo2kpCS++uorOnXqxI8//sg999xjPTcuLo6OHTvi5OTExIkTAXjvvffo2LEjO3fuxN/fv3gSIVLOGQwGwsOqc1vdqryyZC+/HjgHQPzFDJ78cht9W1Rn4v1N8HLXiBsRESl74uLimDp1KkOGDGH27NkADB06FAcHB9566y2GDRtGQMD1l7QJDg7O8yH/008/jZeXF5GRkbz77rvUqFGjWN9DWbDvTDIjFuzk+PnLAHSo78u0/s2p5uFq58hERESkrDNY/ndo5k3666+/ePjhh2ndujVDhw4lODgYBweHPOdVq1atUNffuXMnLVu2ZPLkybz++utAzijTrl27sn//fqKjo3F2vvaU37i4OGrXrs3AgQOtD6sAgwcPZuHChZw4cSLfh9XU1FTq1q1LaGgoq1atsraPHDmSTz75hP3791OvXj0ADh8+TJMmTXj22Wd57733buo9nj59muDgYKKjo/Xwmw+TyUR8fDx+fn7X/BqT/JW1/FksFpbsPMOkH6JISf9nszp/Txem9g2la6OS/cCirOWvNFHubKP82Ub5Kzzlruh9+umnPP300+zcuZOwq3Z73759O61bt+aTTz4p1OCEd999l5deeol9+/bRpEmTAvcrb8+gFouFuX+cYOrPB8k0mXEwGhjTLYRnOtfDaMx/2YT8mEwmTk2fjktyClWHPI5LnTpFGHX5p58ltlH+Ck+5s43yZxvlr/BKe+5sHpHaqFEjAHbt2pWrWPm/TCZToa6/aNEijEYjw4cPt7YZDAaee+45HnzwQdauXUv37t2v2XfZsmVkZGQwYsSIXO0jRozgiy++YNmyZfk+rLq7u+Pr68uFCxfyxNSzZ09rERWgQYMGdO/enUWLFt10IVVE8jIYDPRtWYP29XwZ//0e1h6KB+BsSgZD5m7joVY1eO2+xrmWARARESnNtm/fjouLC82aNcvV3qJFC5ydndmxY0eBrpORkcHFixdJTU1l8+bNvPPOO9SqVYuQkJB8+6WkpJCSkmJ9HRsbC+Q8pxf2Wb20SErN5OXF+/j1YM5sliBvV97r15xWtXywWMzY8vayLl4kY9Ei0lMuYnB1pdrL44oo6orBZDJhNpvL/NeYvSh/hafc2Ub5s43yV3glnbubLdbaXEidMGHCDRfGt8X27dupV68eVapUydXerl07AHbs2HHdQmphHlaTk5PJysri3LlzfP7550RFRTF+/Hjr8ZiYGOLi4mjbtm2evu3atWP58uXExsYSGBh40+9VRPIK8HJlzuA2fLf9NJN/3M/FjJzRqd9uP83vR87zfw80o3OIn52jFBERubGYmBj8/f0x/s8mRUajEX9/f2JiYgp0nXnz5vHkk09aX7dp04bZs2ffcLOpadOmERERkac9ISEBFxeXAt27NNp15iITVhzn3KUsADrX8+aVbrXwcs0mPj7e5uunfbcYS8pFMBox33tPkVyzIjGbzSQnJwPk+dqXG1P+Ck+5s43yZxvlr/BKOnf5zVS/FpsLqZMmTbL1EvmKiYm5ZlEyKCjIejy/vjf7sHrvvfda155ycXHhqaeeyrWW6pU+N4opv0JqeR4NUJz0iY5tynr++rYI4tY6PryyNIrfDp8HIDY5nX/N2UK/1jV4pWcjPFxt/pF2XWU9f/ak3NlG+bON8ld4JZm70jhtqzikpaVdt2Dp6upKWlpaga5zzz33sGrVKpKSkvjvf//Lnj17uHTp0g37jRkzhqFDh1pfx8bG0rZtW6pWrYqfX9n7UNJktvDx+mO8v/owZgs4Oxp5tWcjBrYLLrKBHpbsbI5//z0AlXv2ICA0tEiuW5Fc+Rni6+tbYb7Xi5LyV3jKnW2UP9sof4VX2nNXfFWHInK9B06j0YiTk1O+D5yFeVh97733SExM5NSpU8yZM4f09HSysrJwdXW1XhO45nX/95zrKa+jAYqbPtGxTXnInxPwzj01+SGqEu9viCY10wzAom2nWX/oLK92q03bmp7Fcu/ykD97Ue5so/zZRvkrvJLM3c2OBCgN1q1bR5cuXQp07qpVq7jrrrtwdXUlIyPjmuekp6dbnyVvJCgoyPoB/kMPPcSkSZO4++67OXz4cL659PT0xNMz77+TDg4OpfIXlfycTUln1MJd/HksAYC6fpX48JEWNAnyKtL7pPzyC9l/D6SoOmRImctTaWE0Gsvk11lpofwVnnJnG+XPNspf4ZXm3N10IfXUqVMA1KxZM9frG7ly/s263gOn2WzOVeC8mb5w/YfV1q1bW//+6KOPEhYWxpAhQ/j222+t1wSued309PRc51xPeRsNUFJK+6cSpV15yt/QatW4p0VtXv5+HxuP5vwCdfZiFs9/f5hH2gTzcs+GVHYp2s+JylP+SppyZxvlzzbKX+Epd/lr2LAhn332WYHObdy4MZBTAF29ejVmszlXcdpsNnP27FlrcfRm9evXj4iIiBuu/19erD10jhcX7SbhciYAD7aqQUSvJlQq4n/7LRYLCbPnAODYujUuf+8NISIiIhXXTT9t1K5dG4PBQFpaGs7OztbXN1LYaWFBQUGcPHkyT/uVKfb5PXDa+rDq6upKr169ePfdd0lLS8PNzc3a58p0/JuNCcrXaICSVpo/lSgLylP+gqtW5quh7Zi/5RRv/nSA1MycnzELtkaz4fB5/v1gM9rX9y3Se5an/JU05c42yp9tlL/CU+6uLzAwMNcH4wXRsmVLZs2axZ49ewgLC7O279y5k8zMTFq2bFmoWK7Mhroygri8ysw28+5/DzFzwzEA3J0deLNPU/q0qFEs90vdvIX0qCgAXB/uXyz3EBERkbLlpgupc+bMwWAwWBezv/K6uLRs2ZI1a9aQmJiYa8OpzZs3W4/n19fWh9W0tDQsFgsXL17Ezc2N6tWr4+/vz5YtW/Kcu3nzZgICAgo9mkBEbo7BYGBgu1p0auDH2O/2WKf3nbmQxoBZm3nstlqM69GoyEeoiIiIFEZ4eDijRo0iMjKSWbNmWds//PBDnJ2dCQ8Pt7YlJydbNzD18sqZrh4fH3/N2UtXrnX1zKry5lRCKiMW7mR39AUAmgR58uEjLajrV7nY7pm8ZAkALo0a4diqVbHdR0RERMqOm64uDB48ON/XRa1fv3688847fPTRR7z22mtAzjSbyMhI/Pz8rGtTnT9/nvPnz1OzZk3c3d2Bm3tYPXfuHNWqVct178TERJYsWUJwcHCuYw899BAzZ87k6NGj1KtXD4DDhw/zyy+/8PTTTxdPIkTkuoKruPP10HZ8vfkkb/18kLSsnNGpX/55knWH4vn3g81oV7eqnaMUEZGKLigoiHHjxjFlyhSysrLo1KkT69evZ968eUyYMCHXZqVLlizh8ccf5/PPP7c+b8+YMYMlS5Zw7733Urt2bVJSUlixYgVr1qzh/vvvp2vXrnZ6Z8Vr+Z4Yxi/ey8WMbAAGt6/N+Hsa4eJYvCOlA9+YQqXb22Pw9CK1GAeOiIiISNlh8zCtDRs2cMstt1x3bc/z58+zf/9+OnXqVKjrt2rVikGDBjFx4kTi4+MJDQ1l6dKlrFu3jjlz5lg3Z4qMjCQiIoK1a9dyxx13ADf3sHr77bfTvHlzWrduja+vLydOnGDOnDmcPXuWb775JldMr7zyCt9++y1du3Zl1KhRAEyfPp0qVaowfvz4Qr1PEbGN0Whg0G216RTix0vf7WHL8UQATiWm0n/mJga3r83YHg1xd875sWexWNgVfYGVUXGkpGXj6eZIjyYBhAV7F+soexERqdgiIiLw8fFhxowZLFy4kODgYKZNm2Z9pszPXXfdxZ49e/jqq684e/YsTk5O3HLLLUyfPp3nnnuu+IMvYWmZJiYvj2LBlmgAvN2d+PeDzenW2L9E7m9wcsKrVy9MJhOp8fElck8REREp3WwupHbp0oV58+YxYMCAax5fvXo1AwYMKPQaqZAzXalWrVrMnTuXTz75hJCQEObNm8ejjz56w74FfVgdNmwYS5Ys4d133yU5OZkqVapw22238cILL9CxY8dc5wYGBrJhwwZeeOEFJk6cCEDnzp2ZNm1aruKsiJS8WlUrsfDJW/nizxO8vfIg6VlmAOb+cYJ1h87x74ea4+nqxJhFu4iKSQHAaACzBT5df4wmQZ5M6xdGwwAPe74NEREppwwGA6NHj2b06NH5njd48OA8M786dOhAhw4dijG60uOvsxd5bv4O/jp7CYA2tX14/+EWBHm72TkyERERqcgMFovFYssFjEYjX3311XULqfPmzWPIkCFkZWXZcpty7fTp0wQHBxMdHU2NGsWzWH55YDKZrGuDadOLm1cR83f8/GVe+nY3204m5Wp3cjBgMlswX+Onn9EA7s6OLH6mfa5iakXMX1FR7myj/NlG+Ss85a78K43PoBaLhYVbo4n4MYr0LDMGA4zoUp/n72yAo4PxxhcoAvEffYQlLR2fQY/iVK2avhdspPzZRvkrPOXONsqfbZS/wivtuSvUiNSUlBQuXLhgfZ2QkMCpU6fynJeUlMSCBQuoXr16oQMUESmsOr6V+Oap2/h843H+/cshMrJzRqdmma7/+ZHZAqmZ2YxZtIvlIzpomr+IiEgJSUnP4pXv97J8TywA1TxceO/hMNrX8y2xGEwXL5I4ew7my5fBaKTa6FEldm8REREp/QpVSJ0+fTqTJ08GcqYnjRo16rrrOlksFt58881CBygiYgsHo4GhHevSpVE1nvlqu3WKYH7MFoiKSWH36WTCgr2LP0gREZEKblf0BUYs2EF0YhoAdzT0492HmuNb2aVE47jwzTeYL1/G4ORElUcHlui9RUREpPQrVCH1rrvuwtXVFYvFwiuvvEL//v0JCwvLdY7BYKBSpUq0bt2adu3aFUWsIiKFVs+vMnc0rFagQuoVK/fFqZAqIiJSjMxmC7N/P87bKw+SbbbgaDQwrkcjnuhQB6OxZGeFmDMzSfxyHgBevXvjeJ3NdEVERKTiKlQh9fbbb+f2228HICMjg759+xIaGlqkgYmIFLWL6dnWjaVuxGiA5DSt7SwiImILi8XCrugLrIyKIyUtG083R3o0CSAs2JvEy5m88O1u1h2KByC4ihsfPtLSbh9ipvy4nOxz58BgoMrjj9slBhERESndClVIvdqVXet//vlnli9fzsmTJwGoXbs29957L/fcc4+ttxARKRKebo4FKqJCTrHV09XmH5EiIiIV1qG4i4xZtIuomBQA64eZn64/Ru2q7qSkZZGYmvOh5b3NApnaNxRPVye7xGoxm0n4fA4Albt2xaVuHbvEISIiIqWbzVWClJQUHnjgAdasWYPBYMDf3x+LxcIvv/zCJ598wh133MGSJUvw9PQsinhFRAqtR5MAPl1/rMDn/xJ1lla1fOjW2L8YoxIRESl/DsVd5IGP/yA1M9vadvWHmScSUgFwdjQwuVdT+rcJtusGj5c2bCDzyFEAqj4xxG5xiIiISOlmtPUCo0ePZs2aNbzxxhskJSVx5swZYmJiSEpKYsqUKaxbt47Ro0cXRawiIjYJC/amSZAnBV1y7UTCZYbN2074jI2sOxSPxVLA4awiIiIVmMViYcyiXaRmZt9wJkjNKu52L6ICXPhmEQBuLVrg3rKlXWMRERGR0svmQur333/PU089xfjx4/Hw8LC2e3h48MorrzBs2DC+//57W28jImIzg8HAtH5huDs7XreYajSAk4MBx6t+Ou45ncwTX25n2KJDbDxyXgVVERGRfOyKvkBUTEqBltM5cu4yu08nF39QN1D9P+/i/+qr+D433N6hiIiISClmcyEVyHejKW1CJSKlScMADxY/055bAv9ZbuTqouotgZ789HxH1r7YhX6ta+Bw1cG9sZd57PNt9J+5ic3HEkoybBERkTJjZVTczZ2/7+bOLw5Gd3eqDHqUyn9vqCsiIiJyLTavkXrPPfewfPlynnnmmWseX758uTacEpFSpWGAB8tHdGD36WRW7osjOS0LLzcnejQNoHkNL+v0wncebM4zd9Tng9WHWbrrDFcGom45nkj/mZvoUN+XMXeH0LKmjx3fjYiISOmSkpZt3VjqRowGSE7LKv6gRERERIqAzYXU1157jYcffpj77ruP5557jvr162MwGPjrr7+IjIwkJiaG//znP5w7dy5Xv2rVqtl6axGRQjMYDIQFexMW7J3veXV8KzG9fxhPd6rDOz9HsfpwkvXY70fO8/uR83Rp6MeYbg0JreFVzFGLiIiUfp5ujgUqokJOsdXLzal4A8pHfOQMjK4uePfvj8NVy5SJiIiIXIvNhdQmTZoAsHfvXlasWJHr2JV1BJs2bZqnn8lksvXWIiIlpn61yrx5b11eMLny/poj/BJ11nps7aF41h6Kp1tjf8Z0C8m1bICIiEhF06NJAJ+uP1bw85sGFGM015edmEjCZ59hycjAkm3C9+mn7BKHiIiIlB02F1InTJhg9102RURKSsMADz4d1Jp9Z5KZtuov1hz8Z7T9qv1nWbX/LPeGBjLqrgY08NfIFhERqXjCgr1pEuTJgdj8N5wyGqBxkCfN7TSjI+nr+VgyMjC4ueHdv59dYhAREZGyxeZC6qRJk4ogDBGRsqVpdS/mDG7DjlNJTF/1F78dPm899tPeWH7eF0t48yBG3hVCHd9KdoxURESkZBkMBqb1C+OBj/8gNTP7msVUowHcnR2Z1i/MLoMyzGlpJH39NQDeDz6Io4/WOxcREZEbMxblxfbt28ePP/7I8uXL2bdvX1FeWkSkVGpZ04d5T7Rj0VO30a5OFWu7xQJLd8Vw17T1vPTtbqITU+0YpYiISMlqGODB4mfa51ruxnhVvfSWQE++f7Y9IXaavXHh++8xXbgADg5U+de/7BKDiIiIlD02j0gFWL58OSNHjuTEiRNAztqoBoOBOnXq8N5773HfffcVxW1EREqttnWqsHDYrfx5NIH/rPqL7SdzNqUymS18u/00S3ae4aHWwYzoWp8gbzc7RysiIlL8GgZ4sHxEB3afTmblvjiS07LwcnOiR9MAmtfwstvyYBaTicS5XwDg2b07zjWq2yUOERERKXtsLqT+97//pXfv3tSoUYM333yTxo0bY7FYOHDgAJ9++il9+vTh559/plu3bkURr4hIqWUwGGhf35fb6lVl/V/xTFv1F3tOJwOQbbawYMspFm8/zSNtgxnepT7VPF3tHLGIiEjxMhgMhAV7Exbsbe9QrC6uWkVWdDQAVZ4YYudoREREpCyxuZAaERFB48aN2bhxIx4e/0zNCQ8PZ/jw4bRv357JkyerkCoiFYbBYOCOhtXoHOLHrwfOMW3VXxyITQEg02Tmiz9PsnBrNINurcXTd9TDt7KLnSMWERGpOJK+ylkb1f22W3Fr0sTO0YiIiEhZYvMaqbt27eLxxx/PVUS9wsPDgyFDhrBz505bbyMiUuYYDAa6NfbnpxEd+GhgSxpUq2w9lpFtZtbvx+n0zlreXnmQpMuZdoxURESk4qj+4Qf4Pvccvk89be9QREREpIyxeUSqs7Mzly9fvu7xS5cu4eTkZOttRETKLKPRwD2hgXRvEsDyPTG89+thjp/P+bmZmmni43VHmffnSYbcXpsnOtbFy00/M0VERIqLo48Pfs8Nt3cYIiIiUgbZPCK1Y8eOREZGcvjw4TzHjhw5wowZM+jUqZOttxERKfMcjAbCw6qzanQn3n2oOcFV/tl06lJGNh+sOULHt9fw4erDXMrItmOkIiIiIiIiIvK/bB6ROnXqVNq3b0/Tpk3p1asXDRs2BODgwYMsX74cNzc3pk6danOgIiLlhaODkQdb1SA8LIjvtp/mw9WHiUlOByAlPZv/rPqLORuP81Tnejx2Wy3cnW3+US0iIlLhxUfOwNG/Gl69emF00frkIiIicvNs/u28SZMmbNu2jfHjx7Ny5UoWL14MQKVKlbj//vt58803CQkJsTlQEZHyxsnByCNta9K3ZXW+2RpN5JojnLuYAUBSahb/t+Igs347xjN31Gdgu5q4OjnYOWIREZGyKevsWc5/+ilkZWFJz6DKoEftHZKIiIiUQUUyzKlBgwZ89913mM1m4uPjAfDz88NotHnlABGRcs/F0YHHbqtNv9bBfLXpJJ+sP8r5SzmbT52/lMmU5fuZueEow7vUp3+bYFwcVVAVERG5GUnz5kFWFkYPD7z69LF3OCIiIlJGFWml02g04u/vj7+/v4qoIiI3ydXJgaEd67JhbBde7tkIH/d/Np06m5LBhGVRdH13PQu2nCLLZLZjpCIiImWH6dIlkhZ+A4DPww/jULmSnSMSERGRskrVThGRUsbd2ZGnO9djw9guvNAtBE/XfyYPnLmQxvjv93Lnf9bz3fbTZKugKiIikq8L3yzCfOkSBicnfDSlX0RERGygQqqISCnl4erEiDsb8Nu4rjzftT6VXf4pqJ5KTOXFb3dz9/QNLNt1BpPZYsdIRURESidLZiaJX34JgGev+3GqVs3OEYmIiEhZpkKqiEgp5+XmxJi7G/Lb2C483bkebldtOnXs/GVGLtxFz/c38PPeWMwqqIqIiFgl//wz2WfPAlB1yBA7RyMiIiJlXZkopGZlZTFhwgRq1qyJq6srzZo1Y/78+QXqa7FYeP/99wkJCcHFxYWQkBA++OADLJbcxYY1a9YwdOhQGjVqhJubG1WrVqVPnz7s2bMnzzUHDx6MwWDI86dGjRpF8n5FRK7Fp5IzL/dsxIaxXRjaoQ4ujv/8CP/r7CWe/XoH9374O6v2n83zM05ERKQiujIatfIdd+BSr56doxEREZGyzvHGp9jfsGHD+PLLLxk+fDihoaEsXbqUgQMHkp2dzWOPPZZv38mTJzNp0iQGDRrE2LFjWb9+PSNHjuTChQtMmDDBet7YsWOJj4/nwQcf5JZbbiE2NpYZM2bQvn17Nm7cSPPmzXNd18nJiTlz5uRqq1RJC9eLSPHz83Dhtfsa82Snuny09ggLtkST+fdaqQdiU3jyy200r+HF6G4hdA7xw2Aw2DliERER+6jxwYckfvEFnj172DsUERERKQcMllI+bGnnzp20bNmSyZMn8/rrrwM5o0y7du3K/v37iY6OxtnZ+Zp94+LiqF27NgMHDmT27NnW9sGDB7Nw4UJOnDhBQEAAAOvXr6djx44Yjf+M8Dp27BhNmzalZ8+eLF68OE//9PT0InmPp0+fJjg4mOjoaI1qzYfJZCI+Ph4/Pz8cHBxu3EFyUf5sU5rzd+ZCGpFrjvDttmiy/2dqf6taPrzQLYT29X3tFF3pzl1ZoPzZRvkrPOWu/NMzaMHoe8E2yp9tlL/CU+5so/zZRvkrvNKeu1I/tX/RokUYjUaGDx9ubTMYDDz33HOcO3eOtWvXXrfvsmXLyMjIYMSIEbnaR4wYQUZGBsuWLbO2de7cOVcRFaBu3bq0aNGCqKioa17fbDaTkpKiKbQiYlfVvd2Y2jeUNS/cwYOtamC8agDq9pNJDJi1mYdn/smW44m5+lksFnaeSmLqigOM/34vU1ccYOepJP1MExEREREREbmGUj+1f/v27dSrV48qVarkam/Xrh0AO3bsoHv37tft6+LiQrNmzXK1t2jRAmdnZ3bs2HHD+8fGxhIUFJSnPTMzE09PTy5fvoyXlxf9+vXjnXfewdvb+4bXTElJISUlJdc9IKfqbjKZbti/ojKZTJjNZuWokJQ/25SF/FX3duHtvk15ulMdPlxzhB/2xHKlJrrpWCL9Pv2TDvWrMvquBrg5O/DSt3uIir0IgNEAZgt8uv4YTQI9+PdDzWjo71EkcZWF3JVmyp9tlL/CK8nclcbRBlK2xc+YgUuDBnjceScGfX2JiIhIESn1hdSYmBgCAwPztF8pbsbExOTb19/fP89IU6PRiL+/f759AebPn8/x48cZPXp0rvbAwEBefPFFWrVqBcCqVauYNWsW27dv548//sDFxSXf606bNo2IiIg87QkJCTfsW5GZzWaSk5MB8vw/lRtT/mxTlvJXGRjfJYiHm/kwa1Msqw8nWY/9fiSB348k4PB34fSKq/9+IO4iD33yJzP7NaKer5vN8ZSl3JVGyp9tlL/CK8ncXVlqSaQoZEZHc37GR2A2Ezh1Kt59ets7JBERESknSn0hNS0t7ZrFRaPRiJOTE2lpaTfdF8DV1TXfvocOHeLZZ5+ldevWPP3007mOTZ06Ndfr/v37c8stt/Diiy8yb948hg4dmt9bYsyYMbnOiY2NpW3btlStWhU/P798+1ZkV0bE+Pr6auRKISh/timL+fPzg7aNanIgNoX3Vh/h1wPnrMdM+czeN1sgLcvMW6ujWTa8vc2bVZXF3JUmyp9tlL/CU+6krEqc+wWYzThUqaJNpkRERKRIlfpCqqurKxkZGXnazWYzWVlZuLq63nRfgPT09Ov2jYmJoUePHnh7e7N06VKcnJxuGOfzzz/P+PHj+fXXX29YSPX09MTT0zNPu4ODg35RuQGj0ag82UD5s01ZzV/TGj7M+lcb9py+wKQfothx6sIN+5gtEBV7kX2xlwgL9rY5hrKau9JC+bON8ld4yp2UNdlJSVz4/nsAfAYOwJjP7woiIiIiN6vUz3ELCgqyriF6tSvT8q+1funVfc+ePYvZbM7VbjabOXv27DX7JiQk0K1bN1JTU/nvf/9L9erVCxSnk5MTAQEBJCYm3vhkERE7aFbDmzZ1qtz4xKus3BdXTNGIiIgUvaQFC7CkpWFwdcVnwAB7hyMiIiLlTKkvpLZs2ZKjR4/mKVBu3rzZejy/vhkZGezZsydX+86dO8nMzMzTNyUlhe7du3P69GlWrlxJSEhIgeNMT08nNjaWatWqFbiPiEhJS0nLxngTM/VPJVwuvmBERESKkDk9naSvvgbAu29fHH187ByRiIiIlDelvpDar18/zGYzH330kbXNYrEQGRmJn58fXbp0AeD8+fMcPHiQ1NRU63nh4eE4OzsTGRmZ65offvghzs7OhIeHW9vS0tK47777iIqK4ocffqBFixbXjCcjIyPXPa546623yM7Opnv37ja9XxGR4uTp5phrY6kb+XlfHA998gdLd54hI1u7nouISOmVvHQppsREMBqp8vhge4cjIiIi5VCpXyO1VatWDBo0iIkTJxIfH09oaChLly5l3bp1zJkzx7qZVGRkJBEREaxdu5Y77rgDyJnaP27cOKZMmUJWVhadOnVi/fr1zJs3jwkTJhAYGGi9z8CBA/ntt98YMGAA0dHRfPXVV9ZjlStXpnfv3kDOxlC33XYbffv2JSQkBIPBwK+//sqPP/5I586deeSRR0osNyIiN6tHkwA+XX/spvpsPZHE1hNJTF7uzEOtavBI25rU9q1UTBGKiIjcPIvZTOLncwHw6H43zsHB9g1IREREyqVSX0gFmDVrFrVq1WLu3Ll88sknhISEMG/ePB599NEb9o2IiMDHx4cZM2awcOFCgoODmTZtGqNGjcp13o4dOwCYP38+8+fPz3WsVq1a1kKqt7c33bp1Y9WqVXzxxRdkZ2dTt25dIiIiGDt2LI6OZSKlIlJBhQV70yTIkwOxKfmOTDUA7i4OXM74ZxRq4uVMPt1wjE83HKNjA18GtqvJnbf44+RQ6ic3iIhIOWcwGqn+/nskzJlDlUGP2TscERERKacMFovlJiZ5SnE4ffo0wcHBREdHU6NGDXuHU2qZTCbi4+Px8/PT7sGFoPzZpjzl71DcRR74+A9SM7OvWUw1GsDd2ZHvn22Ps4ORBVtOsWhbNEmpWXnO9fd0oX+bmjzcJpggb7dr3q885c4elD/bKH+Fp9yVf3oGLRh9L9hG+bON8ld4yp1tlD/bKH+FV9pzp2FEIiIVTMMADxY/055bAj2tbVdvQHVLoCffP9ueEH8PavtWYvw9t/Dn+Dt5/+Ew2tTOvXHH2ZQMPlh9mA5vr2HoF9tYe+gc5ptZhFVERERERESkjNA8dBGRCqhhgAfLR3Rg9+lkVu6LIzktCy83J3o0DaB5DS8MBkOu812dHAgPq054WHUOxV1k/uaTfL/jDBczsgEwW+DXA2f59cBZavi4MaBdTR5qFYyfh4s93p6IiFQg5z/+GLeWrXBv2ybPv18iIiIiRUmFVBGRCspgMBAW7E1YsPdN9WsY4EFEeFPG9WzEj7tj+GrTKfaeSbYeP52UxjsrDzF91V90bxLAI21qULeyRqmKiEjRyzh6lPj3PwCgRuSHeNx1l50jEhERkfJMhVQRESkUd2dH+repSf82Ndlz+gJfbzrFst1nSM8yA5BlsrB8TyzL98RSy8eVQe1TeahVTbzcnewcuYiIlBcJc+YA4OjvT+VOnewcjYiIiJR3WiNVRERs1qyGN28/2IzNr9xFRK8mNKhWOdfxk0npvPHTQdq+9SsvfrubnaeS0F6HIiJii6xz50j54UcAqjw2CIOzs50jEhERkfJOI1JFRKTIeLk58a/2tXnstlpsPZHE15tPsmJvLJmmnKJpRraZ77af5rvtp2kS5MnAdrUIDwuikov+ORIRkZuTNO8rLFlZGCtXxrtfP3uHIyIiIhWARqSKiEiRMxgMtK1ThfcfbsHv47rwXIfq1KzinuucqJgUXlmyl3Zvrea1pXs5EJtip2hFRKSsMV26TNLChQB49++Hg4eHnSMSERGRikBDgEREpFhVreTMo60DGNm9KX8ezxml+uuBc5jMOaNUL2Vk89WmU3y16RQta3rz6K21uCc0EFcnBztHLiIipdWF777FfPEiODlR5bHH7B2OiIiIVBAqpIqISIkwGg10CvGjU4gfccnpfLM1mgVbThGXkm49Z8epC+w4dYHJy/fzYMsaDGhXk7p+lfO5qoiIVDSWrCwSv/gSAK9778XJ39/OEYmIiEhFoan9IiJS4gK8XBl5VwN+H9eFmYNa0TnED4Phn+MXUrOY9ftxuv5nPQM+28TPe2PJMpntF7CIiJQejo4Evf1/VO7cmSpDHrd3NCIiIlKBaESqiIjYjaODkbubBHB3kwBOJaSyYOspFm2NJuFypvWcP44m8MfRBPw8XOjfOpiH2wZTw8c9n6uKiEh5ZjAYqNS2LZXatrV3KCIiIlLBaESqiIiUCjWrujOuRyP+GN+VDx5pQbs6VXIdj7+YQeTaI3R8Zy1D5m5lzcGz1nVWRURERERERIqbRqSKiEip4uLoQK/mQfRqHsSRcxf5evMpFm8/TUp6NgAWC6w5eI41B89R3duNR9oG069NMNU8XO0cuYiIFLfzMz+jcqeOuDZqZO9QREREpALSiFQRESm16lfzYOL9Tdj8yl38+8FmNA/2znX8zIU03v3vX7SfuoZnv97OxiPnsVg0SlVESi+LxcL7779PSEgILi4uhISE8MEHHxTqZ9fRo0dxdXXFYDCwadOmYoi2dEmLiiJ+2jSO9+7D5U2b7R2OiIiIVEAqpIqISKnn5uzAQ62DWTb8dpaP6MAjbWvi7uxgPZ5ttvDz3jgGztrMnf9Zz6zfjpF01TqrIiKlxeTJkxk1ahS33norM2bMoF27dowcOZIpU6bc9LVGjx6No2PFmWCWOOdzAJyCg3Fv09rO0YiIiEhFpEKqiIiUKU2rezG1byibX7mTKb2b0ijAI9fxY+cv88ZPB2g3dTVjvtnF9pOJGqUqIqVCXFwcU6dOZciQIXz55ZcMHTqUefPm8a9//Yu33nqLuLi4Al9r5cqV/PLLL4wePboYIy49ss6cIWXlSgCqDP4XBgeHG/QQERERKXoqpIqISJnk4erEoFtrsWJkRxY/cxt9W1TH2fGff9Yys818v/MMD3z8Jz3f/415m05yMT0rz3UsFgs7TyUxdcUBxn+/l6krDrDzVJKKryJS5JYtW0ZGRgYjRozI1T5ixAgyMjJYtmxZga6TlZXFqFGjGDlyJPXq1SuOUEudhC++AJMJB29vvPv2tXc4IiIiUkFVnLlAIiJSLhkMBlrVqkKrWlV4/b7GLN5xmq83n+L4+cvWcw7GXeT1pfuY+vMBwsOqM7BdTZpW9+JQ3EXGLNpFVEwKAEYDmC3w6fpjNAnyZFq/MBr+z4hXEZHC2r59Oy4uLjRr1ixXe4sWLXB2dmbHjh0Fus706dNJSkritdde4/vvvy/w/VNSUkhJSbG+jo2NBcBkMmEymQp8nZJmupDMhW+/A8DrkYexODuXaLwmkwmz2Vyqc1SaKX+2Uf4KT7mzjfJnG+Wv8Eo6dw43OctFhVQRESk3fCo5M7RjXYbcXoc/jyXw9eaT/DfqLNnmnNGlqZkmFmw5xYItp2jo78GJhMtkmczW/uarBqEeiE3hgY//YPEz7VVMFZEiERMTg7+/P0Zj7klhRqMRf39/YmJibniN2NhY3njjDaZNm4anp+dN3X/atGlERETkaU9ISMDFxeWmrlWS0uZ9hSUtDZydMXXvTnx8fIne32w2k5ycDJDn/53cmPJnG+Wv8JQ72yh/tlH+Cq+kcxcQEHBT56uQKiIi5Y7RaOD2+r7cXt+XcynpLNoWzYIt0Zy5kGY959DZi/lew2yB1MxsxizaxfIRHTAYDMUdtoiUc2lpadctWLq6upKWlnbNY1cbO3Ys9erVY8iQITd9/zFjxjB06FDr69jYWNq2bUvVqlXx8/O76euVBHNGBseXLgHAq08f/Bs0KPEYroyI8fX1velRK6L82Ur5KzzlzjbKn22Uv8Ir7blTIVVERMq1ap6uPNe1Ac/cUZ91h87x9eZTrDl4rkB9zRaIiklh9+lkwoK9izdQESlT1q1bR5cuXQp07qpVq7jrrrtwdXUlIyPjmuekp6fj6uqa73X++OMPvv76a9auXVuoERqenp7XHMXq4OBQKn9RATC6uhIYEUHCnM/xHfK43eI0Go2lOk+lnfJnG+Wv8JQ72yh/tlH+Cq80506FVBERqRAcjAbuvMWfO2/x55Ule5m/+VSB+y7fHaNCqojk0rBhQz777LMCndu4cWMAgoKCWL16NWazOVch1Gw2c/bsWYKCgvK9ztixY+nYsSPVq1fnyJEjAJw7l/PB0OnTpzl27Bh169YtzNsptQxGIx533YXHXXfZOxQRERERFVJFRKTisVj+2ViqIOZsPM7R+Et0aVSNLg2rEVzFvXgDFJFSLzAwMNc0+YJo2bIls2bNYs+ePYSFhVnbd+7cSWZmJi1btsy3/+nTpzl58iQNrjG9/aGHHsLFxYX09PSbiklERERECk6FVBERqXA83RwLXESFnILr2kPxrD0UD0RRv1plujT0o0ujarSuVQVnRy0gLyI3Fh4ezqhRo4iMjGTWrFnW9g8//BBnZ2fCw8OtbcnJycTGxhIYGIiXlxcAH3/8MZcvX851zXXr1jFjxgzeeustGjVqVDJvpIQkfD4Xj2534Vyjhr1DEREREQFUSBURkQqoR5MAPl1/rND9j5y7xJFzl/jst+NUdnGkYwNfujSsxh0N/ajmmf8ahyJScQUFBTFu3DimTJlCVlYWnTp1Yv369cybN48JEyYQGBhoPXfJkiU8/vjjfP755wwePBiAnj175rnmpUuXAOjSpQu33npribyPkpC6Yyfn3n6bc//+N7UXLcKtaRN7hyQiIiKiQqqIiFQ8YcHeNAny5EBsSr4jU40GaBzoyf89EMr6v86z5uA5dp5KytXnUkY2K/bFsWJfHABNq3vSpWE1ujSqRvMa3jgYDcX8bkSkLImIiMDHx4cZM2awcOFCgoODmTZtGqNGjbJ3aKVKwpzZADjXqYNr41vsHI2IiIhIDhVSRUSkwjEYDEzrF8YDH/9Bamb2NYupRgO4OzsyrX8YIf4eNK3uzfAu9Um6nMmGw/GsPXiO9X/Fk5SalavfvjMp7DuTwodrjlClkjOdQ/y4o6EfnUP88HZ3LqF3KCKllcFgYPTo0YwePTrf8wYPHmwdiVoU55UlGceOc2n1GgCqDnkcg1HLp4iIiEjpUCaeSrKyspgwYQI1a9bE1dWVZs2aMX/+/AL1tVgsvP/++4SEhODi4kJISAgffPABFkvu35rXrFnD0KFDadSoEW5ublStWpU+ffqwZ8+ea173yJEjhIeH4+XlhaenJ+Hh4Rw9etTm9yoiIiWjYYAHi59pzy2Bnta2qweP3hLoyffPtifE3yNXP59KzoSHVee9h1uw7bVuLH6mPSO61qdpdU/+V+LlTJbsPMPIhbtoOWUVD378BzPWHmF/TEqef4dERCRH4uefg8WCo58fnvffb+9wRERERKzKxIjUYcOG8eWXXzJ8+HBCQ0NZunQpAwcOJDs7m8ceeyzfvpMnT2bSpEkMGjSIsWPHsn79ekaOHMmFCxeYMGGC9byxY8cSHx/Pgw8+yC233EJsbCwzZsygffv2bNy4kebNm1vPjYuLo2PHjjg5OTFx4kQA3nvvPTp27MjOnTvx9/cvnkSIiEiRahjgwfIRHdh9OpmV++JITsvCy82JHk0DaF7DC4Mh/2n5DkYDrWr50KqWDy/c3ZCzKemsPxTPmoPn+P3IeS5lZFvPNVtg28kktp1M4t+/HCLA05Uujfy4o2E1OtT3pZJLmfgnWUSkWGWfP0/ysmUA+Dw2CKOzRvKLiIhI6VHqf2vbuXMnc+fOZfLkybz++usADB06lK5du/LSSy/x8MMP43ydB6y4uDimTp3KkCFDmD17trWvg4MDb731FsOGDSMgIACA//znP3Ts2BHjVVOHBg4cSNOmTZk8eTKLFy+2tk+dOpXExET2799PvXr1ALj//vtp0qQJU6dO5b333iuOVIiISDEwGAyEBXsTFuxt87X8PV3p1yaYfm2Cycw2s+1EImsPnWPtoXiOnLuU69y4lHQWbIlmwZZonB2MtK1ThTsa+tG1UTXq+Fa6YRFXRKQ8SvzqKyyZmRjd3fHp39/e4YiIiIjkUuqn9i9atAij0cjw4cOtbQaDgeeee45z586xdu3a6/ZdtmwZGRkZjBgxIlf7iBEjyMjIYNnfn3YDdO7cOVcRFaBu3bq0aNGCqKioPDH17NnTWkQFaNCgAd27d2fRokWFep8iIlK+ODsaaV/fl1fvbcyvYzqz4aUuTA5vwh0N/XBxzP3vTabJzO9HzvPGTwfo+p/13PHuOib9EMX6v+LJyDLZ6R2IiJQs8+XLJC1YCIB3v344eOZdMkVERETEnkr9iNTt27dTr149qlSpkqu9Xbt2AOzYsYPu3btft6+LiwvNmjXL1d6iRQucnZ3ZsWPHDe8fGxtLUFCQ9XVMTAxxcXG0bds2z7nt2rVj+fLlxMbGEhgYeN1rpqSkkJKSkuseACaTCZNJvzBfj8lkwmw2K0eFpPzZRvkrPOUuR3VvFwa2DWZg22DSMk1sOpbAur/Os/bQOc5cSM917smEVOb+cYK5f5zAzclIqxqV6db0Ml0b+RPk7Wand1A26euv8Eoydw4ODsV+Dyn9DE5O+L/0IonzvqLKv/JfvktERETEHkp9ITUmJuaaRckrxc2YmJh8+/r7++cZaWo0GvH398+3L8D8+fM5fvx4rl1Vr/S5UUz5FVKnTZtGREREnvaEhARcXFzyjakiM5vNJCcnA+T5fyo3pvzZRvkrPOXu2ppWNdD0Nj+G3+rLicR0/jiRzMbjyeyOuYTJ/M95aVlmfj+ewu/HU5j440HqVXWlfR0v2tf2IjSoMo5GLQGQH339FV5J5u7KUktSsRmcnfF+8EG8HnhAy5uIiIhIqVTqC6lpaWnXLC4ajUacnJxIS0u76b4Arq6u+fY9dOgQzz77LK1bt+bpp5/OdU3gmtd1dXXNdc71jBkzhqFDh1pfx8bG0rZtW6pWrYqfn1++fSuyKyNifH19NXKlEJQ/2yh/hafc3Vi1atC2EYwCLqZn8fuRBNYdimfdX/Gcv5SZ69yjCekcTUhn3razeLg60rGBL3eE+NE5xBffyvow7n/p66/wlDuxFxVRRUREpLQq9YVUV1dXMjIy8rSbzWaysrKsxcub6QuQnp5+3b4xMTH06NEDb29vli5dipOTU65rAte8bnp6eq5zrsfT0xPPa6z55ODgoF9UbsBoNCpPNlD+bKP8FZ5yV3DelRy4r3l17mteHbPZwp7TSfy88yRbTl9m9+lkLJZ/zr2Yns3Pe+P4eW8cBgM0q+5Fl0bV6NKwGqHVvTBqtCqgrz9bKHdSEiwWC0nz5+PZsyeO/7Ocl4iIiEhpUuoLqUFBQZw8eTJP+5Up9levX3qtvqtXr8ZsNueakmY2mzl79uw1+yYkJNCtWzdSU1P57bffqF69ep5rwj/rmt5sTCIiIgVlNBoIre5FgHMg4+7z40JaNuv/imftoXjWHzpHSnq29VyLBXafTmb36WTe+/UwvpWd6RxSjS6N/OjYwA8vN6d87pRTyNgVfYGVUXGkpGXj6eZIjyYBhAV7a3SYiBSr1E2bODvlDc6982/qLluKc+3a9g5JRERE5JpKfSG1ZcuWrFmzhsTExFwbTm3evNl6PL++s2bNYs+ePYSFhVnbd+7cSWZmZp6+KSkpdO/endOnT7Nu3TpCQkLyXLN69er4+/uzZcuWPMc2b95MQECACqkiIlIsqlZ2oW/LGvRtWYNsk5md0RdYe/Acaw6e42DcxVznnr+UyeIdp1m84zQORgOtavnQ9e/RqiH+lXMVRw/FXWTMol1ExeRshGg0gNkCn64/RpMgT6b1C6NhgEeJvlcRqTgSZs0GwKVePZxq1bJzNCIiIiLXV+p3XejXrx9ms5mPPvrI2maxWIiMjMTPz48uXboAcP78eQ4ePEhqaqr1vPDwcJydnYmMjMx1zQ8//BBnZ2fCw8OtbWlpadx3331ERUXxww8/0KJFi+vG9NBDD7FixQqOHj1qbTt8+DC//PILDz30kM3vWURE5EYcHYy0qV2FsT0asXJUJ/4c35W3+oTSrbE/7s65p2GbzBa2HE/k/1YcpPt7G+jw9lpeXbKXX/efZXd0Eg98/AcHYlOs55uvWj7gQGwKD3z8B4f+p1ArIlIU0g8e5PLGjQBUeWKIRsCLiIhIqVbqR6S2atWKQYMGMXHiROLj4wkNDWXp0qWsW7eOOXPmWDd9ioyMJCIigrVr13LHHXcAOVPsx40bx5QpU8jKyqJTp06sX7+eefPmMWHCBAIDA633GThwIL/99hsDBgwgOjqar776ynqscuXK9O7d2/r6lVde4dtvv6Vr166MGjUKgOnTp1OlShXGjx9f7DkRERH5X4FebgxoV5MB7WqSkW1iy/FE1h6MZ+2hcxw/fznXuWcupPH15lN8vfkUBsBy7UsCOUXV1MxsxizaxfIRHVTkEJEilTBnDgBO1avj2b27naMRERERyV+pL6QCzJo1i1q1ajF37lw++eQTQkJCmDdvHo8++ugN+0ZERODj48OMGTNYuHAhwcHBTJs2zVoAvWLHjh0AzJ8/n/nz5+c6VqtWrVyF1MDAQDZs2MALL7zAxIkTAejcuTPTpk3LVZwVERGxBxdHBzo2yFkbdcL9jTl+/jJrD55j7aFzbD6WSKbJbD03vyLqFWYLRMWksPt0MmHB3sUWt4hULFkxMaT8vAKAKoMHY3AsE7+aiIiISAVmsFgsBfkdSorR6dOnCQ4OJjo6mho1atg7nFLLZDIRHx+Pn5+fdg8uBOXPNspf4Sl3tinq/F3OyOaPowmsPXSOZTvPcDnTVOC+gZ6udAzxpbZvJepUrUStqpWo7euOu3PpLX7o66/wlLvyz97PoGen/h+JX3yBg5cX9deuwejuXuIxFIS+F2yj/NlG+Ss85c42yp9tlL/CK+25K72/+YiIiEiRq+TiSLfG/nRr7I/FYmHh1mgK+pFqbEo6i7adztPu7+lC7aqVcv74VqKOrzu1fStRq0ol3JxL38OPiNifKSWFC99+C4D3gEdKbRFVRERE5GoqpIqIiFRQnm5OBS6i5udsSgZnUzLYfDwxz7EAT1dq+7pTxzen0FqraiXq+FaiVlV3XJ1UZBWpqAwODvgOf5akbxZRpQDLdYmIiIiUBiqkioiIVFA9mgTw6fpjBT7/xbtDADiRkMqJ85c5kXCZ85cy8+0Tl5JOXEo6m47lLbIGebn+vTzA36NY//57zSoqsoqUd8ZKlaj6xBNUGTJEm9iJiIhImaFCqoiISAUVFuxNkyBPDsSmYM5nZKrRAI2DPBnepX6egsfF9CxOJqRy/PxlTpy/zPGEy5z8u9CacDn/ImtMcjoxyen8eSwhV7vBAEFebtT2dc8ZwXpVsTW4ijsujsVTZLVYLOyKvsDKqDhS0rLxdHOkR5MAwoK9VegRKSb63hIREZGyRIVUERGRCspgMDCtXxgPfPwHqZnZ1yymGg3g7uzItH5h1yx4eLg60bS6F02re+U5lpKe9ffI1b9HsP49ivVEQiqJ+RRZLRY4cyGNMxfS2Hjk2kXWOr45G11dvTZrzSruODsabz4RwKG4i4xZtIuomBTr+5AkduIAAB2aSURBVDZb4NP1x2gS5Mm0fmE0DPAo1LVF5B8Ws5nkJUvw7NEDY6VK9g5HRERE5KaokCoiIlKBNQzwYPEz7a9ZRAS4JdCT6f3DCPG/+SKip6sTzWp406yGd55jyWlZ/xRWz6dyIuEyx89f5mTCZZJSs657zauLrL8fyX3MaIDqPm65iqu1fFzxMGTg5WPG7Tq7fh6Ku2gtJl9xdVH5QGwKD3z8B4ufaa9iqoiNLq1bT+yrr3H2nX9T94cfcPKvZu+QRERERApMhVQREZEKrmGAB8tHdGD36WRW7osjOS0LLzcnejQNoHkNr2KZeuvl5kTzYG+aB3vnOZacmsXxhKtGsJ6/zPG/R7Ump12/yGq2QHRiGtGJafx2+HyuYw7G/VT3dqO2byVqV80ZyXpl06vR3+y67ojcK9dNzcxmzKJdLB/RQVORRWyQMHs2AC716qmIKiIiImWOCqkiIiKCwWAgLNibsGsUNkual7sTYe7XjuVCambOeqxXjWQ9cT5nNGtKenbei/3NZLZwKjGVU4mpbChETGYLRMWksPt0cqnIUXG6slbsir2xnEu6SDWfRHqGBmqtWCmw2EmTyPjrcJ528+XLZBw6BIDpwgVODBiIS0gDAidNKuEIRURERApHhVQREREpM7zdnWlR05kWNX1ytVssFi5cPZL177VZj5+/xPH4y1zKNBXJ/ft+tBFvd2cquThQydmRyi6OVHZ1pJKLI5Wd//6viwOVXK783fHvvzvknHtVm7uzQ6krTF57rdjzzPztuNaKlQLL+OswaTt25HtO5rFjJRSNiIiISNFRIVVERETKPIPBgE8lZ3wqOdPyqiKryWTi3LlzOFby5lRSeq4Nr06cv8zB2BSyrjen/xrMFki8nEni5aKIGSo55xRZrUXX/ynGVna9ukCbuyj7v20ujtdeA7agtFasiIiIiEj+VEgVERGRcs1gMFClkjN+nm60qpV7JOvUnw/w6Qb7jIyzWOBSRjaXMrKBDJuv5+xgzF2UdbmqKPs/xdjKLg45I2n/HlXr7uzAi9/t0VqxIiIiIiL5UCFVREREKqweTQNuqpD6+eDW1PatzOW/C6BX/vvP301cvqo95++mnL9n/tOenmUu8veSaTKTmWomKfX6G3LZqiKtFSsiIiIi8r9USBUREZEKKyzYmyZBnhyITbnuSEzIWSu0cZAndzSsViQjMbNN5pwCa+b/Fl3/KcZeq1BrLcpeXazNNGG6ieUJisLKfXEqpIqIiIhIhaNCqoiIiFRYBoOBaf3CrGuDXqseaTSAu7Mj0/qFFdl0dkcHI17uRrzcnWy+lsViISPbbC2wXkz/u9Caee0Rsle3Xfz7vycTUv9eYuDGjAZITiu+Ua8iIiIiIqWVCqkiIiJSoTUM8GDxM+2vsVt9zvFbAj2Z3j+MEP/SucGSwWDA1ckBVycHfCu7FOoaU1cc4NP1BVviwGwBLzfbC8AiIiIiImWNCqkiIiJS4TUM8GD5iA7sPp3Myn1xJKdl4eXmRI+mATSv4VXuN1bq0SSgwIVUyFlbVkRERESkolEhVURERISckZ1hwd4Vcu3Pm10rtnkNr5ILTkRERESklDDaOwARERERsa8ra8W6OztivM7g2+JYK1ZEREREpCzRiFQRERERKfNrxUrp4RLSoFjOFREREbE3FVJFREREBMi9VuyKvTGcTbyIfxUPeoYGVYi1YqVoBE6aZO8QRERERIqFCqkiIiIiYnVlrdjQIA/i4+Px8/PDwcHB3mGJiIiIiNid1kgVERERERERERERuQEVUkVERERERERERERuQIVUERERERERERERkRtQIVVERERERERERETkBlRIFREREREREREREbkBFVJFREREREREREREbsDR3gEIZGdnAxAbG2vnSEo3k8lEQkICGRkZODg42DucMkf5s43yV3jKnW2UP9sof4VX0rkLCAjA0VGPpiVJz6AFo58jtlH+bKP8FZ5yZxvlzzbKX+HZI3c38xyqp9VSID4+HoC2bdvaORIRERER+4iOjqZGjRr2DqNC0TOoiIiIyM09hxosFoulmOORG0hPT2fv3r34+flpJEY+YmNjadu2LVu2bCEwMNDe4ZQ5yp9tlL/CU+5so/zZRvkrvJLOnUakljw9gxaMfo7YRvmzjfJXeMqdbZQ/2yh/hWeP3GlEahnj6upKmzZt7B1GmREYGKgRKzZQ/myj/BWecmcb5c82yl/hKXfll55Bb46+F2yj/NlG+Ss85c42yp9tlL/CK62502ZTIiIiIiIiIiIiIjegQqqIiIiIiIiIiIjIDaiQKmWGp6cnEydOxNPT096hlEnKn22Uv8JT7myj/NlG+Ss85U4kh74XbKP82Ub5KzzlzjbKn22Uv8Ir7bnTZlMiIiIiIiIiIiIiN6ARqSIiIiIiIiIiIiI3oEKqiIiIiIiIiIiIyA2okCoiIiIiIiIiIiJyAyqkioiIiIiIiIiIiNyACqkiIiIiIiIiIiIiN6BCqoiIiIiIiIiIiMgNqJAqJSorK4sJEyZQs2ZNXF1dadasGfPnzy9Q37feeovevXsTFBSEwWDg6aefvuZ5c+fOxfD/7d17UFTn+Qfw7wILy0oWjIJglEsNsUatjQkVJIpJtDFGnKiRKqmuSauQ2RKTtrZjNdUIXtJ2kraZaaNYSJBcpE20bawX0sFMTBtrYuI1ViGgaUQBL3hBVsI+vz/87cZ1d9llz+4BDt/PzM5kz7573nOe8xC/8y6c1encPqqrqwN5Oqryt3ZHjx7FkiVLMHr0aERHRyMqKgr33nsvtmzZ4nb8pUuXUFBQgPj4eERGRiI9PR07d+4M8NmoT436abX3AP/rV1dXh9zcXKSmpiIqKgoxMTH4zne+g7KyMoiIy3gt9p8atWPvebd7925HTU6fPu3yOnvPs45qp+XeI21hBvUfM6gyzKDKMIP6jxlUGWZQZbScQ8OCtmciNxYuXIiysjJYLBaMHDkSW7ZswWOPPYavvvoK8+bN6/C9S5cuRVxcHNLS0rB161avc61YsQJDhgxx2hYfH6/o+LuSv7XbsGEDiouLMWPGDCxYsABWqxUbN27E9OnTUVxcjB/+8IeOsSKCRx55BP/617/w4x//GImJiXj11VcxZcoU7Ny5E/fff78apxoUatTPTmu9B/hfv/r6epw5cwazZ8/G4MGDce3aNVRWVsJsNuPIkSNYu3atY6xW+0+N2tmx99yz2WwoKChAnz59cOXKFZfX2XueeaudnRZ7j7SFGdR/zKDKMIMqwwzqP2ZQZZhBldF0DhUilezbt08AyMqVKx3bbDabTJgwQeLi4sRqtXb4/s8//9zx3wAkLy/P7bjS0lIBIP/+978Dc+DdgJLa7d27Vy5evOi0zWq1ysiRI6V///7S3t7u2P72228LACkrK3Nsa21tlTvuuENGjRoVuBNSmVr102LviSj/2XVn6tSpEhkZKdeuXXNs02L/qVU79l7H/vCHP0i/fv1k0aJFAkDq6+udXmfveeatdlrtPdIWZlD/MYMqwwyqDDOo/5hBlWEGVUbrOZR/2k+qqaioQEhICCwWi2ObTqfDj370IzQ0NKCqqqrD96ekpHR6zkuXLqG9vb3T7+tulNTunnvuwS233OK0LTw8HNnZ2WhqakJDQ4PTPLfeeityc3Md2yIiIrBw4ULs378f//3vfwN4VupRq3430krvAcp/dt1JSkrC1atXYbVanebRWv+pVbsbsfecnTt3Ds8++yxWrlyJmJgYj/Ow91z5Ursbaan3SFuYQf3HDKoMM6gyzKD+YwZVhhlUGa3nUC6kkmo+/vhjDBkyBLfeeqvT9jFjxgAA9u3bF9D5Jk2aBJPJBKPRiClTpuDIkSMB3b+aglG7U6dOISwsDNHR0U7z3H333QgNDQ3YPN2BWvWz01LvAYGpX0tLC5qamlBbW4sNGzagpKQEGRkZiIqKcppHa/2nVu3s2Huuli1bhoSEBOTl5XU4D3vPlS+1s9Na75G2MIP6jxlUGWZQZZhB/ccMqgwzqDJaz6G8Ryqp5tSpU0hISHDZPnDgQMfrgWA0GmE2m3HfffchOjoan3zyCV544QWMHTsWH330EW6//faAzKOmQNeupqYGb775JqZNm4bIyEineTIyMgI2T3ehVv202HtAYOq3evVqrFq1yvF84sSJKCkpcZlHa/2nVu3Ye+7t378f69evx7Zt21wC6s3zsPec+Vo7rfYeaQszqP+YQZVhBlWGGdR/zKDKMIMqo/UcyoVUUs3Vq1cRERHhsj0kJAR6vR5Xr14NyDw5OTnIyclxPH/kkUcwdepUpKenY8WKFSgvLw/IPGoKZO1aWlqQk5ODiIgIvPDCCz7NYzAYHK/3RGrVT4u9BwSmfo8//jgmTJiAhoYGbNmyBQ0NDWhpafFpnp7cf2rVjr3nXkFBAR5++GFMmjTJr3l6c+/5Wjut9h5pCzOo/5hBlWEGVYYZ1H/MoMowgyqj9RzKhVRSjcFgcHs/FZvNhra2Nsf/LIIhLS0N48aNw7vvvhu0OYIpULVra2vDrFmzcPjwYWzduhVJSUk+zdPa2up4vSdSq37u9PTeAwJTvyFDhji+STE3Nxfz58/HpEmTcOzYMcf7tdh/atXOnd7ee2+88QY+/PBDHD582O95emvvdaZ27mih90hbmEH9xwyqDDOoMsyg/mMGVYYZVBmt51DeI5VUM3DgQNTX17tst/9at/3XvIMlMTER586dC+ocwRKI2tlsNsybNw87duzAa6+9hgceeCAo83RHatXPk57ce0Bw+iInJwdffPEF3nvvvaDO09XUqp0nvbn3Fi9ejFmzZkGn06G6uhrV1dWOWtTV1eHkyZMBmae7Uqt2nvT03iNt6eqf8Z7888AMqgwzqDLMoP5jBlWGGVQZredQLqSSakaPHo2amhqXht6zZ4/j9WD6/PPPERcXF9Q5giUQtcvPz8ebb76Jl19+GTNnzvQ4z759+1y+7U6taxQsatXPk57ce0Bwfnbtf87R3NzsNI/W+k+t2nnSm3vvyy+/xOuvv47U1FTH46WXXgIAZGRk4Lvf/a7TPOy9r3Wmdp709N4jbWEG9R8zqDLMoMowg/qPGVQZZlBlNJ9DhUglH330kQCQwsJCxzabzSYTJkyQ2NhYaW1tFRGRxsZG+eyzz+TKlSse9wVA8vLy3L529uxZl23vvvuuAJAnnnhC4Vl0DaW1+8lPfiIA5Pnnn+9wnr/85S8CQDZu3OjY1traKnfccYeMHDkygGekLrXqp8XeE1FWvzNnzrjsz2azyZQpU0Sn08nx48cd27XYf2rVjr3nWr8///nPLo9Zs2YJACkuLpYdO3Y4xrL3/K+dVnuPtIUZ1H/MoMowgyrDDOo/ZlBlmEGV0XoO5UIqqWru3LkSEhIiTz31lBQXF8vDDz8sAKSkpMQxZvny5QJAqqqqnN5bVlYmhYWFUlhYKAAkLS3N8byurs4xbtiwYZKbmytr166Vl19+WfLz80Wv18uAAQPk5MmTap1qwPlbu9/97ncCQO666y7ZuHGjy+Py5cuOsTabTbKyssRgMMjSpUtl3bp1MnbsWAkNDZXKyko1Tzfg1KifVntPxP/6zZ8/XzIzM+WXv/ylFBcXy5o1a2T06NECQAoKCpzm0Gr/qVE79p77fzduZh9XX1/vtJ2953/ttNx7pC3MoP5jBlWGGVQZZlD/MYMqwwyqjJZzKBdSSVVWq1WWLVsmgwYNkvDwcBkxYoTTpy8inn+YsrKyBIDbx41jly5dKnfddZfExMSIXq+XwYMHy4IFC+R///ufCmcYPP7Wzmw2e6wbAKmtrXXaR3Nzs1gsFomLixODwSBpaWmybds2Fc4wuNSon1Z7T8T/+v31r3+Vhx56SBISEkSv14vJZJJ7771XSktLxWazucyjxf5To3bsPWUhTIS952/ttNx7pC3MoP5jBlWGGVQZZlD/MYMqwwyqjJZzqE5EBERERERERERERETkEb9sioiIiIiIiIiIiMgLLqQSERERERERERERecGFVCIiIiIiIiIiIiIvuJBKRERERERERERE5AUXUomIiIiIiIiIiIi84EIqERERERERERERkRdcSCUiIiIiIiIiIiLyggupRERERERERERERF5wIZWIiIiIiIiIiIjICy6kEhEREREREREREXnBhVQiIiIiIiIiIiIiL7iQSkTd3iuvvAKdToe6urquPpSgufkcu8s5r1ixAjqdDqdPn+7S47ALxvF0ptbd5boQERFR8PWGf/eZQX3DDEpEdlxIJSLqIXbv3o0VK1bgwoULPWK/BBQVFaGysrKrD4OIiIjIb8ygPQ8zKFHwcCGViLq9uXPn4urVq0hKSurqQ1GNu3PevXs3nnvuuaCE2GDst7erra3Fs88+ixMnTnT1oRAREZEfmEGvYwbtWZhBiYIrrKsPgIjIm9DQUISGhnb1YahKC+fc0tICo9HY1YfRZfbu3QsAuPvuu7v4SIiIiMgfWshjnaWFc2YGZQYlCib+RioRdXs33xPIfo+i48ePIz8/H/369UNUVBRmzZqFs2fPurz/zJkzyMvLw2233YaIiAikpqbiV7/6FUTEadyHH36IsWPHwmAwIDExEWvXrkVpaanL/Yjmz5+P5ORkr8d54sQJWCwWDBs2DEajETExMcjOzsbhw4f9OuclS5YAAFJSUqDT6aDT6bBr1y5UVlZCp9Nh8+bNLvvZunUrdDod3nnnHbfzdLTfG126dKnDWtuvyZEjR2A2m9GvXz8MHz7c8bov1+Dy5cv46U9/ipSUFBgMBgwYMAAPPPAA3n//fZfj9nY8docOHcK0adMQExMDo9GIsWPHYseOHR6q7sxdP9zcM55kZGTge9/7HgBg9OjR0Ol0CAsLQ2trq0/vJyIioq7HDMoMejNmUCLib6QSUY81Z84cxMfHo7CwEMePH8dLL70EvV6P119/3TGmqakJ6enpsFqtWLhwIRISEvD+++/j5z//OU6dOoXf/va3AIAjR45g4sSJMJlMWLZsGcLDw7F+/XpERUX5fXx79+7Frl27MH36dCQnJ6O+vh7r1q3D+PHjcfjwYcTHx/u8rxkzZuDo0aPYtGkTXnzxRfTv3x8AMGzYMMTGxmLQoEHYuHEjpk+f7vS+8vJyxMbGYvLkyZ3e7418qTUA5OTkICUlBUVFRbBarQB8vwZPPvkkKioqYLFYMHz4cJw/fx579uzBp59+inHjxnX6eI4dO4bMzExERETgmWeeQVRUFEpLSzFlyhRs3rwZ06ZN81hvpf2wePFiFBUV4fz58ygsLAQAREZGwmAw+PR+IiIi6r6YQZlBmUGJejEhIurmSktLBYDU1taKiMjy5csFgMyZM8dp3KJFiyQ0NFSam5sd2/Ly8iQuLk5Onz7tNHbx4sUSEhIidXV1IiIyY8YM0ev1Ul1d7RjT0NAg0dHRTnOLiJjNZklKSvJ6nFeuXHEZU1NTIxEREVJUVNThe29+LiKyZs0al212S5YskfDwcDl37pxj28WLF8VoNEpBQYHL+Bt1tF9fa20f9+ijj7rsw9drEBMTIxaLpcNj7cy1nzlzpoSFhclnn33m2Nbc3CyJiYmSnJws7e3tIuK+1p3pB09SU1Nl9uzZXscRERFR98QMeh0zKDMoEX2Nf9pPRD2WxWJxep6VlYX29nbHjdVFBBUVFZg6dSpCQ0PR1NTkeDz44IOw2WzYtWsX2tvbsX37dmRnZ2PIkCGO/cXGxuKxxx7z+/huvDdTS0sLzp49C5PJhKFDh+Ljjz/2e7/umM1mXLt2DRUVFY5tb731FlpaWjB37lzF+/dWa7snn3zS6bmv1wAATCYT9u7di/r6esXHc+M1/eY3v+kYZzKZkJ+fj7q6Ohw6dMjtvgPRDy0tLaipqcG3vvUtn8YTERFRz8EM+jVmUGZQot6GC6lE1GPdfI+ovn37AgDOnTsHAGhsbMT58+dRUlKC2NhYp8fEiRMBAA0NDWhsbERLSwuGDh3qMoe7bb5qbW3Fz372MwwcOBB9+vRB//79ERsbiwMHDgT820mHDh2KMWPGoLy83LGtvLwcQ4cORVpamuL9e6u1XUpKitNzX68BADz//PM4cOAABg0ahDFjxmD58uU4evSoX8fT2NiIK1euOAVYuzvvvBPA9W80dScQ/XDw4EHYbDaGWCIiIg1iBv0aMygzKFFvw3ukElGP5ekbReX/b8Zus9kAXL+X0RNPPOF2bGpqqmO8TqfzuK8buRsHXP8U+UZPPfUU/vSnP6GgoACZmZmIjo5GSEgInn76acexBZLZbIbFYkFdXR3Cw8NRVVWFlStXBmTf3mptFxkZ6fTc12sAALNnz0ZWVhb+9re/YefOnXjxxRexevVqlJaW4vvf/75fx9PRGE/XsbP94M7+/fsBAKNGjfJpPBEREfUczKDOmEGZQYl6Ey6kEpFmxcbGwmQy4auvvnJ88uxOe3s7jEaj20+ejx075rItJibG7af5N36rKgBUVFRg3rx5jhvZ250/f95xQ/3O8BS67GbPno1nnnkG5eXliIiIgIi4hD9/9quEr9fALiEhAXl5ecjLy8OFCxeQnp6O5557zqfzuHnePn36uL2m9m3uvvUWAOLi4jrVD+4cOHAAffv2xaBBg3w/aCIiItIEZlBmUGZQIu3in/YTkWaFhobi0UcfxebNm7Fv3z6X15ubm9HW1obQ0FA8+OCD+Pvf/46amhrH642NjS7fCAoAt99+O5qbm/HJJ584tl2+fBmvvvqq07iwsDCXT4/feOMNnDp1yq/z6dOnD4DrIdidvn37Ijs7G+Xl5SgvL8f48eORlJSkeL9K+HoN2tvb0dzc7PRaTEwMUlJS/Dqu0NBQTJ48Ge+8845T8Lx06RLWrVuH5ORkjBgxwuN7O9MP7pw4cYIBloiIqJdiBmUGZQYl0i7+RioRadratWvx3nvvITMzEz/4wQ8wcuRIXLx4EYcOHcJbb72F6upqxMfHY+XKldixYwfGjRsHi8UCvV6P9evXIykpyeWT/9zcXCxZsgTTp0/HokWL0NbWhpKSEgwYMABffPGFY1x2djbKyspgMpkwYsQIfPrpp9i0aRO+8Y1v+HUu99xzDwDgF7/4BebMmYPw8HDcf//9iIuLc4wxm83Izs4GAGzYsCFg+1XCl2tgMBhw2223YebMmRg1ahRMJhM++OADbN++3eWm/r5atWoVKisrHdc0KioKpaWlOHnyJN5++22EhHj+LLEz/eBOSkoKtm/fjtWrVyMxMRGpqakYM2aMX+dBREREPQ8zaGD2qwQzKDMoUVAIEVE3V1paKgCktrZWRESWL18uAKS+vt5pXFVVlQCQqqoqp+1NTU3y9NNPS3Jysuj1eomNjZXMzEz59a9/LVar1THugw8+kPT0dImIiJDBgwfLmjVrpKSkxGluu3/+858yatQo0ev1kpycLL///e9djrO5uVkWLFggcXFxYjQaZfz48fKf//xHsrKyJCsrq8NzvPm5XVFRkQwePFhCQkLcnmtbW5sMGDBADAaDXLhwwdcSe9yvr7X2NM7O2zWwWq2yePFi+fa3vy0mk0mMRqMMHz5cfvOb30hbW5tjP5299gcPHpSpU6eKyWSSyMhIycjIkG3btjmN8VTrzvTDzerr6+Whhx6SW265RQBIUVFRh+OJiIio+2EG/RozKDMoEV2nE/HxrsVERL3QK6+8gscffxy1tbUe72fUndhsNiQmJiIzMxObNm3q6sMhIiIiIj8wgxIRdU+8RyoRkYb84x//wJdffgmz2dzVh0JEREREvQQzKBH1FrxHKhGRBuzZswcHDx7EqlWrcOedd2Ly5MldfUhEREREpHHMoETU2/A3UomINOCPf/wj8vPz0bdvX7z22msd3sSeiIiIiCgQmEGJqLfhPVKJiIiIiIiIiIiIvODHRURERERERERERERecCGViIiIiIiIiIiIyAsupBIRERERERERERF5wYVUIiIiIiIiIiIiIi+4kEpERERERERERETkBRdSiYiIiIiIiIiIiLzgQioRERERERERERGRF1xIJSIiIiIiIiIiIvKCC6lEREREREREREREXnAhlYiIiIiIiIiIiMgLLqQSERERERERERERefF/W0fzTa1pJosAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12.5, 4.2))\n", + "\n", + "axes[0].plot(ts_env, costs_env, \"o-\", linewidth=2, markersize=7, color=\"C0\")\n", + "axes[0].set_xlabel(r\"inequality threshold $t$\")\n", + "axes[0].set_ylabel(r\"optimal cost $\\langle P^\\star, C\\rangle$\")\n", + "axes[0].set_title(\"Optimal cost as a function of $t$\")\n", + "axes[0].grid(alpha=0.3)\n", + "\n", + "axes[1].plot(\n", + " ts_env,\n", + " dcost_dt_num,\n", + " \"o-\",\n", + " label=r\"numerical: $\\Delta\\,\\mathrm{cost} / \\Delta t$\",\n", + " markersize=7,\n", + " color=\"C0\",\n", + ")\n", + "axes[1].plot(\n", + " ts_env,\n", + " dcost_dt_thy,\n", + " \"s--\",\n", + " label=r\"analytical: $-\\alpha\\,/\\,n$\",\n", + " markersize=7,\n", + " color=\"C3\",\n", + ")\n", + "axes[1].set_xlabel(r\"inequality threshold $t$\")\n", + "axes[1].set_ylabel(r\"$d\\,\\mathrm{cost}^\\star\\,/\\,dt$\")\n", + "axes[1].set_title(\"Sensitivity equals the constraint dual (envelope theorem)\")\n", + "axes[1].legend()\n", + "axes[1].grid(alpha=0.3)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "aacee411", + "metadata": {}, + "source": [ + "The two curves overlap to numerical precision: the constraint dual $\\alpha$ that the algorithm computes anyway is *literally* the gradient of the optimal cost with respect to the threshold. There are at least three concrete uses for this:\n", + "\n", + "- **Pricing constraints.** If you have a budget on which you can spend resources to relax the constraint (e.g., raise the cap $t$ by $\\Delta t$), the dual tells you that the cost will decrease by $\\alpha \\cdot \\Delta t / n$ : useful for cost-benefit analysis.\n", + "- **Calibrating thresholds.** If you want to find the threshold $t$ that achieves a target trade-off, the dual gives you a Newton step on $t$ for free.\n", + "- **Differentiating through the solver.** Combined with `jax.grad` (which works through the Sinkhorn updates we built on `apply_lse_kernel`), this lets the constrained solver slot into a larger learning pipeline as a differentiable layer." + ] + }, + { + "cell_type": "markdown", + "id": "5ad49267", + "metadata": {}, + "source": [ + "## 10. Takeaways and pointers\n", + "\n", + "We've reproduced the main experiments of the constrained-Sinkhorn paper *as a thin wrapper around OTT-JAX*. The key idea is that the row and column updates of constrained Sinkhorn are vanilla Sinkhorn updates on a **modified cost matrix** $C^{\\mathrm{eff}} = C - \\sum_m \\alpha_m D_m$; we delegate them to `Geometry.apply_lse_kernel` and only own the Newton step on the constraint duals $\\alpha$. The full implementation fits in well under a hundred lines and inherits OTT-JAX's per-iteration cost of $O(nm)$ as long as the number of additional constraints is $O(1)$.\n", + "\n", + "Several extensions discussed in the paper are natural next steps:\n", + "\n", + "- **Regularisation scheduling** (§3.1, doubling $\\gamma$ progressively) for tighter solutions at low entropy. Plumbs naturally into OTT-JAX's `epsilon_scheduler.Epsilon` class : one would set `epsilon` to an `Epsilon` instance instead of a float.\n", + "- **Sparse Newton acceleration** (§3.1, \"Sinkhorn-Newton-Sparse\"): exploits the fundamental theorem of LP, the optimal $P^\\star$ has only $O(n)$ non-zero entries : to apply a *full* Newton update on $(f, g, \\alpha)$ at $O(n^2)$ cost rather than the $O((n+K+L)^3)$ of a naive full Newton.\n", + "- **Partial OT** (§E of the paper) requires inequality marginals and a different Lagrangian; the same wrapping strategy applies, this time around `ott.solvers.linear`'s unbalanced-Sinkhorn building blocks.\n", + "\n", + "This tutorial sits naturally next to the **Multimarginal OT** and **Unbalanced OT** tutorials in OTT-JAX's *Linear OT* tutorial collection: all three extend the marginal-update pattern with extra dual variables specific to the structure of their respective problems, all three rely on `Geometry.apply_lse_kernel` for the Sinkhorn core, and all three are best understood as *additions* to the standard solver rather than replacements of it. We hope this makes the constrained extension feel like a natural neighbour of what is already there." + ] + }, + { + "cell_type": "markdown", + "id": "d31087cb", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "The papers and tools cited in this tutorial:\n", + "\n", + "**Source paper.** Xun Tang, Holakou Rahmanian, Michael Shavlovsky, Kiran Koshy Thekumparampil, Tesi Xiao, and Lexing Ying. *A Sinkhorn-type Algorithm for Constrained Optimal Transport.* International Conference on Learning Representations (ICLR), 2025. [[arXiv:2403.05054]](https://arxiv.org/abs/2403.05054) [[OpenReview]](https://openreview.net/forum?id=V5kCKFav9j)\n", + "\n", + "**OTT-JAX.** Marco Cuturi, Laetitia Meng-Papaxanthos, Yingtao Tian, Charlotte Bunne, Geoff Davis, and Olivier Teboul. *Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein.* arXiv preprint arXiv:2201.12324, 2022. [[arXiv]](https://arxiv.org/abs/2201.12324) [[code]](https://github.com/ott-jax/ott)\n", + "\n", + "**Sinkhorn distances and entropic OT.** Marco Cuturi. *Sinkhorn distances: Lightspeed computation of optimal transport.* Advances in Neural Information Processing Systems (NeurIPS) 26, 2013. [[paper]](https://papers.nips.cc/paper/2013/hash/af21d0c97db2e27e13572cbf59eb343d-Abstract.html)\n", + "\n", + "**The original matrix-scaling algorithm.** Richard Sinkhorn. *A relationship between arbitrary positive matrices and doubly stochastic matrices.* The Annals of Mathematical Statistics, 35(2):876–879, 1964.\n", + "\n", + "**Exponential convergence of entropic LP (Theorem 1 in §5).** Jonathan Weed. *An explicit analysis of the entropic penalty in linear programming.* In Conference on Learning Theory (COLT), pages 1841–1855, 2018. [[paper]](https://proceedings.mlr.press/v75/weed18a.html)\n", + "\n", + "**Background on computational OT.** Gabriel Peyré and Marco Cuturi. *Computational Optimal Transport.* Foundations and Trends in Machine Learning, 11(5–6):355–607, 2019. [[arXiv:1803.00567]](https://arxiv.org/abs/1803.00567)\n", + "\n", + "**Convergence analysis used by the source paper (Theorem 2).** Jason Altschuler, Jonathan Niles-Weed, and Philippe Rigollet. *Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration.* Advances in Neural Information Processing Systems (NeurIPS) 30, 2017. [[paper]](https://papers.nips.cc/paper/2017/hash/491442df5f88c6aa018e86dac21d3606-Abstract.html)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}