|
|
{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Example usage of Tempestpy" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)\n", "Hello from the pygame community. https://www.pygame.org/contribute.html\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-11-29 12:18:59.459471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-11-29 12:18:59.474489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-11-29 12:18:59.478811: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2024-11-29 12:18:59.488641: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2024-11-29 12:19:00.368388: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", "error: XDG_RUNTIME_DIR not set in the environment.\n" ] } ], "source": [ "from sb3_contrib import MaskablePPO\n", "from sb3_contrib.common.wrappers import ActionMasker\n", "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n", "\n", "import gymnasium as gym\n", "\n", "from minigrid.core.actions import Actions\n", "from minigrid.core.constants import TILE_PIXELS\n", "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n", "\n", "import tempfile, datetime, shutil\n", "\n", "import time\n", "import os\n", "\n", "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n", "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n", "\n", "import os, sys\n", "from copy import deepcopy\n", "\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting the training\n" ] }, { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAEgAWADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDDooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA8y1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KAPpaiiivzQ/IAooooA8A1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOv0s/Xy5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KAPZKKKKACiiigAooooAKKKKAPJ9X/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAXNX/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQB9LUUUV+aH5AFFFFAHzrq//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6/Sz9fCiiigC5q//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6ACiiigC5q//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6ACiiigC5q//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6ACiiigD2SiiigAooooAKKKKACiiigDzLVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KAPZKKKKACiiigAooooAKKKKAPJ9X/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAXNX/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAXNX/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAeyUUUUAFFFFABRRRQAUUUUAeZapqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/+gjd/9/2/xo1f/kNX/wD18Sf+hGqdAFz+19T/AOgjd/8Af9v8aP7X1P8A6CN3/wB/2/xqnRQBr6pqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/+gjd/9/2/xo1f/kNX/wD18Sf+hGqdAFz+19 "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAEgCAIAAAAFWz53AAAR8ElEQVR4Ae2d4XnUSBKGxzz8vksAXwC7CeBN4DYBNoLZADAJkMCZAJiLgAQggTMJ7AWwJoEjgb1PFm6Klrqnx6WR5NKrxw+0uqtKU2/1fJZa1szFfr/fsUEAAhAYI/BsrJM+CEAAAh0BBIJ5AAEIFAkgEEU0DEAAAggEcwACECgSQCCKaBiAAAQQCOYABCBQJIBAFNEwAAEIIBDMAQhAoEgAgSiiYQACEEAgmAMQgECRAAJRRMMABCCAQDAHIACBIgEEooiGAQhAAIFgDkAAAkUCCEQRDQMQgAACwRyAAASKBBCIIhoGIAABBII5AAEIFAkgEEU0DEAAAggEcwACECgSQCCKaBiAAAQQCOYABCBQJIBAFNEwAAEIIBDMAQhAoEgAgSiiYQACEEAgmAMQgECRAAJRRMMABCCAQDAHIACBIgEEooiGAQhAAIFgDkAAAkUCCEQRDQMQgMBzEMQmcHh/OGuC+9/3Z41P8GUJzCEQh8MZ5+h+vyd+bQ69rw1OMgb/CsYZ5mfl6P4hLjH8DIkAgbAEEIiwpSUxCPgJIBB+hkSAQFgCCETY0pIYBPwEEAg/QyJAICwBBCJsaUkMAn4CCISfIREgEJYAAhG2tCQGAT8BBMLPkAgQCEsAgQhbWhKDgJ8AAuFnSAQIhCWAQIQtLYlBwE8AgfAzJAIEwhJAIMKWlsQg4CeAQPgZEgECYQkgEGFLS2IQ8BNAIPwMiQCBsAQQiLClJTEI+AkgEH6GRIBAWAIIRNjSkhgE/AQQCD9DIkAgLAEEImxpSQwCfgIIhJ8hESAQlgACEba0JAYBPwEEws+QCBAISwCBCFtaEoOAnwAC4WdIBAiEJYBAhC0tiUHATwCB8DMkAgTCEkAgwpaWxCDgJ3Chryf3RyECBCAQksDzGbI6HA7nO4oEjvgVvPCpwNFQAD71BJ2jXGI4AeIOgcgEEIjI1SU3CDgJIBBOgLhDIDIBBCJydckNAk4CCIQTIO4QiEwAgYhcXXKDgJMAAuEEiDsEIhNAICJXl9wg4CSAQDgB4g6ByAQQiMjVJTcIOAkgEE6AuEMgMgEEInJ1yQ0CTgIIhBMg7hCITACBiFxdcoOAkwAC4QSIOwQiE0AgIleX3CDgJIBAOAHiDoHIBBCIyNUlNwg4CSAQToC4QyAyAQQicnXJDQJOAgiEEyDuEIhMAIGIXF1yg4CTAALhBIg7BCITQCAiV5fcIOAkgEA4AeIOgcgEEIjI1SU3CDgJIBBOgLhDIDIBBCJydckNAk4CCIQTIO4QiEwAgYhcXXKDgJPAhb7+3Bkitvv7F4fYCW48u9+/MP9rU+B5bXCiscPhjO8xCdxZ479/OxEFwqyVwFnnz7nn57l/wXOJsdZpy+uCwAoIIBArKAIvAQJrJYBArLUyvC4IrIAAArGCIvASILBWAgjEWivD64LACgggECsoAi8BAmslgEA0Vebuf7vbu+OWsmk0U0C7Ed/SEMOZ+dij07YEEAhLo9h+82n3y793V4fi+19zWqOyaTRTQLsRv6eRMM7Mx9aCtiUwxx9K2eM90fbrq93d193nL937/+WL3c0/d1eX31LRnL7+1A1pe/G37t+jZoqggHYjvsU4Px9bC9qWwBx/av2k/1Ltr7ffcWWT+Ppqd3P7XRr0JlePNnW+u919+dq1NdetWSYunYXZiD8/H/2p9ZOen+f+S0oEwrxBx5pWIPpx+zZWj84akjTYAFYm1F+f+taR+JbGsD0tHwRiSNj2IBCWxkh7KBC9kaapThNeXn47axjxvO/qTjHuOgVJlyQly6yf+BmQbHcqPghEBjbbZQ0iA9K6qzd8y3u+u+j4cbmh8QDEr4M6N5/60bczyl2Mplpff9xd3nSLC5VNv9N0I6Nyp6P3VRCFUkC7ET/RWIRPOjqNjABnEBmQ4q4WHd986i4rhisOkoZ0I0P+wzsdfVBN/bR4OTwM8ZflM6wIPSLAGsSRaZDWIKwKpIVJ29kvQypcEou0MGmnfurMDmxDEX94+XYmPqxBZPMw20UgMiD5bhKIfsBO02Q6fM83mqUIqdHo2GiWwqZGo2OjWQqbGo2OjWYpbGo0OjaaKSwCkdiONrjEGMVS7NRvttt99/eU/WnCUBp6z0az4WEaHRvNiF8v05APPRkBBCID0rTbvz+PmjaaDeM0OjaaEX9IgJ5GAtzFaALVnTJ8zJ8gGnpqrUE/9U2PISmUAtqN+InGInzS0WlkBDiDyICM7+ruw4f/7t593r1+2d3FuPx7bmaXIUfvdMhBU19DCqJN9yz0+z9txBeKBfmkQtDICLBImQHJd/tFSjt3ZWFlwkqDliS0pQe30g3Rint/vIoB8YXofHxYpOxnYOlfBKJE5lu/vYuRTdNXP+1uv3x/KCs94pmWMBVCNyyvXnRnH/1mleVbl/mP+PPzQSDMBBxpIhAjUGyXFYi+P3sbl25kWJmQY33q2yMS39IYtqflg0AMCdseBMLSGGkPBaI36qfpq59/WEoY+ksmPvwxvmwxNLY9xLc0hu2p+CAQQ7a2h0VKS+OEttYpb349bq+VSLsYedzhwYL4DyTG/z83n/Gjbq+X25zbqzkZQ6CZAALRjApDCGyPAAKxvZqTMQSaCSAQzagwhMD2CMxxF+NJU33/4vCkXz8vvk5AdzHqBhsfneMuxpP+1OD3bzc+Q+Kn/6Tn57k/1ZpLjPhvADKEwKMJIBCPRocjBOIT2JpA6Gmq+weq4leWDCEwAYGtCcRrPRUxATZCQGAbBLYmEK92O/2wQQACTQQ2JRA6d7i8/+EkomlyYASBTQlEOndIjdYJoGcH9Vzm0U02jWYKaDfiWxpiODMfe3TalsB2BOKl+Q48fRmedk/Y9JU5+jqcyrdmaU5rVDaNZgpoN+L3NBLGmfnYWtC2BOb4Qyl7vOXa2VmDdu8/HLLtBb2+2t197T5LbvitWZrT6Zty9PlR2o6a6TNmFNBuxLcY5+dja0HbEpjjT61X8JdqurU5vELQesQXy2K0bT8wJpvE+mJefWDk8BMos0+RtGalj5/qD038+fnwgTGj0z51bkQgtCp5k3J+aFzvdu8e2sX/rUD0RvZtrJ70NXlZCCsTGqpPfetLfEtj2J6WDwIxJGx7NiIQf97fvLCJq61zin9kXcPdoUD0Npqm+qz6l5c7nSBUtu4U4667oDj1c6WIX6Gqoan4IBB1zltYg9B6pK4mhps6NXTCSoQN0fhZcp18VBXExrRt4lsaw/a5+QyPuM2eLdzF0KVEaasM/eCi78K6vDnyrVn6naYbGZU7HX1EnVMolALajfiJxiJ80tFpZATCn0FoeTK7f2EJaEgGx5cq5aPvwnrzqbus0PVCdlkhaUg3MmQ5vNPRHzJblbCvg/gisCyfrBzs9gTCr0H8a7ernyZo8fJNZTakNQirAmlh0nb2y5AKlcQiLUzaqZ86s4PaUMQfLtmciQ9rENk8zHbDC8SfhQWIxOHIUmUSiN7BTtMUYviebzRLEVKj0bHRLIVNjUbHRrMUNjUaHRvNUtjUaHRsNFNYBCKxHW3EvsTQFYRWIuubDGT2oW6URvWb7XbfLaH3pwlDaegtG81S2NRodGw0S2FTo9Gx0SyFTY1Gx0azFDY1Gh0bzVJYGiUCsQWi8aEsmbUKRM+xn38lpqm/0SzZp0ajY6NZCpsajY6NZilsajQ6NpqlsKnR6NholsLSGBJ4NuyK0qNbmI03GGUm49rWnTJ8zJ8gGjporUE/9U2PISmUAtqN+InGInzS0WlkBAKfQVRuXmQQtCvj2h9E6OaFvqH73efid/DaZcjROx06hqa+hhREm+6J6Pdb2ogvFAvySYWgkREIvEj5V5bqsd2LUYN+kdLOXZnZr+q20qAlCW3DpzMq7v1BKwbEF6Lz8WGRsp+BpX+jCsTowxclCH2/7oaOPJrRC0RvkU3TVz/tbr905wLa7GplWsJUv25YXr3ozj76zSrLty7zH/Hn54NAmAk40owqEP9pXoBIULR48EvaSQ0rEH1n9ja20pC81LAyod361LeOxLc "text/plain": [ "<PIL.Image.Image image mode=RGB size=352x288>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Computing new shield\n", "LOG: Starting with explicit model creation...\n" ] } ], "source": [ "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n", "\n", "def mask_fn(env: gym.Env):\n", " return env.create_action_mask()\n", "\n", "def nomask_fn(env: gym.Env):\n", " return [1.0] * 7\n", "\n", "def main():\n", " #env = \"MiniGrid-LavaSlipperyCliff-16x13-Slip10-Time-v0\"\n", " env = \"MiniGrid-WindyCity-Adv-v0\"\n", "\n", " formula = \"Pmax=? [G ! AgentIsOnLava]\"\n", " value_for_training = 0.99\n", " shield_comparison = \"absolute\"\n", " shielding = ShieldingConfig.Training\n", " \n", " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n", " \n", " env = gym.make(env, render_mode=\"rgb_array\")\n", " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n", " env = RGBImgObsWrapper(env, 8)\n", " env = ImgObsWrapper(env)\n", " env = MiniWrapper(env)\n", "\n", " \n", " env.reset()\n", " Image.fromarray(env.render()).show()\n", " \n", " shield_handlers = dict()\n", " if shield_needed(shielding):\n", " for value in [0.9, 0.95, 0.99, 1.0]:\n", " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n", " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", "\n", "\n", " shield_handlers[value] = shield_handler\n", " if shield_needed(shielding):\n", " for value in [0.9, 0.95, 0.99, 1.0]: \n", " create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n", " print(f\"The shield for shield_value = {value}\")\n", "\n", " if shielding == ShieldingConfig.Training:\n", " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n", " env = ActionMasker(env, mask_fn)\n", " print(\"Training with shield:\")\n", " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n", " elif shielding == ShieldingConfig.Disabled:\n", " env = ActionMasker(env, nomask_fn)\n", " else:\n", " assert(False) \n", " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n", " model.set_logger(logger)\n", " steps = 200\n", "\n", " assert(False)\n", " model.learn(steps,callback=[InfoCallback()])\n", "\n", "\n", "\n", "if __name__ == '__main__':\n", " print(\"Starting the training\")\n", " main()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }
|