Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using saved weights to continue learning #23

Open
abhinavrai44 opened this issue Jun 1, 2017 · 0 comments
Open

Error when using saved weights to continue learning #23

abhinavrai44 opened this issue Jun 1, 2017 · 0 comments

Comments

@abhinavrai44
Copy link

I am getting the following warning when I try to save the weights. Here I am loading the weights from a previously trained model.

{'warnflag': 1, 'task': 'STOP: TOTAL NO. of ITERATIONS EXCEEDS LIMIT', 'nit': 26, 'funcalls': 30}
got zero gradient. not updating

This is the code that I am using

 if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    update_argument_parser(parser, GENERAL_OPTIONS)
    parser.add_argument("--agent",required=True)
    parser.add_argument("--plot",action="store_true")
    parser.add_argument('--visualize', dest='visualize', action='store_true', default=False)
    args,_ = parser.parse_known_args([arg for arg in sys.argv[1:] if arg not in ('-h', '--help')])
    env = StandEnv(args.visualize)
    hdf = h5py.File('a.h5','r')
    snapnames = hdf['agent_snapshots'].keys()
    snapname = snapnames[-1]
    agent = cPickle.loads(hdf['agent_snapshots'][snapname].value)
    agent.stochastic=False
    env_spec = env.spec

    agent_ctor = get_agent_cls(args.agent)
    update_argument_parser(parser, agent_ctor.options)
    args = parser.parse_args()
    
    args.timestep_limit = 200
    cfg = args.__dict__
    np.random.seed(args.seed)
    if args.use_hdf:
        hdf, diagnostics = prepare_h5_file(args)
    gym.logger.setLevel(logging.WARN)

    COUNTER = 0
    def callback(stats):
        global COUNTER
        COUNTER += 1
        # Print stats
        print "*********** Iteration %i ****************" % COUNTER
        print tabulate(filter(lambda (k,v) : np.asarray(v).size==1, stats.items())) #pylint: disable=W0110
        # Store to hdf5
        if args.use_hdf:
            for (stat,val) in stats.items():
                if np.asarray(val).ndim==0:
                    diagnostics[stat].append(val)
                else:
                    assert val.ndim == 1
                    diagnostics[stat].extend(val)
            if args.snapshot_every and ((COUNTER % args.snapshot_every==0) or (COUNTER==args.n_iter)):
                hdf['/agent_snapshots/%0.4i'%COUNTER] = np.array(cPickle.dumps(agent,-1))
        # Plot
        if args.plot:
            animate_rollout(env, agent, min(500, args.timestep_limit))

    run_policy_gradient_algorithm(env, agent, callback=callback, usercfg = cfg)

    if args.use_hdf:
        hdf['env_id'] = env_spec.id
        try: hdf['env'] = np.array(cPickle.dumps(env, -1))
        except Exception: print "failed to pickle env" #pylint: disable=W0703
    env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant