Skip to content

Commit b454071

Browse files
committed
Add postgres database via sqlx
1 parent 6c3124e commit b454071

12 files changed

+218
-24
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ anyhow = "1.0.33"
1717
config = "0.10.1"
1818
serde = "1.0.117"
1919
serde_derive = "1.0.117"
20+
sqlx = { version = "0.4.0-beta.1", default-features = false, features = [ "runtime-tokio", "macros", "postgres", "uuid" ] }
2021

2122
[dependencies.serenity]
2223
version = "0.9.0-rc.2"

config/development.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
database_url="postgres://postgres:[email protected]/sheepbot"
2+
13
[discord]
24
prefix = "?"
35

config/example.toml

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
database_url="postgres://[email protected]/example"
2+
13
[discord]
24
prefix = "?"
35
token = "DISCORD_TOKEN"

docker-compose.prod.yml

+3
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ services:
99
image: fredboat/lavalink:master
1010
volumes:
1111
- ./lavalink/application.yml:/opt/Lavalink/application.yml
12+
db:
13+
image: postgres:12
14+
restart: always

docker-compose.yml

+9-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,12 @@ services:
1717
volumes:
1818
- ./lavalink/application.yml:/opt/Lavalink/application.yml
1919
ports:
20-
- 2333:2333
20+
- 2333:2333
21+
db:
22+
image: postgres:12
23+
restart: always
24+
environment:
25+
POSTGRES_PASSWORD: passworddevelop
26+
POSTGRES_DB: sheepbot
27+
ports:
28+
- 5432:5432
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
--
2+
-- Editor SQL for DB table only_link_channel
3+
-- Created by http://editor.datatables.net/generator
4+
--
5+
6+
CREATE TABLE IF NOT EXISTS "only_link_channel" (
7+
"id" serial,
8+
"guild_id" numeric(9,2),
9+
"user_id" numeric(9,2),
10+
"url" text,
11+
"channel_id" numeric(9,2),
12+
PRIMARY KEY( id )
13+
);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
-- Add migration script here
2+
3+
CREATE TABLE config (
4+
guild_id bigint NOT NULL PRIMARY KEY,
5+
host_role_id bigint
6+
);
7+
8+
CREATE TABLE prefixes (
9+
guild_id bigint NOT NULL PRIMARY KEY,
10+
prefix text
11+
);

src/handler.rs

+42-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use tracing::{error, info};
2+
13
use std::{
24
sync::Arc,
35
collections::HashSet,
@@ -28,6 +30,7 @@ use serenity::{
2830
gateway::Ready,
2931
id::GuildId,
3032
event::VoiceServerUpdateEvent,
33+
prelude::{Guild},
3134
},
3235
};
3336

@@ -37,6 +40,10 @@ use lavalink_rs::{
3740
model::*,
3841
gateway::*,
3942
};
43+
use sqlx::PgPool;
44+
use crate::settings::Settings;
45+
46+
use crate::utils::database::{initialize_tables};
4047

4148

4249
pub(crate) struct VoiceManager;
@@ -55,14 +62,25 @@ impl TypeMapKey for VoiceGuildUpdate {
5562
type Value = Arc<RwLock<HashSet<GuildId>>>;
5663
}
5764

65+
pub(crate) struct ConnectionPool;
66+
67+
impl TypeMapKey for ConnectionPool {
68+
type Value = PgPool;
69+
}
70+
71+
pub(crate) struct SettingsConf;
72+
73+
impl TypeMapKey for SettingsConf {
74+
type Value = Settings;
75+
}
5876

5977
pub(crate) struct Handler;
6078

6179
#[async_trait]
6280
impl EventHandler for Handler {
6381
async fn ready(&self, _: Context, ready: Ready) {
6482
if let Some(shard) = ready.shard {
65-
println!(
83+
info!(
6684
"{} is connected on shard {}/{}!",
6785
ready.user.name,
6886
shard[0],
@@ -80,25 +98,41 @@ impl EventHandler for Handler {
8098
}
8199
}
82100

101+
async fn guild_create(&self, ctx: Context, guild: Guild, is_new: bool) {
102+
// We'll initialize the database tables for a guild if it's new.
103+
if !is_new {
104+
return;
105+
}
106+
107+
initialize_tables(&ctx, &guild).await;
108+
}
109+
83110
}
84111

85112
pub(crate) struct LavalinkHandler;
86113

87114
#[async_trait]
88115
impl LavalinkEventHandler for LavalinkHandler {
89116
async fn track_start(&self, _client: Arc<Mutex<LavalinkClient>>, event: TrackStart) {
90-
println!("Track started!\nGuild: {}", event.guild_id);
117+
info!("Track started!\nGuild: {}", event.guild_id);
91118
}
92119
async fn track_finish(&self, _client: Arc<Mutex<LavalinkClient>>, event: TrackFinish) {
93-
println!("Track finished!\nGuild: {}", event.guild_id);
120+
info!("Track finished!\nGuild: {}", event.guild_id);
94121
}
95122
}
96123

97124
#[hook]
98-
async fn after(_ctx: &Context, _msg: &Message, command_name: &str, command_result: CommandResult) {
99-
match command_result {
100-
Err(why) => println!("Command '{}' returned error {:?} => {}", command_name, why, why),
101-
_ => (),
125+
pub(crate) async fn after(ctx: &Context, msg: &Message, cmd_name: &str, error: CommandResult) {
126+
if let Err(why) = &error {
127+
error!("Error while running command {}", &cmd_name);
128+
error!("{:?}", &error);
129+
130+
let err = why.to_string();
131+
if msg.channel_id.say(ctx, &err).await.is_err() {
132+
error!(
133+
"Unable to send messages on channel id {}",
134+
&msg.channel_id.0
135+
);
136+
};
102137
}
103138
}
104-

src/main.rs

+73-15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ mod utils;
1010
mod handler;
1111
mod settings;
1212

13+
use tracing::{error, info};
14+
1315
use std::{
1416
sync::Arc,
1517
collections::HashSet,
@@ -24,13 +26,12 @@ use serenity::{
2426
framework::standard::{
2527
Args, StandardFramework, CommandGroup,
2628
HelpOptions, help_commands, CommandResult,
27-
macros::{help},
29+
macros::{help, hook},
2830
},
2931
model::prelude::*,
3032
prelude::*,
3133
};
3234

33-
use serenity::prelude::*;
3435
use lavalink_rs::{
3536
LavalinkClient
3637
};
@@ -43,11 +44,16 @@ use crate::handler::{
4344
Lavalink,
4445
VoiceManager,
4546
VoiceGuildUpdate,
46-
LavalinkHandler
47+
LavalinkHandler,
48+
ConnectionPool,
49+
after,
50+
SettingsConf
4751
};
4852

4953
use settings::Settings;
5054

55+
use crate::utils::database::{obtain_pool};
56+
5157

5258
#[help]
5359
async fn my_help(
@@ -62,37 +68,85 @@ async fn my_help(
6268
Ok(())
6369
}
6470

71+
#[hook]
72+
// Sets a custom prefix for a guild.
73+
pub async fn dynamic_prefix(ctx: &Context, msg: &Message) -> Option<String> {
74+
let guild_id = &msg.guild_id;
75+
76+
let data = ctx.data.read().await;
77+
let settings = data.get::<SettingsConf>().unwrap();
78+
let default_prefix = settings.discord.prefix.as_str();
79+
80+
let prefix: String;
81+
if let Some(id) = guild_id {
82+
let pool = data.get::<ConnectionPool>().unwrap();
83+
84+
let res = sqlx::query!(
85+
"SELECT prefix FROM prefixes WHERE guild_id = $1",
86+
id.0 as i64
87+
)
88+
.fetch_one(pool)
89+
.await;
90+
91+
prefix = if let Ok(data) = res {
92+
if let Some(p) = data.prefix {
93+
p
94+
} else {
95+
default_prefix.to_string()
96+
}
97+
} else {
98+
error!("I couldn't query the database for getting guild prefix.");
99+
default_prefix.to_string()
100+
}
101+
} else {
102+
prefix = default_prefix.to_string();
103+
};
104+
105+
Some(prefix)
106+
}
107+
65108
#[tokio::main]
66109
async fn main() -> Result<(), Box<dyn std::error::Error>> {
67110
let settings = match Settings::new() {
68111
Ok(conf) => conf,
69112
Err(why) => panic!("Could not read config: {:?}", why),
70113
};
71114

72-
let token = settings.discord.token;
73-
let lavalink_url = settings.lavalink.url;
74-
let lavalink_password = settings.lavalink.password;
115+
let token = &settings.discord.token;
116+
let lavalink_url = &settings.lavalink.url;
117+
let lavalink_password = &settings.lavalink.password;
75118

76-
let prefix = settings.discord.prefix.as_str();
119+
let db_url = &settings.database_url;
77120

78-
let http = Http::new_with_token(&token);
121+
let http = Http::new_with_token(token);
79122

80-
let bot_id = match http.get_current_application_info().await {
81-
Ok(info) => info.id,
123+
// We will fetch your bot's owners and id
124+
let (owners, bot_id) = match http.get_current_application_info().await {
125+
Ok(info) => {
126+
let mut owners = HashSet::new();
127+
owners.insert(info.owner.id);
128+
129+
(owners, info.id)
130+
}
82131
Err(why) => panic!("Could not access application info: {:?}", why),
83132
};
84133

85134

86135
let framework = StandardFramework::new()
87-
.configure(|c| c
88-
.prefix(prefix))
136+
.configure(|c| {
137+
c.owners(owners)
138+
.dynamic_prefix(dynamic_prefix)
139+
.with_whitespace(false)
140+
.on_mention(Some(bot_id))
141+
})
89142
.help(&MY_HELP)
143+
.after(after)
90144
.group(&MUSIC_GROUP)
91145
.group(&FUN_GROUP);
92146

93-
let mut client = Client::new(&token)
94-
.event_handler(Handler)
147+
let mut client = Client::new(token)
95148
.framework(framework)
149+
.event_handler(Handler)
96150
.await
97151
.expect("Err creating client");
98152

@@ -109,8 +163,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
109163

110164
let lava = lava_client.initialize(LavalinkHandler).await?;
111165
data.insert::<Lavalink>(lava);
166+
167+
let pool = obtain_pool(db_url).await?;
168+
data.insert::<ConnectionPool>(pool);
169+
data.insert::<SettingsConf>(settings);
112170
}
113-
114171

115172
// Here we clone a lock to the Shard Manager, and then move it into a new
116173
// thread. The thread will unlock the manager and print shards' status on a
@@ -134,6 +191,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
134191
}
135192
}
136193
});
194+
137195

138196
// Start two shards. Note that there is an ~5 second ratelimit period
139197
// between when one shard can start after another.

src/settings.rs

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub struct Lavalink {
1616

1717
#[derive(Debug, Deserialize)]
1818
pub struct Settings {
19+
pub database_url: String,
1920
pub discord: Discord,
2021
pub lavalink: Lavalink,
2122
}

0 commit comments

Comments
 (0)