diff --git a/docs/tutorials/sinkhorn/001_constrained_sinkhorn.ipynb b/docs/tutorials/sinkhorn/001_constrained_sinkhorn.ipynb new file mode 100644 index 000000000..d85ca6469 --- /dev/null +++ b/docs/tutorials/sinkhorn/001_constrained_sinkhorn.ipynb @@ -0,0 +1,1610 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A Sinkhorn-type Algorithm for Constrained Optimal Transport\n", + "\n", + "This tutorial illustrates how to solve optimal transport problems with additional linear equality and inequality constraints. It implements the constrained Sinkhorn-type algorithm introduced in a recent paper , leveraging OTT-JAX's Geometry class to seamlessly inject constraints into the dual problem.\n", + "\n", + "We implement educational examples of:\n", + "1. **Algorithm 1** : Constrained Sinkhorn (alternating row/column scaling + Newton constraint dual update)\n", + "2. **Algorithm 2** : Sinkhorn-Newton-Sparse (SNS): warm-start Sinkhorn + sparse Newton acceleration\n", + "3. **Algorithm 4** : APDAGD baseline (adaptive primal-dual accelerated gradient descent)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX version: 0.9.0.1\n", + "OTT-JAX version: 0.6.0\n", + "Devices: [CpuDevice(id=0)]\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from functools import partial\n", + "\n", + "import ott\n", + "from ott.geometry import geometry as ott_geometry\n", + "from ott.problems.linear import linear_problem as ott_linear_problem\n", + "from ott.solvers.linear import sinkhorn as ott_sinkhorn\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Background: Constrained Entropic Optimal Transport\n", + "\n", + "### Problem (Eq. 4--5 of the paper)\n", + "\n", + "The constrained OT problem is:\n", + "\n", + "$$\\min_{P:\\,P\\mathbf{1}=r,\\,P^\\top\\mathbf{1}=c,\\,P\\ge0} \\langle C, P\\rangle \\quad\\text{s.t.}\\quad P \\in \\mathcal{S}$$\n", + "\n", + "where $\\mathcal{S} = \\mathcal{I} \\cap \\mathcal{E}$ encodes $K$ inequality constraints $\\langle D_k, P\\rangle \\ge 0$\n", + "and $L$ equality constraints $\\langle D_{l+K}, P\\rangle = 0$.\n", + "\n", + "Under entropic regularisation (Eq. 5), introducing slack variables $s_k$ for inequalities:\n", + "\n", + "$$\\min_{P,s}\\; \\langle C, P\\rangle + \\frac{1}{\\eta}\\,H(P, s_1,\\ldots,s_K)$$\n", + "\n", + "subject to $P\\mathbf{1}=r$, $P^\\top\\mathbf{1}=c$, $P\\ge0$, $D_k \\cdot P = s_k \\ge 0$, $D_{l+K} \\cdot P = 0$.\n", + "\n", + "### Dual / Lyapunov Function (Eq. 7)\n", + "\n", + "$$f(x,y,a) = -\\frac{1}{\\eta}\\sum_{ij}\\exp\\!\\Big(\\eta\\big(-C_{ij}+\\textstyle\\sum_m a_m(D_m)_{ij}+x_i+y_j\\big)-1\\Big)\n", + "+ \\sum_i x_i r_i + \\sum_j y_j c_j - \\frac{1}{\\eta}\\sum_{k=1}^K \\exp(-\\eta a_k - 1)$$\n", + "\n", + "with intermediate transport plan $P_{ij} = \\exp\\!\\big(\\eta(-C_{ij}+\\sum_m a_m(D_m)_{ij}+x_i+y_j)-1\\big)$.\n", + "\n", + "### First-order conditions (Eq. 8)\n", + "\n", + "$$\\nabla_x f = r - P\\mathbf{1},\\quad \\nabla_y f = c - P^\\top\\mathbf{1},\\quad\n", + "\\partial_{a_k}f = \\exp(-\\eta a_k-1) - \\langle D_k, P\\rangle,\\quad\n", + "\\partial_{a_{l+K}}f = -\\langle D_{l+K}, P\\rangle$$" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Core primitives loaded with OTT-JAX geometry.\n" + ] + } + ], + "source": [ + "# =========================================================\n", + "# Core primitives (shared by all algorithms)\n", + "# =========================================================\n", + "\n", + "def build_constraint_modifier(a, Ds, shape):\n", + " \"\"\"Aggregate the linear constraint contribution sum_m a_m D_m.\"\"\"\n", + " modifier = jnp.zeros(shape)\n", + " for m in range(len(Ds)):\n", + " modifier = modifier + a[m] * Ds[m]\n", + " return modifier\n", + "\n", + "\n", + "def make_ott_geometry(x, y, a, C, Ds, eta):\n", + " \"\"\"Build the OTT-JAX geometry for the current dual iterate.\n", + "\n", + " The paper defines\n", + " P_ij = exp(eta * (-C_ij + sum_m a_m D_m,ij + x_i + y_j) - 1).\n", + " OTT-JAX represents the Gibbs kernel as exp(-cost / epsilon), so we encode\n", + " the dual shifts and constraints in an effective cost matrix with\n", + " epsilon = 1 / eta, then multiply by exp(-1).\n", + " \"\"\"\n", + " constraint_mod = build_constraint_modifier(a, Ds, C.shape)\n", + " effective_cost = C - constraint_mod - x[:, None] - y[None, :]\n", + " return ott_geometry.Geometry(cost_matrix=effective_cost, epsilon=1.0 / eta)\n", + "\n", + "\n", + "def compute_P(x, y, a, C, Ds, eta):\n", + " \"\"\"Intermediate transport plan from dual variables using OTT-JAX geometry.\"\"\"\n", + " geom = make_ott_geometry(x, y, a, C, Ds, eta)\n", + " return jnp.exp(-1.0) * geom.kernel_matrix\n", + "\n", + "\n", + "def solve_ott_entropic_plan(cost_matrix, r, c, eta, max_iterations=5000, threshold=1e-8):\n", + " \"\"\"Reference entropic OT solve with OTT-JAX Sinkhorn.\"\"\"\n", + " geom = ott_geometry.Geometry(cost_matrix=cost_matrix, epsilon=1.0 / eta)\n", + " problem = ott_linear_problem.LinearProblem(geom, a=r, b=c)\n", + " solver = ott_sinkhorn.Sinkhorn(max_iterations=max_iterations, threshold=threshold)\n", + " return solver(problem).matrix\n", + "\n", + "\n", + "def dual_objective(x, y, a, C, Ds, r, c, eta, K):\n", + " \"\"\"Evaluate the Lyapunov / dual function f(x, y, a) -- Eq. 7.\"\"\"\n", + " P = compute_P(x, y, a, C, Ds, eta)\n", + " val = -(1.0 / eta) * jnp.sum(P)\n", + " val += jnp.dot(x, r) + jnp.dot(y, c)\n", + " if K > 0:\n", + " val -= (1.0 / eta) * jnp.sum(jnp.exp(-eta * a[:K] - 1.0))\n", + " return val\n", + "\n", + "\n", + "def compute_gradient(x, y, a, C, Ds, r, c, eta, K, L):\n", + " \"\"\"Full gradient of f w.r.t. (x, y, a) -- Eq. 8.\"\"\"\n", + " P = compute_P(x, y, a, C, Ds, eta)\n", + " M = K + L\n", + " grad_x = r - jnp.sum(P, axis=1)\n", + " grad_y = c - jnp.sum(P, axis=0)\n", + " grad_a = jnp.zeros(M)\n", + " for k in range(K):\n", + " grad_a = grad_a.at[k].set(jnp.exp(-eta * a[k] - 1.0) - jnp.sum(P * Ds[k]))\n", + " for l in range(L):\n", + " grad_a = grad_a.at[K + l].set(-jnp.sum(P * Ds[K + l]))\n", + " return grad_x, grad_y, grad_a, P\n", + "\n", + "\n", + "def compute_violation(P_rounded, Ds, K, L):\n", + " \"\"\"Constraint violation metric (Section 5).\n", + "\n", + " Violation(P) = sum_k |min(D_k . Round(P), 0)| + sum_l |D_{l+K} . Round(P)|\n", + " \"\"\"\n", + " viol = 0.0\n", + " for k in range(K):\n", + " viol += jnp.abs(jnp.minimum(jnp.sum(P_rounded * Ds[k]), 0.0))\n", + " for l in range(L):\n", + " viol += jnp.abs(jnp.sum(P_rounded * Ds[K + l]))\n", + " return viol\n", + "\n", + "\n", + "def round_transport(P, r, c):\n", + " \"\"\"Rounding algorithm from Altschuler et al. (2017).\n", + "\n", + " Projects an approximate transport plan onto U_{r,c} while preserving\n", + " non-negativity. Returns Round(P, U_{r,c}).\n", + " \"\"\"\n", + " P = jnp.maximum(P, 0.0)\n", + " row_sums = jnp.sum(P, axis=1)\n", + " row_sums = jnp.maximum(row_sums, 1e-300)\n", + " P = P * (r / row_sums)[:, None]\n", + "\n", + " col_sums = jnp.sum(P, axis=0)\n", + " col_sums = jnp.maximum(col_sums, 1e-300)\n", + " P = P * (c / col_sums)[None, :]\n", + "\n", + " err_r = r - jnp.sum(P, axis=1)\n", + " err_c = c - jnp.sum(P, axis=0)\n", + " P = P + jnp.outer(err_r, err_c) / jnp.sum(jnp.abs(err_r)).clip(1e-300)\n", + " P = jnp.maximum(P, 0.0)\n", + " return P\n", + "\n", + "\n", + "# -- Constraint conversion helpers (Section 2) --\n", + "def convert_leq_constraint(D, t, n):\n", + " \"\"\"Convert <= t --> D'.P >= 0 where D' = (t*1 - D)/n.\"\"\"\n", + " return (t * jnp.ones((n, n)) - D) / n\n", + "\n", + "\n", + "def convert_geq_constraint(D, t, n):\n", + " \"\"\"Convert >= t --> D'.P >= 0 where D' = (D - t*1)/n.\"\"\"\n", + " return (D - t * jnp.ones((n, n))) / n\n", + "\n", + "\n", + "def convert_eq_constraint(D, t, n):\n", + " \"\"\"Convert = t --> D'.P = 0 where D' = (D - t*1)/n.\"\"\"\n", + " return (D - t * jnp.ones((n, n))) / n\n", + "\n", + "\n", + "print(\"Core primitives loaded with OTT-JAX geometry.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Algorithm 1: Constrained Sinkhorn (Section 3)\n", + "\n", + "The three alternating update steps per iteration:\n", + "\n", + "1. **Row scaling**: $x \\leftarrow x + (\\log r - \\log(P\\mathbf{1}))/\\eta$\n", + "2. **Column scaling**: $y \\leftarrow y + (\\log c - \\log(P^\\top\\mathbf{1}))/\\eta$\n", + "3. **Constraint dual update** (Newton): $(a', t') = \\arg\\max_{\\tilde a, \\tilde t} f(x+\\tilde t\\mathbf{1}, y, \\tilde a)$, then $x \\leftarrow x + t'\\mathbf{1}$\n", + "\n", + "The Newton system for step 3 is $(K{+}L{+}1)\\times(K{+}L{+}1)$, so each iteration is $O(n^2)$." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Algorithm 1 (Constrained Sinkhorn) loaded.\n" + ] + } + ], + "source": [ + "def newton_constraint_update(x, y, a, C, Ds, r, c, eta, K, L, n_newton=5):\n", + " \"\"\"Newton update for constraint duals (a) and normalisation shift (t).\n", + "\n", + " Solves: (a', t') = argmax_{a~, t~} f(x + t~*1, y, a~)\n", + " using Newton's method with backtracking line search (Armijo condition).\n", + " Corresponds to Step 10-11 of Algorithm 1.\n", + " \"\"\"\n", + " M = K + L\n", + " n = C.shape[0]\n", + " t = 0.0\n", + "\n", + " for _step in range(n_newton):\n", + " x_shifted = x + t * jnp.ones(n)\n", + " P = compute_P(x_shifted, y, a, C, Ds, eta)\n", + "\n", + " # Inner products \n", + " PD = jnp.array([jnp.sum(P * Ds[m]) for m in range(M)])\n", + "\n", + " # -- Gradient of f w.r.t. (a, t) --\n", + " grad = jnp.zeros(M + 1)\n", + " for k in range(K):\n", + " grad = grad.at[k].set(jnp.exp(-eta * a[k] - 1.0) - PD[k])\n", + " for l_idx in range(L):\n", + " grad = grad.at[K + l_idx].set(-PD[K + l_idx])\n", + " grad = grad.at[M].set(jnp.sum(r) - jnp.sum(P)) # d_t f (Eq. 9)\n", + "\n", + " # -- Hessian (K+L+1) x (K+L+1) from Appendix C --\n", + " H = jnp.zeros((M + 1, M + 1))\n", + " for m in range(M):\n", + " for m2 in range(m, M):\n", + " val = -eta * jnp.sum(P * Ds[m] * Ds[m2])\n", + " H = H.at[m, m2].set(val)\n", + " if m2 > m:\n", + " H = H.at[m2, m].set(val)\n", + " if m < K:\n", + " H = H.at[m, m].add(-eta * jnp.exp(-eta * a[m] - 1.0))\n", + " cross = -eta * PD[m]\n", + " H = H.at[m, M].set(cross)\n", + " H = H.at[M, m].set(cross)\n", + " H = H.at[M, M].set(-eta * jnp.sum(P))\n", + "\n", + " # -- Newton direction --\n", + " delta = -jnp.linalg.solve(H, grad)\n", + "\n", + " # -- Backtracking line search (Armijo) --\n", + " alpha = 1.0\n", + " f_cur = dual_objective(x_shifted, y, a, C, Ds, r, c, eta, K)\n", + " descent = jnp.dot(grad, delta)\n", + " for _ls in range(20):\n", + " a_try = a + alpha * delta[:M]\n", + " t_try = t + alpha * delta[M]\n", + " f_try = dual_objective(\n", + " x + t_try * jnp.ones(n), y, a_try, C, Ds, r, c, eta, K\n", + " )\n", + " if f_try >= f_cur + 1e-4 * alpha * descent:\n", + " break\n", + " alpha *= 0.5\n", + "\n", + " a = a + alpha * delta[:M]\n", + " t = t + alpha * delta[M]\n", + "\n", + " return a, t\n", + "\n", + "\n", + "def constrained_sinkhorn(C, Ds, r, c, eta, K, L,\n", + " n_iters=100, n_newton=3,\n", + " x_init=None, y_init=None, a_init=None,\n", + " verbose=True):\n", + " \"\"\"Algorithm 1: Sinkhorn-type algorithm under linear constraint.\n", + "\n", + " Parameters\n", + " ----------\n", + " C : (n, n) cost matrix\n", + " Ds : list of (n, n) constraint matrices [D_1, ..., D_{K+L}]\n", + " r, c : (n,) source / target marginals\n", + " eta : regularisation parameter (1/epsilon)\n", + " K, L : number of inequality / equality constraints\n", + " n_iters : outer Sinkhorn iterations (N in Algorithm 1)\n", + " n_newton : Newton steps per constraint update\n", + " x_init, y_init, a_init : optional warm-start dual variables\n", + "\n", + " Returns\n", + " -------\n", + " x, y, a : final dual variables\n", + " P : final transport plan\n", + " history : dict of convergence metrics\n", + " \"\"\"\n", + " n = C.shape[0]\n", + " x = jnp.zeros(n) if x_init is None else x_init.copy()\n", + " y = jnp.zeros(n) if y_init is None else y_init.copy()\n", + " a = jnp.zeros(K + L) if a_init is None else a_init.copy()\n", + "\n", + " history = dict(cost=[], violation=[], dual_obj=[], marginal_err=[])\n", + "\n", + " for i in range(n_iters):\n", + " # Step 5-6: Row scaling\n", + " P = compute_P(x, y, a, C, Ds, eta)\n", + " x = x + (jnp.log(r) - jnp.log(jnp.sum(P, axis=1).clip(1e-300))) / eta\n", + "\n", + " # Step 7-8: Column scaling\n", + " P = compute_P(x, y, a, C, Ds, eta)\n", + " y = y + (jnp.log(c) - jnp.log(jnp.sum(P, axis=0).clip(1e-300))) / eta\n", + "\n", + " # Step 10-11: Constraint dual update (Newton) + shift\n", + " a, t = newton_constraint_update(x, y, a, C, Ds, r, c, eta, K, L, n_newton=n_newton)\n", + " x = x + t * jnp.ones(n)\n", + "\n", + " # -- Metrics --\n", + " P = compute_P(x, y, a, C, Ds, eta)\n", + " P_rounded = round_transport(P, r, c)\n", + " history[\"cost\"].append(float(jnp.sum(C * P_rounded)))\n", + " history[\"violation\"].append(float(compute_violation(P_rounded, Ds, K, L)))\n", + " history[\"dual_obj\"].append(float(dual_objective(x, y, a, C, Ds, r, c, eta, K)))\n", + " history[\"marginal_err\"].append(float(\n", + " jnp.sum(jnp.abs(jnp.sum(P, axis=1) - r))\n", + " + jnp.sum(jnp.abs(jnp.sum(P, axis=0) - c))\n", + " ))\n", + "\n", + " if verbose and (i + 1) % 20 == 0:\n", + " print(f\"Iter {i+1:4d} | cost: {history['cost'][-1]:.6f} | \"\n", + " f\"violation: {history['violation'][-1]:.2e} | \"\n", + " f\"marginal err: {history['marginal_err'][-1]:.2e}\")\n", + "\n", + " return x, y, a, P, history\n", + "\n", + "\n", + "print(\"Algorithm 1 (Constrained Sinkhorn) loaded.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Algorithm 2: Sinkhorn-Newton-Sparse (SNS) -- Appendix C\n", + "\n", + "Two stages:\n", + "\n", + "1. **Sinkhorn stage**: Run Algorithm 1 for $N_1$ iterations as warm start.\n", + "2. **Newton stage**: Use a sparse approximation of the full Hessian (Eq. 16--17) for $N_2$ Newton iterations.\n", + "\n", + "The modified Lyapunov function adds a regulariser for the degenerate direction\n", + "$v = (1_n, -1_n, 0_{K+L})$:\n", + "\n", + "$$\\tilde f(x,y,a) = f(x,y,a) - \\tfrac{1}{2}\\big(\\textstyle\\sum_i x_i - \\sum_j y_j\\big)^2$$\n", + "\n", + "The Hessian submatrix for $(x,y)$ (Eq. 10) is sparsified by keeping only entries of $P$ above threshold $\\rho$." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Algorithm 2 (SNS) loaded.\n" + ] + } + ], + "source": [ + "def sparsify(M, rho):\n", + " \"\"\"Entry-wise truncation: keep entries >= rho, set rest to 0 (Appendix C).\"\"\"\n", + " return jnp.where(jnp.abs(M) >= rho, M, 0.0)\n", + "\n", + "\n", + "def sns_constrained(C, Ds, r, c, eta, K, L,\n", + " N1=20, N2=30, rho=None, n_newton_inner=3,\n", + " x_init=None, y_init=None, a_init=None,\n", + " verbose=True):\n", + " \"\"\"Algorithm 2: Sinkhorn-Newton-Sparse for OT under linear constraint.\n", + "\n", + " Parameters\n", + " ----------\n", + " N1 : Sinkhorn warm-start iterations (Algorithm 1)\n", + " N2 : Sparse Newton iterations\n", + " rho : sparsification threshold (auto-chosen if None)\n", + " \"\"\"\n", + " n = C.shape[0]\n", + " M = K + L\n", + "\n", + " # -- Stage 1: Sinkhorn warm-start --\n", + " if verbose:\n", + " print(\"=== SNS Stage 1: Sinkhorn warm-start ===\")\n", + " x, y, a, P, hist_sink = constrained_sinkhorn(\n", + " C, Ds, r, c, eta, K, L,\n", + " n_iters=N1, n_newton=n_newton_inner,\n", + " x_init=x_init, y_init=y_init, a_init=a_init, verbose=verbose\n", + " )\n", + "\n", + " # -- Stage 2: Sparse Newton --\n", + " if verbose:\n", + " print(f\"\\n=== SNS Stage 2: Sparse Newton ({N2} iters) ===\")\n", + "\n", + " # Degenerate direction v = (1_n, -1_n, 0_{K+L})\n", + " v = jnp.concatenate([jnp.ones(n), -jnp.ones(n), jnp.zeros(M)])\n", + "\n", + " # Project z into non-degenerate subspace (Step 16)\n", + " z = jnp.concatenate([x, y, a])\n", + " z = z - (jnp.dot(z, v) / jnp.dot(v, v)) * v\n", + "\n", + " history = dict(cost=list(hist_sink[\"cost\"]),\n", + " violation=list(hist_sink[\"violation\"]),\n", + " dual_obj=list(hist_sink[\"dual_obj\"]),\n", + " marginal_err=list(hist_sink[\"marginal_err\"]))\n", + "\n", + " for it in range(N2):\n", + " x_cur, y_cur, a_cur = z[:n], z[n:2*n], z[2*n:]\n", + " P = compute_P(x_cur, y_cur, a_cur, C, Ds, eta)\n", + "\n", + " # Auto-choose rho to keep ~3n entries\n", + " if rho is None:\n", + " P_flat = jnp.sort(P.ravel())[::-1]\n", + " keep = min(3 * n, n * n)\n", + " rho_auto = float(P_flat[keep - 1]) if keep < n * n else 0.0\n", + " else:\n", + " rho_auto = rho\n", + "\n", + " P_sparse = sparsify(P, rho_auto)\n", + "\n", + " # -- Build sparse Hessian (Eq. 17) --\n", + " dim = 2 * n + M\n", + " P1 = jnp.sum(P, axis=1)\n", + " Pt1 = jnp.sum(P, axis=0)\n", + "\n", + " # Gradient of f~\n", + " grad_x = r - P1\n", + " grad_y = c - Pt1\n", + " grad_a = jnp.zeros(M)\n", + " for k in range(K):\n", + " grad_a = grad_a.at[k].set(jnp.exp(-eta * a_cur[k] - 1.0) - jnp.sum(P * Ds[k]))\n", + " for l_idx in range(L):\n", + " grad_a = grad_a.at[K + l_idx].set(-jnp.sum(P * Ds[K + l_idx]))\n", + " grad_f = jnp.concatenate([grad_x, grad_y, grad_a])\n", + " # Regulariser gradient\n", + " reg_val = jnp.sum(x_cur) - jnp.sum(y_cur)\n", + " grad_reg = jnp.concatenate([-reg_val * jnp.ones(n), reg_val * jnp.ones(n), jnp.zeros(M)])\n", + " grad_ftilde = grad_f + grad_reg\n", + "\n", + " # Cross terms\n", + " xa_block = jnp.zeros((n, M))\n", + " ya_block = jnp.zeros((n, M))\n", + " for m in range(M):\n", + " xa_block = xa_block.at[:, m].set(-eta * jnp.sum(P * Ds[m], axis=1))\n", + " ya_block = ya_block.at[:, m].set(-eta * jnp.sum(P * Ds[m], axis=0))\n", + "\n", + " aa_block = jnp.zeros((M, M))\n", + " for m in range(M):\n", + " for m2 in range(m, M):\n", + " val = -eta * jnp.sum(P * Ds[m] * Ds[m2])\n", + " aa_block = aa_block.at[m, m2].set(val)\n", + " if m2 > m:\n", + " aa_block = aa_block.at[m2, m].set(val)\n", + " if m < K:\n", + " aa_block = aa_block.at[m, m].add(-eta * jnp.exp(-eta * a_cur[m] - 1.0))\n", + "\n", + " # Assemble full Hessian (Eq. 17)\n", + " H = jnp.zeros((dim, dim))\n", + " H = H.at[:n, :n].set(-eta * jnp.diag(P1))\n", + " H = H.at[:n, n:2*n].set(-eta * P_sparse)\n", + " H = H.at[n:2*n, :n].set(-eta * P_sparse.T)\n", + " H = H.at[n:2*n, n:2*n].set(-eta * jnp.diag(Pt1))\n", + " H = H.at[:n, 2*n:].set(xa_block)\n", + " H = H.at[n:2*n, 2*n:].set(ya_block)\n", + " H = H.at[2*n:, :n].set(xa_block.T)\n", + " H = H.at[2*n:, n:2*n].set(ya_block.T)\n", + " H = H.at[2*n:, 2*n:].set(aa_block)\n", + "\n", + " # Add regulariser: H_ftilde = H_f - vv^T (Step 20)\n", + " H = H - jnp.outer(v, v)\n", + "\n", + " # Solve sparse Newton system (Step 21)\n", + " delta_z = -jnp.linalg.solve(H, grad_ftilde)\n", + "\n", + " # Line search on f~ (Step 22)\n", + " def f_tilde(z_try):\n", + " xt, yt, at = z_try[:n], z_try[n:2*n], z_try[2*n:]\n", + " fval = dual_objective(xt, yt, at, C, Ds, r, c, eta, K)\n", + " reg = -0.5 * (jnp.sum(xt) - jnp.sum(yt)) ** 2\n", + " return fval + reg\n", + "\n", + " alpha = 1.0\n", + " f_cur = f_tilde(z)\n", + " descent = jnp.dot(grad_ftilde, delta_z)\n", + " for _ls in range(30):\n", + " z_try = z + alpha * delta_z\n", + " f_try = f_tilde(z_try)\n", + " if f_try >= f_cur + 1e-4 * alpha * descent:\n", + " break\n", + " alpha *= 0.5\n", + "\n", + " z = z + alpha * delta_z\n", + "\n", + " # -- Metrics --\n", + " x_cur, y_cur, a_cur = z[:n], z[n:2*n], z[2*n:]\n", + " P = compute_P(x_cur, y_cur, a_cur, C, Ds, eta)\n", + " P_rounded = round_transport(P, r, c)\n", + " history[\"cost\"].append(float(jnp.sum(C * P_rounded)))\n", + " history[\"violation\"].append(float(compute_violation(P_rounded, Ds, K, L)))\n", + " history[\"dual_obj\"].append(float(dual_objective(x_cur, y_cur, a_cur, C, Ds, r, c, eta, K)))\n", + " history[\"marginal_err\"].append(float(\n", + " jnp.sum(jnp.abs(jnp.sum(P, axis=1) - r))\n", + " + jnp.sum(jnp.abs(jnp.sum(P, axis=0) - c))\n", + " ))\n", + "\n", + " if verbose and (it + 1) % 5 == 0:\n", + " print(f\"Newton {it+1:3d} | cost: {history['cost'][-1]:.6f} | \"\n", + " f\"viol: {history['violation'][-1]:.2e} | \"\n", + " f\"marg: {history['marginal_err'][-1]:.2e}\")\n", + "\n", + " x_final, y_final, a_final = z[:n], z[n:2*n], z[2*n:]\n", + " P_final = compute_P(x_final, y_final, a_final, C, Ds, eta)\n", + " return x_final, y_final, a_final, P_final, history\n", + "\n", + "\n", + "print(\"Algorithm 2 (SNS) loaded.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Algorithm 3: SNS with Entropy Regularisation Scheduling (Appendix C)\n", + "\n", + "Doubling schedule: start at $\\eta_\\mathrm{init}=1$ and double $\\eta$ until reaching $\\eta_\\mathrm{target}$.\n", + "At each level, run Algorithm 2 with warm-started dual variables from the previous level.\n", + "Final Newton stage at $\\eta_\\mathrm{target}$ to reach convergence." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Algorithm 3 (SNS + scheduling) loaded.\n" + ] + } + ], + "source": [ + "def sns_with_scheduling(C, Ds, r, c, eta_target, K, L,\n", + " N1_per_level=5, N2_per_level=5,\n", + " N2_final=30, rho=None, verbose=True):\n", + " \"\"\"Algorithm 3: SNS with entropy regularisation scheduling.\n", + "\n", + " Parameters\n", + " ----------\n", + " eta_target : target regularisation parameter\n", + " N1_per_level : Sinkhorn iters per scheduling level\n", + " N2_per_level : Newton iters per scheduling level\n", + " N2_final : final Newton iterations at eta_target\n", + " \"\"\"\n", + " n = C.shape[0]\n", + " M = K + L\n", + " N_eta = int(np.ceil(np.log2(eta_target))) # number of doubling steps\n", + "\n", + " x, y, a = jnp.zeros(n), jnp.zeros(n), jnp.zeros(M)\n", + " eta = 1.0\n", + "\n", + " if verbose:\n", + " print(f\"Scheduling: {N_eta} levels, eta_target = {eta_target}\")\n", + "\n", + " for level in range(N_eta):\n", + " if verbose:\n", + " print(f\"\\n--- Level {level+1}/{N_eta}, eta = {eta:.1f} ---\")\n", + " x, y, a, P, _ = sns_constrained(\n", + " C, Ds, r, c, eta, K, L,\n", + " N1=N1_per_level, N2=N2_per_level, rho=rho,\n", + " x_init=x, y_init=y, a_init=a, verbose=False\n", + " )\n", + " eta = min(2.0 * eta, eta_target)\n", + "\n", + " # Final Newton stage at eta_target\n", + " if verbose:\n", + " print(f\"\\n--- Final Newton stage at eta = {eta_target} ---\")\n", + " x, y, a, P_final, history = sns_constrained(\n", + " C, Ds, r, c, eta_target, K, L,\n", + " N1=0, N2=N2_final, rho=rho,\n", + " x_init=x, y_init=y, a_init=a, verbose=verbose\n", + " )\n", + "\n", + " return x, y, a, P_final, history\n", + "\n", + "\n", + "print(\"Algorithm 3 (SNS + scheduling) loaded.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Algorithm 4: APDAGD Baseline (Appendix D)\n", + "\n", + "Adaptive primal-dual accelerated gradient descent from Dvurechensky et al. (2018).\n", + "Applied to the dual objective $f(x,y,a)$ from Eq. 7." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Algorithm 4 (APDAGD) loaded.\n" + ] + } + ], + "source": [ + "def apdagd_constrained(C, Ds, r, c, eta, K, L, N=200, verbose=True):\n", + " \"\"\"Algorithm 4: APDAGD for constrained OT.\n", + "\n", + " Uses the Lyapunov function f(x,y,a) from Eq. 7.\n", + " \"\"\"\n", + " n = C.shape[0]\n", + " M = K + L\n", + " dim = 2 * n + M\n", + "\n", + " def f_eval(z):\n", + " return dual_objective(z[:n], z[n:2*n], z[2*n:], C, Ds, r, c, eta, K)\n", + "\n", + " def grad_f(z):\n", + " gx, gy, ga, _ = compute_gradient(z[:n], z[n:2*n], z[2*n:], C, Ds, r, c, eta, K, L)\n", + " return jnp.concatenate([gx, gy, ga])\n", + "\n", + " z = jnp.zeros(dim)\n", + " zeta = jnp.zeros(dim)\n", + " lam = jnp.zeros(dim)\n", + "\n", + " alpha = 0.0\n", + " beta = 0.0\n", + " L_k = 1.0\n", + "\n", + " history = dict(cost=[], violation=[], dual_obj=[], marginal_err=[])\n", + "\n", + " for k in range(N):\n", + " M_k = L_k / 2.0\n", + " # Inner loop: adaptive line search\n", + " for _inner in range(50):\n", + " M_k = 2.0 * M_k\n", + " alpha_new = (1.0 + jnp.sqrt(1.0 + 4.0 * M_k * beta)) / (2.0 * M_k)\n", + " beta_new = beta + alpha_new\n", + " tau = alpha_new / beta_new\n", + "\n", + " lam_new = tau * zeta + (1.0 - tau) * z\n", + " g = grad_f(lam_new)\n", + " zeta_new = zeta + alpha_new * g\n", + " z_new = tau * zeta_new + (1.0 - tau) * z\n", + "\n", + " # Check sufficient decrease condition (Step 12)\n", + " f_z = f_eval(z_new)\n", + " f_lam = f_eval(lam_new)\n", + " diff = z_new - lam_new\n", + " lhs = f_z\n", + " rhs = f_lam + jnp.dot(g, diff) - (M_k / 2.0) * jnp.dot(diff, diff)\n", + " if lhs >= rhs:\n", + " break\n", + "\n", + " L_k = M_k / 2.0\n", + " z = z_new\n", + " zeta = zeta_new\n", + " alpha = alpha_new\n", + " beta = beta_new\n", + "\n", + " # -- Metrics --\n", + " x_k, y_k, a_k = z[:n], z[n:2*n], z[2*n:]\n", + " P_k = compute_P(x_k, y_k, a_k, C, Ds, eta)\n", + " P_rounded = round_transport(P_k, r, c)\n", + " history[\"cost\"].append(float(jnp.sum(C * P_rounded)))\n", + " history[\"violation\"].append(float(compute_violation(P_rounded, Ds, K, L)))\n", + " history[\"dual_obj\"].append(float(f_eval(z)))\n", + " history[\"marginal_err\"].append(float(\n", + " jnp.sum(jnp.abs(jnp.sum(P_k, axis=1) - r))\n", + " + jnp.sum(jnp.abs(jnp.sum(P_k, axis=0) - c))\n", + " ))\n", + "\n", + " if verbose and (k + 1) % 50 == 0:\n", + " print(f\"APDAGD {k+1:4d} | cost: {history['cost'][-1]:.6f} | \"\n", + " f\"viol: {history['violation'][-1]:.2e}\")\n", + "\n", + " x_final, y_final, a_final = z[:n], z[n:2*n], z[2*n:]\n", + " P_final = compute_P(x_final, y_final, a_final, C, Ds, eta)\n", + " return x_final, y_final, a_final, P_final, history\n", + "\n", + "\n", + "print(\"Algorithm 4 (APDAGD) loaded.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment 1: Random Assignment under Constraints (Section 5, Figures 2-4)\n", + "\n", + "Following the paper exactly:\n", + "- Problem size $n = 500$, entropy regularisation $\\eta = 1200$\n", + "- Uniform marginals $r = c = \\frac{1}{n}\\mathbf{1}$\n", + "- Cost and constraint matrices $C, D_I, D_E$ with i.i.d. $\\mathrm{Unif}([0,1])$ entries\n", + "- Thresholds $t_I = t_E = 0.5$\n", + "- Inequality constraint: $\\langle D_I, P\\rangle \\le t_I$; Equality: $\\langle D_E, P\\rangle = t_E$\n", + "- Conversion: $D_1 = (D_I - t_I\\,\\mathbf{1}_{n\\times n})/n$, $D_2 = (D_E - t_E\\,\\mathbf{1}_{n\\times n})/n$ (Eq. 13)\n", + "\n", + "We compare Algorithm 1, Algorithm 2 (with $N_1=20$ Sinkhorn steps), and APDAGD." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Problem size: n = 500, eta = 1200\n", + "K = 1 inequality, L = 1 equality constraints\n", + "Thresholds: tI = 0.5, tE = 0.5\n" + ] + } + ], + "source": [ + "# -- Problem setup (Section 5, random assignment) --\n", + "n = 500\n", + "eta = 1200\n", + "\n", + "key = jax.random.PRNGKey(42)\n", + "keys = jax.random.split(key, 3)\n", + "\n", + "# Uniform marginals\n", + "r = jnp.ones(n) / n\n", + "c = jnp.ones(n) / n\n", + "\n", + "# Random cost and constraint matrices ~ Unif([0, 1])\n", + "C = jax.random.uniform(keys[0], (n, n))\n", + "DI_raw = jax.random.uniform(keys[1], (n, n)) # inequality matrix\n", + "DE_raw = jax.random.uniform(keys[2], (n, n)) # equality matrix\n", + "\n", + "# Thresholds\n", + "tI, tE = 0.5, 0.5\n", + "\n", + "# Convert to general form (Eq. 13):\n", + "# D_I . P <= t_I --> D1.P >= 0 with D1 = (t_I * 1 - D_I) / n\n", + "# D_E . P = t_E --> D2.P = 0 with D2 = (D_E - t_E * 1) / n\n", + "D1 = convert_leq_constraint(DI_raw, tI, n)\n", + "D2 = convert_eq_constraint(DE_raw, tE, n)\n", + "\n", + "Ds = [D1, D2]\n", + "K, L = 1, 1 # 1 inequality + 1 equality\n", + "\n", + "print(f\"Problem size: n = {n}, eta = {eta}\")\n", + "print(f\"K = {K} inequality, L = {L} equality constraints\")\n", + "print(f\"Thresholds: tI = {tI}, tE = {tE}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================================================\n", + "Running Algorithm 1 (Constrained Sinkhorn)...\n", + "============================================================\n", + "Iter 20 | cost: 0.003437 | violation: 1.68e-06 | marginal err: 2.35e-02\n", + "Iter 40 | cost: 0.003478 | violation: 2.88e-07 | marginal err: 4.68e-03\n", + "Iter 60 | cost: 0.003484 | violation: 6.00e-08 | marginal err: 1.41e-03\n", + "Iter 80 | cost: 0.003486 | violation: 1.86e-08 | marginal err: 6.01e-04\n", + "Iter 100 | cost: 0.003486 | violation: 8.19e-09 | marginal err: 3.26e-04\n", + "Iter 120 | cost: 0.003486 | violation: 4.44e-09 | marginal err: 2.02e-04\n", + "Iter 140 | cost: 0.003487 | violation: 2.70e-09 | marginal err: 1.35e-04\n", + "Iter 160 | cost: 0.003487 | violation: 1.76e-09 | marginal err: 9.57e-05\n", + "Iter 180 | cost: 0.003487 | violation: 1.21e-09 | marginal err: 7.06e-05\n", + "Iter 200 | cost: 0.003487 | violation: 8.74e-10 | marginal err: 5.39e-05\n", + "Iter 220 | cost: 0.003487 | violation: 6.57e-10 | marginal err: 4.22e-05\n", + "Iter 240 | cost: 0.003487 | violation: 5.10e-10 | marginal err: 3.37e-05\n", + "Iter 260 | cost: 0.003487 | violation: 4.06e-10 | marginal err: 2.73e-05\n", + "Iter 280 | cost: 0.003487 | violation: 3.28e-10 | marginal err: 2.25e-05\n", + "Iter 300 | cost: 0.003487 | violation: 2.69e-10 | marginal err: 1.87e-05\n", + "Iter 320 | cost: 0.003487 | violation: 2.21e-10 | marginal err: 1.57e-05\n", + "Iter 340 | cost: 0.003487 | violation: 1.83e-10 | marginal err: 1.33e-05\n", + "Iter 360 | cost: 0.003487 | violation: 1.51e-10 | marginal err: 1.15e-05\n", + "Iter 380 | cost: 0.003487 | violation: 1.25e-10 | marginal err: 1.02e-05\n", + "Iter 400 | cost: 0.003487 | violation: 1.03e-10 | marginal err: 9.25e-06\n", + "\n", + "============================================================\n", + "Running Algorithm 2 (Sinkhorn-Newton-Sparse)...\n", + "============================================================\n", + "=== SNS Stage 1: Sinkhorn warm-start ===\n", + "Iter 20 | cost: 0.003437 | violation: 1.68e-06 | marginal err: 2.35e-02\n", + "\n", + "=== SNS Stage 2: Sparse Newton (80 iters) ===\n", + "Newton 5 | cost: 0.003487 | viol: 4.29e-10 | marg: 1.10e-05\n", + "Newton 10 | cost: 0.003487 | viol: 3.81e-12 | marg: 3.59e-08\n", + "Newton 15 | cost: 0.003487 | viol: 3.80e-14 | marg: 2.92e-10\n", + "Newton 20 | cost: 0.003487 | viol: 2.30e-14 | marg: 1.73e-10\n", + "Newton 25 | cost: 0.003487 | viol: 7.36e-15 | marg: 6.70e-11\n", + "Newton 30 | cost: 0.003487 | viol: 7.36e-15 | marg: 6.70e-11\n", + "Newton 35 | cost: 0.003487 | viol: 7.36e-15 | marg: 6.70e-11\n", + "Newton 40 | cost: 0.003487 | viol: 7.36e-15 | marg: 6.70e-11\n", + "Newton 45 | cost: 0.003487 | viol: 7.36e-15 | marg: 6.70e-11\n", + "Newton 50 | cost: 0.003487 | viol: 5.33e-15 | marg: 4.35e-11\n", + "Newton 55 | cost: 0.003487 | viol: 3.46e-15 | marg: 2.75e-11\n", + "Newton 60 | cost: 0.003487 | viol: 3.46e-15 | marg: 2.75e-11\n", + "Newton 65 | cost: 0.003487 | viol: 3.46e-15 | marg: 2.75e-11\n", + "Newton 70 | cost: 0.003487 | viol: 3.46e-15 | marg: 2.75e-11\n", + "Newton 75 | cost: 0.003487 | viol: 3.46e-15 | marg: 2.75e-11\n", + "Newton 80 | cost: 0.003487 | viol: 3.46e-15 | marg: 2.75e-11\n", + "\n", + "============================================================\n", + "Running Algorithm 4 (APDAGD)...\n", + "============================================================\n", + "APDAGD 50 | cost: 0.003482 | viol: 1.73e-05\n", + "APDAGD 100 | cost: 0.003485 | viol: 1.69e-05\n", + "APDAGD 150 | cost: 0.003485 | viol: 1.63e-05\n", + "APDAGD 200 | cost: 0.003485 | viol: 1.54e-05\n", + "APDAGD 250 | cost: 0.003485 | viol: 1.43e-05\n", + "APDAGD 300 | cost: 0.003485 | viol: 1.30e-05\n", + "APDAGD 350 | cost: 0.003485 | viol: 1.16e-05\n", + "APDAGD 400 | cost: 0.003485 | viol: 1.01e-05\n", + "\n", + "Done.\n" + ] + } + ], + "source": [ + "# -- Run Algorithm 1: Constrained Sinkhorn (400 iters as in Figure 3) --\n", + "print(\"=\" * 60)\n", + "print(\"Running Algorithm 1 (Constrained Sinkhorn)...\")\n", + "print(\"=\" * 60)\n", + "x_sink, y_sink, a_sink, P_sink, hist_sink = constrained_sinkhorn(\n", + " C, Ds, r, c, eta, K=K, L=L,\n", + " n_iters=400, n_newton=3, verbose=True\n", + ")\n", + "\n", + "# -- Run Algorithm 2: SNS (N1=20 Sinkhorn + Newton) --\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Running Algorithm 2 (Sinkhorn-Newton-Sparse)...\")\n", + "print(\"=\" * 60)\n", + "x_sns, y_sns, a_sns, P_sns, hist_sns = sns_constrained(\n", + " C, Ds, r, c, eta, K=K, L=L,\n", + " N1=20, N2=80, rho=None, verbose=True\n", + ")\n", + "\n", + "# -- Run Algorithm 4: APDAGD --\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"Running Algorithm 4 (APDAGD)...\")\n", + "print(\"=\" * 60)\n", + "x_apd, y_apd, a_apd, P_apd, hist_apd = apdagd_constrained(\n", + " C, Ds, r, c, eta, K=K, L=L, N=400, verbose=True\n", + ")\n", + "\n", + "print(\"\\nDone.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Computing reference P_eta* via SNS (high accuracy)...\n", + "Final TV (Sinkhorn): 2.41e-05\n", + "Reference cost: 0.003487\n" + ] + } + ], + "source": [ + "# -- Compute reference P_eta* using SNS (Algorithm 2) at high accuracy --\n", + "print(\"Computing reference P_eta* via SNS (high accuracy)...\")\n", + "x_ref, y_ref, a_ref, P_ref, _ = sns_constrained(\n", + " C, Ds, r, c, eta, K=K, L=L,\n", + " N1=50, N2=100, rho=None, verbose=False\n", + ")\n", + "\n", + "def tv_distance(P, P_ref):\n", + " \"\"\"Total variation distance between two transport plans.\"\"\"\n", + " return 0.5 * float(jnp.sum(jnp.abs(P - P_ref)))\n", + "\n", + "# Compute TV distances by re-running Sinkhorn step by step\n", + "tv_sink = []\n", + "x_t, y_t, a_t = jnp.zeros(n), jnp.zeros(n), jnp.zeros(K + L)\n", + "for i in range(400):\n", + " P_t = compute_P(x_t, y_t, a_t, C, Ds, eta)\n", + " x_t = x_t + (jnp.log(r) - jnp.log(jnp.sum(P_t, axis=1).clip(1e-300))) / eta\n", + " P_t = compute_P(x_t, y_t, a_t, C, Ds, eta)\n", + " y_t = y_t + (jnp.log(c) - jnp.log(jnp.sum(P_t, axis=0).clip(1e-300))) / eta\n", + " a_t, t_t = newton_constraint_update(x_t, y_t, a_t, C, Ds, r, c, eta, K, L, n_newton=3)\n", + " x_t = x_t + t_t * jnp.ones(n)\n", + " P_t = compute_P(x_t, y_t, a_t, C, Ds, eta)\n", + " tv_sink.append(tv_distance(P_t, P_ref))\n", + "\n", + "print(f\"Final TV (Sinkhorn): {tv_sink[-1]:.2e}\")\n", + "print(f\"Reference cost: {float(jnp.sum(C * P_ref)):.6f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABW0AAAIDCAYAAAB/zjwXAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xd4FNX+x/H3lvQKJKRACB2kC9K7UlQs2Ch6pXnVe61cxPpTERtXrwXFguUK1osigoqoKIKiFEEBAQUBKVJCT0LqZnfP748hG5YkmITABvi8nmef3Z05M/OdM7Obk++eOWMzxhhEREREREREREREpEqwBzoAERERERERERERESmipK2IiIiIiIiIiIhIFaKkrYiIiIiIiIiIiEgVoqStiIiIiIiIiIiISBWipK2IiIiIiIiIiIhIFaKkrYiIiIiIiIiIiEgVoqStiIiIiIiIiIiISBWipK2IiIiIiIiIiIhIFaKkrYiIiIiIiIiIiEgVoqStiIiIiIiIiIiISBWipK2IiIiIiIiIiIhIFaKkrYiInBK2bNmCzWZjxIgRgQ5FTiId97JTXZ3eFixYgM1m46GHHipT+VPxfDDG0K5dO/r16xfoUOQMtn79epxOJy+99FKgQxERkTOckrYiIhIwhUmFYz3S09MDHeZJsWPHDiZOnEi/fv2oU6cOwcHBJCYmcsUVV7B06dJK2UZJ9R0UFEStWrUYNGgQy5cvr5TtiFRV5U18ysn11ltv8fPPP/Pwww8HOhTq1q1b6t+lXr16lbjMsmXLuPDCC4mNjSUiIoJOnTrxwQcfHHM7FVmmKnvnnXe48cYbOeeccwgJCcFmszF16tRSy1f0b19566085Zs0acLQoUMZP348hw4dKvO+i4iIVDZnoAMQERFp0KABf/vb30qcFxoaCkCtWrX47bffiImJOZmhnTSTJk3iiSeeoEGDBvTr14/4+Hg2bNjArFmzmDVrFu+99x6DBw+ulG0dWd/Z2dn89NNPTJ8+nVmzZvH111/To0ePStmOnFyn+2dETm9er5eHHnqI7t2706lTp0CHA0BMTAyjR48uNr1u3brFps2fP5/+/fsTGhrKkCFDiIqKYsaMGQwePJg///yTO+64o1KWqeruv/9+tm7dSlxcHElJSWzduvWY5Svyt6+89VaRer7rrrt45513eP755/m///u/468YERGRijAiIiIBsnnzZgOY/v37BzqUgJsxY4ZZsGBBsenfffedCQoKMtWqVTN5eXnHtY1j1feECRMMYHr06HFc26hshTEPHz480KHIaWD+/PkGMOPGjQt0KOVW3thPtc/O7NmzDWBee+21QIdijDEmNTXVpKamlqlsQUGBadCggQkJCTErVqzwTU9PTzeNGzc2wcHBZsuWLce9zKngq6++8sVd+HdlypQppZYv79++8tbb8dRzq1atTGpqqvF4PGXcexERkcql4RFEROSUUNr4jG63mwkTJtCgQQNCQ0Np2LAhEyZM4I8//iix/NSpU0u9XLOkS6ePnLZo0SL69etHbGwsNpvNb9nvvvuOiy++mLi4OEJCQmjUqBH3338/OTk5Zdq/yy+/nJ49exab3r17d3r37s3BgwdZvXp1mdZVEddddx0AP/30U7F5LpeLSZMm0b9/f1JSUggJCaFmzZpcfvnlrFixolj5I+ts+fLl9O3bl6ioKGJiYrjsssvYsmVLsWU8Hg9PPPEEDRs29DuOXq+31JinTJlCx44diYyMJDIyko4dO/7lcV20aBG9e/cmKiqK+Ph4brrpJnJzcwH47LPP6Ny5MxERESQkJHDXXXfhdrvLWIPlryeAGTNm0LNnT2rWrEloaCjJycn06dOHGTNmlLtcZX1Gynv8TkT9lvXzVJ5YH3roIXr37g3A+PHj/S53L+mcLG0733//Pb169SIqKorY2FiuuOIKNm7ceMxljvXdUdbz+EhljaE0FanfE/nZKawHm83GFVdcUWoMZf0+Odm++eYbNm3axNVXX02bNm1802NiYrjvvvtwuVy8+eabx71MRSQmJtK4cWPS09O54447SE1NJSQkhMaNG/Puu+8e9/qP1qdPH1JTU8tcvrx/+8pbb8dTz4MGDWLr1q3Mnz+/zPsjIiJSmZS0FRGRU9qoUaO47777ALj55ps5//zzefbZZ0u8pPV4LFq0iF69emGz2bjhhhv8Ltd8+eWX6dWrFz/88AMDBgzgtttuo3bt2jz22GP07dsXl8t1XNsOCgoCwOn0H9WoMJ4FCxYc1/qPdPQ2AA4cOMDo0aPJz8/nwgsv5F//+he9evVizpw5dOnShWXLlpW4rmXLltGjRw+Cg4N9YxzOmjWLPn36kJeX51f2hhtu4J577sHr9XLzzTfTv39/nnnmGW6//fYS133bbbcxatQoduzYwXXXXcd1113Hjh07GDlyZKnLLF26lPPOO4+YmBhuvPFG6tSpw8svv8z111/P+++/z5VXXklqaio33ngjsbGx/Oc//+Hxxx8vc92Vt55efvllrrzySjZs2MBll13GmDFjOP/880lLS2PmzJnlLleain5GynP8oPLqtyKfp7LE2qtXL4YPHw5Az549GTdunO8RGxv7l/UIsGTJEt8+3nrrrfTs2ZOZM2fSpUsX/vjjjxKXOdZ3R0XO44rEcKSK1O+J/uwYY5g/fz5NmjShWrVqJZYp7/lYGfLz85k6dSqPP/44L7zwQqljrBZ+B5d0A7X+/fsD8O233x73MuW1a9cudu/eTXR0NO3atWPZsmVcccUVDBw4kA0bNjBs2DDWrVt3XNs4kUr621feejueeu7cuTMA8+bNK2fkIiIilSTQXX1FROTMVXj5boMGDcy4ceOKPRYvXlys7JGX+n799dcGMG3atDHZ2dm+6Tt37jQJCQklXho8ZcqUUi/XLOny48JpgHnjjTeKLbN27VrjdDpN69atzb59+/zmFV4a+tRTT5WvYo6wdetWExISYpKSkozb7fab17NnTwOY+fPnl2ldxxoe4fHHHzeAGTBgQLF5eXl5Zvv27cWmr1mzxkRGRpo+ffr4TT+yzqZNm+Y379prrzWA+d///lesfOvWrU1WVpZv+vbt201cXFyx4/jtt98awJx11lkmPT3dN/3AgQOmcePGBjDfffddifHMmjXLN93lcplWrVoZm81m4uLizI8//uibl5mZaWrWrGmqV69uXC5XsX0vSXnrqW3btiY4ONjs3r272DJHnktlLVdZn5GKHr/KqN/yfp4qGmt5h0c4cjuTJ0/2mzd58mQDmIsuuqjUZUr67jie87gsMZR0PhxP/Z7Iz87atWsNYK655ppi88p7jAs9++yzJf5dKe1x5KXzxljDIxRu98hH+/btzcaNG/3KXnnllQYwy5cvL3H/IiMjTUpKynEvU15z5swxgLHb7eb999/3m3fTTTcZwLzyyit+04+33o5UluERSlPa377y1tvx1HNGRkaVHDZIRETOHEraiohIwBQmFUp7PPvss8XKHpmAGDFihAHMRx99VGzdhUnIykratm3btsR9uO2224olVwp5PB4THx9v2rVrd8x6KI3L5TI9evQwgHnrrbeKzd+6dav57bff/JJxx1JSknzs2LGmd+/eBjAJCQnm119/LVeMF198sQkODvZLzhTWWUn/6BbOGzNmjG/ayJEjDWBmzJhRrPwjjzxS7DiOGjXKAMWSEMYY8+677xrAjBo1qtg2e/fuXaz8ww8/bAAzcuTIYvMKt/PHH3+UXgFlVFI9tW3b1kRERJgDBw4cc9mylqusz0h5j19l1m95P08VjbWiSdvGjRsXG9/S4/GYRo0aGZvNZvbs2VNsmdK+Oyp6Hpc1hpLOh4rW74n+7Hz55ZfFjtXRMZT1GBcqLela2uPovwkPPfSQmTdvntm9e7fJzs42K1as8CWJU1NTTWZmpq9s3759DWA2bNhQ4v4lJyeb6Ohov2kVWaa8HnvsMQOY0aNHF5tXmOg/+geA4623I1U0aXusv33lrbfjrefQ0FBTv379csUvIiJSWYpfAykiInKS9e/fny+++KLcy61atQqAbt26FZvXtWvX447rSO3bty9x+pIlSwD48ssvS7yEMigoqEKXn3q9XkaMGMF3333H9ddfz7XXXlusTJ06dcq9XoBNmzYxfvx4v2mJiYksXLiQhg0blrjMypUrefLJJ/n+++9JS0ujoKDAb/6+fftISkrym9auXbti66lduzYA6enpvmmFx7F79+7Fypc0rXB82F69ehWbVzhm6cqVK4vNO3I8w0KFMR9r3s6dO6lXr16x+SUpTz0NGTKEu+66ixYtWnD11VfTu3dvunXrRnR0tN8yZS1XkuP5jJT1+BWqjPqt6OepvLFWVNeuXbHb/UcXs9vtdO3alQ0bNrBq1Sr69OnjN7+0746KnscViaFQRev3RH929u/fD3DMYSrKe4yPd6zbcePG+b1v06YNb731FgBvv/02r732GmPGjDmubZxohefP3//+92Lzdu/eDVDs+AR6jOCy/O07mapXr86+ffsCGoOIiJy5lLQVEZFTVmZmJna7nbi4uGLzEhISKnVbpa3vwIEDADz22GOVti2v18uoUaN47733+Nvf/sbkyZMrbd3gnyTfu3cvb775JnfffTeXXHIJP/74I5GRkX7lFy1axLnnngtY4wI2atSIyMhIbDYbs2bNYtWqVeTn5xfbTklJxcKxCT0ej29aRkZGuY5j4XGPj48vsbzNZiMzM7Nc8Rxr3tGJ19KUt57Gjh1LjRo1ePnll3n66ad56qmncDqdDBgwgGeffdaXTClruZIcz2ekrMevLOXLWr8V/TyVN9aKKq3OCqdnZGSUeZmKnscViaHQiajfyvjshIWFARxzbNqTdYz/yo033sjbb7/NDz/84EvaxsTEAKXXfWZmZrGxeiuyTHmtXLmS+Ph4mjdvXmzeL7/8ApScdA+UsvztK2+9HW895+bmEh4eXuZ9EBERqUxK2oqIyCkrOjoar9fLvn37iiU+CnsRHa2wh1pJdzY/VrLj6Du+HxkDWP/4RUVFlSnuY/F6vYwcOZK33nqLoUOHMnXq1GK96ipTfHw8Y8eOJSMjg0cffZT777+fiRMn+pV57LHHyM/PZ+HChcV6bC5ZssTXm7OiYmJiynUcC4/73r17qVmzpt+8PXv2YIwpUy/UylbeerLZbIwaNYpRo0axf/9+Fi5cyP/+9z8++OADNmzYwC+//ILD4ShzuZJU5DMSSJX9eapspdVZ4fTCBNGRjvXdUZHzuCIxHLlNqHr1W3huFiaVK8PEiRPL1ct64MCBZUpgFv4Akp2d7ZvWqFEjADZs2FCsR3BaWhpZWVl06NDBb3pFlimPrKwsNm7cSN++fUucv2LFCmrVqlXs3DtR9fZXyvq3r7z1djz17PV6ycjIKDHpLSIicjIoaSsiIqes1q1bs2LFCn744QcGDhzoN2/RokUlLlPYo2bHjh3F5hVerlweHTt25Oeff2bJkiWl/nNcVkf+0zp48GDefvvtUpNxle2+++7jjTfe4KWXXmL06NHUrVvXN2/Tpk1Ur169WCIyJyeHn3/++bi33bp1a37++WcWLlzI5Zdf7jdv4cKFxcqfffbZrFixggULFjBo0CC/eYV3Cg9E77HjqacaNWowcOBABg4cyL59+/jmm2/YuHEjTZo0qVC5QhX5jARSZX6eSlL4eapoz8wffvgBr9frl0zyer0sWrQIm81G69aty7yuip7HxxPDia7fimrevDl2u53169dX2jonTpzI1q1by1y+bt26ZfreWLp0qa98oZ49ezJhwgTmzp3LkCFD/Mp/+eWXvjJHqsgy5bFq1SqMMbRt27bYvIyMDDZv3syAAQOKzTtR9XYs5fnbV956O5563rBhA16vl5YtW1Z430RERI7Hieu6IyIicoJdc801ADz88MPk5ub6pqelpfHcc8+VuEy7du2w2WxMmzbN71LcDRs2lLrMsdx00004nU5uvfVWtm3bVmx+enp6mZLBhZeFvvXWW1x11VW88847f5mw3bZtG+vWrSMnJ6fccR8tLCyMu+++m4KCAh555BG/eampqRw8eJC1a9f6pnk8HsaOHcvevXuPe9uFYxY+/PDDfr3XduzYUeIxGT58OADjx4/3u3w8IyPDN1ZvYZmTqbz1tGDBAowxftMKCgp8vQ1DQ0PLVa4kFfmMBFJlfZ5KU716dQD+/PPPCi3/+++/89prr/lNe+211/j9998ZMGBAiUMdlKai5/HxxHCi67eiYmNjadWqFcuXL8fr9VbKOrds2YKxbrpcpseIESN8y5b2vbpu3TruvvtuAK6++mrf9PPOO4/69evz3nvv+Y1DnJGRweOPP05wcDDDhg3zW1dFlunVqxc2m82X1D+WwnWeffbZxeatWLGi1ITu8dRbRZT3b195660i9VyoMEF/PMlzERGR46GetiIicsrq06cPV199Ne+99x4tW7Zk4MCB5Ofn88EHH9CxY0c+/fTTYpdXJicnM3ToUN577z3atWvH+eefz549e5g5cybnn38+M2bMKFcMLVq04KWXXuKf//wnTZo04cILL6RBgwYcOnSIP/74g2+//ZYRI0b85bi0Dz/8MG+++SaRkZE0btyYRx99tFiZoy9DHTZsGN9++y3z588v8WZG5XXDDTfwxBNP8NZbb3HffffRoEEDAG699Vbmzp1Lt27dGDRoEKGhoSxYsIAdO3bQq1evMiUQjqV3796MHDmSKVOm0LJlSy677DLy8/N5//336dSpE7Nnz/Yr36NHD2699VYmTZpEixYtuOKKKzDGMGPGDLZv385tt91Gjx49jiumiihvPQ0cOJDo6Gg6depEamoqBQUFfPXVV/z6669ceeWVpKamlqtcSSryGQmkyvo8laZp06YkJyczbdo0QkJCqF27NjabjVtvvfWYwwoU6t+/P7fddhtz5syhefPmrF27lk8//ZS4uLhyJ8Ereh4fTwwnun6Px2WXXca4ceNYsmQJXbp0OenbP9K0adN45pln6NGjB6mpqURERPD7778zZ84cCgoKuPfee/2OjdPp5PXXX6d///706NGDIUOGEBUVxYwZM9i6dStPPfWUX8/cii5TmNAuHMv3WAqT76UlbUubd7xef/11vv/+ewBWr17tm1b4/detWze/G6OV929feeutIvVc6KuvvsLpdHLRRRcdf8WIiIhUhBEREQmQzZs3G8D079+/zGWHDx/uN72goMA88sgjpl69eiY4ONjUr1/fPP7442bp0qUGMLfffnuxdeXk5JjbbrvNJCQkmJCQENOqVSvz7rvvmvnz5xvAjBs3zle2pGkl+fHHH82QIUNMcnKyCQoKMnFxcaZt27bmnnvuMb/99ttf7t/w4cMNcMzHlClT/Jbp2bOnAcz8+fP/cv3GlK2+J02aZABz7bXX+k3/8MMPTdu2bU14eLiJi4szgwYNMps2bfLFvXnzZl/ZY9VZacfR7XabCRMmmPr16/sdx40bN5ZY3hhj3njjDdO+fXsTHh5uwsPDTfv27c0bb7xRrNyx4pkyZUqJdWuMMePGjStX/RpTvnp66aWXzCWXXGJSU1NNaGioqVGjhunQoYN5+eWXjcvlKne5yvqMlPf4nYj6LevnqSLn2pIlS0zPnj1NVFSU77N15HEpyZHbWbhwoenZs6eJiIgw0dHR5rLLLjMbNmw45jLHUpHzuCwxlLb/xlRO/Vb2Z2fHjh3G6XSaf/7zn6Xu99GOtY/HY8GCBWbQoEGmUaNGJjo62jidTpOYmGguvfRS8+WXX5a63NKlS835559voqOjTVhYmOnQoYOZNm3aMbdV1mW8Xq+pXr26qVu3rikoKPjLfWjXrp2JiooyXq+32Lxrr73WAGbr1q1/uZ7y+qu/ZUcfq4r87TOm/HVd3vLZ2dkmMjLSDBw4sKJVISIictxsxhx1vZ2IiMhp4PXXX+f666/39SoTEX/6jJTdggUL6N27N+PGjeOhhx4KdDinrWuvvZbPPvuMrVu3VqkbpVUFa9asoWXLlrz44ovcdNNNx7WuFi1akJaWxr59+yoputNP4ffjt99+G5ArN0REREBj2oqIyCkuLS2t2HifO3bs4NFHH8XhcOiyRjnj6TMip4pHH32U3NxcJk2aFOhQqpyFCxeSkJDAqFGjjms9ubm5rFu37oQMjXC6cLvdPP7441xyySVK2IqISEBpTFsRETml/fvf/+azzz6je/fu1KxZk23btjF79mwOHTrEQw89REpKSqBDFAkofUbkVJGamsqbb77J7t27Ax1KlfPPf/6zUnrEr169Go/Ho6TtMWzbto1hw4b5bpIpIiISKEraiojIKe3888/n119/5bPPPuPgwYOEhobSqlUrbrrpJr+7e4ucqfQZkVPJoEGDAh3Caa3wJmRt27YNcCRVV/369TUMioiIVAka01ZERERERERERESkCtGYtiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIiIiIiIiIiIiVYiStiIiIiIiIiIiIiJViJK2IiIiIiIiIiIiIlWIkrYiIqeJESNGULdu3UCHUWGnevwiIiIiZxq13yrHggULsNlsLFiwINCh+Dz00EPYbLYKLdurVy969epVZeIROVUpaSsiAGzatIkbb7yR+vXrExoaSnR0NF27duW5554jNzc30OGJnLZycnJ46KGHqlQj/Ui//vorDz30EFu2bAl0KCIicppR+1NOJS+99BJTp04NdBjF7Ny5k4ceeoiVK1cGOpTjVtXbxSInm80YYwIdhIgE1meffcZVV11FSEgIw4YNo0WLFrhcLr7//ntmzJjBiBEjePXVVwMdpvyFgoICvF4vISEhgQ6lQk71+Ctq3759xMfHM27cOB566KFAh1PMhx9+yFVXXcX8+fMrvceEiIicudT+PD2cSe23Fi1aEBcXd0ISil6vF5fLRXBwMHZ7+frWLV++nPbt2zNlyhRGjBhRaTG53W7cbjehoaHlXrawzVjeujpWu/h44hE5VTkDHYCIBNbmzZsZMmQIqampfPPNNyQlJfnm3XzzzWzcuJHPPvssgBEev7y8vAo1gE41QUFBgQ7huJzq8YuIiEjZqP15+lD7rWTZ2dlERESUubzdbq9yyUin04nTWXVSRlUtHpGT4fT+CyIif+nJJ58kKyuL//73v34N5kINGzbk9ttv9713u9088sgjNGjQgJCQEOrWrct9991Hfn6+33J169bloosu4vvvv6dDhw6EhoZSv3593nrrLV+Z5cuXY7PZePPNN4tt98svv8RmszF79mzftB07djBq1CgSEhIICQmhefPmvPHGG37LFY4HNW3aNO6//35q1apFeHg4mZmZAEyfPp1mzZoRGhpKixYtmDlzZoljcXm9XiZOnEjz5s0JDQ0lISGBG2+8kYMHD5Z7Pwulp6fzr3/9i7p16xISEkLt2rUZNmwY+/bt85XJz89n3LhxNGzYkJCQEFJSUrjrrruK1W9Jjt6PLVu2YLPZeOqpp3j11Vd9x6x9+/YsW7bsL9d34MABxo4dS8uWLYmMjCQ6OpoLLriAVatWFSs7adIkmjdvTnh4ONWqVeOcc87hvffe880/dOgQo0eP9u17zZo16du3Lz///HOp8QPs37+fa6+9lujoaGJjYxk+fDirVq3CZrP5XZ42YsQIIiMj2bFjBwMHDiQyMpL4+HjGjh2Lx+MpsU5efPFF6tevT3h4OP369ePPP//EGMMjjzxC7dq1CQsL49JLL+XAgQPF9vfzzz+ne/fuREREEBUVxYABA1i7dm2x4/FXMW3ZsoX4+HgAxo8fj81mw2az/WWP27KcS3v27OG6664jISGB0NBQWrduXeJnbdq0abRr146oqCiio6Np2bIlzz33HABTp07lqquuAqB3796++HTJmoiIHA+1P9X+PJbKaueUJ5a0tDRGjhxJ7dq1CQkJISkpiUsvvdQ3PFTdunVZu3Yt3377ra89VNibdOrUqdhsNr799ltuuukmatasSe3atQHYunUrN910E02aNCEsLIwaNWpw1VVXFRt2qqQxbXv16kWLFi349ddf6d27N+Hh4dSqVYsnn3zSb7n27dsDMHLkSF9spQ3j8OGHH/piPdorr7yCzWZjzZo1QMljyJb1s3g0l8vFgw8+SLt27YiJiSEiIoLu3bszf/58X5m/ahcfTzzl+cyIVClGRM5otWrVMvXr1y9z+eHDhxvAXHnllebFF180w4YNM4AZOHCgX7nU1FTTpEkTk5CQYO677z7zwgsvmLZt2xqbzWbWrFnjK1e/fn1z4YUXFtvOyJEjTbVq1YzL5TLGGJOWlmZq165tUlJSzMMPP2xefvllc8kllxjAPPvss77l5s+fbwDTrFkz06ZNG/PMM8+YCRMmmOzsbDN79mxjs9lMq1atzDPPPGMeeOABU61aNdOiRQuTmprqt/2///3vxul0muuvv95MnjzZ3H333SYiIsK0b9/eF1N59vPQoUOmRYsWxuFwmOuvv968/PLL5pFHHjHt27c3K1asMMYY4/F4TL9+/Ux4eLgZPXq0eeWVV8wtt9xinE6nufTSS8t0bI7cj82bNxvAnH322aZhw4bmiSeeME8++aSJi4sztWvX9tuPkixbtsw0aNDA3HPPPeaVV14xDz/8sKlVq5aJiYkxO3bs8JV79dVXfefEK6+8Yp577jlz3XXXmdtuu81X5uqrrzbBwcFmzJgx5vXXXzdPPPGEufjii80777xTavwej8d07tzZOBwOc8stt5gXXnjB9O3b17Ru3doAZsqUKX7LhoaGmubNm5tRo0aZl19+2VxxxRUGMC+99FKxOmnTpo1p1qyZeeaZZ8z9999vgoODTadOncx9991nunTpYp5//nlz2223GZvNZkaOHOlXL2+99Zax2Wzm/PPPN5MmTTJPPPGEqVu3romNjTWbN28uV0xZWVnm5ZdfNoC57LLLzNtvv23efvtts2rVqlKPS1nOpZycHHPWWWeZoKAg869//cs8//zzpnv37gYwEydO9K1r7ty5BjDnnXeeefHFF82LL75obrnlFnPVVVcZY4zZtGmTue222wxg7rvvPl98aWlpxzhzREREjk3tT7U/S1OZ7ZzyxNKlSxcTExNj7r//fvP666+bxx9/3PTu3dt8++23xhhjZs6caWrXrm2aNm3qaw/NnTvXGGPMlClTfMe/Z8+eZtKkSebf//63McaY6dOnm9atW5sHH3zQvPrqq+a+++4z1apVM6mpqSY7O9u3/cJzaP78+b5pPXv2NMnJySYlJcXcfvvt5qWXXjLnnnuuAcycOXOMMdY5+vDDDxvA3HDDDb7YNm3aVGL95uTkmMjISHPTTTcVm9e7d2/TvHlz3/tx48aZo1NGZf0s9uzZ0/Ts2dP3fu/evSYpKcmMGTPGvPzyy+bJJ580TZo0MUFBQb7j+lft4uOJp6yfGZGqRklbkTNYRkaGAcrUIDPGmJUrVxrA/P3vf/ebPnbsWAOYb775xjctNTXVAOa7777zTduzZ48JCQkxd9xxh2/avffea4KCgsyBAwd80/Lz801sbKwZNWqUb9p1111nkpKSzL59+/y2PWTIEBMTE2NycnKMMUUNnvr16/umFWrZsqWpXbu2OXTokG/aggULDODX2Fy4cKEBzLvvvuu3/BdffFFseln388EHHzSA+eijj8zRvF6vMcaYt99+29jtdrNw4UK/+ZMnTzaA+eGHH4ote6TSGs01atTwq9+PP/7YAObTTz895vry8vKMx+Pxm7Z582YTEhJiHn74Yd+0Sy+91K+BV5KYmBhz8803lyv+GTNmFGt8ezweX2P16KQt4BeXMcacffbZpl27dn7xAyY+Pt6kp6f7pt97770GMK1btzYFBQW+6UOHDjXBwcEmLy/PGGP9IxEbG2uuv/56v+2kpaWZmJgYv+lljWnv3r0GMOPGjTtm/RQqy7k0ceJEA/glxV0ul+ncubOJjIw0mZmZxhhjbr/9dhMdHW3cbnep25s+fXqxfyJEREQqSu1PtT+PpTLbOWWN5eDBgwYw//nPf44ZW/Pmzf0SkYUKk7bdunUr1qY6+nwwxpjFixcbwLz11lu+aaUlbY8ul5+fbxITE80VV1zhm7Zs2bJibeNjGTp0qKlZs6ZfrLt27TJ2u92v3Xp0krQ8n8Wjk7Zut9vk5+f7LXfw4EGTkJDg95k7Vrv4eOIp62dGpKrR8AgiZ7DCS7aioqLKVH7OnDkAjBkzxm/6HXfcAVBs7LFmzZrRvXt33/v4+HiaNGnCH3/84Zs2ePBgCgoK+Oijj3zT5s6dS3p6OoMHDwbAGMOMGTO4+OKLMcawb98+36N///5kZGT4XWYPMHz4cMLCwnzvd+7cyerVqxk2bBiRkZG+6T179qRly5Z+y06fPp2YmBj69u3rt6127doRGRnpdxlPWfdzxowZtG7dmssuu6xYvRZe5jN9+nTOOussmjZt6rfdc889F6DYdstq8ODBVKtWzfe+MNYj4ytJSEiIbxw2j8fD/v37iYyMpEmTJn71HRsby/bt2495yVtsbCxLly5l586dZY77iy++ICgoiOuvv943zW63c/PNN5e6zD/+8Q+/9927dy9xP6+66ipiYmJ87zt27AjA3/72N7+xsjp27IjL5WLHjh0AfPXVV6SnpzN06FC/Y+RwOOjYsWOJx6isMZVVWc6lOXPmkJiYyNChQ33zgoKCuO2228jKyvJdEhcbG0t2djZfffVVheMREREpD7U/1f48lsps55Q1lrCwMIKDg1mwYEGxoSjK4/rrr8fhcPhNO/J8KCgoYP/+/TRs2JDY2Nhi509JIiMj+dvf/uZ7HxwcTIcOHY6rLTl48GD27NnjNxTDhx9+iNfr9Z3/JSnvZ/FIDoeD4OBgwBoG5MCBA7jdbs4555wy1UNlxFOWz4xIVaOkrcgZLDo6GrDGGy2LrVu3Yrfbadiwod/0xMREYmNj2bp1q9/0OnXqFFtHtWrV/BpDrVu3pmnTprz//vu+ae+//z5xcXG+xuLevXtJT0/n1VdfJT4+3u8xcuRIwBrX6kj16tUrFjtQLPaSpm3YsIGMjAxq1qxZbHtZWVnFtlWW/dy0aRMtWrQoVu7o7a5du7bYNhs3blziPpbV0fEVNlr/qlHq9Xp59tlnadSoESEhIcTFxREfH88vv/xCRkaGr9zdd99NZGQkHTp0oFGjRtx888388MMPfut68sknWbNmDSkpKXTo0IGHHnroLxtIW7duJSkpifDwcL/pJR1DgNDQUN84WEfua0n7eXSdFCZwU1JSSpxeuI4NGzYAcO655xY7TnPnzi12jMoTU1mV5VzaunUrjRo1Knbzk7POOss3H+Cmm26icePGXHDBBdSuXZtRo0bxxRdfVDg2ETl1fffdd1x88cUkJydjs9mYNWvWCd1e4diERz6aNm16QrcpVYPan5Q4Te3Pssdc1nZOWWMJCQnhiSee4PPPPychIYEePXrw5JNPkpaWdsw4jnb08QfIzc3lwQcfJCUlxa89nZ6e7teeLk3t2rWLjeN6vG3J888/n5iYmGLnf5s2bXzHvSTl/Swe7c0336RVq1aEhoZSo0YN4uPj+eyzz8pUD5URT1k+MyJVjW69J3IGi46OJjk52TfYfFkd3XAozdG/NBcyxvi9Hzx4MI899hj79u0jKiqKTz75hKFDh/p6PHq9XsDqBTl8+PAS19mqVSu/90f+ql1eXq+XmjVr8u6775Y4/+gkXFn3syzbbdmyJc8880yJ849OKJZVReN7/PHHeeCBBxg1ahSPPPII1atXx263M3r0aN8xAauBvH79embPns0XX3zBjBkzeOmll3jwwQcZP348AIMGDaJ79+7MnDmTuXPn8p///IcnnniCjz76iAsuuKBC+1XW/SxP2b+qq8L9fvvtt0lMTCxW7ug72pYnpkCoWbMmK1eu5Msvv+Tzzz/n888/Z8qUKQwbNqzEG7SIyOkrOzub1q1bM2rUKC6//PKTss3mzZvz9ddf+97rruBnBrU/S6b254lTllhGjx7NxRdfzKxZs/jyyy954IEHmDBhAt988w1nn312mbZT0vG/9dZbmTJlCqNHj6Zz587ExMRgs9kYMmSIX3v6eGIvr5CQEAYOHMjMmTN56aWX2L17Nz/88AOPP/54mZYv62fxSO+88w4jRoxg4MCB3HnnndSsWROHw8GECRPYtGlTuddXkXiq0jkpUlZqGYmc4S666CJeffVVFi9eTOfOnY9ZNjU1Fa/Xy4YNG3y/ZAPs3r2b9PR0UlNTKxTD4MGDGT9+PDNmzCAhIYHMzEyGDBnimx8fH09UVBQej4c+ffpUaBuFsW3cuLHYvKOnNWjQgK+//pquXbseV+P76HX+1T8nDRo0YNWqVZx33nkVagxVtg8//JDevXvz3//+1296eno6cXFxftMiIiIYPHgwgwcPxuVycfnll/PYY49x7733EhoaCkBSUhI33XQTN910E3v27KFt27Y89thjpSZtU1NTmT9/Pjk5OX69bUs6hidLgwYNACvZWdFz8WjlPdZlOZdSU1P55Zdf8Hq9fr1Q1q1b55tfKDg4mIsvvpiLL74Yr9fLTTfdxCuvvMIDDzxAw4YNq8S5KCIn3gUXXHDMH9Hy8/P5v//7P/73v/+Rnp5OixYteOKJJ3x3T68Ip9NZ4g9gcvpT+1Ptz2PFU5ntnPJu+4477uCOO+5gw4YNtGnThqeffpp33nkHqFiy8sMPP2T48OE8/fTTvml5eXmkp6dXKMaSVCSuwYMH8+abbzJv3jx+++03jDHHHBoBju+z+OGHH1K/fn0++ugjv3jHjRtX4X05Ud8NIlWJhkcQOcPdddddRERE8Pe//53du3cXm79p0yaee+45AC688EIAJk6c6Fem8Jf5AQMGVCiGs846i5YtW/L+++/z/vvvk5SURI8ePXzzHQ4HV1xxBTNmzCixEbd3796/3EZycjItWrTgrbfeIisryzf922+/ZfXq1X5lBw0ahMfj4ZFHHim2HrfbXaFG1hVXXMGqVauYOXNmsXmFv+4OGjSIHTt28NprrxUrk5ubS3Z2drm3ezwcDkexX56nT5/uG9+10P79+/3eBwcH06xZM4wxFBQU4PF4il32VLNmTZKTk8nPzy91+/3796egoMCvPrxeLy+++GJFd+m49e/fn+joaB5//HEKCgqKzS/LuXi0woR0Wc+rspxLF154IWlpaX6XvbndbiZNmkRkZCQ9e/YEih87u93u6zVUeGwiIiLKFZ+InJ5uueUWFi9ezLRp0/jll1+46qqrOP/8833DxlTEhg0bSE5Opn79+lxzzTVs27atEiOWqkztT7U/S1OZ7ZyyysnJIS8vz29agwYNiIqK8murRkRElPs4lNSenjRpEh6Pp1zrOZaKtNX69OlD9erVfed/hw4dShze4UjH81ks7OV6ZF0sXbqUxYsX+5UrT7v4RH03iFQl6mkrcoZr0KAB7733HoMHD+ass85i2LBhtGjRApfLxaJFi5g+fTojRowArPG/hg8fzquvvkp6ejo9e/bkxx9/5M0332TgwIH07t27wnEMHjyYBx98kNDQUK677rpiY1T9+9//Zv78+XTs2JHrr7+eZs2aceDAAX7++We+/vprDhw48JfbePzxx7n00kvp2rUrI0eO5ODBg7zwwgu0aNHCryHds2dPbrzxRiZMmMDKlSvp168fQUFBbNiwgenTp/Pcc89x5ZVXlmv/7rzzTj788EOuuuoqRo0aRbt27Thw4ACffPIJkydPpnXr1lx77bV88MEH/OMf/2D+/Pl07doVj8fDunXr+OCDD/jyyy8555xzyrXd43HRRRfx8MMPM3LkSLp06cLq1at59913qV+/vl+5fv36kZiYSNeuXUlISOC3337jhRdeYMCAAURFRZGenk7t2rW58sorad26NZGRkXz99dcsW7bMr9fB0QYOHEiHDh2444472LhxI02bNuWTTz7xHetA9AaJjo7m5Zdf5tprr6Vt27YMGTKE+Ph4tm3bxmeffUbXrl154YUXyrXOsLAwmjVrxvvvv0/jxo2pXr06LVq0KHU8t7KcSzfccAOvvPIKI0aM4KeffqJu3bp8+OGH/PDDD0ycONF385e///3vHDhwgHPPPZfatWuzdetWJk2aRJs2bXw9Ftq0aYPD4eCJJ54gIyODkJAQzj33XGrWrHl8lSkip4xt27YxZcoUtm3bRnJyMgBjx47liy++YMqUKWW+pPZIHTt2ZOrUqTRp0oRdu3Yxfvx4unfvzpo1a8p8gyo5dan9qfbn8cRc1nZOWf3++++cd955DBo0iGbNmuF0Opk5cya7d+/2633drl07Xn75ZR599FEaNmxIzZo1fWMgl+aiiy7i7bffJiYmhmbNmrF48WK+/vpratSoUaH6KUmDBg2IjY1l8uTJREVFERERQceOHY+ZhA0KCuLyyy9n2rRpZGdn89RTT/3ldo7ns3jRRRfx0UcfcdlllzFgwAA2b97M5MmTadasmd/noDzt4hP53SBSZRgREWPM77//bq6//npTt25dExwcbKKiokzXrl3NpEmTTF5enq9cQUGBGT9+vKlXr54JCgoyKSkp5t577/UrY4wxqampZsCAAcW207NnT9OzZ89i0zds2GAAA5jvv/++xBh3795tbr75ZpOSkmKCgoJMYmKiOe+888yrr77qKzN//nwDmOnTp5e4jmnTppmmTZuakJAQ06JFC/PJJ5+YK664wjRt2rRY2VdffdW0a9fOhIWFmaioKNOyZUtz1113mZ07d1ZoP/fv329uueUWU6tWLRMcHGxq165thg8fbvbt2+cr43K5zBNPPGGaN29uQkJCTLVq1Uy7du3M+PHjTUZGRon7VGj48OEmNTXV937z5s0GMP/5z3+KlQXMuHHjjrm+vLw8c8cdd5ikpCQTFhZmunbtahYvXlxs31555RXTo0cPU6NGDRMSEmIaNGhg7rzzTl+8+fn55s477zStW7c2UVFRJiIiwrRu3dq89NJLx4zfGGP27t1rrr76ahMVFWViYmLMiBEjzA8//GAAM23aNL9lIyIiiu3DuHHjzJF/6kqrk9LOmylTphjALFu2rFj5/v37m5iYGBMaGmoaNGhgRowYYZYvX17umIwxZtGiRaZdu3YmODi4TMemLOfS7t27zciRI01cXJwJDg42LVu2NFOmTPFbz4cffmj69etnatasaYKDg02dOnXMjTfeaHbt2uVX7rXXXjP169c3DofDAGb+/PnHjE9ETm2AmTlzpu/97NmzDWAiIiL8Hk6n0wwaNMgYY8xvv/3m+zte2uPuu+8udZsHDx400dHR5vXXXz/RuydViNqfan+WpLLaOWWNZd++febmm282TZs2NRERESYmJsZ07NjRfPDBB37LpKWlmQEDBpioqCgD+Oq6tPaiMdZ3W2GckZGRpn///mbdunUmNTXVDB8+3Feu8Bw6so3Vs2dP07x582LrLKnN/PHHH5tmzZoZp9NpgGJ1UZKvvvrKAMZms5k///yz2PyS2qxl/SwefS56vV7z+OOPm9TUVBMSEmLOPvtsM3v27BL3pbR28fHEU97vBpGqwmaMRl0WkTNbmzZtiI+P56uvvgp0KFIGs2bN4rLLLuP777+na9eugQ5HROS0Y7PZmDlzJgMHDgSsu4pfc801rF27ttiNXCIjI0lMTMTlcvHHH38cc72FdwsvTfv27enTpw8TJkw47n0QqerU/hQRkb+i4RFE5IxRUFCAzWbzuzv1ggULWLVqFY8++mgAI5PS5Obm+t2Mw+PxMGnSJKKjo2nbtm0AIxMROXOcffbZeDwe9uzZQ/fu3UssExwcTNOmTSu8jaysLDZt2sS1115b4XWIVEVqf4qISEUpaSsiZ4wdO3bQp08f/va3v5GcnMy6deuYPHkyiYmJ/OMf/wh0eFKCW2+9ldzcXDp37kx+fj4fffQRixYt4vHHH6+0OyuLiIiVND3ybvabN29m5cqVVK9encaNG3PNNdcwbNgwnn76ac4++2z27t3LvHnzaNWqVYVu9jJ27FguvvhiUlNT2blzJ+PGjcPhcDB06NDK3C2RgFP7U0REKkpJWxE5Y1SrVo127drx+uuvs3fvXiIiIhgwYAD//ve/K/VmAFJ5zj33XJ5++mlmz55NXl4eDRs2ZNKkSdxyyy2BDk1E5LSyfPlyv5u2jBkzBoDhw4czdepUpkyZwqOPPsodd9zBjh07iIuLo1OnTlx00UUV2t727dsZOnQo+/fvJz4+nm7durFkyZJjDp8gcipS+1NERCpKY9qKiIiIiIiIiIiIVCH2QAcgIiIiIiIiIiIiIkWUtBURERERERERERGpQjSm7Qnk9XrZuXMnUVFR2Gy2QIcjIiIictozxnDo0CGSk5Ox29U/4VjUVhURERE5+craXlXS9gTauXMnKSkpgQ5DRERE5Izz559/Urt27UCHUaWprSoiIiISOH/VXlXS9gSKiooCrIMQHR0d4GhERERETn+ZmZmkpKT42mFSOrVVpaw8Hg+LFi0CoEuXLjgcjgBHJCIicuoqa3tVSdsTqPAys+joaDWERURERE4iXe7/19RWlbLyer00aNAAgJiYGA09IiIiUgn+qr2qpK2IiIiIiIiUym6307Rp00CHISIickbRT6QiIiIiIiIiIiIiVYh62oqIiIiIiEipjDF4vV7A6nWr4UdEREROPPW0FRERERERkVJ5vV4WLlzIwoULfclbERERObGUtBURERERERERERGpQpS0FREREREREREREalClLQVERERERERERERqUKUtBURERERERERERGpQpS0FREREREREREREalClLQVERERERERERERqUKcgQ5AREREREREqi6bzUZ8fLzvtYiIiJx4StqKiIiIiIhIqex2O82bNw90GCIiImcUJW1FpGKMgcKeFl4PFORY0xzBEBQa2NhERERERE4WY8DjAk8BBIWB3RHoiERE5DSgpK2IlMzrgX2/w86VsH8jHEqDQ7sga7f13PNu6HijVXb3Wnile9GyUUlQvQHUqG891+sOtdoFZDdERERERI5LXiakrYb8Q9Dk/KLpr/aCPevAnQeYounOMKhWF25eUjRt2xIIj4MaDYo6PoiIiBxDlbgR2YsvvkjdunUJDQ2lY8eO/Pjjj8csP336dJo2bUpoaCgtW7Zkzpw5fvONMTz44IMkJSURFhZGnz592LBhg1+ZSy65hDp16hAaGkpSUhLXXnstO3fu9M3fsmULNput2GPJkiWInJbyD0H6n0XvtyyElzrBrH/Awqdg5TuwaR7sXgM5+yGz6PNSrOF5aBds/R5+fgu+Hge/f1k0L/1PePNi+HQ0LHoBNn5trcsYRERERKTq8Xg8LFiwgAULFuDxeAIdzollDBzYDD+/DTP/Ac+3hX+nwNQL4Yu7/cu6XeDOxS9hC9Y0T77/tNn/ghfawXOtrHbwr59YnSJERERKEfCetu+//z5jxoxh8uTJdOzYkYkTJ9K/f3/Wr19PzZo1i5VftGgRQ4cOZcKECVx00UW89957DBw4kJ9//pkWLVoA8OSTT/L888/z5ptvUq9ePR544AH69+/Pr7/+Smioddl27969ue+++0hKSmLHjh2MHTuWK6+8kkWLFvlt7+uvv/Ybv6lGjRonsDZEAiBrLyydDMteg7bDoN+j1vSk1hAcCYktoeZZEJ1s9aCNSoTIRIhNKVpHzebwf4cbnQW5cOAP2L8JDmyyeunW7lBUdt962Pyd9ThSaCzUbAadb4KzLj6huywiIiIiUqK3B8IfC4pPj64NCS3A4wbH4X+jh/7PenaGWsOD2Z1QkAeuLPC6i5Z150NEHNiDIH0b/DTFegBE1ITmA+HC/5zAnRIRkVORzZjAdm/r2LEj7du354UXXgDA6/WSkpLCrbfeyj333FOs/ODBg8nOzmb27Nm+aZ06daJNmzZMnjwZYwzJycnccccdjB07FoCMjAwSEhKYOnUqQ4YMKTGOTz75hIEDB5Kfn09QUBBbtmyhXr16rFixgjZt2lRo3zIzM4mJiSEjI4Po6OgKrUPkhDmwGRa/ACveOXxJF9D8crhqSlEZrxfsldwh/1AabJpvJXT3bYA9v1mJXXO418bAydBmqPV662L49HYrcZzY0kokJ7WG8OqVG5OIiJw21P4qO9WVlJXH42HhwoUAdO/eHYfjNBqzNWM7xNQuev/ZHfDTm9bQXnW7Qp0ukNzGSroeL1c2bPkeNs6zOjDsWw/GC22ugYEvWWU8bni1JyS1sYYYq9sdYmod/7ZFRKTKKGsbLKA9bV0uFz/99BP33nuvb5rdbqdPnz4sXry4xGUWL17MmDFj/Kb179+fWbNmAbB582bS0tLo06ePb35MTAwdO3Zk8eLFJSZtDxw4wLvvvkuXLl0ICgrym3fJJZeQl5dH48aNueuuu7jkkktK3Z/8/Hzy84sug8nMzCx950UCZd8G+PYJWDPDaiQCJLeFbv+CpgP8y1Z2whasnrqFSdlC7nxr/Nzdv1qN40K7VlmN2X3rYc2HRdNj6kBSK+g+RmPlioiIiEj5eT3ww3Mw/3G48g1odvj/vF73Qt9HIDi88rcZHAGN+1sPAFcO7PnVunlZobRV1nBku9dYw5OBdY+Iut0OP7pDdFLlxyYiIlVOQMe03bdvHx6Ph4SEBL/pCQkJpKWVPL5PWlraMcsXPpdlnXfffTcRERHUqFGDbdu28fHHH/vmRUZG8vTTTzN9+nQ+++wzunXrxsCBA/nkk09K3Z8JEyYQExPje6SkpJRaViRgvnsKVk+3ErYNzoPhn8L131gN1UDd6dYZYvWkbT3Yv6dDy6vgmg/hvAeh2UCoXt+anrEN1s22eiIU2vIDfPsf2L7caoSLiIiIiJRk/yZ443yYNx68BbDhiPsvRMSdmIRtSYLDofY5kFA0HB81m8HfPoKuo62OFTa7dYXaz2/CR9db94wolJ9lDUume0OIiJyWAj6mbSDdeeedXHfddWzdupXx48czbNgwZs+ejc1mIy4uzq9Hb/v27dm5cyf/+c9/Su1te++99/otk5mZqcStlInb68Zus2O3nYDfUQpvLlY4Bm2vu61xtnrcaV3qVZVF1IBGfa1HobwM6+69u1ZBYoui6aunW2ODzX8UQmOgdnur0VuzGSQ0s8bddZzRX3kiIhIg3333Hf/5z3/46aef2LVrFzNnzmTgwIGllh8xYgRvvvlmsenNmjVj7dq1ADz00EOMHz/eb36TJk1Yt25dpcYuclrxemDZf60b5RbkQEg0nP9vaHN1oCMrEhQGDc+zHgC56bB1EWz9wXqu172o7KZ58MEwiEq2rlZL7QKp3SCuUfEbBYuIyCknoBmMuLg4HA4Hu3fv9pu+e/duEhMTS1wmMTHxmOULn3fv3k1SUpJfmaPHpo2LiyMuLo7GjRtz1llnkZKSwpIlS+jcuXOJ2+7YsSNfffVVqfsTEhJCSEhIqfNFSuL2urnykys55DrE7e1u56L6F1VO8tbtsoZBWPQ8nHUJXPlfa3r1+jDk3eNff6CExhRdHnaket0hey9sXmgldjd+bT0K3b0FwqpZr3//0ioT3wRqNDp5vSlEROSMlJ2dTevWrRk1ahSXX375X5Z/7rnn+Pe//+1773a7ad26NVdddZVfuebNm/P110V/65xO/TgpUqodP1v3Skj7xXpfrwdc+pL/zXWrorBYaHqh9Thaxnbr5maHdlodGFZPt6ZHJVlX1PW4o+hKNREROeUEtGUXHBxMu3btmDdvnq+3gdfrZd68edxyyy0lLtO5c2fmzZvH6NGjfdO++uorX6K1Xr16JCYmMm/ePF+SNjMzk6VLl/LPf/6z1Fi8XmtszyPHpD3aypUr/RLBIpVhU/omNmVsAuD/vv8/pq2bxj0d7qFVfKuKr3T3WvjoRti92nqfvRc8BeAIOvZyp7IWV1gPj9vqhZu2yhojd8+vVoK2MGELsOSlI+4KbLMa69UbQI2G1qPDDSdmPF8RETkjXXDBBVxwwQVlLl841FahWbNmcfDgQUaOHOlXzul0ltrRQaQy2Ww2qlev7nt9Sso/ZCVsQ2LgvAfgnOtO/fZe55uh3UjYvszqibvlB+v1oV3WeLg97ywqu+MnyNgB1VKhWl2rI4SIiFRpAf85fsyYMQwfPpxzzjmHDh06MHHiRLKzs32N0mHDhlGrVi0mTJgAwO23307Pnj15+umnGTBgANOmTWP58uW8+uqrgNWIGD16NI8++iiNGjWiXr16PPDAAyQnJ/sSw0uXLmXZsmV069aNatWqsWnTJh544AEaNGjgS/6++eabBAcHc/bZZwPw0Ucf8cYbb/D666+f5BqS092v+38FoGZYTbLd2azet5pr5lzDxfUvZnS70dQMr1n2lXk9sPgF+OZR8LggrDpcPBGaXXpigq+KHE6o3c56lKZWO6sn8t51kHsA0rdZjz/mQ0Q8dPpHUdmZ/4DMHVbjtlpdqFYP4hpbyd2g0BO9NyIiIvz3v/+lT58+pKam+k3fsGEDycnJhIaG0rlzZyZMmECdOnUCFKWczux2O61aHUeHgkDweqybeSW1tt7X7wkDnrGuQIuMD2xslSk43Nq3+j2t9wV5sG2RlaStVreo3HdPwfo5Re/DqkFsqtUTN66xdYNfp64aFRGpSgKetB08eDB79+7lwQcfJC0tjTZt2vDFF1/4biS2bds27Ef8AtqlSxfee+897r//fu677z4aNWrErFmzaNGiaGzLu+66i+zsbG644QbS09Pp1q0bX3zxBaGhVoIlPDycjz76iHHjxpGdnU1SUhLnn38+999/v9/wBo888ghbt27F6XTStGlT3n//fa688spy76PL5cIY4/tV2uPx4PF4sNvtfpexuVwuAIKCgiq1bEFBAcYYnE6nry69Xi9utxubzUZQUFCVKut2u/F6vTgcDhwOR7nLGmMoKCgArN7cJ6JsSfVenrJHHqO1+9fidXvpn9Kfka1H8vyK55m1cRafbPyEuX/MpXH1xtSIqEH10OpUD61OlD2KyKBIgoODcdgd2Gw2jMdA9kHsy1/DfnAjJtSJSWqHu+U1eO1eHBs/8dW7x+PB6/Fit9txOItuPOYucGOMweF0+B0jj9uDzWbDGeSsUmU9bg9erxe7w+5X7+4C6+ZkQcFBpZet1xZT92yrbH4WQfn7IWs3ZO3G4zV4180sKrtrESZrD+6dy6z1Og/3LrHZ8ETVwdvz7iPKrsYYL25bKASH4wyPxhYSBXZ7qfVe4LLOE2eQ0+88OZllq8LxPN5jf0LOk+MoW2K9l6dsAI/96XyenIxjf9znSVX8jnAV1bvNduTf8MJ6P7xvQIHr6GNkfMcoJTqFTikdimI4ge2I08nOnTv5/PPPee+99/ymd+zYkalTp9KkSRN27drF+PHj6d69O2vWrCEqKqrEdeXn5/tdVZaZmXlCYxcJmIztMON66+qrW5ZBTC1revvrAhvXyRAUCg3OtR5HikyAWudA+lbrKrzcg9Zj10oIjoJe9xSV/Wws7N8AMSkQW8e6UXDh6+hauk+EiMhJYjNGt5o8UTIzM4mJieGee+7h/vvvJyIiArBuRvHNN9/Qtm1bv5uaPfbYYxQUFDB69GhiY2MBWLJkCV988QUtW7bkiiuu8JV98sknycnJ4aabbqJmTasn5k8//cSnn35K06ZNGTJkiK/sxIkTSU9P5/rrr6dWLavB8ssvv/DRRx9Rv359hg0b5iv74osvsnfvXkaMGEHdunUBWLduHdOmTSMlJYXrritq6Lz66qvs3LmTq6++msaNGwOwadMm3n77bRITE/nHP4p6K06dOpUtW7Zw1VVX0by5dXfUbdu28cYbb1C9enVuu+02X9l3332XDRs2MHDgQN8QF2lpaUyePJmoqCjuuOMOX9kPPviAX3/9lQsvvJAOHax/BPfv38+kSZMIDQ3lnnuKGh+zZs1i5cqV9O3bl65du/qO0TPPPIPdbufBBx/0lf3ss89YtmwZvXr1olevXgDk5eX5xpd74IEHfP+Az507l0WLFtGlSxf69esHWP/4PvLIIwDcc889vh8MFixYwIIFC2jfvj0DBgwAYOjsoXzzxjf0TenLS4+8RHR0NGv2rWHM1DGs/H4lYXXCiG4X7Ytt7+y9eAu81OhbA2ek1WDK+SOHQ6sOEVorlJgORZc67ft8H548D9V7Vyco1vqHNndrLpk/ZxKSEEJsl1hf2f1f7ced5aZaj2oE17AS2Hk78sj4MYPguGCqdS8aXuDANwcoyCggtmssITWtHzry0/JJX5xOULUgqveqXlT22wMUHCggtmMsIclWWddeFwe/P4gzykmNPjV8ZQ/+cBDXHhcx58QQmmLVWcHBAg4sOIAj3EFc/zhf2fQl6eTvyif67GjC6oYB4M50s3/efuwhduIvLOpBkbEsg7zteUS1jCK8oTV+rSfbw765+7A5bNS8pKg3c+aKTHK35BLZLJKIJtZn1pvvZe+cvQAkXJbgK3vol0PkbMohokkEkc0irbJuL3s/tcrGXxyP3WklF7J+zSJ7fTbhDcKJalX0z/TumdYY3fEXxmMPscpmr88m69cswuqGEX120bHf88kejMcQ1y8OR4R1/uVszOHQ6kOE1g4lpn3Rsd87Zy/efC81zquBM9o6T3K35JK5IpOQpBBiOxUd+31f7sOT46F6r+oEVbPOk7w/88hYnkFwzWCqdS069vu/3o/7kJtq3aoRHG+dJ/k780lfmk5Q9SCq9zzi2C84QMHBAmI7xxKSePg82ZNP+g/pBMUEUf3corIHFx7Etc9FTIcYQmtZx96138XB7w7ijHRSo2/ReZK+KJ383flEt40mLNU69gXpBRyYfwBHqIO4C4rOk4wfM8jbkUdU6yjC61vH3p3lZv9X+7EH2Ym/qOg8yfwpk9xtuUS2iCSikXXsPbke9n2xD5vdRs1Li86TQysPkbM5h4imEUSedfjYu7zs/cw69jUvrYnNbiXCDq05RM6GHMIbhRPVwjr2xmvY8/Ee69gPiMcefPg8+S2L7HXZhNcLJ6pN0Xmy5+M9GK8h7vw4HGHWsc/ekE3Wmix9R+g7wle2qn9HxMc1Z+3r831lT1Q74rXXXuOGG24gIyOD6Oii+qkqbDbbX96I7EgTJkzg6aefZufOnX4/MB8tPT2d1NRUnnnmGb/22pFKunkZUGXrSqRC1s2Bj2+yEpLBUTDozaKbeoklP8tK3h7cAvs3gTsPet5VNP+F9rDv95KXDatm3Sui0JoZ1vpCY6zxd0Nji55Dok/9IShERE6AwnzhX7XB9BOZSAAVeAr4/aDVIIoPL0ogtIhrwYOdHuS9Pe+R2DCRpp2acjDvIAfzD/L5D5+Tk5dDi6QWhESF4DVedmfsZlvkNmIjQqgf1wCbMwxssDp6NQVBBTRLaEZ4dSsRsT9nP1uithBTLYaGSQ1921wTvYZ8Wz5NajYhsqaVXDiQf4DNUZuJio2icVJjX9lfY34l15tLo/hGRCdZXzDpnnQ2RW0iPCacs5LO8pVdF7uO7IJsGsQ3IDYpFoBDtkP8HvU7oTGhNE9q7iv7e8zvHMo9RN24utRIshI12UHZrItaR3BEMC2TWvrKbqq2ifSsdOrE1SE+yaq73LBcfo36FWeIk9aFl8IBm6tt5kDGAWrH1SYhyUqo5GflsyZqDXaHnbOTzvaV3bplK/v27yO5RrJvDOuCvAJ+ibJuWtEu6fCwCx4Xf8ZsY09UBonVE6mVVAt2r8bjzWWlFzBe2uTm4TjcM3eH10laVAo1q9ckJSkFti0Gdy4/He4R1io3lyCv1ajd5baxMyqFuGpxpCalWsM4eFyssIfiNdAiqLp1VYDNzu6wbLZHxVK9WnXqJdWDnP3g9bAqxOC2eWgW3YCwWCthtTc6g21R+4mtFkuDpAZQkA1eD6vD15NvCmgUXY/wmDCMgQMR6WwJ30tERBR1qzXE7jqE8Rbwe5CNPGc+9YJrExEUCsaQ4czij+AwQkPCqR3WkND8/Tg9uWzy5pNn8kj1RBLtDgFjOFRgZ409hGB7KNXtDYlzbSfCncmmAhtZbqib56V6rhsbhpy8Ar43cTg8IYS6G5Hq3kJ1c4CNBS4y3G7q5udTI9c6HDl5br52Vwd3MCanEfXZTqJtP5tcLtLdblLy86mZax2LvFwPc12xeE0Q7qxG1LftpJZ9H1tcLva73dTKzycx1zoWrlwvXxbE4LZZZVNtadSx72GbK5+9bjdJLhfJuVYQ7gIvXxRE4zq83jqOvdS1p7E938Vut5uEfBe1D5f1eg2fF0SRb4JxZzekdsFBGth3stPlYpfbTbwrnzq5RX+iP3dFkusNxp3VgATvIRrZt5OWn88Ot5saLhd1D68XYG5BBFmuYDxZ9Ym35dDU/id78vP50+0m1uWiwRFl57kiyHAF4cmpR1xQPs3sW9mfn88Wt5uYAhsNjyg7Pz+cg64gPNmpRIZ6aGnfzIE8F5vdbqIKoPERZb9zhbHPFY0nO4XwbBut7ZtIz3Oxye0m3AVnHVH2B1cYu13RuHNqE53tpK19A4fyC/jd7Sa0wND8iLJL8sPY6YrCnZOMMzuU9vb1ZOe5Wed2E1zgpeURZZfnh7LNFYknJxF7ViQdHb+Rm+vmV7cbp91G6yPK/pwfzBZXBJ7cBExWDF0ca8nP87DG7cZubJx9RNlf8oPY6ArHkxuPJ6sG3RyrKcj38svhz3K7I8quzXey3lUDT24N3Fk16WpfjfF4WXm47JHfEevz7Kx11cCTVx13VgKd7Wtx2jwlfkdszLPxi6sGntxY3FlJdLD/RoitgBVuN16PoUVeLiF2K2m7NR9+ctXAkxeNO6sW59jXE2bLZ1VBAW63oVluLmGHeyhvz/Pyo6sG3rwo3Fm1Odu+gUhbLqsLCnC5vTTNyyMi14onLdfDIlcNvHkRuLPq0Nq+iWhbNmsLCshze2icn0dUrgeAfXkevnPVINR+xNjiUibGGN544w2uvfbaYyZsAWJjY2ncuDEbN24stcy9997LmDFjfO8zMzNJSaniN2GSKsHj8fDDDz8A0LVrV1/nhSrFnQ9zH4AfX7HeJ7WBK9+AGg0CGlaVFBIJCc2tR0kufh4Obob0PyFj2+HnP60ezNG1/Mt+95R1D4mSRCXBHeuK3n821lpPcKQVQ/DhR0gURMRB66KOR+SmgzPUGrLhVB1HWUTkOKmn7QlUmDnfu3cvNWrU0PAIGh6h2DH6bf9vDJo9iAh7BN8O+pbg4ODyHfv8Q7B7LZ7aHcp87E/n8+RkHPtynSd2G7a8g5C9D4/bhSe+WVG9f/UgZO7EdWg/JucAzvyD2PIOYs/PJDuuDb8P+JA8j8FtbJwzqxdh2dtxeczh9RbdBGRPcAqTmryDy+ulwGvnjk0jqeX6o8Syu0x1LnW8TIEXCoyNd7iPNraNFHgMBnDawX647D5vBO1yXwJs2BxO/hf0KJ0dv5ZYNtsbRLPc131l3wh6knMdK0ss6zWG+jlTMTiwOZy8EPQ8FzmWlFq2Rc4rZBOOzeHkCeerDHYuIMdjJ88E4bU58diDyDdB5JgghuTdx0GisTmcXO2YR3f7L+R57eQbB16C8NqdFODEZRy84LqITCKwOYLobF9LS/tmXF47Huxgd2LsQRibAw8OFnpbkeuIxBEUTF12kGL24PGC12YDRxA2RxDGZgdsbLHVpsAehiPISQ0yqG7SwesFAzaHHbvDgQ0bNhvstcXhcoThdAYRbQ4R6z0IxovNWJecBzsd2GxgB/bZ48l3RlhlvQep6dphjdVnPDhsEOSwY7d5sWPYEVSP3KDqOIKCiHXvoVb2b+AtAK+HIJvB6bBhx2DHy5bwlmSG1cbhcFLdtZOGB77FeN3gKcBp8xJsN9iNB4dx83v1XuyNboHD6aRG7hZa/vkueDx4vAU4MIQ4wHY4/t8TB7AjvgcOp5Pq2Ztpu+5Jq6zHg93mJdgONqyyG2pfydY6l+J0BBGTs5kOS2/DeNx4vAabDYLtBhsGjJdNqYPZ3OTvOBxOInN30OmbQdZ3hMdgt0GQAzAGMPxZbzCbWv4Lh9NJuGs/53za1/p8esAGBDuL/gHclXoJm9uPw+FwEOzN5ezpnQ6XtcoEBxX2ErKxN6Uf27pMwOF04rAZWrzbFsMRZZ32w/9c2khP6sb2nk/jcDhxOBw0nNYNPC4KPDaMzUaQw47t8Hddds227Or19OHviCBqz7gYuyvr8GfZRpDTKmuwkVe9KXt7W1eJBAU5ift0OPbc/bjcBmw2gpwOa702O/lRdTjQ+0kcDuu7J+arMTiy0orKOmy+sgWh8aSfOwGH3YHT6STiu/E4MrZR4DYYGzgd1pAHNpsNd1A0h3o/ht1ufTeHLH4Gx8HNFLi9GOw4HXbsDps1lIw9lJzeDxPsdBAVXjQW+In6+3DgwAFq1KhRZXuPlqen7YIFC+jduzerV6/2GwqsJFlZWdSpU4eHHnrIr+fxsZS1l4eIx+Nh4cKFAHTv3r3qJW33bYQPR1o3GgPofAucNw6cx/6xQ8rJ64X8o27wO+dO694QuemQl374OQPcuRDXBG75sajsi51g728lrzsyEcauL3r/xvlWJwdsEBRujd0bFGa9joiHEbOLyv7wHBzYbM13BBcle53WkGW0G1FUdvtPVi9shxPsTrAHHX52WM+JR3zXZu4EVw5gfO0LjLfodc1mRQnl/ZusYSe8HvC6Dz8KXxdAkwFFw0psmm/dONpbYN1E2XvUo+fdEHr4O3nNDKt84XaN1z+O/o9B1OEbUq6dCes/t+rM4bTqwhFs3YzaEQzt/w7RyVbZXb/AzhWH6ykEnGHW0BrOw48aDa3EeuFxt9mUPBepRGVtgylpewKpISx/5cPfP2T84vF0SurEa/1eK9/CXg/8bwhs+gYufdH/l2k5YTxeQ7bLTXa+m6w8N4cOP2fnu8nKd5Pj8hx+dpOd77GeXR5yDs+zHm7yCrzkuKxp+W6vb/1BuAkjn0wifNPOt/9INdshwnARRj7htjxCKCCEAg4QxbPuq3xln3C+Sqp9N048OHETjAc7Xhx42WdiGFpwv6/sq0FP08y+tdg+GuxkEsE1jicJctgJstu4z/MyTb0b8docGJsdY3NibHa8NicF9hBeSHycIIcNp8POuekzSHZtxmMPxesIxuMIxesIwTiCwR7EmuSrcDidOO02ErJ+JbLgAMYZit0ZBI4QcAZjc4ZicwSTH1Ubp8OJw2HDYYp+HHHYbdbDZit6fdQ0++HXdjvFpjnsNuw2G3ablSS229UIFTldVMX2V1ZWlq8H7Nlnn80zzzxD7969qV69OnXq1OHee+9lx44dvPXWW37LXXvttWzYsIElS5YUW+fYsWO5+OKLSU1NZefOnYwbN46VK1fy66+/Eh9ftpssVcW6kqqpyidtF/wbFkyA8BowcDI07hfoiMSdDwU5/gne37+07iORnwWuLMg/dPg5C4IjrBsoFzpmgjcBxh4xfMN/+8Ofxb8nAasn7307it6/fTlsmldyWZsdxh0sej/tGlg3u+SyAPfvLfphYMbfYfX00svevdUaNgLgk1vh57dKLzvmt6Lk6uf3wNKXSy97y08Qd/jqya/Hw/fPlF72xoWQdPiGgt/9x7p5dWlGfg6pXazXS1+BL+6BoAjrOPkeh3tM9/4/SG5jld2/yUoGh1e3jn1I9OFHlHpNixxBwyOInALW7l8LQPMapVyadCxz74cNcw//EtqokiM7fRljyC3wcDCngPQcFxk5BaTnFpCRW0Bm4XNeARm5bg7lFXAo78hnKzF7ItmdwdiCwkgMchAW7CDEaSctqB/pQXZCgxyEOh2EBNl9zyFOO7cGWeWCnXbynM+x2Wn3vQ922AkJchDssN7PPjw9yGEn2HmuNd1hJ8hp9fpz2m2+nrkr/SI79lhwr/u9O7uUUhb/22LokkUROf0tX76c3r17+94XDlEwfPhwpk6dyq5du9i2bZvfMhkZGcyYMYPnnnuuxHVu376doUOHsn//fuLj4+nWrRtLliwpc8JW5JTnyrF6UQJ0vd26hP/c+yE6KbBxiaWwB+eRGvcv+/L/XASuQ1CQayV/C3KLXh/d76zdCGjQ20oUu/OtMXrdeVZ5x1G9ravVhYSWYA73gvUUFPWIPVpQOITEWJfoYLOSuoevpsFmO9zr9bDIBKhWz+rVemTP3cIHR8Rc6xwrNnvQ4R6/h5dxBFnbCAo/os76WUNH2OxHPA7Hgg0iisb+p1E/K1lqvEX75nEdfhRAZNEY/VSvD40vAM/hOivILaozd56VkC3kyrbW6TpkPY7W9fai15u+gTlji5cBaz+H/g8a9S0qu/SVoqSu73H4fb3u1k3wwErw52UcThZH6WZ4csZQT9sTSL0X5K8M+nQQvx34jWd6PUPf1L5lX3DZf+Gzw2PSXTUVml92QuI7FXi8hvQcF/uyXOzPymdftvV8MNvFgRwXB7MLOJDt4mCOiwPZLtJzCnB5vH+94r8Q5LARGeIkMtRJRLCTqFAn4cFOIkOchAc7iDj6OdhJeIiD8GAHYUHWtPBgB6FBh6cFWwlZ9fgUETk+an+VnepKyqpK9bTN2mN1Xti7Dq6fbyXHROTEKci1EqaubCtpnp9lvXYdsp4b9StKCK+daf2vmrPfWib/EORnFq1r+GwrGQvw42ulJ3gBhrwHTa2bd7PqfZh5Q9G8oHAruRt6uCfvuf8HDQ53Ddm1Cla8e3i4hzCrN7QjpOiHhNSuRWNdZ+21hlUpHELC7rQS4oWJ96gkKxFeWA9Ze/yT50cKiS4aUsKdbw2X4Uu3HTnEhrF6XRf2Qi/Is4YYObpM4XLhcRCVULTefRus177tH/GDQli1omPhccOhnUfE6/D/QcERrCFkAkg9bUWquHxPPhvSrS/cZjWalX3BTfOtsaPA6k1wmiZsCzxe9hzKJy0jj92Z1mPPoXz2Hn4Uvj6QnY+3Aj89BTlsxIQFUy08iJgw6xF9xHN0qNP3HBUaRFSok+jQICJDrQRtiFP/IIiIiIicVKs/hNljrHFVscHWH6Bej0BHJXJ6CwqzHmXR/LLi/596vUXDYRQmQMH67F783OFetJlFZQofhUNEgDX+ryPY6jUMh3tf50BWmvXelVNUdu/vRTckLMllrxQlbbf/CNOuLr3sRc/COaOs19sWw9vH+N+7/wTofJP1eucKeOMYPcvPvR96HP6fft96eOUY32Pd74DzHrReZ2yHyV1LL9vxH3DBE9br7L0wsWXpZc/+mzXMIljJ92eaWVfx+sY2DilKdtfvBT0OJ9iNgc/uKCoTFHZ4ucPnSbW6ULdb0XZ2rz08//C41MGR6ildDqopkQDZcHADbq+b2JBYkiOS/3oBgN2/wgfDrMt5Wg2G7sf4ZbIKM8awL8vF9oM57EzPY1dGLjvSc9mZnsuujDx2puexPzu/2JVPx1ItPIgakSHUiAimRmQwNSJCqBYRTPXwIOs5Iphq4cFUiwgmNiyI8GCHbxgAEREREaniVr0PM28EDCS1gYuegVrtAh2ViPwVu93qERt6VG/C+CbWoyzO/pv1cLuKeu/mZ1rJ3vxM/++C+CZWorMgz7ohnttVNAyEx1U05AJYwy0ktCwaRsJ4Dg+XcXjIjKCII4KwWT13jdcqdzSb3b+sb2gOW9H8wuE17EU3UMXuhNBY/6E3fM+HYzxyGxFHDHNx9E3yjhxaA3M4Xo/vBsZ+7EekAwvyrBsJlia6VtFrdz4s/2/pZZte5J+0faVH8eFHgiMhNMZKBg98qWj65/dY+xgcYQ1/ExxpvQ4Kt5L4KR2KymbttZLGwRGn9RUXStqKBMiv+38FrPFsy5w8XP2B9UepTme4ZFKVHsg9r8DD9oM5bNmXw5b92Ww7kMP2g7n8efg5t6CEP3RHcdptJESHkhAdQmJMKDWjQomPCvE9akaFEB8ZQvWIYJwO+1+uT0REREQqJjY2NnAbX/0hzPoHYKDdSBjw9Gn9T7qIlMIZDM4a/mP5Hi2pVdEN1/5K/V7wz+/LVrZBb7g/rWxl63SEB/aWrWxCc7in+M2hS1S9Hty5oWxlo5OLx+s9PN6x8eBLJoOVQL152eEkd+EYx/lFye6YlKKyNhv0ureojDvXSvoW5FjjISe1KSrrcVs3iHTlQEF20TjQrsM3I8xN949v2etWr+qS1O0OI464MeCLHSD3gPXaGVp0g7ygCEg+Gy474gZ+3zxmJeVDo619DYk5XDbM6v2d1LoMFRoYStqKBEjhTcjKNTTCeeMgKhlaXll8YP8A8HoNOzNy2bQ3m017sti0N4vN+7LZuj+HnRm5x+wpa7NBYnQotWLDSIoNIzk2lOSYMJJjw0iKCSUhOpQaEcEa41VEREQkwBwOB23atAnMxn/7FD66wfpnv+0wGPCM1XNPRETKx24Hewnj2DqcEN+4bOtwhkCve8pW1uGEsb9br42xEqf5WdYQN3kZVk/gQsZAz7sOj5dc+MiyksGubEho4b9ud/4Rrw/ffDBnv/X+yGE4AJa/ATn7So4xqTXc+F3Z9icAlLQVCZC1+6ykbfO45scu6HFbGU67w3rueMOxy58Axhh2ZuTxe9oh1u8+xPo06/HHvizyCkq/qVdkiJO6ceGk1oigTvVwUqqFk1I9jJRq4STHhhHsVINbRERERI4hvilExEPD8+Ci55SwFRE5FdlsRTeDK6mntM1mJW3L6r4dVqLWlXO45252UYL3yCElwBrrN2f/4SE1MqxnV47VW7h6g+PbrxNMSVuRAMhz57ExfSNgDY9QKmPg8zshcxdc+d/iXz4nQL7bw4bdWfy6M5O1OzNYuzOT9WmHOJTvLrF8kMNGao0IGsRH0CA+kvrxkdSLC6dujQiqRwRr3FgRERERqbi4RnDDfIhMUMJWREQsNlvRzc+ONVwGQM87T05MJ4CStiIBsP7gejzGQ43QGiSEJ5Re8IfnrK782ODPpdDg3EqNw+3xsn73IX7ZnsGqP9NZtT2DDbsP4fYWH9fAabfRID6SxolRNEmIpFFCFI1qRpJSPZwgjScrIiIictryeDwsWbIEgE6dOuFwnMDxZLP3w/fPQP3e0KiPNS26jDftFREROY0oaSsSAIVDIzSr0az0nqi/zYavx1mvz59QKQnbvYfy+WnrAZZvOcjKP9NZszOjxOENYsKCaJ4cffgRw1lJ0dSLi9BwBiIiIiJnqIKCUm4OU1nyMmHxi9bDdQh+fA3+/lWVvkGMiIjIiaSkrUgAFN6ErNTxbNP/hI9vsl53uAE6/bPc2zDGsHV/Dks372fZloMs33KALftzipWLCnHSKiWG1rVjaVU7lha1oqkVG6ZhDURERETkxHPlwLLX4PtnIfegNS2xFZz3oPUsIiJyhlLSViQAft3/K1DKeLYeN8z4uzVAdq120P/xMq93R3ouizftZ9GmfSzetJ9dGXl+8202aFwzinPqVqNtnWq0TomlflwEdrsStCIiIiJykm35Hj66ETK3W+/jGkPv/4OzLtH4tSIicsZT0lbkJMspyOGPjD8Aa3iEYr59Av5cAiHRcMV/wRFU6rpyXR4W/7GPb9fv5dvf9xbrSRvksHF2SjXa16vGOanVaVunGjHhpa9PREREROSkCY+DrDSISYFe90KrweDQv6giIiKgpK3ISbf+4Hq8xkvNsJrUDK9ZvEDD82DV/6DPQ1C9XrHZf+zN4pt1e/j2970s3XwAl7toTFq7DVrVjqVLgxp0aRBHu9RqhAWfwBtFiIiIiIiUlSsb/vgWml5ova/ZFK5+H1K7WncAFxERER8lbUVOMt9NyOJK6GULUKcT3PwjBIcD1ti0v2zPYO6vacxdu5sNe7L8iteKDaNnk3h6NY6nU4MaRIeqJ62IiIiIVDHblsKHI+HQLrjhW0g6PF5twz6BjUtERKSKUtJW5CQrvAmZ39AIxkD6VqhWFwCvM4xlf+zns9W7mLt2N2mZRWPTOu02OtWvQa8m8fRqEk+D+EjdNExERERETqioqKiKL7z9J3jnCnAdgthUyD9UeYGJiIicppS0FTnJCpO2fjchW/IyZt7DbO/yMG/mdGP2L7v8ErURwQ56Na1Jv2YJ9GpSk5gw9aYVERERkZPD4XDQrl27ii286xd45zIrYVu3O1z9ge+KMhERESmdkrYiJ1F2QTZbMrYART1td2zfSsLccTiNi1fnreVtTy0AokKdnN88kQtbJtG5QQ1CgzQ2rYiIiIicQvb8Bm8PhLwMSOkEQ6cpYSsiIlJGStqKnES/7f8NgyEhPJEf1ufzwfIlnL35dcYGuVjpbcB0e38uap7Ixa2T6dUknhCnErUiIiIicgrK3AVvXQo5+yH5bLjmAwiJDHRUIiIipwwlbUVOorl/LAFg9944bv9pJQ48PBUyDwD3OdfzU/9+RIToYykiIiIiVYfH42HZsmUAtG/fHoejDB0LIhOg6UXw54/wt48gNOYERykiInJ6UXZI5ATLK3AxacknfLjxf+TYfwcgP7sWtWLDuCf1d5LWH4DwOM65cBQ49ZEUERERkaonLy/vrwsdyW6HAU9bNx0LjT4xQYmIiJzGlCESOUG2HNjD49+/yZK9n2KcB8EOxtiJt7fjzgE30r9pQxxvPWMVbjccnCGBDVhERERE5Hi4cmDJS9D1dnAEgc2mhK2IiEgFKWkrUsl+SdvCuAUvsSH3a2z2AutT5gmnRVR/7u42ijZJda2CmTth22Kw2eGcUYEMWURERETk+Hi9MPMG+O1TSFsNg94MdEQiIiKnNCVtRSrJoq3rePT7F9lWsBCbzYPNDkHu2pxf50ru6jaE2LAI/wWik+H2X2DrDxBTOzBBi4iIiIhUhq/HWQlbRzB0vDHQ0YiIiJzylLQVOU5fb1zFhMUvsNuzFJvNYLNBuLcx17X4O39v1x+73V76wjG1oNWgkxesiIiIiEhl+2kqLHreen3pi5DaJaDhiIiInA6UtBWpoJ92bOLe+U+x0/2DL1kba1pzc9sbGNKqx7EXduVAcPjJCVRERERE5ETZ9A3MHmO97nWfOiSIiIhUEiVtRcrpjwO7GTv3GX7P/RKb3YPNBjU4h7s63cKFTdr99QqMgVd7QWyKdUfdanVPdMgiIiIiIsclPLyEDgd7foMPhoPxQKsh0POukx+YiIjIaUpJW5EyOpiTxb++fI7l6TOx2fOx2SHC25S7Oozh8uady76iPxbAvvXWjcjCa5yweEVEREREKoPD4aBDhw7FZ+zfZD2ndoVLngeb7eQGJiIichpT0lbkL3i9Xp5dNJM31z+PcR6wbjDmSeHGFrdw/TnnH3vM2pL8+Jr13GYohERVfsAiIiIiIifDWRdBchtwhoEzJNDRiIiInFaUtBU5hkVb13Hn/IfJtK0GJ9jc1RjU4B/c030wToej/CtM/xN+/9x63f7vlRusiIiIiMjJFlM70BGIiIiclpS0FSnBwZwsbvn8KVZlzsJm92CMg1aRlzDpwjupEX4cvWOXvQ7GC/V6QnyTygtYREREROQE8Xg8/PTTTwC0a9kUx/tXQ9fR0KhPYAMTERE5jSlpK3KUt1d8w1M/P4zXuR+bHaK8zXmi14N0r9fs+FacexCW/dd63fEfxx+oiIiIiMhJkpOTY72Y9yhsWQgH/oBbf4ag0MAGJiIicppS0lbksPTcbK775BHW587B5jTYPLH8reGtjO12ZfnHrS3Jz2+B6xDUbA6Nzz/+9YmIiIiInEwHt8Kqw/dnuOR5JWxFREROICVtRYAZaxfxyJIH8Dj3YLNBbWcvplz+MIlR1SpvIx3/ASHREF0LKiMJLCIiIiJysrjzYf0c63W7EdBQQyOIiIicSFUic/Tiiy9St25dQkND6dixIz/++OMxy0+fPp2mTZsSGhpKy5YtmTNnjt98YwwPPvggSUlJhIWF0adPHzZs2OBX5pJLLqFOnTqEhoaSlJTEtddey86dO0vc3saNG4mKiiI2Nva49lOqnuz8fK7+8CHGLfsHHuce8ERzQ+NH+fyaSZWbsAXrjrrnjITG/Sp3vSIiIiIiJ9qm+ZCXAdEp0O/RQEcjIiJy2gt40vb9999nzJgxjBs3jp9//pnWrVvTv39/9uzZU2L5RYsWMXToUK677jpWrFjBwIEDGThwIGvWrPGVefLJJ3n++eeZPHkyS5cuJSIigv79+5OXl+cr07t3bz744APWr1/PjBkz2LRpE1deeWWx7RUUFDB06FC6d+9e+TsvAbVi52Z6vnslq7NnYLMZEuyd+ezyWdza+dLK3ZDbBR535a5TRERERORk2bkSdq20Xl/yPIQcx415RUREpExsxhgTyAA6duxI+/bteeGFFwDwer2kpKRw6623cs899xQrP3jwYLKzs5k9e7ZvWqdOnWjTpg2TJ0/GGENycjJ33HEHY8eOBSAjI4OEhASmTp3KkCFDSozjk08+YeDAgeTn5xMUFOSbfvfdd7Nz507OO+88Ro8eTXp6epn3LTMzk5iYGDIyMoiOji7zcnLivfzjHF5aMx4cOeAJ49qGY7mrx6ATs7Elk2HpZOg7HppVckJYRERE/Kj9VXaqKykrz/wnWfjtNxDXhO7/fAaHwxHokERERE5ZZW2DBbSnrcvl4qeffqJPn6LxkOx2O3369GHx4sUlLrN48WK/8gD9+/f3ld+8eTNpaWl+ZWJiYujYsWOp6zxw4ADvvvsuXbp08UvYfvPNN0yfPp0XX3yxTPuTn59PZmam30OqFpfbzbUzHuHFX+8BRw7BnjpM6ffeiUvYuvPhh+fg4GbI2X9itiEiIiJV2nfffcfFF19McnIyNpuNWbNmHbP8ggULsNlsxR5paWl+5co7xJhIhXX7F6Fdrie0ca9ARyIiInLGCGjSdt++fXg8HhISEvymJyQkFGuUFkpLSztm+cLnsqzz7rvvJiIigho1arBt2zY+/vhj37z9+/czYsQIpk6dWuaeBxMmTCAmJsb3SElJKdNycnL8cWA3vd7+GyuzPsBmM6QGncc3V0/nnNoNT9xGV74Hh3ZCVBK0uebEbUdERESqrOzsbFq3bl3mjgCF1q9fz65du3yPmjVr+uaVd4gxkePhCAqiU78r6NTvMvWyFREROUkCPqZtIN15552sWLGCuXPn4nA4GDZsGIWjRVx//fVcffXV9OjRo8zru/fee8nIyPA9/vzzzxMVupTTlxtWMHDmVRyyr8V4g7i89lhmXz2RmNDwE7dRTwF8/4z1uuvt1o3IRERE5IxzwQUX8Oijj3LZZZeVa7maNWuSmJjoe9jtRU33Z555huuvv56RI0fSrFkzJk+eTHh4OG+88UZlhy8iIiIiARDQpG1cXBwOh4Pdu3f7Td+9ezeJiYklLpOYmHjM8oXPZVlnXFwcjRs3pm/fvkybNo05c+awZMkSwBoa4amnnsLpdOJ0OrnuuuvIyMjA6XSW2hgOCQkhOjra7yGB99/lXzL2+xswzoPY3fE82+0Nxp83/MRvePV0SN8GEfHQ9iRsT0RERE4rbdq0ISkpib59+/LDDz/4pldkiDGRClv4NMz8J/y5LNCRiIiInFECmrQNDg6mXbt2zJs3zzfN6/Uyb948OnfuXOIynTt39isP8NVXX/nK16tXj8TERL8ymZmZLF26tNR1Fm4XrHFpwRo7d+XKlb7Hww8/TFRUFCtXrix3LwkJnAe+nsqza+4Cex7h3sbMvmI6fRu1OfEb9nqsBi5A51sg+AT26BUREZHTSlJSEpMnT2bGjBnMmDGDlJQUevXqxc8//wxUbIgx0P0XpIJWz8Czaho/rVjFTz/9hMfjCXREIiIiZwRnoAMYM2YMw4cP55xzzqFDhw5MnDiR7OxsRo4cCcCwYcOoVasWEyZMAOD222+nZ8+ePP300wwYMIBp06axfPlyXn31VQBsNhujR4/m0UcfpVGjRtSrV48HHniA5ORkBg4cCMDSpUtZtmwZ3bp1o1q1amzatIkHHniABg0a+BK7Z511ll+cy5cvx26306JFi5NUM3I8vF4v1338b5Zn/g+bDWraOzFryPNEhYSdnAA2fwf7N0JoLLS/7uRsU0RERE4LTZo0oUmTJr73Xbp0YdOmTTz77LO8/fbbFV7vhAkTGD9+fGWEKGeKrD2wZy1g51BoMhw6FOiIREREzhgBT9oOHjyYvXv38uCDD5KWlkabNm344osvfD0Htm3b5jd+V5cuXXjvvfe4//77ue+++2jUqBGzZs3yS6beddddZGdnc8MNN5Cenk63bt344osvCA0NBSA8PJyPPvqIcePGkZ2dTVJSEueffz73338/ISEad/RUl1OQzxUfjGW7ewEAzcMH8s7lD+E8mTdNaNAbhn0CWbshJOrkbVdEREROSx06dOD7778HKjbEGFj3XxgzZozvfWZmpm6cK8e2+TvrOaGFrhwTERE5yWym8M5bUukyMzOJiYkhIyND49ueJIfyc7lw2g2ksxJjbPRP/CdPn//PQIclIiIiJ0lVb3/ZbDZmzpzpuwKsrPr27UtUVBQfffQRAB07dqRDhw5MmjQJsK4yqlOnDrfccgv33HNPmdZZ1etKqoCPb4EVb+PpeAsLw/oC0L17dxwnszOEiIjIaaasbbCA97QVqSyH8nO58H/Xk25bhfEGcUOTcdzW5dKTG8TuXyE0GmJqn9ztioiISJWVlZXFxo0bfe83b97MypUrqV69OnXq1OHee+9lx44dvPXWWwBMnDiRevXq0bx5c/Ly8nj99df55ptvmDt3rm8dfzXEmMhxMwb++NZ6Xa8HlD5csoiIiJwAStrKaeHohO2/Wk7gunP6n9wgPG746AY48AcMegsa9fnrZUREROS0t3z5cnr37u17XzhEwfDhw5k6dSq7du1i27Ztvvkul4s77riDHTt2EB4eTqtWrfj666/91vFXQ4yJHLeDmyFjG9iDoE4nSFsR6IhERETOKBoe4QTSJWcnR5VI2AIsmQxf3G3dfOzWnyAi7uTHICIicoZT+6vsVFdyTH/+CJ+OhrBYPMM+ZeHChYCGRxARETleGh5Bzgj+CVtn4BK2h9Jg/mPW6z7jlLAVERERkVNbSge4aRG48wEICgoKcEAiIiJnFiVt5ZSVnZ9/VML234FJ2ALMvR/yMyG5LbQdHpgYREREREQqmzMEB9C1a9dARyIiInJGsQc6AJGK8Hq9DPzg9qqRsN38HayeDtjgomfArsvFREREROQUln8ICnIDHYWIiMgZTUlbOSUNm/kIad4fMMbOzc0fC1zC1uuBz++2Xre/DpLPDkwcIiIiIiKVZfkU+HcqzHs40JGIiIicsTQ8gpxy7pv7X1ZlfQjAJcm38c8OFwYuGI8LGvaBnANw7v2Bi0NEREREpLJs/hY8+RBeAwCPx8Pq1asBaNmypW5EJiIichIoaSunlFeXfc4nO5/HZoNWkVfweL/rAhtQUBj0ewR6/x8EhQY2FhERERGR4+V2wdZF1uv6vXyT09PTAxKOiIjImUrDI8gp44vff+b5NQ9gs3lJtHfh7cseDHRIRZSwFREREZHTwfZlUJADEfFQs1mgoxERETljKWkrp4Q1adu46/vbsNnzCfc25qNBE7HbA3j65h6E/w2FP5cFLgYRERERkcr2xwLruV4PsNkCGoqIiMiZTElbqfLSc7MZNucfGEcGDnciH172KlEhYYEN6vuJsH4OfHobeL2BjUVEREREpLJs/tZ6PmJoBBERETn5lLSVKu+amfdR4PgTPJG83v9lUmJrBDagjB2wdLL1+rxxEMgevyIiIiIilSX3oDU8AihpKyIiEmC6EZlUaQ/Pf4dtBd9gjI3bWj7IObUbBjokWDAB3HlQpws07h/oaEREREREKokN+j0Ke36D2DqBDkZEROSMpqStVFkL/ljDB1uexWaHNlFXcEP7CwIdEuxZByvftV73Ha9xvkRERETk9BEWC51vLnFWQO8nISIicgZS0laqpIM5Wfxr/h3YnC4ivE1449L/C3RIlnkPg/FC04sgpUOgoxEREREROeEcDgc9evQIdBgiIiJnFP1cKlXS1TPvxe3cCZ4opl40kWBnFfh9YftyWP8Z2OzWWLYiIiIiIqeLfRthxTtwKC3QkYiIiAjqaStV0Lh5b7LdvQBjbIxp9RBN42sHOiRLdDJc/Bx4CiC+caCjERERERGpPGtmwILHrSvKhrwb6GhERETOeEraSpUy/4/VzNj2HDY7tIsezKhz+gU6pCLRydBuRKCjEBERERGpfBvmWs8l3GjX6/WyZs0aAFq0aKHxbUVERE4CJW2lynC53dy54F5sjgIivWfx2iV3BzokEREREZHTX/Y+2PGT9bph32KzjTEcOHDA91pEREROPP1EKlXG6M8nke/YCp5QXr/w6aoxjm2hg1vgx9dg99pARyIiIiIiUrk2fQMYSGwJ0UmBjkZERERQ0laqiCXb1vPdvrcBGFD7BponpAQ4oqNs/BrmjIUv/y/QkYiIiIiIVK7CoRFK6GUrIiIigaGkrQSc1+tl9Lz/w2YvIMLblMf7XBfokIrbvtx6rt0+sHGIiIiIiFQmr8fqoADQqArdT0JEROQMp6StBNw9X71Otn09xhvEc+c9XjVvbPDnj9ZzSofAxiEiIiIiUpn2/Aa56RAaow4KIiIiVUgVGjRUzkRr0rbx+c7XwA494q6lY51GgQ6puOz9cGCT9br2OYGNRURERESkMiW2gDs3wb714NC/hyIiIlWF/ipLwHi9Xv755f+BPY8QT10mXnBroEMq2Y7DQyPENYawaoGNRURERESkskXUgIgugY5CREREjqCkrQTMhO+mkc5KjHEwofsjBDur6OlYODSCLhcTERERkTOQw+GgV69egQ5DRETkjFIFBw+VM8H2jANM+2MSAG2jr6RvozaBDehYti+znpW0FREREZHTyappMOVCWPV+oCMRERGRo1TRro1yuhvz5dPgyMLhTuTlAXcGOpxjG/Iu7PgZ4psEOhIRERERkcqzfg5s/QHqdg90JCIiInIUJW3lpFu+fSO/Zn+GzQ5/b34bESEhgQ7p2EKioH7PQEchIiIiIlJ5PAWwab71ulG/Yxb1er389ttvAJx11lnY7bpgU0RE5ETTX1s56e6Z/yQ2u4cIb1Nu6jAg0OGIiIiIiJx5/lwK+ZkQHgfJZx+zqDGGvXv3snfvXowxJylAERGRM5t62spJNevXJez2Lgbg/zrdWfV/pV/4NOQcgDbXQEKzQEcjIiIiIlI51s2xnhueB1W9TS4iInIG0l9nOWm8Xi8Tlv4HgCR7Vy4+q0OAIyqDlf+DxS9Axp+BjkREREREpHJ4vbB2pvW62aWBjUVERERKpKStnDTPL/mEHPvvGK+TJ867O9Dh/LWcA7B/g/W6dvvAxiIiIiIiUlm2/wiHdkJINDQ4L9DRiIiISAk0PIKcFHkFLqb+9gI4oXnkAM5OrhfokP7ajp+s5xoNIbx6YGMREREREaksjiBociFExEFQaKCjERERkRIoaSsnxf/N+y8e527whPNs/zGBDqds/vzRelYvWxERERE5ndRqB0P/B7qpmIiISJVVJYZHePHFF6lbty6hoaF07NiRH3/88Zjlp0+fTtOmTQkNDaVly5bMmTPHb74xhgcffJCkpCTCwsLo06cPGzZs8CtzySWXUKdOHUJDQ0lKSuLaa69l586dvvnr16+nd+/eJCQkEBoaSv369bn//vspKCiovB0/Q+zOymDuzrcA6J3wN5KjT5Feq9uVtBURERGR05jNFugIREREpBQBT9q+//77jBkzhnHjxvHzzz/TunVr+vfvz549e0osv2jRIoYOHcp1113HihUrGDhwIAMHDmTNmjW+Mk8++STPP/88kydPZunSpURERNC/f3/y8vJ8ZXr37s0HH3zA+vXrmTFjBps2beLKK6/0zQ8KCmLYsGHMnTuX9evXM3HiRF577TXGjRt34irjNDX2y+fAkYXdHc+EPjcEOpyy8Xpg++HhEVJOgRumiYiIiIiUxcav4cDmci1it9vp3r073bt3x24P+L+QIiIiZwSbMYG9JqZjx460b9+eF154AQCv10tKSgq33nor99xzT7HygwcPJjs7m9mzZ/umderUiTZt2jB58mSMMSQnJ3PHHXcwduxYADIyMkhISGDq1KkMGTKkxDg++eQTBg4cSH5+PkFBQSWWGTNmDMuWLWPhwoVl2rfMzExiYmLIyMggOjq6TMucbtIOHaTv9H7gyOPquvdzb8/BgQ6pbDK2w6u9wZUN9/4JdkegIxIREZEyUPur7FRXZyCPG55uDDn7YdSXUKdToCMSERE545S1DRbQn0ldLhc//fQTffr08U2z2+306dOHxYsXl7jM4sWL/coD9O/f31d+8+bNpKWl+ZWJiYmhY8eOpa7zwIEDvPvuu3Tp0qXUhO3GjRv54osv6NmzZ7n28Uz30II3wJGHw53And2u/OsFqoqY2jD2d7h9pRK2IiIicly+++47Lr74YpKTk7HZbMyaNeuY5T/66CP69u1LfHw80dHRdO7cmS+//NKvzEMPPYTNZvN7NG3a9ATuhZwWtnxnJWzDa0CtcwIdjYiIiBxDQJO2+/btw+PxkJCQ4Dc9ISGBtLS0EpdJS0s7ZvnC57Ks8+677yYiIoIaNWqwbds2Pv7442Lb69KlC6GhoTRq1Iju3bvz8MMPl7o/+fn5ZGZm+j3OZIfyc1m0dyYAF6YMxek4xZKfNhtE1gx0FCIiInKKy87OpnXr1rz44otlKv/dd9/Rt29f5syZw08//UTv3r25+OKLWbFihV+55s2bs2vXLt/j+++/PxHhy+lkzUfW81mXgKPs96T2er2sW7eOdevW4fV6T1BwIiIicqQzekCiO++8kxUrVjB37lwcDgfDhg3j6NEi3n//fX7++Wfee+89PvvsM5566qlS1zdhwgRiYmJ8j5SUlBO9C1XahO/ewzgysHliuLfn1YEOR0RERCQgLrjgAh599FEuu+yyMpWfOHEid911F+3bt6dRo0Y8/vjjNGrUiE8//dSvnNPpJDEx0feIi4s7EeHL6cLtgt8On0PNy3YuFjLGkJaWRlpaWrH/l0REROTEKPvPqydAXFwcDoeD3bt3+03fvXs3iYmJJS6TmJh4zPKFz7t37yYpKcmvTJs2bYptPy4ujsaNG3PWWWeRkpLCkiVL6Ny5s69MYeK1WbNmeDwebrjhBu644w4cJfQavffeexkzZozvfWZm5hmbuHV7PMzZ9j9wQue4gUSFhAU6pLLLPQgvdoLa58BVU8FR8pAZIiIiIieD1+vl0KFDVK9e3W/6hg0bSE5OJjQ0lM6dOzNhwgTq1KlT6nry8/PJz8/3vT/Trwo74/yxAPLSIaIm1O0W6GhERETkLwS0p21wcDDt2rVj3rx5vmler5d58+b5JU6P1LlzZ7/yAF999ZWvfL169UhMTPQrk5mZydKlS0tdZ+F2Ab+GbEllCgoKSr0kKCQkhOjoaL/Hmer5xR/jce4GTyjje18X6HDK588fISsN9q5XwlZEREQC7qmnniIrK4tBgwb5pnXs2JGpU6fyxRdf8PLLL7N582a6d+/OoUOHSl2Prgo7w609PDRCs0t1zwYREZFTQEB72gKMGTOG4cOHc84559ChQwcmTpxIdnY2I0eOBGDYsGHUqlWLCRMmAHD77bfTs2dPnn76aQYMGMC0adNYvnw5r776KgA2m43Ro0fz6KOP0qhRI+rVq8cDDzxAcnIyAwcOBGDp0qUsW7aMbt26Ua1aNTZt2sQDDzxAgwYNfIndd999l6CgIFq2bElISAjLly/n3nvvZfDgwaXerEyK/O/3t8ABzaPOJzGqWqDDKZ9th29YV6djYOMQERGRM957773H+PHj+fjjj6lZs2is/QsuuMD3ulWrVnTs2JHU1FQ++OADrruu5B/MdVXYGczrtXraArS4PKChiIiISNkEPGk7ePBg9u7dy4MPPkhaWhpt2rThiy++8N1IbNu2bdjtRR2Cu3Tpwnvvvcf999/PfffdR6NGjZg1axYtWrTwlbnrrrvIzs7mhhtuID09nW7duvHFF18QGhoKQHh4OB999BHjxo0jOzubpKQkzj//fO6//35CQkIAa4ywJ554gt9//x1jDKmpqdxyyy3861//Oom1c2p6d9UC8hz/z959h0dR9W0c/+5ueiAJoaTQEkB6CaCUCNK7NEWlKIg8iAUVkSKIiMqjooKo+IgFRAUFLC8oTQEJHZTekRIIEEJPo4Rkd98/VlYjBJKwyaTcn+vaa3Znzszcu0Yy+e2Zcw5ht1l4pdkgo+NkXcwGx7Jcxj2zRURERHLa7Nmz+c9//sN3331H69atb9o2ICCAypUrc/DgwQzbeHp6Oq91pZAxm+GZzXBwOZRtZHQaERERyQSTXSPJ55jExET8/f1JSEgoVEMlNJ3xMPGm7ZRzb8HC3h8YHSdr0lLgzbJgTYHBm6FEJaMTiYiISBbk9esvk8nE//3f/znvAMvIt99+y2OPPcbs2bPp2rXrLY+bnJxMuXLlGDduHM8++2ymsuT1z0ryDqvVyurVqwFo2rTpDef3EBERkczJ7DWY4T1tpWBZdnA78abt2O0mXmz8hNFxsi52m6Ng61MCilc0Oo2IiIgUAMnJyel6wEZHR7Nt2zYCAwMpV64co0aN4sSJE3z11VeAY0iEfv368f7779OwYUPi4uIA8Pb2xt/fH4Bhw4bRuXNnypcvT2xsLK+88goWi4VevXrl/huUvM1uB5PJ6BQiIiKSRYZORCYFzzsbPgGghKk+TcOrG5wmG45dGxqhkS5uRURExCU2bdpE3bp1qVu3LuCY06Fu3bqMHTsWgJMnTxITE+Ns/+mnn5KWlsbTTz9NSEiI8/Hcc8852xw/fpxevXpRpUoVHnzwQYoXL86GDRsoWbJk7r45yfv2LYCPGsHvn2X7EGazmcjISCIjI9MNXSciIiI5Rz1txWV2xB3hRNpaTCZ45s6BRsfJHt+SUKYBhN9jdBIREREpIJo3b87NRiSbMWNGutdRUVG3PObs2bNvM5UUGrv/D87shQtHsn0Ik8mEh4eH6zKJiIjILaloKy7z7rqZmEw2fGx3cH+NSKPjZE9Eb8dDRERERCS/S70M+5c4ntfobmwWERERyRIVbcUl0qxWtl1YCm7QKew+o+OIiIiIiMjBZZB6EfzLQun62T6MzWZzjstcqVIlDZEgIiKSC/TbVlziy63LsLudB6sXzzbKp9/iJ56ElCSjU4iIiIiIuMbueY5l9a63NV+D3W4nNjaW2NjYmw71ISIiIq6joq24xLd7fwAgzLsJAd6+BqfJpuWvwVvlYMNUo5OIiIiIiNye1Mvw519DI1TvZmgUERERyToVbeW2HYs/R1zaHwAMqPOQwWluw7ENYLdB8YpGJxERERERuT0Hl8PVZPArA2XuNDqNiIiIZJHGtJXb9t76uZjMabilhdKlagOj42RP8mk4fxgwQZm7jE4jIiIiInJ7/EKh9kNQLOy2hkYQERERY6hoK7dt5ckFYIG7gzrm30kJYjY4lqWqg3eAoVFERERERG5b6Xpw36dGpxAREZFsyqcVNskrFu3fzFVLDHa7haGRPY2Ok33XirblGhmbQ0RERERERERECj0VbeW2fLZ1DgAlzXWpEBhkcJrbcOxa0baxsTlERERERG7X7v+DkzvAbjc6iYiIiGSThkeQbEtKuczBSyvBAj0q32d0nOy7ehFObnc8L9fQ2CwiIiIiIrcjLQV+ehZSEuE/y10yCZnZbKZRo0bO5yIiIpLzVLSVbPto409guYTJ6s/AOzsYHSf77HboMAFO7QH/skanERERERHJvkO/OQq2RUMhtJ5LDmkymfDy8nLJsURERCRzVLSVbPv58HwwQU2/lni45eMfJc8icNd/jE4hIiIiInL79sx3LKt3AfWKFRERybfycaVNjLQj7ggJ7MIEDL6rt9FxREREREQkLQX2LXI8r97NZYe12WxER0cDEB4eriESREREcoF+20q2TN7wLSaTHR9bZSLLVzU6TvbZrLB5Bpzeq4kaRERERCR/OxwFKQlQNATKum6uBrvdzrFjxzh27Bh2XTOLiIjkChVtJctsNhubzy0FoG3ZzganuU2n98DPz8HnbRwFXBERERGR/GrfAseymoZGEBERye/0m1yybPmhHdjczmC3ufFso/uMjnN7YjY4lmXuBItGCxERERGRfOzkdscy/B5jc4iIiMhtU5VKsuzb3YsBCDBVp2QRP4PT3KZjvzuW5RoZm0NERERE5Hb1XwJn9kFgBaOTiIjkC1arldTUVKNjSAHk7u6OxWK5rWOoaCtZtuP8WrDA3SHNjY5y+2K3Opal6xubQ0RERETkdnn4QOl6RqcQEcnz7HY7cXFxxMfHGx1FCrCAgACCg4MxmUzZ2l9FW8mSXXExpFiOYrebeKxeJ6Pj3J4rCXDugON5qC5uRURERERERAqDawXbUqVK4ePjk+2imsiN2O12Ll26xOnTpwEICQnJ1nFUtJUsmb51IQDetnCqlAw1OM1tit3mWAaUA9/ihkYREREREbktm6ZD3E6o2QPC7jY6jYhInmW1Wp0F2+LFVQuQnOHt7Q3A6dOnKVWqVLaGSlDRVrJkQ9xKMEPdEk2MjnL7rg2NEFrX2BwiIiIiIrdr3yI4uBSCarq8aGs2m7nrrrucz0VE8rNrY9j6+PgYnEQKums/Y6mpqSraSs6KTTxPomkfJuDhWh2NjnP76vV1XNR6FjU6iYiIiOQx8fHx/N///R+rV6/m6NGjXLp0iZIlS1K3bl3atWtHZGSk0RFF0ju127EMqunyQ5tMJnx9fV1+XBERI2lIBMlpt/szpq9JJdOmb1mCyWTFkhbEPeE1jI5z+3wC4Y7WUK6h0UlEREQkj4iNjeU///kPISEhjB8/nsuXLxMREUGrVq0oU6YMK1asoE2bNlSvXp05c+YYHVfE4dJ5SIp1PA+qbmwWERHJFTVq1GDBggWZajtu3Di6det2w21RUVEEBAS4Lpi4jHraSqatOLYCgCp+jQxOIiIiIpIz6tatS79+/di8eTPVq9+4+HX58mXmzZvH5MmTOXbsGMOGDcvllCL/cq2XbbGwHLmLzGazERMTA0C5cuU0RIKISC7Zv38/w4YNY/369Vy9epXQ0FD69+/PyJEj2b17t9HxJIepaCuZcjElhdNp28AM91dtb3Sc23d8M+xbABWaQYXmRqcRERGRPGLPnj23nJTE29ubXr160atXL86dO5dLyURu4tQuxzIHhkYAxyzYR44cAaBs2bI5cg4REblep06d6NmzJ3PmzMHT05N9+/axZ88eo2PdVFpaGm5uKje6gr4ilUyZuX05mK+AtSj3VW9sdJzbd3AZrJkE274xOomIiIjkIf8s2KakpHDx4sVMtxcxjLNoWwCGMBMREQDOnj3LoUOHGDRoED4+PlgsFmrUqMEDDzwAQFhYGPPmzQNgxowZRERE8Prrr1OqVCmCgoKYPHlyhsd+6aWXqFOnDidPnnSu+/zzzylbtizFixdnxIgR6drPnDmTatWqERAQQJMmTdiyZYtzW/PmzRkxYgRt27bF19eXxYsXExYWxttvv02jRo0oWrQozZo149ixY677cAoJFW0lUxYcWgpAOa+7cMvGjHd5Tuxf/8CE1jU2h4iIiOQ5Z86coUOHDhQpUgQ/Pz8aNWrEwYMHjY4lkrHkM46lirYiIllmt9u5dDUt1x52uz1TuYoXL06VKlXo378/c+fO5ejRozdtv3v3bnx8fDhx4gRz5sxh+PDhHDp0KF2btLQ0BgwYwNq1a1m1ahUhISEAJCUlsWfPHg4cOMCaNWv46KOPiIqKAmDVqlU8+eSTfPLJJ5w5c4YePXrQvn17EhISnMedMWMG48ePJzk5mdatWwOOQu+3337LmTNn8PX15eWXX87sfxL5i/oryy3ZbDaOXv4DLNChQmuj49w+ux1itzqeh9YzNouIiIjkOSNHjmTbtm289tpreHl58cknnzBw4EBWrFhhdDSRG+szFy5fAIun0UlERPKdy6lWqo/9JdfOt+e1dvh43LocZzKZiIqK4p133uHVV19l3759VKlShffff582bdpc175EiRK88MILgKP3a1hYGNu2baNixYoAXLp0ie7du+Pp6ckvv/yCp+ffvzPsdjvjx4/Hy8uLatWqERkZyebNm2nevDlff/01Dz/8MPfccw8AQ4YM4eOPP2bhwoX07t0bgN69e9OgQQPAMYwUwFNPPUV4eDgAffr04a233sruR1ZoqWgrt/TTvt+xWxKw2zx5JKKV0XFuX2IsJJ8CkwWCaxmdRkRERPKYpUuXMmPGDNq1awfAvffeS7Vq1UhJSUn3B45InuJdzOgEIiLiYsHBwUycOJGJEydy/vx5/vvf/9K9e3fn5JD/FBQUlO61r68vSUlJztfbtm0jMTGRTZs2XXc94+fnh4+Pzw33PX78OM2bN0/XPjw8nOPHjztflytX7obZM8oimaOirdzS3D1LAChpqYW/l88tWucD14ZGKFUNPArA+xERERGXio2NpU6dOs7Xd9xxB56enpw8eZKwsDDjgomIiIjLebtb2PNau1w9X3YEBgYybtw4Jk2aRHR0dJb3j4yMpHPnzrRp04Zly5ZRo0bmhtQpU6aMczLKa44cOUKZMmWcr81mjb6aE1S0lVvam7Ae3KBZmRZGR3EN59AIGs9WREREbszyrzH8LRZLpsegE8lVaz+AQ8uh/qNQo7vRaURE8h2TyZSp4Qpy24ULF5g4cSIPP/wwd9xxBykpKUyaNInAwECqVq2arWMOGDAAi8VCq1at+PXXX6ldu/Yt93n44Yfp3LkzDz/8MA0bNuTjjz/m3LlzdOzYMVsZJPPy3k+l5CkbYvaT5haL3W5mQL1ORsdxjdP7HEsVbUVEROQG7HY7lStXxmQyOdclJydTt27ddD1Jzp8/b0Q8kfSOroPDUVAl567VzWYz9erVcz4XEZGc5+HhwYkTJ+jYsSOnT5/Gy8uLevXqsXjxYnx9fbN93EcffRSLxUKbNm345Zdbj+XbrFkzPvzwQwYMGMDJkyepWbMmixcvJiAgINsZJHNMdnUZyDGJiYn4+/uTkJCAn5+f0XGy5ekFk1l1bhq+tqps6P+d0XFcw26H84cd4375BBqdRkRERFzIFddfX375Zaba9evXL1vHzysKwrWqAO/VgoQYeHQRhN1tdBoRkTzvypUrREdHEx4ejpeXl9FxpADL6Gcts9dg6mkrN7Xj3GYA6hRvaHASFzKZoHhFo1OIiIhIHpXfi7FSiFxJcBRsAYKqG5tFREREXCpP3Nvy0UcfERYWhpeXFw0bNuT333+/afvvvvuOqlWr4uXlRa1atVi0aFG67Xa7nbFjxxISEoK3tzetW7fmwIED6dp06dKFcuXK4eXlRUhICI888gixsbHO7VFRUXTt2pWQkBB8fX2JiIhg1qxZrnvT+cDVtDTirfsBaFdR39qLiIhIwaeb0CRfObXbsfQv67iLLIfYbDZiYmKIiYnBZrPl2HlERETkb4YXbefMmcPQoUN55ZVX2LJlC3Xq1KFdu3acPn36hu3XrVtHr169GDBgAFu3bqVbt25069aNXbt2Odu8/fbbfPDBB0ydOpWNGzfi6+tLu3btuHLlirNNixYtmDt3Lvv37+eHH37g0KFD9OjRI915ateuzQ8//MCOHTvo378/ffv2ZcGCBTn3YeQxvxzYCpbL2G2edKxc3+g4rrHmPZjbDw6tMDqJiIiI5EE1atRg9uzZXL169abtDhw4wJNPPslbb711y2OuWrWKzp07ExoaislkYt68ebfcJyoqinr16uHp6UmlSpWYMWPGdW2y2vFBCqBrRdugzM0Anl12u53Dhw9z+PBhfbEhIiKSS7I8PMLevXuZPXs2q1ev5ujRo1y6dImSJUtSt25d2rVrx/3334+np2emjzdp0iQGDhxI//79AZg6dSoLFy5k+vTpvPjii9e1f//992nfvj3Dhw8H4PXXX2fp0qVMmTKFqVOnYrfbmTx5MmPGjKFr164AfPXVVwQFBTFv3jx69uwJwPPPP+88Zvny5XnxxRfp1q0bqampuLu7M3r06HTnfe655/j111/58ccfuffee7P2oeVTSw6tBcDfVBkvdw+D07jIgaVwdC3c0dboJCIiIpIHffjhh4wcOZKnnnqKNm3acOeddxIaGoqXlxcXLlxgz549rFmzht27dzN48GCefPLJWx7z4sWL1KlTh8cee4z77rvvlu2jo6Pp1KkTTzzxBLNmzWL58uX85z//ISQkhHbt2gF/d3yYOnUqDRs2ZPLkybRr1479+/dTqlSp2/4cJJ849VfHlRwu2oqIiEjuy3TRdsuWLYwYMYI1a9Zw991307BhQ7p37463tzfnz59n165dvPTSSzzzzDOMGDGCIUOG3LJ4e/XqVTZv3syoUaOc68xmM61bt2b9+vU33Gf9+vUMHTo03bp27do5eyxER0cTFxdH69atndv9/f1p2LAh69evdxZt/+n8+fPMmjWLyMhI3N3dM8ybkJBAtWrVbvqeCpKd57YAUCOwrsFJXMRmhdhtjuehBeQ9iYiIiEu1atWKTZs2sWbNGubMmcOsWbM4evQoly9fpkSJEtStW5e+ffvSp08fihXL3O3oHTp0oEOHDpnOMHXqVMLDw5k4cSIA1apVY82aNbz33nvOom1WOz5IAWXxcAyLEFTT6CQiIiLiYpku2t5///0MHz6c77//noCAgAzbrV+/nvfff5+JEyde11v1386ePYvVaiUoKCjd+qCgIPbt23fDfeLi4m7YPi4uzrn92rqM2lwzcuRIpkyZwqVLl2jUqNFNhz6YO3cuf/zxB5988kmGbVJSUkhJSXG+TkxMzLBtXpdmtXLBug8s0DY80ug4rnH2AKReBHdfKFnF6DQiIiKShzVp0oQmTZoYcu7169en64AAjk4KQ4YMAbLX8QEK1rWq/KXjO9DhbbBrnFkREZGCJtNj2v7555889dRTNy3YAjRu3JjZs2c7hy/Iy4YPH87WrVv59ddfsVgs9O3b94ZjNK1YsYL+/fvz2WefUaNGxrcevfnmm/j7+zsfZcuWzcn4OerXg9v+Hs+26p1Gx3GNWEfPYULqgNlibBYRERGRDGTUSSExMZHLly/ftOPDvzsp/FNBulaVfzCZdG0rIiJSAGW6p+21YQPsdjsHDx7k6tWrVKlSBTe3Gx/iZsMMXFOiRAksFgunTp1Kt/7UqVMEBwffcJ/g4OCbtr+2PHXqFCEhIenaREREXHf+EiVKULlyZapVq0bZsmXZsGEDjRs3drZZuXIlnTt35r333qNv3743fT+jRo1KN3RDYmJivr0YXnzQMZ6tn+kOfNwzP0Zxnnbir6Jt6XrG5hAREZF8Y/ny5SxfvpzTp09js6XvzTh9+nSDUmVPQbpWFcBudxRsRUREpEDKdE9bcIwXW7t2bapWrUrt2rWpWLEimzZtyvbJPTw8qF+/PsuXL3eus9lsLF++PF3h9J8aN26crj3A0qVLne3Dw8MJDg5O1yYxMZGNGzdmeMxr5wXS3TIWFRVFp06dmDBhAo8//vgt34+npyd+fn7pHvnVjrObAaheLMLYIK4Uu9Wx1Hi2IiIikgmvvvoqbdu2Zfny5Zw9e5YLFy6ke+SUjDop+Pn54e3tna2OD1CwrlUFWDkBJteCDVONTiIiIiI5IEtF2+HDh5OWlsbMmTP5/vvvKVOmDIMGDbqtAEOHDuWzzz7jyy+/ZO/evTz55JNcvHjROalC3759043X9dxzz7FkyRImTpzIvn37GDduHJs2bWLw4MEAmEwmhgwZwvjx4/npp5/YuXMnffv2JTQ0lG7dugGwceNGpkyZwrZt2zh69Ci//fYbvXr1omLFis7C7ooVK+jUqRPPPvss999/P3FxccTFxXH+/Pnber/5QZrVynmrY0zhNgVlPFu7HcxuYHZX0VZEREQyZerUqcyYMYONGzcyb948/u///i/dI6fcqpNCdjo+SAEUtxPiY3JlPFuz2UxERAQRERGYzVn6E1JERG7TY489hslkYu/evc51UVFRmEwmihQpQtGiRSlfvjxjxoxxdkhs3rw5np6eFC1aFH9/f2rWrMkLL7zAmTNnrjv+qlWrMJlMjBw58rptVquVSZMmUbt2bXx9fSlVqhSNGjViypQppKWlATBjxgwsFgtFihTBz8+P0qVLc//997Nq1aoc+kQKjyz9xl2zZg2fffYZvXr1onv37nz//fds27aNixcvZjvAQw89xLvvvsvYsWOJiIhg27ZtLFmyxDlGV0xMDCdPnnS2j4yM5JtvvuHTTz+lTp06fP/998ybN4+aNf+eMXXEiBE888wzPP7449x1110kJyezZMkSvLy8APDx8eHHH3+kVatWVKlShQEDBlC7dm1WrlyJp6djKIAvv/ySS5cu8eabbxISEuJ83Hfffdl+r/nF8sM7wHIJu82dztUaGB3HNUwmGPALjDoOgRWMTiMiIiL5wNWrV4mMvP0vsJOTk9m2bRvbtm0DHHevbdu2jZiYGMAxbME/h+F64oknOHz4MCNGjGDfvn3873//Y+7cuTz//PPONrfq+CCFwKndjmVwzZu3cwGTyURAQAABAQGYNCSDiEiuSUpKYu7cuQQGBjJt2rR02/z9/UlOTiYpKYkFCxbw2WefpWszYcIEkpKSiI+PZ+7cuZw4cYL69etfd6fOtGnTCAwM5KuvvnIWYq/p3bs306dP54MPPuDs2bPExcUxZcoUfv/9dxISEpztatWqRXJyMomJiezcuZOWLVvSoUMHZs2alQOfSiFizwKTyWSPi4tLt87X19d++PDhrBym0EhISLAD9oSEBKOjZMmzCz+015xR0954+oNGRxERERHJEldef40YMcL+2muv3fZxVqxYYQeue/Tr189ut9vt/fr1szdr1uy6fSIiIuweHh72ChUq2L/44ovrjvvhhx/ay5UrZ/fw8LA3aNDAvmHDhizlyq/XqmK3268k2u2v+DkeyWeNTiMikq9cvnzZvmfPHvvly5eNjnJLn332mb1UqVLO5dWrV+12u+M6wd/fP13b+++/3z548GC73W63N2vWzP7ee++l256ammqvUaOGffjw4c51CQkJdh8fH/s333xj9/X1tc+bN8+5bcWKFXZPT89b1vy++OILe506da5b/+abb9qDg4PtVqs1C++4YMnoZy2z12CZnogMHN+wJicn4+3t7VxnNptJSkoiMTHRuU7jY+Vv2884xrOtWpDGs710HnwCjU4hIiIi+ciVK1f49NNPWbZsGbVr175uot1JkyZl6jjNmzfHbrdnuH3GjBk33Gfr1q03Pe7gwYOdQ4RJIbNnvmNZNAR8i+f46Ww2m/Pux5CQEA2RICIF09Wb3EVusoC7VybbmsHd+8ZtPXyzFGnatGn06dOHnj17MmTIEH7++ecb3gG+fft2Vq1axYQJEzI8lpubG926dWPp0qXOdd9++y1FihThgQceYMmSJUybNo2uXbsC8Msvv9CgQQPCw8OzlPmaHj16MGrUKPbv30+1atWydYzCLktFW7vdTuXKla9bV7duXedzk8mE1Wp1XULJVTabjXPWvWApQOPZxh+D92tD2YbQ72ewuN96HxERESn0duzYQUREBAC7du1Kt023iIthYrfCwhccz+v1y5VT2u12Dhw4AHDTye5ERPK1N0Iz3nZHW+jz3d+v36kEqZdu3LZ8E+i/8O/Xk2vBpXOO5+MSbrzPDezZs4cNGzYwdepUihQpQvfu3Zk2bZqzaJuQkOActqZUqVI888wzPProozc9ZunSpdPN1TRt2jR69+6Nm5sbffv2pV27dpw8eZKQkBDOnj1LaGj6z6RKlSqcOnWKlJQU5s6dS+fOnW96LqBQzA2VU7JUtF2xYkVO5ZA84rfDO8FyEbvNnS5VGxkdxzUO/OKYoMFuV8FWREREMk3XvpInrZkMaVfgjnbQbITRaUREJIdMmzaNOnXqUKdOHQD69etH+/btOXHiBOAY0zY+Pj5Lxzxx4gSBgY67kHfu3Mkff/zBp59+CkCLFi0IDQ3lyy+/5MUXX6REiRLs378/3f7XXoeFhd2yw+a1nNfOJ1mXpaJts2bNciqH5BGLDqwBoCiV8P1rUrZ8b/8Sx7JyO2NziIiISL51/PhxAMqUKWNwEin07vsUStwBkc+A2WJ0GhGRgmN0bMbbTP/693b4wZu0/dcQMkN2ZjlKamoqX3/9NcnJyc47HOx2O1arlRkzZnD33Xdn+ZhpaWnMnz+fjh07AjgnLWvfvr2zTXx8PNOnT+fFF1+kTZs2TJ48mSNHjhAWFpbl833//fcEBwdTpUqVLO8rDpkejOjixZuM1+GC9pI3bDuzBYAqAXUMTuIiVy9C9CrH8yodjM0iIiIi+YrNZuO1117D39+f8uXLU758eQICAnj99dex2WxGx5PCys0TWo4BL3+jk4iIFCwevhk//jme7S3bemfcNpN++uknEhMT2bJlC9u2bWPbtm1s376dl19+menTp990rPwb2bdvH/369SMhIYGhQ4dy9epVZs6cyVtvveU8/rZt29i4cSOHDx9m1apVtGzZkk6dOtGlSxdWrlzJ5cuXsdlsbN26laSkpAzPdeHCBT755BPGjx/Pu+++q3HQb0OmP7lKlSrx1ltvOQegvxG73c7SpUvp0KEDH3zwgUsCSu6x2WycSdsDQOuCMp7t4SiwpkBAOShZ1eg0IiIiko+89NJLTJkyhbfeeoutW7eydetW3njjDT788ENefvllo+NJYfLH57B0LNg0d4iISGEwbdo0evXqRdWqVQkODnY+nn32WWJjYzNVtB05ciRFixbF39+f++67j+DgYDZt2kRQUBDz5s0jNTWVp556Kt3x69SpQ7du3fj8888BmD17No888ghPP/00xYsXJyQkhCeeeILx48en66G7c+dOihQpgp+fHzVq1OCXX35h4cKF9OnTJ8c+o8LAZM9keX7//v2MHj2ahQsXUqdOHe68805CQ0Px8vLiwoUL7Nmzh/Xr1+Pm5saoUaMYNGgQFkvhvl0nMTERf39/EhIS8PPzMzrOLUUd3sUzq3tht7mxvvc6inp633qnvO6nZ2DLV9BgEHR82+g0IiIiksNcef0VGhrK1KlT6dKlS7r18+fP56mnnnKO1ZZf5bdr1ULryBr4qivY0qDHdKh5f65HsFqtrF69GoCmTZsW+r/zRCR/u3LlCtHR0YSHh+Pl5XXrHUSyKaOftcxeg2V6TNsqVarwww8/EBMTw9y5c1mzZg3r1q3j8uXLlChRgrp16/LZZ5/RoUMH/RLPp37+03EhVoSKBaNga7PBn784nms8WxEREcmi8+fPU7Xq9XfqVK1aVTMhS+5Z+IKjYFuzB9S4z+g0IiIikkuyNBEZQLly5Rg2bBjDhg3LiTxioGvj2Vb2jzA2iKvYrdBqLBxcDmFNjE4jIiIi+UydOnWYMmXKdcN+TZkyxTmTs0iOOncIzuwDsxt0mggmkyExzGYztWrVcj4XERGRnJfloi3Aa6+9xrBhw/Dx8Um3/vLly7zzzjuMHTvWJeEk99hsNk6n7gYLtAprbHQc17C4Q92HHQ8RERGRLHr77bfp1KkTy5Yto3Fjx/XR+vXrOXbsGIsWLTI4nRQKB5c5luUag3eAYTFMJhPFixc37PwiIiKFUba+Jn311VdJTk6+bv2lS5d49dVXbzuU5L7NsYfBkoTdbqFb9QJStBURERG5Dc2aNePPP/+ke/fuxMfHEx8fz3333cf+/ftp2rSp0fGkMLg21NcdbY3NISIiIrkuWz1t7XY7phvcmrN9+3YCAwNvO5TkvuWHNwPgaSuNv5fPLVrnA0mnYOd3ULk9lKhkdBoRERHJp0JDQ/nvf/9rdAwpjK5edExCBoYXbW02G6dPnwagVKlSGiJBREQkF2SpaFusWDFMJhMmk4nKlSunK9xarVaSk5N54oknXB5Sct62U7sACPaqaHASF9m/CH59CfbMg/8sMzqNiIiI5BM7duygZs2amM1mduzYcdO2tWvXzqVUUiilXoE7H4PTe6BkFUOj2O129u3bB0DJkiUNzSIiIlJYZKloO3nyZOx2O4899hivvvoq/v7+zm0eHh6EhYU5x/uS/CUm+U8wQ7XA6kZHcY1rt5JVbmdsDhEREclXIiIiiIuLo1SpUkRERGAymbDb7de1M5lMWK1WAxJKoeFbHDq8ZXQKERERMUiWirb9+vUDIDw8nLvvvhs3t2yNriB5jM1mI8l+BIAm5SIMzeISqZfhcJTjeeUOhkYRERGR/CU6OtrZkzA6OtrgNCIiIiJSWGVrMKKiRYuyd+9e5+v58+fTrVs3Ro8ezdWrV10WTnLH1pPRYLmE3W6mVcU6Rse5fdGrIO0y+JWBoBpGpxEREZF8pHz58s4hwI4ePUrp0qUpX758ukfp0qU5evSowUmlQEs4AYdXQpr+thIRESmsslW0HTRoEH/++ScAhw8f5qGHHsLHx4fvvvuOESNGuDSg5LzfDm8BwMMaSlFPb4PTuMCfSxzLKu3hBhPmiYiIiGRGixYtOH/+/HXrExISaNGihQGJpNDY+R181QW+7290EhERETFItoq2f/75JxEREQB89913NGvWjG+++YYZM2bwww8/uDKf5IKtp3YCEORVyeAkLnL8D8eygv6YEhERkeyz2+3pJt695ty5c/j6+hqQSAqNA0sdy/BmxuYQERFD7d+/n86dO1OiRAn8/PyoWrUqEyZMAKB58+ZYLJZ0E6fGx8djMpk4cuQIAGlpaYwePZqwsDCKFClCSEgI9957L0lJSUa8HcmibA1Ka7fbsdlsACxbtox7770XgLJly3L27FnXpZNccST5AJgK0CRkF2Icy+IFpAgtIiIiueq+++4DHJONPfroo3h6ejq3Wa1WduzYQWRkpFHxpKC7HA8x6x3P72hjaBQRETFWp06d6NmzJ3PmzMHT05N9+/axZ88e5/ZixYoxatQoFi5ceMP933rrLX799VdWrFhBeHg4p0+fZsGCBbkVX25Ttoq2d955J+PHj6d169asXLmSjz/+GHBM1hAUFOTSgJKzbDYbSbZosMDd5QrAeLYAQ7bDhaMQWMHoJCIiIpIP+fv7A46OCkWLFsXb++/hozw8PGjUqBEDBw40Kp4UdIdXgN0KJSpDYLjRaQAwm81Ur17d+VxEpCCx2+1cTruca+fzdvO+4Z08/3b27FkOHTrEoEGD8PHxAaBGjRrUqPH33D1PPfUUH3zwAatWreKee+657hgbNmyga9euhIc7fp+UKlWKxx57zEXvRHJatoq2kydPpk+fPsybN4+XXnqJSpUcPRq///579TrIZ3acigHLRex2M60rRhgdxzW8izkeIiIiItnwxRdfABAWFsawYcM0FILkrmtDI9zR1tgc/2AymShVqpTRMUREcsTltMs0/KZhrp1vY++N+Lj73LJd8eLFqVKlCv379+fxxx+nYcOGlC9fPl2bwMBARo4cyYsvvsi6deuuO8bdd9/N+++/T9GiRWnSpAkRERG4uWWrFCgGyNbXpLVr12bnzp0kJCTwyiuvONe/8847fPnlly4LJzlv+aHNALhbQ/D3uvU/GiIiIiKFxSuvvKKCreQumw0O/Op4noeKtiIikvtMJhNRUVHUqVOHV199lQoVKlC9enWWLl2art2QIUM4evQo8+bNu+4YI0eOZPz48fz88880b96cEiVK8OKLL2K1WnPpXcjtuK3y+ubNm9m7dy8A1atXp169ei4JJblny7VJyDwrGpzERfb+DNGroFIbqKwLXREREbk933//PXPnziUmJoarV6+m27ZlyxaDUkmBFbcdLp4Bj6JQrrHRaZzsdjtnzpwBoGTJkpm6rVdEJL/wdvNmY++NuXq+zAoODmbixIlMnDiR8+fP89///pfu3bsTExPz9/G8vXnllVcYPXo0q1evTre/2WzmP//5D//5z39IS0vj119/pXfv3lSoUIHHH3/cZe9Jcka2etqePn2aFi1acNddd/Hss8/y7LPPcuedd9KqVSvnL3PJH44k/QlA1eIFZBKyw1Hw+6d/T94gIiIikk0ffPAB/fv3JygoiK1bt9KgQQOKFy/O4cOH6dChg9HxpCAKiYBBq6Drh+DmYXQaJ5vNxp49e9izZ49zQmoRkYLCZDLh4+6Ta4/sfvEVGBjIuHHjuHjxItHR0em2DRgwAJvNdtO7393c3OjYsSOtWrVi586d2coguStbRdtnnnmG5ORkdu/ezfnz5zl//jy7du0iMTGRZ5991tUZJYfYbDYSbY7/0SPLFJBJyOL/+rapWPmbtxMRERG5hf/97398+umnfPjhh3h4eDBixAiWLl3Ks88+S0JCgtHxpCAymSCkDtTobnQSEREx2IULFxgzZgz79u3DarVy6dIlJk2aRGBgIFWrVk3X1mKx8N///pc33ngj3fr33nuPZcuWkZycjN1uZ+3atURFRWk+qnwiW0XbJUuW8L///Y9q1ao511WvXp2PPvqIxYsXuyyc5Kw9p4+DJRm73UzbSnWNjuMaF446lgEq2oqIiMjtiYmJcf5R4+3tTVJSEgCPPPII3377rZHRREREpIDz8PDgxIkTdOzYEX9/f8qVK8fatWtZvHjxDcfcv//++6lUqVK6db6+vowePZrSpUsTEBDAwIEDGTt2LL169cqttyG3IVtj2tpsNtzd3a9b7+7urttl8pGlzknIggnwLgCTbNjtf/e0DShnbBYRERHJ94KDgzl//jzly5enXLlybNiwgTp16hAdHY3dbjc6nhQ0u/8P9i+BOg9BxZZGpxEREYP5+vryxRdfZLg9KirqunUbNmxI9/rxxx/X2LX5WLZ62rZs2ZLnnnuO2NhY57oTJ07w/PPP06pVK5eFk5y1Jc4xhkmpgjIJ2cUzkHYZMIF/WaPTiIiISD7XsmVLfvrpJwD69+/P888/T5s2bXjooYfo3l23r4uL/f4Z7JgNxzcZnURERETygGz1tJ0yZQpdunQhLCyMsmUdxbFjx45Rs2ZNZs6c6dKAknOik/aDCaoUq3brxvnBtaER/ErnqYkbREREJH/69NNPnXeRPf300xQvXpx169bRpUsXBg0aZHA6KVDOHoSja8FkhojeRqcRERGRPCBbRduyZcuyZcsWli1bxr59+wCoVq0arVu3dmk4yVkJtiNggciyBWQSsgRNQiYiIiKuYzabMZv/vjGtZ8+e9OzZ08BEUmBt/dqxrNgK/MsYm0VERETyhCwVbX/77TcGDx7Mhg0b8PPzo02bNrRp0waAhIQEatSowdSpU2natGmOhBXX2Xv6OFgSsdtNtCkok5DVuA8qtICUJKOTiIiISD61Y8eOTLetXbt2DiaRQsOaCtv/mtiuXl9js2TAZDI5Zyo3mUwGpxERESkcslS0nTx5MgMHDsTPz++6bf7+/gwaNIhJkyapaJsPLD14bRKyIIr7FDU4jYuYTOAT6HiIiIiIZENERAQmk+mWE42ZTCasVmsupZIC7cCvkHwKfEtC5fZGp7khs9lMcHCw0TFEREQKlSwVbbdv386ECRMy3N62bVvefffd2w4lOW/TX5OQlfQoIJOQiYiIiLhAdHS00RGksNny19AIdXpqXgYRERFxylLR9tSpU7i7u2d8MDc3zpw5c9uhJOdFJ+0HoHJBmYQM4Kdnwc0LmgwBv1Cj04iIiEg+VL68xsaXXFahGZw/DHXz5tAIAHa7nfPnzwMQGBioIRJERERyQZaKtqVLl2bXrl1UqlTphtt37NhBSEiIS4JJzkqwRoMFGheUSchsVtj2DdhSIfIZo9OIiIhIAXHo0CEmT57M3r17AahevTrPPfccFSvqbiVxkUZPQsMnHEN95VE2m42dOx136jVt2hSLxWJwIhERkYLPfOsmf+vYsSMvv/wyV65cuW7b5cuXeeWVV7j33ntdFk5yxv4zsdgtCdjtJtpWrGd0HNdIOuko2Jrd1ctWREREXOKXX36hevXq/P7779SuXZvatWuzceNGatSowdKlS42Ol+9tOn6Qh74bQ91p7Rmz7Auj4xgrDxdsRUQkf3viiScYOXJktvaNiooiICDAtYEk07JUtB0zZgznz5+ncuXKvP3228yfP5/58+czYcIEqlSpwvnz53nppZdyKqu4yNJDjknI3KylKFnk+knl8qULRx1L/zJg1jf/IiIicvtefPFFnn/+eTZu3MikSZOYNGkSGzduZMiQIdn64+ejjz4iLCwMLy8vGjZsyO+//55h2+bNm2Myma57dOrUydnm0UcfvW57+/Z5cyKra2w2G5/98Qv3fNmXR5fdx55L80lzO8H8E5N45IfXsdlsRkfMPbFbYdu3cPWS0UlERCQP2r9/P507d6ZEiRL4+flRtWrVm84zlZGpU6c69zty5Agmk4n4+HgXp3WYMWMGERERt3WM5s2bY7FY2LFjh3NdfHw8JpOJI0eO3F7Afxk3bhzdunVz6TFdKUtF26CgINatW0fNmjUZNWoU3bt3p3v37owePZqaNWuyZs0agoKCshQgKxevAN999x1Vq1bFy8uLWrVqsWjRonTb7XY7Y8eOJSQkBG9vb1q3bs2BAwfStenSpQvlypXDy8uLkJAQHnnkEWJjY53br1y5wqOPPkqtWrVwc3PL0/8Bs2PTSccPfoGahCz+r6JtQDljc4iIiEiBsXfvXgYMGHDd+scee4w9e/Zk6Vhz5sxh6NChvPLKK2zZsoU6derQrl07Tp8+fcP2P/74IydPnnQ+du3ahcVi4YEHHkjXrn379unaffvtt1nKlVtOJSfw3KIp1PuiLR/sGcYFtmIy2Slqq05591YAbEueS8dvnuFSaorBaXPJ+o9g3hOw7BWjk4iISB7UqVMn6tSpQ0xMDBcuXOCHH36gQoUKRsfKFcWKFWPUqFFGxzBcloq24JicYdGiRZw9e5aNGzeyYcMGzp49y6JFiwgPD8/SsbJ68bpu3Tp69erFgAED2Lp1K926daNbt27s2rXL2ebtt9/mgw8+YOrUqWzcuBFfX1/atWuXbkiHFi1aMHfuXPbv388PP/zAoUOH6NGjh3O71WrF29ubZ599ltatW2fxE8r7jiQ6ith3BFQ1OIkLXetpW0yTh4iIiIhrlCxZkm3btl23ftu2bZQqVSpLx5o0aRIDBw6kf//+VK9enalTp+Lj48P06dNv2D4wMJDg4GDnY+nSpfj4+FxXtPX09EzXrlixYlnKlVt+P7af3858gtXtFHabJxU82vJh029Z138OC3pPpn3wYOx2Myesq2g9sy9xSReMjpyzLl+APT85ntfpZWwWERHJc86ePcuhQ4cYNGgQPj4+WCwWatSo4bwOmDNnDo0aNXK2v//++9PNMfXCCy/wzDOO+X4effRRhgwZAkCDBg0AKFOmDEWKFGHWrFkAbN68mZYtWxIYGEjJkiWd+17z+eefU7ZsWYoXL86IESNumHnr1q088cQT7Ny5kyJFilCkSBFiYmKw2+1MnDiRihUrEhgYSPv27Tl8+PBN3/9TTz3F2rVrWbVqVYZtZs+eTe3atQkICOCuu+5i3bp1AGzcuPG6z8Ld3Z3k5GQAPvzwQzp37sy8efN44403WLBggTMvQGpqKqNGjaJcuXKULFmShx56iDNnzjiPZzKZmDp1KjVr1sTPz48uXbqQkJBw0/eTXVku2l5TrFgx7rrrLho0aJDti8OsXry+//77tG/fnuHDh1OtWjVef/116tWrx5QpUwBHL9vJkyczZswYunbtSu3atfnqq6+IjY1l3rx5zuM8//zzNGrUiPLlyxMZGcmLL77Ihg0bSE1NBcDX15ePP/6YgQMHEhwcnK33lpfFpx0HICK4ABVtnT1tVbQVERER1xg4cCCPP/44EyZMYPXq1axevZq33nqLQYMGMXDgwEwf5+rVq2zevDldZwCz2Uzr1q1Zv359po4xbdo0evbsia+vb7r1UVFRlCpViipVqvDkk09y7ty5DI+RkpJCYmJiukdu6VytASGWJrQq9QTLeixlfq+JNK9Q07n9nXaDeKLqeOw2D5LMe+j4XW92xcXkWr5ct+M7sKZAUC0IrWt0GhERwdGBL6PHv4fvuVlbq9WaYdvMKl68OFWqVKF///7MnTuXo0ePptvevHlzNm/eTFJSEna7nTVr1uDl5eWcOPW3336jZcuW1x332t3tx48fJzk5mT59+nDixAlatmxJjx49iI2N5ejRozz44IPOfZKSktizZw8HDhxgzZo1fPTRR0RFRV137Lp16zJ16lRq1apFcnIyycnJlCtXjq+//ppJkyYxb948YmNjqVGjBp07dyYtLS3D9x8YGMjIkSN58cUXb7h90aJFDBs2jBkzZnD+/HlGjRpF586dOXfuHPXr1+fixYvpPovy5cuzevXqdJ9Nt27dGD16NPfee68zL8Cbb77JggULWLNmDdHR0ZhMJvr06ZPu/HPnzuW3334jJiaG48eP895772X4Xm6HW44cNROuXbz+s7vzrS5e169fz9ChQ9Ota9eunbMgGx0dTVxcXLoLYn9/fxo2bMj69evp2bPndcc8f/48s2bNIjIyEnd399t6TykpKaSk/H07V25eCGfWxZQU0ixnMAGRZWsYHcd1kk85lsXCDI0hIiIiBcfLL79M0aJFmThxovOaNTQ0lHHjxvHss89m+jhnz57FarVeN4xYUFAQ+/btu+X+v//+O7t27WLatGnp1rdv35777ruP8PBwDh06xOjRo+nQoQPr16/HYrl+jP8333yTV199NdO5Xe3Xhz++6fbBjTpT1q8UYzY8T6rlOL0X9mFauy+4q0ylXEqYi7Z+5VjWe0STkImI5BHXino3EhgYSO3atZ2v165dm+E47AEBAenGdf1nJ8HmzZtnKovJZCIqKop33nmHV199lX379lGlShXef/992rRpQ1BQEJUrV2b16tWEhIRQvnx5GjZsyIoVKwgKCmLXrl2ZPtfMmTOpX78+Tz31lHNd06ZNnc/tdjvjx4/Hy8uLatWqERkZyebNmzN9/K+//ppnn32WWrVqAfDGG2/w2Wef8fvvvxMZGZnhfkOGDGHKlCnMmzfvunN99NFHDB8+nHr16gFw3333MXHiRBYtWsQjjzxC06ZNnZ9FXFwcw4YNY8WKFbRr146VK1cybty4m+YdP3485co5ht+cNGkSpUuXJjY2ltBQx8T3I0aMcN51df/997Nhw4ZMfRZZle2etrfrZhevcXFxN9wnLi7upu2vLTNzzJEjR+Lr60vx4sWJiYlh/vz5t/V+wHEh7O/v73yULVv2to/pauuP7cNksmG3eVK9VBmj47jOwz/C8MNQOW9PviEiIiL5h8lk4vnnn+f48eMkJCSQkJDA8ePHee655zDlYqFt2rRp1KpVy3lL4zU9e/akS5cu1KpVi27durFgwQL++OOPG/Z+ARg1apTzfSQkJHDs2LFcSJ81Xas35Iu2X2FJK4XdLZ7//PI4+8/E3nrH/OTsQYjbCWZ3qPXArdvnASaTiTvuuIM77rgjV3/2RUQKs+DgYCZOnMju3bs5c+YMHTp0oHv37pw/fx5wDP25YsUKfvvtN1q0aEGrVq1YsWIFK1asoHbt2pm+K/7o0aPccccdGW738/PDx8fH+drX15ekpKRMv4/jx48TFhbmfO3p6UloaCjHjx+/6X7e3t688sorjB49+rpeykeOHGH06NEEBAQ4H9u2bePEiRPA35/NihUruOeee5yfzdatWzGbzekK8LfKGxoaiqenZ7q8/7wrP6ufR1YY1tPWaMOHD2fAgAEcPXqUV199lb59+7JgwYLbuggZNWpUup7AiYmJea5w+8cJR28OL3soZrNhNXvXM5nAt7jRKURERKSAKlq0aLb3LVGiBBaLhVOnTqVbf+rUqVsOxXXx4kVmz57Na6+9dsvzVKhQgRIlSnDw4EFatWp13XZPT088PT2zFt4Ad5apxMxOX9B74SPY3M7Q6+f/sKDHTEL9Ao2O5hp/LnYsw+4Gn/zxnsxmM6VLlzY6hohIjvpn79J/+3et6O677870cf859mx2BQYGMm7cOCZNmkR0dDSBgYG0aNGCN998k6CgIJ599lkaNmzIE088QcmSJWnRosUNj3OjOlD58uX59ddfbztjRscvU6YMR44ccb6+evUqsbGxlClz646EAwYMYNKkSXz55Zfp1pctW5ZnnnmGJ5544ob7tWjRggkTJlCyZElatmzpnNDt//7v/2jevLnzv+fN8jZs2BBwdBBNSUnJVF5XM6xql52L1+Dg4Ju2v7bMzDFLlChB5cqVadOmDbNnz2bRokW33Z3Z09MTPz+/dI+8Zu85xyRkJTzyVjFZREREJC+oV68eFy44JsGqW7cu9erVy/CRWR4eHtSvX5/ly5c719lsNpYvX07jxo1vuu93331HSkoKDz/88C3Pc/z4cc6dO5du8o38qmZwOT5uPRWsRUi1HKP7DwNJuHLJ6Fiucf6vyVcqdzA2h4iIpGOxWDJ8/Lu4d7O2/x6iKKP1N3PhwgXGjBnDvn37sFqtXLp0iUmTJhEYGEjVqo75iZo1a8b27dtZv349TZo0ISAggDJlyjBr1qwbjmcLjklWzWYzhw4dcq7r06cPv//+O1OnTiUlJYVLly7ddKiImwkKCuLkyZNcvnzZue7hhx9mypQp7Nmzh5SUFMaMGUPp0qWvu4PoRiwWC//9739544030q1/+umneeedd9i8eTN2u51Lly6xbNkyZ2/YunXrkpaWxqxZs2jRogUmk4mmTZvy4YcfpvtsgoKCOHr0aLrxdR9++GHeeOMNjh07RnJyMkOHDqV169bOoRFyk2FF2+xcvDZu3Dhde4ClS5c624eHhxMcHJyuTWJiIhs3brzpBfG1cUj+OR5tQXU8ORqAMP8KBidxoeOb4JuesHqS0UlEREQkn+vatSuxsY7b8bt160bXrl0zfGTF0KFD+eyzz/jyyy/Zu3cvTz75JBcvXqR///4A9O3bN91cD9dMmzaNbt26Ubx4+juKkpOTGT58OBs2bODIkSMsX76crl27UqlSJdq1a5fNd5+33F2+Gm80/gC7zZNL5j/pPOcJrqReNTrW7bv3PRi6D+o8ZHSSTLPb7cTHxxMfH4/dbjc6johIgefh4cGJEyfo2LEj/v7+lCtXjrVr17J48WLnpKQlSpSgevXqVK9e3bmuVatWXLp0iXvuueeGx7025ECHDh0ICAjgm2++oUyZMixfvpxvvvmGoKAgwsLC+P7777OVu2XLljRq1IjSpUsTEBBATEwMffv25ZlnnuHee+8lODiY7du38/PPP+Pmlrmb/++//34qVUo/vn3nzp156623GDhwIMWKFSM8PJz333/fWd8zm83cc889FC1alMqVKzs/m8TExHRF2wceeAA/Pz9KlixJQEAA4LiLvl27djRu3JiwsDBSU1OZOXNmtj6P22WyG/hbd86cOfTr149PPvmEBg0aMHnyZObOncu+ffsICgqib9++lC5dmjfffBOAdevW0axZM9566y06derE7NmzeeONN9iyZQs1azpmn50wYQJvvfUWX375JeHh4bz88svs2LGDPXv24OXlxcaNG/njjz9o0qQJxYoV49ChQ7z88sucOnWK3bt3O28Z27NnD1evXmXs2LEkJSU5Z4L752DSt5KYmIi/vz8JCQl5ptdt3WntSHOLZVCVNxjcqLPRcVxj8wz4+Tmo1AYezt4/LCIiIlIwuOL6y2w2c9dddzFgwAB69ep1W0Mj/NOUKVN45513iIuLIyIigg8++MB5613z5s0JCwtjxowZzvb79++natWq/Prrr7Rp0ybdsS5fvky3bt3YunUr8fHxhIaG0rZtW15//fXr5nfISF68Vr2RGZuX8e6O4ZjMaZS23MOi3h8WrGG+8gGr1ersddW0adMs9RYTEclrrly5QnR0NOHh4Xh5eRkdRwqwjH7WMnsNZuiYtg899BBnzpxh7NixzovXJUuWOC80Y2Ji0l2QRUZG8s033zBmzBhGjx7NHXfcwbx585wFW3DM4Hbx4kUef/xx4uPjadKkCUuWLHF+OD4+Pvz444+88sorXLx4kZCQENq3b8+YMWPSjfHVsWNHjh496nxdt25dgHz9zfKV1Kukmk9hAhqVqWZ0HNe58Nd/p2Lljc0hIiIiBcLKlSv54osvGDZsGEOHDqVHjx4MGDDgpmPdZcbgwYMZPHjwDbfdaPKwKlWqZHjt6e3tzS+//HJbefKLR+u35vzll5l+cBwnrKsY+NPbTOv2otGxsiftKrh5GJ1CRERE8gFDe9oWdHmt98Lq6D08teoh7DZ3tvX9A7eC8g359wNg1/fQ5nW4+1mj04iIiIiBXHn9dfHiRebOncuMGTNYvXo1lSpVYsCAAfTr1++WE4jlB3ntWvVWhiz+iOWnp2JOK8n2Ab8ZHSfrrKkwqRoE1YD7PociJY1OlGnqaSsiBYl62kpuud2etrqvqBD5/cReADztwQWnYAsQr562IiIi4nq+vr7079+flStX8ueff/LAAw/w0UcfUa5cObp06WJ0vELnhcY9AbC5neHI+dMGp8mGmPVw8QzE7QKfQKPTiIiISB6nom0hsufsAQAC3csZnMTFrg2PEFDA3peIiIjkGZUqVWL06NGMGTOGokWLsnDhQqMjFTplA4pjSSsFwMI/NxqcJhv2L3EsK7cDcwHqQCEiIiI5wtAxbSV3xSRFA1CuaLjBSVwo9TJc/KunRYB62oqIiIjrrVq1iunTp/PDDz9gNpt58MEHGTBggNGxCqVSHndw0naajbHbeZrbmFT33CFYNNzxpX9AWcd1ZOl6EFjBdWH/yW6HPxc7nldulzPnEBGRLNFooZLTbvdnTEXbQuR86jGwQI0SdxgdxXWS4sDd19FbwbuY0WlERESkgIiNjWXGjBnMmDGDgwcPEhkZyQcffMCDDz6Ir6+v0fEKrWrFa3LyzFoOJe65vQOdPwyHlqdf5+YNz2wC/zK3d+wbOXvAcU6LB1Rs6frji4hIprm7uwNw6dIlvL29DU4jBdmlS5eAv3/mskpF20IizWolxRSHCWhYtprRcVwnMBxGn4ArCWAyGZ1GRERECoAOHTqwbNkySpQoQd++fXnssceoUqWK0bEEaFa+Pr+dgUTbIWw2G2ZzFkZ7WzcFqnZ09KYtVR26TIGEYxAfA0fWOJ5v+Bja/df1wa/1sg1rAp5FXX/8HGYymahQoYLzuYhIfmaxWAgICOD0acdduz4+Pvq3TVzKbrdz6dIlTp8+TUBAQLYn8FTRtpDYejIakzkVu81C/dCKRsdxLZMJvAOMTiEiIiIFhLu7O99//z333ntvti+yJWe0q1SPsb9bMFku8fvxAzQql8li+qEV8OtLsOINeG47+JeGeo/8vf3AMph1P2yeAfcMc/0dXM7xbDu49ri5xGw2U66c5o8QkYIjODgYwFm4FckJAQEBzp+17FDRtpDYcGw3AO62ILzcPQxOIyIiIpJ3/fTTT0ZHkAz4enriZS9HCtH8euiPzBVtUy/Dgucdz+v2gSIlr29TqRWUquEYcispzvVF21o9wN1L49mKiOQRJpOJkJAQSpUqRWpqqtFxpAByd3e/7S//VbQtJHadOQBAMfeyBidxsV9ecowP1ngwhN1tdBoRERERyWFlfKpwKCWaLae2Aw/feodV78KFaCgaAi1fvnEbkwn6zgffEjkz5NZdAxyPfMput5OUlARA0aJFdRuxiBQYFotFd9VInpWFQaAkPzuaGA1A2SLhBidxsSOrYf8iSEk0OomIiIiI5IJ6QXUAOH5p/60bn94Layc7nnd4G7z8Mm5bpKTmSMiAzWZjy5YtbNmyBZvNZnQcERGRQkFF20Li7NVjAFQrUcngJC524ahjGVDe2BwiIiIikivaVrwLgCumY1xMScm4oc0GPz8HtjSo0hGqdc7cCa4kwqbpYLPefti0q7DpC0iMvf1jiYiISKGiom0hYLPZuGJyXCjeGVrV4DQudCUBrsQ7ngdoYgQRERGRwqBBmTvA6oPJnMYvB7dk3HDnd3BsI3gUgY7vZK4Xrc0KU5s4xsDdv+j2wx5dCwuGwCfNHEVkERERkUxS0bYQ2HU6BpM5BbvdTOOyBahoGx/jWPoUB88ixmYRERERkVxhNpvxM1cEYOXRzRk3rNENWrwErceBf5lMHtzimDQMYO37YLffVlYOLHUsK7cFs/70EhERkczTlUMhsD5mDwBu1pL4enoanMaFNDSCiIiISKFU0a86AHvP7cq4kZsnNBsBDQZm7eANBoHFE47/ATEbbiMlcHK7Y1leE+aKiIhI1qhoWwjsOP0nAAFumexhkF/E7XAsS1YxNoeIiIiI5KqGoY7JyE5fPeD6gxcNgohejudr38/+cex2OO3oPEGp6refS0RERAoVFW0LgSOJ0QCULhJmbBBXS70E7j5Q5k6jk4iIiIhILupUuSEAVrfTHIs/l37jlUT4qhv8/hlY07J3gsbPACb4czGc2Z+9YySfgsvnwWRWJwMRERHJMhVtC4EzVxxjv1YJrGRwEhdrOx5ePAYRDxudRERERERyUVhgKcxpJQFY9OfG9BsP/AqHV8DGqY4xarOjRCWodq/j+YaPs3eMU7sdy8CK4O6dvWPkESaTibCwMMLCwjBlZkI3ERERuW0q2hZwNpuNy5wAoH5oAZqE7BqLG7h7GZ1CRERERHJZSQ9Hh4T1J7am37BnnmNZrQvcToGxXj+weMCV+Oztf3qvY1mqWvYz5BFms9lZtDVrQjUREZFc4WZ0AMlZB87FgeUydruJu8vl/wtGJ5tNM/CKiIiIFGLVAmty6ux6Dibs+Xvl1YtwYJnjefWut3eCCs1h1HHHhGbZUesBCKwA3gG3l0NEREQKJVW9Crh1MY7bsizW4gR4+xqcxoUWDYMpd8GuH4xOIiIiIiIGaFauPgAJtsPYbDbHyoPLIO0yBJSDkDq3dwKLe/YLtuCY0KxqRygfeXs58gC73c7Fixe5ePEidrvd6DgiIiKFgoq2Bdz2038C4OdWxuAkLnb8dzj7J5jVWVxERESkMGpXuR52uwUsyWyOPexYuWe+Y1m96+0NjfBvhbxQabPZ+OOPP/jjjz/+LpCLiIhIjlLRtoCLjndcwIb6lDc4iQtdvQin/roNrvSdxmYREREREUMU9fTGy+bomPDLwd8h9Qr8+YtjY7XbHBrhmsMrYWoTmJPFiW8TYyFqwt95RERERLJIRdsC7tSVowBULlbJ4CQuFLsN7FYoGgr+pY1OIyIiIiIGKe1dBYCtp3bApXNQrjEUC4PS9V1zAncfiNsJMRuy1tv2+CaIegNWvOGaHCIiIlLoqGhbwF20xwJQN6SKwUlc6MQmx7KMiy7GRURERCRfqhvkGLc25uI+x5f5D38Pgze5bsLa4FpgdodLZyH+aOb3O/3XXWFBNVyTQ0RERAodFW0LsCPnT4MlCYC7y1c3OI0LHf+raKuhEUREREQKtVbhjuvBy6ajXEpNcay0uLvuBO5eEFzT8fzE5szvd8oxGTClCtA1uIiIiOQqFW0LsLXHHN/wm9ICCCrib3AaF7pWtC1zl7E5RERERMRQd5evClZvTOY0lm3LofFjr3UUOLEl8/uc3utYlqrm+jwiIiJSKKhoW4DtPnUIgCLmUIOTuFBaClRsCSWrQWiE0WlERERExEBms5lilqoAfLt1GEkrJrj+JNfGx73WceBWUi/Decd1uIZHEBERkexS0bYAO5xwBICSXmWMDeJKbp7Q7SN4egN4+BqdRkREREQM9nazF/G32tjl6cnDR37HZrO59gRl/uppe3I7WFNv3f7MfrDbwDsQigS5NotBTCYTZcuWpWzZsphMJqPjiIiIFAoq2hZgcZeOAVC2aDmDk4iIiIiI5IxGtjNMPn0GN7udw+Y/6T/vTdeeILAiBNeGqp0gJenW7f85CVkBKXCazWYqVqxIxYoVMbtqkjcRERG5KTejA0jOSUg7CW5QrUQFo6O4zrlDEFAeLPrRFRERERFgz3zuvJLCIFs4H1mOsCVpNm+srMDoZr1cc3yzGZ5Ynfn2tR6A0HpgTXHN+UVERKRQ0tekBVSa1Uqq+TQA9UMrG5zGRayp8PHd8FZZiD9mdBoRERERyQsOrwDgibufprJXJwC+Ofw28/ZsMCaPxR1KVYWQOsacPwfY7XauXLnClStXsNvtRscREREpFFS0LaB2norBZE7DbjcTERJudBzXOL0H0i6D2R38ShudRkRERESMlngSzh8GkxnKN+bb+8fjb6+FyZzG2A3D2BUX47pz2WxwPtp1x8tHbDYbGzZsYMOGDa4fM1hERERuSEXbAmpL7J8AuFlL4OXuYXAaFzn+h2NZpr7jNjURERERKdxi1jmWwbXAyx8PNze+v/8j3NJCsFsS6LOwH2+vmnv7hcYrCTAhDD6oC1cSM253+QL83xOw7kNQj1QRERG5Dap8FVC7zxwGoKglxOAkLnR8s2NZ+k5jc4iIiIhI3hDeHHpMhyZDnauCixbj03b/w2T1x+Z2lq+jX6fBjK58sXlp9s/j5Q9efoAdTm7LuN2pPbD9W/j90wIzCZmIiIgYQ0XbAupo4lEASnmXMTiJCzl72t5lbA4RERERyRt8i0PN+6FGt3Sr7ypTiV8f+Jk6RXpgt3mQYjnCpF1DiZzRi4X7N2XvXKXrO5YnNmfc5vQex7JU9eydQ0REROQvKtoWUKcuOybqCvMvb3ASF7l8Ac4dcDy/dsEsIiIiko989NFHhIWF4eXlRcOGDfn9998zbDtjxgxMJlO6h5eXV7o2drudsWPHEhISgre3N61bt+bAgQM5/TbyjeCixZh5/yv8cO/PVPBoi91uJsm0ixc39Kft10+wI+5I1g547Rr0+E2Kvqd2O5Yq2oqIiMhtUtG2gEqyngSgRomKBidxkWs9GgIrOHpUiIiIiOQjc+bMYejQobzyyits2bKFOnXq0K5dO06fPp3hPn5+fpw8edL5OHr0aLrtb7/9Nh988AFTp05l48aN+Pr60q5dO65cuZLTbydvOPY7rJ4Isdtu2qxKyVDm95rIpy3mUMrcCICTtrX0Xnwfj/zwOucuJWXufGX+GqLrxJaM21zraRtUI3PHFBEREcmAirYF0KXUFKyWcwDcWaaywWlcJLACtBwDdw4wOomIiIhIlk2aNImBAwfSv39/qlevztSpU/Hx8WH69OkZ7mMymQgODnY+goKCnNvsdjuTJ09mzJgxdO3aldq1a/PVV18RGxvLvHnzcuEd5QF75sPy12DLl5lqHlm+Kssf+Yxxd36Ct7UiJnMq25Ln0mJ2B8Ys+4I0q/XmBwipAyYLJMVCYuz12+12OL3X8Vw9bUVEROQ25YmibVZuFQP47rvvqFq1Kl5eXtSqVYtFixal256ZW8W6dOlCuXLl8PLyIiQkhEceeYTY2PQXXzt27KBp06Z4eXlRtmxZ3n77bde84Ry2NfYwJpMNu82d6iXLGh3HNQIrwD3DIXKw0UlEREREsuTq1ats3ryZ1q1bO9eZzWZat27N+vXrM9wvOTmZ8uXLU7ZsWbp27cru3bud26Kjo4mLi0t3TH9/fxo2bHjTYxYoR9c6luXvztJu99eIZMOjP9I7bAymtEDslgTmn5hEoy+7siFmf8Y7evj+XYy90bi2CcchJRHMblC8UpYy5XUmk4nQ0FBCQ0MxaYI1ERGRXGF40Tart4qtW7eOXr16MWDAALZu3Uq3bt3o1q0bu3btcrbJzK1iLVq0YO7cuezfv58ffviBQ4cO0aNHD+f2xMRE2rZtS/ny5dm8eTPvvPMO48aN49NPP825D8NFtp48CIC7rRRuFovBaUREREQKt7Nnz2K1WtP1lAUICgoiLi7uhvtUqVKF6dOnM3/+fGbOnInNZiMyMpLjx48DOPfLyjFTUlJITExM98i3UpLg5HbH83KNs7y72WxmVLOHWN1nMY2L9QWbFymWowxc9jCT1vyY8Y51+0DTYVD8juu3nT8MmKBEZXDzyHKmvMxsNlO5cmUqV66M2Wz4n5AiIiKFguG/cbN6q9j7779P+/btGT58ONWqVeP111+nXr16TJkyBcj8rWLPP/88jRo1onz58kRGRvLiiy+yYcMGUlNTAZg1axZXr15l+vTp1KhRg549e/Lss88yadKkHP9Mbtfes4cA8HcLMTiJi6Rehj9/hTN/Gp1EREREJFc0btyYvn37EhERQbNmzfjxxx8pWbIkn3zySbaP+eabb+Lv7+98lC2bj+/IOrYR7DYoFgb+pbN9GH8vHz7tMpxZHb7Dy1oBLFf44tArPPjdS1xJvXr9Do2ehFYvQ6mq12+r0AxGx0Kv2dnOIyIiInKNoUXb7Nwqtn79+nTtAdq1a+dsn51bxc6fP8+sWbOIjIzE3d3deZ577rkHD4+/vyVv164d+/fv58KFCzc8Tl7pvXAsKQaAEJ98fCH+T2cPwDcPwBcdjE4iIiIikmUlSpTAYrFw6tSpdOtPnTpFcHBwpo7h7u5O3bp1OXjQcUfVtf2ycsxRo0aRkJDgfBw7diyrbyXvOJK9oREyUjs4jKg+c7nDsyMAey/9RLOZPdl7+njWDuThA8XKuyRTXmK327l69SpXr17FbrcbHUdERKRQMLRom51bxeLi4m7aPiu3io0cORJfX1+KFy9OTEwM8+fPv+V5/nmOf8srvRfOXHFcXIYHFJALxvi/ZkougBfAIiIiUvB5eHhQv359li9f7lxns9lYvnw5jRtn7tZ+q9XKzp07CQlx3EkVHh5OcHBwumMmJiaycePGDI/p6emJn59fuke+dXSdY1k+0mWH9PX05MeeE+gTPga7zZNL5gM8uOBBvt2+Mn3Di2dh60xY+AIcWQPWVJdlyKtsNhvr1q1j3bp12Gw2o+OIiIgUCoYPj2Ck4cOHs3XrVn799VcsFgt9+/a9rW+O80rvhYs2R1G5VqkCMgHChSOOZUEpQouIiEihM3ToUD777DO+/PJL9u7dy5NPPsnFixfp378/AH379mXUqFHO9q+99hq//vorhw8fZsuWLTz88MMcPXqU//znP4BjYqghQ4Ywfvx4fvrpJ3bu3Enfvn0JDQ2lW7duRrzF3GNNhTN7Hc9dWLS95sV7HuLj5l/hlhYKliT+u2UI/9u44O8G8wfD/Kfhj89hRieYEA7fPATj/OGnZ+DqRZdnEhERkcLHzciTZ+dWseDg4Ju2/+etYtd6Ilx7HRERcd35S5QoQeXKlalWrRply5Zlw4YNNG7cOMPz/PMc/+bp6Ymnp+ct3nXOir98EZslHhPQoEwVQ7O4zAX1tBUREZH87aGHHuLMmTOMHTuWuLg4IiIiWLJkifNOrpiYmHQTPF24cIGBAwcSFxdHsWLFqF+/PuvWraN69erONiNGjODixYs8/vjjxMfH06RJE5YsWYKXl1euv79cZXGHYQchbicUC8+RUzQNr85vvX6gy9wniDdv5397x5CYcpEX73kIekyDfQvh4DI4uBwunYU/lzh23PUjdP4gRzKJiIhI4WJoT9vs3CrWuHHjdO0Bli5d6myfnVvFrp0XHOPSXjvPqlWrnBOTXTtPlSpVKFasWBbfae754/gBTCY7WL0IL1bK6DiucW14BPW0FRERkXxs8ODBHD16lJSUFDZu3EjDhg2d26KiopgxY4bz9XvvvedsGxcXx8KFC6lbt26645lMJl577TXi4uK4cuUKy5Yto3Llyrn1dozl5gFl6oPJlGOnKOZThF/6TKOkqSEmk5WZh//LS0ung4cv1H4Q7vsUhh2AgSugxRio0AJajc3RTCIiIlJ4GD48QlZvFXvuuedYsmQJEydOZN++fYwbN45NmzYxePBgIHO3im3cuJEpU6awbds2jh49ym+//UavXr2oWLGis7Dbu3dvPDw8GDBgALt372bOnDm8//77DB06NHc/oCzaFueYnMKToHS9NfI19bQVEREREQP4uHuypPdUyro1x2Sy81Psezy78MO/G5jNULoeNBsOfedBw0GGZRUREZGCxfCq3kMPPcS7777L2LFjiYiIYNu2bdfdKnby5Eln+8jISL755hs+/fRT6tSpw/fff8+8efOoWbOms82IESN45plnePzxx7nrrrtITk5Od6uYj48PP/74I61ataJKlSoMGDCA2rVrs3LlSufwBv7+/vz6669ER0dTv359XnjhBcaOHcvjjz+ei59O1v15PhqAAPdQg5O4iN0O8TGO5+ppKyIiIlK4pV2FT+6Bn4dASnKunNLDzY0Fvd6nslcnAFac/ZTH5r2ZK+cWERGRwstkv52Zt+SmEhMT8ff3JyEhIddm5+0wazDH01ZSt+hDfHXfmFw5Z46ypsHuHx29be9+znErnIiIiEgGjLj+yq/y5Wd17A+Y1hq8A2H4IUdP11xis9l4dN4bbE2aA0AD/4eZ1m1krp3fSFarldWrVwPQtGlTLBaLwYlERETyr8xegxne01Zc61xKLACVioUZG8RVLG6OMcOaDVfBVkRERKSwO7rWsSwfmasFWwCz2cxX940hslg/AH5PmMlzi6bkagajmEwmgoODCQ4OxqQxe0VERHKFirYFzGXiAKgdVMngJCIiIiIiLnZ0nWNZPtKwCJ90GUZEkQcB+O3MJ4xZ9oVhWXKL2WymatWqVK1ateDMmyEiIpLH6TduARKXdAEsSQA0KFNAZg6O3QoHlkJirNFJRERERMRINivEbHA8N7BoC/Bl95ecY9zOO/4eb6+aa2geERERKXhUtC1ANh770/HEWoRQv0Bjw7jKH9NgVg/Y8pXRSURERETESKd2Q0oCeBSFoFqGRjGbzXz3wBuUdWuOyWTnq8Nv8L+NCwzNlJPsdjtWqxWr1YqmRBEREckdKtoWIDtPHwLAm2CDk7hQ/FHHMqC8sTlERERExFjHNjqWZRs45j0wmNlsZt5D71HS1BCTycr/9ozl662/GR0rR9hsNlavXs3q1aux2WxGxxERESkUjL/aEZc5cD4agECPUIOTuNCFv4q2xVS0FZHCzWq1kpqaanQMkTzDw8NDY2sWNmYLFK/kKNrmER5ubvz80Ee0+/YxEsw7eHvbSEr6fkL7yvWMjiYiIiL5nIq2BUjsxRgAyhQtZ3ASF7GmQcJxx3P1tBWRQsputxMXF0d8fLzRUUTyFLPZTHh4OB4eHkZHkdxy52OORx7r6enr6cmChz6l3exHuGQ+wPA1z1DS9yvql65odDQRERHJx1S0LUAupMaCBaoEVjA6imskngC7FSweUDTE6DQiIoa4VrAtVaoUPj4+mEwmoyOJGM5msxEbG8vJkycpV66c/r8obPJgD+sAb19+uO9zOn/fizS3WAYseZwfu39DhcAgo6OJiIhIPqWibQFhs9lI4RQAtYMrGZzGRa6NZ+tfNk9enIuI5DSr1eos2BYvXtzoOCJ5SsmSJYmNjSUtLQ13d3ej40hOu3oR3LwcQyTkUWX8A/m602f0XtgHq9tpHpw3kF96zqK4T1Gjo4mIiEg+pEpYAXEk/gxYLgPQsExlg9O4iMazFZFC7toYtj4+PgYnEcl7rg2LYLVaDU4iuWLNe/BWOccyD6sZXI73W/wPrN6kWKLpOvcJrqReNTqWiIiI5EMq2hYQvx/fD4ApLYAAb1+D07hI+D3Q/VNo+ITRSUREDKVbv0Wup/8vCpnjf8DVZPDyNzrJLbWoUIuX7nwXu82NBNMOus8dii2PjcMrIiIieZ+KtgXEzlMHAfA1F6CxX4uVhzoPQeV2RicREZHbUKNGDRYsWJCptuPGjaNbt2433BYVFUVAQIDrgolI/mCzwYktjudl7jI2Syb1rH0PAyqPxW43cTxtJQ//+JrRkW6LyWSiZMmSlCxZUl+YiIiI5BIVbQuI6IQjABT3DDU2iIiIFEr79++nc+fOlChRAj8/P6pWrcqECRMA2L17N/fee6/BCUUk3zr7J6QkgrsvlKxmdJpMe/7u7nQKeQaAnRd/YOjijw1OlH1ms5kaNWpQo0YNzJprQkREJFfoN24BcfLSMQDKFS1A479unQkHlkHqFaOTiIjILXTq1Ik6deoQExPDhQsX+OGHH6hQoYLRsW4qLS3N6AgikhnH/3AsQ+uCJX/Nozyh3UDqFn0IgF9PfczENT8YnEhERETyCxVtC4iEtDgAqpXI238gZ1rqZZj/NMy6H1IvGZ1GRERu4uzZsxw6dIhBgwbh4+ODxWKhRo0aPPDAAwCEhYUxb948AGbMmEFERASvv/46pUqVIigoiMmTJ2d47Jdeeok6depw8uRJ57rPP/+csmXLUrx4cUaMGJGu/cyZM6lWrRoBAQE0adKELVu2OLc1b96cESNG0LZtW3x9fVm8eDFhYWG8/fbbNGrUiKJFi9KsWTOOHTvmug9HRG7ftaJtmTuNzZFNM7qNppx7C0wmO18cGM+321caHUlERETyARVtCwgfczFMVn8igu8wOoprxMc4lh5FwbuYsVlERPIIu93Opatpufaw2+2ZylW8eHGqVKlC//79mTt3LkePHr1p+927d+Pj48OJEyeYM2cOw4cP59ChQ+napKWlMWDAANauXcuqVasICXGM2Z6UlMSePXs4cOAAa9as4aOPPiIqKgqAVatW8eSTT/LJJ59w5swZevToQfv27UlISHAed8aMGYwfP57k5GRat24NOAq93377LWfOnMHX15eXX345s/9JRCQ3HN/kWOaT8Wz/zWw288MD7xJABCZzGm9sHs5vh3YYHStLrFYrUVFRREVFYbVajY4jIiJSKOSv+4skQ2se/cboCK514a8/+IuVB012ICICwOVUK9XH/pJr59vzWjt8PG59qWAymYiKiuKdd97h1VdfZd++fVSpUoX333+fNm3aXNe+RIkSvPDCC4Cj92tYWBjbtm2jYsWKAFy6dInu3bvj6enJL7/8gqenp3Nfu93O+PHj8fLyolq1akRGRrJ582aaN2/O119/zcMPP8w999wDwJAhQ/j4449ZuHAhvXv3BqB37940aNAAAG9vbwCeeuopwsPDAejTpw9vvfVWdj8yEXE1ux1qdAe/0HxbtAXwcvfgpwc+ps3s3qRYohkS9TTf+M6iZnA5o6OJiIhIHqWetpI3xf9VtA0oQGP0iogUYMHBwUycOJHdu3dz5swZOnToQPfu3Tl//vx1bYOCgtK99vX1JSkpyfl627ZtLF26lHHjxqUr2AL4+fnh4+Nzw32PHz9OWFhYuvbh4eEcP37c+bpcuesLJMHBwRlmERGDmUzQbDg8/D0UDbp1+zysmE8R5nb7DEtaKexu8TyycCDHE67/N1JEREQE1NNW8qr4f/S0FRERALzdLex5rV2uni87AgMDGTduHJMmTSI6OjrL+0dGRtK5c2fatGnDsmXLqFGjRqb2K1OmDEeOHEm37siRI5QpU8b5WrOei4iRKgQG8Xm7T+j/a1/S3GK574cB/NLza4r5FDE6moiIiOQx+stF8qYL6mkrIvJvJpMJHw+3XHuYMjk8zYULFxgzZgz79u3DarVy6dIlJk2aRGBgIFWrVs3Wex0wYABvvvkmrVq1YseOzI39+PDDDzNr1izWrl1LWloaH374IefOnaNjx47ZyiAiecCxPyDplNEpXOrOMpWYcPeHYPXisuUgXeY+waXUFKNjiYiISB6jnraSN6mnrYhIvuHh4cGJEyfo2LEjp0+fxsvLi3r16rF48WJ8fX2zfdxHH30Ui8VCmzZt+OWXW4/l26xZMz788EMGDBjAyZMnqVmzJosXLyYgICDbGUTEQHY7fNsTLp2Fgb9B6fpGJ3KZjlXqc/7y27y1dSjx5u10mzOERb2m4GbJ3h0OIiIiUvCY7JmdGlqyLDExEX9/fxISEvDz8zM6Tv5yfBOcPQAVW+b78ctERLLrypUrREdHEx4ejpeXl9FxRPKUjP7/0PVX5uX5z+p8NHwQARYPGHUc3DxvuUt+M3ndPD7/8xVMJhtVvO9lbo//5slhXKxWK6tXrwagadOmWFRcFhERybbMXoPlvSsCEYAyd0JELxVsRURERAqr45scy+DaBbJgCzAkshvdygwBYP/lBQz6eaKxgTJgMpkIDAwkMDAw00PniIiIyO1R0VZERERERPKe4384lmXuMjZHDhvfuj9NAvsDsCH+K0b88qnBia5nNpupXbs2tWvXzpM9gUVERAoi/caVvOfUHvhjmmPiCREREREpnJxF2zuNzZELPu48lOo+XQFYdHIK4377yuBEIiIiYjQVbSXvOfQbLBwKG/5ndBIRERERMULqZYjb4XhewHvaXvPt/a8R7tEGk8nO9zHv8ubKOUZHEhEREQOpaCt5T/xRx7JYeWNziIiIiIgxTu4AWxr4loKAckanyRVms5kfH3yHMm7NMZnszIp+g0lrfjQ6FuCYiGzVqlWsWrUKq9VqdBwREZFCQUVbyXsu/FW0DVDRVkRERKRQKl4Jun8KLUZBIZr4ys1iYf5D7xFsjsRksjH9wGv8b+MCo2MBYLPZsNlsRscQEREpNFS0lbznwhHHUj1tRURERAon3+JQ5yG48zGjk+Q6Dzc3fu75ISVMd2IyW/nfnpeZvulXo2OJiIhILlPRVvIWux3iYxzP1dNWRERERAohL3cPFvaaSgARmMxpTNr5ogq3IiIihYyKtpK3JJ+GtMuACfzLGp1GRERERHLb/iWw7FVIOGF0EkP5uHuyqOen+NlrYjKnMmnnSD7a8LPRsURERCSXqGgrecu1Scj8SoObh7FZRERERCT3rXkP1kyCTdONTmK4op7eLOk13dnj9uN9L/PO6u+MjiUiIiK5QEVbyVtKVYf+i6Hz+0YnERGRLHrssccwmUzs3bvXuS4qKgqTyUSRIkUoWrQo5cuXZ8yYMc7JbJo3b46npydFixbF39+fmjVr8sILL3DmzJnrjr9q1SpMJhMjR468bpvVamXSpEnUrl0bX19fSpUqRaNGjZgyZQppaWkAzJgxA4vFQpEiRfDz86N06dLcf//9rFq1Koc+EZH0PvroI8LCwvDy8qJhw4b8/vvvGbb97LPPaNq0KcWKFaNYsWK0bt36uvaPPvooJpMp3aN9+/Y5/TZyVuxWOLYBzG5w13+MTpMnFPX05pfen1PS1BCTycqXh8bz2oqZRscSERGRHKaireQtnkWgfCTc0droJCIikgVJSUnMnTuXwMBApk2blm6bv78/ycnJJCUlsWDBAj777LN0bSZMmEBSUhLx8fHMnTuXEydOUL9+fU6dOpXuONOmTSMwMJCvvvrKWYi9pnfv3kyfPp0PPviAs2fPEhcXx5QpU/j9999JSEhwtqtVqxbJyckkJiayc+dOWrZsSYcOHZg1a1YOfCoif5szZw5Dhw7llVdeYcuWLdSpU4d27dpx+vTpG7aPioqiV69erFixgvXr11O2bFnatm3LiRPphwxo3749J0+edD6+/fbb3Hg7OWfDVMeyRnfwCzE2Sx7i4+7Jkt5TCbU0xWSyMffo24z85bNczRAQEEBAQECunlNERKQwU9FWREQkv7l6MeNH6pUstL2ccdssmjNnDr6+vkyYMIGvv/6a1NTUG7arVasWTZs2ZceOHddtM5lMVK9enZkzZ+Ln58fEiROd2xITE/n++++ZMmUKSUlJLFy40LktKiqK+fPn8/PPP9O8eXO8vb0xm83ceeedfPXVVxQvXvyGWQIDA3n66ad5+eWXGTZsmLP3r0hOmDRpEgMHDqR///5Ur16dqVOn4uPjw/TpNx4CYNasWTz11FNERERQtWpVPv/8c2w2G8uXL0/XztPTk+DgYOejWLFiufF2ckbSKdj1g+N5wyeNzZIHebi5sbDXh4R7tMFksrMo7gOeXjA5V85tsViIiIggIiICi8WSK+cUEREp7FS0lbxlzWTY/CVcOm90EhGRvOuN0Iwfcx9J3/adShm3ndkjfdvJtf7elkXTpk2jT58+9OzZk4sXL/LzzzeeLGf79u2sWrWKevXqZXgsNzc3unXrxsqVK53rvv32W4oUKcIDDzzA/fffn66n7i+//EKDBg0IDw/Pcm6AHj16EBcXx/79+7O1v8itXL16lc2bN9O69d93EpnNZlq3bs369eszdYxLly6RmppKYGBguvVRUVGUKlWKKlWq8OSTT3Lu3DmXZs9Vm6aDLRXKNIAy9Y1Okye5WSzMe+hdqnp3BmDVuWncP+dF0qxWg5OJiIiIqxletM3K2F4A3333HVWrVsXLy4tatWqxaNGidNvtdjtjx44lJCQEb29vWrduzYEDB5zbjxw5woABAwgPD8fb25uKFSvyyiuvcPXq1XTHmTt3LhEREfj4+FC+fHneeecd171puTGbFVa8AT8/C1fijU4jIiKZtGfPHjZs2EC/fv0oUqQI3bt3T1dUTUhIICAggGLFivHggw/yzDPP8Oijj970mKVLl+b8+b+/wJs2bRq9e/fGzc2Nvn37smjRIk6ePAnA2bNnCQ1NX2iuUqUKAQEBeHt7Z1hA/ue5gHTnE3Gls2fPYrVaCQoKSrc+KCiIuLi4TB1j5MiRhIaGpiv8tm/fnq+++orly5czYcIEVq5cSYcOHbBmUMBLSUkhMTEx3SPPSEuBTX/9u9HoCWOz5HFms5k5PcbTKKAvAH9eWUjrmf8h/nLW75IQERGRvMvNyJNfG9tr6tSpNGzYkMmTJ9OuXTv2799PqVKlrmu/bt06evXqxZtvvsm9997LN998Q7du3diyZQs1a9YE4O233+aDDz7gyy+/JDw8nJdffpl27dqxZ88evLy82LdvHzabjU8++YRKlSqxa9cuBg4cyMWLF3n33XcBWLx4MX369OHDDz+kbdu27N27l4EDB+Lt7c3gwYNz9TMqVM5HgzUF3LwhIMzoNCIiedfo2Iy3mf512+rwgzdp+6/vbofszFacadOmUadOHerUqQNAv379aN++vXPsTX9/f+Lj47N0zBMnTjh7FO7cuZM//viDTz/9FIAWLVoQGhrKl19+yYsvvkiJEiWu6yV77XVYWFiGBax/ngu4rgejSF7x1ltvMXv2bKKiovDy8nKu79mzp/N5rVq1qF27NhUrViQqKopWrVpdd5w333yTV199NVcyZ1nqJajeDQ5HQbUuRqfJ88xmM591Hc5rK0KYe+Rdzpk30ebb3szp9ikVAoNufYAsslqtbNiwAYBGjRppiAQREZFcYGhP26yO7fX+++/Tvn17hg8fTrVq1Xj99depV68eU6ZMARy9bCdPnsyYMWPo2rUrtWvX5quvviI2NpZ58+YBjh4JX3zxBW3btqVChQp06dKFYcOG8eOPPzrP8/XXX9OtWzeeeOIJKlSoQKdOnRg1ahQTJkzAbrfn+OdSaJ3e41iWqgpmwzuBi4jkXR6+GT/cvbLQ1jvjtpmUmprK119/zZ9//ukcU7NPnz5YrVZmzJiRrbeXlpbG/Pnzad68OYCz12779u0JDg4mNDSU06dPO68X2rRpwx9//MGRI0eydb7vv/+e4OBgqlSpkq39RW6lRIkSWCyW6ybXO3XqFMHBwTfd99133+Wtt97i119/pXbt2jdtW6FCBUqUKMHBgzf+smbUqFEkJCQ4H8eOHcvaG8lJ3sWg07vw9EawuBudJt8Y2+JhRkZMAqs3VyyH6f5/vVh3dF+OnCs1NTXD8cpFRETE9QyrjGVnbK/169enaw/Qrl07Z/vo6Gji4uLStfH396dhw4Y3HS8sISEhXe+alJSUdL0YALy9vTl+/DhHjx7N/JuUrDm917EsVd3YHCIikmk//fQTiYmJbNmyhW3btrFt2za2b9/Oyy+/zPTp07P8Zee+ffvo168fCQkJDB06lKtXrzJz5kzeeust5/G3bdvGxo0bOXz4MKtWraJly5Z06tSJLl26sHLlSi5fvozNZmPr1q0kJSVleK4LFy7wySefMH78eN59913M+sJQcoiHhwf169dPN4nYtUnFGjdunOF+b7/9Nq+//jpLlizhzjvvvOV5jh8/zrlz5wgJCbnhdk9PT/z8/NI98hyzenBm1SN1W/JB82mY0gKxuZ1h0PJH+Xrrb0bHEhERkdtk2F8n2RnbKy4u7qbtry2zcsyDBw/y4YcfMmjQIOe6du3a8eOPP7J8+XJsNht//vmncwbra+Pn3UieHicsP3D2tK1mbA4REcm0adOm0atXL6pWrZpuBvtnn32W2NjYTBVtR44cSdGiRfH39+e+++4jODiYTZs2ERQUxLx580hNTeWpp55Kd/w6derQrVs3Pv/8cwBmz57NI488wtNPP03x4sUJCQnhiSeeYPz48bRv3955rp07d1KkSBH8/PyoUaMGv/zyCwsXLqRPnz459hmJAAwdOpTPPvuML7/8kr179/Lkk09y8eJF+vfvD0Dfvn0ZNWqUs/2ECROcX36EhYURFxdHXFwcycnJACQnJzN8+HA2bNjAkSNHWL58OV27dqVSpUq0a9fOkPeYKTf6N2H9/+DY7zfeJpnSokIt5naZhYe1HFguMmH78zy9YDI2m83oaCIiIpJNho5pa7QTJ07Qvn17HnjgAQYOHOhcP3DgQA4dOsS9995Lamoqfn5+PPfcc4wbN+6mvXDy9Dhh+YGzp62KtiIi+cW/JwS9pkSJEly+fBngpuPZRkVF3fT4Dz74IA8++OANt33//ffO525ubgwfPpzhw4dneKxHH330lhOgieSUhx56iDNnzjB27Fji4uKIiIhgyZIlzs4GMTEx6a4zP/74Y65evUqPHj3SHeeVV15h3LhxWCwWduzYwZdffkl8fDyhoaG0bduW119/HU9Pz1x9b5lmTYO3ykGRUuBXGvxCoWgQrP8I7DZ4ZgsUr2h0ynyraskyLOs5hwd+GMop23pWnZtGm1m7mXv/JIr7FDU6noiIiGSRYUXb7IztFRwcfNP215anTp1Kd1vYqVOniIiISLdfbGwsLVq0IDIy0jmxyTUmk4kJEybwxhtvEBcXR8mSJZ23s1WoUCHD9zRq1CiGDh3qfJ2YmEjZsmUzbC//kJYC5/4af03DI4iIiEgBNHjw4Awntf33Fxi3GqPZ29ubX375xUXJcknyKUi9CBeiHY9/qtBcBVsXKOZThF/7TOW5xVNYcWYap20baP1tDz5q/T6R5asaHU9ERESywLDhEbIztlfjxo3TtQdYunSps314eDjBwcHp2iQmJrJx48Z0xzxx4gTNmzenfv36fPHFFxn2nrVYLJQuXRoPDw++/fZbGjduTMmSJTN8T/linLC8yuIBL+yDvj9B0RuPwyYiIiIi+VjRYHhuO/RfDPdPg9avQoNBUKcXdHjb6HQFhtls5sNOzzKq7vtgLUqaWyyDlj/C5HXzjI4mIiIiWWDo8AhDhw6lX79+3HnnnTRo0IDJkydfN7ZX6dKlefPNNwF47rnnaNasGRMnTqRTp07Mnj2bTZs2OXvKmkwmhgwZwvjx47njjjsIDw/n5ZdfJjQ0lG7dugF/F2zLly/Pu+++y5kzZ5x5rvXUPXv2LN9//z3NmzfnypUrfPHFF3z33XesXLkyFz+dQsZkctwqV6SU0UlEREREJCeYLVAszPGQHNenTnMigufw6MJnuGI5xLQDL/Pb0VV82fU1ivkUyfLxihbVEAsiIiK5ydCibVbH9oqMjOSbb75hzJgxjB49mjvuuIN58+ZRs2ZNZ5sRI0Zw8eJFHn/8ceLj42nSpAlLlizBy8sLcPTMPXjwIAcPHqRMmTLp8vxzspQvv/ySYcOGYbfbady4MVFRUTRo0CAnPw4RERERERGXqRFUlqg+c3jk/8ZyIGUR0VeX0uLb7Yxr/F+6VW+U6eNYLBbq16+fg0lFRETk30z2zEzrLNmSmJiIv78/CQkJGirhVla+A1eTIaIPlKxsdBoRkTzhypUrREdHEx4e7vzyUUQcMvr/Q9dfmafPqnD59I/FTNn5X+yWBOx2M3f5P8TH9w7Dy93D6GgiIiKFSmavwQwb01YknW0zYe1kxwQVIiIiIiLiUo/f1YEF9/0fJUx3YTLZ2JT4LU2/foC1R/caHU1ERERuQEVbMd7Vi3DhiON5qWqGRhERERERKajKBZRk+cOfc1+ZYdhtnlyxHGbQb73p++N4klIuZ7if1Wplw4YNbNiwAavVmouJRURECi8VbcV4Z/Y5lr6lwLeEsVlERERERAows9nMq636MaPNbIraamAyp7E1aQ5NZ97L9E2/ZrjflStXuHLlSi4mFRERKdxUtBXjnf7rliz1shURKZBq1KjBggULMtV23LhxdOvW7YbboqKiCAgIcF0wEZFC7M4ylVjT7xt6lh8F1qJY3U7z3u4XaPP1E+w/E2t0PBERkUJPRVsxnrNoW93YHCIikm379++nc+fOlChRAj8/P6pWrcqECRMA2L17N/fee6/BCV1jxowZmEwmhg0blm59t27dGDdunEvPdeTIEUwmE/Hx8S497r99/fXX1KpVCz8/P4oXL06TJk34448/cvScIpI3mM1mXmremyU9FlDBoy12u4k421p6/NSD95Zv5UqqhkIQERExioq2YrzTexxL9bQVEcm3OnXqRJ06dYiJieHChQv88MMPVKhQwehYN5WWlpat/YoVK8bHH3/MsWPHXJwo961evZpnn32Wjz/+mISEBGJiYhg9ejSenp45cr7U1NQcOa6I3J7SfoHM7zWR1+76BIstENyS+Oj3n2jxbhRzNx3DarMbHVFERKTQUdFWjBcf41iqp62IyE3Z7XYupV7KtYfdnrk/0s+ePcuhQ4cYNGgQPj4+WCwWatSowQMPPABAWFgY8+bNAxw9VSMiInj99dcpVaoUQUFBTJ48OcNjv/TSS9SpU4eTJ086133++eeULVuW4sWLM2LEiHTtZ86cSbVq1QgICKBJkyZs2bLFua158+aMGDGCtm3b4uvry+LFiwkLC+Ptt9+mUaNGFC1alGbNmt2yGFuuXDnuv/9+XnnllQzbHDp0iM6dO1OyZEnKly/P+PHjsdlsAFSrVo0lS5YAsHPnTkwmE1OnTgUgISEBd3d3zp49S4MGDQAoU6YMRYoUYdasWQD8+uuv1K1bF39/f+rVq8eyZcuc53300UcZOHAgPXv2pGjRolSpUoWoqKgMc27cuJF69erRpEkTTCYTvr6+dOzYkdq1awN/D0nx4YcfEhISQnBwMK+88orzZyMmJoY2bdpQsmRJihUrRqdOnThy5Ei6PAMGDODBBx/Ez8+PqVOnsmXLFho1aoSfnx8lSpSgc+fOzvanT5+mT58+hISEEBoaypAhQ0hJSbnpfw8RcZ37ajTmsdr3A1A04BAnE64w4vsddPlwDdFnL2b694KIiIjcPjejA4gweBMkHIMiQUYnERHJ0y6nXabhNw1z7Xwbe2/Ex93nlu2KFy9OlSpV6N+/P48//jgNGzakfPnyGbbfvXs3jzzyCCdOnGDt2rW0adOGzp07U7FiRWebtLQ0Bg0axKFDh1i1ahX+/v7s37+fpKQk9uzZw4EDB4iOjubOO++kY8eONG/enFWrVvHkk0+ycOFCGjduzEcffUT79u05cOAA/v7+gKNovGDBAu666y7nhDozZ85k/vz5hISEcN999/Hyyy8zY8aMm77n1157jRo1ajBs2DCqV0//peOlS5do1aoVQ4YM4YcffiAuLo6OHTsSEhLCgAEDaNGiBStWrKB9+/b89ttvVKxYkRUrVvDEE08QFRVF9erVKVGiBL///jvh4eEcP37cOZbvwYMH6dq1K7NmzaJLly7MmzePLl26sHv3bsLDwwGYM2cOP/30E7NmzeLNN9/k0UcfTVdI/afIyEhGjx7NqFGjaNu2LXfeeSdFixZN1yYpKYktW7Zw6NAhZ5G2QoUK9OvXD5vNxtChQ2nRogVXr15lwIABDBw4kKVLlzr3//bbb/m///s/Zs+ezZUrV2jdujWdO3dm3bp1pKamsnHjRsDxpUSXLl24++67OXToEJcvX6ZHjx6MHz+e119//ab/PUTEde4uczef7foM32LRDKp+B/9bEc2fp5OYf/4kof7e2IJO07JaMCaTyeioIiIiBZp62orxTCYIKAduOXMrpoiI5CyTyURUVBR16tTh1VdfpUKFClSvXj1d4e6fSpQowQsvvIC7uzvNmzcnLCyMbdu2ObdfunSJ7t27k5CQwC+//OIsuIKjsDd+/Hi8vLyoVq0akZGRbN68GXCMzfrwww9zzz334O7uzpAhQyhWrBgLFy507t+7d28aNGiAyWTC29sbgKeeeorw8HC8vLzo06eP83g3ExYWxuOPP87o0aOv27Zw4UKKFSvGkCFD8PDwoFy5cjz33HN88803AM6iLcBvv/3Gyy+/zMqVK52vW7ZsmeF558yZQ/Pmzbnvvvtwc3OjR48eNGnShG+//dbZ5loR22Kx0L9/f44ePcq5c+dueLzIyEiWLFnCgQMHeOihhyhevDg9evTgzJkzzjY2m40JEybg4+ND1apVGTx4MF9//bXzc+jQoQNeXl74+fnx0ksvsXr1amevYoC2bdvSrl07zGYzPj4+uLu7c/ToUWJjY/H09OSee+4BYNOmTRw4cIB33nkHHx8fihcvzujRo52fm4jkjtola+Pr7ktCSjxNalxl1YgWDGhSgTSzB4fj0/jPV5vo9MEaFu44qWETREREcpB62hYkF8+B2QzexYxOIiIiOcDbzZuNvTfm6vkyKzg4mIkTJzJx4kTOnz/Pf//7X7p3705MTMx1bYOC0t9Z4evrS1JSkvP1tm3bSExMZNOmTdeNrern54ePj88N9z1+/DjNmzdP1/5aT9VrypUrd8PsNzreG2+8wRtvvAFA06ZNWbx4cbr9XnrpJSpWrMj69evTrT9y5Ai7du1y9o4FR+GzbNmygGOYht69e3PhwgXWrVvHzJkzee+999i9eze//fab85w3cvz4ccLCwtKtq1ChQrr3+O/3Azh7KHfo0MG5LTk5GYCWLVs6C8Xbt2/n0UcfTVdk9vLyolSpUs79ypcvz4kTJwA4c+YMzz33HKtXryYhIZqHQ0MAADMUSURBVAGAlJQUkpKSnMX2f3/m06dP59VXX6V+/foUK1aMwYMHM3jwYI4cOUJ8fDyBgYHOtna7HatVEyGJ5CZ3szsNgxvy27HfWHtiLYPq1OSle2tw6p6KfL76MDEbY9hzMpGnv9lChZK+PNGsIl0jQvF0sxgdXUREpEBRT9uC4uBy+F9DWPyi0UmyZtU7MOcROLTC6CQiInmeyWTCx90n1x7ZvfU1MDCQcePGcfHiRaKjo7O8f2RkJB999BFt2rRh9+7dmd6vTJky1w0DcOTIEcqUKeN8bTZn/tJn9OjRJCcnk5ycfF3BFhw9hocPH87IkSPTrS9btiz169cnPj7e+UhMTHS+l5IlS1K1alUmT55MpUqVKFq0KC1btmTOnDns27fP2fP0Rlkz8x4z0rRpU+f7uVaw/bc6derw2GOPsXPnTue6K1eucPr0aefrmJgYSpcuDcCoUaO4dOkSW7ZsITExkVWrVgGkG/fy3++jYsWKfPXVV8TFxfH5558zbNgwNm/eTNmyZSlVqlS6zy0hISHDrCKSc+4ufTcA62LXOdcF+XnxUqfqrB3Zkmdb3YGflxuHz1xkxPc7iHzzN95eso8T8ZeNiiwiIlLgqGhbUHj5w6VzsGM27FtkdJrMO7QC9v4Eyadv3VZERPKkCxcuMGbMGPbt24fVauXSpUtMmjSJwMBAqlatmq1jDhgwgDfffJNWrVqxY8eOTO3z8MMPM2vWLNauXUtaWhoffvgh586do2PHjtnKkBnPP/88Bw4cYM2aNc519957L6dOneJ///sfV65cwWq1sn///nQTgrVo0YLJkyfTokULwNHb9f3333dOMAaO4q7ZbObQoUPO/R566CGioqKYP38+aWlp/Pjjj6xatYqePXtmK/+8efP4+uuvncMhREdHM2vWLCIjI51tzGYzo0aN4vLly+zfv5+PPvqIPn36AJCYmIiPjw8BAQGcO3eOV1999Zbn/Oqrrzh16hQmk4mAgADMZjMWi4W77rqLsmXLMmbMGJKSkrDb7Rw9evSGBXMRyVmRoY5/A7af2U7S1aR024r5ejC0TWXWjWrFix2qEuLvxbmLV/lf1CGaTviNx7/axJoDZzVpmYiIyG1S0bagKHMnRD7jeL5gCFw6b2icTLHb4fQex/NS1YzNIiIi2ebh4cGJEyfo2LEj/v7+lCtXjrVr17J48WLn7fnZ8eijj/LOO+/Qpk2bdGPeZqTZ/7d35/FNVfn/x19J2nRfWFtKW1qgyFaBArJUAbU/ERgHBgcQETs6P/wiMAPyRUdBREVlURTFBUdnlEFQxxlF5SfMsAhuyFZQdlktdINSu9C9yf39EQkEChRoSQLv5+MRk9x77r2fmyPw6acn5/Tuzbx58/jjH/9IgwYN+OCDD1i2bJnLNAW1LSgoiCeeeMJlztjg4GBWrlzJqlWriIuLo0GDBtx9991kZ2c729x8880UFhY6pyXo3bs3JSUlLvPZBgQEMG3aNPr160d4eDiLFy+mZcuWfPzxx0ybNo369evz9NNP88knn9C8efNLir9+/fosXLiQtm3bEhwcTJ8+fejatStz5sxxtgkJCaFjx440b96cXr16ce+995KamgrAU089xb59+6hXrx7Jycku0y+cy8qVK+nQoQPBwcEMHDiQ559/no4dO2KxWFi6dCkZGRm0adOGsLAwBgwYwL59+y7p3kTk0kWHRNMstBk2w8aGrA3YbDY2bNjAhg0bnFOWBPv5MLp3C75+5Gbm35NEzxYNsBvw35053PO39dz8whpeWbWXw3klbr4bERER72Qy9CvQOlNYWEhYWBgFBQWEhobW/QUry+DNXpC7BxKHwJ1v1/01L0dRNsy5DkxmmJwFvv7ujkhExKOUlZVx8OBB5yJZIlfamjVrGDRoEPn5+e4O5Szn+vNxxfMvL6bPSs7nufXP8f7u9xnSaghTbnAsMgiOqVYslurnr92bU8TC73/m35uPUFxxaj7qbvH1ubNzNP0TmxDsp2VVRETk2lbTHEwjba8mvv4w6A1HEXTbR7BrqbsjOr+To2zrt1DBVkRERETEgyRHnZrXtqbjfBIiQnh6YHs2TElhzpAOJLdsgMkE6w/m8ci/fqTLMysYs2gzn/2QyYnyqroMX0RExOvp15xXm+jOkDwevnnJMU1CbA8IauDuqKp3dJfjWVMjiIiIiIh4lK6RXfEx+5BxIoOfC3++qGOD/Hy4s3M0d3aOJiO/lCVbMvh32hEOHCvmi23ZfLEtG6uPmV4JjeifGMmtbSIIC/CtozsRERHxTiraXo36PAZ7lkPTzmDx4C52zmfb1r1xiIiISLX69OnjkVMjiEjdC/QNJKlxEhuyN7Aucx3RRF/SeZqGBzD25paM6dOC7RmFLNuexbLt2RzMLWblrhxW7srBYjbRpVk9bm7dmFtaNyahcTAmk6mW70hERMS7eHBFTy6Zjx/835XgF+zuSM7PVgUWP420FRERERHxQD2jejqKtlnrGOI75LLOZTKZSIwOIzE6jIf7XseenCK+2JbN8u1Z/JRzgvUH81h/MI+Zy3bTNDyAPtc1olerRnSPb0BYoEbhiojItUdF26vV6QVbux3K8iGwvtvCqdbgN2Hga2DY3R2JiIiIiIicIblpMnPT5rIpZxODowZjMVe/ANnFMplMtI4MpXVkKBP/TyvSj5fw5Z6jfLnnKOv2Hycjv5RF69NZtD4dswnaNw2jZ4uG9GzRgK5x9Qmw1k4cIiIinkxF26td8XH4eBSU/gL3L3eMwvUknjx9g4iIiIjINaxVvVY08G9AXmkeuVW5NAttVifXiW0QSGrPOFJ7xlFaYWPdgVy+3H2M7/bnsv9YMT8eKeDHIwXMX7sfX4uJxKZhdI2rT9e4+nRuVo96QdY6iUtERMSdVDG72lWcgIzNjpG2/5kMA+a4OyKHjDSI6gSaq0pERERExCOZTWaSmybz2f7PyGiQwbDOw+r8mgFWC7e0juCW1hEAZBeU8d3+XL7dd5zv9ueSVVBGWno+aen5vPnVAQASGgfTuVk9OsaE0zE2nITGIVjM+jlDRES8m4q2V7t6zWDwW7B4CGx8G2K6w/WXNx/VZcvYDG/dArE9IHWpRtuKiIiIiHionlE9+Wz/Z3yX8R0TO0+84tePDPNncFI0g5OiMQyD9LwSNh76hU2H8thwKI8Dx4rZe/QEe4+e4IONhwEIslpIjA6jY0w9EpuG0b5pKLH1A7W4mYiIeBVVy64FrW6DXg/DV8/D53+GyERo3Np98ayZ6Xiu31wFWxERERERD9YjqgcmTOz5ZQ+5pbk0DGjotlhMJhPNGgTRrEEQv+8cDcDxE+Vs+vkXtqTns/XwL/x4pIDiChvfH8jj+wN5zmND/H1oH+Uo4LaLCqN1kxBaNArG12J21+2IiIiclypm14o+j8HhDXBwLfxzJIz60nWxsivlyGbY+18wWeCm/73y1xcREa/Srl07Zs2axW9+85taP3dcXBxz585l0KBBHnXeurzni5Genk5sbKxbYxAR96vvX5829dtQ+nMpn639jNT/k4rF4jkLgTUI9qNvu0j6tosEwGY32Hu0iK3p+fxwJJ8dmYXsziqiqKyKdQeOs+7AceexvhYTLRuH0CYyhNZNQkiICKFVRAhRYf4alSsiIm6nou21wmyBO/8Gb94EuT/BFw/D79648nGs/XWUbYe7oEGLK399ERGpMytXruSpp55iy5YtWCwWevbsybPPPktSUlKNjq+u2Lljx446itYzXOl7rqysJCQkBIvFgslkwmq1MmDAAN5++238/E4tVrpmzRqef/55Bg8ezJ133kl4eHidxSQinq9nk558t/87fsr6yd2hXJDFbKJ1ZCitI0O56wbHL54qquzsPVrEjoxCtmUUsDv710JueRW7sgrZlVUIW06dI8hqoWXjYBIiQmjZOJjmDYNo0TiY2PqBGpkrIiJXjIq215LgRjDkXfh0HHQffeWvf2STRtmKiFylPvvsM+6++25eeuklvvjiC6qqqpg/fz69evVizZo1dOnSxd0hCo6CcFVVFbm5uQQHB3PgwAE6d+7MggULeOCBB5zt2rRpw+rVq2nbtq0KtiLCLbG38B3fcbDgIIt3LWZk+5HuDumiWH3MtIsKo11UGEO7xgBgGAYZ+aXszipid3Yhu7KL2JtTxMHcYoorbPxwpIAfjhS4nMfHbCK2fiDNGwUR39AxTUNcgyCaNQgkKjxAi5+JiEit0q8JrzWx3WHM99Ckw5W/9hqNshURuRoZhsH48eN59NFHGTVqFCEhIdSrV4/HHnuMYcOGMWnSJMAxqvTkyNvQ0FD69u1LZmYmAEOGDCE9PZ3hw4cTHBzM6NGjnccsWbLEea24uDhmzJhB165dCQoKol+/fuTl5TFmzBjCw8NJSEjgu+++c7Z/8cUXSUhIICQkhBYtWvDqq6/W+L5efPFFYmNjCQkJIS4ujrfffhuAnJwchg4dSqNGjYiNjWXKlClUVVWd8zwmk4mtW7c638+dO5c+ffrU+J4vdL24uDhmz55N9+7dCQkJoXfv3hw+fLjaWNLS0mjXrh3BwY4pkpo3b05ERARlZWUu7QzDYMeOHdx55501/rxE5OrVtkFbukd1B+CFTS+w6udVbo7o8plMJqLrBZLSNoJxtyTw2t1J/Peh3ux8+nZWTuzFGyOSeCilFXd0iKJdVCiBVgtVdoMDucWs3HWUt74+yONLtnPP39Zz0+wvaT11Gbe8sIbUv2/g8SXbeHPtfr7YlsW2IwXkl1RgGIa7b1lERLyMRtpei05f/GvfKjj4FaQ8CXU5b1NZIRQc1ihbEZGr0E8//cShQ4e4++67z9p3991307dvX0pLSwF4++23WbZsGbGxsTz44IPcc889rF69mo8++qjGc8F++OGHfP7554SGhpKcnEz37t2ZOXMm8+bN4+mnn2b06NH8+OOPADRr1ozVq1cTHR3NmjVr6N+/P506dSI5OfmC9/T444+TlpZG69atycnJIScnx3lPkZGRHDx4kOPHj9O/f3+CgoKYPHnyRX1uNb3nmlzvvffe49NPP6VJkyYMHjyYqVOn8u677551rrS0NLp3dxReysrK+Nvf/kZmZiYDBgxwaRcZ6Zgbsnnz5hd1TyJy9erYqCOF5YXssu3iL1//hb8H/p3rG13v7rBqna/FTMvGIbRsHEK/xFPbDcMgu7CMA8eKOXDsBIeOl/Dz8WIOHS8h/XgJFTY7B3KLOZBbXO15A60WmoYHEBUeQNN6ATQND6BJmD+RYf40CQsgMtSfAKvnzBUsIiLup6LttawoGz4YAVWlUFUGt8+su8Ktfyg8+B1kbtEoWxGRy1RRUQGAr6+vc6EUm82GzWbDbDbj4+NTa21rIjc3F4CoqKiz9kVFRWGz2cjLc6zg/eCDD9K6dWsAZs+eTWRkJEeOHCE6OrrG13vwwQeJiXF8vbV///58/fXXDB48GIBhw4Yxffp0KioqsFqtLiNFb775Zvr27cuaNWsuWLS1WCzO0abNmjUjIiKCiIgIMjIyWL16NdnZ2QQHBxMcHMyUKVN48sknL7poWxM1vd6YMWOIj48HYMSIEcycObPa86WlpbF161b++c9/EhwcTMeOHVm5ciUtWujfZhE5P5PJxE3RN5FTmcPXWV/zp9V/4r1+7xETGuPu0K4Ik8lEk7AAmoQFkNyyocs+m91R0P05t5jDv5RwOK/01+cS0vNKyT1RTkmFjb1HT7D36IlzXiM80JfIUH8iQv2JCPWjccivz6H+NArxo1GwH41C/PD3VXFXRORaoKLttSwkEvrNhM8nwPr5YKuE/i+AuY5mzTBbIFpzGoqIXK7nnnsOgIcffpigoCAAvv32W1avXk1SUhK//e1vnW2ff/55KisrmTBhgnNu0o0bN7J8+XISExNdippz586lpKSEMWPG0Lhx4xrH07Ch44fXzMzMs0ZmZmZmYrFYqF+/PuAY+XpSREQEfn5+ZGRkXFTRNiIiwvk6MDDwrPeGYVBSUoLVamXRokXMmTOHQ4cOYbfbKSkpcRY3z6dFixYsWLCAV199lfvuu4/u3bsze/ZsysvL8ff3d7lm8+bNOXLkSI3jvxhHjhyp0fVOjowFCAoKoqio6Kxz2Ww2fvjhB77++ms6d+5cJ/GKyNXNbDIzq9cs/rjij+zK28WYVWNY2G8h4f7h7g7NrSxmE03DHaNnq1NWaSMzv5TM/DIy8kvIyC8j45dScgrLyCwoJbugjJIKG/klleSXVLI7++y/w08X4udDoxA/Ggb70SDY6ngE+dEw2EqDYD/qB1mpF2ilXpAv9QKtWjxNRMRLqWh7rev8BzD7wqdjYdPfoKIY+s8G/7DaOX/mVlgzA+78G/gF1845RUTEo7Rq1YpmzZrx/vvvM2XKFJd977//PsnJyQQEOH6Q/fnnn537jh49Snl5OU2bNgXAXMu/NExPTyc1NZXly5fTp08ffHx8GDRoUI3nFRw6dChDhw6ltLSUJ554gpEjR7J8+XLKysrIyclxFlIPHTp03qJzUFAQJSUlzvdZWVnO1xe65+jo6Iu+3rns3r2b8vJy2rVrd9HHioj4+voCEOgbyKu3vsqIL0ZwqPAQ478czxspbxDoG+jmCD2Xv6+F5o2Cad6o+p+HDMOgsKyK7IIysgpKOVpUzrGicnIKy8gpLCO7sJzconKOnSinospOUXkVReVV55yK4Uwhfj7UC7JSL9CXsEAr4QG+hAf6Eh5oJSzA1/kI9fchLNCXUH9fQgN8CbJaLuqbNyIiUrtUtBXoNALMPrBkNPz4ARxY4yjcth146ec0DEhbAF88ArZyR+G277O1FrKIyLXs5NfiT/4ADTjndj2zCPjwww+f1bZr164kJSWd1XbChAlnta0Jk8nESy+9xMiRI4mMjGTYsGFUVVXx5ptv8sEHH7Bq1akFa958800GDhxIbGwsf/nLX+jVq5ezABkREcH+/fsv6trnc+LECQzDoHHjxpjNZr744gv++9//8sADD1zw2D179pCens6NN96I1WolODgYHx8fmjZtys0338ykSZOYP38+x48f59lnnyU1NfWc50pKSmLhwoXccMMNbN++nYULF9KqVasa3fOlXO9c0tLSaNmyJf7+/uds8/vf/54mTZqwb98+SkpKWLNmjX5gFxEsFovLtDKNAxvz+q2vc++ye0k7msawpcOYedNM2jXUL4UuhclkchZOr4sMOWe7k8XdY0Xl5J5wFHbziis4fqKc3OIK8k5UcLzYse2Xkkp+KanAMHAWedPzLi4uswlC/H0J8fc59eznQ7C/D8F+px5Bpz0H+lkcr60+BFotBPpZCLT6EOhrwWzWvyciIhdDRVtx6DAMQpvA0ofg+D7IO3jp56oogf83EX543/H+uv7Qa1LtxCkiIlit1rO2WSwWLJaz57i73LY19bvf/Y5///vfTJ8+nfHjx2M2m+nRowdffvklXbt2dba7//77GT58OPv27aN79+4sWrTIuW/y5Mn8+c9/Zvr06dx99928/vrrlxwPQNu2bZkyZQq33HILNpuN3/72ty5TR5xPRUUFU6dOZefOnZjNZjp06OBc2Gvx4sWMGzeOZs2aERAQwIgRI3jkkUfOea558+aRmppKeHg4ycnJpKamsm7duhrf88Ve71zS0tK4/vrzLxq0fft2nnjiCa6//npuu+02cnJyXKZekMvz2muv8fzzz5OdnU2HDh2YN28eN9xwwznbf/TRR0ydOpVDhw6RkJDArFmz6N+/v3O/YRhMmzaNt956i/z8fJKTk3njjTdISEi4Ercj17iEegm8nvI6k9ZO4lDhIe754h7GdBzD/e3vx2LWnKt14fTibsvGF/4Wo81uUFhaSV5JBb8UVzimXyitJL+kgoJSR1E3v6SSwrIqCksrKSytpODXR5XdwG7gfA+llx2/v6+ZQKsPAb4WAqwWAq0W/H0dzwG+joffr88BVjP+Po79fr6O136+Zsd7n1PPfr9ut1rM+Pma8bNYsPqYsfqYsahILCJezmTU9DuCctEKCwsJCwujoKCA0NBQd4dTM1XlsPld6HI/WH4daXV8P1iDHHPgXkjuPvjnSDi6E0xmuHUa9Pxz3c2TKyJyFSsrK+PgwYPEx8efd3Skt4iLi2Pu3LkMGjTI3aFINU6cOMHtt9/ON998A0CvXr1YvXq1y2J1nuRcfz48Nf/68MMPuffee5k/fz7dunVj7ty5fPTRR+zZs6faOaS/++47evXqxYwZM/jNb37D4sWLmTVrFmlpabRv3x6AWbNmMWPGDBYsWEB8fDxTp05l27Zt7Ny5s0Z/Z3jqZyXepaC8gKfWPcWKn1cAkNQ4iedueo6mwU3dHJlcKsMwKKu0U1TmKOie/lxcXkVRWRUnyqs4UeZ4XVxRRXF5FcXlNufrE+U2SiuqKKm04a6Kg4/Z5CzgWi1nPPuY8bWY8bWY8LU4tvtazPhYTFh/ffa1nGrjYzHja3Y8+1hM+Jodzye3W8wm5/E+ZhMW86nXPr++tphP7nNss5hxtDObMP+6z2wyuby3nLFNRK4ONc3BVLStQ1dFIlxVAfNvhNw9UC8eYntAbHfHc3AjsNsg6NfVUzPS4O1bwbBDUGMY8g7E3eje+EVEvJiKtnIlrVu3jsWLFzNv3jxsNhu9e/d2FnA9kbcVbbt160bXrl159dVXAbDb7cTExPCnP/2JRx999Kz2w4YNo7i4mKVLlzq3de/enY4dOzJ//nwMwyAqKor//d//ZdIkxzeaCgoKiIiI4N133+Wuu+66YEye+lmJ57HZbGzbtg2AxMTEs76tYRgGn+3/jOfWP0dJVQnBvsHc3eZuooOjaRjQkMaBjWkY0JB6/vUwmzSY41pysgBcUlFFSYWNkgobpZU2SiqqKKu0ObeVVzq2l1XaKa20UVpho6zSRnmVnbJK19flVfZfHzbKK+3O7RU2u9sKxFeKxWzCYjJhNvPr86nCrtlkchSCf93ueG/CbMKljdnMqdemU69Np7UzObfz63vHuU9ve3K/6bT3JnDuN7mc/1Q7EyffO/Zx2rEml9en2ju245yyyXTGMSfbmFy2nX4Ox8az2p/2ntOOgdP2ndYeXK/haOd6Lk5ve47znTyO09pRTdszz396YxOnPotqz2E6fdsZB7tscz327GNOXevM7a7XN521/0LnOP1U556N60LHmappef4Yz3e9ekFWgv3qfsBCTXMwzxw6IZ7jl4Pg4weYHK9/OQg/LD61v8e4U3PV1o93FGzjboI7367ZyFwRERHxCD/88ANJSUmAY9Gyk/PuyuWrqKhg8+bNPPbYY85tZrOZlJQU51QZZ1q3bh0TJ0502da3b1+WLFkCwMGDB8nOziYlJcW5PywsjG7durFu3bpqi7bl5eWUl5c73xcWFl7Obck1Jj8//5z7TCYTA1sOJCkiiclfT2brsa389ce/ntXObDJj5iKKthcxsNB0MY3F+/j8+qjmd9gWIOjXx5kM53+q2V7dO6Pa5uc5tvqN1Z/jPFsvsdhsAFWXdujFXcQA7HV9IRH3+13MQ0xPufi1I+qKirZyfo2ug9FfQ1kBHN4I6esg/XvI2ARVZVCUfaptQD3481aoF3e+X5OIiMg16tChQ+4OQc5j9OjRztdbt26lY8eO7gvmKpObm4vNZiMiIsJle0REBLt37672mOzs7GrbZ2dnO/ef3HauNmeaMWMGTz311CXdg0hNxITE8M7t7/Dx3o/ZcXwHR0uOkluay9GSo+SV5WE37NgvpvJzlY+alKvIZfz4a7rM40Wk9phMnvXbCbcXba/0ggyHDh1i+vTprF69muzsbKKiorjnnnuYMmWKywIs//nPf5g2bRo7duzA39+fXr16MWfOHOLi4urss/Bo/mGQkOJ4gGPahKoy8D9jGHf9+Csfm4iIiNSqLVu2cMcdd7g7DKlljz32mMvo3cLCQmJiYtwYkVyNfMw+DL1u6FnbK+2V5JflY9SwEnuuWfxqeryIiMjFCrV61nRRbi3afvjhh0ycONFlQYa+ffued0GG4cOHuyzIMGjQIJcFGWbPns0rr7zisiBD3759nQsy7N69G7vdzptvvknLli3Zvn07o0aNori4mBdeeAFwfN1s4MCBTJw4kUWLFlFQUMBDDz3E4MGDSUtLu6KfkcfysToeIiIictU5mRNJ7WjYsCEWi4WcnByX7Tk5OURGVj+dVGRk5Hnbn3zOycmhSZMmLm3ONUraz88PPz+/S70Nkcvia/alUWAjd4chIiLiNdw6C/yLL77IqFGjuO+++2jbti3z588nMDCQv//979W2f/nll7n99tt5+OGHadOmDdOnTycpKcm5oINhGMydO5fHH3+cgQMHcv311/OPf/yDzMxM5/xft99+O++88w633XYbzZs357e//S2TJk3i448/dl5n8+bN2Gw2nnnmGVq0aEFSUhKTJk1i69atVFZW1vnnIiIiIiJXD6vVSufOnVm1apVzm91uZ9WqVfTo0aPaY3r06OHSHmDFihXO9vHx8URGRrq0KSwsZP369ec8p4iIiIh4D7cVbU8uyHD64gk1WZDh9PbgWJDhZPsLLchwLgUFBdSvX9/5vnPnzpjNZt555x1sNhsFBQUsXLiQlJQUfH19z3me8vJyCgsLXR4iIiIiIhMnTuStt95iwYIF7Nq1iwcffJDi4mLuu+8+AO69916XhcrGjx/P8uXLmTNnDrt37+bJJ59k06ZNjBs3DnAs/DRhwgSeeeYZPvvsM7Zt28a9995LVFQUgwYNcsctioiIiEgtctv0CJ6yIMO+ffuYN2+ey9cA4+Pj+e9//8vQoUP5n//5H2w2Gz169OCLL7447z1pcQcRERERqc6wYcM4duwYTzzxBNnZ2XTs2JHly5c789b09HTM5lPjKXr27MnixYt5/PHHmTx5MgkJCSxZssQ5JRjAI488QnFxMQ888AD5+fnceOONLF++HH//apZXF7lMp///KSIiInXP7QuRuVNGRga33347Q4YMYdSoUc7t2dnZjBo1itTUVIYPH05RURFPPPEEv//971mxYgUmU/VLO2pxBxERERE5l3HjxjlHyp5pzZo1Z20bMmQIQ4YMOef5TCYTTz/9NE8//XRthShSLYvFQq9evdwdhoiIyDXFbUVbdy/IkJmZyc0330zPnj3561//6rLvtddeIywsjNmzZzu3vffee8TExLB+/Xq6d+9ebXxa3EFEREREREREREQul9u+4+LOBRkyMjLo06cPnTt35p133jnrqz4lJSVnbbNYLM4YRUREriT92yNyNsMw3B2CiIiIiEidcev0CBMnTiQ1NZUuXbpwww03MHfu3LMWZGjatCkzZswAHAsy9O7dmzlz5jBgwAA++OADNm3a5Bwpe/qCDAkJCcTHxzN16lSXBRlOFmybNWvGCy+8wLFjx5zxnBypO2DAAF566SWefvpp5/QIkydPplmzZnTq1OkKfkIiInIts1qtmM1mMjMzadSoEVar9ZxT9IhcSwzD4NixY5hMpvMuEisitcNut7N9+3YA2rdvr/ltRURErgC3Fm3dsSDDihUr2LdvH/v27SM6OtolnpMjNm655RYWL17M7NmzmT17NoGBgfTo0YPly5cTEBBQ1x+LiIgI4Fj0JT4+nqysLDIzM90djohHMZlMREdHO78NJSJ1xzAM8vLynK9FRESk7pkM/atbZwoLCwkLC6OgoIDQ0FB3hyMiIl7KMAyqqqqw2WzuDkXEY/j6+lZbsFX+VXP6rKSmbDYbX3/9NQA33XSTflkiIiJyGWqag7l1pK2IiIhc2MmvgOtr4CIiIiIiItcGTUYkIiIiIiIiIiIi4kFUtBURERERERERERHxICraioiIiIiIiIiIiHgQzWlbh06u8VZYWOjmSERERESuDSfzLq21e2HKVaWmbDYbxcXFgOP/Fy1EJiIiculqmq+qaFuHioqKAIiJiXFzJCIiIiLXlqKiIsLCwtwdhkdTrioiIiLiPhfKV02GhiHUGbvdTmZmJiEhIZhMpjq9VmFhITExMRw+fJjQ0NA6vZbUPvWfd1P/eTf1n/dS33m3uuo/wzAoKioiKioKs1kzgZ3PlcxVQX9mvZ36z7up/7yX+s67qf+8m7vzVY20rUNms5no6Ogres3Q0FD9ReDF1H/eTf3n3dR/3kt9593qov80wrZm3JGrgv7Mejv1n3dT/3kv9Z13U/95N3flqxp+ICIiIiIiIiIiIuJBVLQVERERERERERER8SAq2l4l/Pz8mDZtGn5+fu4ORS6B+s+7qf+8m/rPe6nvvJv679qjPvdu6j/vpv7zXuo776b+827u7j8tRCYiIiIiIiIiIiLiQTTSVkRERERERERERMSDqGgrIiIiIiIiIiIi4kFUtBURERERERERERHxICraXiVee+014uLi8Pf3p1u3bmzYsMHdIckZZsyYQdeuXQkJCaFx48YMGjSIPXv2uLQpKytj7NixNGjQgODgYO68805ycnLcFLGcz8yZMzGZTEyYMMG5Tf3n2TIyMrjnnnto0KABAQEBJCYmsmnTJud+wzB44oknaNKkCQEBAaSkpLB37143Riwn2Ww2pk6dSnx8PAEBAbRo0YLp06dz+rT86j/P8dVXX3HHHXcQFRWFyWRiyZIlLvtr0ld5eXmMGDGC0NBQwsPD+eMf/8iJEyeu4F1IbVOu6h2Ur149lKt6J+Wr3km5qnfxplxVRdurwIcffsjEiROZNm0aaWlpdOjQgb59+3L06FF3hyanWbt2LWPHjuX7779nxYoVVFZWctttt1FcXOxs89BDD/H555/z0UcfsXbtWjIzMxk8eLAbo5bqbNy4kTfffJPrr7/eZbv6z3P98ssvJCcn4+vry7Jly9i5cydz5syhXr16zjazZ8/mlVdeYf78+axfv56goCD69u1LWVmZGyMXgFmzZvHGG2/w6quvsmvXLmbNmsXs2bOZN2+es436z3MUFxfToUMHXnvttWr316SvRowYwY4dO1ixYgVLly7lq6++4oEHHrhStyC1TLmq91C+enVQruqdlK96L+Wq3sWrclVDvN4NN9xgjB071vneZrMZUVFRxowZM9wYlVzI0aNHDcBYu3atYRiGkZ+fb/j6+hofffSRs82uXbsMwFi3bp27wpQzFBUVGQkJCcaKFSuM3r17G+PHjzcMQ/3n6f7yl78YN9544zn32+12IzIy0nj++eed2/Lz8w0/Pz/j/fffvxIhynkMGDDAuP/++122DR482BgxYoRhGOo/TwYYn3zyifN9Tfpq586dBmBs3LjR2WbZsmWGyWQyMjIyrljsUnuUq3ov5aveR7mq91K+6r2Uq3ovT89VNdLWy1VUVLB582ZSUlKc28xmMykpKaxbt86NkcmFFBQUAFC/fn0ANm/eTGVlpUtftm7dmtjYWPWlBxk7diwDBgxw6SdQ/3m6zz77jC5dujBkyBAaN25Mp06deOutt5z7Dx48SHZ2tkv/hYWF0a1bN/WfB+jZsyerVq3ip59+AuCHH37gm2++oV+/foD6z5vUpK/WrVtHeHg4Xbp0cbZJSUnBbDazfv36Kx6zXB7lqt5N+ar3Ua7qvZSvei/lqlcPT8tVfWr1bHLF5ebmYrPZiIiIcNkeERHB7t273RSVXIjdbmfChAkkJyfTvn17ALKzs7FarYSHh7u0jYiIIDs72w1Rypk++OAD0tLS2Lhx41n71H+e7cCBA7zxxhtMnDiRyZMns3HjRv785z9jtVpJTU119lF1f5eq/9zv0UcfpbCwkNatW2OxWLDZbDz77LOMGDECQP3nRWrSV9nZ2TRu3Nhlv4+PD/Xr11d/eiHlqt5L+ar3Ua7q3ZSvei/lqlcPT8tVVbQVcYOxY8eyfft2vvnmG3eHIjV0+PBhxo8fz4oVK/D393d3OHKR7HY7Xbp04bnnngOgU6dObN++nfnz55Oamurm6ORC/vnPf7Jo0SIWL15Mu3bt2Lp1KxMmTCAqKkr9JyJSR5Svehflqt5P+ar3Uq4qdUXTI3i5hg0bYrFYzlr1Mycnh8jISDdFJeczbtw4li5dypdffkl0dLRze2RkJBUVFeTn57u0V196hs2bN3P06FGSkpLw8fHBx8eHtWvX8sorr+Dj40NERIT6z4M1adKEtm3bumxr06YN6enpAM4+0t+lnunhhx/m0Ucf5a677iIxMZGRI0fy0EMPMWPGDED9501q0leRkZFnLVBVVVVFXl6e+tMLKVf1TspXvY9yVe+nfNV7KVe9enharqqirZezWq107tyZVatWObfZ7XZWrVpFjx493BiZnMkwDMaNG8cnn3zC6tWriY+Pd9nfuXNnfH19Xfpyz549pKenqy89wK233sq2bdvYunWr89GlSxdGjBjhfK3+81zJycns2bPHZdtPP/1Es2bNAIiPjycyMtKl/woLC1m/fr36zwOUlJRgNrumLBaLBbvdDqj/vElN+qpHjx7k5+ezefNmZ5vVq1djt9vp1q3bFY9ZLo9yVe+ifNV7KVf1fspXvZdy1auHx+WqtbqsmbjFBx98YPj5+RnvvvuusXPnTuOBBx4wwsPDjezsbHeHJqd58MEHjbCwMGPNmjVGVlaW81FSUuJsM3r0aCM2NtZYvXq1sWnTJqNHjx5Gjx493Bi1nM/pK/IahvrPk23YsMHw8fExnn32WWPv3r3GokWLjMDAQOO9995ztpk5c6YRHh5ufPrpp8aPP/5oDBw40IiPjzdKS0vdGLkYhmGkpqYaTZs2NZYuXWocPHjQ+Pjjj42GDRsajzzyiLON+s9zFBUVGVu2bDG2bNliAMaLL75obNmyxfj5558Nw6hZX91+++1Gp06djPXr1xvffPONkZCQYAwfPtxdtySXSbmq91C+enVRrupdlK96L+Wq3sWbclUVba8S8+bNM2JjYw2r1WrccMMNxvfff+/ukOQMQLWPd955x9mmtLTUGDNmjFGvXj0jMDDQ+N3vfmdkZWW5L2g5rzMTYfWfZ/v888+N9u3bG35+fkbr1q2Nv/71ry777Xa7MXXqVCMiIsLw8/Mzbr31VmPPnj1uilZOV1hYaIwfP96IjY01/P39jebNmxtTpkwxysvLnW3Uf57jyy+/rPbfu9TUVMMwatZXx48fN4YPH24EBwcboaGhxn333WcUFRW54W6ktihX9Q7KV68uylW9j/JV76Rc1bt4U65qMgzDqN2xuyIiIiIiIiIiIiJyqTSnrYiIiIiIiIiIiIgHUdFWRERERERERERExIOoaCsiIiIiIiIiIiLiQVS0FREREREREREREfEgKtqKiIiIiIiIiIiIeBAVbUVEREREREREREQ8iIq2IiIiIiIiIiIiIh5ERVsRERERERERERERD6KirYiI1FhcXBxz5851dxgiIiIiItVSvioiVwsVbUVEPNQf/vAHBg0aBECfPn2YMGHCFbv2u+++S3h4+FnbN27cyAMPPHDF4hARERERz6V8VUSk7vi4OwAREblyKioqsFqtl3x8o0aNajEaERERERFXyldFRBw00lZExMP94Q9/YO3atbz88suYTCZMJhOHDh0CYPv27fTr14/g4GAiIiIYOXIkubm5zmP79OnDuHHjmDBhAg0bNqRv374AvPjiiyQmJhIUFERMTAxjxozhxIkTAKxZs4b77ruPgoIC5/WefPJJ4Oyvm6WnpzNw4ECCg4MJDQ1l6NCh5OTkOPc/+eSTdOzYkYULFxIXF0dYWBh33XUXRUVFzjb/+te/SExMJCAggAYNGpCSkkJxcXEdfZoiIiIiUtuUr4qI1D4VbUVEPNzLL79Mjx49GDVqFFlZWWRlZRETE0N+fj633HILnTp1YtOmTSxfvpycnByGDh3qcvyCBQuwWq18++23zJ8/HwCz2cwrr7zCjh07WLBgAatXr+aRRx4BoGfPnsydO5fQ0FDn9SZNmnRWXHa7nYEDB5KXl8fatWtZsWIFBw4cYNiwYS7t9u/fz5IlS1i6dClLly5l7dq1zJw5E4CsrCyGDx/O/fffz65du1izZg2DBw/GMIy6+ChFREREpA4oXxURqX2aHkFExMOFhYVhtVoJDAwkMjLSuf3VV1+lU6dOPPfcc85tf//734mJieGnn36iVatWACQkJDB79myXc54+31hcXBzPPPMMo0eP5vXXX8dqtRIWFobJZHK53plWrVrFtm3bOHjwIDExMQD84x//oF27dmzcuJGuXbsCjmT53XffJSQkBICRI0eyatUqnn32WbKysqiqqmLw4ME0a9YMgMTExMv4tERERETkSlO+KiJS+zTSVkTES/3www98+eWXBAcHOx+tW7cGHKMFTurcufNZx65cuZJbb72Vpk2bEhISwsiRIzl+/DglJSU1vv6uXbuIiYlxJsAAbdu2JTw8nF27djm3xcXFORNggCZNmnD06FEAOnTowK233kpiYiJDhgzhrbfe4pdffqn5hyAiIiIiHkv5qojIpVPRVkTES504cYI77riDrVu3ujz27t1Lr169nO2CgoJcjjt06BC/+c1vuP766/n3v//N5s2bee211wDHwg+1zdfX1+W9yWTCbrcDYLFYWLFiBcuWLaNt27bMmzeP6667joMHD9Z6HCIiIiJyZSlfFRG5dCraioh4AavVis1mc9mWlJTEjh07iIuLo2XLli6PMxPf023evBm73c6cOXPo3r07rVq1IjMz84LXO1ObNm04fPgwhw8fdm7buXMn+fn5tG3btsb3ZjKZSE5O5qmnnmLLli1YrVY++eSTGh8vIiIiIu6nfFVEpHapaCsi4gXi4uJYv349hw4dIjc3F7vdztixY8nLy2P48OFs3LiR/fv385///If77rvvvAlsy5YtqaysZN68eRw4cICFCxc6F3w4/XonTpxg1apV5ObmVvs1tJSUFBITExkxYgRpaWls2LCBe++9l969e9OlS5ca3df69et57rnn2LRpE+np6Xz88cccO3aMNm3aXNwHJCIiIiJupXxVRKR2qWgrIuIFJk2ahMVioW3btjRq1Ij09HSioqL49ttvsdls3HbbbSQmJjJhwgTCw8Mxm8/913uHDh148cUXmTVrFu3bt2fRokXMmDHDpU3Pnj0ZPXo0w4YNo1GjRmctDAGOEQeffvop9erVo1evXqSkpNC8eXM+/PDDGt9XaGgoX331Ff3796dVq1Y8/vjjzJkzh379+tX8wxERERERt1O+KiJSu0yGYRjuDkJEREREREREREREHDTSVkRERERERERERMSDqGgrIiIiIiIiIiIi4kFUtBURERERERERERHxICraioiIiIiIiIiIiHgQFW1FREREREREREREPIiKtiIiIiIiIiIiIiIeREVbEREREREREREREQ+ioq2IiIiIiIiIiIiIB1HRVkRERERERERERMSDqGgrIiIiIiIiIiIi4kFUtBURERERERERERHxICraioiIiIiIiIiIiHiQ/w+RC8nEeM0IrgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# =========================================================\n", + "# Figure 2: Cost(P) and Violation(P) convergence\n", + "# =========================================================\n", + "N1_switch = 20 # iteration where SNS switches to Newton\n", + "\n", + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + "\n", + "# -- Cost convergence (left) --\n", + "ax = axes[0]\n", + "ax.plot(hist_sink[\"cost\"][:100], label=\"Sinkhorn\", linewidth=1.5)\n", + "ax.plot(hist_apd[\"cost\"][:100], label=\"APDAGD\", linewidth=1.5, linestyle=\"--\")\n", + "ax.plot(hist_sns[\"cost\"][:100], label=\"Sinkhorn-Newton-Sparse\", linewidth=1.5)\n", + "\n", + "# Optimal solution line\n", + "opt_cost = float(jnp.sum(C * round_transport(P_ref, r, c)))\n", + "ax.axhline(opt_cost, color=\"black\", linestyle=\":\", alpha=0.5, label=r\"Optimal solution $P_\\eta^\\star$\")\n", + "\n", + "ax.set_xlabel(\"Iterations\")\n", + "ax.set_ylabel(\"Cost(P)\")\n", + "ax.set_title(\"Convergence in assignment cost\")\n", + "ax.legend(fontsize=9)\n", + "\n", + "# -- Violation convergence (right) --\n", + "ax = axes[1]\n", + "ax.plot(hist_sink[\"violation\"][:100], label=\"Sinkhorn\", linewidth=1.5)\n", + "ax.plot(hist_apd[\"violation\"][:100], label=\"APDAGD\", linewidth=1.5, linestyle=\"--\")\n", + "ax.plot(hist_sns[\"violation\"][:100], label=\"SNS\", linewidth=1.5)\n", + "ax.axvline(N1_switch, color=\"gray\", linestyle=\"--\", alpha=0.5, label=\"Switch to Newton\")\n", + "\n", + "ax.set_xlabel(\"Iterations\")\n", + "ax.set_ylabel(\"Violation(P)\")\n", + "ax.set_title(\"Convergence in constraint violation\")\n", + "ax.legend(fontsize=9)\n", + "\n", + "fig.suptitle(r\"Figure 2: Random assignment problem (n=500, $\\eta$=1200)\", fontsize=14, y=1.02)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3kAAAJOCAYAAAAK+M50AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAwTFJREFUeJzs3Xd8U+X+wPHPSdIm3bu0QIGyaVkCAoJAQRAQURQRUGQpekW8LhC4/lBwoYKDq15BRcCFXFFwoIJw2YKy95ZNS0v3bpOc3x9p0qZNS1raFNrv+/XKq8k5z3nOc56M5ptnKaqqqgghhBBCCCGEqBE01V0AIYQQQgghhBCVR4I8IYQQQgghhKhBJMgTQgghhBBCiBpEgjwhhBBCCCGEqEEkyBNCCCGEEEKIGkSCPCGEEEIIIYSoQSTIE0IIIYQQQogaRII8IYQQQgghhKhBJMgTQgghhBBCiBpEgjwhhBBCCCGEqEEkyBNCCCGEEEKIGkSCPHHDWbx4MYqicObMmeouihBVoia+xmfOnImiKHbbauJ1VtRbb71Fy5YtMZvN1V0UO46et9pEXqNCVK758+fToEEDcnNzq7soNZ4EeeK6Yv2H6ug2bdq06i5epTh06BDDhg2jcePGeHp6EhwcTM+ePfnpp58qJf9Tp07x2GOP0bhxYwwGA76+vnTv3p158+aRnZ1dKecQN47S3k/Fb76+vnh6epKenl5qXg8++CDu7u4kJia68Ars/fHHH8ycOZOUlJRqK0NlS0tL480332Tq1KloNPJvWVS/G+V9tmHDhlI/07Zv326XNjc3l6lTp1K3bl08PDzo0qULv//+u8N8y5P2epGRkcFLL73EgAEDCAwMRFEUFi9eXCLdjh07mDRpEtHR0Xh5edGgQQPuv/9+jh8/7jBfZ+vC2XRjx44lLy+PBQsWXPM1i7LpqrsAQjjy8ssvExkZabetdevWADz00EOMGDECvV5fHUW7ZmfPniU9PZ0xY8ZQt25dsrKy+O6777jrrrtYsGABjz76aIXzXrVqFcOGDUOv1zN69Ghat25NXl4eW7ZsYcqUKRw6dIiPP/64Eq9GVIXKfI1/8cUXdo8///xzfv/99xLb3dzcGDFiBCtWrGD06NEl8snKyuKHH35gwIABBAUFXXO5oGLX+ccffzBr1izGjh2Lv79/pZSjun322WcYjUZGjhxZ3UURxdzo/28q6kZ7n/3zn//k5ptvttvWtGlTu8djx45l+fLlPP300zRr1ozFixdzxx13sH79em699dYKp71eXLlyhZdffpkGDRrQrl07NmzY4DDdm2++ydatWxk2bBht27YlLi6ODz74gA4dOrB9+3bbdy0rZ+vC2XQGg4ExY8bwzjvv8OSTT9bqngJVThXiOrJo0SIVUHfs2FHdRXEoIyOjSvI1Go1qu3bt1BYtWlQ4j7///lv19vZWW7ZsqV66dKnE/hMnTqjvvffetRSz2lVV/dcmTzzxhOrooz8rK0v18fFR+/fv7/C4r7/+WgXUb775pkLnfemllxyet7zmzJmjAurp06evOa/rRdu2bdVRo0ZdNV11vP4r63kTN5Yb5X22fv16FVC//fbbMtP9+eefKqDOmTPHti07O1tt0qSJesstt1Q47fUkJydHjY2NVVVVVXfs2KEC6qJFi0qk27p1q5qbm2u37fjx46per1cffPBBu+3O1kV562znzp0qoK5bt65C1yqcI/1CxA2ntDESGzZsoFOnThgMBpo0acKCBQtKjCcZO3YsjRo1KpGno3En1m2HDx/mgQceICAgwO7XqIsXLzJ+/Hjq1KmDXq8nOjqazz77rELXpNVqiYiIcNg15ujRo5w7d+6qebz11ltkZGSwcOFCwsPDS+xv2rQpTz31lO3xnj17GDhwIL6+vnh7e3PbbbeV6N5irYOTJ0/aftH18/Nj3LhxZGVl2dItX74cRVHYuHFjifMuWLAARVE4ePCgbZszdXe1+nfm+Xb2XOW5VmueDz/8MHXr1kWv1xMZGcnjjz9OXl5euc/rSPHXeHnKVlEeHh7ce++9rFu3jvj4+BL7v/76a3x8fLjrrruumteWLVu4+eab7Z4bR4pfZ3p6Ok8//TSNGjVCr9cTGhpKv3792L17N2CphylTpgAQGRlp65Z15swZzp49y8SJE2nRogUeHh4EBQUxbNgwh2Oprqfn+vTp0+zfv5++ffs6LKOj17+z11re142zz5sznx1Fz3/8+HFGjRqFn58fISEhzJgxA1VVOX/+PHfffTe+vr6EhYXx9ttvX7W+nL32q72WnE3j6P+Ns589ztZ/ZdRTZX7OlfU+Kw9XfG4VlZ6ejtFodLhv+fLlaLVau54yBoOBhx9+mG3btnH+/PkKpS2PJUuWoCgKe/bs4Z///CdhYWF4eHgwePBgkpOTK5RnUXq9nrCwsKum69atG+7u7nbbmjVrRnR0NEeOHLHb7mxdlLfOOnbsSGBgID/88EO5r1M4T7priutSamoqV65csdsWHBxcavo9e/YwYMAAwsPDmTVrFiaTiZdffpmQkJBrLsuwYcNo1qwZr7/+OqqqAnD58mW6du2KoihMmjSJkJAQfv31Vx5++GHS0tJ4+umnr5pvZmYm2dnZpKam8uOPP/Lrr78yfPjwEulatWpFr169Su16YfXTTz/RuHFjunXrdtVzHzp0iB49euDr68vzzz+Pm5sbCxYsICYmho0bN9KlSxe79Pfffz+RkZHMnj2b3bt38+mnnxIaGsqbb74JwKBBg/D29ua///0vvXr1sjt22bJlREdH27qAlLfuHNW/s893RZ6nq13rpUuX6Ny5MykpKTz66KO0bNmSixcvsnz5crKysnB3d6+U14cjVyvbtXrwwQdZsmQJ//3vf5k0aZJte1JSEqtXr2bkyJF4eHiUmceBAwe4/fbbCQkJYebMmRiNRl566SXq1Klz1fP/4x//YPny5UyaNImoqCgSExPZsmULR44coUOHDtx7770cP36cpUuX8u6779o+E6z1+8cffzBixAjq16/PmTNn+Oijj4iJieHw4cN4enqWON/18Fz/8ccfAHTo0MHhfkev/x07dpTrWp153Tj7vJX3swNg+PDhtGrVijfeeINVq1bx6quvEhgYyIIFC+jTpw9vvvkmX331FZMnT+bmm2+mZ8+epdaXs9d+tdeSs2mKq8j/GmfftxWtp8r+nCvrfVYRzlx/fn4+qampTuUXGBhYYuzquHHjyMjIQKvV0qNHD+bMmUOnTp1s+/fs2UPz5s3x9fW1O65z584A7N27l4iIiHKnLY/9+/ej0WiYMGECbdq0YebMmWzYsIFly5bx73//m5deegm49rqoCFVVuXz5MtHR0Xbbna2LitRZhw4d2Lp16zWXXZShOpsRhSjO2l3T0a14mqLdSAYPHqx6enqqFy9etG07ceKEqtPp7I4dM2aM2rBhwxLnddQlybpt5MiRJdI//PDDanh4uHrlyhW77SNGjFD9/PzUrKysq17rY489Zrs2jUaj3nfffWpSUlKJdIDaq1evMvNKTU1VAfXuu+++6nlVVVWHDBmiuru7q6dOnbJtu3Tpkurj46P27NnTts1aB+PHj7c7/p577lGDgoLsto0cOVINDQ1VjUajbVtsbKyq0WjUl19+2bbN2borq/6dfb7L8zw5e62jR49WNRqNwy7FZrO53Od1pPhrvDzPw9WU1l1TVS3dhsPDw0t0r5k/f74KqKtXr75q/kOGDFENBoN69uxZ27bDhw+rWq22xHmLX6efn5/6xBNPlJl/ad3IHNXptm3bVED9/PPP7bZfT8/1//3f/6mAmp6e7rCMjl7/zl5reV43zj5vzn52FD3/o48+attmNBrV+vXrq4qiqG+88YZte3Jysurh4aGOGTOmxLVV5NqdeS05k6b4a9TZzx5Vdb7+r7WequJzrjK6a5bn9WftdunMrWiZtm7dqg4dOlRduHCh+sMPP6izZ89Wg4KCVIPBoO7evduWLjo6Wu3Tp0+JMh46dEgF1Pnz51cobXn069dPBdRly5bZbQ8PD1fvuOOOa66LosrqrunIF198oQLqwoUL7bY7WxcVqbNHH31U9fDwcKp8omKku6a4Ln344Yf8/vvvdrfSmEwm1q5dy5AhQ6hbt65te9OmTRk4cOA1l+Uf//iH3WNVVfnuu+8YPHgwqqpy5coV261///6kpqbadfcpzdNPP83vv//OkiVLGDhwICaTya4LWNHzXa0VLy0tDQAfH5+rntdkMrFmzRqGDBlC48aNbdvDw8N54IEH2LJliy0/q+J10KNHDxITE+3SDR8+nPj4eLuyLl++HLPZbGuhrEjdFT+3s893RZ+nsq7VbDazcuVKBg8ebPcrsZWiKJX2+nDEmefhWmi1WkaMGMG2bdvsumZ9/fXX1KlTh9tuu63M400mE6tXr2bIkCE0aNDAtr1Vq1b079//quf39/fnzz//5NKlS+Uue9EWxvz8fBITE2natCn+/v6l1vf18FwnJiai0+nw9vZ2qowVudarvW6cfd4q8tkB8Mgjj9jua7VaOnXqhKqqPPzww7bt/v7+tGjRgr///tthPZT32p15LZX39VbR/zXOvm8rUk9V8TlX2Zw5V7t27Ur8zy/tVrRLYrdu3Vi+fDnjx4/nrrvuYtq0aWzfvh1FUZg+fbotXXZ2tsPJcwwGg21/RdKWx/79+4mJieH++++32x4UFGTX+l7Ruqioo0eP8sQTT3DLLbcwZswYu33O1kVF6iwgIIDs7Owq6borLKS7prgude7c2eEXK0fi4+PJzs4uMZMWlJxdqyKKz/KZkJBASkoKH3/8cakzVToa01Rcy5YtadmyJQCjR4/m9ttvZ/Dgwfz555/lnm3K2kWirOnvrRISEsjKyqJFixYl9rVq1Qqz2cz58+ftum0U/eIHlg9ngOTkZNu5BwwYgJ+fH8uWLbMFA8uWLaN9+/Y0b97cdu7y1l3x+nf2+a7o81TWtWZnZ5OWllZi9rGiKuv14Ygzz8O1evDBB3n33Xf5+uuv+de//sWFCxfYvHkz//znP9FqtWUem5CQQHZ2Ns2aNSuxr0WLFvzyyy9lHv/WW28xZswYIiIi6NixI3fccQejR4+2CyhKk52dzezZs1m0aBEXL160dW0ESu36dD0/11bFX/9Q/mu92uvG2eetIp8djs7v5+eHwWAo0QXfz8/vqstzOHvtzryWyvt6q+j/GmfftxWpp6r4nKusz5LynCsgIKDEuNSKatq0KXfffTfff/89JpMJrVaLh4eHw3XZcnJyAPsfD8qT1lkJCQlcvnyZqVOnlth38eJF+vTpY3tcmXVxNXFxcQwaNAg/Pz/buLqinK2LitSZ9b0rs2tWHQnyRK1S2oeJyWQq9ZjiH07WxYpHjRpV4lcvq7Zt25a7bPfddx+PPfYYx48fd/glqiy+vr7UrVvXbnKTylTal/uiX7D0ej1DhgxhxYoV/Oc//+Hy5cts3bqV119/3ZamInVXkX+oFT0XOHetVXFeZ1xr2ZzRsWNHWrZsydKlS/nXv/7F0qVLUVWVBx98sNLOUZr777+fHj16sGLFCtasWcOcOXN48803+f7776/aKv/kk0+yaNEinn76aW655Rb8/PxQFIURI0aUusD49fBcBwUFYTQaSU9Pd9gS7+j1X95rdcXrpiyOzl/RMjl77c68lq7l9VYezl5rReqpuj7nysOZc+Xl5ZGUlORUfiEhIVf9wSkiIoK8vDwyMzPx9fUlPDycixcvlkgXGxsLYNcyW560ztq/fz9QcuztxYsXSU5Opk2bNrZtlV0XpUlNTWXgwIGkpKSwefNmh9flbF1UpM6Sk5Px9PSs8P94cXUS5IkbXmhoKAaDgZMnT5bYV3xbQECAwxksz5496/T5QkJC8PHxwWQyVeqvbdbuDM4OuC7uzjvv5OOPP2bbtm3ccsstpaYLCQnB09OTY8eOldh39OhRNBpNhQaVg6XL5pIlS1i3bh1HjhxBVVW7yWQqo+6cfb6r4nkKCQnB19e3zGC6ql4frvTggw8yY8YM9u/fz9dff02zZs1KrEHlSEhICB4eHpw4caLEPkevN0fCw8OZOHEiEydOJD4+ng4dOvDaa6/ZvnSX9kPN8uXLGTNmjN3Mgzk5ORVezNlVz7W1Nf/06dNOB/9Vca3OPG9V+dnhrPJc+9VeS86msSrP/xpXqarPG1e3rvzxxx/07t3bqbSnT592OEt2UX///TcGg8HWDbp9+/asX7+etLQ0u5bKP//807bfqjxpnWUN8oq/xw8cOFBie2XXhSM5OTkMHjyY48ePs3btWqKiohymc7YuKlJnp0+fplWrVuUuu3CejMkTNzytVkvfvn1ZuXKl3diKkydP8uuvv9qlbdKkCampqbYPXLD80rRixYpynW/o0KF89913Dr8AJiQklHm8o64z+fn5fP7553h4eJT4sHV2CYXnn38eLy8vHnnkES5fvlxi/6lTp5g3bx5arZbbb7+dH374wW7c1eXLl/n666+59dZbK9xdp2/fvgQGBrJs2TKWLVtG586d7bqbXWvdWfNw5vmujHMVp9FoGDJkCD/99BM7d+4ssV9V1So5r6tZW+1efPFF9u7d63QrnlarpX///qxcudLuNXvkyBFWr15d5rEmk6nEDxyhoaHUrVvXrhuQl5cXQIkv9VqttkQrxPvvv19mK31ZXPVcW3+QcXSO0lT2tTr7vFXlZ0d5ynq1a3fmteTs6634uZ39X+MqVfV5U9r7rKpUdByao+vbt28fP/74I7fffrtt5sn77rsPk8lk16U1NzeXRYsW0aVLF7sfJ8qT1lkHDhygfv36tq6qVvv370dRFLsuzlU9Js9kMjF8+HC2bdvGt99+W+aPws7WRUXqbPfu3U7NBi4qTlryRI0wc+ZM1qxZQ/fu3Xn88ccxmUx88MEHtG7dmr1799rSjRgxgqlTp3LPPffwz3/+k6ysLD766COaN29erskw3njjDdavX0+XLl2YMGECUVFRJCUlsXv3btauXVtmV4vHHnuMtLQ0evbsSb169YiLi+Orr77i6NGjvP322yUmYHB2CYUmTZrw9ddf26bhHj16NK1btyYvL48//viDb7/9lrFjxwLw6quv8vvvv3PrrbcyceJEdDodCxYsIDc3l7feesvpeijOzc2Ne++9l2+++YbMzEzmzp1bIs211J2Vs893ZZyruNdff501a9bQq1cvHn30UVq1akVsbCzffvstW7Zswd/fv0rO60qRkZF069bNtoZRebpqzpo1i99++40ePXowceJEjEYj77//PtHR0XY/rhSXnp5O/fr1ue+++2jXrh3e3t6sXbuWHTt22LXadOzYEYAXXniBESNG4ObmxuDBg7nzzjv54osv8PPzIyoqim3btrF27VqCgoIqWAuuea4bN25M69atWbt2LePHj3eqXFVxrc4+b1X12eEsZ67dmdeSs6+34pz97HGlqvi8Ke19Zg3+FEVx6v+Ssyo6Dm348OF4eHjQrVs3QkNDOXz4MB9//DGenp688cYbtnRdunRh2LBhTJ8+nfj4eJo2bcqSJUs4c+YMCxcutMuzPGmdrYf9+/c7bKk/cOAAjRs3ttXrtdQFwAcffEBKSortR4iffvqJCxcuAJauzn5+fjz33HP8+OOPDB48mKSkJL788ku7PEaNGmW772xdlKfOAHbt2kVSUhJ33313ha5TOKlqJ+8Uonys01U7mrK8eJriUwevW7dOvemmm1R3d3e1SZMm6qeffqo+99xzqsFgsEu3Zs0atXXr1qq7u7vaokUL9csvvyxzCYWEhASH5bh8+bL6xBNPqBEREaqbm5saFham3nbbberHH39c5jUuXbpU7du3r1qnTh1Vp9OpAQEBat++fdUffvjBYXqcWEKhqOPHj6sTJkxQGzVqpLq7u6s+Pj5q9+7d1ffff1/Nycmxpdu9e7fav39/1dvbW/X09FR79+6t/vHHH07VQWnPgaqq6u+//64CqqIo6vnz5x2W0Zm6u1r9O/t8O/s8ledaz549q44ePVoNCQlR9Xq92rhxY/WJJ55Qc3Nzy31eR0pbQqE8z0NpylpCoagPP/xQBdTOnTs7nbfVxo0b1Y4dO6ru7u5q48aN1fnz5zt8jxUtf25urjplyhS1Xbt2qo+Pj+rl5aW2a9dO/c9//lMi/1deeUWtV6+eqtFobMcnJyer48aNU4ODg1Vvb2+1f//+6tGjR9WGDRuWmJb/enquVVVV33nnHdXb29vhVPeOXv/OXmt5XzfOPm/OfHaUdf4xY8aoXl5eJdL36tVLjY6OLrWenL12Z15Lzr7eHNWVs589ztZ/ZdRTVXzOOXqfqaqqpqenq4A6YsSIEmWr6Lkqat68eWrnzp3VwMBAVafTqeHh4eqoUaPUEydOlEibnZ2tTp48WQ0LC1P1er168803q7/99pvDfJ1J62w9mEwm1cPDQ502bVqJfe3atVOHDBlSjisuW8OGDa+63EKvXr3KXJahOGfrrTz1O3XqVLVBgwa2pWhE1VBU1UUjr4WoBkOGDOHQoUMOx5qImkeeb3EjSk1NpXHjxrz11lt20+WLG0dt+uz55ZdfuPPOO9m3b5/dhCG1zbXWg9FoxMvLi2nTpjFr1qwqKOH1KTc3l0aNGjFt2jSeeuqp6i5OjSZj8kSNUXwdlhMnTvDLL78QExNTPQUSVUqeb1FT+Pn58fzzzzNnzpxSZwIV14/a/tmzfv16RowYUasDPLj2ejh27Bh5eXm1rh4XLVqEm5ubwzVAReWSljxRY4SHhzN27FgaN27M2bNn+eijj8jNzWXPnj0O138SNzZ5voUQ1UE+e0Rl+Oabbxg5ciRHjx4t97JJQjhDJl4RNcaAAQNYunQpcXFx6PV6brnlFl5//XX5p1tDyfMthKgO8tkjKsOBAwcwGAw0bdq0uosiaihpyRNCCCGEEEKIGkTG5AkhhBBCCCFEDSJBnhBCCCGEEELUIDImrwqYzWYuXbqEj48PiqJUd3GEEEIIIYQQNYCqqqSnp1O3bl00mtLb6yTIqwKXLl0iIiKiuoshhBBCCCGEqIHOnz9P/fr1S90vQV4V8PHxASyV7+vrW23lMJvNJCQkEBISUmak7womk4k//vgDgG7duqHVaqu1PJXteqrrmkzq2XWkrl1D6tl1pK5dQ+rZNaSeXed6q+u0tDQiIiJs8UZpJMirAtYumr6+vtUe5OXk5ODr61vtL0qTyYSXlxdgqZeaGORdL3Vdk0k9u47UtWtIPbuO1LVrSD27htSz61yvdX21IWHXT0mFEEIIIYQQQlwzCfKEEEIIIYQQogaRIE8IIYQQQgghahAZkydcQqPR0LVrV9t9IYQQQtxYzGYzeXl5V02Tn59PTk6O/L+vQlLPruPqunZzc6uUuSskyBMuoSgKBoOhuoshhBBCiArIy8vj9OnTmM3mMtOpqorZbCY9PV3WCq5CUs+uUx117e/vT1hY2DWdT4I8IYQQQghRKlVViY2NRavVEhERUWZrhqqqGI1GdDqdBB9VSOrZdVxZ16qqkpWVRXx8PADh4eEVzkuCPOESZrOZ06dPAxAZGSldC4QQQogbhNFoJCsri7p16+Lp6VlmWgk+XEPq2XVcXdceHh4AxMfHExoaWuGum/JNW7iEqqqcP3+e8+fPo6pqdRdHCCGEEE4ymUwAuLu7V3NJhKgdrD+m5OfnVzgPCfKEEEIIIcRVSYuREK5RGe81CfKEEEIIIYQQogaRIE8IIYQQQtRKiqKwcuVKp9MvXrwYf3//Uvdv2LABRVFISUm55rIJcS0kyBNCCCGEEDVSQkICjz/+OA0aNECv1xMWFkb//v3ZunUrALGxsQwcOLCaSylE5ZPZNYUQQgghRI00dOhQ8vLyWLJkCY0bN+by5cusW7eOxMREAMLCwqq5hM7Jz8/Hzc2tuoshbiDSkieEEEIIIWqclJQUNm/ezJtvvknv3r1p2LAhnTt3Zvr06dx1112AfXfNM2fOoCgK33//Pb1798bT05N27dqxbdu2Us+RkJBAp06duOeee8jNzbVt37VrF506dcLT05Nu3bpx7Ngxu+M++ugjmjRpgru7Oy1atOCLL76w268oCh999BF33XUXXl5evPbaa8ycOZP27dvzxRdfEBkZSXBwMCNHjiQ9Pb2SakzUJBLkCZfQaDTcfPPN3HzzzbJGnhBCCHEDU1WVrDxjtdzKswyTt7c33t7erFy50i4Au5oXXniByZMns3fvXpo3b87IkSMxGo0l0p0/f54ePXrQunVrli9fjl6vt8vj7bffZufOneh0OsaPH2/bt2LFCp566imee+45Dh48yGOPPca4ceNYv369Xf4zZ87knnvu4cCBA7bjT506xcqVK/npp59YuXIlGzdu5I033nD62kTtId01hUsoioKXl1d1F0MIIYQQ1yg730TUi6ur5dyHX+6Pp7tzX191Oh2LFy9mwoQJzJ8/nw4dOtCrVy9GjBhB27ZtSz1u8uTJDBo0CIBZs2YRHR3NyZMnadmypS3NsWPH6NevH/fccw/vvfdeiSnvX3vtNXr16gXAtGnTGDRoEDk5ORgMBubOncvYsWOZOHEiAM8++yzbt29n7ty59O7d25bHAw88wLhx4+zyNZvNLF68GG9vb4xGI6NGjWLdunW89tprTtWJqD2kSUUIIYQQQtRIQ4cO5dKlS/z4448MGDCADRs20KFDBxYvXlzqMUUDwPDwcADi4+Nt27Kzs+nRowf33nsv8+bNc7imWVl5HDlyhO7du9ul7969O0eOHLHb1qlTpxL5NmrUCB8fH7u8i5ZNCCtpyRMuYTabOXfuHAANGjSQLptCCCHEDcrDTcvhl/s73KeqKkajEZ1OVyWLp3u4act9jMFgoF+/fvTr148ZM2bwyCOP8NJLLzF27FiH6YtOcGK9BrPZbNum1+vp27cvP//8M1OmTKFevXrlzsMZjnpAFZ98RVGUcucragf5pi1cQlVVzpw5w5kzZ8rVn14IIYQQ1xdFUfB011XLrTICx6ioKDIzMyt8vEaj4YsvvqBjx4707t2bS5culev4Vq1a2ZZwsNq6dStRUVEVLpMQxUlLXm1hDayq4Fc1IYQQQojrTWJiIsOGDWP8+PG0bdsWHx8fdu7cyVtvvcXdd999TXlrtVq++uorRo4cSZ8+fdiwYYPTyzFMmTKF+++/n5tuuom+ffvy008/8f3337N27dprKpMQRUmQVwso30+AvzfAg/+Feh2ruzhCCCGEEFXO29ubLl268O6773Lq1Cny8/OJiIhgwoQJ/Otf/7rm/HU6HUuXLmX48OG2QM8ZQ4YMYd68ecydO5ennnqKyMhIFi1aRExMzDWXSQgrRZW+c5UuLS0NPz8/UlNT8fX1rbZymM1m4uPjqbPmHyh/r4c734NO4656XFUwmUxs3rwZgB49eqDVlr9P/fXMWtehoaEy3rAKST27jtS1a0g9u47UdcXl5ORw+vRpIiMjMRgMZaat6jF5wkLq2XWqo67Les85G2fIp1xtEFYww1PsvuothxBCCCGEEKLKSZBXC6jh7Sx3JMgTQgghhBCixpMgrzawtuRdPgSm/OotixBCCCGEEKJKycQrtUFgJLj7QF46XDkOdaJdXgSNRkOHDh1s94UQQgghhBBVQ75t1waKBsLaWO7H7q+eIigKvr6++Pr6ygBhIYQQQgghqpC05NUWjbqDRgt6n+ouiRBCCCGEEKIKSZBXW/T5v2o9vdls5sKFCwDUr19fumwKIYQQQghRRSTIEy6hqip///03APXq1avm0gghhBBCCFFzSXNKbZOTBsbc6i6FEEIIIYQQoopIkFebfHEPvBEBpzdXd0mEEEIIIYQQVUSCvNrE4G/5G3+4WoshhBBCCOEq27ZtQ6vVMmjQILvtZ86cQVEU2y0oKIjbb7+dPXv22NLExMTY9uv1eurVq8fgwYP5/vvvSz1fy5Yt0ev1xMXFOdy/fv167rzzTkJCQjAYDDRp0oThw4ezadMmW5oNGzbYzqvRaPDz8+Omm27i+eefJzY29hprRNQGEuTVJqGtLH/jj1RvOYQQQgghXGThwoU8+eSTbNq0iUuXLpXYv3btWmJjY1m9ejUZGRkMHDiQlJQU2/4JEyYQGxvLqVOn+O6774iKimLEiBE8+uijJfLasmUL2dnZ3HfffSxZsqTE/v/85z/cdtttBAUFsWzZMo4dO8aKFSvo1q0bzzzzTIn0x44d49KlS+zYsYOpU6eydu1aWrduzYEDB66tUkSNJ0FebWIL8qQlTwghhBA1X0ZGBsuWLePxxx9n0KBBLF68uESaoKAgwsLC6NSpE3PnzuXy5cv8+eeftv2enp6EhYVRv359unbtyptvvsmCBQv45JNPWLt2rV1eCxcu5IEHHuChhx7is88+s9t37tw5nn76aZ5++mmWLFlCnz59aNiwIW3btuWpp55i586dJcoWGhpKWFgYzZs3Z8SIEWzdupWQkBAmTpxYORUkaiwJ8mqT0CjL34RjYDZXb1mEEEIIcWPLyyz9ZsxxPm1+tnNpK+C///0vLVu2pEWLFowaNYrPPvsMVVVLTe/h4WEpQl5emfmOGTOGgIAAu26b6enpfPvtt4waNYp+/fqRmprK5s2F8yB899135Ofn8/zzzzvMU1GUq16Ph4cH//jHP9i6dSvx8fFXTS9qL1lCoTYJaAQ6AxizIeUMBDZ22ak1Gg3t27e33RdCCCHEDe71uiU2KYAboDa7HR78tnDHnKaQn+U4n4a3wrhVhY/fawNZiSXTzUwtdxEXLlzIqFGjABgwYACpqals3LiRmJiYEmlTUlJ45ZVX8Pb2pnPnzmXmq9FoaN68OWfOnLFt++abb2jWrBnR0dEAjBgxgoULF9KjRw8Ajh8/jq+vL2FhYbZjvvvuO8aMGWN7vG3bNtq0aVPmuVu2bAnA2bNnqVu35HMgBEhLXu2i0UJwc8t9F4/LUxQFf39//P39nfqlSgghhBDiWhw7doy//vqLkSNHAqDT6Rg+fDgLFy60S9etWze8vb0JCAhg3759LFu2jDp16lw1f1VV7b7TfPbZZ7aAEmDUqFF8++23pKen27YV/w7Uv39/9u7dy6pVq8jMzMRkMjl1Xkd5CVGUtOTVNq3ugro3gU/Y1dMKIYQQQpTmXyUnMVFVFaPRiM5db79jysnS81GKtTk8XTmTiixcuBCj0WjX2qWqKnq9ng8++MC2bdmyZURFRREUFIS/v79TeZtMJk6cOMHNN98MwOHDh9m+fTt//fUXU6dOtUv3zTffMGHCBJo1a0ZqaipxcXG21jxvb2+aNm2KTuf8V/IjRyw/1Dds2NDpY0TtIy15tU2vKXDXv6FeR5ee1mw2c/HiRS5evIhZxgMKIYQQNz53r9JvOoPzad08nEtbDkajkc8//5y3336bvXv32m779u2jbt26LF261JY2IiKCJk2aOB3gASxZsoTk5GSGDh0KWALKnj17sm/fPrvzPfvss7aWw/vuuw83NzfefPPNcl1LUdnZ2Xz88cf07NmTkJCQCucjaj5pyRMuoaoqJ06cALDriy6EEEIIUdl+/vlnkpOTefjhh/Hz87PbN3ToUBYuXMiAAQOcyisrK4u4uDiMRiMXLlxgxYoVvPvuuzz++OP07t2b/Px8vvjiC15++WVat25td+wjjzzCO++8w6FDh4iOjubtt9/mqaeeIikpibFjxxIZGUlSUhJffvklAFqt1u74+Ph4cnJySE9PZ9euXbz11ltcuXKF77777hpqR9QG0pJXG+XnQOx+MBmruyRCCCGEEJVu4cKF9O3bt0SAB5Ygb+fOnaSlpTmV1yeffEJ4eDhNmjTh3nvv5fDhwyxbtoz//Oc/APz4448kJiZyzz33lDi2VatWtGrVytaa9+STT7JmzRoSEhK47777aNasGXfccQenT5/mt99+KzHpSosWLahbty4dO3bkjTfeoG/fvhw8eJCoqKjyVomoZaQlr7ZRVZjbHHJT4Ym/IKRFdZdICCGEEKJS/fTTT6Xu69y5s23ykrKWUwDYsGHDVc81dOjQMidMOXzYfn3ivn370rdv3zLzjImJKbNsVyu3ENKSV9soCgQ3s9yXRdGFEEIIIYSocSTIq41CLeuruHoZBSGEEEIIIUTVkyCvNgot6MctLXlCCCGEEELUOBLk1UahrSx/449WbzmEEEIIIYQQlU4mXqmNQgqCvKRTlpk23Qxlp68EGo3GNmOURiO/LQghhBBCCFFV5Nt2beQTBgZ/UM1w5bhLTqkoCkFBQQQFBaEoikvOKYQQQgghRG0kLXm1kaLALU+A1g08g6q7NEIIIYQQQohKJEFebdXreZeezmw2Ex8fD0BoaKh02RRCCCGEEKKKSJAnXEJVVY4etUz0EhISUs2lEUIIIYQQouaS5pTaymyChGNwfE11l0QIIYQQolooisLKlSudTr948WL8/f1L3b9hwwYURSElJeWayybEtZAgr7bKSoQPO8PX90NeVnWXRgghhBCi0iUkJPD444/ToEED9Ho9YWFh9O/fn61btwIQGxvLwIEDq7mUVSsmJgZFUfjmm2/str/33ns0atSoys47duxYhgwZUmX5F7dv3z7uuusuQkNDMRgMNGrUiOHDh9uGC9U2EuTVVl4h4BEIqJB4orpLI4QQQghR6YYOHcqePXtYsmQJx48f58cffyQmJobExEQAwsLC0Ov11VzKq8vPz7+m4w0GA//3f/93zflcrxISErjtttsIDAxk9erVHDlyhEWLFlG3bl0yMzOr9Nx5eXlVmn9FSZBXWymKLIouhBBCiBorJSWFzZs38+abb9K7d28aNmxI586dmT59OnfddRdg313zzJkzKIrC999/T+/evfH09KRdu3Zs27at1HMkJCTQqVMn7rnnHnJzc23bd+3aRadOnfD09KRbt24cO3bM7riPPvqIJk2a4O7uTosWLfjiiy/s9iuKwkcffcRdd92Fl5cXr732GjNnzqR9+/Z88cUXREZGEhwczMiRI0lPT79qXYwcOZKUlBQ++eSTMtP98MMPdOjQAYPBQOPGjZk1axZGoxGAyZMnc+edd9rSvvfeeyiKwm+//Wbb1rRpUz799FNmzpzJkiVL+OGHH1AUBUVR2LBhAwAHDhygT58+eHh4EBQUxKOPPkpGRoYtD2sL4Ny5cwkPDycoKIgnnniizAB169atpKam8umnn3LTTTcRGRlJ7969effdd4mMjAQKu9KuWrWKtm3bYjAY6Nq1KwcPHrTlk5iYyMiRI6lXrx6enp60adOGpUuX2p0rJiaGSZMm8fTTTxMcHEz//v1RVZWZM2faWozr1q3LP//5T9sxubm5TJ48mXr16uHl5UWXLl1s9VFVJMirzUJaWv4mHKnecgghhBDihqGqKln5WaXeso3ZZe6/lpuqqk6X09vbG29vb1auXGkXgF3NCy+8wOTJk9m7dy/Nmzdn5MiRtkCnqPPnz9OjRw9at27N8uXL7VoEX3jhBd5++2127tyJTqdj/Pjxtn0rVqzgqaee4rnnnuPgwYM89thjjBs3jvXr19vlP3PmTO655x4OHDhgO/7UqVOsXLmSn376iZUrV7Jx40beeOONq16Tr68vL7zwAi+//HKpLVubN29m9OjRPPXUUxw+fJgFCxawePFiXnvtNQB69erFli1bMJlMAGzcuJHg4GBbsHLx4kVOnTpFTEwMkydP5v7772fAgAHExsYSGxtLt27dyMzMpH///gQEBLBjxw6+/fZb1q5dy6RJk+zKsn79ek6dOsX69etZsmQJixcvZvHixaVeX1hYGEajkRUrVlz1NTJlyhTefvttduzYQUhICIMHD7YFkDk5OXTs2JFVq1Zx8OBBHn30UUaPHs2OHTvs8liyZAnu7u5s3bqV+fPn89133/Huu++yYMECTpw4wcqVK2nTpo0t/aRJk9i2bRvffPMN+/fvZ9iwYQwYMIATJ6quN53MrlmbWYM8ackTQgghhJOyjdl0+bpLtZz7zwf+xNPN06m0Op2OxYsXM2HCBObPn0+HDh3o1asXI0aMoG3btqUeN3nyZAYNGgTArFmziI6O5uTJk7Rs2dKW5tixY/Tr14977rnH1qJV1GuvvUavXr0AmDZtGoMGDSInJweDwcDcuXMZO3YsEydOBODZZ59l+/btzJ07l969e9vyeOCBBxg3bpxdvmazmcWLF+Pt7Y3RaGTUqFGsW7fOFoiVZeLEicybN4933nmHGTNmlNg/a9Yspk2bxpgxYwBo3Lgxr7zyCs8//zwvvfQSPXr0ID09nT179tCxY0c2bdrElClTbC2hGzZsoF69ejRt2hQADw8PcnNzCQsLs51jyZIl5OTk8Pnnn+Pl5QXABx98wODBg3nzzTepU6cOAAEBAXzwwQdotVpatmzJoEGDWLduHRMmTHB4bV27duVf//oXDzzwAP/4xz/o3Lkzffr0YfTo0bY8rV566SX69etnK0/9+vVZsWIF999/P/Xq1WPy5Mm2tE8++SSrV69m+fLl3HLLLbbtzZo146233rI9XrVqFWFhYfTt2xc3NzcaNGhA586dATh37hyLFi3i3Llz1K1bF7C8xn777TcWLVrE66+/frWnrkKkJa82C7W25FV9kKfRaIiKiiIqKkrWyBNCCCGESwwdOpRLly7x448/MmDAADZs2ECHDh3KbBUqGgCGh4cD2E3ekZ2dTY8ePbj33nuZN29eiQDvankcOXKE7t2726Xv3r07R47Y96zq1KlTiXwbNWqEj4+PXd7WfL/66itb66W3tzebN2+2O1av1/Pyyy8zd+5crly5UiLvffv28fLLL9vlMWHCBGJjY8nKysLf35927dqxYcMGDhw4gLu7O48++ih79uwhIyODjRs32gLb0hw5coR27drZAjzrtZvNZrsurdHR0Wi1WofX+frrr9uV8dy5c4AlsI6Li2P+/PlER0czf/58WrZsyYEDB+zKUDRYCwwMpEWLFra6N5lMvPLKK7Rp04bAwEC8vb1ZvXq17RxWHTt2tHs8bNgwsrOzady4MRMmTGDFihW21t8DBw5gMplo3ry5Xbk3btzIqVOnyqyvayEtebVZSMGYvOQzlhk23Z37ZawiFEUhNDS0yvIXQgghhGt46Dz484E/He5TVRWTyYRWq3UY/FTGucvLYDDQr18/+vXrx4wZM3jkkUd46aWXGDt2rMP0bm5utvvWazCbzbZter2evn378vPPPzNlyhTq1atX7jycUTQQcpSvNW9rvnfddRdduhS2sDoq16hRo5g7dy6vvvpqiZk1MzIymDVrFvfee2+J4wwGA2AZj7Zhwwb0ej29evUiMDCQVq1asWXLFjZu3Mhzzz1XrmssTVnX+Y9//IP777/fts/aOgYQFBTEsGHDGDZsGK+//jo33XQTc+fOZcmSJU6dd86cOcybN4/33nuPNm3a4OXlxdNPP11icpXiz01ERATHjh1j7dq1/P7770ycOJE5c+awceNGMjIy0Gq17Nq1yy5wBUuX4qoiQV4pfv75Z5577jnMZjNTp07lkUceqe4iVT7vEOg7EwIbgyKta0IIIYS4OkVRSu0yqaoqRsWITqerkiCvMkRFRZVrbbziNBoNX3zxBQ888AC9e/dmw4YNdoHG1bRq1YqtW7faukWCZeKQqKioCpcJwMfHx66VzxGNRsPs2bO59957efzxx+32dejQgWPHjtm6WzrSq1cvPvvsM3Q6HQMGDAAsgd/SpUs5fvw4MTExtrTu7u628XtWrVq1YvHixWRmZtoCpa1bt6LRaGjRooVT1xkYGEhgYOBV07m7u9OkSZMSYxC3b99OgwYNAEhOTub48eO0atXKVpa7776bUaNGAZbA/Pjx43ZddUvj4eHB4MGDGTx4ME888YStFfGmm27CZDIRHx9Pjx49nLrGyiBBngNGo5Fnn32W9evX4+fnR8eOHbnnnnsICgqq7qJVvlufcclpVFUlISEBgJCQkOv2g18IIYQQNUNiYiLDhg1j/PjxtG3bFh8fH3bu3Mlbb73F3XfffU15a7VavvrqK0aOHEmfPn3YsGGD3dizskyZMoX777+fm266ib59+/LTTz/x/fffs3bt2msqk7MGDRpEly5dWLBggd14tRdffJE777yTBg0acN9996HRaNi3bx8HDx7k1VdfBaBnz56kp6fz888/2yZ8iYmJ4b777iM8PJzmzZvb8mvUqBGrV6/m2LFjBAUF4efnx4MPPshLL73EmDFjmDlzJgkJCTz55JM89NBDJcbOlcfPP//MN998w4gRI2jevDmqqvLTTz/xyy+/sGjRIru0L7/8MkFBQdSpU4cXXniB4OBg23p+zZo1Y/ny5fzxxx8EBATwzjvvcPny5asGeYsXL8ZkMtGlSxc8PT358ssv8fDwoGHDhgQFBfHggw8yevRo3n77bW666SYSEhJYt24dbdu2tY3/rGzSfOPAX3/9RXR0NPXq1cPb25uBAweyZs2a6i7WDc1sNnP48GEOHz5c7u4KQgghhBDl5e3tTZcuXXj33Xfp2bMnrVu3ZsaMGUyYMIEPPvjgmvPX6XQsXbqU6Oho+vTp4/Si20OGDGHevHnMnTuX6OhoFixYwKJFi+xawaram2++SU5Ojt22/v378/PPP7NmzRpuvvlmunbtyrvvvkvDhg1taQICAmjTpg0hISG2wKdnz56YzeYS4/EmTJhAixYt6NSpEyEhIWzduhVPT09Wr15NUlISN998M/fddx+33XbbNT8fUVFReHp68txzz9G+fXu6du3Kf//7Xz799FMeeughu7RvvPEGTz31FB07diQuLo6ffvoJd3d3AP7v//6PDh060L9/f2JiYggLC3NqQXd/f38++eQTunfvTtu2bVm7di0//fSTrYFo0aJFjB49mueee44WLVowZMgQduzYYWtRrAqKWp65aG8QmzZtYs6cOezatYvY2FhWrFhR4gn68MMPmTNnDnFxcbRr147333/fNgvO8uXL2bBhg+0FN2fOHBRFsZttpyxpaWn4+fmRmpqKr69vpV5beZjNZuLj4wkNDS19spPsZDj3J5hyIeraftUqi8lksg0A7tGjR4k+yTc6p+paXDOpZ9eRunYNqWfXkbquuJycHE6fPk1kZKRtbFZpVFXFaLy+u2vWBFLP5bdhwwZ69+5NcnIy/v7+Th9XHXVd1nvO2TijRn7KZWZm0q5dOz788EOH+5ctW8azzz7LSy+9xO7du2nXrh39+/d3+heYGiXuACwdDr+/VN0lEUIIIYQQQlSCGhnkDRw4kFdffZV77rnH4f533nmHCRMmMG7cOKKiopg/fz6enp589tlngGWWnosXL9rSX7x4sVwDam8oxWfYFEIIIYQQQtzQat3EK3l5eezatYvp06fbtmk0Gvr27cu2bdsA6Ny5MwcPHuTixYv4+fnx66+/Olw00io3N5fc3Fzb47S0NMDSNaQ6x5+ZzWZUVS27DB6BKB6BKNlJmBOOQXi7Ki2L9X5N61rgVF2Layb17DpS164h9ew6UtcVZ6076+1qrGlq4Iig64rUc/n06tXL9v4vb525uq6t7zVHsYSzn2G1Lsi7cuUKJpOpxAw+derU4ehRy6LgOp2Ot99+m969e2M2m3n++efLnFlz9uzZzJo1q8T2hISEEoNaXclsNpOamoqqqmWOPwj0b4J7dhJpp/4kRxteJWUxmUy2KWzj4+Nr5Jg8Z+paXBupZ9eRunYNqWfXkbquuPz8fMxmM0aj0bbAc2ms6+QBNe4H3euJ1LPrVEddG41GzGYziYmJJdYMTE9PdyqPWhfkOeuuu+7irrvucirt9OnTefbZZ22P09LSiIiIICQkpNonXlEUhZCQkDL/oSl1W0PsDvxyY/GtogXLTSaTbT2U0NDQGhnkOVPX4tpIPbuO1LVrSD27jtR1xeXk5JCeno5Op0Onc+6rY/EvpqJqSD27jivrWqfTodFoCAoKKjHxytUmP7LlURUFu54FBwej1Wq5fPmy3fbLly87vb5JcXq9Hr1eX2K7RqOp9n8kiqJcvRyhlsU3lYRjKFVYXutCk1qtttrrpSo4Vdfimkk9u47UtWtIPbuO1HXFaDQaFEWx3cqiqqotjbQwVR2pZ9epjrq2vtccfV45+/lV6z7l3N3d6dixI+vWrbNtM5vNrFu3jltuuaUaS1aNQgsWeEw4UmWn0Gg0hIWFERYWJv9chRBCCCGEqEI1siUvIyODkydP2h6fPn2avXv3EhgYSIMGDXj22WcZM2YMnTp1onPnzrz33ntkZmYybty4aix1NQprC/d+WhjsCSGEEEIIIW5YNTLI27lzJ71797Y9to6XGzNmDIsXL2b48OEkJCTw4osvEhcXR/v27fntt99KTMZSa3j4Q9thVXoKVVVJSkoCIDAwULoWCCGEEEIIUUVqZL+5mJgYu6l+rbfFixfb0kyaNImzZ8+Sm5vLn3/+SZcuXaqvwLWA2WzmwIEDHDhwQKavFkIIIcQNZ+zYsQwZMqTa8xBXl5iYSGhoKGfOnKmW8//222+0b9++Wr/z1sggT1RA8lnY9h/Ytbi6SyKEEEIIcc0SEhJ4/PHHadCgAXq9nrCwMPr378/WrVsrlN+8efPsGgxiYmJ4+umnK6ewZZg5cybt27evlLwURcFgMHD27Fm77UOGDGHs2LGVcg5HGjVqxHvvvVdl+Rf32muvcffdd9OoUSPbthUrVtC1a1f8/Pzw8fEhOjra7vlbvHgxiqIwYMAAu7xSUlJwd3dnw4YNtm0bN26kT58+BAYG4unpSbNmzRgzZgx5eXkADBgwADc3N7766quqvMwySZAnLC4fhNXT4c+Pq7skQgghhBDXbOjQoezZs4clS5Zw/PhxfvzxR2JiYkhMTKxQfn5+fvj7+1duIauBoii8+OKL1V2MKpOVlcXChQt5+OGHbdvWrVvH8OHDGTp0KH/99Re7du3itddeIz8/3+5YnU7H2rVrWb9+fan5Hz58mAEDBtCpUyc2bdrEgQMHeP/993F3d7etpweWVtt///vflX+BTpIgT1jU7WD5m3AE8rKqtyxCCCGEENcgJSWFzZs38+abb9K7d28aNmxI586dmT59um0d5MmTJ3PnnXfajnnvvfdQFIXffvvNtq1p06Z8+umngH1Xy7Fjx7Jx40bmzZtnm+7e2jXw0KFD3Hnnnfj6+uLj40OPHj04deqUXfnmzp1LeHg4QUFBPPHEEyWCDavFixcza9Ys9u3bZzuPtTXx3LlzDBkyBG9vb3x9fbn//vtLLBHmyKRJk/jyyy85ePBgqWnMZjOzZ88mMjISDw8P2rVrx/Lly237O3XqxNy5c22PhwwZgpubGxkZGQBcuHABRVE4efIkMTExnD17lmeeeabEMhzfffcd0dHR6PV6GjVqxNtvv21XjkaNGvH6668zfvx4fHx8aNCgAR9/XHaDxC+//IJer6dr1662bT/99BPdu3dnypQptGjRgubNmzNkyBA+/PBDu2O9vLwYP34806ZNKzX/NWvWEBYWxltvvUXr1q1p0qQJAwYM4JNPPsHDw8OWbvDgwezcubPEc+8qEuQJC99w8A4D1Qxx+6u7NEIIIYS4zplMplJvxccilZW2aOtHWWnLw9vbG29vb1auXElubq7DNL169WLLli22vDdu3EhwcLCtW97Fixc5deoUMTExJY6dN28et9xyCxMmTCA2NpbY2FgiIiK4ePEiPXv2RK/X87///Y9du3Yxfvx4jEaj7dj169dz6tQp1q9fz5IlS1i8eLFdN9Cihg8fznPPPUd0dLTtPMOHD8dsNjN06FCSkpLYuHEjv//+O3///TfDhw+/at10796dO++8s8xAZvbs2Xz++efMnz+fQ4cO8cwzzzBq1Cg2btxoqztrPamqyubNm/H392fLli22uqxXrx5Nmzbl+++/p379+rz88su2awDYtWsX999/PyNGjODAgQPMnDmTGTNmlKiLt99+m06dOrFnzx4mTpzI448/zrFjx0ot++bNm+nYsaPdtrCwMA4dOlRmYGs1c+ZMDhw4YBfUFs8rNjaWTZs2lZlPgwYNqFOnDps3b77qOatCjZxdU1RQvQ5w7Be4uBsadL16eiGEEELUWo6+vKqqitlsJiQkhLZt29q2b926tdRJKPz9/e3GnG3fvt1hy5ajYKs0Op2OxYsXM2HCBObPn0+HDh3o1asXI0aMsJWrR48epKens2fPHjp27MimTZuYMmUKK1euBGDDhg22QKU4Pz8/3N3d8fT0JCwszLb9ww8/xM/Pj2+++QY3NzcAmjdvbndsQEAAH3zwAVqtlpYtWzJo0CDWrVvHhAkTSpzHw8MDb29vdDqd3XnWrFnDwYMH+fvvv2nQoAEAn3/+OdHR0ezYsYObb765zPqZPXs2bdu2ZfPmzfTo0cNuX25uLq+//jpr1661rSHduHFjtmzZwoIFC+jVqxcxMTEsXLgQk8nEwYMHcXd3Z/jw4WzYsIEBAwawYcMGevXqBVhmVddqtfj4+NhdwzvvvMNtt93GjBkzbPV0+PBh5syZYzc+8I477mDixIkATJ06lXfffZf169fTokULh9d29uxZ6tata7ftySefZPPmzbRp04aGDRvStWtXbr/9dh588EH0er1d2rp16/LUU0/xwgsvOJwkZ9iwYaxevZpevXoRFhZG165due222xg9ejS+vr4l8io+/tFVpCVPFLJ22by0p3rLIYQQQghxjYYOHcqlS5f48ccfbYFHhw4dbC1F/v7+tGvXjg0bNnDgwAHc3d159NFH2bNnDxkZGWzcuNEWqDhr79699OjRwxbgORIdHY1Wq7U9Dg8PJz4+vlznOXLkCBEREURERNi2RUVF4e/vz5EjR656fFRUFKNHj3bYmnfy5EmysrLo16+frUXU29ubzz//3Nb1sGiAbK2nmJgYW+vexo0brxqUHzlyhO7du9tt6969OydOnLBruS36Y4GiKISFhZVZX9nZ2RgMBrttXl5erFq1ipMnT/J///d/eHt789xzz9G5c2eyskoOU5o6dSoJCQl89tlnJfZptVoWLVrEhQsXeOutt6hXrx6vv/66rbW1KA8PD4f5u4K05IlC9W6y/L20u9KzVhSFZs2a2e4LIYQQ4sZWvAUILC15RqOxRJBT/Mt8WYqOpbpWBoOBfv360a9fP2bMmMEjjzzCSy+9ZGspsgYmer2eXr16ERgYSKtWrdiyZQsbN27kueeeK9f5io7JKk3xulEUpVqm2p81axbNmze3tVxaWcfVrVq1inr16tnts7Z6FQ2Qt23bRr9+/ejZsyfDhw/n+PHjnDhxotwBcmnKW1/BwcEkJyc73NekSROaNGnCI488wgsvvEDz5s1ZtmwZ48aNs0vn7+/P9OnTmTVrFoMGDXKYV7169XjooYd46KGHeOWVV2jevDnz589n1qxZtjRJSUmEhIQ4e6mVSlryRKG6HUCjg6wkyEio1Kw1Gg316tWjXr16aDTyshNCCCFudFqtttRb8f/1ZaUt2qpVVtrKEBUVRWZmpu2xdVzeunXrbC1PMTExLF26lOPHj5fZGlV8NkXA1gWytIlUKsLReVq1asX58+c5f/68bdvhw4dJSUkhKirKqXwjIiKYNGkS//rXv+zyj4qKQq/Xc+7cOZo2bWp3K9py2KtXL9avX8+mTZuIiYmxBcivvfYa4eHhdt1US7uG4stZbN26lebNm1/T833TTTdx+PDhq6Zr1KgRnp6edq+Hop588kk0Gg3z5s27al4BAQGEh4fb5ZWTk8OpU6e46aabnC98JZJv26KQZyCM+g6eOQTe1fOrgxBCCCHEtUpMTKRPnz58+eWX7N+/n9OnT/Ptt9/y1ltvcffdd9vS9ezZk/T0dH7++We7IO+rr74qEagU16hRI/7880/OnDnDlStXMJvNTJo0ibS0NEaMGMHOnTs5ceIEX3zxRZkThVxNo0aNOH36NHv37uXKlSvk5ubSt29fWrduzahRo9i9ezd//fUXo0ePplevXnTq1MnpvKdPn86lS5dYu3atbZuPjw+TJ0/mmWeeYcmSJZw6dYrdu3fz/vvvs2TJElu6mJgYVq9ejU6no2XLlrZtX331VYlWvEaNGrFp0yYuXrzIlStXAHjuuedYt24dr7zyCsePH2fJkiV88MEHTJ48ucJ1BdC/f38OHTpk15o3c+ZMnn/+eTZs2MDp06fZs2cP48ePJz8/n379+jnMx2AwMGvWLN5//3277QsWLODxxx9nzZo1nDp1ikOHDjF16lQOHTrE4MGDbem2b9+OXq+3jWt0NQnyagmnuwE0jgF3z0o/v6qqpKSkkJKSgqqqlZ6/EEIIIYSVt7c3Xbp04d1336Vnz560bt2aGTNmMGHCBD744ANbuoCAANq0aUNISIgtUOnZsydms/mq3Q0nT56MVqslKiqKkJAQzp07R1BQEP/73//IyMigV69edOzYkU8++aTMMXpXM3ToUAYMGEDv3r0JCQlh6dKlKIrCd999R0BAAD179qRv3740btyYZcuWlSvvwMBApk6dSk5Ojt32V155hRkzZjB79mxatWrFgAEDWLVqFZGRkbY0PXr0KFFPMTExmEymEi2gL7/8MmfOnKFJkya27osdOnTgv//9L9988w2tW7fmxRdf5OWXX77mRdnbtGljy9uqV69e/P3334wePZqWLVsycOBA4uLiWLNmTakTuACMGTOGxo0b223r3LkzGRkZ/OMf/yA6OppevXqxfft2Vq5caVcXS5cu5cEHH8TTs/K/VztDUeUbd6VLS0vDz8+P1NTUErPsuJLZbGb/38d5c+/XnEo7yPaxy53vKmk2WxZID2979bROMJlMtlm4evToUWndLq4XZrOZ+Ph4QkNDpTtqFZJ6dh2pa9eQenYdqeuKy8nJ4fTp00RGRpaY0KI465g8nU4nY/CrkNRz2VatWsWUKVM4ePDgNb/fK1LXV65coUWLFuzcudMuMHZWWe85Z+MM+ZSr4RRFw4H0X8jWnmD+jl+cOyj1AvynCyzoAatfAHP51qYRQgghhBCiugwaNIhHH32UixcvVsv5z5w5w3/+858KBXiVRYK8Gq6Oly/NPPoA8OXhL507yLcetBlmub/tA/j9xSoqnRBCCCGEEJXv6aeftpsoxpU6derk1ML0VUmCvFpgWrdHUFWFdM0hfj+x9+oHKAr0eh7uWWB5vO1DOLutSssohBBCCCGEqBwS5NUCN9dvSrDSEYB3/lro/IHtRsBNowAVfngC8qpnMUchhBBCCCGE8yTIqyUeaz8WgPP5m9l8+uprh9jc/hr41IWkU/D7jKopnBBCCCGuezJXnxCuURnvNQnyaomR7Xrhp7ZBUUz8a+Przh/o4Q93F0w1fPgHy0LpQgghhKg1rDNi5+XlVXNJhKgdsrIsveeuZekNXWUVRlz/Xu/1AhM3PkCKso+P/vqFxzvf4dyBTW+Duz+EZrdbFkyvAEVRbOuMyFS/QgghxI1Dp9Ph6elJQkICbm5uZU5JL1P7u4bUs+u4sq5VVSUrK4v4+Hj8/f2vackxCfJqkZ6R0TT/cwAncn9h/oG5jGzTC38PL+cOvmmU/WOzGcqx7ohGo6FBgwblKK0QQgghrgeKohAeHs7p06c5e/ZsmWlVVcVsNqPRaCT4qEJSz65THXXt7+9PWFjYNeUhQV4t88EdUxmwfCtmXQKP/vQ6/73/tfJnsv9b2PEpjPkRdPrKL6QQQgghrivu7u40a9bsql02zWYziYmJBAUFyaLzVUjq2XVcXddubm7X1IJnJUFeLVPXN5CHW07h0xP/x+Gsn1h+8A7ua93d+Qxy0uC3aZB1BdbOhAGznTpMVVXS09MB8PHxkV+dhBBCiBuMRqPBYDCUmcZsNuPm5obBYJDgowpJPbvOjVrXN05JRaV5qtvd1NHcgqKovPLnDGLTU5w/2OBrGZ8HsP0/cHyNU4eZzWZ2797N7t27MZvN5S+0EEIIIYQQwikS5NVSiwa/jmLyx6xL4MHvp5VvqtYWA6DLPyz3v58AV05UTSGFEEIIIYQQ5SZBXi0V4R/M/938GqqqkMBWJv/2afky6DsL6neGnBT4ahhkJlZJOYUQQgghhBDlI0FeLXZ/m57cGvQAAKvj/sPKw386f7CbAUZ8Df4NIPk0rHgUZJFUIYQQQgghqp0EebXch4OmEKi0Q9EYeXHb85xKvOz8wd4hMHIZuHlCaCswG6uuoEIIIYQQQginSJBXy2k1Wr4aMg+NMRhVl8TIH/5Bek6O8xnUiYKn9sHtr4LWreoKKoQQQgghhHCKBHmC+r5BvBfzbzAbyNae5N5vn8VkKscMmN6hhffzsiA3vfILKYQQQgghhHCKBHkCgN5N2vBkm1moqkKceTPjVr5e/kzSLsGigbD8YTCb7HYpikKjRo1o1KiRrJEnhBBCCCFEFZIgT9g82ukOBtWdCMCejGW88Ptn5csgPQ4SjsKJ1fDXJ3a7NBqNLci7kRaSFEIIIYQQ4kYj37aFnTdv/wetvYYA8MPFeczd/L3zB9frAP0LWgD/9wqkXqz8AgohhBBCCCHKJEGeKOGLe2ZSV9sDRTGz+OQrfPzXGucP7jjOsn5eXgb8MsW2rIKqqmRmZpKZmVm+hdeFEEIIIYQQ5SJBnihBp9Xy44h5BCkdUDRG/n1oOl/u2eDcwRoNDJ4HGh0cWwW/zwBVxWw2s2PHDnbs2IHZXI5JXYQQQgghhBDlIkGecEivc+On4fPxVaNRNHm8sXcyy/Zvcu7gOlFw57uW+3+8Dzs+rbqCCiGEEEIIIexIkCdK5aP3YNXwT/Eyt0DR5PLKrmf5fPd65w7uMBr6z4awNhA1pErLKYQQQgghhCgkQZ4ok7+HN78MX4y3uSWKJpe39k3mkx2rnTv4lonwyDrwDqnaQgohhBBCCCFsJMgTVxXo6c2vIxbZum7OOziNeX/84NzBOn3h/fgjlrX0hBBCCCGEEFVGgjzhFH8Pb1aPXEwglslYPjn+Ii+u+9z5DPZ/C4d/gIMrICO+6goqhBBCCCFELSdBnnCat97Abw98TLi2O4piZsWFOTz24zvOLYnQvD94BkFeOqx5seoLK4QQQgghRC0lQZ4oFw83Pb8+8CEtPAYB8EfyIu5dNg2jyVTmcYqHHxExo4ngEsqhb+HSHlcUVwghhBBCiFpHgjxRblqNlm+HzebWwDEAnMz9hb5fPkJSVkapx2g0Gpp07k+TNl3QoMKaGbaF0oUQQgghhBCVR4I8USGKovDR4Mnc33AqqllLIjvpt3Qkhy9fKPvA22aAVg9nNsPer1xTWCGEEEIIIWoRCfLENZkRM4rpN70LJk/ydGcY/vNIfji8s0Q6VVXJyckhxxCK2mOyZeMvUyAjwcUlFkIIIYQQomaTIE9cswfb9+bjvkvQmkJBl8ILfz7G7I3/tUtjNpvZvn0727dvx9z9GWj/IAz/QtbQE0IIIYQQopJJkCcqxS0NWvLLsG/xVaNQNHl8feYVHvruZccTsmg0MOQ/0LSv6wsqhBBCCCFEDSdBnqg0dX0C+d+oL2nucQcAezO+pfcXY7mUmlz2gYmnYMXjkJ/jglIKIYQQQghRs0mQJyqVXufGd/e/yV31nkE1a0lR9jJw+X2sOb7f8QEmI3x1H+z7GlY8BmazawsshBBCCCFEDSNBnqgSr/Udz6tdPkIx+WHWxfP8tgls/PtAyYRaHQyeBxo3OLwSfn7aEvgJIYQQQgghKkSCPFFlhkTdwg9DluOjtkDR5HE4bQvfHFhLZl6ufcLInnDPfECB3Utg2YOQk1otZRZCCCGEEOJGJ0GeqFKRgWFsGLWU1l53A5BsOsGQJf/mYkq2fcI298H9n4POAMd/g3ntYNuH1VBiIYQQQgghbmwS5Ikq565z48uhswiv05wkfRKXMs9w5783s/5ovH3CqLtgzM8Q3ByykyHhWPUUWAghhBBCiBuYBHnCJTQaDe3bNCHOMw5/3zySs/IZt3gHr606TJ6xyGQrETfD49tgyEfQ6/nC7fFHYftHEH9EJmcRQgghhBCiDLrqLoCoPQIMAQC0aehG/9CGLNl2lk82n+av00m8P7IDDYI8LQm1Omj/gP3B/3sFjv5sua/zgICGYMoD1QwNb4V+s8Ar2LJfVUFRXHRVQgghhBBCXF8kyBMuoaoqvlpftGYtqTnJfHJ3a7o1Deb55fvZdyGVQf/ezOyhbbizbV1HB0OjWyEvA879CcZsSDhauD/zCtz5buHjZaMgdj/41bfcAhpBs35Qr5NlIXYhhBBCCCFqMAnyhEuYzWbiD8fTIrUFKV4pAPSPDqN1PT+eWrqHnWeTmfT1HraevMKLd0bj4a4tPFhRoOvjlpvJCClnLTc3T0vglxYLOvfC9MlnIfWc5Wa16S3wDLIEeg/+t3B7wjFLEOnuCT7hoHWr2ooQQgghhBCiikmQJ1zGQ+cBQHJusm1bPX8Pvnm0K++uPc5/Npxi6V/n2XU2mX+PvImWYb4lM9HqIKiJ5VaaB7+F1PMFt4sQtx+O/QZZiZCZYJ/2mwcg8aTlvqIFv3qWlj+vUAhqCr2nF6b9433Iy7QEglp30OoL7xv8IaBjYdqLuy3dSTVuljJr3Qvv6wzgE1aY1mQEjVa6mAohhBBCiEohQZ5wGYPOAECOKYes/Cw83Sxj8HRaDVP6t+SWxsE8vWwvxy9ncNcHW3m+fwvGd49Eoyln8OMbbrlFdC7clp8DCUfsF1pXVctfj0BLi6ApD1LOWW4A9TraB3nb50PaBYenVEJawtAfCjes+AdcKWV2UL8G8EyRheE/vQ1i91qCTK1bYTCoaME7FCZuK0z739FwcY8lIFQ0BbeC++5e8OiGwrSrJsOFHcXSWm9aGLeqMO2GN+H89sL9FMv/vkWFraV/fQLntjvIU7Hc+s8Gvbcl7cHvLF1si6exnuPWZ8DD35L25Fq4sLPI+YuV4aaHLME0WK4rdp+Dayt43HwgeAVZ0sYfhcsH7fcXvcYGXcEz0JI25TwknrDPs2hdhDQHD8vYUjITLa8HR/WlaCyBvN7HkjYvE7KSSqax3ty9QKe3pDUZwZTrIF+NdDcWQgghhFMkyBMu46ZxQ6tYumGm5KbYgjyrW5sF8+tTPZj63X7+dzSeV1cdYf2xeOYOa0e4n8c1ntwAdW+y36Yo8OQuy32zGTLjIfkMJJ2G7CTwCrFP336k5Yu6KRdM+WDMBbPR8tevnn1av/pgzrd8YTfnWwJI6313++vGXBB4qiYwmoCcwn0arX3a9Mv23VCL0hdr+Uw8aQkeHVGK5Ru3H079z3FaS+EK757/Cw4uLz1pv5cL75/eBLsWl56286NFgrz/wfYy1kZscYctyFOOr4Ytb5ee9tGNhUHesVWw7uXS0477FRp2s9w/ugp+m1p62lHfQdO+BWl/hp/+WXra+7+wLAsCcOxX+O7h0tPe8zG0G265f2K1pYW5NHe+C53GW+6f3gxLR1ieT0UpaBEuEsj3eh46jbOkjd0Py8cXBIvaYsGjFjqOgw4PWdImnUb58Z8E5BtRDB4lj2l1V2F5M+JhzYzCINR6busxDbtD9BBL2tx02PxOsfyKlL1OG2hWUL/GPNi9pCAg1tqXVdGAf0NoeIslrapanju7NEWO8wqBOlGFdXhxd5FAW2ufr7sX+BYZG5wehy3YLp6vRmf5bBFCCCGuMxLkCZdRFMXWmpeck0xd75KTrIT46Fk4phNf/XmOV1cdZuvJRPq/u4nX7mnD4HYOJmWpLJqClhefMEvLjiN9/q/Uw1WzGeKLrPv30PfOn3vcrwVBYH5BQJhvCfxUc0FrThF3/RtyMyz7VDOgFt4vnva2GdB1on0aVS08rqiuj0OrwfZpit40RT4q2g4vCJgdpFPNltlPrZrdDp7BjsugqpYv1FYRnS2T6qillLdIEKuGtkKJGlKkHrA/xlAk4PVrAJE9HVxXwWNraxtYWvTqtHZQ1oIyuBUpr5sn+NR1/FyoqqWLrpWiWLrpFj+/9Xko2lVXLfbcFFf0eTbnW1qhS5OfXeR+lqWVsjQtBhbez01HObMJfWlpg5vbpWX/N2WX1xbkZcCWd0pP22F0YZCXnwW/TC49bev7CoM8sxGWPVh62hZ3wMilhY8X3m6pO0cie8GYHwsff9gZclIdp63fGR75vfDxO1GQHus4eAyNgodXF6b9bCCknEPRaAg2Kyhu7pb3mUYL/g3sy/vjk5axxhqtJV9rOkVjGWs8+L3CtFvnWX6s0hT0BtBoC49z94SeUwrTHloJaRcL0haU1Xqc1g3a3l+Y9sIuS3d3jc7yeWnLv6AsdTsUtjSnxVpar63ntkursfxYY33Nm82FLfxCCCEqjQR5wqUcjcsrTlEURnVtyC1Ngnhm2V72X0jlyaV7+N/ReGbdHY2voYZNjlI0ILmakBbOp63X8epprBrd6nzaZn0Lv4hfTctBlpszoocUBgOlsa6R2HootB3mXL5th5Uj7f32X2wrK9/WQy234lS1ZFDXfAD865LjoFg1gbt3YdqIrvDPPZZ9ZlNBOlPhsT7hhWlDW8HYX+z3m4sEncHNCtP6R2C+9xPSUpLx9fFBgzV9wbHh7QrTegRAv1cK87XlWfC4/s2Fad08LD88FC+r2WS5hogiP7BodBA1pCBN0esrOK5OtH29RXSxL2PR4/zq26f1q2dpWbc7v7lk/UJh66D1x4Ti+4qy/jijmgGj/b78TPvH6bGQdgEFB/+IjTn2jy/sgvhDJc8P9s8xwJGfLN2ZHTH42Qd5Oz+D0xsdp9UUC/I2z4VjvzhOCzDjCrald9f8X9mt/VPPFrbg//wU7P68sGW0eGA6cVvhGOaNc2Df0mLBbpHjhi0qfK73fAWHfygMhjU6FEWDX14+iqc39P6XJZgG+HsDnFxnyUfrVpCvrvBx1BDLEACwTNYVu9/SpV6js9STNdDVullel9Yu3VlJRQJjXcn83Twt+QghRBWQTxcHzp8/z0MPPUR8fDw6nY4ZM2YwbJiTX+ZEmYq25F1NkxBvvnu8G++vO8EH60+yYs9F/jqdxJxhbenWJLiqiypE1XPUgqHVOf/Fz90TAhs7l9bgB426O5fWIwBa30dOfDy+oaFljwX0DITuZXRbtcvXHwbMdi6t3hvuX+JcWq0bPLzGubQAT+1zPu3UM5a/1oC8aGBY3ON/WAK9EkGsueRzOnIp5GdjNhlJTkwgwN8XjfUYXbEuoP1ehpyUwrzNBQGq2UHaDmMsXYrt0potf3XF2mYb97KM+y2ezmwq2VU8sLGlBd9ccF7r+a1lKdoF3M1gaXm3pTMWdksH+7ytP9yoZkuPhhKKvD8y4yHplIM0BYoef+WYpetzsZxs/QxumVS44/xf8Me/S8+3XsfCIO/E77DmhdLTPrQSmvS23D+0AlY9W3raEUuh5R2W+/v/Cz9MKggEiwWQWh0MeKOwtf30Jlg7q0jQqLUPNjtPgMYxlrTxR+Cvj4vkp7UPNpveVviDYHqcpWu5XUBaJO+Q5oWfN3mZlvHOdmUtbLVV8nKAUEtac8HrSusmLbZCuJAEeQ7odDree+892rdvT1xcHB07duSOO+7Ay8vr6gcLhxRFISwsDPdEd0h3LsgDcNNqePb2FvRqEcIzy/ZxLimLBz75k9G3NGTqgJZ46eUlLIRwAVtAXkbA61WOH59CW1n+ms3ku8dDWcG0sy3nUDiu0hk9nnM+bf/XnE9794eWW3HmgiC2aPfvgW9C35lFgsFiAadnUGHaW56wtIiXSFdw37tOYdqoIZZuxUWCbrMpn4y0VLw9DWiKtoDW72QJ+qzBqNlYMIa6YBx10TL41bN06TWbLPvMxoIu9gVlKDo2Wutm+cHEbCrshl+0m3DRejDlFYz3zgVHPYmLtu5mJcLFnaVUPvZdr1POW1psS+MRUBjkXTkBPz9detq+s+DWgv0JR+HTPg6TaQCvjhOhfsFrJvGEpdszFLbYFg1gb55QOMlZWix8MaRY66d7wUzWbpaeDp0nWNLmZsDq6QX73UumrdMGmt9uSWs2wYFvi0xuViSd1t0ybrforN0p54rMil0kXfEfP4S4jsk3ZAfCw8MJD7f8AwgLCyM4OJikpCQJ8q6BRqOhZcuW+KX5oR5Ry+yu6UjHhoH88lQPXlt1hKV/nePzbWf539F43hralm5NpVVPCCGuexoNJYJkvXfhbLxXE9DIcnNGvQ6WW1FmM1nx8XgXD6ib9LHcnBF9j+XmjA6jLbfirAFh0SAv+p6C4NHoOIAs2mIf0dXSCmhLUyQ4NedbxopaBTaGmH8Vpi0ewIa0LEzrEQAt7yxy3qJBb75912CNzjLe2ZpPkXxVsxFVU2RYRdFWXGuLbdFWV2ORscPGHEsAWZqi9ZCXaenqW5r2owqDvPxsWPFY6WmjhhT2HFBVeK9NKQmVgjG+Xxduer+j5TnSFgkerYFk3fb2P5D8/Ixlsja7YLMgOPWLgI5jCtPu/9YS9DsKYA1+EF5kMrmk04DqOIDVuMnMzLXUDRnkbdq0iTlz5rBr1y5iY2NZsWIFQ4YMsUvz4YcfMmfOHOLi4mjXrh3vv/8+nTt3dpxhGXbt2oXJZCIiIqKSSl+7+ev9Aedb8ory1uuYfW8bBrUJZ+p3+7mQnM0Dn/7JqK4NmDawFd7SqieEEOJ6Zx1zWJS7l/1EVGWxLhPkjOCmEFPGjMFFhbWGEV85lza8nf1SQEWoZjOZ8fHYriakpaXbs12LZpGbdQwjWMZfjvmpMLg15RXOVG3Ks5/0yd3TMiGaqWDCMlNeQbBZEEQ2uMW+YE36FEubX/i4aABryrd0gzblW1qY7a+uZJfT5DP2gWxRxbtpH/gOcsuYyKlokPf7DMvYXUfqtIbHNhc+/uq+wjV/iwtoZN9FfcldcOV4sVbKgqDTuw48UGQirXWvQNLfJYNGrZtl0rKYaYVpj/4CGXEFawi7W5Zd0ha5RfYoTJseZ3kOrOsN6/SFwax06a00N+S34szMTNq1a8f48eO59957S+xftmwZzz77LPPnz6dLly6899579O/fn2PHjhEaaukj3r59e4zGkm/KNWvWULeuZRbHpKQkRo8ezSeffFK1F1QLqKqK2WwmwC0A1IoFeVa3Ngtm9TM9mf3LEb768xxfbj/H+qMJvHVfW7pLq54QQghx/dBo7QO5srh5WGZDdobex34ioTLTesNDK5xLq3OH/7tsuW82F1kGqSAg1BT76vzw744DR1OefVdfsASl+Zn2aax/rRMBWTXuDVlXCtMWLUfRrqVgmcTH3adIK2mRCb2KzvQMlsmASgsefYtNUvX3htK7Bhv87YO8Pz+yjBd1ROMGL14pfPzzM6VM5KRYAr5p5wvX5v3tX5Z1dG2BY7HA8N6PC38gObDcskROiXT6whmDrWnjj0L6pWJBaZFjvMMKg/SrzXp9nbohg7yBAwcycODAUve/8847TJgwgXHjLOtDzZ8/n1WrVvHZZ58xbZrlBbl3794yz5Gbm8uQIUOYNm0a3bp1u2ra3Nxc2+O0tDQAzGYzZuug8mpgNpttwVV1M5lMbNmyhYyUDBQUknKSrqlcnm4aXrk7moGtw5j2/QEuJGfz4Kd/MuLmCKYOaIGfh2tn4Lye6romk3p2Halr15B6dh2pa9eocfWsKWjlKvq1oui1hbcv+/iiaW9+xPm0jsa12iUtUs+PFpslt2grqNlsn+/9X1i6uhYNXK0BpMbNPu0tkwpb3QoCWMUalOr0luWjCigRXQsCzSJrCVuDTo3WPq2iAa0exVT43dlCBWMOZkVrK4eSeh7lyrHS60FVC9OeXIuyb2npaVvcYVvmSfnrY5SdC0tP++QeW/dw9Y9/ozYdcd28pp0txw0Z5JUlLy+PXbt2MX36dNs2jUZD37592bZtm1N5qKrK2LFj6dOnDw89dPVB7LNnz2bWrFkltickJJCTk+PgCNcwm82kpqaiqiqaau6PbTKZyMzMRDVafg1JzEokvui6chXU1Ac+H9mC/2y9yPJ9CXyz4zyrD8XyVM/69G8RiOKiZv/rqa5rMqln15G6dg2pZ9eRunYNqWfXKFc9pxf9vuUNGm/L8FhHv4cX/W4WfAuU1UGqaNqohyHKybQxb0MMBcvcGFHMeQUBpCWYNCck2JJq2/4DbdP7wJxnCzAVc2Ha7MRU0FjWi9WHdsGtvTeKKc+2XykINBVzPqkpGaiZlu+hXoo3hqAWBfst6YqW40pyOuZ8S5k9MjJISUm5bl7T6enpTqWrcUHelStXMJlM1KlTx257nTp1OHq0jMG8RWzdupVly5bRtm1bVq5cCcAXX3xBmzaOB+JOnz6dZ58tnCY5LS2NiIgIQkJC8PUtxxpolcxsNqMoCiEhIdX+ojSZTHh5eRGgDYBMSMtPs3WdrQxvDQ/nvs5JvLDyIKcSMpn52xl+P5nOy3dFExlc9RPmXE91XZNJPbuO1LVrSD27jtS1a0g9u0atqeerfFf0sUs7prRkAIQUfTBwJjDTYToV+9jWHPMU/imZ101dGwyGqyeiBgZ5leHWW28tV5OsXq9Hr9eX2K7RaKr9xaAoynVRDlVVURTFthh6Wl4aZszoivdtvwZdmwTzy1M9+GTT37z/v5NsPZnIwH9vYWJMEx6PaYJeV7VTH18vdV3TST27jtS1a0g9u47UtWtIPbuG1LOL6L1RlKzrpq6dLUP1l7SSBQcHo9VquXz5st32y5cvExYWVk2lElYGnQEFBRWV1NJmmLoGep2WSX2aseaZnvRoFkye0cx7a08w8L3NbD155eoZCCGEEEIIcYOrcUGeu7s7HTt2ZN26dbZtZrOZdevWccstt5RxpHAFjaLBt2Cx2JTclCo7T8MgLz4f35n3R95EiI+ev69k8uCnf/LEV7s5n5RVZecVQgghhBCiut2QQV5GRgZ79+61zZB5+vRp9u7dy7lz5wB49tln+eSTT1iyZAlHjhzh8ccfJzMz0zbbpqheAe6WqZSTcpKq9DyKojC4XV3WPtuL0bc0RKPAqgOx9H1nI++sOUZWXinr2gghhBBCCHEDuyHH5O3cuZPevXvbHlsnPRkzZgyLFy9m+PDhJCQk8OKLLxIXF0f79u357bffSkzGIlzHOjgYwD/LHzKqtiWvKD8PN16+uzUjOzdg1k+H2P53Ev/+30n+u/MC0+9oyV3t6rpsFk4hhBBCCCGq2g0Z5MXExKBeZWHCSZMmMWnSJBeVSFyNRqMhOjoaAL84PwDSctNcWoZW4b4sndCV1YfieHXVES4kZ/PUN3v5fNtZZg6Opk19P5eWRwghhBBCiKpwQ3bXFDc2P70lmHJVS15RiqIwoHU4a5/txZT+LfB017LrbDJ3fbiFZ5btlfF6QgghhBDihidBnnA5a5CXmlf5s2s6y+Cm5YneTfnfczHcc1M9VBVW7LnIbW9v5JWfD5OUmVdtZRNCCCGEEOJaSJAnXMJkMrFhwwY2bNiAr5tldk1Xd9d0JMzPwLvD2/PTpFvp3jSIPJOZhVtO0+ut9Xy4/iTZeabqLqIQQgghhBDlIkGecDk/t4KWvCpYJ6+i2tT346tHuvL5+M5EhfuSnmtkzupj9JqznqV/ncNoMld3EYUQQgghhHCKBHnC5azr5FVnd83S9Gwews9P3sq8Ee2pH+BBfHou078/QN93NvLdrgsS7AkhhBBCiOueBHnC5WxB3nXUkleURqNwd/t6rHuuFy/eGUWglztnErN47tt99Ht3E9/vlmBPCCGEEEJcvyTIEy5nm3jlOg3yrPQ6LeNvjWTz872ZOqAlAZ5unL6SybP/lWBPCCGEEEJcvyTIEy7n614w8Upe9U+84gwvvY7HY5qwZWqfEsHe7e9u4rtdF8iXYE8IIYQQQlwnJMgTLmdtycs2ZpNryq3m0jjPGuxtntqH5we0IMDTjb+vZPLct/vo/fZGlu25TFaesbqLKYQQQgghajkJ8oRLKIpCYGAggYGB+Lj7oFW0wPXfZdMRb72OiTFN2VzQshfsredSSg7vbrzArW9u4J3fj8s6e0IIIYQQotpIkCdcQqPR0LZtW9q2bYtWq7V12bwRgzwrb1s3zt68enc09f30pGTn8+91J+j2xjpe+uEg55OyqruYQgghhBCiltFVdwFE7eSn9yM5N/mGDvKsDG5aHujSgN4N9exJMDN/098cvJjGkm1n+WL7WfpHhzGueyQ3NwpAUZTqLq4QQgghhKjhJMgT1eJ6XiuvorQahTvahDOobV22nkxk/sZTbDl5hV8PxvHrwTii6/oytlsjBreri8FNW93FFUIIIYQQNZQEecIlTCYTW7duBaB79+74uVsmX0nLvTFm2CwPRVG4tVkwtzYL5mhcGkv+OMP3uy9y6FIaU5bv541fj/JglwY82LUhdXwN1V1cIYQQQghRw8iYPOEyZrMZs9my1IC/3h+4scfkOaNlmC+z723L9um38fyAFoT7GUjMzOPf/ztJ9zf+x8SvdrH5RAJms1rdRRVCCCGEEDWEtOSJamFbEL0GddcsS4CXOxNjmvJoj8asPnSZRVtPs/NsMr8ciOOXA3E0CPRkROcI7utYn1Afad0TQgghhBAVJ0GeqBa2MXk1vCWvOJ1Ww6C24QxqG87hS2l8s+McK3Zf5FxSFm/9dox31hynX1QdRnZuwK1Ng9FoZKIWIYQQQghRPhLkiWphHZNX24K8oqLq+vLy3a2ZNrAlq/bH8vVf59hzLsU2UUtdPwP3dKjHPTfVp2mod3UXVwghhBBC3CAkyBPVwtZdsxYHeVae7jqGdYpgWKcIjsSm8c1f5/h+z0Uupebw4fpTfLj+FO0i/BnaoR53tq1LoJd7dRdZCCGEEEJcxyTIE9Wito3Jc1arcF9m3d2a6Xe0Yt2ReL7ffYENxxPYdz6FfedTeOXnw/RuEcqQm+rRu0UoHu6yFIMQQgghhLAnQZ5wGX9/f9t96a5ZNoOb1jZ270pGLj/uvcT3ey5w8GIaaw5fZs3hy3i6a+nbqg53tg2nZ/MQWXtPCCGEEEIAEuQJF9FqtbRv3972WLprOi/YW8/4WyMZf2skx+LS+X7PBX7eF8vFlGx+3HeJH/ddwkevo19UHe5sF86tTUNw18nqKEIIIYQQtZUEeaJaWIO8LGMW+aZ83LRu1VyiG0OLMB+mD2zFtAEt2Xs+hZ/3x7JqfyxxaTl8v+ci3++5iK9BR//oMO5oG063JkHoddLCJ4QQQghRm0iQJ6qFt5s3CgoqKql5qQR7BFd3kW4oiqJwU4MAbmoQwAt3tGLXuWRW7Y9l1YFYEtJz+XbXBb7ddQEvdy0xLULpF1WH3i1C8fOUYFoIIYQQoqaTIE+4hMlkYvv27QB07doVrVaLr96X1NxUUnJSJMi7BhqNws2NArm5USAz7ozir9NJrDpwid8PX+ZyWi6rDliCP51GoUvjQG6PCqNvVB3q+XtUd9GFEEIIIUQVkCBPuEx+fr7d4wB9AKm5qSTnJldTiWoerUbhliZB3NIkiJfvas3+i6n8fjiONYcucyI+g60nE9l6MpGXfjxEdF1f+raqQ++WobSt5ycLrwshhBBC1BAS5IlqE2AI4EzaGVJyU6q7KDWSRqPQPsKf9hH+TOnfkjNXMvn98GXWHI5j59lkDl1K49ClNOatO0Gglzu9mocQ0yKEHs1CZC0+IYQQQogbmAR5otr46/0BSM6RljxXaBTsxYSejZnQszGJGbmsOxrP+qPxbDlxhaTMPFbsuciKPRdRFGgf4U9M81BiWoTQRlr5hBBCCCFuKBLkiWoTYAgAkJa8ahDkref+ThHc3ymCfJOZXWeT2XAsgQ3H4jkal86ecynsOZfCu2uPE+jlzi1Ngri1aTC3Ng0mItCzuosvhBBCCCHKIEGeqDbSknd9cNNq6No4iK6Ng5g2sCWxqdm2gG/ryUSSMvMsM3fujwWgQaAn3ZsG0b1pMN2aBEvXTiGEEEKI64wEeaLaBBoCAWTiletMuJ8HIzs3YGTnBuQZzey7kMKWE1fYevIKe8+ncC4pi3N/ZbH0r/MARNf1pXvTYLo3DebmRgF4usvHihBCCCFEdZJvY8JlfHx87B5bW/JSclJcXxjhFHedxrY8wzP9mpORa+Sv04lsOZHIH6eucDQu3TaBy8eb/kanUWhT34/OkYF0iQykY8NA/DxkbT4hhBBCCFeSIE+4hFarpWPHjnbbrGPypCXvxuGt19GnZR36tKwDQHx6DttOJbL15BW2nkzkYkq2bTzfgo1/oyjQMsyXLpGBdI60BIshPvpqvgohhBBCiJpNgjxRbWRM3o0v1MfA3e3rcXf7egBcSM7ir9NJttvfVzI5EpvGkdg0Fv9xBoDGIV62Vr4ODfyJDPZCUWT2TiGEEEKIyiJBnqg2AXqZXbOmqR/gSf0AT+7tUB+wtPTtOJ3MX6cT+fN0Ekfj0vk7IZO/EzJtY/r8Pd24KcKfDg0C6NAwgHYR/njr5aNJCCGEEKKi5JuUcAmTycSOHTsAuPnmm9FqtbbumtnGbLKN2XjoPKqziKIKhPoYGNQ2nEFtwwFIycpj55lk/jqTxO6zyey/mEpKVj7rjyWw/lgCABoFmtfxoUPDAEvw1zCAxtLaJ4QQQgjhNAnyhMvk5OTYPfZy80Kn0WE0G0nNTZUgrxbw93Snb1Qd+kZZxvTlGc0cjk1j99lkdp9LZs+5FC6mZHM0Lp2jcel8/ee5guMsrX1t6vnRwFulh4cfdfzk9SKEEEII4YgEeaLaKIpCgD6AhOwEknOSCfMKq+4iCRdz12loH+FP+wh/xhMJwOW0HPacS2b3uZRSW/v48RThfgba1vejbX1/2tb3o009P/w9Zc0+IYQQQggJ8kS18jf424I8IQDq+BoY0DqcAa0tXTytrX17zyWz/0Iqu88mcjY5h9hUy231ocu2YxsGedKmnp8t+Iuu64uPQZZwEEIIIUTtIkGeqFbWyVdkGQVRmqKtfWazmfj4eDz9AjkSm87+C6nsv5jKgQspnEnM4mzB7ef9sbbjGwZ5El3Xl6hwX6Lq+hIV7kcdX72M8RNCCCFEjSVBnqhW1slXZIZNUR7eeh1dGgfRpXGQbVtqVj4HLqay/2IK+8+nsv9CCpdSc2yB3y8H4mxpg7zcCwI+X9vfyGAvdFpNdVyOEEIIIUSlkiBPVCtZK09UFj9PN25tFsytzYJt25Iy8zgSm8bhS2kcjk3j0KVUTiVkkpiZx+YTV9h84ootrV6noWV4YeDXMsyH5nV88POQ7p5CCCGEuLGUK8j74YcfuPvuu8nMzMTLy6uqyiRqKE9PzxLbpCVPVKVAL3e6Nw2me9PCwC8n38Txy+kcvpTGoYLg70hsGll5JvadT2Hf+RS7PML9DLQI86FFHR9aFAR+TUO9MbhpXXw1QgghhBDOcTrI27RpE88//zyvv/46q1atkiBPlItWq6Vz584ltltb8pJyklxcIlFbGdy0BTNy+tu2mc0qZ5OyOHQp1dbqdzwunUuphRO8bLDO7IllLb9GwV62wM/6t2GQF1qNjPUTQgghRPVyOsgLDw/Hw8MDf39/srOzHaZZvnw533zzDYGBgbRu3Zq2bdvSrl07AgICKq3AomaxTrwiLXmiOmk0CpHBXkQGe3Fn27q27Wk5+RyPS+fY5XSOxRXcLqeTkpXP3wmZ/J2Qya8HC8f66XUamoR40zTUcmtW8LdhkBfuOhnvJ4QQQgjXcDrIa9asGf/+97/p2bMnZrPZYZrnnnuOzz77DI1Gw8GDB/nyyy95/vnnSU5Opnnz5qxatarSCi5qBmt3TRmTJ65HvgY3OjUKpFOjQNs2VVVJSM/laFw6xy+n2/4ev5xOTr5luYfDsWl2+Wg1Cg2DPGlaJABsGupNkxBvvPQyNFoIIYQQlatc3y569uwJgEbj+Bfppk2bcttttwHQu3dv23az2czx48crWkZRA5hMJnbt2gVAx44d0Wot45kkyBM3GkVRCPU1EOproGfzENt2k1nlXFIWJ+MzOBmfwYn4dE7FZ3AqIZOMXKOt5W/N4ct2+dX1M9CkSODXNMSbyBAvQrxlmQchhBBCVEyl/oQcGRnJ008/zYwZMwgKKpzaXKPR0LJly8o8lbgBZWVlldhmHZOXmpeKqqrypVbcsLRFunz2i6pj266qKnFpObbgz3o7lZDBlYw8LqXmcCk1x26mT7AsE9Eo2JPIYO+CfAvvy4yfQgghhChLpQZ5DRs25MCBA/Ts2ROTyUTLli1p3bo1r776amWeRtQgvu6+ABjNRrKN2Xi6lZyBU4gbmaIohPt5EO7nQY9mIXb7UrLy7IO/BEvwdzE5m4xcIwcvpnHwYlqJPAO93G0BZdFboyAvPNxl1k8hhBCitqvUIG/GjBm2+7m5uRw5coQDBw5U5ilEDeOh88BN40a+OZ/U3FQJ8kSt4u/pXmLMH0Cu0cT5pCz+TsjkTGImp69k2u5fTsslKTOPpMw8dp0t2c053M9gCfiCvWgY6EmDQE8aBFn++hikBVAIIYSoDcod5J09e5b9+/dTp04dh1PiW+n1etq3b0/79u2vpXyihlMUBT+9H1eyr5Cal0o44dVdJCGqnV6npWmoD01DfUrsy8g1cuZKQfCXYAkATydagsDU7Hzbkg9/nEoscWyglzsRBYFf8QAwzNeARpZ/EEIIIWqEcgV5S5cuZezYseTn56MoCjfddBO//vorISEhVz9YiFL4uRcEebmp1V0UIa573nodrev50bqeX4l9yZl5nC4I/s4kZnIuKYuziVmcT8oisaD1Lykzr8SC7wDuWg31Az1swV/9AA/8tPm0MXvQKNhbuoEKIYQQN5ByBXmzZs3igQceYPr06Zw/f56pU6cybdo0Fi5cWFXlE7WAn97yZVWCPCGuTYCXOwFe7nRoUHJt0vScfM4nZXMuKYtzSfYB4IXkbPJMZtsMoPZOARDio6d+gAf1/D2oH+BJvQAP6gd4EBHgQT1/TwkChRBCiOtIuYK8v//+m99++41GjRrRvHlzvvzySzp27ChBnnCKwWBwuN1Xb5l8JS2v5AQTQojK4WNwI6quG1F1fUvsM5rMxKbmFASAluDvXGImf8encSktj7QcIwnpuSSk57LnXIrD/IO83C1BYEBBEOhvCQKtAaG3rAcohBBCuEy5/usajUY8PQsnxmjZsiVms5m4uDjCwsIqvXCi5tBqtXTt2tXhPj93ackTojrptBoiAj2JCPSke8E2s9lMfHw8oaGhpOeYOJeUxcUUS6tf4S2Li8nZpOcaSczMIzEzj30XHL+P/T3d7AO/gvt1/S23AE83WUJFCCGEqCTl/ml1yZIldO/enbZt2+Lt7Y1Op3O4/pkQzrJ118yTIE+I65GfpxttPP1oU7/kOECA1Ox8W8BnDQCLBoSp2fmkZFluhy45brE3uGkKlpowEO7nQV3/wr91/S3bZXZQIYQQwjnlCvJ69OjBq6++Snp6OhqNhsjISHJycli4cCF9+/alU6dO+PiUnA1OiLJYg7y0XOmuKcSNyM/DDT8PP6LrOg4C03PyuZiSXSQIzOJiiuX+pZRsrmTkkZNvtswUeqX4mMBCPnod4f72QWC4X2EQWNffA4ObjA0UQgghyhXkbdy4EYATJ06wa9cudu/eze7du/noo4+YPXs2Go2GZs2aceTIkSoprLhxmUwm9u7dC0D79u3Ragu/iEl3TSFqNh+DGy3D3GgZVnI8IEBOvonLaTlcSskhNjWb2NQcLqUU/r2Ukk1ajpH0XCPplzM4fjmj1HMFeLrZBYFhfgbq+BoI8zUQ5qenjq8Bb71OuoYKIYSo0So0Er5Zs2Y0a9aMESNG2LadPn2anTt3smfPnkornKhZ0tPTHW6X7ppC1G4GNy0Ng7xoGORVaprMXCOxqdm2QNBRQJiVZyI5K5/krHwOx5beM8DTXUuYb0Hw52cg1FdvCQJ9DdTxs/wN8dHjptVUxeUKIYQQVa7SpjuLjIwkMjKSYcOGVVaW1S4rK4tWrVoxbNgw5s6dW93FqbGss2tKS54QojReel2pC8QDqKpKWraRS6nZdkHg5bRcLqflEJeaQ1xaDuk5RrLyTPx9JZO/y+gaqigQ5KUnzM8SAIZaWwOLBIJ1fPX4eciEMUIIIa4/Mqd1GV577bVSZ4QUlUfWyRNCXCtFUfDzdMPP041W4Y67hQJk5Rm5nJZLXGqOJfhLs/y1BoLWoNBoVrmSkcuVjFwOXiy9VdDgpiHUx9LyF1pws9w3EOKrJ8RbT6ivniAvPVqNBINCCCFcQ4K8Upw4cYKjR48yePBgDh48WN3FqdGsY/JknTwhRFXzdNcRGawjMrj0rqFms0piZl5h8JeWw+VUa0CYa9uWkpVPTr7Ztr5gWTQKBHmXDASDvd1xN+fQLNuNMD8PQnz0MnmMEEKIa3ZDBnmbNm1izpw57Nq1i9jYWFasWMGQIUPs0nz44YfMmTOHuLg42rVrx/vvv0/nzp2dPsfkyZOZM2cOf/zxRyWXXhRnbcnLNmaTa8pFr9VXc4mEELWZRqMQUhCIta7neMZQKJwwJiE9l/j0XOLTckjIyCU+zfLYuj0xMxezim1B+UMOc/vbds/HoLMLBG33ffW2VsMQb0tXUY20DgohhHCgwkFeSkoKCxcutM2kGR0dzfjx4/HzK/0fYmXJzMykXbt2jB8/nnvvvbfE/mXLlvHss88yf/58unTpwnvvvUf//v05duwYoaGhgGWGR6PRWOLYNWvWsGPHDpo3b07z5s0lyHMBbzdvtIoWk2oiLTeNEM+Q6i6SEEJclTMTxgAYTWaSMvMsgWB6QVBYEAjGp+dwKSmD5Bwz8em55BnNpOcYSc8xciqh9DGDADqNQpC3O0FeeoJ99AR7uxPirSfYW0+QtzvBBfeDfdwJ9HRHJxPJCCFErVGhIG/nzp30798fDw8PW+vYO++8w2uvvcaaNWvo0KFDpRayuIEDBzJw4MBS97/zzjtMmDCBcePGATB//nxWrVrFZ599xrRp0wBs0/k7sn37dr755hu+/fZbMjIyyM/Px9fXlxdffNFh+tzcXHJzc22P09Is3Q7NZjNms7m8l1dpzGYzqqpWaxmKlkWn09nuF5+owNfdl+TcZJJzkgkyBFVHEa/J9VTXNZnUs+tIXVcejQLB3u4Ee7sTFW4/cYzZbCYhIYGQkBAURSE9x2jXCpiQnmtpHSwIDK8U3E/NzsdoVgu6j+ZCbNllUBQI8HArCAAtAWFhIFgYEAYVlFOvq3ldRuU17RpSz64h9ew611tdO1uOCgV5zzzzDHfddReffPKJ7Yu70WjkkUce4emnn2bTpk0VybZS5OXlsWvXLqZPn27bptFo6Nu3L9u2bXMqj9mzZzN79mwAFi9ezMGDB0sN8KzpZ82aVWJ7QkICOTk55byCymM2m0lNTUVVVTSa6v8Ft2nTpgAkJiaW2Oel9SKZZM5ePotvXumTJlyvrre6rqmknl1H6to1HNWzL+DrA018dFj+TZdsKcw3mUnOMpKUlU+Sg7/J2YWPU7ONmFUK9uVDfOnrDFp5u2sJ9NIR6OlGgIflb6BnweOCv/4eOgI8dHjrtTfEDKPymnYNqWfXkHp2neutrktbkqy4CrfkFQ3wAHQ6Hc8//zydOnWqSJaV5sqVK5hMJurUqWO3vU6dOhw9erRKzjl9+nSeffZZ2+O0tDQiIiIICQnB17f6AhZri1lISMh18aIsS6BnIBeyLqDx1Ni61N5IbqS6vpFJPbuO1LVrXEs913MyncmskpyVx5WMPNuMoVcy8ki0Pc4jMSOXxEzL43yTSkaeiYw8E+eSc6+av5tWIdDLvfDm6U5Qwf0g78Lt1m2+huoZSyivadeQenYNqWfXud7q2mAwOJWuQkGer68v586do2XLlnbbz58/j4+P4zWMblRjx469ahq9Xo9eX3KyEI1GU+0vBkVRrotyXI118pW0vLTrvqyluVHq+kYn9ew6UteuUdX1rNFAqK8Hob4eV01rXW8woSAYTLQLDHPtAsWkjDwy80zkm4p0G3WCVqMQ4GnpFmofAOoJ9LbcD7IFiHr8K3GCGXlNu4bUs2tIPbvO9VTXzpahQkHe8OHDefjhh5k7dy7dunUDYOvWrUyZMoWRI0dWJMtKExwcjFar5fLly3bbL1++TFhYWDWVSphMJg4cOABAmzZt0Grtx3sUDfKEEEJUj6LrDTYN9b5q+px8E0mZeSQVtAJa7ydm5pGUYfmbmFmwPSOP9FwjpiJrEDpDo0CAZ2FAGOytLwwOiwSK1lZEf0933HXV/0VMCCGqU4WCvLlz56IoCqNHj7bNUOnm5sbjjz/OG2+8UakFLC93d3c6duzIunXrbMsqmM1m1q1bx6RJk6q1bLVdSkpKqftkQXQhhLjxGNy01PX3oK7/1VsJAXKNJpIz80nMtLQS2gLCgkAwMcP62NKFNC3HMp4wsSCds7z1OgK83Aj0dCfAy50AT3f8Pd1wV/OoH5JDkLeeAE93WxoJDIUQNU2Fgjx3d3fmzZvH7NmzOXXqFABNmjTB09OzUgtXmoyMDE6ePGl7fPr0afbu3UtgYCANGjTg2WefZcyYMXTq1InOnTvz3nvvkZmZaZttU1x/rAuiS5AnhBA1l16nJcxPS5ifc2NK8k1mkgsCPEsAWLK1MCkzjyuZuaRk5ZOSlYdZhYxcIxm5Rs4nZTvI9ZLDcxUNDP0LWg4DPN0J8HQjoKCl0N/TTVoMhRA3hAoFeefOnSMiIgJPT0/atGlTYl+DBg0qpXCl2blzJ71797Y9tk56MmbMGBYvXszw4cNJSEjgxRdfJC4ujvbt2/Pbb7+VmIxFXD989ZYJalLzJMgTQghh4abVEOprINTXuaDQbFZJy8knKTOP5Kw8kjPzScrKI7kgMLyUmEqOWUtKtjWNs4GhY9bA0BIMFgkEi7QgFt3v7+mGwa3mLU8hhLj+VCjIi4yMJDY2tsQsiImJiURGRmIymSqlcKWJiYlBVdUy00yaNEm6Z95ArN01U3JTqrcgQgghblgajYJ/QStbcWazmfj4eEJDQ+0mLrAPDPMtAWFWHilZeSRlWh4nZ1luSZl5pGTlk3wNgaHBTYO/hyXg8/NwswV/fp4F9z3cCvZZAkRrWgkOhRDlUaEgT1VVh2viZGRkOD2tpxBFWbtrpuXKxCtCCCFcp6zAsDTWwDA5qyA4tAsE8wsCROs2S6CYkp2PyaySk28mLj+HuLTyraMrwaEQojzKFeRZu0UqisKMGTPsxuCZTCb+/PNP2rdvX6kFFLWDv94fkJY8IYQQ17+igWFkcMnF6h1RVZWMXGPB2MF8UrLzbOMILY8tLYSpBfeLbpfgUAhRXuUK8vbs2QNYPqgOHDiAu3vhr17u7u60a9eOyZMnV24JRY1R1roe0l1TCCFETaYoCj4GN3wMbkQEOn9cdQWHep0GPw+3EjdfB9v8PAvv++glOBTielCuIG/9+vUAjBs3jnnz5uHr61slhRI1j1arpWfPnqXutwZ52cZs8kx5uGud7zYjhBBC1FTVFRzmGs3Ep+cSn+7ceoZF6bUKfp7u5Q4Q/TykBVGIylKhMXmLFi2q7HKIWs7H3QeNosGsmknNTSXEM6S6iySEEELcsCojOEzNLvuW5uCxWYVck1rhANG9lBbEUoNEuwBR43DOCCFqowoFeUJUNo2iwc/dj+TcZFJyUyTIE0IIIaqBXXBYzmMtE9Lk8ff5ONy8fEnPMTkdJKblGDGZVfKMZhLSc0moSICo1RQEgjp8PdzwNbgV/C36WFfqdr1OWhFFzSFBnnAJs9nMwYMHAWjdurXD8Xl++sIgTwghhBA3Fo1GwdfgRl0/PaGhfmWOxS/O2oJYVkuh5WZ0uN9kVskzmbmSkcuVjPIHiGAZh1gy+HPDx6C7aoDoa5CupuL6IkGecAlVVUlKSrLdd8Q6Li81VxZEF0IIIWqToi2I9QPKd6yqqmTmFbQaZllbBvNJzzGSVnA/LdtY8Lfk4/RcI6oKudfQigiWrqbOBIPW7T6GglbHgm16nXQ3FZVHgjxx3ZBlFIQQQghRXoqi4K3X4a3XUc/fo9zHm80qGXkFAaFdMFgySEwvGiAWua+qkGe8tpZES3dTS9DnYw0QDZaWRJ+CoND611uvxZSdQUS+Hj9Pd9s+N63zraeiZqtwkLd582YWLFjAqVOnWL58OfXq1eOLL74gMjKSW2+9tTLLKGoJackTQgghhKtZu5n6GtygnK2IYAkSM/OMhUFhKQFiicdFWhvNKgXdTfO4kpFXjrMft3tkcNMUBoP6ooGhfZBo6YJacpuPQcYm1hQVCvK+++47HnroIR588EH27NlDbq7lF4vU1FRef/11fvnll0otpKgdJMgTQgghxI1GoynsalqRlkRrd1O7INB23xIIpucaLa2IOUbL4+x8kjNzyDaqpOcYycozAZCTbyYnv+JdTsHa7dTaYqhzGCT6lhI4WscvStfT6lehIO/VV19l/vz5jB49mm+++ca2vXv37rz66quVVjhRu0h3TSGEEELUNkW7m9bFuSDRbDYTHx9PaGgoGo0Go8lMRq4lALS2EKbnGMnILbxfdHu67X7RtEbA2u20vC2K9ty0il3g560v2oJYWhdUnV16L3cdGo0EihVVoSDv2LFjDhe29vPzIyUl5VrLJGopCfKEEEIIIcpPp9Xg7+mOv6d7hfMwmdWCQNFxMJhWSoBYNLjMKJjEJt+kkpSZR1JmxQNFwBb8etsCxcIg0FvvhndBt1Trfu+CbqjWfdbjtbUwWKxQkBcWFsbJkydp1KiR3fYtW7bQuHHjyiiXqIWku6YQQgghRPXQahTbwvIVZR2fWDxIvForYlpOPpl5RjIKHhvNlpnYM3ILWhjTru3aPN21tiDQx+BmCQyLBY/W1sbigaOXu4bsHCMhpcwOf72qUJA3YcIEnnrqKT777DMUReHSpUts27aNyZMnM2PGjMouo6gBtFotMTExZaaRIE8IIYQQ4sZVdHxiRamqSq6xsPtpRo6R9Nx8MgpaDW3bc422bUVbFa3b03ON5BnNAGTlmcjKMxFfwbGKOo3CsVfCK3xN1aFCQd60adMwm83cdtttZGVl0bNnT/R6PZMnT+bJJ5+s7DKKWkK6awohhBBC1G6KomBw02Jw0xLsrb+mvHKNJjJzTWQU6U5qCwxt9/NtQWFGkfGJhcFkPnqtcsNNJFOhIE9RFF544QWmTJnCyZMnycjIICoqCm9v78oun6hFrEFeam4qqqrecG8mIYQQQghx/dDrtOh1WgK9Kj5W0Ww2cynuciWWyjUqFOTNnj2bOnXqMH78eKKiomzbP/vsMxISEpg6dWqlFVDUDGazmSNHjgDQqlUrNJqSi3X6uvsCYFSNZBmz8HLzcmkZhRBCCCGEKE53A07cUvKbthMWLFhAy5YtS2yPjo5m/vz511woUfOoqkpCQgIJCQmopQxc9dB54K6x/NIiXTaFEEIIIYSomAoFeXFxcYSHlxx8GBISQmxs7DUXStROiqLIuDwhhBBCCCGuUYWCvIiICLZu3Vpi+9atW6lbt+41F0rUXn6Gghk2c2SGTSGEEEIIISqiwksoPP300+Tn59OnTx8A1q1bx/PPP89zzz1XqQUUtYufe0GQlydBnhBCCCGEEBVRoSBvypQpJCYmMnHiRPLyLCvZGwwGpk6dyvTp0yu1gKJ2ke6aQgghhBBCXJsKL6Hw5pv/396dR0lV3vkf/9yq3ugd6B2afVEEGmSzNRAMHbE1SnQS0XgMYqJjIic6HZKRnASDk4jRA+MSEiaTH4PmZKJjEs0ZHY3YChgEhIaOLC6AIAi9sfW+1n1+f2BVaKHphepbVbffr3P6nOqqW7e+9T03bT48z32eX+gnP/mJ3n//ffXr10+jR49WbOzF7WUB+DdEJ+QBAAAAPdOjkOeXmJioadOmBasWQP3j+kuSTjWdCnElAAAAQGTqccgrLi5WcXGxKisrZdt2u9fWrFlz0YXBXTwej2bOnBl43JG0fmmSpOONxx2pCwAAAHCbHoW8ZcuW6eGHH9bUqVOVnZ0ty4q8DQLhLMuy5PV6Oz3OH/JONJ7o7ZIAAAAAV+pRyFu9erXWrl2rO+64I9j1oI9jJA8AAAC4OD0KeS0tLbryyiuDXQtczLZtffTRR5KkMWPGdDhl0x/yqhqrHKsNAAAAcJMebYb+7W9/W//93/8d7FrgYsYYlZeXq7y8XMaYDo/zh7zGtkY1tDY4VR4AAADgGj0ayWtqatJvfvMbvfHGG5o4caKio6Pbvb5y5cqgFIe+JyE6Qf2i+qmxrVHHG49rSPSQUJcEAAAARJQehbz33ntPkyZNkiTt3r273WsswoKLldYvTUdqj5wJecmEPAAAAKA7ehTy3nrrrWDXAQT4Qx735QEAAADd16N78oDexAqbAAAAQM/1eDN0Sdq7d68OHz6slpaWds/feOONF1UU+jb2ygMAAAB6rkch7+OPP9ZNN92kXbt2ybKswGqJ/vvxfD5f8CpEn8NIHgAAANBzPZquef/992v48OGqrKxUfHy89uzZo40bN2rq1Klav359kEuEG3g8Hl155ZW68sorO9wjzy+9X7ok9soDAAAAeqJHI3mbN2/Wm2++qbS0NHk8Hnk8Hn3hC1/Q8uXL9b3vfU87d+4Mdp2IcJZlKSYmpkvHDuw3UBLTNQEAAICe6NFIns/nU1JSkiQpLS1Nx44dkyQNHTpUH374YfCqQ5/EdE0AAACg53o0kjd+/Hj9/e9/1/DhwzVjxgw99thjiomJ0W9+8xuNGDEi2DXCBWzb1v79+yVJo0aNuuCUzcDCK00n5LN98nq8jtQIAAAAuEGPQt6Pf/xj1dfXS5IefvhhfeUrX9HMmTM1cOBAPf/880EtEO5gjAmM+I4cOfKCxw6IGyBLlmxj61TzqUDoAwAAANC5HoW8Sy+9VIMHD5Z0ZlTmgw8+0MmTJ9W/f38dOXIkqAWi74nyRKl/XH+dbDqpE40nCHkAAABAN/Tonrzhw4fr+PH290sNGDBAJ0+e1PDhw4NSGPo27ssDAAAAeqZHIc+/L97n1dXVKS4u7qIKAiS2UQAAAAB6qlvTNYuKiiSdWQ5/6dKlio+PD7zm8/m0detWTZo0KagFom/yb6PASB4AAADQPd0Kef7974wx2rVrV7t9z2JiYpSXl6fFixcHt0L0SYEVNtkrDwAAAOiWboW8t956S5K0cOFCPfnkk0pOTu6VogDuyQMAAAB6pkera/7qV79qd1/eJ598ohdffFHjxo3TNddcE7Ti4B4ej0dXXHFF4HFnCHkAAABAz/Ro4ZV58+bp2WeflSSdPn1a06dP14oVKzRv3jz9+te/DmqBcAfLshQXF6e4uDhZltXp8YQ8AAAAoGd6FPJ27NihmTNnSpL++Mc/KisrS5988omeffZZPfXUU0EtEH0TIQ8AAADomR6FvIaGBiUlJUmSXn/9dd18882B6XiffPJJUAuEO9i2rQMHDujAgQOybbvT4/0hr661To1tjb1dHgAAAOAaPQp5o0aN0ksvvaQjR47or3/9a+A+vMrKShZjwXkZY3TkyBEdOXKkw30Wz5YYnahYb6wkVtgEAAAAuqNHIW/p0qVavHixhg0bphkzZig/P1/SmVG9yZMnB7VA9E2WZTFlEwAAAOiBHq2u+bWvfU1f+MIXVFZWpry8vMDzc+bM0U033RS04tC3Dew3UEfrjjKSBwAAAHRDj0KeJGVlZSkrK6vdc9OnT7/oggC/tDhG8gAAAIDu6nLIKyoq0r/9278pISFBRUVFFzx25cqVF10YEJiu2UTIAwAAALqqyyFv586dam1tDTzuSFf2QAO6wh/yqhqqQlwJAAAAEDm6HPLeeuut8z4Gekta/JmQxz15AAAAQNf1+J48oDs8Ho+mTZsWeNwV3JMHAAAAdF+37snrKjfck3fw4EHdddddqqiokNfr1ZYtW5SQkBDqsiKWZVnd7h/35AEAAADd16178s62Y8cOtbW1aezYsZKkjz76SF6vV1OmTAluhSFy55136mc/+5lmzpypkydPKjY2NtQl9Tln75NnjOF+TwAAAKALenRP3sqVK5WUlKRnnnlG/fv3lySdOnVKCxcu1MyZM4NfpcP27Nmj6OjowHcZMGBAiCuKfLZt6/Dhw5KkIUOGdGnK5sB+AyVJbXabalpqlBKb0qs1AgAAAG7QtZujPmfFihVavnx5IOBJUv/+/fWzn/1MK1asCFpxHdm4caNuuOEG5eTkyLIsvfTSS+ccs2rVKg0bNkxxcXGaMWOG3n333S6ff9++fUpMTNQNN9ygyy+/XI888kgQq++bjDE6dOiQDh06JGNMl94T441RckyyJKmyobI3ywMAAABco0cLr9TU1Kiq6txl7auqqlRbW3vRRXWmvr5eeXl5uuuuu3TzzTef8/rzzz+voqIirV69WjNmzNATTzyhuXPn6sMPP1RGRoYkadKkSWprazvnva+//rra2tr09ttvq7S0VBkZGbr22ms1bdo0ffnLX+7174b2MhMyVdNSo8qGSo3uPzrU5QAAAABhr0ch76abbtLChQu1YsUKTZ8+XZK0detW/eAHPzhv6Aq2wsJCFRYWdvj6ypUrdffdd2vhwoWSpNWrV+uVV17RmjVr9OCDD0qSSktLO3z/oEGDNHXqVOXm5kqSrrvuOpWWlhLyQiAjPkP7Tu1jJA8AAADooh6FvNWrV2vx4sX6xje+EdggPSoqSt/61rf0+OOPB7XA7mppaVFJSYmWLFkSeM7j8aigoECbN2/u0jmmTZumyspKnTp1SikpKdq4caP++Z//ucPjm5ub1dzcHPi9pqZG0pn70Gzb7uE3uXi2bcsYE9IaPl+L/3FXF1HJ7JcpSSqrLwuL79GRcOq1m9Fn59BrZ9Bn59BrZ9BnZ9Bn54Rbr7taR49CXnx8vH71q1/p8ccf14EDByRJI0eODIstBo4fPy6fz6fMzMx2z2dmZuqDDz7o0jmioqL0yCOPaNasWTLG6JprrtFXvvKVDo9fvny5li1bds7zVVVVampq6t4XCCLbtlVdXS1jTJf3pustPp9P9fX1kqTKykp5vd4uvS/RJEqSPjnxiSorw3c0L5x67Wb02Tn02hn02Tn02hn02Rn02Tnh1uuu3hp3UZuhJyQkaOLEiRdzirDV2ZTQsy1ZsqTdPoI1NTXKzc1Venq6kpOTe6vETvlHzNLT00N+Ufp8vsA/AmRkZHQ55A2vHi4dkGrsmsD9lOEonHrtZvTZOfTaGfTZOfTaGfTZGfTZOeHW67i4uC4dd1EhLxylpaXJ6/WqoqKi3fMVFRXKysrqlc+MjY097z56Ho8n5BeDZVlhUcfZ+9x1p57sxGxJUkVjRci/Q2fCpdduR5+dQ6+dQZ+dQ6+dQZ+dQZ+dE0697moNoa80yGJiYjRlyhQVFxcHnrNtW8XFxcrPzw9hZX2bx+PR5Zdfrssvv7xb/wPJjD8z7ZaFVwAAAICuiciRvLq6Ou3fvz/w+8GDB1VaWqoBAwZoyJAhKioq0oIFCzR16lRNnz5dTzzxhOrr6wOrbcJ5lmX1aOpqRsKZKZrVzdVqbGtUv6h+wS4NAAAAcJVuhbzdu3dr/PjxvVVLl23fvl1XX3114Hf//XALFizQ2rVrNX/+fFVVVWnp0qUqLy/XpEmT9Nprr52zGAvCX1J0kvpF9VNjW6MqGyo1NHloqEsCAAAAwlq3Qt7EiRM1bdo0ffvb39att96qpKSk3qrrgmbPnh1Yjr8jixYt0qJFixyqCJ2xbVuffvqpJGnw4MFdnrJpWZYy4zN1qOaQKuorCHkAAABAJ7p1T96GDRt02WWX6fvf/76ys7O1YMECvf32271VG1zEGKOPP/5YH3/8cacB/fMyE86MwFY0VHRyJAAAAIBuhbyZM2dqzZo1Kisr09NPP61Dhw7pi1/8osaMGaNf/OIXKi8v76060Yf5F18h5AEAAACd69HqmgkJCVq4cKE2bNigjz76SF//+te1atUqDRkyRDfeeGOwa0QfFwh59YQ8AAAAoDMXvYXCqFGj9KMf/Ug//vGPlZSUpFdeeSUYdQEBjOQBAAAAXXdRWyhs3LhRa9as0Z/+9Cd5PB7dcsst+ta3vhWs2gBJ3JMHAAAAdEe3Q96xY8e0du1arV27Vvv379eVV16pp556SrfccosSEhJ6o0b0cUzXBAAAALquWyGvsLBQb7zxhtLS0vTNb35Td911l8aOHdtbtQGS/jGSd6LphFp9rYr2Roe4IgAAACB8dSvkRUdHa+3atbr11lvl9Xp7qya4kMfj0aRJkwKPu6N/bH9Fe6LVareqsrFSgxIH9UKFAAAAgDt06/9tv/LKK/ryl79MwEO3WZal1NRUpaamyrKsbr+XKZsAAABA13Qr5HV3E2sgWLISsiSx+AoAAADQmW4vvNLdURhAkmzbVllZmSQpOzu721M2/fflldeXB702AAAAwE26HfLGjBnTadA7efJkjwuCOxljtG/fPklSVlZWt9/PXnkAAABA13Q75C1btkwpKSm9UQvQocB0Te7JAwAAAC6o2yHv1ltvVUZGRm/UAnSIkTwAAACga7p1YxT34yFUuCcPAAAA6BpW10RE8I/kHW88rla7NcTVAAAAAOGrWyHPtm2maiIkBsQNULQnWkZGxxuOh7ocAAAAIGx1bx17IEQ8lkcZ8Wf+gaG8gSmbAAAAQEe6vfAK0BMej0cTJkwIPO6JzPhMHa07ygqbAAAAwAUQ8uAIy7I0cODAizqHf/EVVtgEAAAAOsZ0TUQM/155rLAJAAAAdIyRPDjCtm1VVlZKkjIyMno0ZZO98gAAAIDOEfLgCGOMPvjgA0lSenp6j86RFc9IHgAAANAZpmsiYvina7LwCgAAANAxQh4ihn/hlarGKrX62BAdAAAAOB9CHiLG2RuiVzZWhrocAAAAICwR8hAxPJYnsPgK9+UBAAAA50fIQ0TJTsyWRMgDAAAAOkLIQ0RhhU0AAADgwthCAY7weDwaN25c4HFP+VfYLKsvC0pdAAAAgNsQ8uAIy7KUkZFx0edhGwUAAADgwpiuiYjiD3nlDUzXBAAAAM6HkTw4whijqqoqSVJ6erosy+rReVhdEwAAALgwRvLgCNu2tXfvXu3du1e2bff4PP7VNU83n1ZjW2OwygMAAABcg5CHiJIUnaT4qHhJjOYBAAAA50PIQ0SxLOsf9+UR8gAAAIBzEPIQcQh5AAAAQMcIeYg4rLAJAAAAdIyQh4jDSB4AAADQMUIeIk52wpkVNgl5AAAAwLnYJw+OsCxLl1xySeDxxfCHvLL6souuCwAAAHAbQh4c4fF4lJWVFZRznT2SZ4y56NAIAAAAuAnTNRFxMhMyJUmNbY2qbq4OcTUAAABAeCHkwRHGGJ04cUInTpyQMeaizhXrjdXAuIGSpGP1x4JRHgAAAOAahDw4wrZt7dq1S7t27ZJt2xd9Pu7LAwAAAM6PkIeIlJ3ICpsAAADA+RDyEJECI3l1jOQBAAAAZyPkISIxXRMAAAA4P0IeIhIbogMAAADnR8hDRMpKPLPnHiN5AAAAQHuEPEQk/0heVWOVWnwtIa4GAAAACB9RoS4AfYNlWRo9enTg8cXqH9tfsd5YNfuaVdFQodyk3Is+JwAAAOAGjOTBER6PR4MGDdKgQYPk8Vz8ZWdZFvflAQAAAOdByEPEyko4c1/esbpjIa4EAAAACB9M14QjjDGqrq6WJKWkpARlyqY/5DGSBwAAAPwDI3lwhG3bKi0tVWlpqWzbDso5AyGvgZAHAAAA+BHyELG4Jw8AAAA4FyEPESsrnumaAAAAwOcR8hCxuCcPAAAAOBchDxHLH/LqWutU11IX4moAAACA8EDI68C///u/67LLLtO4ceP0ve99T8aYUJeEz4mPjldSTJIkRvMAAAAAP0LeeVRVVemXv/ylSkpKtGvXLpWUlGjLli2hLgvnwQqbAAAAQHvsk9eBtrY2NTU1SZJaW1uVkZER4ooim2VZGjFiROBxsGTFZ2nfqX2M5AEAAACficiRvI0bN+qGG25QTk6OLMvSSy+9dM4xq1at0rBhwxQXF6cZM2bo3Xff7fL509PTtXjxYg0ZMkQ5OTkqKCjQyJEjg/gN+h6Px6MhQ4ZoyJAh8niCd9mxjQIAAADQXkSGvPr6euXl5WnVqlXnff35559XUVGRHnroIe3YsUN5eXmaO3euKisrA8dMmjRJ48ePP+fn2LFjOnXqlF5++WUdOnRIR48e1TvvvKONGzc69fXQDf7pmmX1ZSGuBAAAAAgPETlds7CwUIWFhR2+vnLlSt19991auHChJGn16tV65ZVXtGbNGj344IOSpNLS0g7f/8ILL2jUqFEaMGCAJOn666/Xli1bNGvWrPMe39zcrObm5sDvNTU1kiTbtmXbdre+WzDZti1jTEhr8DPGqLa2VpKUlJQUtCmbGfFnptGW15fT6z6APjuHXjuDPjuHXjuDPjuDPjsn3Hrd1ToiMuRdSEtLi0pKSrRkyZLAcx6PRwUFBdq8eXOXzpGbm6t33nlHTU1Nio6O1vr163XPPfd0ePzy5cu1bNmyc56vqqoK3NcXCrZtq7q6WsaYoE6R7Amfz6edO3dKkiZPniyv1xuU88a1xEmSjtYcbTdS67Rw6rWb0Wfn0Gtn0Gfn0Gtn0Gdn0GfnhFuv/YMmnXFdyDt+/Lh8Pp8yMzPbPZ+ZmakPPvigS+e44oordN1112ny5MnyeDyaM2eObrzxxg6PX7JkiYqKigK/19TUKDc3V+np6UpOTu7ZFwkC27ZlWZbS09NDflH6fD4lJCRIkjIyMoIW8i7pd4m0TTrefFzp6elBXdSlO8Kp125Gn51Dr51Bn51Dr51Bn51Bn50Tbr2Oi4vr0nGuC3nB8vOf/1w///nPu3RsbGysYmNjz3ne4/GE/GKwLCss6jDGBAJYMOvJTjyz8Eqzr1k1rTXqH9c/KOftiXDptdvRZ+fQa2fQZ+fQa2fQZ2fQZ+eEU6+7WkPoKw2ytLQ0eb1eVVRUtHu+oqJCWVlZIaoKvSXGG6OBcQMlsfgKAAAAILkw5MXExGjKlCkqLi4OPGfbtoqLi5Wfnx/CytBb/NsoEPIAAACACJ2uWVdXp/379wd+P3jwoEpLSzVgwAANGTJERUVFWrBggaZOnarp06friSeeUH19fWC1TbhLdmK2dp/YzV55AAAAgCI05G3fvl1XX3114Hf/oicLFizQ2rVrNX/+fFVVVWnp0qUqLy/XpEmT9Nprr52zGAvcwT+Sd6zuWIgrAQAAAEIvIkPe7NmzZYy54DGLFi3SokWLHKoInbEsS8OGDQs8DqacxBxJTNcEAAAApAgNeYg8Ho8nEPKCLSvhzII6ZXWEPAAAAMB1C6+g78lJODOSd6ye6ZoAAAAAIQ+OMMaovr5e9fX1nU617S7/PXknm06qqa0pqOcGAAAAIg0hD46wbVvbtm3Ttm3bZNt2UM+dEpuiflH9JIkVNgEAANDnEfIQ8SzLCkzZZPEVAAAA9HWEPLhCVuJni68Q8gAAANDHEfLgCoHFV9grDwAAAH0cIQ+u4F98hZE8AAAA9HWEPLhCdiIhDwAAAJAIeXCJwEgeG6IDAACgj4sKdQHoGyzLUm5ubuBxsPnvyStvKJdtbHks/v0CAAAAfRMhD47weDwaOXJkr50/LT5NHsujNrtNJ5tOKq1fWq99FgAAABDOGO6AK0R7opUWdybYVdRXhLgaAAAAIHQIeXCEMUZNTU1qamqSMaZXPiMr4cxeeeUN5b1yfgAAACASEPLgCNu2tWXLFm3ZskW2bffKZ2QmZEpiJA8AAAB9GyEPrpEZ/1nIayDkAQAAoO8i5ME1CHkAAAAAIQ8uwnRNAAAAgJAHF2EkDwAAACDkwUXOHsnrrRU8AQAAgHBHyINrZPTLkCS12C063Xw6tMUAAAAAIULIgyMsy1JOTo5ycnJkWVavfEa0N1oD4wZKYsomAAAA+i5CHhzh8Xg0ZswYjRkzRh5P7112LL4CAACAvo6QB1dh8RUAAAD0dVGhLgB9gzFGra2tkqTo6Ohem7LpD3nl9eW9cn4AAAAg3DGSB0fYtq133nlH77zzjmzb7rXPCUzXZCQPAAAAfRQhD67CdE0AAAD0dYQ8uEpWQpYkFl4BAABA30XIg6ucPZLHhugAAADoiwh5cBX/PXmNbY2qaakJcTUAAACA8wh5cJVYb6z6x/aXxAqbAAAA6JsIeXCdwH15LL4CAACAPoiQB0dYlqWsrCxlZWX12h55fuyVBwAAgL6MzdDhCI/Ho0suucSRz/Lfl0fIAwAAQF/ESB5ch+maAAAA6MsYyYMjjDGybVvSmVG93pyyGdhGgb3yAAAA0AcxkgdH2Latt99+W2+//XYg7PUW/0heeQPTNQEAAND3EPLgOlnxn4W8+nI2RAcAAECfQ8iD62QkZEiSmn3Nqm6uDnE1AAAAgLMIeXCdWG+sBsQNkMSUTQAAAPQ9hDy4EouvAAAAoK8i5MGV2CsPAAAAfRUhD67kX3yFvfIAAADQ17BPHhxhWZbS09MDj3ubfxuFT2s/7fXPAgAAAMIJIQ+O8Hg8uuyyyxz7vLEDxkqS9p7c69hnAgAAAOGA6ZpwpfEDx0uSPqn5hG0UAAAA0KcQ8uBKqXGpGpw4WJK058SeEFcDAAAAOIeQB0f4fD6tX79e69evl8/nc+Qzx6edGc3bc5yQBwAAgL6DkAfX8oe8Xcd3hbgSAAAAwDmEPLgWI3kAAADoiwh5cK1LB1wqj+VRZWOlKurZLw8AAAB9A1sowLXio+M1MnWk9p3ap8e2PaYrcq5QtCdaidGJyk7IVnx0fKfnyE3KVZSH/5kAAAAgcvD/XuFqM7JmaN+pfXr9k9f1+iev9+j9v537216oDAAAAOgdhDy4WtGUIk3JnKKSihIdqT0in/GppqVG5XXlavI1dfg+Y4xqW2v13vH3HKwWAAAAuHiEPDjCsiwNGDAg8Ngp0d5oFQwtUMHQgm69r7q5Wl947gtqbGtUq69V0d7oXqoQAAAACC5CHhzh8Xg0ceLEUJfRZUkxSbJkyciouqVaaf3SQl0SAAAA0CWsrgmch8fyKDEmUZJU01IT4moAAACAriPkAR1IiUmRJNU0E/IAAAAQOZiuCUf4fD5t2rRJknTVVVfJ6/WGuKLOJccmS3WM5AEAACCy9PmRvJtuukn9+/fX1772tXNee/nllzV27FiNHj1av/0ty+hfLNu2Zdt2qMvosuSYZElnFmEBAAAAIkWfD3n333+/nn322XOeb2trU1FRkd58803t3LlTjz/+uE6cOBGCChEq/pDHSB4AAAAiSZ8PebNnz1ZSUtI5z7/77ru67LLLNGjQICUmJqqwsFCvv979zbQRuZJjPwt53JMHAACACBLWIW/jxo264YYblJOTI8uy9NJLL51zzKpVqzRs2DDFxcVpxowZevfdd4Py2ceOHdOgQYMCvw8aNEhHjx4NyrkRGQILrzCSBwAAgAgS1iGvvr5eeXl5WrVq1Xlff/7551VUVKSHHnpIO3bsUF5enubOnavKysrAMZMmTdL48ePP+Tl27JhTXwMRKjCSR8gDAABABAnr1TULCwtVWFjY4esrV67U3XffrYULF0qSVq9erVdeeUVr1qzRgw8+KEkqLS3t0Wfn5OS0G7k7evSopk+fft5jm5ub1dzcHPi9puZMKAj1QiO2bcsYExaLnfhr8T+2LCvEFXUuKfrMNN7q5upOexhOvXYz+uwceu0M+uwceu0M+uwM+uyccOt1V+sI65B3IS0tLSopKdGSJUsCz3k8HhUUFGjz5s0Xff7p06dr9+7dOnr0qFJSUvTqq6/qJz/5yXmPXb58uZYtW3bO81VVVWpqarroWnrKtm1VV1fLGCOPJ7SDtj6fL1BDZWVlRGyhYBrPhNLjdcfbjQ6fTzj12s3os3PotTPos3PotTPoszPos3PCrde1tbVdOi5iQ97x48fl8/mUmZnZ7vnMzEx98MEHXT5PQUGB/v73v6u+vl6DBw/WCy+8oPz8fEVFRWnFihW6+uqrZdu2fvjDH2rgwIHnPceSJUtUVFQU+L2mpka5ublKT09XcnJyz75gEPhHzNLT08PioszOzg51Cd2S68uVJDWaRmVkZFzw2HDrtVvRZ+fQa2fQZ+fQa2fQZ2fQZ+eEW6/j4uK6dFzEhrxgeeONNzp87cYbb9SNN97Y6TliY2MVGxt7zvMejyfkF4NlWWFRRyRKjUuVJNW21Hapf/TaGfTZOfTaGfTZOfTaGfTZGfTZOeHU667WEPpKeygtLU1er1cVFRXtnq+oqFBWVlaIqoKbsPAKAAAAIlHEhryYmBhNmTJFxcXFgeds21ZxcbHy8/NDWBnOx+fzadOmTdq0aZN8Pl+oy+kS/2bozb5mNbWF7t5KAAAAoDvCerpmXV2d9u/fH/j94MGDKi0t1YABAzRkyBAVFRVpwYIFmjp1qqZPn64nnnhC9fX1gdU2EV5aW1tDXUK3JEQnyGN5ZBtbNS01iovq2hxoAAAAIJTCOuRt375dV199deB3/+ImCxYs0Nq1azV//nxVVVVp6dKlKi8v16RJk/Taa6+dsxgL0BMey6PkmGSdbj6tmuYaZcRfePEVAAAAIByEdcibPXt2YG+1jixatEiLFi1yqCL0NYGQx315AAAAiBARe08e4AT/fXnVzdUhrgQAAADoGkIecAGssAkAAIBIQ8gDLsA/kkfIAwAAQKQI63vy4C5JSUmhLqHbUmJTJDFdEwAAAJGDkAdHeL1eTZkyJdRldBsjeQAAAIg0hDzgAvwh7+Pqj7Xp6KYOj7ONrdOnTyu1LVUey/lZ0B7Lo8yETI1IGeH4ZwMAACC8EPKAC/BP19xatlVby7aGuJrO/fd1/60J6RNCXQYAAABCiJAHR/h8Pm3btk2SNG3aNHm93hBX1DWzBs/SVTlX6UTTiQseZ4xRW1uboqKiZFmWQ9X947M/rftU9a312n96PyEPAACgjyPkwTFNTU2hLqHbBvYbqNVfXt3pcbZtq7KyUhkZGfJ4nJ+u+cONP9SrB19VbUut458NAACA8MIWCoAL+O8drG0l5AEAAPR1hDzABZJizmxPwUgeAAAACHmACxDyAAAA4EfIA1zAH/LYzw8AAACEPMAFGMkDAACAH6trwjHx8fGhLsG1kqM/W3ilg5DX0NqgVrvVyZJ6hW3bqmmpUX+7v2I9saEuBwAAICwR8uAIr9er6dOnh7oM17rQSN6L+17UQ+88JCPjdFm95tlrn9XkzMmhLgMAACAsMV0TcIGOQt6xumN69N1HXRXwAAAAcGGM5AEu4A95da118tk+eT1eGWP08OaH1dDWoMkZk/X/5v4/eSL833Vs21ZlVaWy0rJCXQoAAEDYIuTBET6fTyUlJZKkKVOmyOv1hrgid/GHPOlM0EuJTdH2iu3adGyTYjwxWnblMkV7okNYYXBYsuS1vLIsK9SlAAAAhK3I/md9RJSGhgY1NDSEugxXivHGKM4bJ+kfUzYPnD4gSbpq0FUanjI8ZLUBAADAWYQ8wCU+f19eZUOlJCkzPjNkNQEAAMB5hDzAJT4f8ioaKiRJmQmEPAAAgL6EkAe4REcjeRnxGSGrCQAAAM4j5AEu4Q95NS01kqSqhipJhDwAAIC+hpAHuAQjeQAAAJDYQgEOiouLC3UJrpYckyxJqm2tVUNrg2pbz4Q9Fl4BAADoWwh5cITX69UVV1wR6jJc7eyRPP8oXkJ0ghKiE0JZFgAAABzGdE3AJc4X8piqCQAA0PcQ8gCXOHvhFf/2CYQ8AACAvofpmnCEz+dTaWmpJGnSpEnyer2hLciFzjeSx/14AAAAfQ8hD46pra0NdQmulhz92cIrTNcEAADo05iuCbgE9+QBAABAIuQBrkHIAwAAgETIA1zDH/LqWutUVl8miXvyAAAA+iJCHuAS/pAnSVWNVZIYyQMAAOiLCHmAS8R4YxQfFR/43WN5NDBuYAgrAgAAQCgQ8uCY6OhoRUdHh7oMV3tgygOakDZBWQlZunXsrfJ62KoCAACgr2ELBTjC6/XqqquuCnUZrnfbJbfptktuC3UZAAAACCFG8gAAAADARQh5AAAAAOAiTNeEI3w+n3bt2iVJmjBhgrxe7hUDAAAAegMhD445ffp0qEsAAAAAXI/pmgAAAADgIoQ8AAAAAHARQh4AAAAAuAghDwAAAABchJAHAAAAAC7C6ppwjMfDvykAAAAAvY2QB0d4vV7NmjUr1GUAAAAArsfQCgAAAAC4CCEPAAAAAFyE6ZpwhG3b2r17tyRp/Pjx3J8HAAAA9BJCHhxhjNHJkycDjwEAAAD0DoZTAAAAAMBFCHkAAAAA4CKEPAAAAABwEUIeAAAAALgIIQ8AAAAAXITVNXuBf/XImpqakNZh27Zqa2sVFxcX8i0LfD6f6uvrJZ3pi9frDWk9wRZOvXYz+uwceu0M+uwceu0M+uwM+uyccOu1P190tlo9Ia8X1NbWSpJyc3NDXAkAAAAAt6mtrVVKSkqHr1uGTcuCzrZtHTt2TElJSbIsK2R11NTUKDc3V0eOHFFycnLI6ugL6LUz6LNz6LUz6LNz6LUz6LMz6LNzwq3XxhjV1tYqJyfngiOLjOT1Ao/Ho8GDB4e6jIDk5OSwuCj7AnrtDPrsHHrtDPrsHHrtDPrsDPrsnHDq9YVG8PxCP7EUAAAAABA0hDwAAAAAcBFCnovFxsbqoYceUmxsbKhLcT167Qz67Bx67Qz67Bx67Qz67Az67JxI7TULrwAAAACAizCSBwAAAAAuQsgDAAAAABch5AEAAACAixDyXGzVqlUaNmyY4uLiNGPGDL377ruhLimi/fSnP5VlWe1+LrnkksDrTU1Nuu+++zRw4EAlJibqn/7pn1RRURHCiiPHxo0bdcMNNygnJ0eWZemll15q97oxRkuXLlV2drb69eungoIC7du3r90xJ0+e1O23367k5GSlpqbqW9/6lurq6hz8FuGvsz7feeed51zj1157bbtj6HPnli9frmnTpikpKUkZGRn66le/qg8//LDdMV35e3H48GFdf/31io+PV0ZGhn7wgx+ora3Nya8S1rrS59mzZ59zTd97773tjqHPnfv1r3+tiRMnBvYJy8/P16uvvhp4nes5ODrrM9dz73j00UdlWZYeeOCBwHNuuKYJeS71/PPPq6ioSA899JB27NihvLw8zZ07V5WVlaEuLaJddtllKisrC/z87W9/C7z2L//yL/rf//1fvfDCC9qwYYOOHTumm2++OYTVRo76+nrl5eVp1apV5339scce01NPPaXVq1dr69atSkhI0Ny5c9XU1BQ45vbbb9eePXu0bt06vfzyy9q4caPuuecep75CROisz5J07bXXtrvG//CHP7R7nT53bsOGDbrvvvu0ZcsWrVu3Tq2trbrmmmtUX18fOKazvxc+n0/XX3+9Wlpa9M477+iZZ57R2rVrtXTp0lB8pbDUlT5L0t13393umn7ssccCr9Hnrhk8eLAeffRRlZSUaPv27frSl76kefPmac+ePZK4noOlsz5LXM/Btm3bNv3Hf/yHJk6c2O55V1zTBq40ffp0c9999wV+9/l8JicnxyxfvjyEVUW2hx56yOTl5Z33tdOnT5vo6GjzwgsvBJ57//33jSSzefNmhyp0B0nmxRdfDPxu27bJysoyjz/+eOC506dPm9jYWPOHP/zBGGPM3r17jSSzbdu2wDGvvvqqsSzLHD161LHaI8nn+2yMMQsWLDDz5s3r8D30uWcqKyuNJLNhwwZjTNf+Xvzf//2f8Xg8pry8PHDMr3/9a5OcnGyam5ud/QIR4vN9NsaYL37xi+b+++/v8D30uef69+9vfvvb33I99zJ/n43heg622tpaM3r0aLNu3bp2vXXLNc1Ingu1tLSopKREBQUFgec8Ho8KCgq0efPmEFYW+fbt26ecnByNGDFCt99+uw4fPixJKikpUWtra7ueX3LJJRoyZAg9v0gHDx5UeXl5u96mpKRoxowZgd5u3rxZqampmjp1auCYgoICeTwebd261fGaI9n69euVkZGhsWPH6jvf+Y5OnDgReI0+90x1dbUkacCAAZK69vdi8+bNmjBhgjIzMwPHzJ07VzU1Ne3+VR//8Pk++/3+979XWlqaxo8fryVLlqihoSHwGn3uPp/Pp+eee0719fXKz8/neu4ln++zH9dz8Nx33326/vrr2127knv+RkeFugAE3/Hjx+Xz+dpdeJKUmZmpDz74IERVRb4ZM2Zo7dq1Gjt2rMrKyrRs2TLNnDlTu3fvVnl5uWJiYpSamtruPZmZmSovLw9NwS7h79/5rmf/a+Xl5crIyGj3elRUlAYMGED/u+Haa6/VzTffrOHDh+vAgQP60Y9+pMLCQm3evFler5c+94Bt23rggQd01VVXafz48ZLUpb8X5eXl573m/a+hvfP1WZK+8Y1vaOjQocrJydF7772nf/3Xf9WHH36oP//5z5Loc3fs2rVL+fn5ampqUmJiol588UWNGzdOpaWlXM9B1FGfJa7nYHruuee0Y8cObdu27ZzX3PI3mpAHdFFhYWHg8cSJEzVjxgwNHTpU//M//6N+/fqFsDIgOG699dbA4wkTJmjixIkaOXKk1q9frzlz5oSwssh13333affu3e3u30XwddTns+8XnTBhgrKzszVnzhwdOHBAI0eOdLrMiDZ27FiVlpaqurpaf/zjH7VgwQJt2LAh1GW5Tkd9HjduHNdzkBw5ckT333+/1q1bp7i4uFCX02uYrulCaWlp8nq956wCVFFRoaysrBBV5T6pqakaM2aM9u/fr6ysLLW0tOj06dPtjqHnF8/fvwtdz1lZWecsKtTW1qaTJ0/S/4swYsQIpaWlaf/+/ZLoc3ctWrRIL7/8st566y0NHjw48HxX/l5kZWWd95r3v4Z/6KjP5zNjxgxJandN0+euiYmJ0ahRozRlyhQtX75ceXl5evLJJ7meg6yjPp8P13PPlJSUqLKyUpdffrmioqIUFRWlDRs26KmnnlJUVJQyMzNdcU0T8lwoJiZGU6ZMUXFxceA527ZVXFzcbl43Lk5dXZ0OHDig7OxsTZkyRdHR0e16/uGHH+rw4cP0/CINHz5cWVlZ7XpbU1OjrVu3Bnqbn5+v06dPq6SkJHDMm2++Kdu2A/8RRPd9+umnOnHihLKzsyXR564yxmjRokV68cUX9eabb2r48OHtXu/K34v8/Hzt2rWrXahet26dkpOTA1O3+rrO+nw+paWlktTumqbPPWPbtpqbm7mee5m/z+fD9dwzc+bM0a5du1RaWhr4mTp1qm6//fbAY1dc06Fe+QW947nnnjOxsbFm7dq1Zu/eveaee+4xqamp7VYBQvd8//vfN+vXrzcHDx40mzZtMgUFBSYtLc1UVlYaY4y59957zZAhQ8ybb75ptm/fbvLz801+fn6Iq44MtbW1ZufOnWbnzp1Gklm5cqXZuXOn+eSTT4wxxjz66KMmNTXV/OUvfzHvvfeemTdvnhk+fLhpbGwMnOPaa681kydPNlu3bjV/+9vfzOjRo81tt90Wqq8Uli7U59raWrN48WKzefNmc/DgQfPGG2+Yyy+/3IwePdo0NTUFzkGfO/ed73zHpKSkmPXr15uysrLAT0NDQ+CYzv5etLW1mfHjx5trrrnGlJaWmtdee82kp6ebJUuWhOIrhaXO+rx//37z8MMPm+3bt5uDBw+av/zlL2bEiBFm1qxZgXPQ56558MEHzYYNG8zBgwfNe++9Zx588EFjWZZ5/fXXjTFcz8FyoT5zPfeuz69c6oZrmpDnYk8//bQZMmSIiYmJMdOnTzdbtmwJdUkRbf78+SY7O9vExMSYQYMGmfnz55v9+/cHXm9sbDTf/e53Tf/+/U18fLy56aabTFlZWQgrjhxvvfWWkXTOz4IFC4wxZ7ZR+MlPfmIyMzNNbGysmTNnjvnwww/bnePEiRPmtttuM4mJiSY5OdksXLjQ1NbWhuDbhK8L9bmhocFcc801Jj093URHR5uhQ4eau++++5x/GKLPnTtfjyWZ//qv/woc05W/F4cOHTKFhYWmX79+Ji0tzXz/+983ra2tDn+b8NVZnw8fPmxmzZplBgwYYGJjY82oUaPMD37wA1NdXd3uPPS5c3fddZcZOnSoiYmJMenp6WbOnDmBgGcM13OwXKjPXM+96/Mhzw3XtGWMMc6NGwIAAAAAehP35AEAAACAixDyAAAAAMBFCHkAAAAA4CKEPAAAAABwEUIeAAAAALgIIQ8AAAAAXISQBwAAAAAuQsgDAAAAABch5AEAEIaGDRumJ554ItRlAAAiECEPANDn3XnnnfrqV78qSZo9e7YeeOABxz577dq1Sk1NPef5bdu26Z577nGsDgCAe0SFugAAANyopaVFMTExPX5/enp6EKsBAPQljOQBAPCZO++8Uxs2bNCTTz4py7JkWZYOHTokSdq9e7cKCwuVmJiozMxM3XHHHTp+/HjgvbNnz9aiRYv0wAMPKC0tTXPnzpUkrVy5UhMmTFBCQoJyc3P13e9+V3V1dZKk9evXa+HChaqurg583k9/+lNJ507XPHz4sObNm6fExEQlJyfrlltuUUVFReD1n/70p5o0aZJ+97vfadiwYUpJSdGtt96q2trawDF//OMfNWHCBPXr108DBw5UQUGB6uvre6mbAIBQIeQBAPCZJ598Uvn5+br77rtVVlamsrIy5ebm6vTp0/rSl76kyZMna/v27XrttddUUVGhW265pd37n3nmGcXExGjTpk1avXq1JMnj8eipp57Snj179Mwzz+jNN9/UD3/4Q0nSlVdeqSeeeELJycmBz1u8ePE5ddm2rXnz5unkyZPasGGD1q1bp48//ljz589vd9yBAwf00ksv6eWXX9bLL7+sDRs26NFHH5UklZWV6bbbbtNdd92l999/X+vXr9fNN98sY0xvtBIAEEJM1wQA4DMpKSmKiYlRfHy8srKyAs//8pe/1OTJk/XII48EnluzZo1yc3P10UcfacyYMZKk0aNH67HHHmt3zrPv7xs2bJh+9rOf6d5779WvfvUrxcTEKCUlRZZltfu8zysuLtauXbt08OBB5ebmSpKeffZZXXbZZdq2bZumTZsm6UwYXLt2rZKSkiRJd9xxh4qLi/Xzn/9cZWVlamtr080336yhQ4dKkiZMmHAR3QIAhCtG8gAA6MTf//53vfXWW0pMTAz8XHLJJZLOjJ75TZky5Zz3vvHGG5ozZ44GDRqkpKQk3XHHHTpx4oQaGhq6/Pnvv/++cnNzAwFPksaNG6fU1FS9//77geeGDRsWCHiSlJ2drcrKSklSXl6e5syZowkTJujrX/+6/vM//1OnTp3qehMAABGDkAcAQCfq6up0ww03qLS0tN3Pvn37NGvWrMBxCQkJ7d536NAhfeUrX9HEiRP1pz/9SSUlJVq1apWkMwuzBFt0dHS73y3Lkm3bkiSv16t169bp1Vdf1bhx4/T0009r7NixOnjwYNDrAACEFiEPAICzxMTEyOfztXvu8ssv1549ezRs2DCNGjWq3c/ng93ZSkpKZNu2VqxYoSuuuEJjxozRsWPHOv28z7v00kt15MgRHTlyJPDc3r17dfr0aY0bN67L382yLF111VVatmyZdu7cqZiYGL344otdfj8AIDIQ8gAAOMuwYcO0detWHTp0SMePH5dt27rvvvt08uRJ3Xbbbdq2bZsOHDigv/71r1q4cOEFA9qoUaPU2tqqp59+Wh9//LF+97vfBRZkOfvz6urqVFxcrOPHj593GmdBQYEmTJig22+/XTt27NC7776rb37zm/riF7+oqVOndul7bd26VY888oi2b9+uw4cP689//rOqqqp06aWXdq9BAICwR8gDAOAsixcvltfr1bhx45Senq7Dhw8rJydHmzZtks/n0zXXXKMJEybogQceUGpqqjyejv9TmpeXp5UrV+oXv/iFxo8fr9///vdavnx5u2OuvPJK3XvvvZo/f77S09PPWbhFOjMC95e//EX9+/fXrFmzVFBQoBEjRuj555/v8vdKTk7Wxo0bdd1112nMmDH68Y9/rBUrVqiwsLDrzQEARATLsHYyAAAAALgGI3kAAAAA4CKEPAAAAABwEUIeAAAAALgIIQ8AAAAAXISQBwAAAAAuQsgDAAAAABch5AEAAACAixDyAAAAAMBFCHkAAAAA4CKEPAAAAABwEUIeAAAAALgIIQ8AAAAAXOT/A+GL7RvnPZzsAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final TV -- Sinkhorn: 2.41e-05, SNS: 2.27e-11, APDAGD: 9.93e-03\n" + ] + } + ], + "source": [ + "# =========================================================\n", + "# Figure 3: TV distance to P_eta* (iterations)\n", + "# =========================================================\n", + "\n", + "# Recompute TV for SNS by re-running\n", + "tv_sns_list = []\n", + "x_t, y_t, a_t = jnp.zeros(n), jnp.zeros(n), jnp.zeros(K + L)\n", + "\n", + "# Sinkhorn stage (N1=20)\n", + "for i in range(20):\n", + " P_t = compute_P(x_t, y_t, a_t, C, Ds, eta)\n", + " x_t = x_t + (jnp.log(r) - jnp.log(jnp.sum(P_t, axis=1).clip(1e-300))) / eta\n", + " P_t = compute_P(x_t, y_t, a_t, C, Ds, eta)\n", + " y_t = y_t + (jnp.log(c) - jnp.log(jnp.sum(P_t, axis=0).clip(1e-300))) / eta\n", + " a_t, t_t = newton_constraint_update(x_t, y_t, a_t, C, Ds, r, c, eta, K, L, n_newton=3)\n", + " x_t = x_t + t_t * jnp.ones(n)\n", + " P_t = compute_P(x_t, y_t, a_t, C, Ds, eta)\n", + " tv_sns_list.append(tv_distance(P_t, P_ref))\n", + "\n", + "# Newton stage (N2=80)\n", + "M_dim = K + L\n", + "v = jnp.concatenate([jnp.ones(n), -jnp.ones(n), jnp.zeros(M_dim)])\n", + "z_t = jnp.concatenate([x_t, y_t, a_t])\n", + "z_t = z_t - (jnp.dot(z_t, v) / jnp.dot(v, v)) * v\n", + "\n", + "for it in range(80):\n", + " xt, yt, at = z_t[:n], z_t[n:2*n], z_t[2*n:]\n", + " Pt = compute_P(xt, yt, at, C, Ds, eta)\n", + " P1v = jnp.sum(Pt, axis=1)\n", + " Pt1v = jnp.sum(Pt, axis=0)\n", + " P_flat_sorted = jnp.sort(Pt.ravel())[::-1]\n", + " keep = min(3 * n, n * n)\n", + " rho_v = float(P_flat_sorted[keep - 1]) if keep < n * n else 0.0\n", + " Ps = sparsify(Pt, rho_v)\n", + "\n", + " # Gradient\n", + " gx = r - P1v\n", + " gy = c - Pt1v\n", + " ga = jnp.zeros(M_dim)\n", + " for k_i in range(K):\n", + " ga = ga.at[k_i].set(jnp.exp(-eta * at[k_i] - 1.0) - jnp.sum(Pt * Ds[k_i]))\n", + " for l_i in range(L):\n", + " ga = ga.at[K + l_i].set(-jnp.sum(Pt * Ds[K + l_i]))\n", + " gf = jnp.concatenate([gx, gy, ga])\n", + " rv = jnp.sum(xt) - jnp.sum(yt)\n", + " greg = jnp.concatenate([-rv * jnp.ones(n), rv * jnp.ones(n), jnp.zeros(M_dim)])\n", + " gft = gf + greg\n", + "\n", + " # Hessian\n", + " dim = 2 * n + M_dim\n", + " xa_b = jnp.zeros((n, M_dim))\n", + " ya_b = jnp.zeros((n, M_dim))\n", + " for m_i in range(M_dim):\n", + " xa_b = xa_b.at[:, m_i].set(-eta * jnp.sum(Pt * Ds[m_i], axis=1))\n", + " ya_b = ya_b.at[:, m_i].set(-eta * jnp.sum(Pt * Ds[m_i], axis=0))\n", + " aa_b = jnp.zeros((M_dim, M_dim))\n", + " for m_i in range(M_dim):\n", + " for m_j in range(m_i, M_dim):\n", + " vv = -eta * jnp.sum(Pt * Ds[m_i] * Ds[m_j])\n", + " aa_b = aa_b.at[m_i, m_j].set(vv)\n", + " if m_j > m_i:\n", + " aa_b = aa_b.at[m_j, m_i].set(vv)\n", + " if m_i < K:\n", + " aa_b = aa_b.at[m_i, m_i].add(-eta * jnp.exp(-eta * at[m_i] - 1.0))\n", + " Hm = jnp.zeros((dim, dim))\n", + " Hm = Hm.at[:n, :n].set(-eta * jnp.diag(P1v))\n", + " Hm = Hm.at[:n, n:2*n].set(-eta * Ps)\n", + " Hm = Hm.at[n:2*n, :n].set(-eta * Ps.T)\n", + " Hm = Hm.at[n:2*n, n:2*n].set(-eta * jnp.diag(Pt1v))\n", + " Hm = Hm.at[:n, 2*n:].set(xa_b)\n", + " Hm = Hm.at[n:2*n, 2*n:].set(ya_b)\n", + " Hm = Hm.at[2*n:, :n].set(xa_b.T)\n", + " Hm = Hm.at[2*n:, n:2*n].set(ya_b.T)\n", + " Hm = Hm.at[2*n:, 2*n:].set(aa_b)\n", + " Hm = Hm - jnp.outer(v, v)\n", + "\n", + " dz = -jnp.linalg.solve(Hm, gft)\n", + " alpha_v = 1.0\n", + " def ft_eval(zz):\n", + " return dual_objective(zz[:n], zz[n:2*n], zz[2*n:], C, Ds, r, c, eta, K) - 0.5*(jnp.sum(zz[:n])-jnp.sum(zz[n:2*n]))**2\n", + " fc = ft_eval(z_t)\n", + " desc = jnp.dot(gft, dz)\n", + " for _ls in range(30):\n", + " zt_ = z_t + alpha_v * dz\n", + " if ft_eval(zt_) >= fc + 1e-4 * alpha_v * desc:\n", + " break\n", + " alpha_v *= 0.5\n", + " z_t = z_t + alpha_v * dz\n", + " Pt = compute_P(z_t[:n], z_t[n:2*n], z_t[2*n:], C, Ds, eta)\n", + " tv_sns_list.append(tv_distance(Pt, P_ref))\n", + "\n", + "# APDAGD TV\n", + "tv_apd_list = []\n", + "z_a = jnp.zeros(2*n + M_dim)\n", + "zeta_a = jnp.zeros(2*n + M_dim)\n", + "alpha_a, beta_a, L_a = 0.0, 0.0, 1.0\n", + "for k_iter in range(400):\n", + " M_a = L_a / 2.0\n", + " for _inner in range(50):\n", + " M_a = 2.0 * M_a\n", + " al_new = (1.0 + jnp.sqrt(1.0 + 4.0*M_a*beta_a))/(2.0*M_a)\n", + " be_new = beta_a + al_new\n", + " tau_a = al_new / be_new\n", + " lam_a = tau_a * zeta_a + (1.0-tau_a)*z_a\n", + " gx_a, gy_a, ga_a, _ = compute_gradient(lam_a[:n], lam_a[n:2*n], lam_a[2*n:], C, Ds, r, c, eta, K, L)\n", + " g_a = jnp.concatenate([gx_a, gy_a, ga_a])\n", + " zeta_new = zeta_a + al_new * g_a\n", + " z_new_a = tau_a * zeta_new + (1.0-tau_a)*z_a\n", + " f_z_a = dual_objective(z_new_a[:n], z_new_a[n:2*n], z_new_a[2*n:], C, Ds, r, c, eta, K)\n", + " f_l_a = dual_objective(lam_a[:n], lam_a[n:2*n], lam_a[2*n:], C, Ds, r, c, eta, K)\n", + " diff_a = z_new_a - lam_a\n", + " if f_z_a >= f_l_a + jnp.dot(g_a, diff_a) - (M_a/2.0)*jnp.dot(diff_a, diff_a):\n", + " break\n", + " L_a = M_a / 2.0\n", + " z_a = z_new_a\n", + " zeta_a = zeta_new\n", + " alpha_a = al_new\n", + " beta_a = be_new\n", + " Pa = compute_P(z_a[:n], z_a[n:2*n], z_a[2*n:], C, Ds, eta)\n", + " tv_apd_list.append(tv_distance(Pa, P_ref))\n", + "\n", + "fig, ax = plt.subplots(1, 1, figsize=(9, 6))\n", + "ax.semilogy(range(1, len(tv_sink)+1), tv_sink, label=\"Sinkhorn\", linewidth=1.5)\n", + "ax.semilogy(range(1, len(tv_apd_list)+1), tv_apd_list, label=\"APDAGD\", linewidth=1.5, linestyle=\"--\")\n", + "ax.semilogy(range(1, len(tv_sns_list)+1), tv_sns_list, label=\"Sinkhorn-Newton-Sparse\", linewidth=1.5)\n", + "ax.axvline(20, color=\"gray\", linestyle=\"--\", alpha=0.5, label=\"Switch to Newton (SNS)\")\n", + "ax.set_xlabel(\"Iterations\")\n", + "ax.set_ylabel(r\"TV distance to $P_\\eta^\\star$\")\n", + "ax.set_title(r\"Figure 3: Convergence in TV distance (random assignment, n=500, $\\eta$=1200)\")\n", + "ax.legend(fontsize=10)\n", + "ax.grid(True, alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "print(f\"Final TV -- Sinkhorn: {tv_sink[-1]:.2e}, SNS: {tv_sns_list[-1]:.2e}, APDAGD: {tv_apd_list[-1]:.2e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Aggregating 50 seeds...\n" + ] + } + ], + "source": [ + "# ============================================================\n", + "# 0. Random instance generator\n", + "# ============================================================\n", + "def generate_random_instance(n, key, K=1, L=1):\n", + " \"\"\"Random OT instance with K inequality + L equality constraints.\"\"\"\n", + " k1, k2, k3 = jax.random.split(key, 3)\n", + " C = jax.random.uniform(k1, (n, n))\n", + " r = jnp.ones(n) / n\n", + " c = jnp.ones(n) / n\n", + " DI = convert_leq_constraint(jax.random.uniform(k2, (n, n)), 0.5, n)\n", + " DE = convert_eq_constraint( jax.random.uniform(k3, (n, n)), 0.5, n)\n", + " Ds = jnp.stack([DI, DE]) # shape (K+L, n, n)\n", + " return C, r, c, Ds, K, L\n", + "\n", + "# ============================================================\n", + "# 1. SNS Newton step\n", + "# ============================================================\n", + "def sns_newton_step(x, y, a, C, Ds, r, c, eta, K, L):\n", + " \"\"\"A single sparse Newton step for the full variable (x, y, a).\"\"\"\n", + " n = x.shape[0]\n", + " M_dim = K + L\n", + " dim = 2 * n + M_dim\n", + " z = jnp.concatenate([x, y, a])\n", + "\n", + " gx, gy, ga, Pt = compute_gradient(x, y, a, C, Ds, r, c, eta, K, L)\n", + " gf = jnp.concatenate([gx, gy, ga])\n", + "\n", + " v = jnp.concatenate([jnp.ones(n), -jnp.ones(n), jnp.zeros(M_dim)])\n", + " rv = jnp.sum(x) - jnp.sum(y)\n", + " gft = gf - rv * v\n", + "\n", + " P1v, Pt1v = jnp.sum(Pt, axis=1), jnp.sum(Pt, axis=0)\n", + " keep = min(3 * n, n * n)\n", + " rho = float(jnp.sort(Pt.ravel())[::-1][keep - 1]) if keep < n * n else 0.0\n", + " Ps = jnp.where(Pt >= rho, Pt, 0.0)\n", + "\n", + " xa_b = jnp.zeros((n, M_dim))\n", + " ya_b = jnp.zeros((n, M_dim))\n", + " for m in range(M_dim):\n", + " xa_b = xa_b.at[:, m].set(-eta * jnp.sum(Pt * Ds[m], axis=1))\n", + " ya_b = ya_b.at[:, m].set(-eta * jnp.sum(Pt * Ds[m], axis=0))\n", + "\n", + " aa_b = jnp.zeros((M_dim, M_dim))\n", + " for mi in range(M_dim):\n", + " for mj in range(mi, M_dim):\n", + " val = -eta * jnp.sum(Pt * Ds[mi] * Ds[mj])\n", + " aa_b = aa_b.at[mi, mj].set(val)\n", + " if mj > mi:\n", + " aa_b = aa_b.at[mj, mi].set(val)\n", + " if mi < K:\n", + " aa_b = aa_b.at[mi, mi].add(-eta * jnp.exp(-eta * a[mi] - 1.0))\n", + "\n", + " Hm = jnp.zeros((dim, dim))\n", + " Hm = Hm.at[:n, :n ].set(-eta * jnp.diag(P1v))\n", + " Hm = Hm.at[:n, n:2*n].set(-eta * Ps)\n", + " Hm = Hm.at[n:2*n, :n ].set(-eta * Ps.T)\n", + " Hm = Hm.at[n:2*n, n:2*n].set(-eta * jnp.diag(Pt1v))\n", + " Hm = Hm.at[:n, 2*n: ].set(xa_b)\n", + " Hm = Hm.at[n:2*n, 2*n: ].set(ya_b)\n", + " Hm = Hm.at[2*n:, :n ].set(xa_b.T)\n", + " Hm = Hm.at[2*n:, n:2*n].set(ya_b.T)\n", + " Hm = Hm.at[2*n:, 2*n: ].set(aa_b)\n", + " Hm = Hm - jnp.outer(v, v)\n", + "\n", + " dz = -jnp.linalg.solve(Hm, gft)\n", + " f_curr = dual_objective(x, y, a, C, Ds, r, c, eta, K)\n", + " descent = jnp.dot(gft, dz)\n", + " alpha = 1.0\n", + "\n", + " for _ in range(20):\n", + " z_new = z + alpha * dz\n", + " f_new = dual_objective(z_new[:n], z_new[n:2*n], z_new[2*n:],\n", + " C, Ds, r, c, eta, K)\n", + " if f_new <= f_curr + 1e-4 * alpha * descent:\n", + " break\n", + " alpha *= 0.5\n", + "\n", + " if alpha < 1e-6:\n", + " return x, y, a\n", + " z_new = z + alpha * dz\n", + " return z_new[:n], z_new[n:2*n], z_new[2*n:]\n", + "\n", + "# ============================================================\n", + "# 2. Aggregation across random seeds\n", + "# ============================================================\n", + "n, eta = 100, 1200\n", + "K, L = 1, 1\n", + "N1, N2 = 20, 80\n", + "n_iters = N1 + N2\n", + "n_seeds = 50\n", + "\n", + "all_tv_sns = []\n", + "print(f\"Aggregating {n_seeds} seeds (n={n}, eta={eta}) ...\")\n", + "\n", + "for seed in range(n_seeds):\n", + " key = jax.random.PRNGKey(seed)\n", + " C_s, r_s, c_s, Ds_s, K_s, L_s = generate_random_instance(n, key, K=K, L=L)\n", + "\n", + " _, _, _, P_ref_s, _ = constrained_sinkhorn(\n", + " C_s, Ds_s, r_s, c_s, eta, K=K_s, L=L_s,\n", + " n_iters=250, n_newton=5, verbose=False\n", + " )\n", + "\n", + " x_s, y_s, a_s = jnp.zeros(n), jnp.zeros(n), jnp.zeros(K_s + L_s)\n", + " tv_vals = []\n", + "\n", + " for i in range(n_iters):\n", + " if i < N1:\n", + " P_curr = compute_P(x_s, y_s, a_s, C_s, Ds_s, eta)\n", + " x_s = x_s + (jnp.log(r_s) - jnp.log(jnp.sum(P_curr, axis=1).clip(1e-300))) / eta\n", + " P_curr = compute_P(x_s, y_s, a_s, C_s, Ds_s, eta)\n", + " y_s = y_s + (jnp.log(c_s) - jnp.log(jnp.sum(P_curr, axis=0).clip(1e-300))) / eta\n", + " a_s, t_s = newton_constraint_update(x_s, y_s, a_s, C_s, Ds_s,\n", + " r_s, c_s, eta, K_s, L_s)\n", + " x_s = x_s + t_s * jnp.ones(n)\n", + " else:\n", + " if i == N1:\n", + " v = jnp.concatenate([jnp.ones(n), -jnp.ones(n), jnp.zeros(K_s + L_s)])\n", + " z = jnp.concatenate([x_s, y_s, a_s])\n", + " z = z - (jnp.dot(z, v) / jnp.dot(v, v)) * v\n", + " x_s, y_s, a_s = z[:n], z[n:2*n], z[2*n:]\n", + " x_s, y_s, a_s = sns_newton_step(x_s, y_s, a_s, C_s, Ds_s,\n", + " r_s, c_s, eta, K_s, L_s)\n", + "\n", + " P_iter = compute_P(x_s, y_s, a_s, C_s, Ds_s, eta)\n", + " tv_vals.append(float(jnp.maximum(tv_distance(P_iter, P_ref_s), 1e-16)))\n", + "\n", + " all_tv_sns.append(tv_vals)\n", + " if (seed + 1) % 10 == 0:\n", + " print(f\" seed {seed+1:3d}/{n_seeds} | \"\n", + " f\"TV after Sinkhorn: {tv_vals[N1-1]:.2e} | \"\n", + " f\"TV after Newton: {tv_vals[-1]:.2e}\")\n", + "\n", + "# ============================================================\n", + "# 3. Plot\n", + "# ============================================================\n", + "all_tv_sns = np.array(all_tv_sns)\n", + "iters = np.arange(1, n_iters + 1)\n", + "median = np.median(all_tv_sns, axis=0)\n", + "q25 = np.percentile(all_tv_sns, 25, axis=0)\n", + "q75 = np.percentile(all_tv_sns, 75, axis=0)\n", + "\n", + "plt.figure(figsize=(8, 6))\n", + "plt.semilogy(iters, median, color='tab:green', lw=2, label='SNS Median')\n", + "plt.fill_between(iters, q25, q75, color='tab:green', alpha=0.2, label='IQR 25–75%')\n", + "plt.axvline(N1, color=\"red\", linestyle=\"--\", label=\"Switch to Newton\")\n", + "plt.xlabel(\"Iterations\")\n", + "plt.ylabel(r\"TV distance to $P_\\eta^\\star$\")\n", + "plt.title(rf\"SNS Performance over {n_seeds} Seeds ($n$={n}, $\\eta$={eta})\")\n", + "plt.legend()\n", + "plt.grid(True, which='both', alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.savefig(\"figure4.png\", dpi=150)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment 2: 1D Transport under $\\ell_2$ Constraint (Figure 1)\n", + "\n", + "Reproducing **Figure 1** of the paper:\n", + "- Main cost = $\\ell_1$ Manhattan distance: $c_1(x,y) = |x-y|$\n", + "- Inequality constraint on $\\ell_2^2$ Euclidean distance: $\\langle C_2, P\\rangle \\le t$\n", + "- By tuning $t$, the transport plan interpolates from minimising $\\ell_2$ cost (binding constraint)\n", + " to minimising $\\ell_1$ cost (no constraint).\n", + "\n", + "$$\\min_{P \\in \\mathcal{U}_{r,c}} \\langle C_1, P\\rangle \\quad\\text{s.t.}\\quad \\langle C_2, P\\rangle \\le t$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "t_min (W2 distance) = 0.0275\n", + "t_max (L2 cost of L1-optimal plan) = 0.0310\n", + "[1/4] t=0.0275 | L1=0.1051, L2=0.0275, viol=5.73e-06\n", + "[2/4] t=0.0287 | L1=0.1060, L2=0.0287, viol=3.71e-06\n", + "[3/4] t=0.0298 | L1=0.1067, L2=0.0298, viol=1.95e-06\n", + "[4/4] t=0.0310 | L1=0.1074, L2=0.0310, viol=0.00e+00\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAB7cAAAIICAYAAAAFVmqVAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAx6ZJREFUeJzs3XmcjXXj//H3md02dgY3JmtkJxLCTUaLpaSS7BHfhObO+rOmIt2kUtxkS0npRt0qNcYSZSnDt6whW2WnGQazXr8/+s65HTPXmXPOnDNzHfN63o/r8bhdn8/1OZ/rLPPu+nyuxWYYhiEAAAAAAAAAAAAAACwsIK87AAAAAAAAAAAAAABAdpjcBgAAAAAAAAAAAABYHpPbAAAAAAAAAAAAAADLY3IbAAAAAAAAAAAAAGB5TG4DAAAAAAAAAAAAACyPyW0AAAAAAAAAAAAAgOUxuQ0AAAAAAAAAAAAAsDwmtwEAAAAAAAAAAAAAlsfkNgAAAAAAAAAAAADA8pjcBgAAAAAAAAAAAABYHpPbAAAAAAAAAAAAAADLY3IbAAAA8DObNm2SzWbTkiVL8rorgIPIyEi1adMmr7shKeu+mP12jh07pq5du6p06dKy2Wzq27ev0/VwXX78e5XVPufH98FKeP8BAACA2weT2wAAAIBFZAy+my3bt2/P6y76zLRp09S9e3dVqVJFNptNkZGRedL2rZ9BYGCgihcvrjp16qhPnz5at26dDMNw6/UTEhI0depUNWrUSEWKFFHBggVVu3ZtjRw5UmfPnnWo6+zzv3U5fvy4B+8GstK3b19t3rxZo0eP1rJly/Tss886Xe8P9uzZo8mTJ7v9Pdm2bZvuuusulSxZUvv37/dN5yDJ88/oduIv74G/9BMAAADID4LyugMAAAAAHPXo0UMPPvhgpvXVqlWTJN133326fv26goODc7trPjNu3DiVKFFCjRo10p9//pnnbWd8BoZh6MqVKzp06JDWrFmj999/X+3bt9fKlStVrFixbNv55ZdfFBUVpRMnTujRRx/VgAEDFBwcrO3bt+vNN9/U4sWL9Z///EfNmzeXJC1btsxh+y1btmj+/PkaNGiQWrVq5VBWunRpl/YF/5XVbycpKUlbtmzR0KFD9eKLL2a73l/s2bNHU6ZMUZs2bVw+WSQlJUVPPPGEevfurblz5youLk61a9fOUT9ux79XnsjqffDkM7rd5NZ7kNPvIZ8VAAAAYB1MbgMAAAAW06hRIz399NOm5QEBAQoLC8vFHmWWlpampKQkFSxY0CvtHT16VFWqVJEk1alTR1evXvVKu562ndVnMGvWLI0aNUqzZs1Sjx499NVXXzlt49q1a+rUqZN+//13/ec//9FDDz1kLxs0aJD+53/+R+3bt1eXLl30888/q2zZspleMzU1VfPnz1fz5s2dfiduV97+nmX12zl79qwMw1CJEiVcWu8N3t4vb1m7dq3OnTunIUOGaPr06SpfvnyO27TC3ysr4H3wHk9+P7z/AAAAwO2D25IDAAAAfsbs2aHHjx9Xt27dFB4ervDwcHXp0kXHjh3L9OzhyZMnm97WOqvnFC9ZskQ2m03r16/X1KlTVbVqVYWFhemTTz6R9NcVrq+++qruuusuhYWFqVixYurUqZN2797t8j5lTD674ujRozp48KBP2nYmMDBQM2fOVMuWLbVu3Tpt3brVaf2FCxfql19+0YgRIxwmtjM0adJEr776qs6fP6/XX3/dK33M4M5nnPH5btiwQf/85z9VtWpVhYaGqkaNGlq6dGmW7Z86dUqPP/64ihYtqvDwcHXq1ElHjx7Nsq6r34/svmdm3OnLrb+dvn37qnLlypKkKVOm2G/5HhkZmeX6TZs2eX2/3G3Hlc9p8uTJ6tevnySpbdu29v5n98zw1atXq3Xr1jp27JhsNpsaN27stL4rsvp75e53zp2/MWbfh6z+trnzO7ly5YrGjx+vZs2aqVSpUgoNDVW1atU0ZswYXbt2ze33wdlntHr1atlsNi1YsCDLtu666y5Vq1bNpUckJCcna8aMGWrQoIEKFiyookWLqkmTJpozZ45DvQsXLui5555TxYoVFRISoooVK+q5557TxYsXHeq589nduHFDkydPVs2aNVWwYEEVK1ZMdevW1ciRI7N9D259vax+P+58Jjn5HmbXz+z2EwAAAIB3ceU2AAAAYDHXrl3ThQsXHNaFhoaqSJEipttcvHhRrVq10tmzZzV48GDVqlVLW7ZsUdu2bZWYmOiVfr344otKSUnRwIEDFR4erpo1ayolJUUdO3bU999/r169emno0KGKj4/XggUL1KJFC3377bdq0qSJV14/Q7t27XTixAm3n33tLQMGDNDWrVv1xRdfqGXLlqb1Pv30U0l/XaVtpm/fvhoxYoT+/e9/65///KfX++qOcePG6fr163r22WcVGhqquXPnqm/fvqpWrZpatGhhr/fnn3/qvvvu06lTpzR48GDVrl1bmzdvVtu2bXX9+nWHNj35fmT1PTPjTl+y8uyzz6pBgwZ64YUX9Mgjj+jRRx+VJFWsWFG7d+/OtL5WrVpe3S9P2nHlc3r00Ud1+vRpzZ8/X+PGjbP3u2rVqk7fj6+//lovvviiYmJidO+996po0aLZvoc54cq+uPMe5fT74Mzvv/+u9957T926ddNTTz2loKAgbd68WTNmzNDu3bv19ddfu9Wes8/o7rvvVkREhBYtWqSBAwc6bLd9+3bt379fr7zyimw2m9PXSE5OVlRUlDZt2qQOHTro6aefVlhYmH7++WetWrVKQ4cOlSTFx8fr3nvv1ZEjR9S/f381atRIu3fv1ty5c7Vhwwbt3LkzU/648tk999xzWrRokXr37q3o6Gilpqbq8OHD2rBhQ7bvwa2y+v146zPJbl+y62d2+wkAAADAywwAAAAAlrBx40ZDUpbLE088kane4sWL7etGjhxpSDI++OADhzYz1rdu3dq+btKkSYYk49ixY5n6ULlyZYe6hmEYixcvNiQZNWrUMBITEx3KZs2aZUgy1q1b57A+Pj7eqFixYqa2XHHXXXcZlStXNi2vXLmy4emhTHZtZ7y3r7/+ummdXbt2GZKMRx991OlrlShRwihSpEi2fapbt64hybhy5Uqmsoz3/ubP2hXufMYZr9GgQQMjKSnJvv63334zQkJCjCeffNJh+7FjxxqSjEWLFjmsHz58eKbvmjvfD2ffMzPu9MUwsv7tHDt2zJBkTJo0yaGu2Xpv7pcn7bj6OWXU37hxY6b+Z+XgwYOGJGPz5s3GnXfeaSxZssSl7bKT1Xvuzr648x65+31w53eSlJRkJCcnZ6o3fvx4Q5KxY8cOp/vs7H3I6jPK2Jd9+/Y5rH/mmWeMwMBA4/fff8+0za1ee+01Q5IxduzYTGVpaWn2/z9u3DhDkvHOO+841JkzZ44hyRg/fnymPrvy2RUvXtx44IEHnPYxu++ps99PTj8Td/bFWT9d2U8AAAAA3sNtyQEAAACLGTRokGJiYhyW8ePHO93mP//5j8qVK6cePXo4rH/xxRe91q8hQ4ZkesbpBx98oDvvvFONGzfWhQsX7EtycrLuv/9+bd26NcdXTN7q+PHjeXbVtiSFh4dLkhISEpzWS0hIcOnK14z24uPjc965HPif//kfhYSE2P9doUIF1ahRQ4cPH3aot2bNGpUtW1a9e/d2WD969OhMbXry/cjqe2bGnb54k7f2y5N2XP2c3LV9+3bZbDZduXJFly5d0uOPP56j9lzhyr648x758vsQEhKi4OBgSVJqaqouX76sCxcuqH379pKkHTt25Pg1bjZw4EDZbDYtXLjQvi4xMVEff/yxHnjgAZeeh/7hhx+qePHimjhxYqaygID/DgetXr1apUuXznSXiWeffValS5fW6tWrM23vymdXtGhR7du3T3v37s22r9nJ6vfjrc8kp78pb+4nAAAAgOxxW3IAAADAYqpXr24fnHfVsWPH1LRpU4cJC0kqU6aMihUr5pV+1ahRI9O6AwcO6Pr16ypdurTpdhcuXFDFihW90gcryJjUzpiUNhMeHp7tBPjN7fn6FtDZyerZ5CVLltSJEycc1v3666+6++67FRgY6LC+XLlymb5rnnw/svqemXGnL97krf3ypB1XPyd37d+/XxUqVNC//vUvjRs3TgUKFLCXJSUlaejQoYqNjdX58+dVrlw5Pf/883r++edz9Jqu7Is775Gvvw/vvvuu5s2bp3379ik9Pd2h7PLlyzlu/2Z33HGH2rdvr2XLlmn69OkKDg62P2f6mWeecamNw4cPq0GDBgoLC3Na79ixY2rSpImCghyHiIKCglSjRg3FxcVl2saVz2727Nnq1auX6tatqypVqqht27bq1KmTOnXqlCmrsmP2d8Ebn0lOf1Pe3E8AAAAA2WNyGwAAAMhnnD2nNTU11bQsq6tpDcNQ3bp1NWvWLNPtnE1K+aOffvpJkpw+C1qS6tSpo2+//VZHjhxRtWrVsqxz7do1HTx4UJGRkSpcuLDX+ujJZ3zrhGCGnFwl78n3w9WrtvOSt/bLk3Z88TlJf00QJyUlad++fVq5cqVDWWpqqiIiIvTNN9+oSpUq+umnnxQVFaWyZcvm6ApvV/bFl39j3PmdzJo1S//4xz/UoUMHDRs2TOXLl1dISIh+//139e3bN9PEqjcMGjRI3bt31+eff65u3bpp4cKFioiI0EMPPeT113KXK59dly5ddPz4cX355ZfavHmz1q9fr4ULF6pVq1Zav369w9XS2cnq9+OtzySnvylv7icAAACA7DG5DQAAANwGIiMjdeTIEaWnpztcKXbu3Dn9+eefDnVLlCghSbp06ZIiIyPt62/cuKHTp0+bTsRmpXr16jp//rz+/ve/55sr1DJuE5zdBNOjjz6qb7/9Vu+9956mT5+eZZ33339fKSkpevTRR73aR29+xreqUqWKDh8+rLS0NIdJodOnT2f6rvn6++FOX7zJW/vly/fH2cStmfPnz2vx4sUKDQ11WF+oUCFNnTrV/u8GDRqoc+fO2rp1q89vX+7Oe+Tu98Gd38myZcsUGRmpr776yqEf69at83DPsv+MunTpojJlymjhwoWqU6eOvvvuO40ePTrTFdZmatSooYMHDyopKSnTZ3qzKlWq6NChQ0pNTXVoOzU1Vb/88kuWVza7qkSJEnr66af19NNPyzAMjRkzRjNmzNBnn32m7t27e/Q9zeCLz8RMdv3Mbj8BAAAAeE/+GH0CAAAAbnOdOnXS6dOn9dFHHzms/+c//5mpbsbtXdevX++w/o033nD76sPevXvrzJkzpldVnj171q32XHH06FEdPHjQ6+1mJy0tTS+++KK2bt2qBx98UC1atHBa/5lnnlG1atU0a9asLCdb4uLiNHbsWJUuXVojR470al+9+RnfqkuXLjp79qzef/99h/WvvfZaprq+/n640xdv8tZ++fL9ybgTwKVLl1yqf+bMGTVo0MClq4JTUlK0ZcsW1atXz+P+ucqd98jd74M7v5PAwEDZbDaHq3lTU1NNT1xxRXafUXBwsPr27auvv/5aU6ZMkSQNGDDA5fZ79uypy5cv6+WXX85UdvN+dO3aVefPn9d7773nUGfBggU6f/68HnnkEZdfM0NaWlqmEwpsNpsaNmwo6b/77O739Ga++EzMmPXT1f0EAAAA4D1cuQ0AAADcBkaPHq3ly5erX79+2rlzp+68805t2bJF33//vUqVKuVw1Vn79u1Vs2ZNTZw4URcvXtQdd9yhrVu3avv27SpVqpRbrzt8+HDFxMRo5MiR2rBhg/7+978rPDxcJ0+eVGxsrMLCwrRx48Zs21m2bJn9+abnz59XcnKyfUKmcuXK6tWrl71uu3btdOLECZdvGetO2xni4uL0wQcfSJKuXLmiQ4cOac2aNTpx4oQ6dOig5cuXZ/u6hQoV0ueff66OHTvqoYceUrdu3dSmTRsFBQVp586dWrZsmQoXLqw1a9YoIiLCpX1xlTc/41uNGjVKy5cv18CBA7Vr1y7ddddd2rRpk7Zt25apbW99P7zRF2/y1n758v25++67FRAQoFdeeUWXL19WoUKFdMcdd6hZs2aZ6n700UeKiYlRYGCgfvvtN+3bt0//+te/9PHHHys4ODhT/aFDh6pIkSLq3bu3R31zhzvvkbvfB3d+J4899pjGjh2rBx54QI8++qgSEhK0fPnyLN8fV7nyGQ0cOFCvv/66PvroI7Vu3VrVq1d3uf3hw4frP//5j15++WX98MMP6tChg8LCwrRv3z4dOnTIPqk/atQorVy5Us8995zi4uLUsGFD7d69WwsXLlTNmjU1atQot/ftypUrKleunDp37qyGDRuqTJkyOnbsmObOnavixYurU6dOLr8HZnzxmZgx62fNmjVd2k8AAAAA3sPkNgAAAHAbKFWqlLZu3ap//OMfWrRokWw2m9q2bauNGzfq7rvvVoECBex1AwMD9fnnn2vYsGF6++23FRISog4dOmjz5s3ZXo18q+DgYH3xxRd69913tWzZMk2aNEmSVL58eTVt2lR9+vRxqZ2FCxdq8+bNDusmTJggSWrdunWWE9Cu8qTtjz76SB999JECAgJUuHBh/e1vf1Pr1q3Vo0cPdezY0eXXrlWrln766Se9+eabWrVqlb788kulpaWpcuXKev755/Xiiy96fWJb8u5nfKvixYtry5Ytio6Otl8h27p1a23cuFHt2rVzqOut74c3+uJN3tovX74/lSpV0qJFi/Taa69pyJAhSklJUZ8+fTJNGiYlJWnlypVasWKFYmNj1bRpU1WuXFmLFy/OcpIwOjpa27Zt04YNG3LlWcLuvEfufh/c+Z2MHDlShmFo4cKFGj58uCIiIvTEE0+oX79+ql27tkf75spnVK1aNbVt21YbNmxw66ptSQoJCdE333yjmTNnavny5Ro3bpzCwsJUvXp19evXz16vaNGi+u677zRp0iR9/vnnWrx4scqWLavBgwdrypQpKlKkiNv7VrBgQY0YMUKxsbFav369rl69ap8EHjt2rMqXL+/ye2DGF5+JGbN+zp8/36X9BAAAAOA9NsPVyx0AAAAA+J2LFy+qVKlSevbZZzVv3ry87g4AP5YxibdhwwaVLl06r7vjlsjISEVGRmrTpk153RW3Pfjgg9q2bZv++OMPhxOVAAAAACA/4pnbAAAAwG3i+vXrmdZlPHv0/vvvz+3uALiNDBs2TOvXr/fLiW1/duTIEX399dd6+umnmdgGAAAAAHHlNgAAAHDbaNu2rSpXrqxGjRopPT1dsbGxWrt2re699159++23CgwMzOsuAvBDJ06cUGRkpEJDQxUU9N+nm7Vq1UpfffVVHvbMdf525faOHTt04MABvfXWWzpw4IAOHDigyMjIvO4WAAAAAOQ5nrkNAAAA3CYefvhhvf/++1q9erWuX7+uv/3tb/rHP/6hSZMmMbENwGOVK1cW58Xnrrlz5+r9999XlSpV9OGHHzKxDQAAAAD/hyu3AQAAAAAAAAAAAACWxzO3AQAAAAAAAAAAAACWx+Q2AAAAAAAAAAAAAMDymNwGAAAAAAAAAAAAAFgek9sAAAAAAAAAAAAAAMtjchsAAAAAAAAAAAAAYHlMbgMAAAAAAAAAAAAALI/JbQAAAAAAAAtISkrSwIEDVaVKFRUpUkQ1atTQ22+/ndfdAgAAAADLCMrrDgAAAAAAAEBKTU1VRESEvvnmG1WpUkU//fSToqKiVLZsWT3++ON53T0AAAAAyHNcuQ0AAAD4qb179yooKEgxMTF51odNmzbJZrNpyZIledYH+L8lS5bIZrNp06ZNed0VWFR++VtTqFAhTZ06VdWqVVNAQIAaNGigzp07a+vWrXndNZdk9Tm5+9lFRkaqTZs2PulfTnz22WcKCQnR4cOH87orAAAAQL7G5DYAAADgp6Kjo9WiRQvdf//9ed2V28qePXs0efJkHT9+PK+7om3btumuu+5SyZIltX///hy3lzHJZLPZNHTo0CzrnDt3TiEhIbLZbJacYHKHs8/SSp+zv/OX99JK/XT1t52SkqItW7aoXr162bZ58+87qyUoiJv35USXLl1Ut25djR49Oq+7AgAAAORrTG4DAAAAfmjbtm2KiYlRdHR0nvbjvvvu0/Xr19WrV6887Yc37dmzR1OmTMnzCbCUlBQ98cQTeuSRRyRJcXFxXms7LCxMy5cvV1JSUqayZcuWyTCM22IizNlnaZXP+XaQG++lN/7WWOUzd+e3PXToUBUpUkS9e/d2uf0ePXpo2bJlmZb3338/x333xO2UE8OHD9fq1au1b9++vO4KAAAAkG8xuQ0AAAD4oXfffVelSpXSgw8+mKf9CAgIUFhYmAIDA/O0H7ejtWvX6ty5cxoyZIji4+NVvnx5r7X9yCOP6PLly/rss88ylS1evFgPPvigQkNDvfZ6wM3S0tJ07do1t7a5nf7WuPrbjo6O1rZt2/TVV18pJCTE5fYbNWqkp59+OtPy1FNPeWsX3HI7fXaPPvqoChYsqHnz5uV1VwAAAIB8i8ltAAAAwM+kpqZqzZo1at++vYKDgx3KMp5dHBsbq5deekmVK1dWgQIF1KxZM23fvl2StHnzZrVs2VKFChVSuXLlNHXqVIc2rly5ovHjx6tZs2YqVaqUQkNDVa1aNY0ZMybThFRWz1LN6MOGDRv0z3/+U1WrVlVoaKhq1KihpUuXuryfycnJmjFjhho0aKCCBQuqaNGiatKkiebMmeNQ78KFC3ruuedUsWJFhYSEqGLFinruued08eJFh3o3btzQ5MmTVbNmTRUsWFDFihVT3bp1NXLkSHudyZMnq1+/fpKktm3b2m/n27dvX5f77S2rV69W69atdezYMdlsNjVu3NhrbTdq1Ej16tXT4sWLHdbv3LlT+/bts78HN3Pne+HpdyA9PT3b+q72w9lnmd3nnBv76ipv/w7c6a83fjMZr7d+/XpNnTpVVatWVVhYmD755JNc/Vvjb7/tESNGKCYmRrGxsSpVqpRP+jF58mTZbLYsr2TP6rnXrn4Xb2X2zO1Tp07p8ccfV9GiRRUeHq5OnTrp6NGjWbaRlJSkV199VXfddZfCwsJUrFgxderUSbt373ao5+vfbuHChdWqVSt9+umnTvcZAAAAgO/4/33mAAAAgHxm165dunr1qpo2bWpaZ8yYMUpLS9Pw4cOVnJysmTNnqkOHDnr//fc1YMAADRo0SD179tQnn3yiiRMn6o477tDTTz8tSfr999/13nvvqVu3bnrqqacUFBSkzZs3a8aMGdq9e7e+/vprl/o5btw4Xb9+Xc8++6xCQ0M1d+5c9e3bV9WqVVOLFi2cbpucnKyoqCht2rRJHTp00NNPP62wsDD9/PPPWrVqlf150fHx8br33nt15MgR9e/fX40aNdLu3bs1d+5cbdiwQTt37lSRIkUkSc8995wWLVqk3r17Kzo6WqmpqTp8+LA2bNhgf91HH31Up0+f1vz58zVu3DjVqlVLklS1alXTvqanp+vSpUsuvSeSVKJECQUEZH+e8ddff60XX3xRMTExuvfee1W0aFGXX8MV/fv3V3R0tH7//XdVqFBBkrRo0SKVKVNGDz/8cKb6nnwv3P0OuFLf1X44+ywLFSrk9HPOjX11hS9+B+7015u/mRdffFEpKSkaOHCgwsPDVbNmzVz9W+PJb1vyze87u9/2sGHDtGHDBm3cuFGlS5d2+bUzXLt2TRcuXMi0PiQkROHh4W63J7n+XXTVn3/+qfvuu0+nTp3S4MGDVbt2bW3evFlt27bV9evXHeqmpKSoY8eO+v7779WrVy8NHTpU8fHxWrBggVq0aKFvv/1WTZo0kZQ7v93mzZvr66+/1sGDB3XnnXe6td8AAAAAvMAAAAAA4FcWLVpkSDI+++yzTGWLFy82JBkNGzY0kpKS7Os/++wzQ5IRFBRk/PDDD/b1SUlJRkREhHHPPfc4rEtOTs7U9vjx4w1Jxo4dO+zrNm7caEgyFi9enKkPDRo0cOjDb7/9ZoSEhBhPPvlktvv42muvGZKMsWPHZipLS0uz//9x48YZkox33nnHoc6cOXMMScb48ePt64oXL2488MAD2b52Rv83btyYbV3DMIxjx44Zklxejh07lm2bBw8eNCQZmzdvNu68805jyZIlLvUlOxmf1+uvv25cuHDBCAkJMV555RXDMAzj2rVrRtGiRY1//OMfhmEYRqFChYzWrVvbt3Xne+Hud8Cd+p70I6vP0lmZL/fVHb74HbjTX2/8ZjLKatSoYSQmJjqU5fbfGnd/24bh/d93dr/t48ePG5KM0NBQo1ChQvalY8eO2fY14z0yWx566CGH+pMmTTLtc+XKlR1+/65+F7P6nLJaN3bsWEOSsWjRIoe2hg8fbkhyeO1Zs2YZkox169Y51I2PjzcqVqyYa3+nMixbtsyQZHz66adZlgMAAADwLa7cBgAAAPzM+fPnJf11haCZIUOGODyjtVWrVpKkZs2a2a9wk/66kq9p06b67rvvHNZlSE1N1ZUrV5SWlqb27dvr5Zdf1o4dO5xeNZ7hf/7nfxzaqlChgmrUqKHDhw9nu+2HH36o4sWLa+LEiZnKbr4qcvXq1SpdurQGDRrkUOfZZ5/VlClTtHr1avtt14sWLap9+/Zp7969qlOnTrZ9cFVERIRiYmLcqp+d7du3y2az6cqVK7p06ZIef/zxnHQxSyVLllTnzp21ZMkSjRs3TqtWrVJ8fLz69++fZX1Pvhfufgdcqe+t76czubGvrvDF78Cd/nrzNzNkyBAVLFjQYZ0V/tZkx9u/7+x+25UrV5ZhGB71NcOgQYPUvXv3TOs9uQo8g6vfRVetWbNGZcuWVe/evR3Wjx49Wm+++abDug8++EB33nmnGjdunOmK9Pvvv19Lly7V9evXVaBAgVz57ZYsWVKSdO7cObf3GwAAAEDOMbkNAAAA+BmbzSZJTidAqlSp4vDv4sWLS5LuuOOOTHWLFy+e6bm87777rubNm6d9+/YpPT3doezy5csu9fPWPkh/TQqcOHEi220PHz6sBg0aKCwszGm9Y8eOqUmTJgoKcjy0CQoKUo0aNRQXF2dfN3v2bPXq1Ut169ZVlSpV1LZtW3Xq1EmdOnXyaHImQ1hYmNq3b+/x9lnZv3+/KlSooH/9618aN26cChQo4FD+ySef6K233tKePXtUqlSpLJ+Z64p+/frpoYce0tatW7Vo0SI1bdpUtWvXNq3v7vfC3e+Aq/W98f3Mjq/31RW++B24019v/mZq1KiR5fq8/luTHW//vrP7bXtD9erVvf43ydXvoqt+/fVX3X333QoMDHRYX65cORUrVsxh3YEDB3T9+nWnk/MXLlxQxYoVJfn+t5uRvRlZDAAAACB3MbkNAAAA+JmMAX5nz4G9dcIgu/U3mzVrlv7xj3+oQ4cOGjZsmMqXL6+QkBD9/vvv6tu3b6bJAnf7kNOrEj3VpUsXHT9+XF9++aU2b96s9evXa+HChWrVqpXWr1/vcOWeO9LS0uxX07uidOnS2X4OFy5cUFJSkvbt26eVK1dmKi9evLiGDh2qs2fP6o033nC7zxmioqJUoUIFTZkyRRs3btTcuXNN63ryvXD3O+BKfW99P53JjX3Na67015u/mVuv2pb842+Nt3/f2f22c5OzydnU1NRc7IlzhmGobt26mjVrlmmdjFzMjd9uRvbm5Ep4AAAAAJ5jchsAAADwMxm3B/bGLXezsmzZMkVGRuqrr75yuDpz3bp1Pnm9rNSoUUMHDx5UUlKSQkNDTetVqVJFhw4dUmpqqsNVq6mpqfrll18yXZFXokQJPf3003r66adlGIbGjBmjGTNm6LPPPrPfxtfdq/FOnTqV5RXxZo4dO6bIyMhs650/f16LFy/Ocv/vv/9+SX/d2jcnAgMD1bt3b02bNk0FChRQjx49TOta4Xvhbj+cfZbOyqyyr776HbjDF7+ZDLn9PnvST1/8vs1+20lJSRo6dKhiY2N1/vx5lStXTs8//7yef/55t/vtioxHW1y6dMmhzzdu3NDp06dVrVo1+zpXv4uuqlKlig4fPqy0tDSHyeXTp0/rzz//dKhbvXp1nT9/Xn//+9+zvWNAbnynjhw5IklefbwFAAAAANd5fu89AAAAAHmiYcOGCg8P1/bt233SfmBgoGw2m8NVa6mpqZo+fbpPXi8rPXv21OXLl/Xyyy9nKru5X127dtX58+f13nvvOdRZsGCBzp8/r0ceeUTSX1df3jphYrPZ1LBhQ0mOV8EXLlw40zpnMp7J6+riyjO3z5w5owYNGuihhx5yqQ85MXjwYE2aNEnz5s1TeHi4aT0rfC/c7Yezz9JZmVX21du/A3f48jeTIbffZ0/66e3ft7PfdmpqqiIiIvTNN98oPj5en3zyiV5++WV98skn7u2oizJuFb9+/XqH9W+88UamK5xd/S66qkuXLjp79qzef/99h/WvvfZaprq9e/fWmTNnTK/cPnv2rP3/58Z3avv27Spbtqxq1qzptTYBAAAAuI4rtwEAAAA/ExgYqEcffVRr1qzx2lV0N3vsscc0duxYPfDAA3r00UeVkJCg5cuXKzg42Kuv48zw4cP1n//8Ry+//LJ++OEHdejQQWFhYdq3b58OHTpkn4wZNWqUVq5cqeeee05xcXFq2LChdu/erYULF6pmzZoaNWqUJOnKlSsqV66cOnfurIYNG6pMmTI6duyY5s6dq+LFi6tTp07217777rsVEBCgV155RZcvX1ahQoV0xx13qFmzZln21dvP5P3oo48UExOjwMBA/fbbb9q3b5/+9a9/6eOPP/bJZ1CpUiVNnjw523pW+F642w9nn6WzMl/va2RkpE6cOJHtpKC3fwfu8OVvJkNuf6c86ac3f9/Z/bYLFSqkqVOn2us3aNBAnTt31tatW/X444+7/DpxcXH64IMPsizr2rWrfZK/ffv2qlmzpiZOnKiLFy/qjjvu0NatW7V9+3aVKlXKYTtXv4uuGjVqlJYvX66BAwdq165duuuuu7Rp0yZt27Yty9eOiYnRyJEjtWHDBv39739XeHi4Tp48qdjYWIWFhWnjxo2SfP+dunr1qrZs2aL+/ft7pT0AAAAA7mNyGwAAAPBDQ4YM0ZIlS7R27Vp169bNq22PHDlShmFo4cKFGj58uCIiIvTEE0+oX79+ql27tldfy0xISIi++eYbzZw5U8uXL9e4ceMUFham6tWrq1+/fvZ6RYsW1XfffadJkybp888/1+LFi1W2bFkNHjxYU6ZMUZEiRST99bzfESNGKDY2VuvXr9fVq1ftE3djx45V+fLl7W1WqlRJixYt0muvvaYhQ4YoJSVFffr0yXaizhuSkpK0cuVKrVixQrGxsWratKkqV66sxYsX5/ok8q2s8L1wtx/OPktnZb7e16tXrzp858x4+3fgjtz4zeT2d8rfftspKSnasmWLXnzxRbde66OPPtJHH32UZdnhw4fttxsPDAzU559/rmHDhuntt99WSEiIOnTooM2bN6tFixYO27n6XXRV8eLFtWXLFkVHR9uv3m7durU2btyodu3aOdQNDg7WF198oXfffVfLli3TpEmTJEnly5dX06ZN1adPH3tdX3+n/v3vf+vatWt69tlnc9wWAAAAAM/YDE/uHwUAAAAgz3Xs2FGJiYnasmVLXncFeWTNmjUaMWKEjh8/ntddgYt++ukn1a9fX4sWLfJoUhD5x7PPPqu4uDh99913CgkJyevuQFKjRo0UGRmpVatW5XVXAAAAgHyLZ24DAAAAfmrmzJnatm2bvvnmm7zuCnJZWlqabty4oZSUFBmGoRs3bigpKSmvuwUXfP3116pfv77D1abAraKjo7Vt2zZ99dVXTGxbxJo1a7R3794snwsOAAAAIPdw5TYAAAAA+JklS5Zkuuq3cuXKXMEN3AYybge/YcMGlS5dOq+7AwAAAACWwuQ2AAAAAACABQwbNkwbNmzQxo0bmdgGAAAAgCwwuQ0AAAAAAJDHTpw4ocjISIWGhiooKMi+vlWrVvrqq6/ysGcAAAAAYB1MbgMAAAAAAAAAAAAALC8grzsAAAAAAAAAAAAAAEB2mNwGAAAAAAAAAAAAAFgek9sAAAAAAAAAAAAAAMtjchsAAAAAAAAAAAAAYHlMbgMAAAAAAAAAAAAALI/JbQAAAAAAAAAAAACA5TG5DQAAAAAAAAAAAACwPCa3AQAAAAAAAAAAAACWx+Q2AAAAAAAAAAAAAMDymNwGAAAAAAAAAAAAAFgek9sAAAAAAAAAAAAAAMtjchsAAAAAAAAAAAAAYHlMbgMAAAAAAAAAAAAALI/JbQAAAAAAAAAAAACA5TG5DQAAAAAAAAAAAACwPCa3AQAAAAAAAAAAAACWx+Q2AAAAAAAAAAAAAMDymNwGAAAAAAAAAAAAAFgek9sAAAAAAAAAAAAAAMtjchsAAAAAAAAAAAAAYHlMbgMAAAAAAAAAAAAALI/JbQAAAAAAAAAAAACA5TG5Db/Tpk0bjRgxwqvteKtNV/z222967LHHNGHChFx5PavLzfceAAAAAAAAAJA9xrH/izFswFqY3Ial9O3bVzabzb6ULFlSHTt21E8//WSvs2rVKk2dOtWrr+uLNs288MILql69ulauXJkrr+dt3g5yd997/kMCAAAAAAAAQH6XMZY+ffp0h/Vr1qyRzWbLcfuMY/+XJ/MHjGMDvsPkNiynY8eOOn36tE6fPq3Y2FgFBQXp4YcftpeXKFFCRYoU8epr+qLNrMTHx2vTpk1q2bKlypcv7/PXy0vJycku1cut9x4AAAAAAAAAbidhYWF67bXXdPnyZa+2m1/GsRnDBvwTk9uwnNDQUEVERCgiIkINGjTQmDFjdOrUKZ0/f15S5jOe2rRpo2HDhmnUqFEqUaKEIiIiNHnyZIc2ExMT1bt3bxUuXFjlypXTzJkzHco9afPKlSvq2bOnChUqpHLlyumNN97I9mysDRs2qFWrVtq8ebNatGjhyduj9PR0zZgxQ9WqVVNoaKgqVaqkV155xV6elJSkYcOGqUyZMgoLC1PLli31ww8/uLVvn376qerWrasCBQqoZMmSat++vRITE9W3b19t3rxZb775pv3q+uPHj9vbHTp0qEaMGKFSpUopKipK69atU8uWLVWsWDGVLFlSDz/8sI4ePerwWu68985eHwAAAAAAAADyk/bt2ysiIkLTpk0zrZPdeHFWcjqOndMxbCn7cWyzMWzJfBw5qzFsSdmOY7s7f8A4NuBbTG7D0q5evaoPPvhA1apVU8mSJU3rLV26VIUKFdKOHTs0Y8YMvfTSS4qJibGXjxw5Ups3b9Znn32mb775Rps2bVJcXJzT186uzejoaH333Xf6/PPPFRMToy1btmTb5pYtW3TPPfdo7dq1euSRR1x8FxyNHTtW06dP14QJE7R//34tX75cZcuWtZePGjVK//73v7V06VLFxcWpWrVqioqK0qVLl1zat9OnT6tHjx7q37+/Dhw4oE2bNunRRx+VYRh688031bx5cw0cONB+dX3FihUd2g0JCdF3332nefPmKTExUdHR0frxxx8VGxurgIAAPfLII0pPT3e6j2b9y+71AQAAAAAAACC/CAwM1Kuvvqq3335bv/32W5Z1XBkvvlVOx7G9MYYtmY8TOxvDluR0HPnWMWxJHo1jOxtjZxwb8DEDsJA+ffoYgYGBRqFChYxChQoZkoxy5coZu3btstdp3bq1MXz4cId/t2zZ0qGdu+++2xg9erRhGIZx5coVIyQkxPjkk0/s5RcvXjQKFChgb8fdNhMSEozg4GBj5cqV9vI///zTKFiwoEM7t2rVqpUxZswYo2HDhvZ1J0+eNFq3bm3UqlXLqFu3rkM/b5WQkGCEhoYaCxYsyLL86tWrRnBwsPHhhx/a1yUnJxvly5c3ZsyY4dK+7dq1y5BkHD9+PMvXuPW9unn9zfuVlfPnzxuSjJ9//tm0vez6Z/b6AAAAAAAAAJBf9OnTx+jSpYthGIZxzz33GP379zcMwzBWr15tZEz9uDJenJWcjGN7YwzbMJyPE2c3hp2x/a3jyK6MYRtG5nFsd8ewzV4fgHdw5TYsp23bttqzZ4/27NmjnTt3KioqSg888IBOnDhhuk29evUc/l2uXDmdO3dOknT06FElJyerWbNm9vISJUqoZs2aTvvhrM1ff/1VKSkpatq0qb28aNGi2bZ5/PhxrVmzRmPGjLGvCwoK0uzZs7V//3598803GjFihP32Kbc6cOCAkpKS1K5duyzLjx49qpSUFIdbxQQHB6tp06Y6cOCAS/tWv359tWvXTnXr1lX37t21YMECl5/Z0rhxY4d/Hz58WD169FCVKlUUHh6uyMhISdLJkyedtuOsfwAAAAAAAACA/3rttde0dOlShzFgyfXx4lvlZBzbW2PYkvk4sTfHsCXPxrEZwwbyDpPbsJxChQqpWrVqqlatmu6++2699957SkxM1IIFC0y3CQ4Odvi3zWbL9tbX2fFFm2fPnlVISIgee+wx+7py5cqpQYMGkqSIiAiVKlXK9JYwBQoUyNHrZ3C2b4GBgYqJidFXX32l2rVr6+2331bNmjV17NixbNstVKiQw787deqkS5cuacGCBdqxY4d27NghSUpOTva4fwAAAAAAAACA/7rvvvsUFRWlsWPHeqW9nIxje2sMWzIfJ/bmGLbk2Tg2Y9hA3mFyG5Zns9kUEBCg69eve7R91apVFRwcbA8kSbp8+bJ++eUXj/tUpUoVBQcH64cffrCvi4+Pz7bN4OBgzZw5UwEBWf/0du3apbS0NNPnb1SvXl0FChRQbGxsluVVq1a1Py8kQ0pKin744QfVrl07u92ys9lsatGihaZMmaLdu3crJCREq1evliSFhIQoLS0t2zYuXryoQ4cOafz48WrXrp1q1arl8tlzzrj6+gAAAAAAAACQX0yfPl3/+c9/tG3bNvs6T8eLczKObYUxbIlxbOB2FpTXHQBulZSUpDNnzkj6axJ6zpw5unr1qjp16uRRe4ULF9aAAQM0cuRIlSxZUmXKlNH/+3//zzSYXVGkSBH16dNHI0eOVIkSJVSmTBlNmjRJAQEBstlsWW6zdOlSJSYmKjQ0VNu3b9eFCxf08MMP28svXbqk3r17O71CPSwsTKNHj9aoUaMUEhKiFi1a6Pz589q3b58GDBigQoUKaciQIfZ+VapUSTNmzNC1a9c0YMAAl/Ztx44dio2NVYcOHVSmTBnt2LFD58+fV61atSRJkZGR2rFjh44fP67ChQurRIkSWb6XxYsXV8mSJTV//nyVK1dOJ0+edLiNjadcfX0AAAAAAAAAyC/q1q2rnj176q233rKv82S8OKfj2FYYw5ayHkfOCuPYgP9hchuWs27dOpUrV07SX5PId955p1auXKk2bdp43Obrr79unyAvUqSI/vGPfyg+Pj5H/Zw1a5YGDx6shx9+WOHh4Ro1apROnTqlsLCwTHVv3LihVatWaenSpRo0aJAiIyP1/vvv28uTkpLUtWtXjRkzRvfee6/T150wYYKCgoI0ceJE/fHHHypXrpwGDx5sL58+fbrS09PVq1cvXblyRU2aNNHXX3+t4sWLu7Rf4eHh+vbbbzV79mwlJCSocuXKmjlzph544AFJ0osvvqg+ffqodu3aun79uo4dO2Z/BsnNAgICtGLFCg0bNkx16tRRzZo19dZbb+Xoc3Tn9QEAAAAAAAAgP3nppZf08ccfO6xzZ7zYW+PYeT2GLWU9jpwVxrEB/2MzDMPI604At4PExERVqFBBM2fOdPkMM0kyDENPPfWUatasqcmTJ/uugwAAAAAAAAAAeIBxbABWweQ24KHdu3fr4MGDatq0qeLj4/XSSy9p06ZNOnLkiEqVKuVyO1u3btV9992nevXq2dctW7ZMdevW9UW3AQAAAAAAAABwC+PYAKyCyW3AQ7t379YzzzyjQ4cOKSQkRI0bN9asWbMIcwAAAAAAAAAAAMAHmNwGAAAAAAAAAAAAAFheQF53AAAAAAAAAAAAAACA7DC5DQAAAAAAAAAAAACwPCa3AQAAAAAAAAAAAACWx+Q2AAAAAAAAAAAAAMDymNwGAAAAAAAAAAAAAFgek9u4bVy8eFFlypTR8ePHPdq+TZs2GjFihM+3cZerr/Hbb7/pscce04QJE3zaH2958sknNXPmzLzuBgAAAADARbcen7pyvOpKHX87nvW13BhrAAD4L8bB/eu/GxgHB7yPyW3cNl555RV16dJFkZGRkqS+ffvKZrNp8ODBmeo+99xzstls6tu3r33dqlWrNHXqVLde05NtfOWFF15Q9erVtXLlyrzuSiZZ/YfJ+PHj9corryg+Pj5vOgUAwE387eDYlxhQB4D8xZ1jZ18dA1v5eNYV3s5OT95n8hsA8g/Gwa373w2MgwO5g8lt3BauXbumhQsXasCAAQ7rK1asqBUrVuj69ev2dTdu3NDy5ctVqVIlh7olSpRQkSJF3HpdT7bxhfj4eG3atEktW7ZU+fLl87o7LqlTp46qVq2qDz74IK+7AgDw0Pnz5zVkyBBVqlRJoaGhioiIUFRUlL777jtJ/jXIauWDY1d4871mQB0A8h9Xj519cQzsj8eznkpOTnapnlXGGgAA1sM4uP/9dwPj4ID3MbmN28KXX36p0NBQ3XPPPQ7rGzVqpIoVK2rVqlX2datWrVKlSpXUsGFDh7pZ3V5t2LBhGjVqlEqUKKGIiAhNnjzZdJs2bdro+eef14gRI1S8eHGVLVtWCxYsUGJiovr166ciRYqoWrVq+uqrrxzaWLdunVq2bKlixYqpZMmSevjhh3X06FG39n/Dhg1q1aqVNm/erBYtWri1rSSlp6drxowZqlatmkJDQ1WpUiW98sor9vKkpCQNGzZMZcqUUVhYmFq2bKkffvjBoY1PP/1UdevWVYECBVSyZEm1b99eiYmJ6tu3rzZv3qw333xTNptNNpvNfsucTp06acWKFW73FwBgDd26ddPu3bu1dOlS/fLLL/r888/Vpk0bXbx40a12XB3o9RV/PDj2BAPqAAAzrh47Z3cyU2Jionr37q3ChQurXLlyLt2CM6fHs1LOj2ldOf735Ji3TZs2Gjp0qEaMGKFSpUopKirKpTEAd8cnnPUBAHB7YRyccXAATG7jNrFlyxY1btw4y7L+/ftr8eLF9n8vWrRI/fr1c6ndpUuXqlChQtqxY4dmzJihl156STExMU7rlypVSjt37tTzzz+vIUOGqHv37rr33nsVFxenDh06qFevXrp27Zp9m8TEREVHR+vHH39UbGysAgIC9Mgjjyg9Pd3Fvf9r/++55x6tXbtWjzzyiMvbZRg7dqymT5+uCRMmaP/+/Vq+fLnKli1rLx81apT+/e9/a+nSpYqLi1O1atUUFRWlS5cuSZJOnz6tHj16qH///jpw4IA2bdqkRx99VIZh6M0331Tz5s01cOBAnT59WqdPn1bFihUlSU2bNtXOnTuVlJTkdp8BAHnrzz//1JYtW/Taa6+pbdu2qly5spo2baqxY8eqc+fObg/0unIA6coB95UrV9SzZ08VKlRI5cqV0xtvvJHtQLyvB9W9tW/uHkBn9T5L2Q8oMKAOAPlTTo6dM4wcOVKbN2/WZ599pm+++UabNm1SXFyc021yejwr5fyYVnJ+/O/pMW9GuyEhIfruu+80b948j8cAnPUvuz4AAG4fjIMzDg5AkgHcBrp06WL079/fYV2fPn2MLl26GOfOnTNCQ0ON48ePG8ePHzfCwsKM8+fPG126dDH69Oljr9+6dWtj+PDhDv9u2bKlQ5t33323MXr06Cy3ubV+amqqUahQIaNXr172dadPnzYkGdu2bTPdl/PnzxuSjJ9//jnLfmWlVatWxpgxY4yGDRsahmEYJ0+eNFq3bm3UqlXLqFu3rvHJJ5+YbpuQkGCEhoYaCxYsyLL86tWrRnBwsPHhhx/a1yUnJxvly5c3ZsyYYRiGYezatcuQZBw/fjzLNsz24X//93+dbgcAsK6UlBSjcOHCxogRI4wbN25kKv/zzz+N5s2bGwMHDjROnz5tnD592khNTTUM469cKFy4sDFy5Ejj4MGDxsGDB41hw4YZ5cuXN7788ktj3759Rp8+fYzixYsbFy9etLfZunVrIzw83Jg8ebLxyy+/GEuXLjVsNpvxzTff2Os888wzRuXKlY3169cbP//8s/HII48YRYoUcZqlL7zwgvHaa68ZtWrVMnbt2uXR+zFq1CijePHixpIlS4wjR44YW7ZssWerN/btjz/+MIKCgoxZs2YZx44dM3766SfjnXfeMa5cuWL6Xmf1PhuGYXz66afGv//9b+Pw4cPG7t27jU6dOhl169Y10tLS7H259b+JnPXN2WcNALA+d46ds8qIjH9fuXLFCAkJcTj+vHjxolGgQAGnOXzr8axh5P4xbXbH/54e87Zu3dphv7Jy6xhAVu25Oz4BALh9MQ6e+b8bunbtahQrVszo1q2b020ZBwduH0F5NakOeNP169cVFhaWZVnp0qX10EMPacmSJTIMQw899JBKlSrlUrv16tVz+He5cuV07tw5l+oHBgaqZMmSqlu3rn1dxllgN7dx+PBhTZw4UTt27NCFCxfsZ6qdPHlSderUcamfx48f1/nz5zVlyhRJUlBQkGbPnq0GDRrozJkzaty4sR588EEVKlQo07YHDhxQUlKS2rVrl2XbR48eVUpKisOVbMHBwWratKkOHDggSapfv77atWununXrKioqSh06dNBjjz2m4sWLO+13gQIFJMnhDD4AgH8ICgrSkiVLNHDgQM2bN0+NGjVS69at9eSTT6pevXoqWrSoQkJCVLBgQUVERGTavnr16poxY4akv87enjt3rpYsWaIHHnhAkrRgwQLFxMRo4cKFGjlypH27evXqadKkSfY25syZo9jYWN1///26cuWKli5dquXLl9tzbfHixdneavzHH39UixYtFBYWpkaNGkmSTp06pV69euncuXMKCgrShAkT1L179yy3v3Llit58803NmTNHffr0kSRVrVpVLVu29Nq+nT59WqmpqXr00UdVuXJlSXL4bwyz9/rm9zlDt27dHP69aNEilS5dWvv37zf9bw9nfcvuswYA+IecHDtLfx07Jicnq1mzZvZ1JUqUUM2aNZ1ud+vxrJT7x7SS8+N/T495JWW6us7TMQB3xycAALcnxsEz/3fD8OHD1b9/fy1dutTptoyDA7cPbkuO20KpUqV0+fJl0/L+/ftryZIlWrp0qfr37+9yu8HBwQ7/ttlsTm+TklX9m9fZbDZJcmijU6dOunTpkhYsWKAdO3Zox44dktx7/ujZs2cVEhKixx57TNJf//HRoEEDSVJERIRKlSrlcLu1m2UEa04EBgYqJiZGX331lWrXrq23335bNWvW1LFjx5xul9Gn0qVL57gPAIDc161bN/3xxx/6/PPP1bFjR23atEmNGjXSkiVLst325oFeVwedJecH3L/++qtSUlLUtGlTe3nRokVdGlRfs2aNxowZY1+XMai+f/9+ffPNNxoxYoQSExOz3N7ZAbK39u3mA+ju3btrwYIFTv/bJ0NWt6s7fPiwevTooSpVqig8PFyRkZGS/hpQMMOAOgDkD54eO+fErcezUu4f00rOj/89PeaVlGlC3tMxAHfHJwAAtyfGwTP/d0ObNm1UpEiRbLdlHBy4fTC5jdtCw4YNtX//ftPyjh07Kjk5WSkpKfbnTVrBxYsXdejQIY0fP17t2rVTrVq1XBqovlVwcLBmzpypgIDMP+ldu3YpLS3N9Hlb1atXV4ECBRQbG5tledWqVe3PB8uQkpKiH374QbVr17avs9lsatGihaZMmaLdu3crJCREq1evlvTX1WRpaWmZ2t67d6/+9re/uXU1AADAWsLCwnT//fdrwoQJ+v7779W3b1/7Fb7OZHXllSt8MbB7Ow+qZ/U+ezKgwIA6AOQPOTl2rlq1qoKDg+25IkmXL1/WL7/84nQ7Z8ezUu4d02bHk2PeW3lrDCArrvYBAODfGAd3/t8NzjAODtw+mNzGbSEqKkr79u0zDcTAwEAdOHBA+/fvV2BgYC73zlzx4sVVsmRJzZ8/X0eOHNGGDRsUHR3tVhtLly5VYmKiQkNDtX37dq1du9ZedunSJfXu3Vvz58833T4sLEyjR4/WqFGj9P777+vo0aPavn27Fi5cKOmvQfEhQ4Zo5MiRWrdunfbv36+BAwfq2rVrGjBggCRpx44devXVV/Xjjz/q5MmTWrVqlc6fP69atWpJkiIjI7Vjxw4dP37c4ZYzW7ZsUYcOHdzaXwCAtdWuXdt+hbOrg6zeGnSuUqWKgoOD9cMPP9jXxcfH5+mgurf2TbL2oDoD6gBwe8jJsXPhwoU1YMAAjRw5Uhs2bNDevXvVt29fp4PPzo5npdw7ps2Op8e8t/LGGIAZV/sAAPBvjIOb/3dDdhgHB24fPHMbt4W6deuqUaNG+uSTT/Tss89mWSc8PDyXe5W9gIAArVixQsOGDVOdOnVUs2ZNvfXWW2rTpo1L29+4cUOrVq3S0qVLNWjQIEVGRur999+XJCUlJalr164aM2aM7r33XqftTJgwQUFBQZo4caL++OMPlStXToMHD7aXT58+Xenp6erVq5euXLmiJk2a6Ouvv7Y/SyQ8PFzffvutZs+erYSEBFWuXFkzZ860P1v0xRdfVJ8+fVS7dm1dv35dx44dU0REhNasWaN169Z58M4BAPLaxYsX1b17d/Xv31/16tVTkSJF9OOPP2rGjBnq0qWLJMeDusKFC6tEiRJZDnDffABZokQJVapUSTNmzHBr0FmSihQpoj59+tjbKVOmjCZNmqSAgAD7LdFudevB8YULF/Twww/byzMG1RcsWGD6ujcfIIeEhKhFixY6f/689u3bpwEDBnhl33bs2KHY2Fh16NBBZcqU0Y4dO0wPoDPe66zcPKBQrlw5nTx50uF27J5y9bMGAFhfTo6dX3/9dV29elWdOnVSkSJF9I9//EPx8fFZ1nV2PCvl7jFtdjw55s147MfNcjoG4IyrfQAA+DfGwbP+7wZXMQ4O3CYM4Daxdu1ao1atWkZaWlpedyXPpaenG08++aQxadKkvO6KqXfffde4//7787obAAAP3bhxwxgzZozRqFEjo2jRokbBggWNmjVrGuPHjzeuXbtmGIZhHDp0yLjnnnuMAgUKGJKMY8eOGYZhGK1btzaGDx/u0N7169eN559/3ihVqpQRGhpqtGjRwti5c6dDnay269Kli9GnTx/7vxMSEoynnnrKKFiwoBEREWHMmjXLaNq0qTFmzJhM+3D9+nWjc+fOxtKlS40777zT6Nixo3Hu3DmHfWzVqpXx/vvvZ/t+pKWlGS+//LJRuXJlIzg42KhUqZLx6quvem3f9u/fb0RFRRmlS5c2QkNDjRo1ahhvv/22vW5W73VWbRqGYcTExBi1atUyQkNDjXr16hmbNm0yJBmrV6/Osi+uvO9mnzUAAJ7wh2NaAADyCuPgmW3cuNHo1q1bXncjS4yDA95nMwzDyNvpdcB7Zs+erW7dupneNjS/2Lp1q+677z7Vq1fPvm7ZsmWqW7duHvbK0XvvvadWrVqpZs2aed0VAMBtLDExURUqVNDMmTPdulLaMAw99dRTqlmzpiZPnuy7DgIAgEz84ZgWAIC8xDj4f7Vv317/+7//q8TERJUoUUIrV65U8+bN87pbdoyDA97H5DYAAABuG7t379bBgwfVtGlTxcfH66WXXtKmTZt05MgRlSpVyuV2GFQHAAAAAAAArIdnbgMAAOC28s9//lOHDh1SSEiIGjdurC1btrg1sS1JLVu2VHp6uo96CAAAAAAAAMATXLkNAAAAAAAAAAAAALC8gLx88W+//VadOnVS+fLlZbPZtGbNmmy32bRpkxo1aqTQ0FBVq1ZNS5Ys8Xk/AQDIr8hqAACsj7wGAMDayGoAALwnTye3ExMTVb9+fb3zzjsu1T927JgeeughtW3bVnv27NGIESP0zDPP6Ouvv/ZxTwEAyJ/IagAArI+8BgDA2shqAAC8xzK3JbfZbFq9erW6du1qWmf06NH64osvtHfvXvu6J598Un/++afWrVuXC70EACD/IqsBALA+8hoAAGsjqwEAyJmgvO6AO7Zt26b27ds7rIuKitKIESNMt0lKSlJSUpL93+np6bp06ZJKliwpm83mq64CQL5lGIauXLmi8uXLKyAgT28QgjzgSVZL5DUA5DbyOn/j2BoArI+szt/IagCwPrI67/jV5PaZM2dUtmxZh3Vly5ZVQkKCrl+/rgIFCmTaZtq0aZoyZUpudREA8H9OnTqlv/3tb3ndDeQyT7JaIq8BIK+Q1/kTx9YA4D/I6vyJrAYA/0FW5z6/mtz2xNixYxUdHW3/d3x8vCpVqqQjx06pSHh4HvYMAG5PVxISVO2OiipSpEhedwV+hLwGgNxFXsNdZDUA5C6yGu4iqwEgd5HVecevJrcjIiJ09uxZh3Vnz55VeHi46ZVgoaGhCg0NzbS+SHi4wgl1APAZbnmVP3mS1RJ5DQB5hbzOnzi2BgD/QVbnT2Q1APgPsjr3+dVN4Js3b67Y2FiHdTExMWrevHke9QgAANyMrAYAwPrIawAArI2sBgDAXJ5Obl+9elV79uzRnj17JEnHjh3Tnj17dPLkSUl/3Uqld+/e9vqDBw/Wr7/+qlGjRungwYN699139cknn+iFF17Ii+4DAHDbI6sBALA+8hoAAGsjqwEA8J48ndz+8ccf1bBhQzVs2FCSFB0drYYNG2rixImSpNOnT9sDXpLuuOMOffHFF4qJiVH9+vU1c+ZMvffee4qKisqT/gMAcLsjqwEAsD7yGgAAayOrAQDwHpthGEZedyI3JSQkqGjRojp7MZ5njQCADyQkJKhsyaKKj+fvLDxHXgOAb5HXyCmyGgB8i6xGTpHVAOBbZHXe8atnbgMAAAAAAAAAAAAA8icmtwEAAAAAAAAAAAAAlsfkNgAAAAAAAAAAAADA8pjcBgAAAAAAAAAAAABYHpPbAAAAAAAAAAAAAADLY3IbAAAAAAAAAAAAAGB5TG4DAAAAAAAAAAAAACyPyW0AAAAAAAAAAAAAgOUxuQ0AAAAAAAAAAAAAsDwmtwEAAAAAAAAAAAAAlsfkNgAAAAAAAAAAAADA8pjcBgAAAAAAAAAAAABYHpPbAAAAAAAAAAAAAADLY3IbAAAAAAAAAAAAAGB5TG4DAAAAAAAAAAAAACyPyW0AAAAAAAAAAAAAgOUxuQ0AAAAAAAAAAAAAsDwmtwEAAAAAAAAAAAAAlsfkNgAAAAAAAAAAAADA8pjcBgAAAAAAAAAAAABYXlBedwAAbgeGYbi9jc1m80FPAACAGfIaAAAAAADAv3HlNgAAAAAAAAAAAADA8pjcBgAAAAAAAAAAAABYHpPbAAAAAAAAAAAAAADLY3IbAAAAAAAAAAAAAGB5TG4DAAAAAAAAAAAAACyPyW0AAAAAAAAAAAAAgOUF5XUHAMAXDMPwcDsvd8Tpa3n2YjabszInhQAAWAx5DQCAtflDVnuKrAYA+AtP89h5m15v0uvIapjhym0AAAAAAAAAAAAAgOUxuQ0AAAAAAAAAAAAAsDwmtwEAAAAAAAAAAAAAlsfkNgAAAAAAAAAAAADA8pjcBgAAAAAAAAAAAABYHpPbAAAAAAAAAAAAAADLC8rrDgCApwzDMC1LNy9yup3z1/NoM1M2m2fbBQVwXhIAwH+kOwllZ9FKXgMAkDv8Jas9zWQzZDUAwF84y05fZLWnfSGrkVuY3AYA+I0bN24oOTnZrW1CQkIUFhbmox4BAICbkdUAAFgfeQ0AgLWR1c4xuQ0A8As3btxQgSIlpdRrbm0XERGhY8eO5ZtgBwAgr5DVAABYH3kNAIC1kdXZY3IbAOAXkpOTpdRrCr2rnxQY4tpGack6s2+xkpOT80WoAwCQl8hqAACsj7wGAMDayOrsMbkNAPAvQSGyBYa6VNXw8nNeAACAC8hqAACsj7wGAMDayGpTTG4DAPyLLeCvxdW6AAAgd5HVAABYH3kNAIC1kdWmmNwGAPgXm+2vxdW6AAAgd5HVAABYH3kNAIC1kdWmmNwGYGmGYZiWpaWblzkpUrrT7czLnDTplFms2JwEToGQQA9fLR/gjDUAsBxnee0sW8nr2xRZDQCWQ1YjE/IaACwlt7PaWR4764uz7chqLyOrTTG5DQDwL5yxBgCAtZHVAABYH3kNAIC1kdWmmNwGAPgZN85YU/46Yw0AAGsgqwEAsD7yGgAAayOrzTC5DQDwL5yxBgCAtZHVAABYH3kNAIC1kdWmmNwGAPgXnjUCAIC1kdUAAFgfeQ0AgLWR1aaY3AYA+BfOWAMAwNrIagAArI+8BgDA2shqU0xuAwD8C2esAQBgbWQ1AADWR14DAGBtZLWp/LW3APxOumG+pKYbpktyarrpciMlzXS5nmy+XEvybDFrr0BIoOkCJzLOWHN1AQD4nNO8TjNMF/L6NkVWA4Dl3A5ZfS0564Ws9pCP8/qdd95RZGSkwsLC1KxZM+3cudNp/ZUrV+rOO+9UWFiY6tatqy+//NKhfPLkybrzzjtVqFAhFS9eXO3bt9eOHTsc6ly6dEk9e/ZUeHi4ihUrpgEDBujq1atu9x0A8kJuZ7VZrma7kNW5JxeOrf01r5ncBgD4l4wz1lxdAABA7sqFrPbXA3AAACzDh3n98ccfKzo6WpMmTVJcXJzq16+vqKgonTt3Lsv633//vXr06KEBAwZo9+7d6tq1q7p27aq9e/fa69SoUUNz5szRzz//rK1btyoyMlIdOnTQ+fPn7XV69uypffv2KSYmRmvXrtW3336rQYMGefb+AACQ13x8bO3Pec2oPwDAv9hsboQ6V4MBAJDrfJzV/nwADgCAZfgwr2fNmqWBAweqX79+ql27tubNm6eCBQtq0aJFWdZ/88031bFjR40cOVK1atXS1KlT1ahRI82ZM8de56mnnlL79u1VpUoV3XXXXZo1a5YSEhL0008/SZIOHDigdevW6b333lOzZs3UsmVLvf3221qxYoX++OMPz98nAADyio+Prf05r5ncBgD4lwCbe4sHuBoMAIAc8HFW+/MBOAAAluFBXickJDgsSUlJmZpNTk7Wrl271L59+/++VECA2rdvr23btmXZlW3btjnUl6SoqCjT+snJyZo/f76KFi2q+vXr29soVqyYmjRpYq/Xvn17BQQEZDr+BgDAL/goqyX/z2smtwEA/oXbsQAAYG0eZHV+OQAHAMAyPMjrihUrqmjRovZl2rRpmZq9cOGC0tLSVLZsWYf1ZcuW1ZkzZ7LsypkzZ1yqv3btWhUuXFhhYWF64403FBMTo1KlStnbKFOmjEP9oKAglShRwvR1AQCwNB9lteT/ec3kNgDAv9hs7i1u4mowAAByyIOszi8H4AAAWIYHeX3q1CnFx8fbl7Fjx+Zql9u2bas9e/bo+++/V8eOHfX444+bnogOAIDf88OslnInr4O82hoAeFlaumFalpSSblp2IyXNSZn5dqlp5mVOuiJnU6jVIgo7KYXb3Lki+6arwW4WGhqq0NDQTNUzrga7OfRduRosOjraYV1UVJTWrFmTZX1PrgZ75JFHst9XAMhDTvM61bO8dpbzKU7y2llfbE5OeqpOXnuPB1l96tQphYeH21dnldO+lnEAfuHCBS1YsECPP/64duzYkWlSGwD8kb9ktTM1yhXxaDuY8CCvw8PDHfI6K6VKlVJgYKDOnj3rsP7s2bOKiIjIcpuIiAiX6hcqVEjVqlVTtWrVdM8996h69epauHChxo4dq4iIiEwD56mpqbp06ZLp6wKAlfgiq5OdbOeszFlUG4Z5IVntZT7Kasn/85ortwEA/oWrwQAAsDYPsjrjADxjMZvczo0D8HvuuUcLFy5UUFCQFi5caG+DAXMAwG3FR3dFCwkJUePGjRUbG2tfl56ertjYWDVv3jzLbZo3b+5QX5JiYmJM69/cbsajTJo3b64///xTu3btspdv2LBB6enpatasmcv9BwDAMnx4B1N/z2smtwEA/sWDZ43kl9uxAABgCR5ktav8/QAcAADL8GFeR0dHa8GCBVq6dKkOHDigIUOGKDExUf369ZMk9e7d2+G4fPjw4Vq3bp1mzpypgwcPavLkyfrxxx81dOhQSVJiYqLGjRun7du368SJE9q1a5f69++v33//Xd27d5ck1apVSx07dtTAgQO1c+dOfffddxo6dKiefPJJlS9f3ktvGgAAuciHWS35d15zW3IAgH9x50y0W64Gy46/344FAABL8CCr3REdHa0+ffqoSZMmatq0qWbPnp3pALxChQr2O7UMHz5crVu31syZM/XQQw9pxYoV+vHHHzV//nxJfx2Av/LKK+rcubPKlSunCxcu6J133jE9AJ83b55SUlIYMAcA+Dcf5vUTTzyh8+fPa+LEiTpz5owaNGigdevW2e96dvLkSQUE/HcQ/t5779Xy5cs1fvx4jRs3TtWrV9eaNWtUp04dSVJgYKAOHjyopUuX6sKFCypZsqTuvvtubdmyRXfddZe9nQ8//FBDhw5Vu3btFBAQoG7duumtt95yq+8AAFiGj4+t/Tmv8/zK7XfeeUeRkZEKCwtTs2bNtHPnTqf1Z8+erZo1a6pAgQKqWLGiXnjhBd24cSOXegsAyHNcDZYnyGsAgMt8fHb5E088oX/+85+aOHGiGjRooD179mQ6AD99+rS9fsYB+Pz581W/fn19+umnWR6Ad+vWTTVq1FCnTp108eLFLA/A77zzTrVr104PPvigWrZsaZ8gtwKyGgDgFh/n9dChQ3XixAklJSVpx44dDse2mzZt0pIlSxzqd+/eXYcOHVJSUpL27t2rBx980F4WFhamVatW6ffff1dSUpL++OMPffbZZ7r77rsd2ihRooSWL1+uK1euKD4+XosWLVLhwoXd7rsvkdcAAJf5OKsl/83rPL1y++OPP1Z0dLTmzZunZs2aafbs2YqKitKhQ4cyPXtUkpYvX64xY8Zo0aJFuvfee/XLL7+ob9++stlsmjVrVh7sAQAg13E1WK4jrwEAbvFxVkt/HYBn3PrsVps2bcq0rnv37vbcvVXGAXh2Mg7ArYisBgC4LRfyGo7IawCAW8hqU3l65fasWbM0cOBA9evXT7Vr19a8efNUsGBBLVq0KMv633//vVq0aKGnnnpKkZGR6tChg3r06JHtGW4AgNuJO2ercTWYN5DXAAD3+DarkRlZDQBwH3md28hrAIB7yGozeXbldnJysnbt2uXwMPKAgAC1b99e27Zty3Kbe++9Vx988IF27typpk2b6tdff9WXX36pXr16mb5OUlKS/bavkpSQkOC9nQDgcylp6aZlV2+kmpYlXE8xLbuSZL5dYop5WbphmJb9vWbmM2zhI1wNlqvIawCuSHWS14lOcjfhunmZs5y/kmKe8ynp5n25/86ypmXwIs4uz1VkNQBX5HZWJySbZ3VyepppWVStCNMyeBl5natyI6/JasC/eZrVV5xltbOMd5LVSWlktSWQ1abybHL7woULSktLs18Jl6Fs2bI6ePBglts89dRTunDhglq2bCnDMJSamqrBgwdr3Lhxpq8zbdo0TZkyxat9BwDkIZvN9WeI5LNQ9wXyGgDgNrI6V5HVAACPkNe5KjfymqwGgNsMWW3Kr65T37Rpk1599VW9++67iouL06pVq/TFF19o6tSpptuMHTtW8fHx9uXUqVO52GMAgNe5eisW+y1ZkNvIawDI58hqyyOrAQDktfW5m9dkNQDcZshqU3l25XapUqUUGBios2fPOqw/e/asIiKyvq3BhAkT1KtXLz3zzDOSpLp16yoxMVGDBg3S//t//08BAZk/vNDQUIWGhnp/BwAAeYPbseQq8hoA4DayOleR1QAAj5DXuSo38pqsBoDbDFltKs+m8kNCQtS4cWPFxsba16Wnpys2NlbNmzfPcptr165lCu3AwEBJkuHkWbgAgNsIZ6zlKvIaAOA2sjpXkdUAAI+Q17mKvAYAuI2sNpVnV25LUnR0tPr06aMmTZqoadOmmj17thITE9WvXz9JUu/evVWhQgVNmzZNktSpUyfNmjVLDRs2VLNmzXTkyBFNmDBBnTp1sgc7AOA2xxlruY68BgC4hazOdWQ1AMBt5HWuI68BAG4hq03l6eT2E088ofPnz2vixIk6c+aMGjRooHXr1qls2bKSpJMnTzqcnTZ+/HjZbDaNHz9ev//+u0qXLq1OnTrplVdeyatdAADkNnfORMtnZ6z5CnkNAHALWZ3ryGoAgNvI61xHXgMA3EJWm7IZ+eweJgkJCSpatKjOXoxXeHh4XncHgJzfSulyYopp2R+Xr5uWHY9PNC079qf5dhcTU03LJt5fw7QsICB/nRnlTEJCgsqWLKr4eO/+nc34+x368NuyBRdwaRsj5bqS1j7v9b7A98hrwHqc5fWf18zz+vdL3s/rs1fMX++lqJqmZeT1f/kir8nq/IWsBqzHF1l9IuGaadmvl83Lfo9PNi179YE7TcvI6v/i2Bo5RVYD1uMsq+OdZfXlG6ZlJ5wcV//q5Lj6tz+TTMvIateQ1XknT6/cBgDAXTabTTZuxwIAgGWR1QAAWB95DQCAtZHV5pjcBgD4FUIdAABrI6sBALA+8hoAAGsjq80xuQ0A8C+2/1tcrQsAAHIXWQ0AgPWR1wAAWBtZbYrJbQCAX+GMNQAArI2sBgDA+shrAACsjaw2x+Q2AMCvEOoAAFgbWQ0AgPWR1wAAWBtZbY7JbQCAXyHUAQCwNrIaAADrI68BALA2stock9sALC0xKdW07OClK6ZlG49cNi375bd407J1z7cwLQsIyF8BAQCAqxKT0kzLfrlsntexh83z+sAJ87KYEa1My8hrAAAyu+Ykqw87yer1TrL656MXTcs2vtjatIysBgAgs2vJ3s/q/z18wbRs00iyGv6LyW0AgF/hjDUAAKyNrAYAwPrIawAArI2sNsfkNgDAv9j+b3G1LgAAyF1kNQAA1kdeAwBgbWS1KSa3AQB+hTPWAACwNrIaAADrI68BALA2stock9sAAL9is8mNUPdtXwAAQGZkNQAA1kdeAwBgbWS1OSa3AQB+xSY3zljLb6kOAIAFkNUAAFgfeQ0AgLWR1eaY3AYA+BVuxwIAgLWR1QAAWB95DQCAtZHV5pjcBpDn0g3zssuJKaZln/10zrRs4+ZfTMuO/+tx07LAgPwVAn7JJtdPROPjBACvcZbXfyYmm5atcZLXMbEHTMtOvdfDtIy8tjiyGgDyhNOsvmZ+bL3aSVZ//c0+07LfF/c0LSOr/QB5DQC5zllWxzvJ6s9+Ns/qL7762bTsjyVPm5aR1X6ArDbF5DYAwL+4ccaakc/OWAMAwBLIagAArI+8BgDA2shqU0xuAwD8iju3Y3H9mSQAAMBbyGoAAKyPvAYAwNrIanNMbgMA/AqhDgCAtZHVAABYH3kNAIC1kdXmmNwGAPgXnjUCAIC1kdUAAFgfeQ0AgLWR1aaY3AYA+BXOWAMAwNrIagAArI+8BgDA2shqcwF53QEAANyREequLgAAIHeR1QAAWJ+v8/qdd95RZGSkwsLC1KxZM+3cudNp/ZUrV+rOO+9UWFiY6tatqy+//NJelpKSotGjR6tu3boqVKiQypcvr969e+uPP/5waCMyMjJTv6dPn+523wEAsAKOrc1x5TaAPGcYhmnZjj8umZat/dcK07Lzm80PXoICOa/Hn3HGGgDkDWd5/eOZy6Zl/5nzvmnZua2zTMvIa/9FVgNA3nCW1XHOsvrNRaZl57a9ZVpGVvs3X+b1xx9/rOjoaM2bN0/NmjXT7NmzFRUVpUOHDqlMmTKZ6n///ffq0aOHpk2bpocffljLly9X165dFRcXpzp16ujatWuKi4vThAkTVL9+fV2+fFnDhw9X586d9eOPPzq09dJLL2ngwIH2fxcpUsStvgOALznL6j3n/jQtW/PGQtMysvr2lRvH1u+8845ef/11nTlzRvXr19fbb7+tpk2bmtZfuXKlJkyYoOPHj6t69ep67bXX9OCDD0r662S08ePH68svv9Svv/6qokWLqn379po+fbrKly9vbyMyMlInTpxwaHfatGkaM2aMy/3mmw0A8CucsQYAgLXlRlZzNRgAADnjy7yeNWuWBg4cqH79+ql27dqaN2+eChYsqEWLsj6R4s0331THjh01cuRI1apVS1OnTlWjRo00Z84cSVLRokUVExOjxx9/XDVr1tQ999yjOXPmaNeuXTp58qRDW0WKFFFERIR9KVSokGdvEAAAeczXx9YZJ6NNmjRJcXFxql+/vqKionTu3Lks62ecjDZgwADt3r1bXbt2VdeuXbV3715JcjgZLS4uTqtWrdKhQ4fUuXPnTG299NJLOn36tH15/vnn3eo7k9sAAP9ic3PxAAPmAADkgI+z2p8PwAEAsAwP8johIcFhSUpKytRscnKydu3apfbt29vXBQQEqH379tq2bVuWXdm2bZtDfUmKiooyrS9J8fHxstlsKlasmMP66dOnq2TJkmrYsKFef/11paamOnsXAACwLh8fW/vzyWhMbgMA/ApnrAEAYG2+zmp/PgAHAMAqPMnrihUrqmjRovZl2rRpmdq9cOGC0tLSVLZsWYf1ZcuW1ZkzZ7Lsy5kzZ9yqf+PGDY0ePVo9evRQeHi4ff2wYcO0YsUKbdy4Uc8++6xeffVVjRo1yq33BQAAq/Akq105EU3y/5PReOY2AMCv+PpZIzcPmEvSvHnz9MUXX2jRokVZPvfj5gFzSZo6dapiYmI0Z84czZs3zz5gfrM5c+aoadOmOnnypCpVqmRfnzFgDgCAP/MkqxMSEhzWh4aGKjQ0NFP9jAPwsWPH2te5cgAeHR3tsC4qKkpr1qwx7ZezA/CpU6eqUqVKeuqpp/TCCy8oKIjDagCA//Ekr0+dOuUwmZxVVvtaSkqKHn/8cRmGoblz5zqU3Zz39erVU0hIiJ599llNmzYtT/oKAEBOeJLVFStWdFg/adIkTZ48OVN9ZyejHTx4MMvX8ObJaI0aNVKJEiX0/fffa+zYsTp9+rRmzZqV7X5m4CgcAOBXGDAHAMDaOAAHAMD6PMnr8PBwh2zMSqlSpRQYGKizZ886rD979qzpydwREREu1c+Y2D5x4oQ2bNiQbV+aNWum1NRUHT9+XDVr1nRaFwAAq/HXE9Ek35+Mxog5AMC/uPMMkf+rx4A5AAC5yIOszi8H4AAAWIYHee2KkJAQNW7cWLGxserataskKT09XbGxsRo6dGiW2zRv3lyxsbEaMWKEfV1MTIyaN29u/3dGRh8+fFgbN25UyZIls+3Lnj17FBAQoDJlyri+AwAAWIUHWe3KiWiS/5+MxuQ2gDyXlJpuWjbq+ZmmZae/e9O0LCgwIEd9gnVxxhoA5I2kFPO8fuF//mla5iyvg4PI69uRr64Ek/z/ABwAfMlZVg8b8rpp2R9kdb7ky0d+RUdHq0+fPmrSpImaNm2q2bNnKzEx0f74r969e6tChQr2Z3YPHz5crVu31syZM/XQQw9pxYoV+vHHHzV//nxJf2X0Y489pri4OK1du1ZpaWn2E8pLlCihkJAQbdu2TTt27FDbtm1VpEgRbdu2TS+88IKefvppFS9e3K3+A4CvOMvq/xlsPg5OVudPvsxqfz8ZjcltAIBfYcAcAABr4wAcAADr82VeP/HEEzp//rwmTpyoM2fOqEGDBlq3bp39rmcnT55UQMB/J2PuvfdeLV++XOPHj9e4ceNUvXp1rVmzRnXq1JEk/f777/r8888lSQ0aNHB4rY0bN6pNmzYKDQ3VihUrNHnyZCUlJemOO+7QCy+8kOkxYgAA+AtfZrXk3yejMbkNAPArNrkR6u7cO00MmAMA4A2+zGrJvw/AAQCwCl/n9dChQ02Pozdt2pRpXffu3dW9e/cs60dGRsowDKev16hRI23fvt3tfgIAYFW+zmp/PhmNyW0AgF/hjDUAAKzN11ntzwfgAABYha/zGgAA5ExuZLW/nozG5DYAwL/Y/m9xta6bGDAHACCHfJzVkv8egAMAYBm5kNcAACAHyGpTTG4DAPwKZ6wBAGBtXAkGAID1kdcAAFgbWW2OyW0AgF8h1AEAsDayGgAA6yOvAQCwNrLaHJPbAHJFerr5latDVv5kWvbj2ummZaHBAaZluH3ZbH8trtYFALjOaV5/ap7Xu754zbSMvM5/yGoA8B1nWT3YybG1s6wOI6vzJfIaAHzD06yO+8+rpmVkdf5EVptjchsA4Ff+CnVXz1jzcWcAAEAmZDUAANZHXgMAYG1ktTkmtwEA/sWNM9aUz0IdAABLIKsBALA+8hoAAGsjq00xuQ0A8Cs8awQAAGsjqwEAsD7yGgAAayOrzTG5DQDwKzxrBAAAayOrAQCwPvIaAABrI6vNMbkNAPArAQE2BQS4ltaGi/UAAID3kNUAAFgfeQ0AgLWR1eaY3AYA+BXOWAMAwNrIagAArI+8BgDA2shqc0xuA8gVv55LNC0b06aaaVmZ8FDTsvz2HAn8hWeNAIDvHDl71bRs3N+rm5aR17gZWQ0AvnPo9BXTsgntzbO6LFmNW5DXAOAbB/8wz+qJ99cwLeO4Grciq80xuQ0A8CucsQYAgLWR1QAAWB95DQCAtZHV5pjcBgD4Fc5YAwDA2shqAACsj7wGAMDayGpzTG4DAPwKoQ4AgLWR1QAAWB95DQCAtZHV5pjcBgD4FW7HAgCAtZHVAABYH3kNAIC1kdXm3J7cTkxM1PTp0xUbG6tz584pPT3dofzXX3/1WucAALiVTW6csaZ8luoAAFgAWQ0AgPWR1wAAWBtZbc7tye1nnnlGmzdvVq9evVSuXLl8d6k7ACBvccYaAADWRlYDAGB95DUAANZGVptze3L7q6++0hdffKEWLVr4oj8A/Ni1pFTTslJFQkzL0gzzNsOCA3PSJdyGeNYIAORM4g3zvC4dHmpaZjjJ69CggJx0CbcZshoAcubK9RTTsohiYaZlzv6ihpDVuAV5DQCeS3CS1eWKm2d1gJO/p6HBZDUckdXm3J7cLl68uEqUKOGLvgAAkC3OWAMAwNrIagAArI+8BgDA2shqc26fCjJ16lRNnDhR165d80V/AABwKuOMNVcXAACQu8hqAACsj7wGAMDayGpzbl+5PXPmTB09elRly5ZVZGSkgoODHcrj4uK81jkAAG7FGWsAAFgbWQ0AgPWR1wAAWBtZbc7tye2uXbv6oBsAALiGZ40AAGBtZDUAANZHXgMAYG1ktTm3J7cnTZrki34AAOAaN85YU/7KdAAArIGsBgDA+shrAACsjaw25fbkdoZdu3bpwIEDkqS77rpLDRs29FqnAAAwwxlrAABYG1kNAID1kdcAAFgbWW3O7cntc+fO6cknn9SmTZtUrFgxSdKff/6ptm3basWKFSpdurS3+wjAQtLTDdOykKAA07KAACd/XM2bVKCz7ZAv8awRAMhemod5HRjo2R9O8ho3I6sBIHupaemmZWHBgaZlaYZ5xjv7k0pW41bkNQA452lWG86y2skf1AD+2OIWZLU585EtE88//7yuXLmiffv26dKlS7p06ZL27t2rhIQEDRs2zBd9BADALuOMNVcXAACQu8hqAACsj7wGAMDayGpzbl+5vW7dOq1fv161atWyr6tdu7beeecddejQwaudAwDgVpyxBgCAtZHVAABYH3kNAIC1kdXm3L5yOz09XcHBwZnWBwcHKz3d/DYNZt555x1FRkYqLCxMzZo1086dO53W//PPP/Xcc8+pXLlyCg0NVY0aNfTll1+6/boAAP/EGWt5g7wGALiKrM4bZDUAwB3kdd4grwEAriKrzbl95fbf//53DR8+XB999JHKly8vSfr999/1wgsvqF27dm619fHHHys6Olrz5s1Ts2bNNHv2bEVFRenQoUMqU6ZMpvrJycm6//77VaZMGX366aeqUKGCTpw4YX/2NwDg9udOWOe3UPcV8hoA4A6yOveR1QAAd5HXuY+8BgC4g6w25/aV23PmzFFCQoIiIyNVtWpVVa1aVXfccYcSEhL09ttvu9XWrFmzNHDgQPXr10+1a9fWvHnzVLBgQS1atCjL+osWLdKlS5e0Zs0atWjRQpGRkWrdurXq16/v7m4AAPxUxu1YXF2Qc+Q1AMAdZHXuI6sBAO7ydV67e4XyypUrdeeddyosLEx169Z1uDo5JSVFo0ePVt26dVWoUCGVL19evXv31h9//OHQxqVLl9SzZ0+Fh4erWLFiGjBggK5evep+532EvAYAuCM3jq39Na/dntyuWLGi4uLi9MUXX2jEiBEaMWKEvvzyS8XFxelvf/uby+0kJydr165dat++/X87ExCg9u3ba9u2bVlu8/nnn6t58+Z67rnnVLZsWdWpU0evvvqq0tLSTF8nKSlJCQkJDgsAwH9xO5bcRV4DANyVG1ntrwfgvkBWAwA84cu8zrhCedKkSYqLi1P9+vUVFRWlc+fOZVn/+++/V48ePTRgwADt3r1bXbt2VdeuXbV3715J0rVr1xQXF6cJEyYoLi5Oq1at0qFDh9S5c2eHdnr27Kl9+/YpJiZGa9eu1bfffqtBgwZ59gZ5WW7kNVkNALcXXx9b+3Neuz25Lf31ht5///16/vnn9fzzzzuEsqsuXLigtLQ0lS1b1mF92bJldebMmSy3+fXXX/Xpp58qLS1NX375pSZMmKCZM2fq5ZdfNn2dadOmqWjRovalYsWKbvcVwH85OzMowGYzXYIDA0yXoECb6cJVPrgVZ6zlLvIa8E8BNvMlMMBmujjLa2cLeY2b+Tqr/fkA3BfIasA/OctjT7M6yMkSEGAzXZA/+TKv3b1C+c0331THjh01cuRI1apVS1OnTlWjRo00Z84cSVLRokUVExOjxx9/XDVr1tQ999yjOXPmaNeuXTp58qQk6cCBA1q3bp3ee+89NWvWTC1bttTbb7+tFStWZDr+zgu5kddkNeA+wzBMF2d5HORscZLHztokq3ErXx9b+3NeuzS5/dZbb+nGjRv2/+9s8aX09HSVKVNG8+fPV+PGjfXEE0/o//2//6d58+aZbjN27FjFx8fbl1OnTvm0jwAA3+KMNesjrwEgf/N1VvvzAbhVkNUAAE/y+targpOSkjK168kVytu2bct08VRUVJRpfUmKj4+XzWazP39627ZtKlasmJo0aWKv0759ewUEBGjHjh0uvy9W4m5ek9UAcHvxVVZL/p/XQa5UeuONN9SzZ0+FhYXpjTfeMK1ns9k0bNgwl164VKlSCgwM1NmzZx3Wnz17VhEREVluU65cOQUHByswMNC+rlatWjpz5oySk5MVEhKSaZvQ0FCFhoa61CcAgPXZ5PqZaJ6c13jzgLkkzZs3T1988YUWLVqkMWPGZKp/84C5JE2dOlUxMTGaM2eO5s2bZx8wv9mcOXPUtGlTnTx5UpUqVbIPmP/www/2YH/77bf14IMP6p///KfKly/vwZ54B3kNAHCXJ1l9620zzXIh4wB87Nix9nWuHIBHR0c7rIuKitKaNWtM++XuAfgjjzziZC99i6wGAHjCk7y+9UrgSZMmafLkyQ7rnF2hfPDgwSzbP3PmjFtXNN+4cUOjR49Wjx49FB4ebm+jTJkyDvWCgoJUokQJ03ZyU27kNVkNALcXX2W15P957dKV28eOHVPJkiXt/99s+fXXX11+4ZCQEDVu3FixsbH2denp6YqNjVXz5s2z3KZFixY6cuSI0tPT7et++eUXlStXLsuDbwDA7cfZ7e+zWqT8c8aaL5DXAAB3eZLVFStWdLiN5rRp07Js25NbelrpANwXyGoAgCc8yetTp045XBl888lmuSUlJUWPP/64DMPQ3Llzc/31PUVeAwDc5a9ZLfk+r91+5vZLL72ka9euZVp//fp1vfTSS261FR0drQULFmjp0qU6cOCAhgwZosTERPvVcr1793Z444cMGaJLly5p+PDh+uWXX/TFF1/o1Vdf1XPPPefubgAA/JQnzxphwDxnyGsAgDs8yer8cgDuK2Q1AMBdnuR1eHi4w5LVVcKeXKEcERHhUv2MnD5x4oRiYmLsx9QZbdz6OLHU1FRdunTJ9HVzG3kNAHCHr7Ja8v+8dum25DebMmWKBg8erIIFCzqsv3btmqZMmaKJEye63NYTTzyh8+fPa+LEiTpz5owaNGigdevW2ScJTp48qYCA/86/V6xYUV9//bVeeOEF1atXTxUqVNDw4cM1evRod3cDAOCn3Hk+Z0a9U6dOOYRoXt2my18HzMlrAIA7PMnqjAPv7OTWAfiGDRv8asCcrAYAuMuTvHbFzVcod+3aVdJ/r1AeOnRolts0b95csbGxGjFihH1dTEyMwxXNGTl9+PBhbdy40X6X0Zvb+PPPP7Vr1y41btxYkrRhwwalp6erWbNmLvffl8hrAIA7fJXVkv/ntduT24ZhZPkm/e///q9KlCjhbnMaOnSo6Ru1adOmTOuaN2+u7du3u/06AIDbQ4Dtr8XVuhID5t5AXgMAXOVJVrvK3w/AfYmsBgC4w5d5HR0drT59+qhJkyZq2rSpZs+enekK5QoVKtjvqjZ8+HC1bt1aM2fO1EMPPaQVK1boxx9/1Pz58yX9ldOPPfaY4uLitHbtWqWlpdnvclaiRAmFhISoVq1a6tixowYOHKh58+YpJSVFQ4cO1ZNPPqny5cu7twM+RF4DAFzly6yW/DuvXZ7cLl68uP0sgRo1ajhMcKelpenq1asaPHiwyy8MwD85OwPI2clBhmE4a9Sj10M+ZXPje8GAOYB8yid57eHrIR/yYVZL/n0ADgAZnP+dNM9jEhde48O8dvcK5XvvvVfLly/X+PHjNW7cOFWvXl1r1qxRnTp1JEm///67Pv/8c0lSgwYNHF5r48aNatOmjSTpww8/1NChQ9WuXTsFBASoW7dueuutt9zrPIB8xfNjWbIaucDHx9b+nNcuT27Pnj1bhmGof//+mjJliooWLWovCwkJUWRkpMNAPgAAvnDzM0RcqesuBswBAMgZX2e1Px+AAwBgFb7Oa3evUO7evbu6d++eZf3IyEiXTsIsUaKEli9f7lY/AQCwKl9nteS/ee3y5HafPn0kSXfccYfuvfdeBQcH5+iFAQDwhO3//udqXXcxYA4AQM74Oqsl/z0ABwDAKnIjrwEAgOfIanMuTW4nJCTYnw3asGFDXb9+XdevX8+yrivPNAUAwFO+ftaIxIA5AAA5kRtZDQAAcoa8BgDA2shqcy5NbhcvXlynT59WmTJlVKxYsSzv8W4Yhmw2m9LS0rzeSQAAMthsNpefNcIzYAEAyH1kNQAA1kdeAwBgbWS1OZcmtzds2KASJUpI+usWqgAA5JXceNYIAADwHFkNAID1kdcAAFgbWW3Opcnt1q1bZ/n/AQDIbQE2mwJcTGtX6wEAAO8hqwEAsD7yGgAAayOrzQW4u8G6deu0detW+7/feecdNWjQQE899ZQuX77s1c4BuH1k3ELD3QW4VcYZa64uAADXkdfwBrIaAHLG0zwmq+EO8hoAPEdWIzeQ1ebcntweOXKkEhISJEk///yzoqOj9eCDD+rYsWOKjo72egcBALgZ/2EIAIC1kdUAAFgfeQ0AgLWR1eZcui35zY4dO6batWtLkv7973+rU6dOevXVVxUXF6cHH3zQ6x0EAOBm7pyJls8yHQAASyCrAQCwPvIaAABrI6vNuT25HRISomvXrkmS1q9fr969e0uSSpQoYb+iGwAAX+FZIwAAWBtZDQCA9ZHXAABYG1ltzu3J7ZYtWyo6OlotWrTQzp079fHHH0uSfvnlF/3tb3/zegcBAAAAAAAAAAAAAHD7mdtz5sxRUFCQPv30U82dO1cVKlSQJH311Vfq2LGj1zsIAMDNbG4uAAAgd5HVAABYH3kNAIC1kdXm3L5yu1KlSlq7dm2m9W+88YZXOgQAgDM2m002F2+z4mo9AADgPWQ1AADWR14DAGBtZLU5tye3JSktLU1r1qzRgQMHJEl33XWXOnfurMDAQK92DgCAWwXY/lpcrQsAAHIXWQ0AgPWR1wAAWBtZbc7tye0jR47owQcf1O+//66aNWtKkqZNm6aKFSvqiy++UNWqVb3eSQAAMnDGGgAA1kZWAwBgfeQ1AADWRlabc/uZ28OGDVPVqlV16tQpxcXFKS4uTidPntQdd9yhYcOG+aKPAAA4sNlcWwAAQN4gqwEAsD7yGgAAayOrs+b2ldubN2/W9u3bVaJECfu6kiVLavr06WrRooVXOwcAwK04Yw0AAGsjqwEAsD7yGgAAayOrzbk9uR0aGqorV65kWn/16lWFhIR4pVMAAJjhWSMAAFgbWQ0AgPWR1wAAWBtZbc7t25I//PDDGjRokHbs2CHDMGQYhrZv367Bgwerc+fOvugjAAB2GWesuboAAIDcRVYDAGB95DUAANZGVptze3L7rbfeUtWqVdW8eXOFhYUpLCxMLVq0ULVq1fTmm2/6oo8AANjZ3FwAAEDuIqsBALA+8hoAAGsjq825fVvyYsWK6bPPPtPhw4d14MAB2Ww21apVS9WqVfNF/wAAcBBgsynAxTPRXK0HAAC8h6wGAMD6yGsAAKyNrDbn9uR2hurVq9sntPPb5e4AgLxjs/21uFoXAADkLrIaAADrI68BALA2stqc27cll6SFCxeqTp069tuS16lTR++99563+wYAQCY8awQAAGsjqwEAsD7yGgAAayOrzbl95fbEiRM1a9YsPf/882revLkkadu2bXrhhRd08uRJvfTSS17vJAAAGThjDQAAayOrAQCwPvIaAABrI6vNuT25PXfuXC1YsEA9evSwr+vcubPq1aun559/nsltAIBP8awRAACsjawGAMD6yGsAAKyNrDbn9uR2SkqKmjRpkml948aNlZqa6pVOAQBghjPWAACwNrIaAADrI68BALA2stqc28/c7tWrl+bOnZtp/fz589WzZ0+vdAoAADM8awQAAGsjqwEAsD7yGgAAayOrzbl95bYkLVy4UN98843uueceSdKOHTt08uRJ9e7dW9HR0fZ6s2bN8k4vAQD4PwFy/cwst8/gAgAAOUZWAwBgfeQ1AADWRlabc3tye+/evWrUqJEk6ejRo5KkUqVKqVSpUtq7d6+9Xn47SwAAkDvcORONLAIAIPeR1QAAWB95DQCAtZHV5tye3N64caMv+gEAgEtsNimAZ40AAGBZZDUAANZHXgMAYG1ktbn8dqU6AMDPBdjcWwAAQO4iqwEAsD5f5/U777yjyMhIhYWFqVmzZtq5c6fT+itXrtSdd96psLAw1a1bV19++aVD+apVq9ShQweVLFlSNptNe/bsydRGmzZtMj1/dPDgwe53HgAAC+DY2hyT2wAAv3LrgWp2CwAAyF25kdUMmAMAkDO+zOuPP/5Y0dHRmjRpkuLi4lS/fn1FRUXp3LlzWdb//vvv1aNHDw0YMEC7d+9W165d1bVrV4dHYCYmJqply5Z67bXXnL72wIEDdfr0afsyY8YMt/oOAIBVcGxtjsltAIBfyY0z1vw11AEAsAJfZzUD5gAA5Jwv83rWrFkaOHCg+vXrp9q1a2vevHkqWLCgFi1alGX9N998Ux07dtTIkSNVq1YtTZ06VY0aNdKcOXPsdXr16qWJEyeqffv2Tl+7YMGCioiIsC/h4eHudR4AAIvg2Nock9sAAL9is7m3uMufQx0AACvwdVYzYA4AQM75Kq+Tk5O1a9cuh0wNCAhQ+/bttW3btiy32bZtW6YMjoqKMq3vzIcffqhSpUqpTp06Gjt2rK5du+Z2GwAAWAHH1uaY3AYA+JUAm82txV3+HOoAAFiBJ1mdkJDgsCQlJWXZNgPmAAB4h6/y+sKFC0pLS1PZsmUd1pctW1ZnzpzJsi9nzpxxq76Zp556Sh988IE2btyosWPHatmyZXr66afdagMAAKvg2NqcR5Pby5YtU4sWLVS+fHmdOHFCkjR79mx99tlnnjQHAIDLAtxcpPwT6gAAWIEnWV2xYkUVLVrUvkybNi3LthkwBwDAO3yZ13ll0KBBioqKUt26ddWzZ0+9//77Wr16tY4ePZrXXQMAwG0cW5sLcqu2pLlz52rixIkaMWKEXnnlFaWlpUmSihUrptmzZ6tLly7uNgkAgMvcuc1KRr2KFSs6rJ80aZImT56cqb6zUD948GCWr+HNUK9cubLKly+vn376SaNHj9ahQ4e0atUqt9oBACCveZLVp06dcrhjSWhoqA96ljODBg2y//+6deuqXLlyateunY4ePaqqVavmYc8AAHCfr/K6VKlSCgwM1NmzZx3Wnz17VhEREVm2HxER4VZ9VzVr1kySdOTIEbIaAOB3OLY25/bk9ttvv60FCxaoa9eumj59un19kyZN9OKLL7rbHAAAbgmQ67cbD9Bf9fJLqAMAYAWeZHV4eLhLj+NgwBwAAO/wVV6HhISocePGio2NVdeuXSVJ6enpio2N1dChQ7Pcpnnz5oqNjdWIESPs62JiYtS8eXOX+mdmz549kqRy5crlqB0AAPICx9bm3L4t+bFjx9SwYcNM60NDQ5WYmOhucwAAuCXjjDVXF+m/oZ6xmE1uWzXUAQDwJ55ktatuHjDPkDFgbjYAnjFgfjMGzAEA+Z0v8zo6OloLFizQ0qVLdeDAAQ0ZMkSJiYnq16+fJKl3794aO3asvf7w4cO1bt06zZw5UwcPHtTkyZP1448/OkyGX7p0SXv27NH+/fslSYcOHdKePXvsd007evSopk6dql27dun48eP6/PPP1bt3b913332qV69eDt8tAAByH8fW5ty+cvuOO+7Qnj17VLlyZYf169atU61atdxtDgAAtwTY/lpcresOzjAHACDnfJnV0l8D5n369FGTJk3UtGlTzZ49O9OAeYUKFezPFhs+fLhat26tmTNn6qGHHtKKFSv0448/av78+fY2L126pJMnT+qPP/6Q9NeAufTXSWwRERE6evSoli9frgcffFAlS5bUTz/9pBdeeIEBcwCA3/JlXj/xxBM6f/68Jk6cqDNnzqhBgwZat26d/ZFeJ0+eVEDAf6+5uvfee7V8+XKNHz9e48aNU/Xq1bVmzRrVqVPHXufzzz+3Z70kPfnkk5L++9ixkJAQrV+/3v7fBRUrVlS3bt00fvx49zoPAIBFcGxtzu3J7ejoaD333HO6ceOGDMPQzp079dFHH2natGl677333G0OAAC32Gxy+XYs7p6xJvl3qAMAYAW+zmoGzAEAyDlf5/XQoUNNTxLftGlTpnXdu3dX9+7dTdvr27ev+vbta1pesWJFbd682d1uAgBgWRxbm3N7cvuZZ55RgQIFNH78eF27dk1PPfWUypcvrzfffNPeSQAAfMWd26zkt1AHAMAKfJ3VEgPmAADkVG7kNQAA8BzH1ubcntyWpJ49e6pnz566du2arl69qjJlyuS4IwAAuMLXt2OR/DfUAQCwgtzIagAAkDPkNQAA1kZWm3N7cvvYsWNKTU1V9erVVbBgQRUsWFCSdPjwYQUHBysyMtLbfQQAwM72f/9ztS4AAMhdZDUAANZHXgMAYG1ktbmA7Ks46tu3r77//vtM63fs2OH0qjQAALwh44w1VxcAAJC7yGoAAKyPvAYAwNrIanNuT27v3r1bLVq0yLT+nnvu0Z49e7zRJwAATBHqAABYG1kNAID1kdcAAFgbWW3O7duS22w2XblyJdP6+Ph4paWleaVTAACYsdlsstlcvB2Li/UAAID3kNUAAFgfeQ0AgLWR1ebcnty+7777NG3aNH300UcKDAyUJKWlpWnatGlq2bKl1zsIwH8YhuGkzHw7Z39389sfZWTPnTPR8tsZawDgCmd57SnyGjcjqwEgZ8hq5AbyGgA8R1YjN5DV5tye3J4+fbpat26tmjVrqlWrVpKkLVu2KCEhQRs2bPB6BwEAuJnN5vyEiFvrAgCA3EVWAwBgfeQ1AADWRlabc/uZ23fddZd++uknPf744zp37pyuXLmi3r176+DBg6pTp44v+ggAgF2AzebWAgAAchdZDQCA9ZHXAABYG1ltzq0rt1NSUtSxY0fNmzdPr776qq/6BACAKW7HAgCAtZHVAABYH3kNAIC1kdXm3JrcDg4O1k8//eSrvgAAkD03bseifBbqAABYAlkNAID1kdcAAFgbWW3K7duSP/3001q4cKEv+gIAQLYCZHNrAQAAuYusBgDA+shrAACsjaw259aV25KUmpqqRYsWaf369WrcuLEKFSrkUD5r1iyvdQ4AgFvZ3DhjLZ89agQAAEsgqwEAsD7yGgAAayOrzbk9ub137141atRIkvTLL784lNny27sH3KYMw3BS5mQ7D9t0tmGgk/tL8Dcnf+JZIwDwl1zPayfIa9yMrAaAv5DVsDLyGgByP6udZW6Ak1bJ6vyJrDbn9uT2xo0bfdEPAABcEmCzKcDF/6BztR4AAPAeshoAAOsjrwEAsDay2pzbk9sAAOQlbscCAIC1kdUAAFgfeQ0AgLWR1eac3JQoa23bttXf//5308UT77zzjiIjIxUWFqZmzZpp586dLm23YsUK2Ww2de3a1aPXBQD4nwDZ7GetZbson6W6D5HVAABXkdV5g6wGALiDvM4b5DUAwFVktTm3J7cbNGig+vXr25fatWsrOTlZcXFxqlu3rtsd+PjjjxUdHa1JkyYpLi5O9evXV1RUlM6dO+d0u+PHj+vFF19Uq1at3H5NAID/yjhjzdUFOUdWAwDcQVbnPrIaAOAu8jr3kdcAAHeQ1ebcntx+4403HJY5c+Zo69atGjFihIKDg93uwKxZszRw4ED169dPtWvX1rx581SwYEEtWrTIdJu0tDT17NlTU6ZMUZUqVdx+TQCA/wpwc0HOkdUAAHeQ1bmPrAYAuIu8zn3kNQD8//buPzyq6sD/+GdmkkmCIUGkJIBoEFFAUSpIDOqXWlPT1bqmtlu0VpFS7K7CqmltRZGguEL9QVGh5YFWtPtocXWV7VqalUaxVVKo/Gi1AlWBoisJoEIwLvk15/uHMnWYuYfMkJk5M3m/fPK0zLl35s6d5L5z50xmEA9a7a3b7u+3vvUta4hjaWtr0/r161VZWfn3DfL7VVlZqYaGBs/17rrrLvXv319Tpkw54m20traqubk54gsAkLl8Pl9cXzg6qWi1RK8BIJvQ6tSi1QCARNDr1OJ5cABAvGi1t5zuuqKGhgbl5+fHtc7evXvV2dmpkpKSiMtLSkq0ZcuWmOu8/PLL+vnPf65NmzZ16Tbmzp2rO++8M67tQs9ijPEcy+YDgu1+W4YUsgx2hrzHLEPWT4PwWx6DLH54YOGT/Xvm8GVxdFLRaole48hs3bLJ9JanutfW3WzZlfQan0WrU4tWwxUhS19sLcjmVtvOg23r0WqkAr1OLZ4HhwtsfbH9nGd6x5PRatt6tvNxn7zHcgPef4uaAbsZSUCrvcU9uX355ZdH/NsYo127dunVV1/VHXfc0W0bFsuBAwd09dVXa+nSperXr1+X1pkxY4ZqamrC/25ubtbgwYOTtYkAgCTz+3zWJ2YOXxaplUirJXoNANmEVruNVgMAJHrtOp4HBwDQam9xT24XFxdH/Nvv9+vUU0/VXXfdpYsuuiiu6+rXr58CgYCampoiLm9qalJpaWnU8m+//bZ27NihSy+9NHxZKBSSJOXk5Gjr1q0aOnRoxDp5eXnKy8uLa7sAAG7rWalOr1S0WqLXAJBtaHXq0GoAQKLoderwPDgAIBG0Ora4J7eXLVvWbTceDAY1ZswY1dfXq7q6WtInka6vr9e0adOilh8+fLhee+21iMtmzpypAwcO6MEHH+SVaADQA/h8XX8rnh72grWkoNUAgHjR6tSi1QCARNDr1KLXAIB40WpvCX/m9vr167V582ZJ0mmnnabPf/7zCV1PTU2NJk2apLFjx2rcuHFasGCBWlpaNHnyZEnSNddco0GDBmnu3LnKz8/X6aefHrF+nz59JCnqcgBAdvL5fF3+PJ9M+NyfTECrAQDxoNWpR6sBAPGi16lHrwEA8aDV3uKe3N69e7euuOIKrV69OhzUffv26YILLtDy5cv1uc99Lq7rmzhxovbs2aNZs2apsbFRo0ePVl1dnUpKSiRJO3fulN/vj3czAQBZyv/pV1eXxdGj1QCAeNDq1KPVAIB40evUo9cAgHjQam8+Y4yJZ4WJEydq27Zt+sUvfqERI0ZIkt544w1NmjRJJ598sn75y18mZUO7S3Nzs4qLi9X0/n4VFRWle3PgANuPQDa/2sV2v21HhZBlsDPkPWYZsn5uRDDH+7Ds92fv45PJmpubVXJcsfbv797j7KHj97Lfb1Gvwt5dWufjjw5o8vnDu31bkHz0GoeL81fWsExveap7bd3Nll0ZDNDrTJOMXtPqnoVW43AhS19sOc7mVtvOg23r0WpInFvj6NFqHM7WF1sJMr3jyWi1bT3b+bhtb+XS6oxDq9Mn7r/crqur029/+9vwxLYkjRw5UosWLdJFF13UrRsHdJdEwy1l9sR3Mp4Q7+j0HmvrDCW0XsAS55yA95j/CI8espNPR/q5jVwWQObosHTE1t1Mb3mqe91u2c/W35ks+yvH0nJ63fPQaiB7tbZ3eo7Zzut8lp92fxa32raerdUdlh7bWm173ptW43D0GshOB9sSa7Vt8jQTjgFJabWtx5aOd1qu02/53Sbgt6yXEY8Cuhut9hb35HYoFFJubm7U5bm5uQqFvJ8oAwCgO/BZIwAAuI1WAwDgPnoNAIDbaLW3uN+G/Ytf/KJuvPFGvffee+HL/vd//1c333yzLrzwwm7dOAAADueP8wsAAKQWrQYAwH30GgAAt9Fqb3Hf34ULF6q5uVllZWUaOnSohg4dqiFDhqi5uVkPP/xwMrYRAICwQ69Y6+oXAABILVoNAID7kt3rRYsWqaysTPn5+SovL9e6deusyz/11FMaPny48vPzNWrUKK1cuTJi/JlnntFFF12k4447Tj6fT5s2bYq6joMHD+qGG27Qcccdp8LCQn3ta19TU1NT3NsOAIALUnFunam9jntye/DgwdqwYYN+/etf66abbtJNN92klStXasOGDTr++OPjvToAAOLii/MrEZkadQAAXECrAQBwXzJ7/eSTT6qmpka1tbXasGGDzjzzTFVVVWn37t0xl1+zZo2uvPJKTZkyRRs3blR1dbWqq6v1+uuvh5dpaWnReeedpx/96Eeet3vzzTfrv//7v/XUU0/ppZde0nvvvafLL788zq0HAMANyT63zuReJ/SX6j6fT1/60pc0ffp0TZ8+XZWVlYlcDQAAcfP54vuKVyZHHQAAF9BqAADcl8xez58/X1OnTtXkyZM1cuRILV68WL169dIjjzwSc/kHH3xQX/7yl3XLLbdoxIgRmjNnjs466ywtXLgwvMzVV1+tWbNmeT4PvX//fv385z/X/Pnz9cUvflFjxozRsmXLtGbNGv3hD3+I7w4AAOCAZJ9bZ3Kvuzy53dDQoOeeey7isl/84hcaMmSI+vfvr+uuu06tra1dvmEAABLhly+ur3hlctQBAHABrQYAwH2J9Lq5uTniK9ZzwW1tbVq/fn1EU/1+vyorK9XQ0BBzWxoaGqIaXFVV5bl8LOvXr1d7e3vE9QwfPlwnnHBCXNcDAIArktVqKfN7ndPVBe+66y594Qtf0Fe+8hVJ0muvvaYpU6bo2muv1YgRI3Tfffdp4MCBmj17dpdvHOhOB9s6Pcdsr1oJ+L0HfZYn2/wylttz47MDjfcmKmQZ7Oj0HmvtCHmOHWz3fgzaLOvZHoNewYDnGHqmeF6Jdmi55ubmiMvz8vKUl5cXtfyhqM+YMSN8WVeiXlNTE3FZVVWVVqxY0bWN1JGjfs4553T5ugDXtRzs8BzzW3oQsLwk07LaESbO3Gh5wr0OeY+1dXp3t7Xde6zdsp6t1wX0Gp9Bq4HMtv/jds+xHEsLLDlTjqXjpoe22tZc2/lzu+V83dbqvFxajUiJ9Hrw4MERl9fW1kY9F7x37151dnaqpKQk4vKSkhJt2bIl5vU3NjbGXL6xsbFrG/jpdQSDQfXp0+eorgfIBO8f8P4jw6Atuom9ka/1pNuV58gTbXWnpdW258htrbb139pq62OHnihZrZYyv9ddntzetGmT5syZE/738uXLVV5erqVLl0r6ZId57SQAALqL79P/urqs1HOiDgCAC2g1AADuS6TX77zzjoqKisKXx3ohGgAA6B602luXJ7c//PDDiCcEXnrpJf3DP/xD+N9nn3223nnnne7dOgAADpPIK9Z6StQBAHABrQYAwH2J9LqoqCii17H069dPgUBATU1NEZc3NTWptLQ05jqlpaVxLe91HW1tbdq3b1/Ei9HivR4AAFyRrFZLmd/rLr/PQUlJibZv3y7pk7eC27BhQ8Rbrx04cEC5ubldvmEAABLhi+NzRg69Yu1Q1A99eT1h7kLUj+Z6AABwAa0GAMB9ifS6K4LBoMaMGaP6+vrwZaFQSPX19aqoqIi5TkVFRcTykrRq1SrP5WMZM2aMcnNzI65n69at2rlzZ1zXAwCAK5LVainze93lye2LL75Yt956q37/+99rxowZ6tWrl84///zw+J///GcNHTq0yzcMAEAiDr1iratf8cj0qAMA4AJaDQCA+5LZ65qaGi1dulSPPfaYNm/erH/5l39RS0uLJk+eLEm65pprNGPGjPDyN954o+rq6vTAAw9oy5Ytmj17tl599VVNmzYtvMwHH3ygTZs26Y033pD0SYc3bdoU/niQ4uJiTZkyRTU1NXrxxRe1fv16TZ48WRUVFRF/oAUAQKZIZqulzO51l9+WfM6cObr88ss1YcIEFRYW6rHHHlMwGAyPP/LII7rooou6fMMAACQikbdjiUdNTY0mTZqksWPHaty4cVqwYEFU1AcNGqS5c+dK+iTqEyZM0AMPPKBLLrlEy5cv16uvvqolS5aEr/ODDz7Qzp079d5770n6JOrSJ38FVlpaGhH1vn37qqioSNOnT+ckHACQkWg1AADuS2avJ06cqD179mjWrFlqbGzU6NGjVVdXF/7Iy507d8rv//vfXI0fP15PPPGEZs6cqdtuu03Dhg3TihUrdPrpp4eX+dWvfhVuvSRdccUVkqTa2lrNnj1bkvTjH/9Yfr9fX/va19Ta2qqqqir95Cc/iW/jAQBwRLLPrTO51z5jjIlnhf3796uwsFCBQCDi8g8++ECFhYURE94uam5uVnFxsZre39+l951H5jjY1uk5ZvvBDvi9B32WFS2rWddLpVDI+8c7ZPnR7+j0HmvtCHmOHWz3fgzaLOvZHoPP9fb+vMXcnC6/+QRSqLm5WSXHFWv//u49zh46fj+7bpuOKezdpXVaPjqgr447Ke5tWbhwoe67775w1B966CGVl5dLkr7whS+orKxMjz76aHj5p556SjNnztSOHTs0bNgw3Xvvvbr44ovD448++mhE1A/5bNQPHjyo733ve/rlL38ZEfWe+lan9Dp7tRzs8BzzW3pga4WtyX5Lk225TmXLE+61ZT1bd1vbvcfaOxPr9bHHeJ8DBOm1k5LRa1rds9Dq7LX/43bPsRxbqwPeY7b1emqr2y2ttnW83XK+bmt1cS/vjxKk1W7KhnNrpBetzl7vH2j1HLMd03MD3mPWc+4Ez8czodWdtlZbmmtrta3/tv1clO/9t6g8D+4mWp0+cU9uZzqintne/6jNc8x6km0ZS3Q9V6Ju+xG2dNQablucP7a8iODjVu8Ji/+zrJeXG/AcG3hsvudYvmU9pE+yo/5ff4wv6ped3XOink3odWZr3HfQcyzX9sS35UQ70fUSnRTPhF7bnhS3veDM1nLbxHeO5TEo7UOvM00yJ7dpdc9AqzPbzr0fe47l5Xp31faEeaJPptt6nOiL0RORaKttT6a3WV40Zmuu7fzZdr5ua/XnirxfOE6r3cS5NY4Wrc5sbzd95DlWEPQ+btueZ823tNrWEFuPXWm1bZbLPoGd2IvNDibhReN9C71fNE6r3USr06fLb0sOAIALfJ/+19VlAQBAatFqAADcR68BAHAbrfbGexkAAAAAAAAAAAAAAJzHX24DADKKz2f/7L3DlwUAAKlFqwEAcB+9BgDAbbTaG5PbAICM4lPX32alhzUdAAAn0GoAANxHrwEAcBut9sbkNgAgo/h9n3x1dVkAAJBatBoAAPfRawAA3EarvTG5DQDIKL5P/+vqsgAAILVoNQAA7qPXAAC4jVZ7Y3Ibztmxp8VzzG/54IDcHL/nWNAyZgKWH3qf93o+472asR5IvFf0dfMHIxjjfVshy1hnyHusozPkOdba7j3W1uE9ZntcLbsLPRSfNQK4Yet7BzzHcixtzcsNeI7lW27Pb2my3+8dC8uQM722sbW80zLW3uk9Zmtyq2XMyPL7FL3GZ9BqIHVsnfjzzv2eY/mWHtvYGhiwRDdgOYH2W35vsPfFjVbbzq1DtnNry1i75bzb1nFajXjQayA1bK3+4/YPPccKg95TNwHLn2jmWMY6Lc0NmMR+0F1ptW0zrGOWwWS02vbY0WocjlZ7Y3IbAJBRfOr6Z4j0sKYDAOAEWg0AgPvoNQAAbqPV3pjcBgBkFL989r/2P2xZAACQWrQaAAD30WsAANxGq70xuQ0AyCi8Yg0AALfRagAA3EevAQBwG632xuQ2ACCzUHUAANxGqwEAcB+9BgDAbbTaE5PbAICM4vv0v64uCwAAUotWAwDgPnoNAIDbaLU3JrcBAJnFJ3Xxo0Z63CvWAABwAq0GAMB99BoAALfRak9MbiMt/vS3fZ5jfstPa07AeywvFEhoWwI+v/dYyHiOGdtRJRMOJN53TcZY7rdlvURZD9CZsC+RUrwbC9C9bMf8tds+8BwL+r37mZebYJP9lt8B/N7bmWNtmuUGHTlIWLtrXc8yZlnP8uuNdVuArqLVQPeyHZtf2Lrbc6xXjvdTPomed+eGvPsfsp0/J/argTMS7aO9496jtlaHbOvZHgMaj8PQa6D72I6xz/1ll+dYcTDXcywv4N3cTkuPOxM8P87mStibm1iPOy2DtjGf5XcwWo3D0WpvTG4DADILVQcAwG20GgAA99FrAADcRqs9MbkNAMgofNYIAABuo9UAALiPXgMA4DZa7Y3JbQBARvHF8VkjXf5MEgAA0G1oNQAA7qPXAAC4jVZ7Y3IbAJBReDcWAADcRqsBAHAfvQYAwG202huT2wCAzELVAQBwG60GAMB99BoAALfRak9MbgMAMgqfNQIAgNtoNQAA7qPXAAC4jVZ7Y3IbR8UY4zm25u33PcdClvVyfH7PsV65Ac+xgN/7h7cz5H2dnd6bIsuQlW09Zw4xlg3xW/ZlTsB7LDfHez/bBC3rObO/4Aw+awSIn63XL2zd7Tlm+8W4M8f7Om1NDlmO+bbttLXVtl6ml8S29bZjnOUhUI7td6aA9+NjW4/jLT6LVgPxs7Xsub/s8hwL+m3nut7XaTsnt2XV3tzs5bMerLz3ibXjljFbx/2WbbGdy9vvA3oieg3Ex9bAp//0rudYQY7389m2n61Ej9uJ/rjyYw64h1Z7Y3IbAJBReDcWAADcRqsBAHAfvQYAwG202huT2wCAzELVAQBwG60GAMB99BoAALfRak9MbgMAMgqfNQIAgNtoNQAA7qPXAAC4jVZ7Y3IbAJBR+KwRAADcRqsBAHAfvQYAwG202huT2wCAjMK7sQAA4DZaDQCA++g1AABuo9XemNwGAGQWqg4AgNtoNQAA7qPXAAC4jVZ7YnIbR2SM8Rz7w9sfeI61d3qvZ+Q95g9Y1vMesrL9XPstg4keD1w5jvgs70Vhu98By2BuwO85VhD0vs7cgPd15uUGEtoW9Ex81ggQm63XL/11b0LXaXtLo4C1MYn1x94t7zHbepnwtkyJbr/tMcix9Dov1/t7xXZ7wRzv67Q9Puh5aDUQm63Vz29u8hwL+r2Pv0HL8d42ZuuE7RzMNmbtmedI5kv09xfbvrQ9PrbnRXIs592kGoej10A0W6uf+8suz7Fjcr2nWQpyvJ/37JXjvZ7tuVRbJ/zWVnsOJSXWtkamUsKttmy+7TEI2Vqd4O9S6JlotTcmtwEAGYXPGgEAwG20GgAA99FrAADcRqu9eb+8BAAAB/ni/ErEokWLVFZWpvz8fJWXl2vdunXW5Z966ikNHz5c+fn5GjVqlFauXBkxbozRrFmzNGDAABUUFKiyslJvvvlmxDJlZWXy+XwRX/PmzUvwHgAAkD60GgAA9yW717QaAICjw7m1Nya3AQCZJclVf/LJJ1VTU6Pa2lpt2LBBZ555pqqqqrR79+6Yy69Zs0ZXXnmlpkyZoo0bN6q6ulrV1dV6/fXXw8vce++9euihh7R48WKtXbtWxxxzjKqqqnTw4MGI67rrrru0a9eu8Nf06dPjvwMAAKQbrQYAwH1J7DWtBgCgG3Bu7YnJbQBARvHF+V+85s+fr6lTp2ry5MkaOXKkFi9erF69eumRRx6JufyDDz6oL3/5y7rllls0YsQIzZkzR2eddZYWLlwo6ZNXqy1YsEAzZ87UZZddpjPOOEO/+MUv9N5772nFihUR19W7d2+VlpaGv4455pi4tx8AgHSj1QAAuC+ZvabVAAAcPc6tvTG5DQDIKIc+a6SrX5LU3Nwc8dXa2hrzutva2rR+/XpVVlaGL/P7/aqsrFRDQ0PMdRoaGiKWl6Sqqqrw8tu3b1djY2PEMsXFxSovL4+6znnz5um4447T5z//ed13333q6OiIe/8AAJButBoAAPclq9e0GgCA7sG5tbecuJYGACDN4nmXlUPLDR48OOLy2tpazZ49O2r5vXv3qrOzUyUlJRGXl5SUaMuWLTFvo7GxMebyjY2N4fFDl3ktI0n/+q//qrPOOkt9+/bVmjVrNGPGDO3atUvz588/4v0EAMAltBoAAPclq9e0GgCA7sG5tTcmt3FEr7/T7DmW6/f+439fV3/qDpMXCHiP5XrfXjDHeyw34L0xAb/3mN8yluj98yW6YiK3ZRmz3DXl2Ac9h2z7sjOU2OOTyv2FDJFA1d955x0VFRWFL87Ly+v2zTpaNTU14f9/xhlnKBgM6rvf/a7mzp3r5PbCPX/6237PsV653m0NGe/jc44vse4m2mtbf2xpSkYqMqHXtu4GLW2VvL8fcgLd//igB6LV6MGMMZ5jG3bs8xzrkxdM6PZybS2wtdoylmirbV1K/PzZNuZ+q23PKdiam+f9bWQ/l7dcp59zaxwuC3tNq9EVtlav3faB59jnCry/h2zH2DzLsdl+Xu193mZ7LjXRVifaCVfykupW54Yssba8YbLt8XFlX8IhWdhqqXt6zduSAwAySiKfNVJUVBTx5RXJfv36KRAIqKmpKeLypqYmlZaWxlyntLTUuvyh/43nOiWpvLxcHR0d2rFjh/fOAADAQbQaAAD3JavXtBoAgO7BubU3JrcBAJklns8ZifMVj8FgUGPGjFF9fX34slAopPr6elVUVMRcp6KiImJ5SVq1alV4+SFDhqi0tDRimebmZq1du9bzOiVp06ZN8vv96t+/f3x3AgCAdKPVAAC4L0m9ptUAAHQTzq098bbkAICMkshnjcSjpqZGkyZN0tixYzVu3DgtWLBALS0tmjx5siTpmmuu0aBBgzR37lxJ0o033qgJEybogQce0CWXXKLly5fr1Vdf1ZIlSz7ZBp9PN910k+6++24NGzZMQ4YM0R133KGBAwequrpaktTQ0KC1a9fqggsuUO/evdXQ0KCbb75Z3/rWt3TssccmcC8AAEgfWg0AgPuS2WtaDQDA0ePc2huT2wCAzJLkqk+cOFF79uzRrFmz1NjYqNGjR6uurk4lJSWSpJ07d8rv//sbn4wfP15PPPGEZs6cqdtuu03Dhg3TihUrdPrpp4eX+cEPfqCWlhZdd9112rdvn8477zzV1dUpPz9f0ieffbJ8+XLNnj1bra2tGjJkiG6++eaIzx8BACBj0GoAANyXxF7TagAAugHn1t531xhj++T7rNPc3Kzi4mI1vb8/4kPV4e21nfs9x1o7Qp5jHcZ7zCYvEPAcy8/1fif9gqD3er0sY/mWsWCO9+3l+L2PFn6fZcyyXiJsP8K2n+6QZbAz5D3W0WkZs6xnu87cgPc+6ZXn/RqcQDfvS3SP5uZmlRxXrP37u/c4e+j4ventJvXu3bXrPXCgWaOHlnT7tiD56HX8Nu3Y5znWFvJusq0HOT7vDtoamWiv83MT+x0gt4f22tbk9k7vx7zNsp6t17Z9WVTg3eucAJ+E5KJk9JpW9yy0Ojbb8X6DpdW2cymbXMux2d5q7+baW+19nXmW67Sd89nO62xjPkvHE5GMVrdbmttmeT7FNmZrvK25fXrleo7ZvleQPpxb42jR6thsx/u12z5I6Dpt55Z5lmOz7fhr66qtx7bG59nOqy3baTs97qmtPtjemdB12s6r+xzj3Wrb44r0odXpw19uAwAySvhzRLq4LAAASC1aDQCA++g1AABuo9XemNyGJGn77hbPMdsrumyvzAoZ26u9vNfLsbyiO9FXnif6F9gBy3baXnmWygOJ/RVw3q8S89vepyLBF23bvx+8t8W2Hn+cjcMl+7NGAJe93fSR51ivPMurs0PeB3bbq5ttf71s+wss66vPLWPWXlteRZ7NvfZZjmT2P4i2/R5m+cvtBHvd3a/IR2aj1ejJ3m7yPrcutLwrle18yXaMtZ0v2f4Cy/6XYom9W4r9vM79hiTaatuQbZ/Yfpeyddz2nIn1XWnc2M1wCL1GT/Vmo/d59bEFQc8x23usJPqXzbbzXFurbWPWd0tJsBO2vqSSrdWJvkGx7X7b5g1sj0HAb3lO3nIfbI8PeiZa7Y3JbQBAZqHqAAC4jVYDAOA+eg0AgNtotScnPlRn0aJFKisrU35+vsrLy7Vu3TrPZZcuXarzzz9fxx57rI499lhVVlZalwcAZBdfnP+he9BqAEBX0er0oNUAgHjQ6/Sg1wCArqLV3tI+uf3kk0+qpqZGtbW12rBhg84880xVVVVp9+7dMZdfvXq1rrzySr344otqaGjQ4MGDddFFF+l///d/U7zlAIB08OnvnzdyxK90b2yWoNUAgHjQ6tSj1QCAeNHr1KPXAIB40GpvaZ/cnj9/vqZOnarJkydr5MiRWrx4sXr16qVHHnkk5vKPP/64rr/+eo0ePVrDhw/Xz372M4VCIdXX16d4ywEA6eCL8wtHj1YDAOJBq1OPVgMA4kWvU49eAwDiQau9pXVyu62tTevXr1dlZWX4Mr/fr8rKSjU0NHTpOj7++GO1t7erb9++ydpMAIBDuvxqtU+/cHRoNQAgXrQ6tWg1ACAR9Dq16DUAIF602ltOOm9879696uzsVElJScTlJSUl2rJlS5eu44c//KEGDhwY8YvBZ7W2tqq1tTX87+bm5sQ3GADggHhei9bDqp4EqWi1RK8BILvQ6lSi1QCAxNDrVOJ5cABA/Gi1l7RObh+tefPmafny5Vq9erXy8/NjLjN37lzdeeedKd4yN+1pbvUcyw8GPMdyOr1/KEImsW3xW37OApbB3ID3mw0Ec7zHcgLe15ljuU7bq11s98HnyMtk7Nvh/eD5LQdCn2VfGmP7hkhsn7iyL+GOeF6JxrdP+nWl1RK9/qzdll4XWHpta7L9+OzNdgy29zqxltvWs92e7WfddhhwpTG27fBZHjt/ovvE9nuYZY/5LSvafi9Cz0OrMwutjt/u/Qc9x3rldX+rbZ2wHX9zLIOBBHts63g2t9rG1kfbebf9uQjv9UK23w2s3yvu70ukFr3OLDwPHp8mS6sL872nRBI8dT7Cc8iJnefans+29djaf2urEz2/dOMAYX0MbL9tWN7b2PbtELSs2OlPrNWO7Eo4hFZ7S+vbkvfr10+BQEBNTU0Rlzc1Nam0tNS67v3336958+bp+eef1xlnnOG53IwZM7R///7w1zvvvNMt2w4ASA9fnF84OqlotUSvASCb0OrUotUAgETQ69TieXAAQLxotbe0Tm4Hg0GNGTNG9fX14ctCoZDq6+tVUVHhud69996rOXPmqK6uTmPHjrXeRl5enoqKiiK+AACZi88aSa1UtFqi1wCQTWh1atFqAEAi6HVq8Tw4ACBetNpb2t+WvKamRpMmTdLYsWM1btw4LViwQC0tLZo8ebIk6ZprrtGgQYM0d+5cSdKPfvQjzZo1S0888YTKysrU2NgoSSosLFRhYWHa7gcAIDV8n/7X1WVx9Gg1ACAetDr1aDUAIF70OvXoNQAgHrTaW9ontydOnKg9e/Zo1qxZamxs1OjRo1VXV6eSkhJJ0s6dO+X3//0PzH/605+qra1NX//61yOup7a2VrNnz07lpgMA0iGe91npWU1PGloNAIgLrU45Wg0AiBu9Tjl6DQCIC632lPbJbUmaNm2apk2bFnNs9erVEf/esWNH8jcIAOAsmp4etBoA0FW0Oj1oNQAgHvQ6Peg1AKCraLU3Jya3AQDoqng+Q6SnfdYIAAAuoNUAALiPXgMA4DZa7Y3J7SzUcrAj5uV5uf6Yl0tSTsh4joVyvNczxns968+S5Sctx+89FkjCmGVI/iw+Ivis9y2xx9Uk+PqgLN7NSAI+awTZwqvX+ZZehwLe39OWlFuO6naJNtKlXmd6Y2zbn3CTvb/FjrAttv2c4Tsa3YpWI1t85HluHfBcJ2g5R7b12LJawi3wJ9pjyw3arjObW53o+bP/CLX2vD3LzrQ13nZrmbCfkVr0GtnAq9UFQe9W25pre67bJtHzoYTPgRPseKK/U2SCxJ/rts1TWG7PcmsBk9iOzua5CCSGVntjchsAkFl4PxYAANxGqwEAcB+9BgDAbbTaE5PbAICMQtMBAHAbrQYAwH30GgAAt9Fqb0xuAwAyCp81AgCA22g1AADuo9cAALiNVntjchsAkGG6/lkjPe81awAAuIBWAwDgPnoNAIDbaLUXJrcBABmFV6wBAOA2Wg0AgPvoNQAAbqPV3pjczlDtHSHPsZxA7O/igPH+7g75jeeY98gRBi0/TLafM7/lp9D2Axrw29azjFm2xXZ7tuvMdInetyzeJQCQkLZu7rUJeN+WMQm23MLaZMt69Lr7JL793o96IMP3CQB0J1urcz1anWNpma25tlYnytpOy+He1njL3aPVMdi33/sx91v2mOXXwYRl+n4G0HMl0mpjq5Ilx4mW2nqETfFz5LQ6mm37/ZZHPWTZYwH/UW1STBm+m4GUYnIbAJBReMUaAABuo9UAALiPXgMA4DZa7Y3JbQBARvHF8VkjXf9MEgAA0F1oNQAA7qPXAAC4jVZ7Y3IbAJBReMUaAABuo9UAALiPXgMA4DZa7Y3JbQBARvHpCJ9ldNiyAAAgtWg1AADuo9cAALiNVntjchsAkFmoOgAAbqPVAAC4j14DAOA2Wu2JyW0AQEbhs0YAAHAbrQYAwH30GgAAt9Fqb0xuOywUMp5jfr/3N6rPezVPxiT2jW+7Kds1Jvr+/z7Lionenu06AbiHzxqBa2y9Dlh6bRLqtW20+7/hXep1orfXU7FPkE60Gq7pTGGrbRI9707GzwmtTg3rfrbsEtPd33xADPQaLnGl1S6h1alhu99+y+yHScpzMD3zMYA3Wu3Nn+4NAAAgHr44vxKxaNEilZWVKT8/X+Xl5Vq3bp11+aeeekrDhw9Xfn6+Ro0apZUrV0aMG2M0a9YsDRgwQAUFBaqsrNSbb74ZscwHH3ygq666SkVFRerTp4+mTJmijz76KMF7AABA+tBqAADcl+xe02oAAI4O59bemNwGAGSWJFf9ySefVE1NjWpra7VhwwadeeaZqqqq0u7du2Muv2bNGl155ZWaMmWKNm7cqOrqalVXV+v1118PL3PvvffqoYce0uLFi7V27Vodc8wxqqqq0sGDB8PLXHXVVfrLX/6iVatW6bnnntPvfvc7XXfddfHfAQAA0o1WAwDgviT2mlYDANANOLf25DM97L2OmpubVVxcrKb396uoqCjdm2Nle5tT24OWyEOa6HcBb0sO4HDNzc0qOa5Y+/d373H20PG7cW/Xr7e5uVml/eLblvLycp199tlauHChJCkUCmnw4MGaPn26br311qjlJ06cqJaWFj333HPhy8455xyNHj1aixcvljFGAwcO1Pe+9z19//vflyTt379fJSUlevTRR3XFFVdo8+bNGjlypP74xz9q7NixkqS6ujpdfPHFevfddzVw4MAubXs2odde68S9ylFxqdeJ3h6A2JLRa1rds2RSq21vdWrT3U+XJHp1vC15z5OMp+p4DDJPJp9b02o30OrMRqvTz/b9lYxvPdtH0cJNmdxqKbN7zV9uAwAyyoEDzXF9SZ/E/bNfra2tMa+7ra1N69evV2VlZfgyv9+vyspKNTQ0xFynoaEhYnlJqqqqCi+/fft2NTY2RixTXFys8vLy8DINDQ3q06dPOOiSVFlZKb/fr7Vr1yawlwAASB9aDQCA+5LVa1oNAED34NzaW06XlwQAII2CwaBKS0s1bMjguNYrLCzU4MGR69TW1mr27NlRy+7du1ednZ0qKSmJuLykpERbtmyJef2NjY0xl29sbAyPH7rMtkz//v0jxnNyctS3b9/wMgAAuI5WAwDgvmT3mlYDAHB0OLc+Mia3AQAZIT8/X9u3b1dbW1tc6xljot5aKS8vrzs3DQAAiFYDAJAJ6DUAAG6j1UfG5Haa2T63wfo50bbrTOST4y2rJPrxEcn4hAg+MwTo2fLz85Wfn5+06+/Xr58CgYCampoiLm9qalJpaWnMdUpLS63LH/rfpqYmDRgwIGKZ0aNHh5fZvXt3xHV0dHTogw8+8LxdpFYm9Nol9BrouWg10sXW6kQ/vjChVts4lDla7Tb2M5Itmb2m1fCSjFZ7BS1TPoo7GYd7GtJ9rJ97bptPyZRvQDiNc2s7PnMbAIBPBYNBjRkzRvX19eHLQqGQ6uvrVVFREXOdioqKiOUladWqVeHlhwwZotLS0ohlmpubtXbt2vAyFRUV2rdvn9avXx9e5oUXXlAoFFJ5eXm33T8AADIdrQYAwG20GgAA92V6r/nLbQAAPqOmpkaTJk3S2LFjNW7cOC1YsEAtLS2aPHmyJOmaa67RoEGDNHfuXEnSjTfeqAkTJuiBBx7QJZdcouXLl+vVV1/VkiVLJH3yKs+bbrpJd999t4YNG6YhQ4bojjvu0MCBA1VdXS1JGjFihL785S9r6tSpWrx4sdrb2zVt2jRdccUVGjhwYFr2AwAArqLVAAC4jVYDAOC+TO41k9sAAHzGxIkTtWfPHs2aNUuNjY0aPXq06urqVFJSIknauXOn/P6/v/HJ+PHj9cQTT2jmzJm67bbbNGzYMK1YsUKnn356eJkf/OAHamlp0XXXXad9+/bpvPPOU11dXcRbyzz++OOaNm2aLrzwQvn9fn3ta1/TQw89lLo7DgBAhqDVAAC4jVYDAOC+TO61z/SwDwBobm5WcXGxmt7fr6KionRvTlI+f6G7r5LP3AYQj+bmZpUcV6z9+904ziIz0evMRq8B99FrHC1andloNeA+Wo2j1RNa7X1bKbupo8JnbmenRL/XeewyD61OHz5zGwAAAAAAAAAAAADgPCa3AQAAAAAAAAAAAADO4zO30ywZbzXhdZUJvx3GUWyL53XyFhsAgAySCb1OBnoNAMgUqWy1Tao7TqsBAJkilc0ij0gnfj8Dko+/3AYAAAAAAAAAAAAAOI/JbQAAAAAAAAAAAACA85jcBgAAAAAAAAAAAAA4j8ltAAAAAAAAAAAAAIDzmNwGAAAAAAAAAAAAADiPyW0AAAAAAAAAAAAAgPNy0r0BSB2fz5fuTQAAAEdArwEAyFx0HAAAAACSi7/cBgAAAAAAAAAAAAA4j8ltAAAAAAAAAAAAAIDzmNwGAAAAAAAAAAAAADiPyW0AAAAAAAAAAAAAgPOY3AYAAAAAAAAAAAAAOI/JbQAAAAAAAAAAAACA85jcBgAAAAAAAAAAAAA4j8ltAAAAAAAAAAAAAIDzmNwGAAAAAAAAAAAAADiPyW0AAAAAAAAAAAAAgPOY3AYAAAAAAAAAAAAAOI/JbQAAAAAAAAAAAACA85jcBgAAAAAAAAAAAAA4j8ltAAAAAAAAAAAAAIDzmNwGAAAAAAAAAAAAADiPyW0AAAAAAAAAAAAAgPOY3AYAAAAAAAAAAAAAOI/JbQAAAAAAAAAAAACA85jcBgAAAAAAAAAAAAA4j8ltAAAAAAAAAAAAAIDzmNwGAAAAAAAAAAAAADiPyW0AAAAAAAAAAAAAgPOY3AYAAAAAAAAAAAAAOM+Jye1FixaprKxM+fn5Ki8v17p166zLP/XUUxo+fLjy8/M1atQorVy5MkVbCgBAz0SrAQBwG60GAMB99BoAgKOX9sntJ598UjU1NaqtrdWGDRt05plnqqqqSrt37465/Jo1a3TllVdqypQp2rhxo6qrq1VdXa3XX389xVsOAEDPQKsBAHAbrQYAwH30GgCA7uEzxph0bkB5ebnOPvtsLVy4UJIUCoU0ePBgTZ8+XbfeemvU8hMnTlRLS4uee+658GXnnHOORo8ercWLFx/x9pqbm1VcXKym9/erqKio++4IAEDSJ8fZkuOKtX8/x9lskepWS/QaAJKNXmcXWg0A2YdWZx+eBweA7EKr0ycnnTfe1tam9evXa8aMGeHL/H6/Kisr1dDQEHOdhoYG1dTURFxWVVWlFStWxFy+tbVVra2t4X/v379fknSgufkotx4AEMuh42uaXzuFbpKKVkv0GgBSjV5nD1oNANmJVmcXngcHgOxDq9MnrZPbe/fuVWdnp0pKSiIuLykp0ZYtW2Ku09jYGHP5xsbGmMvPnTtXd955Z9TlJw8ZnOBWAwC64v3331dxcXG6NwNHKRWtlug1AKQLvc58tBoAshutzg48Dw4A2YtWp15aJ7dTYcaMGRGvcNu3b59OPPFE7dy5k2+2TzU3N2vw4MF65513eOuEz2C/RGOfRGOfRNu/f79OOOEE9e3bN92bggxCr4+M40009kk09kk09kls9BrxotVHxvEmGvskNvZLNPZJNFqNeNHqruF4E419Eo19Eo19Eo1Wp09aJ7f79eunQCCgpqamiMubmppUWloac53S0tK4ls/Ly1NeXl7U5cXFxfwAHqaoqIh9EgP7JRr7JBr7JJrf70/3JqAbpKLVEr2OB8ebaOyTaOyTaOyT2Oh15qPV7uF4E419Ehv7JRr7JBqtzg48D+4ejjfR2CfR2CfR2CfRaHXqpXWPB4NBjRkzRvX19eHLQqGQ6uvrVVFREXOdioqKiOUladWqVZ7LAwCAxNFqAADcRqsBAHAfvQYAoPuk/W3Ja2pqNGnSJI0dO1bjxo3TggUL1NLSosmTJ0uSrrnmGg0aNEhz586VJN14442aMGGCHnjgAV1yySVavny5Xn31VS1ZsiSddwMAgKxFqwEAcButBgDAffQaAIDukfbJ7YkTJ2rPnj2aNWuWGhsbNXr0aNXV1amkpESStHPnzog/6R8/fryeeOIJzZw5U7fddpuGDRumFStW6PTTT+/S7eXl5am2tjbmW7T0VOyT2Ngv0dgn0dgn0dgn2SfVrZb4PoqFfRKNfRKNfRKNfRIb+yW70Go3sE+isU9iY79EY59EY59kH54HdwP7JRr7JBr7JBr7JBr7JH18xhiT7o0AAAAAAAAAAAAAAMCGTzkHAAAAAAAAAAAAADiPyW0AAAAAAAAAAAAAgPOY3AYAAAAAAAAAAAAAOI/JbQAAAAAAAAAAAACA87JycnvRokUqKytTfn6+ysvLtW7dOuvyTz31lIYPH678/HyNGjVKK1euTNGWpk48+2Tp0qU6//zzdeyxx+rYY49VZWXlEfdhpor3e+WQ5cuXy+fzqbq6OrkbmAbx7pN9+/bphhtu0IABA5SXl6dTTjkl636G4t0nCxYs0KmnnqqCggINHjxYN998sw4ePJiirU2+3/3ud7r00ks1cOBA+Xw+rVix4ojrrF69WmeddZby8vJ08skn69FHH036dsJ99DoavY5Gq6PR6mi0OhKtRneh1dFodTRaHY1Wx0avI9FrdAdaHY1Wx0avo9HraLQ6Eq12mMkyy5cvN8Fg0DzyyCPmL3/5i5k6darp06ePaWpqirn8K6+8YgKBgLn33nvNG2+8YWbOnGlyc3PNa6+9luItT55498k3v/lNs2jRIrNx40azefNmc+2115ri4mLz7rvvpnjLkyve/XLI9u3bzaBBg8z5559vLrvsstRsbIrEu09aW1vN2LFjzcUXX2xefvlls337drN69WqzadOmFG958sS7Tx5//HGTl5dnHn/8cbN9+3bzP//zP2bAgAHm5ptvTvGWJ8/KlSvN7bffbp555hkjyTz77LPW5bdt22Z69eplampqzBtvvGEefvhhEwgETF1dXWo2GE6i19HodTRaHY1WR6PV0Wg1ugOtjkaro9HqaLQ6NnodjV7jaNHqaLQ6NnodjV5Ho9XRaLW7sm5ye9y4ceaGG24I/7uzs9MMHDjQzJ07N+by3/jGN8wll1wScVl5ebn57ne/m9TtTKV498nhOjo6TO/evc1jjz2WrE1Mi0T2S0dHhxk/frz52c9+ZiZNmpR1UY93n/z0pz81J510kmlra0vVJqZcvPvkhhtuMF/84hcjLqupqTHnnntuUrczXboS9R/84AfmtNNOi7hs4sSJpqqqKolbBtfR62j0Ohqtjkaro9FqO1qNRNHqaLQ6Gq2ORqtjo9d29BqJoNXRaHVs9DoavY5Gq+1otVuy6m3J29ratH79elVWVoYv8/v9qqysVENDQ8x1GhoaIpaXpKqqKs/lM00i++RwH3/8sdrb29W3b99kbWbKJbpf7rrrLvXv319TpkxJxWamVCL75Fe/+pUqKip0ww03qKSkRKeffrruuecedXZ2pmqzkyqRfTJ+/HitX78+/JYt27Zt08qVK3XxxRenZJtdlO3HWcSPXkej19FodTRaHY1Wd49sP8YifrQ6Gq2ORquj0erY6HX3yPbjLOJDq6PR6tjodTR6HY1Wd49sP866JCfdG9Cd9u7dq87OTpWUlERcXlJSoi1btsRcp7GxMebyjY2NSdvOVEpknxzuhz/8oQYOHBj1Q5nJEtkvL7/8sn7+859r06ZNKdjC1Etkn2zbtk0vvPCCrrrqKq1cuVJvvfWWrr/+erW3t6u2tjYVm51UieyTb37zm9q7d6/OO+88GWPU0dGhf/7nf9Ztt92Wik12ktdxtrm5Wf/3f/+ngoKCNG0Z0oVeR6PX0Wh1NFodjVZ3D1qNw9HqaLQ6Gq2ORqtjo9fdg17js2h1NFodG72ORq+j0eruQatTJ6v+chvdb968eVq+fLmeffZZ5efnp3tz0ubAgQO6+uqrtXTpUvXr1y/dm+OMUCik/v37a8mSJRozZowmTpyo22+/XYsXL073pqXN6tWrdc899+gnP/mJNmzYoGeeeUa//vWvNWfOnHRvGoAsRq9ptRdaHY1WA0gHWk2rvdDq2Og1gFSj1Z+g17HR62i0GumUVX+53a9fPwUCATU1NUVc3tTUpNLS0pjrlJaWxrV8pklknxxy//33a968efrtb3+rM844I5mbmXLx7pe3335bO3bs0KWXXhq+LBQKSZJycnK0detWDR06NLkbnWSJfK8MGDBAubm5CgQC4ctGjBihxsZGtbW1KRgMJnWbky2RfXLHHXfo6quv1ne+8x1J0qhRo9TS0qLrrrtOt99+u/z+nveaIq/jbFFREa9W66HodTR6HY1WR6PV0Wh196DVOBytjkaro9HqaLQ6NnrdPeg1PotWR6PVsdHraPQ6Gq3uHrQ6dbLquysYDGrMmDGqr68PXxYKhVRfX6+KioqY61RUVEQsL0mrVq3yXD7TJLJPJOnee+/VnDlzVFdXp7Fjx6ZiU1Mq3v0yfPhwvfbaa9q0aVP46x//8R91wQUXaNOmTRo8eHAqNz8pEvleOffcc/XWW2+Ff8GRpL/+9a8aMGBAxgddSmyffPzxx1HhPvRLjzEmeRvrsGw/ziJ+9DoavY5Gq6PR6mi0untk+zEW8aPV0Wh1NFodjVbHRq+7R7YfZxEfWh2NVsdGr6PR62i0untk+3HWKSbLLF++3OTl5ZlHH33UvPHGG+a6664zffr0MY2NjcYYY66++mpz6623hpd/5ZVXTE5Ojrn//vvN5s2bTW1trcnNzTWvvfZauu5Ct4t3n8ybN88Eg0Hz9NNPm127doW/Dhw4kK67kBTx7pfDTZo0yVx22WUp2trUiHef7Ny50/Tu3dtMmzbNbN261Tz33HOmf//+5u67707XXeh28e6T2tpa07t3b/PLX/7SbNu2zTz//PNm6NCh5hvf+Ea67kK3O3DggNm4caPZuHGjkWTmz59vNm7caP72t78ZY4y59dZbzdVXXx1eftu2baZXr17mlltuMZs3bzaLFi0ygUDA1NXVpesuwAH0Ohq9jkaro9HqaLQ6Gq1Gd6DV0Wh1NFodjVbHRq+j0WscLVodjVbHRq+j0etotDoarXZX1k1uG2PMww8/bE444QQTDAbNuHHjzB/+8Ifw2IQJE8ykSZMilv+P//gPc8opp5hgMGhOO+008+tf/zrFW5x88eyTE0880UiK+qqtrU39hidZvN8rn5WNUTcm/n2yZs0aU15ebvLy8sxJJ51k/u3f/s10dHSkeKuTK5590t7ebmbPnm2GDh1q8vPzzeDBg831119vPvzww9RveJK8+OKLMY8Rh/bDpEmTzIQJE6LWGT16tAkGg+akk04yy5YtS/l2wz30Ohq9jkaro9HqaLQ6Eq1Gd6HV0Wh1NFodjVbHRq8j0Wt0B1odjVbHRq+j0etotDoSrXaXz5ge+v4AAAAAAAAAAAAAAICMkVWfuQ0AAAAAAAAAAAAAyE5MbgMAAAAAAAAAAAAAnMfkNgAAAAAAAAAAAADAeUxuAwAAAAAAAAAAAACcx+Q2AAAAAAAAAAAAAMB5TG4DAAAAAAAAAAAAAJzH5DYAAAAAAAAAAAAAwHlMbgM92LXXXqvq6mrrMqtXr5bP59O+fftSsk0AAODvaDUAAO6j1wAAuI1WA9mFyW1kLZ/PZ/2aPXt2WrdtxYoVabv9Qx588EE9+uij4X9/4Qtf0E033RSxzPjx47Vr1y4VFxenduMAAFmPVh8ZrQYApBu9PjJ6DQBIJ1p9ZLQayC456d4AIFl27doV/v9PPvmkZs2apa1bt4YvKywsjOv62traFAwGu237XNCVUAeDQZWWlqZgawAAPQ2tPjJaDQBIN3p9ZPQaAJBOtPrIaDWQXfjLbWSt0tLS8FdxcbF8Pl/43y0tLbrqqqtUUlKiwsJCnX322frtb38bsX5ZWZnmzJmja665RkVFRbruuuskSUuXLtXgwYPVq1cvffWrX9X8+fPVp0+fiHX/67/+S2eddZby8/N10kkn6c4771RHR0f4eiXpq1/9qnw+X/jfh9uxY4d8Pp+WL1+u8ePHKz8/X6effrpeeumliOVeeukljRs3Tnl5eRowYIBuvfXW8G1J0tNPP61Ro0apoKBAxx13nCorK9XS0iIp8u1Yrr32Wr300kt68MEHw6/q27FjR8y3Y/nP//xPnXbaacrLy1NZWZkeeOCBqH13zz336Nvf/rZ69+6tE044QUuWLDniYwYA6Flo9SdoNQDAZfT6E/QaAOAqWv0JWg30IAboAZYtW2aKi4vD/960aZNZvHixee2118xf//pXM3PmTJOfn2/+9re/hZc58cQTTVFRkbn//vvNW2+9Zd566y3z8ssvG7/fb+677z6zdetWs2jRItO3b9+I6/7d735nioqKzKOPPmrefvtt8/zzz5uysjIze/ZsY4wxu3fvNpLMsmXLzK5du8zu3btjbvP27duNJHP88cebp59+2rzxxhvmO9/5jundu7fZu3evMcaYd9991/Tq1ctcf/31ZvPmzebZZ581/fr1M7W1tcYYY9577z2Tk5Nj5s+fb7Zv327+/Oc/m0WLFpkDBw4YY4yZNGmSueyyy4wxxuzbt89UVFSYqVOnml27dpldu3aZjo4O8+KLLxpJ5sMPPzTGGPPqq68av99v7rrrLrN161azbNkyU1BQYJYtWxax7/r27WsWLVpk3nzzTTN37lzj9/vNli1bjuJRBABkM1pNqwEA7qPX9BoA4DZaTauBnoDJbfQIh0c9ltNOO808/PDD4X+feOKJprq6OmKZiRMnmksuuSTisquuuiriui+88EJzzz33RCzz7//+72bAgAHhf0syzz77rHV7DkV93rx54cva29vN8ccfb370ox8ZY4y57bbbzKmnnmpCoVB4mUWLFpnCwkLT2dlp1q9fbySZHTt2xLyNz0bdGGMmTJhgbrzxxohlDo/6N7/5TfOlL30pYplbbrnFjBw5MvzvE0880XzrW98K/zsUCpn+/fubn/70p9b7DADouWg1rQYAuI9e02sAgNtoNa0GegLelhw90kcffaTvf//7GjFihPr06aPCwkJt3rxZO3fujFhu7NixEf/eunWrxo0bF3HZ4f/+05/+pLvuukuFhYXhr6lTp2rXrl36+OOP497WioqK8P/PycnR2LFjtXnzZknS5s2bVVFRIZ/PF17m3HPP1UcffaR3331XZ555pi688EKNGjVK//RP/6SlS5fqww8/jHsbPmvz5s0699xzIy4799xz9eabb6qzszN82RlnnBH+/4feCmf37t1HddsAgJ6DVieOVgMAUoVeJ45eAwBSgVYnjlYD7spJ9wYA6fD9739fq1at0v3336+TTz5ZBQUF+vrXv662traI5Y455pi4r/ujjz7SnXfeqcsvvzxqLD8/P+FtTkQgENCqVau0Zs0aPf/883r44Yd1++23a+3atRoyZEhSbzs3Nzfi3z6fT6FQKKm3CQDIHrSaVgMA3Eev6TUAwG20mlYD2Yi/3EaP9Morr+jaa6/VV7/6VY0aNUqlpaXasWPHEdc79dRT9cc//jHissP/fdZZZ2nr1q06+eSTo778/k9+5HJzcyNe3WXzhz/8Ifz/Ozo6tH79eo0YMUKSNGLECDU0NMgYE3HfevfureOPP17SJzE999xzdeedd2rjxo0KBoN69tlnY95WMBg84naNGDFCr7zySsRlr7zyik455RQFAoEu3ScAAI6EVtNqAID76DW9BgC4jVbTaiAb8Zfb6JGGDRumZ555Rpdeeql8Pp/uuOOOLr2aavr06fp//+//af78+br00kv1wgsv6De/+U3E26HMmjVLX/nKV3TCCSfo61//uvx+v/70pz/p9ddf19133y1JKisrU319vc4991zl5eXp2GOP9bzNRYsWadiwYRoxYoR+/OMf68MPP9S3v/1tSdL111+vBQsWaPr06Zo2bZq2bt2q2tpa1dTUyO/3a+3ataqvr9dFF12k/v37a+3atdqzZ0/4l4LDlZWVae3atdqxY4cKCwvVt2/fqGW+973v6eyzz9acOXM0ceJENTQ0aOHChfrJT35yxP0HAEBX0WpaDQBwH72m1wAAt9FqWg1kpfR+5DeQGsuWLTPFxcXhf2/fvt1ccMEFpqCgwAwePNgsXLjQTJgwwdx4443hZU488UTz4x//OOq6lixZYgYNGmQKCgpMdXW1ufvuu01paWnEMnV1dWb8+PGmoKDAFBUVmXHjxpklS5aEx3/1q1+Zk08+2eTk5JgTTzwx5jZv377dSDJPPPGEGTdunAkGg2bkyJHmhRdeiFhu9erV5uyzzzbBYNCUlpaaH/7wh6a9vd0YY8wbb7xhqqqqzOc+9zmTl5dnTjnlFPPwww+H1500aZK57LLLwv/eunWrOeecc0xBQYGRZLZv325efPFFI8l8+OGH4eWefvppM3LkSJObm2tOOOEEc99990VsU6x9d+aZZ5ra2tqY9xUAAFpNqwEA7qPX9BoA4DZaTauBnsBnzGfexwFA3KZOnaotW7bo97//fbde744dOzRkyBBt3LhRo0eP7tbrBgCgJ6HVAAC4j14DAOA2Wg3AFbwtORCn+++/X1/60pd0zDHH6De/+Y0ee+wx3ooEAACH0GoAANxHrwEAcButBuAqJreBOK1bt0733nuvDhw4oJNOOkkPPfSQvvOd76R7swAAwKdoNQAA7qPXAAC4jVYDcBVvSw4AAAAAAAAAAAAAcJ4/3RsAAAAAAAAAAAAAAMCRMLkNAAAAAAAAAAAAAHAek9sAAAAAAAAAAAAAAOcxuQ0AAAAAAAAAAAAAcB6T2wAAAAAAAAAAAAAA5zG5DQAAAAAAAAAAAABwHpPbAAAAAAAAAAAAAADnMbkNAAAAAAAAAAAAAHAek9sAAAAAAAAAAAAAAOf9f0UcWaPmQfV2AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# -- 1D distributions for Figure 1 --\n", + "n_1d = 50\n", + "grid = jnp.linspace(0, 1, n_1d)\n", + "\n", + "# Simple source and target\n", + "def gaussian(x, mu, sigma):\n", + " \"\"\"Gaussian bump (un-normalised).\"\"\"\n", + " p = jnp.exp(-0.5 * ((x - mu) / sigma) ** 2)\n", + " return p / p.sum()\n", + "\n", + "# Source: bimodal\n", + "r_1d = 0.6 * gaussian(grid, 0.25, 0.06) + 0.4 * gaussian(grid, 0.75, 0.06)\n", + "r_1d = r_1d / r_1d.sum()\n", + "\n", + "# Target: bimodal (shifted)\n", + "c_1d = 0.4 * gaussian(grid, 0.30, 0.06) + 0.6 * gaussian(grid, 0.70, 0.06)\n", + "c_1d = c_1d / c_1d.sum()\n", + "\n", + "# Cost matrices\n", + "C1_1d = jnp.abs(grid[:, None] - grid[None, :])\n", + "C2_1d = (grid[:, None] - grid[None, :]) ** 2\n", + "\n", + "eta_1d = 100\n", + "\n", + "# -- Compute t_min (W2 distance) and t_max (L2 cost of W1-optimal plan) --\n", + "# These unconstrained references are solved directly with OTT-JAX Sinkhorn.\n", + "P_w2 = solve_ott_entropic_plan(C2_1d, r_1d, c_1d, eta_1d)\n", + "t_min = float(jnp.sum(C2_1d * P_w2))\n", + "\n", + "P_w1 = solve_ott_entropic_plan(C1_1d, r_1d, c_1d, eta_1d)\n", + "t_max = float(jnp.sum(C2_1d * P_w1))\n", + "\n", + "print(f\"t_min (W2 distance) = {t_min:.4f}\")\n", + "print(f\"t_max (L2 cost of L1-optimal plan) = {t_max:.4f}\")\n", + "\n", + "# -- Solve for 4 threshold values: binding -> no constraint --\n", + "t_values = np.linspace(t_min, t_max, 4)\n", + "labels = [\n", + " \"Binding $\\\\ell_2$ constraint\\n(Minimal $\\\\ell_2$ cost)\",\n", + " \"Strong $\\\\ell_2$ constraint\",\n", + " \"Mild $\\\\ell_2$ constraint\",\n", + " \"No $\\\\ell_2$ constraint\\n(Minimal $\\\\ell_1$ cost)\"\n", + "]\n", + "\n", + "plans_1d = []\n", + "for idx, t_val in enumerate(t_values):\n", + " D_ineq = convert_leq_constraint(C2_1d, float(t_val), n_1d)\n", + " _, _, _, P_cot, h = constrained_sinkhorn(\n", + " C1_1d, [D_ineq], r_1d, c_1d, eta_1d,\n", + " K=1, L=0, n_iters=200, n_newton=5, verbose=False\n", + " )\n", + " l1_cost = float(jnp.sum(C1_1d * P_cot))\n", + " l2_cost = float(jnp.sum(C2_1d * P_cot))\n", + " print(f\"[{idx+1}/4] t={t_val:.4f} | L1={l1_cost:.4f}, L2={l2_cost:.4f}, \"\n", + " f\"viol={h['violation'][-1]:.2e}\")\n", + " plans_1d.append(P_cot)\n", + "\n", + "# -- Plot (reproducing Figure 1) --\n", + "fig, axes = plt.subplots(1, 4, figsize=(20, 4.5))\n", + "\n", + "for idx, (P_plan, label) in enumerate(zip(plans_1d, labels)):\n", + " ax = axes[idx]\n", + " im = ax.imshow(\n", + " np.array(P_plan), cmap=\"Blues\", origin=\"lower\",\n", + " extent=[0, 1, 0, 1], aspect=\"equal\"\n", + " )\n", + " ax.set_xlabel(\"Target position\")\n", + " if idx == 0:\n", + " ax.set_ylabel(\"Source position\")\n", + " ax.set_title(label, fontsize=10)\n", + " plt.colorbar(im, ax=ax, shrink=0.75)\n", + "\n", + "fig.suptitle(\n", + " \"Figure 1: 1D OT under different $\\\\ell_2$ inequality constraints\\n\"\n", + " \"(main cost = $\\\\ell_1$ Manhattan, constraint = $\\\\ell_2^2$ Euclidean)\",\n", + " fontsize=13, y=1.06\n", + ")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment 3: Toy Pareto Front -- Manhattan vs. Squared Euclidean\n", + "\n", + "Following Section 5 of the paper:\n", + "\n", + "$$\\min_{P \\in \\mathcal{U}_{r,c}} \\langle C_1, P\\rangle \\quad\\text{s.t.}\\quad \\langle C_2, P\\rangle \\le t$$\n", + "\n", + "By sweeping $t \\in [t_{min}, t_{max}]$ for different $\\eta \\in \\{10, 100, 1000\\}$, we trace Pareto fronts between\n", + "$\\ell_1$ and $\\ell_2^2$ transport cost. As $\\eta$ increases, the front converges to the true Pareto front." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "eta = 10: done\n", + "eta = 100: done\n", + "eta = 1000: done\n", + "True Pareto (eta=5000): done\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxYAAAJOCAYAAAAqFJGJAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAmkpJREFUeJzs3Xd8E/X/B/BX0kn3XlgoslpmmaWsAhYKAlIEmcoQQURALCJDRhGkLhAVtKIoQxCEHyACsiqoSNkFZBRFCv0yuuiku7n7/VESG5JC00vppbyej0cfmrvP3edzeSch79xnKERRFEFERERERCSBsrobQEREREREpo+JBRERERERScbEgoiIiIiIJGNiQUREREREkjGxICIiIiIiyZhYEBERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIlFDXL48GEoFAqsWbOmuptCJuCLL76Av78/rKysoFAocP369epuUo3QrVs3+Pn5VXczNBQKBcaMGfPIbYIgIDIyEk8//TTMzc2hUCg0+7Zu3YqWLVuiVq1aUCgUOHz4cNU3nKrFmDFjtGIvF3l5eZg6dSrq1KkDMzMzWb3HqOpFRkZW6b9T1f25ff36dSgUCkRGRlZbG4yFiYWJUCcN5f0dO3asuptodOo3mr6/Zs2aST7/g+e0trZGw4YNERERgfT0dCNcQeXt2LGjSj9gDh06hNdffx3+/v6Ijo7G+vXr4e7uXmX1PUxmZiYiIyP5ZbWarV27FgsXLkT37t2xevVqrF+/HgDw999/Y/jw4XB0dMSKFSuwfv16BAQEVHNrK2bNmjVYvny50c9b1e/PqmaK7f/ggw/w+eefY+jQoVUW14owxeeO5OH69euIjIzE2bNnq7spVcq8uhtAhhk+fDieffZZne0NGjSAi4sL8vPzYWFhUQ0tqzoDBw7E888/r7XNycnJKOcODAzE9OnTAQDp6enYs2cPPvnkExw4cACnT5+GpaWlUeox1I4dO7B27doq+wfswIEDAIBvv/0WLi4uVVJHRWVmZmLhwoUASn81oqqXn58PMzMzrW0HDhyAo6MjvvnmG61frA8fPoySkhIsX74crVu3ftxNlWTNmjW4fv06pk2bZtTzVvX7s6qZYvsPHDiA5s2b46OPPqrWdpjic0ePtn//foiiWKV1XL9+HQsXLoSfnx8CAwO19tWtWxf5+fkwNzf9r+WmfwVPmNatW+PFF18sd7+1tfVjbI2u4uJiqFQqo7ajRYsWD71mKWrXrq117qlTp6J///7YtWsXfvrpJ7zwwguS68jJyYG9vb3k8xhTUlISAFQoqVAnqzXhA89UGfs1pO/9mZSUBCcnJ51uMIa8VgyhUqlQWFgIGxsbo55Xjvgeki4pKQl16tSp7mYYzFRjb6rtNkTZz6Dq+hFRTd1roiZgV6gapLwxFnfv3sXLL78MV1dX2NnZoUePHoiLi9Pbp1Bf32ug9Je/B/tWq/s8Xrx4EREREXjqqadgbW2t6ZZVWFiIJUuWoGnTprC2toaTkxP69++PuLg4g6+toKAAeXl5Dy2TlpaG+Ph4ZGVlGXz+ssLCwgAAV69eBQBs3rwZzz33HOrUqQMrKyu4ubkhPDwc58+f1znWz88P3bp1Q1xcHMLCwuDo6IgWLVpo9v/zzz946aWX4O3tDUtLS/j5+WHGjBnIzc3VlOnWrRvWrl0LQLu7Vtm4nj9/HgMHDoSrqyusra3RpEkTfPjhh1CpVA+9NnX3su+++07r/Oo7Ber+1ampqXj55Zfh6ekJW1tb3Lx5U3P8Sy+9BE9PT1hZWaF+/fqYM2eOTmzUr40rV65gzpw5eOqpp2BlZYWWLVtiz549mnKHDx9GvXr1AAALFy7UtKfs63LdunVo3749nJycYGtri6effhojR45Eampqha5V3y+L+vrrqq89KysLr732Gjw8PGBtbY1OnTrh+PHjOufIyMjA+PHj4ebmBltbW3Tr1g2nT58utz2nTp3CwIED4ebmBisrKzRu3BjvvfceSkpKtMqp35fXrl3D4MGD4eLiAgcHh4deKwBcvHgRvXv3hq2tLVxcXDBy5EikpKToLVv2fa7+3Dh06BBu3LihiYH6+ViwYAEAoF69ejqxycrKwsyZM9GgQQNYWVnB3d0dw4cPx7Vr17TqU39+HDx4EIsWLUL9+vVhbW2NH3/8EQAgiiK+/PJLtGnTBjY2NrCzs0P37t1x6NAhrfOUjemuXbvQrl07WFtbw9vbGzNmzNB6Lv38/PDbb79pXVNFxofs3r0bISEhcHNzQ61atVCnTh08//zz+PvvvzXxedj781HvIUOfs19//RUff/wx6tevDysrKzRq1EhTf1kqlQqLFi1C3bp1YW1tjRYtWmDz5s06r/WKfL6o2/mo94EgCFi+fDlatGgBe3t7ODg4oHHjxhg3bhyKi4sf+jwDQElJCT744AM0adIE1tbWcHV1xcCBA/HXX3/pPA8JCQn47bffNO2tyB2DzZs3o3PnzrC3t4eNjQ2CgoKwdetWnXLq13tsbCxCQkJga2sLV1dXvPLKK7h3756mnNTYG/r5efHiRUydOhVeXl6oVasWgoKCEBMToylXVFQEd3d3dOrUSe/1f/TRR1AoFPj9998f+jwZ6zWrvsZBgwbBwcEBDg4OGDBgABISEjT/NpYtZ8jnsz63b9/G9OnTERgYCGdnZ82/hR988IHOv4WP+gzS933o4sWLeOGFF1C7dm1YWVnBy8sL3bt3x+7duzVlcnJyMHfuXAQFBWk+2xs0aIBZs2ZpxXXNmjXo3r07AGDs2LE6//aW93xU5D3y4PGP+mysajU3Fa2h8vLykJaWprXNysqq3F8zCwsLERoairNnz2LMmDFo3749zp8/j9DQUKP9Ajly5EjUqlUL06dPh0KhgLe3N4qLi9G7d28cPXoUL730EiZPnoysrCx8/fXX6NSpE37//Xe0bdu2QudfunQp3n33XYiiiKeeegpjx47FO++8AysrK61yK1aswMKFC/Hdd9/pTY4q6p9//gEAuLm5ac7r6uqKCRMmwMvLC//++y9WrVqFTp064cyZM2jYsKHW8YmJiejRowdeeOEFDBo0SPMP0+nTp9GjRw84OTnh1VdfRe3atXHu3Dl89tln+PPPP/Hbb7/BwsIC77zzDgRBwB9//KHp5w4AHTt2BFD6BTUkJAQWFhZ4/fXX4eXlhZ9//hkzZ87EuXPnsGHDhnKvzd3dHevXr8eqVau0zu/p6alVrmfPnvDy8sK8efOQm5sLOzs73LhxA+3bt0dWVhYmTZqEhg0b4vDhw4iKisKff/6JmJgYnV+3Ro8eDQsLC7z11lsoKirC8uXLER4ejr///ht+fn4ICAjAJ598gjfffFOry5udnR0AYP369Rg9ejS6dOmCd999F7Vq1cL//vc/7NmzBykpKVUyLiQsLAzu7u6YP38+7t69i2XLlqFv375ISEjQvM+Ki4sRFhaGkydP4qWXXkKHDh1w9uxZhIaGwtXVVeecu3fvxvPPP48GDRpg+vTpcHFxQWxsLObPn4+zZ89iy5YtWuXv3buHkJAQdOrUCe+99165CYJaQkICunTpgsLCQkyePBm+vr74+eef0bt370deb0BAANavX4/33nsPaWlp+OSTTwAA9evXR2hoKLZt24bt27fjk08+gZubmyY2WVlZ6NixIxITE/Hyyy+jadOmuHPnDr744gsEBQXh1KlTqFu3rlZdb731FoqLizF+/HjNl1AAeOmll/DDDz9g8ODBGDt2LAoLC7Fhwwb07NkT27Ztw3PPPad1nj179uCLL77AxIkT8fLLL+Onn37Cxx9/DGdnZ8yZMwcAsHz5csyePVvrmtTXW57ffvsNzz33HJo1a4bZs2fDyckJt2/fxsGDB3H16lU0atToke9PNX3voco8Z3PmzEF+fj5effVVWFlZ4csvv8SYMWPQoEEDrS+UkydPRnR0NLp374633noLqampmDRpkiZxV6to+yvyPnjvvfcwf/589O/fHxMnToSZmRkSEhKwc+dOFBYWPrJb7siRI/Hjjz+iZ8+eeO2115CUlISVK1ciODgYf/zxB1q1aoWuXbti/fr1ePPNN+Hm5oZ33nkHALR+sNFn7ty5eO+999C7d28sWrQISqUS27dvxwsvvIAVK1bg9ddf1yp/9uxZ9OvXD2PHjsWIESNw+PBhrF69GkqlEqtWrTLouTPW5+eoUaNgZmaGmTNnIicnB1999RV69+6NX375BaGhobC0tMTo0aOxdOlSXLlyRfN+Uvv222/RqFEjdO3a9aHP1cPabchr9u7du+jSpQuSk5MxceJEBAQE4I8//kD37t21fjwzlvPnz2Pbtm0YOHAg6tevj+LiYuzduxezZs3CtWvX8NVXX+kcU95n0IPu3r2LHj16AAAmTpyIunXrIi0tDadOncLx48fRt29fAMCtW7fwzTffYNCgQRgxYgTMzc3x22+/4cMPP0RcXBz27dsHAOjatSvmzJmDJUuWYMKECejSpQsA3X97H1SR90hZFflsrHIimYRDhw6JAPT+DR06VKvMd999pzlu5cqVIgBx8eLFWudTb69bt67WdgDi6NGjder/7rvvRADioUOHNNsWLFggAhBDQkLE4uJirfLLli0TAYh79+7V2p6VlSX6+vqKISEhj7zmGzduiD169BA/++wz8aeffhK/+eYbsVevXiIAMTQ0VCwpKdEqr25P2et/GABir169xNTUVDE1NVX8+++/xWXLlokWFhaio6OjmJycLIqiKN67d0/n2EuXLomWlpbia6+9prW9bt26IgDx66+/1jmmRYsWYuPGjcXs7Gyt7du2bdNp9+jRo8Xy3p4dO3YUzczMxHPnzmm2CYIgvvDCCyIA8eDBg4+89vLOr94+cuRInX0jRowQAYi7d+/W2v7WW2+JAMRvvvlGs00di759+4qCIGi2nzhxQgQgzpo1S7MtISFBBCAuWLBAp86BAweK9vb2Oq+vinjYedXtS0hI0GxTX/uDMf3xxx9FAGJ0dLRm21dffSUCEOfPn69V9pNPPtF5X+Xn54uenp5ily5dyn2flH1fhYSEiADEd955p8LXOnz4cBGA+Ouvv2q2CYIghoeH631P69sWEhKi83kgivqfK1EUxalTp4rW1tbi2bNntbZfv35dtLe31zq/+vOjUaNGYm5urlZ59ev/q6++0tpeXFwstmnTRvTz89O8htQxtbGx0WqPIAhi06ZNRS8vrwpdU3nefPNNEYDmvV+eh70/H/YeqsxzFhgYKBYWFmq237x5U7S0tBSHDRum2XbhwgURgBgWFiaqVCrN9vPnz4tKpbLc1/rD2l+R90GrVq3EgIAAved5lP3794sAxCFDhmh9Rpw9e1Y0MzMTO3furFW+bt26Ffp3QxRF8fTp0yIAcfbs2Tr7BgwYINrb22t9DgMQFQqFeOzYMa2yzz77rGhubi7m5ORotlU29pX5/Gzfvr1W7P/3v/+Jtra2or+/v2bblStXRADijBkztM575MgREYD4wQcf6G1rRdttyGt2xowZIgDx+++/1yqr3l42foZ+PuvblpeXp/XaUXvxxRdFpVIp3r59W7PtYZ9Boqj7WfHTTz+JAMTNmzfrlC2rsLBQLCoq0tk+d+5cEYB4/PhxzTZ939HU9D0fhrxHDP1srErsCmViJkyYgAMHDmj9zZ07t9zyP//8M8zMzPDGG29obX/llVfg6OholDZNmzZN55eW77//Hv7+/mjTpg3S0tI0f0VFRejZsyeOHDmC/Pz8h563Tp06iImJwZQpU/Dcc89h3Lhx2LdvH8aPH4+DBw9i06ZNWuUjIyMhiqJBdyv2798Pd3d3uLu7o1GjRoiIiECTJk2wf/9+eHh4AABsbW0BlHbXyM7ORlpaGtzd3dG4cWO9XWRcXFwwduxYrW1//fUXzp8/jxEjRqCwsFDrOencuTNsbW2xf//+R7Y3JSUFR48exXPPPaf1i51CodD8krd9+/YKX3953nrrLa3HgiBg586daNWqlc7kAbNnz9b8GvigN954Q6vPfrt27WBnZ6e5K/Qojo6OyMvLw+7du6t8YJ3am2++qfVY/atV2Tbv2LEDZmZmmoH/aq+99ppOt6UDBw4gOTkZY8eORWZmplbs1c+lvtg/GIPyCIKAn3/+GW3bttXcagdKXxNvv/12hc5hKFEUsWHDBnTt2hW1a9fWuiZbW1t06NBB7zW99tprOmMqvv/+e9jb2yM8PFzrPJmZmejfvz+uX7+u83oJDw/X6ragUCjQvXt3JCUlaXVdMZT6M/H//u//JHcdeDB+lX3OJk2apNX/u3bt2mjUqJHWc7Jr1y4Ape83pfK/f9abN2+u6dppqIq8DxwdHXHr1i0cOXLE4POrPy/eeecdrc+Ili1bon///jhy5MgjuzuWZ8OGDVAoFBg9erTW85yWlobnnnsOOTk5iI2N1TomODgYQUFBWtt69OiBkpISg6c4Ndbn55tvvqkV+6eeegojR45EfHw8Ll++DABo1KgRQkJCsG7dOq3X7OrVq2Fubo7Ro0dXut2GvmZ//vlneHt7Y/jw4Q89r7Gop78GSruFpaenIy0tDWFhYRAEAadOndI5Rt9nkD7qz4JffvkF2dnZ5ZaztLTU3JkrKSlBRkYG0tLSEBoaCgB6vyNUVGXeI1X12WgIdoUyMQ0bNtS8YCsiISEBPj4+mu4LapaWlqhXrx4yMjIkt6lRo0Y62y5fvoz8/PyHdlVJS0uDr6+vwfW98847+Prrr7F7926MHDnS4OPLCgoKwuLFiwGUdimrW7euzgDBuLg4zJs3D4cPH9a5nftgNwOgtAvJgzPuqP8RWLBggabP+oOSk5Mf2d6EhAQAQNOmTXX2BQQEQKlU6u33aqgHY5qamop79+7prdfFxQXe3t5663366ad1trm6uuLu3bsVasecOXPw+++/Izw8HK6urggJCUGfPn0wdOjQKhsQ/2Cb1V2byrb52rVr8Pb21kkirKys8PTTT2u9r9Sxf/nll8ut88HYu7u7V3jms5SUFNy7dw/+/v46+5o0aVKhcxgqNTUVd+/e1STm+pT9gqtW3mdFTk7OQ7sEJCcnax1b3usKKI3Tg593FTV58mT89NNPmDRpEmbOnInOnTujd+/eGD58uMHd7vS9hyrznJV3rTdu3NA8Vn8u6OvW0bhxY/zyyy8GtV1fvfreB0uWLEF4eDi6dOkCHx8fdOvWDX379sXgwYMfORg2ISEBSqVSb9e0pk2bYseOHUhISKhUd8fLly9DFEW97wm1B99zj3pNGcJYn5/6nhv1e/ratWua/RMmTMDIkSOxa9cuhIeHIycnBz/++CP69ev3yK42j2q3Ia/ZhIQEtG/fXud17OHhYbSZHMsqKSnB+++/j3Xr1uHq1as6Pz7p+36j7zNIn5CQEIwaNQpr1qzBhg0b0K5dO4SGhmLo0KE6n6tffPEFoqOjcfHiRQiC8Mg2VFRl3iNV9dloCCYWVCEP+/VOX/YviiKaN2+OZcuWlXtcZfvH+/r6wszMTGesSWW4ubk9NFFLTExE165d4eDggHnz5qFx48awtbWFQqHAtGnT9P4CUN7zAQDTp08vt9+7s7NzJa/C+Iw1U8+DCZZaRe8+NGzYEJcuXUJMTAxiYmLw22+/Yfz48ViwYAF+//131K9fv9xjH7bI18Nez1LbXN5xH330kc4Ug2o+Pj5aj+U+U5L6mkJDQzFz5swKH1fee8Pd3R0bN24s97gH160pL0Zl21YZrq6uOHnyJP744w8cOHAAv//+O958800sWLAAe/bsQXBwcIXP9eC1VvY5M/br0Zj1BgcH499//8W+fftw6NAhHDp0CBs3bsTixYtx5MiRapvKWhRFKBQK/PLLL+Vex4Nf8o35mnrc799BgwZh6tSpWL16NcLDw7F582bk5ubilVdeMeg8xnrNVkRlP5/LioiI0Kxt8s4778DDwwMWFhY4c+YMZs6cqfMlHzAsNmvXrsWMGTPwyy+/4I8//sDSpUvx3nvvYfny5Zg8eTIAYNmyZZg+fTp69eqFqVOnwsfHB5aWlrh16xbGjBmjtw1Vqao+Gw3BxKKG8/Pzw8GDB3Hv3j2tTLW4uBgJCQk6vyK4uLjoXRzO0F/BGzZsiNTUVPTo0UPvr3BSXLt2DSqVyqBfYipr+/btuHfvHnbu3KnVzQQozf4fHEBeHvUAbzMzswrdcSrvQ1d9h+TixYs6++Lj4yEIgt5fLKRyd3eHvb293nozMjJw586dcr80P8qjVvm1srLCs88+q+lCsGfPHvTt2xfLli3DypUryz1O/aXGGK/nBz399NPYv38/srOzte5aFBYW4tq1a1pJojr2tra2Bt1trCh3d3fY2dkhPj5eZ9+lS5eMXp+6TicnJ2RnZ0u+poYNG+Lvv/9Ghw4djP5rWmVWkDYzM0O3bt00s7WcP38ebdq0weLFizWzwVTmvMZ8zh6k7vpw5coVnff/lStXdMobc2VtOzs7DBo0CIMGDQJQ+uvt66+/jtWrV2PGjBnlHvf0009DEARcvnxZZyC2+nWr745wRTRs2BB79+5FnTp1jL6YY2VjX5nPz8uXL6Nly5Za29TPTdk4W1lZYdSoUfjss89w+/ZtrF69GrVr167Q5A2Parchr1k/Pz9cvXoVgiBo/bufkpKCzMxMrbLG+Hxev349unbtqtMtWj2jozE0a9YMzZo1w4wZM5CZmYmgoCDMmjULr7/+OhQKBdavXw8/Pz/88ssvWte8d+9enXMZ+tqpyvdIVeIYixquf//+UKlU+PTTT7W2f/3113qnZW3UqBFiY2O1pknLyMjQTE9aUaNGjUJSUlK5dywq0u1H3+1nQRA0Y0r69++vtc9Y082Wpc7+H8z0v/76a838/hXRqlUrNGvWDNHR0Xo/NEtKSrQ+YNVfsB780PXw8EDHjh3x888/48KFC5rtoigiKioKQOmCgsamVCo1UwU/+IH5/vvvQxCEStdb3rUC0HtXSr1I26NWR7e3t4eXlxd+/fVXrfhdu3YNO3bsqFRb1QYMGACVSoWlS5dqbf/yyy91+uOGhYXBw8MD77//vt425+fnIycnp9JtMTMzQ79+/XDq1Cmt6VlFUcSHH35Y6fM+jFKpxMiRI3HixAm903cCeORMVmqjRo2CIAiYPXu23v0V+awoj52dHTIyMir8S52+15u/vz9q1apVoffnwxjzOXuQ+rPw008/1fqF9K+//tLMSlNWZdqvj5T3Z3h4OAAgKipKKz4XLlzAzp070blz50rf1X7ppZcAlHal1DcFt9TXFGB47Cvz+fnJJ5+gqKhI8/jmzZvYuHEjGjdurJMwjR8/HiqVCjNnzsSxY8cwZsyYh/56XdF2G/Ka7d+/P+7cuYMffvhBq8zHH3+sc5wxPp/NzMx03tu5ublas8BVVnp6us7dBicnJ9SrVw95eXkoKCjQtEGhUGi1Q91F60GGvnaq8j1SlXjHooZ75ZVX8NVXX2Hu3Lm4evWqZrrZH3/8EQ0aNNC55Th58mS8+OKL6NGjB1566SVkZmbi66+/Rt26dQ36Iv3GG2/gwIEDmDFjBn799Vf06NEDDg4OSExMRExMDKytrXXmqH/Q+PHjkZ2djY4dO8LX1xdpaWn4v//7P5w+fRoDBgzA4MGDtcoba7rZsvr06QMbGxvNlLnOzs74888/sWfPHtSvX7/Ct2zVv2z06NEDLVq00Ezbl5eXh6tXr2Lbtm2IiorStLtDhw5YsWIFJk2ahL59+8LCwgJBQUGoV68ePv30U4SEhKBLly6a6WZ37dqFffv2YcSIEXjmmWeMcu0PWrJkCQ4cOIDw8HBMmjQJDRo0wO+//47Nmzeja9euBg0SLMvV1RUNGjTApk2bUL9+fc0c6v3790evXr3g5OSELl26wNfXF5mZmZr5yNVfHh5m8uTJmDt3Lvr06YPw8HDcvn0b0dHRaNasGU6ePFmp9gKl85CvWrUK7777LhISEhAcHIy4uDhs2bJF53Vha2uLdevWITw8HI0bN8bLL7+MBg0aIDMzE/Hx8ZrpXKWsOr548WL88ssv6NevH6ZMmYKnnnoKP//8c6UHv1bEe++9hz///BNDhgzBkCFD0KFDB1haWuLGjRvYs2cP2rRpo7M2gj7qKWZXrFiBM2fOoF+/fnBzc8PNmzcRGxuLq1evVvoOU4cOHbBr1y5MnjwZHTt2hJmZGXr06KGZmOFB48ePx82bN9GrVy/NSribN29GTk4ORo0apXXe8t6fD2Os5+xBTZs2xYQJE7Bq1SqEhoZi4MCBSE1NxcqVK9GqVSucPn1a69fSyrb/QQEBAejQoQOCgoLg4+ODO3fuYNWqVbC0tMSwYcMeemzPnj0xZMgQbNq0CRkZGejXr59mKk1ra2t89tlnBj8Pau3atUNkZCQiIyMRGBiIF154QdO+06dPY8+ePVpf2A1R2eeuMp+fJSUl6NKlC4YPH46cnBxER0cjPz9f73MTEBCAzp074/vvv4dCoXjomC5DGPKanTlzJjZu3IixY8fixIkT8Pf3xx9//IGjR4/Czc1N5xd7qZ/PgwcPxldffYWhQ4ciNDQUycnJ+Pbbb/VO+W2odevW4ZNPPsHAgQPRoEEDWFhY4LfffsO+ffswZMgQ1KpVS9OG2bNno0+fPnj++eeRnZ2NjRs36p1quUmTJrC3t8cXX3wBGxsbODk5wcPDQzMxwoOq8j1Spap83ikyCvU0ZR999NEjyzw4lVlKSoo4evRo0dnZWbSxsRG7d+8uxsXFiW3atNE7VeCHH34o1qlTR7S0tBT9/f3F1atXP3S62QenoVQrLi4WP/30U7Ft27aijY2NaGNjIzZo0EAcMWKEuG/fvkde8zfffCOGhISInp6eooWFhWhnZycGBQWJK1eu1JpS8cH2GDLdbN++fR9Z7rfffhM7deok2tnZiY6OjuKzzz4r/vXXX3qnsnzUlIjXr18XX331VbFu3bqihYWF6OLiIrZu3VqcNWuWmJiYqCmnUqnE6dOni7Vr19ZMF1n2us6ePSsOGDBAdHZ21sTpgw8+0JmCtzyPmm62PNeuXRNffPFF0d3dXbSwsBDr1asnzp49W2f6voe9NvQ9R8ePHxc7duwo2tjYaE3XumrVKjE0NFTzGvDy8hL79OmjNa3qwxQXF4szZswQvby8RCsrK7FVq1bizp07HzrdrD7QMz3r3bt3xZdffll0cXERbWxsxJCQEPHkyZPlTnH6119/iSNHjhR9fHxECwsL0cPDQwwODhbfffdd8e7du5pyhk6Rqnb+/HmxZ8+eoo2Njejs7CyOGDFCTE5OrrLpZkVRFHNzc8V3331XbNasmWhtbS3a2dmJ/v7+4iuvvKI1dae+z48HrVu3TuzcubNob28vWllZiXXr1hUHDhwobtq0SVPG0Ckqc3NzxZdffln08PDQvI8e1ob/+7//E/v37y/Wrl1btLS0FN3c3MSuXbuKW7du1Sr3sPfno95DxnjO9MWqpKREjIyMFH19fUVLS0uxefPm4ubNm8Xp06frTKFb2fY/+LqJiooSu3TpIrq7u4uWlpbiU089JQ4ePFg8ffp0uddfVnFxsfj++++L/v7+oqWlpejs7CwOGDBAPH/+vE5ZQ6abVdu1a5fYq1cvzefkU089Jfbu3Vv88ssvH3pdavpiICX2hn5+XrhwQZw8ebLo6ekpWllZie3atRP3799f7vnXrVsnAhB79Ojx6CenDGO9ZtXXOHDgQNHOzk60t7cXn3vuOfHatWuiq6ur2KdPH62yhnw+l/f+fuutt8Q6deqIVlZWYoMGDcSoqCjx4MGDOv9mPuoz6MH3VFxcnDhq1Cixfv36oo2NjWhvby+2aNFC/Pjjj8WCggJNuZKSEnHJkiVi/fr1RUtLS7FOnTrijBkzxEuXLun9rNq9e7fYqlUr0crKSmsK3vI+2yr6HjH0s7EqKUTxMY3mIFlRqVRwc3NDUFCQ3r6ARERk+vr3749ff/0V2dnZkrvG0OMRGRmJhQsXalasrqgff/wRQ4cOxcaNG3WmfK1Od+/ehZubG1599VVER0dXd3OoinGMxRNA33oR0dHRyMzMRM+ePauhRUREZEz6PufPnz+PX375BT169GBS8QRYuXIl3Nzc8Pzzz1dbG/S9DtXjDfh948nAMRZPgPHjx6OgoAAdO3aElZUVYmNjsXHjRjRo0AATJkyo7uYREZFEa9euxbp169C3b1+4u7sjPj5eM97h3Xffre7mURVJSUlBTEwM/vjjD/z++++Iioqq8GyFVeHZZ59F3bp10bp1awiCgJiYGOzatQsdO3bUDEammo2JxROgV69eWLlyJRYtWoR79+7B09MTr7zyChYtWlRli4wREdHj07p1a2zfvh2fffYZ0tPTYW9vjx49emDBggVo1apVdTePqsilS5cwYsQIODk5YeLEiZg+fXq1tqdfv35Yt24dtm/fjvz8fDz11FOYPn06FixYwLtmTwiOsSAiIiIiIsk4xoKIiIiIiCRjYkFERERERJJxjMVjIAgCbt++DXt7e4OXdCciIiIiqi6iKCInJwc+Pj5QKh9+T4KJxWNw+/Zt+Pr6VncziIiIiIgq5X//+x+eeuqph5ZhYvEYqGde+t///gcHB4fHWrcgCEhNTYW7u/sjs0x6/Bgf+WOM5I3xkTfGR/4YI3mTQ3yys7Ph6+tboZlEmVg8BuruTw4ODtWSWBQUFMDBwYEfGDLE+MgfYyRvjI+8MT7yxxjJm5ziU5Hu/HwFERERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIO3ZUalUqG4uNho5xMEAcXFxSgoKKj2QT+kq6bHx8LCAmZmZtXdDCIiInoMmFjIhCiKSEpKQmZmptHPKwgCcnJyuDifDD0J8XFycoKXl1eNvT4iIiIqxcRCJtRJhYeHB2xsbIz2JUwURZSUlMDc3Jxf7GSoJsdHFEXk5eUhJSUFAODt7V3NLSIiIqKqxMRCBlQqlSapcHV1Neq5a/IX15qgpsenVq1aAICUlBR4eHiwWxQREVENVvM6dZsg9ZgKGxubam4JkfGpX9fGHDtERERE8sPEQkZq4i/WRHxdExERPRmYWBARERERkWRMLMgoFArFI//WrFlT7W2ztLRE48aNMWfOHOTm5j62NuzYsQNffPGF0c53+PDhcp/ntLQ0o9XzMGvWrMHGjRsfS11EREQkfxy8TUYRGxur9Tg4OBhTpkzBiBEjNNvq16//uJuloW5LQUEBDh48iPfffx8JCQn44YcfHkv9O3bswKlTpzBp0iSjnve7776Dv7+/1jYnJyej1lGeNWvWwM7OTivGRERE9ORiYkFG0aFDB51tderU0btdLT8/XzNrUFUr25Zu3brhzp07+Pbbb/H555/Dzc2tUud8nO0vT7NmzdC2bdtHllOpVBAEARYWFo+hVURERPQkYlcoeiwiIyNhZ2eHEydOIDg4GNbW1li5cqWmS8+pU6e0yoeHh6Nbt25a2y5fvowBAwbA0dERtra26Nu3L/79999KtUf9ZTwhIQHx8fEYNmwYfH19YWNjgyZNmmDp0qUQBEFT/vr165ruXOPHj4erqyvat28PACgsLMScOXNQt25dWFlZISAgQKuL0JgxY7B27VpcvHhR011pzJgxmv3bt29Hq1atYG1tDR8fH0RERKCgoKBS16XWrVs39OvXD2vXrkXjxo1hZWWFc+fOAQC++uorzTY/Pz8sXrxY61rXrFkDhUKBuLg49OnTB7a2tmjYsCHWrVundf7ffvsNu3fv1lxTZGSkpDYTERGRaeMdC3psioqKMGLECLz55ptYsmQJXF1dkZ6eXqFjr127ho4dO6JZs2ZYs2YNlEol3nvvPTzzzDO4cuUKrKysDGpLQkICAMDHxwfx8fFo3LgxRo4cCXt7e5w9exYLFizAvXv3sGDBAq3jZs+ejb59++KHH37QfBkfMmQIjhw5ggULFiAgIAB79uzBiy++CGdnZ/Tp0wfz5s1Damoq4uPjsWHDBgCAu7s7AGDnzp0YNmwYhg0bhvfffx/x8fGYM2cOEhMTsXXr1kdeh0qlQklJieaxUqmEUln6e8GpU6dw/fp1vPvuu3B2doavry8+//xzTJ06FVOmTEG/fv1w9OhRREZGIjMzEx9//LHWuUeOHInx48cjIiICX3/9NcaMGYN27dohICAAX3zxBV588UXY2NhojnvqqacMigERERHVLEwsZCorKwt//fWX5POIogiVSgUzM7NKTfvZvHlzODo6Sm4HULqOwXvvvYehQ4dqth0+fLhCxy5cuBAuLi44cOAArK2tAQAdO3bE008/jdWrVz9y7IIgCCgpKdGMsfjyyy8RHByM2rVro3bt2njmmWcAlD5fnTt3Rl5eHlasWKGTWAQGBuKbb77RPD506BB27tyJffv2oVevXgCAnj174s6dO1iwYAH69OmD+vXrw93dHTdu3NDpGrZw4UIEBQVhw4YNUCgU6N27N2xsbPDqq6/ir7/+QvPmzR96XQ+eb9y4cZr2paen4+TJk/D19QVQmoS8++67GDZsGD777DMAQK9evVBUVISlS5di9uzZWgs0Tp48WfO8duzYEbt378b//d//Ye7cuWjSpAkcHBxgZ2f30O5uREREBslJBgqydLdbOwL2no+/PTKQlF2Evf/eQEpOITwdrBHaxBO1naq3K3Z5mFjI1F9//YUuXbpUdzPwxx9/oHPnzkY7X9++fSt13P79+zFs2DCYm5trfqF3dnZGq1atcPLkyUceP3PmTMycOVPzuGfPnli1ahUAoKCgAFFRUdiwYQMSExO1FnK7d+8e7Ozsym3//v374eLigh49emjdOejZsycmTpyoSer0uXfvHs6ePYsPPvhAa/vQoUPx6quv4siRI49MLNatW4eAgADNY/WdEABo0aKFJqkAgPj4eKSlpeGFF17QqS8qKgonTpxAnz59NNvViRIA2Nraom7durh58+ZD20NERFRpOcnA98+Xn1i8uO2JSy4OXk7Ge7v+Rm6xCPXPw+tir2N2nwCENpHfc8HEgh4bGxsbrS/phkhLS8Py5cuxfPlynX2WlpaPPP6NN97Aiy++qBlXYG9vr9k3c+ZMfP3111iwYAHatGkDJycn/PTTT1i8eDEKCgq02uzpqf0mTktLQ3p6ermDou/cuVNuF6HMzEyIoggPDw+t7Y6OjrCysqpQN7GAgIByB28/2NaMjAy929WPH6zvwdmlLC0tJY/9ICIiKldBVumf0gyAAjC3LP1vSdF/+56gxOJWZj7e33sFuUUC7K0tYGttDlEE0u4VIuqXy2ji4wAfmd25YGJBj42+rljqbk1FRUVa2zMyMrTKu7i4oG/fvnq7PJVNEsrz1FNPlfsFfMuWLXj11Ve17mjs3r27Qtfg4uICd3d37NmzR2/5B5OGspycnKBQKJCamqq1PSsrC4WFhXBxcSn32IrQ11YASElJ0dqenJystZ+IiKha5KYBquLSv4JMwMwKcPApTTSE4kceXtNsPX0TydkFUIgisgpKYGNpjjouteBmZ4WU7AIcuJSM0R39qruZWphYyFTz5s3xxx9/SD6PMcZYVCX1r/mXL19Gx44dAZTeBThz5gzatGmjKRcaGooLFy6gVatW5XYtqqz8/Hytux4qlQqbNm2q0LGhoaH48MMPYWlpiRYtWpRbTt+v/XZ2dggMDMS2bdswffp0zfYff/wRAIzaBQ0AGjduDHd3d2zZsgUDBw7Uqs/S0lIzy1VF8Q4GEREZTU4y8PNUIDcVEEUAYmmCkX4VgAKwqdzU8KbqwKVkfHckAXmFKgCACCC3sATZBSVwsC7tJZGcLb9/g5lYyJSjo6NRvliKooiSkhKYm5tXKrGoak899RSCgoKwcOFCODo6wtzcHB988IHOgPGFCxeiXbt2CAsLw4QJE+Dp6YmkpCT89ttv6NKlC4YPH17pNvTs2RNff/01mjRpAjc3N3zxxRcoLCys8LH9+/dH79698fbbb6NFixbIzc3FxYsXcfXqVc1A6oCAAHz77bf44Ycf0LBhQ7i5ucHPzw8LFizAwIED8dJLL+HFF1/ElStXMGfOHAwaNMjoSZ2ZmRnmzZuHqVOnwsPDA88++yyOHTuGDz74ANOmTdMauF0RAQEBWLt2LX7++Wd4e3vDx8cHPj4+Rm0zERE9IQqygMIcQKG4n1gAUCoBKABRKP17QtzKzMf7v1yGShShUACCCCjvf4VLyiqAlXnp7I+eDtbV2Er9uI4FVbsNGzagQYMGGDNmDN566y288cYbOt2WGjRogBMnTsDV1RWTJk1CWFgYZs2ahdzc3IfeKaiIzz//HCEhIZgyZQrGjRuH5s2bY86cORU+fuvWrZg4cSK++OIL9OnTB+PGjcP+/fsREhKiKTNu3Di88MILmDJlCtq1a6dZ8+G5557DDz/8gL/++gsDBgzA+++/jwkTJuD777+XdE3lmTJlCr788kvs2bMH/fr1w+rVqxEZGYkPP/zQ4HO9/fbb6NSpE0aNGoV27dppBsMTEREZLDcNEEpQ+tu8UPpfEf/dvXiCqLtAmZfNsRSAuZkCJSoBd7IK4FDLAj1lOHhbIYrikxWtapCdnQ1HR0dkZWXBwcFBZ39BQQESEhJQr149zZgDY5H7HYsn3ZMQn6p8fT8OgiAgJSUFHh4emjVCSD4YH3ljfORPFjHKSQbW9AXu/ovSpOIBCjPA5WlgzO4aP3j7wKVkzNhyDln5xZq7FUBpYqFUKKASRTjWssDHg1s+tlmhHvU9tix2hSIiIiKi6qPuBqVUAoICgAqAsnTQtigAtZyB/p/V+KSibBcopVIBQfhvilmFQgEHa3MUlQh4uVM9WU41C7ArFBERERFVN6F0kDIU93+iL3sXX2kO2Nb8wdsHLyUjI68YtcyVEEVR0wHM8n6fqGJBhKejNQa30T+NvRyYXGKxcuVK+Pn5wdraGkFBQThx4sRDy2/ZsgX+/v6wtrZG8+bNdaYFjYyMhL+/P2xtbeHs7IzQ0FAcP35cq0x6ejpGjhwJBwcHODk5Ydy4cbh3757Rr42IiIjoiZObBuRnlI6xUA/SFlWlU8yKKsDCpnSBvBru2LW7yMorQnpesaYLlAigWCVCAGCmVGB2nwDZrV1RlkklFps3b0ZERAQWLFiAM2fOoGXLlggLC9OZl1/t6NGjGD58OMaNG4e4uDiEh4cjPDwcFy5c0JRp1KgRVqxYgb/++gtHjhyBn58fevXqpbW2wMiRI3Hx4kUcOHAAu3btwu+//44JEyZU+fUSERER1XhF9wCIgKJMD32lxf3HSqDbrCeiG9TJ6+n371KUzgalwP3ZoESgloVS1l2g1EwqsVi2bBnGjx+PsWPHokmTJoiOjoaNjQ2+/fZbveU//fRT9O7dGzNmzEBAQAAWLVqE1q1bY8WKFZoyI0aMQGhoKJ5++mk0bdoUy5YtQ3Z2Ns6fPw+gdH2FvXv34ptvvkFQUBA6d+6Mzz//HJs2bcLt27cfy3UTERER1Vh5GShNLNSdfxTQfLNWKErHWNRwW0/fRF5RCQCgRCidDcrCTAkLMwVEBWBjYS7rLlBqJjN4u6ioCKdPn8bs2bM125RKJUJDQxEbG6v3mNjYWERERGhtCwsLw44dO8qtY9WqVXB0dETLli0153ByctKa/jQ0NBRKpRLHjx/XWmhMrbCwUGsdhOzsbAClMy8Igu5sB4IglPalu/9nbOpzcgIwearp8VG/rst7/cud+v1pim1/EjA+8sb4yF+1x+heMhS/Rd1fq0K9UQRURaX/qzSHaGkL1ODX0MHLpYvh5RdpX2OJIECpUEABoK2fM7wcrKolTobUaTKJRVpaGlQqFTw9tW8BeXp6Ij4+Xu8xSUlJessnJSVpbdu1axeGDRuGvLw8eHt748CBA3Bzc9Ocw8PDQ6u8ubk5XFxcdM6jFhUVhYULF+psT01N1btScXFxMQRBQElJCUpKSvSes7LUK28DqLHTmZqyJyE+JSUlEAQBd+/ehYWFRXU3x2CCICArKwuiKHK6TBlifOSN8ZG/6o6RWcY1OBfegxmUuL9wBQAzQKEERAGClSPS8wFVOd3eTV1SdhHe2/U3ilUqKBWASj12/f5+R2szlKgENHGzKLfrf1XLycmpcFmTSSyqUvfu3XH27FmkpaXh66+/xpAhQ3D8+HGdhKKiZs+erXWnJDs7G76+vnB3dy93HYucnByYm5vD3LxqQmKKX+ieJDU5Pubm5lAqlXB1dTXZdSwUCgXc3d35xUiGGB95Y3zkr9pjVPAvFLi/tHTZRRsAQAEozSzg6uoKuFXuO5nc7f33BnKKBNhYmiOvqPQujQKAlYUSxSUCilSAi40Fwts9DQ8X22ppoyH/dptMYuHm5gYzMzMkJydrbU9OToaXl5feY7y8vCpU3tbWFg0aNECDBg3QoUMHNGzYEKtXr8bs2bPh5eWlkyGWlJQgPT293HqtrKxgZWWls12pVOp90yqVSigUCs2fMYmiqDlnTf1F3JQ9CfFRv67Le/2bAlNvf03H+Mgb4yN/1RajnGTg5zeAvHRoLYwn3O+9oTADrOyhqOVcusZFDXQ8IR1ZecWAQqHpCSYCKCoRSu/dKBWY3OUpPOViW23vIUPqNZkoWVpaok2bNoiJidFsEwQBMTExCA4O1ntMcHCwVnkAOHDgQLnly55XPUYiODgYmZmZOH36tGb/r7/+CkEQEBQUVNnLISIiInqylV0YTzMjlOL+jFBmNX5hPK2ZoMTSmaCUiv9u2NSyMMPYTn7o8rRTNbbSMCZzxwIAIiIiMHr0aLRt2xbt27fH8uXLkZubi7FjxwIARo0ahdq1ayMqKgoA8MYbbyAkJARLly5F3759sWnTJpw6dQqrVq0CAOTm5uK9997Dc889B29vb6SlpWHlypW4desWXnjhBQBAQEAAevfujfHjxyM6OhrFxcWYPHkyhg0bBh8fn+p5IoiIiIhqAvXCeCgzI5RaDV8Yr+xMUOqxFZbmpWNNilUibCzNMKh1baCo4mMcqptJJRZDhw5Famoq5s+fj6SkJAQGBmLv3r2aAdqJiYlat2s6duyIjRs3Yu7cuZgzZw4aNmyIHTt2oFmzZgAAMzMzxMfHY+3atUhLS4OrqyvatWuHP/74A02bNtWcZ8OGDZg8eTKeeeYZKJVKDBo0CJ999tnjvXgiIiKimkS9MJ5YduIa4b8ZoGrwwngHLv03E1TZOSGLVf/NBNXezwU+TrWQksLEospMnjwZkydP1rvv8OHDOtteeOEFzd2HB1lbW2Pbtm2PrNPFxQUbN240qJ30eFy9ehUff/wxjh07hgsXLsDf319rAcSy4uPjMWXKFBw9ehT29vYYNWoUFi9eDEtLy8fcaiIiItJaGE9Ulf6/0qJ0EQdRqLEL493KzMf7v1yGShS1ZoJS3l+6w9nWAiqViKCnXau1nZVhcokFUVkXL17E7t27ERQU9NB1EjIyMtCjRw80bNgQ27Ztw61btxAREYG8vDytBROJiIjoMVOgNJkAShfEU2+soQvjHbyUjOz8Yng5WOHf1NK7NeZKwNysdCaogmIBHvZW6CnzVbb1YWJRA93KzMfBS8lIzi6Ah70VujdyRR03++puVpXo378/BgwYAAAYM2YMTp06pbdcdHQ0srOzsX37dri4uAAond1r0qRJmDNnDsfLEBERPW4lRaWJhFhmfIV4/wdChRKwtKu2plWlq6n3kF8soDC7sHSBcREQoUCJSoSA0pmgZvcJgI9TLZNbXNJkZoWiijlwKRmjVh/Hil//wZZT/8OKQ1cxdu1pHLyc/OiDjWT8+PEIDg7Gvn370K5dO9jY2KBly5Y4efKk0euq6BRov/zyC0JDQzVJBQAMGTIEgiBg//79Rm8XERERPUROMhATWTqeQnzwy7MCqOVSIwduH7iUjF/+uoPcohJkF5RAJZR2gXKwNoeTjQVsLczwcqd6CDXBuxUA71jIliiKKCwxLEu9nZmPJXsuI7ewBG72VlAqFBBEEXfvFWLJnng08LCDt2OtCp/PylxZqbUVzp07hzt37uDjjz/GnDlzoFQqMWXKFEybNg1//vmnVtmyK08/jNSFA+Pj4/Hyyy9rbXNycoK3t3e5K7cTERFRFSk71awgoHSshRlK+0UJZbpE1RzqsRUAoFQooBJFKFC6jkhuYQkcalnA09Eag9s8Vb0NlYCJhUwVlgh4ITrWoGPS7hUiNacQ5koFcgrUMyyIEEQR6bnFGPH1cbjZ6S7cV54tE4NhbWFmUBsEQcDFixfRqlUr7N27F2ZmpcefOHEC3333nU75tWvXaqYLfpiEhAT4+fkZ1JayMjIy4OTkpLPd2dkZ6enplT4vERERVUJuWulUswolgPs/MCrud4kSAVjZ17gZobaevonk7AJYmimhQGlSoV6zovj+quPqLlCmiolFDVJyf1qBB+8ylD4WNfur0tWrV5GXl4fp06drkgoAKCoqgpub7i3N/v37V6iLFMdAEBER1RA5ycDPU4H89PvLV6gTCkXpnYoauDCeenrZ3EIV8hQqCGLpvRmHWhYwVyqRX1SCPs28TbYLlBoTC5myMldiy8SHrxD+oO+P3UD0b//C/X43KDWVICDtXhFeDXkaL3aoa1AbDHX+/HkAQLdu3bS2X7x4UWttEDUXFxc4Oj76FwmpXaGcnZ2RlZWlsz0jI0Nr3AURERFVMXU3KMX9+VWF+wmFUlk63kJpVqPGV2hNL6tUQCXcv1uhVCCvSIW6rlYARDTwMP3B6hy8LVMKhQLWFmYG/fVp7g3HWhZIzy0CgPvJhYiM3GI41rLAs829DTpfZcdX+Pr6wtlZe4q4uLg4BAYG6pRfu3YtLCwsHvl3/fr1SjyL//H399cZS5GVlYU7d+7A399f0rmJiIjIQOoVt8X74ytwf3Yosep7VzxuBy8lIyOvGNbmSgiCuncJYGGmQIlKwJ2sAjjUsjDJ6WUfxDsWNUhtp1qY3ScAUb9cRkp2gWa7vbU5ZvXxfyx99s6fP4+WLVtqbUtKSkJSUhJatWqlU/5xdYXq06cPlixZgszMTM1Yiy1btkCpVKJXr16Szk1EREQG0Lfitqi6v0geatyK28eu3UVWXhEAhWaVbUEEikpKV90uO72sqWNiUcOENvFEEx8HHCizjkWPxm7wdX08t9fOnz+PESNGaG07e/YsAOi9Y+Hq6gpX18qvLJmXl4c9e/YAAG7cuIHs7Gxs3boVABASEgJ3d3cAwMSJE/H5558jPDwcc+bMwa1btzBjxgxMnDiR4zeIiIgeJ60Vt+8nFzV0xe1bmfk4eT0dIkpnwlQoADNFaWIBALVMfHrZBzGxqIF8nGphdEc/AKUv4pKSkocfYCTZ2dm4ceMGWrRoobU9Li4OPj4+8PDwMHqdKSkpeOGFF7S2qR8fOnRIM9bD2dkZMTExmDJlCsLDw2Fvb49XXnkF7733ntHbRERERA+RlwHNgG0AwP0xFuouUTVoxe2tp28ir6ikNGdC6cRX5uZKiKKIYpUIG0szk55e9kFMLMhoHBwc9K4QWd74CmPw8/ODWMH+mAEBATh48GCVtIOIiIgqICcZ+C3qgUXxREBVOj4USvMas+K2eiao/CJBk0aJKO0CpVQooADQ3s+lRnSBUuPgbapycXFxesdXEBER0ROmIAsozi+zGB5Ku0QpLUq3WTvXiBmhys4Epb5MJUrXrVAAcLa1gGMtCwQ9Xfnu4HLExIKqVE5ODv79998qu2NBREREJiQ3DRBK7n/Zvv87vnoWSgVKp5qtAdSL4SlEUTPRlbmZApb3p/IvKBbgbGtZI2aCKotdoahK2dvb6+0eRURERE8Y9cJ4eekAynw3EIpL/6swqxErbpddDK+sEkGEUiy98po0E1RZTCyIiIiIqOqpF8ZTKu/nFQIAZeldClGoEStul+0Cpbg/+5O6+5NCoYCDtTmKSoQaNRNUWewKRURERERVT90NqizF/a+iCpQO3Dbx8RXqLlBK/DelrJlCAQv1TFCCCE9H6xo1E1RZvGNBRERERFWrvG5QYsn9eVhNvxtU2S5Q6mEjAKASRYiqmt0FSo2JBRERERFVrbLdoETl/YXxFKV3KWpAN6iyXaCUCkVpVyiUjktXPgFdoNTYFYqIiIiIqp6gHsx8/45F2Z/1TbwblLoLlLkCEO5PA6VUAJZPSBcoNd6xICIiIqKqlZsG5Gfcv1Nxnyj8t1CehY3JdoN6sAtU2cXwSlTiE9EFSo2JBRERERFVraJ7AMTSxfDUyYXSAhDF0uSi2yyT7Aal1QVKqYAg3O8ChSerC5Qau0IRERERUdXKy0Dpb/jqgduK0q5Q6oEItZyrr20S/NcFqjSpEFF6OZYWT1YXKDXesSAiIiKiqpOTDPwW9V+3JwCACKiKSv9XaQ5Y2lVL06Qo2wVKqVBpukABT14XKDUmFkRERERUdQqygOL80illxfsDuJUWpf8VBcDa2eQGbpfXBQp4MrtAqbErFJm0q1evYuLEiQgMDIS5uTmaNWtWbtn4+Hj07NkTtra28PLywttvv42ioiKDyxAREVElaGaBUpROO6vuCqU0q85WVcrBS8nIzi+Gl4M1gP8GbFs9oV2g1HjHgkzaxYsXsXv3bgQFBUEQBAiCoLdcRkYGevTogYYNG2Lbtm24desWIiIikJeXhxUrVlS4DBEREVWCeoVt9dgKUbg/cBsmuTBecnYBgNLpZC3NFCgQSu9cPKldoNSYWJBJ69+/PwYMGAAAGDNmDE6dOqW3XHR0NLKzs7F9+3a4uLgAAEpKSjBp0iTMmTMHPj4+FSpDRERElaA0A5z97v+0LwBmlkBJYWkXKRNcGM/z/p0KAKjvboes/GKUCCKKVQIKilRPXBcoNXaFqmlykoHUv//7S/sbSPsHuJf82Jowfvx4BAcHY9++fWjXrh1sbGzQsmVLnDx50uh1KZUVewn/8ssvCA0N1SQMADBkyBAIgoD9+/dXuAwRERFVkiAAZualSQUAQAGYWZjc+AoACG3iCYdaFki7VwgRgJONJVxsLWGmVDyRXaDUeMeiJslJBr5/vnSQVBnmogjUcgJe3PZYfhE4d+4c7ty5g48//hhz5syBUqnElClTMG3aNPz5559aZUVRhEqlKudM/zE3l/ZSjY+Px8svv6y1zcnJCd7e3oiPj69wGSIiIjKQtWPpX0EWUFSsf5+Jqe1UC7P7BCDql8tIud8tCgAcalk8kV2g1JhYyJUoAiUFjy5X1r1kID/zgV8DALGkCIr8jNL91g4VP5+5dZmBVhUjCAIuXryIVq1aYe/evTAzKx2QdeLECXz33Xc65deuXYuxY8c+8rwJCQnw8/MzqC1lZWRkwMnJSWe7s7Mz0tPTK1yGiIiIDGTvWfrj5gM/fAIoTSpMrBuUWmgTTzTxccCBS8lIzi6Ap4M1ejbxfGKTCoCJhXyVFADf9jb8mNzk0gFSiv+6CCkEAYAAbBtfmixU1Mt7AQvD3hxXr15FXl4epk+frkkqAKCoqAhubrq3Ovv371+hLlIc30BERGTC7D1NNoF4GB+nWhjd0a+6myEbTCzIqM6fPw8A6Natm9b2ixcvomnTpjrlXVxc4Oj46FugUrtCOTs7IytL95eSjIwMzZiKipQhIiIiIv2YWMiVuXXpHQNDpP0D/DAcsLTRujMhFudDUZwHPP814NbQsDYY6Ny5c/D19YWzs7PW9ri4OEybNk2n/OPqCuXv768zTiIrKwt37tyBv79/hcsQERERkX5MLORKoTC4G5JmTISqBFCUWdRNKCntGmVubfg5DXT+/Hm0bNlSa1tSUhKSkpLQqlUrnfKPqytUnz59sGTJEmRmZmrGUWzZsgVKpRK9evWqcBkiIiIi0o+JRU1SzqwLCvWsUI9h1oXz589jxIgRWtvOnj0LAAgMDNQp7+rqCldX10rXl5eXhz179gAAbty4gezsbGzduhUAEBISAnd3dwDAxIkT8fnnnyM8PBxz5szBrVu3MGPGDEycOFGTtFSkDBERERHpx8SiJtE764KIkhIVzO1cqnzQVHZ2Nm7cuIEWLVpobY+Li4OPjw88PDyMXmdKSgpeeOEFrW3qx4cOHdKM9XB2dkZMTAymTJmC8PBw2Nvb45VXXsF7772nOa4iZYiIiIhIPyYWNc2Dsy6IIlBSAkgc/FwRDg4OEARBZ3tcXJzeuxXG4OfnB1EUK1Q2ICAABw8elFyGiIiIiHRx5W2qcnFxcXrHVxARERFRzcHEgqpUTk4O/v333yq7Y0FERERE8sCuUFSl7O3t9XaPIiIiIqKahXcsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIkFERERERFJxsSCiIiIiIgkY2JBRERERESSMbEgIiIiIiLJmFgQEREREZFkTCyIiIiIiEgyJhZERERERCQZEwsiIiIiIpKMiQWZtKtXr2LixIkIDAyEubk5mjVrVm7Z+Ph49OzZE7a2tvDy8sLbb7+NoqIig8sYUo6IiIjoSWFe3Q0gkuLixYvYvXs3goKCIAgCBEHQWy4jIwM9evRAw4YNsW3bNty6dQsRERHIy8vDihUrKlzGkHJERERETxImFmTS+vfvjwEDBgAAxowZg1OnTuktFx0djezsbGzfvh0uLi4AgJKSEkyaNAlz5syBj49PhcpU9FxERERETxp2hSKjGz9+PIKDg7Fv3z60a9cONjY2aNmyJU6ePGn0upTKir2Ef/nlF4SGhmoSAQAYMmQIBEHA/v37K1zGkHJERERETxImFjXUX6l/4Y1f38BfaX899rrPnTuHmzdv4uOPP8acOXPwww8/ICMjA9OmTdMpK4oiSkpKHvknVXx8PPz9/bW2OTk5wdvbG/Hx8RUuY0g5IiIioicJu0LJlCiKKFQVVvrYdZfW4UTSCViaWWJeu3mwgAUUCoVB57EyszL4GEEQcPHiRbRq1Qp79+6FmZkZAODEiRP47rvvdMqvXbsWY8eOfeR5ExIS4OfnZ1BbysrIyICTk5POdmdnZ6Snp1e4jCHliIiIiJ4kTCxkqlBViNF7R1fq2LziPNy6dwuiKCImMQaX716GraWtwedZ23strM2tDTrm6tWryMvLw/Tp0zVJBQAUFRXBzc1Np3z//v0r1EWK4xaIiIiI5I2JRQ0jiiLSC9IhiALMFeYoEUuQUZgBGwsbg+8+VMb58+cBAN26ddPafvHiRTRt2lSnvIuLCxwdHR95XnNzaS9VZ2dnZGVl6WzPyMjQjJWoSBlDyhERERE9SZhYyJSVmRXW9l5r8HHnU8/jnSPvwNnKGTYWNsgrzkOBqgDT20xHS4+WBrfBUOfOnYOvry+cnZ21tsfFxekdY/G4ukL5+/vrjH/IysrCnTt3NOMlKlLGkHJERERETxImFjKlUCgM7oYkiiK2Xd2GAlUBbC1sUSQUwVxpjoLCAmy/uh3tvdtX+V2L8+fPo2VL7QQmKSkJSUlJaNWqlU75x9UVqk+fPliyZAkyMzM14yO2bNkCpVKJXr16VbiMIeWIiIiIniQmNyvUypUr4efnB2trawQFBeHEiRMPLb9lyxb4+/vD2toazZs3x549ezT7iouLMXPmTDRv3hy2trbw8fHBqFGjcPv2ba1z+Pn5QaFQaP29//77VXJ9UpQIJUjKTUIt81rIK8lDXnEe8kryUMu8FpLzklEiSJ9d6VHOnz+PFi1aaG07e/YsACAwMFCnvKurK9q2bfvIP0tLS7315eXlYevWrdi6dStu3LiB7OxszePU1FRNuYkTJ8Le3h7h4eHYv38/vvvuO8yYMQMTJ07UJC0VKWNIOSIiIqInimhCNm3aJFpaWorffvutePHiRXH8+PGik5OTmJycrLf8n3/+KZqZmYkffviheOnSJXHu3LmihYWF+Ndff4miKIqZmZliaGiouHnzZjE+Pl6MjY0V27dvL7Zp00brPHXr1hXfffdd8c6dO5q/e/fuVbjdWVlZIgAxKytL7/78/Hzx0qVLYn5+foXPWZ7UvFTx38x///vL+Fe8knpFTM1NlXzuR8nKyhIVCoW4adMmre1LliwRfXx8qqTOhIQEEYDev0OHDmmVvXTpkvjMM8+ItWrVEj08PMS33npLLCwsNLiMIeUeRRAEsaioSBQEweBjTYUxX9/VQaVSiXfu3BFVKlV1N4X0YHzkjfGRP8ZI3uQQn0d9jy1LIYqiWH1pjWGCgoLQrl07rFixAkDp1Ka+vr6YMmUKZs2apVN+6NChyM3Nxa5duzTbOnTogMDAQERHR+ut4+TJk2jfvj1u3LiBOnXqACi9YzFt2jS9YwQqIjs7G46OjsjKyoKDg4PO/oKCAiQkJKBevXqwtjas+9OjiPfXiTA3N38sg7f1GTJkCHJzc7F79+5qqV/O5BCfqlaVr+/HQRAEpKSkwMPDo8ILMtLjw/jIG+Mjf4yRvMkhPo/6HluWybyCioqKcPr0aYSGhmq2KZVKhIaGIjY2Vu8xsbGxWuUBICwsrNzyQOkgXIVCobNOwfvvvw9XV1e0atUKH330kVEWbXtSxMXF6R1fQUREREQ1h8kM3k5LS4NKpYKnp6fWdk9Pz3JXO05KStJbPikpSW/5goICzJw5E8OHD9fKyKZOnYrWrVvDxcUFR48exezZs3Hnzh0sW7ZM73kKCwtRWPjf4nbZ2dkASrNOQRB0yguCAFEUNX/Gpj5nddycysnJwb///ouWLVtWS/2moDrj8zioX9flvf7lTv3+NMW2PwkYH3ljfOSPMZI3OcTHkLpNJrGoasXFxRgyZAhEUcSXX36ptS8iIkLz/y1atIClpSVeffVVREVFwcpKd0rWqKgoLFy4UGd7amoqCgoK9NYtCAJKSkqMfidEFEWoVCoAqJauNrVq1dIkWbzLo6u64/M4lJSUQBAE3L17FxYWFtXdHIMJgoCsrCyIoshuAjLE+Mgb4yN/jJG8ySE+OTk5FS5rMomFm5sbzMzMkJycrLU9OTkZXl5eeo/x8vKqUHl1UnHjxg38+uuvj+w/FhQUhJKSEly/fh2NGzfW2T979mytZCQ7Oxu+vr5wd3cvd4xFTk4OzM3NJS8EVx5T/EL3JKnJ8TE3N4dSqYSrq6vJjrFQKBRwd3fnP7oyxPjIG+Mjf4yRvMkhPob8220yiYWlpSXatGmDmJgYhIeHAyh9smNiYjB58mS9xwQHByMmJkZr0PWBAwcQHByseaxOKv755x8cOnQIrq6uj2zL2bNnoVQq4eHhoXe/lZWV3jsZSqVS74tCqVRqTWVrTKIoas5ZU38RN2VPQnzUr+vyXv+mwNTbX9MxPvLG+MgfYyRv1R0fQ+o1mcQCKO2SNHr0aLRt2xbt27fH8uXLkZubq1m5edSoUahduzaioqIAAG+88QZCQkKwdOlS9O3bF5s2bcKpU6ewatUqAKVJxeDBg3HmzBns2rULKpVKM/7CxcUFlpaWiI2NxfHjx9G9e3fY29sjNjYWb775Jl588UWd1aWJiIiIiJ5UJpVYDB06FKmpqZg/fz6SkpIQGBiIvXv3agZoJyYmamVVHTt2xMaNGzF37lzMmTMHDRs2xI4dO9CsWTMAwK1bt7Bz504Auou3HTp0CN26dYOVlRU2bdqEyMhIFBYWol69enjzzTe1ujoRERERET3pTGodC1NV0XUs/Pz8UKtWLaPW/SSsk2DKnoT45Ofn4/r161zHgqoE4yNvjI/8MUbyJof41Mh1LGoy9cDdvLy8am4JkfGpX9c1eYA6ERERmVhXqJrKzMwMTk5OSElJAQDY2NgY7dfrJ+EXcVNWk+MjiiLy8vKQkpICJycnmJmZVXeTiIiIqAoxsZAJ9RS46uTCWNSLqqhnniJ5eRLi4+TkVO6U0ERERFRzMLGQCYVCAW9vb3h4eKC4uNho51UvTObq6sq+kzJU0+NjYWHBOxVERERPCCYWMmNmZmbUL2KCIMDCwgLW1tY18ourqWN8iIiIqKbgNxkiIiIiIpKMiQUREREREUnGxIKIiIiIiCRjYkFERERERJIxsSAiIiIiIsmYWBARERERkWRMLIiIiIiISDImFkREREREJBkTCyIiIiIikoyJBRERERERScbEgoiIiIiIJGNiQUREREREkjGxICIiIiIiyZhYEBERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIkFERERERFJxsSCiIiIiIgkk5xYJCYmIj8/v9z9+fn5SExMlFoNERERERHJmOTEol69eti+fXu5+3fu3Il69epJrYaIiIiIiGRMcmIhiuJD9xcXF0OpZI8rIiIiIqKazLwyB2VnZyMzM1Pz+O7du3q7O2VmZmLTpk3w9vaudAOJiIiIiEj+KpVYfPLJJ3j33XcBAAqFAtOmTcO0adP0lhVFEYsXL650A4mIiIiISP4qlVj06tULdnZ2EEURb7/9NoYPH47WrVtrlVEoFLC1tUWbNm3Qtm1bozSWiIiIiIjkqVKJRXBwMIKDgwEAubm5eP7559G8eXOjNoyIiIiIiExHpRKLshYsWGCMdhARERERkQmTPF3T2bNn8cMPP2ht27dvH7p27YqgoCB8+umnUqsgIiIiIiKZk5xYvP3229i8ebPmcUJCAgYOHIiEhAQAQEREBFatWiW1GiIiIiIikjHJicW5c+fQuXNnzeN169bBzMwMcXFxOH78OAYPHozo6Gip1RARERERkYxJTiyysrLg6uqqebxnzx707NkTbm5uAICePXvi6tWrUqshIiIiIiIZk5xYeHt74/LlywCAO3fu4PTp0+jVq5dm/71797jyNhERERFRDSd5VqgBAwbg888/R0FBAY4fPw4rKysMHDhQs//cuXN4+umnpVZDREREREQyJjmxWLx4MVJTU7F+/Xo4OTlhzZo18PT0BABkZ2dj69ateP311yU3lIiIiIiI5EtyYmFnZ4cNGzaUu+/mzZuwsbGRWg0REREREcmY5MSirHv37uF///sfAMDX1xd2dnZwdHQ0ZhVERERERCRDRhlVffLkSXTv3h3Ozs5o1qwZmjVrBmdnZ/To0QOnTp0yRhVERERERCRjku9YHD9+HN26dYOlpSVeeeUVBAQEAAAuX76MH374AV27dsXhw4fRvn17yY0lIiIiIiJ5kpxYvPPOO6hduzaOHDkCLy8vrX2RkZHo1KkT3nnnHRw4cEBqVUREREREJFOSu0IdP34cr776qk5SAQCenp6YMGECjh07JrUaIiIiIiKSMcmJhVKpRElJSbn7VSoVF8gjIiIiIqrhJH/j79ixI1auXIkbN27o7EtMTMQXX3yBTp06Sa2GiIiIiIhkTPIYiyVLlqBr167w9/fHwIED0ahRIwDAlStX8NNPP8Hc3BxRUVGSG0pERERERPIlObFo1aoVjh8/jnfeeQc7d+5EXl4eAMDGxga9e/fG4sWL0aRJE8kNJSIiIiIi+TLKAnlNmjTB9u3bIQgCUlNTAQDu7u4cW0FERERE9ISQnFiUlJQgLy8PDg4OUCqV8PT01NqfnZ0NGxsbmJsbdZFvIiIiIiKSEcm3FKZOnYqOHTuWu79Tp06YPn261GqIiIiIiEjGJCcWe/fuxeDBg8vdP3jwYOzZs0dqNUREREREJGOSE4vbt2+jdu3a5e738fHBrVu3pFZDREREREQyJjmxcHV1xZUrV8rdf/nyZTg4OEithoiIiIiIZExyYtG7d2989dVXiIuL09l35swZrFq1Cn369JFaDRERERERyZjkqZoWLVqEvXv3on379njuuefQtGlTAMCFCxfw888/w8PDA4sWLZLcUCIiIiIiki/JiYWPjw9OnTqFWbNm4aeffsL27dsBAA4ODhg5ciSWLFkCHx8fyQ0lIiIiIiL5MsriEt7e3li7di1EUdRaIE+hUBjj9EREREREJHNGXbVOoVDAw8PDmKckIiIiIiITIHnwNhERERERERMLIiIiIiKSjIkFERERERFJZnKJxcqVK+Hn5wdra2sEBQXhxIkTDy2/ZcsW+Pv7w9raGs2bN8eePXs0+4qLizFz5kw0b94ctra28PHxwahRo3D79m2tc6Snp2PkyJFwcHCAk5MTxo0bh3v37lXJ9RERERERmSLJiUViYiLy8/PL3Z+fn4/ExESp1QAANm/ejIiICCxYsABnzpxBy5YtERYWhpSUFL3ljx49iuHDh2PcuHGIi4tDeHg4wsPDceHCBQBAXl4ezpw5g3nz5uHMmTPYtm0brly5gueee07rPCNHjsTFixdx4MAB7Nq1C7///jsmTJhglGsiIiIiIqoJFKIoilJOYGZmhvXr12PEiBF692/evBkjRoyASqWSUg0AICgoCO3atcOKFSsAAIIgwNfXF1OmTMGsWbN0yg8dOhS5ubnYtWuXZluHDh0QGBiI6OhovXWcPHkS7du3x40bN1CnTh1cvnwZTZo0wcmTJ9G2bVsAwN69e/Hss8/i5s2bFVqjIzs7G46OjsjKyoKDg0NlLr3SBEFASkoKPDw8oFSa3A2qGo/xkT/GSN4YH3ljfOSPMZI3OcTHkO+xkqebfVReUlxcbJQnoqioCKdPn8bs2bM125RKJUJDQxEbG6v3mNjYWERERGhtCwsLw44dO8qtJysrCwqFAk5OTppzODk5aZIKAAgNDYVSqcTx48cxcOBAnXMUFhaisLBQ8zg7OxtA6YtDEIRHXqsxCYIAURQfe71UMYyP/DFG8sb4yBvjI3+MkbzJIT6G1F2pxCI7OxuZmZmax3fv3tXb3SkzMxObNm2Ct7d3ZarRkpaWBpVKBU9PT63tnp6eiI+P13tMUlKS3vJJSUl6yxcUFGDmzJkYPny4JiNLSkrSWZvD3NwcLi4u5Z4nKioKCxcu1NmempqKgoIC/RdYRQRBQFZWFkRR5C8RMsT4yB9jJG+Mj7wxPvLHGMmbHOKTk5NT4bKVSiw++eQTvPvuuwBKF8WbNm0apk2bpresKIpYvHhxZap5rIqLizFkyBCIoogvv/xS0rlmz56tdackOzsbvr6+cHd3r5auUAqFAu7u7vzAkCHGR/4YI3ljfOSN8ZE/xkje5BAfa2vrCpetVGLRq1cv2NnZQRRFvP322xg+fDhat26tVUahUMDW1hZt2rTR6kZUWW5ubjAzM0NycrLW9uTkZHh5eek9xsvLq0Ll1UnFjRs38Ouvv2p9+ffy8tIZHF5SUoL09PRy67WysoKVlZXOdqVSWS0vCoVCUW1106MxPvLHGMkb4yNvjI/8MUbyVt3xMaTeSiUWwcHBCA4OBgDk5uZi0KBBaNasWWVOVWGWlpZo06YNYmJiEB4eDqA0i4uJicHkyZPLbWdMTIzW3ZQDBw5o2g78l1T8888/OHToEFxdXXXOkZmZidOnT6NNmzYAgF9//RWCICAoKMi4F0lEREREZKIkDd7Oy8vDzp074enpWeWJBQBERERg9OjRaNu2Ldq3b4/ly5cjNzcXY8eOBQCMGjUKtWvXRlRUFADgjTfeQEhICJYuXYq+ffti06ZNOHXqFFatWgWgNKkYPHgwzpw5g127dkGlUmnGTbi4uMDS0hIBAQHo3bs3xo8fj+joaBQXF2Py5MkYNmxYhWaEIiIiIiJ6EkhKLGxsbJCQkACFQmGs9jzU0KFDkZqaivnz5yMpKQmBgYHYu3evZoB2YmKi1u2ajh07YuPGjZg7dy7mzJmDhg0bYseOHZok6NatW9i5cycAIDAwUKuuQ4cOoVu3bgCADRs2YPLkyXjmmWegVCoxaNAgfPbZZ1V/wUREREREJkLyOhYjRoxAQUEBtm3bZqw21Thcx4LKw/jIH2Mkb4yPvDE+8scYyZsc4mPI91jJLZw3bx7+/vtvvPTSSzhy5Ahu3bqF9PR0nT8iIiIiIqq5JC+Q17RpUwDApUuXsHHjxnLLGWPlbSIiIiIikifJicX8+fMf2xgLIiIiIiKSJ8mJRWRkpBGaQUREREREpkxyYlHWvXv38L///Q8A4OvrCzs7O2OenoiIiIiIZMoow8tPnjyJ7t27w9nZGc2aNUOzZs3g7OyMHj164NSpU8aogoiIiIiIZEzyHYvjx4+jW7dusLS0xCuvvIKAgAAAwOXLl/HDDz+ga9euOHz4MNq3by+5sUREREREJE+SE4t33nkHtWvXxpEjR+Dl5aW1LzIyEp06dcI777yDAwcOSK2KiIiIiIhkSnJXqOPHj+PVV1/VSSoAwNPTExMmTMCxY8ekVkNERERERDImObFQKpUoKSkpd79KpeJKjkRERERENZzkb/wdO3bEypUrcePGDZ19iYmJ+OKLL9CpUyep1RARERERkYxJHmOxZMkSdO3aFf7+/hg4cCAaNWoEALhy5Qp++uknmJubIyoqSnJDiYiIiIhIviQnFq1atcLx48fxzjvvYOfOncjLywMA2NjYoHfv3li8eDGaNGkiuaFERERERCRfRlkgr0mTJti+fTsEQUBqaioAwN3dnWMriIiIiIieEEZbefvChQvYs2cPrl+/DgCoV68e+vTpg2bNmhmrCiIiIiIikinJiUVhYSFeffVVrF+/HqIoau5SCIKAWbNmYeTIkfjmm29gaWkpubFERERERCRPkvsqzZw5E+vWrcNrr72Gy5cvo6CgAIWFhbh8+TImTpyI77//Hm+//bYx2kpERERERDIl+Y7F999/j5deegkrVqzQ2t64cWOsXLkS2dnZ+P7777F8+XKpVRERERERkUxJvmNRXFyMDh06lLu/Y8eOD11Aj4iIiIiITJ/kxCIsLAz79u0rd//evXvRq1cvqdUQEREREZGMSe4KtWjRIgwZMgTPP/88Xn/9dTRo0AAA8M8//2hW5N68eTPS09O1jnNxcZFaNRERERERyYTkxCIgIAAA8Ndff+Gnn37S2ieKIgDoXSBPpVJJrZqIiIiIiGRCcmIxf/58KBQKY7SFiIiIiIhMlOTEIjIy0gjNICIiIiIiUyZ58DYREREREZHkxOLs2bP44YcftLbt27cPXbt2RVBQED799FOpVRARERERkcxJTizefvttbN68WfM4ISEBAwcOREJCAgAgIiICq1atkloNERERERHJmOTE4ty5c+jcubPm8bp162BmZoa4uDgcP34cgwcPRnR0tNRqiIiIiIhIxiQnFllZWXB1ddU83rNnD3r27Ak3NzcAQM+ePXH16lWp1RARERERkYxJTiy8vb1x+fJlAMCdO3dw+vRprZW27927B6WSY8SJiIiIiGoyydPNDhgwAJ9//jkKCgpw/PhxWFlZYeDAgZr9586dw9NPPy21GiIiIiIikjHJicXixYuRmpqK9evXw8nJCWvWrIGnpycAIDs7G1u3bsXrr78uuaFERERERCRfkhMLOzs7bNiwodx9N2/ehI2NjdRqiIiIiIhIxiQnFmXdu3cPGRkZEEVRa3tWVhbq1KljzKqIiIiIiEhGJCcWBQUFWLhwIVavXo27d++WW06lUkmtioiIiIiIZEpyYjFp0iSsXbsW4eHh6NKlC5ydnY3RLiIiIiIiMiGSE4tt27bhlVdewVdffWWM9hARERERkQmSvMCEQqFA69atjdEWIiIiIiIyUZITiwEDBuDgwYPGaAsREREREZkoyYnFvHnzcO3aNUyYMAGnT59Gamoq0tPTdf6IiIiIiKjmkjzGomHDhgCAuLg4rF69utxynBWKiIiIiKjmkpxYzJ8/HwqFwhhtISIiIiIiEyU5sYiMjDRCM4iIiIiIyJRJHmNBRERERETExIKIiIiIiCST3BUKAAoKCvB///d/OHPmDLKysiAIgtZ+hULx0IHdRERERERk2iQnFjdu3ED37t1x/fp1ODk5ISsrCy4uLsjMzIRKpYKbmxvs7OyM0VYiIiIiIpIpyV2hZsyYgaysLBw7dgx///03RFHE5s2bce/ePXzwwQeoVasW9u3bZ4y2EhERERGRTElOLH799VdMmjQJ7du3h1JZejpRFGFlZYUZM2bgmWeewbRp06RWQ0REREREMiY5scjLy4Ofnx8AwMHBAQqFAllZWZr9wcHBOHLkiNRqiIiIiIhIxiQnFnXq1MHNmzcBAObm5qhduzaOHTum2X/p0iVYW1tLrYaIiIiIiGRM8uDtHj164KeffsKCBQsAAGPGjEFUVBQyMjIgCALWr1+PUaNGSW4oERERERHJl+TEYtasWTh58iQKCwthZWWFOXPm4Pbt29i6dSvMzMwwYsQILFu2zBhtJSIiIiIimZLUFSovLw8DBw5EamoqrKysAADW1tb45ptvkJGRgbS0NKxZswYODg5GaSwREREREcmTpMTCxsYGCQkJUCgUxmoPERERERGZIMmDt3v37s11KoiIiIiInnCSE4t58+bh77//xksvvYQjR47g1q1bSE9P1/kjIiIiIqKaS/Lg7aZNmwIonVZ248aN5ZZTqVRSqyIiIiIiIpmSnFjMnz+fYyyIiIiIiJ5wkhOLyMhIIzSDiIiIiIhMmeQxFomJicjPzy93f35+PhITE6VWQ0REREREMiY5sahXrx62b99e7v6dO3eiXr16UqshIiIiIiIZk5xYiKL40P3FxcVQKiVXo7Fy5Ur4+fnB2toaQUFBOHHixEPLb9myBf7+/rC2tkbz5s2xZ88erf3btm1Dr1694OrqCoVCgbNnz+qco1u3blAoFFp/EydONNo1ERERERGZukp948/OzkZiYqKmi9Pdu3c1j8v+nT9/Hps2bYK3t7dRGrt582ZERERgwYIFOHPmDFq2bImwsDCkpKToLX/06FEMHz4c48aNQ1xcHMLDwxEeHo4LFy5oyuTm5qJz58744IMPHlr3+PHjcefOHc3fhx9+aJRrIiIiIiKqCSo1ePuTTz7Bu+++CwBQKBSYNm0apk2bpresKIpYvHhxpRtY1rJlyzB+/HiMHTsWABAdHY3du3fj22+/xaxZs3TKf/rpp+jduzdmzJgBAFi0aBEOHDiAFStWIDo6GgDw0ksvAQCuX7/+0LptbGzg5eVllOsgIiIiIqppKpVY9OrVC3Z2dhBFEW+//TaGDx+O1q1ba5VRKBSwtbVFmzZt0LZtW8kNLSoqwunTpzF79mzNNqVSidDQUMTGxuo9JjY2FhEREVrbwsLCsGPHDoPr37BhA77//nt4eXmhf//+mDdvHmxsbAw+DxERERFRTVSpxCI4OBjBwcEASrsSDRo0CM2aNTNqwx6UlpYGlUoFT09Pre2enp6Ij4/Xe0xSUpLe8klJSQbVPWLECNStWxc+Pj44f/48Zs6ciStXrmDbtm16yxcWFqKwsFDzODs7GwAgCAIEQTCobqkEQYAoio+9XqoYxkf+GCN5Y3zkjfGRP8ZI3uQQH0PqlryOxYIFC6SeQvYmTJig+f/mzZvD29sbzzzzDP7991/Ur19fp3xUVBQWLlyosz01NRUFBQVV2tYHCYKArKwsiKJo1EH0ZByMj/wxRvLG+Mgb4yN/jJG8ySE+OTk5FS4rObF4XNzc3GBmZobk5GSt7cnJyeWOffDy8jKofEUFBQUBAK5evao3sZg9e7ZWF6zs7Gz4+vrC3d0dDg4Okuo2lCAIUCgUcHd35weGDDE+8scYyRvjI2+Mj/wxRvImh/hYW1tXuKzJJBaWlpZo06YNYmJiEB4eDqD0yY6JicHkyZP1HhMcHIyYmBitgeUHDhzQdOOqLPWUtOXNdmVlZQUrKyud7UqlslpeFAqFotrqpkdjfOSPMZI3xkfeGB/5Y4zkrbrjY0i9JpNYAEBERARGjx6Ntm3bon379li+fDlyc3M1s0SNGjUKtWvXRlRUFADgjTfeQEhICJYuXYq+ffti06ZNOHXqFFatWqU5Z3p6OhITE3H79m0AwJUrVwCU3u3w8vLCv//+i40bN+LZZ5+Fq6srzp8/jzfffBNdu3ZFixYtHvMzQEREREQkTyaVWAwdOhSpqamYP38+kpKSEBgYiL1792oGaCcmJmplVR07dsTGjRsxd+5czJkzBw0bNsSOHTu0Bprv3LlTk5gAwLBhwwCUjh2JjIyEpaUlDh48qElifH19MWjQIMydO/cxXTURERERkfwpxEctnU2SZWdnw9HREVlZWdUyxiIlJQUeHh68xSlDjI/8MUbyxvjIG+Mjf4yRvMkhPoZ8jzWohXl5eYiLi9M7OvzPP/80rJVERERERFRjVDixOHbsGOrWrYt+/frB09NTZzXtPn366BzDRISIiIiI6MlQ4cQiIiICK1aswK1bt3Du3Dns2rULo0aNgron1YM9qiqTiBARERERkWmqcGJx6dIlDB06FADQsGFDHD58GOnp6Rg4cCCKiop0yhuaiBARERERkemqcGLh6OiIW7duaR5bW1tjx44dqFWrFsLCwnSW+zY0ESEiIiIiItNV4cQiNDQU3333ndY2c3NzbNy4EQ0aNEB+fr7WPkMTESIiIiIiMl0VTiy+/PJLRERE6GxXKBT4+uuvcf36da3thiYiRERERERkuiq8QJ6lpSUsLS3L3V+nTh2tx19++SVKSkp0yqkTkXnz5hnQTCIiIiIikrMqW3nb0ESEiIiIiIhMV6WW8Js+fToaNmyIjh074vXXX8e2bdtQWFiot+y2bds4AxQRERERUQ1XqcRiz549ePHFFzFy5Ejk5uZi6tSp8Pb2xpQpU3Djxg2tsidPnsTUqVON0lgiIiIiIpKnSiUWM2fOxPvvv4/9+/ejS5cuOHz4MLZv3460tDQ0adIEr7zyCu7cuQMAWLJkCfbs2WPURhMRERERkbxUKrEYM2YMLly4gBYtWuDzzz+Hv78/Bg0ahKSkJHTo0AFr165Fo0aN8PrrryMoKAgNGjQwdruJiIiIiEhGKj14u379+li0aBEWLVqEu3fvIjY2FidPnsSVK1cQEhICpVKJEydO4OzZs9i9e7cx20xERERERDJjlFmhXF1d0a9fP/Tr109n36pVqxAdHY1evXoZoyoiIiIiIpKhSnWFMsSECRPwww8/VHU1RERERERUjao8sQAAKyurx1ENERERERFVk8eSWBARERERUc3GxIKIiIiIiCRjYkFERERERJIxsSAiIiIiIsmMMt2s2r1795CRkQFRFHX21alTx5hVERERERGRjEhOLAoKCrBw4UKsXr0ad+/eLbecSqWSWhUREREREcmU5MRi0qRJWLt2LcLDw9GlSxc4Ozsbo11ERERERGRCJCcW27ZtwyuvvIKvvvrKGO0hIiIiIiITJHnwtkKhQOvWrY3RFiIiIiIiMlGSE4sBAwbg4MGDxmgLERERERGZKMmJxbx583Dt2jVMmDABp0+fRmpqKtLT03X+iIiIiIio5pI8xqJhw4YAgLi4OKxevbrccpwVioiIiIio5pKcWMyfPx8KhcIYbSEiIiIiIhMlObGIjIw0QjOIiIiIiMiUSR5jQUREREREJPmOBVC6+vb//d//4cyZM8jKyoIgCFr7FQrFQ8dfEBERERGRaZOcWNy4cQPdu3fH9evX4eTkhKysLLi4uCAzMxMqlQpubm6ws7MzRluJiIiIiEimJHeFmjFjBrKysnDs2DH8/fffEEURmzdvxr179/DBBx+gVq1a2LdvnzHaSkREREREMiU5sfj1118xadIktG/fHkpl6elEUYSVlRVmzJiBZ555BtOmTZNaDRERERERyZjkxCIvLw9+fn4AAAcHBygUCmRlZWn2BwcH48iRI1KrISIiIiIiGZOcWNSpUwc3b94EAJibm6N27do4duyYZv+lS5dgbW0ttRoiIiIiIpIxyYO3e/TogZ9++gkLFiwAAIwZMwZRUVHIyMiAIAhYv349Ro0aJbmhREREREQkX5ITi1mzZuHkyZMoLCyElZUV5syZg9u3b2Pr1q0wMzPDiBEjsGzZMmO0lYiIiIiIZEpyYlGnTh3UqVNH89ja2hrffPMNvvnmG6mnJiIiIiIiE8GVt4mIiIiISDKjrLwNAIWFhThz5gxSUlLQqVMnuLm5GevUREREREQkc0a5Y/HZZ5/B29sbnTt3xvPPP4/z588DANLS0uDm5oZvv/3WGNUQEREREZFMSU4svvvuO0ybNg29e/fG6tWrIYqiZp+bmxt69OiBTZs2Sa2GiIiIiIhkTHJisXTpUgwYMAAbN25E//79dfa3adMGFy9elFoNERERERHJmOTE4urVq+jTp0+5+11cXHD37l2p1RARERERkYxJTiycnJyQlpZW7v5Lly7By8tLajVERERERCRjkhOLZ599FqtWrUJmZqbOvosXL+Lrr7/Gc889J7UaIiIiIiKSMcmJxeLFi6FSqdCsWTPMnTsXCoUCa9euxYsvvoi2bdvCw8MD8+fPN0ZbiYiIiIhIpiQnFj4+Pjh9+jR69+6NzZs3QxRFrF+/Hj///DOGDx+OY8eOcU0LIiIiIqIazigL5Hl4eOCbb77BN998g9TUVAiCAHd3dyiVXNibiIiIiOhJYLSVt9Xc3d2NfUoiIiIiIpI5gxMLpVIJhUJhcEUqlcrgY4iIiIiIyDQYnFjMnz9fJ7HYvn07Ll68iLCwMDRu3BgAEB8fj/3796NZs2YIDw83SmOJiIiIiEieDE4sIiMjtR6vWrUKKSkpuHDhgiapULt8+TJ69OgBHx8fSY0kIiIiIiJ5kzy6+qOPPsLkyZN1kgoACAgIwOTJk/Hhhx9KrYaIiIiIiGRMcmJx8+ZNWFhYlLvfwsICN2/elFoNERERERHJmOTEolmzZvjiiy9w69YtnX03b97EF198gebNm0uthoiIiIiIZEzydLOffPIJwsLC0KhRIwwcOBANGjQAAPzzzz/YsWMHRFHE999/L7mhREREREQkX5ITi86dO+P48eOYN28etm/fjvz8fABArVq1EBYWhoULF/KOBRERERFRDWeUBfKaNWuG7du3QxAEpKamAgBX3iYiIiIieoIYdeVtpVIJT09PY56SiIiIiIhMgMG3FBITE5GYmKjz+FF/xrJy5Ur4+fnB2toaQUFBOHHixEPLb9myBf7+/rC2tkbz5s2xZ88erf3btm1Dr1694OrqCoVCgbNnz+qco6CgAK+//jpcXV1hZ2eHQYMGITk52WjXRERERERk6gxOLPz8/FCvXj0UFRVpPX7UnzFs3rwZERERWLBgAc6cOYOWLVsiLCwMKSkpessfPXoUw4cPx7hx4xAXF4fw8HCEh4fjwoULmjK5ubno3LkzPvjgg3LrffPNN/Hzzz9jy5Yt+O2333D79m08//zzRrkmIiIiIqKaQCGKomjIAWvWrIFCocCoUaOgUCg0jx9l9OjRlW6kWlBQENq1a4cVK1YAAARBgK+vL6ZMmYJZs2bplB86dChyc3Oxa9cuzbYOHTogMDAQ0dHRWmWvX7+OevXqIS4uDoGBgZrtWVlZcHd3x8aNGzF48GAAQHx8PAICAhAbG4sOHTo8st3Z2dlwdHREVlYWHBwcKnPplSYIAlJSUuDh4cExLzLE+MgfYyRvjI+8MT7yxxjJmxziY8j3WIPHWIwZM+ahj6tKUVERTp8+jdmzZ2u2KZVKhIaGIjY2Vu8xsbGxiIiI0NoWFhaGHTt2VLje06dPo7i4GKGhoZpt/v7+qFOnTrmJRWFhIQoLCzWPs7OzAZS+OARBqHDdxiAIAkRRfOz1UsUwPvLHGMkb4yNvjI/8MUbyJof4GFK3wYlFZcdL1KlTp1LHqaWlpUGlUukMDvf09ER8fLzeY5KSkvSWT0pKqnC9SUlJsLS0hJOTU4XPExUVhYULF+psT01NRUFBQYXrNgZBEJCVlQVRFPlLhAwxPvLHGMkb4yNvjI/8MUbyJof45OTkVLiswYmFn59fhbo+PUilUhl8jKmaPXu21p2S7Oxs+Pr6wt3dvVq6QikUCk7/K1OMj/wxRvLG+Mgb4yN/jJG8ySE+1tbWFS5rcGLx7bffViqxkMrNzQ1mZmY6szElJyfDy8tL7zFeXl4GlS/vHEVFRcjMzNS6a/Gw81hZWcHKykpnu1KprJYXhUKhqLa66dEYH/ljjOSN8ZE3xkf+GCN5q+74GFKv5DEWj4ulpSXatGmDmJgYhIeHAyjN4mJiYjB58mS9xwQHByMmJgbTpk3TbDtw4ACCg4MrXG+bNm1gYWGBmJgYDBo0CABw5coVJCYmGnQeIiIiIqKaTPICeSUlJcjLyyu3i092djZsbGxgbi59Lb6IiAiMHj0abdu2Rfv27bF8+XLk5uZi7NixAIBRo0ahdu3aiIqKAgC88cYbCAkJwdKlS9G3b19s2rQJp06dwqpVqzTnTE9PR2JiIm7fvg2gNGkASu9UeHl5wdHREePGjUNERARcXFzg4OCAKVOmIDg4uEIzQhERERERPQkk31OZOnUqOnbsWO7+Tp06Yfr06VKrAVA6fezHH3+M+fPnIzAwEGfPnsXevXs1A7QTExNx584dTfmOHTti48aNWLVqFVq2bImtW7dix44daNasmabMzp070apVK/Tt2xcAMGzYMLRq1UprOtpPPvkE/fr1w6BBg9C1a1d4eXlh27ZtRrkmIiIiIqKawOB1LB709NNPY9SoUYiMjNS7f+HChfj+++/xzz//SKnGpHEdCyoP4yN/jJG8MT7yxvjIH2Mkb3KIjyHfYyW38Pbt26hdu3a5+318fHDr1i2p1RARERERkYxJTixcXV014xL0uXz58mP/lZ6IiIiIiB4vyYlF79698dVXXyEuLk5n35kzZ7Bq1Sr06dNHajVERERERCRjkqdqWrRoEfbu3Yv27dvjueeeQ9OmTQEAFy5cwM8//wwPDw8sWrRIckOJiIiIiEi+JCcWPj4+OHXqFGbNmoWffvoJ27dvBwA4ODhg5MiRWLJkCXx8fCQ3lIiIiIiI5Ev64hIAvL29sXbtWoiiiNTUVACAu7t7tazQTUREREREj59REgs1hUIBDw8PY56SiIiIiIhMACcsJiIiIiIiySTfsVAqlRXq8qRSqaRWRUREREREMiU5sZg/f75OYqFSqXD9+nXs2LEDjRs3Rr9+/aRWQ0REREREMiY5sYiMjCx33507d9ChQwc0atRIajVERERERCRjVTrGwtvbGxMnTuQ6FkRERERENVyVD962tbVFQkJCVVdDRERERETVqEoTiwsXLuCzzz5jVygiIiIiohpO8hiLevXq6Z0VKjMzE1lZWbCxscGOHTukVkNERERERDImObEICQnRSSwUCgWcnZ1Rv359DBs2DC4uLlKrISIiIiIiGZOcWKxZs8YIzSAiIiIiIlNWqTEWJ06cQHp6eoXKJiQkYN26dZWphoiIiIiITESlEovg4GDs3btX8zg9PR02Njb47bffdMoePXoUY8eOrXwLiYiIiIhI9iqVWIiiqPO4oKAAKpXKKI0iIiIiIiLTUuXrWBARERERUc3HxIKIiIiIiCRjYkFERERERJJVerrZ69ev48yZMwCArKwsAMA///wDJycnrXIJCQmVbx0REREREZmESicW8+bNw7x587S2TZo0SaecKIp6V+YmIiIiIqKao1KJxXfffWfsdhARERERkQmrVGIxevRoY7eDiIiIiIhMGAdvExERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIkFERERERFJxsSCiIiIiIgkY2JBRERERESSMbEgIiIiIiLJmFgQEREREZFkTCyIiIiIiEgyJhZERERERCQZEwsiIiIiIpKMiQUREREREUnGxIKIiIiIiCRjYkFERERERJIxsSAiIiIiIsmYWBARERERkWRMLIiIiIiISDImFkREREREJBkTCyIiIiIikoyJBRERERERScbEgoiIiIiIJGNiQUREREREkjGxICIiIiIiyZhYEBERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSzLy6G0BERPISuDYQKqhgBjOcHX22uptDREQmgncsiIhIIysrCyqoAAAqqJCVlVXNLSIiIlPBxIKIiDRCdoQ89DEREVF5mFgQEREA7bsVarxrQUREFcXEgoiIAJR/d4J3LYiIqCKYWBARkd67FWq8a0FERBXBxIKIiHDi7glJ+4mIiEwusVi5ciX8/PxgbW2NoKAgnDjx8H/stmzZAn9/f1hbW6N58+bYs2eP1n5RFDF//nx4e3ujVq1aCA0NxT///KNVxs/PDwqFQuvv/fffN/q1ERFVl55P90RL15ZwMHfQ+Wvt1ho9n+5Z3U0kIiKZM6l1LDZv3oyIiAhER0cjKCgIy5cvR1hYGK5cuQIPDw+d8kePHsXw4cMRFRWFfv36YePGjQgPD8eZM2fQrFkzAMCHH36Izz77DGvXrkW9evUwb948hIWF4dKlS7C2ttac691338X48eM1j+3t7av+gomIHqPv+31f3U0gIiITZlJ3LJYtW4bx48dj7NixaNKkCaKjo2FjY4Nvv/1Wb/lPP/0UvXv3xowZMxAQEIBFixahdevWWLFiBYDSuxXLly/H3LlzMWDAALRo0QLr1q3D7du3sWPHDq1z2dvbw8vLS/Nna2tb1ZdLRERERGQyTCaxKCoqwunTpxEaGqrZplQqERoaitjYWL3HxMbGapUHgLCwME35hIQEJCUlaZVxdHREUFCQzjnff/99uLq6olWrVvjoo49QUlJirEsjIiIiIjJ5JtMVKi0tDSqVCp6enlrbPT09ER8fr/eYpKQkveWTkpI0+9XbyisDAFOnTkXr1q3h4uKCo0ePYvbs2bhz5w6WLVumt97CwkIUFhZqHmdnZwMABEGAIAgVuVyjEQQBoig+9nqpYhgf+WOM5I3xkTfGR/4YI3mTQ3wMqdtkEovqFBERofn/Fi1awNLSEq+++iqioqJgZWWlUz4qKgoLFy7U2Z6amoqCgoIqbeuDBEFAVlYWRFGEUmkyN6ieGIyP/DFG8sb4yBvjI3+MkbzJIT45OTkVLmsyiYWbmxvMzMyQnJystT05ORleXl56j/Hy8npoefV/k5OT4e3trVUmMDCw3LYEBQWhpKQE169fR+PGjXX2z549WysZyc7Ohq+vL9zd3eHg4PDwCzUyQRCgUCjg7u7ODwwZYnzkjzGSN8ZH3hgf+WOM5E0O8Sk7mdGjmExiYWlpiTZt2iAmJgbh4eEASp/smJgYTJ48We8xwcHBiImJwbRp0zTbDhw4gODgYABAvXr14OXlhZiYGE0ikZ2djePHj+O1114rty1nz56FUqnUOxMVAFhZWem9k6FUKqvlRaFQKKqtbno0xkf+GCN5Y3zkjfGRP8ZI3qo7PobUazKJBVDaJWn06NFo27Yt2rdvj+XLlyM3Nxdjx44FAIwaNQq1a9dGVFQUAOCNN95ASEgIli5dir59+2LTpk04deoUVq1aBaA0UNOmTcPixYvRsGFDzXSzPj4+muQlNjYWx48fR/fu3WFvb4/Y2Fi8+eabePHFF+Hs7FwtzwMRERERkdyYVGIxdOhQpKamYv78+UhKSkJgYCD27t2rGXydmJiolVV17NgRGzduxNy5czFnzhw0bNgQO3bs0KxhAQBvv/02cnNzMWHCBGRmZqJz587Yu3ev5raPlZUVNm3ahMjISBQWFqJevXp48803tbo6ERERERE96RSiKIrV3YiaLjs7G46OjsjKyqqWMRYpKSnw8PDgLU4ZYnzkjzGSN8ZH3hgf+WOM5E0O8THkeyxfQUREREREJBkTCyIiIiIikoyJBRERERERScbEgoiIiIiIJGNiQUREREREkjGxICIiIiIiyZhYEBERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIkFERERERFJxsSCiIiIiIgkY2JBRERERESSMbEgIiIiIiLJmFgQEREREZFkTCyIiIiIiEgyJhZERERERCQZEwsiIiIiIpKMiQUREREREUnGxIKIiIiIiCRjYkFERERERJIxsSAiIiIiIsmYWBARERERkWRMLIiIiIiISDImFkREREREJBkTCyIiIiIikoyJBRERERERScbEgoiIiIiIJGNiQUREREREkjGxICIiIiIiyZhYEBERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIkFERERERFJxsSCiIiIiIgkY2JBRERERESSMbEgIiIiIiLJmFgQEREREZFkTCyIiIiIiEgyJhZERERERCQZEwsiIiIiIpKMiQUREREREUnGxIKIiIiIiCRjYkFERERERJIxsSAiIiIiIsmYWBARERERkWRMLIiIiIiISDImFkREREREJBkTCyIiIiIikoyJBRERERERScbEgoiIiIiIJGNiQUREREREkjGxICIiIiIiyZhYEBERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIkFERERERFJxsSCiIiIiIgkY2JBRERERESSMbEgIiIiIiLJmFgQEREREZFkJpdYrFy5En5+frC2tkZQUBBOnDjx0PJbtmyBv78/rK2t0bx5c+zZs0drvyiKmD9/Pry9vVGrVi2Ehobin3/+0SqTnp6OkSNHwsHBAU5OThg3bhzu3btn9GsjIiIiIjJVJpVYbN68GREREViwYAHOnDmDli1bIiwsDCkpKXrLHz16FMOHD8e4ceMQFxeH8PBwhIeH48KFC5oyH374IT777DNER0fj+PHjsLW1RVhYGAoKCjRlRo4ciYsXL+LAgQPYtWsXfv/9d0yYMKHKr5eIiIiIyFQoRFEUq7sRFRUUFIR27dphxYoVAABBEODr64spU6Zg1qxZOuWHDh2K3Nxc7Nq1S7OtQ4cOCAwMRHR0NERRhI+PD6ZPn4633noLAJCVlQVPT0+sWbMGw4YNw+XLl9GkSROcPHkSbdu2BQDs3bsXzz77LG7evAkfH59Htjs7OxuOjo7IysqCg4ODMZ6KChMEASkpKfDw8IBSaVJ55BOB8ZE/xkjeGB95Y3zkjzGSNznEx5DvseaPqU2SFRUV4fTp05g9e7Zmm1KpRGhoKGJjY/UeExsbi4iICK1tYWFh2LFjBwAgISEBSUlJCA0N1ex3dHREUFAQYmNjMWzYMMTGxsLJyUmTVABAaGgolEoljh8/joEDB+rUW1hYiMLCQs3j7OxsAKUvDkEQDL94CQRBgCiKj71eqhjGR/4YI3ljfOSN8ZE/xkje5BAfQ+o2mcQiLS0NKpUKnp6eWts9PT0RHx+v95ikpCS95ZOSkjT71dseVsbDw0Nrv7m5OVxcXDRlHhQVFYWFCxfqbE9NTdXqYvU4CIKArKwsiKLIXyJkiPGRP8ZI3hgfeWN85I8xkjc5xCcnJ6fCZU0msTAls2fP1rpTkp2dDV9fX7i7u1dLVyiFQgF3d3d+YMgQ4yN/jJG8MT7yxvjIH2Mkb3KIj7W1dYXLmkxi4ebmBjMzMyQnJ2ttT05OhpeXl95jvLy8Hlpe/d/k5GR4e3trlQkMDNSUeXBweElJCdLT08ut18rKClZWVjrblUpltbwoFApFtdVNj8b4yB9jJG+Mj7wxPvLHGMlbdcfHkHpN5hVkaWmJNm3aICYmRrNNEATExMQgODhY7zHBwcFa5QHgwIEDmvL16tWDl5eXVpns7GwcP35cUyY4OBiZmZk4ffq0psyvv/4KQRAQFBRktOsjIiIiIjJlJnPHAgAiIiIwevRotG3bFu3bt8fy5cuRm5uLsWPHAgBGjRqF2rVrIyoqCgDwxhtvICQkBEuXLkXfvn2xadMmnDp1CqtWrQJQmgFOmzYNixcvRsOGDVGvXj3MmzcPPj4+CA8PBwAEBASgd+/eGD9+PKKjo1FcXIzJkydj2LBhFZoRioiIiIjoSWBSicXQoUORmpqK+fPnIykpCYGBgdi7d69m8HViYqLW7ZqOHTti48aNmDt3LubMmYOGDRtix44daNasmabM22+/jdzcXEyYMAGZmZno3Lkz9u7dq9WfbMOGDZg8eTKeeeYZKJVKDBo0CJ999tnju3AiIiIiIpkzqXUsTBXXsaDyMD7yxxjJG+Mjb4yP/DFG8iaH+BjyPZavICIiIiIikoyJBRERERERScbEgoiIiIiIJGNiQUREREREkjGxICIiIiIiyZhYEBERERGRZEwsiIiIiIhIMiYWREREREQkGRMLIiIiIiKSjIkF0f+3d/dRUdZZHMC/MwMMiDIYgjCCwmqpkOuuJhw3i/ZE+VISVqJSvuzuyRBot87q1tam06ZRmtsaa+Juu3p60UyPK2ZaYMuaJCibWuErrOiW+H5EcBElnrt/FLOOMwzz9gzD8P2cw1Ge5/J77v3dA871mRmIiIiIyG0cLIiIiIiIyG0cLIiIiIiIyG0cLIiIiIiIyG0cLIiIiIiIyG0cLIiIiIiIyG0BnZ1AdyAiAICGhgavX1tRFDQ2NiI4OBhaLedIX8P++D72yLexP76N/fF97JFv84X+tD1+bXs8aw8HCy9obGwEAMTFxXVyJkREREREzmtsbITBYLAboxFHxg9yi6IoqKurQ69evaDRaLx67YaGBsTFxeHrr79GWFiYV69NHWN/fB975NvYH9/G/vg+9si3+UJ/RASNjY0wGo0d3jXhHQsv0Gq1iI2N7dQcwsLC+APDh7E/vo898m3sj29jf3wfe+TbOrs/Hd2paMMn0xERERERkds4WBARERERkds4WPg5vV6PBQsWQK/Xd3YqZAP74/vYI9/G/vg29sf3sUe+rav1hy/eJiIiIiIit/GOBRERERERuY2DBRERERERuY2DBRERERERuY2DRRe0fPlyxMfHIzg4GCkpKdizZ4/d+PXr12PIkCEIDg7GsGHDsHXrVovzIoL58+cjJiYGISEhSEtLQ3V1tZol+DVP9qelpQVPP/00hg0bhtDQUBiNRsyYMQN1dXVql+G3PP39c73s7GxoNBr88Y9/9HDW3YsaPTp06BDS09NhMBgQGhqKUaNG4T//+Y9aJfg1T/fn8uXLyMvLQ2xsLEJCQpCYmIjCwkI1S/BrzvTnwIEDeOihhxAfH2/3Z5ezPSf7PN2j/Px8jBo1Cr169UJUVBQyMjJw5MgRFSuwQ6hLee+99yQoKEj+9re/yYEDB+Sxxx6T8PBwOXPmjM34zz77THQ6nSxevFgOHjwov/vd7yQwMFC++uorc8zLL78sBoNBNm3aJF988YWkp6dLQkKCXLlyxVtl+Q1P96e+vl7S0tJk3bp1cvjwYSkvL5fk5GQZOXKkN8vyG2p8/7TZuHGjDB8+XIxGo7z22msqV+K/1OhRTU2N3HTTTTJv3jzZu3ev1NTUSFFRUbtrUvvU6M9jjz0mAwcOlNLSUqmtrZWVK1eKTqeToqIib5XlN5ztz549e2Tu3Lmydu1aiY6Otvmzy9k1yT41ejR27FhZtWqVVFVVyf79+2XChAnSv39/uXz5ssrVWONg0cUkJydLbm6u+fPW1lYxGo2Sn59vMz4zM1Puu+8+i2MpKSny+OOPi4iIoigSHR0tS5YsMZ+vr68XvV4va9euVaEC/+bp/tiyZ88eASAnTpzwTNLdiFr9+eabb6Rfv35SVVUlAwYM4GDhBjV6NGXKFHn00UfVSbibUaM/SUlJ8vvf/94iZsSIEfLcc895MPPuwdn+XK+9n13urEnW1OjRjc6ePSsAZMeOHe6k6hI+FaoLuXbtGj7//HOkpaWZj2m1WqSlpaG8vNzm15SXl1vEA8DYsWPN8bW1tTh9+rRFjMFgQEpKSrtrkm1q9MeWS5cuQaPRIDw83CN5dxdq9UdRFEyfPh3z5s1DUlKSOsl3E2r0SFEUfPjhh7jlllswduxYREVFISUlBZs2bVKtDn+l1vfQT37yE2zevBknT56EiKC0tBRHjx7Fvffeq04hfsqV/nTGmt2Zt/bz0qVLAICbbrrJY2s6ioNFF3L+/Hm0traib9++Fsf79u2L06dP2/ya06dP241v+9OZNck2Nfpzo+bmZjz99NOYNm0awsLCPJN4N6FWf1555RUEBATgl7/8peeT7mbU6NHZs2dx+fJlvPzyyxg3bhyKi4sxadIkPPjgg9ixY4c6hfgptb6HCgoKkJiYiNjYWAQFBWHcuHFYvnw57rzzTs8X4cdc6U9nrNmdeWM/FUXBk08+idtvvx233nqrR9Z0RoDXr0hELmlpaUFmZiZEBCtWrOjsdAjA559/jmXLlmHv3r3QaDSdnQ7ZoCgKAOCBBx7AU089BQD40Y9+hF27dqGwsBCpqamdmR7hu8GioqICmzdvxoABA/Dpp58iNzcXRqPR6m4HEdmXm5uLqqoqlJWVdcr1eceiC+nTpw90Oh3OnDljcfzMmTOIjo62+TXR0dF249v+dGZNsk2N/rRpGypOnDiBkpIS3q1wgRr92blzJ86ePYv+/fsjICAAAQEBOHHiBH79618jPj5elTr8mRo96tOnDwICApCYmGgRM3ToUL4rlJPU6M+VK1fw7LPP4g9/+AMmTpyIH/7wh8jLy8OUKVPw6quvqlOIn3KlP52xZnem9n7m5eVhy5YtKC0tRWxsrNvruYKDRRcSFBSEkSNH4pNPPjEfUxQFn3zyCUaPHm3za0aPHm0RDwAlJSXm+ISEBERHR1vENDQ0YPfu3e2uSbap0R/g/0NFdXU1tm/fjoiICHUK8HNq9Gf69On48ssvsX//fvOH0WjEvHnz8PHHH6tXjJ9So0dBQUEYNWqU1VsvHj16FAMGDPBwBf5Njf60tLSgpaUFWq3lwxGdTme+20SOcaU/nbFmd6bWfooI8vLy8Pe//x3/+Mc/kJCQ4Il0XU6GupD33ntP9Hq9rF69Wg4ePCizZ8+W8PBwOX36tIiITJ8+XZ555hlz/GeffSYBAQHy6quvyqFDh2TBggU23242PDxcioqK5Msvv5QHHniAbzfrIk/359q1a5Keni6xsbGyf/9+OXXqlPnj6tWrnVJjV6bG98+N+K5Q7lGjRxs3bpTAwED585//LNXV1VJQUCA6nU527tzp9fq6OjX6k5qaKklJSVJaWirHjh2TVatWSXBwsLzxxhter6+rc7Y/V69elX379sm+ffskJiZG5s6dK/v27ZPq6mqH1yTnqNGjOXPmiMFgkH/+858WjxOampq8Xh8Hiy6ooKBA+vfvL0FBQZKcnCwVFRXmc6mpqTJz5kyL+Pfff19uueUWCQoKkqSkJPnwww8tziuKIs8//7z07dtX9Hq93H333XLkyBFvlOKXPNmf2tpaAWDzo7S01EsV+RdPf//ciIOF+9To0V//+lcZNGiQBAcHy/Dhw2XTpk1ql+G3PN2fU6dOyaxZs8RoNEpwcLAMHjxYli5dKoqieKMcv+NMf9r7NyY1NdXhNcl5nu5Re48TVq1a5b2ivqf5PiEiIiIiIiKX8TUWRERERETkNg4WRERERETkNg4WRERERETkNg4WRERERETkNg4WRERERETkNg4WRERERETkNg4WRERERETkNg4WRERERETkNg4WRERERETkNg4WRERERETkNg4WREQ+wGQyQaPR4Pz5852dChERkUs4WBBRt7V69WpoNBpoNBqUlZVZnRcRxMXFQaPR4P777++EDF23a9cumEwm1NfXO3XOm1avXo0hQ4bgqaeecmsdX6nHF/ny3vhybkTkGg4WRNTtBQcHY82aNVbHd+zYgW+++QZ6vb4TsnLPrl278MILL7Q7WLR3zltqamqQnZ2NKVOmYPPmzW6t5Qv1+Cpf3htfzo2IXMPBgoi6vQkTJmD9+vX49ttvLY6vWbMGI0eORHR0dCdl5r9Wr16Ne+65B5GRkYiMjPTqtf/73/969XqdoTvUSES+h4MFEXV706ZNw4ULF1BSUmI+du3aNWzYsAFZWVlW8SdOnEBOTg4GDx6MkJAQREREYPLkyTh+/LhFXNvrJmpqajBr1iyEh4fDYDDgZz/7GZqammzmUl9fbzfWkWubTCbMmzcPAJCQkGB+utfx48ftnnO0Lldru94HH3yA9PR0VFZW4sc//nGH8e2xV8/1eR48eBBZWVno3bs3xowZ4/BeOlNrY2MjnnzyScTHx0Ov1yMqKgr33HMP9u7da7XW4cOHkZmZibCwMEREROBXv/oVmpubrerbt28fxo8fj7CwMPTs2RN33303KioqbOZ3Y40d7Y0tJ0+exC9+8QsYjUbo9XokJCRgzpw5uHbtmtN52dsPV3IjIt8X0NkJEBF1tvj4eIwePRpr167F+PHjAQDbtm3DpUuXMHXqVLz++usW8ZWVldi1axemTp2K2NhYHD9+HCtWrMBdd92FgwcPokePHhbxmZmZSEhIQH5+Pvbu3Ys333wTUVFReOWVV6xy6SjWkWs/+OCDOHr0KNauXYvXXnsNffr0AQBERkbaPbdt2zan6nK2tjbnz5/HV199hdTUVJhMJixbtsyJblmyV8/1Jk+ejJtvvhkvvfQSRMThvXSm1uzsbGzYsAF5eXlITEzEhQsXUFZWhkOHDmHEiBFWa8XHxyM/Px8VFRV4/fXXcfHiRbz11lvmmAMHDuCOO+5AWFgYfvOb3yAwMBArV67EXXfdhR07diAlJcVujWPGjHFob9rU1dUhOTkZ9fX1mD17NoYMGYKTJ09iw4YNaGpqQlBQkFN52dsPR/tGRF2MEBF1U6tWrRIAUllZKX/605+kV69e0tTUJCIikydPlp/+9KciIjJgwAC57777zF/XFnO98vJyASBvvfWW+diCBQsEgPz85z+3iJ00aZJERERYHHM01tFrL1myRABIbW2tVXx75xxd29nabrR582bp3bu3lJaWisFgsHldZ9irtS3PadOmWZ3zdB8NBoPk5ubazbVtrfT0dIvjOTk5AkC++OIL87GMjAwJCgqSf//73+ZjdXV10qtXL7nzzjsdqtHe3txoxowZotVqpbKy0uqcoihO59XRfjiTGxF1DXwqFBERvvsf5CtXrmDLli1obGzEli1bbD4NCgBCQkLMf29pacGFCxcwaNAghIeHWzztpU12drbF53fccQcuXLiAhoYGp2OdvbYzXFnbmdraVFZWYtiwYSgsLMQjjzxicV213Jgn4Pk+hoeHY/fu3airq+swn9zcXIvPn3jiCQDA1q1bAQCtra0oLi5GRkYGfvCDH5jjYmJikJWVhbKyMqs9tlWjoxRFwaZNmzBx4kTcdtttVuc1Go3TeTmzH0TkHzhYEBHhu6dgpKWlYc2aNdi4cSNaW1vx8MMP24y9cuUK5s+fj7i4OOj1evTp0weRkZGor6/HpUuXrOL79+9v8Xnv3r0BABcvXnQ61tlrO8OVtZ2prc2RI0eg0+lQVFSEuXPnWpxbsWIFRowYgcDAQJhMJrfquV5CQoLVMU/3cfHixaiqqkJcXBySk5NhMplw7Ngxm/ncfPPNFp8PHDgQWq3W/BqDc+fOoampCYMHD7b62qFDh0JRFHz99dcd1uioc+fOoaGhAbfeemuHcY7m5cx+EJF/4GBBRPS9rKwsbNu2DYWFhRg/fjzCw8Ntxj3xxBNYtGgRMjMz8f7776O4uBglJSWIiIiAoihW8TqdzuY68v1z/Z2JdfbaznBlbWdqa3P+/Hl8+umnmDp1qtWD4ZiYGJhMJjz00EOuF2KDrbsinu5jZmYmjh07hoKCAhiNRixZsgRJSUnYtm1bh/m13RFwhzfu/DjDnf0goq6JL94mIvrepEmT8Pjjj6OiogLr1q1rN27Dhg2YOXMmli5daj7W3Nzslffjd/Ta9h6otnfOW3VptVro9XosXLjQ6lxGRgaA/z8lyBGuPihXo96YmBjk5OQgJycHZ8+exYgRI7Bo0SLzmwK0qa6uthiqampqoCgK4uPjAXx3B61Hjx44cuSI1TUOHz4MrVaLuLi4DvNxdG8iIyMRFhaGqqqqDuOcycvefnhimCIi38I7FkRE3+vZsydWrFgBk8mEiRMnthun0+ms/ke+oKAAra2taqfo8LVDQ0MBwOaD5PbOeaMuEcHFixcxY8YM9OvXzyNr2qvVHk/W29raavX0qaioKBiNRly9etUqfvny5VbXBWAeQHQ6He69914UFRVZvAXrmTNnsGbNGowZMwZhYWEd5uXo3mi1WmRkZOCDDz7Av/71L6vzbfvkaF6O7IerfSMi38U7FkRE15k5c2aHMffffz/efvttGAwGJCYmory8HNu3b0dERITq+Tl67ZEjRwIAnnvuOUydOhWBgYGYOHEiQkND2z3njbr+8pe/YP/+/QC+e8Hws88+i9tuu63d17M4wl6t9niy3sbGRsTGxuLhhx/G8OHD0bNnT2zfvh2VlZUWd0Ta1NbWIj09HePGjUN5eTneeecdZGVlYfjw4eaYhQsXoqSkBGPGjEFOTg4CAgKwcuVKXL16FYsXL3YoL2f25qWXXkJxcTFSU1Mxe/ZsDB06FKdOncL69etRVlZmfmqgI3k5sh+u9o2IfBcHCyIiJy1btgw6nQ7vvvsumpubcfvtt2P79u0YO3asz1x71KhRePHFF1FYWIiPPvoIiqKgtrYWoaGh7Z5Tu67m5mZs3boVW7ZswdKlSzFw4EBMmDABkyZNcmtde7Xa48l6e/TogZycHBQXF2Pjxo1QFAWDBg3CG2+8gTlz5ljFr1u3DvPnz8czzzyDgIAA5OXlYcmSJRYxSUlJ2LlzJ377298iPz8fiqIgJSUF77zzjtXvsGiPM3vTr18/7N69G88//zzeffddNDQ0oF+/fhg/frzF7/RwJC9H9sPVvhGR79KIvVfYERERdYLs7GxER0d79J2hfIHJZMILL7yAc+fOmX8pHBGRv+BrLIiIyGd8++23aG5uRmtrq8XfiYjI93GwICIin7Fw4UKEhITgzTffxKJFixASEoK33367s9MiIiIHcLAgIiKfYTKZICIWH7NmzerstIiIyAF8jQUREREREbmNdyyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMhtHCyIiIiIiMht/wOJRUnu/YbsQAAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# -- Pareto front profiling (Figure 5) --\n", + "t_sweep = np.linspace(t_min * 1.01, t_max * 0.99, 15)\n", + "\n", + "pareto_data = {}\n", + "for eta_p in [10, 100, 1000]:\n", + " costs_l1, costs_l2 = [], []\n", + " for t_val in t_sweep:\n", + " D_ineq = convert_leq_constraint(C2_1d, float(t_val), n_1d)\n", + " _, _, _, P_p, _ = constrained_sinkhorn(\n", + " C1_1d, [D_ineq], r_1d, c_1d, eta_p,\n", + " K=1, L=0, n_iters=200, n_newton=5, verbose=False\n", + " )\n", + " costs_l1.append(float(jnp.sum(C1_1d * P_p)))\n", + " costs_l2.append(float(jnp.sum(C2_1d * P_p)))\n", + " pareto_data[eta_p] = (costs_l1, costs_l2)\n", + " print(f\"eta = {eta_p}: done\")\n", + "\n", + "# \"True\" Pareto front at very high eta\n", + "costs_l1_true, costs_l2_true = [], []\n", + "eta_true = 5000\n", + "for t_val in t_sweep:\n", + " D_ineq = convert_leq_constraint(C2_1d, float(t_val), n_1d)\n", + " _, _, _, P_p, _ = constrained_sinkhorn(\n", + " C1_1d, [D_ineq], r_1d, c_1d, eta_true,\n", + " K=1, L=0, n_iters=300, n_newton=5, verbose=False\n", + " )\n", + " costs_l1_true.append(float(jnp.sum(C1_1d * P_p)))\n", + " costs_l2_true.append(float(jnp.sum(C2_1d * P_p)))\n", + "print(f\"True Pareto (eta={eta_true}): done\")\n", + "\n", + "# -- Plot (reproducing Figure 5) --\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "ax.plot(costs_l1_true, costs_l2_true, \"k-\", linewidth=2.5, label=\"True Pareto Front\", zorder=5)\n", + "\n", + "markers = [\"o\", \"s\", \"^\"]\n", + "for (eta_p, (cl1, cl2)), marker in zip(pareto_data.items(), markers):\n", + " ax.plot(cl1, cl2, f\"-{marker}\", linewidth=1.5, markersize=5, label=f\"$\\\\eta = {eta_p}$\", alpha=0.8)\n", + "\n", + "ax.set_xlabel(\"Manhattan $\\\\ell_1$ transport cost\", fontsize=12)\n", + "ax.set_ylabel(\"Euclidean $\\\\ell_2^2$ transport cost\", fontsize=12)\n", + "ax.set_title(\"Figure 5: Pareto fronts under different strengths of entropy regularisation\", fontsize=13)\n", + "ax.legend(fontsize=11)\n", + "ax.grid(True, alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxYAAAFUCAYAAAC5sarpAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAiBJJREFUeJzt3Xd8U+X+B/BPkjZJZwp0QS2UJVuqjMpQHOUWRAVn5V5ZghtXBQWuAopaQRH8iYrgAO8VQbiIKIhgBRdVZpG9WtrSNunO3uf5/VEbCE1HmnFy0u/79coLevLknG/S0yTf83yf5xExxhgIIYQQQgghxANivgMghBBCCCGECB8lFoQQQgghhBCPUWJBCCGEEEII8RglFoQQQgghhBCPUWJBCCGEEEII8RglFoQQQgghhBCPUWJBCCGEEEII8RglFoQQQgghhBCPUWJBCCGEEEII8RglFoQQQZo6dSpSUlL4DiMgXbhwASKRCGvWrHFsW7hwIUQikV+Of9NNN+Gmm25y/Lxnzx6IRCJs2rTJL8cP9HNj//79GD58OCIiIiASiZCXl8d3SOQyIpEICxcu5DsMQgSJEgtCfOTo0aO499570aVLF8jlciQlJWH06NF47733+A6tTfnggw+cvmAHukCKt7S0FAsXLgzIL76BHFtTrFYr7rvvPlRXV2PZsmX4z3/+gy5duvj0mK+//jruvPNOJCQkNPmleerUqRCJRI5bZGQkunXrhnvvvRf/+9//wHGcT+MkLbd3714sXLgQtbW1fIdCiJMQvgMgJBjt3bsXN998Mzp37oyHH34YiYmJKC4uxh9//IF3330XTz31FN8hthkffPABYmNjMXXqVL5DaRFfxfvSSy9hzpw5bj2mtLQUr7zyClJSUpCamtrix+3cudPN6NzXVGyrV68O2C/B58+fR2FhIVavXo0ZM2b45ZgvvfQSEhMTce211+KHH35osq1MJsPHH38MADAajSgsLMS3336Le++9FzfddBO++eYbREdH+yNs3hiNRoSEBPbXo7179+KVV17B1KlTERMTw3c4hDgE9l8OIQL1+uuvQ6FQYP/+/Q3e9MvLy/0ej16vR0REhN+PG6xMJhOkUinEYuF0+oaEhPj8y5LBYEB4eDikUqlPj9Oc0NBQXo/flPq/f29+GWzu77ugoAApKSmorKxEXFxck/sKCQnBgw8+6LTttddew5tvvom5c+fi4YcfxoYNG7wSdyDhOA4WiwVyuRxyuZzvcAgRLOF8KhIiIOfPn0e/fv1cfnmIj493+tlms2HRokXo3r07ZDIZUlJSMG/ePJjNZqd2jZUwpKSkOF3dXrNmDUQiEX7++Wc88cQTiI+Px1VXXeW4//vvv8eoUaMQFRWF6OhoDBkyBOvWrXPa559//okxY8ZAoVAgPDwco0aNwu+//97s87ZYLJg/fz4GDRoEhUKBiIgI3HDDDdi9e7dTu/oxAG+//TZWrVrleO5DhgzB/v37G+x3y5Yt6N+/P+RyOfr374+vv/662VjqX5vjx4/j559/dpR31Nf+V1dXY9asWRgwYAAiIyMRHR2NsWPH4siRI077qB8fsH79erz00ktISkpCeHg4NBoNAGDjxo3o27evU2yuavw5jsPy5cvRr18/yOVyJCQk4NFHH0VNTU2L4m1MbW0tpk6dCoVCgZiYGEyZMsVleYSrMRa7du3CyJEjERMTg8jISPTq1Qvz5s1zPO8hQ4YAAKZNm+aIp75M66abbkL//v1x8OBB3HjjjQgPD3c89soxFvXsdjvmzZuHxMRERERE4M4770RxcbFTmyvP53qX77O52Fy9/nq9Hs8//zySk5Mhk8nQq1cvvP3222CMObUTiUSYOXOm45yTyWTo168fduzY4dROq9Xi2WefRUpKCmQyGeLj4zF69GgcOnSoQez1pk6dilGjRgEA7rvvvga/359++gk33HADIiIiEBMTg/Hjx+PkyZNO+6j/PZ44cQL//Oc/0a5dO4wcObLRYwLwyniTOXPm4B//+Ac2btyIM2fONNv+1KlTuPfee9G+fXvI5XIMHjwYW7duddxfXl6OuLg43HTTTU6/g3PnziEiIgKZmZmObZefa8OHD0dYWBi6du2KlStXNjiu2WzGggUL0KNHD8hkMiQnJ+OFF15w+X46c+ZMfPHFF+jXrx9kMpnjd3zle239a37mzBk8+OCDUCgUiIuLw8svvwzGGIqLizF+/HhER0cjMTERS5cu9Tiups6/hQsXYvbs2QCArl27Os7/CxcuAGj675oQX6MeC0J8oEuXLsjNzcWxY8fQv3//JtvOmDEDa9euxb333ovnn38ef/75J7Kzs3Hy5MkWf4F25YknnkBcXBzmz58PvV4PoC7peOihh9CvXz/MnTsXMTExOHz4MHbs2IF//vOfAOq+3IwdOxaDBg3CggULIBaL8dlnn+GWW27Br7/+iqFDhzZ6TI1Gg48//hgTJ07Eww8/DK1Wi08++QQZGRnYt29fg5KVdevWQavV4tFHH4VIJMKSJUtw9913Iz8/33HVeefOnbjnnnvQt29fZGdno6qqCtOmTXNKlhqzfPlyPPXUU4iMjMS///1vAEBCQgIAID8/H1u2bMF9992Hrl27QqVS4aOPPsKoUaNw4sQJdOrUyWlfixYtglQqxaxZs2A2myGVSrFt2zZkZmZiwIAByM7ORk1NDaZPn46kpKQGsTz66KNYs2YNpk2bhqeffhoFBQVYsWIFDh8+jN9//x2hoaFNxusKYwzjx4/Hb7/9hsceewx9+vTB119/jSlTpjT72hw/fhy33347rrnmGrz66quQyWQ4d+6cI4Hs06cPXn31VcyfPx+PPPIIbrjhBgDA8OHDHfuoqqrC2LFj8cADD+DBBx9sMlagridPJBLhxRdfRHl5OZYvX4709HTk5eUhLCys2ZjrtSS2yzHGcOedd2L37t2YPn06UlNT8cMPP2D27NkoKSnBsmXLnNr/9ttv2Lx5M5544glERUXh//7v/3DPPfegqKgIHTp0AAA89thj2LRpE2bOnIm+ffuiqqoKv/32G06ePInrrrvOZRyPPvookpKS8MYbb+Dpp5/GkCFDHK/Zjz/+iLFjx6Jbt25YuHAhjEYj3nvvPYwYMQKHDh1qkBzcd9996NmzJ954440GyZGvTJo0CTt37sSuXbtw9dVXN9ru+PHjGDFiBJKSkjBnzhxERETgq6++woQJE/C///0Pd911F+Lj4/Hhhx/ivvvuw3vvvYenn34aHMdh6tSpiIqKwgcffOC0z5qaGtx22224//77MXHiRHz11Vd4/PHHIZVK8dBDDwGoS97vvPNO/Pbbb3jkkUfQp08fHD16FMuWLcOZM2ewZcsWp33+9NNP+OqrrzBz5kzExsY2m4BlZmaiT58+ePPNN7Ft2za89tpraN++PT766CPccsstWLx4Mb744gvMmjULQ4YMwY033tiquJo7/+6++26cOXMGX375JZYtW4bY2FgAQFxcXLN/14T4HCOEeN3OnTuZRCJhEomEDRs2jL3wwgvshx9+YBaLxaldXl4eA8BmzJjhtH3WrFkMAPvpp58c2wCwBQsWNDhWly5d2JQpUxw/f/bZZwwAGzlyJLPZbI7ttbW1LCoqiqWlpTGj0ei0D47jHP/27NmTZWRkOLYxxpjBYGBdu3Zlo0ePbvJ522w2ZjabnbbV1NSwhIQE9tBDDzm2FRQUMACsQ4cOrLq62rH9m2++YQDYt99+69iWmprKOnbsyGprax3bdu7cyQCwLl26NBkPY4z169ePjRo1qsF2k8nE7Ha707aCggImk8nYq6++6ti2e/duBoB169aNGQwGp/YDBgxgV111FdNqtY5te/bsaRDbr7/+ygCwL774wunxO3bsaLC9sXhd2bJlCwPAlixZ4thms9nYDTfcwACwzz77zLF9wYIF7PK3/GXLljEArKKiotH979+/v8F+6o0aNYoBYCtXrnR53+XPof41TEpKYhqNxrH9q6++YgDYu+++69h25fnc2D6bim3KlClOr3/96/Taa685tbv33nuZSCRi586dc2wDwKRSqdO2I0eOMADsvffec2xTKBTsySefbHDs5tS/Fhs3bnTanpqayuLj41lVVZXTccViMZs8ebJjW/3vceLEiW4fu6KiotH3EcbqXreIiIhGH3/48GEGgD333HNNHufWW29lAwYMYCaTybGN4zg2fPhw1rNnT6e2EydOZOHh4ezMmTPsrbfeYgDYli1bnNrUn2tLly51bDObzY7XrP599T//+Q8Ti8Xs119/dXr8ypUrGQD2+++/O7YBYGKxmB0/frxB/Fe+RvWv+SOPPOLYZrPZ2FVXXcVEIhF78803HdtrampYWFiY0znsblwtOf/qX6uCggKnfbbk75oQX6JSKEJ8YPTo0cjNzcWdd96JI0eOYMmSJcjIyEBSUpJTOcD27dsBAFlZWU6Pf/755wEA27Zta3UMDz/8MCQSiePnXbt2QavVYs6cOQ1qiOtLZPLy8nD27Fn885//RFVVFSorK1FZWQm9Xo9bb70Vv/zyS5ODYiUSiaO+nuM4VFdXw2azYfDgwS5LRDIzM9GuXTvHz/VXnvPz8wEAZWVlyMvLw5QpU6BQKBztRo8ejb59+7r7kjiRyWSOMRJ2ux1VVVWOsgFXsU6ZMsXpqnppaSmOHj2KyZMnIzIy0rF91KhRGDBggNNjN27cCIVCgdGjRzte08rKSgwaNAiRkZENSsVaavv27QgJCcHjjz/u2CaRSFo0OUB9md4333zT6oHOMpkM06ZNa3H7yZMnIyoqyvHzvffei44dOzr+Dnxl+/btkEgkePrpp522P//882CM4fvvv3fanp6eju7duzt+vuaaaxAdHe04L4G61+/PP/9EaWmpx/HVn+dTp05F+/btnY47evRol6/PY4895vFx3VV/nmu12kbbVFdX46effsL9998PrVbrONerqqqQkZGBs2fPoqSkxNF+xYoVUCgUuPfee/Hyyy9j0qRJGD9+fIP9hoSE4NFHH3X8LJVK8eijj6K8vBwHDx4EUPd31qdPH/Tu3dvp7+yWW24BgAZ/Z6NGjXLrfeTywfYSiQSDBw8GYwzTp093bI+JiUGvXr2czhV342rJ+dcYb/xdE+IJSiwI8ZEhQ4Zg8+bNqKmpwb59+zB37lxotVrce++9OHHiBACgsLAQYrEYPXr0cHpsYmIiYmJiUFhY2Orjd+3a1enn8+fPA0CTpVlnz54FUPclOi4uzun28ccfw2w2Q61WN3nctWvX4pprroFcLkeHDh0QFxeHbdu2uXxc586dnX6uTzLqxx3UP/+ePXs2eGyvXr2ajKM5HMdh2bJl6NmzJ2QyGWJjYxEXF4e//vrLZaxXvp71sV35u3O17ezZs1Cr1YiPj2/wuup0ulYP6C8sLETHjh2dEhugZa9NZmYmRowYgRkzZiAhIQEPPPAAvvrqK7e+jCQlJbk1UPvK36NIJEKPHj0cteG+UlhYiE6dOjklNUBdSVX9/Ze78rwE6s7Ny8fDLFmyBMeOHUNycjKGDh2KhQsXtuiLX2PxAa5/b3369HEk95e78nz0B51OBwANXsfLnTt3DowxvPzyyw3O9QULFgBwnsCiffv2+L//+z/89ddfUCgU+L//+z+X++3UqVODAer15Vj158/Zs2dx/PjxBsetb3fl35m7r+GV54VCoYBcLneUIl2+/fJzxd24WnL+NcYbf9eEeILGWBDiY1KpFEOGDMGQIUNw9dVXY9q0adi4caPjQxZAg0G17rDb7S63u1OzXq/+w+ett95qdHrRK7/EXu6///0vpk6digkTJmD27NmIj4+HRCJBdna2I7G53OU9KpdjfqgZf+ONN/Dyyy/joYcewqJFi9C+fXuIxWI8++yzLj+EW/N61uM4DvHx8fjiiy9c3t/cTD2+EBYWhl9++QW7d+/Gtm3bsGPHDmzYsAG33HILdu7c2ejv5sp9eFtjfwt2u71FMXlDS87L+++/HzfccAO+/vpr7Ny5E2+99RYWL16MzZs3Y+zYsT6P0RevfXOOHTsGwHUyXa/+b2fWrFnIyMhw2ebKx9dPgVtTU4OLFy+2esYsjuMwYMAAvPPOOy7vT05OdvrZ3dfQ1XnRknPF3bg8eV/0xt81IZ6gxIIQPxo8eDCAutIHoG6QN8dxOHv2rOPqKQCoVCrU1tY6LZzVrl27BrP9WCwWx76aU9+1fuzYsUa/GNS3iY6ORnp6esue1GU2bdqEbt26YfPmzU5fEC9PotxR//zre1Iud/r06Rbto7Evqps2bcLNN9+MTz75xGl7bW1tgyuQTcV27ty5Bvddua179+748ccfMWLEiGa/zLiTZHbp0gU5OTnQ6XROCV9LXxuxWIxbb70Vt956K9555x288cYb+Pe//43du3cjPT3d6yt1X/l7ZIzh3LlzuOaaaxzbXJ3nQN1V/W7dujl+dvd1+vHHH6HVap2utp86dcpxf2t07NgRTzzxBJ544gmUl5fjuuuuw+uvv+52YlF/fFe/t1OnTiE2NjYgpov+z3/+A5FIhNGjRzfapv53FBoa2qL3kB07duDjjz/GCy+8gC+++AJTpkzBn3/+2WBq5NLS0gbT6tbPTlU/6Lp79+44cuQIbr31Vr+tMt8Svoirqf0093dNiC9RKRQhPrB7926XV5fqa6XrSx5uu+02AHWzF12u/srWuHHjHNu6d++OX375xandqlWrGu2xuNI//vEPREVFITs7GyaTyem++lgHDRqE7t274+2333aUPVyuoqKiyWPUXw27/Ln/+eefyM3NbVGMV+rYsSNSU1Oxdu1ap/KkXbt2OcrJmhMREeHyi6pEImnwO9q4caNT/XdTOnXqhP79++Pzzz93eq1+/vlnHD161Knt/fffD7vdjkWLFjXYj81mc4qvsXhdue2222Cz2fDhhx86ttnt9hat7l5dXd1gW30vVf0UmPVf4ry1uu/nn3/uVJ+/adMmlJWVOX0R7969O/744w9YLBbHtu+++67BtLTuxHbbbbfBbrdjxYoVTtuXLVsGkUjkdiJgt9sblMvFx8ejU6dODaYPbYnLz/PLn8+xY8ewc+dOx/sEn958803s3LkTmZmZLksT68XHx+Omm27CRx995PKix+XvIbW1tZgxYwaGDh2KN954Ax9//DEOHTqEN954o8HjbDYbPvroI8fPFosFH330EeLi4jBo0CAAdX9nJSUlWL16dYPHG43GBuVk/uKLuBo7/1vyd02IL1GPBSE+8NRTT8FgMOCuu+5C7969YbFYsHfvXmzYsAEpKSmOAa8DBw7ElClTsGrVKtTW1mLUqFHYt28f1q5diwkTJuDmm2927HPGjBl47LHHcM8992D06NE4cuQIfvjhhxZdXQfqeiGWLVuGGTNmYMiQIY458I8cOQKDwYC1a9dCLBbj448/xtixY9GvXz9MmzYNSUlJKCkpwe7duxEdHY1vv/220WPcfvvt2Lx5M+666y6MGzcOBQUFWLlyJfr27esyUWmJ7OxsjBs3DiNHjsRDDz2E6upqvPfee+jXr1+L9jlo0CB8+OGHeO2119CjRw/Ex8fjlltuwe23345XX30V06ZNw/Dhw3H06FF88cUXTlfFm/PGG29g/PjxGDFiBKZNm4aamhqsWLEC/fv3d4pt1KhRePTRR5GdnY28vDz84x//QGhoKM6ePYuNGzfi3Xffxb333ttkvK7ccccdGDFiBObMmYMLFy6gb9++2Lx5c7PjYADg1VdfxS+//IJx48ahS5cuKC8vxwcffICrrrrKsS5C9+7dERMTg5UrVyIqKgoRERFIS0trdX1/+/btMXLkSEybNg0qlQrLly9Hjx498PDDDzvazJgxA5s2bcKYMWNw//334/z58/jvf//rNJjV3djuuOMO3Hzzzfj3v/+NCxcuYODAgdi5cye++eYbPPvssw323RytVourrroK9957LwYOHIjIyEj8+OOP2L9/v8s1DFrirbfewtixYzFs2DBMnz7dMd2sQqFwuX6NO/7zn/+gsLAQBoMBAPDLL7/gtddeA1A3hezlPTY2mw3//e9/AdQtBFlYWIitW7fir7/+ws0334xVq1Y1e7z3338fI0eOxIABA/Dwww+jW7duUKlUyM3NxcWLFx1rxTzzzDOoqqrCjz/+CIlEgjFjxmDGjBl47bXXMH78eAwcONCxz06dOmHx4sW4cOECrr76amzYsAF5eXlYtWqVY2rqSZMm4auvvsJjjz2G3bt3Y8SIEbDb7Th16hS++uor/PDDD45eY3/yRVz1ydS///1vPPDAAwgNDcUdd9zRor9rQnyKn8moCAlu33//PXvooYdY7969WWRkJJNKpaxHjx7sqaeeYiqVyqmt1Wplr7zyCuvatSsLDQ1lycnJbO7cuU5TNTLGmN1uZy+++CKLjY1l4eHhLCMjg507d67R6Wb379/vMratW7ey4cOHs7CwMBYdHc2GDh3KvvzyS6c2hw8fZnfffTfr0KEDk8lkrEuXLuz+++9nOTk5TT5vjuPYG2+8wbp06cJkMhm79tpr2Xfffddg+s/66WbfeuutBvuAi+kw//e//7E+ffowmUzG+vbtyzZv3txgn41RKpVs3LhxLCoqigFwTFlqMpnY888/zzp27MjCwsLYiBEjWG5ubqNTpV45PWi99evXs969ezOZTMb69+/Ptm7dyu655x7Wu3fvBm1XrVrFBg0axMLCwlhUVBQbMGAAe+GFF1hpaWmz8TamqqqKTZo0iUVHRzOFQsEmTZrkmBa0qelmc3Jy2Pjx41mnTp2YVCplnTp1YhMnTmRnzpxx2v8333zD+vbty0JCQpz2OWrUKNavXz+XMTX2Gn755Zds7ty5LD4+noWFhbFx48axwsLCBo9funQpS0pKYjKZjI0YMYIdOHCgwT6bis3VuaHVatlzzz3HOnXqxEJDQ1nPnj3ZW2+95TStMmN155+raWQv/zszm81s9uzZbODAgSwqKopFRESwgQMHsg8++MDl63G5ps6nH3/8kY0YMcLxt3nHHXewEydOOLWp/z26M51o/XStrm67d+92tJsyZYrTfeHh4SwlJYXdc889bNOmTQ2mZ27K+fPn2eTJk1liYiILDQ1lSUlJ7Pbbb2ebNm1ijF2aWvryKWQZY0yj0bAuXbqwgQMHOqaRrT/XDhw4wIYNG8bkcjnr0qULW7FiRYPjWiwWtnjxYtavXz8mk8lYu3bt2KBBg9grr7zC1Gq1o11jv+f6+1xNN3vla97Y9Lyu/jY8jcvVNMyLFi1iSUlJTCwWO6aebenfNSG+ImLMTyvrEEJIG5Gamoq4uDjs2rWL71AIEbybbroJlZWVjsHjhJDARWMsCCGklaxWK2w2m9O2PXv24MiRI7jpppv4CYoQQgjhCY2xIISQViopKUF6ejoefPBBdOrUCadOncLKlSuRmJjIywJmhBBCCJ8osSCEkFZq164dBg0ahI8//hgVFRWIiIjAuHHj8Oabb6JDhw58h0cIIYT4FY2xIIQQQgghhHiMxlgQQgghhBBCPEaJBSGEEEIIIcRjNMbCBY7jUFpaiqioKIhEIr7DIYQQQgghhBeMMWi1WnTq1AlicdN9EpRYuFBaWork5GS+wyCEEEIIISQgFBcX46qrrmqyDSUWLkRFRQGoewGjo6N5joYQQgghhBB+aDQaJCcnO74fN4USCxfqy5+io6MpsSCEEEIIIW1eS4YH0OBtQgghhBBCiMcosSCEEEIIIYR4jBILQgghhBBCiMdojAUhhBBCCPEqu90Oq9XKdxikBUJDQyGRSLyyL0osCCGEEEKIVzDGoFQqUVtby3coxA0xMTFITEz0eP02SiwIIYQQQohX1CcV8fHxCA8Pp4WGAxxjDAaDAeXl5QCAjh07erQ/SiwIIYQQQojH7Ha7I6no0KED3+GQFgoLCwMAlJeXIz4+3qOyKBq8TQghhBBCPFY/piI8PJznSIi76n9nno6L4T2xeP/995GSkgK5XI60tDTs27evyfYbN25E7969IZfLMWDAAGzfvt3pfp1Oh5kzZ+Kqq65CWFgY+vbti5UrV/ryKRASNBhjfIdACCFE4Kj8SXi89TvjNbHYsGEDsrKysGDBAhw6dAgDBw5ERkaGo87rSnv37sXEiRMxffp0HD58GBMmTMCECRNw7NgxR5usrCzs2LED//3vf3Hy5Ek8++yzmDlzJrZu3eqvp0WIIOnNNnyTV4rDRTUwWe18h0MIaYusJqDoT0BfxXckhJBWEDEeL1GmpaVhyJAhWLFiBQCA4zgkJyfjqaeewpw5cxq0z8zMhF6vx3fffefYdv311yM1NdXRK9G/f39kZmbi5ZdfdrQZNGgQxo4di9dee61FcWk0GigUCqjVakRHR3vyFAkRjJ9OqaBUmwEAIWIReiZEok/HaMhDvTMFHSGENCv/Z6A6v+7/EbFAbC+gfTdAQkNChcBkMqGgoABdu3aFXC7nOxzihqZ+d+58L+btL9ViseDgwYOYO3euY5tYLEZ6ejpyc3NdPiY3NxdZWVlO2zIyMrBlyxbHz8OHD8fWrVvx0EMPoVOnTtizZw/OnDmDZcuWNRqL2WyG2Wx2/KzRaFr5rAgRpjMqrSOpAAAbx3CyTIuzKh0lGIQQ/6gtvpRUAIC+su52cX9dchHXCwhvz198xCPr/izy6/H+mdbZ7cdUVFRg/vz52LZtG1QqFdq1a4eBAwdi/vz5GDFihA+iDD68JRaVlZWw2+1ISEhw2p6QkIBTp065fIxSqXTZXqlUOn5+77338Mgjj+Cqq65CSEgIxGIxVq9ejRtvvLHRWLKzs/HKK6948GwIES6NyYq8olqX91GCQQjxC7sVKPqjkfssQMWpultEHBB7NfViEJ+45557YLFYsHbtWnTr1g0qlQo5OTmoqvJtaZ7FYoFUKvXpMerZbDaEhPjub4f3wdve9t577+GPP/7A1q1bcfDgQSxduhRPPvkkfvzxx0YfM3fuXKjVasetuLjYjxETwh/GGHLPV8HGNV0RWZ9gbD1SirziWhqDQQjxrosHAIuu+Xb6CqDwd+CvDUBhLmCo9n1spE2ora3Fr7/+isWLF+Pmm29Gly5dMHToUMydOxd33nkngLoKl6effhrx8fGQy+UYOXIk9u/f77SflJQULF++3GlbamoqFi5c6Pj5pptuwsyZM/Hss88iNjYWGRkZ4DgOS5YsQY8ePSCTydC5c2e8/vrrjsdwHIfs7Gx07doVYWFhGDhwIDZt2tTkc7pw4QJEIhG++uor3HDDDZDJZD4fc8xbuh8bGwuJRAKVSuW0XaVSITEx0eVjEhMTm2xvNBoxb948fP311xg3bhwA4JprrkFeXh7efvttpKenu9yvTCaDTCbz9CkRIjjHSzWo0lla3N5mZzhRqsEZlRZXJ0Shd2IU9WAQQjyjVdX1Rrjjyl6M5DQgMs438ZE2ITIyEpGRkdiyZQuuv/56l98LX3jhBfzvf//D2rVr0aVLFyxZsgQZGRk4d+4c2rd3r0xv7dq1ePzxx/H7778DqLvIvXr1aixbtgwjR45EWVmZUwVPdnY2/vvf/2LlypXo2bMnfvnlFzz44IOIi4vDqFGjXB7jyJEjAIC33noLb7zxBrp27Yq4ON/+nfDWYyGVSjFo0CDk5OQ4tnEch5ycHAwbNszlY4YNG+bUHgB27drlaG+1WmG1WiEWOz8tiUQCjuO8/AwIEbYavQXHStStemx9grH1SCnOlWu9HBkhpM3g7HU9EJ7QVwAlB70TD2mzQkJCsGbNGqxduxYxMTEYMWIE5s2bh7/++gsAoNfr8eGHH+Ktt97C2LFj0bdvX6xevRphYWH45JNP3D5ez549sWTJEvTq1QudOnXCu+++iyVLlmDKlCno3r07Ro4ciRkzZgCo6yl544038OmnnyIjIwPdunXD1KlT8eCDD+Kjjz5q9Bh5eXmIiIjAxo0bMXr0aPTo0QMKhaJ1L1AL8VqgmJWVhSlTpmDw4MEYOnQoli9fDr1ej2nTpgEAJk+ejKSkJGRnZwMAnnnmGYwaNQpLly7FuHHjsH79ehw4cACrVq0CAERHR2PUqFGYPXs2wsLC0KVLF/z888/4/PPP8c477/D2PAkJNBzHkJtfhWYqoJplszPkFauR0iECIZKgq6wkhPhaWR5gat0FDifasropaiNotWfSevfccw/GjRuHX3/9FX/88Qe+//57LFmyBB9//DGuu+46WK1Wp0HcoaGhGDp0KE6ePOn2sQYNGuT4/8mTJ2E2m3Hrrbe6bHvu3DkYDAaMHj3aabvFYsG1117b6DGOHDmCO++8EykpKW7H11q8JhaZmZmOEfhKpRKpqanYsWOHY4B2UVGRU+/D8OHDsW7dOrz00kuYN28eevbsiS1btqB///6ONuvXr8fcuXPxr3/9C9XV1ejSpQtef/11PPbYY35/foQEqiMXa1Fr8Gx1zXoWG4eCSj16JkR5ZX+EkDbCUA0ojzXfrqVUx4BurktCCGkpuVyO0aNHY/To0Xj55ZcxY8YMLFiwAN9++22LHi8WixssNutqNeuIiAjH/8PCwprcp05XN/5o27ZtSEpKcrqvqVL+vLw8l8s3+BLvUyrMnDkTM2fOdHnfnj17Gmy77777cN999zW6v8TERHz22WfeCo+QoFOuNeGU0rvlS6dVWkosCCEtxxhw4TeAebFMueYCYB4EyCK9t0/S5vXt2xdbtmxB9+7dIZVK8fvvv6NLly4A6hKG/fv349lnn3W0j4uLQ1lZmeNnjUaDgoKCJo/Rs2dPhIWFIScnx1H+dGUMMpkMRUVFjY6nuJJGo8GFCxea7NHwBd4TC0KI/9jsHP7Ir4a3l8XUGG0oUxvRUdH0VRdCCAFQ17tg8PIUnowDyk8CyUO8u1/SJlRVVeG+++7DQw89hGuuuQZRUVE4cOAAlixZgvHjxyMiIgKPP/44Zs+ejfbt26Nz585YsmQJDAYDpk+f7tjPLbfcgjVr1uCOO+5ATEwM5s+fD4mk6UlO5HI5XnzxRbzwwguQSqUYMWIEKioqcPz4cUyfPh1RUVGYNWsWnnvuOXAch5EjR0KtVuP3339HdHQ0pkyZ0mCfR44cgUQiwYABA7z+WjWFEgtC2pDDxbXQmWw+2fcppZYSC0JI80waoDTPN/uuPAN0HAiE+GdNABI8IiMjkZaWhmXLluH8+fOwWq1ITk7Gww8/jHnz5gEA3nzzTXAch0mTJkGr1WLw4MH44Ycf0K5dO8d+5s6di4KCAtx+++1QKBRYtGhRsz0WAPDyyy8jJCQE8+fPR2lpKTp27OhUxr9o0SLExcUhOzsb+fn5iImJwXXXXeeI7UpHjhxBr169/L4CuohdWQhG3Fq6nBChKFMbsftUhU+PMe6ajlCEhfr0GIQQgTu9o26wta9cNQRI7N98O+J1JpMJBQUF6Nq1q9+/0BLPNPW7c+d7MU3jQkgbYLbZ8We+7xeSOqOiqWcJIU2oOOPbpAKoK4eiKeYJ4QUlFoS0AQcv1MBg8f1q2QUVelhs9IFOCHHBYgAu7m++ncfH0QE1zZeeEEK8jxILQoJcUZUBF6oMfjmWjWM4V67zy7EIIQJTlFu3YrY/qI775ziEECeUWBASxIwWO/Zf8H0J1OXOlmsbzOFNCGnjqguA2iL/Hc9QBWh8XHJFCGmAEgtCgtifBVUw+7k0SW+2o7ja6NdjEkICmM0MFP/p/+NSrwUhfkeJBSFB6ly5DqW1Jl6OfZoGcRNC6hXvA6w8XGxQFwPGWv8fl5A2jBILQoIQxzEcKa7l7fgVWjOq9X6qpSaEBC51CVB1jr/jU68FIX5FiQUhQUilNfm9BOpKp5QaXo9PCOEZx9UN2OZT9Xl+eksIaaMosSAkCBX6aRaophRVGWD0wxS3hJAApSkBzDyXRXL2unUtCCF+QYkFIUGG4xiKq/lPLDhWN0MUIaSNCpS1JCpOA3Yb31EQ0iZQYkFIkCnTmGC1B8Z0r+fKdbBzgRELIcSPOA6oLeY7ijo2U11JFCHE50L4DoAQ4l2FVXq+Q3AwWTkUVunRLS6S71AIIf6kKfHfYngtoToOxPXiO4q27cBn/j3e4GluNReJRE3ev2DBAixcuNCDgDx30003ITU1FcuXL+c1jqZQYkFIELFzDCU1gTVQ8bRSS4kFIW1NzQW+I3BmUtct0BfTme9ISIAqK7u0oOKGDRswf/58nD592rEtMtL9zzGLxQKpVOqV+ISCSqEICSKltcaAKYOqV2OwolzDz3oahBAecJx/V9luKZp6ljQhMTHRcVMoFBCJRE7bIiMjsWPHDowcORIxMTHo0KEDbr/9dpw/f6nM7qabbsLMmTPx7LPPIjY2FhkZGQAArVaLf/3rX4iIiEDHjh2xbNky3HTTTXj22Wcdj+U4DtnZ2ejatSvCwsIwcOBAbNq0yXH/1KlT8fPPP+Pdd9+FSCSCSCTChQsXXD6XoqIiTJkyBQkJCY59/fbbbz553a4UEInF+++/j5SUFMjlcqSlpWHfvn1Ntt+4cSN69+4NuVyOAQMGYPv27U7317/gV97eeustXz4NQngXCIO2XTmlpEHchLQZgVYGVU+rBPSVfEdBBEyv1yMrKwsHDhxATk4OxGIx7rrrLnDcpend165dC6lUit9//x0rV64EAGRlZeH333/H1q1bsWvXLvz66684dOiQ076zs7Px+eefY+XKlTh+/Diee+45PPjgg/j5558BAO+++y6GDRuGhx9+GGVlZSgrK0NycnKDGAsLCzF06FAYjUZs3boVf/31F2bOnIno6GgfvjKX8F4KtWHDBmRlZWHlypVIS0vD8uXLkZGRgdOnTyM+Pr5B+71792LixInIzs7G7bffjnXr1mHChAk4dOgQ+vfvD8C5OwsAvv/+e0yfPh333HOPX54TIXyw2TlcrA2sMqh6JbVG6Mw2RMp4f8shhPhaoJVBXU51DOh2E99REIG68nvkp59+iri4OJw4ccLxHbRnz55YsmSJo41Wq8XatWuxbt063HrrrQCAzz77DJ06dXK0MZvNeOONN/Djjz9i2LBhAIBu3brht99+w0cffYRRo0ZBoVBAKpUiPDwciYmJjcb4+OOP4/rrr8dXX33l2NazZ0/Pn3wL8d5j8c477+Dhhx/GtGnT0LdvX6xcuRLh4eH49NNPXbZ/9913MWbMGMyePRt9+vTBokWLcN1112HFihWONpd3XSUmJuKbb77BzTffjG7duvnraRHid6W1JtgCrAyqHmPAGRX1WhAS9AK1DKpezQXArOM7CiJQZ8+excSJE9GtWzdER0cjJSUFQF3pUb1BgwY5PSY/Px9WqxVDhw51bFMoFOjV69JkAufOnYPBYMDo0aMRGRnpuH3++edOpVbNKSwsxPfff8/rIHNeLx9aLBYcPHgQc+fOdWwTi8VIT09Hbq7r1Tpzc3ORlZXltC0jIwNbtmxx2V6lUmHbtm1Yu3at1+ImJBAVBWgZVL3z5ToMSFIgVML79QxCiK8EahlUPcaA8hNA8tDm2xJyhTvuuANdunTB6tWr0alTJ3Ach/79+8NiuXTOR0REuL1fna4u2d22bRuSkpKc7pPJZC3eT15eHqRSKVJTU92OwVt4TSwqKytht9uRkJDgtD0hIQGnTp1y+RilUumyvVKpdNl+7dq1iIqKwt13391oHGazGWaz2fGzRqNp6VMgJCDY7BxKA7QMqp7VzlBQqcfVCVF8h0II8ZVALoOqV3kG6JgKhLSt2XqIZ6qqqnD69GmsXr0aN9xwAwC0aEB0t27dEBoaiv3796Nz57pZydRqNc6cOYMbb7wRANC3b1/IZDIUFRVh1KhRje5LKpXCbrc3en9oaChsNhsMBgPCw8PdeXpeE/QFz59++in+9a9/QS6XN9omOzsbr7zyih+jIsS7SmqNsAlgIbrTSi0lFoQEq0Avg6pntwKVp4HEAXxHQgSkXbt26NChA1atWoWOHTuiqKgIc+bMafZxUVFRmDJlCmbPno327dsjPj4eCxYsgFgsdqydERUVhVmzZuG5554Dx3EYOXIk1Go1fv/9d0RHR2PKlCkAgJSUFPz555+4cOECIiMj0b59e4jFl6oA0tLSoFAo8Pjjj2POnDlgjOGXX37Brbfe6rdxFrzWJMTGxkIikUClUjltV6lUjQ5MSUxMbHH7X3/9FadPn8aMGTOajGPu3LlQq9WOW3FxgKwWSkgLFVYFdhlUPa3JhpIA71khhLRSoJdBXa78ZF0iREgLicVirF+/HgcPHkT//v3x3HPPtXi20XfeeQfDhg3D7bffjvT0dIwYMQJ9+vRxuui9aNEivPzyy8jOzkafPn0wZswYbNu2DV27dnW0mTVrFiQSCfr27Yu4uDinsR0A0KFDB3z77bc4e/YshgwZgpEjR2Lr1q0uJ0PyFRFjjNfLnGlpaRg6dCjee+89AHXz+Hbu3BkzZ850mQlmZmbCYDDg22+/dWwbPnw4rrnmGse0XvWmTp2KY8eO4cCBA27FpNFooFAooFar/TY9FyGtZbFx+PrwRdgF8hmZqJDhlt4JzTckhAhLwa9A1Tm+o2i5rjcCHbrzHUVQMZlMKCgoQNeuXZusFGnr9Ho9kpKSsHTpUkyfPp3vcAA0/btz53sx76VQWVlZmDJlCgYPHoyhQ4di+fLl0Ov1mDatbin2yZMnIykpCdnZ2QCAZ555BqNGjcLSpUsxbtw4rF+/HgcOHMCqVauc9qvRaLBx40YsXbrU78+JEH8qqTUKJqkAAKXaDLXBCkV4KN+hEEK8RShlUJerLqDEgvjF4cOHcerUKQwdOhRqtRqvvvoqAGD8+PE8R+Z9vCcWmZmZqKiowPz586FUKpGamoodO3Y4BmgXFRU51Y8NHz4c69atw0svvYR58+ahZ8+e2LJli2P+4Hrr168HYwwTJ0706/MhxN8Kq/R8h+C20yothnZtz3cYhBBvEVIZVD1tGWC3ARLevwqRNuDtt9/G6dOnIZVKMWjQIPz666+IjY3lOyyv470UKhBRKRQRCouNw+ZDFyGAcdtOZCFi3H1dkmPgGiFE4IRWBlWvRzoQ03D1YtI6VAolXN4qhaIJ5QkRsOIag+CSCgAw2zhU6gR2dZMQ4poQy6DqqWmyFkK8iRILQgSsSCCzQbkS6OtuEEJaSIhlUPXUF/mOgJCgQokFIQJlstqh0pj4DqPVaNpZQoKEEBbFa4xFD+ir+I6CkKBBiQUhAnWxxijIMqh6tQYr9GYb32EQQjwh5DKoelQORYjXUGJBiEAVVQtvNqgrUTkUIQIn5DKoepRYeB1Hiw8Kjrd+ZzTHGiECVFcGZeY7DI+V1BrRMyGK7zAIIa0l5DKoevpKwGIApOF8RyJ4UqkUYrEYpaWliIuLg1Qqpdn/AhxjDBaLBRUVFRCLxZBKpR7tjxILQgSouNqAYJgoulxjhs3OIURCnaeECA7HBc/VfvVFIO5qvqMQPLFYjK5du6KsrAylpaV8h0PcEB4ejs6dOzutHdcalFgQIkCFAp4N6nI2jkGlNSMpJozvUAgh7tKUADbh95wCqEuQKLHwCqlUis6dO8Nms8Fut/MdDmkBiUSCkJAQr/QuUWJBiMAYLXZU6ILkwxx14ywosSBEgIKhDKqephTg7IBYwnckQUEkEiE0NBShoaF8h0L8jOoPCBGYoiApg6pHA7gJEaBgKoMCAM4GaMv4joIQwaPEghCBKawS/mxQl9Ob7ag1CHxWGULaGm1p8JRB1asNokSJEJ5QYkGIgBgsNlTqgu9L+MUa6rUgRFCCqQyqHq3CTYjHKLEgRECCZdD2lagcihABCYZF8Vyx6ABDNd9RECJolFgQIiBF1cGZWFTpLTBZafYQQgQhGMug6gXTuBFCeECJBSECoTPbUBWEZVAAwBhQpjbxHQYhpCWCsQyqHpVDEeIRSiwIEYiiIC2DqldC4ywICXzBWgZVT18BWOkiByGtRYkFIQJRVB1cs0FdqUxtBMcF0Ty6hASjYC6DAuq6T6nXgpBWo8SCEAHQmqyo1lv5DsOnrHYWVAv/ERKUgrkMqp46iHtkCPEx3hOL999/HykpKZDL5UhLS8O+ffuabL9x40b07t0bcrkcAwYMwPbt2xu0OXnyJO68804oFApERERgyJAhKCqiNwoiXME6G9SVSmh2KEICV7CXQdXTlNY9V0KI23hNLDZs2ICsrCwsWLAAhw4dwsCBA5GRkYHy8nKX7ffu3YuJEydi+vTpOHz4MCZMmIAJEybg2LFjjjbnz5/HyJEj0bt3b+zZswd//fUXXn75Zcjlcn89LUK8Llhng7oSTTtLSAAL9jKoenYrrcJNSCuJGGO8FTWnpaVhyJAhWLFiBQCA4zgkJyfjqaeewpw5cxq0z8zMhF6vx3fffefYdv311yM1NRUrV64EADzwwAMIDQ3Ff/7zn1bHpdFooFAooFarER0d3er9EOINOrMNW/NK+Q7Db+4Y2BFR8lC+wyCEXOnCb0DlWb6j8I/4vkDnNL6jICQguPO9mLceC4vFgoMHDyI9Pf1SMGIx0tPTkZub6/Ixubm5Tu0BICMjw9Ge4zhs27YNV199NTIyMhAfH4+0tDRs2bKlyVjMZjM0Go3TjZBAUdbGruKX1tKMLIQEnLZSBlWP1rMgpFV4SywqKytht9uRkJDgtD0hIQFKpdLlY5RKZZPty8vLodPp8Oabb2LMmDHYuXMn7rrrLtx99934+eefG40lOzsbCoXCcUtOTvbw2RHiPW1tfYeS2rZR9kWIoOiUbaMMqp5ZCxhr+I6CEMHhffC2N3F/D7YaP348nnvuOaSmpmLOnDm4/fbbHaVSrsydOxdqtdpxKy6mKxUkMHAcg0rTthKLco0ZVjsNnCQkoKhL+I7A/2jaWULcxltiERsbC4lEApVK5bRdpVIhMTHR5WMSExObbB8bG4uQkBD07dvXqU2fPn2anBVKJpMhOjra6UZIIKjUmWG1t621HTgGKNtYLw0hAU/TBhOLWrrISIi7eEsspFIpBg0ahJycHMc2juOQk5ODYcOGuXzMsGHDnNoDwK5duxztpVIphgwZgtOnTzu1OXPmDLp06eLlZ0CI75W20S/YNO0sIQHEom+bZUH68rZV/kWIF4TwefCsrCxMmTIFgwcPxtChQ7F8+XLo9XpMmzYNADB58mQkJSUhOzsbAPDMM89g1KhRWLp0KcaNG4f169fjwIEDWLVqlWOfs2fPRmZmJm688UbcfPPN2LFjB7799lvs2bOHj6dIiEeU6rb5BZumnSUkgGjazqx0TupX4e7Qne9ICBEMXhOLzMxMVFRUYP78+VAqlUhNTcWOHTscA7SLioogFl/qVBk+fDjWrVuHl156CfPmzUPPnj2xZcsW9O/f39HmrrvuwsqVK5GdnY2nn34avXr1wv/+9z+MHDnS78+PEE+YrPagX227MSYrhyqdGR0iZXyHQghpy2MN1MWUWBDiBl7XsQhUtI4FCQQFlXrknq/iOwze9E+KxjVXxfAdBiFtG2PAkS/bbkmQRAoMnAiIg2quG0LcIoh1LAghTStro2VQ9agcipAAoK9su0kFANgtgE7VfDtCCABKLAgJWG19ZqRqvRVGi53vMAhp2zRtuAyqXlsuBSPETZRYEBKAqvUWmKy0lgPNDkUIz9rqwO3LqdvQiuOEeIgSC0ICEJUB1aHXgRAe2SyAvoLvKPhn0gAmNd9RECIIlFgQEoDaehlUPaXGBDtH80sQwgtNSd3gbUKL5RHSQpRYEBJgLDYOlbo2PFjyMjY7Q7mWkixCeEFlUJfQOAtCWoQSC0ICjEpjAl2kv4TKoQjhiaaE7wgCh05VVxpGCGkSJRaEBJgyKoNyUlJLrwchfmesASx6vqMIHIyjGbIIaQFKLAgJMG19/Yor6Uw2qA1tcwVyQnijpt6KBmicBSHNosSCkACiNlqhN9PaDVeiaWcJ8TMaX9EQDWYnpFkhfAdACLmEZoNqKMSqQ21RMWCyAgkDgIgOfIdESHCz2wCdku8oAgbHALXJimqdHmeMRzCoX2/ER8v5DouQgESJBSEBhMqgALHdjDBzOcJNKoSbyhFi00MsAqziGIRqSoGrxwDh7fkOk5DgpVMCXNvuOb2UTFhQY7A4pr0Wiy9iz+kOuPHqOCQqKLkg5EqUWBASIOwcQ7mm7U0zK2J2yM2VCDeVI8ykgsxa26DcgGOAxmhDB7EZOLuzLrkIi+ElXkKCXhstg2osmbhcuFGJKu4a/HymHDf0jEOnmDAeIiUkcFFiQUiAqNCaYWsj88xKbAZEGYoQbiqH3FwJEWv+6mitwYIOEVLAagTO/AD0GgvIo/0QLSFtTBtas6ElycTlpFY1JDYD7CHh+PVsBUb0iMVV7cL9FC0hgY8GbxMSIErbShkUY+hYlYsOtUcRZlK1KKkA6ga2Oz7yrYa65MKs9VmYhLRJZh1gUvMdhV+U1BpxqKgGZ5RaVOrMzSYV9cJNKgCAnQN+O1uJ4mqDL8MkRFAosSAkQLSVgdvtNKcgM1e7/TirnUFntl3aYNHVJRc01z4h3tNGyqA0Jisu1hhbnExcLtx0aWA7x4DfzlXiQiW9DxECUGJBSEAwWGyobQNrNUgtarTXnGj149XGK14jsxY4swOw0BVDQryiDSwCZ2cMBZWtf88IN5U7jQNjDMjNr0J+hc4b4REiaAGRWLz//vtISUmBXC5HWloa9u3b12T7jRs3onfv3pDL5RgwYAC2b9/udP/UqVMhEomcbmPGjPHlUyDEI21itW3GIaF6X90Ktq3UILEAAJMGOPsDYG0DryEhvsQYoCnjOwqfu1hjhMna+lmvxJwFckuV0zbGgD/yq3GunJIL0rbxnlhs2LABWVlZWLBgAQ4dOoSBAwciIyMD5eXlLtvv3bsXEydOxPTp03H48GFMmDABEyZMwLFjx5zajRkzBmVlZY7bl19+6Y+nQ0irlNUG/5fi9pqTkFpqPdqH3myDlXORmBhr65ILW9ubVYsQr9GVA3YL31H4lNZsg0rj+ftt2N/jLK60r6AaZ1Q09ou0XbwnFu+88w4efvhhTJs2DX379sXKlSsRHh6OTz/91GX7d999F2PGjMHs2bPRp08fLFq0CNdddx1WrFjh1E4mkyExMdFxa9eunT+eDiFuY4xB6YUPukAms9SgneaUx/thf08765Khum4qWltwfzEixGc0JXxH4FMcAwoq9V5ZPDvC1PgCggcu1OBkmcbzgxAiQLwmFhaLBQcPHkR6erpjm1gsRnp6OnJzc10+Jjc316k9AGRkZDRov2fPHsTHx6NXr154/PHHUVXl3G1JSKCo1FlgsbW+PCjgMQ7xVfs9KoG6nMtyqHr6SuDcLsAe/ONVCPG6IB+4Xao2wmjxzsJ/MksNxE307hwuqsXx0rYxuxYhl+M1saisrITdbkdCQoLT9oSEBCiVrq8GKJXKZtuPGTMGn3/+OXJycrB48WL8/PPPGDt2LOx2128oZrMZGo3G6UaIvwT7bFAd1MchtXrvA1ZtbKZHQlcOnPsRsDfSs0EIachqAgyVfEfhMwaLHWW1XpzSmzHHtLONOVKsxl8Xa713TEIEICgXyHvggQcc/x8wYACuueYadO/eHXv27MGtt97aoH12djZeeeUVf4ZIiEMwr18hM1chRnvaq/u02BgMFjvCpZLGG2mVwPmfgB63AuIm2hFC6mhLG6x4HywYgPxKHby9/mi4SQldRHKTbY6VaMAYMDA5xrsHJyRA8dpjERsbC4lEApXKOetXqVRITEx0+ZjExES32gNAt27dEBsbi3Pnzrm8f+7cuVCr1Y5bcXGxm8+EkNYx2+yo1gfnmAARZ0NC9X6ffFmpbaocqp6mBDi/G3A12JsQ4kwdvOMrytRG6M3eKYG6XHgT4ywud7xUE7Tv84RcidfEQiqVYtCgQcjJyXFs4zgOOTk5GDZsmMvHDBs2zKk9AOzatavR9gBw8eJFVFVVoWPHji7vl8lkiI6OdroR4g8qtTlYLxKig/oYQq2+mR2l2XIoR8NioPq8T2IgJKgE6fgKo82OEm+WQF1GYjdBamlZmSeVRJG2olWJRX5+vtcCyMrKwurVq7F27VqcPHkSjz/+OPR6PaZNmwYAmDx5MubOneto/8wzz2DHjh1YunQpTp06hYULF+LAgQOYOXMmAECn02H27Nn4448/cOHCBeTk5GD8+PHo0aMHMjIyvBY3Id4QrGVQcnMlFDrXPYTeoDPZYG9pRlZ2hHotCGmKoRqwBucikwUVep/++be016K01oQqHU2HTYJfqxKLHj164Oabb8Z///tfmEyeDTzNzMzE22+/jfnz5yM1NRV5eXnYsWOHY4B2UVERysouLdgzfPhwrFu3DqtWrcLAgQOxadMmbNmyBf379wcASCQS/PXXX7jzzjtx9dVXY/r06Rg0aBB+/fVXyGQyj2IlxNuCceC2iLMhoco3JVD1OAZoTC0cnG3WUq8FIU0J0mlmlRoTtC19n2illiYWAPBXCc0SRYKfiDH3P/3z8vLw2Wef4csvv4TFYkFmZiamT5+OoUOH+iJGv9NoNFAoFFCr1VQWRXym1mDB9qMt/1ASirjqQ4jW+f6LfEK0DCkdIlrWWBYF9LsbEPO+dA8hgefMD0FXCmWycThWoobd2yO2ryQSIz9pPJi4ZXPhjO6bgLgoushJhMWd78Wt+pRNTU3Fu+++i9LSUnz66acoKyvDyJEj0b9/f7zzzjuoqKhoVeCEtCVlQdhbEWZS+SWpAAB1YwvluUK9FoS4ZrcBuqanTRWiC5V63ycVAMA4hJlb/p3naEmt72IhJAB4dPkuJCQEd999NzZu3IjFixfj3LlzmDVrFpKTkzF58mSnEiZCiLOyIBtfIeasiK8+4Lfjmax2mGxuzPRCYy0IaUhbBnDenzGJT+U6c9MLaXqZO+VQSrUZ5Zrgu6hESD2PEosDBw7giSeeQMeOHfHOO+9g1qxZOH/+PHbt2oXS0lKMHz/eW3ESElRsdg4V2uAayNeh9ghCbP4dAOrWlwezFqj23sQThASFICuBstg5FFX5930owuheSetfF2msBQlerVog75133sFnn32G06dP47bbbsPnn3+O2267DeK/65e7du2KNWvWICUlxZuxEhI0VFoz7EF08TzcWIZoXYHfj1trsCIhSt7yB5TlAe270VgLQuppLvIdgVddqDL4pwTqMiE2HUKsOthCI1vUvlxrhkpjQkK0G+9dhAhEqz5dP/zwQ/zzn/9EYWEhtmzZgttvv92RVNSLj4/HJ5984pUgCQk2yiAqgxLbLX4tgbqc1mRzbzVd6rUg5BKzFjBp+I7Ca6r0FtTwtBBdhBvlUAD1WpDg1aoei127dqFz584NkgnGGIqLi9G5c2dIpVJMmTLFK0ESEmxKa4Onxra95gQkdn6ej51j0JqsUISFtvxB1GtBSJ0gWm3baudQWKXn7fjhJhXUUT1a3L5Ca0aZ2oiOijAfRkWI/7Xqk7V79+6orKxssL26uhpdu3b1OChCgpnObPP53Or+IrZbeCmBulytu4M0qdeCkDpBtH5FUbURVrt/S6AuF2YuB5h79a3Ua0GCUasSi8aWvtDpdJDLqWaQkKaU1QZPGZRCdx4ixm+SpGnN7C9leT5dwI+QgMdxdTNCBQGzjUOVnt/JMEScDWHmhhdcm1Kls6AkiD4PCAHcLIXKysoCAIhEIsyfPx/h4eGO++x2O/7880+kpqZ6NUBCgk3QrF/BOCh05/iOAgaLHWY7B5nEjeskZi1QdR6IbXnpAiFBRV8O2P03JasvKTWmgLhOEG5SwiiPd+sxRy/WIimGyqFI8HArsTh8+DCAuh6Lo0ePQiqVOu6TSqUYOHAgZs2a5d0ICQkiHMegDJI5zKMMRbyNrbiS2mhFfKSbq9mW5QEdugMikU9iIiSgBUkZlI2xgJm6O9yoRFXMNW49plpvRXG1Acntw5tvTIgAuJVY7N69GwAwbdo0vPvuu80u600IcVahM8PGYx2wN8VozvAdgoPa0IrEgnotSFsWJAO3K7Vmv08v2xipVQ2J3Qi7xL0eiGMlakosSNBo1RiLzz77jJIKQlohWOppw0wqSK2BM/BQY7KiVV8taKwFaYusRsBQxXcUHmNAwPUAh5tUbj+mxlDXa0FIMGhxj8Xdd9+NNWvWIDo6GnfffXeTbTdv3uxxYIQEo5Ka4EgsYrSB01sBADY7g85sQ5TMzRm0qdeCtEVBstp2jcECszWwVhoNNyqhjUhx+3FHqdeCBIkWfworFAqI/q5FVigUPguIkGClMVmDYprZUKsa4Ub3FoPyB7XR6n5iAQDKIzTWgrQt6mK+I/AKZQBOhBFuLq/rBXXz/aTWYEVhlR5dOkT4KDJC/KPFn8KfffaZy/8TQlqmNEjKoGK0Z/kOwSW10YqrWjO7iklDvRak7eC4oBhfEajrAYntZsgs1TDLOrj92KMlanRuH+64iEuIELVqjIXRaITBcKkesLCwEMuXL8fOnTu9FhghwSYYyqAkdhOi9EV8h+GS3myD1d7KsgjlERprQdoGnQqwW/iOwmOBNrbicq0ZZwEAGqMNF6porAURtlYlFuPHj8fnn38OAKitrcXQoUOxdOlSjB8/Hh9++KFXAyQkGFhsXMBMieiJugXx7HyH4RJjgNrYyiuYJg2txk3ahiAogzLbONToAzc5Cje1vlT0WIm60UWICRGCViUWhw4dwg033AAA2LRpExITE1FYWIjPP/8c//d//+fVAAkJBkq1CQEyI2KriZgdCt15vsNokro1q3DXoxmiSFtQG5g9ju5QaQL7/VRuqYaYa917kdZkQ0Gl3ssREeI/rUosDAYDoqKiAAA7d+7E3XffDbFYjOuvvx6FhYVu7+/9999HSkoK5HI50tLSsG/fvibbb9y4Eb1794ZcLseAAQOwffv2Rts+9thjEIlEWL58udtxEeItwTDNbJS+EGJ7YPe6aEweXMWkXgsS7Iy1dTOhCZidMZQHeu8vYwjzpNeiVAMukDMnQprQqsSiR48e2LJlC4qLi/HDDz/gH//4BwCgvLzc7fUtNmzYgKysLCxYsACHDh3CwIEDkZGRgfLycpft9+7di4kTJ2L69Ok4fPgwJkyYgAkTJuDYsWMN2n799df4448/0KlTJ/efJCFewhgLioHbgTbFrCsWG4Pe4kGpFvVakGCmvsh3BB6rCKAF8ZrS2nEWAKAz2ZBPvRZEoFqVWMyfPx+zZs1CSkoK0tLSMGzYMAB1vRfXXnutW/t655138PDDD2PatGno27cvVq5cifDwcHz66acu27/77rsYM2YMZs+ejT59+mDRokW47rrrsGLFCqd2JSUleOqpp/DFF18gNDS0NU+TEK+o1FlgtgXWXOvuCjeWIdQqjCudaiP1WhDiksDHVzAAKk2A91b8LcLDKbmPl6qp14IIUqsSi3vvvRdFRUU4cOAAduzY4dh+6623YtmyZS3ej8ViwcGDB5Genn4pILEY6enpyM3NdfmY3Nxcp/YAkJGR4dSe4zhMmjQJs2fPRr9+/ZqNw2w2Q6PRON0I8RbqrfAvtcGDcRYA9VqQ4GQzAzrXlQBCUWOwwGQNzMkjriSxGyG1qFv9eL3ZjotBMJMgaXtalVgAQGJiIq699lqIxZd2MXToUPTu3bvF+6isrITdbkdCQoLT9oSEBCiVrrN9pVLZbPvFixcjJCQETz/9dIviyM7OhkKhcNySk5Nb/BwIaY7Qx1fILDUIMwnnC4nObIPNk8TApAFqLngtHkICgvoiwITdc6oKwAXxmuLJ7FAAcEYljF5iQi7XimVqAb1ejzfffBM5OTkoLy8Hxzm/WeXn81dKcPDgQbz77rs4dOhQixeZmTt3LrKyshw/azQaSi6IV+jNNtR6egWdZ0LqrQAAjgEaoxXtw6Wt30nFKaB9V+8FRQjfBF4GpbfYoQnABfGaEm5SoTa6V6sfX641o9ZgQYwn72WE+FmrEosZM2bg559/xqRJk9CxY8dWrxIZGxsLiUQClcp5kJNKpUJiYqLLxyQmJjbZ/tdff0V5eTk6d+7suN9ut+P555/H8uXLceHChQb7lMlkkMlkrXoOhDRF6GVQEpsBkQbhDfj0OLHQKgFDNRDe3ntBEcKXIFhtW6kW3nup3FwJEWcDE7fqqxYA4IxKh6Fd6X2ICEerzvbvv/8e27Ztw4gRIzw6uFQqxaBBg5CTk4MJEyYAqBsfkZOTg5kzZ7p8zLBhw5CTk4Nnn33WsW3Xrl2OAeSTJk1yOQZj0qRJmDZtmkfxEuIuoZdBxejOCbJ8otaT9SzqVZwGugzzfD+E8E3gq22b7RyqA3hBvMaImB1h5goYwjq2eh8XKvVITY6BNKTVleuE+FWrEot27dqhfXvvZNBZWVmYMmUKBg8ejKFDh2L58uXQ6/WOJGDy5MlISkpCdnY2AOCZZ57BqFGjsHTpUowbNw7r16/HgQMHsGrVKgBAhw4d0KFDB6djhIaGIjExEb16tb5LkhB32ewcVBph1QRfTsRZEa0T5gxJZisHg9WO8FBJ63dSfR5IGgSEUBkCETiBTzNbrjEH9IJ4TQk3qTxKLGwcQ0GlHr0So7wYFSG+06oUeNGiRZg/fz4MBoPHAWRmZuLtt9/G/PnzkZqairy8POzYscMxQLuoqAhlZWWO9sOHD8e6deuwatUqDBw4EJs2bcKWLVvQv39/j2MhxJuUGhPswrvY7xCtv9Dq1WMDgcbTXgu7Fag6551gCOGTWrirbdctiCfcCzSeDuAGgLPlNIibCIeIMfenT7n22mtx/vx5MMaQkpLSYJ2IQ4cOeS1APmg0GigUCqjVarcX/COk3r6Capwr1/EdRuswhi5l3yPEJtxFmhThoeid4OFVPrkC6H+3dwIihA8mNXBsM99RtJpKa8KFSs8vYvKpsNNtsIVEeLSPW3rHI1Eh91JEhLjHne/FrSqFqh8PQQhpnJAHbkcYSwSdVAB1q9dyDBC3bm6JOiY1oCkFojt5LS5C/KpW2LNBKdXCWBCvKeFGJTRR3T3axxmVlhILIgitSiwWLFjg7TgICSrVegsMFmEs5OSK0KaYdcXOMWhMVsSEhTbfuCnlJymxIMIl4PEVQloQrynhJpXHiUVJrRF6sw0RstbPMEWIP7R6moHa2lp8/PHHmDt3LqqrqwHUlUCVlAh7SjtCvEHIvRVycyXk5iq+w/AKtTdmh1IXA2aBlrSRts1mqZsRSqCUAp784nLhJhVEnGdrcDAGnBVqaS1pU1qVWPz111+4+uqrsXjxYrz99tuora0FAGzevBlz5871ZnyECJKQp5lVaM/yHYLXeCWxYAyoPO35fgjxN3WxIKeLBv5eEM8orAXxGiNiNoSbPE/wzpfrYBfq9FikzWhVYpGVlYWpU6fi7NmzkMsv1fzddttt+OWXX7wWHCFCZLLaUaUT3pzrABBi0yPSGDy9jkaLHSabF75YVZ6tW2SMECERcBmUEBfEa0qk0fPfhdnGoaha2APZSfBrVWKxf/9+PProow22JyUlQan0fGo1QoRMyL0VMdqzdVfog0iNNxbWshqBmgLP90OIv3AcoBHmRQKLQBfEa0q4scwrvUdnVDT1LAlsrUosZDIZNBpNg+1nzpxBXFycx0ERImRCHV8h4qyI0l/gOwyvqzZ46QtKxSnv7IcQf9CXAzZhzqikEvCCeI0Rc1avlENV6SxBl3SR4NKqxOLOO+/Eq6++Cqu1rn5ZJBKhqKgIL774Iu655x6vBkiIkHAcQ5lamAMOo/RFgl4QrzE6kw1mb5RD6coBQ7Xn+yHEHwQ6zSzHgAoBL4jXlEiDd0rTqNeCBLJWJRZLly6FTqdDXFwcjEYjRo0ahR49eiAqKgqvv/66t2MkRDBUWhNsdmFealPozvMdgs947Qpf+Unv7IcQX1MLM7Go1JthFeh7aHMijKVeKYcqqjLAbBP+NLwkOLVqQmSFQoFdu3bh999/x5EjR6DT6XDdddchPT3d2/ERIihCLYOSmyogtar5DsNnqg0WdPTG4lLV+cBVQ4AQqef7IsRXTOq6mwCVB8kUs66IOQvCzBUwyhM82o+NY8iv0KNPx6ZXQCaED24nFhzHYc2aNdi8eTMuXLgAkUiErl27IjExEYwxiESeLHNLiLCV1ArzQ1GhO8d3CD6lM9lgtnOQSVq9dE8dzgZUnQUS+nknMEJ8QaBlUFqzDXpzcF+JjzRc9DixAOrWtOidGEXfuUjAcetTljGGO++8EzNmzEBJSQkGDBiAfv36obCwEFOnTsVdd93lqzgJCXhqoxU6k/DmXZfYjYg0lvIdhs95ZXYogAZxk8An0Glmg7m3ol5dOZTnpV46kw2lAh3PR4KbWz0Wa9aswS+//IKcnBzcfPPNTvf99NNPmDBhAj7//HNMnjzZq0ESIgQlNcIsg1Lo8gW7iJY7qvUWJEZ7oRzKpAHUJYAiyfN9EeJtAl1t2xqEU8y6IrGbIDdXwiT3fAbNMyotkmLCvBAVId7jVo/Fl19+iXnz5jVIKgDglltuwZw5c/DFF194LThChESQ61cwDtG6fL6j8Aud2QaL3UsJVAUN4iYBSnNRkBcKyrXBN8VsY7y1CGlZrQlaU/DN5EeEza3E4q+//sKYMWMavX/s2LE4cuSIx0ERIjRmmx2VOuHNGR9pKIHE3ja60xkDary1poX6ImDWeWdfhHiTAMdXMNQlFm2Ft6adBerGWhASSNxKLKqrq5GQ0Pigo4SEBNTU1HgcFCFCU1ZrEuSC1cE+aPtK1TovJRaMARWnvbMvQryFMUGutl1jsMDijbVmBEJiN0JmrvLKvvIr9LB5qyeWEC9wK7Gw2+0ICWl8WIZEIoHNJrzBq4R4SojTzEottZCbK/kOw6+0Zhus3voQrjwDcME9gw0RGJ1KkKttq9rAoO0reavXwmLjUFht8Mq+CPEGt2eFmjp1Ku6++26Xt4ceeqhVQbz//vtISUmBXC5HWloa9u3b12T7jRs3onfv3pDL5RgwYAC2b9/udP/ChQvRu3dvREREoF27dkhPT8eff/7ZqtgIaQ5jTJCzc7S13gqg7oJutbfKoWwmoOaCd/ZFiDcIcDYog8UOjbHtXZD01jgLADhLK3GTAOJWYjFlyhTEx8dDoVC4vMXHx7s9I9SGDRuQlZWFBQsW4NChQxg4cCAyMjJQXl7usv3evXsxceJETJ8+HYcPH8aECRMwYcIEHDt2zNHm6quvxooVK3D06FH89ttvSElJwT/+8Q9UVFS4FRshLVGhNQuuG1/MWRGlF14ttjd4deYZWombBJLaIr4jcFu5VngXZbwhxKaHzOKd0vFqvRUVbWiMCglsIsb4rQxPS0vDkCFDsGLFCgB1C/AlJyfjqaeewpw5cxq0z8zMhF6vx3fffefYdv311yM1NRUrV650eQyNRgOFQoEff/wRt956a7Mx1bdXq9WIjqaVLUnTDhfV4GSZsK4YKbRnEFvTNidaEImAa5NjEOrpYnn1+twJRHTwzr4IaS2TBjj2P76jcIuNY8grroW9rUwHdYXa6N6oihnglX2ldAjH8B6xXtkXIVdy53uxlz5ZW8diseDgwYNIT093bBOLxUhPT0dubq7Lx+Tm5jq1B4CMjIxG21ssFqxatQoKhQIDBw70XvCE/E2I08zGaM/zHQJvvDo7FEBTz5LAoBZeD2SlztxmkwoAiPBiOVRRtQEmK435IvzjNbGorKyE3W5vMNNUQkIClEqly8colcoWtf/uu+8QGRkJuVyOZcuWYdeuXYiNdZ3Nm81maDQapxshLaE1WQVXHxxmVCLE1ranKKzWe3Hu9+oCQQ6YJUFGgNPMtqUpZl0JtWohtai9si+OAedo6lkSAHhNLHzp5ptvRl5eHvbu3YsxY8bg/vvvb3TcRnZ2ttNYkeTkZD9HS4SqtFZ49cExbXDQ9pW0Jqv3ZofibEDlWe/si5DWEOBq27VGK4wWusIeYfTegPvzFTrwXN1OCL+JRWxsLCQSCVQq5zdElUqFxMREl49JTExsUfuIiAj06NED119/PT755BOEhITgk08+cbnPuXPnQq1WO27FxcK78kP4IbRpZkNseoSbXPcGtiUcA2oMXuy1qDgFQS5kQoKDAFfbbuu9FfUiDd4rh9Kb7bhYI6zPJBJ8eE0spFIpBg0ahJycHMc2juOQk5ODYcOGuXzMsGHDnNoDwK5duxptf/l+zWbXb2QymQzR0dFON0KaY7Vzgpt/XaE7T1+A/+bV2aHMWkFO9UmChMDOPbONQ603xzkJmNSqRqjVe5N/nCyjUm7CL95LobKysrB69WqsXbsWJ0+exOOPPw69Xo9p06YBACZPnoy5c+c62j/zzDPYsWMHli5dilOnTmHhwoU4cOAAZs6cCQDQ6/WYN28e/vjjDxQWFuLgwYN46KGHUFJSgvvuu4+X50iCU1mtCUIadyhidkTrL/AdRsDQmqywcl68ylt+wnv7IqSlGBNcYlGuNdP1jct4a7E8AKjUWVCpo94gwp/Gl9H2k8zMTFRUVGD+/PlQKpVITU3Fjh07HAO0i4qKIBZfyn+GDx+OdevW4aWXXsK8efPQs2dPbNmyBf379wdQt/r3qVOnsHbtWlRWVqJDhw4YMmQIfv31V/Tr14+X50iCU36lsAbKRRqKIbbTB049jgG1BiviImXe2aGmFDBUA+HtvbM/QlpCVy6oyQM4BlS00bUrGhNhLEGNoo/X9neqTIuRPb30vkaIm3hfxyIQ0ToWpDlGix1b8koEddXtKuWPXluQKVjEhIeiV0KU93bYoQfQ9Qbv7Y+Q5hTvB1THmm8XICp0ZuRX6PkOI+AUdroNtpAIr+xLJALuHNgJETLerx2TICGYdSwIEar8Sp2gkgqZuYqSChc0Rits3qxnq84HLAbv7Y+QpnAcUC2sNWlUGuH0rviTN8uhGANOKYW1aCsJHpRYENIKBZXCuuKm0Anry4e/1JdDeQ3jgHJaMI/4iboYsApnFiCt2Qa9WVjr/vhLpBennQWA/AodLDZhzRRGggMlFoS4qUJrFtSieGK72atXw4JNtbdnp6k8DdiFc34QARPY+inlAptFz59k5mpIbN7r7bTaGc5XCGscIAkOlFgQ4qZ8gb1ZR+sLIGK0EFVj1EaLd8uhbGag8oz39keIKxZD3foVAmG1c95P4oNMpNF7a1oAwBmVlhbMI35HiQUhbrDZORRWC6iGnjEqg2oGx3m5HAqom3qWPtCJL1WdFdQ5VqEzw5uzOwcjby6WB9QtmFdcLZxSORIcKLEgxA1F1QbY7ML5MI8wliLEi93rwcrrV1LNWqC20Lv7JORyAiqDYqCVtltCbqmExO7dcrGTSlowj/gXJRaEuEFo0yRSb0XLqI0W2Lx99Vd13Lv7I6SepqwueRWIGoMFZit1VzSLMUQYS726yyqdBRWU1BE/osSCkBbSmqyCuuoWatUgzKTiOwxB4DhA7e1yKF153Y0QbxPYGB4VDdpuMV9MtHGKei2IH1FiQUgLCa+3Ip/vEASlWu+DgaUCWriMCITNIqgyO4PVLqhZ9PgWZq6A2O7d96KLNUboaJpf4ieUWBDSAowxQa1dIbZbEK0v4DsMQVEbrbB7uxyqtggw0dVC4kXV5wFOOLO80RSzbmIcIrw8OxRjwGnqtSB+QokFIS2g1JhgsAjnwzxGdxYijq5QucPOMe/PDsUYLZhHvEtAg7ZtHEOljqaYdZe3p50FgPMVelowj/gFJRaEtMD5cgH1VnBWKLTC+fIRSHxSDlV5pm5tC0I8pa8CDFV8R9FiKo0Jdm+uEdNGhJtUEHPevchhowXziJ9QYkFIM8w2O0pqhTNlq0J71usfSm2FT8qhOBtQcdq7+yRtk4AGbdsYg5LKoFqHcQg3lnl9t2dUWnCU6BEfo8SCkGZcqDTALpAeZBFnRYzuHN9hCJadY1AbfZCUlZ8ErQ5GPMLZgWrhTMigUpsEteZPoPHF7FB6sx3FNcK5SEaEiRILQppRUCmc7mOFLh9iO5XdeKLGF+VQVoOgvhSSAFRzAfDybEG+YmOMppj1ULhJ6ZOe55Nlwln/hAgTJRaENKFGb0G1XhhlRSLOhhitcEolAlWNwQqfVAvQ1LPEEwIatF2uMcFKvRUeETE7on0wZXi13oJyLSV9xHcosSCkCfkC6q2I1udDYqcPDE/ZOeabQdzGGkDj3VV1SRth0gBa79fc+4KdMSjV9D7kDTHaswDzfgnlKeq1ID5EiQUhjeA4hguVwqhHFTE72mlogLC3KDVG3+xYddw3+yXBrUpIvRVm6q3wEondiChDsdf3W1JrhNYkjJ54IjwBkVi8//77SElJgVwuR1paGvbt29dk+40bN6J3796Qy+UYMGAAtm/f7rjParXixRdfxIABAxAREYFOnTph8uTJKC2lK4XEPRdrjDALZN7vKN0F6q3wIr3Z7ptB3OqLdT0XhLQUY0ClMCZksDOGMrWPkvI2yhflrXUL5lGvBfEN3hOLDRs2ICsrCwsWLMChQ4cwcOBAZGRkoLy83GX7vXv3YuLEiZg+fToOHz6MCRMmYMKECTh2rK5+2WAw4NChQ3j55Zdx6NAhbN68GadPn8add97pz6dFgsB5oZRBMQ7ttKf4jiLolPmqnIN6LYg71BfrBv8LQLmWeiu8TWqpRZhJ5fX95tOCecRHRIx5e9J296SlpWHIkCFYsWIFAIDjOCQnJ+Opp57CnDlzGrTPzMyEXq/Hd99959h2/fXXIzU1FStXrnR5jP3792Po0KEoLCxE586dm41Jo9FAoVBArVYjOjq6lc+MCJnBYsM3eaXg96+jZaJ1+YirPsh3GEGpf5ICEVKJd3cqlgAD7gNCw7y7XxKczuUAtUV8R9EsjgFHLtbAYhPAm6bAGMISURZ3g9f3OzBZgX6dFF7fLwk+7nwv5rXHwmKx4ODBg0hPT3dsE4vFSE9PR25ursvH5ObmOrUHgIyMjEbbA4BarYZIJEJMTIzL+81mMzQajdONtG35FXpBJBVgHI2t8CGlL8o6OHvduhaENMdiqOuxEIByrYmSCh8JNyoRalV7fb9nVTpaMI94Ha+JRWVlJex2OxISEpy2JyQkQKlUunyMUql0q73JZMKLL76IiRMnNpplZWdnQ6FQOG7JycmteDYkmORX6vkOoUWiDEUIsQmkZEuAqvUW34yzqTgN2G3e3y8JLtXnfTIrkLdxDDS2wsfaabw/1sJgsaOoWhhldkQ4eB9j4UtWqxX3338/GGP48MMPG203d+5cqNVqx6242PuzMBDhKNeYoDMJ4EsfY2inobEVvsQxH80QZTMBVcIYkEt4VCmMdWkqdNRb4WuRhmJI7N5/LzqlpAoN4l28JhaxsbGQSCRQqZwHJqlUKiQmJrp8TGJiYova1ycVhYWF2LVrV5M1YTKZDNHR0U430nadrxBGb0Wk4SJCrTSzh69VaC2wcj64alx+HMKotyO80Krq1q8IcBwDSmtpRjpfEzE7FNrzXt9vtd5Kq6QTr+I1sZBKpRg0aBBycnIc2ziOQ05ODoYNG+byMcOGDXNqDwC7du1yal+fVJw9exY//vgjOnTo4JsnQIKO1c6hWCBdw+00VKfvD3aOoVxj9v6OTRpBDMolPBFUb0Xgl2sFA4XuPESc93vT84prvb5P0nbxXgqVlZWF1atXY+3atTh58iQef/xx6PV6TJs2DQAwefJkzJ0719H+mWeewY4dO7B06VKcOnUKCxcuxIEDBzBz5kwAdUnFvffeiwMHDuCLL76A3W6HUqmEUqmExeKD1XRJUCmsMsAmgMFsEYaLkPpgMB9xTaUxwSenRVke9VqQhmwWoOYC31E0q25sBV3t9hcxZ0G0/oLX91uls6BAIOMKSeAL4TuAzMxMVFRUYP78+VAqlUhNTcWOHTscA7SLioogFl/Kf4YPH45169bhpZdewrx589CzZ09s2bIF/fv3BwCUlJRg69atAIDU1FSnY+3evRs33XSTX54XESahvLnS2Ar/stoZKnUmxEfJvbtjQ3Xdlem4Xt7dLxG2mgLAB1emva1SZ4LZSr0V/hSjPQt1ZHdAJPLqfo8U1yK5XRhCJLxfbyYCx/s6FoGI1rFomzQmK747UsZ3GM0KN5ahY8VvfIfR5shDJbjmKgW8+3EOIEQO9L8HCJF6e89EqE5+C+gr+Y6iSRwD/iqppcSCB8rY4dCHJ3l9v/2TonHNVTFe3y8RPsGsY0FIIMkXyKDt9poTfIfQJpmsdtQYfFBOaTMBpYe9v18iTIbqgE8qAKBSZ6akgicxWt+sXXSqTAu9OfB7ykhgo8SCEACMMRRUBv56EGFGJWTmar7DaLN8Vk9ecQow1vhm30RYKs/yHUGzaGwFv+TmKsjMVV7fr41jNJCbeIwSC0IAlKpNMFoC/+pbe5oJilc6kw0ak9X7O2YcULzf+/slwsLZ6xbFC3BVejNMVjvfYbRpMVrfzBpWWGVAhdYHs+CRNoMSC0IAnCgN/Pniw0zlkJsDv0Qi2PnsSq2mBKgp9M2+iTDUFgK2wP5Sx0DrVgSCSGMJQqy+6WU/WEi9p6T1KLEgbV5hlV4QV2ho3YrAUGuwwmDx0dXai/vrrlqTtqk88Gd7q9JZqLciEDCGGN05n+y6Wm9BfkXglwaTwESJBWnTbHZOEDWlcnMlwkzlfIdB/uazXguzFlAd882+SWCrLgB0Kr6jaBIDUFJr5DsM8rdoXQHEdt+sz3XkYi1s9sAvDyaBhxIL0qadUmqhNwf+1TfqrQgs1XozzL760C37C7AIY4Yy4iV2G3DxAN9RNIt6KwKLiNkQrc/3yb6NFg4nygK/RJgEHkosSJtlsNgEMbYi3FiGcKOS7zDIZTgGqHzVa8EJ40sm8SLlX4AlsEtPGIBSNfVWBJoY7dm6yR984GSZhqafJW6jxIK0WXlFtbBxgb0+pNhuQXw1fckMROVas+/On+p8QEelb22CQMrfKrQmGH01toi0msRuQpShyCf7tnMQRKkwCSyUWJA2qVJnxoUqA99hNCuu9jAkdpqBJRDZOYZyXw76L/oDYIGd+BIvKN4X8AP2TTY7iqqptyJQxWh8M/UsUDf9bLmWPoNIy1FiQdokIUynF2EoQaTeN1eiiHeoNEb4rNPLUCWIxdKIB9QlQG1g/40zAPkVetgDvHe3LZNa1QjzYbnsocJan+2bBB9KLEibk1+hQ5XONzNpeIvYbkZczSG+wyDNsNgYKvU+7LUoPQTYAvtcJa3EcUDxn3xH0awytRFaE9XZB7p2PlowD6DpZ4l7KLEgbYrNzuHIxVq+w2hWXM0hKoESCKWvBnEDgNUIlB3x3f4Jf8qPAyY131E0SW+xo6SGSqCEIMykgtTiu/PpyMVaWGn6WdIClFiQNuV4qQZGS2C/OUbqixFpuMh3GKSFjBY7agw+7FUoPwEYa323f+J/FkPAJ4wcq+vdpQoo4YjxYa+F0cIJYhZFwj9KLEiboTPbcEoZ2G+MEruRSqAEyGcL5gF1U0le3O+7/RP/KzkA2K18R9GkizUG360wT3wiylAEic13k5KcUmqgo+lnSTMosSBtRl5RLQK9Jze++hDEHNXUC43WZIPWlx+46otAbbHv9k/8R6sCqs7zHUWTNGYrlBoqxRQcxiGh+oDPZpOzc3Wfo4Q0hRIL0iaUa0woqg7s6WWj9BcQbizlOwzSShcq9bD7cnrY4j/rBvwS4WIs4Ads2xhDQYWBZjoWqDCTCjHa0z7bf1E1TT9LmsZ7YvH+++8jJSUFcrkcaWlp2LdvX5PtN27ciN69e0Mul2PAgAHYvn270/2bN2/GP/7xD3To0AEikQh5eXk+jJ4IAWMs4KeXldgMiK3J4zsM4gGDxY5CX66NYtbWDfglwlVxum4a4QBWVGWAyUolUELWQX0cMrPvzrNDhTVglHmSRvCaWGzYsAFZWVlYsGABDh06hIEDByIjIwPl5a5XnN27dy8mTpyI6dOn4/Dhw5gwYQImTJiAY8curVqq1+sxcuRILF682F9PgwS48xV61BgCu545ofoAxFxgx0iaV6E1o9KXUxmXHakb+EuEx2aumz44gNUYrajw5aKPxD8Yh8SqfT77TKnWW3G+Qu+TfRPhEzEe0860tDQMGTIEK1asAABwHIfk5GQ89dRTmDNnToP2mZmZ0Ov1+O677xzbrr/+eqSmpmLlypVObS9cuICuXbvi8OHDSE1NdSsujUYDhUIBtVqN6Oho958YCRgWG4fv/iqFyRq4JSTRunzEVR/kOwziJRKxCP2SohEWIvHNAdp3A7qN8s2+ie8U5gIVp/iOolFWO4ejJWpY7XQlOljoIjpD1SHNJ/sOEYtwS594xEbKfLJ/Eljc+V7MW4+FxWLBwYMHkZ6efikYsRjp6enIzc11+Zjc3Fyn9gCQkZHRaHtCjpWqAzqpCLHpEVsb2NNOEvfYOYbzKh9O01mdDyiPNd+OBA5DNVDpu7p3b7hQZaCkIshE6osQpb/gk33bOIZfzlTQLFGkAd4Si8rKStjtdiQkJDhtT0hIgFLpeml6pVLpVvuWMpvN0Gg0TjcifBqTFWeUWr7DaFJ89QGIOHpjDjZ6ix1F1T4sFbi4vy7BIMJQ9IfPZurxhkqdBdV6mo0uGMXVHEao1TefgyYrhz2ny2GxBe7FO+J/vA/eDgTZ2dlQKBSOW3JyMt8hES84XFQb0Is7KbRnEWZyPZ6ICJ9KY0aVL7+sFfwKaMp8t3/iHVXnAZ2K7ygaZbZxuFBF9fLBSsTZkFD1Z916OD6gMdrw69kKcIH8YUv8irfEIjY2FhKJBCqV8xuuSqVCYmKiy8ckJia61b6l5s6dC7Va7bgVF9N88UJXpjaipMbIdxiNCrVq0aGWylmCXUGlHiabj2bYYRxw/qe6MhsSmOxW4OIBvqNoUn6lDnb6UhjUZJYadKg96rP9qzRm/FlA70OkDm+JhVQqxaBBg5CTk+PYxnEccnJyMGzYMJePGTZsmFN7ANi1a1ej7VtKJpMhOjra6UaEi+MYDhXW8h1G4xhDfPV+iBiVQAU7O8dwrtyH4y3sFuDcj4CFrjgHpLIjgDVwZ/FSakzQGOl9qC2I0Z5BuNGzsvGmFFTqcaxE7bP9E+HgtRQqKysLq1evxtq1a3Hy5Ek8/vjj0Ov1mDZtGgBg8uTJmDt3rqP9M888gx07dmDp0qU4deoUFi5ciAMHDmDmzJmONtXV1cjLy8OJEycAAKdPn0ZeXp7H4zCIcJyr0EFtDNypW2O0ZyD34RzjJLDozT4eb2HRA2d3ATaqkQ8oJjWgCtx1RwxWO4prAjfpId4XX70fErvvFrf766IaFyrpIkdbx2tikZmZibfffhvz589Hamoq8vLysGPHDscA7aKiIpSVXaohHj58ONatW4dVq1Zh4MCB2LRpE7Zs2YL+/fs72mzduhXXXnstxo0bBwB44IEHcO211zaYjpYEpzK1EYeLAncxPJm5Gu3Vgftlg/iGSmNGtcGHX/yNNXVlUbQyd+Ao3uezunZPcQzIr9DT6dLGSOwmJFQ1vQixp/7Ir0K5hlbmbst4XcciUNE6FsJUpjbilzMVsAfoh2WYqRwdK3+nWaDaKIlYhP5JCshDfHg9h9a44B9jQOFeoPIM35G4xABcqNKjXEML4bVVVTHXoDa6l8/2Lw0R4x/9EhAtD/XZMYh/CWIdC0K8qbQ2sJOKCMNFdKz4jZKKNszn4y2AuiloL9Jii7xhDCj4JWCTCo4B58p1lFS0cR3UxyAz+26wtcXGYc/pCpisPpq4ggQ0SiyI4JXWGvHr2cBNKqJ1+Uis+gMiRm+ybZ3ebPN9XbvyL6D8pG+PQRriOCB/d8CuL2LlOJxSami9CgIwDolVf0LE+W4sos5k+/tiHxXFtDWUWBBBC/Skop36JOKqDwb04ljEv5Rqk2/HWwBA8Z9ATaFvj0Eu4ex1Y1wC9DU32zicLNNCa6IeU1InxKZDXM1hnx6jUmdB7nmaqKStocSCCFagJxWxNXlor6a1KkhDBZV6mH25Wm19SY6OFmD0ObsNOJcDqANz/SODxY4TZRoYLdRjSpxF6QsRpfdtMlxUbUBeca1Pj0ECCyUWRJBKAjmpYBziq/ZBoT3LdyQkQNnsDOcqfDzegvv7C6+J5pb3GbsVOLcL0JTwHYlLGpMVJ8o0sPgyiSWCFldzCKFWrU+PcaJUg3PlOp8egwQOSiyI4JTUGvFbgCYVIs6GjpV7fX4ViAifzmTDhSq9b5MLm6lujQtr4K5CL1g2C3DmB0AbmGskVektOK3SUo07aZKIs+Gq8t0IN5Y139gDBy5Uo0xN70NtASUWRFACOakQc1Z0qvjV52/QJHhUaM04qdTAZPNhmYpZW5dc0Orc3mM1AWd2APoKviNxSakx4XyFjtapIC0itpvRseI3dKg54rO1VzgG/HKmAidKNaBVDoIbJRZEMEpqjfg1QKeUldiNSFLthtxcyXcoRGB0JhuOlWhQofPhFKCGKuD4FqD8lO+O0VZYjXVJhSEwB6UWVhtQWGWg+SKI22K0Z3BV+W6E2Hwzc52dA/KKa7HzhApqo+9mpCL8osSCCEJ9UhGIvfohVh2uUu2G1Eq17KR17BxDfoUe5yp0sPnqJLdbgKJc4PT3NO6itSz6utfPWMN3JA1wDDhfoYNSTasek9aTmauRrNyFCIPvxg1V6SzYcawMx0vV1HsRhCixIAHvYo0hYJMKqaX27ys8VGZCPFels+BYqRoasw+v5mmVwIlvgLIjoFoZN5h1AZuU2RjDGZUGlTpao4J4TsxZkFi5F7E1h322/pKdA44Uq6n3IghRYkEC2sUaA347WxmQSUW4sRRJ5XsgsdMVQuI9ZiuHU2VaXKw1wmenPWcHSg4BJ7cCeirfa5ZJU5dUmH07e05rWOwcTpVpoDbSGhXEuxTac0hS7UaI1XczOlHvRfARMfpNNqDRaKBQKKBWqxEdHc13OG2SzmzD8RI1Cip9PGtOK4Qby9BefQIySzXfoZAgFykPQfe4SMhDfHgNSCQC4vsCna4DJCG+O45Q1RYDhXsBq49XTHeTjWMo15qgVJtgtQfYmyQJKkwcgvJ2g6GLSPbpcdpHSDGsWwcowkN9ehziPne+F1Ni4QIlFvzRmW04VqLGhQBMKCIMJWinOQmZJfDqq0nwkohFSOkQgdhIqW8PJIsCOg8DFEm+PY4QMAbUFADKo4AhsC4gWO0clBozVBoTTSVL/EoT2RWVMalgYt9dgJCIgX6dFOjXKRoikchnxyHuocTCQ5RY+J/WZMXxUk3AJhTtNScgtdTyHQppw2IjpegSG4EQX3/YdugBJA8FQmS+PU4g4jig6hyg/Cvgyp7Mdg5KtQkVWjMlFIQ3llAFlLHXwxrq2+9G7SOkuL5be8SE+/iCCmkRSiw8RImF/wR2QlGM9uqTNNsTCRjyUAm6xUUgSubjkqXQsLrkon033x4nUNhtQMUpoPxEwK33YbLZUVprQpXOHHDvkaRtYqIQVMUMgCaiC5jYd2VLYhHQP0mBqxOiIPVlOShpFiUWHqLEwve0JiuOlWhQ6OuVh93FGCINF9FOcwJSq4bvaAhxKVIegvYRUrSPkEIm8eEHrlwBxHQBYjoDkXG+Ow5fbOa6ZKL8VN0q5QHEYLGjTG1Eld5Ca1KQgMREEujDOkIXngxDWEcwkcQnx5GIgURFGDq3D0dSTBglGTygxMJDlFj4TmAnFMVopzlJCQURlEh5CDr8nWRIfZlkhIYDMcl1SUZUJ0As4A93iwFQHQcqTwP2wJrqUmu2oUxtQo2epo4lwsGJQ6EPS4I2PBlGeULdpBA+IBYBiQo5unSIoCTDjyix8BAlFt5ltXOo1luQX6EPqIQi1KpBmLkScnMVwszlPlttlBB/ibqsJ8OnSYZEWjfIW5FcdwsRQB00Z69b2K7yTN04Cs438/O7i2OAzmKFzmSH2miBhqaNJQJnl8ihC0+GLvwqmGSxPjtOfZLRuX04rmoXTkmGDwkusXj//ffx1ltvQalUYuDAgXjvvfcwdOjQRttv3LgRL7/8Mi5cuICePXti8eLFuO222xz3M8awYMECrF69GrW1tRgxYgQ+/PBD9OzZs0XxUGLhGbXBigqdGVU6M6r0FqiNVv678hkHuaX67ySiEnJLFcR2M89BEeIbIhEQJQtB+0gp2oX7OMkQiYGojnU9GTGdAWm4747VUjYLYKyum9HJUFV3M6kBxv+CgGYbB53ZBp3ZCp3ZDoPZFjAXWwjxNltIBHThydBEdIY1VOGz41CS4VuCSiw2bNiAyZMnY+XKlUhLS8Py5cuxceNGnD59GvHx8Q3a7927FzfeeCOys7Nx++23Y926dVi8eDEOHTqE/v37AwAWL16M7OxsrF27Fl27dsXLL7+Mo0eP4sSJE5DL5c3GRIlFy5msdlTpLajSmVGpM6NKZwmIOdXFnBVyc+XfPRKVkFlqfLaCKCGBTCSq68mIlIVCFiKGPFQMWYgEMl998MoVgDQSkEbUJRmhEX///++bxMuDPS2GusTBWJ9EVAfMjE4cA/QWG3Qm29/JhA0WG//JDSF8sIRGwyxtB2tI5KVbaBQ4Lw8AF4uAmHApouUhiJKHIkoe8vctlBKOVhJUYpGWloYhQ4ZgxYoVAACO45CcnIynnnoKc+bMadA+MzMTer0e3333nWPb9ddfj9TUVKxcuRKMMXTq1AnPP/88Zs2aBQBQq9VISEjAmjVr8MADDzQbEyUWl9jsHMw2DiarHSYbB7PVDpOVQ63RgkqdBTqT/7vtRcwOid0Eid2EkPp/ORMkdiNC7CaE2PQ0ToKQZojFcCQY8lAx5KGSup9DxZBJJBD7alZbifRSkhEafulfxgHMXlei5Pg/d8X2+n8ZwNkAUy1gNfoo0OZZ7RysdnbFv3U3k42DwWIDR3kEIU2yS+ROyYYlNBLWkChYQyK9vmaGLETsSDKi5CGIvizxCPFlz67AufO9mNdlVi0WCw4ePIi5c+c6tonFYqSnpyM3N9flY3Jzc5GVleW0LSMjA1u2bAEAFBQUQKlUIj093XG/QqFAWloacnNzXSYWZrMZZvOlshiNRrhfSu0cg51j4NiV/8Lp58v/b7ZdljxY7Y7/m60cbJ700TMGgEEEBjAGEeMggh0izg4xs0N02U3c2P85699JgwkhdiMkdjPEHA1qJMRTHAcYLXYYLQ178kQiQBoiRqhEDIlIBIlYBLFIhBAJIBaJIRHDse3yfyUiEcRiEUR/70MEACJABNHfP4sgslsgMlrqxjv4AQMcpZgMrC4nYXXviRzHwADYGQfGARz+3uZoc+m902LjYOU42P5OHGx2RiVMhHhB/YVCubmywX12iRx2iRycKAScOPTvf0PAiULr/hWHgolCYBeHgtVvE4WCE0kAkQgMIkAkdvxrsYhQabWhUtfwe4QsRIwQiQhSSd17X2iIGKFiUd2/EjFC/74v5Ir/i0WASCSCWASIRXXvhyLH///+12dXagIPr4lFZWUl7HY7EhISnLYnJCTg1KlTLh+jVCpdtlcqlY7767c11uZK2dnZeOWVV1r1HHyl/os/Y5c+DAE0+Jn7+z+OD08XHVAiABIRIJGI/v6piU9Dx+NZ09v/Thou/cu52E4IaSu4v28tdfmHcV0qUsfdyWSufC9s8J7oTkx/3wghbYDo7+9Eorqkw/H/ujsva9PUz5e+E3HM1XvgpTe0S0mG2HlXl4VT/15Yf5/I8VjhJCe8JhaBYu7cuU69IBqNBsnJyTxGVHc1UAJhnESEEEIIIYTwenEmNjYWEokEKpXKabtKpUJiYqLLxyQmJjbZvv5fd/Ypk8kQHR3tdCOEEEIIIYS0HK+JhVQqxaBBg5CTk+PYxnEccnJyMGzYMJePGTZsmFN7ANi1a5ejfdeuXZGYmOjURqPR4M8//2x0n4QQQgghhBDP8F4KlZWVhSlTpmDw4MEYOnQoli9fDr1ej2nTpgEAJk+ejKSkJGRnZwMAnnnmGYwaNQpLly7FuHHjsH79ehw4cACrVq0CUFez++yzz+K1115Dz549HdPNdurUCRMmTODraRJCCCGEEBLUeE8sMjMzUVFRgfnz50OpVCI1NRU7duxwDL4uKipyDHQBgOHDh2PdunV46aWXMG/ePPTs2RNbtmxxrGEBAC+88AL0ej0eeeQR1NbWYuTIkdixY0eL1rAghBBCCCGEuI/3dSwCEa1jQQghhBBCiHvfi2lmPUIIIYQQQojHKLEghBBCCCGEeIwSC0IIIYQQQojHKLEghBBCCCGEeIz3WaECUf14do1Gw3MkhBBCCCGE8Kf++3BL5nuixMIFrVYLAEhOTuY5EkIIIYQQQvin1WqhUCiabEPTzbrAcRxKS0sRFRUFkUjk9+NrNBokJyejuLiYprtto+gcIHQOEIDOA0LnAOH/HGCMQavVolOnTk5ry7lCPRYuiMViXHXVVXyHgejoaHoTaePoHCB0DhCAzgNC5wDh9xxorqeiHg3eJoQQQgghhHiMEgtCCCGEEEKIxyixCEAymQwLFiyATCbjOxTCEzoHCJ0DBKDzgNA5QIR1DtDgbUIIIYQQQojHqMeCEEIIIYQQ4jFKLAghhBBCCCEeo8SCEEIIIYQQ4jFKLHjy/vvvIyUlBXK5HGlpadi3b1+T7Tdu3IjevXtDLpdjwIAB2L59u58iJb7izjmwevVq3HDDDWjXrh3atWuH9PT0Zs8ZEvjcfR+ot379eohEIkyYMMG3ARKfc/ccqK2txZNPPomOHTtCJpPh6quvps+DIODuebB8+XL06tULYWFhSE5OxnPPPQeTyeSnaIk3/fLLL7jjjjvQqVMniEQibNmypdnH7NmzB9dddx1kMhl69OiBNWvW+DzOFmPE79avX8+kUin79NNP2fHjx9nDDz/MYmJimEqlctn+999/ZxKJhC1ZsoSdOHGCvfTSSyw0NJQdPXrUz5ETb3H3HPjnP//J3n//fXb48GF28uRJNnXqVKZQKNjFixf9HDnxFnfPgXoFBQUsKSmJ3XDDDWz8+PH+CZb4hLvngNlsZoMHD2a33XYb++2331hBQQHbs2cPy8vL83PkxJvcPQ+++OILJpPJ2BdffMEKCgrYDz/8wDp27Miee+45P0dOvGH79u3s3//+N9u8eTMDwL7++usm2+fn57Pw8HCWlZXFTpw4wd577z0mkUjYjh07/BNwMyix4MHQoUPZk08+6fjZbrezTp06sezsbJft77//fjZu3DinbWlpaezRRx/1aZzEd9w9B65ks9lYVFQUW7t2ra9CJD7WmnPAZrOx4cOHs48//phNmTKFEguBc/cc+PDDD1m3bt2YxWLxV4jED9w9D5588kl2yy23OG3LyspiI0aM8GmcxPdakli88MILrF+/fk7bMjMzWUZGhg8jazkqhfIzi8WCgwcPIj093bFNLBYjPT0dubm5Lh+Tm5vr1B4AMjIyGm1PAltrzoErGQwGWK1WtG/f3ldhEh9q7Tnw6quvIj4+HtOnT/dHmMSHWnMObN26FcOGDcOTTz6JhIQE9O/fH2+88Qbsdru/wiZe1przYPjw4Th48KCjXCo/Px/bt2/Hbbfd5peYCb8C/TthCN8BtDWVlZWw2+1ISEhw2p6QkIBTp065fIxSqXTZXqlU+ixO4jutOQeu9OKLL6JTp04N3lyIMLTmHPjtt9/wySefIC8vzw8REl9rzTmQn5+Pn376Cf/617+wfft2nDt3Dk888QSsVisWLFjgj7CJl7XmPPjnP/+JyspKjBw5Eowx2Gw2PPbYY5g3b54/QiY8a+w7oUajgdFoRFhYGE+R1aEeC0IE5s0338T69evx9ddfQy6X8x0O8QOtVotJkyZh9erViI2N5TscwhOO4xAfH49Vq1Zh0KBByMzMxL///W+sXLmS79CIH+3ZswdvvPEGPvjgAxw6dAibN2/Gtm3bsGjRIr5DI4R6LPwtNjYWEokEKpXKabtKpUJiYqLLxyQmJrrVngS21pwD9d5++228+eab+PHHH3HNNdf4MkziQ+6eA+fPn8eFCxdwxx13OLZxHAcACAkJwenTp9G9e3ffBk28qjXvAx07dkRoaCgkEoljW58+faBUKmGxWCCVSn0aM/G+1pwHL7/8MiZNmoQZM2YAAAYMGAC9Xo9HHnkE//73vyEW0zXjYNbYd8Lo6GjeeysA6rHwO6lUikGDBiEnJ8exjeM45OTkYNiwYS4fM2zYMKf2ALBr165G25PA1ppzAACWLFmCRYsWYceOHRg8eLA/QiU+4u450Lt3bxw9ehR5eXmO25133ombb74ZeXl5SE5O9mf4xAta8z4wYsQInDt3zpFUAsCZM2fQsWNHSioEqjXngcFgaJA81CebjDHfBUsCQsB/J+R79HhbtH79eiaTydiaNWvYiRMn2COPPMJiYmKYUqlkjDE2adIkNmfOHEf733//nYWEhLC3336bnTx5ki1YsICmmxU4d8+BN998k0mlUrZp0yZWVlbmuGm1Wr6eAvGQu+fAlWhWKOFz9xwoKipiUVFRbObMmez06dPsu+++Y/Hx8ey1117j6ykQL3D3PFiwYAGLiopiX375JcvPz2c7d+5k3bt3Z/fffz9fT4F4QKvVssOHD7PDhw8zAOydd95hhw8fZoWFhYwxxubMmcMmTZrkaF8/3ezs2bPZyZMn2fvvv0/TzRLG3nvvPda5c2cmlUrZ0KFD2R9//OG4b9SoUWzKlClO7b/66it29dVXM6lUyvr168e2bdvm54iJt7lzDnTp0oUBaHBbsGCB/wMnXuPu+8DlKLEIDu6eA3v37mVpaWlMJpOxbt26sddff53ZbDY/R028zZ3zwGq1soULF7Lu3bszuVzOkpOT2RNPPMFqamr8Hzjx2O7du11+vtf/zqdMmcJGjRrV4DGpqalMKpWybt26sc8++8zvcTdGxBj1mxFCCCGEEEI8Q2MsCCGEEEIIIR6jxIIQQgghhBDiMUosCCGEEEIIIR6jxIIQQgghhBDiMUosCCGEEEIIIR6jxIIQQgghhBDiMUosCCGEEEIIIR6jxIIQQgghhBDiMUosCCGEBIw9e/ZAJBKhtra2yXYpKSlYvny5X2IihBDSMpRYEEIIcdvUqVMhEokgEokglUrRo0cPvPrqq7DZbB7td/jw4SgrK4NCoQAArFmzBjExMQ3a7d+/H4888ohHxyKEEOJdIXwHQAghRJjGjBmDzz77DGazGdu3b8eTTz6J0NBQzJ07t9X7lEqlSExMbLZdXFxcq49BCCHEN6jHghBCSKvIZDIkJiaiS5cuePzxx5Geno6tW7eipqYGkydPRrt27RAeHo6xY8fi7NmzjscVFhbijjvuQLt27RAREYF+/fph+/btAJxLofbs2YNp06ZBrVY7ekcWLlwIoGEpVFFREcaPH4/IyEhER0fj/vvvh0qlcty/cOFCpKam4j//+Q9SUlKgUCjwwAMPQKvV+uW1IoSQtoASC0IIIV4RFhYGi8WCqVOn4sCBA9i6dStyc3PBGMNtt90Gq9UKAHjyySdhNpvxyy+/4OjRo1i8eDEiIyMb7G/48OFYvnw5oqOjUVZWhrKyMsyaNatBO47jMH78eFRXV+Pnn3/Grl27kJ+fj8zMTKd258+fx5YtW/Ddd9/hu+++w88//4w333zTNy8GIYS0QVQKRQghxCOMMeTk5OCHH37A2LFjsWXLFvz+++8YPnw4AOCLL75AcnIytmzZgvvuuw9FRUW45557MGDAAABAt27dXO5XKpVCoVBAJBI1WR6Vk5ODo0ePoqCgAMnJyQCAzz//HP369cP+/fsxZMgQAHUJyJo1axAVFQUAmDRpEnJycvD666977bUghJC2jHosCCGEtMp3332HyMhIyOVyjB07FpmZmZg6dSpCQkKQlpbmaNehQwf06tULJ0+eBAA8/fTTeO211zBixAgsWLAAf/31l0dxnDx5EsnJyY6kAgD69u2LmJgYxzGBuvKp+qQCADp27Ijy8nKPjk0IIeQSSiwIIYS0ys0334y8vDycPXsWRqMRa9euhUgkavZxM2bMQH5+PiZNmoSjR49i8ODBeO+993web2hoqNPPIpEIHMf5/LiEENJWUGJBCCGkVSIiItCjRw907twZISF1lbV9+vSBzWbDn3/+6WhXVVWF06dPo2/fvo5tycnJeOyxx7B582Y8//zzWL16tctjSKVS2O32JuPo06cPiouLUVxc7Nh24sQJ1NbWOh2TEEKIb1FiQQghxGt69uyJ8ePH4+GHH8Zvv/2GI0eO4MEHH0RSUhLGjx8PAHj22Wfxww8/oKCgAIcOHcLu3bvRp08fl/tLSUmBTqdDTk4OKisrYTAYGrRJT0/HgAED8K9//QuHDh3Cvn37MHnyZIwaNQqDBw/26fMlhBByCSUWhBBCvOqzzz7DoEGDcPvtt2PYsGFgjGH79u2OUiS73Y4nn3wSffr0wZgxY3D11Vfjgw8+cLmv4cOH47HHHkNmZibi4uKwZMmSBm1EIhG++eYbtGvXDjfeeCPS09PRrVs3bNiwwafPkxBCiDMRY4zxHQQhhBBCCCFE2KjHghBCCCGEEOIxSiwIIYQQQgghHqPEghBCCCGEEOIxSiwIIYQQQgghHqPEghBCCCGEEOIxSiwIIYQQQgghHqPEghBCCCGEEOIxSiwIIYQQQgghHqPEghBCCCGEEOIxSiwIIYQQQgghHqPEghBCCCGEEOIxSiwIIYQQQgghHvt/EUz3eMK1C+YAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 3.5))\n", + "ax.fill_between(np.array(grid), 0, np.array(r_1d), alpha=0.4, label=\"Source $r$\")\n", + "ax.fill_between(np.array(grid), 0, np.array(c_1d), alpha=0.4, label=\"Target $c$\")\n", + "ax.set_xlabel(\"Position\")\n", + "ax.set_ylabel(\"Density\")\n", + "ax.set_title(\"Source and target distributions for 1D experiments\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This tutorial demonstrats how to leverage OTT-JAX to solve constrained optimal transport problems efficiently using the algorithms introduced in the paper. By using the Geometry class to encode dual shifts, we can seamlessly integrate additional linear constraints into the Sinkhorn loop." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}