Skip to content

Commit 3a48fdd

Browse files
pghistory_backfill: avoid prefetching - dry-run working
1 parent 713a69f commit 3a48fdd

1 file changed

Lines changed: 122 additions & 13 deletions

File tree

dojo/management/commands/pghistory_backfill.py

Lines changed: 122 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
This command creates initial snapshots for all existing records in tracked models.
55
"""
66
import logging
7+
import time
78

89
from django.apps import apps
910
from django.conf import settings
@@ -33,6 +34,16 @@ def add_arguments(self, parser):
3334
action="store_true",
3435
help="Show what would be done without actually creating events",
3536
)
37+
parser.add_argument(
38+
"--log-queries",
39+
action="store_true",
40+
help="Enable database query logging (default: enabled)",
41+
)
42+
parser.add_argument(
43+
"--no-log-queries",
44+
action="store_true",
45+
help="Disable database query logging",
46+
)
3647

3748
def get_excluded_fields(self, model_name):
3849
"""Get the list of excluded fields for a specific model from pghistory configuration."""
@@ -45,6 +56,51 @@ def get_excluded_fields(self, model_name):
4556
}
4657
return excluded_fields_map.get(model_name, [])
4758

59+
def enable_db_logging(self):
60+
"""Enable database query logging for this command."""
61+
# Store original DEBUG setting
62+
self.original_debug = settings.DEBUG
63+
64+
# Configure database query logging
65+
db_logger = logging.getLogger("django.db.backends")
66+
db_logger.setLevel(logging.DEBUG)
67+
68+
# Add a handler if one doesn't exist
69+
if not db_logger.handlers:
70+
handler = logging.StreamHandler()
71+
formatter = logging.Formatter(
72+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
73+
)
74+
handler.setFormatter(formatter)
75+
db_logger.addHandler(handler)
76+
77+
# Also enable the SQL logger specifically
78+
sql_logger = logging.getLogger("django.db.backends.sql")
79+
sql_logger.setLevel(logging.DEBUG)
80+
81+
# Ensure the root logger propagates to our handlers
82+
if not sql_logger.handlers:
83+
sql_logger.addHandler(handler)
84+
85+
# Enable query logging in Django settings
86+
settings.DEBUG = True
87+
88+
self.stdout.write(
89+
self.style.SUCCESS("Database query logging enabled"),
90+
)
91+
92+
def disable_db_logging(self):
93+
"""Disable database query logging."""
94+
# Restore original DEBUG setting
95+
settings.DEBUG = self.original_debug
96+
97+
# Disable query logging by setting a higher level
98+
logging.getLogger("django.db.backends").setLevel(logging.INFO)
99+
logging.getLogger("django.db.backends.sql").setLevel(logging.INFO)
100+
self.stdout.write(
101+
self.style.SUCCESS("Database query logging disabled"),
102+
)
103+
48104
def handle(self, *args, **options):
49105
if not settings.ENABLE_AUDITLOG or settings.AUDITLOG_TYPE != "django-pghistory":
50106
self.stdout.write(
@@ -55,6 +111,17 @@ def handle(self, *args, **options):
55111
)
56112
return
57113

114+
# Enable database query logging based on options
115+
# Default to enabled unless explicitly disabled
116+
enable_query_logging = not options.get("no_log_queries")
117+
118+
if enable_query_logging:
119+
self.enable_db_logging()
120+
else:
121+
self.stdout.write(
122+
self.style.WARNING("Database query logging disabled"),
123+
)
124+
58125
# Models that are tracked by pghistory
59126
tracked_models = [
60127
"Dojo_User", "Endpoint", "Engagement", "Finding", "Finding_Group",
@@ -83,9 +150,11 @@ def handle(self, *args, **options):
83150
)
84151

85152
total_processed = 0
153+
total_start_time = time.time()
86154
self.stdout.write(f"Starting backfill for {len(tracked_models)} model(s)...")
87155

88156
for model_name in tracked_models:
157+
model_start_time = time.time()
89158
self.stdout.write(f"\nProcessing {model_name}...")
90159

91160
try:
@@ -143,6 +212,7 @@ def handle(self, *args, **options):
143212
processed = 0
144213
event_records = []
145214
failed_records = []
215+
batch_start_time = time.time()
146216

147217
for instance in records_needing_backfill.iterator():
148218
try:
@@ -156,8 +226,17 @@ def handle(self, *args, **options):
156226
for field in instance._meta.fields:
157227
field_name = field.name
158228
if field_name not in excluded_fields:
159-
field_value = getattr(instance, field_name)
160-
event_data[field_name] = field_value
229+
# Handle foreign key fields differently
230+
if field.many_to_one: # ForeignKey field
231+
# For foreign keys, use the _id field to get the raw ID value
232+
# Store it under the _id field name for the Event model
233+
field_id_name = f"{field_name}_id"
234+
field_value = getattr(instance, field_id_name)
235+
event_data[field_id_name] = field_value
236+
else:
237+
# For non-foreign key fields, use value_from_object() to avoid queries
238+
field_value = field.value_from_object(instance)
239+
event_data[field_name] = field_value
161240

162241
# Explicitly preserve created timestamp from the original instance
163242
# Only if not excluded and exists
@@ -180,12 +259,16 @@ def handle(self, *args, **options):
180259

181260
except Exception as e:
182261
failed_records.append(instance.id)
183-
logger.error(
184-
f"Failed to prepare event for {model_name} ID {instance.id}: {e}",
262+
logger.exception(
263+
f"Failed to prepare event for {model_name} ID {instance.id}",
185264
)
186265

187266
# Bulk create when we hit batch_size records
188267
if len(event_records) >= batch_size:
268+
batch_end_time = time.time()
269+
batch_duration = batch_end_time - batch_start_time
270+
batch_records_per_second = len(event_records) / batch_duration if batch_duration > 0 else 0
271+
189272
if not dry_run and event_records:
190273
try:
191274
attempted = len(event_records)
@@ -199,19 +282,25 @@ def handle(self, *args, **options):
199282
f"actually created {actually_created} ({attempted - actually_created} skipped)",
200283
)
201284
except Exception as e:
202-
logger.error(f"Failed to bulk create events for {model_name}: {e}")
285+
logger.exception(f"Failed to bulk create events for {model_name}")
203286
raise
204287
elif dry_run:
205288
processed += len(event_records)
206289

207290
event_records = [] # Reset for next batch
291+
batch_start_time = time.time() # Reset batch timer
208292

209-
# Progress update
293+
# Progress update with batch timing
210294
progress = (processed / backfill_count) * 100
211-
self.stdout.write(f" Processed {processed:,}/{backfill_count:,} records needing backfill ({progress:.1f}%)")
295+
self.stdout.write(f" Processed {processed:,}/{backfill_count:,} records needing backfill ({progress:.1f}%) - "
296+
f"Last batch: {batch_duration:.2f}s ({batch_records_per_second:.1f} records/sec)")
212297

213298
# Handle remaining records
214299
if event_records:
300+
batch_end_time = time.time()
301+
batch_duration = batch_end_time - batch_start_time
302+
batch_records_per_second = len(event_records) / batch_duration if batch_duration > 0 else 0
303+
215304
if not dry_run:
216305
try:
217306
attempted = len(event_records)
@@ -225,41 +314,61 @@ def handle(self, *args, **options):
225314
f"actually created {actually_created} ({attempted - actually_created} skipped)",
226315
)
227316
except Exception as e:
228-
logger.error(f"Failed to bulk create final batch for {model_name}: {e}")
317+
logger.exception(f"Failed to bulk create final batch for {model_name}")
229318
raise
230319
else:
231320
processed += len(event_records)
232321

322+
# Log final batch timing
323+
self.stdout.write(f" Final batch: {batch_duration:.2f}s ({batch_records_per_second:.1f} records/sec)")
324+
233325
# Final progress update
234326
if backfill_count > 0:
235327
progress = (processed / backfill_count) * 100
236328
self.stdout.write(f" Processed {processed:,}/{backfill_count:,} records needing backfill ({progress:.1f}%)")
237329

238330
total_processed += processed
239331

240-
# Show completion summary
332+
# Calculate timing for this model
333+
model_end_time = time.time()
334+
model_duration = model_end_time - model_start_time
335+
records_per_second = processed / model_duration if model_duration > 0 else 0
336+
337+
# Show completion summary with timing
241338
if failed_records:
242339
self.stdout.write(
243340
self.style.WARNING(
244341
f" ⚠ Completed {model_name}: {processed:,} records processed, "
245-
f"{len(failed_records)} records failed",
342+
f"{len(failed_records)} records failed in {model_duration:.2f}s "
343+
f"({records_per_second:.1f} records/sec)",
246344
),
247345
)
248346
else:
249347
self.stdout.write(
250348
self.style.SUCCESS(
251-
f" ✓ Completed {model_name}: {processed:,} records",
349+
f" ✓ Completed {model_name}: {processed:,} records in {model_duration:.2f}s "
350+
f"({records_per_second:.1f} records/sec)",
252351
),
253352
)
254353

255354
except Exception as e:
256355
self.stdout.write(
257356
self.style.ERROR(f" ✗ Failed to process {model_name}: {e}"),
258357
)
259-
logger.error(f"Error processing {model_name}: {e}")
358+
logger.exception(f"Error processing {model_name}")
359+
360+
# Calculate total timing
361+
total_end_time = time.time()
362+
total_duration = total_end_time - total_start_time
363+
total_records_per_second = total_processed / total_duration if total_duration > 0 else 0
364+
365+
# Disable database query logging if it was enabled
366+
if enable_query_logging:
367+
self.disable_db_logging()
260368

261369
self.stdout.write(
262370
self.style.SUCCESS(
263-
f"\nBACKFILL COMPLETE: Processed {total_processed:,} records",
371+
f"\nBACKFILL COMPLETE: Processed {total_processed:,} records in {total_duration:.2f}s "
372+
f"({total_records_per_second:.1f} records/sec)",
264373
),
265374
)

0 commit comments

Comments
 (0)