1
+ import logging
2
+ import shlex
1
3
import shutil
4
+ import subprocess
2
5
import tempfile
3
- from copy import copy
4
6
from datetime import datetime , UTC
5
7
6
8
import pytest
7
9
import os
8
10
9
11
from flask import Flask
10
- from sqlalchemy import create_engine , select
11
- from sqlalchemy .orm import Session , make_transient_to_detached
12
+ from sqlalchemy import create_engine , text , CursorResult
13
+ from sqlalchemy .orm import Session
14
+ from sqlalchemy .pool import NullPool
12
15
13
16
from .legacy import util
14
17
from .legacy .passwords import hash_password
20
23
from ..auth .auth import Auth
21
24
from ..auth .auth .middleware import AuthMiddleware
22
25
26
+ logging .basicConfig (level = logging .INFO )
23
27
24
- @pytest .fixture
25
- def classic_db_engine ():
26
- db_path = tempfile .mkdtemp ()
27
- uri = f'sqlite:///{ db_path } /test.db'
28
- engine = create_engine (uri )
29
- util .create_all (engine )
30
- yield engine
31
- shutil .rmtree (db_path )
28
+ DB_PORT = 25336
29
+ DB_NAME = "testdb"
30
+ ROOT_PASSWORD = "rootpassword"
32
31
32
+ my_sql_cmd = ["mysql" , f"--port={ DB_PORT } " , "-h" , "127.0.0.1" , "-u" , "root" , f"--password={ ROOT_PASSWORD } " ,
33
+ # "--ssl-mode=DISABLED",
34
+ DB_NAME ]
33
35
36
+ def arxiv_base_dir () -> str :
37
+ """
38
+ Returns:
39
+ "arxiv-base" directory abs path
40
+ """
41
+ here = os .path .abspath (__file__ )
42
+ root_dir = here
43
+ for _ in range (3 ):
44
+ root_dir = os .path .dirname (root_dir )
45
+ return root_dir
46
+
47
+
48
+ @pytest .fixture (scope = "session" )
49
+ def db_uri (request ):
50
+ db_type = request .config .getoption ("--db" )
51
+
52
+ if db_type == "sqlite" :
53
+ # db_path = tempfile.mkdtemp()
54
+ # uri = f'sqlite:///{db_path}/test.db'
55
+ uri = f'sqlite'
56
+ elif db_type == "mysql" :
57
+ # load_arxiv_db_schema.py sets up the docker and load the db schema
58
+ loader_py = os .path .join (arxiv_base_dir (), "development" , "load_arxiv_db_schema.py" )
59
+ subprocess .run (["poetry" , "run" , "python" , loader_py , f"--db_name={ DB_NAME } " , f"--db_port={ DB_PORT } " ,
60
+ f"--root_password={ ROOT_PASSWORD } " ], encoding = "utf-8" , check = True )
61
+ uri = f"mysql://testuser:[email protected] :{ DB_PORT } /{ DB_NAME } "
62
+ else :
63
+ raise ValueError (f"Unsupported database dialect: { db_type } " )
64
+
65
+ yield uri
66
+
67
+
68
+ @pytest .fixture (scope = "function" )
69
+ def classic_db_engine (db_uri ):
70
+ logger = logging .getLogger ()
71
+ db_path = None
72
+ if db_uri .startswith ("sqlite" ):
73
+ db_path = tempfile .mkdtemp ()
74
+ uri = f'sqlite:///{ db_path } /test.db'
75
+ db_engine = create_engine (uri )
76
+ util .create_arxiv_db_schema (db_engine )
77
+ else :
78
+ conn_args = {}
79
+ # conn_args["ssl"] = None
80
+ db_engine = create_engine (db_uri , connect_args = conn_args , poolclass = NullPool )
81
+
82
+ # Clean up the tables to real fresh
83
+ targets = []
84
+ with db_engine .connect () as connection :
85
+ tables = [row [0 ] for row in connection .execute (text ("SHOW TABLES" ))]
86
+ for table_name in tables :
87
+ counter : CursorResult = connection .execute (text (f"select count(*) from { table_name } " ))
88
+ count = counter .first ()[0 ]
89
+ if count and int (count ):
90
+ targets .append (table_name )
91
+ connection .invalidate ()
92
+
93
+ if targets :
94
+ statements = [ "SET FOREIGN_KEY_CHECKS = 0;" ] + [f"TRUNCATE TABLE { table_name } ;" for table_name in targets ] + ["SET FOREIGN_KEY_CHECKS = 1;" ]
95
+ debug_sql = "SHOW PROCESSLIST;\n SELECT * FROM INFORMATION_SCHEMA.INNODB_LOCKS;\n "
96
+ sql = "\n " .join (statements )
97
+ mysql = subprocess .Popen (my_sql_cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE , stdin = subprocess .PIPE , encoding = "utf-8" )
98
+ try :
99
+ logger .info (debug_sql + sql )
100
+ out , err = mysql .communicate (sql , timeout = 9999 )
101
+ if out :
102
+ logger .info (out )
103
+ if err :
104
+ logger .info (err )
105
+ except Exception as exc :
106
+ logger .error (f"BOO: { str (exc )} " , exc_info = True )
107
+
108
+ util .bootstrap_arxiv_db (db_engine )
109
+
110
+ yield db_engine
111
+
112
+ if db_path :
113
+ shutil .rmtree (db_path )
114
+ else :
115
+ with db_engine .connect () as connection :
116
+ danglings : CursorResult = connection .execute (text ("select id from information_schema.processlist where user = 'testuser';" )).all ()
117
+ connection .invalidate ()
118
+
119
+ if danglings :
120
+ kill_conn = "\n " .join ([ f"kill { id [0 ]} ;" for id in danglings ])
121
+ logger .info (kill_conn )
122
+ mysql = subprocess .Popen (my_sql_cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE , stdin = subprocess .PIPE , encoding = "utf-8" )
123
+ mysql .communicate (kill_conn )
124
+ db_engine .dispose ()
34
125
35
- @pytest .fixture
36
- def classic_db_engine ():
37
- db_path = tempfile .mkdtemp ()
38
- uri = f'sqlite:///{ db_path } /test.db'
39
- engine = create_engine (uri )
40
- util .create_all (engine )
41
- yield engine
42
- shutil .rmtree (db_path )
43
126
44
127
45
128
@pytest .fixture
@@ -90,7 +173,7 @@ def foouser(mocker):
90
173
issued_when = n ,
91
174
issued_to = '127.0.0.1' ,
92
175
remote_host = 'foohost.foo.com' ,
93
- session_id = 0
176
+ session_id = 1
94
177
)
95
178
user .tapir_nicknames = nick
96
179
user .tapir_passwords = password
@@ -100,8 +183,17 @@ def foouser(mocker):
100
183
101
184
@pytest .fixture
102
185
def db_with_user (classic_db_engine , foouser ):
103
- # just combines classic_db_engine and foouser
104
- with Session (classic_db_engine , expire_on_commit = False ) as session :
186
+ try :
187
+ _load_test_user (classic_db_engine , foouser )
188
+ except Exception as e :
189
+ pass
190
+ yield classic_db_engine
191
+
192
+
193
+ def _load_test_user (db_engine , foouser ):
194
+ # just combines db_engine and foouser
195
+ with Session (db_engine ) as session :
196
+
105
197
user = models .TapirUser (
106
198
user_id = foouser .user_id ,
107
199
first_name = foouser .first_name ,
@@ -117,6 +209,15 @@ def db_with_user(classic_db_engine, foouser):
117
209
flag_banned = foouser .flag_banned ,
118
210
tracking_cookie = foouser .tracking_cookie ,
119
211
)
212
+ session .add (user )
213
+ session .commit ()
214
+
215
+ # Make sure the ID is correct. If you are using mysql with different auto-increment. you may get an different id
216
+ # However, domain.User's user_id is str, and the db/models User model user_id is int.
217
+ # wish they match but since tapir's user id came from auto-increment id which has to be int, I guess
218
+ # "it is what it is".
219
+ assert str (foouser .user_id ) == str (user .user_id )
220
+
120
221
nick = models .TapirNickname (
121
222
nickname = foouser .tapir_nicknames .nickname ,
122
223
user_id = foouser .tapir_nicknames .user_id ,
@@ -126,11 +227,30 @@ def db_with_user(classic_db_engine, foouser):
126
227
policy = foouser .tapir_nicknames .policy ,
127
228
flag_primary = foouser .tapir_nicknames .flag_primary ,
128
229
)
230
+ session .add (nick )
231
+ session .commit ()
232
+
129
233
password = models .TapirUsersPassword (
130
234
user_id = foouser .user_id ,
131
235
password_storage = foouser .tapir_passwords .password_storage ,
132
236
password_enc = foouser .tapir_passwords .password_enc ,
133
237
)
238
+ session .add (password )
239
+ session .commit ()
240
+
241
+ with Session (db_engine ) as session :
242
+ tapir_session_1 = models .TapirSession (
243
+ session_id = foouser .tapir_tokens .session_id ,
244
+ user_id = foouser .user_id ,
245
+ last_reissue = 0 ,
246
+ start_time = 0 ,
247
+ end_time = 0
248
+ )
249
+ session .add (tapir_session_1 )
250
+ session .commit ()
251
+ assert foouser .tapir_tokens .session_id == tapir_session_1 .session_id
252
+
253
+ with Session (db_engine ) as session :
134
254
token = models .TapirPermanentToken (
135
255
user_id = foouser .user_id ,
136
256
secret = foouser .tapir_tokens .secret ,
@@ -140,20 +260,14 @@ def db_with_user(classic_db_engine, foouser):
140
260
remote_host = foouser .tapir_tokens .remote_host ,
141
261
session_id = foouser .tapir_tokens .session_id ,
142
262
)
143
- session .add (user )
144
263
session .add (token )
145
- session .add (password )
146
- session .add (nick )
147
264
session .commit ()
148
- session .close ()
149
265
150
- foouser .tapir_nicknames .nickname
151
- yield classic_db_engine
152
266
153
267
@pytest .fixture
154
268
def db_configed (db_with_user ):
155
- configure_db_engine (db_with_user ,None )
156
-
269
+ db_engine , _ = configure_db_engine (db_with_user ,None )
270
+ yield None
157
271
158
272
@pytest .fixture
159
273
def app (db_with_user ):
@@ -169,3 +283,8 @@ def app(db_with_user):
169
283
@pytest .fixture
170
284
def request_context (app ):
171
285
yield app .test_request_context ()
286
+
287
+
288
+ def pytest_addoption (parser ):
289
+ parser .addoption ("--db" , action = "store" , default = "sqlite" ,
290
+ help = "Database type to test against (sqlite/mysql)" )
0 commit comments