diff --git a/CP5/active_plugins/cpforeign/__init__.py b/CP5/active_plugins/cpforeign/__init__.py new file mode 100644 index 00000000..10bc1553 --- /dev/null +++ b/CP5/active_plugins/cpforeign/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +__all__ = ["server"] diff --git a/CP5/active_plugins/cpforeign/bimz.py b/CP5/active_plugins/cpforeign/bimz.py new file mode 100644 index 00000000..ea119e21 --- /dev/null +++ b/CP5/active_plugins/cpforeign/bimz.py @@ -0,0 +1,109 @@ +import numpy as np +import logging + +from bioimageio.spec import InvalidDescr, load_description +from bioimageio.spec.model.v0_5 import ModelDescr +import bioimageio.core.prediction as bi_pred + +from skimage.filters import threshold_otsu +from skimage.measure import label +from skimage.morphology import closing, square +from skimage.segmentation import clear_border + +from server import ForeignToolClient + +logger = logging.getLogger(__name__) + +# https://bioimage.io/#/?tags=affable-shark&id=10.5281%2Fzenodo.5764892 +MODEL_ID = "affable-shark" +MODEL_DOI = "10.5281/zenodo.11092561" + +def load_model(): + loaded_description = load_description(MODEL_ID) + if isinstance(loaded_description, InvalidDescr): + raise ValueError(f"Failed to load {MODEL_ID}") + elif not isinstance(loaded_description, ModelDescr): + raise ValueError("This notebook expects a model 0.5 description") + + model = loaded_description + example_model_id = model.id + assert example_model_id is not None + + try: + descr = load_description(MODEL_ID) + except InvalidDescr as e: + logger.error(f"Invalid description: {e}") + return None + + return descr + +def predict(input_image, model): + out = bi_pred.predict(model=model, inputs={'input0': input_image}, skip_postprocessing=True, skip_preprocessing=True) + return np.array(out.members['output0'].data[0]) + +def run(image_data, image_header): + model = load_model() + + logger.debug("loaded model") + + # scaled image + im = image_data.copy() + logger.debug(f"provided image of shape {im.shape}, type {im.dtype}") + # im = (image_data / np.iinfo(image_data.dtype).max).astype(np.float32) + + pad_y = (64 - image_data.shape[0] % 64) % 64 + pad_x = (64 - image_data.shape[1] % 64) % 64 + # padded image + im = np.pad(im, ((0, pad_y), (0, pad_x)), mode='constant', constant_values=0) + logger.debug(f"padded image of shape {im.shape}, type {im.dtype}") + + # input image + im = im.reshape([1,1,im.shape[0],im.shape[1]]) + logger.debug(f"input image of shape {im.shape}, type {im.dtype}") + + # output image + logger.debug("running prediction") + res = predict(im, model) + del im + logger.debug(f"output image of shape {res.shape}, dtype {res.dtype}") + + # unpadded result + res = res[:, :image_data.shape[0], :image_data.shape[1]] + logger.debug(f"de-padded output image of shape {res.shape}, dtype {res.dtype}") + + # just the foreground probabilities, ignore boundaries + res = res[0] + logger.debug(f"using only fg prob of shape {res.shape}, dtype {res.dtype}") + + # threshold above certain prob + thresh = threshold_otsu(res) + logger.debug(f"threshold image shape {thresh.shape}, dtype {thresh.dtype}") + # make binary, with closing (remove small holes in fg with dilate then erode) + bw = closing(res > thresh, square(3)) + logger.debug(f"binary image of shape {bw.shape}, type {bw.dtype}") + + # remove border cells + # cleared = clear_border(bw) + # labels = label(cleared) + + # convert to labels + labels = label(bw) + logger.debug(f"labels of shape {labels.shape}, dtype {labels.dtype}") + + return labels + + +def main(): + client = ForeignToolClient(7878, cb=run) + client.receive_images() + +if __name__ == "__main__": + # init logging + logging.root.setLevel(logging.DEBUG) + stream_handler = logging.StreamHandler() + fmt = logging.Formatter(" [%(process)d|%(levelno)s] %(name)s::%(funcName)s: %(message)s") + stream_handler.setFormatter(fmt) + logging.root.addHandler(stream_handler) + + logger.debug("Starting bimz.py") + main() \ No newline at end of file diff --git a/CP5/active_plugins/cpforeign/bioimage_server.ipynb b/CP5/active_plugins/cpforeign/bioimage_server.ipynb new file mode 100644 index 00000000..c84da27d --- /dev/null +++ b/CP5/active_plugins/cpforeign/bioimage_server.ipynb @@ -0,0 +1,607 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import zmq\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Server setup and image retrieval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PORT = \"7878\"\n", + "DOMAIN = \"*\"\n", + "SOCKET_ADDR = f\"tcp://{DOMAIN}:{PORT}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "HELLO = \"Hello\"\n", + "ACK = \"Acknowledge\"\n", + "DENIED = \"Denied\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "context = zmq.Context()\n", + "socket = context.socket(zmq.PAIR)\n", + "#socket.copy_threshold = 0\n", + "b = socket.bind(SOCKET_ADDR)\n", + "b" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def cleanup(context, socket):\n", + " print(\"destroying existing context\")\n", + " if socket:\n", + " socket.close()\n", + " if context:\n", + " context.term()\n", + " # destroy is more destructive\n", + " # doesn't require sockets closed first\n", + " # may leave them hanging if managed by other threads\n", + " context.destroy()\n", + " print(\"socket closed\", socket.closed, \"context closed\", context.closed)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# handshake\n", + "socket.send_string(HELLO)\n", + "print(\"Sent hello, waiting for acknowledgement...\")\n", + "ack = socket.recv_string()\n", + "if ack == ACK:\n", + " print('Received connection ack:', ack)\n", + "else:\n", + " print(\"Received unkown message\", ack)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#cleanup(context, socket)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "header = socket.recv_json()\n", + "header" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# acknowledge receipt of header, ask for image data\n", + "socket.send_string(ACK)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "im_bytes = socket.recv(copy=False)\n", + "im_bytes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "buf = memoryview(im_bytes)\n", + "im = np.frombuffer(buf, dtype=header['descr'])\n", + "im = (im * 255).astype(np.uint8)\n", + "im.shape = header['shape']\n", + "im" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# acknowledge receipt of image data\n", + "socket.send_string(ACK)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(im, cmap='gray')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "BIOIMAGE download and inspect the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "from typing_extensions import assert_never\n", + "\n", + "\n", + "from bioimageio.spec.pretty_validation_errors import enable_pretty_validation_errors_in_ipynb\n", + "from bioimageio.spec import InvalidDescr, load_description\n", + "from bioimageio.spec.model.v0_5 import ModelDescr\n", + "\n", + "from bioimageio.spec.model.v0_5 import ArchitectureFromLibraryDescr, ArchitectureFromFileDescr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "enable_pretty_validation_errors_in_ipynb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# https://bioimage.io/#/?tags=affable-shark&id=10.5281%2Fzenodo.5764892\n", + "MODEL_ID = \"affable-shark\"\n", + "MODEL_DOI = \"10.5281/zenodo.11092561\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source = MODEL_ID\n", + "\n", + "loaded_description = load_description(source)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_description" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_description.validation_summary.display()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# let's make sure we have a valid model...\n", + "if isinstance(loaded_description, InvalidDescr):\n", + " raise ValueError(f\"Failed to load {source}\")\n", + "elif not isinstance(loaded_description, ModelDescr):\n", + " raise ValueError(\"This notebook expects a model 0.5 description\")\n", + "\n", + "model = loaded_description\n", + "example_model_id = model.id\n", + "assert example_model_id is not None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"The model is named '{model.name}'\")\n", + "print(f\"Description:\\n{model.description}\")\n", + "print(f\"License: {model.license}\")\n", + "\n", + "print(\"\\nThe authors of the model are:\")\n", + "pprint(model.authors)\n", + "print(f\"\\nIn addition to the authors it is maintained by:\")\n", + "pprint(model.maintainers)\n", + "\n", + "print(\"\\nIf you use this model, you are expected to cite:\")\n", + "pprint(model.cite)\n", + "\n", + "print(f\"\\nFurther documentation can be found here: {model.documentation}\")\n", + "\n", + "if model.git_repo is None:\n", + " print(\"\\nThere is no associated GitHub repository.\")\n", + "else:\n", + " print(f\"\\nThere is an associated GitHub repository: {model.git_repo}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for w in [(weights := model.weights).onnx, weights.keras_hdf5, weights.tensorflow_js, weights.tensorflow_saved_model_bundle, weights.torchscript,weights.pytorch_state_dict]:\n", + " if w is None:\n", + " continue\n", + "\n", + " print(w.weights_format_name)\n", + " print(f\"weights are available at {w.source.absolute()}\")\n", + " print(f\"and have a SHA-256 value of {w.sha256}\")\n", + " details = {k: v for k, v in w.model_dump(mode=\"json\", exclude_none=True).items() if k not in (\"source\", \"sha256\")}\n", + " if details:\n", + " print(f\"additonal metadata for {w.weights_format_name}:\")\n", + " pprint(details)\n", + "\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:\")\n", + "for ipt in model.inputs:\n", + " print(f\"\\ninput '{ipt.id}' with axes:\")\n", + " pprint(ipt.axes)\n", + " print(f\"Data description: {ipt.data}\")\n", + " print(f\"Test tensor available at: {ipt.test_tensor.source.absolute()}\")\n", + " if len(ipt.preprocessing) > 1:\n", + " print(\"This input is preprocessed with: \")\n", + " for p in ipt.preprocessing:\n", + " print(p)\n", + "\n", + "print(\"\\n-------------------------------------------------------------------------------\")\n", + "# # and what the model outputs are\n", + "print(f\"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:\")\n", + "for out in model.outputs:\n", + " print(f\"\\noutput '{out.id}' with axes:\")\n", + " pprint(out.axes)\n", + " print(f\"Data description: {out.data}\")\n", + " print(f\"Test tensor available at: {out.test_tensor.source.absolute()}\")\n", + " if len(out.postprocessing) > 1:\n", + " print(\"This output is postprocessed with: \")\n", + " for p in out.postprocessing:\n", + " print(p)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert isinstance(model, ModelDescr)\n", + "if (w:=model.weights.pytorch_state_dict) is not None:\n", + " arch = w.architecture\n", + " print(f\"callable: {arch.callable}\")\n", + " if isinstance(arch, ArchitectureFromFileDescr):\n", + " print(f\"import from file: {arch.source.absolute()}\")\n", + " if arch.sha256 is not None:\n", + " print(f\"SHA-256: {arch.sha256}\")\n", + " elif isinstance(arch, ArchitectureFromLibraryDescr):\n", + " print(f\"import from module: {arch.import_from}\")\n", + " else:\n", + " assert_never(arch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "BIOIMAGE - run prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import bioimageio.core.prediction as bi_pred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "im = (im / np.iinfo(im.dtype).max).astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pad_y = (64 - im.shape[0] % 64) % 64\n", + "pad_x = (64 - im.shape[1] % 64) % 64\n", + "padded_image = np.pad(im, ((0, pad_y), (0, pad_x)), mode='constant', constant_values=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_image = padded_image.reshape([1,1,padded_image.shape[0],padded_image.shape[1]])\n", + "del padded_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_image.shape, input_image.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "out = bi_pred.predict(model=model, inputs={'input0': input_image}, skip_postprocessing=True, skip_preprocessing=True)\n", + "del input_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = np.array(out.members['output0'].data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = res[:, :im.shape[0], :im.shape[1]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(res[0,:,:])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(res[1,:,:])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(np.stack([\n", + " np.zeros_like(res[0]),\n", + " res[0],\n", + " res[1]\n", + "]).transpose(1, 2, 0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from skimage.filters import threshold_otsu\n", + "from skimage.segmentation import clear_border\n", + "from skimage.measure import label, regionprops\n", + "from skimage.morphology import closing, square\n", + "from skimage.color import label2rgb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res_im = res[0]\n", + "\n", + "thresh = threshold_otsu(res_im)\n", + "bw = closing(res_im > thresh, square(3))\n", + "# cleared = clear_border(bw)\n", + "# label_image = label(cleared)\n", + "label_image = label(bw)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(bw)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(label_image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(label2rgb(label_image, image=im, bg_label=0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Send back the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "return_header = np.lib.format.header_data_from_array_1_0(label_image)\n", + "return_header" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "socket.send_json(return_header)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ack = socket.recv_string()\n", + "if ack == ACK:\n", + " print('Received return header ack:', ack)\n", + "else:\n", + " print(\"Received unkown message\", ack)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "socket.send(label_image, copy=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# socket.send_string(\"Cancel\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cleanup(context, socket)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cp_lis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.1.-1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/CP5/active_plugins/cpforeign/server.py b/CP5/active_plugins/cpforeign/server.py new file mode 100644 index 00000000..d5a6f668 --- /dev/null +++ b/CP5/active_plugins/cpforeign/server.py @@ -0,0 +1,189 @@ +import zmq +import numpy as np +import logging + +HELLO = "Hello" +ACK = "Acknowledge" +DENY = "Deny" +CANCEL = "Cancel" + +class ForeignToolError(Exception): + pass + +def _receive_ack(socket, logger, subject=None): + ack = socket.recv_string() + if ack == ACK: + if subject: + logger.debug(f"received ack for {subject}") + else: + logger.debug("received ack") + return True + elif ack == DENY: + raise ForeignToolError("denied, aborting") + elif ack == CANCEL: + raise ForeignToolError("canceled, aborting") + else: + raise ForeignToolError("unexpected response", ack) + +def _send_ack(socket, logger, subject): + if subject: + logger.debug(f"sending ack for {subject}") + else: + logger.debug("sending ack") + socket.send_string(ACK) + +class ForeignToolServer(object): + def __init__(self, port, domain='*', protocol='tcp', wait_for_handshake=True): + """ + Launch a server on the given port. + """ + self._logger = logging.getLogger(f"{__name__} [server]") + + self._context = zmq.Context() + self._server_socket = self._context.socket(zmq.PAIR) + self._server_socket.bind(f"{protocol}://{domain}:{port}") + self._logger.info(f"launched on {self._server_socket.getsockopt(zmq.LAST_ENDPOINT)}") + + if wait_for_handshake: + self.wait_for_handshake() + + def wait_for_handshake(self): + self._logger.debug("waiting for handshake from client") + client_hello = self._server_socket.recv_string() + if client_hello == HELLO: + self._logger.debug("received correct handshake") + _send_ack(self._server_socket, self._logger, subject="handshake") + else: + self._logger.debug(f"received incorrect handshake {client_hello}") + self._logger.debug("sending deny") + self._server_socket.send_string(DENY) + raise ForeignToolError("server received incorrect handshake") + + def _serve_image(self, image_data): + """ + Serve an image to the client. + """ + header = np.lib.format.header_data_from_array_1_0(image_data) + + self._logger.debug(f"sending header {header} waiting for acknowledgement") + self._server_socket.send_json(header) + + ack = _receive_ack(self._server_socket, self._logger, subject="header") + + self._logger.debug(f"sending image data {image_data.shape} waiting for acknowledgement") + + self._server_socket.send(image_data, copy=False) + + ack = _receive_ack(self._server_socket, self._logger, subject="image data") + + labels_header = self._server_socket.recv_json() + + ack = _send_ack(self._server_socket, self._logger, subject="return header") + + label_bytes = self._server_socket.recv(copy=False) + + self._logger.debug("received label byte data") + + self._logger.debug("parsing label data") + labels = np.frombuffer(label_bytes, dtype=labels_header['descr']) + labels.shape = labels_header['shape'] + self._logger.debug(f"parse label data of shape {labels.shape}") + + _send_ack(self._server_socket, self._logger, subject="return data") + + return labels + + def serve_one_image(self, image_data): + """ + Serve an image to the client. + """ + return self._serve_image(image_data) + +class ForeignToolClient(object): + def __init__(self, port, domain='localhost', protocol='tcp', do_handshake=True, cb=None): + """ + Connect to a server on the given port. + """ + self._logger = logging.getLogger(f"{__name__} [client]") + + self._context = zmq.Context() + self._client_socket = self._context.socket(zmq.PAIR) + self._client_socket.connect(f"{protocol}://{domain}:{port}") + self._logger.info(f"connected to {self._client_socket.getsockopt(zmq.LAST_ENDPOINT)}") + + if cb: + self.register_cb(cb) + + if do_handshake: + self.do_handshake() + + def do_handshake(self): + """ + Handshake with the server. + """ + self._client_socket.send_string(HELLO) + response = _receive_ack(self._client_socket, self._logger, subject="handshake") + + def register_cb(self, cb): + """ + Register a callback to be executed on the server. + Must be run before receeive_image + """ + self._cb = cb + + def _execute_cb(self, im, header): + """ + Execute the callback on the server. + """ + return self._cb(im, header) + + def _receive_image(self): + """ + Receive an image from the server. + """ + header = self._client_socket.recv_json() + self._logger.debug(f"received header {header}") + + _send_ack(self._client_socket, self._logger, subject="header") + + im_bytes = self._client_socket.recv(copy=False) + self._logger.debug("received image bytes") + + self._logger.debug("parsing image data") + buf = memoryview(im_bytes) + im = np.frombuffer(buf, dtype=header['descr']) + im.shape = header['shape'] + self._logger.debug(f"parsed image data {im.shape}") + + _send_ack(self._client_socket, self._logger, subject="image data") + + self._logger.debug("executing callback") + return_data = self._execute_cb(im, header) + self._logger.debug("executed callback") + + return_header = np.lib.format.header_data_from_array_1_0(return_data) + self._logger.debug(f"returning header {return_header}") + self._client_socket.send_json(return_header) + + ack = _receive_ack(self._client_socket, self._logger, subject="return header") + + self._logger.debug("returning data") + self._client_socket.send(return_data, copy=False) + + ack = _receive_ack(self._client_socket, self._logger, subject="return data") + + def receive_one_image(self): + """ + Receive a single image from the server. + """ + self._receive_image() + + def receive_images(self): + """ + Receive images from the server. + """ + while True: + try: + self._receive_image() + except ForeignToolError: + break diff --git a/CP5/active_plugins/cpforeign/thresh.py b/CP5/active_plugins/cpforeign/thresh.py new file mode 100644 index 00000000..25c1956f --- /dev/null +++ b/CP5/active_plugins/cpforeign/thresh.py @@ -0,0 +1,52 @@ +import logging +import numpy as np +import skimage as ski +import scipy as sp + +from server import ForeignToolClient + +logger = logging.getLogger(__name__) + + +def run(image_data, image_header): + im = (image_data * 255).astype(np.uint8) + + markers = np.zeros_like(im, dtype=np.uint8) + IDK = 0 + BG = 1 + FG = 2 + markers[im < 30] = BG + markers[im > 50] = FG + # rest = IDK + + elevation_map = ski.filters.sobel(im) + segmentation = ski.segmentation.watershed(elevation_map, markers) + segmentation = sp.ndimage.binary_fill_holes(segmentation - 1) + + labels, _ = sp.ndimage.label(segmentation) + + # remove small objects + sizes = np.bincount(labels.ravel()) + mask_sizes = sizes > 20 + mask_sizes[0] = 0 + segmentation = mask_sizes[labels] + + labels, _ = sp.ndimage.label(segmentation) + + return labels + + +def main(): + client = ForeignToolClient(7878, cb=run) + client.receive_images() + +if __name__ == "__main__": + # init logging + logging.root.setLevel(logging.DEBUG) + stream_handler = logging.StreamHandler() + fmt = logging.Formatter(" [%(process)d|%(levelno)s] %(name)s::%(funcName)s: %(message)s") + stream_handler.setFormatter(fmt) + logging.root.addHandler(stream_handler) + + logger.debug("Starting thresh.py") + main() \ No newline at end of file diff --git a/CP5/active_plugins/cpforeign/zmq_server.ipynb b/CP5/active_plugins/cpforeign/zmq_server.ipynb new file mode 100644 index 00000000..5c88d461 --- /dev/null +++ b/CP5/active_plugins/cpforeign/zmq_server.ipynb @@ -0,0 +1,408 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import zmq\n", + "import numpy as np\n", + "import skimage as ski\n", + "import matplotlib.pyplot as plt\n", + "import scipy as sp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PORT = \"7878\"\n", + "DOMAIN = \"*\"\n", + "SOCKET_ADDR = f\"tcp://{DOMAIN}:{PORT}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "HELLO = \"Hello\"\n", + "ACK = \"Acknowledge\"\n", + "DENIED = \"Denied\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def cleanup(context, socket):\n", + " print(\"destroying existing context\")\n", + " if socket:\n", + " socket.close()\n", + " if context:\n", + " context.term()\n", + " # destroy is more destructive\n", + " # doesn't require sockets closed first\n", + " # may leave them hanging if managed by other threads\n", + " context.destroy()\n", + " print(\"socket closed\", socket.closed, \"context closed\", context.closed)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "context = zmq.Context()\n", + "socket = context.socket(zmq.PAIR)\n", + "#socket.copy_threshold = 0\n", + "b = socket.bind(SOCKET_ADDR)\n", + "b" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#cleanup(context, socket)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# handshake\n", + "socket.send_string(HELLO)\n", + "print(\"Sent hello, waiting for acknowledgement...\")\n", + "ack = socket.recv_string()\n", + "if ack == ACK:\n", + " print('Received connection ack:', ack)\n", + "else:\n", + " print(\"Received unkown message\", ack)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "header = socket.recv_json()\n", + "header" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# acknowledge receipt of header, ask for image data\n", + "socket.send_string(ACK)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "im_bytes = socket.recv(copy=False)\n", + "im_bytes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "buf = memoryview(im_bytes)\n", + "im = np.frombuffer(buf, dtype=header['descr'])\n", + "im = (im * 255).astype(np.uint8)\n", + "im.shape = header['shape']\n", + "im" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# acknowledge receipt of image data\n", + "socket.send_string(ACK)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(im, cmap='gray')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import napari" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "viewer = napari.view_image(im)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the histogram\n", + "plt.figure(figsize=(10, 6))\n", + "plt.hist(im.flatten(), bins=256, range=(0, 256), edgecolor='black')\n", + "\n", + "# Labeling the axes and adding a title\n", + "plt.xlabel('Pixel Value')\n", + "plt.ylabel('Frequency')\n", + "plt.title('Pixel Value Frequency Distribution')\n", + "\n", + "# Display the plot\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(plt.colormaps())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cmap_name = 'Paired'\n", + "# Reshape to a 2D array (gradient image) with shape (1, 256)\n", + "gradient = np.arange(256)[np.newaxis, :]\n", + "\n", + "# Display the gradient image using the specified colormap\n", + "plt.figure(figsize=(8, 2))\n", + "plt.imshow(gradient, aspect='auto', cmap=cmap_name)\n", + "plt.colorbar(label='Pixel Value')\n", + "plt.title(f'Colormap: {cmap_name}')\n", + "plt.xlabel('Value')\n", + "plt.ylabel('Intensity')\n", + "plt.yticks([])\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ets = lambda v: int(v/65535*255)\n", + "ste = lambda v: int(v/255*65535)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "markers = np.zeros_like(im, dtype=np.uint8)\n", + "IDK = 0\n", + "BG = 1\n", + "FG = 2\n", + "# markers[im < ste(30)] = BG # blue\n", + "# markers[im > ste(50)] = FG # brown\n", + "markers[im < 30] = BG # blue\n", + "markers[im > 50] = FG # brown\n", + "# IDK = RED\n", + "\n", + "# plt.imshow(np.array(markers / markers.max() * (2**8-1), dtype=np.uint8), cmap='gray')\n", + "plt.imshow(np.array(markers / markers.max() * (2**8-1), dtype=np.uint8), cmap=cmap_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "elevation_map = ski.filters.sobel(im)\n", + "\n", + "plt.imshow(elevation_map)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "segmentation = ski.segmentation.watershed(elevation_map, markers)\n", + "segmentation = sp.ndimage.binary_fill_holes(segmentation - 1)\n", + "\n", + "plt.imshow(segmentation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a figure with two subplots side by side\n", + "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", + "\n", + "# Display the first image in the first subplot\n", + "axes[0].imshow(im, cmap='gray', vmin=0, vmax=255)\n", + "axes[0].set_title('Orig')\n", + "axes[0].axis('off') # Hide the axes\n", + "\n", + "# Display the second image in the second subplot\n", + "axes[1].imshow(segmentation, cmap=cmap_name, vmin=0, vmax=segmentation.max())\n", + "axes[1].set_title('Seg')\n", + "axes[1].axis('off') # Hide the axes\n", + "\n", + "# Adjust layout to prevent overlap\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "labels, _ = sp.ndimage.label(segmentation)\n", + "\n", + "# remove small objects\n", + "sizes = np.bincount(labels.ravel())\n", + "mask_sizes = sizes > 20\n", + "mask_sizes[0] = 0\n", + "segmentation = mask_sizes[labels]\n", + "\n", + "labels, _ = sp.ndimage.label(segmentation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "return_header = np.lib.format.header_data_from_array_1_0(labels)\n", + "return_header" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "socket.send_json(return_header)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ack = socket.recv_string()\n", + "if ack == ACK:\n", + " print('Received return header ack:', ack)\n", + "else:\n", + " print(\"Received unkown message\", ack)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "socket.send(labels, copy=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# socket.send_string(\"Cancel\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cleanup(context, socket)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cp_lis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/CP5/active_plugins/runforeignenv.py b/CP5/active_plugins/runforeignenv.py new file mode 100644 index 00000000..20c8b54c --- /dev/null +++ b/CP5/active_plugins/runforeignenv.py @@ -0,0 +1,172 @@ +import shlex +import sys +import os +import re +import subprocess +import threading +import logging + +from cellprofiler_core.module.image_segmentation import ImageSegmentation +from cellprofiler_core.setting.text import Text +from cellprofiler_core.setting.text import Filename +from cellprofiler_core.setting.text import Integer +from cellprofiler_core.object import Objects + +from cpforeign.server import ForeignToolServer + +LOGGER = logging.getLogger(__name__) + +HELLO = "Hello" +ACK = "Acknowledge" +DENIED = "Denied" + +__doc__ = """\ +RunForeignEnv +============ + +**RunForeign** runs a foreign tool, in a foreign (conda) environment, via sockets. + + +Assumes there is a client up and running. + +| + +============ ============ =============== +Supports 2D? Supports 3D? Respects masks? +============ ============ =============== +YES NO YES +============ ============ =============== + +""" + +def _run_logger(workR): + # this thread shuts itself down by reading from worker's stdout + # which either reads content from stdout or blocks until it can do so + # when the worker is shut down, empty byte string is returned continuously + # which evaluates as None so the break is hit + # I don't really like this approach; we should just shut it down with the other + # threads explicitly + while True: + try: + print('reading') + line = workR.stdout.readline() + if (type(line) == bytes): + line = line.decode("utf-8") + if not line: + break + log_msg_match = re.match(fr"{workR.pid}\|(10|20|30|40|50)\|(.*)", line) + if log_msg_match: + levelno = int(log_msg_match.group(1)) + msg = log_msg_match.group(2) + else: + levelno = 20 + msg = line + + LOGGER.log(levelno, "\n\r [Worker (%d)] %s", workR.pid, msg.rstrip()) + + except Exception as e: + LOGGER.exception(e) + break + + +class RunForeignEnv(ImageSegmentation): + category = "Object Processing" + + module_name = "RunForeignEnv" + + variable_revision_number = 1 + + def create_settings(self): + super().create_settings() + + self._server = None + self._client_launched = False + + self.server_port = Integer( + text="Server port number", + value=7878, + minval=0, + doc="""\ +The port number which the server is listening on. The server must be launched manually first. +""", + ) + + self.env_name = Text(text="Conda environment name", value="foreign-thresh") + + self.algo_path = Filename(text="Algorithm path", value="/Users/ngogober/Developer/CellProfiler/CellProfiler-plugins/CP5/active_plugins/cpforeign/thresh.py") + + def settings(self): + return super().settings() + [self.server_port, self.env_name, self.algo_path] + + # ImageSegmentation defines this so we have to overide it + def visible_settings(self): + return self.settings() + + # ImageSegmentation defines this so we have to overide it + def volumetric(self): + return False + + def prepare_run(self, workspace): + + LOGGER.debug(">>> Preparing run") + if not self._server: + LOGGER.debug(">>> Initializing server") + self._server = ForeignToolServer(self.server_port.value, wait_for_handshake=False) + + if not self._client_launched: + LOGGER.debug(">>> Launching client") + command = f"conda run --no-capture-output -n {self.env_name.value} python {self.algo_path.value}" + args = shlex.split(command) + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self._client_proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stdout, bufsize=1, universal_newlines=True, env=env, close_fds=False) + #self._client_thread = threading.Thread(target=_run_logger, args=(self._client_proc,), name="foreign client stdout logger thread") + #self._client_thread.start() + + self._client_launched = True + self._server.wait_for_handshake() + + return True + + def post_run(self, workspace): + if self._client_launched: + LOGGER.debug(">>> Shuttding down client") + #self._client_thread.join() + self._client_proc.terminate() + + def run(self, workspace): + # TODO: is this supposed to not run in test mode? because it doesn't... + self.prepare_run(workspace) + + x_name = self.x_name.value + + y_name = self.y_name.value + + images = workspace.image_set + + x = images.get_image(x_name) + + dimensions = x.dimensions + + x_data = x.pixel_data + + y_data = self._server.serve_one_image(x_data) + + y = Objects() + + y.segmented = y_data + + y.parent_image = x.parent_image + + objects = workspace.object_set + + objects.add_objects(y, y_name) + + self.add_measurements(workspace) + + if self.show_window: + workspace.display_data.x_data = x_data + + workspace.display_data.y_data = y_data + + workspace.display_data.dimensions = dimensions diff --git a/CP5/active_plugins/runforeignnb.py b/CP5/active_plugins/runforeignnb.py new file mode 100644 index 00000000..d26dd9b5 --- /dev/null +++ b/CP5/active_plugins/runforeignnb.py @@ -0,0 +1,216 @@ +import zmq +import numpy as np +import logging + +from cellprofiler_core.module.image_segmentation import ImageSegmentation +from cellprofiler_core.setting.do_something import DoSomething +from cellprofiler_core.setting.text import Integer +from cellprofiler_core.object import Objects + +LOGGER = logging.getLogger(__name__) + +HELLO = "Hello" +ACK = "Acknowledge" +DENIED = "Denied" + +__doc__ = """\ +RunForeignNb +============ + +**RunForeign** runs a foreign notebook via sockets. + + +Assumes there is a notebook running as a server to do the handshake, receive an image, and send back labels. + +| + +============ ============ =============== +Supports 2D? Supports 3D? Respects masks? +============ ============ =============== +YES NO YES +============ ============ =============== + +""" + +class RunForeignNb(ImageSegmentation): + category = "Object Processing" + + module_name = "RunForeignNb" + + variable_revision_number = 1 + + def create_settings(self): + super().create_settings() + + self.context = None + self.server_socket = None + + # TODO: launch server automatically, if necessary + self.server_port = Integer( + text="Server port number", + value=7878, + minval=0, + doc="""\ +The port number which the server is listening on. The server must be launched manually first. +""", + ) + + # TODO: perform handshake automatically, if necessary + self.server_handshake = DoSomething( + "", + "Perform Server Handshake", + self.do_server_handshake, + doc=f"""\ +Press this button to do an initial handshake with the server. +This must be done manually, once. +""", + ) + + def settings(self): + return super().settings() + [self.server_port, self.server_handshake] + + # ImageSegmentation defines this so we have to overide it + def visible_settings(self): + return self.settings() + + # ImageSegmentation defines this so we have to overide it + def volumetric(self): + return False + + def run(self, workspace): + x_name = self.x_name.value + + y_name = self.y_name.value + + images = workspace.image_set + + x = images.get_image(x_name) + + dimensions = x.dimensions + + x_data = x.pixel_data + + y_data = self.do_server_execute(x_data) + + y = Objects() + + y.segmented = y_data + + y.parent_image = x.parent_image + + objects = workspace.object_set + + objects.add_objects(y, y_name) + + self.add_measurements(workspace) + + if self.show_window: + workspace.display_data.x_data = x_data + + workspace.display_data.y_data = y_data + + workspace.display_data.dimensions = dimensions + + def do_server_handshake(self): + def cleanup(): + LOGGER.debug("destroying existing context") + if self.server_socket is not None: + self.server_socket.close() + LOGGER.debug(f"socket closed: {self.server_socket.closed}") + self.server_socket = None + if self.context is not None: + self.context.term() + # destroy is more destructive + # doesn't require sockets closed first + # may leave them hanging if managed by other threads + self.context.destroy() + LOGGER.debug(f"context closed: {self.context.closed}") + self.context = None + + port = str(self.server_port.value) + domain = "localhost" + socket_addr = f"tcp://{domain}:{port}" + + if self.context is not None or self.server_socket is not None: + cleanup() + + self.context = zmq.Context() + self.server_socket = self.context.socket(zmq.PAIR) + #self.server_socket.copy_threshold = 0 + + LOGGER.debug(f"connecting to {socket_addr}") + + c = self.server_socket.connect(socket_addr) + + LOGGER.debug(f"setup socket at {c}") + + LOGGER.debug("receiving handshake, waiting for acknowledgement") + + poller = zmq.Poller() + poller.register(self.server_socket, zmq.POLLIN) + while True: + socks = dict(poller.poll(5000)) + if socks.get(self.server_socket) == zmq.POLLIN: + break + else: + LOGGER.debug("handshake timeout") + cleanup() + return + + hello = self.server_socket.recv_string() + + if hello == HELLO: + LOGGER.debug(f"received correct greeting {hello}") + else: + LOGGER.debug(f"received unexpected greeting {hello}") + + LOGGER.debug("acknowledging handshake") + + self.server_socket.send_string(ACK) + + + def do_server_execute(self, im_data): + dummy_data = lambda: np.array([[]]) + + socket = self.server_socket + header = np.lib.format.header_data_from_array_1_0(im_data) + + LOGGER.debug(f"sending header {header}; waiting for acknowledgement") + socket.send_json(header) + + ack = socket.recv_string() + if ack == ACK: + LOGGER.debug(f"header acknowledged: {ack}") + else: + LOGGER.debug(f"unexpected response {ack}") + return dummy_data() + + LOGGER.debug(f"sending image data {im_data.shape}; waiting for acknowledgement") + socket.send(im_data, copy=False) + + ack = socket.recv_string() + if ack == ACK: + LOGGER.debug(f"image data acknowledged {ack}") + elif ack == DENIED: + LOGGER.debug(f"image data denied, aborting {ack}") + return dummy_data() + else: + LOGGER.debug(f"unknown response to image data {ack}") + return dummy_data() + + LOGGER.debug("waiting for return header") + return_header = socket.recv_json() + LOGGER.debug(f"received return header {return_header}") + + LOGGER.debug("acknowledging header reciept") + socket.send_string(ACK) + + LOGGER.debug("waiting for image data") + label_data_buf = socket.recv(copy=False) + LOGGER.debug("image data received") + + labels = np.frombuffer(label_data_buf, dtype=return_header['descr']) + labels.shape = return_header['shape'] + LOGGER.debug(f"returning label data {labels.shape}") + + return labels