{ "cells": [ { "cell_type": "markdown", "id": "1e240c11-442b-4c09-b245-f286907ec7a7", "metadata": {}, "source": [ "# Using `refnx` on a cluster with MPI\n", "\n", "`refnx` can be used on a compute cluster, typically when you want to do a largish MCMC sampling run. You will need to install these packages in the Python environment:\n", "\n", "- refnx\n", "- numpy\n", "- cython\n", "- schwimmbad\n", "- mpi4py\n", "- scipy\n", "\n", "For this specific example you'll also need the `corner` and `matplotlib` packages. Setting up a Python environment on your cluster can have difficulties, so contact your helpful cluster administrator if you need help.\n", "\n", "You would typically start the code running with something along the lines of:\n", "\n", "```\n", "mpiexec -n 8 python cf.py # requests parallelisation over 8 processes\n", "```\n", "\n", "(assuming the script is saved as `cf.py`). This call might be started using a scheduler, such as PBS. Use of that is outside the bounds of this tutorial. Again, your cluster admin would be able to help there.\n", "This file would generate a text file called `steps.chain` which would then be further processed to give an output that's useful.\n", "\n", "When you start modifying this example for your purposes you should begin by tailoring the `setup` function to return an `refnx.analysis.Objective` for your system." ] }, { "cell_type": "markdown", "id": "d47ad3e9-3759-447d-a4d5-d204f116ba7e", "metadata": {}, "source": [ "```python\n", "import sys\n", "import os.path\n", "\n", "import refnx\n", "from schwimmbad import MPIPool\n", "\n", "from refnx.reflect import SLD, Slab, ReflectModel\n", "from refnx.dataset import ReflectDataset\n", "from refnx.analysis import (Objective, CurveFitter, Transform, GlobalObjective)\n", "\n", "\n", "def setup():\n", " # Tailor this function for your own system\n", " \n", " # load the data.\n", " DATASET_NAME = os.path.join(refnx.__path__[0],\n", " 'analysis',\n", " 'test',\n", " 'c_PLP0011859_q.txt')\n", "\n", " # load the data\n", " data = ReflectDataset(DATASET_NAME)\n", "\n", " # the materials we're using\n", " si = SLD(2.07, name='Si')\n", " sio2 = SLD(3.47, name='SiO2')\n", " film = SLD(2, name='film')\n", " d2o = SLD(6.36, name='d2o')\n", "\n", " structure = si | sio2(30, 3) | film(250, 3) | d2o(0, 3)\n", " structure[1].thick.setp(vary=True, bounds=(15., 50.))\n", " structure[1].rough.setp(vary=True, bounds=(1., 6.))\n", " structure[2].thick.setp(vary=True, bounds=(200, 300))\n", " structure[2].sld.real.setp(vary=True, bounds=(0.1, 3))\n", " structure[2].rough.setp(vary=True, bounds=(1, 6))\n", "\n", " model = ReflectModel(structure, bkg=9e-6, scale=1.)\n", " model.bkg.setp(vary=True, bounds=(1e-8, 1e-5))\n", " model.scale.setp(vary=True, bounds=(0.9, 1.1))\n", " \n", " # model.threads controls the parallelisation of the reflectivity calculation\n", " # because we're parallelising the MCMC calculation we don't want oversubscription\n", " # of the computer, so we only calculate the reflectivity with one thread.\n", " model.threads = 1\n", " \n", " # fit on a logR scale, but use weighting\n", " objective = Objective(model, data, transform=Transform('logY'),\n", " use_weights=True)\n", "\n", " return objective\n", "\n", "\n", "def structure_plot(obj, samples=0):\n", " # plot sld profiles\n", " import matplotlib.pyplot as plt\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", "\n", " if isinstance(obj, GlobalObjective):\n", " if samples > 0:\n", " savedparams = np.array(obj.parameters)\n", " for pvec in obj.parameters.pgen(ngen=samples):\n", " obj.setp(pvec)\n", " for o in obj.objectives:\n", " if hasattr(o.model, 'structure'):\n", " ax.plot(*o.model.structure.sld_profile(),\n", " color=\"k\", alpha=0.01)\n", "\n", " # put back saved_params\n", " obj.setp(savedparams)\n", "\n", " for o in obj.objectives:\n", " if hasattr(o.model, 'structure'):\n", " ax.plot(*o.model.structure.sld_profile(), zorder=20)\n", "\n", " ax.set_ylabel('SLD / $10^{-6}\\\\AA^{-2}$')\n", " ax.set_xlabel(\"z / $\\\\AA$\")\n", "\n", " elif isinstance(obj, Objective) and hasattr(obj.model, 'structure'):\n", " fig, ax = obj.model.structure.plot(samples=samples)\n", "\n", " fig.savefig('steps_sld.png', dpi=1000)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " with MPIPool() as pool:\n", " if not pool.is_master():\n", " pool.wait()\n", " sys.exit(0)\n", " # buffering so the program doesn't try to write to the file\n", " # constantly\n", " with open('steps.chain', 'w', buffering=500000) as f:\n", " objective = setup()\n", " # Create the fitter and fit\n", " fitter = CurveFitter(objective, nwalkers=300)\n", " fitter.initialise('prior')\n", " fitter.fit('differential_evolution')\n", " # Collect 200 saved steps, which are thinned/separated by 10 steps.\n", " fitter.sample(200, pool=pool.map, f=f, verbose=False, nthin=10);\n", " f.flush()\n", "\n", " # the following section is only necessary if you want to make some pretty graphs\n", " try:\n", " # create graphs of reflectivity and SLD profiles\n", " import matplotlib\n", " import matplotlib.pyplot as plt\n", " matplotlib.use('agg')\n", "\n", " fig, ax = objective.plot(samples=1000)\n", " ax.set_ylabel('R')\n", " ax.set_xlabel(\"Q / $\\\\AA$\")\n", " fig.savefig('steps.png', dpi=1000)\n", "\n", " structure_plot(objective, samples=1000)\n", "\n", " # corner plot\n", " fig = objective.corner()\n", " fig.savefig('steps_corner.png')\n", "\n", " # plot the Autocorrelation function of the chain\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " ax.plot(fitter.acf())\n", " ax.set_ylabel('autocorrelation')\n", " ax.set_xlabel('step')\n", " fig.savefig('steps-autocorrelation.png')\n", " except ImportError:\n", " pass\n", "```" ] }, { "cell_type": "markdown", "id": "b40748db-95db-4b34-881e-dde9efaf7aa6", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }