Skip to content
This repository was archived by the owner on Oct 2, 2024. It is now read-only.

Commit 607cd27

Browse files
frascuchonpre-commit-ci[bot]davidberenstein1957
authored
[FEATURE] Allow update dataset settings for fields, vectors and metadata (#232)
* refactor: Define VectorFieldModel as a ResourceModel * feat: Align VectorsAPI methods with endpoints * refactor: Define VectoField as Resource * refactor: Using VectorField methods * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * chore: apply PR suggestions * refactor: Redefine field model to customize only settings * feat: Implement Fields API methods using new model def * chore: Review some naming * feat: Redefine TextField as a Resource * chore: define private VectoField properties * chore: Redefine fields refs based on TextField * chore: Remove unused defs * tests: Update tests * chore: Using proper import * refactor: Review Metadata API model and methods * refactor: Align metadata fields with Resource class * refactor: implement upsert metadata using metadata resource methods * tests: Adapt tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: Allow update fields, vectors and metadata settings * feat: update dataset with its settings * tests: Add integration tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: dataset resource tests * chore: using model_dump instead of dict * refactor: Using SettingsProperties class for field, vectors, and metadata management * chore: Using id instead of external_id * chore: Change test conditioN * chore: Update settings tests with new settings properties container * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [BUGFIX] Export and import settings including vectors and metadata (#235) * fix: metadata and vector infos are included for serialization * tests: Add more tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: David Berenstein <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David Berenstein <[email protected]>
1 parent d3c82e1 commit 607cd27

23 files changed

+575
-362
lines changed

src/argilla_sdk/_api/_fields.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
from uuid import UUID
1717

1818
import httpx
19+
1920
from argilla_sdk._api._base import ResourceAPI
2021
from argilla_sdk._exceptions import api_error_handler
21-
from argilla_sdk._models import FieldBaseModel, TextFieldModel, FieldModel
22+
from argilla_sdk._models import FieldModel
2223

2324
__all__ = ["FieldsAPI"]
2425

2526

26-
class FieldsAPI(ResourceAPI[FieldBaseModel]):
27+
class FieldsAPI(ResourceAPI[FieldModel]):
2728
"""Manage datasets via the API"""
2829

2930
http_client: httpx.Client
@@ -33,36 +34,39 @@ class FieldsAPI(ResourceAPI[FieldBaseModel]):
3334
################
3435

3536
@api_error_handler
36-
def create(self, dataset_id: UUID, field: FieldModel) -> FieldModel:
37-
url = f"/api/v1/datasets/{dataset_id}/fields"
37+
def get(self, id: UUID) -> FieldModel:
38+
raise NotImplementedError()
39+
40+
@api_error_handler
41+
def create(self, field: FieldModel) -> FieldModel:
42+
url = f"/api/v1/datasets/{field.dataset_id}/fields"
3843
response = self.http_client.post(url=url, json=field.model_dump())
3944
response.raise_for_status()
4045
response_json = response.json()
41-
field_model = self._model_from_json(response_json=response_json)
42-
self._log_message(message=f"Created field {field_model.name} in dataset {dataset_id}")
43-
return field_model
46+
created_field = self._model_from_json(response_json=response_json)
47+
self._log_message(message=f"Created field {created_field.name} in dataset {field.dataset_id}")
48+
return created_field
4449

4550
@api_error_handler
4651
def update(self, field: FieldModel) -> FieldModel:
47-
# TODO: Implement update method for fields with server side ID
48-
raise NotImplementedError
52+
url = f"/api/v1/fields/{field.id}"
53+
response = self.http_client.patch(url, json=field.model_dump())
54+
response.raise_for_status()
55+
response_json = response.json()
56+
updated_field = self._model_from_json(response_json)
57+
self._log_message(message=f"Update field {updated_field.name} with id {field.id}")
58+
return updated_field
4959

5060
@api_error_handler
51-
def delete(self, dataset_id: UUID) -> None:
52-
# TODO: Implement delete method for fields with server side ID
53-
raise NotImplementedError
61+
def delete(self, field_id: UUID) -> None:
62+
url = f"/api/v1/fields/{field_id}"
63+
self.http_client.delete(url).raise_for_status()
64+
self._log_message(message=f"Deleted field {field_id}")
5465

5566
####################
5667
# Utility methods #
5768
####################
5869

59-
def create_many(self, dataset_id: UUID, fields: List[FieldModel]) -> List[FieldModel]:
60-
field_models = []
61-
for field in fields:
62-
field_model = self.create(dataset_id=dataset_id, field=field)
63-
field_models.append(field_model)
64-
return field_models
65-
6670
@api_error_handler
6771
def list(self, dataset_id: UUID) -> List[FieldModel]:
6872
response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/fields")
@@ -78,19 +82,7 @@ def list(self, dataset_id: UUID) -> List[FieldModel]:
7882
def _model_from_json(self, response_json: Dict) -> FieldModel:
7983
response_json["inserted_at"] = self._date_from_iso_format(date=response_json["inserted_at"])
8084
response_json["updated_at"] = self._date_from_iso_format(date=response_json["updated_at"])
81-
return self._get_model_from_response(response_json=response_json)
85+
return FieldModel(**response_json)
8286

8387
def _model_from_jsons(self, response_jsons: List[Dict]) -> List[FieldModel]:
8488
return list(map(self._model_from_json, response_jsons))
85-
86-
def _get_model_from_response(self, response_json: Dict) -> FieldModel:
87-
try:
88-
field_type = response_json.get("settings", {}).get("type")
89-
except Exception as e:
90-
raise ValueError("Invalid response type: missing 'settings.type' in response") from e
91-
if field_type == "text":
92-
# TODO: Avoid apply validations here (check_fields=False?)
93-
return TextFieldModel(**response_json)
94-
else:
95-
# TODO: Add more field types
96-
raise ValueError(f"Invalid field type: {field_type}")

src/argilla_sdk/_api/_metadata.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from uuid import UUID
1717

1818
import httpx
19+
1920
from argilla_sdk._api._base import ResourceAPI
2021
from argilla_sdk._exceptions import api_error_handler
2122
from argilla_sdk._models import MetadataFieldModel
@@ -33,53 +34,44 @@ class MetadataAPI(ResourceAPI[MetadataFieldModel]):
3334
################
3435

3536
@api_error_handler
36-
def create(self, dataset_id: UUID, metadata_field: MetadataFieldModel) -> MetadataFieldModel:
37-
url = f"/api/v1/datasets/{dataset_id}/metadata-properties"
38-
response = self.http_client.post(url=url, json=metadata_field.model_dump())
37+
def get(self, metadata_id: UUID) -> MetadataFieldModel:
38+
raise NotImplementedError()
39+
40+
@api_error_handler
41+
def create(self, metadata: MetadataFieldModel) -> MetadataFieldModel:
42+
url = f"/api/v1/datasets/{metadata.dataset_id}/metadata-properties"
43+
response = self.http_client.post(url=url, json=metadata.model_dump())
3944
response.raise_for_status()
4045
response_json = response.json()
41-
metadata_field_model = self._model_from_json(response_json=response_json)
42-
self._log_message(message=f"Created metadata field {metadata_field_model.name} in dataset {dataset_id}")
43-
return metadata_field_model
46+
created_metadata = self._model_from_json(response_json=response_json)
47+
self._log_message(message=f"Created metadata field {created_metadata.name} in dataset {metadata.dataset_id}")
48+
return created_metadata
4449

4550
@api_error_handler
46-
def update(self, metadata_field: MetadataFieldModel) -> MetadataFieldModel:
47-
url = f"/api/v1/metadata-properties/{metadata_field.id}"
48-
response = self.http_client.patch(url=url, json=metadata_field.model_dump())
51+
def update(self, metadata: MetadataFieldModel) -> MetadataFieldModel:
52+
url = f"/api/v1/metadata-properties/{metadata.id}"
53+
response = self.http_client.patch(url=url, json=metadata.model_dump())
4954
response.raise_for_status()
5055
response_json = response.json()
51-
metadata_field_model = self._model_from_json(response_json=response_json)
52-
self._log_message(message=f"Updated field {metadata_field_model.name}")
53-
return metadata_field_model
56+
updated_metadata = self._model_from_json(response_json=response_json)
57+
self._log_message(message=f"Updated metadata field {updated_metadata.name}")
58+
return updated_metadata
5459

55-
@api_error_handler
56-
def delete(self, id: UUID) -> None:
57-
url = f"/api/v1/metadata-properties/{id}"
60+
def delete(self, metadata_id: UUID) -> None:
61+
url = f"/api/v1/metadata-properties/{metadata_id}"
5862
self.http_client.delete(url=url).raise_for_status()
59-
self._log_message(message=f"Deleted field {id}")
60-
61-
@api_error_handler
62-
def get(self, id: UUID) -> MetadataFieldModel:
63-
raise NotImplementedError()
63+
self._log_message(message=f"Deleted metadata field {metadata_id}")
6464

6565
####################
6666
# Utility methods #
6767
####################
6868

69-
def create_many(self, dataset_id: UUID, metadata_fields: List[MetadataFieldModel]) -> List[MetadataFieldModel]:
70-
metadata_field_models = []
71-
for metadata_field in metadata_fields:
72-
metadata_field_model = self.create(dataset_id=dataset_id, metadata_field=metadata_field)
73-
metadata_field_models.append(metadata_field_model)
74-
return metadata_field_models
75-
7669
@api_error_handler
7770
def list(self, dataset_id: UUID) -> List[MetadataFieldModel]:
7871
response = self.http_client.get(f"/api/v1/me/datasets/{dataset_id}/metadata-properties")
7972
response.raise_for_status()
8073
response_json = response.json()
81-
metadata_field_model = self._model_from_jsons(response_jsons=response_json["items"])
82-
return metadata_field_model
74+
return self._model_from_jsons(response_jsons=response_json["items"])
8375

8476
####################
8577
# Private methods #

src/argilla_sdk/_api/_vectors.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from uuid import UUID
1717

1818
import httpx
19+
1920
from argilla_sdk._api._base import ResourceAPI
2021
from argilla_sdk._exceptions import api_error_handler
2122
from argilla_sdk._models import VectorFieldModel
@@ -33,36 +34,36 @@ class VectorsAPI(ResourceAPI[VectorFieldModel]):
3334
################
3435

3536
@api_error_handler
36-
def create(self, dataset_id: UUID, vector: VectorFieldModel) -> VectorFieldModel:
37-
url = f"/api/v1/datasets/{dataset_id}/vectors-settings"
37+
def create(self, vector: VectorFieldModel) -> VectorFieldModel:
38+
url = f"/api/v1/datasets/{vector.dataset_id}/vectors-settings"
3839
response = self.http_client.post(url=url, json=vector.model_dump())
3940
response.raise_for_status()
4041
response_json = response.json()
41-
vector_model = self._model_from_json(response_json=response_json)
42-
self._log_message(message=f"Created vector {vector_model.name} in dataset {dataset_id}")
43-
return vector_model
42+
created_vector = self._model_from_json(response_json=response_json)
43+
self._log_message(message=f"Created vector {created_vector.name} in dataset {created_vector.dataset_id}")
44+
return created_vector
4445

4546
@api_error_handler
4647
def update(self, vector: VectorFieldModel) -> VectorFieldModel:
47-
# TODO: Implement update method for vectors with server side ID
48-
raise NotImplementedError
48+
url = f"/api/v1/vectors-settings/{vector.id}"
49+
response = self.http_client.patch(url, json=vector.model_dump())
50+
response.raise_for_status()
51+
response_json = response.json()
52+
updated_vector = self._model_from_json(response_json)
53+
self._log_message(message=f"Updated vector {updated_vector.name} with id {updated_vector.id}")
54+
return updated_vector
4955

5056
@api_error_handler
5157
def delete(self, vector_id: UUID) -> None:
52-
# TODO: Implement delete method for vectors with server side ID
53-
raise NotImplementedError
58+
url = f"/api/v1/vectors-settings/{vector_id}"
59+
response = self.http_client.delete(url)
60+
response.raise_for_status()
61+
self._log_message(message=f"Deleted vector with id {vector_id}")
5462

5563
####################
5664
# Utility methods #
5765
####################
5866

59-
def create_many(self, dataset_id: UUID, vectors: List[VectorFieldModel]) -> List[VectorFieldModel]:
60-
vector_models = []
61-
for vector in vectors:
62-
vector_model = self.create(dataset_id=dataset_id, vector=vector)
63-
vector_models.append(vector_model)
64-
return vector_models
65-
6667
@api_error_handler
6768
def list(self, dataset_id: UUID) -> List[VectorFieldModel]:
6869
response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/vectors-settings")

src/argilla_sdk/_models/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232
ScopeModel,
3333
)
3434
from argilla_sdk._models._settings._fields import (
35-
TextFieldModel,
36-
FieldSettings,
37-
FieldBaseModel,
35+
FieldModel,
36+
TextFieldSettings,
3837
FieldModel,
3938
)
4039
from argilla_sdk._models._settings._questions import (

src/argilla_sdk/_models/_settings/_fields.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import Optional, Literal
1616
from uuid import UUID
1717

18-
from pydantic import BaseModel, field_serializer, field_validator, Field
18+
from pydantic import BaseModel, field_serializer, field_validator
1919
from pydantic_core.core_schema import ValidationInfo
2020

2121
from argilla_sdk._helpers import log_message
22+
from argilla_sdk._models import ResourceModel
2223

2324

24-
class FieldSettings(BaseModel):
25-
type: str = Field(validate_default=True)
25+
class TextFieldSettings(BaseModel):
26+
type: Literal["text"] = "text"
2627
use_markdown: Optional[bool] = False
2728

2829

29-
class FieldBaseModel(BaseModel):
30-
id: Optional[UUID] = None
30+
class FieldModel(ResourceModel):
3131
name: str
32-
3332
title: Optional[str] = None
3433
required: bool = True
3534
description: Optional[str] = None
35+
settings: TextFieldSettings = TextFieldSettings(use_markdown=False)
36+
dataset_id: Optional[UUID] = None
3637

3738
@field_validator("name")
3839
@classmethod
@@ -48,13 +49,6 @@ def __title_default(cls, title: str, info: ValidationInfo) -> str:
4849
log_message(f"TextField title is {validated_title}")
4950
return validated_title
5051

51-
@field_serializer("id", when_used="unless-none")
52+
@field_serializer("id", "dataset_id", when_used="unless-none")
5253
def serialize_id(self, value: UUID) -> str:
5354
return str(value)
54-
55-
56-
class TextFieldModel(FieldBaseModel):
57-
settings: FieldSettings = FieldSettings(type="text", use_markdown=False)
58-
59-
60-
FieldModel = TextFieldModel

src/argilla_sdk/_models/_settings/_metadata.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator
2020

2121
from argilla_sdk._exceptions import MetadataError
22+
from argilla_sdk._models import ResourceModel
2223

2324

2425
class MetadataPropertyType(str, Enum):
@@ -36,9 +37,7 @@ class TermsMetadataPropertySettings(BaseMetadataPropertySettings):
3637
type: Literal[MetadataPropertyType.terms]
3738
values: Optional[List[str]] = None
3839

39-
@field_validator(
40-
"values",
41-
)
40+
@field_validator("values")
4241
@classmethod
4342
def __validate_values(cls, values):
4443
if values is None:
@@ -94,17 +93,18 @@ class FloatMetadataPropertySettings(NumericMetadataPropertySettings):
9493
]
9594

9695

97-
class MetadataFieldModel(BaseModel):
96+
class MetadataFieldModel(ResourceModel):
9897
"""The schema definition of a metadata field in an Argilla dataset."""
9998

100-
id: Optional[UUID] = None
10199
name: str
102100
settings: MetadataPropertySettings
103101

104102
type: Optional[MetadataPropertyType] = Field(None, validate_default=True)
105103
title: Optional[str] = None
106104
visible_for_annotators: Optional[bool] = True
107105

106+
dataset_id: Optional[UUID] = None
107+
108108
@field_validator("name")
109109
@classmethod
110110
def __name_lower(cls, name):
@@ -117,7 +117,7 @@ def __title_default(cls, title, values):
117117
validated_title = title or values.data["name"]
118118
return validated_title
119119

120-
@field_serializer("id", when_used="unless-none")
120+
@field_serializer("id", "dataset_id", when_used="unless-none")
121121
def serialize_id(self, value: UUID) -> str:
122122
return str(value)
123123

src/argilla_sdk/_models/_settings/_vectors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,17 @@
1515
from typing import Optional
1616
from uuid import UUID
1717

18-
from pydantic import BaseModel, field_validator, field_serializer
18+
from pydantic import field_validator, field_serializer
1919
from pydantic_core.core_schema import ValidationInfo
2020

21+
from argilla_sdk._models import ResourceModel
2122
from argilla_sdk._helpers import log_message
2223

2324

24-
class VectorFieldModel(BaseModel):
25+
class VectorFieldModel(ResourceModel):
2526
name: str
2627
title: Optional[str] = None
2728
dimensions: int
28-
29-
id: Optional[UUID] = None
3029
dataset_id: Optional[UUID] = None
3130

3231
@field_serializer("id", "dataset_id", when_used="unless-none")

src/argilla_sdk/datasets/_resource.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ def create(self) -> "Dataset":
164164
self.__rollback_dataset_creation()
165165
raise SettingsError from e
166166

167+
def update(self) -> "Dataset":
168+
"""Updates the dataset on the server with the current settings.
169+
170+
Returns:
171+
Dataset: The updated dataset object.
172+
"""
173+
self.settings.update()
174+
return self
175+
167176
@classmethod
168177
def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset":
169178
return cls(client=client, _model=model)
@@ -173,7 +182,6 @@ def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset":
173182
#####################
174183

175184
def _publish(self) -> "Dataset":
176-
self.settings.validate()
177185
self._settings.create()
178186
self._api.publish(dataset_id=self._model.id)
179187

0 commit comments

Comments
 (0)