Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add test pytorch imported models #79

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
cmake_minimum_required(VERSION 3.1)
project(RTNeural VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 17)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RTNeural is intended to be compatible with C++14, so we shouldn't be setting this in the top-level CMakeList.

include(cmake/CXXStandard.cmake)

add_subdirectory(RTNeural)
Expand Down
1 change: 1 addition & 0 deletions models/pytorch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"model_data": {"model": "SimpleRNN", "input_size": 1, "skip": 1, "output_size": 1, "unit_type": "LSTM", "num_layers": 1, "hidden_size": 12, "bias_fl": true}, "state_dict": {"rec.weight_ih_l0": [[0.05020524561405182], [-0.31360065937042236], [-0.15067864954471588], [0.2806737422943115], [0.02998657338321209], [-0.6892085075378418], [-0.04979453980922699], [0.4359737038612366], [-0.3269232511520386], [0.03499136492609978], [0.04311678558588028], [-0.018392041325569153], [-0.09896978735923767], [0.5271430015563965], [-0.5696582794189453], [0.32318925857543945], [-0.13496124744415283], [0.405482679605484], [-0.313450425863266], [0.0730346143245697], [-0.12087029963731766], [0.19023048877716064], [-0.06474478542804718], [0.15336623787879944], [-0.5249449610710144], [-0.22947825491428375], [0.13840869069099426], [-0.8544607162475586], [-1.0087761878967285], [0.11612961441278458], [0.18440590798854828], [0.9191402792930603], [0.9374542236328125], [-0.3703789710998535], [1.7154607772827148], [0.5464937090873718], [0.02379165217280388], [0.01072766538709402], [-0.030283097177743912], [0.13924889266490936], [-0.018017185851931572], [-0.18404167890548706], [0.03342645242810249], [0.22683492302894592], [-0.02473754994571209], [0.27487701177597046], [0.00877404399216175], [-0.0435529425740242]], "rec.weight_hh_l0": [[-0.0419795922935009, 0.1817820817232132, -0.2635403871536255, -0.38822832703590393, -0.10167697072029114, -0.2753196656703949, -0.08735445141792297, -0.13422825932502747, -0.2680502235889435, -0.1305692046880722, -0.003557975636795163, -0.13957737386226654], [0.3447078764438629, 0.1452958583831787, 0.21235381066799164, 0.35688483715057373, -0.17178060114383698, 0.1822955161333084, 1.0669270753860474, -0.9317453503608704, -0.2226657122373581, 0.6336563229560852, -0.17045281827449799, 0.7514078617095947], [-0.2533455491065979, -0.06069974973797798, -0.35844650864601135, 0.003759554587304592, 0.11620926856994629, 0.08196529000997543, 0.443137526512146, -0.11686224490404129, 0.024045534431934357, 0.16937142610549927, -0.0515710711479187, -0.32831907272338867], [0.13250599801540375, -0.019869109615683556, -0.2595517039299011, 0.2735815942287445, 0.2469712197780609, -0.10645327717065811, -0.2007511705160141, 0.06993260234594345, 0.19377069175243378, -0.420617938041687, 0.1740662157535553, -0.37211093306541443], [0.09176705777645111, 0.054482121020555496, -0.11108475923538208, -0.05554075911641121, 0.028892921283841133, -0.5341330170631409, -0.44128209352493286, 0.48233988881111145, 0.11457467824220657, -0.3316223919391632, 0.012944720685482025, -0.4595147371292114], [0.10779893398284912, -0.39388924837112427, 0.20794905722141266, 0.38155287504196167, 0.26059120893478394, 0.4953070282936096, 0.593137800693512, -0.5435242056846619, -0.7978695034980774, 0.8115385174751282, -0.026878798380494118, 0.5990843772888184], [-0.031853239983320236, -0.17585283517837524, -0.10184640437364578, 0.1836632341146469, 0.2556924819946289, 0.036626849323511124, 0.21951058506965637, -0.4338825047016144, 0.02179029770195484, 0.2786271870136261, -0.10784928500652313, -0.03530285134911537], [-0.09768553078174591, 0.2312772125005722, -0.3799952268600464, -0.020916374400258064, -0.1588582694530487, 0.3677980601787567, 0.38945162296295166, 0.19620734453201294, 0.411011666059494, 0.1406802237033844, 0.047509368509054184, -0.4520294666290283], [-0.21391651034355164, 0.08397214859724045, -0.4814964234828949, 0.04342612996697426, -0.007354023400694132, 0.08117345720529556, -0.014773029834032059, -0.107285276055336, 0.4101288914680481, 0.716863751411438, -0.15310686826705933, -0.6229735016822815], [-0.04274970665574074, 0.300245076417923, -0.5902673602104187, -0.26268288493156433, -0.08551324158906937, 0.10370974242687225, 0.15658539533615112, -0.6740304827690125, 0.3825993835926056, -0.7310903668403625, -0.2002175748348236, -0.12198829650878906], [0.007586871739476919, -0.023996934294700623, -0.3798433840274811, 0.10259561985731125, 0.13029052317142487, -0.14023162424564362, -0.1841200888156891, 0.1720430701971054, 0.13423040509223938, -0.1883481740951538, -0.1360965520143509, -0.35829469561576843], [0.044670674949884415, 0.0580819770693779, -0.4002642035484314, -0.1053876131772995, 0.19849006831645966, 0.3271082043647766, 0.022483645007014275, -0.4021897614002228, 0.07445303350687027, -0.19105540215969086, -0.07941491901874542, -0.2842693030834198], [0.2131032794713974, -0.1319408267736435, -0.3117741346359253, -0.3962928354740143, -0.06026780977845192, -0.6062923073768616, -0.19696640968322754, 0.021591667085886, -0.16526609659194946, -0.22705307602882385, -0.09513083100318909, -0.032005418092012405], [-0.15700161457061768, 0.2977375090122223, 0.2477426379919052, 0.43106889724731445, -0.006900585722178221, -0.3350847363471985, -0.5528604388237, 0.6616589426994324, 0.6194667816162109, -0.5455767512321472, -0.035018645226955414, -0.9212391972541809], [-0.11113276332616806, -0.31088364124298096, -0.37277036905288696, -0.027676252648234367, -0.18657110631465912, -0.29180675745010376, -0.49224501848220825, 0.9204225540161133, -0.0398826077580452, -0.4514818489551544, -0.09501200914382935, -0.2861374020576477], [0.08624312281608582, -0.29853367805480957, 0.03387894108891487, 0.2551155090332031, 0.08103425800800323, 0.08212204277515411, -0.053582001477479935, -0.11621475219726562, 0.03102377988398075, -0.4537014663219452, 0.5775313973426819, 0.1742255538702011], [-0.11658778786659241, 0.025222918018698692, -0.5548572540283203, -0.004994192160665989, 0.059502869844436646, -0.05820043757557869, -0.20483215153217316, 0.756228506565094, -0.6063075065612793, -0.5825205445289612, -0.07070402055978775, 0.023446781560778618], [-0.3176962435245514, 0.4853091835975647, 0.02597038447856903, -0.2562324106693268, -0.23233185708522797, -0.13416479527950287, -0.45839011669158936, 0.7953363656997681, -0.8111507296562195, -0.7934166789054871, -0.4952712655067444, -0.565464437007904], [-0.02637968398630619, 0.202874094247818, -0.38101905584335327, -0.08340650051832199, -0.21255797147750854, -0.6426507234573364, -0.3722417950630188, -0.13830359280109406, -0.040362462401390076, -0.7064327597618103, -0.22895532846450806, -0.6142423152923584], [0.038974057883024216, 0.5124906897544861, -0.41787290573120117, -0.631679892539978, 0.11833527684211731, -0.0908958688378334, 0.31528589129447937, 0.3563598096370697, 0.1361166536808014, -0.7814971208572388, 0.03258926048874855, -0.49021443724632263], [0.3476903736591339, 0.08684235066175461, -0.047647569328546524, -0.5132737755775452, -0.16795197129249573, 0.18226490914821625, 0.17241425812244415, -0.15925873816013336, 0.27610304951667786, -0.9138140082359314, 0.06866197288036346, -0.6046007871627808], [-0.15014983713626862, 0.06521818041801453, -0.1884533166885376, -0.2480771541595459, -0.1519910991191864, 0.26398321986198425, 0.20219819247722626, -0.46542179584503174, 0.5596254467964172, -0.8196132183074951, -0.03035290725529194, -0.40090158581733704], [0.0800454393029213, -0.21852022409439087, -0.39614754915237427, 0.34936994314193726, -0.04554533213376999, -0.17933227121829987, -0.1229795441031456, 0.21646521985530853, 0.22071956098079681, -0.14098012447357178, -0.25386756658554077, -0.5645362138748169], [-0.18772798776626587, 0.08124112337827682, -0.3827485740184784, -0.2828895151615143, 0.006489857099950314, -0.31992727518081665, -0.7471112012863159, 0.8270829916000366, 0.40388038754463196, -0.07709557563066483, -0.09612168371677399, -0.533905565738678], [-0.05556026101112366, -0.9484651684761047, -0.3851669132709503, 0.4081540107727051, 0.10818988084793091, 0.05053427442908287, -0.0068828146904706955, 0.5549035668373108, 0.430528461933136, 0.19256411492824554, -0.5549635887145996, 0.4220047891139984], [-0.22861865162849426, 0.2707577645778656, 0.15497872233390808, -0.00420735776424408, 0.1368376761674881, 0.11974138021469116, 0.27559995651245117, 0.47974255681037903, -0.08727943152189255, 0.23193541169166565, -0.3137245178222656, -0.11538560688495636], [0.5224423408508301, 0.0007433576975017786, 0.6795427203178406, -0.13858827948570251, -0.09943308681249619, 0.23152537643909454, -0.06098189577460289, 0.13208699226379395, -0.014613705687224865, 0.34401029348373413, -0.11229292303323746, -0.11098750680685043], [-0.02699975296854973, -0.38032159209251404, 0.1251433938741684, 0.08602305501699448, -0.07317224144935608, 0.35677748918533325, 0.5047268867492676, 0.38798922300338745, 0.8084245324134827, -0.4468041658401489, -0.11837373673915863, -0.021292971447110176], [-2.9368317127227783, -0.46928903460502625, 1.7071194648742676, 0.32532450556755066, 1.2166223526000977, -0.26895201206207275, -0.45305702090263367, 0.12723864614963531, 0.08562743663787842, -0.3800410330295563, 1.57733952999115, -0.6840223073959351], [-0.058886781334877014, -0.13382424414157867, -0.23539793491363525, 0.19094085693359375, 0.07505623996257782, 0.34596890211105347, 0.25001296401023865, -0.06199520081281662, 0.000489857979118824, 0.11955216526985168, -0.012012261897325516, -0.1772087812423706], [0.006041266955435276, 0.26415735483169556, 0.14041587710380554, -0.11104265600442886, 0.028746692463755608, 0.58452308177948, 0.12349818646907806, 0.019248418509960175, 0.49999186396598816, -0.11842067539691925, -0.0003830951754935086, 0.20304101705551147], [-0.08140194416046143, 0.2235168218612671, -0.3840726315975189, -0.4758818745613098, 0.030706197023391724, -0.7356177568435669, -0.17233802378177643, 0.4515981078147888, -0.7734667062759399, 0.255995512008667, 0.6352667212486267, 0.08383439481258392], [-0.04103945195674896, -0.1180829331278801, 0.04661482945084572, -0.3084677457809448, -0.013833864592015743, -1.122132658958435, 0.06130778044462204, 0.3286936581134796, 0.4442979693412781, -0.10064051300287247, -0.0011290592374280095, 0.08868379145860672], [-0.06547056883573532, 0.016941610723733902, 0.07246027141809464, -0.31740376353263855, -0.11826825886964798, 0.9346165657043457, 0.07284092158079147, 1.1103997230529785, -0.6557912230491638, 0.12093563377857208, -0.319370836019516, -0.055965013802051544], [0.014858919195830822, -0.29892420768737793, 0.35606926679611206, -0.25534215569496155, -0.142509326338768, -0.16611771285533905, -0.3154788017272949, -0.444398432970047, -0.5605687499046326, -0.233811154961586, -0.34681662917137146, -0.062379416078329086], [-0.48060405254364014, -0.2162168323993683, -0.17593473196029663, 0.40442851185798645, 0.1034175232052803, -1.0842241048812866, 0.3691804111003876, -0.3307529091835022, -0.33082932233810425, -0.21694831550121307, -0.16666562855243683, 0.1988600492477417], [0.11597444117069244, 0.2257516235113144, -0.2621161937713623, -0.27512669563293457, -0.04655130207538605, -0.4819784164428711, -0.3634260594844818, 0.3025965988636017, -0.0394701287150383, -0.26852646470069885, -0.07453807443380356, -0.2688557505607605], [0.28041157126426697, -0.12446824461221695, 0.09645140171051025, 0.46718695759773254, 0.09595435857772827, -0.22437851130962372, 0.046361811459064484, -0.03138069063425064, -0.24360516667366028, 0.04627212882041931, -0.05742642283439636, -0.2166488766670227], [0.4229910373687744, -0.3798232972621918, -0.263976514339447, 0.019098546355962753, -0.055619072169065475, -0.19564153254032135, -0.3931824862957001, -0.03400629386305809, -0.07653382420539856, -0.4449172914028168, 0.2134639173746109, -0.09262046217918396], [0.11228057742118835, 0.04794364422559738, -0.252733051776886, 0.1409764587879181, -0.32227617502212524, -0.038604140281677246, -0.2105533331632614, 0.29489123821258545, 0.15710194408893585, -0.3597046732902527, 0.17894583940505981, -0.5184696316719055], [-0.08132879436016083, -0.19466619193553925, 0.03514159470796585, 0.07888448983430862, -0.30759933590888977, -0.09117507934570312, 0.1740230917930603, 0.02857838198542595, 0.035044703632593155, 0.026516977697610855, -0.10298246145248413, 0.07899776846170425], [0.17437061667442322, 0.18775694072246552, 0.2819613516330719, 0.3819820284843445, -0.05457629635930061, -1.1224738359451294, -0.5681564211845398, -0.03331182524561882, 0.15846768021583557, -0.0976681113243103, -0.3201681077480316, -0.10659728944301605], [0.1399640142917633, -0.23175092041492462, 0.3789215683937073, 0.313232958316803, 0.11130651086568832, -0.4792938828468323, 0.05871028080582619, -0.12333610653877258, -0.2872529923915863, 0.03286588191986084, 0.4404626488685608, -0.12851369380950928], [0.1265312135219574, -0.09050451219081879, -0.0037765454035252333, 0.04323415830731392, -0.09308173507452011, -0.016297774389386177, -0.05721269175410271, 0.584079384803772, 0.35873615741729736, 0.11946998536586761, 0.10912395268678665, -0.6366721987724304], [0.04909504950046539, 0.32090526819229126, -0.5680671334266663, -0.1072963997721672, -0.03254212066531181, -0.2975282669067383, -0.2773382067680359, 0.324497252702713, 0.2021632343530655, -0.2434045672416687, -0.3057680130004883, -0.43721804022789], [0.08271791785955429, 0.5097098350524902, 0.03258184716105461, -0.4940716326236725, 0.32611075043678284, 0.47777312994003296, 0.3004067838191986, 1.0320416688919067, 0.29032155871391296, -0.4478389322757721, -0.19198855757713318, -0.5681642293930054], [-0.13038671016693115, -0.0025995909236371517, -0.15705037117004395, 0.14669963717460632, -0.02158268913626671, -0.22054298222064972, -0.26781514286994934, 0.16741982102394104, 0.14742746949195862, -0.08125656098127365, -0.10222998261451721, -0.4331514239311218], [-0.7470344305038452, 0.1750098615884781, 0.8109745979309082, -0.6849470734596252, 0.10942117124795914, -0.125234454870224, -0.24239204823970795, -0.3907221555709839, -0.15082667768001556, -0.4020195007324219, 0.577019214630127, -0.813296914100647]], "rec.bias_ih_l0": [0.18495628237724304, -0.506679892539978, -0.46178513765335083, 0.05456842482089996, 1.1718180179595947, -1.0529217720031738, -0.19310152530670166, 0.28850769996643066, -0.1551726907491684, 0.2721996307373047, 0.37355145812034607, -0.2473764419555664, 0.1638660728931427, 1.1433839797973633, 0.5559374690055847, 0.22115133702754974, 0.7331377863883972, 1.3374582529067993, 0.5825927257537842, 0.054831571877002716, 0.3157130777835846, 0.32288771867752075, 0.21341075003147125, 0.4869515001773834, -0.05533791705965996, 0.09520503133535385, 0.08175188302993774, 0.08855045586824417, -0.0142483776435256, -0.14839479327201843, 0.008983447216451168, -0.06693097203969955, -0.14205633103847504, -0.09493766725063324, -0.13136929273605347, -0.16240379214286804, 0.6593663692474365, 0.016166353598237038, 0.5665537714958191, 0.19291839003562927, -0.7089212536811829, -0.04102630540728569, 0.07033070921897888, 0.6546640396118164, 0.36995425820350647, 0.21640649437904358, 0.5367745757102966, 0.15529075264930725], "rec.bias_hh_l0": [0.15469570457935333, -0.7044667601585388, -0.10761204361915588, 0.037127215415239334, 1.1718180179595947, -0.8123846650123596, -0.5138460397720337, 0.3279798626899719, -0.13171398639678955, 0.06436160951852798, 0.3052067756652832, 0.040790706872940063, 0.047532640397548676, 0.9885876774787903, 0.5453023910522461, 0.22449643909931183, 0.49310818314552307, 1.3622493743896484, 0.5394005179405212, 0.24795567989349365, 0.657709002494812, 0.26866820454597473, 0.015237968415021896, 0.6372072696685791, 0.14401943981647491, -0.01898677460849285, 0.013817887753248215, -0.08701165020465851, -0.3654780387878418, -0.1833612322807312, -0.08087602257728577, -0.203774094581604, -0.4889949858188629, -0.05346549674868584, 0.08078761398792267, -0.186619371175766, 0.6807979941368103, 0.15421076118946075, 0.809780478477478, 0.23488816618919373, -0.4570198059082031, 0.009004351682960987, 0.16159342229366302, 0.7161290645599365, 0.4933304786682129, -0.0530376173555851, 0.5819734334945679, 0.13463646173477173], "lin.weight": [[-0.11094890534877777, -0.25576481223106384, 0.06525841355323792, -0.06567507237195969, 1.016162633895874, -0.26113417744636536, -0.21008238196372986, -0.13794735074043274, -0.33709046244621277, -0.07722245156764984, -0.5268365144729614, -0.008685651235282421]], "lin.bias": [-0.1556263417005539]}}
Loading