Thanks for the reply. Here is the code I am working on
target_log_prob_fn=model.joint_distribution(
observed_time_series=df_train["coverage"]).log_prob,
surrogate_posterior=variational_posteriors,
optimizer=tf.optimizers.Adam(learning_rate=0.1),
num_steps=num_variational_steps,
jit_compile=True)
And the error is pretty large but here it is below. The op __inference_run_jitted_minimize_26194 seems to be the culprit. What do you think?
2022-06-27 10:58:32.006790: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at xla_ops.cc:296 : UNIMPLEMENTED: Could not find compiler for platform METAL: NOT_FOUND: could not find registered compiler for platform METAL -- check target linkage
---------------------------------------------------------------------------
UnimplementedError Traceback (most recent call last)
/Users/joseph/Downloads/user_model.ipynb Cell 16' in <cell line: 8>()
5 num_variational_steps = int(num_variational_steps)
7 # Build and optimize the variational loss function.
----> 8 elbo_loss_curve = tfp.vi.fit_surrogate_posterior(
9 target_log_prob_fn=model.joint_distribution(
10 observed_time_series=df_train["coverage"]).log_prob,
11 surrogate_posterior=variational_posteriors,
12 optimizer=tf.optimizers.Adam(learning_rate=0.1),
13 num_steps=num_variational_steps,
14 jit_compile=True)
16 fig, ax = plt.subplots(figsize=(12, 8))
17 ax.plot(elbo_loss_curve, marker='.')
File /opt/homebrew/Caskroom/miniforge/base/envs/mlp/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py:561, in deprecated_args.<locals>.deprecated_wrapper.<locals>.new_func(*args, **kwargs)
553 _PRINTED_WARNING[(func, arg_name)] = True
554 logging.warning(
555 'From %s: calling %s (from %s) with %s is deprecated and will '
556 'be removed %s.\nInstructions for updating:\n%s',
(...)
559 'in a future version' if date is None else ('after %s' % date),
560 instructions)
--> 561 return func(*args, **kwargs)
File /opt/homebrew/Caskroom/miniforge/base/envs/mlp/lib/python3.8/site-packages/tensorflow_probability/python/vi/optimization.py:751, in fit_surrogate_posterior(target_log_prob_fn, surrogate_posterior, optimizer, num_steps, convergence_criterion, trace_fn, variational_loss_fn, discrepancy_fn, sample_size, importance_sample_size, trainable_variables, jit_compile, seed, name)
744 def complete_variational_loss_fn(seed=None):
745 return variational_loss_fn(
746 target_log_prob_fn,
747 surrogate_posterior,
748 sample_size=sample_size,
749 seed=seed)
--> 751 return tfp_math.minimize(complete_variational_loss_fn,
752 num_steps=num_steps,
753 optimizer=optimizer,
754 convergence_criterion=convergence_criterion,
755 trace_fn=trace_fn,
756 trainable_variables=trainable_variables,
757 jit_compile=jit_compile,
758 seed=seed,
759 name=name)
File /opt/homebrew/Caskroom/miniforge/base/envs/mlp/lib/python3.8/site-packages/tensorflow_probability/python/math/minimize.py:610, in minimize(loss_fn, num_steps, optimizer, convergence_criterion, batch_convergence_reduce_fn, trainable_variables, trace_fn, return_full_length_trace, jit_compile, seed, name)
442 def minimize(loss_fn,
443 num_steps,
444 optimizer,
(...)
451 seed=None,
452 name='minimize'):
453 """Minimize a loss function using a provided optimizer.
454
455 Args:
(...)
608
609 """
--> 610 _, traced_values = _minimize_common(
611 num_steps=num_steps,
612 optimizer_step_fn=_make_stateful_optimizer_step_fn(
613 loss_fn=loss_fn,
614 optimizer=optimizer,
615 trainable_variables=trainable_variables),
616 initial_parameters=(),
617 initial_optimizer_state=(),
618 convergence_criterion=convergence_criterion,
619 batch_convergence_reduce_fn=batch_convergence_reduce_fn,
620 trace_fn=trace_fn,
621 return_full_length_trace=return_full_length_trace,
622 jit_compile=jit_compile,
623 seed=seed,
624 name=name)
625 return traced_values
File /opt/homebrew/Caskroom/miniforge/base/envs/mlp/lib/python3.8/site-packages/tensorflow_probability/python/math/minimize.py:134, in _minimize_common(num_steps, optimizer_step_fn, initial_parameters, initial_optimizer_state, convergence_criterion, batch_convergence_reduce_fn, trace_fn, return_full_length_trace, jit_compile, seed, name)
131 @tf.function(autograph=False, jit_compile=True)
132 def run_jitted_minimize():
133 return _minimize_common(**kwargs)
--> 134 return run_jitted_minimize()
136 # Main optimization routine.
137 with tf.name_scope(name) as name:
File /opt/homebrew/Caskroom/miniforge/base/envs/mlp/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
File /opt/homebrew/Caskroom/miniforge/base/envs/mlp/lib/python3.8/site-packages/tensorflow/python/eager/execute.py:54, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
52 try:
53 ctx.ensure_initialized()
---> 54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
57 if name is not None:
UnimplementedError: Could not find compiler for platform METAL: NOT_FOUND: could not find registered compiler for platform METAL -- check target linkage [Op:__inference_run_jitted_minimize_26194]