You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
1.7 KiB

  1. from typing import Dict, Optional, Union
  2. from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
  3. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
  4. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
  5. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  6. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  7. from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX
  8. torch, nn = try_import_torch()
  9. class TorchActionMaskModel(TorchModelV2, nn.Module):
  10. def __init__(
  11. self,
  12. obs_space,
  13. action_space,
  14. num_outputs,
  15. model_config,
  16. name,
  17. **kwargs,
  18. ):
  19. orig_space = getattr(obs_space, "original_space", obs_space)
  20. TorchModelV2.__init__(
  21. self, obs_space, action_space, num_outputs, model_config, name, **kwargs
  22. )
  23. nn.Module.__init__(self)
  24. self.count = 0
  25. self.internal_model = TorchFC(
  26. orig_space["data"],
  27. action_space,
  28. num_outputs,
  29. model_config,
  30. name + "_internal",
  31. )
  32. def forward(self, input_dict, state, seq_lens):
  33. # Extract the available actions tensor from the observation.
  34. # Compute the unmasked logits.
  35. logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
  36. action_mask = input_dict["obs"]["action_mask"]
  37. inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
  38. masked_logits = logits + inf_mask
  39. # Return masked logits.
  40. return masked_logits, state
  41. def value_function(self):
  42. return self.internal_model.value_function()