"""
refnx is distributed under the following license:
Copyright (c) 2015 A. R. J. Nelson, ANSTO
Permission to use and redistribute the source code or binary forms of this
software and its documentation, with or without modification is hereby
granted provided that the above notice of copyright, these terms of use,
and the disclaimer of warranty below appear in the source code and
documentation, and that none of the names of above institutions or
authors appear in advertising or endorsement of works derived from this
software without specific prior written permission from all parties.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THIS SOFTWARE.
"""
import time
import datetime
import pickle
import warnings
import glob
import os
import numpy as np
import ipywidgets as widgets
import traitlets
from traitlets import HasTraits
import matplotlib.gridspec as gridspec
from refnx.reflect import Slab, ReflectModel
from refnx.dataset import ReflectDataset
from refnx.analysis import Objective, CurveFitter, Transform
from refnx._lib import flatten, possibly_open_file
class ReflectModelView(HasTraits):
"""
An ipywidgets viewport of a `refnx.reflect.ReflectModel`.
Parameters
----------
reflect_model: refnx.reflect.ReflectModel
Notes
-----
Use the `model_box` property to view/modify the ReflectModel parameters.
Use the `limits_box` property to view the limits for the varying
parameters.
Observe the `view_changed` traitlet to determine when widget values are
changed.
Observe the `view_redraw` traitlet to determine when a complete redraw
of the view is required (because the number of widgets has changed for
example).
"""
# traitlet to say when params were last altered
view_changed = traitlets.Float(time.time())
# traitlet to ask when a redraw of the GUI is requested.
# e.g. the number of layers has changed, or there are a
# different number of fitted parameters requiring
# different limit widgets.
view_redraw = traitlets.Float(time.time())
# number of varying parameters in the model
num_varying = traitlets.Int(0)
def __init__(self, reflect_model):
super().__init__()
self.model = reflect_model
self.structure_view = StructureView(self.model.structure)
self.last_selected_param = None
self.param_widgets_link = {}
slab_views = self.structure_view.slab_views
slab_views[0].w_thick.disabled = True
slab_views[0].c_thick.disabled = True
slab_views[0].w_rough.disabled = True
slab_views[0].c_rough.disabled = True
slab_views[-1].w_thick.disabled = True
slab_views[-1].c_thick.disabled = True
# got to listen to all the slab views
for slab_view in slab_views:
slab_view.observe(
self._on_slab_params_modified, names=["view_changed"]
)
# if you'd like to change the number of layers
self.w_layers = widgets.BoundedIntText(
description="Number of layers",
value=len(slab_views) - 2,
min=0,
max=1000,
style={"description_width": "120px"},
continuous_update=False,
)
self.w_layers.observe(self._on_change_layers, names=["value"])
# where you're going to add/remove layers
# varying layers is a flag to say if you're currently in the process
# of adding/removing layers
self._varying_layers = False
self._location = None
self.ok_button = None
self.cancel_button = None
# associated with ReflectModel
p = reflect_model.scale
self.w_scale = widgets.FloatText(
value=p.value,
description="scale",
step=0.01,
style={"description_width": "120px"},
)
self.c_scale = widgets.Checkbox(value=p.vary)
self.scale_low_limit = widgets.FloatText(value=p.bounds.lb, step=0.01)
self.scale_hi_limit = widgets.FloatText(value=p.bounds.ub, step=0.01)
p = reflect_model.bkg
self.w_bkg = widgets.FloatText(
value=p.value,
description="background",
step=1e-7,
style={"description_width": "120px"},
)
self.c_bkg = widgets.Checkbox(value=reflect_model.bkg.vary)
self.bkg_low_limit = widgets.FloatText(p.bounds.lb, step=1e-8)
self.bkg_hi_limit = widgets.FloatText(value=p.bounds.ub, step=1e-7)
p = reflect_model.dq
self.w_dq = widgets.BoundedFloatText(
value=p.value, description="dq/q", step=0.1, min=0, max=20.0
)
self.c_dq = widgets.Checkbox(value=reflect_model.dq.vary)
self.dq_low_limit = widgets.BoundedFloatText(
value=p.bounds.lb, min=0, max=20, step=0.1
)
self.dq_hi_limit = widgets.BoundedFloatText(
value=p.bounds.ub, min=0, max=20, step=0.1
)
self.c_scale.style.description_width = "0px"
self.c_bkg.style.description_width = "0px"
self.c_dq.style.description_width = "0px"
self.do_fit_button = widgets.Button(description="Do Fit")
self.to_code_button = widgets.Button(description="To code")
self.save_model_button = widgets.Button(description="Save Model")
self.load_model_button = widgets.Button(description="Load Model")
widget_list = [
self.w_scale,
self.c_scale,
self.w_bkg,
self.c_bkg,
self.w_dq,
self.c_dq,
]
limits_widgets_list = [
self.scale_low_limit,
self.scale_hi_limit,
self.bkg_low_limit,
self.bkg_hi_limit,
self.dq_low_limit,
self.dq_hi_limit,
]
for widget in widget_list:
widget.observe(self._on_model_params_modified, names=["value"])
for widget in limits_widgets_list:
widget.observe(self._on_model_limits_modified, names=["value"])
# button to create default limits
self.default_limits_button = widgets.Button(
description="Set default limits"
)
self.default_limits_button.on_click(self.default_limits)
# widgets for easy model change
self.model_slider = widgets.FloatSlider()
self.model_slider.layout = widgets.Layout(width="100%")
self.model_slider_link = None
self.model_slider_min = widgets.FloatText()
self.model_slider_min.layout = widgets.Layout(width="10%")
self.model_slider_max = widgets.FloatText()
self.model_slider_max.layout = widgets.Layout(width="10%")
self.model_slider_max.observe(
self._on_slider_limits_modified, names=["value"]
)
self.model_slider_min.observe(
self._on_slider_limits_modified, names=["value"]
)
self.last_selected_param = None
self.num_varying = len(self.model.parameters.varying_parameters())
self._link_param_widgets()
def _on_model_params_modified(self, change):
"""
Called when ReflectModel parameters are varied.
"""
d = self.param_widgets_link
for par in [self.model.scale, self.model.bkg, self.model.dq]:
idx = id(par)
wids = d[idx]
if change["owner"] in wids:
loc = wids.index(change["owner"])
if loc == 0:
par.value = wids[0].value
# this captures when the user starts modifying a different
# parameter
self._possibly_link_slider(change["owner"])
self.view_changed = time.time()
break
elif loc == 1:
# you are changing the number of varying parameters
par.vary = wids[1].value
# need to rebuild the limit widgets, achieved by redrawing
# box
# set the number of varying parameters
self.num_varying = len(
self.model.parameters.varying_parameters()
)
self.view_redraw = time.time()
break
else:
return
# this captures when the user starts modifying a different parameter
self._possibly_link_slider(change["owner"])
def _on_model_limits_modified(self, change):
"""
When a limit widget is changed, update corresponding limits in the
underlying ReflectModel.
"""
d = self.param_widgets_link
for par in [self.model.scale, self.model.bkg, self.model.dq]:
idx = id(par)
wids = d[idx]
if change["owner"] in wids:
loc = wids.index(change["owner"])
if loc == 2:
par.bounds.lb = wids[2].value
break
elif loc == 3:
par.bounds.ub = wids[3].value
break
def default_limits(self, change):
"""
Makes default limits for the parameters being varied
"""
varying_parameters = self.model.parameters.varying_parameters()
for par in varying_parameters:
par.bounds.lb = min(0, 2 * par.value)
par.bounds.ub = max(0, 2 * par.value)
self.refresh()
def _on_slab_params_modified(self, change):
"""
Called when slab parameters are varied.
"""
# this captures when the user starts modifying a different parameter
if isinstance(change["owner"].param_being_varied, widgets.Checkbox):
# you are changing the number of fitted parameters
# set the number of varying parameters
self.num_varying = len(self.model.parameters.varying_parameters())
# need to rebuild the limit widgets, achieved by redrawing box
self.view_redraw = time.time()
else:
self._possibly_link_slider(change["owner"].param_being_varied)
self.view_changed = time.time()
def _possibly_link_slider(self, change_owner):
"""
When a ReflectModel value is changed link a slider widget to the
parameter that is being varied.
"""
if change_owner is not self.last_selected_param:
self.last_selected_param = change_owner
if self.model_slider_link is not None:
self.model_slider_link.unlink()
self.model_slider_link = widgets.link(
(self.last_selected_param, "value"),
(self.model_slider, "value"),
)
self.model_slider_max.value = max(
0, 2.0 * self.last_selected_param.value
)
self.model_slider_min.value = min(
0, 2.0 * self.last_selected_param.value
)
def _on_slider_limits_modified(self, change):
"""
Callback when adjusting the min, max widgets for the main slider
widget.
"""
self.model_slider.max = self.model_slider_max.value
self.model_slider.min = self.model_slider_min.value
self.model_slider.step = (
self.model_slider.max - self.model_slider.min
) / 200.0
def _on_change_layers(self, change):
self.ok_button = widgets.Button(description="OK")
if change["new"] > change["old"]:
description = "Insert before which layer?"
min_loc = 1
max_loc = len(self.model.structure) - 2 + 1
self.ok_button.on_click(self._increase_layers)
elif change["new"] < change["old"]:
min_loc = 1
max_loc = (
len(self.model.structure)
- 2
- (change["old"] - change["new"])
+ 1
)
description = "Remove from which layer?"
self.ok_button.on_click(self._decrease_layers)
else:
return
self._varying_layers = True
self.w_layers.disabled = True
self.do_fit_button.disabled = True
self.to_code_button.disabled = True
self.save_model_button.disabled = True
self.load_model_button.disabled = True
self._location = widgets.BoundedIntText(
value=min_loc,
description=description,
min=min_loc,
max=max_loc,
style={"description_width": "initial"},
)
self.cancel_button = widgets.Button(description="Cancel")
self.cancel_button.on_click(self._cancel_layers)
self.view_redraw = time.time()
def _increase_layers(self, b):
self.w_layers.disabled = False
self.do_fit_button.disabled = False
self.to_code_button.disabled = False
self.save_model_button.disabled = False
self.load_model_button.disabled = False
how_many = self.w_layers.value - (len(self.model.structure) - 2)
loc = self._location.value
for i in range(how_many):
slab = Slab(0, 0, 3)
slab.thick.bounds = (0, 2 * slab.thick.value)
slab.sld.real.bounds = (0, 2 * slab.sld.real.value)
slab.sld.imag.bounds = (0, 2 * slab.sld.imag.value)
slab.rough.bounds = (0, 2 * slab.rough.value)
slab_view = SlabView(slab)
self.model.structure.insert(loc, slab)
self.structure_view.slab_views.insert(loc, slab_view)
slab_view.observe(self._on_slab_params_modified)
rename_params(self.model.structure)
self._varying_layers = False
# set the number of varying parameters
self.num_varying = len(self.model.parameters.varying_parameters())
self.view_redraw = time.time()
def _decrease_layers(self, b):
self.w_layers.disabled = False
self.do_fit_button.disabled = False
self.to_code_button.disabled = False
self.save_model_button.disabled = False
self.load_model_button.disabled = False
loc = self._location.value
how_many = len(self.model.structure) - 2 - self.w_layers.value
for i in range(how_many):
self.model.structure.pop(loc)
slab_view = self.structure_view.slab_views.pop(loc)
slab_view.unobserve_all()
rename_params(self.model.structure)
self._varying_layers = False
# set the number of varying parameters
self.num_varying = len(self.model.parameters.varying_parameters())
self.view_redraw = time.time()
def _link_param_widgets(self):
"""
Creates a dictionary of {parameter: (associated_widgets_tuple)}.
"""
# link parameters to widgets (value, checkbox,
# upperlim, lowerlim)
self.param_widgets_link = {}
d = self.param_widgets_link
model = self.model
d[id(model.scale)] = (
self.w_scale,
self.c_scale,
self.scale_low_limit,
self.scale_hi_limit,
)
d[id(model.bkg)] = (
self.w_bkg,
self.c_bkg,
self.bkg_low_limit,
self.bkg_hi_limit,
)
d[id(model.dq)] = (
self.w_dq,
self.c_dq,
self.dq_low_limit,
self.dq_hi_limit,
)
def _cancel_layers(self, b):
# disable the change layers widget to prevent recursion
self.w_layers.unobserve(self._on_change_layers, names="value")
self.w_layers.value = len(self.model.structure) - 2
self.w_layers.observe(self._on_change_layers, names="value")
self.w_layers.disabled = False
self.do_fit_button.disabled = False
self.to_code_button.disabled = False
self.save_model_button.disabled = False
self.load_model_button.disabled = False
self._varying_layers = False
self.view_redraw = time.time()
def refresh(self):
"""
Updates the widget values from the underlying `ReflectModel`.
"""
for par in [self.model.scale, self.model.bkg, self.model.dq]:
wid = self.param_widgets_link[id(par)]
wid[0].value = par.value
wid[1].value = par.vary
wid[2].value = par.bounds.lb
wid[3].value = par.bounds.ub
slab_views = self.structure_view.slab_views
for slab_view in slab_views:
slab_view.refresh()
@property
def model_box(self):
"""
`ipywidgets.Vbox` displaying model relevant widgets.
"""
output = [
self.w_layers,
widgets.HBox([self.w_scale, self.c_scale, self.w_dq, self.c_dq]),
widgets.HBox([self.w_bkg, self.c_bkg]),
self.structure_view.box,
widgets.HBox(
[
self.model_slider_min,
self.model_slider,
self.model_slider_max,
]
),
]
if self._varying_layers:
output.append(
widgets.HBox(
[self._location, self.ok_button, self.cancel_button]
)
)
output.append(
widgets.HBox(
[
self.do_fit_button,
self.to_code_button,
self.save_model_button,
self.load_model_button,
]
)
)
return widgets.VBox(output)
@property
def limits_box(self):
varying_pars = self.model.parameters.varying_parameters()
hboxes = [self.default_limits_button]
d = {}
d.update(self.param_widgets_link)
slab_views = self.structure_view.slab_views
for slab_view in slab_views:
d.update(slab_view.param_widgets_link)
for par in varying_pars:
name = widgets.Text(par.name)
name.disabled = True
val, check, ll, ul = d[id(par)]
hbox = widgets.HBox([name, ll, val, ul])
hboxes.append(hbox)
return widgets.VBox(hboxes)
class StructureView:
def __init__(self, structure):
self.structure = structure
self.slab_views = [SlabView(slab) for slab in structure]
@property
def box(self):
layout = widgets.Layout(flex="1 1 auto", width="auto")
label_row = widgets.HBox(
[
widgets.HTML("thick", layout=layout),
widgets.HTML("sld", layout=layout),
widgets.HTML("isld", layout=layout),
widgets.HTML("rough", layout=layout),
]
)
hboxes = [label_row]
hboxes.extend([view.box for view in self.slab_views])
# add in layer numbers
self.slab_views[0].w_thick.description = "fronting"
self.slab_views[-1].w_thick.description = "backing"
for i in range(1, len(self.slab_views) - 1):
self.slab_views[i].w_thick.description = str(i)
return widgets.VBox(hboxes)
class SlabView(HasTraits):
"""
An ipywidgets viewport of a `refnx.reflect.Slab`.
Parameters
----------
slab: refnx.reflect.Slab
Notes
-----
An ipywidgets viewport of a `refnx.reflect.Slab`.
Use the `box` property to view/modify the `Slab` parameters.
Observe the `view_changed` traitlet to determine when widget values are
changed.
"""
# traitlet to say when params were last altered
view_changed = traitlets.Float(time.time())
def __init__(self, slab):
self.slab = slab
self.param_widgets_link = {}
self.widgets_param_link = {}
self.param_being_varied = None
p = slab.thick
self.w_thick = widgets.FloatText(value=p.value, step=1)
self.widgets_param_link[self.w_thick] = p
self.c_thick = widgets.Checkbox(value=p.vary)
self.thick_low_limit = widgets.FloatText(value=p.bounds.lb, step=1)
self.thick_hi_limit = widgets.FloatText(value=p.bounds.ub, step=1)
p = slab.sld.real
self.w_sld = widgets.FloatText(value=p.value, step=0.01)
self.widgets_param_link[self.w_sld] = p
self.c_sld = widgets.Checkbox(value=p.vary)
self.sld_low_limit = widgets.FloatText(value=p.bounds.lb, step=0.01)
self.sld_hi_limit = widgets.FloatText(value=p.bounds.ub, step=0.01)
p = slab.sld.imag
self.w_isld = widgets.FloatText(value=p.value, step=0.01)
self.widgets_param_link[self.w_isld] = p
self.c_isld = widgets.Checkbox(value=p.vary)
self.isld_low_limit = widgets.FloatText(value=p.bounds.lb, step=0.01)
self.isld_hi_limit = widgets.FloatText(value=p.bounds.ub, step=0.01)
p = slab.rough
self.w_rough = widgets.FloatText(value=p, step=1)
self.widgets_param_link[self.w_rough] = p
self.c_rough = widgets.Checkbox(value=p.vary)
self.rough_low_limit = widgets.FloatText(p.bounds.lb, step=0.01)
self.rough_hi_limit = widgets.FloatText(value=p.bounds.ub, step=0.01)
self._widget_list = [
self.w_thick,
self.c_thick,
self.w_sld,
self.c_sld,
self.w_isld,
self.c_isld,
self.w_rough,
self.c_rough,
]
self._limits_list = [
self.thick_low_limit,
self.thick_hi_limit,
self.sld_low_limit,
self.sld_hi_limit,
self.isld_low_limit,
self.isld_hi_limit,
self.rough_low_limit,
self.rough_hi_limit,
]
# link widgets to observers
for widget in [self.w_thick, self.w_sld, self.w_isld, self.w_rough]:
widget.style.description_width = "0px"
widget.observe(self._on_slab_values_modified, names="value")
self.w_thick.style.description_width = "50px"
for widget in [self.c_thick, self.c_sld, self.c_isld, self.c_rough]:
widget.style.description_width = "0px"
widget.observe(self._on_slab_varies_modified, names="value")
for widget in self._limits_list:
widget.observe(self._on_slab_limits_modified, names="value")
self._link_param_widgets()
def _on_slab_values_modified(self, change):
d = self.widgets_param_link
d[change["owner"]].value = change["owner"].value
self.param_being_varied = change["owner"]
self.view_changed = time.time()
def _on_slab_varies_modified(self, change):
d = self.param_widgets_link
slab = self.slab
for par in flatten(slab.parameters):
if id(par) in d and change["owner"] in d[id(par)]:
wids = d[id(par)]
par.vary = wids[1].value
break
self.param_being_varied = change["owner"]
self.view_changed = time.time()
def _on_slab_limits_modified(self, change):
slab = self.slab
d = self.param_widgets_link
for par in flatten(slab.parameters):
if id(par) in d and change["owner"] in d[id(par)]:
wids = d[id(par)]
loc = wids.index(change["owner"])
if loc == 2:
par.bounds.lb = wids[loc].value
break
elif loc == 3:
par.bounds.ub = wids[loc].value
break
else:
return
def _link_param_widgets(self):
"""
Creates a dictionary of {parameter: (associated_widgets_tuple)}.
"""
# link parameters to widgets (value, checkbox,
# upperlim, lowerlim)
d = self.param_widgets_link
d[id(self.slab.thick)] = (
self.w_thick,
self.c_thick,
self.thick_low_limit,
self.thick_hi_limit,
)
d[id(self.slab.sld.real)] = (
self.w_sld,
self.c_sld,
self.sld_low_limit,
self.sld_hi_limit,
)
d[id(self.slab.sld.imag)] = (
self.w_isld,
self.c_isld,
self.isld_low_limit,
self.isld_hi_limit,
)
d[id(self.slab.rough)] = (
self.w_rough,
self.c_rough,
self.rough_low_limit,
self.rough_hi_limit,
)
def refresh(self):
"""
Updates the widget values from the underlying `Slab` parameters.
"""
d = self.param_widgets_link
ids = {id(p): p for p in flatten(self.slab.parameters) if id(p) in d}
for idx, par in ids.items():
widgets = d[idx]
widgets[0].value = par.value
widgets[1].value = par.vary
widgets[2].value = par.bounds.lb
widgets[3].value = par.bounds.ub
@property
def box(self):
return widgets.HBox(self._widget_list)
[docs]class Motofit:
"""
An interactive slab modeller (Jupyter/ipywidgets based) for Neutron and
X-ray reflectometry data.
The interactive modeller is designed to be used in a Jupyter notebook.
>>> # specify that plots are in a separate graph window
>>> %matplotlib qt
>>> # alternately if you want the graph to be embedded in the notebook use
>>> # %matplotlib notebook
>>> from refnx.reflect import Motofit
>>> # create an instance of the modeller
>>> app = Motofit()
>>> # display it in the notebook by calling the object with a datafile.
>>> app('dataset1.txt')
>>> # lets fit a different dataset
>>> app2 = Motofit()
>>> app2('dataset2.txt')
The `Motofit` instance has several useful attributes that can be used in
other cells. For example, one can access the `objective` and `curvefitter`
attributes for more advanced fitting functionality than is available in the
GUI. A `code` attribute can be used to retrieve a Python code fragment that
can be used as a basis for developing more complicated models, such as
interparameter constraints, global fitting, etc.
Attributes
----------
dataset: :class:`refnx.dataset.Data1D`
The dataset associated with the modeller
model: :class:`refnx.reflect.ReflectModel`
Calculates a theoretical model, from an interfacial structure
(`model.Structure`).
objective: :class:`refnx.analysis.Objective`
The Objective that allows one to compare the model against the data.
fig: :class:`matplotlib.figure.Figure`
Graph displaying the data.
"""
def __init__(self):
# attributes for the graph
# for the graph
self.qmin = 0.005
self.qmax = 0.5
self.qpnt = 1000
self.fig = None
self.ax_data = None
self.ax_residual = None
self.ax_sld = None
# gridspecs specify how the plots are laid out. Gridspec1 is when the
# residuals plot is displayed. Gridspec2 is when it's not visible
self._gridspec1 = gridspec.GridSpec(
2, 2, height_ratios=[5, 1], width_ratios=[1, 1], hspace=0.01
)
self._gridspec2 = gridspec.GridSpec(1, 2)
self.theoretical_plot = None
self.theoretical_plot_sld = None
# attributes for a user dataset
self.dataset = None
self.objective = None
self._curvefitter = None
self.data_plot = None
self.residuals_plot = None
self.data_plot_sld = None
self.dataset_name = widgets.Text(description="dataset:")
self.dataset_name.disabled = True
self.chisqr = widgets.FloatText(description="chi-squared:")
self.chisqr.disabled = True
# fronting
slab0 = Slab(0, 0, 0)
slab1 = Slab(25, 3.47, 3)
slab2 = Slab(0, 2.07, 3)
structure = slab0 | slab1 | slab2
rename_params(structure)
self.model = ReflectModel(structure)
structure = slab0 | slab1 | slab2
self.model = ReflectModel(structure)
# give some default parameter limits
self.model.scale.bounds = (0.1, 2)
self.model.bkg.bounds = (1e-8, 2e-5)
self.model.dq.bounds = (0, 20)
for slab in self.model.structure:
slab.thick.bounds = (0, 2 * slab.thick.value)
slab.sld.real.bounds = (0, 2 * slab.sld.real.value)
slab.sld.imag.bounds = (0, 2 * slab.sld.imag.value)
slab.rough.bounds = (0, 2 * slab.rough.value)
# the main GUI widget
self.display_box = widgets.VBox()
self.tab = widgets.Tab()
self.tab.observe(self._on_tab_changed, names="selected_index")
# an output area for messages.
self.output = widgets.Output()
# options tab
self.plot_type = widgets.Dropdown(
options=["lin", "logY", "YX4", "YX2"],
value="lin",
description="Plot Type:",
disabled=False,
)
self.plot_type.observe(self._on_plot_type_changed, names="value")
self.use_weights = widgets.RadioButtons(
options=["Yes", "No"],
value="Yes",
description="use dataset weights?",
style={"description_width": "initial"},
)
self.use_weights.observe(self._on_use_weights_changed, names="value")
self.transform = Transform("lin")
self.display_residuals = widgets.Checkbox(
value=False, description="Display residuals"
)
self.display_residuals.observe(
self._on_display_residuals_changed, names="value"
)
self.model_view = None
self.set_model(self.model)
[docs] def save_model(self, *args, f=None):
"""
Serialise a model to a pickle file.
If `f` is not specified then the file name is constructed from the
current dataset name; if there is no current dataset then the filename
is constructed from the current time. These constructed filenames will
be in the current working directory, for a specific save location `f`
must be provided.
This method is only intended to be used to serialise models created by
this interactive Jupyter widget modeller.
Parameters
----------
f: file like or str, optional
File to save model to.
"""
if f is None:
f = "model_" + datetime.datetime.now().isoformat() + ".pkl"
if self.dataset is not None:
f = "model_" + self.dataset.name + ".pkl"
with possibly_open_file(f) as g:
pickle.dump(self.model, g)
[docs] def load_model(self, *args, f=None):
"""
Load a serialised model.
If `f` is not specified then an attempt will be made to find a model
corresponding to the current dataset name,
`'model_' + self.dataset.name + '.pkl'`. If there is no current
dataset then the most recent model will be loaded.
This method is only intended to be used to deserialise models created
by this interactive Jupyter widget modeller, and will not successfully
load complicated ReflectModel created outside of the interactive
modeller.
Parameters
----------
f: file like or str, optional
pickle file to load model from.
"""
if f is None and self.dataset is not None:
# try and load the model corresponding to the current dataset
f = "model_" + self.dataset.name + ".pkl"
elif f is None:
# load the most recent model file
files = list(filter(os.path.isfile, glob.glob("model_*.pkl")))
files.sort(key=lambda x: os.path.getmtime(x))
files.reverse()
if len(files):
f = files[0]
if f is None:
self._print("No model file is specified/available.")
return
try:
with possibly_open_file(f, "rb") as g:
reflect_model = pickle.load(g)
self.set_model(reflect_model)
except (RuntimeError, FileNotFoundError) as exc:
# RuntimeError if the file isn't a ReflectModel
# FileNotFoundError if the specified file name wasn't found
self._print(repr(exc), repr(f))
[docs] def set_model(self, model):
"""
Change the `refnx.reflect.ReflectModel` associated with the `Motofit`
instance.
Parameters
----------
model: refnx.reflect.ReflectModel
"""
if not isinstance(model, ReflectModel):
raise RuntimeError("`model` was not an instance of ReflectModel")
if self.model_view is not None:
self.model_view.unobserve_all()
# figure out if the reflect_model is a different instance. If it is
# then the objective has to be updated.
if model is not self.model:
self.model = model
self._update_analysis_objects()
self.model = model
self.model_view = ReflectModelView(self.model)
self.model_view.observe(self.update_model, names=["view_changed"])
self.model_view.observe(self.redraw, names=["view_redraw"])
# observe when the number of varying parameters changed. This
# invalidates a curvefitter, and a new one has to be produced.
self.model_view.observe(
self._on_num_varying_changed, names=["num_varying"]
)
self.model_view.do_fit_button.on_click(self.do_fit)
self.model_view.to_code_button.on_click(self._to_code)
self.model_view.save_model_button.on_click(self.save_model)
self.model_view.load_model_button.on_click(self.load_model)
self.redraw(None)
[docs] def update_model(self, change):
"""
Updates the plots when the parameters change
Parameters
----------
change
"""
if not self.fig:
return
q = np.linspace(self.qmin, self.qmax, self.qpnt)
theoretical = self.model.model(q)
yt, _ = self.transform(q, theoretical)
sld_profile = self.model.structure.sld_profile()
z, sld = sld_profile
if self.theoretical_plot is not None:
self.theoretical_plot.set_data(q, yt)
self.theoretical_plot_sld.set_data(z, sld)
self.ax_sld.relim()
self.ax_sld.autoscale_view()
if self.dataset is not None:
# if there's a dataset loaded then residuals_plot
# should exist
residuals = self.objective.residuals()
self.chisqr.value = np.sum(residuals**2)
self.residuals_plot.set_data(self.dataset.x, residuals)
self.ax_residual.relim()
self.ax_residual.autoscale_view()
self.fig.canvas.draw()
def _on_num_varying_changed(self, change):
# observe when the number of varying parameters changed. This
# invalidates a curvefitter, and a new one has to be produced.
if change["new"] != change["old"]:
self._curvefitter = None
def _update_analysis_objects(self):
use_weights = self.use_weights.value == "Yes"
self.objective = Objective(
self.model,
self.dataset,
transform=self.transform,
use_weights=use_weights,
)
self._curvefitter = None
[docs] def __call__(self, data=None, model=None):
"""
Display the `Motofit` GUI in a Jupyter notebook cell.
Parameters
----------
data: refnx.dataset.Data1D
The dataset to associate with the `Motofit` instance.
model: refnx.reflect.ReflectModel or str or file-like
A model to associate with the data.
If `model` is a `str` or `file`-like then the `load_model` method
will be used to try and load the model from file. This assumes that
the file is a pickle of a `ReflectModel`
"""
# the theoretical model
# display the main graph
import matplotlib.pyplot as plt
self.fig = plt.figure(figsize=(9, 4))
# grid specs depending on whether the residuals are displayed
if self.display_residuals.value:
d_gs = self._gridspec1[0, 0]
sld_gs = self._gridspec1[:, 1]
else:
d_gs = self._gridspec2[0, 0]
sld_gs = self._gridspec2[0, 1]
self.ax_data = self.fig.add_subplot(d_gs)
self.ax_data.set_xlabel(r"$Q/\AA^{-1}$")
self.ax_data.set_ylabel("Reflectivity")
self.ax_data.grid(True, color="b", linestyle="--", linewidth=0.1)
self.ax_sld = self.fig.add_subplot(sld_gs)
self.ax_sld.set_ylabel(r"$\rho/10^{-6}\AA^{-2}$")
self.ax_sld.set_xlabel(r"$z/\AA$")
self.ax_residual = self.fig.add_subplot(
self._gridspec1[1, 0], sharex=self.ax_data
)
self.ax_residual.set_xlabel(r"$Q/\AA^{-1}$")
self.ax_residual.grid(True, color="b", linestyle="--", linewidth=0.1)
self.ax_residual.set_visible(self.display_residuals.value)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.fig.tight_layout()
q = np.linspace(self.qmin, self.qmax, self.qpnt)
theoretical = self.model.model(q)
yt, _ = self.transform(q, theoretical)
self.theoretical_plot = self.ax_data.plot(q, yt, zorder=2)[0]
self.ax_data.set_yscale("log")
z, sld = self.model.structure.sld_profile()
self.theoretical_plot_sld = self.ax_sld.plot(z, sld)[0]
# the figure has been reset, so remove ref to the data_plot,
# residual_plot
self.data_plot = None
self.residuals_plot = None
self.dataset = None
if data is not None:
self.load_data(data)
if isinstance(model, ReflectModel):
self.set_model(model)
return self.display_box
elif model is not None:
self.load_model(model)
return self.display_box
self.redraw(None)
return self.display_box
[docs] def load_data(self, data):
"""
Load a dataset into the `Motofit` instance.
Parameters
----------
data: refnx.dataset.Data1D, or str, or file-like
"""
if isinstance(data, ReflectDataset):
self.dataset = data
else:
self.dataset = ReflectDataset(data)
self.dataset_name.value = self.dataset.name
# loading a dataset changes the objective and curvefitter
self._update_analysis_objects()
self.qmin = np.min(self.dataset.x)
self.qmax = np.max(self.dataset.x)
if self.fig is not None:
yt, et = self.transform(self.dataset.x, self.dataset.y)
if self.data_plot is None:
(self.data_plot,) = self.ax_data.plot(
self.dataset.x,
yt,
label=self.dataset.name,
ms=2,
marker="o",
ls="",
zorder=1,
)
self.data_plot.set_label(self.dataset.name)
self.ax_data.legend()
# no need to calculate residuals here, that'll be updated in
# the redraw method
(self.residuals_plot,) = self.ax_residual.plot(self.dataset.x)
else:
self.data_plot.set_xdata(self.dataset.x)
self.data_plot.set_ydata(yt)
# calculate theoretical model over same range as data
# use redraw over update_model because it ensures chi2 widget gets
# displayed
self.redraw(None)
self.ax_data.relim()
self.ax_data.autoscale_view()
self.ax_residual.relim()
self.ax_residual.autoscale_view()
self.fig.canvas.draw()
[docs] def redraw(self, change):
"""
Redraw the Jupyter GUI associated with the `Motofit` instance.
"""
self._update_display_box(self.display_box)
self.update_model(None)
@property
def curvefitter(self):
"""
class:`CurveFitter` : Object for fitting the data based on the
objective.
"""
if self.objective is not None and self._curvefitter is None:
self._curvefitter = CurveFitter(self.objective)
return self._curvefitter
def _print(self, string):
"""
Print to the output widget
"""
from IPython.display import clear_output
with self.output:
clear_output()
print(string)
[docs] def do_fit(self, *args):
"""
Ask the Motofit object to perform a fit (differential evolution).
Parameters
----------
change
Notes
-----
After performing the fit the Jupyter display is updated.
"""
if self.dataset is None:
return
if not self.model.parameters.varying_parameters():
self._print("No parameters are being varied")
return
try:
logp = self.objective.logp()
if not np.isfinite(logp):
self._print(
"One of your parameter values lies outside its"
" bounds. Please adjust the value, or the bounds."
)
return
except ZeroDivisionError:
self._print(
"One parameter has equal lower and upper bounds."
" Either alter the bounds, or don't let that"
" parameter vary."
)
return
def callback(xk, convergence):
self.chisqr.value = self.objective.chisqr(xk)
self.curvefitter.fit("differential_evolution", callback=callback)
# need to update the widgets as the model will be updated.
# this also redraws GUI.
# self.model_view.refresh()
self.set_model(self.model)
self._print(str(self.objective))
def _to_code(self, change=None):
self._print(self.code)
@property
def code(self):
"""
str : A Python code fragment capable of fitting the data.
Executable Python code fragment for the GUI model.
"""
if self.objective is None:
self._update_analysis_objects()
return to_code(self.objective)
def _on_tab_changed(self, change):
pass
def _on_plot_type_changed(self, change):
"""
User would like to plot and fit as logR/linR/RQ4/RQ2, etc
"""
self.transform = Transform(change["new"])
if self.objective is not None:
self.objective.transform = self.transform
if self.dataset is not None:
yt, _ = self.transform(self.dataset.x, self.dataset.y)
self.data_plot.set_xdata(self.dataset.x)
self.data_plot.set_ydata(yt)
self.update_model(None)
# probably have to change LHS axis of the data plot when
# going between different plot types.
if change["new"] == "logY":
self.ax_data.set_yscale("linear")
else:
self.ax_data.set_yscale("log")
self.ax_data.relim()
self.ax_data.autoscale_view()
self.fig.canvas.draw()
def _on_use_weights_changed(self, change):
self._update_analysis_objects()
self.update_model(None)
def _on_display_residuals_changed(self, change):
import matplotlib.pyplot as plt
if change["new"]:
self.ax_residual.set_visible(True)
self.ax_data.set_position(
self._gridspec1[0, 0].get_position(self.fig)
)
self.ax_sld.set_position(
self._gridspec1[:, 1].get_position(self.fig)
)
plt.setp(self.ax_data.get_xticklabels(), visible=False)
else:
self.ax_residual.set_visible(False)
self.ax_data.set_position(
self._gridspec2[:, 0].get_position(self.fig)
)
self.ax_sld.set_position(
self._gridspec2[:, 1].get_position(self.fig)
)
plt.setp(self.ax_data.get_xticklabels(), visible=True)
@property
def _options_box(self):
return widgets.VBox(
[self.plot_type, self.use_weights, self.display_residuals]
)
def _update_display_box(self, box):
"""
Redraw the Jupyter GUI associated with the `Motofit` instance
"""
vbox_widgets = []
if self.dataset is not None:
vbox_widgets.append(widgets.HBox([self.dataset_name, self.chisqr]))
self.tab.children = [
self.model_view.model_box,
self.model_view.limits_box,
self._options_box,
]
self.tab.set_title(0, "Model")
self.tab.set_title(1, "Limits")
self.tab.set_title(2, "Options")
vbox_widgets.append(self.tab)
vbox_widgets.append(self.output)
box.children = tuple(vbox_widgets)
def rename_params(structure):
for i in range(1, len(structure) - 1):
structure[i].thick.name = "%d - thick" % i
structure[i].sld.real.name = "%d - sld" % i
structure[i].sld.imag.name = "%d - isld" % i
structure[i].rough.name = "%d - rough" % i
structure[0].sld.real.name = "sld - fronting"
structure[0].sld.imag.name = "isld - fronting"
structure[-1].sld.real.name = "sld - backing"
structure[-1].sld.imag.name = "isld - backing"
structure[-1].rough.name = "rough - backing"
def to_code(objective):
"""
Create executable Python code fragment that corresponds to the model in the
GUI.
Parameters
----------
objective: refnx.analysis.Objective
Returns
-------
code: str
Python code that can construct a reflectometry fitting system
"""
header = """import numpy as np
import refnx
from refnx.analysis import Objective, CurveFitter, Transform
from refnx.reflect import SLD, Slab, ReflectModel, Structure
from refnx.dataset import ReflectDataset
print(refnx.version.version)
"""
code = [header]
# the dataset
code.append(f"data = ReflectDataset('{objective.data.filename}'")
# make some SLD's and slabs
slds = ["\n# set up the SLD objects for each layer"]
slabs = ["\n# set up the Slab objects for each layer"]
limits = []
structure = "structure = "
for i, slab in enumerate(objective.model.structure):
sld = slab.sld
lims = [
(sld.real, "sld{0}.real"),
(sld.imag, "sld{0}.imag"),
(slab.thick, "slab{0}.thick"),
(slab.rough, "slab{0}.rough"),
]
slds.append(
f"sld{i} = SLD("
f"{sld.real.value} + {sld.imag.value}j, name='{slab.name}')"
)
slabs.append(
f"slab{i} = Slab("
f"{slab.thick.value}, sld{i}, {slab.rough.value}, "
f"name='{slab.name}')"
)
for p, temp in lims:
if p.vary:
limits.append(
(
temp
+ f".setp(vary=True, bounds=({p.bounds.lb}, {p.bounds.ub}))"
)
)
if not i:
structure += f"slab{i}"
else:
structure += f" | slab{i}"
code.extend(slds)
code.extend(slabs)
code.append("\n# set up the limits for SLD's and Slabs")
code.extend(limits)
code.append("\n# set up the Structure object from the Slabs")
code.append(structure)
model = objective.model
code.append("\n# make the reflectometry model")
code.append(
f"model = ReflectModel("
f"structure, "
f"scale={model.scale.value}, "
f"bkg={model.bkg.value}, "
f"dq={model.dq.value})"
)
lims = [
(model.scale, "model.scale"),
(model.bkg, "model.bkg"),
(model.dq, "model.dq"),
]
for p, temp in lims:
if p.vary:
code.append(
temp
+ f".setp(vary=True, bounds=({p.bounds.lb}, {p.bounds.ub}))"
)
code.append("\n# make the objective")
code.append(f"transform = {objective.transform!r}")
code.append(
f"objective = Objective(model, data, transform=transform,"
f" use_weights={objective.weighted})"
)
code.append("\n# make the curvefitter")
code.append("fitter = CurveFitter(objective)")
code.append("fitter.fit('differential_evolution')")
code.append("print(objective)")
return "\n".join(code)