Skip to content
Snippets Groups Projects
Commit c1cacec9 authored by Yoann GOURVES's avatar Yoann GOURVES
Browse files

rework interactor

parent a22032a6
No related branches found
No related tags found
No related merge requests found
No preview for this file type
...@@ -12,23 +12,21 @@ void UCarTrainingEnv::GatherAgentReward_Implementation(float& OutReward, const i ...@@ -12,23 +12,21 @@ void UCarTrainingEnv::GatherAgentReward_Implementation(float& OutReward, const i
GEngine->AddOnScreenDebugMessage(INDEX_NONE, 15.0f, FColor::Red, FString::Printf(TEXT("ERROR UCarTrainingEnv Initialized GatherAgentReward_Implementation."))); GEngine->AddOnScreenDebugMessage(INDEX_NONE, 15.0f, FColor::Red, FString::Printf(TEXT("ERROR UCarTrainingEnv Initialized GatherAgentReward_Implementation.")));
return; return;
} }
else /*else
GEngine->AddOnScreenDebugMessage(INDEX_NONE, 15.0f, FColor::Green, FString::Printf(TEXT("SUCCESS UCarTrainingEnv Initialized GatherAgentReward_Implementation."))); GEngine->AddOnScreenDebugMessage(INDEX_NONE, 15.0f, FColor::Green, FString::Printf(TEXT("SUCCESS UCarTrainingEnv Initialized GatherAgentReward_Implementation.")));
*/
auto scored = PlayerCar->GetPlayerState<APlayerCarState>()->GetVehicleHit(); auto scored = PlayerCar->GetPlayerState<APlayerCarState>()->GetVehicleHit();
AgentHitCount.FindOrAdd(AgentId) += scored;
float weight = -1.0f; float weight = -1.0f;
OutReward = weight * scored; OutReward = weight * scored;
PlayerCar->GetPlayerState<APlayerCarState>()->ResetVehicleHit(); PlayerCar->GetPlayerState<APlayerCarState>()->ResetVehicleHit();
} }
void UCarTrainingEnv::GatherAgentCompletion_Implementation(ELearningAgentsCompletion& OutCompletion, const int32 AgentId) void UCarTrainingEnv::GatherAgentCompletion_Implementation(ELearningAgentsCompletion& OutCompletion, const int32 AgentId)
{ {
//on time out make this episode complete (timeout 30 seconds) <- idea: map of agentId associated with time of episode star auto hitCount = AgentHitCount.Find(AgentId);
auto time = AgentTimeMap.Find(AgentId); if (*hitCount>maxHitCount) {
auto currentTime = GetWorld()->GetTimeSeconds(); OutCompletion = ELearningAgentsCompletion::Termination;
if (time != nullptr && currentTime - *time > 30.0f) {
OutCompletion = ELearningAgentsCompletion::Truncation;
} }
else { else {
OutCompletion = ELearningAgentsCompletion::Running; OutCompletion = ELearningAgentsCompletion::Running;
...@@ -38,7 +36,7 @@ void UCarTrainingEnv::GatherAgentCompletion_Implementation(ELearningAgentsComple ...@@ -38,7 +36,7 @@ void UCarTrainingEnv::GatherAgentCompletion_Implementation(ELearningAgentsComple
void UCarTrainingEnv::ResetAgentEpisode_Implementation(const int32 AgentId) void UCarTrainingEnv::ResetAgentEpisode_Implementation(const int32 AgentId)
{ {
//reset the time of the agent //reset the time of the agent
AgentTimeMap.Add(AgentId, GetWorld()->GetTimeSeconds()); AgentHitCount.FindOrAdd(AgentId, 0);
//reset environment //reset environment
auto CustomGameState = GetWorld()->GetGameState<ACustomGameState>(); auto CustomGameState = GetWorld()->GetGameState<ACustomGameState>();
......
...@@ -234,6 +234,8 @@ void ACustomGameMode::OnScoreCollisionBoxEndOverlap(UPrimitiveComponent* Overlap ...@@ -234,6 +234,8 @@ void ACustomGameMode::OnScoreCollisionBoxEndOverlap(UPrimitiveComponent* Overlap
// Get the game state // Get the game state
ACustomGameState* CustomGameState = GetGameState<ACustomGameState>(); ACustomGameState* CustomGameState = GetGameState<ACustomGameState>();
CustomGameState->fSpeed *= 1.05; CustomGameState->fSpeed *= 1.05;
//clamp the speed
CustomGameState->fSpeed = FMath::Clamp(CustomGameState->fSpeed, 100.0f, 3000.0f);
} }
......
...@@ -9,9 +9,19 @@ void UCustomLearningInteractor::SpecifyAgentObservation_Implementation(FLearning ...@@ -9,9 +9,19 @@ void UCustomLearningInteractor::SpecifyAgentObservation_Implementation(FLearning
{ {
auto CustomGameState = GetWorld()->GetGameState<ACustomGameState>(); auto CustomGameState = GetWorld()->GetGameState<ACustomGameState>();
int numberOfLanes = CustomGameState->iNumberOfLane*2; int numberOfLanes = CustomGameState->iNumberOfLane*2;
//TMap<FName, FLearningAgentsObservationSchemaElement> ObstaclesObservation;
//ObstaclesObservation.Add("ObstacleAngle", ULearningAgentsObservations::SpecifyAngleObservation(InObservationSchema, "ObstacleAngle"));
//ObstaclesObservation.Add("ObstacleDistance", ULearningAgentsObservations::SpecifyProportionAlongRayObservation(InObservationSchema, "ObstacleDistance"));
TMap<FName, FLearningAgentsObservationSchemaElement> ObstaclesObservation; TMap<FName, FLearningAgentsObservationSchemaElement> ObstaclesObservation;
ObstaclesObservation.Add("ObstacleAngle", ULearningAgentsObservations::SpecifyAngleObservation(InObservationSchema, "ObstacleAngle")); ObstaclesObservation.Add("ObstacleAngles",
ObstaclesObservation.Add("ObstacleDistance", ULearningAgentsObservations::SpecifyProportionAlongRayObservation(InObservationSchema, "ObstacleDistance")); ULearningAgentsObservations::SpecifyEitherObservation(InObservationSchema,
ULearningAgentsObservations::SpecifyArrayObservation(InObservationSchema, ULearningAgentsObservations::SpecifyAngleObservation(InObservationSchema, "ObstacleAngle"), 200, 32, 4, 32, "ObstacleAngles"),
ULearningAgentsObservations::SpecifyNullObservation(InObservationSchema, "ObstacleAngle"),128,"ObstacleAngles"));
ObstaclesObservation.Add("ObstacleDistances",
ULearningAgentsObservations::SpecifyEitherObservation(InObservationSchema,
ULearningAgentsObservations::SpecifyArrayObservation(InObservationSchema, ULearningAgentsObservations::SpecifyProportionAlongRayObservation(InObservationSchema, "ObstacleDistance"), 200, 32, 4, 32, "ObstacleDistances"),
ULearningAgentsObservations::SpecifyNullObservation(InObservationSchema, "ObstacleDistance"), 128, "ObstacleDistances"));
TMap<FName, FLearningAgentsObservationSchemaElement> SelfObservation; TMap<FName, FLearningAgentsObservationSchemaElement> SelfObservation;
...@@ -38,10 +48,10 @@ void UCustomLearningInteractor::GatherAgentObservation_Implementation(FLearningA ...@@ -38,10 +48,10 @@ void UCustomLearningInteractor::GatherAgentObservation_Implementation(FLearningA
if (!PlayerCarState) PlayerCarState = GetWorld()->SpawnActor<APlayerCarState>(); if (!PlayerCarState) PlayerCarState = GetWorld()->SpawnActor<APlayerCarState>();
TMap<FName, FLearningAgentsObservationObjectElement> ObstaclesObservations; TMap<FName, FLearningAgentsObservationObjectElement> ObstaclesObservations;
TMap<FName, FLearningAgentsObservationObjectElement> SelfObservations; TMap<FName, FLearningAgentsObservationObjectElement> SelfObservations;
if (Obstacles.Num() == 0) TArray<FLearningAgentsObservationObjectElement> AngleObservations;
TArray<FLearningAgentsObservationObjectElement> DistanceObservations;
if (Obstacles.Num() != 0)
{ {
return;
}
for (auto Obstacle : Obstacles) for (auto Obstacle : Obstacles)
{ {
// Get the angle between the player car and the obstacle & distance // Get the angle between the player car and the obstacle & distance
...@@ -51,8 +61,15 @@ void UCustomLearningInteractor::GatherAgentObservation_Implementation(FLearningA ...@@ -51,8 +61,15 @@ void UCustomLearningInteractor::GatherAgentObservation_Implementation(FLearningA
float Angle = FMath::RadiansToDegrees(FMath::Acos(FVector::DotProduct(Direction.GetSafeNormal(), FVector(1.0f, 0.0f, 0.0f)))); float Angle = FMath::RadiansToDegrees(FMath::Acos(FVector::DotProduct(Direction.GetSafeNormal(), FVector(1.0f, 0.0f, 0.0f))));
auto AngleObservation = ULearningAgentsObservations::MakeAngleObservation(InObservationObject, Angle, 0.0f, "ObstacleAngle"); // add logger possible auto AngleObservation = ULearningAgentsObservations::MakeAngleObservation(InObservationObject, Angle, 0.0f, "ObstacleAngle"); // add logger possible
auto DistanceObservation = ULearningAgentsObservations::MakeProportionAlongRayObservation(InObservationObject, PlayerCarLocation, ObstacleLocation, FTransform(), ECC_WorldStatic, "ObstacleDistance"); // add logger possible auto DistanceObservation = ULearningAgentsObservations::MakeProportionAlongRayObservation(InObservationObject, PlayerCarLocation, ObstacleLocation, FTransform(), ECC_WorldStatic, "ObstacleDistance"); // add logger possible
ObstaclesObservations.Add("ObstacleAngle", AngleObservation); AngleObservations.Add(AngleObservation);
ObstaclesObservations.Add("ObstacleDistance", DistanceObservation); DistanceObservations.Add(DistanceObservation);
}
ObstaclesObservations.Add("ObstacleAngles", ULearningAgentsObservations::MakeEitherAObservation(InObservationObject,ULearningAgentsObservations::MakeArrayObservation(InObservationObject, AngleObservations, 200, "ObstacleAngles"), "ObstacleAngles"));
ObstaclesObservations.Add("ObstacleDistances", ULearningAgentsObservations::MakeEitherAObservation(InObservationObject, ULearningAgentsObservations::MakeArrayObservation(InObservationObject, DistanceObservations, 200, "ObstacleDistances"), "ObstacleDistances"));
}
else {
ObstaclesObservations.Add("ObstacleAngles", ULearningAgentsObservations::MakeEitherBObservation(InObservationObject, ULearningAgentsObservations::MakeNullObservation(InObservationObject, "ObstacleAngle"), "ObstacleAngles"));
ObstaclesObservations.Add("ObstacleDistances", ULearningAgentsObservations::MakeEitherBObservation(InObservationObject, ULearningAgentsObservations::MakeNullObservation(InObservationObject, "ObstacleDistance"), "ObstacleDistances"));
} }
// Get available lane observation // Get available lane observation
int Currentlane = PlayerCarState->GetPlayerLane(); int Currentlane = PlayerCarState->GetPlayerLane();
...@@ -76,9 +93,15 @@ void UCustomLearningInteractor::SpecifyAgentAction_Implementation(FLearningAgent ...@@ -76,9 +93,15 @@ void UCustomLearningInteractor::SpecifyAgentAction_Implementation(FLearningAgent
void UCustomLearningInteractor::PerformAgentAction_Implementation(const ULearningAgentsActionObject* InActionObject, const FLearningAgentsActionObjectElement& InActionObjectElement, const int32 AgentId) void UCustomLearningInteractor::PerformAgentAction_Implementation(const ULearningAgentsActionObject* InActionObject, const FLearningAgentsActionObjectElement& InActionObjectElement, const int32 AgentId)
{ {
static float lastActionTime = 0.0f;
if (GetWorld()->GetTimeSeconds() - lastActionTime < 0.4f)
{
return;
}
auto PlayerCar = Cast<APlayerCar>(Manager->GetAgent(AgentId)); auto PlayerCar = Cast<APlayerCar>(Manager->GetAgent(AgentId));
auto PlayerCarState = Cast<APlayerCarState>(PlayerCar->GetPlayerState()); auto PlayerCarState = Cast<APlayerCarState>(PlayerCar->GetPlayerState());
int CurrentLane = PlayerCarState->GetPlayerLane();
TMap<FName, FLearningAgentsActionObjectElement> ActionMap; TMap<FName, FLearningAgentsActionObjectElement> ActionMap;
if (!ULearningAgentsActions::GetStructAction(ActionMap, InActionObject, InActionObjectElement)) { if (!ULearningAgentsActions::GetStructAction(ActionMap, InActionObject, InActionObjectElement)) {
return; return;
...@@ -98,19 +121,18 @@ void UCustomLearningInteractor::PerformAgentAction_Implementation(const ULearnin ...@@ -98,19 +121,18 @@ void UCustomLearningInteractor::PerformAgentAction_Implementation(const ULearnin
{ {
UE_LOG(LogTemp, Warning, TEXT("Steering: %d"), Steering); UE_LOG(LogTemp, Warning, TEXT("Steering: %d"), Steering);
} }
int TargetLane = CurrentLane;
switch (Steering) switch (Steering)
{ {
case 0: case 0:
PlayerCar->EnhancedInputMoveLeft(FInputActionInstance()); PlayerCar->EnhancedInputMoveLeft(FInputActionInstance());
TargetLane = CurrentLane - 1;
break; break;
case 1: case 1:
TargetLane = CurrentLane;
break; break;
case 2: case 2:
PlayerCar->EnhancedInputMoveRight(FInputActionInstance()); PlayerCar->EnhancedInputMoveRight(FInputActionInstance());
TargetLane = CurrentLane +1;
break; break;
} }
lastActionTime = GetWorld()->GetTimeSeconds();
} }
...@@ -16,8 +16,13 @@ class DRIVESAFE_API UCarTrainingEnv : public ULearningAgentsTrainingEnvironment ...@@ -16,8 +16,13 @@ class DRIVESAFE_API UCarTrainingEnv : public ULearningAgentsTrainingEnvironment
private: private:
//map of timecreation for id of agent //map of timecreation for id of agent
TMap<int32, float> AgentTimeMap; TMap<int32, int64> AgentHitCount;
public: public:
UPROPERTY(EditAnywhere, Category = "Trainer Env Settings")
int64 maxHitCount = 2;
virtual void GatherAgentReward_Implementation(float& OutReward, const int32 AgentId) override; virtual void GatherAgentReward_Implementation(float& OutReward, const int32 AgentId) override;
virtual void GatherAgentCompletion_Implementation(ELearningAgentsCompletion& OutCompletion, const int32 AgentId) override; virtual void GatherAgentCompletion_Implementation(ELearningAgentsCompletion& OutCompletion, const int32 AgentId) override;
virtual void ResetAgentEpisode_Implementation(const int32 AgentId) override; virtual void ResetAgentEpisode_Implementation(const int32 AgentId) override;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment