Commit 832611cc authored by 高雅喆's avatar 高雅喆

update .gitignore

parent 78fbd701
data/
*.pyc
.DS_Store
.idea/*
.ipynb_checkpoints/
*.csv
*.ipynb
*.txt
$HOME/*
.idea/
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
</component>
</module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredIdentifiers">
<list>
<option value="virtualTB" />
</list>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (venv)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/gm-rl-rec.iml" filepath="$PROJECT_DIR$/.idea/gm-rl-rec.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="8abf2ea5-ed71-4a23-b5cd-c1b9c2061907" name="Default Changelist" comment="" />
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FileEditorManager">
<leaf SIDE_TABS_SIZE_LIMIT_KEY="300">
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/gm_train_ddpg.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="42">
<caret line="56" column="23" lean-forward="true" selection-start-line="56" selection-start-column="23" selection-end-line="56" selection-end-column="23" />
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/test_policies.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="150">
<caret line="10" column="16" lean-forward="true" selection-start-line="10" selection-start-column="16" selection-end-line="10" selection-end-column="16" />
<folding>
<element signature="e#0#38#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/tf_agents_test.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="234">
<caret line="169" column="45" lean-forward="true" selection-start-line="169" selection-start-column="45" selection-end-line="169" selection-end-column="45" />
<folding>
<element signature="e#881#919#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/rl_21_test.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="105">
<caret line="10" column="38" lean-forward="true" selection-start-line="10" selection-start-column="38" selection-end-line="10" selection-end-column="38" />
</state>
</provider>
</entry>
</file>
</leaf>
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="FindInProjectRecents">
<findStrings>
<find>num_iterations</find>
<find>environments</find>
<find>ts</find>
<find>driver</find>
<find>.train</find>
<find>dynamic_step_driver</find>
<find>optimizer</find>
<find>next</find>
<find>policy</find>
<find>eval_poli</find>
<find>compute</find>
<find>root_dir</find>
<find>train_dir</find>
<find>eval_dir</find>
<find>results</find>
<find>eval_summary_writer</find>
<find>global_step</find>
<find>eval_policy</find>
<find>spec</find>
<find>collect_policy</find>
<find>reset</find>
<find>AverageReturn</find>
<find>agent</find>
<find>num_repeates</find>
<find>time_step</find>
<find>policy_state</find>
<find>numpy</find>
<find>action_spec</find>
<find>action</find>
<find>train</find>
</findStrings>
<replaceStrings>
<replace />
</replaceStrings>
</component>
<component name="IdeDocumentHistory">
<option name="CHANGED_PATHS">
<list>
<option value="$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/environments/suite_gym.py" />
<option value="$PROJECT_DIR$/gm_train_eval.py" />
<option value="$PROJECT_DIR$/rl_21_test.py" />
<option value="$PROJECT_DIR$/test.py" />
<option value="$PROJECT_DIR$/test_environments.py" />
<option value="$PROJECT_DIR$/test_drivers.py" />
<option value="$PROJECT_DIR$/replay_buffers_test.py" />
<option value="$PROJECT_DIR$/test_replay_buffers.py" />
<option value="$PROJECT_DIR$/rl_CartPole_test.py" />
<option value="$PROJECT_DIR$/gm_train_ddpg.py" />
<option value="$PROJECT_DIR$/test_policies.py" />
<option value="$PROJECT_DIR$/tf_agents_test.py" />
</list>
</option>
</component>
<component name="ProjectFrameBounds">
<option name="x" value="57" />
<option name="y" value="23" />
<option name="width" value="1360" />
<option name="height" value="877" />
</component>
<component name="ProjectView">
<navigator proportions="" version="1">
<foldersAlwaysOnTop value="true" />
</navigator>
<panes>
<pane id="Scope" />
<pane id="ProjectPane">
<subPane>
<expand>
<path>
<item name="gm-rl-rec" type="b2602c69:ProjectViewProjectNode" />
<item name="gm-rl-rec" type="462c0819:PsiDirectoryNode" />
</path>
</expand>
<select />
</subPane>
</pane>
</panes>
</component>
<component name="PropertiesComponent">
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RunDashboard">
<option name="ruleStates">
<list>
<RuleState>
<option name="name" value="ConfigurationTypeDashboardGroupingRule" />
</RuleState>
<RuleState>
<option name="name" value="StatusDashboardGroupingRule" />
</RuleState>
</list>
</option>
</component>
<component name="RunManager" selected="Python.tf_agents_test">
<configuration name="rl_CartPole_test" type="PythonConfigurationType" factoryName="Python" temporary="true">
<module name="gm-rl-rec" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/rl_CartPole_test.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="test_drivers" type="PythonConfigurationType" factoryName="Python" temporary="true">
<module name="gm-rl-rec" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/test_drivers.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="test_policies" type="PythonConfigurationType" factoryName="Python" temporary="true">
<module name="gm-rl-rec" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/test_policies.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="test_replay_buffers" type="PythonConfigurationType" factoryName="Python" temporary="true">
<module name="gm-rl-rec" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/test_replay_buffers.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="tf_agents_test" type="PythonConfigurationType" factoryName="Python" temporary="true">
<module name="gm-rl-rec" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/tf_agents_test.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<list>
<item itemvalue="Python.rl_CartPole_test" />
<item itemvalue="Python.test_drivers" />
<item itemvalue="Python.test_policies" />
<item itemvalue="Python.test_replay_buffers" />
<item itemvalue="Python.tf_agents_test" />
</list>
<recent_temporary>
<list>
<item itemvalue="Python.tf_agents_test" />
<item itemvalue="Python.test_policies" />
<item itemvalue="Python.rl_CartPole_test" />
<item itemvalue="Python.test_replay_buffers" />
<item itemvalue="Python.test_drivers" />
</list>
</recent_temporary>
</component>
<component name="SvnConfiguration">
<configuration />
</component>
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="8abf2ea5-ed71-4a23-b5cd-c1b9c2061907" name="Default Changelist" comment="" />
<created>1562123253186</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1562123253186</updated>
</task>
<servers />
</component>
<component name="TodoView">
<todo-panel id="selected-file">
<is-autoscroll-to-source value="true" />
</todo-panel>
<todo-panel id="all">
<are-packages-shown value="true" />
<is-autoscroll-to-source value="true" />
</todo-panel>
</component>
<component name="ToolWindowManager">
<frame x="0" y="23" width="1360" height="877" extended-state="0" />
<layout>
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.24943988" />
<window_info id="Structure" order="1" side_tool="true" weight="0.25" />
<window_info id="Favorites" order="2" side_tool="true" />
<window_info id="Project Explorer" order="3" />
<window_info id="Job Explorer" order="4" />
<window_info anchor="bottom" id="Message" order="0" />
<window_info anchor="bottom" id="Find" order="1" />
<window_info anchor="bottom" id="Run" order="2" visible="true" weight="0.41338584" />
<window_info anchor="bottom" id="Debug" order="3" weight="0.4527559" />
<window_info anchor="bottom" id="Cvs" order="4" weight="0.25" />
<window_info anchor="bottom" id="Inspection" order="5" weight="0.4" />
<window_info anchor="bottom" id="TODO" order="6" weight="0.3299363" />
<window_info anchor="bottom" id="Version Control" order="7" />
<window_info anchor="bottom" id="Terminal" order="8" weight="0.32939634" />
<window_info anchor="bottom" id="Event Log" order="9" side_tool="true" />
<window_info anchor="bottom" id="Python Console" order="10" weight="0.3299363" />
<window_info anchor="bottom" id="Console" order="11" />
<window_info anchor="right" id="Commander" internal_type="SLIDING" order="0" type="SLIDING" weight="0.4" />
<window_info anchor="right" id="Ant Build" order="1" weight="0.25" />
<window_info anchor="right" content_ui="combo" id="Hierarchy" order="2" weight="0.25" />
</layout>
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/test.py</url>
<line>7</line>
<option name="timeStamp" value="30" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/test.py</url>
<line>8</line>
<option name="timeStamp" value="31" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/test.py</url>
<line>11</line>
<option name="timeStamp" value="32" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/test.py</url>
<line>13</line>
<option name="timeStamp" value="33" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/test.py</url>
<line>6</line>
<option name="timeStamp" value="34" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/tf_agents_test.py</url>
<line>168</line>
<option name="timeStamp" value="40" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/tf_agents_test.py</url>
<line>298</line>
<option name="timeStamp" value="52" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/tf_agents_test.py</url>
<line>304</line>
<option name="timeStamp" value="53" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/tf_agents_test.py</url>
<line>342</line>
<option name="timeStamp" value="62" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/tf_agents_test.py</url>
<line>343</line>
<option name="timeStamp" value="63" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/tf_agents_test.py</url>
<line>344</line>
<option name="timeStamp" value="64" />
</line-breakpoint>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/tf_agents_test.py</url>
<line>345</line>
<option name="timeStamp" value="65" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="editorHistoryManager">
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/gin/config.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="167">
<caret line="1008" selection-start-line="1008" selection-end-line="1008" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="1037">
<caret line="1346" column="28" lean-forward="true" selection-start-line="1346" selection-start-column="28" selection-end-line="1346" selection-end-column="28" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/environments/utils.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="161">
<caret line="52" column="19" lean-forward="true" selection-start-line="52" selection-start-column="19" selection-end-line="52" selection-end-column="19" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/environments/py_environment.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="151">
<caret line="143" column="6" selection-start-line="143" selection-start-column="6" selection-end-line="143" selection-end-column="6" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/gym/envs/mujoco/mujoco_env.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="210">
<caret line="14" column="18" lean-forward="true" selection-start-line="14" selection-start-column="18" selection-end-line="14" selection-end-column="18" />
<folding>
<element signature="e#0#9#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/environments/suite_mujoco.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="168">
<caret line="44" column="26" selection-start-line="44" selection-start-column="26" selection-end-line="44" selection-end-column="26" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/environments/wrappers.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="189">
<caret line="68" column="34" lean-forward="true" selection-start-line="68" selection-start-column="34" selection-end-line="68" selection-end-column="34" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Library/Caches/PyCharmCE2019.1/python_stubs/941605742/numpy/random/mtrand.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="146">
<caret line="2024" column="4" selection-start-line="2024" selection-start-column="4" selection-end-line="2024" selection-end-column="4" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/gm_train_eval.py" />
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="176">
<caret line="5552" selection-start-line="5552" selection-end-line="5552" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/eager/monitoring.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="176">
<caret line="327" selection-start-line="327" selection-end-line="327" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="191">
<caret line="233" selection-start-line="233" selection-end-line="233" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/compat/v2_compat.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="306">
<caret line="44" selection-start-line="44" selection-end-line="44" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/absl/app.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-2689">
<caret line="253" column="40" lean-forward="true" selection-start-line="253" selection-start-column="40" selection-end-line="253" selection-end-column="40" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/absl/logging/__init__.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="236">
<caret line="280" selection-start-line="280" selection-end-line="280" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Library/Caches/PyCharmCE2019.1/python_stubs/941605742/builtins.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="118">
<caret line="4164" column="8" selection-start-line="4164" selection-start-column="8" selection-end-line="4164" selection-end-column="8" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/absl/flags/_flagvalues.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="281">
<caret line="477" selection-start-line="477" selection-end-line="477" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/eager/context.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="168">
<caret line="1451" selection-start-line="1451" selection-end-line="1451" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="246">
<caret line="427" selection-start-line="427" selection-end-line="427" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/policies/random_py_policy.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="179">
<caret line="57" column="33" lean-forward="true" selection-start-line="57" selection-start-column="33" selection-end-line="57" selection-end-column="33" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/drivers/dynamic_step_driver.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="440">
<caret line="131" column="57" selection-start-line="131" selection-start-column="57" selection-end-line="131" selection-end-column="57" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/replay_buffers_test.py" />
<entry file="file://$PROJECT_DIR$/test_environments.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="60">
<caret line="4" column="14" selection-start-line="4" selection-start-column="5" selection-end-line="4" selection-end-column="14" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/test.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="105">
<caret line="7" column="3" lean-forward="true" selection-start-line="7" selection-start-column="3" selection-end-line="7" selection-end-column="3" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/test_drivers.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="302">
<caret line="33" column="38" lean-forward="true" selection-start-line="33" selection-start-column="38" selection-end-line="33" selection-end-column="38" />
<folding>
<element signature="e#0#18#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/trajectories/time_step.py">
<provider selected="true" editor-type-id="text-editor" />
</entry>
<entry file="file://$PROJECT_DIR$/test_replay_buffers.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="288">
<caret line="134" selection-start-line="134" selection-end-line="134" />
<folding>
<element signature="e#0#23#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/environments/suite_gym.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="181">
<caret line="44" column="51" selection-start-line="44" selection-start-column="51" selection-end-line="44" selection-end-column="51" />
<folding>
<element signature="e#1093#1103#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/utils/common.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="163">
<caret line="58" column="26" lean-forward="true" selection-start-line="58" selection-start-column="26" selection-end-line="58" selection-end-column="26" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/eval/metric_utils.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="76">
<caret line="119" column="17" selection-start-line="119" selection-start-column="4" selection-end-line="119" selection-end-column="17" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/torch/__init__.pyi">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="143">
<caret line="491" column="8" selection-start-line="491" selection-start-column="8" selection-end-line="491" selection-end-column="8" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tensorflow/python/training/training_util.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="15">
<caret line="147" column="4" selection-start-line="147" selection-start-column="4" selection-end-line="147" selection-end-column="4" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/specs/__init__.py">
<provider selected="true" editor-type-id="text-editor">
<state>
<folding>
<element signature="e#753#801#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/specs/array_spec.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-215" />
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/environments/tf_environment.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="276">
<caret line="148" column="33" selection-start-line="148" selection-start-column="33" selection-end-line="148" selection-end-column="33" />
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/agents/tf_agent.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="162">
<caret line="219" column="6" selection-start-line="219" selection-start-column="6" selection-end-line="219" selection-end-column="6" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/rl_CartPole_test.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="150">
<caret line="10" column="33" lean-forward="true" selection-start-line="10" selection-start-column="33" selection-end-line="10" selection-end-column="33" />
<folding>
<element signature="e#0#44#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$USER_HOME$/Downloads/code/agents/venv/lib/python3.6/site-packages/tf_agents/policies/py_policy.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="198">
<caret line="101" column="6" selection-start-line="101" selection-start-column="6" selection-end-line="101" selection-end-column="6" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/tf_agents_test.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="234">
<caret line="169" column="45" lean-forward="true" selection-start-line="169" selection-start-column="45" selection-end-line="169" selection-end-column="45" />
<folding>
<element signature="e#881#919#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/test_policies.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="150">
<caret line="10" column="16" lean-forward="true" selection-start-line="10" selection-start-column="16" selection-end-line="10" selection-end-column="16" />
<folding>
<element signature="e#0#38#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/gm_train_ddpg.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="42">
<caret line="56" column="23" lean-forward="true" selection-start-line="56" selection-start-column="23" selection-end-line="56" selection-end-column="23" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/rl_21_test.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="105">
<caret line="10" column="38" lean-forward="true" selection-start-line="10" selection-start-column="38" selection-end-line="10" selection-end-column="38" />
</state>
</provider>
</entry>
</component>
</project>
\ No newline at end of file
# coding=utf-8
# Copyright 2018 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Train and Eval DDPG.
To run:
```bash
tensorboard --logdir $HOME/tmp/ddpg/gym/HalfCheetah-v2/ --port 2223 &
python tf_agents/agents/ddpg/examples/v2/train_eval.py \
--root_dir=$HOME/tmp/ddpg/gym/HalfCheetah-v2/ \
--num_iterations=2000000 \
--alsologtostderr
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from tf_agents.agents.ddpg import actor_network
from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.ddpg import ddpg_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import parallel_py_environment
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.environments import py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common
import abc
import tensorflow as tf
import numpy as np
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts
flags.DEFINE_string('root_dir', '$HOME/tmp/ddpg/gym/HalfCheetah-v2/',
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_integer('num_iterations', 100000,
'Total number train/eval iterations to perform.')
flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding parameters.')
FLAGS = flags.FLAGS
class CardGameEnv(py_environment.PyEnvironment):
def __init__(self):
self._action_spec = array_spec.BoundedArraySpec(
shape=(), dtype=np.float, minimum=0.0, maximum=1.0, name='action')
self._observation_spec = array_spec.BoundedArraySpec(
shape=(1,), dtype=np.float, minimum=0.0, name='observation')
self._state = 0.0
self._episode_ended = False
self._current_time_step = self._reset()
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def _reset(self):
self._state = 0.0
self._episode_ended = False
return ts.restart(np.array([self._state], dtype=np.float))
def _step(self, action):
if self._episode_ended:
# The last action ended the episode. Ignore the current action and start
# a new episode.
return self.reset()
# Make sure episodes don't go on forever.
if action >= 1.0:
self._episode_ended = True
elif action < 1.0:
new_card = np.random.randint(1, 11)
self._state += new_card
else:
raise ValueError('`action` should be 0 or 1.')
if self._episode_ended or self._state >= 21:
reward = self._state - 21 if self._state <= 21 else -21
return ts.termination(np.array([self._state], dtype=np.float), reward)
else:
return ts.transition(
np.array([self._state], dtype=np.float), reward=0.0, discount=1.0)
@gin.configurable
def train_eval(
root_dir,
env_name='',
env_load_fn=suite_mujoco.load,
num_iterations=2000000,
actor_fc_layers=(400, 300),
critic_obs_fc_layers=(400,),
critic_action_fc_layers=None,
critic_joint_fc_layers=(300,),
# Params for collect
initial_collect_steps=1,
collect_steps_per_iteration=1,
num_parallel_environments=1,
replay_buffer_capacity=100000,
ou_stddev=0.2,
ou_damping=0.15,
# Params for target update
target_update_tau=0.05,
target_update_period=5,
# Params for train
train_steps_per_iteration=1,
batch_size=64,
actor_learning_rate=1e-4,
critic_learning_rate=1e-3,
dqda_clipping=None,
td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
gamma=0.995,
reward_scale_factor=1.0,
gradient_clipping=None,
use_tf_functions=True,
# Params for eval
num_eval_episodes=10,
eval_interval=10,
# Params for checkpoints, summaries, and logging
log_interval=10,
summary_interval=10,
summaries_flush_secs=10,
debug_summaries=False,
summarize_grads_and_vars=False,
eval_metrics_callback=None):
"""A simple train and eval for DDPG."""
# tensorboard log
root_dir = os.path.expanduser(root_dir)
train_dir = os.path.join(root_dir, 'train')
eval_dir = os.path.join(root_dir, 'eval')
train_summary_writer = tf.compat.v2.summary.create_file_writer(
train_dir, flush_millis=summaries_flush_secs * 1000)
train_summary_writer.set_as_default()
eval_summary_writer = tf.compat.v2.summary.create_file_writer(
eval_dir, flush_millis=summaries_flush_secs * 1000)
eval_metrics = [
tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
]
# initialize
env = CardGameEnv()
global_step = tf.compat.v1.train.get_or_create_global_step()
with tf.compat.v2.summary.record_if(
lambda: tf.math.equal(global_step % summary_interval, 0)):
tf_env = tf_py_environment.TFPyEnvironment(env)
eval_tf_env = tf_py_environment.TFPyEnvironment(env)
actor_net = actor_network.ActorNetwork(
tf_env.time_step_spec().observation,
tf_env.action_spec(),
fc_layer_params=actor_fc_layers,
)
critic_net_input_specs = (tf_env.time_step_spec().observation,
tf_env.action_spec())
critic_net = critic_network.CriticNetwork(
critic_net_input_specs,
observation_fc_layer_params=critic_obs_fc_layers,
action_fc_layer_params=critic_action_fc_layers,
joint_fc_layer_params=critic_joint_fc_layers,
)
tf_agent = ddpg_agent.DdpgAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
actor_network=actor_net,
critic_network=critic_net,
actor_optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=actor_learning_rate),
critic_optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=critic_learning_rate),
ou_stddev=ou_stddev,
ou_damping=ou_damping,
target_update_tau=target_update_tau,
target_update_period=target_update_period,
dqda_clipping=dqda_clipping,
td_errors_loss_fn=td_errors_loss_fn,
gamma=gamma,
reward_scale_factor=reward_scale_factor,
gradient_clipping=gradient_clipping,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=global_step)
tf_agent.initialize()
train_metrics = [
tf_metrics.NumberOfEpisodes(),
tf_metrics.EnvironmentSteps(),
tf_metrics.AverageReturnMetric(),
tf_metrics.AverageEpisodeLengthMetric(),
]
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
tf_agent.collect_data_spec,
batch_size=tf_env.batch_size,
max_length=replay_buffer_capacity)
initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
tf_env,
collect_policy,
observers=[replay_buffer.add_batch],
num_steps=initial_collect_steps)
collect_driver = dynamic_step_driver.DynamicStepDriver(
tf_env,
collect_policy,
observers=[replay_buffer.add_batch] + train_metrics,
num_steps=collect_steps_per_iteration)
if use_tf_functions:
initial_collect_driver.run = common.function(initial_collect_driver.run)
collect_driver.run = common.function(collect_driver.run)
tf_agent.train = common.function(tf_agent.train)
# Collect initial replay data.
logging.info(
'Initializing replay buffer by collecting experience for %d steps with '
'a random policy.', initial_collect_steps)
initial_collect_driver.run()
results = metric_utils.eager_compute(
eval_metrics,
eval_tf_env,
eval_policy,
num_episodes=num_eval_episodes,
train_step=global_step,
summary_writer=eval_summary_writer,
summary_prefix='Metrics',
)
if eval_metrics_callback is not None:
eval_metrics_callback(results, global_step.numpy())
metric_utils.log_metrics(eval_metrics)
time_step = None
policy_state = collect_policy.get_initial_state(tf_env.batch_size)
timed_at_step = global_step.numpy()
time_acc = 0
# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
# train and eval
for _ in range(num_iterations):
start_time = time.time()
time_step, policy_state = collect_driver.run(
time_step=time_step,
policy_state=policy_state,
)
for _ in range(train_steps_per_iteration):
experience, _ = next(iterator)
train_loss = tf_agent.train(experience)
action_step = eval_policy.action(time_step)
print("print eval action", "-" * 100)
print(action_step.action)
time_acc += time.time() - start_time
if global_step.numpy() % log_interval == 0:
logging.info('step = %d, loss = %f', global_step.numpy(),
train_loss.loss)
steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
logging.info('%.3f steps/sec', steps_per_sec)
tf.compat.v2.summary.scalar(
name='global_steps_per_sec', data=steps_per_sec, step=global_step)
timed_at_step = global_step.numpy()
time_acc = 0
for train_metric in train_metrics:
train_metric.tf_summaries(
train_step=global_step, step_metrics=train_metrics[:2])
if global_step.numpy() % eval_interval == 0:
results = metric_utils.eager_compute(
eval_metrics,
eval_tf_env,
eval_policy,
num_episodes=num_eval_episodes,
train_step=global_step,
summary_writer=eval_summary_writer,
summary_prefix='Metrics',
)
if eval_metrics_callback is not None:
eval_metrics_callback(results, global_step.numpy())
metric_utils.log_metrics(eval_metrics)
return train_loss
def main(_):
tf.compat.v1.enable_v2_behavior()
logging.set_verbosity(logging.INFO)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
if __name__ == '__main__':
app.run(main)
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import q_network
import tensorflow as tf
from tf_agents.agents.dqn import dqn_agent
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
from tf_agents.environments import py_environment
from tf_agents.specs import array_spec
import numpy as np
from tf_agents.trajectories import time_step as ts
tf.compat.v1.enable_v2_behavior()
class CardGameEnv(py_environment.PyEnvironment):
def __init__(self):
self._action_spec = array_spec.BoundedArraySpec(
shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')
self._observation_spec = array_spec.BoundedArraySpec(
shape=(1,), dtype=np.int32, minimum=0, name='observation')
self._state = 0
self._episode_ended = False
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def _reset(self):
self._state = 0
self._episode_ended = False
return ts.restart(np.array([self._state], dtype=np.int32))
def _step(self, action):
if self._episode_ended:
# The last action ended the episode. Ignore the current action and start
# a new episode.
return self.reset()
# Make sure episodes don't go on forever.
if action == 1:
self._episode_ended = True
elif action == 0:
new_card = np.random.randint(1, 11)
# print("random card")
# print(new_card)
# print("state")
# print(self._state)
self._state += new_card
else:
raise ValueError('`action` should be 0 or 1.')
if self._episode_ended or self._state >= 21:
reward = self._state - 21 if self._state <= 21 else -21
return ts.termination(np.array([self._state], dtype=np.int32), reward)
else:
return ts.transition(
np.array([self._state], dtype=np.int32), reward=0.0, discount=1.0)
num_iterations =3000 # @param
initial_collect_steps = 1000 # @param
collect_steps_per_iteration = 1 # @param
replay_buffer_capacity = 100000 # @param
fc_layer_params = (100,)
batch_size = 64 # @param
learning_rate = 1e-3 # @param
log_interval = 200 # @param
num_eval_episodes = 10 # @param
eval_interval = 1000 # @param
env = CardGameEnv()
env.reset()
print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
time_step = env.reset()
print('Time step:')
print(time_step)
action = 1
next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
train_py_env = CardGameEnv()
eval_py_env = CardGameEnv()
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
#agent
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.compat.v2.Variable(0)
tf_agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
train_step_counter=train_step_counter)
tf_agent.initialize()
#policy
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
train_env.action_spec())
#Metrics and Evaluation
def compute_avg_return(environment, policy, num_episodes=10):
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
# print("eval action","-"*100)
# print(action_step.action)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
#Replay Buffer
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=tf_agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
#Data Collection
def collect_step(environment, policy):
time_step = environment.current_time_step()
action_step = policy.action(time_step)
next_time_step = environment.step(action_step.action)
traj = trajectory.from_transition(time_step, action_step, next_time_step)
# Add trajectory to the replay buffer
replay_buffer.add_batch(traj)
for _ in range(initial_collect_steps):
collect_step(train_env, random_policy)
# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3)
iterator = iter(dataset)
#Training the agent
tf_agent.train = common.function(tf_agent.train)
# Reset the train step
tf_agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]
# returns = []
for _ in range(num_iterations):
# Collect a few steps using collect_policy and save to the replay buffer.
for _ in range(collect_steps_per_iteration):
collect_step(train_env, tf_agent.collect_policy)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = tf_agent.train(experience)
step = tf_agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
#plots
steps = range(0, num_iterations + 1, eval_interval)
print("-"*100)
print(steps,returns)
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import q_network
import tensorflow as tf
from tf_agents.agents.dqn import dqn_agent
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tf.compat.v1.enable_v2_behavior()
# CartPole-v0 Environment
env_name = 'CartPole-v0'
num_iterations =1000 # @param
initial_collect_steps = 1000 # @param
collect_steps_per_iteration = 1 # @param
replay_buffer_capacity = 100000 # @param
fc_layer_params = (100,)
batch_size = 64 # @param
learning_rate = 1e-3 # @param
log_interval = 200 # @param
num_eval_episodes = 10 # @param
eval_interval = 1000 # @param
env = suite_gym.load(env_name)
env.reset()
print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
time_step = env.reset()
print('Time step:')
print(time_step)
action = 1
next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
#agent
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.compat.v2.Variable(0)
tf_agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
train_step_counter=train_step_counter)
tf_agent.initialize()
#policy
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
train_env.action_spec())
#Metrics and Evaluation
def compute_avg_return(environment, policy, num_episodes=10):
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
print("print eval action","-"*100)
print(action_step.action)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
#Replay Buffer
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=tf_agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
#Data Collection
def collect_step(environment, policy):
time_step = environment.current_time_step()
action_step = policy.action(time_step)
next_time_step = environment.step(action_step.action)
traj = trajectory.from_transition(time_step, action_step, next_time_step)
# Add trajectory to the replay buffer
replay_buffer.add_batch(traj)
for _ in range(initial_collect_steps):
collect_step(train_env, random_policy)
# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3)
iterator = iter(dataset)
#Training the agent
tf_agent.train = common.function(tf_agent.train)
# Reset the train step
tf_agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
# avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
# returns = [avg_return]
# print("before training returns")
# print(returns)
returns = []
for _ in range(num_iterations):
# Collect a few steps using collect_policy and save to the replay buffer.
for _ in range(collect_steps_per_iteration):
collect_step(train_env, tf_agent.collect_policy)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = tf_agent.train(experience)
step = tf_agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
#plots
steps = range(0, num_iterations + 1, eval_interval)
print("-"*100)
print("after training returns")
print(returns)
def helloworld():
print("hello world")
print("hello world2")
print("hello world3")
print("hello world4")
a=1
b=2
c=3
for i in range(1,3):
print("i",i)
print(i)
a=1
b=2
c=3
helloworld()
a=1
b=2
c=3
\ No newline at end of file
import numpy as np
import tensorflow as tf
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import dynamic_episode_driver
tf.compat.v1.enable_v2_behavior()
#TensorFlow Drivers
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
time_step_spec=tf_env.time_step_spec())
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
tf_env, tf_policy, observers, num_episodes=2)
# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
\ No newline at end of file
import tensorflow as tf
from tf_agents.environments import py_environment
from tf_agents.specs import array_spec
import numpy as np
from tf_agents.trajectories import time_step as ts
tf.compat.v1.enable_v2_behavior()
class CardGameEnv(py_environment.PyEnvironment):
def __init__(self):
self._action_spec = array_spec.BoundedArraySpec(
shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')
self._observation_spec = array_spec.BoundedArraySpec(
shape=(1,), dtype=np.int32, minimum=0, name='observation')
self._state = 0
self._episode_ended = False
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def _reset(self):
self._state = 0
self._episode_ended = False
return ts.restart(np.array([self._state], dtype=np.int32))
def _step(self, action):
if self._episode_ended:
# The last action ended the episode. Ignore the current action and start
# a new episode.
return self.reset()
# Make sure episodes don't go on forever.
if action == 1:
self._episode_ended = True
elif action == 0:
new_card = np.random.randint(1, 11)
self._state += new_card
else:
raise ValueError('`action` should be 0 or 1.')
if self._episode_ended or self._state >= 21:
reward = self._state - 21 if self._state <= 21 else -21
return ts.termination(np.array([self._state], dtype=np.int32), reward)
else:
return ts.transition(
np.array([self._state], dtype=np.int32), reward=0.0, discount=1.0)
get_new_card_action = 0
end_round_action = 1
environment = CardGameEnv()
time_step = environment.reset()
print(time_step)
cumulative_reward = time_step.reward
for _ in range(3):
time_step = environment.step(get_new_card_action)
print(time_step)
cumulative_reward += time_step.reward
time_step = environment.step(end_round_action)
print(time_step)
cumulative_reward += time_step.reward
print('Final Reward = ', cumulative_reward)
\ No newline at end of file
from tf_agents.specs import array_spec
from tf_agents.policies import random_py_policy
import numpy as np
from tf_agents.policies import scripted_py_policy
#Random Python Policy
action_spec = array_spec.BoundedArraySpec(shape=(1,), dtype=np.int32, minimum=0, maximum=10)
my_random_py_policy = random_py_policy.RandomPyPolicy(time_step_spec=None,
action_spec=action_spec)
time_step = None
action_step = my_random_py_policy.action(time_step)
print(action_step)
action_step = my_random_py_policy.action(time_step)
print(action_step)
print("*"*100)
#Scripted Python Policy
action_spec = array_spec.BoundedArraySpec((2,), np.int32, -10, 10)
action_script = [(1, np.array([5, 2], dtype=np.int32)),
(0, np.array([0, 0], dtype=np.int32)), # Setting `num_repeates` to 0 will skip this action.
(2, np.array([1, 2], dtype=np.int32)),
(1, np.array([3, 4], dtype=np.int32))]
my_scripted_py_policy = scripted_py_policy.ScriptedPyPolicy(
time_step_spec=None, action_spec=action_spec, action_script=action_script)
policy_state = my_scripted_py_policy.get_initial_state()
time_step = None
print('Executing scripted policy...')
action_step = my_scripted_py_policy.action(time_step, policy_state)
print(action_step.action[0])
action_step= my_scripted_py_policy.action(time_step, action_step.state)
print(action_step.action[0])
action_step = my_scripted_py_policy.action(time_step, action_step.state)
print(action_step.action[0])
action_step = my_scripted_py_policy.action(time_step, action_step.state)
print(action_step.action[0])
print('Resetting my_scripted_py_policy...')
policy_state = my_scripted_py_policy.get_initial_state()
action_step = my_scripted_py_policy.action(time_step, policy_state)
print(action_step)
print("*"*100)
import tensorflow as tf
import numpy as np
from tf_agents import specs
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import q_network
from tf_agents.replay_buffers import py_uniform_replay_buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step
tf.compat.v1.enable_v2_behavior()
#Creating the buffer
data_spec = (
tf.TensorSpec([3], tf.float32, 'action'),
(
tf.TensorSpec([5], tf.float32, 'lidar'),
tf.TensorSpec([3, 2], tf.float32, 'camera')
)
)
batch_size = 32
max_length = 1000
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec,
batch_size=batch_size,
max_length=max_length)
#Writing to the buffer
action = tf.constant(1 * np.ones(
data_spec[0].shape.as_list(), dtype=np.float32))
lidar = tf.constant(
2 * np.ones(data_spec[1][0].shape.as_list(), dtype=np.float32))
camera = tf.constant(
3 * np.ones(data_spec[1][1].shape.as_list(), dtype=np.float32))
values = (action, (lidar, camera))
values_batched = tf.nest.map_structure(lambda t: tf.stack([t] * batch_size),
values)
replay_buffer.add_batch(values_batched)
#Reading form the buffer
# add more items to the buffer before reading
for _ in range(5):
replay_buffer.add_batch(values_batched)
# Get one sample from the replay buffer with batch size 10 and 1 timestep:
sample = replay_buffer.get_next(sample_batch_size=10, num_steps=1)
# Convert the replay buffer to a tf.data.Dataset and iterate through it
dataset = replay_buffer.as_dataset(
sample_batch_size=4,
num_steps=2)
iterator = iter(dataset)
print("Iterator trajectories:")
trajectories = []
for _ in range(3):
t, _ = next(iterator)
trajectories.append(t)
print(tf.nest.map_structure(lambda t: t.shape, trajectories))
# Read all elements in the replay buffer:
trajectories = replay_buffer.gather_all()
print("Trajectories from gather all:")
print(tf.nest.map_structure(lambda t: t.shape, trajectories))
#PyUniformReplayBuffer
replay_buffer_capacity = 1000*32 # same capacity as the TFUniformReplayBuffer
py_replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
capacity=replay_buffer_capacity,
data_spec=tensor_spec.to_nest_array_spec(data_spec))
#Using replay buffers during training
#Data collection
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
q_net = q_network.QNetwork(
tf_env.time_step_spec().observation,
tf_env.action_spec(),
fc_layer_params=(100,))
agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=tf.compat.v1.train.AdamOptimizer(0.001))
replay_buffer_capacity = 1000
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
agent.collect_data_spec,
batch_size=tf_env.batch_size,
max_length=replay_buffer_capacity)
# Add an observer that adds to the replay buffer:
replay_observer = [replay_buffer.add_batch]
collect_steps_per_iteration = 10
collect_op = dynamic_step_driver.DynamicStepDriver(
tf_env,
agent.collect_policy,
observers=replay_observer,
num_steps=collect_steps_per_iteration).run()
#Reading data for a train step
# Read the replay buffer as a Dataset,
# read batches of 4 elements, each with 2 timesteps:
dataset = replay_buffer.as_dataset(
sample_batch_size=4,
num_steps=2)
iterator = iter(dataset)
num_train_steps = 10
for _ in range(num_train_steps):
trajectories, _ = next(iterator)
loss = agent.train(experience=trajectories)
# coding=utf-8
# Copyright 2018 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Train and Eval DDPG.
To run:
```bash
tensorboard --logdir $HOME/tmp/ddpg/gym/HalfCheetah-v2/ --port 2223 &
python tf_agents/agents/ddpg/examples/v2/train_eval.py \
--root_dir=$HOME/tmp/ddpg/gym/HalfCheetah-v2/ \
--num_iterations=2000000 \
--alsologtostderr
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
from tf_agents.agents.ddpg import actor_network
from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.ddpg import ddpg_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import parallel_py_environment
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.environments import py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common
import abc
import tensorflow as tf
import numpy as np
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts
flags.DEFINE_string('root_dir', '$HOME/tmp/ddpg/gym/HalfCheetah-v2/',
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_integer('num_iterations', 100000,
'Total number train/eval iterations to perform.')
flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding parameters.')
FLAGS = flags.FLAGS
class CardGameEnv(py_environment.PyEnvironment):
def __init__(self):
self._action_spec = array_spec.BoundedArraySpec(
shape=(), dtype=np.float, minimum=0.0, maximum=10.0, name='action')
self._observation_spec = array_spec.BoundedArraySpec(
shape=(1,), dtype=np.float, minimum=0.0, name='observation')
self._state = 0.0
self._episode_ended = False
self._current_time_step = self._reset()
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def _reset(self):
self._state = 0.0
self._episode_ended = False
return ts.restart(np.array([self._state], dtype=np.float))
def _step(self, action):
if self._episode_ended:
# The last action ended the episode. Ignore the current action and start
# a new episode.
return self.reset()
# Make sure episodes don't go on forever.
if action >= 5.0:
self._episode_ended = True
elif action < 1.0:
new_card = np.random.randint(1, 11)
self._state += new_card
else:
raise ValueError('`action` should be 0 or 1.')
if self._episode_ended or self._state >= 21:
reward = self._state - 21 if self._state <= 21 else -21
return ts.termination(np.array([self._state], dtype=np.float), reward)
else:
return ts.transition(
np.array([self._state], dtype=np.float), reward=0.0, discount=1.0)
@gin.configurable
def train_eval(
root_dir,
env_name='',
env_load_fn=suite_mujoco.load,
num_iterations=2000000,
actor_fc_layers=(400, 300),
critic_obs_fc_layers=(400,),
critic_action_fc_layers=None,
critic_joint_fc_layers=(300,),
# Params for collect
initial_collect_steps=1,
collect_steps_per_iteration=1,
num_parallel_environments=1,
replay_buffer_capacity=100000,
ou_stddev=0.2,
ou_damping=0.15,
# Params for target update
target_update_tau=0.05,
target_update_period=5,
# Params for train
train_steps_per_iteration=1,
batch_size=64,
actor_learning_rate=1e-4,
critic_learning_rate=1e-3,
dqda_clipping=None,
td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
gamma=0.995,
reward_scale_factor=1.0,
gradient_clipping=None,
use_tf_functions=True,
# Params for eval
num_eval_episodes=10,
eval_interval=10,
# Params for checkpoints, summaries, and logging
log_interval=10,
summary_interval=10,
summaries_flush_secs=10,
debug_summaries=False,
summarize_grads_and_vars=False,
eval_metrics_callback=None):
"""A simple train and eval for DDPG."""
# tensorboard log
root_dir = os.path.expanduser(root_dir)
train_dir = os.path.join(root_dir, 'train')
eval_dir = os.path.join(root_dir, 'eval')
train_summary_writer = tf.compat.v2.summary.create_file_writer(
train_dir, flush_millis=summaries_flush_secs * 1000)
train_summary_writer.set_as_default()
eval_summary_writer = tf.compat.v2.summary.create_file_writer(
eval_dir, flush_millis=summaries_flush_secs * 1000)
eval_metrics = [
tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
]
# initialize
env = CardGameEnv()
global_step = tf.compat.v1.train.get_or_create_global_step()
with tf.compat.v2.summary.record_if(
lambda: tf.math.equal(global_step % summary_interval, 0)):
tf_env = tf_py_environment.TFPyEnvironment(env)
eval_tf_env = tf_py_environment.TFPyEnvironment(env)
actor_net = actor_network.ActorNetwork(
tf_env.time_step_spec().observation,
tf_env.action_spec(),
fc_layer_params=actor_fc_layers,
)
critic_net_input_specs = (tf_env.time_step_spec().observation,
tf_env.action_spec())
critic_net = critic_network.CriticNetwork(
critic_net_input_specs,
observation_fc_layer_params=critic_obs_fc_layers,
action_fc_layer_params=critic_action_fc_layers,
joint_fc_layer_params=critic_joint_fc_layers,
)
tf_agent = ddpg_agent.DdpgAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
actor_network=actor_net,
critic_network=critic_net,
actor_optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=actor_learning_rate),
critic_optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=critic_learning_rate),
ou_stddev=ou_stddev,
ou_damping=ou_damping,
target_update_tau=target_update_tau,
target_update_period=target_update_period,
dqda_clipping=dqda_clipping,
td_errors_loss_fn=td_errors_loss_fn,
gamma=gamma,
reward_scale_factor=reward_scale_factor,
gradient_clipping=gradient_clipping,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=global_step)
tf_agent.initialize()
train_metrics = [
tf_metrics.NumberOfEpisodes(),
tf_metrics.EnvironmentSteps(),
tf_metrics.AverageReturnMetric(),
tf_metrics.AverageEpisodeLengthMetric(),
]
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
tf_agent.collect_data_spec,
batch_size=tf_env.batch_size,
max_length=replay_buffer_capacity)
initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
tf_env,
collect_policy,
observers=[replay_buffer.add_batch],
num_steps=initial_collect_steps)
collect_driver = dynamic_step_driver.DynamicStepDriver(
tf_env,
collect_policy,
observers=[replay_buffer.add_batch] + train_metrics,
num_steps=collect_steps_per_iteration)
if use_tf_functions:
initial_collect_driver.run = common.function(initial_collect_driver.run)
collect_driver.run = common.function(collect_driver.run)
tf_agent.train = common.function(tf_agent.train)
# Collect initial replay data.
logging.info(
'Initializing replay buffer by collecting experience for %d steps with '
'a random policy.', initial_collect_steps)
initial_collect_driver.run()
results = metric_utils.eager_compute(
eval_metrics,
eval_tf_env,
eval_policy,
num_episodes=num_eval_episodes,
train_step=global_step,
summary_writer=eval_summary_writer,
summary_prefix='Metrics',
)
if eval_metrics_callback is not None:
eval_metrics_callback(results, global_step.numpy())
metric_utils.log_metrics(eval_metrics)
time_step = None
policy_state = collect_policy.get_initial_state(tf_env.batch_size)
timed_at_step = global_step.numpy()
time_acc = 0
# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
# train and eval
for _ in range(num_iterations):
start_time = time.time()
time_step, policy_state = collect_driver.run(
time_step=time_step,
policy_state=policy_state,
)
for _ in range(train_steps_per_iteration):
experience, _ = next(iterator)
train_loss = tf_agent.train(experience)
action_step = eval_policy.action(time_step)
print("-" * 100)
print(int(action_step.action.numpy()[0]))
time_acc += time.time() - start_time
if global_step.numpy() % log_interval == 0:
logging.info('step = %d, loss = %f', global_step.numpy(),
train_loss.loss)
steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
logging.info('%.3f steps/sec', steps_per_sec)
tf.compat.v2.summary.scalar(
name='global_steps_per_sec', data=steps_per_sec, step=global_step)
timed_at_step = global_step.numpy()
time_acc = 0
for train_metric in train_metrics:
train_metric.tf_summaries(
train_step=global_step, step_metrics=train_metrics[:2])
if global_step.numpy() % eval_interval == 0:
results = metric_utils.eager_compute(
eval_metrics,
eval_tf_env,
eval_policy,
num_episodes=num_eval_episodes,
train_step=global_step,
summary_writer=eval_summary_writer,
summary_prefix='Metrics',
)
if eval_metrics_callback is not None:
eval_metrics_callback(results, global_step.numpy())
metric_utils.log_metrics(eval_metrics)
return train_loss
def main(_):
tf.compat.v1.enable_v2_behavior()
logging.set_verbosity(logging.INFO)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
if __name__ == '__main__':
app.run(main)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment