1
1
from django_dbml .utils import to_snake_case
2
2
from django .apps import apps
3
- from django .core .management .base import BaseCommand
3
+ from django .core .management .base import BaseCommand , CommandError
4
4
from django .db import models
5
5
6
6
7
7
class Command (BaseCommand ):
8
- help = "The main DBML management file"
8
+ help = "Generate a DBML file based on Django models"
9
+
10
+ def add_arguments (self , parser ):
11
+ parser .add_argument (
12
+ 'args' , metavar = 'app_label[.ModelName]' , nargs = '*' ,
13
+ help = 'Restricts dbml generation to the specified app_label or app_label.ModelName.' ,
14
+ )
9
15
10
16
def get_field_notes (self , field ):
11
17
if len (field .keys ()) == 1 :
@@ -29,7 +35,35 @@ def get_field_notes(self, field):
29
35
return ""
30
36
return "[{}]" .format (", " .join (attributes ))
31
37
32
- def handle (self , * args , ** kwargs ):
38
+ def get_app_tables (self , app_labels ):
39
+ # get the list of models to generate DBML for
40
+
41
+ # if no apps are specified, process all models
42
+ if not app_labels :
43
+ return apps .get_models ()
44
+
45
+ # get specific models when app or app.model is specified
46
+ app_tables = []
47
+ for app in app_labels :
48
+ app_label_parts = app .split ('.' )
49
+ # first part is always the app label
50
+ app_label = app_label_parts [0 ]
51
+ # use the second part as model label if set
52
+ model_label = app_label_parts [1 ] if len (app_label_parts ) > 1 else None
53
+ try :
54
+ app_config = apps .get_app_config (app_label )
55
+ except LookupError as e :
56
+ raise CommandError (str (e ))
57
+
58
+ app_config = apps .get_app_config (app_label )
59
+ if model_label :
60
+ app_tables .append (app_config .get_model (model_label ))
61
+ else :
62
+ app_tables .extend (app_config .get_models ())
63
+
64
+ return app_tables
65
+
66
+ def handle (self , * app_labels , ** kwargs ):
33
67
all_fields = {}
34
68
allowed_types = ["ForeignKey" , "ManyToManyField" ]
35
69
for field_type in models .__all__ :
@@ -44,7 +78,8 @@ def handle(self, *args, **kwargs):
44
78
)
45
79
46
80
tables = {}
47
- app_tables = apps .get_models ()
81
+ app_tables = self .get_app_tables (app_labels )
82
+
48
83
for app_table in app_tables :
49
84
table_name = app_table .__name__
50
85
tables [table_name ] = {"fields" : {}, "relations" : []}
0 commit comments