Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] A way to specify the input of environment resets through the DataCollector #1906

Closed
1 task done
jensbreitung opened this issue Feb 14, 2024 · 11 comments · Fixed by #2071
Closed
1 task done
Assignees
Labels
enhancement New feature or request

Comments

@jensbreitung
Copy link

Motivation

I'm working on an RL task in a (continuous) domain, however, the initial state the environment assumes on a reset comes
from a curated dataset, since we have prior knowledge of how the state of the environment typically looks in practice.
The environment should ideally not contain the entire dataset but only work with a single example (since that's all it "needs" to know to simulate the agent's actions and their effect). However, I would also like to make use of DataCollectors for training and validation.

Solution

Add an optional parameter reset_env_kwargs or similar to DataCollectors that allows to specify the arguments that are used when the collector calls the reset function of the environment.
This way one can specify the input of the reset (outside of the environment code) and hence does not have to move the entire dataset into the environment to be able to reset to specific environment states.

Alternatives

It is possible that I missed another (easier) way of doing this by registering some sort of hook? In that case, I'd appreciate a pointer or small example of how this could be implemented.

Additional context

In the case that you find this addition useful, I'd be happy to contribute.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@jensbreitung jensbreitung added the enhancement New feature or request label Feb 14, 2024
@vmoens
Copy link
Contributor

vmoens commented Feb 14, 2024

That makes total sense yeah.
Can you provide a pseudo-code of how you think this should work?

I can imagine several cases:

  • the input is fixed
  • the input is random
  • the input depends on the previous trajectory
  • the input is sampled from a dataset

Maybe a callable would be a good idea?

collector = DataCollector(..., env_reset_func: Callable[[], TensorDictBase])

wdyt?

@jensbreitung
Copy link
Author

callable sounds good to me and is probably much more versatile than just kwargs, but at least to me the name env_reset_func suggests that it overrides the entire reset function which is probably not ideal.

Other than that, it's exactly what I'd like to have :)

@vmoens
Copy link
Contributor

vmoens commented Feb 14, 2024

i'm pretty bad at naming things you might have noticed :p
happy with any other name

@jensbreitung
Copy link
Author

maybe env_pre_reset_fn? not sure if thats any better, the name gets long quickly ^^
In principle, it could also be a function you register directly to the environment(s), similar to how you can register recorders to trainers?

@vmoens
Copy link
Contributor

vmoens commented Feb 14, 2024

That actually sounds like TensorDictPrimer. Does that solve your problem or do you need something a bit more fancy?

@jensbreitung
Copy link
Author

I feel like TensorDictPrimer is a bit too limiting. If I understand it correctly it only allows me to set keys that can be populated according to distributions specified by their corresponding specs.

In my use case, I need to load an individual (empirical) sample from a recorded dataset (in an unfortunately rather unwieldy format) into the environment. A callback as discussed earlier sounds like a much simpler solution for this.

@vmoens
Copy link
Contributor

vmoens commented Feb 14, 2024

Got it!
We can make that happen, some sort of TensorDictPrimer where you pass a callable then?
I think TensorDictPrimer is the right abstraction for this problem bc it takes care of patching your specs in such a way that we will be able to know in advance how to populate buffers and such (that could be done lazily, without you needing to write the specs explicitly).

I can make a PR with that if you think that could work!

@jensbreitung
Copy link
Author

I think this would work. My data format is a bit odd in that it also contains strings and all kinds of other stuff that is not supported within TensorDict but I could probably work around it.
In general the use of TensorDictPrimer would not allow to dynamically populate the kwargs of reset during e.g. data collection. Maybe that is intended but it feels like an odd limitation.

@vmoens
Copy link
Contributor

vmoens commented Feb 15, 2024

I love odd :)
Now you can carry non tensor data in tensordict if that's of interest!
If you can give me some insight on what you'd like to do and can't it'd be amazing.
Still a bit experimental but I'd love to get feedback. In the meantime I can work on your issue

@jensbreitung
Copy link
Author

jensbreitung commented Feb 15, 2024

oh interesting, my tensordict version was outdated. I played around with 0.3.0 a bit and most data works fine, however, pandas dataframes do not seem to work, they are simply converted to tensors of corresponding dimension and dtype=torch.float32 filled with nan values.

Some more details about what I'm trying to achieve and how the data looks like:

The dataset contains dictionaries with all sorts of different types of data which together represent a complex system and its corresponding state.
The overall "shape" of the system stays the same throughout the entire training/validation process hence the environment can be constructed by taking a random dataset element and building the env accordingly.
However, each system state is defined by (potentially) all values contained in the dictionary.
The environment produces observations which loosely depend on the system state (but is pretty much defined by the data required to construct the environment in the first place).
The reward signal however again depends (potentially again) on pretty much all information we have about the system state.
In any case, most of the system state is required to be in the environment for logging purposes and contained in the tensordict returned by step.

The current limitations

To be able to have all this additional data available in the environment and use collectors during training/inference all of this data needs to be given to the environment so it can manually fetch everything corresponding to a system state at reset. This adds unnecessary state to the environment since from the environments perspective only one system state matters at each point in time.
It also complicates splitting the data easily across multiple environments when moving to vectorized environments and training on multiple nodes.

In principle one could probably do with less information in the environment but it makes the entire workflow much simpler. For instance, during validation, one could simply feed in the entire system and its initial state into the environment, perform an action as per the trained policy and then take out the resulting system + state in one go.

@vmoens
Copy link
Contributor

vmoens commented Feb 15, 2024

Ok quite clear thanks! Let me sleep on this a bit and come back to you!

@vmoens vmoens linked a pull request Apr 16, 2024 that will close this issue
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants