Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial model support generic #1630

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions sea-orm-macros/src/derives/partial_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::Expr;
use syn::Generics;

use syn::Meta;

Expand All @@ -17,7 +18,6 @@ use self::util::GetAsKVMeta;
enum Error {
InputNotStruct,
EntityNotSpecific,
NotSupportGeneric(Span),
BothFromColAndFromExpr(Span),
Syn(syn::Error),
}
Expand All @@ -35,14 +35,11 @@ struct DerivePartialModel {
entity_ident: Option<syn::Ident>,
ident: syn::Ident,
fields: Vec<ColumnAs>,
generic: Generics,
}

impl DerivePartialModel {
fn new(input: syn::DeriveInput) -> Result<Self, Error> {
if !input.generics.params.is_empty() {
return Err(Error::NotSupportGeneric(input.generics.params.span()));
}

let syn::Data::Struct(syn::DataStruct{fields:syn::Fields::Named(syn::FieldsNamed{named:fields,..}),..},..) = input.data else{
return Err(Error::InputNotStruct);
};
Expand Down Expand Up @@ -128,6 +125,7 @@ impl DerivePartialModel {
entity_ident,
ident: input.ident,
fields: column_as_list,
generic: input.generics,
})
}

Expand All @@ -141,6 +139,7 @@ impl DerivePartialModel {
entity_ident,
ident,
fields,
generic,
} = self;
let select_col_code_gen = fields.iter().map(|col_as| match col_as {
ColumnAs::Col(ident) => {
Expand All @@ -158,9 +157,11 @@ impl DerivePartialModel {
},
});

let (impl_generic, type_generic, where_clause) = generic.split_for_impl();

quote! {
#[automatically_derived]
impl sea_orm::PartialModelTrait for #ident{
impl #impl_generic sea_orm::PartialModelTrait for #ident #type_generic #where_clause {
fn select_cols<S: sea_orm::SelectColumns>(#select_ident: S) -> S{
#(#select_col_code_gen)*
#select_ident
Expand All @@ -175,14 +176,11 @@ pub fn expand_derive_partial_model(input: syn::DeriveInput) -> syn::Result<Token

match DerivePartialModel::new(input) {
Ok(partial_model) => partial_model.expand(),
Err(Error::NotSupportGeneric(span)) => Ok(quote_spanned! {
span => compile_error!("you can only derive `DerivePartialModel` on named struct");
}),
Err(Error::BothFromColAndFromExpr(span)) => Ok(quote_spanned! {
span => compile_error!("you can only use one of `from_col` or `from_expr`");
}),
Err(Error::EntityNotSpecific) => Ok(quote_spanned! {
ident_span => compile_error!("you need specific which entity you are using")
ident_span => compile_error!("you need specific which entity you are using");
}),
Err(Error::InputNotStruct) => Ok(quote_spanned! {
ident_span => compile_error!("you can only derive `DerivePartialModel` on named struct");
Expand Down
15 changes: 15 additions & 0 deletions sea-orm-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,21 @@ pub fn derive_from_json_query_result(input: TokenStream) -> TokenStream {
/// sum: i32
/// }
/// ```
///
/// `DerivePartialModel` support generic argument, it not require extra generic bound.
///
/// Note that, the bound [`TryGetable`](sea_orm::TryGetable) is for [`FromQueryResult`]
/// ```rust
/// use sea_orm::{
/// entity::prelude::*, sea_query::Expr, DerivePartialModel, FromQueryResult, TryGetable,
/// };
///
/// #[derive(Debug, FromQueryResult, DerivePartialModel)]
/// struct SelectResult<T: TryGetable> {
/// #[sea_orm(from_expr = "Expr::val(1).add(1)")]
/// sum: T,
/// }
/// ```
#[cfg(feature = "derive")]
#[proc_macro_derive(DerivePartialModel, attributes(sea_orm))]
pub fn derive_partial_model(input: TokenStream) -> TokenStream {
Expand Down
48 changes: 47 additions & 1 deletion tests/partial_model_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::ops::Deref;

use entity::{Column, Entity};
use sea_orm::{ColumnTrait, DerivePartialModel, FromQueryResult};
use sea_orm::{ColumnTrait, DerivePartialModel, FromQueryResult, TryGetable};
use sea_query::Expr;

mod entity {
Expand Down Expand Up @@ -45,3 +47,47 @@ struct FieldFromExpr {
#[sea_orm(from_expr = "Expr::col(Column::Id).equals(Column::Foo)")]
_bar: bool,
}

#[derive(FromQueryResult, DerivePartialModel)]
#[sea_orm(entity = "Entity")]
struct GenericTest<T>
where
T: TryGetable,
{
_foo: i32,
_bar: T,
}
#[derive(FromQueryResult, DerivePartialModel)]
#[sea_orm(entity = "Entity")]
struct MultiGenericTest<T: TryGetable, F: TryGetable> {
#[sea_orm(from_expr = "Column::Bar2.sum()")]
_foo: T,
_bar: F,
}

#[derive(FromQueryResult, DerivePartialModel)]
#[sea_orm(entity = "Entity")]
struct GenericWithBoundsTest<T: TryGetable + Copy + Clone + 'static> {
_foo: T,
}

#[derive(FromQueryResult, DerivePartialModel)]
#[sea_orm(entity = "Entity")]
struct WhereGenericTest<T>
where
T: TryGetable + Deref,
<T as Deref>::Target: Clone,
{
_foo: T,
}

#[derive(FromQueryResult, DerivePartialModel)]
struct MixedBoundTest<T: TryGetable + Clone, F>
where
F: TryGetable + Clone,
{
#[sea_orm(from_expr = "Column::Bar2.sum()")]
_foo: T,
#[sea_orm(from_expr = "Expr::col(Column::Id).equals(Column::Foo)")]
_bar: F,
}