springDamperUserFunctionNumbaJIT.py
You can view and download this file on Github: springDamperUserFunctionNumbaJIT.py
1#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2# This is an EXUDYN example
3#
4# Details: Test with user-defined load function and user-defined spring-damper function (Duffing oscillator)
5#
6# Author: Johannes Gerstmayr
7# Date: 2019-11-15
8#
9# Copyright:This file is part of Exudyn. Exudyn is free software. You can redistribute it and/or modify it under the terms of the Exudyn license. See 'LICENSE.txt' for more details.
10#
11#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
12
13import sys
14sys.exudynFast = True
15
16from exudyn.utilities import ClearWorkspace
17ClearWorkspace()
18
19import exudyn as exu
20from exudyn.utilities import *
21
22import numpy as np
23
24#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
25#NUMBA PART; mainly, we need to register MainSystem mbs in numba to get user functions work
26#import numba jit for compilation of functions:
27# from numba import jit
28
29#create identity operator for replacement of jit:
30try:
31 from numba import jit
32 print('running WITH JIT')
33except: #define replacement operator
34 print('running WITHOUT JIT')
35 def jit(ob):
36 return ob
37
38# from numba import jit, cfunc, types, njit
39# from numba.types import float64, void, int64 #for signatures of user functions!
40#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
41
42
43# @jit
44# def myfunc():
45# print("my function")
46
47#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
48
49
50useGraphics = False #without test
51
52
53SC = exu.SystemContainer()
54mbs = SC.AddSystem()
55exu.Print('EXUDYN version='+exu.GetVersionString())
56
57L=0.5
58mass = 1.6 #mass in kg
59spring = 4000 #stiffness of spring-damper in N/m
60damper = 4 #damping constant in N/(m/s)
61load0 = 80
62
63omega0=np.sqrt(spring/mass)
64f0 = 0.*omega0/(2*np.pi)
65f1 = 1.*omega0/(2*np.pi)
66
67exu.Print('resonance frequency = '+str(omega0))
68tEnd = 50 #end time of simulation
69steps = 1000000 #number of steps
70
71#first test without JIT:
72
73def sf(u,v,k,d):
74 return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
75
76def springForce(mbs2, t, itemIndex, u, v, k, d, offset):
77 return sf(u,v,k,d)
78 # x=test(mbs.systemData.GetTime()) #5 microseconds
79 # q=mbs.systemData.GetODE2Coordinates() #5 microseconds
80 # return 0.1*k*u+k*u**3+v*d
81
82#linear frequency sweep in time interval [0, t1] and frequency interval [f0,f1];
83def Sweep(t, t1, f0, f1):
84 k = (f1-f0)/t1
85 return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
86
87#user function for load; void replaces mbs, which then may not be used!!!
88#most time lost due to pybind11 std::function capturing; no simple way to overcome problem at this point (avoid many function calls!)
89#@cfunc(float64(void, float64, float64)) #possible, but does not lead to speed up
90#@jit #not possible because of mbs not recognized by numba
91def userLoad(mbs, t, load):
92 #x=mbs.systemData.GetTime() #call to systemData function takes around 5us ! Cannot be optimized!
93 #global tEnd, f0, f1 #global does not change performance
94 return load*Sweep(t, tEnd, f0, f1) #global variable does not seem to make problems!
95
96#node for 3D mass point:
97n1=mbs.AddNode(Point(referenceCoordinates = [L,0,0]))
98
99#ground node
100nGround=mbs.AddNode(NodePointGround(referenceCoordinates = [0,0,0]))
101
102#add mass point (this is a 3D object with 3 coordinates):
103massPoint = mbs.AddObject(MassPoint(physicsMass = mass, nodeNumber = n1))
104
105#marker for ground (=fixed):
106groundMarker=mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= nGround, coordinate = 0))
107#marker for springDamper for first (x-)coordinate:
108nodeMarker =mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= n1, coordinate = 0))
109
110#Spring-Damper between two marker coordinates
111oSD=mbs.AddObject(CoordinateSpringDamper(markerNumbers = [groundMarker, nodeMarker],
112 stiffness = spring, damping = damper,
113 springForceUserFunction = springForce,
114 ))
115
116#add load:
117loadC = mbs.AddLoad(LoadCoordinate(markerNumber = nodeMarker,
118 load = load0,
119 loadUserFunction=userLoad,
120 ))
121
122mbs.Assemble()
123
124simulationSettings = exu.SimulationSettings()
125simulationSettings.solutionSettings.writeSolutionToFile = False
126simulationSettings.timeIntegration.numberOfSteps = steps
127simulationSettings.timeIntegration.endTime = tEnd
128simulationSettings.timeIntegration.newton.useModifiedNewton=True
129
130simulationSettings.timeIntegration.generalizedAlpha.spectralRadius = 1
131
132simulationSettings.displayStatistics = True
133simulationSettings.displayComputationTime = True
134simulationSettings.timeIntegration.verboseMode = 1
135
136#start solver:
137mbs.SolveDynamic(simulationSettings)
138
139#evaluate final (=current) output values
140u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
141exu.Print('displacement=',u[0])
142
143
144#%%+++++++++++++++++++++++++++++++++++++++++++++++++++++
145#run again with JIT included:
146
147#use jit for every time-consuming parts
148#the more complex it gets, the speedup will be larger!
149#however, this part can only contain simple structures (no mbs, no exudyn functions [but you could @jit them!])
150@jit
151def sf2(u,v,k,d):
152 return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
153
154def springForce2(mbs2, t, itemIndex, u, v, k, d, offset):
155 return sf2(u,v,k,d)
156
157# jit for both sub-functions of user functions:
158mbs.SetObjectParameter(oSD, 'springForceUserFunction', springForce2)
159
160#jit gives us speedup and works out of the box:
161@jit
162def Sweep2(t, t1, f0, f1):
163 k = (f1-f0)/t1
164 return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
165
166#user function for load; void replaces mbs, which then may not be used!!!
167# @cfunc(float64(void, float64, float64), nopython=True, fastmath=True) #possible, but does not lead to speed up
168def userLoad2(mbs, t, load):
169 return load*Sweep2(t, tEnd, f0, f1) #global variable does not seem to make problems!
170
171mbs.SetLoadParameter(loadC,'loadUserFunction', userLoad2)
172
173mbs.SolveDynamic(simulationSettings)
174
175#evaluate final (=current) output values
176u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
177exu.Print('JIT, displacement=',u[0])
178
179
180#performance:
181#1e6 time steps
182# no user functions:
183# tCPU=1.15 seconds
184
185# regular, Python user function for spring-damper and load:
186# tCPU=16.7 seconds
187
188# jit, Python user function for spring-damper and load:
189# tCPU=5.58 seconds (on average)
190#==>speedup of user function part: 16.7/(5.58-1.15)=4.43
191#speedup will be much larger if Python functions are larger!
192#approx. 400.000 Python function calls/second!