離散cartpole環境が正常に学習しない

33 vues (au cours des 30 derniers jours)
ryuuzi
ryuuzi le 25 Oct 2024 à 11:39
Réponse apportée : Hiro Yoshino le 5 Nov 2024 à 7:17
「create custom environment from class template」を参考に離散cartpole環境を作成して、強化学習デザイナーにインポートさせてみました。
しかし、学習が安定に収束してくれませんでした。試行錯誤してみましたが、対処法が思いつきませんでした。
教えてください
classdef matlab < rl.env.MATLABEnvironment
properties
% Acceleration due to gravity in m/s^2
Gravity = 9.8
% Mass of the cart
MassCart = 1.0
% Mass of the pole
MassPole = 0.1
% Half the length of the pole
Length = 0.5
% Max Force the input can appy
MaxForce = 10
% Sample time
Ts = 0.02
% Angle at which to fail the episode
ThetaThresholdRadians = 12 * pi/180
% Distance at which to fail the episode
XThreshold = 2.4
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1
% Penalty when the cart-pole fails to balance
PenaltyForFalling = -5
end
properties
% system state [x,dx,theta,dtheta]'
State = zeros(4,1)
end
properties(Access = protected)
% Internal flag to store stale env that is finished
IsDone = false
end
properties (Transient,Access = private)
Visualizer = []
end
methods
function this = matlab()%ObservationInfo, ActionInfo
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';
this = this@rl.env.MATLABEnvironment(ObservationInfo, ActionInfo);
updateActionInfo(this);
end
function set.State(this,state)
validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');
this.State = double(state(:));
notifyEnvUpdated(this);
end
function set.Length(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Length');
this.Length = val;
notifyEnvUpdated(this);
end
function set.Gravity(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');
this.Gravity = val;
end
function set.MassCart(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassCart');
this.MassCart = val;
end
function set.MassPole(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassPole');
this.MassPole = val;
end
function set.MaxForce(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');
this.MaxForce = val;
updateActionInfo(this);
end
function set.Ts(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');
this.Ts = val;
end
function set.ThetaThresholdRadians(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','ThetaThresholdRadians');
this.ThetaThresholdRadians = val;
notifyEnvUpdated(this);
end
function set.XThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','XThreshold');
this.XThreshold = val;
notifyEnvUpdated(this);
end
function set.RewardForNotFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');
this.RewardForNotFalling = val;
end
function set.PenaltyForFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');
this.PenaltyForFalling = val;
end
function [observation,reward,isdone,loggedSignals] = step(this,action)
loggedSignals = [];
% Get action
force = getForce(this,action);
% Unpack state vector
state = this.State;
%x = state(1);
x_dot = state(2);
theta = state(3);
theta_dot = state(4);
% Apply motion equations
costheta = cos(theta);
sintheta = sin(theta);
totalmass = this.MassCart + this.MassPole;
polemasslength = this.MassPole*this.Length;
temp = (force + polemasslength * theta_dot * theta_dot * sintheta) / totalmass;
thetaacc = (this.Gravity * sintheta - costheta* temp) / (this.Length * (4.0/3.0 - this.MassPole * costheta * costheta / totalmass));
xacc = temp - polemasslength * thetaacc * costheta / totalmass;
% Euler integration
observation = state + this.Ts.*[x_dot;xacc;theta_dot;thetaacc];
this.State = observation;
x = observation(1);
theta = observation(3);
isdone = abs(x) > this.XThreshold || abs(theta) > this.ThetaThresholdRadians;
this.IsDone = isdone;
% Get reward
reward = getReward(this,x,force);
end
function initialState = reset(this)
% Randomize the initial pendulum angle between (+- .05 rad)
% Theta (+- .05 rad)
T0 = 2*0.05*rand - 0.05;
% Thetadot
Td0 = 0;
% X
X0 = 0;
% Xdot
Xd0 = 0;
initialState= [X0;Xd0;T0;Td0];
this.State = initialState;
end
function varargout = plot(this)
% Visualizes the environment
if isempty(this.Visualizer) || ~isvalid(this.Visualizer)
this.Visualizer = rl.env.viz.CartPoleVisualizer(this);
else
bringToFront(this.Visualizer);
end
if nargout
varargout{1} = this.Visualizer;
end
end
end
methods (Access = protected)
function force = getForce(this,action)
if ~ismember(action,this.ActionInfo.Elements)
error(message('rl:env:CartPoleDiscreteInvalidAction',sprintf('%g',-this.MaxForce),sprintf('%g',this.MaxForce)));
end
force = action;
end
% update the action info based on max force
function updateActionInfo(this)
this.ActionInfo.Elements = this.MaxForce*[-1 10];
end
function Reward = getReward(this,~,~)
if ~this.IsDone
Reward = this.RewardForNotFalling;
else
Reward = this.PenaltyForFalling;
end
end
end
end

Réponse acceptée

Hiro Yoshino
Hiro Yoshino le 5 Nov 2024 à 7:17
に離散 cartpole が有るので、動作するものを開いて中身を調べてみると参考になる (答えが有る) かもしれません

Plus de réponses (0)

Catégories

En savoir plus sur ビッグ データの処理 dans Help Center et File Exchange

Tags

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!