-
-
Notifications
You must be signed in to change notification settings - Fork 68
Atmorep Implementation with its tests #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Thanks for working on this. I think this needs a lot of changes. A few general notes:
- Don't put test logic in the actual library code. So don't have logic for
MagicMock
s and don't return dummy values, or add dummy values if something doesn't exist. That is the test code job to do, not the code in the library. This also applies to the other PRs you've opened. - Please don't just have a
config
as the argument toinit
, each class should take all the arguments it needs in theinit
spelled out as keyword arguments. Theconfig
can be nice to bundle them up, but shouldn't be theinit
arguments for nearly all classes or functions. - Model implementations shouldn't have their own datasets, unless really necessary. We are aiming to have these codes be as interoperable as possible, and reduce duplication where possible, so things like the
ERA5Dataset
can be removed, they have equivalents underdata/
and work for any model in this repo. Same for the normalizer, there are simpler or equivalent normalizers that already exist in this repo, so those could be used instead, the implementation here isn't needed. - The CRPS, etc. are all useful, but those would be generally useful. My advice would be to split those metrics out into its own PR that adds them to a
graph_weather/metrics/crps.py
, etc. and tests them separately. They would be useful for all the models here! - All imports should be at the top of the file, you shouldn't need to import inside functions, and if that is necessary for tests to pass, then the code should be refactored so that is not the case, as it is saying that there is too much of a linking between test and library code.
- Training script should also be quite general, unless the model has very specific needs, it should be able to be trained by the other training scripts.
- As much as possible, there should be type hints for all the arguments and return types in the code, so it is easier to see what is expected in and out.
Please take the time to go through and address the comments. Also, please comment on issues or open one up if you want to add a model or a large PR like this one. I can help with scoping it or breaking it up so it is easier to review, and there is less wasted effort.
from graph_weather.models.atmorep.config import AtmoRepConfig | ||
|
||
|
||
class ERA5Dataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this, we could probably just run from ARCO-ERA5, it might be easier. Ideally, this would also be a more separate PR for the dataset, under /data/
and be more general than AtmoRep-specific.
self.transform = transform | ||
|
||
# Expect a data index file in the data directory | ||
index_file = os.path.join(data_dir, "data_index.txt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need a index file for dataset like this, the ARCO-ERA5 is quite simple to use and index into, so this shouldn't be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend removing this. An ERA5Dataset
is/would be useful, but that would be a different PR, and more general than just to AtmoRep. Additionally, there is the ARCO-ERA5 dataset on GCP that is very simple to use and read from, so most of the functionality here could be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a more general normalizer than AtmoRep, and so shouldn't be in this repo. Additionally, generally, these classes and functions shouldn't make data for tests. Tests should handle that. I would remove this normalizer and file, as there are other normalizers already present.
Returns: | ||
dict: A dictionary of statistics for each field. | ||
""" | ||
self.stats = {field: {"mean": 0.0, "std": 1.0} for field in self.config.input_fields} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.stats = {field: {"mean": 0.0, "std": 1.0} for field in self.config.input_fields} | |
return NotImplementedError |
If this isn't actually computing the mean/stddev then it should return a NotImplementedError, not some fake stats, as that can cause downstream errors.
|
||
|
||
class UncertaintyEstimator: | ||
def __init__(self, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have this take the actual args, not the config
return entropy | ||
|
||
|
||
class UncertaintyEstimator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant to the one above, please remove it.
return entropy.reshape(B, T, H, W) | ||
|
||
|
||
class CalibrationMetrics: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These aren't bad, but I would maybe generalize these rank histograms and CRPS to a more general losses/
or the losses.py
file
mapped_preds = torch.zeros_like(ensemble_preds) | ||
|
||
# Process each location separately | ||
for b in range(B): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please do this vectorized or something, this won't scale at all.
tests/test_atmorep.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please only use pytest
for the testing, so don't use the unittest.mock
or such. Also, the tests might be better if they are split across a few files, so maybe split the tests into model ones under tests/atmorep/test_model.py
, loss ones under tests/atmorep/test_loss.py
, etc.
Finally, instead of tagging me in the description, please request my review when it's ready. And do check that the code you were doing passes the precommit tests. |
…meters, and address reviewer feedback
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Overview of the changes:
If the updated code aligns with your feedback, or if there are other areas you'd like me to prioritize or additional features to implement, please let me know. Otherwise, I will start working on the formatting for these files. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for all this work. There are some changes I would like to see still, and you missed a few of the comments I made in the previous round. Just a suggestion, but if you go through and respond to the comments I make on the PR, it can help me see where the changes occured, as well as ensure they aren't missed.
Few other notes: I don't think we need the training script for this, thank you for the work making it and adding the tests, but I'm more thinking this repo should have a single train script that can train any of the models in this repo, rather than lots of individual training scripts for each model on its own. So if you could remove it that would be great. Same with the ERA5Dataset file, as there is already an ARCO-ERA5 dataset in the repo, that is more generic.
I also am not sure we need the sampler.py, I think that can be removed as its somewhat specific to NetCDF files, and I don't quite see the utility of it.
One final note in general. I'm very happy you are quite excited and dedicated to contribute to this repository. One thing that might also help is making smaller PRs. Rather than very large ones that include everything to do with model, training, etc. its easier for me to review and go over smaller ones. So maybe one only adding the unique layer/module from a paper, then a follow on PR that adds the encoder or decoder, and a another one that adds the processor, for example. Then a final one if needed for the dataset, possibly. It just makes it easier for me to review, faster for me to review, and gives more chances for feedback before you spend a lot time writing up a lot of code and opening the PR.
num_heads: int, | ||
dropout: float = 0.1, | ||
attention_dropout: float = 0.1, | ||
transformer_block_cls=nn.Module, # replace with your actual block class if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove this bit or if set to a default, default to the one that works with the decoder. But the class is used in a quite specific way, so I would recommend removing this as a configuration option.
time_steps: int, | ||
num_layers: int, | ||
field_name: str = "unknown_field", | ||
transformer_block_cls=nn.Module, # replace with your actual block class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as for the decoder, its used in quite a specific way, so removing this as an option makes more sense.
transformer_block_cls=nn.Module, # replace with your actual block class |
after the regular transformer blocks to better capture temporal/spatial dependencies. | ||
|
||
Args: | ||
All the same args as MultiFormer, plus any needed for SpatioTemporalAttention. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still copy the args over here, duplication in these kinds of docstrings is okay so that the documentation is right next to the code it is describing.
""" | ||
Args: | ||
predictions (dict): Dict of field predictions, each with shape | ||
[E, B, T, H, W] or [B, T, H, W]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[E, B, T, H, W] or [B, T, H, W]. | |
[Ensemble, Batch, Time, Height, Width] or [Batch, Time, Height, Width]. |
For these in the docstrings, I think it is helpful to write out what the dimension ordering means, just to be a bit clearer. If you could do that on the other docstrings that would be great!
# Ensure the model has parameters; if it fails, let it break. | ||
params = list(model.parameters()) | ||
if len(params) == 0: | ||
# If for some reason the model has no parameters, register a dummy parameter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This falls under the test logic, if there is no parameters, you want the script to fail, not continue working but training something that you don't know what it is.
graph_weather/data/sampler.py
Outdated
self.logger = logging.getLogger("HierarchicalSampler") | ||
self.logger.setLevel(logging.INFO) | ||
|
||
def _get_available_time_segments(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't actually checking available times in the dataset, just generating a list of years and months.
Thank you so much for your detailed feedback! I appreciate you taking the time to review my work and provide such specific guidance. I apologize for missing some of your earlier comments — that was an oversight on my part. If you could point me to any specific ones that I missed, I’ll make sure to address them as soon as possible. As per your suggestion, I’ve removed the training script, ERA5Dataset file, sampler.py removing test_loss.py and test_training.py. I now understand your vision for a more streamlined codebase with a single training script, rather than multiple model-specific implementations. Your advice on submitting smaller, more focused PRs makes perfect sense, and I’ll adopt this approach going forward. Breaking down contributions into more manageable pieces will certainly make the review process smoother for everyone involved. I’m excited about contributing to this project and learning from your expertise. I’m committed to aligning my workflow with the team’s practices and improving my contributions. Is there anything else you would suggest I prioritize in addressing the current PR, aside from the files mentioned? Thanks again for your valuable guidance! |
Hi @jacobbieker, |
Pull Request
Description
This PR introduces a comprehensive test suite for the AtmoRep model in
test_atmorep.py
. It includes unit tests for configuration, data preprocessing, model inference, normalization, training components, and robustness checks.Issue Addressed #103
Key Features
create_forecast
.How Has This Been Tested?
Checklist:
@jacobbieker