springDamperUserFunctionNumbaJIT.py

You can view and download this file on Github: springDamperUserFunctionNumbaJIT.py

  1#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  2# This is an EXUDYN example
  3#
  4# Details:  Test with user-defined load function and user-defined spring-damper function (Duffing oscillator)
  5#
  6# Author:   Johannes Gerstmayr
  7# Date:     2019-11-15
  8#
  9# Copyright:This file is part of Exudyn. Exudyn is free software. You can redistribute it and/or modify it under the terms of the Exudyn license. See 'LICENSE.txt' for more details.
 10#
 11#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 12
 13import sys
 14sys.exudynFast = True
 15
 16from exudyn.utilities import ClearWorkspace
 17ClearWorkspace()
 18
 19import exudyn as exu
 20from exudyn.utilities import *
 21
 22import numpy as np
 23
 24#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 25#NUMBA PART; mainly, we need to register MainSystem mbs in numba to get user functions work
 26#import numba jit for compilation of functions:
 27# from numba import jit
 28
 29#create identity operator for replacement of jit:
 30try:
 31    from numba import jit
 32    print('running WITH JIT')
 33except: #define replacement operator
 34    print('running WITHOUT JIT')
 35    def jit(ob):
 36        return ob
 37
 38# from numba import jit, cfunc, types, njit
 39# from numba.types import float64, void, int64 #for signatures of user functions!
 40#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 41
 42
 43# @jit
 44# def myfunc():
 45#     print("my function")
 46
 47#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 48
 49
 50useGraphics = False #without test
 51
 52
 53SC = exu.SystemContainer()
 54mbs = SC.AddSystem()
 55exu.Print('EXUDYN version='+exu.GetVersionString())
 56
 57L=0.5
 58mass = 1.6          #mass in kg
 59spring = 4000       #stiffness of spring-damper in N/m
 60damper = 4          #damping constant in N/(m/s)
 61load0 = 80
 62
 63omega0=np.sqrt(spring/mass)
 64f0 = 0.*omega0/(2*np.pi)
 65f1 = 1.*omega0/(2*np.pi)
 66
 67exu.Print('resonance frequency = '+str(omega0))
 68tEnd = 50     #end time of simulation
 69steps = 1000000  #number of steps
 70
 71#first test without JIT:
 72
 73def sf(u,v,k,d):
 74    return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
 75
 76def springForce(mbs2, t, itemIndex, u, v, k, d, offset):
 77    return sf(u,v,k,d)
 78    # x=test(mbs.systemData.GetTime()) #5 microseconds
 79    # q=mbs.systemData.GetODE2Coordinates() #5 microseconds
 80    # return 0.1*k*u+k*u**3+v*d
 81
 82#linear frequency sweep in time interval [0, t1] and frequency interval [f0,f1];
 83def Sweep(t, t1, f0, f1):
 84    k = (f1-f0)/t1
 85    return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
 86
 87#user function for load; void replaces mbs, which then may not be used!!!
 88#most time lost due to pybind11 std::function capturing; no simple way to overcome problem at this point (avoid many function calls!)
 89#@cfunc(float64(void, float64, float64)) #possible, but does not lead to speed up
 90#@jit #not possible because of mbs not recognized by numba
 91def userLoad(mbs, t, load):
 92    #x=mbs.systemData.GetTime() #call to systemData function takes around 5us ! Cannot be optimized!
 93    #global tEnd, f0, f1 #global does not change performance
 94    return load*Sweep(t, tEnd, f0, f1) #global variable does not seem to make problems!
 95
 96#node for 3D mass point:
 97n1=mbs.AddNode(Point(referenceCoordinates = [L,0,0]))
 98
 99#ground node
100nGround=mbs.AddNode(NodePointGround(referenceCoordinates = [0,0,0]))
101
102#add mass point (this is a 3D object with 3 coordinates):
103massPoint = mbs.AddObject(MassPoint(physicsMass = mass, nodeNumber = n1))
104
105#marker for ground (=fixed):
106groundMarker=mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= nGround, coordinate = 0))
107#marker for springDamper for first (x-)coordinate:
108nodeMarker  =mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= n1, coordinate = 0))
109
110#Spring-Damper between two marker coordinates
111oSD=mbs.AddObject(CoordinateSpringDamper(markerNumbers = [groundMarker, nodeMarker],
112                                     stiffness = spring, damping = damper,
113                                     springForceUserFunction = springForce,
114                                     ))
115
116#add load:
117loadC = mbs.AddLoad(LoadCoordinate(markerNumber = nodeMarker,
118                           load = load0,
119                           loadUserFunction=userLoad,
120                           ))
121
122mbs.Assemble()
123
124simulationSettings = exu.SimulationSettings()
125simulationSettings.solutionSettings.writeSolutionToFile = False
126simulationSettings.timeIntegration.numberOfSteps = steps
127simulationSettings.timeIntegration.endTime = tEnd
128simulationSettings.timeIntegration.newton.useModifiedNewton=True
129
130simulationSettings.timeIntegration.generalizedAlpha.spectralRadius = 1
131
132simulationSettings.displayStatistics = True
133simulationSettings.displayComputationTime = True
134simulationSettings.timeIntegration.verboseMode = 1
135
136#start solver:
137mbs.SolveDynamic(simulationSettings)
138
139#evaluate final (=current) output values
140u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
141exu.Print('displacement=',u[0])
142
143
144#%%+++++++++++++++++++++++++++++++++++++++++++++++++++++
145#run again with JIT included:
146
147#use jit for every time-consuming parts
148#the more complex it gets, the speedup will be larger!
149#however, this part can only contain simple structures (no mbs, no exudyn functions [but you could @jit them!])
150@jit
151def sf2(u,v,k,d):
152    return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
153
154def springForce2(mbs2, t, itemIndex, u, v, k, d, offset):
155    return sf2(u,v,k,d)
156
157# jit for both sub-functions of user functions:
158mbs.SetObjectParameter(oSD, 'springForceUserFunction', springForce2)
159
160#jit gives us speedup and works out of the box:
161@jit
162def Sweep2(t, t1, f0, f1):
163    k = (f1-f0)/t1
164    return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
165
166#user function for load; void replaces mbs, which then may not be used!!!
167# @cfunc(float64(void, float64, float64), nopython=True, fastmath=True) #possible, but does not lead to speed up
168def userLoad2(mbs, t, load):
169    return load*Sweep2(t, tEnd, f0, f1) #global variable does not seem to make problems!
170
171mbs.SetLoadParameter(loadC,'loadUserFunction', userLoad2)
172
173mbs.SolveDynamic(simulationSettings)
174
175#evaluate final (=current) output values
176u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
177exu.Print('JIT, displacement=',u[0])
178
179
180#performance:
181#1e6 time steps
182# no user functions:
183# tCPU=1.15 seconds
184
185# regular, Python user function for spring-damper and load:
186# tCPU=16.7 seconds
187
188# jit, Python user function for spring-damper and load:
189# tCPU=5.58 seconds (on average)
190#==>speedup of user function part: 16.7/(5.58-1.15)=4.43
191#speedup will be much larger if Python functions are larger!
192#approx. 400.000 Python function calls/second!