Coverage for /builds/kinetik161/ase/ase/calculators/kim/kimpy_wrappers.py: 74.53%
318 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""
2Wrappers that provide a minimal interface to kimpy methods and objects
4Daniel S. Karls
5University of Minnesota
6"""
8import functools
9from abc import ABC
11import kimpy
12import numpy as np
14from .exceptions import (KIMModelInitializationError, KIMModelNotFound,
15 KIMModelParameterError, KimpyError)
17# Function used for casting parameter/extent indices to C-compatible ints
18c_int = np.intc
20# Function used for casting floating point parameter values to C-compatible
21# doubles
22c_double = np.double
25def c_int_args(func):
26 """
27 Decorator for instance methods that will cast all of the args passed,
28 excluding the first (which corresponds to 'self'), to C-compatible
29 integers.
30 """
32 @functools.wraps(func)
33 def myfunc(*args, **kwargs):
34 args_cast = [args[0]]
35 args_cast += map(c_int, args[1:])
36 return func(*args, **kwargs)
38 return myfunc
41def check_call(f, *args, **kwargs):
42 """Call a kimpy function using its arguments and, if a RuntimeError is
43 raised, catch it and raise a KimpyError with the exception's
44 message.
46 (Starting with kimpy 2.0.0, a RuntimeError is the only exception
47 type raised when something goes wrong.)"""
49 try:
50 return f(*args, **kwargs)
51 except RuntimeError as e:
52 raise KimpyError(
53 f'Calling kimpy function "{f.__name__}" failed:\n {str(e)}')
56def check_call_wrapper(func):
57 @functools.wraps(func)
58 def myfunc(*args, **kwargs):
59 return check_call(func, *args, **kwargs)
61 return myfunc
64# kimpy methods
65collections_create = functools.partial(check_call, kimpy.collections.create)
66model_create = functools.partial(check_call, kimpy.model.create)
67simulator_model_create = functools.partial(
68 check_call, kimpy.simulator_model.create)
69get_species_name = functools.partial(
70 check_call, kimpy.species_name.get_species_name)
71get_number_of_species_names = functools.partial(
72 check_call, kimpy.species_name.get_number_of_species_names
73)
75# kimpy attributes (here to avoid importing kimpy in higher-level modules)
76collection_item_type_portableModel = kimpy.collection_item_type.portableModel
79class ModelCollections:
80 """
81 KIM Portable Models and Simulator Models are installed/managed into
82 different "collections". In order to search through the different
83 KIM API model collections on the system, a corresponding object must
84 be instantiated. For more on model collections, see the KIM API's
85 install file:
86 https://github.com/openkim/kim-api/blob/master/INSTALL
87 """
89 def __init__(self):
90 self.collection = collections_create()
92 def __enter__(self):
93 return self
95 def __exit__(self, exc_type, value, traceback):
96 pass
98 def get_item_type(self, model_name):
99 try:
100 model_type = check_call(self.collection.get_item_type, model_name)
101 except KimpyError:
102 msg = (
103 "Could not find model {} installed in any of the KIM API "
104 "model collections on this system. See "
105 "https://openkim.org/doc/usage/obtaining-models/ for "
106 "instructions on installing models.".format(model_name)
107 )
108 raise KIMModelNotFound(msg)
110 return model_type
112 @property
113 def initialized(self):
114 return hasattr(self, "collection")
117class PortableModel:
118 """Creates a KIM API Portable Model object and provides a minimal
119 interface to it"""
121 def __init__(self, model_name, debug):
122 self.model_name = model_name
123 self.debug = debug
125 # Create KIM API Model object
126 units_accepted, self.kim_model = model_create(
127 kimpy.numbering.zeroBased,
128 kimpy.length_unit.A,
129 kimpy.energy_unit.eV,
130 kimpy.charge_unit.e,
131 kimpy.temperature_unit.K,
132 kimpy.time_unit.ps,
133 self.model_name,
134 )
136 if not units_accepted:
137 raise KIMModelInitializationError(
138 "Requested units not accepted in kimpy.model.create"
139 )
141 if self.debug:
142 l_unit, e_unit, c_unit, te_unit, ti_unit = check_call(
143 self.kim_model.get_units
144 )
145 print(f"Length unit is: {l_unit}")
146 print(f"Energy unit is: {e_unit}")
147 print(f"Charge unit is: {c_unit}")
148 print(f"Temperature unit is: {te_unit}")
149 print(f"Time unit is: {ti_unit}")
150 print()
152 self._create_parameters()
154 def __enter__(self):
155 return self
157 def __exit__(self, exc_type, value, traceback):
158 pass
160 @check_call_wrapper
161 def _get_number_of_parameters(self):
162 return self.kim_model.get_number_of_parameters()
164 def _create_parameters(self):
165 def _kim_model_parameter(**kwargs):
166 dtype = kwargs["dtype"]
168 if dtype == "Integer":
169 return KIMModelParameterInteger(**kwargs)
170 elif dtype == "Double":
171 return KIMModelParameterDouble(**kwargs)
172 else:
173 raise KIMModelParameterError(
174 f"Invalid model parameter type {dtype}. Supported types "
175 "'Integer' and 'Double'."
176 )
178 self._parameters = {}
179 num_params = self._get_number_of_parameters()
180 for index_param in range(num_params):
181 parameter_metadata = self._get_one_parameter_metadata(index_param)
182 name = parameter_metadata["name"]
184 self._parameters[name] = _kim_model_parameter(
185 kim_model=self.kim_model,
186 dtype=parameter_metadata["dtype"],
187 extent=parameter_metadata["extent"],
188 name=name,
189 description=parameter_metadata["description"],
190 parameter_index=index_param,
191 )
193 def get_model_supported_species_and_codes(self):
194 """Get all of the supported species for this model and their
195 corresponding integer codes that are defined in the KIM API
197 Returns
198 -------
199 species : list of str
200 Abbreviated chemical symbols of all species the mmodel
201 supports (e.g. ["Mo", "S"])
203 codes : list of int
204 Integer codes used by the model for each species (order
205 corresponds to the order of ``species``)
206 """
207 species = []
208 codes = []
209 num_kim_species = get_number_of_species_names()
211 for i in range(num_kim_species):
212 species_name = get_species_name(i)
214 species_is_supported, code = self.get_species_support_and_code(
215 species_name)
217 if species_is_supported:
218 species.append(str(species_name))
219 codes.append(code)
221 return species, codes
223 @check_call_wrapper
224 def clear_then_refresh(self):
225 self.kim_model.clear_then_refresh()
227 @c_int_args
228 def _get_parameter_metadata(self, index_parameter):
229 try:
230 dtype, extent, name, description = check_call(
231 self.kim_model.get_parameter_metadata, index_parameter
232 )
233 except KimpyError as e:
234 raise KIMModelParameterError(
235 "Failed to retrieve metadata for "
236 f"parameter at index {index_parameter}"
237 ) from e
239 return dtype, extent, name, description
241 def parameters_metadata(self):
242 """Metadata associated with all model parameters.
244 Returns
245 -------
246 dict
247 Metadata associated with all model parameters.
248 """
249 return {
250 param_name: param.metadata
251 for param_name, param in self._parameters.items()
252 }
254 def parameter_names(self):
255 """Names of model parameters registered in the KIM API.
257 Returns
258 -------
259 tuple
260 Names of model parameters registered in the KIM API
261 """
262 return tuple(self._parameters.keys())
264 def get_parameters(self, **kwargs):
265 """
266 Get the values of one or more model parameter arrays.
268 Given the names of one or more model parameters and a set of indices
269 for each of them, retrieve the corresponding elements of the relevant
270 model parameter arrays.
272 Parameters
273 ----------
274 **kwargs
275 Names of the model parameters and the indices whose values should
276 be retrieved.
278 Returns
279 -------
280 dict
281 The requested indices and the values of the model's parameters.
283 Note
284 ----
285 The output of this method can be used as input of
286 ``set_parameters``.
288 Example
289 -------
290 To get `epsilons` and `sigmas` in the LJ universal model for Mo-Mo
291 (index 4879), Mo-S (index 2006) and S-S (index 1980) interactions::
293 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
294 >>> calc = KIM(LJ)
295 >>> calc.get_parameters(epsilons=[4879, 2006, 1980],
296 ... sigmas=[4879, 2006, 1980])
297 {'epsilons': [[4879, 2006, 1980],
298 [4.47499, 4.421814057295943, 4.36927]],
299 'sigmas': [[4879, 2006, 1980],
300 [2.74397, 2.30743, 1.87089]]}
301 """
302 parameters = {}
303 for parameter_name, index_range in kwargs.items():
304 parameters.update(
305 self._get_one_parameter(
306 parameter_name,
307 index_range))
308 return parameters
310 def set_parameters(self, **kwargs):
311 """
312 Set the values of one or more model parameter arrays.
314 Given the names of one or more model parameters and a set of indices
315 and corresponding values for each of them, mutate the corresponding
316 elements of the relevant model parameter arrays.
318 Parameters
319 ----------
320 **kwargs
321 Names of the model parameters to mutate and the corresponding
322 indices and values to set.
324 Returns
325 -------
326 dict
327 The requested indices and the values of the model's parameters
328 that were set.
330 Example
331 -------
332 To set `epsilons` in the LJ universal model for Mo-Mo (index 4879),
333 Mo-S (index 2006) and S-S (index 1980) interactions to 5.0, 4.5, and
334 4.0, respectively::
336 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
337 >>> calc = KIM(LJ)
338 >>> calc.set_parameters(epsilons=[[4879, 2006, 1980],
339 ... [5.0, 4.5, 4.0]])
340 {'epsilons': [[4879, 2006, 1980],
341 [5.0, 4.5, 4.0]]}
342 """
343 parameters = {}
344 for parameter_name, parameter_data in kwargs.items():
345 index_range, values = parameter_data
346 self._set_one_parameter(parameter_name, index_range, values)
347 parameters[parameter_name] = parameter_data
349 return parameters
351 def _get_one_parameter(self, parameter_name, index_range):
352 """
353 Retrieve value of one or more components of a model parameter array.
355 Parameters
356 ----------
357 parameter_name : str
358 Name of model parameter registered in the KIM API.
359 index_range : int or list
360 Zero-based index (int) or indices (list of int) specifying the
361 component(s) of the corresponding model parameter array that are
362 to be retrieved.
364 Returns
365 -------
366 dict
367 The requested indices and the corresponding values of the model
368 parameter array.
369 """
370 if parameter_name not in self._parameters:
371 raise KIMModelParameterError(
372 f"Parameter '{parameter_name}' is not "
373 "supported by this model. "
374 "Please check that the parameter name is spelled correctly."
375 )
377 return self._parameters[parameter_name].get_values(index_range)
379 def _set_one_parameter(self, parameter_name, index_range, values):
380 """
381 Set the value of one or more components of a model parameter array.
383 Parameters
384 ----------
385 parameter_name : str
386 Name of model parameter registered in the KIM API.
387 index_range : int or list
388 Zero-based index (int) or indices (list of int) specifying the
389 component(s) of the corresponding model parameter array that are
390 to be mutated.
391 values : int/float or list
392 Value(s) to assign to the component(s) of the model parameter
393 array specified by ``index_range``.
394 """
395 if parameter_name not in self._parameters:
396 raise KIMModelParameterError(
397 f"Parameter '{parameter_name}' is not "
398 "supported by this model. "
399 "Please check that the parameter name is spelled correctly."
400 )
402 self._parameters[parameter_name].set_values(index_range, values)
404 def _get_one_parameter_metadata(self, index_parameter):
405 """
406 Get metadata associated with a single model parameter.
408 Parameters
409 ----------
410 index_parameter : int
411 Zero-based index used by the KIM API to refer to this model
412 parameter.
414 Returns
415 -------
416 dict
417 Metadata associated with the requested model parameter.
418 """
419 dtype, extent, name, description = self._get_parameter_metadata(
420 index_parameter)
421 parameter_metadata = {
422 "name": name,
423 "dtype": repr(dtype),
424 "extent": extent,
425 "description": description,
426 }
427 return parameter_metadata
429 @check_call_wrapper
430 def compute(self, compute_args_wrapped, release_GIL):
431 return self.kim_model.compute(
432 compute_args_wrapped.compute_args, release_GIL)
434 @check_call_wrapper
435 def get_species_support_and_code(self, species_name):
436 return self.kim_model.get_species_support_and_code(species_name)
438 @check_call_wrapper
439 def get_influence_distance(self):
440 return self.kim_model.get_influence_distance()
442 @check_call_wrapper
443 def get_neighbor_list_cutoffs_and_hints(self):
444 return self.kim_model.get_neighbor_list_cutoffs_and_hints()
446 def compute_arguments_create(self):
447 return ComputeArguments(self, self.debug)
449 @property
450 def initialized(self):
451 return hasattr(self, "kim_model")
454class KIMModelParameter(ABC):
455 def __init__(self, kim_model, dtype, extent,
456 name, description, parameter_index):
457 self._kim_model = kim_model
458 self._dtype = dtype
459 self._extent = extent
460 self._name = name
461 self._description = description
463 # Ensure that parameter_index is cast to a C-compatible integer. This
464 # is necessary because this is passed to kimpy.
465 self._parameter_index = c_int(parameter_index)
467 @property
468 def metadata(self):
469 return {
470 "dtype": self._dtype,
471 "extent": self._extent,
472 "name": self._name,
473 "description": self._description,
474 }
476 @c_int_args
477 def _get_one_value(self, index_extent):
478 get_parameter = getattr(self._kim_model, self._dtype_accessor)
479 try:
480 return check_call(
481 get_parameter, self._parameter_index, index_extent)
482 except KimpyError as exception:
483 raise KIMModelParameterError(
484 f"Failed to access component {index_extent} of model "
485 f"parameter of type '{self._dtype}' at parameter index "
486 f"{self._parameter_index}"
487 ) from exception
489 def _set_one_value(self, index_extent, value):
490 value_typecast = self._dtype_c(value)
492 try:
493 check_call(
494 self._kim_model.set_parameter,
495 self._parameter_index,
496 c_int(index_extent),
497 value_typecast,
498 )
499 except KimpyError:
500 raise KIMModelParameterError(
501 f"Failed to set component {index_extent} at parameter index "
502 f"{self._parameter_index} to {self._dtype} value "
503 f"{value_typecast}"
504 )
506 def get_values(self, index_range):
507 index_range_dim = np.ndim(index_range)
508 if index_range_dim == 0:
509 values = self._get_one_value(index_range)
510 elif index_range_dim == 1:
511 values = []
512 for idx in index_range:
513 values.append(self._get_one_value(idx))
514 else:
515 raise KIMModelParameterError(
516 "Index range must be an integer or a list of integers"
517 )
518 return {self._name: [index_range, values]}
520 def set_values(self, index_range, values):
521 index_range_dim = np.ndim(index_range)
522 values_dim = np.ndim(values)
524 # Check the shape of index_range and values
525 msg = "index_range and values must have the same shape"
526 assert index_range_dim == values_dim, msg
528 if index_range_dim == 0:
529 self._set_one_value(index_range, values)
530 elif index_range_dim == 1:
531 assert len(index_range) == len(values), msg
532 for idx, value in zip(index_range, values):
533 self._set_one_value(idx, value)
534 else:
535 raise KIMModelParameterError(
536 "Index range must be an integer or a list containing a "
537 "single integer"
538 )
541class KIMModelParameterInteger(KIMModelParameter):
542 _dtype_c = c_int
543 _dtype_accessor = "get_parameter_int"
546class KIMModelParameterDouble(KIMModelParameter):
547 _dtype_c = c_double
548 _dtype_accessor = "get_parameter_double"
551class ComputeArguments:
552 """Creates a KIM API ComputeArguments object from a KIM Portable
553 Model object and configures it for ASE. A ComputeArguments object
554 is associated with a KIM Portable Model and is used to inform the
555 KIM API of what the model can compute. It is also used to
556 register the data arrays that allow the KIM API to pass the atomic
557 coordinates to the model and retrieve the corresponding energy and
558 forces, etc."""
560 def __init__(self, kim_model_wrapped, debug):
561 self.kim_model_wrapped = kim_model_wrapped
562 self.debug = debug
564 # Create KIM API ComputeArguments object
565 self.compute_args = check_call(
566 self.kim_model_wrapped.kim_model.compute_arguments_create
567 )
569 # Check compute arguments
570 kimpy_arg_name = kimpy.compute_argument_name
571 num_arguments = kimpy_arg_name.get_number_of_compute_argument_names()
572 if self.debug:
573 print(f"Number of compute_args: {num_arguments}")
575 for i in range(num_arguments):
576 name = check_call(kimpy_arg_name.get_compute_argument_name, i)
577 dtype = check_call(
578 kimpy_arg_name.get_compute_argument_data_type, name)
580 arg_support = self.get_argument_support_status(name)
582 if self.debug:
583 print(
584 "Compute Argument name {:21} is of type {:7} "
585 "and has support "
586 "status {}".format(*[str(x)
587 for x in [name, dtype, arg_support]])
588 )
590 # See if the model demands that we ask it for anything
591 # other than energy and forces. If so, raise an
592 # exception.
593 if arg_support == kimpy.support_status.required:
594 if (
595 name != kimpy.compute_argument_name.partialEnergy
596 and name != kimpy.compute_argument_name.partialForces
597 ):
598 raise KIMModelInitializationError(
599 f"Unsupported required ComputeArgument {name}"
600 )
602 # Check compute callbacks
603 callback_name = kimpy.compute_callback_name
604 num_callbacks = callback_name.get_number_of_compute_callback_names()
605 if self.debug:
606 print()
607 print(f"Number of callbacks: {num_callbacks}")
609 for i in range(num_callbacks):
610 name = check_call(callback_name.get_compute_callback_name, i)
612 support_status = self.get_callback_support_status(name)
614 if self.debug:
615 print(
616 "Compute callback {:17} has support status {}".format(
617 str(name), support_status
618 )
619 )
621 # Cannot handle any "required" callbacks
622 if support_status == kimpy.support_status.required:
623 raise KIMModelInitializationError(
624 f"Unsupported required ComputeCallback: {name}"
625 )
627 @check_call_wrapper
628 def set_argument_pointer(self, compute_arg_name, data_object):
629 return self.compute_args.set_argument_pointer(
630 compute_arg_name, data_object)
632 @check_call_wrapper
633 def get_argument_support_status(self, name):
634 return self.compute_args.get_argument_support_status(name)
636 @check_call_wrapper
637 def get_callback_support_status(self, name):
638 return self.compute_args.get_callback_support_status(name)
640 @check_call_wrapper
641 def set_callback(self, compute_callback_name,
642 callback_function, data_object):
643 return self.compute_args.set_callback(
644 compute_callback_name, callback_function, data_object
645 )
647 @check_call_wrapper
648 def set_callback_pointer(
649 self, compute_callback_name, callback, data_object):
650 return self.compute_args.set_callback_pointer(
651 compute_callback_name, callback, data_object
652 )
654 def update(
655 self, num_particles, species_code, particle_contributing,
656 coords, energy, forces
657 ):
658 """Register model input and output in the kim_model object."""
659 compute_arg_name = kimpy.compute_argument_name
660 set_argument_pointer = self.set_argument_pointer
662 set_argument_pointer(compute_arg_name.numberOfParticles, num_particles)
663 set_argument_pointer(
664 compute_arg_name.particleSpeciesCodes,
665 species_code)
666 set_argument_pointer(
667 compute_arg_name.particleContributing, particle_contributing
668 )
669 set_argument_pointer(compute_arg_name.coordinates, coords)
670 set_argument_pointer(compute_arg_name.partialEnergy, energy)
671 set_argument_pointer(compute_arg_name.partialForces, forces)
673 if self.debug:
674 print("Debug: called update_kim")
675 print()
678class SimulatorModel:
679 """Creates a KIM API Simulator Model object and provides a minimal
680 interface to it. This is only necessary in this package in order to
681 extract any information about a given simulator model because it is
682 generally embedded in a shared object.
683 """
685 def __init__(self, model_name):
686 # Create a KIM API Simulator Model object for this model
687 self.model_name = model_name
688 self.simulator_model = simulator_model_create(self.model_name)
690 # Need to close template map in order to access simulator
691 # model metadata
692 self.simulator_model.close_template_map()
694 def __enter__(self):
695 return self
697 def __exit__(self, exc_type, value, traceback):
698 pass
700 @property
701 def simulator_name(self):
702 simulator_name, _ = self.simulator_model.\
703 get_simulator_name_and_version()
704 return simulator_name
706 @property
707 def num_supported_species(self):
708 num_supported_species = self.simulator_model.\
709 get_number_of_supported_species()
710 if num_supported_species == 0:
711 raise KIMModelInitializationError(
712 "Unable to determine supported species of "
713 "simulator model {}.".format(self.model_name)
714 )
715 return num_supported_species
717 @property
718 def supported_species(self):
719 supported_species = []
720 for spec_code in range(self.num_supported_species):
721 species = check_call(
722 self.simulator_model.get_supported_species, spec_code)
723 supported_species.append(species)
725 return tuple(supported_species)
727 @property
728 def num_metadata_fields(self):
729 return self.simulator_model.get_number_of_simulator_fields()
731 @property
732 def metadata(self):
733 sm_metadata_fields = {}
734 for field in range(self.num_metadata_fields):
735 extent, field_name = check_call(
736 self.simulator_model.get_simulator_field_metadata, field
737 )
738 sm_metadata_fields[field_name] = []
739 for ln in range(extent):
740 field_line = check_call(
741 self.simulator_model.get_simulator_field_line, field, ln
742 )
743 sm_metadata_fields[field_name].append(field_line)
745 return sm_metadata_fields
747 @property
748 def supported_units(self):
749 try:
750 supported_units = self.metadata["units"][0]
751 except (KeyError, IndexError):
752 raise KIMModelInitializationError(
753 "Unable to determine supported units of "
754 "simulator model {}.".format(self.model_name)
755 )
757 return supported_units
759 @property
760 def atom_style(self):
761 """
762 See if a 'model-init' field exists in the SM metadata and, if
763 so, whether it contains any entries including an "atom_style"
764 command. This is specific to LAMMPS SMs and is only required
765 for using the LAMMPSrun calculator because it uses
766 lammps.inputwriter to create a data file. All other content in
767 'model-init', if it exists, is ignored.
768 """
769 atom_style = None
770 for ln in self.metadata.get("model-init", []):
771 if ln.find("atom_style") != -1:
772 atom_style = ln.split()[1]
774 return atom_style
776 @property
777 def model_defn(self):
778 return self.metadata["model-defn"]
780 @property
781 def initialized(self):
782 return hasattr(self, "simulator_model")