Skip to content

Commit 1940f81

Browse files
authored
[Backend] Add custom sqs retention time (#4185)
* Add custom sqs retention time * Fix flake8 issues * Add missing migration * Add fixes * Fix errors * Update common.py * Fix migrations
1 parent df03312 commit 1940f81

File tree

7 files changed

+86
-4
lines changed

7 files changed

+86
-4
lines changed

apps/base/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ def get_or_create_sqs_queue(queue_name, challenge=None):
201201
!= "AWS.SimpleQueueService.NonExistentQueue"
202202
):
203203
logger.exception("Cannot get queue: {}".format(queue_name))
204+
sqs_retention_period = SQS_RETENTION_PERIOD if challenge is None else str(challenge.sqs_retention_period)
204205
queue = sqs.create_queue(
205206
QueueName=queue_name,
206-
Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD},
207+
Attributes={"MessageRetentionPeriod": sqs_retention_period},
207208
)
208209
return queue
209210

apps/challenges/aws_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,34 @@ def create_ec2_instance(challenge, ec2_storage=None, worker_instance_type=None,
771771
}
772772

773773

774+
def update_sqs_retention_period(challenge):
775+
"""
776+
Update the SQS retention period for a challenge.
777+
778+
Args:
779+
challenge (Challenge): The challenge for which the SQS retention period is to be updated.
780+
781+
Returns:
782+
dict: A dictionary containing the status and message of the operation.
783+
"""
784+
sqs_retention_period = str(challenge.sqs_retention_period)
785+
try:
786+
sqs = get_boto3_client("sqs", aws_keys)
787+
queue_url = sqs.get_queue_url(QueueName=challenge.queue)['QueueUrl']
788+
response = sqs.set_queue_attributes(
789+
QueueUrl=queue_url,
790+
Attributes={
791+
'MessageRetentionPeriod': sqs_retention_period
792+
}
793+
)
794+
return {"message": response.json()}
795+
except Exception as e:
796+
logger.exception(e)
797+
return {
798+
"error": e.response,
799+
}
800+
801+
774802
def start_workers(queryset):
775803
"""
776804
The function called by the admin action method to start all the selected workers.
@@ -1794,3 +1822,16 @@ def setup_ec2(challenge):
17941822
if challenge_obj.ec2_instance_id:
17951823
return start_ec2_instance(challenge_obj)
17961824
return create_ec2_instance(challenge_obj)
1825+
1826+
1827+
@app.task
1828+
def update_sqs_retention_period_task(challenge):
1829+
"""
1830+
Updates sqs retention period for a challenge when the attribute is changed.
1831+
1832+
Arguments:
1833+
challenge {<class 'django.db.models.query.QuerySet'>} -- instance of the model calling the post hook
1834+
"""
1835+
for obj in serializers.deserialize("json", challenge):
1836+
challenge_obj = obj.object
1837+
return update_sqs_retention_period(challenge_obj)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Generated by Django 2.2.20 on 2023-12-10 15:16
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('challenges', '0111_alter_challenge_ephemeral_storage_default'),
10+
]
11+
12+
operations = [
13+
migrations.AddField(
14+
model_name='challenge',
15+
name='sqs_retention_period',
16+
field=models.PositiveIntegerField(default=259200, verbose_name='SQS Retention Period'),
17+
),
18+
]

apps/challenges/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, *args, **kwargs):
3535
super(Challenge, self).__init__(*args, **kwargs)
3636
self._original_evaluation_script = self.evaluation_script
3737
self._original_approved_by_admin = self.approved_by_admin
38+
self._original_sqs_retention_period = self.sqs_retention_period
3839

3940
title = models.CharField(max_length=100, db_index=True)
4041
short_description = models.TextField(null=True, blank=True)
@@ -126,6 +127,10 @@ def __init__(self, *args, **kwargs):
126127
verbose_name="SQS queue name",
127128
db_index=True,
128129
)
130+
sqs_retention_period = models.PositiveIntegerField(
131+
default=259200,
132+
verbose_name="SQS Retention Period"
133+
)
129134
is_docker_based = models.BooleanField(
130135
default=False, verbose_name="Is Docker Based", db_index=True
131136
)
@@ -277,6 +282,21 @@ def create_eks_cluster_or_ec2_for_challenge(sender, instance, created, **kwargs)
277282
aws.challenge_approval_callback(sender, instance, field_name, **kwargs)
278283

279284

285+
@receiver(signals.post_save, sender="challenges.Challenge")
286+
def update_sqs_retention_period_for_challenge(sender, instance, created, **kwargs):
287+
field_name = "sqs_retention_period"
288+
import challenges.aws_utils as aws
289+
290+
if not created and is_model_field_changed(instance, field_name):
291+
serialized_obj = serializers.serialize("json", [instance])
292+
aws.update_sqs_retention_period_task.delay(serialized_obj)
293+
# Update challenge
294+
curr = getattr(instance, "{}".format(field_name))
295+
challenge = instance
296+
challenge._original_sqs_retention_period = curr
297+
challenge.save()
298+
299+
280300
class DatasetSplit(TimeStampedModel):
281301
name = models.CharField(max_length=100)
282302
codename = models.CharField(max_length=100)

apps/jobs/sender.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def get_or_create_sqs_queue(queue_name, challenge=None):
6262
ex.response["Error"]["Code"]
6363
== "AWS.SimpleQueueService.NonExistentQueue"
6464
):
65+
sqs_retention_period = SQS_RETENTION_PERIOD if challenge is None else str(challenge.sqs_retention_period)
6566
queue = sqs.create_queue(
6667
QueueName=queue_name,
67-
Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD},
68+
Attributes={"MessageRetentionPeriod": sqs_retention_period},
6869
)
6970
else:
7071
logger.exception("Cannot get or create Queue")

scripts/workers/submission_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,9 +806,10 @@ def get_or_create_sqs_queue(queue_name, challenge=None):
806806
!= "AWS.SimpleQueueService.NonExistentQueue"
807807
):
808808
logger.exception("Cannot get queue: {}".format(queue_name))
809+
sqs_retention_period = SQS_RETENTION_PERIOD if challenge is None else str(challenge.sqs_retention_period)
809810
queue = sqs.create_queue(
810811
QueueName=queue_name,
811-
Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD},
812+
Attributes={"MessageRetentionPeriod": sqs_retention_period},
812813
)
813814
return queue
814815

settings/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,4 +406,4 @@
406406
}
407407

408408
# SQS Queue Message Retention Period
409-
SQS_RETENTION_PERIOD = "1209600"
409+
SQS_RETENTION_PERIOD = "345600"

0 commit comments

Comments
 (0)